Is there any ways to implement maxpooling according to norm of sub vectors in a group in Pytorch? Specifically, this is what I want to implement:
Input:
x: a 2-D float tensor, shape #Nodes * dim
cluster: a 1-D long tensor, shape #Nodes
Output:
y, a 2-D float tensor, and:
y[i]=x[k] where k=argmax_{cluster[k]=i}(torch.norm(x[k],p=2)).
I tried torch.scatter with reduce="max", but this only works for dim=1 and x[i]>0.
Can someone help me to solve the problem?
I don't think there's any built-in function to do what you want. Basically this would be some form of scatter_reduce on the norm of
x, but instead of selecting the max norm you want to select the row corresponding to the max norm.A straightforward implementation may look something like this
Which should work just fine if
clusters.max()is relatively small. If there are many clusters though then this approach has to unnecessarily create masks overclusterfor every unique cluster id. To avoid this you can make use ofargsort. The best I could come up with in pure python was the following.For example in random tests with
NODES = 60000,DIMS = 512,cluster.max()=6000the first version takes about 620ms whie the second version takes about 78ms.