Based on this answer I am trying to make a class jit compatible by creating a pytree node, but I get:
TypeError: Cannot interpret value of type <class '__main__.TestModel'> as an abstract array; it does not have a dtype attribute
The error line is in the fit function when calling self.step.
Is there anything wrong with my implementation?
import jax
import flax.linen as nn
import optax
from jax.tree_util import register_pytree_node_class
from dataclasses import dataclass
from typing import Callable
def data_loader(X, Y, batch_size):
for i in range(0, len(X), batch_size):
yield X[i : i + batch_size], Y[i : i + batch_size]
@register_pytree_node_class
@dataclass
class Parent(nn.Module):
key: jax.random.PRNGKey
params: dict = None
@jax.jit
def step(self, loss_fn, optimizer, opt_state, x, y):
loss, grads = jax.value_and_grad(loss_fn)(y, self.predict(x))
opt_grads, opt_state = optimizer.update(grads, opt_state)
params = optax.apply_updates(self.params, opt_grads)
return params, opt_state, loss
@jax.jit
def predict(self, x):
return self.apply(self.params, x)
def fit(
self,
X,
Y,
optimizer: Callable,
loss: Callable,
batch_size=32,
epochs=10,
verbose=True,
):
opt_state = optimizer.init(self.params)
self.params = self.init(self.key, X)
history = []
for i in range(epochs):
epoch_loss = 0
for x, y in data_loader(X, Y, batch_size):
self.params, opt_state, loss_value = self.step(
loss, optimizer, opt_state, x, y
)
epoch_loss += loss_value
history.append(epoch_loss / (len(X) // batch_size))
if verbose:
print(f"Epoch {i+1}/{epochs} - loss: {history[-1]}")
return history
def tree_flatten(self):
return (self.params,), None
@classmethod
def tree_unflatten(cls, aux_data, children):
return cls(*children, aux_data)
class TestModel(Parent):
d_hidden: int = 64
d_out: int = 1
@nn.compact
def __call__(self, x):
x = nn.Dense(self.d_hidden)(x)
x = nn.relu(x)
x = nn.Dense(self.d_out)(x)
x = nn.sigmoid(x)
return x
x_train = jax.random.normal(jax.random.PRNGKey(0), (209, 12288))
y_train = jax.random.randint(jax.random.PRNGKey(0), (209, 1), 0, 2)
model = TestModel(key=jax.random.PRNGKey(0))
model.fit(
x_train,
y_train,
optimizer=optax.adam(1e-3),
loss=optax.sigmoid_binary_cross_entropy,
)