I am trying load a dataset from files and train an AI model on it.
For some reason when I use for iamges, targets in dataloader in my main it loads the targets like:
{
'image_id':[all image ids],
'keypoints':[all keypoint lists],
'labels':[all label lists],
'boxes':[all bboxes]
}
instead of
[
{
'image_id':image_id of first sample,
'keypoints':[list of keypoints of first sample],
'labels':[list of labels of first sample],
'boxes':bbox of first sample
},
...
{
'image_id':image_id of fourth sample,
'keypoints':[list of keypoints of fourth sample],
'labels':[list of labels of fourth sample],
'boxes':bbox of fourth sample
}
]
This is my __getitem__ function:
def __getitem__(self, idx):
annotation = self.annotations[idx]
image_id = annotation['image_id']
file_name = annotation['file_name']
image_path = f"{self.images_dir}/{file_name}"
image = Image.open(image_path).convert("RGB")
bbox = np.array(annotation['bbox'])
keypoints = np.array([[ann["x"],ann["y"]] for ann in annotation["keypoints"]])
labels = np.array([kp_num[ann["name"]] for ann in annotation["keypoints"]])
target = {
"image_id":image_id,
"keypoints": torch.tensor(keypoints, dtype=torch.float32),
"labels": torch.tensor(labels, dtype=torch.int64),
"boxes":torch.tensor(bbox, dtype=torch.int)
}
if self.transform:
image = self.transform(image)
return image, target
I hoped for it to return a list of targets but it returns a dict of lists. I tried putting target in the return statement in a list with one element but it just returned a list with a single entry with all info instead of a list with batch_size many targets. I am using the torch.utils.data DataLoader class.
EDIT: solved it, I just implemented a custom_collate_fn like this:
def custom_collate_fn(batch):
images = [item[0] for item in batch]
targets = [item[1] for item in batch]
return images, targets
The default
collate_fnwould behave like stacking up the each keys of your dictionary. Assuming that batch_size=N, each element oftarget['key']would be [N, ...]. (Recommended behavior for future compatibility, memory allocations, and optimization on browsing datasets.)If you really need to pop out the values on your dictionary, try below code.
BTW, mentioned with above reasons, I suggest you use the default
collate_fninstead.