aboutsummaryrefslogtreecommitdiff
path: root/code/sunlab/common/scaler/max_abs_scaler.py
diff options
context:
space:
mode:
authorChristian C <cc@localhost>2024-11-11 12:29:32 -0800
committerChristian C <cc@localhost>2024-11-11 12:29:32 -0800
commitb85ee9d64a536937912544c7bbd5b98b635b7e8d (patch)
treecef7bc17d7b29f40fc6b1867d0ce0a742d5583d0 /code/sunlab/common/scaler/max_abs_scaler.py
Initial commit
Diffstat (limited to 'code/sunlab/common/scaler/max_abs_scaler.py')
-rw-r--r--code/sunlab/common/scaler/max_abs_scaler.py48
1 files changed, 48 insertions, 0 deletions
diff --git a/code/sunlab/common/scaler/max_abs_scaler.py b/code/sunlab/common/scaler/max_abs_scaler.py
new file mode 100644
index 0000000..56ea589
--- /dev/null
+++ b/code/sunlab/common/scaler/max_abs_scaler.py
@@ -0,0 +1,48 @@
+from .adversarial_scaler import AdversarialScaler
+
+
+class MaxAbsScaler(AdversarialScaler):
+ """# MaxAbsScaler
+
+ Scale the data based on the maximum-absolute value in each column"""
+
+ def __init__(self, base_directory):
+ """# MaxAbsScaler initialization
+
+ - Initialize the base directory of the model where it will live
+ - Initialize the scaler model"""
+ super().__init__(base_directory)
+ from sklearn.preprocessing import MaxAbsScaler as MAS
+
+ self.scaler_base = MAS()
+ self.scaler = None
+
+ def init(self, data):
+ """# Scaler initialization
+
+ Initialize the scaler transformation with the data"""
+ self.scaler = self.scaler_base.fit(data)
+ return self
+
+ def load(self):
+ """# Scaler loading
+
+ Load the data scaler model from a file"""
+ from pickle import load
+
+ with open(f"{self.base_directory}/portable/maxabs_scaler.pkl", "rb") as fhandle:
+ self.scaler = load(fhandle)
+ return self
+
+ def save(self):
+ """# Scaler saving
+
+ Save the data scaler model"""
+ from pickle import dump
+
+ with open(f"{self.base_directory}/portable/maxabs_scaler.pkl", "wb") as fhandle:
+ dump(self.scaler, fhandle)
+
+ def __call__(self, *args, **kwargs):
+ """# Scale the given data"""
+ return self.scaler.transform(*args, **kwargs)