Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.
def test_ctgan_numpy():
data = pd.DataFrame({
'continuous': np.random.random(100),
'discrete': np.random.choice(['a', 'b', 'c'], 100)
})
discrete_columns = [1]
ctgan = CTGANSynthesizer()
ctgan.fit(data.values, discrete_columns, epochs=1)
sampled = ctgan.sample(100)
assert sampled.shape == (100, 2)
assert isinstance(sampled, np.ndarray)
assert set(np.unique(sampled[:, 1])) == {'a', 'b', 'c'}
def test_ctgan_dataframe():
data = pd.DataFrame({
'continuous': np.random.random(100),
'discrete': np.random.choice(['a', 'b', 'c'], 100)
})
discrete_columns = ['discrete']
ctgan = CTGANSynthesizer()
ctgan.fit(data, discrete_columns, epochs=1)
sampled = ctgan.sample(100)
assert sampled.shape == (100, 2)
assert isinstance(sampled, pd.DataFrame)
assert set(sampled.columns) == {'continuous', 'discrete'}
assert set(sampled['discrete'].unique()) == {'a', 'b', 'c'}
def main():
args = _parse_args()
if args.tsv:
data, discrete_columns = read_tsv(args.data, args.metadata)
else:
data, discrete_columns = read_csv(
args.data, args.metadata, args.header, args.discrete)
gen_dims = [int(x) for x in args.gen_dims.split(',')]
dis_dims = [int(x) for x in args.dis_dims.split(',')]
model = CTGANSynthesizer(
z_dim=args.z_dim, gen_dims=gen_dims, dis_dims=dis_dims,
gen_lr=args.gen_lr, gen_decay=args.gen_decay,
dis_lr=args.dis_lr, dis_decay=args.dis_decay,
batch_size=args.bs)
model.fit(data, discrete_columns, args.epochs)
num_samples = args.num_samples or len(data)
sampled = model.sample(num_samples)
if args.tsv:
write_tsv(sampled, args.metadata, args.output)
else:
sampled.to_csv(args.output, index=False)