Issues between PyTorch DataLoader and Matplotlib's Imshow for Image Classification Task

15 views Asked by At

I am currently working on a binary classification task involving image data. To begin, it is essential for me to inspect my dataset. However, I have encountered an issue with the DataLoader.

On the official PyTorch website, there is written like this

training_data = datasets.FashionMNIST(
    root="data",
    train=True,
    download=True,
    transform=ToTensor()
)

labels_map = {
    0: "T-Shirt",
    1: "Trouser",
    2: "Pullover",
    3: "Dress",
    4: "Coat",
    5: "Sandal",
    6: "Shirt",
    7: "Sneaker",
    8: "Bag",
    9: "Ankle Boot",
}
figure = plt.figure(figsize=(8, 8))
cols, rows = 3, 3
for i in range(1, cols * rows + 1):
    sample_idx = torch.randint(len(training_data), size=(1,)).item()
    img, label = training_data[sample_idx]
    figure.add_subplot(rows, cols, i)
    plt.title(labels_map[label])
    plt.axis("off")
    plt.imshow(img.squeeze(), cmap="gray")
plt.show()

When they set training data, they transformed data type to tensor. And they just use imshow(matplotlib). But when i try this process on my own, the error TypeError: pic should be PIL Image or ndarray. Got <class 'torch.Tensor'> bother me.

When I ask this to GPT4, it said "PyTorch and matplotlib are compatible." However, when I inquired again with my code provided, it mentioned, "You need to convert the PyTorch tensor to a NumPy array before using imshow." Which one is the accurate statement?

1

There are 1 answers

0
Minh Nguyen Hoang On

The second one should be the right statement. You should change to this

plt.imshow(img.squeeze().numpy(), cmap="gray")