diff options
| author | Christian C <cc@localhost> | 2024-11-11 12:29:32 -0800 | 
|---|---|---|
| committer | Christian C <cc@localhost> | 2024-11-11 12:29:32 -0800 | 
| commit | b85ee9d64a536937912544c7bbd5b98b635b7e8d (patch) | |
| tree | cef7bc17d7b29f40fc6b1867d0ce0a742d5583d0 /code/sunlab/common/data | |
Initial commit
Diffstat (limited to 'code/sunlab/common/data')
| -rw-r--r-- | code/sunlab/common/data/__init__.py | 6 | ||||
| -rw-r--r-- | code/sunlab/common/data/basic.py | 6 | ||||
| -rw-r--r-- | code/sunlab/common/data/dataset.py | 255 | ||||
| -rw-r--r-- | code/sunlab/common/data/dataset_iterator.py | 34 | ||||
| -rw-r--r-- | code/sunlab/common/data/image_dataset.py | 75 | ||||
| -rw-r--r-- | code/sunlab/common/data/shape_dataset.py | 57 | ||||
| -rw-r--r-- | code/sunlab/common/data/utilities.py | 119 | 
7 files changed, 552 insertions, 0 deletions
diff --git a/code/sunlab/common/data/__init__.py b/code/sunlab/common/data/__init__.py new file mode 100644 index 0000000..3e26874 --- /dev/null +++ b/code/sunlab/common/data/__init__.py @@ -0,0 +1,6 @@ +from .basic import * +from .dataset import * +from .dataset_iterator import * +from .shape_dataset import * +from .image_dataset import * +from .utilities import * diff --git a/code/sunlab/common/data/basic.py b/code/sunlab/common/data/basic.py new file mode 100644 index 0000000..bb2e912 --- /dev/null +++ b/code/sunlab/common/data/basic.py @@ -0,0 +1,6 @@ +import numpy + + +numpy.load_dat = lambda *args, **kwargs: numpy.load( +    *args, **kwargs, allow_pickle=True +).item() diff --git a/code/sunlab/common/data/dataset.py b/code/sunlab/common/data/dataset.py new file mode 100644 index 0000000..8589abf --- /dev/null +++ b/code/sunlab/common/data/dataset.py @@ -0,0 +1,255 @@ +from .dataset_iterator import DatasetIterator + + +class Dataset: +    """# Dataset Superclass""" + +    base_scale = 10.0 + +    def __init__( +        self, +        dataset_filename, +        data_columns=[], +        label_columns=[], +        batch_size=None, +        shuffle=False, +        val_split=0.0, +        scaler=None, +        sort_columns=None, +        random_seed=4332, +        pre_scale=10.0, +        **kwargs +    ): +        """# Initialize Dataset +        self.dataset = dataset (N, ...) +        self.labels = labels (N, ...) + +        Optional Arguments: +        - prescale_function: The function that takes the ratio and transforms +        the dataset by multiplying the prescale_function output +        - sort_columns: The columns to sort the data by initially +        - equal_split: If the classifications should be equally split in +        training""" +        from pandas import read_csv +        from numpy import array, all +        from numpy.random import seed + +        if seed is not None: +            seed(random_seed) + +        # Basic Dataset Information +        self.data_columns = data_columns +        self.label_columns = label_columns +        self.source = dataset_filename +        self.dataframe = read_csv(self.source) + +        # Pre-scaling Transformation +        prescale_ratio = self.base_scale / pre_scale +        ratio = prescale_ratio +        prescale_powers = array([2, 1, 1, 0, 2, 1, 0, 0, 1, 1, 1, 1, 1]) +        if "prescale_function" in kwargs.keys(): +            prescale_function = kwargs["prescale_function"] +        else: + +            def prescale_function(x): +                return x**prescale_powers + +        self.prescale_function = prescale_function +        self.prescale_factor = self.prescale_function(ratio) +        assert ( +            len(data_columns) == self.prescale_factor.shape[0] +        ), "Column Mismatch on Prescale" +        self.original_scale = pre_scale + +        # Scaling Transformation +        self.scaled = scaler is not None +        self.scaler = scaler + +        # Training Dataset Information +        self.do_split = False if val_split == 0.0 else True +        self.validation_split = val_split +        self.batch_size = batch_size +        self.do_shuffle = shuffle +        self.equal_split = False +        if "equal_split" in kwargs.keys(): +            self.equal_split = kwargs["equal_split"] + +        # Classification Labels if they exist +        self.dataset = self.dataframe[self.data_columns].to_numpy() +        if len(self.label_columns) == 0: +            self.labels = None +        elif not all([column in self.dataframe.columns for column in label_columns]): +            import warnings + +            warnings.warn( +                "No classification labels found for the dataset", RuntimeWarning +            ) +            self.labels = None +        else: +            self.labels = self.dataframe[self.label_columns].squeeze() + +        # Initialize the dataset +        if "sort_columns" in kwargs.keys(): +            self.sort(kwargs["sort_columns"]) +        if self.do_shuffle: +            self.shuffle() +        if self.do_split: +            self.split() +        self.refresh_dataset() + +    def __len__(self): +        """# Get how many cases are in the dataset""" +        return self.dataset.shape[0] + +    def __getitem__(self, idx): +        """# Make Dataset Sliceable""" +        idx_slice = None +        slice_stride = 1 if self.batch_size is None else self.batch_size +        # If we pass a slice, return the slice +        if type(idx) == slice: +            idx_slice = idx +        # If we pass an int, return a batch-size slice +        else: +            idx_slice = slice( +                idx * slice_stride, min([len(self), (idx + 1) * slice_stride]) +            ) +        if self.labels is None: +            return self.dataset[idx_slice, ...] +        return self.dataset[idx_slice, ...], self.labels[idx_slice, ...] + +    def scale_data(self, data): +        """# Scale dataset from scaling function""" +        data = data * self.prescale_factor +        if not (self.scaler is None): +            data = self.scaler(data) +        return data + +    def scale(self): +        """# Scale Dataset""" +        self.dataset = self.scale_data(self.dataset) + +    def refresh_dataset(self, dataframe=None): +        """# Refresh Dataset + +        Regenerate the dataset from a dataframe. +        Primarily used after a sort or filter.""" +        if dataframe is None: +            dataframe = self.dataframe +        self.dataset = dataframe[self.data_columns].to_numpy() +        if self.labels is not None: +            self.labels = dataframe[self.label_columns].to_numpy().squeeze() +        self.scale() + +    def sort_on(self, columns): +        """# Sort Dataset on Column(s)""" +        from numpy import all + +        if type(columns) == str: +            columns = [columns] +        if columns is not None: +            assert all( +                [column in self.dataframe.columns for column in columns] +            ), "Dataframe does not contain some provided columns!" +            self.dataframe = self.dataframe.sort_values(by=columns) +            self.refresh_dataset() + +    def filter_on(self, column, value): +        """# Filter Dataset on Column Value(s)""" +        assert column in self.dataframe.columns, "Column DNE" +        self.working_dataset = self.dataframe[self.dataframe[column].isin(value)] +        self.refresh_dataset(self.working_dataset) + +    def filter_off(self): +        """# Remove any filter on the dataset""" +        self.refresh_dataset() + +    def unique(self, column): +        """# Get unique values in a column(s)""" +        assert column in self.dataframe.columns, "Column DNE" +        from numpy import unique + +        return unique(self.dataframe[column]) + +    def shuffle_data(self, data, labels=None): +        """# Shuffle a dataset""" +        from numpy.random import permutation + +        shuffled = permutation(data.shape[0]) +        if labels is not None: +            assert ( +                self.labels.shape[0] == self.dataset.shape[0] +            ), "Dataset and Label Shape Mismatch" +            shuf_data = data[shuffled, ...] +            shuf_labels = labels[shuffled] +            if len(labels.shape) > 1: +                shuf_labels = labels[shuffled,...] +            return shuf_data, shuf_labels +        return data[shuffled, ...] + +    def shuffle(self): +        """# Shuffle the dataset""" +        if self.do_shuffle: +            if self.labels is None: +                self.dataset = self.shuffle_data(self.dataset) +            self.dataset, self.labels = self.shuffle_data(self.dataset, self.labels) + +    def split(self): +        """# Training/ Validation Splitting""" +        from numpy import floor, unique, where, hstack, delete +        from numpy.random import permutation + +        equal_classes = self.equal_split +        if not self.do_split: +            return +        assert self.validation_split <= 1.0, "Too High" +        assert self.validation_split > 0.0, "Too Low" +        train_count = int(floor(self.dataset.shape[0] * (1 - self.validation_split))) +        training_data = self.dataset[:train_count, ...] +        training_labels = None +        validation_data = self.dataset[train_count:, ...] +        validation_labels = None +        if self.labels is not None: +            if equal_classes: +                # Ensure the split balances the prevalence of each class +                assert len(self.labels.shape) == 1, "1D Classification Only Currently" +                classification_breakdown = unique(self.labels, return_counts=True) +                train_count = min( +                    [ +                        train_count, +                        classification_breakdown.shape[0] +                        * min(classification_breakdown[1]), +                    ] +                ) +                class_size = train_count / classification_breakdown.shape[0] +                class_indicies = [ +                    permutation(where(self.labels == _class)[0]) +                    for _class in classification_breakdown[0] +                ] +                class_indicies = [indexes[:class_size] for indexes in class_indicies] +                train_class_indicies = hstack(class_indicies).squeeze() +                train_class_indicies = permutation(train_class_indicies) +                training_data = self.dataset[train_class_indicies, ...] +                training_labels = self.labels[train_class_indicies] +                if len(self.labels.shape) > 1: +                    training_labels = self.labels[train_class_indicies,...] +                validation_data = delete(self.dataset, train_class_indicies, axis=0) +                validation_labels = delete( +                    self.labels, train_class_indicies, axis=0 +                ).squeeze() +            else: +                training_labels = self.labels[:train_count] +                if len(training_labels.shape) > 1: +                    training_labels = self.labels[:train_count, ...] +                validation_labels = self.labels[train_count:] +                if len(validation_labels.shape) > 1: +                    validation_labels = self.labels[train_count:, ...] +        self.training_data = training_data +        self.validation_data = validation_data +        self.training = DatasetIterator(training_data, training_labels, self.batch_size) +        self.validation = DatasetIterator( +            validation_data, validation_labels, self.batch_size +        ) + +    def reset_iterators(self): +        """# Reset Train/ Validation Iterators""" +        self.split() diff --git a/code/sunlab/common/data/dataset_iterator.py b/code/sunlab/common/data/dataset_iterator.py new file mode 100644 index 0000000..7c91caa --- /dev/null +++ b/code/sunlab/common/data/dataset_iterator.py @@ -0,0 +1,34 @@ +class DatasetIterator: +    """# Dataset Iterator + +    Creates an iterator object on a dataset and labels""" + +    def __init__(self, dataset, labels=None, batch_size=None): +        """# Initialize the iterator with the dataset and labels + +        - batch_size: How many to include in the iteration""" +        self.dataset = dataset +        self.labels = labels +        self.current = 0 +        self.batch_size = ( +            batch_size if batch_size is not None else self.dataset.shape[0] +        ) + +    def __iter__(self): +        """# Iterator Function""" +        return self + +    def __next__(self): +        """# Next Iteration + +        Slice the dataset and labels to provide""" +        self.cur = self.current +        self.current += 1 +        if self.cur * self.batch_size < self.dataset.shape[0]: +            iterator_slice = slice( +                self.cur * self.batch_size, (self.cur + 1) * self.batch_size +            ) +            if self.labels is None: +                return self.dataset[iterator_slice, ...] +            return self.dataset[iterator_slice, ...], self.labels[iterator_slice, ...] +        raise StopIteration diff --git a/code/sunlab/common/data/image_dataset.py b/code/sunlab/common/data/image_dataset.py new file mode 100644 index 0000000..46e77b6 --- /dev/null +++ b/code/sunlab/common/data/image_dataset.py @@ -0,0 +1,75 @@ +class ImageDataset: +    def __init__( +        self, +        base_directory, +        ext="png", +        channels=[0], +        batch_size=None, +        shuffle=False, +        rotate=False, +        rotate_p=1., +    ): +        """# Image Dataset + +        Load a directory of images""" +        from glob import glob +        from matplotlib.pyplot import imread +        from numpy import newaxis, vstack +        from numpy.random import permutation, rand + +        self.base_directory = base_directory +        files = glob(self.base_directory + "*." + ext) +        self.dataset = [] +        for file in files: +            im = imread(file)[newaxis, :, :, channels].transpose(0, 3, 1, 2) +            self.dataset.append(im) +            # Also add rotations of the image to the dataset +            if rotate: +                if rand() < rotate_p: +                    self.dataset.append(im[:, :, ::-1, :]) +                if rand() < rotate_p: +                    self.dataset.append(im[:, :, :, ::-1]) +                if rand() < rotate_p: +                    self.dataset.append(im[:, :, ::-1, ::-1]) +                if rand() < rotate_p: +                    self.dataset.append(im.transpose(0, 1, 3, 2)) +                if rand() < rotate_p: +                    self.dataset.append(im.transpose(0, 1, 3, 2)[:, :, ::-1, :]) +                if rand() < rotate_p: +                    self.dataset.append(im.transpose(0, 1, 3, 2)[:, :, :, ::-1]) +                if rand() < rotate_p: +                    self.dataset.append(im.transpose(0, 1, 3, 2)[:, :, ::-1, ::-1]) +        self.dataset = vstack(self.dataset) +        if shuffle: +            self.dataset = self.dataset[permutation(self.dataset.shape[0]), ...] +        self.batch_size = ( +            batch_size if batch_size is not None else self.dataset.shape[0] +        ) + +    def torch(self, device=None): +        """# Cast to Torch Tensor""" +        import torch + +        if device is None: +            device = torch.device("cpu") +        return torch.tensor(self.dataset).to(device) + +    def numpy(self): +        """# Cast to Numpy Array""" +        return self.dataset + +    def __len__(self): +        """# Return Number of Cases + +        (or Number in each Batch)""" +        return self.dataset.shape[0] // self.batch_size + +    def __getitem__(self, index): +        """# Slice By Batch""" +        if type(index) == tuple: +            return self.dataset[index] +        elif type(index) == int: +            return self.dataset[ +                index * self.batch_size : (index + 1) * self.batch_size, ... +            ] +        return diff --git a/code/sunlab/common/data/shape_dataset.py b/code/sunlab/common/data/shape_dataset.py new file mode 100644 index 0000000..5a68736 --- /dev/null +++ b/code/sunlab/common/data/shape_dataset.py @@ -0,0 +1,57 @@ +from .dataset import Dataset + + +class ShapeDataset(Dataset): +    """# Shape Dataset""" + +    def __init__( +        self, +        dataset_filename, +        data_columns=[ +            "Area", +            "MjrAxisLength", +            "MnrAxisLength", +            "Eccentricity", +            "ConvexArea", +            "EquivDiameter", +            "Solidity", +            "Extent", +            "Perimeter", +            "ConvexPerim", +            "FibLen", +            "InscribeR", +            "BlebLen", +        ], +        label_columns=["Class"], +        batch_size=None, +        shuffle=False, +        val_split=0.0, +        scaler=None, +        sort_columns=None, +        random_seed=4332, +        pre_scale=10, +        **kwargs +    ): +        """# Initialize Dataset +        self.dataset = dataset (N, ...) +        self.labels = labels (N, ...) + +        Optional Arguments: +        - prescale_function: The function that takes the ratio and transforms +        the dataset by multiplying the prescale_function output +        - sort_columns: The columns to sort the data by initially +        - equal_split: If the classifications should be equally split in +        training""" +        super().__init__( +            dataset_filename, +            data_columns=data_columns, +            label_columns=label_columns, +            batch_size=batch_size, +            shuffle=shuffle, +            val_split=val_split, +            scaler=scaler, +            sort_columns=sort_columns, +            random_seed=random_seed, +            pre_scale=pre_scale, +            **kwargs +        ) diff --git a/code/sunlab/common/data/utilities.py b/code/sunlab/common/data/utilities.py new file mode 100644 index 0000000..6b4e6f3 --- /dev/null +++ b/code/sunlab/common/data/utilities.py @@ -0,0 +1,119 @@ +from .shape_dataset import ShapeDataset +from ..scaler.max_abs_scaler import MaxAbsScaler + + +def import_10x( +    filename, +    magnification=10, +    batch_size=None, +    shuffle=False, +    val_split=0.0, +    scaler=None, +    sort_columns=None, +): +    """# Import a 10x Image Dataset +     +    Pixel-to-micron: ???""" +    magnification = 10 +    dataset = ShapeDataset( +        filename, +        batch_size=batch_size, +        shuffle=shuffle, +        pre_scale=magnification, +        val_split=val_split, +        scaler=scaler, +        sort_columns=sort_columns, +    ) +    return dataset + + +def import_20x( +    filename, +    magnification=10, +    batch_size=None, +    shuffle=False, +    val_split=0.0, +    scaler=None, +    sort_columns=None, +): +    """# Import a 20x Image Dataset +     +    Pixel-to-micron: ???""" +    magnification = 20 +    dataset = ShapeDataset( +        filename, +        batch_size=batch_size, +        shuffle=shuffle, +        pre_scale=magnification, +        val_split=val_split, +        scaler=scaler, +        sort_columns=sort_columns, +    ) +    return dataset + + +def import_dataset( +    filename, +    magnification, +    batch_size=None, +    shuffle=False, +    val_split=0.0, +    scaler=None, +    sort_columns=None, +): +    """# Import a dataset + +    Requires a magnificaiton to be specified""" +    dataset = ShapeDataset( +        filename, +        pre_scale=magnification, +        batch_size=batch_size, +        shuffle=shuffle, +        val_split=val_split, +        scaler=scaler, +        sort_columns=sort_columns, +    ) +    return dataset + + +def import_full_dataset(fname, magnification=20, scaler=None): +    """# Import a Full Dataset + +    If a classification file exists(.txt with a 'Class' header and 'frame','cellnumber' headers), also import it""" +    from os.path import isfile +    import pandas as pd +    import numpy as np + +    cfname = fname +    tfname = cfname[:-3] + "txt" +    columns = [ +        "frame", +        "cellnumber", +        "x-cent", +        "y-cent", +        "actinedge", +        "filopodia", +        "bleb", +        "lamellipodia", +    ] +    if isfile(tfname): +        dataset = import_dataset(cfname, magnification=magnification, scaler=scaler) +        class_df = np.loadtxt(tfname, skiprows=1) +        class_df = pd.DataFrame(class_df, columns=columns) +        full_df = pd.merge( +            dataset.dataframe, +            class_df, +            left_on=["Frames", "CellNum"], +            right_on=["frame", "cellnumber"], +        ) +        full_df["Class"] = np.argmax( +            class_df[["actinedge", "filopodia", "bleb", "lamellipodia"]].to_numpy(), +            axis=-1, +        ) +        dataset.labels = full_df["Class"].to_numpy() +    else: +        dataset = import_dataset(cfname, magnification=magnification, scaler=scaler) +        full_df = dataset.dataframe +    dataset.dataframe = full_df +    dataset.filter_off() +    return dataset  | 
