How to use the magent.model.BaseModel function in magent

To help you get started, we’ve selected a few magent examples, based on popular ways it is used in public projects.

Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.

github geek-ai / MAgent / python / magent / builtin / rule_model / rushgather.py View on Github external
"""gather agent, rush to food according to minimap"""

import numpy as np

from magent.model import BaseModel
from magent.c_lib import _LIB, as_int32_c_array, as_float_c_array


class RushGatherer(BaseModel):
    def __init__(self, env, handle, *args, **kwargs):
        BaseModel.__init__(self, env, handle)

        self.env = env
        self.handle = handle
        self.n_action = env.get_action_space(handle)
        self.view_size = env.get_view_space(handle)
        self.attack_base, self.view2attack = env.get_view2attack(handle)

    def infer_action(self, states, *args, **kwargs):
        obs_buf = as_float_c_array(states[0])
        hp_buf  = as_float_c_array(states[1])
        n, height, width, n_channel = states[0].shape
        buf = np.empty((n,), dtype=np.int32)
        act_buf = as_int32_c_array(buf)
        attack_base = self.attack_base
github geek-ai / MAgent / python / magent / builtin / mx_model / base.py View on Github external
import os
import mxnet as mx

from magent.utility import has_gpu
from magent.model import BaseModel


class MXBaseModel(BaseModel):
    def __init__(self, env, handle, name, subclass_name):
        """init a model

        Parameters
        ----------
        env: magent.Environment
        handle: handle (ctypes.c_int32)
        name: str
        subclass_name: str
            name of subclass
        """
        BaseModel.__init__(self, env, handle)
        self.name = name
        self.subclass_name = subclass_name

    def _get_ctx(self):
github geek-ai / MAgent / python / magent / builtin / tf_model / base.py View on Github external
import os
import tensorflow as tf

from magent.model import BaseModel


class TFBaseModel(BaseModel):
    """base model for tensorflow model"""
    def __init__(self, env, handle, name, subclass_name):
        BaseModel.__init__(self, env, handle)
        self.name = name
        self.subclass_name = subclass_name

    def save(self, dir_name, epoch):
        """save model to dir"""
        if not os.path.exists(dir_name):
            os.mkdir(dir_name)
        dir_name = os.path.join(dir_name, self.name)
        if not os.path.exists(dir_name):
            os.mkdir(dir_name)
        model_vars = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, self.name)
        saver = tf.train.Saver(model_vars)
        saver.save(self.sess, os.path.join(dir_name, (self.subclass_name + "_%d") % epoch))
github geek-ai / MAgent / python / magent / builtin / rule_model / rush.py View on Github external
"""deprecated"""

import ctypes
import numpy as np

from magent.model import BaseModel
from magent.c_lib import _LIB, as_int32_c_array, as_float_c_array


class RushPredator(BaseModel):
    def __init__(self, env, handle, attack_handle, *args, **kwargs):
        BaseModel.__init__(self, env, handle)

        self.attack_channel = env.get_channel(attack_handle)
        self.attack_base, self.view2attack = env.get_view2attack(handle)

        print("attack_channel", self.attack_channel)
        print("view2attack", self.view2attack)

    def infer_action(self, observations, *args, **kwargs):
        obs_buf = as_float_c_array(observations[0])
        hp_buf  = as_float_c_array(observations[1])
        n, height, width, n_channel = observations[0].shape
        buf = np.empty((n,), dtype=np.int32)
        act_buf = as_int32_c_array(buf)
        attack_channel = self.attack_channel
github geek-ai / MAgent / python / magent / builtin / rule_model / runaway.py View on Github external
"""deprecated"""

import numpy as np

from magent.model import BaseModel
from magent.c_lib import _LIB, as_int32_c_array, as_float_c_array


class RunawayPrey(BaseModel):
    def __init__(self, env, handle, away_handle, *args, **kwargs):
        BaseModel.__init__(self, env, handle)

        self.away_channel = env.get_channel(away_handle)
        self.attack_base, _ = env.get_view2attack(handle)
        self.move_back = 4

        print("attack base", self.attack_base, "away", self.away_channel)

    def infer_action(self, observations, *args, **kwargs):
        obs_buf = as_float_c_array(observations[0])
        hp_buf  = as_float_c_array(observations[1])
        n, height, width, n_channel = observations[0].shape
        buf = np.empty((n,), dtype=np.int32)
        act_buf = as_int32_c_array(buf)
        _LIB.runaway_infer_action(obs_buf, hp_buf, n, height, width, n_channel,