Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.
def test_retry_with_noauth_401(capsys):
def fail():
res = requests.Response()
res.status_code = 401
raise retry.TransientException(exc=requests.HTTPError(response=res))
fn = retry.Retry(fail, check_retry_fn=util.no_retry_auth)
with pytest.raises(CommError) as excinfo:
fn()
assert excinfo.value.message == 'Invalid or missing api_key. Run wandb login'
newline (bool, optional): Print a newline at the end of the string
repeat (bool, optional): If set to False only prints the string once per process
"""
if string:
line = '\n'.join(['{}: {}'.format(LOG_STRING, s)
for s in string.split('\n')])
else:
line = ''
if not repeat and line in PRINTED_MESSAGES:
return
# Repeated line tracking limited to 1k messages
if len(PRINTED_MESSAGES) < 1000:
PRINTED_MESSAGES.add(line)
if os.getenv(env.SILENT):
from wandb import util
util.mkdir_exists_ok(os.path.dirname(util.get_log_file_path()))
with open(util.get_log_file_path(), 'w') as log:
click.echo(line, file=log, nl=newline)
else:
click.echo(line, file=sys.stderr, nl=newline)
if isinstance(v, six.string_types):
if len(v) >= 20:
v = v[:20] + '...'
wandb.termlog(format_str.format(k, v))
elif isinstance(v, numbers.Number):
wandb.termlog(format_str.format(k, v))
self._run.history.load()
history_keys = self._run.history.keys()
# Only print sparklines if the terminal is utf-8
if len(history_keys) and sys.stdout.encoding == "UTF_8":
logger.info("rendering history")
wandb.termlog('Run history:')
max_len = max([len(k) for k in history_keys])
for key in history_keys:
vals = util.downsample(self._run.history.column(key), 40)
if any((not isinstance(v, numbers.Number) for v in vals)):
continue
line = sparkline.sparkify(vals)
format_str = u' {:>%s} {}' % max_len
wandb.termlog(format_str.format(key, line))
wandb_files = set([save_name for save_name in self._file_pusher.files() if util.is_wandb_file(save_name)])
media_files = set([save_name for save_name in self._file_pusher.files() if save_name.startswith('media')])
other_files = set(self._file_pusher.files()) - wandb_files - media_files
logger.info("syncing files to cloud storage")
if other_files:
wandb.termlog('Syncing files in %s:' % os.path.relpath(self._run.dir))
for save_name in sorted(other_files):
wandb.termlog(' %s' % save_name)
wandb.termlog('plus {} W&B file(s) and {} media file(s)'.format(len(wandb_files), len(media_files)))
else:
'heartbeat_seconds': 30,
}
self.client = Client(
transport=RequestsHTTPTransport(
headers={'User-Agent': self.user_agent, 'X-WANDB-USERNAME': env.get_username(env=self._environ)},
use_json=True,
# this timeout won't apply when the DNS lookup fails. in that case, it will be 60s
# https://bugs.python.org/issue22889
timeout=self.HTTP_TIMEOUT,
auth=("api", self.api_key or ""),
url='%s/graphql' % self.settings('base_url')
)
)
self.gql = retry.Retry(self.execute,
retry_timedelta=retry_timedelta,
check_retry_fn=util.no_retry_auth,
retryable_exceptions=(RetryError, requests.RequestException))
self._current_run_id = None
self._file_stream_api = None
self._grouping = grouping
self._caption = caption
self._width = None
self._height = None
self._image = None
if isinstance(data_or_path, six.string_types):
super(Image, self).__init__(data_or_path, is_tmp=False)
else:
data = data_or_path
PILImage = util.get_module(
"PIL.Image", required='wandb.Image needs the PIL package. To get it, run "pip install pillow".')
if util.is_matplotlib_typename(util.get_full_typename(data)):
buf = six.BytesIO()
util.ensure_matplotlib_figure(data).savefig(buf)
self._image = PILImage.open(buf)
elif isinstance(data, PILImage.Image):
self._image = data
elif util.is_pytorch_tensor_typename(util.get_full_typename(data)):
vis_util = util.get_module(
"torchvision.utils", "torchvision is required to render images")
if hasattr(data, "requires_grad") and data.requires_grad:
data = data.detach()
data = vis_util.make_grid(data, normalize=True)
self._image = PILImage.fromarray(data.mul(255).clamp(
0, 255).byte().permute(1, 2, 0).cpu().numpy())
else:
if hasattr(data, "numpy"): # TF data eager tensors
data = data.numpy()
if data.ndim > 2:
data = data.squeeze() # get rid of trivial dimensions as a convenience
def val_to_json(run, key, val, step='summary'):
# Converts a wandb datatype to its JSON representation.
converted = val
typename = util.get_full_typename(val)
if util.is_pandas_data_frame(val):
assert step == 'summary', "We don't yet support DataFrames in History."
return data_frame_to_json(val, run, key, step)
elif util.is_matplotlib_typename(typename):
# This handles plots with images in it because plotly doesn't support it
# TODO: should we handle a list of plots?
val = util.ensure_matplotlib_figure(val)
if any(len(ax.images) > 0 for ax in val.axes):
PILImage = util.get_module(
"PIL.Image", required="Logging plots with images requires pil: pip install pillow")
buf = six.BytesIO()
val.savefig(buf)
val = Image(PILImage.open(buf))
else:
converted = plot_to_json(val)
elif util.is_plotly_typename(typename):
converted = plot_to_json(val)
elif isinstance(val, collections.Sequence) and all(isinstance(v, WBValue) for v in val):
# This check will break down if Image/Audio/... have child classes.
if len(val) and isinstance(val[0], BatchableMedia) and all(isinstance(v, type(val[0])) for v in val):
def _read_queue(self):
# called from the push thread (_thread_body), this does an initial read
# that'll block for up to rate_limit_seconds. Then it tries to read
# as much out of the queue as it can. We do this because the http post
# to the server happens within _thread_body, and can take longer than
# our rate limit. So next time we get a chance to read the queue we want
# read all the stuff that queue'd up since last time.
#
# If we have more than MAX_ITEMS_PER_PUSH in the queue then the push thread
# will get behind and data will buffer up in the queue.
return util.read_many_from_queue(
self._queue, self.MAX_ITEMS_PER_PUSH, self.rate_limit_seconds())
# handle non-git directories
if not root:
root = os.path.abspath(os.getcwd())
host = socket.gethostname()
remote_url = 'file://%s%s' % (host, root)
run.save(program=args['program'], api=api)
env = dict(os.environ)
run.set_environment(env)
try:
rm = wandb.run_manager.RunManager(api, run)
except wandb.run_manager.Error:
exc_type, exc_value, exc_traceback = sys.exc_info()
wandb.termerror('An Exception was raised during setup, see %s for full traceback.' %
util.get_log_file_path())
wandb.termerror(exc_value)
if 'permission' in str(exc_value):
wandb.termerror(
'Are you sure you provided the correct API key to "wandb login"?')
lines = traceback.format_exception(
exc_type, exc_value, exc_traceback)
logging.error('\n'.join(lines))
else:
rm.run_user_process(args['program'], args['args'], env)
def val_to_json(run, key, val, step='summary'):
# Converts a wandb datatype to its JSON representation.
converted = val
typename = util.get_full_typename(val)
if util.is_pandas_data_frame(val):
assert step == 'summary', "We don't yet support DataFrames in History."
return data_frame_to_json(val, run, key, step)
elif util.is_matplotlib_typename(typename):
# This handles plots with images in it because plotly doesn't support it
# TODO: should we handle a list of plots?
val = util.ensure_matplotlib_figure(val)
if any(len(ax.images) > 0 for ax in val.axes):
PILImage = util.get_module(
"PIL.Image", required="Logging plots with images requires pil: pip install pillow")
buf = six.BytesIO()
val.savefig(buf)
val = Image(PILImage.open(buf))
else:
converted = plot_to_json(val)
elif util.is_plotly_typename(typename):
converted = plot_to_json(val)