PyTorch: redundancy between map_location and .to(device)?

30 views Asked by At

The PyTorch documentation on torch.save and torch.load contains the following section:

Save on CPU, Load on GPU

Save:

torch.save(model.state_dict(), PATH)

Load:

device = torch.device("cuda")
model = TheModelClass(*args, **kwargs)
model.load_state_dict(torch.load(PATH, map_location="cuda:0"))  # Choose whatever GPU device number you want
model.to(device)

When loading a model on a GPU that was trained and saved on CPU, set the map_location argument in the torch.load() function to cuda:device_id. This loads the model to a given GPU device. Next, be sure to call model.to(torch.device('cuda')) to convert the model’s parameter tensors to CUDA tensors.

Isn't calling .to(device) in this example redundant since calling load with map_location must have already placed it in the GPU?

Moreover, the text says "Next, be sure to call model.to(torch.device('cuda')) to convert the model’s parameter tensors to CUDA tensors". But wasn't that conversion already done during load with map_location?

1

There are 1 answers

3
Ivan On

Using torch.load with map_location set to a cuda device will load the state to the desired device. However, it is likely that when load_state_dict is called, the tensors are copied to the model's CPU, including if the model is on the CPU. If you test your code without the model.to(device), you will notice your model is, in fact, not on the GPU. The state was on the GPU, but the model never was. You can read more here.

model = TheModelClass(*args, **kwargs)
model.load_state_dict(torch.load(PATH, map_location=device))

This means you are required to use the model.to(device).

You can prevent that by first transferring the model to the Cuda device, and then loading your state. But overall the number of transfers to the GPU is the same, 2: 1st with your initialized model, then 2nd the state.

model = TheModelClass(*args, **kwargs).to(device)
model.load_state_dict(torch.load(PATH, map_location=device))

You should be able to reduce the number of transfers by first loading the state on the CPU, then transfer the model to the GPU:

model = TheModelClass(*args, **kwargs)
model.load_state_dict(torch.load(PATH, map_location='cpu'))
model.to(device)