Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.
def setUpClass(cls):
cls.env = gym.make("CartPole-v0")
policy = DQN(
state_shape=cls.env.observation_space.shape,
action_dim=cls.env.action_space.n,
memory_capacity=2**4)
cls.replay_buffer = get_replay_buffer(
policy, cls.env)
cls.output_dir = os.path.join(
os.path.dirname(__file__),
"tests")
if not os.path.isdir(cls.output_dir):
os.makedirs(cls.output_dir)
def setUpClass(cls):
super().setUpClass()
cls.agent = DQN(
state_shape=cls.discrete_env.observation_space.shape,
action_dim=cls.discrete_env.action_space.n,
batch_size=cls.batch_size,
enable_categorical_dqn=True,
epsilon=0.,
gpu=-1)
def setUpClass(cls):
super().setUpClass()
cls.agent = DQN(
state_shape=cls.discrete_env.observation_space.shape,
action_dim=cls.discrete_env.action_space.n,
batch_size=cls.batch_size,
enable_double_dqn=True,
enable_dueling_dqn=True,
epsilon=0.,
gpu=-1)
def setUpClass(cls):
super().setUpClass()
cls.agent = DQN(
state_shape=cls.discrete_env.observation_space.shape,
action_dim=cls.discrete_env.action_space.n,
batch_size=cls.batch_size,
enable_noisy_dqn=True,
epsilon=0.,
gpu=-1)
if __name__ == '__main__':
parser = Trainer.get_argument()
parser = DQN.get_argument(parser)
parser.set_defaults(test_interval=2000)
parser.set_defaults(max_steps=100000)
parser.set_defaults(gpu=-1)
parser.set_defaults(n_warmup=500)
parser.set_defaults(batch_size=32)
parser.set_defaults(memory_capacity=int(1e4))
parser.add_argument('--env-name', type=str, default="CartPole-v0")
args = parser.parse_args()
env = gym.make(args.env_name)
test_env = gym.make(args.env_name)
policy = DQN(
enable_double_dqn=args.enable_double_dqn,
enable_dueling_dqn=args.enable_dueling_dqn,
enable_noisy_dqn=args.enable_noisy_dqn,
enable_categorical_dqn=args.enable_categorical_dqn,
state_shape=env.observation_space.shape,
action_dim=env.action_space.n,
target_replace_interval=300,
discount=0.99,
gpu=args.gpu,
memory_capacity=args.memory_capacity,
batch_size=args.batch_size,
n_warmup=args.n_warmup)
trainer = Trainer(policy, env, args, test_env=test_env)
trainer()
def call(self, inputs):
"""Compute probability
"""
features = tf.concat(inputs, axis=1)
features = self.l1(features)
features = self.l2(features)
if self._enable_dueling_dqn:
raise NotImplementedError
else:
features = self.l3(features) # [batch_size, action_dim * n_atoms]
features = tf.reshape(features, (-1, self._action_dim, self._n_atoms)) # [batch_size, action_dim, n_atoms]
features = tf.keras.activations.softmax(features, axis=2) # [batch_size, action_dim, n_atoms]
return tf.clip_by_value(features, 1e-8, 1.0-1e-8)
class CategoricalDQN(DQN):
def __init__(self, *args, **kwargs):
kwargs["q_func"] = CategoricalQFunc
super().__init__(*args, **kwargs)
self._v_max, self._v_min = 10., -10.
self._delta_z = (self._v_max - self._v_min) / (self.q_func._n_atoms - 1)
self._z_list = tf.constant(
[self._v_min + i * self._delta_z for i in range(self.q_func._n_atoms)],
dtype=tf.float64)
self._z_list_broadcasted = tf.tile(
tf.reshape(self._z_list, [1, self.q_func._n_atoms]),
tf.constant([self._action_dim, 1]))
def get_action(self, state, test=False):
if isinstance(state, LazyFrames):
state = np.array(state)
assert isinstance(state, np.ndarray)
def policy_fn(env, name, memory_capacity=int(1e6),
gpu=-1, noise_level=0.3):
return DQN(
name=name,
enable_double_dqn=args.enable_double_dqn,
enable_dueling_dqn=args.enable_dueling_dqn,
enable_noisy_dqn=args.enable_noisy_dqn,
enable_categorical_dqn=args.enable_categorical_dqn,
state_shape=env.observation_space.shape,
action_dim=env.action_space.n,
n_warmup=n_warmup,
target_replace_interval=target_replace_interval,
batch_size=batch_size,
memory_capacity=memory_capacity,
discount=0.99,
epsilon=1.,
epsilon_min=0.1,
epsilon_decay_step=epsilon_decay_rate,
optimizer=optimizer,