SpectralNorm gets cancelled when followed by BatchNorm?

92 views Asked by At

I am implementing a version of FastGAN, and it seems like (see the official repo) they use a Spectral norm directly followed by a Batch norm. Doesn't the spectral norm get fully cancelled by the following Batch Norm?

I believe it does, and have coded up a simple demonstration, could someone please confirm I am not missing something basic and that indeed the authors made a mistake?

The reasoning

  • Spectral normalization takes a linear operator O and rescales it by it's inverse spectral norm 1 / sigma(O)
  • Batch norm (without learnable beta,gamma) effectively shifts and scales the operator after which it is applied: BN(O(x)) = (O(x) - E[O(x)]) / (std[O(x)])
    • so BN(k*O(x)) = (O(x)) for any inputs x and factor k

My simple demonstration:

import torch
import torch.nn as nn
# BatchNormed Convolution
conv = nn.Conv2d(3,1,3,1,0, bias=False)
c_dict = dict(conv.named_parameters())
c_dict["weight"].data = torch.ones_like(c_dict["weight"].data)
bn1 = nn.BatchNorm2d(1, affine=False)

# BatchNormed SpectralNormed
conv2 = nn.Conv2d(3,1,3,1,0, bias=False)
c_dict = dict(conv2.named_parameters())
c_dict["weight"].data = torch.ones_like(c_dict["weight"].data)

# Spectral Normalize the 2nd conv
sn_conv = spectral_norm(conv2)

bn2 = nn.BatchNorm2d(1, affine=False)

# Now feed an image into both conv & sn_conv
im = torch.randn(size=(2,3,4,4))
cim = conv(im)
sncim = sn_conv(im)

bn1.train()
bn2.train()

# Feed the conv & sn_conv outputs through their batchnorm 1000 times
for i in range(1000):
    cimbn = bn1(cim)
    sncimbn = bn2(sncim)

assert torch.isclose(cimbn, sncimbn)
>>> Gives the same result

I believe the only effective difference between the two (BN(SN(Conv)) & BN(Conv)) is that they will get initialized to slightly different functions?

0

There are 0 answers