Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.
def test_band_ret_task(self):
class TMultipleOutputsPipeline(PipelineTask):
t_types = parameter.value([1, 2])
t_output = output
def band(self):
self.t_output = {t: TTask(t_param=t).t_output for t in self.t_types}
task = TMultipleOutputsPipeline()
assert_run_task(task)
@task(result=(output(name="o_a").csv[List[str]], "o_b"))
def t_f(a=5):
return {"o_a": [str(a)], "o_b": ["2"]}
def _pd_to(self, data, file_or_path, *args, **kwargs):
kwargs = combine_mappings({"format": "fixed"}, kwargs)
with pd.HDFStore(file_or_path, "w") as store:
kwargs.pop("mode", None)
store.put("features", data.features, data_columns=True, **kwargs)
store.put("targets", data.targets, data_columns=True, **kwargs)
register_marshaller(MyData, FileFormat.hdf5, MyDataToHdf5())
MyDataParameter = register_custom_parameter(MyData, parameter.data.type(MyData))
class MyDataReport(PythonTask):
my_data = parameter[MyData]
report = output[DataFrame]
def run(self):
self.report = self.my_data.features.head(1)
class BuildMyData(PythonTask):
my_data = output.hdf5[MyData]
def run(self):
features = pd.DataFrame(data=[[1, 2], [2, 3]], columns=["Names", "Births"])
targets = pd.DataFrame(data=[[1, 22], [2, 33]], columns=["Names", "Class"])
self.my_data = MyData(features=features, targets=targets)
@task
def validate_my_data(my_data):
@task(result=output.hdf5)
def t_f_hdf5(i=1):
# type:(int)->pd.DataFrame
return pd.DataFrame(
data=list(zip(["Bob", "Jessica"], [968, 155])), columns=["Names", "Births"]
)
def test_single_task(self, tmpdir_factory):
class TestTask(Task):
test_input = data
p = parameter[str]
d = parameter[date]
param_from_config = parameter[date]
a_output = output.data
def run(self):
self.a_output.write("ss")
actual = TestTask(test_input=__file__, p="333", d=date(2018, 3, 4))
assert actual.p == "333"
actual.dbnd_run()
assert actual.a_output.read() == "ss"
def test_inject_dict(self):
class TTaskCombineInputs(PythonTask):
t_inputs = parameter[Dict[int, Target]]
t_output = output
def run(self):
with self.t_output.open("w") as fp:
for t_name, t_target in six.iteritems(self.t_inputs):
fp.write(t_target.read())
class TMultipleInjectPipeline(PipelineTask):
t_types = parameter.value([1, 2])
t_output = output
def band(self):
t_inputs = {t: TTask(t_param=t).t_output for t in self.t_types}
self.t_output = TTaskCombineInputs(t_inputs=t_inputs).t_output
task = TMultipleInjectPipeline()
assert_run_task(task)
logger.error(task.t_output.read())
@task(result=output.data_list_str)
def t_f_b(t_input2):
# type: (DataList[str]) -> List[str]
return ["s_%s" % s for s in t_input2]
import logging
from dbnd import PipelineTask, PythonTask, output, parameter
from test_dbnd.scenarios import data
logger = logging.getLogger(__name__)
class SF_A(PythonTask):
my_filter = parameter[int]
input_logs = data
input_labels = data
o_devices = output[str]
o_stats = output.target
def run(self):
self.o_devices = "devices.."
self.o_stats.write("stats\t1")
class SF_B(PythonTask):
combine_similar = parameter[bool]
input_devices = data
o_device_histogram = output
o_types = output
def run(self):
@task(result=output.txt())
def validate_model(model, validation_dataset):
# type: (ElasticNet, pd.DataFrame) -> str
""" Calculates metrics of wine prediction model (py27) """
validation_x = validation_dataset.drop(["quality"], 1)
validation_y = validation_dataset[["quality"]]
prediction = model.predict(validation_x)
(rmse, mae, r2) = calculate_metrics(validation_y, prediction)
log_artifact(
"prediction_scatter_plot", _create_scatter_plot(validation_y, prediction)
)
log_metric("rmse", rmse)
log_metric("mae", rmse)
log_metric("r2", r2)
@task(archive=output(output_ext=".tar.gz")[Path])
def export_db(
archive,
include_db=True,
include_logs=True,
task_version=utcnow().strftime("%Y%m%d_%H%M%S"),
):
# type: (Path, bool, bool, str)-> None
from dbnd._core.current import get_databand_context
logger.info("Compressing files to %s..." % archive)
with tarfile.open(str(archive), "w:gz") as tar:
if include_db:
assert_web_enabled(reason="dbnd_web is required for export db")