Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.
"""gather agent, rush to food according to minimap"""
import numpy as np
from magent.model import BaseModel
from magent.c_lib import _LIB, as_int32_c_array, as_float_c_array
class RushGatherer(BaseModel):
def __init__(self, env, handle, *args, **kwargs):
BaseModel.__init__(self, env, handle)
self.env = env
self.handle = handle
self.n_action = env.get_action_space(handle)
self.view_size = env.get_view_space(handle)
self.attack_base, self.view2attack = env.get_view2attack(handle)
def infer_action(self, states, *args, **kwargs):
obs_buf = as_float_c_array(states[0])
hp_buf = as_float_c_array(states[1])
n, height, width, n_channel = states[0].shape
buf = np.empty((n,), dtype=np.int32)
act_buf = as_int32_c_array(buf)
attack_base = self.attack_base
import os
import mxnet as mx
from magent.utility import has_gpu
from magent.model import BaseModel
class MXBaseModel(BaseModel):
def __init__(self, env, handle, name, subclass_name):
"""init a model
Parameters
----------
env: magent.Environment
handle: handle (ctypes.c_int32)
name: str
subclass_name: str
name of subclass
"""
BaseModel.__init__(self, env, handle)
self.name = name
self.subclass_name = subclass_name
def _get_ctx(self):
import os
import tensorflow as tf
from magent.model import BaseModel
class TFBaseModel(BaseModel):
"""base model for tensorflow model"""
def __init__(self, env, handle, name, subclass_name):
BaseModel.__init__(self, env, handle)
self.name = name
self.subclass_name = subclass_name
def save(self, dir_name, epoch):
"""save model to dir"""
if not os.path.exists(dir_name):
os.mkdir(dir_name)
dir_name = os.path.join(dir_name, self.name)
if not os.path.exists(dir_name):
os.mkdir(dir_name)
model_vars = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, self.name)
saver = tf.train.Saver(model_vars)
saver.save(self.sess, os.path.join(dir_name, (self.subclass_name + "_%d") % epoch))
"""deprecated"""
import ctypes
import numpy as np
from magent.model import BaseModel
from magent.c_lib import _LIB, as_int32_c_array, as_float_c_array
class RushPredator(BaseModel):
def __init__(self, env, handle, attack_handle, *args, **kwargs):
BaseModel.__init__(self, env, handle)
self.attack_channel = env.get_channel(attack_handle)
self.attack_base, self.view2attack = env.get_view2attack(handle)
print("attack_channel", self.attack_channel)
print("view2attack", self.view2attack)
def infer_action(self, observations, *args, **kwargs):
obs_buf = as_float_c_array(observations[0])
hp_buf = as_float_c_array(observations[1])
n, height, width, n_channel = observations[0].shape
buf = np.empty((n,), dtype=np.int32)
act_buf = as_int32_c_array(buf)
attack_channel = self.attack_channel
"""deprecated"""
import numpy as np
from magent.model import BaseModel
from magent.c_lib import _LIB, as_int32_c_array, as_float_c_array
class RunawayPrey(BaseModel):
def __init__(self, env, handle, away_handle, *args, **kwargs):
BaseModel.__init__(self, env, handle)
self.away_channel = env.get_channel(away_handle)
self.attack_base, _ = env.get_view2attack(handle)
self.move_back = 4
print("attack base", self.attack_base, "away", self.away_channel)
def infer_action(self, observations, *args, **kwargs):
obs_buf = as_float_c_array(observations[0])
hp_buf = as_float_c_array(observations[1])
n, height, width, n_channel = observations[0].shape
buf = np.empty((n,), dtype=np.int32)
act_buf = as_int32_c_array(buf)
_LIB.runaway_infer_action(obs_buf, hp_buf, n, height, width, n_channel,