Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.
def test_multi_agent_train():
request_queue = Queue(maxsize=1)
response_queue = Queue(maxsize=1)
response_queue.put_nowait(dict(action='RotateRight'))
s = ai2thor.server.Server(request_queue, response_queue, '127.0.0.1')
c = s.app.test_client()
res = c.post(
'/train',
buffered=True,
content_type='multipart/form-data; boundary=OVCo05I3SVXLPeTvCgJjHl1EOleL4u9TDx5raRVt',
input_stream=BytesIO(generate_multi_agent_form(metadata_simple, s.sequence_id)))
assert res.status_code == 200
def test_non_multipart():
request_queue = Queue(maxsize=1)
response_queue = Queue(maxsize=1)
response_queue.put_nowait(dict(action='RotateRight'))
s = ai2thor.server.Server(request_queue, response_queue, '127.0.0.1')
c = s.app.test_client()
s.client_token = '1234567'
m = dict(agents=[metadata_simple], sequenceId=s.sequence_id)
res = c.post(
'/train',
data=dict(metadata=json.dumps(m), token=s.client_token))
assert res.status_code == 200
def test_train():
request_queue = Queue(maxsize=1)
response_queue = Queue(maxsize=1)
response_queue.put_nowait(dict(action='RotateRight'))
s = ai2thor.server.Server(request_queue, response_queue, '127.0.0.1')
c = s.app.test_client()
res = c.post(
'/train',
buffered=True,
content_type='multipart/form-data; boundary=OVCo05I3SVXLPeTvCgJjHl1EOleL4u9TDx5raRVt',
input_stream=BytesIO(generate_form(metadata_simple, s.sequence_id)))
assert res.status_code == 200
def test_client_token_mismatch():
request_queue = Queue(maxsize=1)
response_queue = Queue(maxsize=1)
response_queue.put_nowait(dict(action='RotateRight'))
s = ai2thor.server.Server(request_queue, response_queue, '127.0.0.1')
s.client_token = '123456'
c = s.app.test_client()
res = c.post(
'/train',
buffered=True,
content_type='multipart/form-data; boundary=OVCo05I3SVXLPeTvCgJjHl1EOleL4u9TDx5raRVt',
input_stream=BytesIO(generate_form(metadata_simple, s.sequence_id + 1)))
assert res.status_code == 403
def test_train_numpy_action():
request_queue = Queue(maxsize=1)
response_queue = Queue(maxsize=1)
response_queue.put_nowait(dict(
action='Teleport',
rotation=dict(y=np.array([24])[0]),
moveMagnitude=np.array([55.5])[0],
))
s = ai2thor.server.Server(request_queue, response_queue, '127.0.0.1')
c = s.app.test_client()
res = c.post(
'/train',
buffered=True,
content_type='multipart/form-data; boundary=OVCo05I3SVXLPeTvCgJjHl1EOleL4u9TDx5raRVt',
input_stream=BytesIO(generate_form(metadata_simple, s.sequence_id)))
j = json.loads(res.get_data())
assert j == {'action': 'Teleport', 'rotation': {'y': 24}, 'sequenceId': 1, 'moveMagnitude': 55.5}
assert res.status_code == 200
def server():
request_queue = Queue(maxsize=1)
response_queue = Queue(maxsize=1)
return ai2thor.server.Server(request_queue, response_queue, '127.0.0.1')
if self.server_thread is not None:
print('start() method depreciated. The server has already started when Controller was initialized.')
# Stops the current server and creates a new one. This is done so
# that the arguments passed in will be used on the server.
self.stop()
env = os.environ.copy()
image_name = None
if self.docker_enabled:
self.check_docker()
host = ai2thor.docker.bridge_gateway()
self.server = ai2thor.server.Server(
self.request_queue,
self.response_queue,
host,
port=port)
_, port = self.server.wsgi_server.socket.getsockname()
self.server_thread = threading.Thread(target=self._start_server_thread)
self.server_thread.daemon = True
self.server_thread.start()
if start_unity:
if platform.system() == 'Linux':
if self.docker_enabled: