Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.
def softmax_256_loss(x, l):
# Based on PixelRNN paper (https://arxiv.org/pdf/1601.06759v3.pdf)
# x: (3,32,32)
# l: (3x256, 32, 32)
x = (x + 1.) * 127.5
x = ct.one_hot(x, 256, axis=0) # (256, 3,32,32)
l = ct.reshape(l, (256, 3,32,32))
return ct.reduce_sum(ct.cross_entropy_with_softmax(l, x, axis=0))
"""
Porting discretized_mix_logistic_loss from https://github.com/openai/pixel-cnn/blob/master/pixel_cnn_pp/nn.py.
log-likelihood for mixture of discretized logistics, assumes the data has been rescaled to [-1,1] interval
"""
xs = x.shape # true image (i.e. labels) to regress to. CHW C=3
ls = l.shape # predicted distribution. DHW D=100
nr_mix = int(ls[0] / 10) # here and below: unpacking the params of the mixture of logistics
logit_probs = l[:nr_mix, :, :]
l = ct.reshape(l[nr_mix:100, :, :], (nr_mix*3,) + xs)
means = l[:nr_mix, :, :, :]
log_scales = ct.element_max(l[nr_mix:2*nr_mix, :, :, :], -7.)
coeffs = ct.tanh(l[2*nr_mix:3*nr_mix, :, :, :])
# here and below: getting the means and adjusting them based on preceding sub-pixels
x = ct.reshape(x, (1,)+xs) + ct.constant(value=0., shape=(nr_mix,)+xs)
m2 = ct.reshape(means[:, 1, :, :] + coeffs[:, 0, :, :] * x[:, 0, :, :], (nr_mix,1,xs[1],xs[2]))
m3 = ct.reshape(means[:, 2, :, :] + coeffs[:, 1, :, :] * x[:, 0, :, :] + coeffs[:, 2, :, :] * x[:, 1, :, :], (nr_mix,1,xs[1],xs[2]))
means = ct.splice(ct.reshape(means[:,0,:,:], (nr_mix,1,xs[1],xs[2])), m2, m3, axis=1)
centered_x = x - means
inv_stdv = ct.exp(-log_scales)
plus_in = inv_stdv * (centered_x + 1./255.)
cdf_plus = ct.sigmoid(plus_in)
min_in = inv_stdv * (centered_x - 1./255.)
cdf_min = ct.sigmoid(min_in)
log_cdf_plus = plus_in - ct.softplus(plus_in) # log probability for edge case of 0 (before scaling)
log_one_minus_cdf_min = -ct.softplus(min_in) # log probability for edge case of 255 (before scaling)
cdf_delta = cdf_plus - cdf_min # probability for all other cases
mid_in = inv_stdv * centered_x
log_pdf_mid = mid_in - log_scales - 2. * ct.softplus(mid_in) # log probability in the center of the bin, to be used in extreme cases (not actually used in our code)
# now select the right output: left edge case, right edge case, normal case, extremely low prob case (doesn't actually happen for us)
log-likelihood for mixture of discretized logistics, assumes the data has been rescaled to [-1,1] interval
"""
xs = x.shape # true image (i.e. labels) to regress to. CHW C=3
ls = l.shape # predicted distribution. DHW D=100
nr_mix = int(ls[0] / 10) # here and below: unpacking the params of the mixture of logistics
logit_probs = l[:nr_mix, :, :]
l = ct.reshape(l[nr_mix:100, :, :], (nr_mix*3,) + xs)
means = l[:nr_mix, :, :, :]
log_scales = ct.element_max(l[nr_mix:2*nr_mix, :, :, :], -7.)
coeffs = ct.tanh(l[2*nr_mix:3*nr_mix, :, :, :])
# here and below: getting the means and adjusting them based on preceding sub-pixels
x = ct.reshape(x, (1,)+xs) + ct.constant(value=0., shape=(nr_mix,)+xs)
m2 = ct.reshape(means[:, 1, :, :] + coeffs[:, 0, :, :] * x[:, 0, :, :], (nr_mix,1,xs[1],xs[2]))
m3 = ct.reshape(means[:, 2, :, :] + coeffs[:, 1, :, :] * x[:, 0, :, :] + coeffs[:, 2, :, :] * x[:, 1, :, :], (nr_mix,1,xs[1],xs[2]))
means = ct.splice(ct.reshape(means[:,0,:,:], (nr_mix,1,xs[1],xs[2])), m2, m3, axis=1)
centered_x = x - means
inv_stdv = ct.exp(-log_scales)
plus_in = inv_stdv * (centered_x + 1./255.)
cdf_plus = ct.sigmoid(plus_in)
min_in = inv_stdv * (centered_x - 1./255.)
cdf_min = ct.sigmoid(min_in)
log_cdf_plus = plus_in - ct.softplus(plus_in) # log probability for edge case of 0 (before scaling)
log_one_minus_cdf_min = -ct.softplus(min_in) # log probability for edge case of 255 (before scaling)
cdf_delta = cdf_plus - cdf_min # probability for all other cases
mid_in = inv_stdv * centered_x
log_pdf_mid = mid_in - log_scales - 2. * ct.softplus(mid_in) # log probability in the center of the bin, to be used in extreme cases (not actually used in our code)
# now select the right output: left edge case, right edge case, normal case, extremely low prob case (doesn't actually happen for us)
# this is what we are really doing, but using the robust version below for extreme cases in other applications and to avoid NaN issue with tf.select()
n_outputs_per_class = int(np.ceil(label_dim / label_classes))
target_class = C.floor((label_index + 0.5) / n_outputs_per_class)
target_output_in_class = C.round(label_index - target_class * n_outputs_per_class)
w1 = parameter(shape=(input_dim, label_classes), init=C.glorot_normal(), name='hsoftmax_w1')
b1 = parameter(shape=(label_classes), init=C.glorot_normal(), name='hsoftmax_b1')
w2s = parameter(shape=(label_classes, input_dim, n_outputs_per_class,), init=C.glorot_normal(), name='hsoftmax_w2s')
b2s = parameter(shape=(label_classes, n_outputs_per_class,), init=C.glorot_normal(), name='hsoftmax_b2s')
class_probs = softmax(b1 + times(input_var, w1))
# TODO: fix the bug in backprop for sparse, and use sparse embedding to accelerate
target_class_one_hot = C.one_hot(target_class, num_classes=label_classes, sparse_output=False)
w2 = C.reshape(C.times(target_class_one_hot, w2s, output_rank=2), [input_dim, -1])
b2 = C.reshape(times(target_class_one_hot, b2s, output_rank=1), [-1])
probs_in_class = softmax(b2 + times(input_var, w2))
prob_in_class = C.times_transpose(C.one_hot(target_output_in_class, num_classes=n_outputs_per_class, sparse_output=False), probs_in_class)
class_prob = C.times_transpose(C.one_hot(target_class, num_classes=label_classes, sparse_output=False), class_probs)
output_prob = prob_in_class * class_prob
# this is for calculating all the outputs' probabilities
all_probs = []
for i in range(label_classes):
ci = C.constant(i)
ci_one_hot = C.one_hot(ci, num_classes=label_classes, sparse_output=False)
w2a = C.times(ci_one_hot, w2s, output_rank=2)
b2a = C.times(ci_one_hot, b2s, output_rank=1)
probs_in_classa = C.softmax(b2a + times(input_var, w2a))
class_proba = C.times_transpose(ci_one_hot, class_probs)
def _reshape_batch(x, shape):
# there is a bug in cntk 2.1's unpack_batch implementation
if hasattr(C, 'unpack_batch') and _get_cntk_version() >= 2.2:
const_a = C.unpack_batch(x)
const_a = C.reshape(const_a, shape)
return C.to_batch(const_a)
else:
return C.user_function(ReshapeBatch(x, shape[1:]))
Arguments:
x: input tensor
nmix (int): number of mixture
ndim (int): number of dimension of gaussian
Returns:
tuple
"""
if len(x.shape) != 1:
raise ValueError("Must be a 1d tensor, but input has shape {0}".format(x.shape))
alpha = C.softmax(C.slice(x, 0, 0, nmix), name='alpha')
sigma = C.exp(C.slice(x, 0, nmix, 2 * nmix), name='sigma') # common variance for all components in single gaussian kernel
mu = C.reshape(C.slice(x, 0, 2 * nmix, (ndim + 2) * nmix), shape=(nmix, ndim), name='mu')
return alpha, mu, sigma
if left_shape[0] != right_shape[0]:
raise ValueError("first axis of left operand and right operand must be the same")
if (left_shape[0] < 0 or right_shape[0] < 0) and seq_axis_present:
raise ValueError("Static batch axis cannot be a free axis when dynamic sequence axis is also present")
# Combine dynamic sequence axis and static batch axis
if not seq_axis_present:
left_unpacked = left
right_unpacked = right
else:
left_unpacked = C.sequence.unpack(left, padding_value=0, no_mask_output=True)
right_unpacked = C.sequence.unpack(right, padding_value=0, no_mask_output=True)
left_unpacked = C.reshape(left_unpacked, (-1,) + left_shape[1:])
right_unpacked = C.reshape(right_unpacked, (-1,) + right_shape[1:])
# Fold static batch axis into dynamic sequence axis
left_folded = C.to_sequence(left_unpacked) # do not set sequence length as batch axis has been folded in
right_folded = C.to_sequence_like(right_unpacked, left_folded) # seq_length / axis set here to tell cntk they have the same seq axis
# Matrix Multiply when no static batch axis is present
result = C.times(left_folded, right_folded, output_rank=output_rank, infer_input_rank_to_map=infer_input_rank_to_map)
# Split dynamic sequence axis back to original dynamic sequence and static batch axis
result_unpacked = C.sequence.unpack(result, padding_value=0, no_mask_output=True)
if not seq_axis_present:
result_packed = C.reshape(result_unpacked, (static_batch_axis, ) + result.shape)
else:
result_unfolded = C.reshape(result_unpacked, (-1, static_batch_axis) + result.shape)
result_packed = C.to_sequence_like(result_unfolded, left)
if first_run:
V = ct.parameter(output_channels_shape + x_channels_shape + filter_shape, init=init, name='V'); set_parameter(scope, 'V', V)
g = ct.parameter(output_channels_shape, init=global_g_init, name='g'); set_parameter(scope, 'g', g)
b = ct.parameter(output_channels_shape, name='b'); set_parameter(scope, 'b', b)
# use weight normalization (Salimans & Kingma, 2016)
V_norm = l2_normalize(V, axes=(1, 2, 3))
x_init = ct.convolution(V_norm, x, strides=x_channels_shape + strides, auto_padding=paddings)
m_init, v_init = moments(x_init, axes=(ct.Axis.default_batch_axis(),1,2))
scale_init = init_scale / ct.sqrt(v_init + 1e-8)
g_new = ct.assign(g, scale_init)
b_new = ct.assign(b, -m_init*scale_init)
x_init = ct.reshape(scale_init, (num_filters, 1, 1))*(x_init-ct.reshape(m_init, (num_filters, 1, 1))) + ct.reshape(g_new + b_new, (num_filters, 1, 1))*0
if nonlinearity is not None:
x_init = nonlinearity(x_init)
return x_init
else:
V,g,b = get_parameters(scope, ['V','g','b'])
# use weight normalization (Salimans & Kingma, 2016)
V_norm = l2_normalize(V, axes=(1, 2, 3))
W = ct.reshape(g, (num_filters, 1, 1, 1)) * V_norm
x = ct.convolution(W, x, strides=x_channels_shape + strides, auto_padding=paddings) + ct.reshape(b, (num_filters, 1, 1))
if nonlinearity is not None:
x = nonlinearity(x)
return x
_axis = [_ + len(shape) if _ < 0 else _ for _ in axis]
if shape.count(C.InferredDimension) > 1 or shape.count(C.FreeDimension) > 1:
result = x
for index in sorted(_axis, reverse=True):
result = C.reshape(result,
shape=(),
begin_axis=index,
end_axis=index + 1)
return result
else:
for index in sorted(_axis, reverse=True):
del shape[index]
shape = [C.InferredDimension if _ == C.FreeDimension else _ for _ in shape]
return C.reshape(x, shape)
# so W * [h; u; h.* u] becomes w1 * h + w2 * u + w3 * (h.*u)
ws1 = C.parameter(shape=(2 * self.hidden_dim, 1), init=C.glorot_uniform())
ws2 = C.parameter(shape=(2 * self.hidden_dim, 1), init=C.glorot_uniform())
ws3 = C.parameter(shape=(1, 2 * self.hidden_dim), init=C.glorot_uniform())
att_bias = C.parameter(shape=(), init=0)
wh = C.times (c_processed, ws1)
wu = C.reshape(C.times (qvw, ws2), (-1,))
whu = C.reshape(C.reduce_sum(c_processed * C.sequence.broadcast_as(qvw * ws3, c_processed), axis=1), (-1,))
S = wh + whu + C.sequence.broadcast_as(wu, c_processed) + att_bias
# mask out values outside of Query, and fill in gaps with -1e+30 as neutral value for both reduce_log_sum_exp and reduce_max
qvw_mask_expanded = C.sequence.broadcast_as(qvw_mask, c_processed)
S = C.element_select(qvw_mask_expanded, S, C.constant(-1e+30))
q_attn = C.reshape(C.softmax(S), (-1,1))
#q_attn = print_node(q_attn)
c2q = C.reshape(C.reduce_sum(C.sequence.broadcast_as(qvw, q_attn) * q_attn, axis=0),(-1))
max_col = C.reduce_max(S)
c_attn = C.sequence.softmax(max_col)
htilde = C.sequence.reduce_sum(c_processed * c_attn)
q2c = C.sequence.broadcast_as(htilde, c_processed)
q2c_out = c_processed * q2c
att_context = C.splice(c_processed, c2q, c_processed * c2q, q2c_out)
return C.as_block(
att_context,
[(c_processed, context), (q_processed, query)],
'attention_layer',
'attention_layer')