How to use the elephas.ml_model.ElephasTransformer function in elephas

To help you get started, we’ve selected a few elephas examples, based on popular ways it is used in public projects.

Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.

github maxpumperla / elephas / tests / test_ml_model.py View on Github external
def test_serialization_transformer():
    transformer = ElephasTransformer()
    transformer.set_keras_model_config(model.to_yaml())
    transformer.save("test.h5")
    load_ml_transformer("test.h5")
github maxpumperla / elephas / elephas / ml_model.py View on Github external
optimizer = get_optimizer(self.get_optimizer_config())
        keras_model.compile(loss=loss, optimizer=optimizer, metrics=metrics)

        spark_model = SparkModel(model=keras_model, 
                                 mode=self.get_mode(), 
                                 frequency=self.get_frequency(),
                                 num_workers=self.get_num_workers())
        spark_model.fit(simple_rdd, 
                        epochs=self.get_epochs(), 
                        batch_size=self.get_batch_size(),
                        verbose=self.get_verbosity(), 
                        validation_split=self.get_validation_split())

        model_weights = spark_model.master_network.get_weights()
        weights = simple_rdd.ctx.broadcast(model_weights)
        return ElephasTransformer(labelCol=self.getLabelCol(),
                                  outputCol='prediction',
                                  keras_model_config=spark_model.master_network.to_yaml(),
                                  weights=weights)
github maxpumperla / elephas / elephas / ml_model.py View on Github external
def __init__(self, **kwargs):
        super(ElephasTransformer, self).__init__()
        if "weights" in kwargs.keys():
            # Strip model weights from parameters to init Transformer
            self.weights = kwargs.pop('weights')
        self.set_params(**kwargs)
github maxpumperla / elephas / elephas / ml_model.py View on Github external
def load_ml_transformer(file_name):
    f = h5py.File(file_name, mode='r')
    elephas_conf = json.loads(f.attrs.get('distributed_config'))
    config = elephas_conf.get('config')
    return ElephasTransformer(**config)