Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.
from datetime import date, datetime, timedelta
import pytz
from databand.parameters import DateHourParameter, TimeDeltaParameter
from dbnd import data, output, parameter
from dbnd._core.task import base_task
from dbnd._vendor.cloudpickle import cloudpickle
from dbnd.tasks import DataSourceTask, Task
from dbnd_test_scenarios.test_common.task.factories import TTask
from targets import target
class DefaultInsignificantParamTask(TTask):
insignificant_param = parameter.value(significant=False, default="value")
necessary_param = parameter.value(significant=False)[str]
class MyExternalTask(DataSourceTask):
some_outputs = output
def band(self):
self.some_outputs = target("/tmp")
class TestTaskObject(object):
def test_task_deepcopy(self, tmpdir_factory):
class TestTask(Task):
test_input = data
p = parameter[str]
d = parameter[date]
import logging
import dbnd
from dbnd import config, dbnd_run_cmd, output, parameter
from dbnd.testing.helpers_pytest import skip_on_windows
from dbnd_test_scenarios.dbnd_scenarios import scenario_path
from dbnd_test_scenarios.test_common.task.factories import TTask
logger = logging.getLogger(__name__)
class MyConfig(dbnd.Config):
mc_p = parameter[int]
mc_q = parameter.value(73)
class MyConfigTester(dbnd.Task):
t_output = output.json[object]
def run(self):
config.log_current_config(sections=["MyConfig"], as_table=True)
c = MyConfig()
self.t_output = [c.mc_p, c.mc_q]
class TestTaskCliSetConfig(object):
def test_from_extra_config(self):
class MyTaskWithConfg(TTask):
parameter_with_config = parameter[str]
from dbnd.tasks.basics import SimplestTask
from dbnd.testing import assert_run_task
from dbnd.testing.helpers_pytest import skip_on_windows
class SleepyTask(SimplestTask):
sleep_time = parameter.value(0.1, significant=False)
def run(self):
if self.sleep_time:
time.sleep(self.sleep_time)
super(SleepyTask, self).run()
class ParallelTasksPipeline(PipelineTask):
num_of_tasks = parameter.value(3)
def band(self):
tasks = []
for i in range(self.num_of_tasks):
tasks.append(SleepyTask(simplest_param=str(i)))
return tasks
class TestTasksParallelExample(object):
def test_parallel_simple_executor(self):
target = ParallelTasksPipeline(num_of_tasks=2)
run_task(target)
assert target._complete()
# @with_context(conf={'executor': {'local': 'true'},
# 'databand': {'module': ParallelTasksPipeline.__module__}})
import logging
from functools import partial
from dbnd import output, parameter, task
from dbnd._core.decorator.decorated_task import DecoratedPythonTask
from dbnd.testing.helpers_pytest import assert_run_task
from dbnd_test_scenarios.test_common.targets.target_test_base import TargetTestBase
class MyExpTask(DecoratedPythonTask):
custom_name = parameter.value("aa")
previous_exp = parameter.value(1)
score_card = output.csv.data
my_ratio = output.csv.data
def run(self):
# wrapping code
score = self._invoke_func()
self.score_card.write(str(score))
self.my_ratio.write_pickle(self.previous_exp + 1)
my_experiment = partial(task, _task_type=MyExpTask)
def test_generated_output_dict(self):
def _get_all_splits(task, task_output): # type: (Task, ParameterBase) -> dict
result = {}
target = task_output.build_target(task)
for i in range(task.parts):
name = "part_%s" % i
result[name] = (
target.partition(name="train_%s" % name),
target.partition(name="test_%s" % name),
)
return result
class TGeneratedOutputs(PythonTask):
parts = parameter.value(3)
splits = output.csv.folder(output_factory=_get_all_splits)
def run(self):
for key, split in self.splits.items():
train, test = split
train.write(key)
test.write(key)
assert_run_task(TGeneratedOutputs())
from __future__ import absolute_import
import pandas as pd
import dbnd
from dbnd import Config, data, output, parameter, task
from dbnd._core.current import current_task_run
from dbnd.tasks import PythonTask
class TTask(PythonTask):
t_param = parameter.value("1")
t_output = output.data
def run(self):
self.t_output.write("%s" % self.t_param)
class TTaskWithInput(TTask):
t_input = data
class TTaskThatFails(TTask):
def run(self):
raise ValueError()
class CaseSensitiveParameterTask(PythonTask):
import pytest
from airflow import settings
from dbnd import PipelineTask, parameter
from dbnd._core.errors import DatabandConfigError
from dbnd._core.errors.base import DatabandRunError
from dbnd._core.inline import run_cmd_locally, run_task
from dbnd.tasks.basics import SimplestTask
from dbnd.testing import assert_run_task
from dbnd.testing.helpers_pytest import skip_on_windows
class SleepyTask(SimplestTask):
sleep_time = parameter.value(0.1, significant=False)
def run(self):
if self.sleep_time:
time.sleep(self.sleep_time)
super(SleepyTask, self).run()
class ParallelTasksPipeline(PipelineTask):
num_of_tasks = parameter.value(3)
def band(self):
tasks = []
for i in range(self.num_of_tasks):
tasks.append(SleepyTask(simplest_param=str(i)))
return tasks
import logging
import sys
from dbnd import PipelineTask, PythonTask, output, parameter
logger = logging.getLogger(__name__)
class T1(PythonTask):
p1 = parameter.value("somep")
o_1 = output[str]
def run(self):
self.o_1 = self.p1
class T2(PythonTask):
p1 = parameter.value("somep")
o_1 = output[str]
def run(self):
raise Exception()
# self.o_1 = self.p1
class TPipe(PipelineTask):
system_secrets = parameter(empty_default=True).help(
"System secrets (used by Databand Framework)"
)[List]
env_vars = parameter(empty_default=True)[Dict]
node_selectors = parameter(empty_default=True)[Dict]
annotations = parameter(empty_default=True)[Dict]
pods_creation_batch_size = parameter.value(10)[int]
service_account_name = parameter.none()[str]
gcp_service_account_keys = parameter.none()[
str
] # it's actually dict, but KubeConf expects str
affinity = parameter(empty_default=True)[Dict]
tolerations = parameter(empty_default=True)[List]
hostnetwork = parameter.value(False)
configmaps = parameter(empty_default=True)[List[str]]
volumes = parameter.none()[List[str]]
volume_mounts = parameter.none()[List[str]]
security_context = parameter.none()[List[str]]
labels = parameter.none()[Dict]
request_memory = parameter.none()[str]
request_cpu = parameter.none()[str]
limit_memory = parameter.none()[str]
limit_cpu = parameter.none()[str]
requests = parameter.none()[Dict]
limits = parameter.none()[Dict]
pod_exit_code_to_retry_count = parameter(empty_default=True).help(
from dbnd_gcp.apache_beam import ApacheBeamJobCtrl
if typing.TYPE_CHECKING:
from dbnd_gcp.dataflow.dataflow_config import DataflowConfig
logger = logging.getLogger(__name__)
class ApacheBeamConfig(Config):
"""Apache Beam (-s [TASK].spark.[PARAM]=[VAL] for specific tasks)"""
# we don't want spark class to inherit from this one, as it should has Config behaviour
_conf__task_family = "beam"
jar = parameter.value(None, description="Main application jar")[str]
verbose = parameter.value(
False,
description="Whether to pass the verbose flag to spark-submit process for debugging",
)
options = parameter(empty_default=True)[Dict[str, str]]
class LocalBeamEngineConfig(EngineConfig):
def get_beam_ctrl(self, task_run):
from dbnd_gcp.apache_beam.local_apache_beam import LocalApacheBeamJobCtrl
return LocalApacheBeamJobCtrl(task_run)