Source code for snsynth.transform.chain
from .base import CachingColumnTransformer
import warnings
[docs]class ChainTransformer(CachingColumnTransformer):
"""Sequentially process a column through multiple transforms. When reversed,
the inverse transforms are applied in reverse order.
:param transforms: A list of ColumnTransformers to apply sequentially.
"""
def __init__(self, transformers):
self.transformers = transformers
super().__init__()
@property
def output_type(self):
return self.transformers[-1].output_type
@property
def needs_epsilon(self):
return any(transformer.needs_epsilon for transformer in self.transformers)
@property
def cardinality(self):
cards = []
for transformer in self.transformers:
for c in transformer.cardinality:
cards.append(c)
return cards
@property
def fit_complete(self):
return all([t.fit_complete for t in self.transformers])
def allocate_privacy_budget(self, epsilon, odometer):
n_with_epsilon = sum([1 for t in self.transformers if t.needs_epsilon])
if n_with_epsilon == 0:
return
elif n_with_epsilon > 1:
warnings.warn(f"Multiple transformers in chain need epsilon, which is likely wasteful.")
else:
for transformer in self.transformers:
if transformer.needs_epsilon:
transformer.allocate_privacy_budget(epsilon / n_with_epsilon, odometer)
def _fit_finish(self):
vals = self._fit_vals
for transformer in self.transformers:
vals = transformer.fit_transform(vals)
self._fit_vals = []
self.output_width = self.transformers[-1].output_width
def _clear_fit(self):
for transformer in self.transformers:
transformer._clear_fit()
if self.fit_complete:
self.output_width = self.transformers[-1].output_width
def _transform(self, val):
for transformer in self.transformers:
val = transformer._transform(val)
return val
def _inverse_transform(self, val):
for transformer in reversed(self.transformers):
val = transformer._inverse_transform(val)
return val