Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

use custom dataset and augmentation #14

Open
ramdhan1989 opened this issue May 5, 2022 · 1 comment
Open

use custom dataset and augmentation #14

ramdhan1989 opened this issue May 5, 2022 · 1 comment

Comments

@ramdhan1989
Copy link

hi there! is there any easy way to run for custom dataset and augmentation? if they are not supported yet, kindly need your advise which part of your code that need to be modified?

thank you

@Tieck-IT
Copy link

This is my code for using custom dataset.

utils.py

@attr.s(auto_attribs=True, slots=True)
class FashionDataset(DatasetBase):
    transform_train: Callable[[Any], torch.Tensor] = imagenet_default_transform
    transform_test: Callable[[Any], torch.Tensor] = imagenet_default_transform

    def configure_train(self):
        assert os.path.exists(self.data_path)
        return CustomDataset(self.data_path, split="train", transform=self.transform_train)

    def configure_validation(self):
        assert os.path.exists(self.data_path)
        return CustomDataset(self.data_path, split="val", transform=self.transform_test)

custom.py (new file)

class CustomDataset(Dataset):
    def __init__(self, csv_file, split="train", transform=None):
        self.data = pd.read_csv(csv_file)
        self.data = self.data[self.data["split"] == split]
        if split == "val":
            self.data = self.data.sample(frac=1, random_state=42)
        self.transform = transform
        self.image_paths = self.data["image_path"].to_numpy()
        self.labels = self.data["label"].to_numpy()
        self.image2ram = True
        if self.image2ram:
            self.images = []
            for img_path in tqdm(self.image_paths, desc="Loading images to RAM", total=len(self.image_paths)):
                image = Image.open(img_path).convert("RGB")
                self.images += [image]

        del self.data

    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, idx):
        # Open image file
        if self.image2ram:
            image = self.images[idx]
        else:
            img_path = self.image_paths[idx]
            image = Image.open(img_path).convert("RGB")

        # Get label
        label = torch.tensor(self.labels[idx], dtype=torch.long)

        # Apply transforms if any
        if self.transform:
            image = self.transform(image)

        return image, label

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants