pytorch split array by list of indices

62 views Asked by At

I want to split a torch array by a list of indices.

For example say my input array is torch.arange(20)

tensor([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 16, 17,
        18, 19])

and my list of indices is splits = [1,2,5,10]

Then my result would be:

(tensor([0]),
 tensor([1, 2]),
 tensor([3, 4, 5, 6, 7]),
 tensor([ 8,  9, 10, 11, 12, 13, 14, 15, 16, 17]))

assume my input array is always long enough to bigger than the sum of my list of indices.

4

There are 4 answers

0
mozway On BEST ANSWER

You could use tensor_split on the cumulated sum of the splits (e.g. with np.cumsum), excluding the last chunk:

import torch
import numpy as np

t = torch.arange(20)
splits = [1,2,5,10]

t.tensor_split(np.cumsum(splits).tolist())[:-1]

Output:

(tensor([0]),
 tensor([1, 2]),
 tensor([3, 4, 5, 6, 7]),
 tensor([ 8,  9, 10, 11, 12, 13, 14, 15, 16, 17]),
)
0
Anthony Tac On

I used numpy for my example but it might work with tensors as well.

import numpy as np

array = np.array([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 16, 17,
        18, 19])

def split(array, split_list: list[int]):
    if sum(split_list) > len(array):
        return None
    tup = ()
    for i in split_list:
        tup = tup + (array[:i].tolist(),)
        array = array[i:]
    return tup

print(split(array, [1, 2, 5]))

And i got this as output:

([0], [1, 2], [3, 4, 5, 6, 7])

I hope this is what you meant?

0
Timeless On

Another possible option would be to slice the tensor first, then split it :

import torch

t = torch.arange(20)

splits = [1, 2, 5, 10]

out = torch.split(t[: sum(splits)], splits)

Output :

(tensor([0]),
 tensor([1, 2]),
 tensor([3, 4, 5, 6, 7]),
 tensor([ 8,  9, 10, 11, 12, 13, 14, 15, 16, 17]))
0
Siwar Ayachi On

You can achieve this using PyTorch's torch.split function along with list comprehension to split the array according to the provided indices.

import torch

input_array = torch.arange(20)
splits = [1, 2, 5, 10]

result = [input_array.split(split) for split in splits]
print(tuple(result))