Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.
def _arg_reduce(input, operation, dim=None, keepdim=False, topk=1, out=None):
if dim is None: keepdim = False
dev = MakeDevice(inputs=[input])
key = '{}/{}/dim:{}/keepdim:{}/topk:{}'.format(
operation, dev, dim, int(keepdim), topk)
module = get_module(
ArgReduce, key, dev,
axis=dim,
topk=topk,
keepdim=keepdim,
operation=operation,
)
return module.forward(input, out)
def _assign(output, starts, sizes, input):
if not isinstance(input, Tensor):
if isinstance(input, (tuple, list)):
input = Tensor(input, dtype=output.dtype, device=output.device)
else:
input = WrapScalar(input, output.dtype, output.device)
nstarts, nsizes = len(starts), len(sizes)
dev = MakeDevice(inputs=[input])
key = 'Assign/{}/nstarts:{}/nsizes:{}'.format(dev, nstarts, nsizes)
module = get_module(Assign, key, dev, nstarts=nstarts, nsizes=nsizes)
return module.forward(input, output, starts, sizes)
Parameters
----------
condition : dragon.vm.torch.Tensor
The byte condition tensor.
x : dragon.vm.torch.Tensor
The elements for *1*.
y : dragon.vm.torch.Tensor
The elements for *0*.
Returns
-------
dragon.vm.torch.Tensor
The output tensor.
"""
dev = MakeDevice(inputs=[condition, x, y])
key = 'Where/{}'.format(dev)
module = get_module(Where, key, dev)
return module.forward(condition, x, y)
def _masked_assign(output, mask, input):
if not isinstance(input, Tensor):
if isinstance(input, (tuple, list)):
input = Tensor(input, dtype=output.dtype, device=output.device)
else:
input = WrapScalar(input, output.dtype, output.device)
dev = MakeDevice(inputs=[input])
key = 'MaskedAssign/{}'.format(dev)
module = get_module(MaskedAssign, key, dev)
return module.forward(input, output, mask)
def _compare(input, other, operation, out=None):
if not isinstance(other, Tensor):
other = WrapScalar(other, input.dtype, input.device)
dev = MakeDevice(inputs=[input, other])
key = 'Compare/{}/{}'.format(operation, dev)
module = get_module(Compare, key, dev, operation=operation)
return module.forward(input, other, out)
"""Return the indices of non-zero elements.
Parameters
----------
input : dragon.vm.torch.Tensor
The input tensor.
out : dragon.vm.torch.Tensor, optional
The optional output tensor.
Returns
-------
dragon.vm.torch.Tensor
The output tensor.
"""
dev = MakeDevice(inputs=[input])
key = 'NonZero/{}'.format(dev)
module = get_module(NonZero, key, dev)
return module.forward(input, out)
Parameters
----------
input : dragon.vm.torch.Tensor
The input tensor.
num_samples : int
The number of samples.
eps : float, optional, default=0.
The prob to a uniform sampling.
Returns
-------
dragon.vm.torch.Tensor
The output tensor.
"""
dev = MakeDevice(inputs=[input])
key = 'Multinomial/{}' \
'/num_samples:{}' \
'/eps:{}'.format(dev, num_samples, eps)
module = get_module(
Multinomial, key, dev,
eps=eps,
num_samples=num_samples,
)
return module.forward(input, out)
input : dragon.vm.torch.Tensor
The input tensor.
dim : int, optional, default=0
The axis of channels.
group : int, optional, default=1
The number of groups.
out : dragon.vm.torch.Tensor, optional
The output tensor.
Returns
-------
dragon.vm.torch.Tensor
The new tensor.
"""
dev = MakeDevice([input])
key = 'ChannelShuffle/{}/dim:{}/group:{}'.format(dev, dim, group)
module = get_module(
ChannelShuffle, key, dev,
axis=dim,
group=group,
)
return module.forward(input, out)
"""Compute the square-root of input.
Parameters
----------
input : dragon.vm.torch.Tensor
The input tensor.
out : dragon.vm.torch.Tensor, optional
The output tensor.
Returns
-------
dragon.vm.torch.Tensor
The output tensor.
"""
dev = MakeDevice(inputs=[input])
key = 'Sqrt/{}'.format(dev)
module = get_module(Sqrt, key, dev)
return module.forward(input, out)