Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.
def _parse_settings(settings):
"""settings could be json or comma seperated assignments."""
ret = {}
# TODO(jhr): merge with magic_impl:_parse_magic
if settings.find('=') > 0:
for item in settings.split(","):
kv = item.split("=")
if len(kv) != 2:
wandb.termwarn("Unable to parse sweep settings key value pair", repeat=False)
ret.update(dict([kv]))
return ret
wandb.termwarn("Unable to parse settings parameter", repeat=False)
return ret
if not isinstance(sweep_id, str):
wandb.termerror('Expected string sweep_id')
return
sweep_split = sweep_id.split('/')
if len(sweep_split) == 1:
pass
elif len(sweep_split) == 2:
split_project, sweep_id = sweep_split
if project and split_project:
wandb.termwarn('Ignoring project commandline parameter')
project = split_project or project
elif len(sweep_split) == 3:
split_entity, split_project, sweep_id = sweep_split
if entity and split_entity:
wandb.termwarn('Ignoring entity commandline parameter')
if project and split_project:
wandb.termwarn('Ignoring project commandline parameter')
project = split_project or project
entity = split_entity or entity
else:
wandb.termerror('Expected sweep_id in form of sweep, project/sweep, or entity/project/sweep')
return
if entity:
env.set_entity(entity)
if project:
env.set_project(project)
logger.setLevel(logging.DEBUG)
ch = logging.StreamHandler()
log_level = logging.DEBUG
if in_jupyter:
first = value.histo.bucket_limit[0] + \
value.histo.bucket_limit[0] - value.histo.bucket_limit[1]
last = value.histo.bucket_limit[-2] + \
value.histo.bucket_limit[-2] - value.histo.bucket_limit[-3]
np_histogram = (list(value.histo.bucket), [
first] + value.histo.bucket_limit[:-1] + [last])
try:
#TODO: we should just re-bin if there are too many buckets
values[tag] = wandb.Histogram(
np_histogram=np_histogram)
except ValueError:
wandb.termwarn("Not logging key \"{}\". Histograms must have fewer than {} bins".format(
tag, wandb.Histogram.MAX_LENGTH), repeat=False)
else:
#TODO: is there a case where we can render this?
wandb.termwarn("Not logging key \"{}\". Found a histogram with only 2 bins.".format(tag), repeat=False)
elif value.tag == "_hparams_/session_start_info":
if wandb.util.get_module("tensorboard.plugins.hparams"):
from tensorboard.plugins.hparams import plugin_data_pb2
plugin_data = plugin_data_pb2.HParamsPluginData()
plugin_data.ParseFromString(
value.metadata.plugin_data.content)
for key, param in six.iteritems(plugin_data.session_start_info.hparams):
if not wandb.run.config.get(key):
wandb.run.config[key] = param.number_value or param.string_value or param.bool_value
else:
wandb.termerror(
"Received hparams tf.summary, but could not import the hparams plugin from tensorboard")
return values
def set_wandb_attrs(cbk, val_data):
if isinstance(cbk, WandbCallback):
if is_generator_like(val_data):
cbk.generator = val_data
elif is_dataset(val_data):
if context.executing_eagerly():
cbk.generator = iter(val_data)
else:
wandb.termwarn(
"Found a validation dataset in graph mode, can't patch Keras.")
elif isinstance(val_data, tuple) and isinstance(val_data[0], tf.Tensor):
# Graph mode dataset generator
def gen():
while True:
yield K.get_session().run(val_data)
cbk.generator = gen()
else:
cbk.validation_data = val_data
def image_segmentation_multiclass_dataframe(x, y_true, y_pred, labels, example_ids=None, class_colors=None):
np = util.get_module('numpy', required='dataframes require numpy')
pd = util.get_module('pandas', required='dataframes require pandas')
x, y_true, y_pred= np.array(x), np.array(y_true), np.array(y_pred)
if x.shape[0] != y_true.shape[0]:
termwarn('Sample count mismatch: x(%d) != y_true(%d). skipping evaluation' % (x.shape[0], y_true.shape[0]))
return
if x.shape[0] != y_pred.shape[0]:
termwarn('Sample count mismatch: x(%d) != y_pred(%d). skipping evaluation' % (x.shape[0], y_pred.shape[0]))
return
if class_colors is not None and len(class_colors) != y_true.shape[-1]:
termwarn('Class color count mismatch: y_true(%d) != class_colors(%d). using generated colors' % (y_true.shape[-1], len(class_colors)))
class_colors = None
class_count = y_true.shape[-1]
if class_colors is None:
class_colors = util.class_colors(class_count)
class_colors = np.array(class_colors)
y_true_class = np.argmax(y_true, axis=-1)
y_pred_class = np.argmax(y_pred, axis=-1)
def best_run(self, order=None):
"Returns the best run sorted by the metric defined in config or the order passed in"
if order is None:
order = self.order
else:
order = QueryGenerator.format_order_key(order)
if order is None:
wandb.termwarn("No order specified and couldn't find metric in sweep config, returning most recent run")
else:
wandb.termlog("Sorting runs by %s" % order)
filters = {"$and": [{"sweep": self.id}]}
try:
return Runs(self.client, self.entity, self.project, order=order, filters=filters, per_page=1)[0]
except IndexError:
return None
def _fix_step(run_id, metrics, step, timestamp):
"""Handle different steps by namespacing a new step counter with the first key
if global step decreases. Also auto increase step if we're it's not increasing
every LOG_FLUSH_MINIMUM seconds TODO: make this actually work"""
key = list(metrics)[0]
run_log = _get_run(run_id, False)
if step in (None, 0) and int(time.time() - run_log["last_log"]) > LOG_FLUSH_MINIMUM:
wandb.termwarn("Metric logged without a step, pass a step to log_metric.", repeat=False)
step = run_log["step"] + 1
# Run with multiple steps, keeping a seperate step count
if step < run_log["step"]:
metrics[key+"/step"] = step
step = run_log["step"]
if step != run_log["step"]:
run_log["step"] = step
run_log["last_log"] = time.time()
return metrics, step
def _setup_resume(self, resume_status):
# write the tail of the history file
try:
history_tail = json.loads(resume_status['historyTail'])
jsonlfile.write_jsonl_file(os.path.join(self._run.dir, wandb_run.HISTORY_FNAME),
history_tail)
except ValueError:
logger.error("Couldn't parse history")
wandb.termwarn("Couldn't load recent history, resuming may not function properly")
# write the tail of the events file
try:
events_tail = json.loads(resume_status['eventsTail'])
jsonlfile.write_jsonl_file(os.path.join(self._run.dir, wandb_run.EVENTS_FNAME),
events_tail)
except ValueError:
logger.error("Couldn't parse system metrics / events")
# load the previous runs summary to avoid losing it, the user process will need to load it
self._run.summary.update(json.loads(resume_status['summaryMetrics'] or "{}"))
# Note: these calls need to happen after writing the files above. Because the access
# to self._run.events below triggers events to initialize, but we need the previous
# events to be written before that happens.
# If we're not in an interactive environment, default to dry-run.
elif not isatty(sys.stdout) or not isatty(sys.stdin):
result = LOGIN_CHOICE_DRYRUN
else:
for i, choice in enumerate(choices):
wandb.termlog("(%i) %s" % (i + 1, choice))
def prompt_choice():
try:
return int(six.moves.input("%s: Enter your choice: " % wandb.core.LOG_STRING)) - 1
except ValueError:
return -1
idx = -1
while idx < 0 or idx > len(choices) - 1:
idx = prompt_choice()
if idx < 0 or idx > len(choices) - 1:
wandb.termwarn("Invalid choice")
result = choices[idx]
wandb.termlog("You chose '%s'" % result)
if result == LOGIN_CHOICE_ANON:
key = api.create_anonymous_api_key()
set_api_key(api, key, anonymous=True)
return key
elif result == LOGIN_CHOICE_NEW:
key = browser_callback(signup=True) if browser_callback else None
if not key:
wandb.termlog('Create an account here: {}/authorize?signup=true'.format(api.app_url))
key = input_callback('%s: Paste an API key from your profile and hit enter' % wandb.core.LOG_STRING).strip()
set_api_key(api, key)
def backward_hook(module, input, output):
[hook.remove() for hook in hooks]
graph.loaded = True
if wandb.run:
wandb.run.summary["graph_%i" % graph_idx] = graph
else:
wandb.termwarn(
"wandb.watch was called without a call to wandb.init, call wandb.init before wandb.watch", repeat=False)
# TODO: Keeping this here as a starting point for adding graph data
if not graph.loaded:
def traverse(node, functions=[]):
if hasattr(node, 'grad_fn'):
node = node.grad_fn
if hasattr(node, 'variable'):
node = graph.nodes_by_id.get(id(node.variable))
if node:
node.functions = list(functions)
del functions[:]
if hasattr(node, 'next_functions'):
functions.append(type(node).__name__)
for f in node.next_functions: