What is the recommended approach to implement array behaviour/methods on irregular/inhomogeneous data (possesses some inherient dimensionality) within JAX?
Two principle options come to mind:
- make homogeneous and use a mask
- flatten and implement custom methods (i.e. broadcasting and reduction)
Clearly option 1 is favourable as this requires less implementation overhead (and consequently validation/testing). The concern is memory complexity - in situations where this is paramount (to avoid having to distribute the array) is there a better alternative to option 2 (that can exploit the highly optimised array methods)?
EDIT: The following implements a concrete example which contains sparsity.
import jax as jx
import jax.numpy as jnp
jx.config.update("jax_enable_x64", True)
# Problem specific variables (static)
n_vars = 3 # Number of variable sets
n_smps = 10 # Maximum number of set elements
p_smps = 0.2 # Representation of problem sparsity
# Each set contains a differing number of elements (binomial random for example)
n_lvls = jx.random.bernoulli(
jx.random.PRNGKey(0),
p_smps,
(n_vars, n_smps)
).sum(axis=1, dtype='i4')
# Derived quantities depend on constant coefficients (uniform random for example)
a_vars = jx.random.uniform(jx.random.PRNGKey(1), (n_vars, ), dtype='f8')
b_vars = jx.random.uniform(jx.random.PRNGKey(2), (n_vars, ), dtype='f8')
b_vars = 10.0*b_vars
c_vars = jx.random.uniform(jx.random.PRNGKey(3), (n_vars, ), dtype='f8')
c_vars = 2.0*c_vars
The problem is intrinsically represented with a 7 element state. What follows is one implementation of option 1
### Homogeneous with mask ###
# Define the level index array
i_smps = jnp.arange(n_smps, dtype='i4')
mask = n_lvls[:,None]>i_smps[None,:]
# Generate an initial state that respects the unity axiom
x_vars = 1.0/(1.0+n_lvls[:,None]*i_smps[None,:]).astype('f8')
x_vars = jnp.where(mask, x_vars, 0.0)
x_vars = x_vars/x_vars.sum()
# Generate a coefficient tensor
P_vars = a_vars[:,None]+b_vars[:,None]*i_smps[None,:]
P_vars = jnp.where(mask, P_vars, 0.0)
# Determine a scalar moment
scalar_moment = (x_vars*c_vars[:,None]).sum()
# >>> DeviceArray(0.66574861, dtype=float64)
# Determine a transition tensor
trans_tens = (P_vars[:,:,None,None]-P_vars[None,None,:,:])
trans_tens = trans_tens*x_vars[None,None,:,:]*x_vars[:,:,None,None]
trans_tens.sum(axis=(2,3))
# >>> DeviceArray([[-0.37032842, 0.16153429, 0.22063015, 0.24335933, ...
Ensuring homogeneity increases this to 30. Furthermore, computing derived quantities involves numerous multiply by zero operations.
This is one approach to flatten and reduce sparsity using Boolean indexing.
There may be a more memory efficient way of applying Boolean indexing to determine
iandj.