Leverage broadcasting to make this subtraction more efficient

60 views Asked by At

I have an array x of shape (N, T, d). I have two functions f and g which both take an array of shape (some_dimension, d) and return an array of shape (some_dimension, ).

I would like to compute f on all of x. This is simple: f(x.reshape(-1, d)).

I would then like to compute g only on the first slice of the second dimension, meaning g(x[:, 0, :]) and subtract this to the evaluation of f on for all dimensions. This is exemplified in the code

MWE - Inefficient Way

import numpy as np

# Reproducibility
seed = 1234
rng = np.random.default_rng(seed=seed)

# Generate x
N = 100
T = 10
d = 2
x = rng.normal(loc=0.0, scale=1.0, size=(N, T, d))

# In practice the functions are not this simple
def f(x):
    return x[:, 0] + x[:, 1]

def g(x):
    return x[:, 0]**2 - x[:, 1]**2

# Compute f on all the (flattened) array
fx = f(x.reshape(-1, d)).reshape(N, T)

# Compute g only on the first slice of second dimension. Here are two ways of doing so
gx = np.tile(g(x[:, 0])[:, None], reps=(1, T))
gx = np.repeat(g(x[:, 0]), axis=0, repeats=T).reshape(N, T)

# Finally compute what I really want to compute
diff = fx - gx

Is there a more efficient way? I feel that using broadcasting there must be, but I cannot figure it out.

1

There are 1 answers

0
hpaulj On BEST ANSWER

Reducing the size of the example so we can examine (5,4) arrays:

In [138]: 
     ...: # Generate x
     ...: N = 5
     ...: T = 4
     ...: d = 2
     ...: x = np.arange(40).reshape(N,T,d) #(rng.normal(loc=0.0, scale=1.0, size=(N, T, d))
     ...: 
     ...: # In practice the functions are not this simple
     ...: def f(x):
     ...:     return x[:, 0] + x[:, 1]
     ...: 
     ...: def g(x):
     ...:     return x[:, 0]**2 - x[:, 1]**2
     ...: 
     ...: # Compute f on all the (flattened) array
     ...: fx = f(x.reshape(-1, d)).reshape(N, T)
     ...: 
     ...: # Compute g only on the first slice of second dimension. Here are two ways of doing so
     ...: gx1 = np.tile(g(x[:, 0])[:, None], reps=(1, T))
     ...: gx2 = np.repeat(g(x[:, 0]), axis=0, repeats=T).reshape(N, T)

In [139]: fx.shape,gx1.shape,gx2.shape
Out[139]: ((5, 4), (5, 4), (5, 4))

All the elements of fx differ, so no further 'broadcasting' is possible.

In [140]: fx
Out[140]: 
array([[ 1,  5,  9, 13],
       [17, 21, 25, 29],
       [33, 37, 41, 45],
       [49, 53, 57, 61],
       [65, 69, 73, 77]])

Your use of tile and repeat do the same thing. tile uses repeat, so doesn't add anything:

In [141]: gx1
Out[141]: 
array([[ -1,  -1,  -1,  -1],
       [-17, -17, -17, -17],
       [-33, -33, -33, -33],
       [-49, -49, -49, -49],
       [-65, -65, -65, -65]])

In [142]: gx2
Out[142]: 
array([[ -1,  -1,  -1,  -1],
       [-17, -17, -17, -17],
       [-33, -33, -33, -33],
       [-49, -49, -49, -49],
       [-65, -65, -65, -65]])

gx just repeats the 5 g() values 4 times.

In [143]: g(x[:, 0])
Out[143]: array([ -1, -17, -33, -49, -65])

In [144]: fx-gx1
Out[144]: 
array([[  2,   6,  10,  14],
       [ 34,  38,  42,  46],
       [ 66,  70,  74,  78],
       [ 98, 102, 106, 110],
       [130, 134, 138, 142]])

So gx can be replaced with a (5,1) array, which broadcasts with the (5,4) fx:

In [145]: fx-g(x[:,0])[:,None]
Out[145]: 
array([[  2,   6,  10,  14],
       [ 34,  38,  42,  46],
       [ 66,  70,  74,  78],
       [ 98, 102, 106, 110],
       [130, 134, 138, 142]])

I haven't tried to make more sense of the T versus d dimensions that I commented on.

This answer may be too wordy, but it illustrates the way I visualized and discovered a broadcasting fix.