How to release CPU memory in pytorch? (for large-scale inference)

2k views Asked by At

I'm trying to do large-scale inference of a pretrained BERT model on a single machine and I'm running into CPU out-of-memory errors. Since the dataset is too big to score the model on the whole dataset at once, I'm trying to run it in batches, store the results in a list, and then concatenate those tensors together at the end. I understand that storing tensors in lists can quickly use up large amounts of CPU memory.

However, I am unable to figure out how to release this memory after the tensors are concatenated and therefore I'm running into OOM errors downstream.

Minimal Reproducible Example:

import gc, time, torch, pytorch_lightning as pl
from transformers import BertTokenizer, BertModel
from torch.utils.data import DataLoader

class EncoderModelPL(pl.LightningModule):
    def __init__(
        self,
        model: BertModel,
    ):
        super(EncoderModelPL, self).__init__()
        self.model: BertModel = model

    def forward(self, x):
        return self.model(x, output_hidden_states=True) # yes, I know this uses a ton of memory but my downstream application requires intermediate hidden states
    
MODEL_ID = "bert-base-uncased"
tokenizer = BertTokenizer.from_pretrained(MODEL_ID)
model = EncoderModelPL(BertModel.from_pretrained(MODEL_ID)).to("cuda")

dataset = torch.randint(low=0, high=30000, size=(800, 200), device="cuda")
tokens_dataloader = DataLoader(dataset, batch_size=32, shuffle=False) 

trainer = pl.Trainer(accelerator="gpu")
bert_outputs_per_batch: list = trainer.predict(
    model=model, dataloaders=tokens_dataloader
) # CPU memory steadily increases here, this outputs a num-batches-length list containing the Bert output for each batch, stored in CPU
del bert_outputs_per_batch
gc.collect()

...<more downstream stuff>...

A few additional notes on what I've tried:

  • As shown in the code, I've tried deleting the list of outputs and then garbage collecting.
  • I also tried sleeping for a few seconds after that in case that helps the garbage collection (it doesn't).
  • I don't believe this is an issue caused by tracking the computation graph, I've confirmed the same behavior if I add .detach() to dataset.
  • I've tried playing with pin_memory=False in the Dataloader, it appears to make no difference.

In summary, my questions are:

  • How can I force the release of tensors from CPU memory?
  • More broadly, is there a better (e.g. more memory-efficient) approach to large-scale inference than batching?

Thanks for the help!

1

There are 1 answers

0
blue_lama On

The memory used is a constant plus a something proportional to the batch size.
memory = a + (b * batch_size)

First try with batch_size=1, if it works try with 2, 4, 8.