Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.
def broadcast(opname,x,y,bcpat):
return cgt.broadcast(opname, x, y, bcpat) if isinstance(x, core.Node) else eval("x %s y"%opname)
def circ_conv_1d(wg_bhn, s_bh3, axis=2):
"VERY inefficient way to implement circular convolution for the special case of filter size 3"
assert axis == 2
n = cgt.size(wg_bhn,2)
wback = cgt.concatenate([wg_bhn[:,:,n-1:n], wg_bhn[:,:,:n-1]], axis=2)
w = wg_bhn
wfwd = cgt.concatenate([wg_bhn[:,:,1:n], wg_bhn[:,:,0:1]], axis=2)
return cgt.broadcast("*", s_bh3[:,:,0:1] , wback, "xx1,xxx")\
+ cgt.broadcast("*", s_bh3[:,:,1:2] , w, "xx1,xxx")\
+ cgt.broadcast("*", s_bh3[:,:,2:3] , wfwd, "xx1,xxx")
# take inner product along memory axis k * M
numer_bhn = cgt.einsum("bhm,bnm->bhn", k_bhm, M_bnm)
# compute denominator |k| * |m|
denom_bhn = cgt.broadcast("*",
cgt.norm(k_bhm, axis=2, keepdims=True), # -> shape bh1
cgt.norm(M_bnm, axis=2, keepdims=True).transpose([0,2,1]), # -> bn1 -> b1n
"xx1,x1x"
)
csim_bhn = numer_bhn / denom_bhn
assert infer_shape(csim_bhn) == (opt.b, 2*opt.h, opt.n)
# scale by beta
tmp_bhn = cgt.broadcast("*", beta_bh[:,:,None], csim_bhn, "xx1,xxx")
wc_bhn = sum_normalize2(cgt.exp( tmp_bhn ))
# Interpolation
g_bh1 = g_bh[:,:,None]
wg_bhn = cgt.broadcast("*", wprev_bhn, (1 - g_bh1), "xxx,xx1") \
+ cgt.broadcast("*", wc_bhn, g_bh1, "xxx,xx1")
# Shift
wtil_bhn = circ_conv_1d(wg_bhn, s_bh3, axis=2)
# Sharpening
wfin_bhn = sum_normalize2(cgt.broadcast("**", wtil_bhn, gamma_bh.reshape([opt.b,2*opt.h,1]), "xxx,xx1"))
b,h,n = opt.b, 2*opt.h, opt.n
assert infer_shape(wtil_bhn) == (b,h,n)
assert infer_shape(gamma_bh) == (b,h)
assert infer_shape(gamma_bh[:,:,None]) == (b,h,1)
return wfin_bhn
def __call__(self, Y, U):
if Y.ndim > (self.axis + 1):
Y = Y.reshape(Y.shape[:self.axis] + [cgt.mul_multi(Y.shape[self.axis:])])
outer_YU = cgt.broadcast('*',
Y.dimshuffle(range(Y.ndim) + ['x']),
U.dimshuffle([0] + ['x']*self.axis + [1]),
''.join(['x']*Y.ndim + ['1', ',', 'x'] + ['1']*self.axis + ['x']))
bilinear = cgt.dot(outer_YU.reshape((outer_YU.shape[0], cgt.mul_multi(outer_YU.shape[1:]))),
self.M.reshape((self.y_dim, self.y_dim * self.u_dim)).T)
if self.axis > 1:
bilinear = bilinear.reshape((-1,) + self.y_shape[:self.axis-1] + (self.y_dim,))
linear = cgt.dot(U, self.N.T)
if self.axis > 1:
linear = linear.dimshuffle([0] + ['x']*(self.axis-1) + [1])
activation = bilinear + linear
if self.b is not None:
activation += cgt.broadcast('+',
activation,
self.b.dimshuffle(['x']*self.axis + [0]),
''.join(['x']*activation.ndim + [','] + ['1']*(activation.ndim-1) + ['x']))
# Content addressing
# Cosine similarity
# take inner product along memory axis k * M
numer_bhn = cgt.einsum("bhm,bnm->bhn", k_bhm, M_bnm)
# compute denominator |k| * |m|
denom_bhn = cgt.broadcast("*",
cgt.norm(k_bhm, axis=2, keepdims=True), # -> shape bh1
cgt.norm(M_bnm, axis=2, keepdims=True).transpose([0,2,1]), # -> bn1 -> b1n
"xx1,x1x"
)
csim_bhn = numer_bhn / denom_bhn
assert infer_shape(csim_bhn) == (opt.b, 2*opt.h, opt.n)
# scale by beta
tmp_bhn = cgt.broadcast("*", beta_bh[:,:,None], csim_bhn, "xx1,xxx")
wc_bhn = sum_normalize2(cgt.exp( tmp_bhn ))
# Interpolation
g_bh1 = g_bh[:,:,None]
wg_bhn = cgt.broadcast("*", wprev_bhn, (1 - g_bh1), "xxx,xx1") \
+ cgt.broadcast("*", wc_bhn, g_bh1, "xxx,xx1")
# Shift
wtil_bhn = circ_conv_1d(wg_bhn, s_bh3, axis=2)
# Sharpening
wfin_bhn = sum_normalize2(cgt.broadcast("**", wtil_bhn, gamma_bh.reshape([opt.b,2*opt.h,1]), "xxx,xx1"))
b,h,n = opt.b, 2*opt.h, opt.n
assert infer_shape(wtil_bhn) == (b,h,n)
assert infer_shape(gamma_bh) == (b,h)
assert infer_shape(gamma_bh[:,:,None]) == (b,h,1)
return wfin_bhn
def circ_conv_1d(wg_bhn, s_bh3, axis=2):
"VERY inefficient way to implement circular convolution for the special case of filter size 3"
assert axis == 2
n = cgt.size(wg_bhn,2)
wback = cgt.concatenate([wg_bhn[:,:,n-1:n], wg_bhn[:,:,:n-1]], axis=2)
w = wg_bhn
wfwd = cgt.concatenate([wg_bhn[:,:,1:n], wg_bhn[:,:,0:1]], axis=2)
return cgt.broadcast("*", s_bh3[:,:,0:1] , wback, "xx1,xxx")\
+ cgt.broadcast("*", s_bh3[:,:,1:2] , w, "xx1,xxx")\
+ cgt.broadcast("*", s_bh3[:,:,2:3] , wfwd, "xx1,xxx")
dtype=cgt.floatX
)
if activation == cgt.sigmoid:
W_values *= 4
W = cgt.shared(W_values, name=prefix+"_W")
if b is None:
b_values = np.zeros((n_out,), dtype=cgt.floatX)
b = cgt.shared(b_values, name=prefix+"_b")
self.W = W
self.b = b
# XXX broadcast api may change
lin_output = cgt.broadcast("+", cgt.dot(input, self.W),
cgt.dimshuffle(self.b, ["x", 0]), "xx,1x")
self.output = (
lin_output if activation is None
else activation(lin_output)
)
# parameters of the model
self.params = [self.W, self.b]
numer_bhn = cgt.einsum("bhm,bnm->bhn", k_bhm, M_bnm)
# compute denominator |k| * |m|
denom_bhn = cgt.broadcast("*",
cgt.norm(k_bhm, axis=2, keepdims=True), # -> shape bh1
cgt.norm(M_bnm, axis=2, keepdims=True).transpose([0,2,1]), # -> bn1 -> b1n
"xx1,x1x"
)
csim_bhn = numer_bhn / denom_bhn
assert infer_shape(csim_bhn) == (opt.b, 2*opt.h, opt.n)
# scale by beta
tmp_bhn = cgt.broadcast("*", beta_bh[:,:,None], csim_bhn, "xx1,xxx")
wc_bhn = sum_normalize2(cgt.exp( tmp_bhn ))
# Interpolation
g_bh1 = g_bh[:,:,None]
wg_bhn = cgt.broadcast("*", wprev_bhn, (1 - g_bh1), "xxx,xx1") \
+ cgt.broadcast("*", wc_bhn, g_bh1, "xxx,xx1")
# Shift
wtil_bhn = circ_conv_1d(wg_bhn, s_bh3, axis=2)
# Sharpening
wfin_bhn = sum_normalize2(cgt.broadcast("**", wtil_bhn, gamma_bh.reshape([opt.b,2*opt.h,1]), "xxx,xx1"))
b,h,n = opt.b, 2*opt.h, opt.n
assert infer_shape(wtil_bhn) == (b,h,n)
assert infer_shape(gamma_bh) == (b,h)
assert infer_shape(gamma_bh[:,:,None]) == (b,h,1)
return wfin_bhn