How to use the magent.builtin.tf_model.DeepRecurrentQNetwork function in magent

To help you get started, we’ve selected a few magent examples, based on popular ways it is used in public projects.

Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.

github geek-ai / MAgent / examples / train_against.py View on Github external
models[0].load("data/battle_model", args.opponent)
    else:
        models.append(magent.ProcessingModel(env, handles[1], names[1], 20000, 0, RandomActor))

    # load our model
    if args.alg == 'dqn':
        from magent.builtin.tf_model import DeepQNetwork
        models.append(magent.ProcessingModel(env, handles[0], names[0], 20001, 1000, DeepQNetwork,
                                   batch_size=batch_size,
                                   learning_rate=3e-4,
                                   memory_size=2 ** 20, train_freq=train_freq, eval_obs=eval_obs[0]))
                                   
        step_batch_size = None
    elif args.alg == 'drqn':
        from magent.builtin.tf_model import DeepRecurrentQNetwork
        models.append(magent.ProcessingModel(env, handles[0], names[0], 20001, 1000, DeepRecurrentQNetwork,
                                   batch_size=batch_size/unroll_step, unroll_step=unroll_step,
                                   learning_rate=3e-4,
                                   memory_size=4 * 625, train_freq=train_freq, eval_obs=eval_obs[0]))
        step_batch_size = None
    elif args.alg == 'a2c':
        from magent.builtin.mx_model import AdvantageActorCritic
        step_batch_size = 10 * args.map_size * args.map_size * 0.04
        models.append(magent.ProcessingModel(env, handles[0], names[0], 20001, 1000, AdvantageActorCritic,
                                             learning_rate=1e-3))

    # load if
    savedir = 'save_model'
    if args.load_from is not None:
        start_from = args.load_from
        print("load ... %d" % start_from)
        models[0].load(savedir, start_from)
github geek-ai / MAgent / examples / train_tiger.py View on Github external
models = [
        RandomActor(env, deer_handle, tiger_handle),
    ]

    batch_size = 512
    unroll     = 8

    if args.alg == 'dqn':
        from magent.builtin.tf_model import DeepQNetwork
        models.append(DeepQNetwork(env, tiger_handle, "tiger",
                                   batch_size=batch_size,
                                   memory_size=2 ** 20, learning_rate=4e-4))
        step_batch_size = None
    elif args.alg == 'drqn':
        from magent.builtin.tf_model import DeepRecurrentQNetwork
        models.append(DeepRecurrentQNetwork(env, tiger_handle, "tiger",
                                   batch_size=batch_size/unroll, unroll_step=unroll,
                                   memory_size=20000, learning_rate=4e-4))
        step_batch_size = None
    elif args.alg == 'a2c':
        from magent.builtin.mx_model import AdvantageActorCritic
        step_batch_size = int(10 * args.map_size * args.map_size*0.01)
        models.append(AdvantageActorCritic(env, tiger_handle, "tiger",
                                   batch_size=step_batch_size,
                                   learning_rate=1e-2))
    else:
        raise NotImplementedError

    # load if
    savedir = 'save_model'
    if args.load_from is not None:
        start_from = args.load_from
github geek-ai / MAgent / examples / train_single.py View on Github external
batch_size = 512
    unroll_step = 8
    target_update = 1200
    train_freq = 5

    models = []
    if args.alg == 'dqn':
        from magent.builtin.tf_model import DeepQNetwork
        models.append(DeepQNetwork(env, handles[0], args.name,
                                   batch_size=batch_size,
                                   learning_rate=3e-4,
                                   memory_size=2 ** 21, target_update=target_update,
                                   train_freq=train_freq, eval_obs=eval_obs))
    elif args.alg == 'drqn':
        from magent.builtin.tf_model import DeepRecurrentQNetwork
        models.append(DeepRecurrentQNetwork(env, handles[0], args.name,
                                   learning_rate=3e-4,
                                   batch_size=batch_size/unroll_step, unroll_step=unroll_step,
                                   memory_size=2 * 8 * 625, target_update=target_update,
                                   train_freq=train_freq, eval_obs=eval_obs))
    else:
        # see train_against.py to know how to use a2c
        raise NotImplementedError

    models.append(models[0])

    # load if
    savedir = 'save_model'
    if args.load_from is not None:
        start_from = args.load_from
        print("load ... %d" % start_from)
        for model in models:
