Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.
def _fundamental(input, value, op='Add', out=None):
if not isinstance(value, Tensor):
value = WrapScalar(value, input.dtype, input.device)
dev = MakeDevice(inputs=[input, value])
key = '{}/{}'.format(op, dev)
module = get_module(Fundamental, key, dev, op_type=op)
return module.forward(input, value, out)
def _rfundamental(input, value, op='RAdd', out=None):
if not isinstance(value, Tensor):
value = WrapScalar(value, input.dtype, input.device)
dev = MakeDevice(inputs=[input, value])
key = '{}/{}'.format(op, dev)
module = get_module(Fundamental, key, dev, op_type=op)
return module.forward(value, input, out)
def _fundamental(input, value, op='Add', out=None):
if not isinstance(value, Tensor):
value = WrapScalar(value, input.dtype, input._ctx)
ctx = MakeContext(inputs=[input, value])
key = 'torch.ops.{}/{}:{}'.format(op.lower(), ctx[0], ctx[1])
module = get_module(Fundamental, key, ctx, op_type=op)
return module.forward(input, value, out)
def _maximum(input, other, out=None):
if not isinstance(input, Tensor):
input = WrapScalar(input, other.dtype, other._ctx)
elif not isinstance(other, Tensor):
other = WrapScalar(other, input.dtype, input._ctx)
ctx = MakeContext(inputs=[input])
key = 'torch.ops.maximum/{}:{}'.format(ctx[0], ctx[1])
module = get_module(Maximum, key, ctx)
return module.forward(input, other, out)
def _minimum(input, other, out=None):
if not isinstance(input, Tensor):
input = WrapScalar(input, other.dtype, other._ctx)
elif not isinstance(other, Tensor):
other = WrapScalar(other, input.dtype, input._ctx)
ctx = MakeContext(inputs=[input])
key = 'torch.ops.minimum/{}:{}'.format(ctx[0], ctx[1])
module = get_module(Minimum, key, ctx)
return module.forward(input, other, out)
def _compare(input, other, operation, out=None):
if not isinstance(other, Tensor):
other = WrapScalar(other, input.dtype, input.device)
dev = MakeDevice(inputs=[input, other])
key = 'Compare/{}/{}'.format(operation, dev)
module = get_module(Compare, key, dev, operation=operation)
return module.forward(input, other, out)
The input tensor.
other : dragon.vm.torch.Tensor or number
The input tensor.
out : dragon.vm.torch.Tensor, optional
The output tensor.
Returns
-------
dragon.vm.torch.Tensor
The output tensor.
"""
if not isinstance(input, Tensor):
input = WrapScalar(input, other.dtype, other.device)
elif not isinstance(other, Tensor):
other = WrapScalar(other, input.dtype, input.device)
dev = MakeDevice(inputs=[input])
key = 'Minimum/{}'.format(dev)
module = get_module(Minimum, key, dev)
return module.forward(input, other, out)