@@ -321,7 +321,7 @@ def loopback_server_factory(socket, version=SSLv23_METHOD):
321
321
return server
322
322
323
323
324
- def loopback (server_factory = None , client_factory = None ):
324
+ def loopback (server_factory = None , client_factory = None , blocking = True ):
325
325
"""
326
326
Create a connected socket pair and force two connected SSL sockets
327
327
to talk to each other via memory BIOs.
@@ -337,8 +337,8 @@ def loopback(server_factory=None, client_factory=None):
337
337
338
338
handshake (client , server )
339
339
340
- server .setblocking (True )
341
- client .setblocking (True )
340
+ server .setblocking (blocking )
341
+ client .setblocking (blocking )
342
342
return server , client
343
343
344
344
@@ -3297,11 +3297,134 @@ def test_memoryview_really_doesnt_overfill(self):
3297
3297
self ._doesnt_overfill_test (_make_memoryview )
3298
3298
3299
3299
3300
+ @pytest .fixture
3301
+ def nonblocking_tls_connections_pair ():
3302
+ """Return a non-blocking TLS loopback connections pair."""
3303
+ return loopback (blocking = False )
3304
+
3305
+
3306
+ @pytest .fixture
3307
+ def nonblocking_tls_server_connection (nonblocking_tls_connections_pair ):
3308
+ """Return a non-blocking TLS server socket connected to loopback."""
3309
+ return nonblocking_tls_connections_pair [0 ]
3310
+
3311
+
3312
+ @pytest .fixture
3313
+ def nonblocking_tls_client_connection (nonblocking_tls_connections_pair ):
3314
+ """Return a non-blocking TLS client socket connected to loopback."""
3315
+ return nonblocking_tls_connections_pair [1 ]
3316
+
3317
+
3300
3318
class TestConnectionSendall :
3301
3319
"""
3302
3320
Tests for `Connection.sendall`.
3303
3321
"""
3304
3322
3323
+ def test_want_write (
3324
+ self ,
3325
+ monkeypatch ,
3326
+ nonblocking_tls_server_connection ,
3327
+ nonblocking_tls_client_connection ,
3328
+ ):
3329
+ msg = b"x"
3330
+ garbage_size = 1024 * 1024 * 64
3331
+ large_payload = b"p" * garbage_size * 2
3332
+ payload_size = len (large_payload )
3333
+
3334
+ sent_garbage_size = 0
3335
+ try :
3336
+ sent_garbage_size += nonblocking_tls_client_connection .send (
3337
+ msg * garbage_size ,
3338
+ )
3339
+ except WantWriteError :
3340
+ pass
3341
+ for i in range (garbage_size ):
3342
+ try :
3343
+ sent_garbage_size += nonblocking_tls_client_connection .send (
3344
+ msg ,
3345
+ )
3346
+ except WantWriteError :
3347
+ break
3348
+ else :
3349
+ pytest .fail (
3350
+ "Failed to fill socket buffer, cannot test "
3351
+ "'want write' in `sendall()`"
3352
+ )
3353
+ garbage_payload = sent_garbage_size * msg
3354
+
3355
+ def consume_garbage (conn ):
3356
+ assert patched_ssl_write .want_write_counter >= 1
3357
+ assert not consume_garbage .garbage_consumed
3358
+
3359
+ while len (consume_garbage .consumed ) < sent_garbage_size :
3360
+ try :
3361
+ consume_garbage .consumed += conn .recv (
3362
+ sent_garbage_size - len (consume_garbage .consumed ),
3363
+ )
3364
+ except WantReadError :
3365
+ pass
3366
+
3367
+ assert consume_garbage .consumed == garbage_payload
3368
+
3369
+ consume_garbage .garbage_consumed = True
3370
+
3371
+ consume_garbage .garbage_consumed = False
3372
+ consume_garbage .consumed = b""
3373
+
3374
+ def consume_payload (conn ):
3375
+ try :
3376
+ consume_payload .consumed += conn .recv (payload_size )
3377
+ except WantReadError :
3378
+ pass
3379
+
3380
+ consume_payload .consumed = b""
3381
+
3382
+ original_ssl_write = _lib .SSL_write
3383
+
3384
+ def patched_ssl_write (ctx , data , size ):
3385
+ write_result = original_ssl_write (ctx , data , size )
3386
+ try :
3387
+ nonblocking_tls_client_connection ._raise_ssl_error (
3388
+ ctx ,
3389
+ write_result ,
3390
+ )
3391
+ except WantWriteError :
3392
+ patched_ssl_write .want_write_counter += 1
3393
+ consume_data_on_server = (
3394
+ consume_payload
3395
+ if consume_garbage .garbage_consumed
3396
+ else consume_garbage
3397
+ )
3398
+
3399
+ consume_data_on_server (nonblocking_tls_server_connection )
3400
+ # NOTE: We don't re-raise it as the calling code will do
3401
+ # NOTE: the same after the call.
3402
+ return write_result
3403
+
3404
+ patched_ssl_write .want_write_counter = 0
3405
+
3406
+ # NOTE: Make the client think it needs a handshake so that it'll
3407
+ # NOTE: attempt to `do_handshake()` on the next `SSL_write()`
3408
+ # NOTE: that originates from `sendall()`:
3409
+ nonblocking_tls_client_connection .set_connect_state ()
3410
+ try :
3411
+ nonblocking_tls_client_connection .do_handshake ()
3412
+ except WantWriteError :
3413
+ assert True # Sanity check
3414
+ except :
3415
+ assert False # This should never happen (see the note above)
3416
+
3417
+ with monkeypatch .context () as mp_ctx :
3418
+ mp_ctx .setattr (_lib , "SSL_write" , patched_ssl_write )
3419
+ nonblocking_tls_client_connection .sendall (large_payload )
3420
+
3421
+ assert consume_garbage .garbage_consumed
3422
+
3423
+ # NOTE: Read the leftover data from the very last `SSL_write()`
3424
+ consume_payload (nonblocking_tls_server_connection )
3425
+
3426
+ assert consume_payload .consumed == large_payload
3427
+
3305
3428
def test_wrong_args (self ):
3306
3429
"""
3307
3430
When called with arguments other than a string argument for its first
0 commit comments