Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.
def phi_e(self, inputs):
nodes, edges, u, index1, index2, gnode, gbond = inputs
index1 = tf.reshape(index1, (-1,))
index2 = tf.reshape(index2, (-1,))
fs = tf.gather(nodes, index1, axis=1)
fr = tf.gather(nodes, index2, axis=1)
concate_node = tf.concat([fs, fr], axis=-1)
u_expand = repeat_with_index(u, gbond, axis=1)
concated = tf.concat([concate_node, edges, u_expand], axis=-1)
return self._mlp(concated, self.phi_e_weights, self.phi_e_biases)
def phi_v(self, b_ei_p, inputs):
nodes, edges, u, index1, index2, gnode, gbond = inputs
u_expand = repeat_with_index(u, gnode, axis=1)
concated = tf.concat([b_ei_p, nodes, u_expand], axis=-1)
return self._mlp(concated, self.phi_v_weights, self.phi_v_biases)
feature_graph_index = tf.reshape(feature_graph_index, (-1,))
_, _, count = tf.unique_with_counts(feature_graph_index)
m = kb.dot(features, self.m_weight)
if self.use_bias:
m += self.m_bias
self.h = tf.zeros(tf.stack(
[tf.shape(input=features)[0], tf.shape(input=count)[0], self.n_hidden]))
self.c = tf.zeros(tf.stack(
[tf.shape(input=features)[0], tf.shape(input=count)[0], self.n_hidden]))
q_star = tf.zeros(tf.stack(
[tf.shape(input=features)[0], tf.shape(input=count)[0], 2 * self.n_hidden]))
for i in range(self.T):
self.h, c = self._lstm(q_star, self.c)
e_i_t = tf.reduce_sum(
input_tensor=m * repeat_with_index(self.h, feature_graph_index), axis=-1)
exp = tf.exp(e_i_t)
# print(exp.shape)
seg_sum = tf.transpose(
a=tf.math.segment_sum(
tf.transpose(a=exp, perm=[1, 0]),
feature_graph_index),
perm=[1, 0])
seg_sum = tf.expand_dims(seg_sum, axis=-1)
# print(seg_sum.shape)
a_i_t = exp / tf.squeeze(
repeat_with_index(seg_sum, feature_graph_index))
# print(a_i_t.shape)
r_t = tf.transpose(a=tf.math.segment_sum(
tf.transpose(a=tf.multiply(m, a_i_t[:, :, None]), perm=[1, 0, 2]),
feature_graph_index), perm=[1, 0, 2])
q_star = kb.concatenate([self.h, r_t], axis=-1)
[tf.shape(input=features)[0], tf.shape(input=count)[0], 2 * self.n_hidden]))
for i in range(self.T):
self.h, c = self._lstm(q_star, self.c)
e_i_t = tf.reduce_sum(
input_tensor=m * repeat_with_index(self.h, feature_graph_index), axis=-1)
exp = tf.exp(e_i_t)
# print(exp.shape)
seg_sum = tf.transpose(
a=tf.math.segment_sum(
tf.transpose(a=exp, perm=[1, 0]),
feature_graph_index),
perm=[1, 0])
seg_sum = tf.expand_dims(seg_sum, axis=-1)
# print(seg_sum.shape)
a_i_t = exp / tf.squeeze(
repeat_with_index(seg_sum, feature_graph_index))
# print(a_i_t.shape)
r_t = tf.transpose(a=tf.math.segment_sum(
tf.transpose(a=tf.multiply(m, a_i_t[:, :, None]), perm=[1, 0, 2]),
feature_graph_index), perm=[1, 0, 2])
q_star = kb.concatenate([self.h, r_t], axis=-1)
return q_star