Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.
ValueError, 'config error, expected thin = 2'
):
check_sampler_csv(
path=csv_good, iter_warmup=100, iter_sampling=20, thin=2
)
with self.assertRaisesRegex(
ValueError, 'config error, expected save_warmup'
):
check_sampler_csv(
path=csv_good,
iter_warmup=100,
iter_sampling=10,
save_warmup=True,
)
with self.assertRaisesRegex(ValueError, 'expected 1000 draws'):
check_sampler_csv(path=csv_good, iter_warmup=100)
self.assertEqual(dict['thin'], 7)
self.assertEqual(dict['draws_sampling'], 70)
self.assertEqual(dict['seed'], 12345)
self.assertEqual(dict['max_depth'], 11)
self.assertEqual(dict['delta'], 0.98)
with self.assertRaisesRegex(ValueError, 'config error'):
check_sampler_csv(
path=csv_file,
is_fixed_param=False,
iter_sampling=490,
iter_warmup=490,
thin=9,
)
with self.assertRaisesRegex(ValueError, 'expected 490 draws, found 70'):
check_sampler_csv(
path=csv_file,
is_fixed_param=False,
iter_sampling=490,
iter_warmup=490,
)
def test_check_sampler_csv_2(self):
csv_bad = os.path.join(DATAFILES_PATH, 'no_such_file.csv')
with self.assertRaises(Exception):
check_sampler_csv(csv_bad)
def test_check_sampler_csv_metric_1(self):
csv_bad = os.path.join(DATAFILES_PATH, 'output_bad_metric_1.csv')
with self.assertRaisesRegex(Exception, 'expecting metric'):
check_sampler_csv(csv_bad)
def test_check_sampler_csv_3(self):
csv_bad = os.path.join(DATAFILES_PATH, 'output_bad_cols.csv')
with self.assertRaisesRegex(Exception, '8 items'):
check_sampler_csv(csv_bad)
def test_check_sampler_csv_metric_3(self):
csv_bad = os.path.join(DATAFILES_PATH, 'output_bad_metric_3.csv')
with self.assertRaisesRegex(
Exception, 'invalid or missing mass matrix specification'
):
check_sampler_csv(csv_bad)
Populates attributes for draws, column_names, num_params, metric_type.
Raises exception when inconsistencies detected.
"""
dzero = {}
for i in range(self.runset.chains):
if i == 0:
dzero = check_sampler_csv(
path=self.runset.csv_files[i],
is_fixed_param=self._is_fixed_param,
iter_sampling=self._iter_sampling,
iter_warmup=self._iter_warmup,
save_warmup=self._save_warmup,
thin=self._thin,
)
else:
drest = check_sampler_csv(
path=self.runset.csv_files[i],
is_fixed_param=self._is_fixed_param,
iter_sampling=self._iter_sampling,
iter_warmup=self._iter_warmup,
save_warmup=self._save_warmup,
thin=self._thin,
)
for key in dzero:
if (
key not in ['id', 'diagnostic_file']
and dzero[key] != drest[key]
):
raise ValueError(
'csv file header mismatch, '
'file {}, key {} is {}, expected {}'.format(
self.runset.csv_files[i],
def _validate_csv_files(self) -> None:
"""
Checks that csv output files for all chains are consistent.
Populates attributes for draws, column_names, num_params, metric_type.
Raises exception when inconsistencies detected.
"""
dzero = {}
for i in range(self.runset.chains):
if i == 0:
dzero = check_sampler_csv(
path=self.runset.csv_files[i],
is_fixed_param=self._is_fixed_param,
iter_sampling=self._iter_sampling,
iter_warmup=self._iter_warmup,
save_warmup=self._save_warmup,
thin=self._thin,
)
else:
drest = check_sampler_csv(
path=self.runset.csv_files[i],
is_fixed_param=self._is_fixed_param,
iter_sampling=self._iter_sampling,
iter_warmup=self._iter_warmup,
save_warmup=self._save_warmup,
thin=self._thin,
)