Source code for snsynth.transform.onehot

from snsynth.transform.definitions import ColumnType
from .base import ColumnTransformer
import numpy as np

[docs]class OneHotEncoder(ColumnTransformer): """Transforms integer-labeled data into one-hot encoding. Inputs are assumed to be 0-based. To convert from unstructured categorical data, chain with LabelTransformer first. """ cache_fit = False def __init__(self): super().__init__() @property def output_type(self): return ColumnType.CATEGORICAL @property def cardinality(self): return [2] * (self.max + 1) def _fit(self, val): if val > self.max: self.max = val def _fit_finish(self): self.output_width = self.max + 1 super()._fit_finish() def _clear_fit(self): self._fit_complete = False self.max = -1 def _transform(self, val): if self.max < 0 or not self._fit_complete: raise ValueError("OneHotEncoder has not been fit yet.") elif val < 0 or val > self.max: raise ValueError( f"Provided integer-label {val} is invalid." " Please ensure that all inputs are 0-based and provided during data fit." ) elif self.max == 0: return 1 bits = [0] * (self.max + 1) bits[val] = 1 return tuple(bits) def _inverse_transform(self, val): # will always choose first if multiple are set return np.argmax(val)