Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.
def get_action(self, state, test=False):
if isinstance(state, LazyFrames):
state = np.array(state)
assert isinstance(state, np.ndarray), \
"Input instance should be np.ndarray, not {}".format(type(state))
is_single_input = state.ndim == self._state_ndim
if is_single_input:
state = np.expand_dims(state, axis=0).astype(np.float32)
action, logp, _ = self._get_action_body(state, test)
if is_single_input:
return action.numpy()[0], logp.numpy()
else:
return action.numpy(), logp.numpy()
def get_action(self, state, test=False, tensor=False):
if isinstance(state, LazyFrames):
state = np.array(state)
if self.discrete_input:
is_single_input = not isinstance(state, np.ndarray)
else:
if not tensor:
assert isinstance(state, np.ndarray)
is_single_input = state.ndim == self._state_ndim
if not test and np.random.rand() < self.epsilon:
if is_single_input:
action = np.random.randint(self._action_dim)
else:
action = np.array([np.random.randint(self._action_dim)
for _ in range(state.shape[0])], dtype=np.int64)
if tensor:
return tf.convert_to_tensor(action)
def _get_ob(self):
assert len(self.frames) == self.k
return LazyFrames(list(self.frames))
def get_action_and_val(self, state, test=False):
if isinstance(state, LazyFrames):
state = np.array(state)
is_single_input = state.ndim == self._state_ndim
if is_single_input:
state = np.expand_dims(state, axis=0).astype(np.float32)
action, logp, v = self._get_action_logp_v_body(state, test)
if is_single_input:
v = v[0]
action = action[0]
return action.numpy(), logp.numpy(), v.numpy()
def get_action(self, state, test=False):
if isinstance(state, LazyFrames):
state = np.array(state)
assert isinstance(state, np.ndarray)
if not test and np.random.rand() < self.epsilon:
action = np.random.randint(self._action_dim)
else:
state = np.expand_dims(state, axis=0).astype(np.float64)
action_probs = self._get_action_body(tf.constant(state))
action = tf.argmax(
tf.reduce_sum(action_probs * self._z_list_broadcasted, axis=2),
axis=1)
action = action.numpy()[0]
return action
def step(self, action):
next_obs, rew, done, env_info = self.env.step(action)
assert isinstance(next_obs, LazyFrames)
return np.array(next_obs), rew, done, env_info