Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.
if args.eval:
print("sample eval set...")
env.reset()
generate_map(env, args.map_size, handles)
for i in range(len(handles)):
eval_obs[i] = magent.utility.sample_observation(env, handles, 2048, 500)
# load models
batch_size = 256
unroll_step = 8
target_update = 1200
train_freq = 5
if args.alg == 'dqn':
from magent.builtin.tf_model import DeepQNetwork
RLModel = DeepQNetwork
base_args = {'batch_size': batch_size,
'memory_size': 2 ** 20, 'learning_rate': 1e-4,
'target_update': target_update, 'train_freq': train_freq}
elif args.alg == 'drqn':
from magent.builtin.tf_model import DeepRecurrentQNetwork
RLModel = DeepRecurrentQNetwork
base_args = {'batch_size': batch_size / unroll_step, 'unroll_step': unroll_step,
'memory_size': 8 * 625, 'learning_rate': 1e-4,
'target_update': target_update, 'train_freq': train_freq}
elif args.alg == 'a2c':
# see train_against.py to know how to use a2c
raise NotImplementedError
# init models
names = [args.name + "-l", args.name + "-r"]
models = []
env.set_render_dir("build/render")
# two groups of animal
deer_handle, tiger_handle = env.get_handles()
# init two models
models = [
RandomActor(env, deer_handle, tiger_handle),
]
batch_size = 512
unroll = 8
if args.alg == 'dqn':
from magent.builtin.tf_model import DeepQNetwork
models.append(DeepQNetwork(env, tiger_handle, "tiger",
batch_size=batch_size,
memory_size=2 ** 20, learning_rate=4e-4))
step_batch_size = None
elif args.alg == 'drqn':
from magent.builtin.tf_model import DeepRecurrentQNetwork
models.append(DeepRecurrentQNetwork(env, tiger_handle, "tiger",
batch_size=batch_size/unroll, unroll_step=unroll,
memory_size=20000, learning_rate=4e-4))
step_batch_size = None
elif args.alg == 'a2c':
from magent.builtin.mx_model import AdvantageActorCritic
step_batch_size = int(10 * args.map_size * args.map_size*0.01)
models.append(AdvantageActorCritic(env, tiger_handle, "tiger",
batch_size=step_batch_size,
learning_rate=1e-2))
else:
def __init__(self, path="data/battle_model", total_step=1000, add_counter=10, add_interval=50):
# some parameter
map_size = 125
eps = 0.05
# init the game
env = magent.GridWorld(load_config(map_size))
handles = env.get_handles()
models = []
models.append(DeepQNetwork(env, handles[0], 'trusty-battle-game-l', use_conv=True))
models.append(DeepQNetwork(env, handles[1], 'trusty-battle-game-r', use_conv=True))
# load model
models[0].load(path, 0, 'trusty-battle-game-l')
models[1].load(path, 0, 'trusty-battle-game-r')
# init environment
env.reset()
generate_map(env, map_size, handles)
# save to member variable
self.env = env
self.handles = handles
self.eps = eps
self.models = models
self.map_size = map_size
self.total_step = total_step
if match and int(match.group(1)) > begin:
ret.append((savedir, name, int(match.group(1)), model_class))
ret.sort(key=lambda x: x[2])
ret = [ret[i] for i in range(0, len(ret), pick_every)]
return ret
if __name__ == '__main__':
map_size = 125
env = magent.GridWorld("battle", map_size=map_size)
env.set_render_dir("build/render")
# scan file names
model_name = extract_model_names('save_model', 'battle', DeepQNetwork, begin=0, pick_every=5)
print("total models = %d" % len(model_name))
print("models", [x[:-1] for x in model_name])
handles = env.get_handles()
def play_wrapper(model_names, n_rounds):
time_stamp = time.time()
models = []
for i, item in enumerate(model_names):
models.append(magent.ProcessingModel(env, handles[i], item[1], 0, item[-1]))
for i, item in enumerate(model_names):
models[i].load(item[0], item[2])
leftID, rightID = 0, 1
# init the game "pursuit" (config file are stored in python/magent/builtin/config/)
env = magent.GridWorld("pursuit", map_size=map_size)
env.set_render_dir("build/render")
# get group handles
predator, prey = env.get_handles()
# init env and agents
env.reset()
env.add_walls(method="random", n=map_size * map_size * 0.01)
env.add_agents(predator, method="random", n=map_size * map_size * 0.02)
env.add_agents(prey, method="random", n=map_size * map_size * 0.02)
# init two models
model1 = DeepQNetwork(env, predator, "predator")
model2 = DeepQNetwork(env, prey, "prey")
# load trained model
model1.load("data/pursuit_model")
model2.load("data/pursuit_model")
done = False
step_ct = 0
print("nums: %d vs %d" % (env.get_num(predator), env.get_num(prey)))
while not done:
# take actions for deers
obs_1 = env.get_observation(predator)
ids_1 = env.get_agent_id(predator)
acts_1 = model1.infer_action(obs_1, ids_1)
env.set_action(predator, acts_1)
if args.eval:
print("sample eval set...")
env.reset()
generate_map(env, args.map_size, handles)
eval_obs = magent.utility.sample_observation(env, handles, 2048, 500)[0]
# init models
batch_size = 512
unroll_step = 8
target_update = 1200
train_freq = 5
models = []
if args.alg == 'dqn':
from magent.builtin.tf_model import DeepQNetwork
models.append(DeepQNetwork(env, handles[0], args.name,
batch_size=batch_size,
learning_rate=3e-4,
memory_size=2 ** 21, target_update=target_update,
train_freq=train_freq, eval_obs=eval_obs))
elif args.alg == 'drqn':
from magent.builtin.tf_model import DeepRecurrentQNetwork
models.append(DeepRecurrentQNetwork(env, handles[0], args.name,
learning_rate=3e-4,
batch_size=batch_size/unroll_step, unroll_step=unroll_step,
memory_size=2 * 8 * 625, target_update=target_update,
train_freq=train_freq, eval_obs=eval_obs))
else:
# see train_against.py to know how to use a2c
raise NotImplementedError
models.append(models[0])
eval_obs = [None, None]
if args.eval:
print("sample eval set...")
env.reset()
generate_map(env, args.map_size, handles)
for i in range(len(handles)):
eval_obs[i] = magent.utility.sample_observation(env, handles, 2048, 500)
# load models
batch_size = 256
unroll_step = 8
target_update = 1200
train_freq = 5
if args.alg == 'dqn':
RLModel = DeepQNetwork
base_args = {'batch_size': batch_size,
'memory_size': 2 ** 21, 'learning_rate': 1e-4,
'target_update': target_update, 'train_freq': train_freq}
elif args.alg == 'drqn':
RLModel = DeepRecurrentQNetwork
base_args = {'batch_size': batch_size / unroll_step, 'unroll_step': unroll_step,
'memory_size': 8 * 625, 'learning_rate': 1e-4,
'target_update': target_update, 'train_freq': train_freq}
elif args.alg == 'a2c':
raise NotImplementedError
else:
raise NotImplementedError
# init models
names = [args.name + "-l", args.name + "-r"]
models = []
# init the game "pursuit" (config file are stored in python/magent/builtin/config/)
env = magent.GridWorld("pursuit", map_size=map_size)
env.set_render_dir("build/render")
# get group handles
predator, prey = env.get_handles()
# init env and agents
env.reset()
env.add_walls(method="random", n=map_size * map_size * 0.01)
env.add_agents(predator, method="random", n=map_size * map_size * 0.02)
env.add_agents(prey, method="random", n=map_size * map_size * 0.02)
# init two models
model1 = DeepQNetwork(env, predator, "predator")
model2 = DeepQNetwork(env, prey, "prey")
# load trained model
model1.load("data/pursuit_model")
model2.load("data/pursuit_model")
done = False
step_ct = 0
print("nums: %d vs %d" % (env.get_num(predator), env.get_num(prey)))
while not done:
# take actions for deers
obs_1 = env.get_observation(predator)
ids_1 = env.get_agent_id(predator)
acts_1 = model1.infer_action(obs_1, ids_1)
env.set_action(predator, acts_1)
# take actions for tigers
eval_obs = magent.utility.sample_observation(env, handles, n_obs=2048, step=500)
else:
eval_obs = [None, None]
# init models
names = [args.name + "-a", "battle"]
batch_size = 512
unroll_step = 16
train_freq = 5
models = []
# load opponent
if args.opponent >= 0:
from magent.builtin.tf_model import DeepQNetwork
models.append(magent.ProcessingModel(env, handles[1], names[1], 20000, 0, DeepQNetwork))
models[0].load("data/battle_model", args.opponent)
else:
models.append(magent.ProcessingModel(env, handles[1], names[1], 20000, 0, RandomActor))
# load our model
if args.alg == 'dqn':
from magent.builtin.tf_model import DeepQNetwork
models.append(magent.ProcessingModel(env, handles[0], names[0], 20001, 1000, DeepQNetwork,
batch_size=batch_size,
learning_rate=3e-4,
memory_size=2 ** 20, train_freq=train_freq, eval_obs=eval_obs[0]))
step_batch_size = None
elif args.alg == 'drqn':
from magent.builtin.tf_model import DeepRecurrentQNetwork
models.append(magent.ProcessingModel(env, handles[0], names[0], 20001, 1000, DeepRecurrentQNetwork,
import sys
import magent
from magent.builtin.tf_model import DeepQNetwork
env = magent.GridWorld("battle", map_size=125)
handles = env.get_handles()
rounds = eval(sys.argv[1])
for i in [rounds]:
model = DeepQNetwork(env, handles[0], "battle")
print("load %d" % i)
model.load("data/", i, "selfplay")
print("save %d" % i)
model.save("data/battle_model", i)