Finetune Sentence Transformer using Huggingface Trainer API

183 views Asked by At

I want to fine tune a Sentence Transformer (For example MPNET) using Contrastive Learning. Is it possible to use the Huggingface Trainer API for this? If yes, how? Can you kindly guide me with some suggestion?

I went through the Trainer API docs and examples but could not actually find a support for Contrastive Learning or fine tuning embeddings. Most documents discuss primarily about Classification models.

2

There are 2 answers

1
Fahad Ebrahim On

The Huggingface Trainer is used for models. If you want to train embeddings, you can refer to the sentence-transformer website:

https://www.sbert.net/docs/training/overview.html

0
Halim Saad-Rached On

Here's the documentation of SBERT contrastive loss. I used this loss function to finetune the hugging face sentence transformer for semantic text similarity of sentence pairs. Below is a code excerpt:

    from sentence_transformers import SentenceTransformer, SentencesDataset, losses, evaluation
    from sentence_transformers.readers import InputExample
    import torch
    from torch.utils.data import DataLoader

    #model=...
    #train_examples=...
    #val_examples=...

    batch_size=64
    epochs=10
    evaluation_steps=32
    warmup_steps=int(len(train_examples) * epochs * 0.1 )
    optimizer_params={'lr':2e-5}
    weight_decay = 0.05

    train_dataset = SentencesDataset(train_examples, model)
    train_dataloader = DataLoader(train_dataset, shuffle=True, batch_size=batch_size)
    train_loss = losses.ContrastiveLoss(model=model)
    evaluator = evaluation.EmbeddingSimilarityEvaluator.from_input_examples(val_examples)
    optimizer = torch.optim.AdamW

    model.fit(
      train_objectives=[(train_dataloader, train_loss)],
      epochs=epochs,
      evaluator=evaluator,
      evaluation_steps=evaluation_steps,
      warmup_steps=warmup_steps,
      optimizer_class=optimizer,
      optimizer_params=optimizer_params,
      weight_decay = weight_decay,
      output_path='output_folder',
      save_best_model=True,
      show_progress_bar = True
    )