github geek-ai / MAgent / examples / train_battle.py View on Github external
# load models
    batch_size = 256
    unroll_step = 8
    target_update = 1200
    train_freq = 5

    if args.alg == 'dqn':
        from magent.builtin.tf_model import DeepQNetwork
        RLModel = DeepQNetwork
        base_args = {'batch_size': batch_size,
                     'memory_size': 2 ** 20, 'learning_rate': 1e-4,
                     'target_update': target_update, 'train_freq': train_freq}
    elif args.alg == 'drqn':
        from magent.builtin.tf_model import DeepRecurrentQNetwork
        RLModel = DeepRecurrentQNetwork
        base_args = {'batch_size': batch_size / unroll_step, 'unroll_step': unroll_step,
                     'memory_size': 8 * 625, 'learning_rate': 1e-4,
                     'target_update': target_update, 'train_freq': train_freq}
    elif args.alg == 'a2c':
        # see train_against.py to know how to use a2c
        raise NotImplementedError

    # init models
    names = [args.name + "-l", args.name + "-r"]
    models = []

    for i in range(len(names)):
        model_args = {'eval_obs': eval_obs[i]}
        model_args.update(base_args)
        models.append(magent.ProcessingModel(env, handles[i], names[i], 20000+i, 1000, RLModel, **model_args))
github geek-ai / MAgent / examples / train_multi.py View on Github external
# load models
    batch_size = 256
    unroll_step = 8
    target_update = 1000
    train_freq = 5

    if args.alg == 'dqn':
        from magent.builtin.tf_model import DeepQNetwork
        RLModel = DeepQNetwork
        base_args = {'batch_size': batch_size,
                     'memory_size': 2 ** 20,
                     'target_update': target_update, 'train_freq': train_freq}
    elif args.alg == 'drqn':
        from magent.builtin.tf_model import DeepRecurrentQNetwork
        RLModel = DeepRecurrentQNetwork
        base_args = {'batch_size': batch_size / unroll_step, 'unroll_step': unroll_step,
                     'memory_size': 8 * 300,
                     'target_update': target_update, 'train_freq': train_freq}
    elif args.alg == 'a2c':
        raise NotImplementedError
    else:
        raise NotImplementedError

    # load models
    names = [args.name + "-l0", args.name + "-l1", args.name + "-r0", args.name + "-r1"]
    models = []

    for i in range(len(names)):
        model_args = {'eval_obs': eval_obs[i]}
        model_args.update(base_args)
        models.append(magent.ProcessingModel(env, handles[i], names[i], 20000+i, 1000, RLModel, **model_args))
github geek-ai / MAgent / examples / train_battle_game.py View on Github external
for i in range(len(handles)):
            eval_obs[i] = magent.utility.sample_observation(env, handles, 2048, 500)

    # load models
    batch_size = 256
    unroll_step = 8
    target_update = 1200
    train_freq = 5

    if args.alg == 'dqn':
        RLModel = DeepQNetwork
        base_args = {'batch_size': batch_size,
                     'memory_size': 2 ** 21, 'learning_rate': 1e-4,
                     'target_update': target_update, 'train_freq': train_freq}
    elif args.alg == 'drqn':
        RLModel = DeepRecurrentQNetwork
        base_args = {'batch_size': batch_size / unroll_step, 'unroll_step': unroll_step,
                     'memory_size': 8 * 625, 'learning_rate': 1e-4,
                     'target_update': target_update, 'train_freq': train_freq}
    elif args.alg == 'a2c':
        raise NotImplementedError
    else:
        raise NotImplementedError

    # init models
    names = [args.name + "-l", args.name + "-r"]
    models = []

    for i in range(len(names)):
        model_args = {'eval_obs': eval_obs[i]}
        model_args.update(base_args)
        models.append(magent.ProcessingModel(env, handles[i], names[i], 20000, 1000, RLModel, **model_args))