Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.
:param step_delay:
:param step_trigger:
:param threshold_step:
:param collect_stats:
:param save_file:
:param args:
:param kwargs:
"""
SparkModel.__init__(self, model=model, num_workers=num_workers, batch_size=batch_size, mode='asynchronous',
shake_frequency=shake_frequency, min_threshold=min_threshold,
update_threshold=update_threshold, workers_per_node=workers_per_node,
num_batches_prefetch=num_batches_prefetch, step_delay=step_delay, step_trigger=step_trigger,
threshold_step=threshold_step, collect_stats=collect_stats, *args, **kwargs)
self.save(save_file)
model_file = java_classes.File(save_file)
keras_model_type = model.__class__.__name__
self.java_spark_model = dl4j_import(
java_spark_context, model_file, keras_model_type)
def main():
# Set Java Spark context
conf = java_classes.SparkConf().setMaster('local[*]').setAppName("elephas_dl4j")
jsc = java_classes.JavaSparkContext(conf)
# Define Keras model
model = keras.models.Sequential()
model.add(keras.layers.Dense(128, input_dim=784))
model.add(keras.layers.Dense(units=10, activation='softmax'))
model.compile(loss='categorical_crossentropy', optimizer='sgd', metrics=['accuracy'])
# Define DL4J Elephas model
spark_model = ParameterAveragingModel(java_spark_context=jsc, model=model, num_workers=4, batch_size=32)
# Load data and build DL4J DataSet RDD under the hood
(x_train, y_train), (x_test, y_test) = keras.datasets.mnist.load_data()
x_train = x_train.reshape(60000, 784)
x_test = x_test.reshape(10000, 784)
x_train = x_train.astype("float64")
x_test = x_test.astype("float64")
def to_java_rdd(jsc, features, labels, batch_size):
"""Convert numpy features and labels into a JavaRDD of
DL4J DataSet type.
:param jsc: JavaSparkContext from pyjnius
:param features: numpy array with features
:param labels: numpy array with labels:
:return: JavaRDD
"""
data_sets = java_classes.ArrayList()
num_batches = int(len(features) / batch_size)
for i in range(num_batches):
xi = ndarray(features[:batch_size].copy())
yi = ndarray(labels[:batch_size].copy())
data_set = java_classes.DataSet(xi.array, yi.array)
data_sets.add(data_set)
features = features[batch_size:]
labels = labels[batch_size:]
return jsc.parallelize(data_sets)
def dl4j_import(jsc, model_file, keras_model_type):
emi = java_classes.ElephasModelImport
if keras_model_type == "Sequential":
try:
return emi.importElephasSequentialModelAndWeights(
jsc, model_file.absolutePath)
except:
print("Couldn't load Keras model into DL4J")
elif keras_model_type == "Model":
try:
return emi.importElephasModelAndWeights(jsc, model_file.absolutePath)
except:
print("Couldn't load Keras model into DL4J")
else:
raise Exception(
"Keras model not understood, got: {}".format(keras_model_type))