How to use the delve.metrics.get_explained_variance function in delve

To help you get started, we’ve selected a few delve examples, based on popular ways it is used in public projects.

Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.

github delve-team / delve / delve / hooks.py View on Github external
def add_layer_saturation(layer: torch.nn.Module,
                         eig_vals: Optional[np.ndarray] = None,
                         n_iter: Optional[int] = None,
                         method='cumvar99'):
    training_state = get_training_state(layer)
    layer_type = layer._get_name().lower()

    if eig_vals is None:
        eig_vals = get_layer_prop(layer, f'{training_state}_eig_vals')
    if n_iter is None:
        n_iter = layer.forward_iter
    nr_eig_vals = get_explained_variance(eig_vals)

    layer_name = layer.name + (f'_{layer.conv_method}'
                               if layer_type == 'conv2d' else '')
    if method == 'cumvar99':
        saturation = get_layer_saturation(nr_eig_vals, layer.out_features)
        layer.writer.add_scalar(
            f'{training_state}-{layer_name}-percent_saturation-{method}',
            saturation, n_iter)
    elif method == 'simpson_di':
        saturation = get_eigenval_diversity_index(eig_vals)
        layer.writer.add_scalar(
            f'{training_state}-{layer_name}-percent_saturation-{method}',
            saturation, n_iter)
    elif method == 'all':
        cumvar99_saturation = get_layer_saturation(nr_eig_vals,
                                                   layer.out_features)