Generate prediction sequence with transformers model built from scratch

40 views Asked by At

I'm building a basic transformers models from scratch in PyTorch (with simple tokenization and no masking). I'm using 'The Mysterious Island' by Jules Verne as the training set, so download it from Project Gutenberg and place in the appropriate directory locally to run the code below.

I'm having trouble conceptualizing and coding how I would predict on new text. It's possible that my encoder/decoder_input shapes are incorrect. The way I have it set up is that the input_encoder is a sequence of 30 tokens (tokens 1-30 in the corpus for the first example), and the input_decoder is tokens 2-31 in the corpus for the first example.

My output has the same shape as the input (i.e. 1, context_size, vocab_size). When I predict, I feed in the start sequence padded with enough tokens to get to the context size as the input_encoder. The input_decoder is a start-of-sequence token with padded tokens up to the context size.

When I call the forward method during prediction, I'm under the impression that the model should be autoregressive. By that logic, I should just get 1 new token that I would add to the input_decoder with each pass-through. However, the output from one forward pass has n tokens (n = context size).

Perhaps that latter part of my setup is incorrect? If the output is predicted 30 words at once, then maybe I don't need to use the autoregressive approach?

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F

import torch.optim as optim

from torch.utils.data import Dataset
from torch.utils.data import DataLoader

class PositionalEncoding(nn.Module):

    def __init__(self, context_size, d_model):
        super().__init__()

        self.encoding = torch.zeros(context_size, d_model)

        pos = torch.arange(0, context_size).unsqueeze(dim=1)
        dim = torch.arange(
            0, d_model, 2)  # dim is i in the positional encoding formula
        self.encoding[:, 0::2] = torch.sin(pos / (10000**(2 * dim / d_model)))
        self.encoding[:, 1::2] = torch.cos(pos / (10000**(2 * dim / d_model)))

    def forward(self, x):
        seq_len = x.size(1)
        return self.encoding[:seq_len, :]

class PositionwiseFeedForward(nn.Module):

    def __init__(self, d_model, d_ff):
        super().__init__()
        self.linear1 = nn.Linear(d_model, d_ff)
        self.linear2 = nn.Linear(d_ff, d_model)
        self.relu = nn.ReLU()

    def forward(self, x):
        x = self.linear1(x)
        x = self.relu(x)
        x = self.linear2(x)
        return x

class EncoderBlock(nn.Module):
    def __init__(self, d_model, num_heads, d_ff):
        super().__init__()
        self.self_attn = nn.MultiheadAttention(d_model, num_heads)
        self.feed_forward = PositionwiseFeedForward(d_model, d_ff)
        self.norm1 = nn.LayerNorm(d_model)
        self.norm2 = nn.LayerNorm(d_model)

    def forward(self, x):
        hidden_states, _ = self.self_attn(query=x, key=x, value=x)
        x = self.norm1(x + hidden_states)
        ff_output = self.feed_forward(x)
        x = self.norm2(x + ff_output)
        return x

class Encoder(nn.Module):
    # input_size - # rows in token embedding
    # context size - # rows in positional embedding
    # d_ff - internal dimension of the FF network
    # num encoder blocks
    def __init__(self, input_size, context_size, d_model, d_ff, num_heads,
                 n_blocks):
        super().__init__()

        self.embedding = nn.Embedding(input_size, d_model)
        self.pos_embedding = PositionalEncoding(context_size, d_model)

        self.blocks = nn.ModuleList([
            EncoderBlock(
                d_model=d_model,
                num_heads=num_heads,
                d_ff=d_ff,
            ) for _ in range(n_blocks)
        ])

    def forward(self, x):
        x = self.embedding(x) + self.pos_embedding(x)
        for block in self.blocks:
            x = block(x)
        return x

class DecoderBlock(nn.Module):
    def __init__(self, d_model, num_heads, d_ff):
        super().__init__()
        self.self_attn = nn.MultiheadAttention(d_model, num_heads)
        self.cross_attn = nn.MultiheadAttention(d_model, num_heads)
        self.feed_forward = PositionwiseFeedForward(d_model, d_ff)
        self.norm1 = nn.LayerNorm(d_model)
        self.norm2 = nn.LayerNorm(d_model)
        self.norm3 = nn.LayerNorm(d_model)

    def forward(self, x, enc_output):
        hidden_states, _ = self.self_attn(x, x, x)
        x = self.norm1(x + hidden_states)
        hidden_states, _ = self.cross_attn(
                             query=x, key=enc_output, value=enc_output)
        x = self.norm2(x + hidden_states)
        ff_output = self.feed_forward(x)
        x = self.norm3(x + ff_output)
        return x

