Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.
def test_standard_store_registry_with_mocked_entrypoint():
mock_entrypoint = mock.Mock()
mock_entrypoint.name = "mock-scheme"
with mock.patch(
"entrypoints.get_group_all", return_value=[mock_entrypoint]
):
# Entrypoints are registered at import time, so we need to reload the
# module to register the entrypoint given by the mocked
# extrypoints.get_group_all
reload(mlflow.tracking._tracking_service.utils)
expected_standard_registry = {
'',
'file',
'http',
'https',
'postgresql',
'mysql',
'sqlite',
'mssql',
'databricks',
'mock-scheme'
}
assert expected_standard_registry.issubset(
mlflow.tracking._tracking_service.utils._tracking_store_registry._registry.keys()
)
def test_gbt():
old_uri = tracking.get_tracking_uri()
with TempDir(chdr=False, remove_on_exit=True) as tmp:
try:
diamonds = tmp.path("diamonds")
artifacts = tmp.path("artifacts")
os.mkdir(diamonds)
os.mkdir(artifacts)
tracking.set_tracking_uri(artifacts)
mlflow.set_experiment("test-experiment")
# Download the diamonds dataset via mlflow run
run(".", entry_point="main", version=None,
parameters={"dest-dir": diamonds},
mode="local", cluster_spec=None, git_username=None, git_password=None,
use_conda=True, storage_dir=None)
# Run the main gbt app via mlflow
submitted_run = run(
def _currently_registered_run_context_provider_classes():
return {
provider.__class__
for provider in mlflow.tracking.context._run_context_provider_registry
}
log_param = mlflow.tracking.fluent.log_param
log_metric = mlflow.tracking.fluent.log_metric
set_tag = mlflow.tracking.fluent.set_tag
delete_tag = mlflow.tracking.fluent.delete_tag
log_artifacts = mlflow.tracking.fluent.log_artifacts
log_artifact = mlflow.tracking.fluent.log_artifact
active_run = mlflow.tracking.fluent.active_run
get_run = mlflow.tracking.fluent.get_run
start_run = mlflow.tracking.fluent.start_run
end_run = mlflow.tracking.fluent.end_run
search_runs = mlflow.tracking.fluent.search_runs
get_artifact_uri = mlflow.tracking.fluent.get_artifact_uri
set_tracking_uri = tracking.set_tracking_uri
get_experiment = mlflow.tracking.fluent.get_experiment
get_experiment_by_name = mlflow.tracking.fluent.get_experiment_by_name
get_tracking_uri = tracking.get_tracking_uri
create_experiment = mlflow.tracking.fluent.create_experiment
set_experiment = mlflow.tracking.fluent.set_experiment
log_params = mlflow.tracking.fluent.log_params
log_metrics = mlflow.tracking.fluent.log_metrics
set_tags = mlflow.tracking.fluent.set_tags
delete_experiment = mlflow.tracking.fluent.delete_experiment
delete_run = mlflow.tracking.fluent.delete_run
register_model = mlflow.tracking._model_registry.fluent.register_model
run = projects.run
__all__ = ["ActiveRun", "log_param", "log_params", "log_metric", "log_metrics", "set_tag",
"set_tags", "delete_tag", "log_artifacts", "log_artifact", "active_run", "start_run",
"end_run", "search_runs", "get_artifact_uri", "set_tracking_uri", "create_experiment",
def get_host_(tracking_uri):
host = os.environ.get('DATABRICKS_HOST',None)
if host is not None:
return host
try:
db_profile = mlflow.tracking.utils.get_db_profile_from_uri(tracking_uri)
config = databricks_utils.get_databricks_host_creds(db_profile)
return config.host
except Exception as e:
return None
def __init__(self, mlflow_client=None, spark=None, logmod=20):
self.logmod = logmod
self.mlflow_client = mlflow_client
self.spark = spark
self.mlflow_client = mlflow_client
if mlflow_client is None:
self.mlflow_client = mlflow.tracking.MlflowClient()
mlflow_utils.dump_mlflow_info()
if spark is None:
self.spark = SparkSession.builder.appName("mlflow_metrics").enableHiveSupport().getOrCreate()
print("logmod:",logmod)
def run_databricks(remote_run, uri, entry_point, work_dir, parameters, experiment_id, cluster_spec):
"""
Run the project at the specified URI on Databricks, returning a ``SubmittedRun`` that can be
used to query the run's status or wait for the resulting Databricks Job run to terminate.
"""
profile = get_db_profile_from_uri(tracking.get_tracking_uri())
run_id = remote_run.info.run_id
db_job_runner = DatabricksJobRunner(databricks_profile=profile)
db_run_id = db_job_runner.run_databricks(
uri, entry_point, work_dir, parameters, experiment_id, cluster_spec, run_id)
submitted_run = DatabricksSubmittedRun(db_run_id, run_id, db_job_runner)
submitted_run._print_description_and_log_tags()
return submitted_run
parser.add_argument("--threshold", type=int, default=3200)
parser.add_argument("--epoch", type=int, default=10)
parser.add_argument("--pretrain_unit", type=str, default="ERCF", choices=["N", "E", "R", "C", "F", "ERCF"])
args = parser.parse_args()
USE_POI = (args.use_poi == 1)
device = torch.device("cuda:" + args.gpu)
mlflow.set_tracking_uri("/data1/output")
experiment_name = "Default"
experiment_ID = 0
try:
experiment_ID = mlflow.create_experiment(name=experiment_name)
print("Initial Create!")
except:
service = mlflow.tracking.get_service()
experiments = service.list_experiments()
for exp in experiments:
if exp.name == experiment_name:
experiment_ID = exp.experiment_id
print("Experiment Exists!")
break
setproctitle.setproctitle('DPLink')
thre = args.threshold
rnn_unit = 'GRU'
attn_unit = 'dot'
test_pretrain = False # test the effect of different pretrain degree, working with run_pretrain
pre_path, rank_pre2, hit_pre2 = None, None, None
for run_id in range(args.repeat):
with mlflow.start_run(experiment_id=experiment_ID):
def run_databricks(self, uri, entry_point, work_dir, parameters, experiment_id, cluster_spec,
run_id):
tracking_uri = _get_tracking_uri_for_run()
dbfs_fuse_uri = self._upload_project_to_dbfs(work_dir, experiment_id)
env_vars = {
tracking._TRACKING_URI_ENV_VAR: tracking_uri,
tracking._EXPERIMENT_ID_ENV_VAR: experiment_id,
}
_logger.info("=== Running entry point %s of project %s on Databricks ===", entry_point, uri)
# Launch run on Databricks
command = _get_databricks_run_cmd(dbfs_fuse_uri, run_id, entry_point, parameters)
return self._run_shell_command_job(uri, command, env_vars, cluster_spec)
@click.option("--experiment-id", envvar=mlflow.tracking._EXPERIMENT_ID_ENV_VAR, type=click.STRING,
help="Specify the experiment ID for list of runs.", required=True)
@click.option("--view", "-v", default="active_only",
help="Select view type for list experiments. Valid view types are "
"'active_only' (default), 'deleted_only', and 'all'.")
def list_run(experiment_id, view):
"""
List all runs of the specified experiment in the configured tracking server.
"""
store = _get_store()
view_type = ViewType.from_string(view) if view else ViewType.ACTIVE_ONLY
runs = store.search_runs([experiment_id], None, view_type)
table = []
for run in runs:
tags = {k: v for k, v in run.data.tags.items()}
run_name = tags.get(MLFLOW_RUN_NAME, "")
table.append([conv_longdate_to_str(run.info.start_time), run_name, run.info.run_id])