Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.
return
for key in self.logs:
train_sats = []
val_sats = []
for i, layer_name in enumerate(self.logs[key]):
if layer_name in self.ignore_layer_names:
continue
if self.logs[key][layer_name]._cov_mtx is None:
raise ValueError("Attempting to compute intrinsic"
"dimensionality when covariance"
"is not initialized")
cov_mat = self.logs[key][layer_name].fix()
log_values = {}
for stat in self.stats:
if stat == 'lsat':
log_values[key.replace(STATMAP['cov'], STATMAP['lsat'])+'_'+layer_name] = compute_saturation(cov_mat, thresh=self.threshold)
elif stat == 'idim':
log_values[key.replace(STATMAP['cov'], STATMAP['idim'])+'_'+layer_name] = compute_intrinsic_dimensionality(cov_mat, thresh=self.threshold)
elif stat == 'cov':
log_values[key+'_'+layer_name] = cov_mat.cpu().numpy()
elif stat == 'det':
log_values[key.replace(STATMAP['cov'], STATMAP['det'])+'_'+layer_name] = compute_cov_determinant(cov_mat)
elif stat == 'trc':
log_values[key.replace(STATMAP['cov'], STATMAP['trc'])+'_'+layer_name] = compute_cov_trace(cov_mat)
elif stat == 'dtrc':
log_values[key.replace(STATMAP['cov'], STATMAP['dtrc'])+'_'+layer_name] = compute_diag_trace(cov_mat)
self.seen_samples[key.split('-')[0]][layer_name] = 0
if self.reset_covariance:
self.logs[key][layer_name]._cov_mtx = None
if self.layerwise_sat:
self.writer.add_scalars(prefix='', value_dict=log_values)
def _filter_by_stat_shortcuts(paths: List[str], stats: List[str], neg: bool = False) -> List[str]:
result = []
for stat in stats:
result += _filter_by_stat(paths, STATMAP[stat], neg)
return list(set(result))
if not isinstance(stats, list):
stats = list(stats)
supported_stats = [
'lsat',
'idim',
'cov',
'det',
'trc',
'dtrc',
]
compatible = [stat in supported_stats for stat in stats]
incompatible = [i for i, x in enumerate(compatible) if not x]
assert all(compatible), "Stat {} is not supported".format(
stats[incompatible[0]])
name_mapper = STATMAP
logs = {
f'{mode}-{name_mapper[stat]}': OrderedDict()
for mode, stat in product(['train', 'eval'], ['cov'])
}
return logs, stats