Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.
axis = [1, 2]
ch = shape[3]
new_shape = [1, 1, 1, ch]
else:
axis = [2, 3]
ch = shape[1]
new_shape = [1, ch, 1, 1]
assert ch is not None, "Input of InstanceNorm require known channel!"
mean, var = tf.nn.moments(x, axis, keep_dims=True)
if not use_affine:
return tf.divide(x - mean, tf.sqrt(var + epsilon), name='output')
beta = tf.get_variable('beta', [ch], initializer=tf.constant_initializer())
beta = tf.reshape(beta, new_shape)
if gamma_init is None:
gamma_init = tf.constant_initializer(1.0)
gamma = tf.get_variable('gamma', [ch], initializer=gamma_init)
gamma = tf.reshape(gamma, new_shape)
ret = tf.nn.batch_normalization(x, mean, var, beta, gamma, epsilon, name='output')
vh = ret.variables = VariableHolder()
if use_affine:
vh.gamma = gamma
vh.beta = beta
return ret
def _check_unused_regularization():
coll = tfv1.get_collection(tfv1.GraphKeys.REGULARIZATION_LOSSES)
unconsumed_reg = []
for c in coll:
if len(c.consumers()) == 0:
unconsumed_reg.append(c)
if unconsumed_reg:
logger.warn("The following tensors appear in REGULARIZATION_LOSSES collection but have no "
"consumers! You may have forgotten to add regularization to total cost.")
logger.warn("Unconsumed regularization: {}".format(', '.join([x.name for x in unconsumed_reg])))
def internal_update_bn_ema(xn, batch_mean, batch_var,
moving_mean, moving_var, decay):
update_op1 = moving_averages.assign_moving_average(
moving_mean, batch_mean, decay, zero_debias=False,
name='mean_ema_op')
update_op2 = moving_averages.assign_moving_average(
moving_var, batch_var, decay, zero_debias=False,
name='var_ema_op')
# When sync_statistics is True, always enable internal_update.
# Otherwise the update ops (only executed on main tower)
# will hang when some BatchNorm layers are unused (https://github.com/tensorpack/tensorpack/issues/1078)
with tf.control_dependencies([update_op1, update_op2]):
return tf.identity(xn, name='output')
Default weight initializer is variance_scaling_initializer(2.0).
Variable Names:
* ``W``: weights of shape [in_dim, out_dim]
* ``b``: bias
"""
if kernel_initializer is None:
if get_tf_version_tuple() <= (1, 12):
kernel_initializer = tf.contrib.layers.variance_scaling_initializer(2.0) # deprecated
else:
kernel_initializer = tf.keras.initializers.VarianceScaling(2.0, distribution='untruncated_normal')
inputs = batch_flatten(inputs)
with rename_get_variable({'kernel': 'W', 'bias': 'b'}):
layer = tf.layers.Dense(
units=units,
activation=activation,
use_bias=use_bias,
kernel_initializer=kernel_initializer,
bias_initializer=bias_initializer,
kernel_regularizer=kernel_regularizer,
bias_regularizer=bias_regularizer,
activity_regularizer=activity_regularizer,
_reuse=tf.get_variable_scope().reuse)
ret = layer.apply(inputs, scope=tf.get_variable_scope())
ret = tf.identity(ret, name='output')
ret.variables = VariableHolder(W=layer.kernel)
if use_bias:
ret.variables.b = layer.bias
return ret
gamma_init=None, data_format='channels_last'):
"""
Layer Normalization layer, as described in the paper:
`Layer Normalization `_.
Args:
x (tf.Tensor): a 4D or 2D tensor. When 4D, the layout should match data_format.
epsilon (float): epsilon to avoid divide-by-zero.
use_scale, use_bias (bool): whether to use the extra affine transformation or not.
"""
data_format = get_data_format(data_format, keras_mode=False)
shape = x.get_shape().as_list()
ndims = len(shape)
assert ndims in [2, 4]
mean, var = tf.nn.moments(x, list(range(1, len(shape))), keep_dims=True)
if data_format == 'NCHW':
chan = shape[1]
new_shape = [1, chan, 1, 1]
else:
chan = shape[-1]
new_shape = [1, 1, 1, chan]
if ndims == 2:
new_shape = [1, chan]
if use_bias:
beta = tf.get_variable('beta', [chan], initializer=tf.constant_initializer())
beta = tf.reshape(beta, new_shape)
else:
beta = tf.zeros([1] * ndims, name='beta')
if use_scale:
def start(self):
from ..compat import tfv1
self._sess = tfv1.get_default_session()
super(ShareSessionThread, self).start()
def process(self, grads):
"""
Process the symbolic gradients.
Args:
grads (list): list of (grad, var).
Returns:
list: processed gradients, with the same type as input.
"""
# reuse the old name_scope, if process() is called multiple times
if self._name_scope is None:
with tfv1.name_scope(type(self).__name__) as scope:
self._name_scope = scope
return self._process(grads)
else:
with tfv1.name_scope(self._name_scope):
return self._process(grads)
def MaxPooling(
inputs,
pool_size,
strides=None,
padding='valid',
data_format='channels_last'):
"""
Same as `tf.layers.MaxPooling2D`. Default strides is equal to pool_size.
"""
if strides is None:
strides = pool_size
layer = tf.layers.MaxPooling2D(pool_size, strides, padding=padding, data_format=data_format)
ret = layer.apply(inputs, scope=tf.get_variable_scope())
return tf.identity(ret, name='output')
tf_args['virtual_batch_size'] = virtual_batch_size
else:
assert virtual_batch_size is None, "Feature not supported in this version of TF!"
use_fp16 = inputs.dtype == tf.float16
if use_fp16:
# non-fused does not support fp16; fused does not support all layouts.
# we made our best guess here
tf_args['fused'] = True
layer = tf.layers.BatchNormalization(**tf_args)
xn = layer.apply(inputs, training=training, scope=tf.get_variable_scope())
# Add EMA variables to the correct collection
if ctx.is_main_training_tower:
for v in layer.non_trainable_variables:
if isinstance(v, tf.Variable):
tf.add_to_collection(tf.GraphKeys.MODEL_VARIABLES, v)
if not do_ema_update:
restore_collection(coll_bk)
if do_ema_update and ema_update == "internal":
# Implement "internal" update.
restore_collection(coll_bk)
assert layer.updates
with tf.control_dependencies(layer.updates):
ret = tf.identity(xn, name='output')
else:
ret = tf.identity(xn, name='output')
vh = ret.variables = VariableHolder(
moving_mean=layer.moving_mean,
mean=layer.moving_mean, # for backward-compatibility
moving_variance=layer.moving_variance,
def _keys_to_freeze(self):
# freeze UPDATE_OPS during inference because they should never be used
return [tf.GraphKeys.SUMMARIES, MOVING_SUMMARY_OPS_KEY, tf.GraphKeys.UPDATE_OPS]