How to use the elephas.spark_model.SparkModel function in elephas

To help you get started, we’ve selected a few elephas examples, based on popular ways it is used in public projects.

Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.

github maxpumperla / elephas / tests / test_spark_model.py View on Github external
spark_model = SparkModel(model, frequency='batch',
                             mode='synchronous', num_workers=2)
    spark_model.fit(rdd, epochs=epochs, batch_size=batch_size,
                    verbose=2, validation_split=0.1)
    score = spark_model.master_network.evaluate(x_test, y_test, verbose=2)
    print('Test accuracy:', score[1])

    # async epoch
    spark_model = SparkModel(model, frequency='epoch', mode='asynchronous')
    spark_model.fit(rdd, epochs=epochs, batch_size=batch_size,
                    verbose=2, validation_split=0.1)
    score = spark_model.master_network.evaluate(x_test, y_test, verbose=2)
    print('Test accuracy:', score[1])

    # hog wild epoch
    spark_model = SparkModel(model, frequency='epoch', mode='hogwild')
    spark_model.fit(rdd, epochs=epochs, batch_size=batch_size,
                    verbose=2, validation_split=0.1)
    score = spark_model.master_network.evaluate(x_test, y_test, verbose=2)
    print('Test accuracy:', score[1])
github maxpumperla / elephas / tests / test_model_serialization.py View on Github external
# This returns a tensor
    inputs = Input(shape=(784,))

    # a layer instance is callable on a tensor, and returns a tensor
    x = Dense(64, activation='relu')(inputs)
    x = Dense(64, activation='relu')(x)
    predictions = Dense(10, activation='softmax')(x)

    # This creates a model that includes
    # the Input layer and three Dense layers
    model = Model(inputs=inputs, outputs=predictions)
    model.compile(optimizer='rmsprop',
                  loss='categorical_crossentropy',
                  metrics=['accuracy'])

    spark_model = SparkModel(model, frequency='epoch',
                             mode='synchronous', foo="bar")
    spark_model.save("elephas_model.h5")
github maxpumperla / elephas / tests / test_model_serialization.py View on Github external
# Create Spark context
    pytest.mark.usefixtures("spark_context")

    seq_model = Sequential()
    seq_model.add(Dense(128, input_dim=784))
    seq_model.add(Activation('relu'))
    seq_model.add(Dropout(0.2))
    seq_model.add(Dense(128))
    seq_model.add(Activation('relu'))
    seq_model.add(Dropout(0.2))
    seq_model.add(Dense(10))
    seq_model.add(Activation('softmax'))

    seq_model.compile(
        optimizer="sgd", loss="categorical_crossentropy", metrics=["acc"])
    spark_model = SparkModel(seq_model, frequency='epoch', mode='synchronous')
    spark_model.save("elephas_sequential.h5")
github maxpumperla / elephas / examples / mnist_mlp_spark.py View on Github external
model.add(Activation('relu'))
model.add(Dropout(0.2))
model.add(Dense(128))
model.add(Activation('relu'))
model.add(Dropout(0.2))
model.add(Dense(10))
model.add(Activation('softmax'))

sgd = SGD(lr=0.1)
model.compile(sgd, 'categorical_crossentropy', ['acc'])

# Build RDD from numpy features and labels
rdd = to_simple_rdd(sc, x_train, y_train)

# Initialize SparkModel from Keras model and Spark context
spark_model = SparkModel(model, frequency='epoch', mode='asynchronous')

# Train Spark model
spark_model.fit(rdd, epochs=epochs, batch_size=batch_size, verbose=2, validation_split=0.1)

