Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.
def _ensemble_predict(self, obs):
actor_outputs = []
for i in range(self.ensemble_num):
actor_outputs.append(self.actors[i].predict(obs))
batch_actions = layers.concat(actor_outputs, axis=0)
batch_obs = layers.expand(obs, expand_times=[self.ensemble_num, 1])
critic_outputs = []
for i in range(self.ensemble_num):
critic_output = self.critics[i].predict(batch_obs, batch_actions)
critic_output = layers.unsqueeze(critic_output, axes=[1])
critic_outputs.append(critic_output)
score_matrix = layers.concat(critic_outputs, axis=1)
# Normalize scores given by each critic
sum_critic_score = layers.reduce_sum(
score_matrix, dim=0, keep_dim=True)
sum_critic_score = layers.expand(
sum_critic_score, expand_times=[self.ensemble_num, 1])
norm_score_matrix = score_matrix / sum_critic_score
actions_mean_score = layers.reduce_mean(
norm_score_matrix, dim=1, keep_dim=True)
best_score_id = layers.argmax(actions_mean_score, axis=0)
best_score_id = layers.cast(best_score_id, dtype='int32')
ensemble_predict_action = layers.gather(batch_actions, best_score_id)
ensemble_predict_action = layers.squeeze(
ensemble_predict_action, axes=[0])
return ensemble_predict_action
2. For each actor, will calculate its score by
average scores given by all critics
3. choose action of the actor whose score is best
"""
actor_outputs = []
for i in range(self.ensemble_num):
actor_outputs.append(self.models[i].policy(obs))
batch_actions = layers.concat(actor_outputs, axis=0)
batch_obs = layers.expand(obs, expand_times=[self.ensemble_num, 1])
critic_outputs = []
for i in range(self.ensemble_num):
critic_output = self.models[i].value(batch_obs, batch_actions)
critic_output = layers.unsqueeze(critic_output, axes=[1])
critic_outputs.append(critic_output)
score_matrix = layers.concat(critic_outputs, axis=1)
# Normalize scores given by each critic
sum_critic_score = layers.reduce_sum(
score_matrix, dim=0, keep_dim=True)
sum_critic_score = layers.expand(
sum_critic_score, expand_times=[self.ensemble_num, 1])
norm_score_matrix = score_matrix / sum_critic_score
actions_mean_score = layers.reduce_mean(
norm_score_matrix, dim=1, keep_dim=True)
best_score_id = layers.argmax(actions_mean_score, axis=0)
best_score_id = layers.cast(best_score_id, dtype='int32')
ensemble_predict_action = layers.gather(batch_actions, best_score_id)
return ensemble_predict_action
def value(self, obs, action):
real_obs = layers.slice(
obs, axes=[1], starts=[0], ends=[-self.vel_obs_dim])
# target related fetures
vel_obs = layers.slice(
obs, axes=[1], starts=[-self.vel_obs_dim], ends=[self.obs_dim])
hid0 = self.fc0(real_obs)
hid1 = self.fc1(hid0)
vel_hid0 = self.vel_fc0(vel_obs)
vel_hid1 = self.vel_fc1(vel_hid0)
a1 = self.act_fc0(action)
concat = layers.concat([hid1, a1, vel_hid1], axis=1)
hid2 = self.fc2(concat)
Q = self.fc3(hid2)
Q = layers.squeeze(Q, axes=[1])
return Q
def _ensemble_predict(self, obs):
actor_outputs = []
for i in range(self.ensemble_num):
actor_outputs.append(self.actors[i].predict(obs))
batch_actions = layers.concat(actor_outputs, axis=0)
batch_obs = layers.expand(obs, expand_times=[self.ensemble_num, 1])
critic_outputs = []
for i in range(self.ensemble_num):
critic_output = self.critics[i].predict(batch_obs, batch_actions)
critic_output = layers.unsqueeze(critic_output, axes=[1])
critic_outputs.append(critic_output)
score_matrix = layers.concat(critic_outputs, axis=1)
# Normalize scores given by each critic
sum_critic_score = layers.reduce_sum(
score_matrix, dim=0, keep_dim=True)
sum_critic_score = layers.expand(
sum_critic_score, expand_times=[self.ensemble_num, 1])
norm_score_matrix = score_matrix / sum_critic_score
def predict(self, obs):
real_obs = layers.slice(
obs, axes=[1], starts=[0], ends=[self.obs_dim - self.vel_obs_dim])
vel_obs = layers.slice(
obs, axes=[1], starts=[-self.vel_obs_dim], ends=[self.obs_dim])
hid0 = self.fc0(real_obs)
hid1 = self.fc1(hid0)
vel_hid0 = self.vel_fc0(vel_obs)
vel_hid1 = self.vel_fc1(vel_hid0)
concat = layers.concat([hid1, vel_hid1], axis=1)
hid2 = self.fc2(concat)
means = self.fc3(hid2)
return means
def policy(self, obs):
real_obs = layers.slice(
obs, axes=[1], starts=[0], ends=[-self.vel_obs_dim])
# target related fetures
vel_obs = layers.slice(
obs, axes=[1], starts=[-self.vel_obs_dim], ends=[self.obs_dim])
hid0 = self.fc0(real_obs)
hid1 = self.fc1(hid0)
vel_hid0 = self.vel_fc0(vel_obs)
vel_hid1 = self.vel_fc1(vel_hid0)
concat = layers.concat([hid1, vel_hid1], axis=1)
hid2 = self.fc2(concat)
means = self.fc3(hid2)
return means
def value(self, obs, act):
hid1 = self.fc1(obs)
concat = layers.concat([hid1, act], axis=1)
hid2 = self.fc2(concat)
Q = self.fc3(hid2)
Q = layers.squeeze(Q, axes=[1])
return Q