Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.
models[0].load("data/battle_model", args.opponent)
else:
models.append(magent.ProcessingModel(env, handles[1], names[1], 20000, 0, RandomActor))
# load our model
if args.alg == 'dqn':
from magent.builtin.tf_model import DeepQNetwork
models.append(magent.ProcessingModel(env, handles[0], names[0], 20001, 1000, DeepQNetwork,
batch_size=batch_size,
learning_rate=3e-4,
memory_size=2 ** 20, train_freq=train_freq, eval_obs=eval_obs[0]))
step_batch_size = None
elif args.alg == 'drqn':
from magent.builtin.tf_model import DeepRecurrentQNetwork
models.append(magent.ProcessingModel(env, handles[0], names[0], 20001, 1000, DeepRecurrentQNetwork,
batch_size=batch_size/unroll_step, unroll_step=unroll_step,
learning_rate=3e-4,
memory_size=4 * 625, train_freq=train_freq, eval_obs=eval_obs[0]))
step_batch_size = None
elif args.alg == 'a2c':
from magent.builtin.mx_model import AdvantageActorCritic
step_batch_size = 10 * args.map_size * args.map_size * 0.04
models.append(magent.ProcessingModel(env, handles[0], names[0], 20001, 1000, AdvantageActorCritic,
learning_rate=1e-3))
# load if
savedir = 'save_model'
if args.load_from is not None:
start_from = args.load_from
print("load ... %d" % start_from)
models[0].load(savedir, start_from)
models = [
RandomActor(env, deer_handle, tiger_handle),
]
batch_size = 512
unroll = 8
if args.alg == 'dqn':
from magent.builtin.tf_model import DeepQNetwork
models.append(DeepQNetwork(env, tiger_handle, "tiger",
batch_size=batch_size,
memory_size=2 ** 20, learning_rate=4e-4))
step_batch_size = None
elif args.alg == 'drqn':
from magent.builtin.tf_model import DeepRecurrentQNetwork
models.append(DeepRecurrentQNetwork(env, tiger_handle, "tiger",
batch_size=batch_size/unroll, unroll_step=unroll,
memory_size=20000, learning_rate=4e-4))
step_batch_size = None
elif args.alg == 'a2c':
from magent.builtin.mx_model import AdvantageActorCritic
step_batch_size = int(10 * args.map_size * args.map_size*0.01)
models.append(AdvantageActorCritic(env, tiger_handle, "tiger",
batch_size=step_batch_size,
learning_rate=1e-2))
else:
raise NotImplementedError
# load if
savedir = 'save_model'
if args.load_from is not None:
start_from = args.load_from
batch_size = 512
unroll_step = 8
target_update = 1200
train_freq = 5
models = []
if args.alg == 'dqn':
from magent.builtin.tf_model import DeepQNetwork
models.append(DeepQNetwork(env, handles[0], args.name,
batch_size=batch_size,
learning_rate=3e-4,
memory_size=2 ** 21, target_update=target_update,
train_freq=train_freq, eval_obs=eval_obs))
elif args.alg == 'drqn':
from magent.builtin.tf_model import DeepRecurrentQNetwork
models.append(DeepRecurrentQNetwork(env, handles[0], args.name,
learning_rate=3e-4,
batch_size=batch_size/unroll_step, unroll_step=unroll_step,
memory_size=2 * 8 * 625, target_update=target_update,
train_freq=train_freq, eval_obs=eval_obs))
else:
# see train_against.py to know how to use a2c
raise NotImplementedError
models.append(models[0])
# load if
savedir = 'save_model'
if args.load_from is not None:
start_from = args.load_from
print("load ... %d" % start_from)
for model in models:
# load models
batch_size = 256
unroll_step = 8
target_update = 1200
train_freq = 5
if args.alg == 'dqn':
from magent.builtin.tf_model import DeepQNetwork
RLModel = DeepQNetwork
base_args = {'batch_size': batch_size,
'memory_size': 2 ** 20, 'learning_rate': 1e-4,
'target_update': target_update, 'train_freq': train_freq}
elif args.alg == 'drqn':
from magent.builtin.tf_model import DeepRecurrentQNetwork
RLModel = DeepRecurrentQNetwork
base_args = {'batch_size': batch_size / unroll_step, 'unroll_step': unroll_step,
'memory_size': 8 * 625, 'learning_rate': 1e-4,
'target_update': target_update, 'train_freq': train_freq}
elif args.alg == 'a2c':
# see train_against.py to know how to use a2c
raise NotImplementedError
# init models
names = [args.name + "-l", args.name + "-r"]
models = []
for i in range(len(names)):
model_args = {'eval_obs': eval_obs[i]}
model_args.update(base_args)
models.append(magent.ProcessingModel(env, handles[i], names[i], 20000+i, 1000, RLModel, **model_args))
# load models
batch_size = 256
unroll_step = 8
target_update = 1000
train_freq = 5
if args.alg == 'dqn':
from magent.builtin.tf_model import DeepQNetwork
RLModel = DeepQNetwork
base_args = {'batch_size': batch_size,
'memory_size': 2 ** 20,
'target_update': target_update, 'train_freq': train_freq}
elif args.alg == 'drqn':
from magent.builtin.tf_model import DeepRecurrentQNetwork
RLModel = DeepRecurrentQNetwork
base_args = {'batch_size': batch_size / unroll_step, 'unroll_step': unroll_step,
'memory_size': 8 * 300,
'target_update': target_update, 'train_freq': train_freq}
elif args.alg == 'a2c':
raise NotImplementedError
else:
raise NotImplementedError
# load models
names = [args.name + "-l0", args.name + "-l1", args.name + "-r0", args.name + "-r1"]
models = []
for i in range(len(names)):
model_args = {'eval_obs': eval_obs[i]}
model_args.update(base_args)
models.append(magent.ProcessingModel(env, handles[i], names[i], 20000+i, 1000, RLModel, **model_args))
for i in range(len(handles)):
eval_obs[i] = magent.utility.sample_observation(env, handles, 2048, 500)
# load models
batch_size = 256
unroll_step = 8
target_update = 1200
train_freq = 5
if args.alg == 'dqn':
RLModel = DeepQNetwork
base_args = {'batch_size': batch_size,
'memory_size': 2 ** 21, 'learning_rate': 1e-4,
'target_update': target_update, 'train_freq': train_freq}
elif args.alg == 'drqn':
RLModel = DeepRecurrentQNetwork
base_args = {'batch_size': batch_size / unroll_step, 'unroll_step': unroll_step,
'memory_size': 8 * 625, 'learning_rate': 1e-4,
'target_update': target_update, 'train_freq': train_freq}
elif args.alg == 'a2c':
raise NotImplementedError
else:
raise NotImplementedError
# init models
names = [args.name + "-l", args.name + "-r"]
models = []
for i in range(len(names)):
model_args = {'eval_obs': eval_obs[i]}
model_args.update(base_args)
models.append(magent.ProcessingModel(env, handles[i], names[i], 20000, 1000, RLModel, **model_args))