How to give custom activation function to Conv2DNormActivation

142 views Asked by At

As it seems torchvision.ops.Conv2dNormActivation function only takes the Activation functions defined under torch.nn due to the declaration of the activation_layer argument as Callable[..., torch.nn.Module] in the source code.

I tried defining a custom activation function (example) like

class ExpExp(nn.Module):
    __constants__ = ['inplace']
    inplace: bool
    def __init__(self,inplace: bool = False):
        super(TanhExp, self).__init__()
  
    def forward(self, x):
        return x*torch.exp(torch.exp(x))
    
    def extra_repr(self): 
        inplace_str = 'inplace=True' if self.inplace else ''
        return inplace_str

Has called it as

Conv2dNormActivation(
                    in_channels,
                    out_channels,
                    kernel_size=(h,w),
                    stride=stride,
                    padding = padding,
                    norm_layer=norm_layer,
                    activation_layer=ExpExp(),
                    inplace=None
                ))

It returned an error TypeError: forward() missing 1 required positional argument: 'x'

I think I have to give input x to the ExpExp() activation function when calling it in the Conv2dNoemActivation. But how do I do that? Is there a way we can specify custom activation function?

0

There are 0 answers