Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.
@utils.set_precision(1)
def ste_sign(x, clip_value=1.0):
r"""Sign binarization function.
\\[
q(x) = \begin{cases}
-1 & x < 0 \\\
1 & x \geq 0
\end{cases}
\\]
The gradient is estimated using the Straight-Through Estimator
(essentially the binarization is replaced by a clipped identity on the
backward pass).
\\[\frac{\partial q(x)}{\partial x} = \begin{cases}
1 & \left|x\right| \leq \texttt{clip_value} \\\
0 & \left|x\right| > \texttt{clip_value}
@lq.utils.set_precision(1)
@lq.utils.register_keras_custom_object
def xnor_weight_scale(x):
"""
Clips the weights between -1 and +1 and then calculates a scale factor per
weight filter. See https://arxiv.org/abs/1603.05279 for more details
"""
x = tf.clip_by_value(x, -1, 1)
alpha = tf.reduce_mean(tf.abs(x), axis=[0, 1, 2], keepdims=True)
return alpha * lq.quantizers.ste_sign(x)
@utils.set_precision(2)
def dorefa_quantizer(x, k_bit=2):
r"""k_bit quantizer as in the DoReFa paper.
\\[
q(x) = \begin{cases}
0 & x < \frac{1}{2n} \\\
\frac{i}{n} & \frac{2i-1}{2n} < |x| < \frac{2i+1}{2n} \text{ for } i \in \\{1,n-1\\}\\\
1 & \frac{2n-1}{2n} < x
\end{cases}
\\]
where \\(n = 2^{\text{k_bit}} - 1\\). The number of bits, k_bit, needs to be passed as an argument.
The gradient is estimated using the Straight-Through Estimator
(essentially the binarization is replaced by a clipped identity on the
backward pass).
\\[\frac{\partial q(x)}{\partial x} = \begin{cases}
@lq.utils.set_precision(1)
def magnitude_aware_sign_unclipped(x):
"""
Scaled sign function with identity pseudo-gradient as used for the weights
in the DoReFa paper. The Scale factor is calculated per layer.
"""
scale_factor = tf.stop_gradient(tf.reduce_mean(tf.abs(x)))
@tf.custom_gradient
def _magnitude_aware_sign(x):
return lq.math.sign(x) * scale_factor, lambda dy: dy
return _magnitude_aware_sign(x)
@utils.set_precision(2)
def ste_tern(x, threshold_value=0.05, ternary_weight_networks=False, clip_value=1.0):
r"""Ternarization function.
\\[
q(x) = \begin{cases}
+1 & x > \Delta \\\
0 & |x| < \Delta \\\
-1 & x < - \Delta
\end{cases}
\\]
where \\(\Delta\\) is defined as the threshold and can be passed as an argument,
or can be calculated as per the Ternary Weight Networks original paper, such that
\\[
\Delta = \frac{0.7}{n} \sum_{i=1}^{n} |W_i|
@utils.set_precision(1)
@tf.custom_gradient
def approx_sign(x):
r"""
Sign binarization function.
\\[
q(x) = \begin{cases}
-1 & x < 0 \\\
1 & x \geq 0
\end{cases}
\\]
The gradient is estimated using the ApproxSign method.
\\[\frac{\partial q(x)}{\partial x} = \begin{cases}
(2 - 2 \left|x\right|) & \left|x\right| \leq 1 \\\
0 & \left|x\right| > 1
\end{cases}
@utils.set_precision(1)
def swish_sign(x, beta=5.0):
r"""Sign binarization function.
\\[
q(x) = \begin{cases}
-1 & x < 0 \\\
1 & x \geq 0
\end{cases}
\\]
The gradient is estimated using the SignSwish method.
\\[
\frac{\partial q_{\beta}(x)}{\partial x} = \frac{\beta\left\\{2-\beta x \tanh \left(\frac{\beta x}{2}\right)\right\\}}{1+\cosh (\beta x)}
\\]
@utils.set_precision(1)
def magnitude_aware_sign(x, clip_value=1.0):
r"""Magnitude-aware sign for Bi-Real Net.
A scaled sign function computed according to Section 3.3 in
[Zechun Liu et al](https://arxiv.org/abs/1808.00278).
```plot-activation
quantizers._scaled_sign
```
# Arguments
x: Input tensor
clip_value: Threshold for clipping gradients. If `None` gradients are not clipped.
# Returns
Scaled binarized tensor (with values in \\(\\{-a, a\\}\\), where \\(a\\) is a float).
@utils.set_precision(1)
def ste_heaviside(x, clip_value=1.0):
r"""
Binarization function with output values 0 and 1.
\\[
q(x) = \begin{cases}
+1 & x > 0 \\\
0 & x \leq 0
\end{cases}
\\]
The gradient is estimated using the Straight-Through Estimator
(essentially the binarization is replaced by a clipped identity on the
backward pass).
\\[\frac{\partial q(x)}{\partial x} = \begin{cases}