Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 13 additions & 1 deletion source/websocket.c
Original file line number Diff line number Diff line change
Expand Up @@ -209,12 +209,24 @@ static void s_websocket_on_connection_setup(
PyObject *tuple_py = PyTuple_New(2);
AWS_FATAL_ASSERT(tuple_py && "header tuple allocation failed");

/* Header names are tokens as per RFC 7230 Section 3.2 (strict ASCII),
* which means aws-c-http rejects on the wire if they contain non-ASCII bytes.
* So errors related to http header decoding will be caught at the protocol level.
* We should never fail wrangling the header name. */
PyObject *name_py = PyUnicode_FromAwsByteCursor(&header_i->name);
AWS_FATAL_ASSERT(name_py && "header name wrangling failed");
PyTuple_SetItem(tuple_py, 0, name_py); /* Steals a reference */

/* Header value can contain RFC 7230 obs-text (0x80-0xFF), which is
* not guaranteed valid UTF-8. On decode failure, log it and drop
* the whole header list rather than aborting the process. */
PyObject *value_py = PyUnicode_FromAwsByteCursor(&header_i->value);
AWS_FATAL_ASSERT(value_py && "header value wrangling failed");
if (!value_py) {
PyErr_WriteUnraisable(websocket_core_py);
Py_DECREF(tuple_py);
Py_CLEAR(headers_py);
break;
}
PyTuple_SetItem(tuple_py, 1, value_py); /* Steals a reference */

PyList_SetItem(headers_py, i, tuple_py); /* Steals a reference */
Expand Down
104 changes: 104 additions & 0 deletions test/test_websocket.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@
from queue import Empty, Queue
import secrets
import socket
import subprocess
import sys
from test import NativeResourceTest
import threading
from time import sleep, time
Expand Down Expand Up @@ -182,6 +184,54 @@ def send_async(self, msg):
asyncio.run_coroutine_threadsafe(self._current_connection.send(msg), self._server_loop)


class MockHandshakeServer:
# A raw-socket server that accepts one connection, drains the client's
# HTTP handshake request, and sends back a caller-provided response.
# Use this when tests need to send byte sequences that the 3rdparty
# `websockets` library can't produce (e.g. malformed headers).
#
# Usage:
# with MockHandshakeServer(host, response=b"HTTP/1.1 ...") as server:
# # spawn a client that connects to (host, server.port)
# ...

def __init__(self, host, response):
self._host = host
self._response = response
self._listener = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
self._listener.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
self._listener.bind((host, 0))
self._listener.listen(1)
self.port = self._listener.getsockname()[1]
self._thread = threading.Thread(target=self._serve, daemon=True)

def __enter__(self):
self._thread.start()
return self

def __exit__(self, exc_type, exc_value, exc_tb):
self._listener.close()
self._thread.join(TIMEOUT)

def _serve(self):
try:
conn, _ = self._listener.accept()
except OSError:
return
with closing(conn):
conn.settimeout(TIMEOUT)
try:
buf = b""
while b"\r\n\r\n" not in buf:
chunk = conn.recv(4096)
if not chunk:
return
buf += chunk
conn.sendall(self._response)
except OSError:
pass


class TestClient(NativeResourceTest):
def setUp(self):
super().setUp()
Expand Down Expand Up @@ -324,6 +374,60 @@ def test_connect_failure_with_response(self):
# check that body is a valid string
self.assertGreater(len(setup_data.handshake_response_body.decode()), 0)

def test_connect_response_header_with_invalid_name_is_protocol_error(self):
# A response header whose name contains a non-tchar byte (e.g. 0xE9) is
# rejected by aws-c-http's HTTP/1.1 decoder before reaching the binding.
# The connection should fail with AWS_ERROR_HTTP_PROTOCOL_ERROR.
response = (
b"HTTP/1.1 403 Forbidden\r\n"
b"Content-Length: 0\r\n"
b"X-Bad\xe9Name: whatever\r\n"
b"\r\n"
)
with MockHandshakeServer(self.host, response=response) as server:
setup_future = Future()
connect(
host=self.host,
port=server.port,
handshake_request=create_handshake_request(host=self.host),
on_connection_setup=lambda x: setup_future.set_result(x))

setup_data: OnConnectionSetupData = setup_future.result(TIMEOUT)

self.assertIsNone(setup_data.websocket)
self.assertIsNotNone(setup_data.exception)
self.assertEqual("AWS_ERROR_HTTP_PROTOCOL_ERROR", setup_data.exception.name)
# bad-name response is rejected at the parser, so no headers reach Python
self.assertIsNone(setup_data.handshake_response_headers)

def test_connect_response_header_with_obs_text_does_not_abort(self):
# A response header value containing a non-UTF-8 obs-text byte (e.g. lone 0xE9)
# must not crash the process. Run the client in a subprocess so that an abort,
# if it happens, is observable as a non-zero exit code.
response = (
b"HTTP/1.1 403 Forbidden\r\n"
b"Content-Length: 0\r\n"
b"X-Reason: caf\xe9\r\n"
b"\r\n"
)
with MockHandshakeServer(self.host, response=response) as server:
proc = subprocess.Popen(
[sys.executable, '-m', 'test.ws_connect_helper', self.host, str(server.port)],
stdout=subprocess.PIPE,
stderr=subprocess.PIPE)

try:
stdout, stderr = proc.communicate(timeout=TIMEOUT)
except subprocess.TimeoutExpired:
proc.kill()
stdout, stderr = proc.communicate()
self.fail("client subprocess hung")

self.assertEqual(
0, proc.returncode,
f"client subprocess crashed (returncode={proc.returncode}). "
f"stdout={stdout!r} stderr={stderr!r}")

def test_exception_in_setup_callback_closes_websocket(self):
with WebSocketServer(self.host, self.port) as server:
setup_future = Future()
Expand Down
28 changes: 28 additions & 0 deletions test/ws_connect_helper.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
# SPDX-License-Identifier: Apache-2.0.

# Helper for test_websocket subprocess scenarios.
# Runs awscrt.websocket.connect() against a host:port given on the command
# line and waits for on_connection_setup to fire. Used by tests that need
# to observe whether a malformed server response crashes the client process.

import sys
from concurrent.futures import Future

from awscrt.websocket import connect, create_handshake_request

TIMEOUT = 10.0


def main(host, port):
setup_future = Future()
connect(
host=host,
port=port,
handshake_request=create_handshake_request(host=host),
on_connection_setup=lambda x: setup_future.set_result(x))
setup_future.result(TIMEOUT)


if __name__ == '__main__':
main(sys.argv[1], int(sys.argv[2]))
Loading