Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.
def test_is_integer(case, expected):
assert _.is_integer(case) == expected
if ps.is_empty(self.fc_hid_layers):
self.rnn_input_dim = self.in_dim
else:
fc_dims = [self.in_dim] + self.fc_hid_layers
self.fc_model = net_util.build_fc_model(fc_dims, self.hid_layers_activation)
self.rnn_input_dim = fc_dims[-1]
# RNN model
self.rnn_model = getattr(nn, net_util.get_nn_name(self.cell_type))(
input_size=self.rnn_input_dim,
hidden_size=self.rnn_hidden_size,
num_layers=self.rnn_num_layers,
batch_first=True, bidirectional=self.bidirectional)
# tails. avoid list for single-tail for compute speed
if ps.is_integer(self.out_dim):
self.model_tail = net_util.build_fc_model([self.rnn_hidden_size, self.out_dim], self.out_layer_activation)
else:
if not ps.is_list(self.out_layer_activation):
self.out_layer_activation = [self.out_layer_activation] * len(out_dim)
assert len(self.out_layer_activation) == len(self.out_dim)
tails = []
for out_d, out_activ in zip(self.out_dim, self.out_layer_activation):
tail = net_util.build_fc_model([self.rnn_hidden_size, out_d], out_activ)
tails.append(tail)
self.model_tails = nn.ModuleList(tails)
net_util.init_layers(self, self.init_fn)
self.loss_fn = net_util.get_loss_fn(self, self.loss_spec)
self.to(self.device)
self.train()
def get_policy_out_dim(body):
'''Helper method to construct the policy network out_dim for a body according to is_discrete, action_type'''
action_dim = body.action_dim
if body.is_discrete:
if body.action_type == 'multi_discrete':
assert ps.is_list(action_dim), action_dim
policy_out_dim = action_dim
else:
assert ps.is_integer(action_dim), action_dim
policy_out_dim = action_dim
else:
assert ps.is_integer(action_dim), action_dim
if action_dim == 1: # single action, use [loc, scale]
policy_out_dim = 2
else: # multi-action, use [locs], [scales]
policy_out_dim = [action_dim, action_dim]
return policy_out_dim
])
# conv body
self.conv_model = self.build_conv_layers(self.conv_hid_layers)
self.conv_out_dim = self.get_conv_output_size()
# fc body
if ps.is_empty(self.fc_hid_layers):
tail_in_dim = self.conv_out_dim
else:
# fc body from flattened conv
self.fc_model = net_util.build_fc_model([self.conv_out_dim] + self.fc_hid_layers, self.hid_layers_activation)
tail_in_dim = self.fc_hid_layers[-1]
# tails. avoid list for single-tail for compute speed
if ps.is_integer(self.out_dim):
self.model_tail = net_util.build_fc_model([tail_in_dim, self.out_dim], self.out_layer_activation)
else:
if not ps.is_list(self.out_layer_activation):
self.out_layer_activation = [self.out_layer_activation] * len(out_dim)
assert len(self.out_layer_activation) == len(self.out_dim)
tails = []
for out_d, out_activ in zip(self.out_dim, self.out_layer_activation):
tail = net_util.build_fc_model([tail_in_dim, out_d], out_activ)
tails.append(tail)
self.model_tails = nn.ModuleList(tails)
net_util.init_layers(self, self.init_fn)
self.loss_fn = net_util.get_loss_fn(self, self.loss_spec)
self.to(self.device)
self.train()
])
# conv body
self.conv_model = self.build_conv_layers(self.conv_hid_layers)
self.conv_out_dim = self.get_conv_output_size()
# fc body
if ps.is_empty(self.fc_hid_layers):
tail_in_dim = self.conv_out_dim
else:
# fc body from flattened conv
self.fc_model = net_util.build_fc_model([self.conv_out_dim] + self.fc_hid_layers, self.hid_layers_activation)
tail_in_dim = self.fc_hid_layers[-1]
# tails. avoid list for single-tail for compute speed
if ps.is_integer(self.out_dim):
self.model_tail = net_util.build_fc_model([tail_in_dim, self.out_dim], self.out_layer_activation)
else:
if not ps.is_list(self.out_layer_activation):
self.out_layer_activation = [self.out_layer_activation] * len(out_dim)
assert len(self.out_layer_activation) == len(self.out_dim)
tails = []
for out_d, out_activ in zip(self.out_dim, self.out_layer_activation):
tail = net_util.build_fc_model([tail_in_dim, out_d], out_activ)
tails.append(tail)
self.model_tails = nn.ModuleList(tails)
net_util.init_layers(self, self.init_fn)
self.loss_fn = net_util.get_loss_fn(self, self.loss_spec)
self.to(self.device)
self.train()
def get_policy_out_dim(body):
'''Helper method to construct the policy network out_dim for a body according to is_discrete, action_type'''
action_dim = body.action_dim
if body.is_discrete:
if body.action_type == 'multi_discrete':
assert ps.is_list(action_dim), action_dim
policy_out_dim = action_dim
else:
assert ps.is_integer(action_dim), action_dim
policy_out_dim = action_dim
else:
assert ps.is_integer(action_dim), action_dim
if action_dim == 1: # single action, use [loc, scale]
policy_out_dim = 2
else: # multi-action, use [locs], [scales]
policy_out_dim = [action_dim, action_dim]
return policy_out_dim
'init_fn',
'clip_grad_val',
'loss_spec',
'optim_spec',
'lr_scheduler_spec',
'update_type',
'update_frequency',
'polyak_coef',
'gpu',
])
dims = [self.in_dim] + self.hid_layers
self.model = net_util.build_fc_model(dims, self.hid_layers_activation)
# add last layer with no activation
# tails. avoid list for single-tail for compute speed
if ps.is_integer(self.out_dim):
self.model_tail = net_util.build_fc_model([dims[-1], self.out_dim], self.out_layer_activation)
else:
if not ps.is_list(self.out_layer_activation):
self.out_layer_activation = [self.out_layer_activation] * len(out_dim)
assert len(self.out_layer_activation) == len(self.out_dim)
tails = []
for out_d, out_activ in zip(self.out_dim, self.out_layer_activation):
tail = net_util.build_fc_model([dims[-1], out_d], out_activ)
tails.append(tail)
self.model_tails = nn.ModuleList(tails)
net_util.init_layers(self, self.init_fn)
self.loss_fn = net_util.get_loss_fn(self, self.loss_spec)
self.to(self.device)
self.train()
])
# conv layer
self.conv_model = self.build_conv_layers(self.conv_hid_layers)
self.conv_out_dim = self.get_conv_output_size()
# fc layer
if not ps.is_empty(self.fc_hid_layers):
# fc layer from flattened conv
self.fc_model = self.build_fc_layers(self.fc_hid_layers)
tail_in_dim = self.fc_hid_layers[-1]
else:
tail_in_dim = self.conv_out_dim
# tails. avoid list for single-tail for compute speed
if ps.is_integer(self.out_dim):
self.model_tail = nn.Linear(tail_in_dim, self.out_dim)
else:
self.model_tails = nn.ModuleList([nn.Linear(tail_in_dim, out_d) for out_d in self.out_dim])
net_util.init_layers(self, self.init_fn)
for module in self.modules():
module.to(self.device)
self.loss_fn = net_util.get_loss_fn(self, self.loss_spec)
self.optim = net_util.get_optim(self, self.optim_spec)
self.lr_scheduler = net_util.get_lr_scheduler(self, self.lr_scheduler_spec)