Skip to content

Commit 4c70c57

Browse files
committed
Add a test for WANT_READ during sendall()
1 parent 6119ee5 commit 4c70c57

File tree

1 file changed

+126
-3
lines changed

1 file changed

+126
-3
lines changed

tests/test_ssl.py

Lines changed: 126 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -321,7 +321,7 @@ def loopback_server_factory(socket, version=SSLv23_METHOD):
321321
return server
322322

323323

324-
def loopback(server_factory=None, client_factory=None):
324+
def loopback(server_factory=None, client_factory=None, blocking=True):
325325
"""
326326
Create a connected socket pair and force two connected SSL sockets
327327
to talk to each other via memory BIOs.
@@ -337,8 +337,8 @@ def loopback(server_factory=None, client_factory=None):
337337

338338
handshake(client, server)
339339

340-
server.setblocking(True)
341-
client.setblocking(True)
340+
server.setblocking(blocking)
341+
client.setblocking(blocking)
342342
return server, client
343343

344344

@@ -3297,11 +3297,134 @@ def test_memoryview_really_doesnt_overfill(self):
32973297
self._doesnt_overfill_test(_make_memoryview)
32983298

32993299

3300+
@pytest.fixture
3301+
def nonblocking_tls_connections_pair():
3302+
"""Return a non-blocking TLS loopback connections pair."""
3303+
return loopback(blocking=False)
3304+
3305+
3306+
@pytest.fixture
3307+
def nonblocking_tls_server_connection(nonblocking_tls_connections_pair):
3308+
"""Return a non-blocking TLS server socket connected to loopback."""
3309+
return nonblocking_tls_connections_pair[0]
3310+
3311+
3312+
@pytest.fixture
3313+
def nonblocking_tls_client_connection(nonblocking_tls_connections_pair):
3314+
"""Return a non-blocking TLS client socket connected to loopback."""
3315+
return nonblocking_tls_connections_pair[1]
3316+
3317+
33003318
class TestConnectionSendall:
33013319
"""
33023320
Tests for `Connection.sendall`.
33033321
"""
33043322

3323+
def test_want_write(
3324+
self,
3325+
monkeypatch,
3326+
nonblocking_tls_server_connection,
3327+
nonblocking_tls_client_connection,
3328+
):
3329+
msg = b"x"
3330+
garbage_size = 1024 * 1024 * 64
3331+
large_payload = b"p" * garbage_size * 2
3332+
payload_size = len(large_payload)
3333+
3334+
sent_garbage_size = 0
3335+
try:
3336+
sent_garbage_size += nonblocking_tls_client_connection.send(
3337+
msg * garbage_size,
3338+
)
3339+
except WantWriteError:
3340+
pass
3341+
for i in range(garbage_size):
3342+
try:
3343+
sent_garbage_size += nonblocking_tls_client_connection.send(
3344+
msg,
3345+
)
3346+
except WantWriteError:
3347+
break
3348+
else:
3349+
pytest.fail(
3350+
"Failed to fill socket buffer, cannot test "
3351+
"'want write' in `sendall()`"
3352+
)
3353+
garbage_payload = sent_garbage_size * msg
3354+
3355+
def consume_garbage(conn):
3356+
assert patched_ssl_write.want_write_counter >= 1
3357+
assert not consume_garbage.garbage_consumed
3358+
3359+
while len(consume_garbage.consumed) < sent_garbage_size:
3360+
try:
3361+
consume_garbage.consumed += conn.recv(
3362+
sent_garbage_size - len(consume_garbage.consumed),
3363+
)
3364+
except WantReadError:
3365+
pass
3366+
3367+
assert consume_garbage.consumed == garbage_payload
3368+
3369+
consume_garbage.garbage_consumed = True
3370+
3371+
consume_garbage.garbage_consumed = False
3372+
consume_garbage.consumed = b""
3373+
3374+
def consume_payload(conn):
3375+
try:
3376+
consume_payload.consumed += conn.recv(payload_size)
3377+
except WantReadError:
3378+
pass
3379+
3380+
consume_payload.consumed = b""
3381+
3382+
original_ssl_write = _lib.SSL_write
3383+
3384+
def patched_ssl_write(ctx, data, size):
3385+
write_result = original_ssl_write(ctx, data, size)
3386+
try:
3387+
nonblocking_tls_client_connection._raise_ssl_error(
3388+
ctx,
3389+
write_result,
3390+
)
3391+
except WantWriteError:
3392+
patched_ssl_write.want_write_counter += 1
3393+
consume_data_on_server = (
3394+
consume_payload
3395+
if consume_garbage.garbage_consumed
3396+
else consume_garbage
3397+
)
3398+
3399+
consume_data_on_server(nonblocking_tls_server_connection)
3400+
# NOTE: We don't re-raise it as the calling code will do
3401+
# NOTE: the same after the call.
3402+
return write_result
3403+
3404+
patched_ssl_write.want_write_counter = 0
3405+
3406+
# NOTE: Make the client think it needs a handshake so that it'll
3407+
# NOTE: attempt to `do_handshake()` on the next `SSL_write()`
3408+
# NOTE: that originates from `sendall()`:
3409+
nonblocking_tls_client_connection.set_connect_state()
3410+
try:
3411+
nonblocking_tls_client_connection.do_handshake()
3412+
except WantWriteError:
3413+
assert True # Sanity check
3414+
except:
3415+
assert False # This should never happen (see the note above)
3416+
3417+
with monkeypatch.context() as mp_ctx:
3418+
mp_ctx.setattr(_lib, "SSL_write", patched_ssl_write)
3419+
nonblocking_tls_client_connection.sendall(large_payload)
3420+
3421+
assert consume_garbage.garbage_consumed
3422+
3423+
# NOTE: Read the leftover data from the very last `SSL_write()`
3424+
consume_payload(nonblocking_tls_server_connection)
3425+
3426+
assert consume_payload.consumed == large_payload
3427+
33053428
def test_wrong_args(self):
33063429
"""
33073430
When called with arguments other than a string argument for its first

0 commit comments

Comments
 (0)