Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.
import torch
import numpy
import math
import sys
import torchsnooper
from python_toolbox import sys_tools
import re
import snoop
import copy
ansi_escape = re.compile(r'\x1B\[[0-?]*[ -/]*[@-~]')
default_config = copy.copy(snoop.config)
def func():
x = torch.tensor(math.inf)
x = torch.tensor(math.nan)
x = torch.tensor(1.0, requires_grad=True)
x = torch.tensor([1.0, math.nan, math.inf])
x = numpy.zeros((2, 2))
x = (x, x)
verbose_expect = '''
01:24:31.56 >>> Call to func in File "test_snoop.py", line 16
01:24:31.56 16 | def func():
01:24:31.56 17 | x = torch.tensor(math.inf)
01:24:31.56 .......... x = tensor<(), float32, cpu, has_inf>
def assert_output(verbose, expect):
torchsnooper.register_snoop(verbose=verbose)
with sys_tools.OutputCapturer(stdout=False, stderr=True) as output_capturer:
assert sys.gettrace() is None
snoop(func)()
assert sys.gettrace() is None
output = output_capturer.string_io.getvalue()
output = ansi_escape.sub('', output)
assert clean_output(output) == clean_output(expect)
snoop.config = default_config
unwanted = {
snoop.configuration.len_shape_watch,
snoop.configuration.dtype_watch,
}
snoop.config.watch_extras = tuple(x for x in snoop.config.watch_extras if x not in unwanted)
if verbose:
class TensorWrap:
def __init__(self, tensor):
self.tensor = tensor
def __repr__(self):
return self.tensor.__repr__()
snoop.config.watch_extras += (
lambda source, value: ('{}.data'.format(source), TensorWrap(value.data)),
)
def register_snoop(verbose=False, tensor_format=default_format, numpy_format=default_numpy_format):
import snoop
import cheap_repr
import snoop.configuration
cheap_repr.register_repr(torch.Tensor)(lambda x, _: tensor_format(x))
cheap_repr.register_repr(numpy.ndarray)(lambda x, _: numpy_format(x))
cheap_repr.cheap_repr(torch.zeros(6))
unwanted = {
snoop.configuration.len_shape_watch,
snoop.configuration.dtype_watch,
}
snoop.config.watch_extras = tuple(x for x in snoop.config.watch_extras if x not in unwanted)
if verbose:
class TensorWrap:
def __init__(self, tensor):
self.tensor = tensor
def __repr__(self):
return self.tensor.__repr__()
snoop.config.watch_extras += (
lambda source, value: ('{}.data'.format(source), TensorWrap(value.data)),
)