How to use the magent.builtin.tf_model.DeepQNetwork function in magent

To help you get started, we’ve selected a few magent examples, based on popular ways it is used in public projects.

Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.

github geek-ai / MAgent / examples / train_battle.py View on Github external
if args.eval:
        print("sample eval set...")
        env.reset()
        generate_map(env, args.map_size, handles)
        for i in range(len(handles)):
            eval_obs[i] = magent.utility.sample_observation(env, handles, 2048, 500)

    # load models
    batch_size = 256
    unroll_step = 8
    target_update = 1200
    train_freq = 5

    if args.alg == 'dqn':
        from magent.builtin.tf_model import DeepQNetwork
        RLModel = DeepQNetwork
        base_args = {'batch_size': batch_size,
                     'memory_size': 2 ** 20, 'learning_rate': 1e-4,
                     'target_update': target_update, 'train_freq': train_freq}
    elif args.alg == 'drqn':
        from magent.builtin.tf_model import DeepRecurrentQNetwork
        RLModel = DeepRecurrentQNetwork
        base_args = {'batch_size': batch_size / unroll_step, 'unroll_step': unroll_step,
                     'memory_size': 8 * 625, 'learning_rate': 1e-4,
                     'target_update': target_update, 'train_freq': train_freq}
    elif args.alg == 'a2c':
        # see train_against.py to know how to use a2c
        raise NotImplementedError

    # init models
    names = [args.name + "-l", args.name + "-r"]
    models = []
github geek-ai / MAgent / examples / train_tiger.py View on Github external
env.set_render_dir("build/render")

    # two groups of animal
    deer_handle, tiger_handle = env.get_handles()

    # init two models
    models = [
        RandomActor(env, deer_handle, tiger_handle),
    ]

    batch_size = 512
    unroll     = 8

    if args.alg == 'dqn':
        from magent.builtin.tf_model import DeepQNetwork
        models.append(DeepQNetwork(env, tiger_handle, "tiger",
                                   batch_size=batch_size,
                                   memory_size=2 ** 20, learning_rate=4e-4))
        step_batch_size = None
    elif args.alg == 'drqn':
        from magent.builtin.tf_model import DeepRecurrentQNetwork
        models.append(DeepRecurrentQNetwork(env, tiger_handle, "tiger",
                                   batch_size=batch_size/unroll, unroll_step=unroll,
                                   memory_size=20000, learning_rate=4e-4))
        step_batch_size = None
    elif args.alg == 'a2c':
        from magent.builtin.mx_model import AdvantageActorCritic
        step_batch_size = int(10 * args.map_size * args.map_size*0.01)
        models.append(AdvantageActorCritic(env, tiger_handle, "tiger",
                                   batch_size=step_batch_size,
                                   learning_rate=1e-2))
    else:
github geek-ai / MAgent / python / magent / renderer / server / battle_server.py View on Github external
def __init__(self, path="data/battle_model", total_step=1000, add_counter=10, add_interval=50):
        # some parameter
        map_size = 125
        eps = 0.05

        # init the game
        env = magent.GridWorld(load_config(map_size))

        handles = env.get_handles()
        models = []
        models.append(DeepQNetwork(env, handles[0], 'trusty-battle-game-l', use_conv=True))
        models.append(DeepQNetwork(env, handles[1], 'trusty-battle-game-r', use_conv=True))

        # load model
        models[0].load(path, 0, 'trusty-battle-game-l')
        models[1].load(path, 0, 'trusty-battle-game-r')

        # init environment
        env.reset()
        generate_map(env, map_size, handles)

        # save to member variable
        self.env = env
        self.handles = handles
        self.eps = eps
        self.models = models
        self.map_size = map_size
        self.total_step = total_step
github geek-ai / MAgent / scripts / tournament.py View on Github external
if match and int(match.group(1)) > begin:
            ret.append((savedir, name, int(match.group(1)), model_class))

    ret.sort(key=lambda x: x[2])
    ret = [ret[i] for i in range(0, len(ret), pick_every)]

    return ret


if __name__ == '__main__':
    map_size = 125
    env = magent.GridWorld("battle", map_size=map_size)
    env.set_render_dir("build/render")

    # scan file names
    model_name = extract_model_names('save_model', 'battle', DeepQNetwork, begin=0, pick_every=5)

    print("total models = %d" % len(model_name))
    print("models", [x[:-1] for x in model_name])
    handles = env.get_handles()

    def play_wrapper(model_names, n_rounds):
        time_stamp = time.time()

        models = []
        for i, item in enumerate(model_names):
            models.append(magent.ProcessingModel(env, handles[i], item[1], 0, item[-1]))

        for i, item in enumerate(model_names):
            models[i].load(item[0], item[2])

        leftID, rightID = 0, 1
github geek-ai / MAgent / examples / api_demo.py View on Github external
# init the game "pursuit"  (config file are stored in python/magent/builtin/config/)
    env = magent.GridWorld("pursuit", map_size=map_size)
    env.set_render_dir("build/render")

    # get group handles
    predator, prey = env.get_handles()

    # init env and agents
    env.reset()
    env.add_walls(method="random", n=map_size * map_size * 0.01)
    env.add_agents(predator, method="random", n=map_size * map_size * 0.02)
    env.add_agents(prey,     method="random", n=map_size * map_size * 0.02)

    # init two models
    model1 = DeepQNetwork(env, predator, "predator")
    model2 = DeepQNetwork(env, prey,     "prey")

    # load trained model
    model1.load("data/pursuit_model")
    model2.load("data/pursuit_model")

    done = False
    step_ct = 0
    print("nums: %d vs %d" % (env.get_num(predator), env.get_num(prey)))
    while not done:
        # take actions for deers
        obs_1 = env.get_observation(predator)
        ids_1 = env.get_agent_id(predator)
        acts_1 = model1.infer_action(obs_1, ids_1)
        env.set_action(predator, acts_1)
