I am implementing a version of FastGAN, and it seems like (see the official repo) they use a Spectral norm directly followed by a Batch norm. Doesn't the spectral norm get fully cancelled by the following Batch Norm?
I believe it does, and have coded up a simple demonstration, could someone please confirm I am not missing something basic and that indeed the authors made a mistake?
The reasoning
- Spectral normalization takes a linear operator
Oand rescales it by it's inverse spectral norm1 / sigma(O) - Batch norm (without learnable
beta,gamma) effectively shifts and scales the operator after which it is applied:BN(O(x)) = (O(x) - E[O(x)]) / (std[O(x)])- so
BN(k*O(x)) = (O(x))for any inputsxand factork
- so
My simple demonstration:
import torch
import torch.nn as nn
# BatchNormed Convolution
conv = nn.Conv2d(3,1,3,1,0, bias=False)
c_dict = dict(conv.named_parameters())
c_dict["weight"].data = torch.ones_like(c_dict["weight"].data)
bn1 = nn.BatchNorm2d(1, affine=False)
# BatchNormed SpectralNormed
conv2 = nn.Conv2d(3,1,3,1,0, bias=False)
c_dict = dict(conv2.named_parameters())
c_dict["weight"].data = torch.ones_like(c_dict["weight"].data)
# Spectral Normalize the 2nd conv
sn_conv = spectral_norm(conv2)
bn2 = nn.BatchNorm2d(1, affine=False)
# Now feed an image into both conv & sn_conv
im = torch.randn(size=(2,3,4,4))
cim = conv(im)
sncim = sn_conv(im)
bn1.train()
bn2.train()
# Feed the conv & sn_conv outputs through their batchnorm 1000 times
for i in range(1000):
cimbn = bn1(cim)
sncimbn = bn2(sncim)
assert torch.isclose(cimbn, sncimbn)
>>> Gives the same result
I believe the only effective difference between the two (BN(SN(Conv)) & BN(Conv)) is that they will get initialized to slightly different functions?