Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.
import sys
if sys.version >= '3':
long = int
unicode = str
import py4j.protocol
from py4j.protocol import Py4JJavaError
from py4j.java_gateway import JavaObject
from py4j.java_collections import JavaArray, JavaList
from pyspark import RDD, SparkContext
from pyspark.serializers import PickleSerializer, AutoBatchedSerializer
from pyspark.sql import DataFrame, SQLContext
# Hack for support float('inf') in Py4j
_old_smart_decode = py4j.protocol.smart_decode
_float_str_mapping = {
'nan': 'NaN',
'inf': 'Infinity',
'-inf': '-Infinity',
}
def _new_smart_decode(obj):
if isinstance(obj, float):
s = str(obj)
return _float_str_mapping.get(s, s)
return _old_smart_decode(obj)
py4j.protocol.smart_decode = _new_smart_decode
raise Py4JNetworkError(
"Error while sending", e, proto.ERROR_ON_SEND)
try:
while True:
answer = smart_decode(self.stream.readline()[:-1])
logger.debug("Answer received: {0}".format(answer))
# Happens when a the other end is dead. There might be an empty
# answer before the socket raises an error.
if answer.strip() == "":
raise Py4JNetworkError("Answer from Java side is empty")
if answer.startswith(proto.RETURN_MESSAGE):
return answer[1:]
else:
command = answer
obj_id = smart_decode(self.stream.readline())[:-1]
if command == proto.CALL_PROXY_COMMAND_NAME:
return_message = self._call_proxy(obj_id, self.stream)
self.socket.sendall(return_message.encode("utf-8"))
elif command == proto.GARBAGE_COLLECT_PROXY_COMMAND_NAME:
self.stream.readline()
_garbage_collect_proxy(self.pool, obj_id)
self.socket.sendall(
proto.SUCCESS_RETURN_MESSAGE.encode("utf-8"))
else:
logger.error("Unknown command {0}".format(command))
# We're sending something to prevent blocking,
# but at this point, the protocol is broken.
self.socket.sendall(
proto.ERROR_RETURN_MESSAGE.encode("utf-8"))
except Exception as e:
def put(self, object, force_id=None):
"""Adds a proxy to the pool.
:param object: The proxy to add to the pool.
:rtype: A unique identifier associated with the object.
"""
with self.lock:
if force_id:
id = force_id
else:
id = proto.PYTHON_PROXY_PREFIX + smart_decode(self.next_id)
self.next_id += 1
self.dict[id] = object
return id
def _get_params(self, input):
params = []
temp = smart_decode(input.readline())[:-1]
while temp != proto.END:
param = get_return_value("y" + temp, self.java_client)
params.append(param)
temp = smart_decode(input.readline())[:-1]
return params
def send_command(self, command):
# TODO At some point extract common code from wait_for_commands
logger.debug("Command to send: {0}".format(command))
try:
self.socket.sendall(command.encode("utf-8"))
except Exception as e:
logger.info("Error while sending or receiving.", exc_info=True)
raise Py4JNetworkError(
"Error while sending", e, proto.ERROR_ON_SEND)
try:
while True:
answer = smart_decode(self.stream.readline()[:-1])
logger.debug("Answer received: {0}".format(answer))
# Happens when a the other end is dead. There might be an empty
# answer before the socket raises an error.
if answer.strip() == "":
raise Py4JNetworkError("Answer from Java side is empty")
if answer.startswith(proto.RETURN_MESSAGE):
return answer[1:]
else:
command = answer
obj_id = smart_decode(self.stream.readline())[:-1]
if command == proto.CALL_PROXY_COMMAND_NAME:
return_message = self._call_proxy(obj_id, self.stream)
self.socket.sendall(return_message.encode("utf-8"))
elif command == proto.GARBAGE_COLLECT_PROXY_COMMAND_NAME:
self.stream.readline()
_old_smart_decode = py4j.protocol.smart_decode
_float_str_mapping = {
'nan': 'NaN',
'inf': 'Infinity',
'-inf': '-Infinity',
}
def _new_smart_decode(obj):
if isinstance(obj, float):
s = str(obj)
return _float_str_mapping.get(s, s)
return _old_smart_decode(obj)
py4j.protocol.smart_decode = _new_smart_decode
_picklable_classes = [
'SparseVector',
'DenseVector',
'SparseMatrix',
'DenseMatrix',
]
# this will call the ML version of pythonToJava()
def _to_java_object_rdd(rdd):
""" Return an JavaRDD of Object by unpickling
It will convert each Python object into Java object by Pyrolite, whenever the
RDD is serialized in batch or not.
def _garbage_collect_object(gateway_client, target_id):
try:
try:
ThreadSafeFinalizer.remove_finalizer(
smart_decode(gateway_client.address) +
smart_decode(gateway_client.port) +
target_id)
gateway_client.garbage_collect_object(target_id)
except Exception:
logger.debug(
"Exception while garbage collecting an object",
exc_info=True)
except Exception:
# Maybe logger is dead at this point.
pass
_old_smart_decode = py4j.protocol.smart_decode
_float_str_mapping = {
'nan': 'NaN',
'inf': 'Infinity',
'-inf': '-Infinity',
}
def _new_smart_decode(obj):
if isinstance(obj, float):
s = str(obj)
return _float_str_mapping.get(s, s)
return _old_smart_decode(obj)
py4j.protocol.smart_decode = _new_smart_decode
_picklable_classes = [
'LinkedList',
'SparseVector',
'DenseVector',
'DenseMatrix',
'Rating',
'LabeledPoint',
]
# this will call the MLlib version of pythonToJava()
def _to_java_object_rdd(rdd):
""" Return a JavaRDD of Object by unpickling
:rtype: the `string` answer received from the JVM (The answer follows
the Py4J protocol).
"""
logger.debug("Command to send: {0}".format(command))
try:
# Write will only fail if remote is closed for large payloads or
# if it sent a RST packet (SO_LINGER)
self.socket.sendall(command.encode("utf-8"))
except Exception as e:
logger.info("Error while sending.", exc_info=True)
raise Py4JNetworkError(
"Error while sending", e, proto.ERROR_ON_SEND)
try:
answer = smart_decode(self.stream.readline()[:-1])
logger.debug("Answer received: {0}".format(answer))
if answer.startswith(proto.RETURN_MESSAGE):
answer = answer[1:]
# Happens when a the other end is dead. There might be an empty
# answer before the socket raises an error.
if answer.strip() == "":
raise Py4JNetworkError("Answer from Java side is empty")
return answer
except Exception as e:
logger.info("Error while receiving.", exc_info=True)
raise Py4JNetworkError(
"Error while receiving", e, proto.ERROR_ON_RECEIVE)