1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94
| def batchnorm_forward(x, gamma, beta, bn_param): """ Forward pass for batch normalization.
During training the sample mean and (uncorrected) sample variance are computed from minibatch statistics and used to normalize the incoming data. During training we also keep an exponentially decaying running mean of the mean and variance of each feature, and these averages are used to normalize data at test-time.
At each timestep we update the running averages for mean and variance using an exponential decay based on the momentum parameter:
running_mean = momentum * running_mean + (1 - momentum) * sample_mean running_var = momentum * running_var + (1 - momentum) * sample_var
Note that the batch normalization paper suggests a different test-time behavior: they compute sample mean and variance for each feature using a large number of training images rather than using a running average. For this implementation we have chosen to use running averages instead since they do not require an additional estimation step; the torch7 implementation of batch normalization also uses running averages.
Input: - x: Data of shape (N, D) - gamma: Scale parameter of shape (D,) - beta: Shift paremeter of shape (D,) - bn_param: Dictionary with the following keys: - mode: 'train' or 'test'; required - eps: Constant for numeric stability - momentum: Constant for running mean / variance. - running_mean: Array of shape (D,) giving running mean of features - running_var Array of shape (D,) giving running variance of features
Returns a tuple of: - out: of shape (N, D) - cache: A tuple of values needed in the backward pass """ mode = bn_param['mode'] eps = bn_param.get('eps', 1e-5) momentum = bn_param.get('momentum', 0.9)
N, D = x.shape running_mean = bn_param.get('running_mean', np.zeros(D, dtype=x.dtype)) running_var = bn_param.get('running_var', np.zeros(D, dtype=x.dtype))
out, cache = None, None if mode == 'train': sample_mean = np.mean(x, axis=0) sample_var = np.var(x, axis=0) x_hat = (x - sample_mean) / (np.sqrt(sample_var + eps)) out = gamma * x_hat + beta cache = (gamma, x, sample_mean, sample_var, eps, x_hat) running_mean = momentum * running_mean + (1 - momentum) * sample_mean running_var = momentum * running_var + (1 - momentum) * sample_var elif mode == 'test': scale = gamma / np.sqrt(running_var + eps) out = x * scale + (beta - running_mean * scale) else: raise ValueError('Invalid forward batchnorm mode "%s"' % mode)
bn_param['running_mean'] = running_mean bn_param['running_var'] = running_var
return out, cache
|