Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.
def test_encode_fit():
df = _some_df()
encode_stage = Encode()
with pytest.raises(UnfittedPipelineStageError):
encode_stage.transform(df)
res_df = encode_stage.fit(df)
assert 'lbl' in res_df.columns
assert res_df['lbl'][1] == 'acd'
assert res_df['lbl'][2] == 'alk'
assert res_df['lbl'][3] == 'alk'
# see only transform (no fit) when already fitted
df2 = _some_df2()
res_df2 = encode_stage.transform(df2)
assert 'lbl' in res_df2.columns
assert res_df2['lbl'][1] == 1
assert res_df2['lbl'][2] == 0
def test_encode_in_pipelin_fit_n_transform():
drop_name = pdp.ColDrop('name')
encode_stage = Encode()
pline = drop_name + encode_stage
df = _some_df()
with pytest.raises(UnfittedPipelineStageError):
res_df = pline.transform(df)
res_df = pline.fit(df)
assert 'lbl' in res_df.columns
assert 'name' in res_df.columns
assert res_df['lbl'][1] == 'acd'
assert res_df['lbl'][2] == 'alk'
assert res_df['lbl'][3] == 'alk'
res_df = pline.transform(df)
assert 'lbl' in res_df.columns
def test_encode_with_args():
df = _some_df()
encode_stage = Encode("lbl", drop=False)
res_df = encode_stage(df, verbose=True)
assert 'lbl' in res_df.columns
assert res_df['lbl_enc'][1] == 0
assert res_df['lbl_enc'][2] == 1
assert res_df['lbl_enc'][3] == 1
# see only transform (no fit) when already fitted
df2 = _some_df2()
res_df2 = encode_stage(df2)
assert 'lbl' in res_df.columns
assert res_df2['lbl_enc'][1] == 1
assert res_df2['lbl_enc'][2] == 0
assert res_df2['lbl_enc'][3] == 1
# check fit_transform when already fitted
df2 = _some_df2()
def test_encode_in_pipeline():
drop_name = pdp.ColDrop('name')
encode_stage = Encode()
pline = drop_name + encode_stage
df = _some_df()
res_df = pline(df)
assert 'lbl' in res_df.columns
assert 'name' not in res_df.columns
assert res_df['lbl'][1] == 0
assert res_df['lbl'][2] == 1
assert res_df['lbl'][3] == 1
# check fitted pipeline
df2 = _some_df2()
res_df2 = pline(df2)
assert 'lbl' in res_df2.columns
assert res_df2['lbl'][1] == 1
assert res_df2['lbl'][2] == 0
def test_encode_with_exclude():
df = _some_df()
encode_stage = Encode("lbl", exclude_columns="name")
res_df = encode_stage(df)
assert 'lbl' in res_df.columns
assert res_df['lbl'][1] == 0
assert res_df['lbl'][2] == 1
assert res_df['lbl'][3] == 1
def test_encode():
df = _some_df()
encode_stage = Encode()
res_df = encode_stage(df)
assert 'lbl' in res_df.columns
assert res_df['lbl'][1] == 0
assert res_df['lbl'][2] == 1
assert res_df['lbl'][3] == 1
assert res_df['name'][1] == 0
assert res_df['name'][2] == 1
assert res_df['name'][3] == 2
# see only transform (no fit) when already fitted
df2 = _some_df2()
res_df2 = encode_stage(df2)
assert 'lbl' in res_df2.columns
assert res_df2['lbl'][1] == 1
assert res_df2['lbl'][2] == 0
assert res_df2['lbl'][3] == 1
self, columns=None, exclude_columns=None, drop=True, **kwargs
):
if columns is None:
self._columns = None
else:
self._columns = _interpret_columns_param(columns)
if exclude_columns is None:
self._exclude_columns = []
else:
self._exclude_columns = _interpret_columns_param(exclude_columns)
self._drop = drop
self.encoders = {}
col_str = _list_str(self._columns)
super_kwargs = {
"exmsg": Encode._DEF_ENCODE_EXC_MSG.format(col_str),
"appmsg": Encode._DEF_ENCODE_APP_MSG.format(col_str),
"desc": "Encode {}".format(col_str or "all categorical columns"),
}
super_kwargs.update(**kwargs)
super().__init__(**super_kwargs)
def __init__(
self, columns=None, exclude_columns=None, drop=True, **kwargs
):
if columns is None:
self._columns = None
else:
self._columns = _interpret_columns_param(columns)
if exclude_columns is None:
self._exclude_columns = []
else:
self._exclude_columns = _interpret_columns_param(exclude_columns)
self._drop = drop
self.encoders = {}
col_str = _list_str(self._columns)
super_kwargs = {
"exmsg": Encode._DEF_ENCODE_EXC_MSG.format(col_str),
"appmsg": Encode._DEF_ENCODE_APP_MSG.format(col_str),
"desc": "Encode {}".format(col_str or "all categorical columns"),
}
super_kwargs.update(**kwargs)
super().__init__(**super_kwargs)