Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.
def to_numpy(nd4j_array):
""" Convert an ND4J array to a numpy array
:param nd4j_array:
:return:
"""
buff = nd4j_array.data()
address = buff.pointer().address()
type_name = java_classes.DataTypeUtil.getDtypeFromContext()
data_type = java_classes.DataTypeUtil.getDTypeForName(type_name)
mapping = {
'double': ctypes.c_double,
'float': ctypes.c_float
}
Pointer = ctypes.POINTER(mapping[data_type])
pointer = ctypes.cast(address, Pointer)
np_array = np.ctypeslib.as_array(pointer, tuple(nd4j_array.shape()))
return np_array
def get_context_dtype():
"""Returns the nd4j dtype
"""
dtype = java_classes.DataTypeUtil.getDtypeFromContext()
return java_classes.DataTypeUtil.getDTypeForName(dtype)