I am implementing the dilated k-nearest neighbors algorithm. The algorithm unfortunately has nested loops. The presence of loops severely hampers the execution speed.
import torch
dilation=3
nbd_size=5
knn_key=torch.randint([0,30,(64,12,198,100)])
dilated_keys=torch.zeros([knn_key.shape[0],knn_key.shape[1],knn_key.shape[2],nbd_size])
for i in range(knn_key.shape[0]):
for j in range(knn_key.shape[1]):
for k in range(knn_key.shape[2]):
list_indices=[]
while (len(list_indices))<nbd_size:
for l in range(knn_key.shape[3]):
if knn_key[i][j][k][l]%dilation==k%dilation:
list_indices.append(knn_key[i][j][k][l])
if (len(list_indices))>=nbd_size:
break
list_indices_tensor=torch.tensor(list_indices)
dilated_keys[i][j][k]=list_indices_tensor
The variable knn_key stores the 100 nearest neighbours among the 1000 data points originally available. The dilated_keys stores the nbd_size=5 selected indices of the neighbours that are used after applying dialation filter. Any help to use broadcasting solution to remove the three nested loops will be highly helpful.
You can reduce the inner loops and conditions as follows: