Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.
def test_run_context_multi_run(live_mock_server, git_repo):
os.environ[env.BASE_URL] = "http://localhost:%i" % 8765
os.environ["WANDB_API_KEY"] = "B" * 40
with wandb.init() as run:
run.log({"a": 1, "b": 2})
with wandb.init(reinit=True) as run:
run.log({"c": 3, "d": 4})
assert len(glob.glob("wandb/*")) == 4
def global_wandb_settings(tmpdir):
os.environ[env.CONFIG_DIR] = tmpdir.strpath
with open(os.path.join(tmpdir.strpath, 'settings'), "w+") as f:
yield f
del os.environ[env.CONFIG_DIR]
@click.option("--entity", "-e", default=None, envvar=env.ENTITY, help="The entity to scope the listing to.")
@display_error
def projects(entity, display=True):
projects = api.list_projects(entity=entity)
if len(projects) == 0:
message = "No projects found for %s" % entity
else:
message = 'Latest projects for "%s"' % entity
if display:
click.echo(click.style(message, bold=True))
for project in projects:
click.echo("".join(
(click.style(project['name'], fg="blue", bold=True),
" - ",
str(project['description'] or "").split("\n")[0])
))
return projects
def _get_or_start_wandb_run(run, run_id=None, name=None):
try:
os.environ[env.SILENT] = "1"
os.environ[env.SYNC_MLFLOW] = os.getenv(env.SYNC_MLFLOW, "all")
if run_id:
os.environ[env.RESUME] = "allow" # TODO: must?
if run.data.tags.get("mlflow.parentRunId"):
parent = RUNS.get(run.data.tags["mlflow.parentRunId"], {"run": None})["run"]
if parent and parent.group is None:
parent.group = run.data.tags["mlflow.parentRunId"]
parent.job_type = "parent"
parent.save()
#TODO: maybe call save
os.environ[env.RUN_GROUP] = run.data.tags["mlflow.parentRunId"]
os.environ[env.JOB_TYPE] = "child"
project = os.getenv(env.PROJECT, client.get_experiment(run.info.experiment_id).name)
config = run.data.tags
config["mlflow.tracking_uri"] = mlflow.get_tracking_uri()
config["mlflow.experiment_id"] = run.info.experiment_id
wandb_run = RUNS.get(run.info.run_id)
if wandb_run is None:
wandb_run = wandb.init(id=run.info.run_id, project=project,
name=name, config=config, reinit=True)
wandb.termlog("Syncing MLFlow metrics, params, and artifacts to: %s" %
wandb_run.get_url().split("/runs/")[0], repeat=False, force=True)
wandb_run.config._set_wandb('mlflow_version', mlflow.__version__)
RUNS[wandb_run.id] = {"step": 0, "last_log": time.time(), "run": wandb_run}
return wandb_run
except Exception as e:
def __init__(self, client, entity, project, run_id, attrs={}):
"""
Run is always initialized by calling api.runs() where api is an instance of wandb.Api
"""
super(Run, self).__init__(dict(attrs))
self.client = client
self._entity = entity
self.project = project
self._files = {}
self._base_dir = env.get_dir(tempfile.gettempdir())
self.id = run_id
self.sweep = None
self.dir = os.path.join(self._base_dir, *self.path)
try:
os.makedirs(self.dir)
except OSError:
pass
self._summary = None
self.state = attrs.get("state", "not found")
self.load(force=not attrs)
def __init__(self, overrides={}):
self.settings = {
'entity': None,
'project': None,
'run': "latest",
'base_url': env.get_base_url("https://api.wandb.ai")
}
if self.api_key is None:
wandb.login()
self.settings.update(overrides)
if 'username' in overrides and 'entity' not in overrides:
wandb.termwarn('Passing "username" to Api is deprecated. please use "entity" instead.')
self.settings['entity'] = overrides['username']
self._projects = {}
self._runs = {}
self._sweeps = {}
self._reports = {}
self._base_client = Client(
transport=RequestsHTTPTransport(
headers={'User-Agent': self.user_agent, 'Use-Admin-Privileges': "true"},
use_json=True,
# this timeout won't apply when the DNS lookup fails. in that case, it will be 60s
The purpose of this module is to break circular imports.
"""
import os
import string
import sys
import time
import click
from . import env
from . import io_wrap
# We use the hidden version if it already exists, otherwise non-hidden.
if os.path.exists(os.path.join(env.get_dir(os.getcwd()), '.wandb')):
__stage_dir__ = '.wandb' + os.sep
elif os.path.exists(os.path.join(env.get_dir(os.getcwd()), 'wandb')):
__stage_dir__ = "wandb" + os.sep
else:
__stage_dir__ = None
SCRIPT_PATH = os.path.abspath(sys.argv[0])
START_TIME = time.time()
LIB_ROOT = os.path.join(os.path.dirname(__file__), '..')
IS_GIT = os.path.exists(os.path.join(LIB_ROOT, '.git'))
def wandb_dir():
return os.path.join(env.get_dir(os.getcwd()), __stage_dir__ or ("wandb" + os.sep))
if wandb.env.is_debug():
six.reraise(type(err.last_exception), err.last_exception, sys.exc_info()[2])
else:
six.reraise(CommError, CommError(
message, err.last_exception), sys.exc_info()[2])
except Exception as err:
# gql raises server errors with dict's as strings...
if len(err.args) > 0:
payload = err.args[0]
else:
payload = err
if str(payload).startswith("{"):
message = ast.literal_eval(str(payload))["message"]
else:
message = str(err)
if wandb.env.is_debug():
six.reraise(*sys.exc_info())
else:
six.reraise(CommError, CommError(
message, err), sys.exc_info()[2])
def _wandb_join(exit_code=None):
global _global_run_stack
shutdown_async_log_thread()
run.close_files()
if exit_code is not None:
hooks.exit_code = exit_code
_user_process_finished(server, hooks,
wandb_process, stdout_redirector, stderr_redirector)
if len(_global_run_stack) > 0:
_global_run_stack.pop()
join = _wandb_join
_user_process_finished_called = False
# redirect output last of all so we don't miss out on error messages
stdout_redirector.redirect()
if not env.is_debug():
stderr_redirector.redirect()
def set_setting(self, key, value, globally=False):
self._settings.set(Settings.DEFAULT_SECTION, key, value, globally=globally)
if key == 'entity':
env.set_entity(value, env=self._environ)
elif key == 'project':
env.set_project(value, env=self._environ)