Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.
def _compare(input, other, operation, out=None):
if not isinstance(other, Tensor):
other = WrapScalar(other, input.dtype, input.device)
dev = MakeDevice(inputs=[input, other])
key = 'Compare/{}/{}'.format(operation, dev)
module = get_module(Compare, key, dev, operation=operation)
return module.forward(input, other, out)
w : dragon.vm.torch.Tensor
The w.
bias : dragon.vm.torch.Tensor, optional
The bias.
transW : boolean
Whether to transpose the ``w``.
Returns
-------
dragon.vm.torch.Tensor
The output tensor.
"""
dev = MakeDevice(inputs=[x, w] + ([bias] if bias else []))
key = 'FullyConnected/{}/transW:{}'.format(dev, transW)
module = get_module(FullyConnected, key, dev, transW=transW)
return module.forward(x, w, bias, out)
def roi_align(feature, rois, pooled_h, pooled_w,
spatial_scale, sampling_ratio=2):
ctx = MakeContext(inputs=[feature])
key = 'torch.ops.roi_align/{}:{}/pool_h:{}/pool_w:{}/' \
'spatial_scale:{}/sampling_ratio:{}'.format(
ctx[0], ctx[1], pooled_h, pooled_w, spatial_scale, sampling_ratio)
module = get_module(RoIAlign, key, ctx, pooled_h=pooled_h,
pooled_w=pooled_w, spatial_scale=spatial_scale, sampling_ratio=sampling_ratio)
return module.forward(feature, rois)
def _fundamental(input, value, op='Add', out=None):
if not isinstance(value, Tensor):
value = WrapScalar(value, input.dtype, input.device)
dev = MakeDevice(inputs=[input, value])
key = '{}/{}'.format(op, dev)
module = get_module(Fundamental, key, dev, op_type=op)
return module.forward(input, value, out)
def _uniform(input, shape, low, high):
dev = MakeDevice(inputs=[input]); ndim = len(shape)
key = 'Uniform/{}/dtype:{}/ndim:{}/low:{}/high:{}'.format(
dev, input.dtype, ndim, float(low), float(high))
module = get_module(
RandomUniform, key, dev,
ndim=ndim,
low=low,
high=high,
dtype=input.dtype,
)
return module.forward(input, shape)
def _maximum(input, other, out=None):
if not isinstance(input, Tensor):
input = WrapScalar(input, other.dtype, other._ctx)
elif not isinstance(other, Tensor):
other = WrapScalar(other, input.dtype, input._ctx)
ctx = MakeContext(inputs=[input])
key = 'torch.ops.maximum/{}:{}'.format(ctx[0], ctx[1])
module = get_module(Maximum, key, ctx)
return module.forward(input, other, out)
def _log(input, out=None):
ctx = MakeContext(inputs=[input])
key = 'torch.ops.log/{}:{}'.format(ctx[0], ctx[1])
module = get_module(Log, key, ctx)
return module.forward(input, out)
def _fundamental(input, value, op='Add', out=None):
if not isinstance(value, Tensor):
value = WrapScalar(value, input.dtype, input._ctx)
ctx = MakeContext(inputs=[input, value])
key = 'torch.ops.{}/{}:{}'.format(op.lower(), ctx[0], ctx[1])
module = get_module(Fundamental, key, ctx, op_type=op)
return module.forward(input, value, out)
def _fill(input, shape, value):
dev = MakeDevice(inputs=[input]); ndim = len(shape)
key = 'Fill/{}/dtype:{}/ndim:{}/value:{}' \
.format(dev, input.dtype, ndim, value)
module = get_module(
Fill, key, dev,
ndim=ndim,
value=value,
dtype=input.dtype,
)
return module.forward(input, shape)