Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.
// Make new latent vectors.
const zVectors = tf.randomUniform([batchSize, latentSize], -1, 1);
const sampledLabels =
tf.randomUniform([batchSize, 1], 0, NUM_CLASSES, 'int32')
.asType('float32');
// We want to train the generator to trick the discriminator.
// For the generator, we want all the {fake, not-fake} labels to say
// not-fake.
const trick = tf.tidy(() => tf.ones([batchSize, 1]).mul(SOFT_ONE));
return [zVectors, sampledLabels, trick];
});
const losses = await combined.trainOnBatch(
[noise, sampledLabels], [trick, sampledLabels]);
tf.dispose([noise, sampledLabels, trick]);
return losses;
}
const generatedImages =
generator.predict([zVectors, sampledLabels], {batchSize: batchSize});
const x = tf.concat([imageBatch, generatedImages], 0);
const y = tf.tidy(
() => tf.concat(
[tf.ones([batchSize, 1]).mul(SOFT_ONE), tf.zeros([batchSize, 1])]));
const auxY = tf.concat([labelBatch, sampledLabels], 0);
return [x, y, auxY];
});
const losses = await discriminator.trainOnBatch(x, [y, auxY]);
tf.dispose([x, y, auxY]);
return losses;
}