Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.
def abs(x):
"Applies function abs elementwise to argument x"
return core.Result(core.ElwiseUnary("abs"), [x])
def cast(x, dtype):
dtype = Dtype.canon(dtype)
if x.dtype == dtype:
return x
else:
diff = _dtype_kind(dtype) in "cf"
opname = "castto%s"%dtype
ui = UnaryInfo(opname, lambda x: x.astype(dtype),diff, dtype, "cast(gy,%s)"%x.dtype if diff else "_no_grad()", "((%s)x)"%np2c[dtype])
return Result(ElwiseUnary(opname, ui), [x])
def sign(x):
"Applies function sign elementwise to argument x"
return core.Result(core.ElwiseUnary("sign"), [x])
def cast(x, dtype):
"""
Convert variable x to the desired datatype
"""
x = core.as_node(x)
if (x.dtype == dtype):
return x
else:
diff = (core.dtype_kind(dtype) in 'cf')
opname = 'cast_to_%s' % dtype
ui = core.UnaryInfo(opname, _get_nu_cast(dtype), diff, dtype,
lambda x,y,gy : cast(gy, x.dtype) if diff else None, 'x')
return core.Result(core.ElwiseUnary(opname, ui), [x])
def iceil(x):
"Applies function iceil elementwise to argument x"
return core.Result(core.ElwiseUnary("iceil"), [x])
def __neg__(self):
return Result(ElwiseUnary("neg"), [self])
# Binary ops
def ceil(x):
"Applies function ceil elementwise to argument x"
return core.Result(core.ElwiseUnary("ceil"), [x])
node2shape[node] = node2shape[node.parents[0]][node.op.idx]
else:
newparents = node.parents
node2shape[node] = node.op.shp_apply(newparents)
# assert all([s.get_dtype() == "i8" for s in node2shape[node]])
assert len(node2shape[node]) == node.ndim or isinstance(node.get_type(),Tuple)
# -- SCALAR VALUE --
if isinstance(node, Result):
op = node.op
if isinstance(op, Fill):
node2sv[node] = op.value
elif isinstance(op, Constant) and utils.is_singleton(op.value):
node2sv[node] = op.value.flat[0]
elif isinstance(op, Repeat) and newparents[0] in node2sv:
node2sv[node] = node2sv[newparents[0]]
elif isinstance(op, (ElwiseUnary, ElwiseBinary)) and all(p in node2sv for p in newparents):
node2sv[node] = node.op.info.pyfunc(*(node2sv[p] for p in newparents))