How to do argmax in group in pytorch?

445 views Asked by At

Is there any ways to implement maxpooling according to norm of sub vectors in a group in Pytorch? Specifically, this is what I want to implement:

Input:

x: a 2-D float tensor, shape #Nodes * dim

cluster: a 1-D long tensor, shape #Nodes

Output:

y, a 2-D float tensor, and:

y[i]=x[k] where k=argmax_{cluster[k]=i}(torch.norm(x[k],p=2)).

I tried torch.scatter with reduce="max", but this only works for dim=1 and x[i]>0.

Can someone help me to solve the problem?

1

There are 1 answers

1
jodag On

I don't think there's any built-in function to do what you want. Basically this would be some form of scatter_reduce on the norm of x, but instead of selecting the max norm you want to select the row corresponding to the max norm.

A straightforward implementation may look something like this

"""
input
    x: float tensor of size [NODES, DIMS]
    cluster: long tensor of size [NODES]
output
    float tensor of size [cluster.max()+1, DIMS]
"""

num_clusters = cluster.max().item() + 1
y = torch.zeros((num_clusters, DIMS), dtype=x.dtype, device=x.device)
for cluster_id in torch.unique(cluster):
    x_cluster = x[cluster == cluster_id]
    y[cluster_id] = x_cluster[torch.argmax(torch.norm(x_cluster, dim=1), dim=0)]

Which should work just fine if clusters.max() is relatively small. If there are many clusters though then this approach has to unnecessarily create masks over cluster for every unique cluster id. To avoid this you can make use of argsort. The best I could come up with in pure python was the following.

num_clusters = cluster.max().item() + 1
x_norm = torch.norm(x, dim=1)

cluster_sortidx = torch.argsort(cluster)
cluster_ids, cluster_counts = torch.unique_consecutive(cluster[cluster_sortidx], return_counts=True)

end_indices = torch.cumsum(cluster_counts, dim=0).cpu().tolist()
start_indices = [0] + end_indices[:-1]

y = torch.zeros((num_clusters, DIMS), dtype=x.dtype, device=x.device)
for cluster_id, a, b in zip(cluster_ids, start_indices, end_indices):
    indices = cluster_sortidx[a:b]
    y[cluster_id] = x[indices[torch.argmax(x_norm[indices], dim=0)]]

For example in random tests with NODES = 60000, DIMS = 512, cluster.max()=6000 the first version takes about 620ms whie the second version takes about 78ms.