aboutsummaryrefslogtreecommitdiff
path: root/code/sunlab/common/data
diff options
context:
space:
mode:
authorChristian C <cc@localhost>2024-11-11 12:29:32 -0800
committerChristian C <cc@localhost>2024-11-11 12:29:32 -0800
commitb85ee9d64a536937912544c7bbd5b98b635b7e8d (patch)
treecef7bc17d7b29f40fc6b1867d0ce0a742d5583d0 /code/sunlab/common/data
Initial commit
Diffstat (limited to 'code/sunlab/common/data')
-rw-r--r--code/sunlab/common/data/__init__.py6
-rw-r--r--code/sunlab/common/data/basic.py6
-rw-r--r--code/sunlab/common/data/dataset.py255
-rw-r--r--code/sunlab/common/data/dataset_iterator.py34
-rw-r--r--code/sunlab/common/data/image_dataset.py75
-rw-r--r--code/sunlab/common/data/shape_dataset.py57
-rw-r--r--code/sunlab/common/data/utilities.py119
7 files changed, 552 insertions, 0 deletions
diff --git a/code/sunlab/common/data/__init__.py b/code/sunlab/common/data/__init__.py
new file mode 100644
index 0000000..3e26874
--- /dev/null
+++ b/code/sunlab/common/data/__init__.py
@@ -0,0 +1,6 @@
+from .basic import *
+from .dataset import *
+from .dataset_iterator import *
+from .shape_dataset import *
+from .image_dataset import *
+from .utilities import *
diff --git a/code/sunlab/common/data/basic.py b/code/sunlab/common/data/basic.py
new file mode 100644
index 0000000..bb2e912
--- /dev/null
+++ b/code/sunlab/common/data/basic.py
@@ -0,0 +1,6 @@
+import numpy
+
+
+numpy.load_dat = lambda *args, **kwargs: numpy.load(
+ *args, **kwargs, allow_pickle=True
+).item()
diff --git a/code/sunlab/common/data/dataset.py b/code/sunlab/common/data/dataset.py
new file mode 100644
index 0000000..8589abf
--- /dev/null
+++ b/code/sunlab/common/data/dataset.py
@@ -0,0 +1,255 @@
+from .dataset_iterator import DatasetIterator
+
+
+class Dataset:
+ """# Dataset Superclass"""
+
+ base_scale = 10.0
+
+ def __init__(
+ self,
+ dataset_filename,
+ data_columns=[],
+ label_columns=[],
+ batch_size=None,
+ shuffle=False,
+ val_split=0.0,
+ scaler=None,
+ sort_columns=None,
+ random_seed=4332,
+ pre_scale=10.0,
+ **kwargs
+ ):
+ """# Initialize Dataset
+ self.dataset = dataset (N, ...)
+ self.labels = labels (N, ...)
+
+ Optional Arguments:
+ - prescale_function: The function that takes the ratio and transforms
+ the dataset by multiplying the prescale_function output
+ - sort_columns: The columns to sort the data by initially
+ - equal_split: If the classifications should be equally split in
+ training"""
+ from pandas import read_csv
+ from numpy import array, all
+ from numpy.random import seed
+
+ if seed is not None:
+ seed(random_seed)
+
+ # Basic Dataset Information
+ self.data_columns = data_columns
+ self.label_columns = label_columns
+ self.source = dataset_filename
+ self.dataframe = read_csv(self.source)
+
+ # Pre-scaling Transformation
+ prescale_ratio = self.base_scale / pre_scale
+ ratio = prescale_ratio
+ prescale_powers = array([2, 1, 1, 0, 2, 1, 0, 0, 1, 1, 1, 1, 1])
+ if "prescale_function" in kwargs.keys():
+ prescale_function = kwargs["prescale_function"]
+ else:
+
+ def prescale_function(x):
+ return x**prescale_powers
+
+ self.prescale_function = prescale_function
+ self.prescale_factor = self.prescale_function(ratio)
+ assert (
+ len(data_columns) == self.prescale_factor.shape[0]
+ ), "Column Mismatch on Prescale"
+ self.original_scale = pre_scale
+
+ # Scaling Transformation
+ self.scaled = scaler is not None
+ self.scaler = scaler
+
+ # Training Dataset Information
+ self.do_split = False if val_split == 0.0 else True
+ self.validation_split = val_split
+ self.batch_size = batch_size
+ self.do_shuffle = shuffle
+ self.equal_split = False
+ if "equal_split" in kwargs.keys():
+ self.equal_split = kwargs["equal_split"]
+
+ # Classification Labels if they exist
+ self.dataset = self.dataframe[self.data_columns].to_numpy()
+ if len(self.label_columns) == 0:
+ self.labels = None
+ elif not all([column in self.dataframe.columns for column in label_columns]):
+ import warnings
+
+ warnings.warn(
+ "No classification labels found for the dataset", RuntimeWarning
+ )
+ self.labels = None
+ else:
+ self.labels = self.dataframe[self.label_columns].squeeze()
+
+ # Initialize the dataset
+ if "sort_columns" in kwargs.keys():
+ self.sort(kwargs["sort_columns"])
+ if self.do_shuffle:
+ self.shuffle()
+ if self.do_split:
+ self.split()
+ self.refresh_dataset()
+
+ def __len__(self):
+ """# Get how many cases are in the dataset"""
+ return self.dataset.shape[0]
+
+ def __getitem__(self, idx):
+ """# Make Dataset Sliceable"""
+ idx_slice = None
+ slice_stride = 1 if self.batch_size is None else self.batch_size
+ # If we pass a slice, return the slice
+ if type(idx) == slice:
+ idx_slice = idx
+ # If we pass an int, return a batch-size slice
+ else:
+ idx_slice = slice(
+ idx * slice_stride, min([len(self), (idx + 1) * slice_stride])
+ )
+ if self.labels is None:
+ return self.dataset[idx_slice, ...]
+ return self.dataset[idx_slice, ...], self.labels[idx_slice, ...]
+
+ def scale_data(self, data):
+ """# Scale dataset from scaling function"""
+ data = data * self.prescale_factor
+ if not (self.scaler is None):
+ data = self.scaler(data)
+ return data
+
+ def scale(self):
+ """# Scale Dataset"""
+ self.dataset = self.scale_data(self.dataset)
+
+ def refresh_dataset(self, dataframe=None):
+ """# Refresh Dataset
+
+ Regenerate the dataset from a dataframe.
+ Primarily used after a sort or filter."""
+ if dataframe is None:
+ dataframe = self.dataframe
+ self.dataset = dataframe[self.data_columns].to_numpy()
+ if self.labels is not None:
+ self.labels = dataframe[self.label_columns].to_numpy().squeeze()
+ self.scale()
+
+ def sort_on(self, columns):
+ """# Sort Dataset on Column(s)"""
+ from numpy import all
+
+ if type(columns) == str:
+ columns = [columns]
+ if columns is not None:
+ assert all(
+ [column in self.dataframe.columns for column in columns]
+ ), "Dataframe does not contain some provided columns!"
+ self.dataframe = self.dataframe.sort_values(by=columns)
+ self.refresh_dataset()
+
+ def filter_on(self, column, value):
+ """# Filter Dataset on Column Value(s)"""
+ assert column in self.dataframe.columns, "Column DNE"
+ self.working_dataset = self.dataframe[self.dataframe[column].isin(value)]
+ self.refresh_dataset(self.working_dataset)
+
+ def filter_off(self):
+ """# Remove any filter on the dataset"""
+ self.refresh_dataset()
+
+ def unique(self, column):
+ """# Get unique values in a column(s)"""
+ assert column in self.dataframe.columns, "Column DNE"
+ from numpy import unique
+
+ return unique(self.dataframe[column])
+
+ def shuffle_data(self, data, labels=None):
+ """# Shuffle a dataset"""
+ from numpy.random import permutation
+
+ shuffled = permutation(data.shape[0])
+ if labels is not None:
+ assert (
+ self.labels.shape[0] == self.dataset.shape[0]
+ ), "Dataset and Label Shape Mismatch"
+ shuf_data = data[shuffled, ...]
+ shuf_labels = labels[shuffled]
+ if len(labels.shape) > 1:
+ shuf_labels = labels[shuffled,...]
+ return shuf_data, shuf_labels
+ return data[shuffled, ...]
+
+ def shuffle(self):
+ """# Shuffle the dataset"""
+ if self.do_shuffle:
+ if self.labels is None:
+ self.dataset = self.shuffle_data(self.dataset)
+ self.dataset, self.labels = self.shuffle_data(self.dataset, self.labels)
+
+ def split(self):
+ """# Training/ Validation Splitting"""
+ from numpy import floor, unique, where, hstack, delete
+ from numpy.random import permutation
+
+ equal_classes = self.equal_split
+ if not self.do_split:
+ return
+ assert self.validation_split <= 1.0, "Too High"
+ assert self.validation_split > 0.0, "Too Low"
+ train_count = int(floor(self.dataset.shape[0] * (1 - self.validation_split)))
+ training_data = self.dataset[:train_count, ...]
+ training_labels = None
+ validation_data = self.dataset[train_count:, ...]
+ validation_labels = None
+ if self.labels is not None:
+ if equal_classes:
+ # Ensure the split balances the prevalence of each class
+ assert len(self.labels.shape) == 1, "1D Classification Only Currently"
+ classification_breakdown = unique(self.labels, return_counts=True)
+ train_count = min(
+ [
+ train_count,
+ classification_breakdown.shape[0]
+ * min(classification_breakdown[1]),
+ ]
+ )
+ class_size = train_count / classification_breakdown.shape[0]
+ class_indicies = [
+ permutation(where(self.labels == _class)[0])
+ for _class in classification_breakdown[0]
+ ]
+ class_indicies = [indexes[:class_size] for indexes in class_indicies]
+ train_class_indicies = hstack(class_indicies).squeeze()
+ train_class_indicies = permutation(train_class_indicies)
+ training_data = self.dataset[train_class_indicies, ...]
+ training_labels = self.labels[train_class_indicies]
+ if len(self.labels.shape) > 1:
+ training_labels = self.labels[train_class_indicies,...]
+ validation_data = delete(self.dataset, train_class_indicies, axis=0)
+ validation_labels = delete(
+ self.labels, train_class_indicies, axis=0
+ ).squeeze()
+ else:
+ training_labels = self.labels[:train_count]
+ if len(training_labels.shape) > 1:
+ training_labels = self.labels[:train_count, ...]
+ validation_labels = self.labels[train_count:]
+ if len(validation_labels.shape) > 1:
+ validation_labels = self.labels[train_count:, ...]
+ self.training_data = training_data
+ self.validation_data = validation_data
+ self.training = DatasetIterator(training_data, training_labels, self.batch_size)
+ self.validation = DatasetIterator(
+ validation_data, validation_labels, self.batch_size
+ )
+
+ def reset_iterators(self):
+ """# Reset Train/ Validation Iterators"""
+ self.split()
diff --git a/code/sunlab/common/data/dataset_iterator.py b/code/sunlab/common/data/dataset_iterator.py
new file mode 100644
index 0000000..7c91caa
--- /dev/null
+++ b/code/sunlab/common/data/dataset_iterator.py
@@ -0,0 +1,34 @@
+class DatasetIterator:
+ """# Dataset Iterator
+
+ Creates an iterator object on a dataset and labels"""
+
+ def __init__(self, dataset, labels=None, batch_size=None):
+ """# Initialize the iterator with the dataset and labels
+
+ - batch_size: How many to include in the iteration"""
+ self.dataset = dataset
+ self.labels = labels
+ self.current = 0
+ self.batch_size = (
+ batch_size if batch_size is not None else self.dataset.shape[0]
+ )
+
+ def __iter__(self):
+ """# Iterator Function"""
+ return self
+
+ def __next__(self):
+ """# Next Iteration
+
+ Slice the dataset and labels to provide"""
+ self.cur = self.current
+ self.current += 1
+ if self.cur * self.batch_size < self.dataset.shape[0]:
+ iterator_slice = slice(
+ self.cur * self.batch_size, (self.cur + 1) * self.batch_size
+ )
+ if self.labels is None:
+ return self.dataset[iterator_slice, ...]
+ return self.dataset[iterator_slice, ...], self.labels[iterator_slice, ...]
+ raise StopIteration
diff --git a/code/sunlab/common/data/image_dataset.py b/code/sunlab/common/data/image_dataset.py
new file mode 100644
index 0000000..46e77b6
--- /dev/null
+++ b/code/sunlab/common/data/image_dataset.py
@@ -0,0 +1,75 @@
+class ImageDataset:
+ def __init__(
+ self,
+ base_directory,
+ ext="png",
+ channels=[0],
+ batch_size=None,
+ shuffle=False,
+ rotate=False,
+ rotate_p=1.,
+ ):
+ """# Image Dataset
+
+ Load a directory of images"""
+ from glob import glob
+ from matplotlib.pyplot import imread
+ from numpy import newaxis, vstack
+ from numpy.random import permutation, rand
+
+ self.base_directory = base_directory
+ files = glob(self.base_directory + "*." + ext)
+ self.dataset = []
+ for file in files:
+ im = imread(file)[newaxis, :, :, channels].transpose(0, 3, 1, 2)
+ self.dataset.append(im)
+ # Also add rotations of the image to the dataset
+ if rotate:
+ if rand() < rotate_p:
+ self.dataset.append(im[:, :, ::-1, :])
+ if rand() < rotate_p:
+ self.dataset.append(im[:, :, :, ::-1])
+ if rand() < rotate_p:
+ self.dataset.append(im[:, :, ::-1, ::-1])
+ if rand() < rotate_p:
+ self.dataset.append(im.transpose(0, 1, 3, 2))
+ if rand() < rotate_p:
+ self.dataset.append(im.transpose(0, 1, 3, 2)[:, :, ::-1, :])
+ if rand() < rotate_p:
+ self.dataset.append(im.transpose(0, 1, 3, 2)[:, :, :, ::-1])
+ if rand() < rotate_p:
+ self.dataset.append(im.transpose(0, 1, 3, 2)[:, :, ::-1, ::-1])
+ self.dataset = vstack(self.dataset)
+ if shuffle:
+ self.dataset = self.dataset[permutation(self.dataset.shape[0]), ...]
+ self.batch_size = (
+ batch_size if batch_size is not None else self.dataset.shape[0]
+ )
+
+ def torch(self, device=None):
+ """# Cast to Torch Tensor"""
+ import torch
+
+ if device is None:
+ device = torch.device("cpu")
+ return torch.tensor(self.dataset).to(device)
+
+ def numpy(self):
+ """# Cast to Numpy Array"""
+ return self.dataset
+
+ def __len__(self):
+ """# Return Number of Cases
+
+ (or Number in each Batch)"""
+ return self.dataset.shape[0] // self.batch_size
+
+ def __getitem__(self, index):
+ """# Slice By Batch"""
+ if type(index) == tuple:
+ return self.dataset[index]
+ elif type(index) == int:
+ return self.dataset[
+ index * self.batch_size : (index + 1) * self.batch_size, ...
+ ]
+ return
diff --git a/code/sunlab/common/data/shape_dataset.py b/code/sunlab/common/data/shape_dataset.py
new file mode 100644
index 0000000..5a68736
--- /dev/null
+++ b/code/sunlab/common/data/shape_dataset.py
@@ -0,0 +1,57 @@
+from .dataset import Dataset
+
+
+class ShapeDataset(Dataset):
+ """# Shape Dataset"""
+
+ def __init__(
+ self,
+ dataset_filename,
+ data_columns=[
+ "Area",
+ "MjrAxisLength",
+ "MnrAxisLength",
+ "Eccentricity",
+ "ConvexArea",
+ "EquivDiameter",
+ "Solidity",
+ "Extent",
+ "Perimeter",
+ "ConvexPerim",
+ "FibLen",
+ "InscribeR",
+ "BlebLen",
+ ],
+ label_columns=["Class"],
+ batch_size=None,
+ shuffle=False,
+ val_split=0.0,
+ scaler=None,
+ sort_columns=None,
+ random_seed=4332,
+ pre_scale=10,
+ **kwargs
+ ):
+ """# Initialize Dataset
+ self.dataset = dataset (N, ...)
+ self.labels = labels (N, ...)
+
+ Optional Arguments:
+ - prescale_function: The function that takes the ratio and transforms
+ the dataset by multiplying the prescale_function output
+ - sort_columns: The columns to sort the data by initially
+ - equal_split: If the classifications should be equally split in
+ training"""
+ super().__init__(
+ dataset_filename,
+ data_columns=data_columns,
+ label_columns=label_columns,
+ batch_size=batch_size,
+ shuffle=shuffle,
+ val_split=val_split,
+ scaler=scaler,
+ sort_columns=sort_columns,
+ random_seed=random_seed,
+ pre_scale=pre_scale,
+ **kwargs
+ )
diff --git a/code/sunlab/common/data/utilities.py b/code/sunlab/common/data/utilities.py
new file mode 100644
index 0000000..6b4e6f3
--- /dev/null
+++ b/code/sunlab/common/data/utilities.py
@@ -0,0 +1,119 @@
+from .shape_dataset import ShapeDataset
+from ..scaler.max_abs_scaler import MaxAbsScaler
+
+
+def import_10x(
+ filename,
+ magnification=10,
+ batch_size=None,
+ shuffle=False,
+ val_split=0.0,
+ scaler=None,
+ sort_columns=None,
+):
+ """# Import a 10x Image Dataset
+
+ Pixel-to-micron: ???"""
+ magnification = 10
+ dataset = ShapeDataset(
+ filename,
+ batch_size=batch_size,
+ shuffle=shuffle,
+ pre_scale=magnification,
+ val_split=val_split,
+ scaler=scaler,
+ sort_columns=sort_columns,
+ )
+ return dataset
+
+
+def import_20x(
+ filename,
+ magnification=10,
+ batch_size=None,
+ shuffle=False,
+ val_split=0.0,
+ scaler=None,
+ sort_columns=None,
+):
+ """# Import a 20x Image Dataset
+
+ Pixel-to-micron: ???"""
+ magnification = 20
+ dataset = ShapeDataset(
+ filename,
+ batch_size=batch_size,
+ shuffle=shuffle,
+ pre_scale=magnification,
+ val_split=val_split,
+ scaler=scaler,
+ sort_columns=sort_columns,
+ )
+ return dataset
+
+
+def import_dataset(
+ filename,
+ magnification,
+ batch_size=None,
+ shuffle=False,
+ val_split=0.0,
+ scaler=None,
+ sort_columns=None,
+):
+ """# Import a dataset
+
+ Requires a magnificaiton to be specified"""
+ dataset = ShapeDataset(
+ filename,
+ pre_scale=magnification,
+ batch_size=batch_size,
+ shuffle=shuffle,
+ val_split=val_split,
+ scaler=scaler,
+ sort_columns=sort_columns,
+ )
+ return dataset
+
+
+def import_full_dataset(fname, magnification=20, scaler=None):
+ """# Import a Full Dataset
+
+ If a classification file exists(.txt with a 'Class' header and 'frame','cellnumber' headers), also import it"""
+ from os.path import isfile
+ import pandas as pd
+ import numpy as np
+
+ cfname = fname
+ tfname = cfname[:-3] + "txt"
+ columns = [
+ "frame",
+ "cellnumber",
+ "x-cent",
+ "y-cent",
+ "actinedge",
+ "filopodia",
+ "bleb",
+ "lamellipodia",
+ ]
+ if isfile(tfname):
+ dataset = import_dataset(cfname, magnification=magnification, scaler=scaler)
+ class_df = np.loadtxt(tfname, skiprows=1)
+ class_df = pd.DataFrame(class_df, columns=columns)
+ full_df = pd.merge(
+ dataset.dataframe,
+ class_df,
+ left_on=["Frames", "CellNum"],
+ right_on=["frame", "cellnumber"],
+ )
+ full_df["Class"] = np.argmax(
+ class_df[["actinedge", "filopodia", "bleb", "lamellipodia"]].to_numpy(),
+ axis=-1,
+ )
+ dataset.labels = full_df["Class"].to_numpy()
+ else:
+ dataset = import_dataset(cfname, magnification=magnification, scaler=scaler)
+ full_df = dataset.dataframe
+ dataset.dataframe = full_df
+ dataset.filter_off()
+ return dataset