Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.
data_sampler = Sampler(train_data, self.transformer.output_info)
data_dim = self.transformer.output_dimensions
self.cond_generator = ConditionalGenerator(
train_data,
self.transformer.output_info,
log_frequency
)
self.generator = Generator(
self.embedding_dim + self.cond_generator.n_opt,
self.gen_dim,
data_dim
).to(self.device)
discriminator = Discriminator(
data_dim + self.cond_generator.n_opt,
self.dis_dim
).to(self.device)
optimizerG = optim.Adam(
self.generator.parameters(), lr=2e-4, betas=(0.5, 0.9),
weight_decay=self.l2scale
)
optimizerD = optim.Adam(discriminator.parameters(), lr=2e-4, betas=(0.5, 0.9))
assert self.batch_size % 2 == 0
mean = torch.zeros(self.batch_size, self.embedding_dim, device=self.device)
std = mean + 1
steps_per_epoch = max(len(train_data) // self.batch_size, 1)
for i in range(epochs):
def __init__(self, input_dim, dis_dims, pack=10):
super(Discriminator, self).__init__()
dim = input_dim * pack
self.pack = pack
self.packdim = dim
seq = []
for item in list(dis_dims):
seq += [Linear(dim, item), LeakyReLU(0.2), Dropout(0.5)]
dim = item
seq += [Linear(dim, 1)]
self.seq = Sequential(*seq)