Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.
def do_network_training(
updates_queue: multiprocessing.Queue,
weights_queue: multiprocessing.Queue,
network, seed, config, lr, total_timesteps, learning_starts,
buffer_size, exploration_fraction, exploration_initial_eps, exploration_final_eps,
train_freq, batch_size, print_freq, checkpoint_freq, gamma,
target_network_update_freq, prioritized_replay, prioritized_replay_alpha,
prioritized_replay_beta0, prioritized_replay_beta_iters,
prioritized_replay_eps, experiment_name, load_path, network_kwargs):
_ = get_session()
set_global_seeds(seed)
q_func = build_q_func(network, **network_kwargs)
def make_obs_ph(name):
return ObservationInput(DotaEnvironment.get_observation_space(), name=name)
_, train, update_target, debug = deepq.build_train(
scope='deepq_train',
make_obs_ph=make_obs_ph,
q_func=q_func,
num_actions=DotaEnvironment.get_action_space().n,
optimizer=tf.train.AdamOptimizer(learning_rate=lr),
gamma=gamma,
grad_norm_clipping=10,)
if prioritized_replay:
replay_buffer = PrioritizedReplayBuffer(buffer_size, alpha=prioritized_replay_alpha)
if prioritized_replay_beta_iters is None:
def do_agent_exploration(
updates_queue: multiprocessing.Queue,
q_func_vars_trained_queue: multiprocessing.Queue,
network, seed, config, lr, total_timesteps, learning_starts,
buffer_size, exploration_fraction, exploration_initial_eps, exploration_final_eps,
train_freq, batch_size, print_freq, checkpoint_freq, gamma,
target_network_update_freq, prioritized_replay, prioritized_replay_alpha,
prioritized_replay_beta0, prioritized_replay_beta_iters,
prioritized_replay_eps, experiment_name, load_path, network_kwargs):
env = DotaEnvironment()
sess = get_session()
set_global_seeds(seed)
q_func = build_q_func(network, **network_kwargs)
# capture the shape outside the closure so that the env object is not serialized
# by cloudpickle when serializing make_obs_ph
observation_space = env.observation_space
def make_obs_ph(name):
return ObservationInput(observation_space, name=name)
act, _, _, debug = deepq.build_train(
scope='deepq_act',
make_obs_ph=make_obs_ph,
q_func=q_func,
num_actions=env.action_space.n,
optimizer=tf.train.AdamOptimizer(learning_rate=lr),
gamma=gamma,
grad_norm_clipping=10, )