# Evaluate Spark model by evaluating the underlying model
score = spark_model.master_network.evaluate(x_test, y_test, verbose=2)
print('Test accuracy:', score[1])
github abhishekakumar / GatorSquad / SparkElephasModel.py View on Github external
# Obtaining 3D training and testing vectors
    (feature_train, label_train), (feature_test, label_test) = lstm.train_test_split(modelFeatures,modelLabel,trainSize,timeSteps)

    # Condition to check whether the failure cases exists in the data
    if len(feature_train)==0:
        print("DiskModel has no failure eleements. Training of the model cannot proceed!!")
        return
    # Initializing the Adam Optimizer for Elephas
    adam = elephas_optimizers.Adam()
    print "Adam Optimizer initialized"
    #Converting Dataframe to Spark RDD
    rddataset = to_simple_rdd(sc, feature_train, label_train)
    print "Training data converted into Resilient Distributed Dataset"
    #Initializing the SparkModel with Optimizer,Master-Worker Mode and Number of Workers
    spark_model = SparkModel(sc,lstmModel,optimizer=adam ,frequency='epoch', mode='asynchronous', num_workers=2)
    print "Spark Model Initialized"
    #Initial training run of the model
    spark_model.train(rddataset, nb_epoch=10, batch_size=200, verbose=1, validation_split=0)
    # Saving the model
    score = spark_model.evaluate(feature_test, label_test,show_accuracy=True)

    while(score <= 0.5):
        # Training the Input Data set
        spark_model.train(rddataset, nb_epoch=10, batch_size=200, verbose=1, validation_split=0)
        print "LSTM model training done !!"
        score = spark_model.evaluate(feature_test, label_test,show_accuracy=True)
    print "Saving weights!!"
    outFilePath=os.environ.get('GATOR_SQUAD_HOME')
    outFilePath=outFilePath+"Weights/"+str(year)+"/"+str(month)+"/"+str(modelName)+"_my_model_weights.h5"
    spark_model.save_weights(outFilePath)
    print "LSTM model testing commencing !!"
github maxpumperla / elephas / elephas / spark_model.py View on Github external
def load_spark_model(file_name):
    model = load_model(file_name)
    f = h5py.File(file_name, mode='r')

    elephas_conf = json.loads(f.attrs.get('distributed_config'))
    class_name = elephas_conf.get('class_name')
    config = elephas_conf.get('config')
    if class_name == "SparkModel":
        return SparkModel(model=model, **config)
    elif class_name == "SparkMLlibModel":
        return SparkMLlibModel(model=model, **config)


class SparkMLlibModel(SparkModel):

    def __init__(self, model, mode='asynchronous', frequency='epoch', parameter_server_mode='http',
                 num_workers=4, elephas_optimizer=None, custom_objects=None, batch_size=32, *args, **kwargs):
        """SparkMLlibModel

        The Spark MLlib model takes RDDs of LabeledPoints for training.

        :param model: Compiled Keras model
        :param mode: String, choose from `asynchronous`, `synchronous` and `hogwild`
        :param frequency: String, either `epoch` or `batch`
        :param parameter_server_mode: String, either `http` or `socket`
        :param num_workers: int, number of workers used for training (defaults to None)
        :param elephas_optimizer: Elephas optimizer
        :param custom_objects: Keras custom objects
        """
        SparkModel.__init__(self, model=model, mode=mode, frequency=frequency,
github maxpumperla / elephas / elephas / dl4j.py View on Github external
from .spark_model import SparkModel
try:
    from elephas.java import java_classes, adapter
except:
    raise Exception("Warning: java classes couldn't be loaded.")


class ParameterAveragingModel(SparkModel):
    def __init__(self, java_spark_context, model, num_workers, batch_size, averaging_frequency=5,
                 num_batches_prefetch=0, collect_stats=False, save_file='temp.h5', *args, **kwargs):
        """ParameterAveragingModel

         :param java_spark_context JavaSparkContext, initialized through pyjnius
         :param model: compiled Keras model
         :param num_workers: number of Spark workers/executors.
         :param batch_size: batch size used for model training
         :param averaging_frequency: int, after how many batches of training averaging takes place
         :param num_batches_prefetch: int, how many batches to pre-fetch, deactivated if 0.
         :param collect_stats: boolean, if statistics get collected during training
         :param save_file: where to store elephas model temporarily.
         """
        SparkModel.__init__(self, model=model, batch_size=batch_size, mode='synchronous',
                            averaging_frequency=averaging_frequency, num_batches_prefetch=num_batches_prefetch,
                            num_workers=num_workers, collect_stats=collect_stats, *args, **kwargs)