Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.
q_func = build_q_func(network, **network_kwargs)
def make_obs_ph(name):
return ObservationInput(DotaEnvironment.get_observation_space(), name=name)
_, train, update_target, debug = deepq.build_train(
scope='deepq_train',
make_obs_ph=make_obs_ph,
q_func=q_func,
num_actions=DotaEnvironment.get_action_space().n,
optimizer=tf.train.AdamOptimizer(learning_rate=lr),
gamma=gamma,
grad_norm_clipping=10,)
if prioritized_replay:
replay_buffer = PrioritizedReplayBuffer(buffer_size, alpha=prioritized_replay_alpha)
if prioritized_replay_beta_iters is None:
prioritized_replay_beta_iters = total_timesteps
beta_schedule = LinearSchedule(prioritized_replay_beta_iters,
initial_p=prioritized_replay_beta0,
final_p=1.0)
else:
replay_buffer = ReplayBuffer(buffer_size)
beta_schedule = None
U.initialize()
update_target()
reward_shaper = ActionAdviceRewardShaper(config=config)
reward_shaper.load()
reward_shaper.generate_merged_demo()
"""Create Prioritized Replay buffer.
Parameters
----------
size: int
Max number of transitions to store in the buffer. When the buffer
overflows the old memories are dropped.
alpha: float
how much prioritization is used
(0 - no prioritization, 1 - full prioritization)
See Also
--------
ReplayBuffer.__init__
"""
super(PrioritizedReplayBuffer, self).__init__(size)
assert alpha >= 0
self._alpha = alpha
it_capacity = 1
while it_capacity < size:
it_capacity *= 2
self._it_sum = SumSegmentTree(it_capacity)
self._it_min = MinSegmentTree(it_capacity)
self._max_priority = 1.0