Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.
def to_java_rdd(jsc, features, labels, batch_size):
"""Convert numpy features and labels into a JavaRDD of
DL4J DataSet type.
:param jsc: JavaSparkContext from pyjnius
:param features: numpy array with features
:param labels: numpy array with labels:
:return: JavaRDD
"""
data_sets = java_classes.ArrayList()
num_batches = int(len(features) / batch_size)
for i in range(num_batches):
xi = ndarray(features[:batch_size].copy())
yi = ndarray(labels[:batch_size].copy())
data_set = java_classes.DataSet(xi.array, yi.array)
data_sets.add(data_set)
features = features[batch_size:]
labels = labels[batch_size:]
return jsc.parallelize(data_sets)
def array(*args, **kwargs):
return ndarray(*args, **kwargs)
if type(key) is slice:
start = key.start
stop = key.stop
step = key.step
if start is None:
start = 0
if stop is None:
shape = self.array.shape()
if shape[0] == 1:
stop = shape[1]
else:
stop = shape[0]
if stop - start <= 0:
return None
if step is None or step == 1:
return ndarray(self.array.get(NDArrayIndex.interval(start, stop)))
else:
return ndarray(self.array.get(NDArrayIndex.interval(start, step, stop)))
if type(key) is list:
raise NotImplemented(
'Sorry, this type of indexing is not supported yet.')
if type(key) is tuple:
key = list(key)
shape = self.array.shape()
ndim = len(shape)
nk = len(key)
key += [slice(None)] * (ndim - nk)
args = []
for i, dim in enumerate(key):
if type(dim) is int:
args.append(NDArrayIndex.point(dim))
elif type(dim) is slice:
def __add__(self, other):
other = _indarray(other)
x, y = broadcast(self.array, other)
return ndarray(x.add(y))
def to_java_rdd(jsc, features, labels, batch_size):
"""Convert numpy features and labels into a JavaRDD of
DL4J DataSet type.
:param jsc: JavaSparkContext from pyjnius
:param features: numpy array with features
:param labels: numpy array with labels:
:return: JavaRDD
"""
data_sets = java_classes.ArrayList()
num_batches = int(len(features) / batch_size)
for i in range(num_batches):
xi = ndarray(features[:batch_size].copy())
yi = ndarray(labels[:batch_size].copy())
data_set = java_classes.DataSet(xi.array, yi.array)
data_sets.add(data_set)
features = features[batch_size:]
labels = labels[batch_size:]
return jsc.parallelize(data_sets)
def __div__(self, other):
other = _indarray(other)
x, y = broadcast(self.array, other)
return ndarray(x.div(y))
def _indarray(x):
if type(x) is INDArray:
return x
elif type(x) is ndarray:
return x.array
elif 'numpy' in str(type(x)):
return _from_numpy(x)
elif type(x) in (list, tuple):
return _from_numpy(np.array(x))
elif type(x) in (int, float):
return Nd4j.scalar(x)
else:
raise Exception('Data type not understood :' + str(type(x)))
def __mul__(self, other):
other = _indarray(other)
x, y = broadcast(self.array, other)
return ndarray(x.mul(y))
def __sub__(self, other):
other = _indarray(other)
x, y = broadcast(self.array, other)
return ndarray(x.sub(y))