I'm currently trying to run a distributed training job in an aws sagemaker training job. I'm using the SMP library for AWS to parallelize my model. For the training I created the following script:
import torch.distributed as dist
import datetime
dist.init_process_group("nccl", timeout=datetime.timedelta(seconds=7200))
import torch.sagemaker as tsm
tsm.init()
from torch.utils.data import Dataset
import torch
from peft import (
LoraConfig,
get_peft_model,
get_peft_model_state_dict,
prepare_model_for_int8_training,
)
from datasets import load_from_disk
from transformers import AutoTokenizer, AutoModelForCausalLM, TrainingArguments, Trainer, DataCollatorForSeq2Seq
import argparse
import sys
import os
import logging
import matplotlib.pyplot as plt
import pandas as pd
if __name__ == "__main__":
parser = argparse.ArgumentParser()
# hyperparameters sent by the client
parser.add_argument("--batch_size", type=int, default=128)
parser.add_argument("--per_device_train_batch_size", type=int, default=32)
parser.add_argument("--model_name", type=str, default="codellama/CodeLlama-7b-hf")
parser.add_argument("--learn_rate", type=str, default="3e-4")
parser.add_argument("--warmup_steps", type=int, default=400)
parser.add_argument("--epochs", type=int, default=1)
# Data, model and output directories
parser.add_argument("--output_data_dir", type=str, default=os.environ["SM_OUTPUT_DATA_DIR"])
parser.add_argument("--model-dir", type=str, default=os.environ["SM_MODEL_DIR"])
parser.add_argument("--n_gpus", type=str, default=os.environ["SM_NUM_GPUS"])
parser.add_argument("--training_dir", type=str, default=os.environ["SM_CHANNEL_TRAIN"])
parser.add_argument("--test_dir", type=str, default=os.environ["SM_CHANNEL_TEST"])
args, _ = parser.parse_known_args()
# Set up logging
logger = logging.getLogger(__name__)
logging.basicConfig(
level=logging.getLevelName("INFO"),
handlers=[logging.StreamHandler(sys.stdout)],
format="%(asctime)s - %(name)s - %(levelname)s - %(message)s"
)
model = AutoModelForCausalLM.from_pretrained(
args.model_name,
load_in_8bit=True,
torch_dtype=torch.float16
)
tokenizer = AutoTokenizer.from_pretrained(args.model_name)
tokenizer.add_eos_token = True
tokenizer.pad_token_id = 0
tokenizer.padding_side = "left"
# %%
class MyDataset(Dataset):
def __init__(self, train=True):
if train:
self.dataset = load_from_disk(dataset_path=args.training_dir)
else:
self.dataset = load_from_disk(dataset_path=args.test_dir)
def __len__(self):
return len(self.dataset)
def __getitem__(self, idx):
return self.dataset[idx]
train_dataset = MyDataset(train=True)
val_dataset = MyDataset(train=False)
model.train() # put model back into training mode
model = prepare_model_for_int8_training(model)
config = LoraConfig(
r=16,
lora_alpha=16,
target_modules=[
"q_proj",
"k_proj",
"v_proj",
"o_proj",
],
lora_dropout=0.05,
bias="none",
task_type="CAUSAL_LM",
)
model = get_peft_model(model, config)
if torch.cuda.device_count() > 1:
# keeps Trainer from trying its own DataParallelism when more than 1 gpu is available
model.is_parallelizable = True
model.model_parallel = True
training_args = TrainingArguments(
per_device_train_batch_size=args.per_device_train_batch_size,
per_device_eval_batch_size=args.per_device_train_batch_size,
gradient_checkpointing=True,
warmup_steps=args.warmup_steps,
learning_rate=float(args.learn_rate),
bf16=True,
logging_steps=1,
optim="adamw_torch",
evaluation_strategy="no",
save_strategy="steps",
#eval_steps=20,
save_steps=20,
output_dir=args.model_dir,
load_best_model_at_end=False,
group_by_length=True, # group sequences of roughly the same length together to speed up training
log_level='debug',
logging_dir=f"{args.output_data_dir}/logs",
ddp_find_unused_parameters=False,
num_train_epochs=args.epochs
)
def compute_perplexity(pred):
# Extract the predicted logits from the model output
logits = pred.predictions
# Flatten the logits and labels to compute cross-entropy loss
logits = logits.view(-1, logits.size(-1))
labels = pred.label_ids.view(-1)
# Compute cross-entropy loss
loss = torch.nn.functional.cross_entropy(logits, labels)
# Compute perplexity
perplexity = torch.exp(loss)
return {"perplexity": perplexity.item()}
trainer = Trainer(
model=model,
train_dataset=train_dataset,
eval_dataset=val_dataset,
args=training_args,
compute_metrics=compute_perplexity,
data_collator=DataCollatorForSeq2Seq(
tokenizer, pad_to_multiple_of=8, return_tensors="pt", padding=True
))
trainer.train()
The command that is being executed by sagemaker is the following:
torchrun --nnodes 1 --nproc_per_node 8 train.py --epochs 1 --learn_rate 0.0003 --model_name codellama/CodeLlama-7b-hf --mp_parameters activation_loading_horizon=2,hybrid_shard_degree=8,sm_activation_offloading=False,tensor_parallel_degree=1 --per_device_train_batch_size 4 --warmup_steps 100
But for some reason when trainer.train() gets called, the training starts TWICE simultaneously. Do you know how I can change my script so that it runs only ONCE?