Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.
def test_rload_data(self):
dfile = os.path.join(DATAFILES_PATH, 'rdump_test.data.R')
data_dict = rload(dfile)
self.assertEqual(data_dict['N'], 128)
self.assertEqual(data_dict['M'], 2)
self.assertEqual(data_dict['x'].shape, (128, 2))
def test_roundtrip_metric(self):
dfile = os.path.join(DATAFILES_PATH, 'metric_diag.data.R')
data_dict_1 = rload(dfile)
self.assertEqual(data_dict_1['inv_metric'].shape, (3,))
dfile_tmp = os.path.join(DATAFILES_PATH, 'tmp.data.R')
rdump(dfile_tmp, data_dict_1)
data_dict_2 = rload(dfile_tmp)
self.assertTrue('inv_metric' in data_dict_2)
for i, x in enumerate(data_dict_2['inv_metric']):
self.assertEqual(x, data_dict_2['inv_metric'][i])
os.remove(dfile_tmp)
def test_rload_wrong_data(self):
dfile = os.path.join(DATAFILES_PATH, 'metric_diag.data.json')
data_dict = rload(dfile)
self.assertEqual(data_dict, None)
def test_rload_bad_data_2(self):
dfile = os.path.join(DATAFILES_PATH, 'rdump_bad_2.data.R')
with self.assertRaises(ValueError):
rload(dfile)
def test_rload_jags_data(self):
dfile = os.path.join(DATAFILES_PATH, 'rdump_jags.data.R')
data_dict = rload(dfile)
self.assertEqual(data_dict['N'], 128)
self.assertEqual(data_dict['M'], 2)
self.assertEqual(data_dict['y'].shape, (128,))
def test_rload_bad_data_3(self):
dfile = os.path.join(DATAFILES_PATH, 'rdump_bad_3.data.R')
with self.assertRaises(ValueError):
rload(dfile)
def test_rload_metric(self):
dfile = os.path.join(DATAFILES_PATH, 'metric_diag.data.R')
data_dict = rload(dfile)
self.assertEqual(data_dict['inv_metric'].shape, (3,))
dfile = os.path.join(DATAFILES_PATH, 'metric_dense.data.R')
data_dict = rload(dfile)
self.assertEqual(data_dict['inv_metric'].shape, (3, 3))
def test_rload_bad_data_1(self):
dfile = os.path.join(DATAFILES_PATH, 'rdump_bad_1.data.R')
with self.assertRaises(ValueError):
rload(dfile)
def test_roundtrip_metric(self):
dfile = os.path.join(DATAFILES_PATH, 'metric_diag.data.R')
data_dict_1 = rload(dfile)
self.assertEqual(data_dict_1['inv_metric'].shape, (3,))
dfile_tmp = os.path.join(DATAFILES_PATH, 'tmp.data.R')
rdump(dfile_tmp, data_dict_1)
data_dict_2 = rload(dfile_tmp)
self.assertTrue('inv_metric' in data_dict_2)
for i, x in enumerate(data_dict_2['inv_metric']):
self.assertEqual(x, data_dict_2['inv_metric'][i])
os.remove(dfile_tmp)