Is this iter() redundant for the single line code in cs231n?

50 views Asked by At

This is the code for cs231n, I wonder in the return line of the __iter__ function , can we replace return iter((self.X[i:i+B], self.y[i:i+B]) for i in range(0, N, B)) with return ((self.X[i:i+B], self.y[i:i+B]) for i in range(0, N, B))? since ((self.X[i:i+B], self.y[i:i+B]) for i in range(0, N, B)) is also a generator object. Actually I don't quite understand the function of iter() here, does it just call the function __iter__ of a generator object and just return itself? I am confuesd.

class Dataset(object):
def __init__(self, X, y, batch_size, shuffle=False):
    """
    Construct a Dataset object to iterate over data X and labels y
    
    Inputs:
    - X: Numpy array of data, of any shape
    - y: Numpy array of labels, of any shape but with y.shape[0] == X.shape[0]
    - batch_size: Integer giving number of elements per minibatch
    - shuffle: (optional) Boolean, whether to shuffle the data on each epoch
    """
    assert X.shape[0] == y.shape[0], 'Got different numbers of data and labels'
    self.X, self.y = X, y
    self.batch_size, self.shuffle = batch_size, shuffle

def __iter__(self):
    N, B = self.X.shape[0], self.batch_size
    idxs = np.arange(N)
    if self.shuffle:
        np.random.shuffle(idxs)
    return iter((self.X[i:i+B], self.y[i:i+B]) for i in range(0, N, B))

train_dset = Dataset(X_train, y_train, batch_size=64, shuffle=True)
val_dset = Dataset(X_val, y_val, batch_size=64, shuffle=False)
test_dset = Dataset(X_test, y_test, batch_size=64)
1

There are 1 answers

0
Brian61354270 On

Yes, calling iter is redundant. You can easily check that calling iter on a generator object just returns the generator object itself in a Python shell:

>>> gen = (x for x in [])
>>> gen is iter(gen)
True

More formally, from the Python glossary you can verify that generator expressions return iterators, and that all iterators return themselves when their __iter__ method is invoked:

Iterators are required to have an __iter__() method that returns the iterator object itself so every iterator is also iterable and may be used in most places where other iterables are accepted.