Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.
class TupleIndex(Op):
call_type="valret"
def __init__(self, idx):
self.idx = idx
def get_py_impl(self):
def f(reads):
return reads[0][self.idx]
return PyImpl(valret_func=f)
def shp_apply(self, inputs):
return shape(inputs[0])[self.idx]
def typ_apply(self, inputs):
intype = inputs[0].get_type()
assert isinstance(intype, Tuple)
return inputs[0].get_type()[self.idx]
class MakeTuple(Op):
call_type="valret"
def get_py_impl(self):
def f(inputs):
return tuple(inputs)
return PyImpl(valret_func=f)
def shp_apply(self, inputs):
return tuple(shape(x) for x in inputs)
def typ_apply(self, inputs):
return Tuple(*(x.get_type() for x in inputs))
def unpack(tup):
return [Result(TupleIndex(i),[tup]) for i in xrange(len(tup.get_type()))]
# Assertion and debug operations
# ----------------------------------------------------------------
class Assertion(Op):
def f(reads, write):
np.outer(reads[0], reads[1], out=write)
return PyImpl(inplace_func=f)
def pullback(self, inputs, _output, goutput):
return [goutput.dot(inputs[0]), inputs[1].dot(goutput)]
def shp_apply(self, inputs):
return [size(inputs[0],0), size(inputs[1],0)]
def typ_apply(self, _inputs):
return Tensor(floatX, 2)
# BLAS 1
# ----------------------------------------------------------------
class Dot(Op):
call_type = "valret"
def get_py_impl(self):
def f(reads):
return np.dot(reads[0], reads[1])
return PyImpl(valret_func=f)
def pullback(self, inputs, _output, goutput):
x, y = inputs
return [y*goutput, x*goutput]
def shp_apply(self, _inputs):
return []
def typ_apply(self, _inputs):
return Tensor(floatX, 0)
# Composition
# ----------------------------------------------------------------
def shp_apply(self, _):
return []
def get_py_impl(self):
def f(reads, write):
x = reads[0]
if not x.item():
self.display_error()
return PyImpl(inplace_func=f)
def display_error(self):
print "Stack trace at failed assertion:"
print "**************************"
traceback.print_list(self.stack)
print "**************************"
raise AssertionError("Assertion failed. Message: %s. Above, you can find the stack trace of the failed node"%self.msg)
class DebugFunc(Op):
"""
Call a function when the graph is executed
"""
def __init__(self, yourfunc):
self.yourfunc = yourfunc
def typ_apply(self, _):
return Tensor('i8',0)
def shp_apply(self, _):
return []
def get_py_impl(self):
def f(reads, write):
def fn(*reads):
self.yourfunc(*reads)
return PyImpl(inplace_func=f)
def assert_(x,msg=None):
call_type="valret"
def get_py_impl(self):
def f(inputs):
return tuple(inputs)
return PyImpl(valret_func=f)
def shp_apply(self, inputs):
return tuple(shape(x) for x in inputs)
def typ_apply(self, inputs):
return Tuple(*(x.get_type() for x in inputs))
def unpack(tup):
return [Result(TupleIndex(i),[tup]) for i in xrange(len(tup.get_type()))]
# Assertion and debug operations
# ----------------------------------------------------------------
class Assertion(Op):
"""
Assertion gets evaluated when the graph is executed, and it prints out a stack trace on failure
"""
def __init__(self, msg):
self.stack = traceback.extract_stack()[:-2]
self.msg = msg
def typ_apply(self, inputs):
x, = inputs
assert x.ndim==0 and x.dtype=='i1'
return Tensor('i8',0)
def shp_apply(self, _):
return []
def get_py_impl(self):
def f(reads, write):
x = reads[0]
if not x.item():
return core.NativeCompileInfo(code, closure_triples = poolinfo2closure(self.info),
includes=["cudnn_support.h"], link_flags="-lcudnn -lcudart")
def shp_apply(self, inputs):
info = self.info
batch_size, channels, height, width = cgt.shape(inputs[0])
pooled_height = cgt.ceil_divide(height + 2*info.pad_h - info.kernel_h, info.stride_h)
pooled_width = cgt.ceil_divide(width + 2*info.pad_w - info.kernel_w, info.stride_w)
outshape = [batch_size , channels, pooled_height, pooled_width]
return outshape
def typ_apply(self, input_types):
return input_types[0]
def pullback(self, inputs, output, gout):
return [core.Result(CudnnPoolBackward(self.info), [inputs[0], output, gout])]
class CudnnPoolBackward(core.Op):
available_impls = ("native_gpu",)
def __init__(self, info):
self.info = info
def get_native_compile_info(self, _input_types, devtype):
assert devtype == "gpu"
code = """
CGT_EXPORT_C void $setup(pooling_closure* closure) {setup_cudnn(closure);}
CGT_EXPORT_C void $teardown(pooling_closure* closure) {teardown_cudnn(closure);}
CGT_EXPORT_C void $function(pooling_closure* closure, cgtArray** reads, cgtArray* write) {
if (!closure->handle) setup_cudnn(closure);
performPoolingBackward(closure, reads[0], reads[1], reads[2], write);
}"""
return core.NativeCompileInfo(code, closure_triples = poolinfo2closure(self.info),
includes=["cudnn_support.h"], link_flags="-lcudnn -lcudart")
def shp_apply(self, inputs):
x = reads[0]
shp = map(int,reads[1:])
np.copyto(write, np.fft.fftn(x,shp,self.axes))
return PyImpl(inplace_func=f)
def pullback(self, inputs, _outputs, goutput):
return real(Result(RFFT(self.axes),[goutput]+inputs[1:]))
def shp_apply(self, inputs):
out = shape(inputs[0])
for (ax,sz) in utils.safezip(self.axes, inputs[1:]):
out[ax]=sz
return out
def typ_apply(self, inputs):
assert inputs[0].dtype==floatX
return Tensor(complexX,inputs[0].ndim)
class IRFFT(Op):
def __init__(self, axes):
self.axes = axes
def get_diff(self, _):
return [True]
def get_py_impl(self):
def f(reads, write):
x = reads[0]
shp = map(int,reads[1:])
slis = [slice(0,None) for _ in xrange(x.ndim)]
for (ax,s) in zip(self.axes,shp): slis[ax] = slice(0, s)
np.copyto(write, np.real(np.fft.ifftn(x,axes=self.axes)[slis]))
return PyImpl(inplace_func=f)
def pullback(self, inputs, _outputs, goutput):
return Result(IRFFT(self.axes),[goutput]) # XXX is this right?
def shp_apply(self, inputs):
return shape(inputs[0])
#!/usr/bin/env python
import cgt
for (name,val) in cgt.__dict__.iteritems():
if not name.startswith("_"):
if not val.__doc__:
print "API function %s requires docstring!"%name
for (name,val) in cgt.core.__dict__.iteritems():
if isinstance(val, type) and issubclass(val, cgt.core.Op):
if val.get_native_compile_info == cgt.core.Op.get_native_compile_info:
print "Op %s is missing 'get_native_compile_info'!"%name
from collections import namedtuple
def cudnn_conv_closure(*ints):
return (ctypes.c_int*len(ints))(*ints)
def make_closure(ph, pw, sv, sh):
return [
("ph",ctypes.c_int,ph),
("pw",ctypes.c_int,pw),
("sv",ctypes.c_int,sv),
("sh",ctypes.c_int,sh),
("handle",ctypes.c_void_p,0),
("stream",ctypes.c_void_p,0),
]
class CudnnConvForward(core.Op):
available_impls = ("native_gpu",)
def __init__(self, ph, pw, sv, sh):
"pad_height, pad_width, stride_vertical, stride_horizontal"
self.ph = ph
self.pw = pw
self.sv = sv
self.sh = sh
def get_native_compile_info(self, _input_types, devtype):
assert devtype=="gpu"
code = """
CGT_EXPORT_C void $setup(conv_closure* closure) {setup_cudnn(closure);}
CGT_EXPORT_C void $teardown(conv_closure* closure) {teardown_cudnn(closure);}
CGT_EXPORT_C void $function(conv_closure* closure, cgtArray** reads, cgtArray* write) {
if (!closure->handle) setup_cudnn(closure);
performConvForward(closure, reads[0], reads[1], reads[2], write);
def f(reads, write):
x,y = reads
if self.tA: x = x.T
x.dot(y, out=write)
return PyImpl(inplace_func=f)
def get_replacement(self, inputs, analysis):
if inputs[1] in analysis["node2sv"]:
return sum(inputs[0],0 if self.tA else 1) * analysis["node2sv"][inputs[1]]
def pullback(self, inputs, _output, goutput):
return [outer(goutput,inputs[1]), Result(Mul21(not self.tA), [inputs[0],goutput])]
def shp_apply(self, inputs):
return [size(inputs[0], 1 if self.tA else 0)]
def typ_apply(self, inputs):
return Tensor(inputs[0].get_dtype(), 1)
class Mul22(Op):
c_extra_includes = ["cblas.h"]
c_extra_link_flags = "-lblas"
def __init__(self, tA, tB):
self.tA = tA
self.tB = tB
def get_py_impl(self):
def f(reads, write):
x,y = reads
if self.tA: x = x.T
if self.tB: y = y.T
x.dot(y, out=write)
return PyImpl(inplace_func=f)
def pullback(self, inputs, output, goutput):
return [Result(Mul22(False, not self.tB), [goutput, inputs[1]]),
Result(Mul22(not self.tA, False), [inputs[0], goutput])]
def shp_apply(self, inputs):
return _get_value_type(self.value)
def get_hash(self):
return str(id(self))
def get_closure(self, _):
raise NotImplementedError # TODO
def c_code(self, inputs):
raise RuntimeError # move to get_*_impl
return """
typedef struct constcl {void* ptr} constcl;
void CGT_FUNCNAME(void* cldata, cgt_array** io) {
io[0]->data = ((constcl*)cldata)->ptr;
io[0]->ownsdata = false;
}
"""
class Fill(Op):
"""
(value, shape...) -> array filled with `value`, with shape `shape`
"""
def __init__(self, value):
self.value = _as_valid_array(value)
assert self.value.ndim ==0
self.dtype = self.value.dtype
assert self.value.ndim==0
def get_diff(self, num_inputs):
return [False]*num_inputs
def get_name(self):
return "fill{%g}"%self.value
def get_py_impl(self):
def f(reads, write):
write[...] = self.value
return PyImpl(inplace_func=f)