Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.
# along with the software. If not, See,
#
#
#
# ------------------------------------------------------------
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from dragon.vm.torch.autograd import no_grad
from dragon.vm.torch.tensor import _ReferenceTensor
from dragon.vm.torch.ops.modules.base import BaseModule
class Indexing(BaseModule):
"""This module imports the *CropOp* from backend.
Arbitrary length of starts and sizes will be take,
and the resulting memory is deep copied.
"""
def __init__(self, key, dev, **kwargs):
super(Indexing, self).__init__(key, dev, **kwargs)
self.nstarts = kwargs.get('nstarts', 0)
self.nsizes = kwargs.get('nsizes', 0)
self.register_op()
def register_op(self):
self.op_meta = {
'op_type': 'Crop',
'arguments': {
def register_op(self):
self.op_meta = {
'op_type': 'Compare',
'arguments': {
'operation': self.operation,
'to_uint8': True,
}}
def forward(self, x1, x2, y):
inputs = [x1, x2]; self.unify_devices(inputs)
outputs = [y] if y else [self.register_output()]
return self.run(inputs, outputs)
class Assign(BaseModule):
"""This module imports the *AssignOp* from backend.
Arbitrary length of starts and sizes will be take.
"""
def __init__(self, key, dev, **kwargs):
super(Assign, self).__init__(key, dev, **kwargs)
self.nstarts = kwargs.get('nstarts', 0)
self.nsizes = kwargs.get('nsizes', 0)
self.register_op()
def register_op(self):
self.op_meta = {
'op_type': 'Assign',
'arguments': {
'starts_desc': [
'arguments': {
'dtype': self.dtype,
'value': float(self.value),
'dims_desc': [d for d in self.shape] if self.n_dim > 0 else None,
}
}
def forward(self, x, shape):
outputs = [x]; self.unify_devices(outputs)
if shape is not None:
for ix, d in enumerate(shape):
self.set_argument_i(self.shape[ix], d)
return self.run([], outputs)
class Reshape(BaseModule):
def __init__(self, key, ctx, **kwargs):
super(Reshape, self).__init__(key, ctx, **kwargs)
self.n_dim = kwargs.get('n_dim', 0)
self.register_arguments()
self.register_op()
def register_arguments(self):
self.dims = [self.register_argument('dims[{}]'.format(i))
for i in range(self.n_dim)]
def register_op(self):
self.op_meta = {
'op_type': 'Reshape',
'n_inputs': 1, 'n_outputs': 1,
'arguments': {
'dims_desc': [d for d in self.dims]
self.op_meta = {
'op_type': 'Reduce',
'arguments': {
'operation': self.operation,
'axes': [self.dim] if self.dim is not None else None,
'keep_dims': self.keepdim,
},
}
def forward(self, x, y):
inputs = [x]; self.unify_devices(inputs)
outputs = [y] if y else [self.register_output()]
return self.run(inputs, outputs)
class ArgReduce(BaseModule):
def __init__(self, key, dev, **kwargs):
super(ArgReduce, self).__init__(key, dev, **kwargs)
self.operation = kwargs.get('operation', 'ARGMAX')
self.axis = kwargs.get('axis', None)
self.keepdim = kwargs.get('keepdim', True)
self.topk = kwargs.get('topk', 1)
self.register_op()
def register_op(self):
self.op_meta = {
'op_type': 'ArgReduce',
'arguments': {
'operation': self.operation
if 'ARG' in self.operation \
else 'ARG' + self.operation,
'axis': self.axis if self.axis else 2147483647,
def register_op(self):
self.op_meta = {
'op_type': 'ChannelShuffle',
'arguments': {
'axis': self.axis,
'group': self.group,
},
}
def forward(self, x, y):
inputs = [x]; self.unify_devices(inputs)
outputs = [y] if y else [self.register_output()]
return self.run(inputs, outputs)
class Repeat(BaseModule):
def __init__(self, key, dev, **kwargs):
super(Repeat, self).__init__(key, dev, **kwargs)
self.ntimes = kwargs.get('ntimes', 0)
self.register_op()
def register_op(self):
self.op_meta = {
'op_type': 'Tile',
'arguments': {
'multiples_desc': [
'${{HANDLE}}/multiples[{}]'.format(n)
for n in range(self.ntimes)
],
},
}
else:
if y:
if not isinstance(y, (tuple, list)):
raise TypeError('Excepted outputs as a tuple or list, got {}.'.format(type(y)))
if len(y) != 2:
raise ValueError('Excepted 2 outputs, got {}.'.format(len(y)))
outputs = [y[1], y[0]]
else: outputs = [self.register_output(), self.register_output()]
returns = self.run(inputs, outputs)
# Return values only
if self.axis is None: return returns[1]
# Return values and indices
return returns[1], returns[0]
class Reshape(BaseModule):
def __init__(self, key, dev, **kwargs):
super(Reshape, self).__init__(key, dev, **kwargs)
self.ndim = kwargs.get('ndim', 0)
self.register_op()
def register_op(self):
self.op_meta = {
'op_type': 'Reshape',
'arguments': {
'dims_desc': [
'${{HANDLE}}/dims[{}]'.format(n)
for n in range(self.ndim)
],
},
}
},
}
def update_args(self, A, starts, sizes):
for i, e in enumerate(starts):
self.set_arg_i64('{}/starts[{}]'.format(A, i), e)
self.set_arg_i64('{}/sizes[{}]'.format(A, i), sizes[i])
def forward(self, x, starts, sizes):
inputs = [x]; self.unify_devices(inputs)
outputs = [self.register_output()]
callback = lambda A: self.update_args(A, starts, sizes)
return self.run(inputs, outputs, callback=callback)
class Concat(BaseModule):
"""This module imports the *ConcatOp* from backend.
Concatenate the inputs along the given axis.
"""
def __init__(self, key, dev, **kwargs):
super(Concat, self).__init__(key, dev, **kwargs)
self.axis = kwargs.get('axis', 0)
self.register_op()
def register_op(self):
self.op_meta = {
'op_type': 'Concat',
'arguments': {
'axis': self.axis
},
class Maximum(BaseModule):
def __init__(self, key, dev, **kwargs):
super(Maximum, self).__init__(key, dev, **kwargs)
self.register_op()
def register_op(self):
self.op_meta = {'op_type': 'Maximum', 'arguments': {}}
def forward(self, x1, x2, y):
inputs = [x1, x2]; self.unify_devices(inputs)
outputs = [y] if y else [self.register_output()]
return self.run(inputs, outputs)
class Minimum(BaseModule):
def __init__(self, key, dev, **kwargs):
super(Minimum, self).__init__(key, dev, **kwargs)
self.register_op()
def register_op(self):
self.op_meta = {'op_type': 'Minimum', 'arguments': {}}
def forward(self, x1, x2, y):
inputs = [x1, x2]; self.unify_devices(inputs)
outputs = [y] if y else [self.register_output()]
return self.run(inputs, outputs)
class Clamp(BaseModule):
def __init__(self, key, dev, **kwargs):
super(Clamp, self).__init__(key, dev, **kwargs)
# Licensed under the BSD 2-Clause License.
# You should have received a copy of the BSD 2-Clause License
# along with the software. If not, See,
#
#
#
# ------------------------------------------------------------
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from dragon.vm.torch.ops.modules.base import BaseModule
class Reduce(BaseModule):
def __init__(self, key, ctx, **kwargs):
super(Reduce, self).__init__(key, ctx, **kwargs)
self.operation = kwargs.get('operation', 'SUM')
self.dim = kwargs.get('dim', None)
self.keepdim = kwargs.get('keepdim', True)
self.register_arguments()
self.register_op()
def register_arguments(self):
"""No Arguments for reduce op.
Mutable ``axis`` and ``keep_dims`` is non-trivial for backend,
we simply hash them in the persistent key.
"""
pass
def register_op(self):
self.op_meta = {
'op_type': 'Stack',
'arguments': {
'axis': self.axis
},
}
def forward(self, xs, y):
inputs = xs; self.unify_devices(inputs)
outputs = [y] if y else [self.register_output()]
return self.run(inputs, outputs)
class Chunk(BaseModule):
"""This module imports the *SliceOp* from backend.
Slice the inputs into several parts along the given axis.
"""
def __init__(self, key, dev, **kwargs):
super(Chunk, self).__init__(key, dev, **kwargs)
self.axis = kwargs.get('axis', 0)
self.chunks = kwargs.get('chunks', 1)
self.register_op()
def register_op(self):
self.op_meta = {
'op_type': 'Slice',
'arguments': {
'axis': self.axis,