How to use the elephas.java.java_classes.ArrayList function in elephas

To help you get started, we’ve selected a few elephas examples, based on popular ways it is used in public projects.

Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.

github maxpumperla / elephas / elephas / utils / rdd_utils.py View on Github external
def to_java_rdd(jsc, features, labels, batch_size):
    """Convert numpy features and labels into a JavaRDD of
    DL4J DataSet type.

    :param jsc: JavaSparkContext from pyjnius
    :param features: numpy array with features
    :param labels: numpy array with labels:
    :return: JavaRDD
    """
    data_sets = java_classes.ArrayList()
    num_batches = int(len(features) / batch_size)
    for i in range(num_batches):
        xi = ndarray(features[:batch_size].copy())
        yi = ndarray(labels[:batch_size].copy())
        data_set = java_classes.DataSet(xi.array, yi.array)
        data_sets.add(data_set)
        features = features[batch_size:]
        labels = labels[batch_size:]

    return jsc.parallelize(data_sets)
github maxpumperla / elephas / elephas / java / adapter.py View on Github external
def retrieve_keras_weights(java_model):
    """For a previously imported Keras model, after training it with DL4J Spark,
    we want to set the resulting weights back to the original Keras model.

    :param java_model: DL4J model (MultiLayerNetwork or ComputationGraph
    :return: list of numpy arrays in correct order for model.set_weights(...) of a corresponding Keras model
    """
    weights = []
    layers = java_model.getLayers()
    for layer in layers:
        params = layer.paramTable()
        keys = params.keySet()
        key_list = java_classes.ArrayList(keys)
        for key in key_list:
            weight = params.get(key)
            np_weight = np.squeeze(to_numpy(weight))
            weights.append(np_weight)
    return weights