Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.
function buildCombinedModel(latentSize, generator, discriminator, optimizer) {
// Latent vector. This is one of the two inputs to the generator.
const latent = tf.input({shape: [latentSize]});
// Desired image class. This is the second input to the generator.
const imageClass = tf.input({shape: [1]});
// Get the symbolic tensor for fake images generated by the generator.
let fake = generator.apply([latent, imageClass]);
let aux;
// We only want to be able to train generation for the combined model.
discriminator.trainable = false;
[fake, aux] = discriminator.apply(fake);
const combined =
tf.model({inputs: [latent, imageClass], outputs: [fake, aux]});
combined.compile({
optimizer,
loss: ['binaryCrossentropy', 'sparseCategoricalCrossentropy']
});
combined.summary();
return combined;
}
// [0, NUM_CLASSES).
const imageClass = tf.input({shape: [1]});
// The desired label is converted to a vector of length `latentSize`
// through embedding lookup.
const classEmbedding = tf.layers.embedding({
inputDim: NUM_CLASSES,
outputDim: latentSize,
embeddingsInitializer: 'glorotNormal'
}).apply(imageClass);
// Hadamard product between z-space and a class conditional embedding.
const h = tf.layers.multiply().apply([latent, classEmbedding]);
const fakeImage = cnn.apply(h);
return tf.model({inputs: [latent, imageClass], outputs: fakeImage});
}
// Unlike most TensorFlow.js models, the discriminator has two outputs.
// The 1st output is the probability score assigned by the discriminator to
// how likely the input example is a real MNIST image (as versus
// a "fake" one generated by the generator).
const realnessScore =
tf.layers.dense({units: 1, activation: 'sigmoid'}).apply(features);
// The 2nd output is the softmax probabilities assign by the discriminator
// for the 10 MNIST digit classes (0 through 9). "aux" stands for "auxiliary"
// (the namesake of ACGAN) and refers to the fact that unlike a standard GAN
// (which performs just binary real/fake classification), the discriminator
// part of ACGAN also performs multi-class classification.
const aux = tf.layers.dense({units: NUM_CLASSES, activation: 'softmax'})
.apply(features);
return tf.model({inputs: image, outputs: [realnessScore, aux]});
}