Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.
def setUpClass(cls):
super().setUpClass()
cls.agent = SAC(
state_shape=cls.continuous_env.observation_space.shape,
action_dim=cls.continuous_env.action_space.low.size,
batch_size=cls.batch_size,
gpu=-1)
dim_state, dim_action = env.observation_space.shape[0], env.action_space.shape[0]
max_action = env.action_space.high
with tf.device("/gpu:0"):
if name == "DDPG":
if tf2rl:
policy = DDPG_tf2rl(state_shape=(dim_state,), action_dim=dim_action,
max_action=max_action[0], max_grad=1.)
saved_policy = DDPG_tf2rl(state_shape=(dim_state,), action_dim=dim_action,
max_action=max_action[0], max_grad=1.)
else:
policy = DDPG.DDPG(dim_state, dim_action, max_action, training=True)
saved_policy = DDPG.DDPG(dim_state, dim_action, max_action, training=False)
elif name == "SAC":
if tf2rl:
policy = SAC_tf2rl(state_shape=(dim_state,), action_dim=dim_action, max_action=max_action[0])
saved_policy = SAC_tf2rl(state_shape=(dim_state,), action_dim=dim_action, max_action=max_action[0])
else:
policy = SAC.SAC(dim_state, dim_action, max_action, training=True)
saved_policy = SAC.SAC(dim_state, dim_action, max_action, training=False)
else:
raise ValueError("invalid policy")
return policy, saved_policy
max_action = env.action_space.high
with tf.device("/gpu:0"):
if name == "DDPG":
if tf2rl:
policy = DDPG_tf2rl(state_shape=(dim_state,), action_dim=dim_action,
max_action=max_action[0], max_grad=1.)
saved_policy = DDPG_tf2rl(state_shape=(dim_state,), action_dim=dim_action,
max_action=max_action[0], max_grad=1.)
else:
policy = DDPG.DDPG(dim_state, dim_action, max_action, training=True)
saved_policy = DDPG.DDPG(dim_state, dim_action, max_action, training=False)
elif name == "SAC":
if tf2rl:
policy = SAC_tf2rl(state_shape=(dim_state,), action_dim=dim_action, max_action=max_action[0])
saved_policy = SAC_tf2rl(state_shape=(dim_state,), action_dim=dim_action, max_action=max_action[0])
else:
policy = SAC.SAC(dim_state, dim_action, max_action, training=True)
saved_policy = SAC.SAC(dim_state, dim_action, max_action, training=False)
else:
raise ValueError("invalid policy")
return policy, saved_policy
def get_argument(parser=None):
parser = SAC.get_argument(parser)
parser.add_argument('--target-update-interval', type=int, default=None)
return parser
from tf2rl.algos.sac import SAC
from tf2rl.experiments.trainer import Trainer
if __name__ == '__main__':
parser = Trainer.get_argument()
parser = SAC.get_argument(parser)
parser.add_argument('--env-name', type=str, default="RoboschoolAnt-v1")
parser.set_defaults(batch_size=100)
parser.set_defaults(n_warmup=10000)
args = parser.parse_args()
env = gym.make(args.env_name)
test_env = gym.make(args.env_name)
policy = SAC(
state_shape=env.observation_space.shape,
action_dim=env.action_space.high.size,
gpu=args.gpu,
memory_capacity=args.memory_capacity,
max_action=env.action_space.high[0],
batch_size=args.batch_size,
n_warmup=args.n_warmup,
auto_alpha=args.auto_alpha)
trainer = Trainer(policy, env, args, test_env=test_env)
trainer()
self.l2 = Dense(critic_units[1], name="L2", activation='relu')
self.l3 = Dense(action_dim, name="L2", activation='linear')
dummy_state = tf.constant(
np.zeros(shape=(1,) + state_shape, dtype=np.float32))
self(dummy_state)
def call(self, states):
features = self.l1(states)
features = self.l2(features)
values = self.l3(features)
return values
class SACDiscrete(SAC):
def __init__(
self,
state_shape,
action_dim,
*args,
actor_fn=None,
critic_fn=None,
target_update_interval=None,
**kwargs):
kwargs["name"] = "SAC_discrete"
self.actor_fn = actor_fn if actor_fn is not None else CategoricalActor
self.critic_fn = critic_fn if critic_fn is not None else CriticQ
self.target_hard_update = target_update_interval is not None
self.target_update_interval = target_update_interval
self.n_training = tf.Variable(0, dtype=tf.int32)
super().__init__(state_shape, action_dim, *args, **kwargs)
import roboschool
import gym
from tf2rl.algos.sac import SAC
from tf2rl.experiments.trainer import Trainer
if __name__ == '__main__':
parser = Trainer.get_argument()
parser = SAC.get_argument(parser)
parser.add_argument('--env-name', type=str, default="RoboschoolAnt-v1")
parser.set_defaults(batch_size=100)
parser.set_defaults(n_warmup=10000)
args = parser.parse_args()
env = gym.make(args.env_name)
test_env = gym.make(args.env_name)
policy = SAC(
state_shape=env.observation_space.shape,
action_dim=env.action_space.high.size,
gpu=args.gpu,
memory_capacity=args.memory_capacity,
max_action=env.action_space.high[0],
batch_size=args.batch_size,
n_warmup=args.n_warmup,
auto_alpha=args.auto_alpha)