I'm working on a machine learning project in PyTorch where I need to optimize a model using the full batch gradient descent method. The key requirement is that the optimizer should use all the data points in the dataset for each update. My challenge with the existing torch.optim.SGD optimizer is that it doesn't inherently support using the entire dataset in a single update. This is crucial for my project as I need the optimization process to consider all data points to ensure the most accurate updates to the model parameters.
Additionally, I would like to retain the use of Nesterov momentum in the optimization process. I understand that one could potentially modify the batch size to equal the entire dataset, simulating a full batch update with the SGD optimizer. However, I'm interested in whether there's a more elegant or direct way to implement a true Gradient Descent optimizer in PyTorch that also supports Nesterov momentum.
Ideally, I'm looking for a solution or guidance on how to implement or configure an optimizer in PyTorch that meets the following criteria:
- Utilizes the entire dataset for each parameter update (true Gradient Descent behavior).
- Incorporates Nesterov momentum for more efficient convergence.
- Is compatible with the rest of the PyTorch ecosystem, by subclassing torch.optim.Optimizer
The pytorch SGD implementation is actually independent of the batching! It only uses the gradients that were calculated and stored in the parameters
.gradattribute in the backward pass. So the batch size used for calculations and the batch size used for optimization are decoupled.You can now either:
a) Put all your samples as one big batch through your model by setting the batchsize to the dataset size or
b) Accumulate the gradients for many smaller batches before doing a single step of the optimizer (Pseudo-code):