diff options
author | Christian C <cc@localhost> | 2024-11-11 12:29:32 -0800 |
---|---|---|
committer | Christian C <cc@localhost> | 2024-11-11 12:29:32 -0800 |
commit | b85ee9d64a536937912544c7bbd5b98b635b7e8d (patch) | |
tree | cef7bc17d7b29f40fc6b1867d0ce0a742d5583d0 /code/sunlab/common |
Initial commit
Diffstat (limited to 'code/sunlab/common')
28 files changed, 1473 insertions, 0 deletions
diff --git a/code/sunlab/common/__init__.py b/code/sunlab/common/__init__.py new file mode 100644 index 0000000..cb6716c --- /dev/null +++ b/code/sunlab/common/__init__.py @@ -0,0 +1,5 @@ +from .data import * +from .distribution import * +from .scaler import * +from .mathlib import * +from .plotting import * diff --git a/code/sunlab/common/data/__init__.py b/code/sunlab/common/data/__init__.py new file mode 100644 index 0000000..3e26874 --- /dev/null +++ b/code/sunlab/common/data/__init__.py @@ -0,0 +1,6 @@ +from .basic import * +from .dataset import * +from .dataset_iterator import * +from .shape_dataset import * +from .image_dataset import * +from .utilities import * diff --git a/code/sunlab/common/data/basic.py b/code/sunlab/common/data/basic.py new file mode 100644 index 0000000..bb2e912 --- /dev/null +++ b/code/sunlab/common/data/basic.py @@ -0,0 +1,6 @@ +import numpy + + +numpy.load_dat = lambda *args, **kwargs: numpy.load( + *args, **kwargs, allow_pickle=True +).item() diff --git a/code/sunlab/common/data/dataset.py b/code/sunlab/common/data/dataset.py new file mode 100644 index 0000000..8589abf --- /dev/null +++ b/code/sunlab/common/data/dataset.py @@ -0,0 +1,255 @@ +from .dataset_iterator import DatasetIterator + + +class Dataset: + """# Dataset Superclass""" + + base_scale = 10.0 + + def __init__( + self, + dataset_filename, + data_columns=[], + label_columns=[], + batch_size=None, + shuffle=False, + val_split=0.0, + scaler=None, + sort_columns=None, + random_seed=4332, + pre_scale=10.0, + **kwargs + ): + """# Initialize Dataset + self.dataset = dataset (N, ...) + self.labels = labels (N, ...) + + Optional Arguments: + - prescale_function: The function that takes the ratio and transforms + the dataset by multiplying the prescale_function output + - sort_columns: The columns to sort the data by initially + - equal_split: If the classifications should be equally split in + training""" + from pandas import read_csv + from numpy import array, all + from numpy.random import seed + + if seed is not None: + seed(random_seed) + + # Basic Dataset Information + self.data_columns = data_columns + self.label_columns = label_columns + self.source = dataset_filename + self.dataframe = read_csv(self.source) + + # Pre-scaling Transformation + prescale_ratio = self.base_scale / pre_scale + ratio = prescale_ratio + prescale_powers = array([2, 1, 1, 0, 2, 1, 0, 0, 1, 1, 1, 1, 1]) + if "prescale_function" in kwargs.keys(): + prescale_function = kwargs["prescale_function"] + else: + + def prescale_function(x): + return x**prescale_powers + + self.prescale_function = prescale_function + self.prescale_factor = self.prescale_function(ratio) + assert ( + len(data_columns) == self.prescale_factor.shape[0] + ), "Column Mismatch on Prescale" + self.original_scale = pre_scale + + # Scaling Transformation + self.scaled = scaler is not None + self.scaler = scaler + + # Training Dataset Information + self.do_split = False if val_split == 0.0 else True + self.validation_split = val_split + self.batch_size = batch_size + self.do_shuffle = shuffle + self.equal_split = False + if "equal_split" in kwargs.keys(): + self.equal_split = kwargs["equal_split"] + + # Classification Labels if they exist + self.dataset = self.dataframe[self.data_columns].to_numpy() + if len(self.label_columns) == 0: + self.labels = None + elif not all([column in self.dataframe.columns for column in label_columns]): + import warnings + + warnings.warn( + "No classification labels found for the dataset", RuntimeWarning + ) + self.labels = None + else: + self.labels = self.dataframe[self.label_columns].squeeze() + + # Initialize the dataset + if "sort_columns" in kwargs.keys(): + self.sort(kwargs["sort_columns"]) + if self.do_shuffle: + self.shuffle() + if self.do_split: + self.split() + self.refresh_dataset() + + def __len__(self): + """# Get how many cases are in the dataset""" + return self.dataset.shape[0] + + def __getitem__(self, idx): + """# Make Dataset Sliceable""" + idx_slice = None + slice_stride = 1 if self.batch_size is None else self.batch_size + # If we pass a slice, return the slice + if type(idx) == slice: + idx_slice = idx + # If we pass an int, return a batch-size slice + else: + idx_slice = slice( + idx * slice_stride, min([len(self), (idx + 1) * slice_stride]) + ) + if self.labels is None: + return self.dataset[idx_slice, ...] + return self.dataset[idx_slice, ...], self.labels[idx_slice, ...] + + def scale_data(self, data): + """# Scale dataset from scaling function""" + data = data * self.prescale_factor + if not (self.scaler is None): + data = self.scaler(data) + return data + + def scale(self): + """# Scale Dataset""" + self.dataset = self.scale_data(self.dataset) + + def refresh_dataset(self, dataframe=None): + """# Refresh Dataset + + Regenerate the dataset from a dataframe. + Primarily used after a sort or filter.""" + if dataframe is None: + dataframe = self.dataframe + self.dataset = dataframe[self.data_columns].to_numpy() + if self.labels is not None: + self.labels = dataframe[self.label_columns].to_numpy().squeeze() + self.scale() + + def sort_on(self, columns): + """# Sort Dataset on Column(s)""" + from numpy import all + + if type(columns) == str: + columns = [columns] + if columns is not None: + assert all( + [column in self.dataframe.columns for column in columns] + ), "Dataframe does not contain some provided columns!" + self.dataframe = self.dataframe.sort_values(by=columns) + self.refresh_dataset() + + def filter_on(self, column, value): + """# Filter Dataset on Column Value(s)""" + assert column in self.dataframe.columns, "Column DNE" + self.working_dataset = self.dataframe[self.dataframe[column].isin(value)] + self.refresh_dataset(self.working_dataset) + + def filter_off(self): + """# Remove any filter on the dataset""" + self.refresh_dataset() + + def unique(self, column): + """# Get unique values in a column(s)""" + assert column in self.dataframe.columns, "Column DNE" + from numpy import unique + + return unique(self.dataframe[column]) + + def shuffle_data(self, data, labels=None): + """# Shuffle a dataset""" + from numpy.random import permutation + + shuffled = permutation(data.shape[0]) + if labels is not None: + assert ( + self.labels.shape[0] == self.dataset.shape[0] + ), "Dataset and Label Shape Mismatch" + shuf_data = data[shuffled, ...] + shuf_labels = labels[shuffled] + if len(labels.shape) > 1: + shuf_labels = labels[shuffled,...] + return shuf_data, shuf_labels + return data[shuffled, ...] + + def shuffle(self): + """# Shuffle the dataset""" + if self.do_shuffle: + if self.labels is None: + self.dataset = self.shuffle_data(self.dataset) + self.dataset, self.labels = self.shuffle_data(self.dataset, self.labels) + + def split(self): + """# Training/ Validation Splitting""" + from numpy import floor, unique, where, hstack, delete + from numpy.random import permutation + + equal_classes = self.equal_split + if not self.do_split: + return + assert self.validation_split <= 1.0, "Too High" + assert self.validation_split > 0.0, "Too Low" + train_count = int(floor(self.dataset.shape[0] * (1 - self.validation_split))) + training_data = self.dataset[:train_count, ...] + training_labels = None + validation_data = self.dataset[train_count:, ...] + validation_labels = None + if self.labels is not None: + if equal_classes: + # Ensure the split balances the prevalence of each class + assert len(self.labels.shape) == 1, "1D Classification Only Currently" + classification_breakdown = unique(self.labels, return_counts=True) + train_count = min( + [ + train_count, + classification_breakdown.shape[0] + * min(classification_breakdown[1]), + ] + ) + class_size = train_count / classification_breakdown.shape[0] + class_indicies = [ + permutation(where(self.labels == _class)[0]) + for _class in classification_breakdown[0] + ] + class_indicies = [indexes[:class_size] for indexes in class_indicies] + train_class_indicies = hstack(class_indicies).squeeze() + train_class_indicies = permutation(train_class_indicies) + training_data = self.dataset[train_class_indicies, ...] + training_labels = self.labels[train_class_indicies] + if len(self.labels.shape) > 1: + training_labels = self.labels[train_class_indicies,...] + validation_data = delete(self.dataset, train_class_indicies, axis=0) + validation_labels = delete( + self.labels, train_class_indicies, axis=0 + ).squeeze() + else: + training_labels = self.labels[:train_count] + if len(training_labels.shape) > 1: + training_labels = self.labels[:train_count, ...] + validation_labels = self.labels[train_count:] + if len(validation_labels.shape) > 1: + validation_labels = self.labels[train_count:, ...] + self.training_data = training_data + self.validation_data = validation_data + self.training = DatasetIterator(training_data, training_labels, self.batch_size) + self.validation = DatasetIterator( + validation_data, validation_labels, self.batch_size + ) + + def reset_iterators(self): + """# Reset Train/ Validation Iterators""" + self.split() diff --git a/code/sunlab/common/data/dataset_iterator.py b/code/sunlab/common/data/dataset_iterator.py new file mode 100644 index 0000000..7c91caa --- /dev/null +++ b/code/sunlab/common/data/dataset_iterator.py @@ -0,0 +1,34 @@ +class DatasetIterator: + """# Dataset Iterator + + Creates an iterator object on a dataset and labels""" + + def __init__(self, dataset, labels=None, batch_size=None): + """# Initialize the iterator with the dataset and labels + + - batch_size: How many to include in the iteration""" + self.dataset = dataset + self.labels = labels + self.current = 0 + self.batch_size = ( + batch_size if batch_size is not None else self.dataset.shape[0] + ) + + def __iter__(self): + """# Iterator Function""" + return self + + def __next__(self): + """# Next Iteration + + Slice the dataset and labels to provide""" + self.cur = self.current + self.current += 1 + if self.cur * self.batch_size < self.dataset.shape[0]: + iterator_slice = slice( + self.cur * self.batch_size, (self.cur + 1) * self.batch_size + ) + if self.labels is None: + return self.dataset[iterator_slice, ...] + return self.dataset[iterator_slice, ...], self.labels[iterator_slice, ...] + raise StopIteration diff --git a/code/sunlab/common/data/image_dataset.py b/code/sunlab/common/data/image_dataset.py new file mode 100644 index 0000000..46e77b6 --- /dev/null +++ b/code/sunlab/common/data/image_dataset.py @@ -0,0 +1,75 @@ +class ImageDataset: + def __init__( + self, + base_directory, + ext="png", + channels=[0], + batch_size=None, + shuffle=False, + rotate=False, + rotate_p=1., + ): + """# Image Dataset + + Load a directory of images""" + from glob import glob + from matplotlib.pyplot import imread + from numpy import newaxis, vstack + from numpy.random import permutation, rand + + self.base_directory = base_directory + files = glob(self.base_directory + "*." + ext) + self.dataset = [] + for file in files: + im = imread(file)[newaxis, :, :, channels].transpose(0, 3, 1, 2) + self.dataset.append(im) + # Also add rotations of the image to the dataset + if rotate: + if rand() < rotate_p: + self.dataset.append(im[:, :, ::-1, :]) + if rand() < rotate_p: + self.dataset.append(im[:, :, :, ::-1]) + if rand() < rotate_p: + self.dataset.append(im[:, :, ::-1, ::-1]) + if rand() < rotate_p: + self.dataset.append(im.transpose(0, 1, 3, 2)) + if rand() < rotate_p: + self.dataset.append(im.transpose(0, 1, 3, 2)[:, :, ::-1, :]) + if rand() < rotate_p: + self.dataset.append(im.transpose(0, 1, 3, 2)[:, :, :, ::-1]) + if rand() < rotate_p: + self.dataset.append(im.transpose(0, 1, 3, 2)[:, :, ::-1, ::-1]) + self.dataset = vstack(self.dataset) + if shuffle: + self.dataset = self.dataset[permutation(self.dataset.shape[0]), ...] + self.batch_size = ( + batch_size if batch_size is not None else self.dataset.shape[0] + ) + + def torch(self, device=None): + """# Cast to Torch Tensor""" + import torch + + if device is None: + device = torch.device("cpu") + return torch.tensor(self.dataset).to(device) + + def numpy(self): + """# Cast to Numpy Array""" + return self.dataset + + def __len__(self): + """# Return Number of Cases + + (or Number in each Batch)""" + return self.dataset.shape[0] // self.batch_size + + def __getitem__(self, index): + """# Slice By Batch""" + if type(index) == tuple: + return self.dataset[index] + elif type(index) == int: + return self.dataset[ + index * self.batch_size : (index + 1) * self.batch_size, ... + ] + return diff --git a/code/sunlab/common/data/shape_dataset.py b/code/sunlab/common/data/shape_dataset.py new file mode 100644 index 0000000..5a68736 --- /dev/null +++ b/code/sunlab/common/data/shape_dataset.py @@ -0,0 +1,57 @@ +from .dataset import Dataset + + +class ShapeDataset(Dataset): + """# Shape Dataset""" + + def __init__( + self, + dataset_filename, + data_columns=[ + "Area", + "MjrAxisLength", + "MnrAxisLength", + "Eccentricity", + "ConvexArea", + "EquivDiameter", + "Solidity", + "Extent", + "Perimeter", + "ConvexPerim", + "FibLen", + "InscribeR", + "BlebLen", + ], + label_columns=["Class"], + batch_size=None, + shuffle=False, + val_split=0.0, + scaler=None, + sort_columns=None, + random_seed=4332, + pre_scale=10, + **kwargs + ): + """# Initialize Dataset + self.dataset = dataset (N, ...) + self.labels = labels (N, ...) + + Optional Arguments: + - prescale_function: The function that takes the ratio and transforms + the dataset by multiplying the prescale_function output + - sort_columns: The columns to sort the data by initially + - equal_split: If the classifications should be equally split in + training""" + super().__init__( + dataset_filename, + data_columns=data_columns, + label_columns=label_columns, + batch_size=batch_size, + shuffle=shuffle, + val_split=val_split, + scaler=scaler, + sort_columns=sort_columns, + random_seed=random_seed, + pre_scale=pre_scale, + **kwargs + ) diff --git a/code/sunlab/common/data/utilities.py b/code/sunlab/common/data/utilities.py new file mode 100644 index 0000000..6b4e6f3 --- /dev/null +++ b/code/sunlab/common/data/utilities.py @@ -0,0 +1,119 @@ +from .shape_dataset import ShapeDataset +from ..scaler.max_abs_scaler import MaxAbsScaler + + +def import_10x( + filename, + magnification=10, + batch_size=None, + shuffle=False, + val_split=0.0, + scaler=None, + sort_columns=None, +): + """# Import a 10x Image Dataset + + Pixel-to-micron: ???""" + magnification = 10 + dataset = ShapeDataset( + filename, + batch_size=batch_size, + shuffle=shuffle, + pre_scale=magnification, + val_split=val_split, + scaler=scaler, + sort_columns=sort_columns, + ) + return dataset + + +def import_20x( + filename, + magnification=10, + batch_size=None, + shuffle=False, + val_split=0.0, + scaler=None, + sort_columns=None, +): + """# Import a 20x Image Dataset + + Pixel-to-micron: ???""" + magnification = 20 + dataset = ShapeDataset( + filename, + batch_size=batch_size, + shuffle=shuffle, + pre_scale=magnification, + val_split=val_split, + scaler=scaler, + sort_columns=sort_columns, + ) + return dataset + + +def import_dataset( + filename, + magnification, + batch_size=None, + shuffle=False, + val_split=0.0, + scaler=None, + sort_columns=None, +): + """# Import a dataset + + Requires a magnificaiton to be specified""" + dataset = ShapeDataset( + filename, + pre_scale=magnification, + batch_size=batch_size, + shuffle=shuffle, + val_split=val_split, + scaler=scaler, + sort_columns=sort_columns, + ) + return dataset + + +def import_full_dataset(fname, magnification=20, scaler=None): + """# Import a Full Dataset + + If a classification file exists(.txt with a 'Class' header and 'frame','cellnumber' headers), also import it""" + from os.path import isfile + import pandas as pd + import numpy as np + + cfname = fname + tfname = cfname[:-3] + "txt" + columns = [ + "frame", + "cellnumber", + "x-cent", + "y-cent", + "actinedge", + "filopodia", + "bleb", + "lamellipodia", + ] + if isfile(tfname): + dataset = import_dataset(cfname, magnification=magnification, scaler=scaler) + class_df = np.loadtxt(tfname, skiprows=1) + class_df = pd.DataFrame(class_df, columns=columns) + full_df = pd.merge( + dataset.dataframe, + class_df, + left_on=["Frames", "CellNum"], + right_on=["frame", "cellnumber"], + ) + full_df["Class"] = np.argmax( + class_df[["actinedge", "filopodia", "bleb", "lamellipodia"]].to_numpy(), + axis=-1, + ) + dataset.labels = full_df["Class"].to_numpy() + else: + dataset = import_dataset(cfname, magnification=magnification, scaler=scaler) + full_df = dataset.dataframe + dataset.dataframe = full_df + dataset.filter_off() + return dataset diff --git a/code/sunlab/common/distribution/__init__.py b/code/sunlab/common/distribution/__init__.py new file mode 100644 index 0000000..a23cb0c --- /dev/null +++ b/code/sunlab/common/distribution/__init__.py @@ -0,0 +1,7 @@ +from .gaussian_distribution import * +from .x_gaussian_distribution import * +from .o_gaussian_distribution import * +from .s_gaussian_distribution import * +from .uniform_distribution import * +from .symmetric_uniform_distribution import * +from .swiss_roll_distribution import * diff --git a/code/sunlab/common/distribution/adversarial_distribution.py b/code/sunlab/common/distribution/adversarial_distribution.py new file mode 100644 index 0000000..675c00e --- /dev/null +++ b/code/sunlab/common/distribution/adversarial_distribution.py @@ -0,0 +1,35 @@ +class AdversarialDistribution: + """# Distribution Class to use in Adversarial Autoencoder + + For any distribution to be implemented, make sure to ensure each of the + methods are implemented""" + + def __init__(self, N): + """# Initialize the distribution for N-dimensions""" + self.dims = N + return + + def get_full_name(self): + """# Return a human-readable name of the distribution""" + return self.full_name + + def get_name(self): + """# Return a shortened name of the distribution + + Preferrably, the name should include characters that can be included in + a file name""" + return self.name + + def __str__(self): + """# Returns the short name""" + return self.get_name() + + def __repr__(self): + """# Returns the short name""" + return self.get_name() + + def __call__(self, *args): + """# Magic method when calling the distribution + + This method is going to be called when you use `dist(...)`""" + raise NotImplementedError("This distribution has not been implemented yet") diff --git a/code/sunlab/common/distribution/gaussian_distribution.py b/code/sunlab/common/distribution/gaussian_distribution.py new file mode 100644 index 0000000..e478ab6 --- /dev/null +++ b/code/sunlab/common/distribution/gaussian_distribution.py @@ -0,0 +1,23 @@ +from .adversarial_distribution import * + + +class GaussianDistribution(AdversarialDistribution): + """# Gaussian Distribution""" + + def __init__(self, N): + """# Gaussian Distribution Initialization + + Initializes the name and dimensions""" + super().__init__(N) + self.full_name = f"{N}-Dimensional Gaussian Distribution" + self.name = "G" + + def __call__(self, *args): + """# Magic method when calling the distribution + + This method is going to be called when you use gauss(N1,...,Nm)""" + import numpy as np + + return np.random.multivariate_normal( + mean=np.zeros(self.dims), cov=np.eye(self.dims), size=[*args] + ) diff --git a/code/sunlab/common/distribution/o_gaussian_distribution.py b/code/sunlab/common/distribution/o_gaussian_distribution.py new file mode 100644 index 0000000..1222ca1 --- /dev/null +++ b/code/sunlab/common/distribution/o_gaussian_distribution.py @@ -0,0 +1,38 @@ +from .adversarial_distribution import * + + +class OGaussianDistribution(AdversarialDistribution): + """# O Gaussian Distribution""" + + def __init__(self, N): + """# O Gaussian Distribution Initialization + + Initializes the name and dimensions""" + super().__init__(N) + assert self.dims == 2, "This Distribution only Supports 2-Dimensions" + self.full_name = "2-Dimensional O-Gaussian Distribution" + self.name = "OG" + + def __call__(self, *args): + """# Magic method when calling the distribution + + This method is going to be called when you use xgauss(case_count)""" + import numpy as np + + assert len(args) == 1, "Only 1 argument supported" + N = args[0] + sample_base = np.zeros((4 * N, 2)) + sample_base[0 * N : (0 + 1) * N, :] = np.random.multivariate_normal( + mean=[1, 1], cov=[[1, 0], [0, 1]], size=[N] + ) + sample_base[1 * N : (1 + 1) * N, :] = np.random.multivariate_normal( + mean=[-1, -1], cov=[[1, 0], [0, 1]], size=[N] + ) + sample_base[2 * N : (2 + 1) * N, :] = np.random.multivariate_normal( + mean=[-1, 1], cov=[[1, 0], [0, 1]], size=[N] + ) + sample_base[3 * N : (3 + 1) * N, :] = np.random.multivariate_normal( + mean=[1, -1], cov=[[1, 0], [0, 1]], size=[N] + ) + np.random.shuffle(sample_base) + return sample_base[:N, :] diff --git a/code/sunlab/common/distribution/s_gaussian_distribution.py b/code/sunlab/common/distribution/s_gaussian_distribution.py new file mode 100644 index 0000000..cace57f --- /dev/null +++ b/code/sunlab/common/distribution/s_gaussian_distribution.py @@ -0,0 +1,40 @@ +from .adversarial_distribution import * + + +class SGaussianDistribution(AdversarialDistribution): + """# S Gaussian Distribution""" + + def __init__(self, N, scale=0): + """# S Gaussian Distribution Initialization + + Initializes the name and dimensions""" + super().__init__(N) + assert self.dims == 2, "This Distribution only Supports 2-Dimensions" + self.full_name = "2-Dimensional S-Gaussian Distribution" + self.name = "SG" + self.scale = scale + + def __call__(self, *args): + """# Magic method when calling the distribution + + This method is going to be called when you use xgauss(case_count)""" + import numpy as np + + assert len(args) == 1, "Only 1 argument supported" + N = args[0] + sample_base = np.zeros((4 * N, 2)) + scale = self.scale + sample_base[0 * N : (0 + 1) * N, :] = np.random.multivariate_normal( + mean=[1, 1], cov=[[1, scale], [scale, 1]], size=[N] + ) + sample_base[1 * N : (1 + 1) * N, :] = np.random.multivariate_normal( + mean=[-1, -1], cov=[[1, scale], [scale, 1]], size=[N] + ) + sample_base[2 * N : (2 + 1) * N, :] = np.random.multivariate_normal( + mean=[-1, 1], cov=[[1, -scale], [-scale, 1]], size=[N] + ) + sample_base[3 * N : (3 + 1) * N, :] = np.random.multivariate_normal( + mean=[1, -1], cov=[[1, -scale], [-scale, 1]], size=[N] + ) + np.random.shuffle(sample_base) + return sample_base[:N, :] diff --git a/code/sunlab/common/distribution/swiss_roll_distribution.py b/code/sunlab/common/distribution/swiss_roll_distribution.py new file mode 100644 index 0000000..613bfc5 --- /dev/null +++ b/code/sunlab/common/distribution/swiss_roll_distribution.py @@ -0,0 +1,42 @@ +from .adversarial_distribution import * + + +class SwissRollDistribution(AdversarialDistribution): + """# Swiss Roll Distribution""" + + def __init__(self, N, scaling_factor=0.25, noise_level=0.15): + """# Swiss Roll Distribution Initialization + + Initializes the name and dimensions""" + super().__init__(N) + assert (self.dims == 2) or ( + self.dims == 3 + ), "This Distribution only Supports 2,3-Dimensions" + self.full_name = f"{self.dims}-Dimensional Swiss Roll Distribution Distribution" + self.name = f"SR{self.dims}" + self.noise_level = noise_level + self.scale = scaling_factor + + def __call__(self, *args): + """# Magic method when calling the distribution + + This method is going to be called when you use xgauss(case_count)""" + import numpy as np + + assert len(args) == 1, "Only 1 argument supported" + N = args[0] + noise = self.noise_level + scaling_factor = self.scale + + t = 3 * np.pi / 2 * (1 + 2 * np.random.rand(1, N)) + h = 21 * np.random.rand(1, N) + RANDOM = np.random.randn(3, N) * noise + data = ( + np.concatenate( + (scaling_factor * t * np.cos(t), h, scaling_factor * t * np.sin(t)) + ) + + RANDOM + ) + if self.dims == 2: + return data.T[:, [0, 2]] + return data.T[:, [0, 2, 1]] diff --git a/code/sunlab/common/distribution/symmetric_uniform_distribution.py b/code/sunlab/common/distribution/symmetric_uniform_distribution.py new file mode 100644 index 0000000..c3a4db0 --- /dev/null +++ b/code/sunlab/common/distribution/symmetric_uniform_distribution.py @@ -0,0 +1,21 @@ +from .adversarial_distribution import * + + +class SymmetricUniformDistribution(AdversarialDistribution): + """# Symmetric Uniform Distribution on [-1, 1)""" + + def __init__(self, N): + """# Symmetric Uniform Distribution Initialization + + Initializes the name and dimensions""" + super().__init__(N) + self.full_name = f"{N}-Dimensional Symmetric Uniform Distribution" + self.name = "SU" + + def __call__(self, *args): + """# Magic method when calling the distribution + + This method is going to be called when you use suniform(N1,...,Nm)""" + import numpy as np + + return np.random.rand(*args, self.dims) * 2.0 - 1.0 diff --git a/code/sunlab/common/distribution/uniform_distribution.py b/code/sunlab/common/distribution/uniform_distribution.py new file mode 100644 index 0000000..3e23e67 --- /dev/null +++ b/code/sunlab/common/distribution/uniform_distribution.py @@ -0,0 +1,21 @@ +from .adversarial_distribution import * + + +class UniformDistribution(AdversarialDistribution): + """# Uniform Distribution on [0, 1)""" + + def __init__(self, N): + """# Uniform Distribution Initialization + + Initializes the name and dimensions""" + super().__init__(N) + self.full_name = f"{N}-Dimensional Uniform Distribution" + self.name = "U" + + def __call__(self, *args): + """# Magic method when calling the distribution + + This method is going to be called when you use uniform(N1,...,Nm)""" + import numpy as np + + return np.random.rand(*args, self.dims) diff --git a/code/sunlab/common/distribution/x_gaussian_distribution.py b/code/sunlab/common/distribution/x_gaussian_distribution.py new file mode 100644 index 0000000..b4330aa --- /dev/null +++ b/code/sunlab/common/distribution/x_gaussian_distribution.py @@ -0,0 +1,38 @@ +from .adversarial_distribution import * + + +class XGaussianDistribution(AdversarialDistribution): + """# X Gaussian Distribution""" + + def __init__(self, N): + """# X Gaussian Distribution Initialization + + Initializes the name and dimensions""" + super().__init__(N) + assert self.dims == 2, "This Distribution only Supports 2-Dimensions" + self.full_name = "2-Dimensional X-Gaussian Distribution" + self.name = "XG" + + def __call__(self, *args): + """# Magic method when calling the distribution + + This method is going to be called when you use xgauss(case_count)""" + import numpy as np + + assert len(args) == 1, "Only 1 argument supported" + N = args[0] + sample_base = np.zeros((4 * N, 2)) + sample_base[0 * N : (0 + 1) * N, :] = np.random.multivariate_normal( + mean=[1, 1], cov=[[1, 0.7], [0.7, 1]], size=[N] + ) + sample_base[1 * N : (1 + 1) * N, :] = np.random.multivariate_normal( + mean=[-1, -1], cov=[[1, 0.7], [0.7, 1]], size=[N] + ) + sample_base[2 * N : (2 + 1) * N, :] = np.random.multivariate_normal( + mean=[-1, 1], cov=[[1, -0.7], [-0.7, 1]], size=[N] + ) + sample_base[3 * N : (3 + 1) * N, :] = np.random.multivariate_normal( + mean=[1, -1], cov=[[1, -0.7], [-0.7, 1]], size=[N] + ) + np.random.shuffle(sample_base) + return sample_base[:N, :] diff --git a/code/sunlab/common/mathlib/__init__.py b/code/sunlab/common/mathlib/__init__.py new file mode 100644 index 0000000..9b5ed21 --- /dev/null +++ b/code/sunlab/common/mathlib/__init__.py @@ -0,0 +1 @@ +from .base import * diff --git a/code/sunlab/common/mathlib/base.py b/code/sunlab/common/mathlib/base.py new file mode 100644 index 0000000..38ab14c --- /dev/null +++ b/code/sunlab/common/mathlib/base.py @@ -0,0 +1,57 @@ +import numpy as np + + +def angle(a, b): + """# Get Angle Between Row Vectors""" + from numpy import arctan2, pi + + theta_a = arctan2(a[:, 1], a[:, 0]) + theta_b = arctan2(b[:, 1], b[:, 0]) + d_theta = theta_b - theta_a + assert (-pi <= d_theta) and (d_theta <= pi), "Theta difference outside of [-π,π]" + return d_theta + + +def normalize(column): + """# Normalize Column Vector""" + from numpy.linalg import norm + + return column / norm(column, axis=0) + + +def winding(xy_grid, trajectory_start, trajectory_end): + """# Get Winding Number on Grid""" + from numpy import zeros, cross, clip, arcsin + + trajectories = trajectory_end - trajectory_start + winding = zeros((xy_grid.shape[0])) + for idx, trajectory in enumerate(trajectories): + r = xy_grid - trajectory_start[idx] + cross = cross(normalize(trajectory), normalize(r)) + cross = clip(cross, -1, 1) + theta = arcsin(cross) + winding += theta + return winding + + +def vorticity(xy_grid, trajectory_start, trajectory_end): + """# Get Vorticity Number on Grid""" + from numpy import zeros, cross + + trajectories = trajectory_end - trajectory_start + vorticity = zeros((xy_grid.shape[0])) + for idx, trajectory in enumerate(trajectories): + r = xy_grid - trajectory_start[idx] + vorticity += cross(normalize(trajectory), normalize(r)) + return vorticity + + +def data_range(data): + """# Get the range of values for each row""" + from numpy import min, max + + return min(data, axis=0), max(data, axis=0) + + +np.normalize = normalize +np.range = data_range diff --git a/code/sunlab/common/mathlib/lyapunov.py b/code/sunlab/common/mathlib/lyapunov.py new file mode 100644 index 0000000..3c747f1 --- /dev/null +++ b/code/sunlab/common/mathlib/lyapunov.py @@ -0,0 +1,54 @@ +def trajectory_to_distances(x): + """X: [N,N_t,N_d] + ret [N,N_t]""" + from numpy import zeros + from numpy.linalg import norm + from itertools import product, combinations + + x = [x[idx, ...] for idx in range(x.shape[0])] + pairwise_trajectories = combinations(x, 2) + _N_COMB = len(list(pairwise_trajectories)) + N_max = x[0].shape[0] + distances = zeros((_N_COMB, N_max)) + pairwise_trajectories = combinations(x, 2) + for idx, (a_t, b_t) in enumerate(pairwise_trajectories): + distances[idx, :] = norm(a_t[:N_max, :] - b_t[:N_max, :], axis=-1) + return distances + + +def Lyapunov_d(X): + """X: [N,N_t] + λ_n = ln(|dX_n|/|dX_0|)/n; n = [1,2,...]""" + from numpy import zeros, log, repeat + + Y = zeros((X.shape[0], X.shape[1] - 1)) + Y = log(X[:, 1:] / repeat([X[:, 0]], Y.shape[1], axis=0).T) / ( + repeat([range(Y.shape[1])], Y.shape[0], axis=0) + 1 + ) + return Y + + +def Lyapunov_t(X): + """X: [N,N_t,N_d]""" + return Lyapunov_d(trajectory_to_distances(X)) + + +Lyapunov = Lyapunov_d + + +def RelativeDistance_d(X): + """X: [N,N_t] + λ_n = ln(|dX_n|/|dX_0|)/n; n = [1,2,...]""" + from numpy import zeros, log, repeat + + Y = zeros((X.shape[0], X.shape[1] - 1)) + Y = log(X[:, 1:] / repeat([X[:, 0]], Y.shape[1], axis=0).T) + return Y + + +def RelativeDistance_t(X): + """X: [N,N_t,N_d]""" + return RelativeDistance_d(trajectory_to_distances(X)) + + +RelativeDistance = RelativeDistance_d diff --git a/code/sunlab/common/mathlib/random_walks.py b/code/sunlab/common/mathlib/random_walks.py new file mode 100644 index 0000000..3aa3bcb --- /dev/null +++ b/code/sunlab/common/mathlib/random_walks.py @@ -0,0 +1,83 @@ +def get_levy_flight(T=50, D=2, t0=0.1, alpha=3, periodic=False): + from numpy import vstack + from mistree import get_levy_flight as get_flight + + if D == 2: + x, y = get_flight(T, mode="2D", periodic=periodic, t_0=t0, alpha=alpha) + xy = vstack([x, y]).T + elif D == 3: + x, y, z = get_flight(T, mode="3D", periodic=periodic, t_0=t0, alpha=alpha) + xy = vstack([x, y, z]).T + else: + raise ValueError(f"Dimension {D} not supported!") + return xy + + +def get_levy_flights(N=10, T=50, D=2, t0=0.1, alpha=3, periodic=False): + from numpy import moveaxis, array + + trajectories = [] + for _ in range(N): + xy = get_levy_flight(T=T, D=D, t0=t0, alpha=alpha, periodic=periodic) + trajectories.append(xy) + return moveaxis(array(trajectories), 0, 1) + + +def get_jitter_levy_flights( + N=10, T=50, D=2, t0=0.1, alpha=3, periodic=False, noise=5e-2 +): + from numpy.random import randn + + trajectories = get_levy_flights( + N=N, T=T, D=D, t0=t0, alpha=alpha, periodic=periodic + ) + return trajectories + randn(*trajectories.shape) * noise + + +def get_gaussian_random_walk(T=50, D=2, R=5, step_size=0.5, soft=None): + from numpy import array, sin, cos, exp, zeros, pi + from numpy.random import randn, uniform, rand + from numpy.linalg import norm + + def is_in(x, R=1): + from numpy.linalg import norm + + return norm(x) < R + + X = zeros((T, D)) + for t in range(1, T): + while True: + if D == 2: + angle = uniform(0, pi * 2) + step = randn(1) * step_size + X[t, :] = X[t - 1, :] + array([cos(angle), sin(angle)]) * step + else: + X[t, :] = X[t - 1, :] + randn(D) / D * step_size + if soft is None: + if is_in(X[t, :], R): + break + elif rand() < exp(-(norm(X[t, :]) - R) * soft): + break + return X + + +def get_gaussian_random_walks(N=10, T=50, D=2, R=5, step_size=0.5, soft=None): + from numpy import moveaxis, array + + trajectories = [] + for _ in range(N): + xy = get_gaussian_random_walk(T=T, D=D, R=R, step_size=step_size, soft=soft) + trajectories.append(xy) + return moveaxis(array(trajectories), 0, 1) + + +def get_gaussian_sample(T=50, D=2): + from numpy.random import randn + + return randn(T, D) + + +def get_gaussian_samples(N=10, T=50, D=2, R=5, step_size=0.5): + from numpy.random import randn + + return randn(T, N, D) diff --git a/code/sunlab/common/plotting/__init__.py b/code/sunlab/common/plotting/__init__.py new file mode 100644 index 0000000..d6873aa --- /dev/null +++ b/code/sunlab/common/plotting/__init__.py @@ -0,0 +1,2 @@ +from .colors import * +from .base import * diff --git a/code/sunlab/common/plotting/base.py b/code/sunlab/common/plotting/base.py new file mode 100644 index 0000000..aaf4a94 --- /dev/null +++ b/code/sunlab/common/plotting/base.py @@ -0,0 +1,270 @@ +from matplotlib import pyplot as plt + + +def blank_plot(_plt=None, _xticks=False, _yticks=False): + """# Remove Plot Labels""" + if _plt is None: + _plt = plt + _plt.xlabel("") + _plt.ylabel("") + _plt.title("") + tick_params = { + "which": "both", + "bottom": _xticks, + "left": _yticks, + "right": False, + "labelleft": False, + "labelbottom": False, + } + _plt.tick_params(**tick_params) + for child in plt.gcf().get_children(): + if child._label == "<colorbar>": + child.set_yticks([]) + axs = _plt.gcf().get_axes() + try: + axs = axs.flatten() + except: + ... + for ax in axs: + ax.set_xlabel("") + ax.set_ylabel("") + ax.set_title("") + ax.tick_params(**tick_params) + + +def save_plot(name, _plt=None, _xticks=False, _yticks=False, tighten=True): + """# Save Plot in Multiple Formats""" + assert type(name) == str, "Name must be string" + from os.path import dirname + from os import makedirs + + makedirs(dirname(name), exist_ok=True) + if _plt is None: + from matplotlib import pyplot as plt + _plt = plt + _plt.savefig(name + ".png", dpi=1000) + blank_plot(_plt, _xticks=_xticks, _yticks=_yticks) + if tighten: + from matplotlib import pyplot as plt + plt.tight_layout() + _plt.savefig(name + ".pdf") + _plt.savefig(name + ".svg") + + +def scatter_2d(data_2d, labels=None, _plt=None, **matplotlib_kwargs): + """# Scatter 2d Data + + - data_2d: 2d-dataset to plot + - labels: labels for each case + - _plt: Optional specific plot to transform""" + from .colors import Pcolor + + if _plt is None: + _plt = plt + + def _filter(data, labels=None, _filter_on=None): + if labels is None: + return data, False + else: + return data[labels == _filter_on], True + + for _class in [2, 3, 1, 0]: + local_data, has_color = _filter(data_2d, labels, _class) + if has_color: + _plt.scatter( + local_data[:, 0], + local_data[:, 1], + color=Pcolor[_class], + **matplotlib_kwargs + ) + else: + _plt.scatter(local_data[:, 0], local_data[:, 1], **matplotlib_kwargs) + break + return _plt + + +def plot_contour(two_d_mask, color="black", color_map=None, raise_error=False): + """# Plot Contour of Mask""" + from matplotlib.pyplot import contour + from numpy import mgrid + + z = two_d_mask + x, y = mgrid[: z.shape[1], : z.shape[0]] + if color_map is not None: + try: + color = color_map(color) + except Exception as e: + if raise_error: + raise e + try: + contour(x, y, z.T, colors=color, linewidth=0.5) + except Exception as e: + if raise_error: + raise e + + +def gaussian_smooth_plot( + xy, + sigma=0.1, + extent=[-1, 1, -1, 1], + grid_n=100, + grid=None, + do_normalize=True, +): + """# Plot Data with Gaussian Smoothening around point""" + from numpy import array, ndarray, linspace, meshgrid, zeros_like + from numpy import pi, sqrt, exp + from numpy.linalg import norm + + if not type(xy) == ndarray: + xy = array(xy) + if grid is not None: + XY = grid + else: + X = linspace(extent[0], extent[1], grid_n) + Y = linspace(extent[2], extent[3], grid_n) + XY = array(meshgrid(X, Y)).T + smoothed = zeros_like(XY[:, :, 0]) + factor = 1 + if do_normalize: + factor = (sqrt(2 * pi) * sigma) ** 2. + if len(xy.shape) == 1: + smoothed = exp(-((norm(xy - XY, axis=-1) / (sqrt(2) * sigma)) ** 2.0)) / factor + else: + try: + from tqdm.notebook import tqdm + except Exception: + + def tqdm(x): + return x + + for i in tqdm(range(xy.shape[0])): + if xy.shape[-1] == 2: + smoothed += ( + exp(-((norm((xy[i, :] - XY), axis=-1) / (sqrt(2) * sigma)) ** 2.0)) + / factor + ) + elif xy.shape[-1] == 3: + smoothed += ( + exp(-((norm((xy[i, :2] - XY), axis=-1) / (sqrt(2) * sigma)) ** 2.0)) + / factor + * xy[i, 2] + ) + return smoothed, XY + + +def plot_grid_data(xy_grid, data_grid, cbar=False, _plt=None, _cmap="RdBu", grid_min=1): + """# Plot Gridded Data""" + from numpy import nanmin, nanmax + from matplotlib.colors import TwoSlopeNorm + + if _plt is None: + _plt = plt + norm = TwoSlopeNorm( + vmin=nanmin([-grid_min, nanmin(data_grid)]), + vcenter=0, + vmax=nanmax([grid_min, nanmax(data_grid)]), + ) + _plt.pcolor(xy_grid[:, :, 0], xy_grid[:, :, 1], data_grid, cmap="RdBu", norm=norm) + if cbar: + _plt.colorbar() + + +def plot_winding(xy_grid, winding_grid, cbar=False, _plt=None): + plot_grid_data(xy_grid, winding_grid, cbar=cbar, _plt=_plt, grid_min=1e-5) + + +def plot_vorticity(xy_grid, vorticity_grid, cbar=False, save=False, _plt=None): + plot_grid_data(xy_grid, vorticity_grid, cbar=cbar, _plt=_plt, grid_min=1e-1) + + +plt.blank = lambda: blank_plot(plt) +plt.scatter2d = lambda data, labels=None, **matplotlib_kwargs: scatter_2d( + data, labels, plt, **matplotlib_kwargs +) +plt.save = save_plot + + +def interpolate_points(df, columns=["Latent-0", "Latent-1"], kind="quadratic", N=500): + """# Interpolate points""" + from scipy.interpolate import interp1d + import numpy as np + + points = df[columns].to_numpy() + distance = np.cumsum(np.sqrt(np.sum(np.diff(points, axis=0) ** 2, axis=1))) + distance = np.insert(distance, 0, 0) / distance[-1] + interpolator = interp1d(distance, points, kind=kind, axis=0) + alpha = np.linspace(0, 1, N) + interpolated_points = interpolator(alpha) + return interpolated_points.reshape(-1, 1, 2) + + +def plot_trajectory( + df, + Fm=24, + FM=96, + interpolate=False, + interpolation_kind="quadratic", + interpolation_N=500, + columns=["Latent-0", "Latent-1"], + frame_column="Frames", + alpha=0.8, + lw=4, + _plt=None, + _z=None, +): + """# Plot Trajectories + + Interpolation possible""" + import numpy as np + from matplotlib.collections import LineCollection + import matplotlib as mpl + + if _plt is None: + _plt = plt + if type(_plt) == mpl.axes._axes.Axes: + _ca = _plt + else: + try: + _ca = _plt.gca() + except: + _ca = _plt + X = df[columns[0]] + Y = df[columns[1]] + fm, fM = np.min(df[frame_column]), np.max(df[frame_column]) + + if interpolate: + if interpolation_kind == "cubic": + if len(df) <= 3: + return + elif interpolation_kind == "quadratic": + if len(df) <= 2: + return + else: + if len(df) <= 1: + return + points = interpolate_points( + df, kind=interpolation_kind, columns=columns, N=interpolation_N + ) + else: + points = np.array([X, Y]).T.reshape(-1, 1, 2) + + segments = np.concatenate([points[:-1], points[1:]], axis=1) + lc = LineCollection( + segments, + cmap=plt.get_cmap("plasma"), + norm=mpl.colors.Normalize(Fm, FM), + ) + if _z is not None: + from mpl_toolkits.mplot3d.art3d import line_collection_2d_to_3d + + if interpolate: + _z = _z # TODO: Interpolate + line_collection_2d_to_3d(lc, _z) + lc.set_array(np.linspace(fm, fM, points.shape[0])) + lc.set_linewidth(lw) + lc.set_alpha(alpha) + _ca.add_collection(lc) + _ca.autoscale() + _ca.margins(0.04) + return lc diff --git a/code/sunlab/common/plotting/colors.py b/code/sunlab/common/plotting/colors.py new file mode 100644 index 0000000..c4fc727 --- /dev/null +++ b/code/sunlab/common/plotting/colors.py @@ -0,0 +1,38 @@ +class PhenotypeColors: + """# Phenotype Colorings + + Standardization for the different phenotype colors""" + + def __init__(self): + """# Empty Construtor""" + pass + + def get_basic_colors(self, transition=False): + """# Return the Color Names + + - transition: Returns the color for the transition class too""" + if transition: + return ["yellow", "purple", "green", "blue", "cyan"] + return ["yellow", "purple", "green", "blue"] + + def get_colors(self, transition=False): + """# Return the Color Names + + - transition: Returns the color for the transition class too""" + if transition: + return ["#ffff00", "#ff3cfa", "#11f309", "#213ff0", "cyan"] + return ["#ffff00", "#ff3cfa", "#11fe09", "#213ff0"] + + def get_colormap(self, transition=False): + """# Return the Matplotlib Colormap + + - transition: Returns the color for the transition class too""" + from matplotlib.colors import ListedColormap as LC + + return LC(self.get_colors(transition)) + + +# Basic Exports +Pcolor = PhenotypeColors().get_colors() +Pmap = PhenotypeColors().get_colormap() +Pmapx = PhenotypeColors().get_colormap(True) diff --git a/code/sunlab/common/scaler/__init__.py b/code/sunlab/common/scaler/__init__.py new file mode 100644 index 0000000..2a2281a --- /dev/null +++ b/code/sunlab/common/scaler/__init__.py @@ -0,0 +1,2 @@ +from .max_abs_scaler import * +from .quantile_scaler import * diff --git a/code/sunlab/common/scaler/adversarial_scaler.py b/code/sunlab/common/scaler/adversarial_scaler.py new file mode 100644 index 0000000..7f61725 --- /dev/null +++ b/code/sunlab/common/scaler/adversarial_scaler.py @@ -0,0 +1,44 @@ +class AdversarialScaler: + """# Scaler Class to use in Adversarial Autoencoder + + For any scaler to be implemented, make sure to ensure each of the methods + are implemented: + - __init__ (optional) + - init + - load + - save + - __call__""" + + def __init__(self, base_directory): + """# Scaler initialization + + - Initialize the base directory of the model where it will live""" + self.base_directory = base_directory + + def init(self, data): + """# Scaler initialization + + Initialize the scaler transformation with the data + Should always return self in subclasses""" + raise NotImplementedError("Scaler initialization has not been implemented yet") + + def load(self): + """# Scaler loading + + Load the data scaler model from a file + Should always return self in subclasses""" + raise NotImplementedError("Scaler loading has not been implemented yet") + + def save(self): + """# Scaler saving + + Save the data scaler model""" + raise NotImplementedError("Scaler saving has not been implemented yet") + + def transform(self, *args, **kwargs): + """# Scale the given data""" + return self.__call__(*args, **kwargs) + + def __call__(self, *args, **kwargs): + """# Scale the given data""" + raise NotImplementedError("Scaler has not been implemented yet") diff --git a/code/sunlab/common/scaler/max_abs_scaler.py b/code/sunlab/common/scaler/max_abs_scaler.py new file mode 100644 index 0000000..56ea589 --- /dev/null +++ b/code/sunlab/common/scaler/max_abs_scaler.py @@ -0,0 +1,48 @@ +from .adversarial_scaler import AdversarialScaler + + +class MaxAbsScaler(AdversarialScaler): + """# MaxAbsScaler + + Scale the data based on the maximum-absolute value in each column""" + + def __init__(self, base_directory): + """# MaxAbsScaler initialization + + - Initialize the base directory of the model where it will live + - Initialize the scaler model""" + super().__init__(base_directory) + from sklearn.preprocessing import MaxAbsScaler as MAS + + self.scaler_base = MAS() + self.scaler = None + + def init(self, data): + """# Scaler initialization + + Initialize the scaler transformation with the data""" + self.scaler = self.scaler_base.fit(data) + return self + + def load(self): + """# Scaler loading + + Load the data scaler model from a file""" + from pickle import load + + with open(f"{self.base_directory}/portable/maxabs_scaler.pkl", "rb") as fhandle: + self.scaler = load(fhandle) + return self + + def save(self): + """# Scaler saving + + Save the data scaler model""" + from pickle import dump + + with open(f"{self.base_directory}/portable/maxabs_scaler.pkl", "wb") as fhandle: + dump(self.scaler, fhandle) + + def __call__(self, *args, **kwargs): + """# Scale the given data""" + return self.scaler.transform(*args, **kwargs) diff --git a/code/sunlab/common/scaler/quantile_scaler.py b/code/sunlab/common/scaler/quantile_scaler.py new file mode 100644 index 0000000..a0f53fd --- /dev/null +++ b/code/sunlab/common/scaler/quantile_scaler.py @@ -0,0 +1,52 @@ +from .adversarial_scaler import AdversarialScaler + + +class QuantileScaler(AdversarialScaler): + """# QuantileScaler + + Scale the data based on the quantile distributions of each column""" + + def __init__(self, base_directory): + """# QuantileScaler initialization + + - Initialize the base directory of the model where it will live + - Initialize the scaler model""" + super().__init__(base_directory) + from sklearn.preprocessing import QuantileTransformer as QS + + self.scaler_base = QS() + self.scaler = None + + def init(self, data): + """# Scaler initialization + + Initialize the scaler transformation with the data""" + self.scaler = self.scaler_base.fit(data) + return self + + def load(self): + """# Scaler loading + + Load the data scaler model from a file""" + from pickle import load + + with open( + f"{self.base_directory}/portable/quantile_scaler.pkl", "rb" + ) as fhandle: + self.scaler = load(fhandle) + return self + + def save(self): + """# Scaler saving + + Save the data scaler model""" + from pickle import dump + + with open( + f"{self.base_directory}/portable/quantile_scaler.pkl", "wb" + ) as fhandle: + dump(self.scaler, fhandle) + + def __call__(self, *args, **kwargs): + """# Scale the given data""" + return self.scaler.transform(*args, **kwargs) |