Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.
(ASGIWebsocketState.CLOSED, "websocket.http.response.body"),
],
)
async def test_websocket_asgi_send_invalid_message_given_state(
state: ASGIWebsocketState, message_type: str
) -> None:
stream = MockH2WebsocketStream()
stream.state = state
with pytest.raises(UnexpectedMessage):
await stream.asgi_send({"type": message_type})
async def handle_connection(self) -> None:
try:
request = await self.read_request()
async with trio.open_nursery() as nursery:
nursery.start_soon(self.read_messages)
await self.handle_websocket(request)
if self.state == ASGIWebsocketState.HTTPCLOSED:
raise MustCloseError()
except (trio.BrokenResourceError, trio.ClosedResourceError):
await self.asgi_put({"type": "websocket.disconnect"})
except MustCloseError:
pass
finally:
await self.aclose()
def maybe_close(self, future: asyncio.Future) -> None:
# Close the connection iff a HTTP response was sent
if self.state == ASGIWebsocketState.HTTPCLOSED:
self.close()
raise_if_subprotocol_present(headers)
headers.extend(self.response_headers())
await self.asend(
AcceptConnection(
extensions=[PerMessageDeflate()],
extra_headers=headers,
subprotocol=message.get("subprotocol"),
)
)
self.state = ASGIWebsocketState.CONNECTED
self.config.access_logger.access(
self.scope, {"status": 101, "headers": []}, time() - self.start_time
)
elif (
message["type"] == "websocket.http.response.start"
and self.state == ASGIWebsocketState.HANDSHAKE
):
self.response = message
self.config.access_logger.access(self.scope, self.response, time() - self.start_time)
elif message["type"] == "websocket.http.response.body" and self.state in {
ASGIWebsocketState.HANDSHAKE,
ASGIWebsocketState.RESPONSE,
}:
await self._asgi_send_rejection(message)
elif message["type"] == "websocket.send" and self.state == ASGIWebsocketState.CONNECTED:
data: Union[bytes, str]
if message.get("bytes") is not None:
await self.asend(BytesMessage(data=bytes(message["bytes"])))
elif not isinstance(message["text"], str):
raise TypeError(f"{message['text']} should be a str")
else:
await self.asend(TextMessage(data=message["text"]))