GPU allocation explodes when logging scalar to flax's tensorboard

35 views Asked by At

I noticed that when I use flax's tensorboard from flax.metrics import tensorboard to log the loss, the GPU allocation explodes.

To compute the loss metrics I use the has_aux as explained in here.

These are the functions I am using. In the model class:

@partial(jit, static_argnums=(0,))
def compute_losses(self, params, x_train, y_train) -> dict:
    # data fit loss
    y_pred = self.forward(params, x_train)
    rec_loss =  jnp.mean(jnp.abs(y_pred - y_train)**2)
    # reguralizations
    # ...
    loss_dict = {"rec":rec_loss}
    return loss, loss_dict
    
@partial(jit, static_argnums=(0,))
def grad_loss(self, params, x_rec, y_rec):
    return jax.grad(self.compute_losses, has_aux=True)(params, x_rec, y_rec)

In the main:

grads, loss_dict = model.grad_loss(params, input, target)
updates, opt_state = optimizer.update(grads, opt_state)
params = optax.apply_updates(params, updates)
# log loss
if step % config.logging.log_every_steps == 0 and config.logging.log_loss:
    # compute loss
    loss_dict = model.losses_no_grad(params, input, target)
    for term in loss_dict.keys():
        writer.scalar(f'loss/{term}', loss_dict[term], step)

Without the writer, the GPU consumption is around 25%. If activated, it goes up to 100%. Why is it happening?

EDIT: This does not happen with from torch.utils import tensorboard. You can see the GPU consumption here

0

There are 0 answers