Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.
###############################################################################
def eval_model_multithread(pred, nr_eval, get_player_fn):
"""
Args:
pred (OfflinePredictor): state -> Qvalue
"""
NR_PROC = min(multiprocessing.cpu_count() // 2, 8)
with pred.sess.as_default():
mean_score, max_score, mean_dist, max_dist = eval_with_funcs([pred] * NR_PROC, nr_eval, get_player_fn)
logger.info("Average Score: {}; Max Score: {}; Average Distance: {}; Max Distance: {}".format(mean_score, max_score, mean_dist, max_dist))
###############################################################################
class Evaluator(Callback):
def __init__(self, nr_eval, input_names, output_names,
get_player_fn, directory, files_list = None):
self.directory = directory
self.files_list = files_list
self.eval_episode = nr_eval
self.input_names = input_names
self.output_names = output_names
self.get_player_fn = get_player_fn
def _setup_graph(self):
NR_PROC = min(multiprocessing.cpu_count() // 2, 20)
self.pred_funcs = [self.trainer.get_predictor(
self.input_names, self.output_names)] * NR_PROC
def _trigger(self):
# File: misc.py
import numpy as np
import os
import time
from collections import deque
from ..utils import logger
from ..utils.utils import humanize_time_delta
from .base import Callback
__all__ = ['SendStat', 'InjectShell', 'EstimatedTimeLeft']
class SendStat(Callback):
""" An equivalent of :class:`SendMonitorData`, but as a normal callback. """
def __init__(self, command, names):
self.command = command
if not isinstance(names, list):
names = [names]
self.names = names
def _trigger(self):
M = self.trainer.monitors
v = {k: M.get_latest(k) for k in self.names}
cmd = self.command.format(**v)
ret = os.system(cmd)
if ret != 0:
logger.error("Command {} failed with ret={}!".format(cmd, ret))
# File: summary.py
import numpy as np
from collections import deque
from ..compat import tfv1 as tf
from ..tfutils.common import get_op_tensor_name
from ..utils import logger
from ..utils.naming import MOVING_SUMMARY_OPS_KEY
from .base import Callback
__all__ = ['MovingAverageSummary', 'MergeAllSummaries', 'SimpleMovingAverage']
class MovingAverageSummary(Callback):
"""
Maintain the moving average of summarized tensors in every step,
by ops added to the collection.
Note that it only **maintains** the moving averages by updating
the relevant variables in the graph,
the actual summary should be done in other callbacks.
This callback is one of the :func:`DEFAULT_CALLBACKS()`.
"""
def __init__(self, collection=MOVING_SUMMARY_OPS_KEY, train_op=None):
"""
Args:
collection(str): the collection of EMA-maintaining ops.
The default value would work with
the tensors you added by :func:`tfutils.summary.add_moving_summary()`,
but you can use other collections as well.
else:
if isinstance(self._train_op, tf.Tensor):
self._train_op = self._train_op.op
if not isinstance(self._train_op, tf.Operation):
self._train_op = self.graph.get_operation_by_name(self._train_op)
self._train_op._add_control_inputs(ops)
logger.info("[MovingAverageSummary] {} operations in collection '{}'"
" will be run together with operation '{}'.".format(
len(ops), self._collection, self._train_op.name))
def _before_run(self, _):
if self._train_op is None:
return self._fetch
class MergeAllSummaries_RunAlone(Callback):
def __init__(self, period, key):
self._period = period
self._key = key
def _setup_graph(self):
size = len(tf.get_collection(self._key))
logger.info("Summarizing collection '{}' of size {}.".format(self._key, size))
self.summary_op = tf.summary.merge_all(self._key)
def _trigger_step(self):
if self._period:
if (self.local_step + 1) % self._period == 0:
self._trigger()
def _trigger(self):
if self.summary_op:
size = len(tf.get_collection(self._key))
logger.info("Summarizing collection '{}' of size {}.".format(self._key, size))
self.summary_op = tf.summary.merge_all(self._key)
def _trigger_step(self):
if self._period:
if (self.local_step + 1) % self._period == 0:
self._trigger()
def _trigger(self):
if self.summary_op:
summary = self.summary_op.eval()
self.trainer.monitors.put_summary(summary)
class MergeAllSummaries_RunWithOp(Callback):
def __init__(self, period, key):
self._period = period
self._key = key
def _setup_graph(self):
size = len(tf.get_collection(self._key))
logger.info("Summarizing collection '{}' of size {}.".format(self._key, size))
self.summary_op = tf.summary.merge_all(self._key)
if self.summary_op is not None:
self._fetches = tf.train.SessionRunArgs(self.summary_op)
else:
self._fetches = None
def _need_run(self):
if self.local_step == self.trainer.steps_per_epoch - 1:
return True
collection_list=self.graph.get_all_collection_keys())
def _trigger(self):
try:
self.saver.save(
tf.get_default_session(),
self.path,
global_step=tf.train.get_global_step(),
write_meta_graph=False)
logger.info("Model saved to %s." % tf.train.get_checkpoint_state(self.checkpoint_dir).model_checkpoint_path)
except (OSError, IOError, tf.errors.PermissionDeniedError,
tf.errors.ResourceExhaustedError): # disk error sometimes.. just ignore it
logger.exception("Exception in ModelSaver!")
class MinSaver(Callback):
"""
Separately save the model with minimum value of some statistics.
"""
def __init__(self, monitor_stat, reverse=False, filename=None, checkpoint_dir=None):
"""
Args:
monitor_stat(str): the name of the statistics.
reverse (bool): if True, will save the maximum.
filename (str): the name for the saved model.
Defaults to ``min-{monitor_stat}.tfmodel``.
checkpoint_dir (str): the directory containing checkpoints.
Example:
Save the model with minimum validation error to
"min-val-error.tfmodel":
#!/usr/bin/env python
# -*- coding: utf-8 -*-
# File: graph.py
# Author: Yuxin Wu
""" Graph related callbacks"""
from .base import Callback
from ..utils import logger
__all__ = ['RunOp']
class RunOp(Callback):
""" Run an op periodically"""
def __init__(self, setup_func, run_before=True, run_epoch=True):
"""
:param setup_func: a function that returns the op in the graph
:param run_before: run the op before training
:param run_epoch: run the op on every epoch trigger
"""
self.setup_func = setup_func
self.run_before = run_before
self.run_epoch = run_epoch
def _setup_graph(self):
self._op = self.setup_func()
#self._op_name = self._op.name
def _before_train(self):
You shouldn't need to use this.
"""
def __init__(self, cb):
self._cb = cb
@HIDE_DOC
def before_run(self, ctx):
return self._cb.before_run(ctx)
@HIDE_DOC
def after_run(self, ctx, vals):
self._cb.after_run(ctx, vals)
class HookToCallback(Callback):
"""
Make a ``tf.train.SessionRunHook`` into a callback.
Note that when ``SessionRunHook.after_create_session`` is called, the ``coord`` argument will be None.
"""
_chief_only = False
def __init__(self, hook):
"""
Args:
hook (tf.train.SessionRunHook):
"""
self._hook = hook
def _setup_graph(self):
with tf.name_scope(None): # jump out of the name scope
import threading
from tqdm import tqdm
import six
from six.moves import queue
from ..dataflow import DataFlow
from ..utils import *
from ..utils.concurrency import LoopThread
from ..callbacks.base import Callback
__all__ = ['ExpReplay']
Experience = namedtuple('Experience',
['state', 'action', 'reward', 'isOver'])
class ExpReplay(DataFlow, Callback):
"""
Implement experience replay in the paper
`Human-level control through deep reinforcement learning`.
This implementation provides the interface as an DataFlow.
This DataFlow is not fork-safe (doesn't support multiprocess prefetching)
"""
def __init__(self,
predictor_io_names,
player,
batch_size=32,
memory_size=1e6,
init_memory_size=50000,
exploration=1,
end_exploration=0.1,
exploration_epoch_anneal=0.002,
clear_extraneous_savers=True)
def _trigger(self):
try:
self.saver.save(
tf.get_default_session(),
self.path,
global_step=tf.train.get_global_step(),
write_meta_graph=False)
logger.info("Model saved to %s." % tf.train.get_checkpoint_state(self.checkpoint_dir).model_checkpoint_path)
except (OSError, IOError, tf.errors.PermissionDeniedError,
tf.errors.ResourceExhaustedError): # disk error sometimes.. just ignore it
logger.exception("Exception in ModelSaver!")
class MinSaver(Callback):
"""
Separately save the model with minimum value of some statistics.
"""
def __init__(self, monitor_stat, reverse=False, filename=None, checkpoint_dir=None):
"""
Args:
monitor_stat(str): the name of the statistics.
reverse (bool): if True, will save the maximum.
filename (str): the name for the saved model.
Defaults to ``min-{monitor_stat}.tfmodel``.
checkpoint_dir (str): the directory containing checkpoints.
Example:
Save the model with minimum validation error to
"min-val-error.tfmodel":