Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.
def _resource_apply_dense(self, grad, var):
var_dtype = var.dtype.base_dtype
gamma = self._get_decayed_hyper("gamma", var_dtype)
threshold = self._get_decayed_hyper("threshold", var_dtype)
m = self.get_slot(var, "m")
m_t = tf.compat.v1.assign(
m, (1 - gamma) * m + gamma * grad, use_locking=self._use_locking
)
var_t = lq.math.sign(-tf.sign(var * m_t - threshold) * var)
return tf.compat.v1.assign(var, var_t, use_locking=self._use_locking).op
# Returns
Binarized tensor.
# References
- [Bi-Real Net: Enhancing the Performance of 1-bit CNNs With Improved
Representational Capability and Advanced
Training Algorithm](http://arxiv.org/abs/1808.00278)
"""
def grad(dy):
abs_x = tf.math.abs(x)
zeros = tf.zeros_like(dy)
mask = tf.math.less_equal(abs_x, 1.0)
return tf.where(mask, (1 - abs_x) * 2 * dy, zeros)
return math.sign(x), grad
def _call(x):
def grad(dy):
return _clipped_gradient(x, dy, clip_value)
return math.sign(x), grad
ms = [
tf.keras.backend.zeros(
tf.keras.backend.int_shape(p), dtype=tf.keras.backend.dtype(p)
)
for p in params
]
fp_params = []
for p, g, m in zip(params, grads, ms):
if self.is_binary(p):
m_t = (1 - self.gamma) * m + self.gamma * g
self.updates.append(tf.assign(m, m_t))
self.updates.append(
tf.assign(p, lq.math.sign(-p * tf.sign(p * m_t - self.threshold)))
)
else:
fp_params.append(p)
return self.updates + self.fp_optimizer.get_updates(loss, fp_params)
def _magnitude_aware_sign(x):
return lq.math.sign(x) * scale_factor, lambda dy: dy
def _resource_apply_dense(self, grad, var):
var_dtype = var.dtype.base_dtype
gamma = self._get_decayed_hyper("gamma", var_dtype)
threshold = self._get_decayed_hyper("threshold", var_dtype)
m = self.get_slot(var, "m")
m_t = tf.compat.v1.assign(
m, (1 - gamma) * m + gamma * grad, use_locking=self._use_locking
)
var_t = lq.math.sign(-tf.sign(var * m_t - threshold) * var)
return tf.compat.v1.assign(var, var_t, use_locking=self._use_locking).op
def _resource_apply_dense(self, grad, var):
var_dtype = var.dtype.base_dtype
gamma = self._get_decayed_hyper("gamma", var_dtype)
threshold = self._get_decayed_hyper("threshold", var_dtype)
m = self.get_slot(var, "m")
m_t = m.assign_add(gamma * (grad - m))
var_t = lq.math.sign(-tf.sign(var * m_t - threshold) * var)
return var.assign(var_t).op
def _call(x):
def grad(dy):
return _clipped_gradient(x, dy, clip_value)
return math.heaviside(x), grad