Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.
def test_dunder_str():
assert str(1) == "1"
def always_one(self):
return 'one'
curse(int, '__str__', always_one)
assert str(1) == "one"
def test_overriding_non_c_things():
"The `curse` function should not blow up on non-c python objs"
# Given that I have an instance of a python class
class Yo(object):
pass
obj = Yo()
# When I curse an instance method
curse(Yo, "my_method", lambda *a, **k: "Yo" * 2)
# Then I see that my object was cursed properly
assert obj.my_method() == "YoYo"
def test_dunder_func_chaining():
"""Overload * (mul) operator to to chaining between functions"""
def matmul_chaining(self, other):
if not isinstance(other, FunctionType):
raise NotImplementedError()
def wrapper(*args, **kwargs):
res = other(*args, **kwargs)
if hasattr(res, "__iter__"):
return self(*res)
return self(res)
return wrapper
curse(FunctionType, "__mul__", matmul_chaining)
f = lambda x, y: x * y
g = lambda x: (x, x)
squared = f * g
for i in range(0, 10, 2):
assert squared(i) == i ** 2
def test_sequence_dunder():
def derive_func(func, deriv_grad):
if deriv_grad == 0:
return func
e = 0.0000001
def wrapper(x):
return (func(x + e) - func(x - e)) / (2 * e)
if deriv_grad == 1:
return wrapper
return wrapper[deriv_grad - 1]
curse(FunctionType, "__getitem__", derive_func)
# a function an its derivations
f = lambda x: x ** 3 - 2 * x ** 2
f_1 = lambda x: 3 * x ** 2 - 4 * x
f_2 = lambda x: 6 * x - 4
for x in range(0, 10):
x = float(x) / 10.
assert almost_equal(f(x), f[0](x))
assert almost_equal(f_1(x), f[1](x))
# our hacky derivation becomes numerically unstable here
assert almost_equal(f_2(x), f[2](x), e=.01)
@wraps(f)
def wrapped(*args, **kwargs):
if not isinstance(args[1], unicode):
raise TypeError('must be unicode, not bytes')
return f(*args, **kwargs)
return wrapped
@require_unicode_arg2
def fromhex(self, string):
return b'\xff\xff'
def myrepr(self):
print('Calling fancy fn!')
return b'b' + self.__oldrepr__()
curse(bytes, "__repr__", myrepr)
print(repr(b''))
@contextmanager
def new_bytes_context():
curse(bytes, "__oldrepr__", bytes.__repr__)
curse(bytes, "_c___repr", myrepr)
curse(bytes, "fromhex", classmethod(fromhex))
yield
reverse(bytes, "fromhex")
reverse(bytes, "_c___repr")
reverse(bytes, "__oldrepr__")
with new_bytes_context():
b = b'Byte string'
print(repr(b.fromhex(u'aa 0f')))
print(repr(bytes.fromhex(u'b3 2e')))
import codecs
def _unhex(x):
return codecs.decode(x, 'hex')
def _bytes(x):
return codecs.encode(x, 'latin1')
def _str(x):
return x.decode('latin1')
def _hex(x):
return codecs.encode(x.bytes(), 'hex')
def _nop(x):
return x
forbiddenfruit.curse(bytes, "str", _str)
forbiddenfruit.curse(str, "bytes", _bytes)
forbiddenfruit.curse(str, "str", _nop)
forbiddenfruit.curse(bytes, "bytes", _nop)
forbiddenfruit.curse(bytes, "unhex", _unhex)
forbiddenfruit.curse(str, "unhex", _unhex)
forbiddenfruit.curse(str, "hex", _hex)
else:
logging.error("Unsupported python variant.")
return codecs.decode(x, 'hex')
def _bytes(x):
return codecs.encode(x, 'latin1')
def _str(x):
return x.decode('latin1')
def _hex(x):
return codecs.encode(x.bytes(), 'hex')
def _nop(x):
return x
forbiddenfruit.curse(bytes, "str", _str)
forbiddenfruit.curse(str, "bytes", _bytes)
forbiddenfruit.curse(str, "str", _nop)
forbiddenfruit.curse(bytes, "bytes", _nop)
forbiddenfruit.curse(bytes, "unhex", _unhex)
forbiddenfruit.curse(str, "unhex", _unhex)
forbiddenfruit.curse(str, "hex", _hex)
else:
logging.error("Unsupported python variant.")
def new_bytes_context():
curse(bytes, "__oldrepr__", bytes.__repr__)
curse(bytes, "_c___repr", myrepr)
curse(bytes, "fromhex", classmethod(fromhex))
yield
reverse(bytes, "fromhex")
reverse(bytes, "_c___repr")
reverse(bytes, "__oldrepr__")
def center(self):
return self
def ub(self):
return self
def cudify(self, cuda_async = True):
return self.cuda(non_blocking=cuda_async) if use_cuda else self
def log_softmax(self, *args, dim = 1, **kargs):
return F.log_softmax(self, *args,dim = dim, **kargs)
if torch.__version__[0] == "0" and torch.__version__ != "0.4.1":
Point.log_softmax = log_softmax
for nm in getMethodNames(Point):
curse(Pt, nm, getattr(Point, nm))
def spy_on_file_io():
orig_read = io.BufferedReader.read
orig_write = io.TextIOWrapper.write
forbiddenfruit.curse(io.BufferedReader, "read", spy_read(io.BufferedReader.read))
forbiddenfruit.curse(io.TextIOWrapper, "write", spy_write(io.TextIOWrapper.write))
yield
forbiddenfruit.curse(io.BufferedReader, "read", orig_read)
forbiddenfruit.curse(io.TextIOWrapper, "write", orig_write)