class Decoder(nn.Module):
    def __init__(self, output_size, context_size,
                 d_model, d_ff, num_heads, n_blocks):
        super().__init__()
        self.embedding = nn.Embedding(output_size, d_model)
        self.pos_embedding = PositionalEncoding(context_size, d_model)

        self.blocks = nn.ModuleList([
            DecoderBlock(
                d_model=d_model,
                num_heads=num_heads,
                d_ff=d_ff,
            )
            for _ in range(n_blocks)
        ])

        self.out = nn.Linear(d_model, output_size)

    def forward(self, x, enc_output):
        x = self.embedding(x) + self.pos_embedding(x)

        for block in self.blocks:
            x = block(x, enc_output)

        output = self.out(x)
        return output

class Transformer(nn.Module):

    def __init__(self, vocab_size, context_size,
                 d_model, d_ff, num_heads, n_blocks):
        super().__init__()

        self.encoder = Encoder(
            vocab_size,
            context_size,
            d_model,
            d_ff,
            num_heads,
            n_blocks
        )

        self.decoder = Decoder(
            vocab_size,
            context_size,
            d_model,
            d_ff,
            num_heads,
            n_blocks
        )

    def forward(self, input_encoder, input_decoder):
        enc_output = self.encoder(input_encoder)   # (64, 100, 10)
        output = self.decoder(input_decoder, enc_output)  # input_decoder shape - (64, 99)
        return output

## Prep data
with open('your_directory/1268-0.txt', 'r', encoding="utf8") as fp:
    text=fp.read()

start_indx = text.find('THE MYSTERIOUS ISLAND')
end_indx = text.find('End of the Project Gutenberg')

text = text[start_indx:end_indx]

SOS_token = 0
EOS_token = 1
PAD_token = 2   # Need to have padding so that the input & output sentences
# are the same length - required for the cross-attention computation

index2words = {SOS_token: 'SOS', EOS_token: 'EOS', PAD_token: 'PAD'}

words_list = set(text.split(' '))

for word in words_list:
    index2words[len(index2words)] = word

words2index = {w: i for i, w in index2words.items()}

text_encoded = np.array(
    [words2index[word] for word in text.split(' ')],
    dtype=np.int32)

CONTEXT_SIZE = 30
chunk_size = CONTEXT_SIZE  + 1

# n chunks where each next chunk is 1 word offset from the previous chunk
token_chunks = [
    text_encoded[i:i + chunk_size]
    for i in range(len(text_encoded) - chunk_size + 1)
]

class TextDataset(Dataset):
    def __init__(self, text_chunks):
        self.text_chunks = text_chunks

    def __len__(self):
        return len(self.text_chunks)

    def __getitem__(self, idx):
        text_chunk = self.text_chunks[idx]
        return text_chunk[:-1].long(), text_chunk[1:].long()

seq_dataset = TextDataset(torch.tensor(token_chunks))

BATCH_SIZE = 50
seq_dl = DataLoader(seq_dataset,
                    batch_size=BATCH_SIZE,
                    shuffle=True,
                    drop_last=True)

## Build model
EPOCHS = 10
VOCAB_SIZE = len(words2index)
D_MODEL = 10
D_FF = 20
NUM_HEADS = 2
N_BLOCKS = 12

model = Transformer(
    vocab_size=VOCAB_SIZE,
    context_size=CONTEXT_SIZE,
    d_model=D_MODEL,
    d_ff=D_FF,  # internal dimension of the feed forward network
    num_heads=NUM_HEADS,
    n_blocks=N_BLOCKS)

criterion = nn.CrossEntropyLoss(ignore_index=0)
optimizer = optim.Adam(model.parameters(), lr=0.0001, betas=(0.9, 0.98), eps=1e-9)

model.train()

for epoch in range(EPOCHS):
    for src_data, tgt_data in seq_dl:
        output = model(src_data, tgt_data)
        loss = criterion(
            output.view(-1, VOCAB_SIZE),
            tgt_data.view(-1))  # CrossEntropyLoss requires preds to be
        # shape (batch_size, num_classes) and actual to be shape (batch_size), so need to reshape
        loss.backward()
        optimizer.step()
        optimizer.zero_grad()
    print(f"Epoch: {epoch+1}, Loss: {loss.item()}")

## Predict sequence
prediction_start = "And then they set sail again for another land"

with torch.no_grad():
    # Assuming you have a method to tokenize the input words and convert them into input_tokens
    encoder_tokens = convert2tensors(prediction_start, CONTEXT_SIZE)

    # Initialize input_decoder with a start token
    input_decoder = SOS_token
    decoder_tokens = torch.tensor([input_decoder] + [PAD_token] * (CONTEXT_SIZE-1))

    # Generate predictions iteratively until an end token is generated or a maximum length is reached
    while len(decoder_tokens) < CONTEXT_SIZE:
        # Call the forward method of your Transformer class to generate predictions
        output = model(encoder_tokens, decoder_tokens)

        # Append the predicted token to the input_decoder
        decoder_tokens.append(output)

        # Break the loop if the end token is generated
        if EOS_token in output:
            break

0

There are 0 answers