How to use the ivis.nn.network.triplet_network function in ivis

To help you get started, we’ve selected a few ivis examples, based on popular ways it is used in public projects.

Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.

github beringresearch / ivis / tests / nn / test_network.py View on Github external
def test_triplet_network():

    X = np.zeros(shape=(10, 5))
    embedding_dims = 3

    base_model = Sequential()
    base_model.add(Dense(8, input_shape=(X.shape[-1],)))

    model, _, _, _ = triplet_network(base_model, embedding_dims=embedding_dims, embedding_l2=0.1)
    encoder = model.layers[3]

    assert model.layers[3].output_shape == (None, 3)
    assert np.all(base_model.get_weights()[0] == encoder.get_weights()[0])
    assert np.all([isinstance(layer, keras.layers.InputLayer) for layer in model.layers[:3]])

    assert encoder.output_shape == (None, embedding_dims)
github beringresearch / ivis / ivis / ivis.py View on Github external
loss_monitor = 'loss'
        try:
            triplet_loss_func = triplet_loss(distance=self.distance,
                                             margin=self.margin)
        except KeyError:
            raise ValueError('Loss function `{}` not implemented.'.format(self.distance))

        if self.model_ is None:
            if type(self.model_def) is str:
                input_size = (X.shape[-1],)
                self.model_, anchor_embedding, _, _ = \
                    triplet_network(base_network(self.model_def, input_size),
                                    embedding_dims=self.embedding_dims)
            else:
                self.model_, anchor_embedding, _, _ = \
                    triplet_network(self.model_def,
                                    embedding_dims=self.embedding_dims)

            if Y is None:
                self.model_.compile(optimizer='adam', loss=triplet_loss_func)
            else:
                if is_categorical(self.supervision_metric):
                    if not is_multiclass(self.supervision_metric):
                        if not is_hinge(self.supervision_metric):
                            # Binary logistic classifier
                            if len(Y.shape) > 1:
                                self.n_classes = Y.shape[-1]
                            else:
                                self.n_classes = 1
                            supervised_output = Dense(self.n_classes, activation='sigmoid',
                                                      name='supervised')(anchor_embedding)
                        else: