Source code for snsynth.transform.onehot
from snsynth.transform.definitions import ColumnType
from .base import ColumnTransformer
import numpy as np
[docs]class OneHotEncoder(ColumnTransformer):
"""Transforms integer-labeled data into one-hot encoding. Inputs are assumed to be 0-based.
To convert from unstructured categorical data, chain with LabelTransformer first.
"""
cache_fit = False
def __init__(self):
super().__init__()
@property
def output_type(self):
return ColumnType.CATEGORICAL
@property
def cardinality(self):
return [2] * (self.max + 1)
def _fit(self, val):
if val > self.max:
self.max = val
def _fit_finish(self):
self.output_width = self.max + 1
super()._fit_finish()
def _clear_fit(self):
self._fit_complete = False
self.max = -1
def _transform(self, val):
if self.max < 0 or not self._fit_complete:
raise ValueError("OneHotEncoder has not been fit yet.")
elif val < 0 or val > self.max:
raise ValueError(
f"Provided integer-label {val} is invalid."
" Please ensure that all inputs are 0-based and provided during data fit."
)
elif self.max == 0:
return 1
bits = [0] * (self.max + 1)
bits[val] = 1
return tuple(bits)
def _inverse_transform(self, val):
# will always choose first if multiple are set
return np.argmax(val)