Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.
critic_outputs.append(critic_output)
score_matrix = layers.concat(critic_outputs, axis=1)
# Normalize scores given by each critic
sum_critic_score = layers.reduce_sum(
score_matrix, dim=0, keep_dim=True)
sum_critic_score = layers.expand(
sum_critic_score, expand_times=[self.ensemble_num, 1])
norm_score_matrix = score_matrix / sum_critic_score
actions_mean_score = layers.reduce_mean(
norm_score_matrix, dim=1, keep_dim=True)
best_score_id = layers.argmax(actions_mean_score, axis=0)
best_score_id = layers.cast(best_score_id, dtype='int32')
ensemble_predict_action = layers.gather(batch_actions, best_score_id)
ensemble_predict_action = layers.squeeze(
ensemble_predict_action, axes=[0])
return ensemble_predict_action
def value(self, obs):
"""
Args:
obs: A float32 tensor of shape [B, C, H, W]
Returns:
value: B
"""
obs = obs / 255.0
conv1 = self.conv1(obs)
conv2 = self.conv2(conv1)
conv3 = self.conv3(conv2)
flatten = layers.flatten(conv3, axis=1)
value = self.value_fc(flatten)
value = layers.squeeze(value, axes=[1])
return value
def predict(self, obs, action):
real_obs = layers.slice(
obs, axes=[1], starts=[0], ends=[self.obs_dim - self.vel_obs_dim])
vel_obs = layers.slice(
obs, axes=[1], starts=[-self.vel_obs_dim], ends=[self.obs_dim])
hid0 = self.fc0(real_obs)
hid1 = self.fc1(hid0)
vel_hid0 = self.vel_fc0(vel_obs)
vel_hid1 = self.vel_fc1(vel_hid0)
a1 = self.act_fc0(action)
concat = layers.concat([hid1, a1, vel_hid1], axis=1)
hid2 = self.fc2(concat)
V = self.fc3(hid2)
V = layers.squeeze(V, axes=[1])
return V
Returns:
policy_logits: B * ACT_DIM
values: B
"""
obs = obs / 255.0
conv1 = self.conv1(obs)
conv2 = self.conv2(conv1)
conv3 = self.conv3(conv2)
flatten = layers.flatten(conv3, axis=1)
fc_output = self.fc(flatten)
policy_logits = self.policy_fc(fc_output)
values = self.value_fc(fc_output)
values = layers.squeeze(values, axes=[1])
return policy_logits, values
def value(self, obs, act):
hid1 = self.fc1(obs)
concat = layers.concat([hid1, act], axis=1)
hid2 = self.fc2(concat)
Q = self.fc3(hid2)
Q = layers.squeeze(Q, axes=[1])
return Q
Args:
obs: A float32 tensor of shape [B, C, H, W]
Returns:
values: B
"""
obs = obs / 255.0
conv1 = self.conv1(obs)
conv2 = self.conv2(conv1)
conv3 = self.conv3(conv2)
flatten = layers.flatten(conv3, axis=1)
fc_output = self.fc(flatten)
values = self.value_fc(fc_output)
values = layers.squeeze(values, axes=[1])
return values
def value(self, obs):
hid1 = self.fc1(obs)
hid2 = self.fc2(hid1)
hid3 = self.fc3(hid2)
V = self.fc4(hid3)
V = layers.squeeze(V, axes=[])
return V