Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.
shared_name = re.sub('tower[0-9]+/', '', tf.get_variable_scope().name)
batch_mean = gen_nccl_ops.nccl_all_reduce(
input=batch_mean,
reduction='sum',
num_devices=num_dev,
shared_name=shared_name + '_NCCL_mean') * (1.0 / num_dev)
batch_mean_square = gen_nccl_ops.nccl_all_reduce(
input=batch_mean_square,
reduction='sum',
num_devices=num_dev,
shared_name=shared_name + '_NCCL_mean_square') * (1.0 / num_dev)
elif sync_statistics == 'horovod':
# Require https://github.com/uber/horovod/pull/331
import horovod.tensorflow as hvd
if hvd.size() == 1:
logger.warn("BatchNorm(sync_statistics='horovod') is used with only one process!")
else:
import horovod
hvd_version = tuple(map(int, horovod.__version__.split('.')[:3]))
assert hvd_version >= (0, 13, 6), "sync_statistics=horovod needs horovod>=0.13.6 !"
batch_mean = hvd.allreduce(batch_mean, average=True)
batch_mean_square = hvd.allreduce(batch_mean_square, average=True)
batch_var = batch_mean_square - tf.square(batch_mean)
batch_mean_vec = batch_mean
batch_var_vec = batch_var
beta, gamma, moving_mean, moving_var = get_bn_variables(
num_chan, scale, center, beta_initializer, gamma_initializer)
if new_shape is not None:
batch_mean = tf.reshape(batch_mean, new_shape)
batch_var = tf.reshape(batch_var, new_shape)
def maybe_download_and_extract(dest_directory, cifar_classnum):
"""Download and extract the tarball from Alex's website.
copied from tensorflow example """
assert cifar_classnum == 10 or cifar_classnum == 100
if cifar_classnum == 10:
cifar_foldername = 'cifar-10-batches-py'
else:
cifar_foldername = 'cifar-100-python'
if os.path.isdir(os.path.join(dest_directory, cifar_foldername)):
logger.info("Found cifar{} data in {}.".format(cifar_classnum, dest_directory))
return
else:
DATA_URL = DATA_URL_CIFAR_10 if cifar_classnum == 10 else DATA_URL_CIFAR_100
download(DATA_URL, dest_directory)
filename = DATA_URL.split('/')[-1]
filepath = os.path.join(dest_directory, filename)
import tarfile
tarfile.open(filepath, 'r:gz').extractall(dest_directory)
def load_existing_json():
"""
Look for an existing json under :meth:`logger.get_logger_dir()` named "stats.json",
and return the loaded list of statistics if found. Returns None otherwise.
"""
dir = logger.get_logger_dir()
fname = os.path.join(dir, JSONWriter.FILENAME)
if tf.gfile.Exists(fname):
with open(fname) as f:
stats = json.load(f)
assert isinstance(stats, list), type(stats)
return stats
return None
def _save(self):
ckpt = tf.train.get_checkpoint_state(logger.LOG_DIR)
if ckpt is None:
raise RuntimeError(
"Cannot find a checkpoint state. Do you forget to use ModelSaver?")
path = ckpt.model_checkpoint_path
newname = os.path.join(logger.LOG_DIR,
self.filename or
('max-' if self.reverse else 'min-' + self.monitor_stat + '.tfmodel'))
shutil.copy(path, newname)
logger.info("Model with {} '{}' saved.".format(
'maximum' if self.reverse else 'minimum', self.monitor_stat))
def __getattr__(self, layer_name):
layer = eval(layer_name)
if hasattr(layer, 'f'):
# this is a registered tensorpack layer
if layer.use_scope:
def f(name, *args, **kwargs):
ret = layer(name, self._t, *args, **kwargs)
return LinearWrap(ret)
else:
def f(*args, **kwargs):
ret = layer(self._t, *args, **kwargs)
return LinearWrap(ret)
return f
else:
if layer_name != 'tf':
logger.warn("You're calling LinearWrap.__getattr__ with something neither a layer nor 'tf'!")
assert isinstance(layer, ModuleType)
return LinearWrap.TFModuleFunc(layer, self._t)
from ..callbacks import Callback
from ..tfutils.varmanip import SessionUpdate
from ..predict import OfflinePredictor
from ..utils import logger
from ..utils.timer import *
from ..utils.serialize import *
from ..utils.concurrency import *
__all__ = ['SimulatorProcess', 'SimulatorMaster',
'SimulatorProcessStateExchange', 'SimulatorProcessSharedWeight',
'TransitionExperience', 'WeightSync']
try:
import zmq
except ImportError:
logger.warn("Error in 'import zmq'. RL simulator won't be available.")
__all__ = []
class TransitionExperience(object):
""" A transition of state, or experience"""
def __init__(self, state, action, reward, **kwargs):
""" kwargs: whatever other attribute you want to save"""
self.state = state
self.action = action
self.reward = reward
for k, v in six.iteritems(kwargs):
setattr(self, k, v)
class SimulatorProcessBase(mp.Process):
__metaclass__ = ABCMeta
def __init__(self, idx):
def __init__(self, logdir=None, max_queue=10, flush_secs=120, split_files=False):
"""
Args:
logdir: ``logger.get_logger_dir()`` by default.
max_queue, flush_secs: Same as in :class:`tf.summary.FileWriter`.
split_files: if True, split events to multiple files rather than
append to a single file. Useful on certain filesystems where append is expensive.
"""
if logdir is None:
logdir = logger.get_logger_dir()
assert tf.gfile.IsDirectory(logdir), logdir
self._logdir = logdir
self._max_queue = max_queue
self._flush_secs = flush_secs
self._split_files = split_files
inputs = tf.reshape(inputs, [-1, 1, 1, n_out]) # fused_bn only takes 4D input
# fused_bn has error using NCHW? (see #190)
xn, batch_mean, batch_var = tf.nn.fused_batch_norm(
inputs, gamma, beta, epsilon=epsilon,
is_training=True, data_format=data_format)
if ndims == 2:
xn = tf.squeeze(xn, [1, 2])
else:
if ctx.is_training:
assert get_tf_version_tuple() >= (1, 4), \
"Fine tuning a BatchNorm model with fixed statistics is only " \
"supported after https://github.com/tensorflow/tensorflow/pull/12580 "
if ctx.is_main_training_tower: # only warn in first tower
logger.warn("[BatchNorm] Using moving_mean/moving_variance in training.")
# Using moving_mean/moving_variance in training, which means we
# loaded a pre-trained BN and only fine-tuning the affine part.
xn, _, _ = tf.nn.fused_batch_norm(
inputs, gamma, beta,
mean=moving_mean, variance=moving_var, epsilon=epsilon,
data_format=data_format, is_training=False)
else:
if ndims == 4:
xn, _, _ = tf.nn.fused_batch_norm(
inputs, gamma, beta,
mean=moving_mean, variance=moving_var, epsilon=epsilon,
data_format=data_format, is_training=False)
else:
xn = tf.nn.batch_normalization(
inputs, moving_mean, moving_var, beta, gamma, epsilon)
data_train = AugmentImageComponent(data_train, augmentors)
data_train = BatchData(data_train, 128)
data_train = MultiProcessRunner(data_train, 5, 5)
augmentors = [imgaug.Resize((40, 40))]
data_test = AugmentImageComponent(data_test, augmentors)
data_test = BatchData(data_test, 128, remainder=True)
return data_train, data_test
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--prob', help='disturb prob', type=float, required=True)
args = parser.parse_args()
logger.auto_set_dir()
data_train, data_test = get_data()
config = TrainConfig(
model=Model(),
data=QueueInput(data_train),
callbacks=[
ModelSaver(),
InferenceRunner(data_test,
ScalarStats(['cost', 'accuracy']))
],
max_epoch=350,
)
launch_train_with_config(config, SimpleTrainer())
def process(self):
for idx, layer in enumerate(self.net.layers):
param = layer.blobs
name = self.layer_names[idx]
if layer.type in self.processors:
logger.info("Processing layer {} of type {}".format(
name, layer.type))
dic = self.processors[layer.type](idx, name, param)
self.param_dict.update(dic)
elif len(layer.blobs) != 0:
logger.warn(
"{} layer contains parameters but is not supported!".format(layer.type))
return self.param_dict