1+ import asyncio
12import io
23import logging
3- from concurrent .futures import Future
4+ from asyncio import Task
5+ from collections import Counter
6+ from typing import Coroutine
47
58import pytest
69import requests
710from requests_toolbelt import MultipartDecoder , MultipartEncoder
11+
812from unstructured_client ._hooks .custom import form_utils , pdf_utils , request_utils
913from unstructured_client ._hooks .custom .form_utils import (
1014 PARTITION_FORM_CONCURRENCY_LEVEL_KEY ,
1822 MAX_PAGES_PER_SPLIT ,
1923 MIN_PAGES_PER_SPLIT ,
2024 SplitPdfHook ,
21- get_optimal_split_size ,
25+ get_optimal_split_size , run_tasks ,
2226)
2327from unstructured_client .models import shared
2428
@@ -224,7 +228,6 @@ def test_unit_parse_form_data():
224228 b"--boundary--\r \n "
225229 )
226230
227-
228231 decoded_data = MultipartDecoder (
229232 test_form_data ,
230233 "multipart/form-data; boundary=boundary" ,
@@ -361,22 +364,22 @@ def test_get_optimal_split_size(num_pages, concurrency_level, expected_split_siz
361364 ({}, DEFAULT_CONCURRENCY_LEVEL ), # no value
362365 ({"split_pdf_concurrency_level" : 10 }, 10 ), # valid number
363366 (
364- # exceeds max value
365- {"split_pdf_concurrency_level" : f"{ MAX_CONCURRENCY_LEVEL + 1 } " },
366- MAX_CONCURRENCY_LEVEL ,
367+ # exceeds max value
368+ {"split_pdf_concurrency_level" : f"{ MAX_CONCURRENCY_LEVEL + 1 } " },
369+ MAX_CONCURRENCY_LEVEL ,
367370 ),
368371 ({"split_pdf_concurrency_level" : - 3 }, DEFAULT_CONCURRENCY_LEVEL ), # negative value
369372 ],
370373)
371374def test_unit_get_split_pdf_concurrency_level_returns_valid_number (form_data , expected_result ):
372375 assert (
373- form_utils .get_split_pdf_concurrency_level_param (
374- form_data ,
375- key = PARTITION_FORM_CONCURRENCY_LEVEL_KEY ,
376- fallback_value = DEFAULT_CONCURRENCY_LEVEL ,
377- max_allowed = MAX_CONCURRENCY_LEVEL ,
378- )
379- == expected_result
376+ form_utils .get_split_pdf_concurrency_level_param (
377+ form_data ,
378+ key = PARTITION_FORM_CONCURRENCY_LEVEL_KEY ,
379+ fallback_value = DEFAULT_CONCURRENCY_LEVEL ,
380+ max_allowed = MAX_CONCURRENCY_LEVEL ,
381+ )
382+ == expected_result
380383 )
381384
382385
@@ -404,16 +407,16 @@ def test_unit_get_starting_page_number(starting_page_number, expected_result):
404407@pytest .mark .parametrize (
405408 "page_range, expected_result" ,
406409 [
407- (["1" , "14" ], (1 , 14 )), # Valid range, start on boundary
408- (["4" , "16" ], (4 , 16 )), # Valid range, end on boundary
409- (None , (1 , 20 )), # Range not specified, defaults to full range
410+ (["1" , "14" ], (1 , 14 )), # Valid range, start on boundary
411+ (["4" , "16" ], (4 , 16 )), # Valid range, end on boundary
412+ (None , (1 , 20 )), # Range not specified, defaults to full range
410413 (["2" , "5" ], (2 , 5 )), # Valid range within boundary
411- (["2" , "100" ], None ), # End page too high
412- (["50" , "100" ], None ), # Range too high
413- (["-50" , "5" ], None ), # Start page too low
414- (["-50" , "-2" ], None ), # Range too low
415- (["10" , "2" ], None ), # Backwards range
416- (["foo" , "foo" ], None ), # Parse error
414+ (["2" , "100" ], None ), # End page too high
415+ (["50" , "100" ], None ), # Range too high
416+ (["-50" , "5" ], None ), # Start page too low
417+ (["-50" , "-2" ], None ), # Range too low
418+ (["10" , "2" ], None ), # Backwards range
419+ (["foo" , "foo" ], None ), # Parse error
417420 ],
418421)
419422def test_unit_get_page_range_returns_valid_range (page_range , expected_result ):
@@ -432,3 +435,96 @@ def test_unit_get_page_range_returns_valid_range(page_range, expected_result):
432435 return
433436
434437 assert result == expected_result
438+
439+
440+ async def _request_mock (fails : bool , content : str ) -> requests .Response :
441+ response = requests .Response ()
442+ response .status_code = 500 if fails else 200
443+ response ._content = content .encode ()
444+ return response
445+
446+
447+ @pytest .mark .parametrize (
448+ ("allow_failed" , "tasks" , "expected_responses" ), [
449+ pytest .param (
450+ True , [
451+ _request_mock (fails = False , content = "1" ),
452+ _request_mock (fails = False , content = "2" ),
453+ _request_mock (fails = False , content = "3" ),
454+ _request_mock (fails = False , content = "4" ),
455+ ],
456+ ["1" , "2" , "3" , "4" ],
457+ id = "no failures, fails allower"
458+ ),
459+ pytest .param (
460+ True , [
461+ _request_mock (fails = False , content = "1" ),
462+ _request_mock (fails = True , content = "2" ),
463+ _request_mock (fails = False , content = "3" ),
464+ _request_mock (fails = True , content = "4" ),
465+ ],
466+ ["1" , "2" , "3" , "4" ],
467+ id = "failures, fails allowed"
468+ ),
469+ pytest .param (
470+ False , [
471+ _request_mock (fails = True , content = "failure" ),
472+ _request_mock (fails = False , content = "2" ),
473+ _request_mock (fails = True , content = "failure" ),
474+ _request_mock (fails = False , content = "4" ),
475+ ],
476+ ["failure" ],
477+ id = "failures, fails disallowed"
478+ ),
479+ pytest .param (
480+ False , [
481+ _request_mock (fails = False , content = "1" ),
482+ _request_mock (fails = False , content = "2" ),
483+ _request_mock (fails = False , content = "3" ),
484+ _request_mock (fails = False , content = "4" ),
485+ ],
486+ ["1" , "2" , "3" , "4" ],
487+ id = "no failures, fails disallowed"
488+ ),
489+ ]
490+ )
491+ @pytest .mark .asyncio
492+ async def test_unit_disallow_failed_coroutines (
493+ allow_failed : bool ,
494+ tasks : list [Task ],
495+ expected_responses : list [str ],
496+ ):
497+ """Test disallow failed coroutines method properly sets the flag to False."""
498+ responses = await run_tasks (tasks , allow_failed = allow_failed )
499+ response_contents = [response [1 ].content .decode () for response in responses ]
500+ assert response_contents == expected_responses
501+
502+
503+ async def _fetch_canceller_error (fails : bool , content : str , cancelled_counter : Counter ):
504+ try :
505+ if not fails :
506+ await asyncio .sleep (0.01 )
507+ print ("Doesn't fail" )
508+ else :
509+ print ("Fails" )
510+ return await _request_mock (fails = fails , content = content )
511+ except asyncio .CancelledError :
512+ cancelled_counter .update (["cancelled" ])
513+ print (cancelled_counter ["cancelled" ])
514+ print ("Cancelled" )
515+
516+
517+ @pytest .mark .asyncio
518+ async def test_remaining_tasks_cancelled_when_fails_disallowed ():
519+ cancelled_counter = Counter ()
520+ tasks = [
521+ _fetch_canceller_error (fails = True , content = "1" , cancelled_counter = cancelled_counter ),
522+ * [_fetch_canceller_error (fails = False , content = f"{ i } " , cancelled_counter = cancelled_counter )
523+ for i in range (2 , 200 )],
524+ ]
525+
526+ await run_tasks (tasks , allow_failed = False )
527+ # give some time to actually cancel the tasks in background
528+ await asyncio .sleep (1 )
529+ print ("Cancelled amount: " , cancelled_counter ["cancelled" ])
530+ assert len (tasks ) > cancelled_counter ["cancelled" ] > 0
0 commit comments