aboutsummaryrefslogtreecommitdiff
path: root/code/sunlab/common/data/image_dataset.py
blob: 46e77b61d365d83d3ae40b16954ca59fc8ae1bf9 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
class ImageDataset:
    def __init__(
        self,
        base_directory,
        ext="png",
        channels=[0],
        batch_size=None,
        shuffle=False,
        rotate=False,
        rotate_p=1.,
    ):
        """# Image Dataset

        Load a directory of images"""
        from glob import glob
        from matplotlib.pyplot import imread
        from numpy import newaxis, vstack
        from numpy.random import permutation, rand

        self.base_directory = base_directory
        files = glob(self.base_directory + "*." + ext)
        self.dataset = []
        for file in files:
            im = imread(file)[newaxis, :, :, channels].transpose(0, 3, 1, 2)
            self.dataset.append(im)
            # Also add rotations of the image to the dataset
            if rotate:
                if rand() < rotate_p:
                    self.dataset.append(im[:, :, ::-1, :])
                if rand() < rotate_p:
                    self.dataset.append(im[:, :, :, ::-1])
                if rand() < rotate_p:
                    self.dataset.append(im[:, :, ::-1, ::-1])
                if rand() < rotate_p:
                    self.dataset.append(im.transpose(0, 1, 3, 2))
                if rand() < rotate_p:
                    self.dataset.append(im.transpose(0, 1, 3, 2)[:, :, ::-1, :])
                if rand() < rotate_p:
                    self.dataset.append(im.transpose(0, 1, 3, 2)[:, :, :, ::-1])
                if rand() < rotate_p:
                    self.dataset.append(im.transpose(0, 1, 3, 2)[:, :, ::-1, ::-1])
        self.dataset = vstack(self.dataset)
        if shuffle:
            self.dataset = self.dataset[permutation(self.dataset.shape[0]), ...]
        self.batch_size = (
            batch_size if batch_size is not None else self.dataset.shape[0]
        )

    def torch(self, device=None):
        """# Cast to Torch Tensor"""
        import torch

        if device is None:
            device = torch.device("cpu")
        return torch.tensor(self.dataset).to(device)

    def numpy(self):
        """# Cast to Numpy Array"""
        return self.dataset

    def __len__(self):
        """# Return Number of Cases

        (or Number in each Batch)"""
        return self.dataset.shape[0] // self.batch_size

    def __getitem__(self, index):
        """# Slice By Batch"""
        if type(index) == tuple:
            return self.dataset[index]
        elif type(index) == int:
            return self.dataset[
                index * self.batch_size : (index + 1) * self.batch_size, ...
            ]
        return