Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.
def test_reservation_enviroment_not_exists_get_server_ip_return_actual_host_ip(self):
tfso_server = Server(5)
assert tfso_server.get_server_ip() == util.get_ip_address()
def test_reservation_enviroment_not_exists_start_listening_socket_return_socket(self):
tfso_server = Server(1)
print(tfso_server.start_listening_socket().getsockname()[1])
assert type(tfso_server.start_listening_socket().getsockname()[1]) == int
def test_reservation_enviroment_exists_get_server_ip_return_environment_value(self):
tfso_server = Server(5)
with mock.patch.dict(os.environ,{'TFOS_SERVER_HOST':'my_host_ip'}):
assert tfso_server.get_server_ip() == "my_host_ip"
def test_reservation_enviroment_exists_start_listening_socket_return_socket_listening_to_environment_port_value(self):
tfso_server = Server(1)
with mock.patch.dict(os.environ, {'TFOS_SERVER_PORT': '9999'}):
assert tfso_server.start_listening_socket().getsockname()[1] == 9999
def test_reservation_server(self):
"""Test reservation server, expecting 1 reservation"""
s = Server(1)
addr = s.start()
# add first reservation
c = Client(addr)
resp = c.register({'node': 1})
self.assertEqual(resp, 'OK')
# get list of reservations
reservations = c.get_reservations()
self.assertEqual(len(reservations), 1)
# should return immediately with list of reservations
reservations = c.await_reservations()
self.assertEqual(len(reservations), 1)
# request server stop
def test_reservation_server_multi(self):
"""Test reservation server, expecting multiple reservations"""
num_clients = 4
s = Server(num_clients)
addr = s.start()
def reserve(num):
c = Client(addr)
# time.sleep(random.randint(0,5)) # simulate varying start times
resp = c.register({'node': num})
self.assertEqual(resp, 'OK')
c.await_reservations()
c.close()
# start/register clients
threads = [None] * num_clients
for i in range(num_clients):
threads[i] = threading.Thread(target=reserve, args=(i,))
threads[i].start()
if num_workers > 0:
cluster_template['worker'] = executors[:num_workers]
logger.info("cluster_template: {}".format(cluster_template))
# get default filesystem from spark
defaultFS = sc._jsc.hadoopConfiguration().get("fs.defaultFS")
# strip trailing "root" slash from "file:///" to be consistent w/ "hdfs://..."
if defaultFS.startswith("file://") and len(defaultFS) > 7 and defaultFS.endswith("/"):
defaultFS = defaultFS[:-1]
# get current working dir of spark launch
working_dir = os.getcwd()
# start a server to listen for reservations and broadcast cluster_spec
server = reservation.Server(num_executors)
server_addr = server.start()
# start TF nodes on all executors
logger.info("Starting TensorFlow on executors")
cluster_meta = {
'id': random.getrandbits(64),
'cluster_template': cluster_template,
'num_executors': num_executors,
'default_fs': defaultFS,
'working_dir': working_dir,
'server_addr': server_addr
}
if driver_ps_nodes:
nodeRDD = sc.parallelize(range(num_ps, num_executors), num_executors - num_ps)
else:
nodeRDD = sc.parallelize(range(num_executors), num_executors)