github geek-ai / MAgent / examples / train_single.py View on Github external
if args.eval:
        print("sample eval set...")
        env.reset()
        generate_map(env, args.map_size, handles)
        eval_obs = magent.utility.sample_observation(env, handles, 2048, 500)[0]

    # init models
    batch_size = 512
    unroll_step = 8
    target_update = 1200
    train_freq = 5

    models = []
    if args.alg == 'dqn':
        from magent.builtin.tf_model import DeepQNetwork
        models.append(DeepQNetwork(env, handles[0], args.name,
                                   batch_size=batch_size,
                                   learning_rate=3e-4,
                                   memory_size=2 ** 21, target_update=target_update,
                                   train_freq=train_freq, eval_obs=eval_obs))
    elif args.alg == 'drqn':
        from magent.builtin.tf_model import DeepRecurrentQNetwork
        models.append(DeepRecurrentQNetwork(env, handles[0], args.name,
                                   learning_rate=3e-4,
                                   batch_size=batch_size/unroll_step, unroll_step=unroll_step,
                                   memory_size=2 * 8 * 625, target_update=target_update,
                                   train_freq=train_freq, eval_obs=eval_obs))
    else:
        # see train_against.py to know how to use a2c
        raise NotImplementedError

    models.append(models[0])
github geek-ai / MAgent / examples / train_battle_game.py View on Github external
eval_obs = [None, None]
    if args.eval:
        print("sample eval set...")
        env.reset()
        generate_map(env, args.map_size, handles)
        for i in range(len(handles)):
            eval_obs[i] = magent.utility.sample_observation(env, handles, 2048, 500)

    # load models
    batch_size = 256
    unroll_step = 8
    target_update = 1200
    train_freq = 5

    if args.alg == 'dqn':
        RLModel = DeepQNetwork
        base_args = {'batch_size': batch_size,
                     'memory_size': 2 ** 21, 'learning_rate': 1e-4,
                     'target_update': target_update, 'train_freq': train_freq}
    elif args.alg == 'drqn':
        RLModel = DeepRecurrentQNetwork
        base_args = {'batch_size': batch_size / unroll_step, 'unroll_step': unroll_step,
                     'memory_size': 8 * 625, 'learning_rate': 1e-4,
                     'target_update': target_update, 'train_freq': train_freq}
    elif args.alg == 'a2c':
        raise NotImplementedError
    else:
        raise NotImplementedError

    # init models
    names = [args.name + "-l", args.name + "-r"]
    models = []
github geek-ai / MAgent / examples / api_demo.py View on Github external
# init the game "pursuit"  (config file are stored in python/magent/builtin/config/)
    env = magent.GridWorld("pursuit", map_size=map_size)
    env.set_render_dir("build/render")

    # get group handles
    predator, prey = env.get_handles()

    # init env and agents
    env.reset()
    env.add_walls(method="random", n=map_size * map_size * 0.01)
    env.add_agents(predator, method="random", n=map_size * map_size * 0.02)
    env.add_agents(prey,     method="random", n=map_size * map_size * 0.02)

    # init two models
    model1 = DeepQNetwork(env, predator, "predator")
    model2 = DeepQNetwork(env, prey,     "prey")

    # load trained model
    model1.load("data/pursuit_model")
    model2.load("data/pursuit_model")

    done = False
    step_ct = 0
    print("nums: %d vs %d" % (env.get_num(predator), env.get_num(prey)))
    while not done:
        # take actions for deers
        obs_1 = env.get_observation(predator)
        ids_1 = env.get_agent_id(predator)
        acts_1 = model1.infer_action(obs_1, ids_1)
        env.set_action(predator, acts_1)

        # take actions for tigers
github geek-ai / MAgent / examples / train_against.py View on Github external
eval_obs = magent.utility.sample_observation(env, handles, n_obs=2048, step=500)
    else:
        eval_obs = [None, None]

    # init models
    names = [args.name + "-a", "battle"]
    batch_size = 512
    unroll_step = 16
    train_freq = 5

    models = []

    # load opponent
    if args.opponent >= 0:
        from magent.builtin.tf_model import DeepQNetwork
        models.append(magent.ProcessingModel(env, handles[1], names[1], 20000, 0, DeepQNetwork))
        models[0].load("data/battle_model", args.opponent)
    else:
        models.append(magent.ProcessingModel(env, handles[1], names[1], 20000, 0, RandomActor))

    # load our model
    if args.alg == 'dqn':
        from magent.builtin.tf_model import DeepQNetwork
        models.append(magent.ProcessingModel(env, handles[0], names[0], 20001, 1000, DeepQNetwork,
                                   batch_size=batch_size,
                                   learning_rate=3e-4,
                                   memory_size=2 ** 20, train_freq=train_freq, eval_obs=eval_obs[0]))
                                   
        step_batch_size = None
    elif args.alg == 'drqn':
        from magent.builtin.tf_model import DeepRecurrentQNetwork
        models.append(magent.ProcessingModel(env, handles[0], names[0], 20001, 1000, DeepRecurrentQNetwork,
github geek-ai / MAgent / scripts / rename.py View on Github external
import sys

import magent
from magent.builtin.tf_model import DeepQNetwork

env = magent.GridWorld("battle", map_size=125)

handles = env.get_handles()

rounds = eval(sys.argv[1])

for i in [rounds]:
    model = DeepQNetwork(env, handles[0], "battle")
    print("load %d" % i)
    model.load("data/", i, "selfplay")
    print("save %d" % i)
    model.save("data/battle_model", i)