NLP Transformers Machine Translation

97 views Asked by At

Training transformer in pytorch train loss reduces but fails in validation. I have taken only one example just to test whether model architecture is fine or not. I have tested it on a larger dataset and still have the same results.

I guess there is some problem with decoder greedy decoding or tgt mask.

import torch
import torch.nn as nn
import random
import math
random.seed(1)
eng = 'this is transformer'
hin = 'यह ट्रांसफार्मर है'

eng_dict = {key:val for val,key in enumerate(set(eng),start=1)}
hin_dict = {key:val for val,key in enumerate(set(hin),start=1)}
class PositionalEncoding1D(nn.Module):

    def __init__(self, d_model, max_len=100):
        super().__init__()
        pe = torch.zeros((max_len, d_model), requires_grad=False)
        position = torch.arange(0, max_len).unsqueeze(1)
        div_term = torch.exp(
            torch.arange(0, d_model, 2) * -(math.log(10000.0) / d_model)
        )
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        pe = pe.unsqueeze(0)
        self.register_buffer("pe", pe)

    def forward(self, x):
        _, T, _ = x.shape
        return x + self.pe[:, :T]

class Encoder(nn.Module):
    def __init__(self):
        super().__init__()
        self.vocab_size = len(list(eng_dict.keys())) + 1
        self.encoder = nn.TransformerEncoder(nn.TransformerEncoderLayer(128,4,512,0.1),6)
        self.embeddings = nn.Embedding(self.vocab_size,embedding_dim = 128)
        self.posEmb = PositionalEncoding1D(128,64)

    def forward(self,src):
        src = self.posEmb(self.embeddings(src))
        x = self.encoder(src)
        return x
        

class Decoder(nn.Module):
    def __init__(self):
        super().__init__()
        self.vocab_size = len(list(hin_dict.keys())) + 2
        self.decoder =      nn.TransformerDecoder(nn.TransformerDecoderLayer(128,4,512,0.1,batch_first=True),6)
        self.embeddings = nn.Embedding(self.vocab_size,embedding_dim = 128)
        self.posEmb = PositionalEncoding1D(128,64)
        self.lin = nn.Linear(128,self.vocab_size)

    def forward_train(self,memory,tgt):
        tgt = self.posEmb(self.embeddings(tgt))
        tgt_mask = torch.triu(torch.ones(20,20),diagonal=1)
        x = self.decoder(tgt,memory,tgt_mask)
        x = self.lin(x)
        return x
    
    def forward_infer(self,memory):
        tgt = self.posEmb(self.embeddings(torch.tensor([[0]],dtype=torch.int32)))
        preds = []
        for i in range(20):
            x = self.decoder(tgt,memory)
            x = self.lin(x[:,-1,:]) # greedy
            _,pr = torch.max(x,dim=1)
            preds.append(pr)
            tgt = torch.cat([tgt ,
                             self.posEmb.pe[:, len(preds)] + (self.embeddings(pr) * math.sqrt(128)).unsqueeze(0) ],dim=1)
            
        return torch.tensor(preds)
    
    def forward(self,memory,tgt):
        if tgt is not None:
            x = self.forward_train(memory,torch.tensor(tgt,dtype=torch.int).unsqueeze(0))
        else:
            x = self.forward_infer(memory=memory)
        return x

class EncDec(nn.Module):

    def __init__(self):
        super().__init__()
        self.enc = Encoder()
        self.dec = Decoder()

    def forward(self,src,target=None):
        x = self.enc(torch.tensor(src,dtype=torch.int).unsqueeze(0))
        y = self.dec.forward(x,target)
        return y
    
def tokenize(text,lang):
    if 'hin' in lang:
        lst = [hin_dict[i] for i in text]
        lst = [0] + lst +  [13] # sos and eos
    else:
        lst = [eng_dict[i] for i in text]

    return lst


loss = torch.nn.CrossEntropyLoss(ignore_index=0)
optim = torch.optim.AdamW(model.parameters(),lr =0.0001)
for epoch in range(2000):
    for steps in range(1):
        out = model(tokenized_eng,tokenized_hin)
        optim.zero_grad()
        criterion = loss(out.reshape(-1,14),torch.tensor(tokenized_hin,dtype=torch.long)) 
        criterion.backward()
        optim.step()
        if epoch % 100==0:
            print(criterion)

model.eval()
x = model(tokenized_eng)
print('The output is  = ')
print(x)
The output is  = tensor([1, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8])
tensor(2.7242, grad_fn=<NllLossBackward0>)
tensor(0.0548, grad_fn=<NllLossBackward0>)
tensor(0.0243, grad_fn=<NllLossBackward0>)
tensor(0.0168, grad_fn=<NllLossBackward0>)
tensor(0.0122, grad_fn=<NllLossBackward0>)
tensor(0.0098, grad_fn=<NllLossBackward0>)
tensor(0.0077, grad_fn=<NllLossBackward0>)
tensor(0.0066, grad_fn=<NllLossBackward0>)
tensor(0.0053, grad_fn=<NllLossBackward0>)
tensor(0.0046, grad_fn=<NllLossBackward0>)
tensor(0.0042, grad_fn=<NllLossBackward0>)
tensor(0.0036, grad_fn=<NllLossBackward0>)
tensor(0.0034, grad_fn=<NllLossBackward0>)
tensor(0.0029, grad_fn=<NllLossBackward0>)
tensor(0.0026, grad_fn=<NllLossBackward0>)
tensor(0.0023, grad_fn=<NllLossBackward0>)
tensor(0.0020, grad_fn=<NllLossBackward0>)
tensor(0.0019, grad_fn=<NllLossBackward0>)
tensor(0.0017, grad_fn=<NllLossBackward0>)
tensor(0.0016, grad_fn=<NllLossBackward0>)
The output is  = tensor([1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1])
0

There are 0 answers