Skip to content

Commit adcd48f

Browse files
[pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
1 parent 054a98b commit adcd48f

File tree

4 files changed

+91
-36
lines changed

4 files changed

+91
-36
lines changed

datashuttle/configs/canonical_tags.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,7 @@ def get_datetime_format(format_type: str) -> str:
4545
If format_type is not one of the supported types
4646
"""
4747
if format_type not in _DATETIME_FORMATS:
48-
raise ValueError(f"Invalid format type: {format_type}. Must be one of {list(_DATETIME_FORMATS.keys())}")
48+
raise ValueError(
49+
f"Invalid format type: {format_type}. Must be one of {list(_DATETIME_FORMATS.keys())}"
50+
)
4951
return _DATETIME_FORMATS[format_type]
50-

datashuttle/utils/folders.py

Lines changed: 28 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -391,8 +391,12 @@ def filter_names_by_datetime_range(
391391
"""
392392
filtered_names: List[str] = []
393393
for candidate in names:
394-
candidate_basename = candidate if isinstance(candidate, str) else candidate.name
395-
value = get_values_from_bids_formatted_name([candidate_basename], format_type)[0]
394+
candidate_basename = (
395+
candidate if isinstance(candidate, str) else candidate.name
396+
)
397+
value = get_values_from_bids_formatted_name(
398+
[candidate_basename], format_type
399+
)[0]
396400
try:
397401
candidate_timepoint = datetime.strptime(
398402
value, canonical_tags.get_datetime_format(format_type)
@@ -445,10 +449,12 @@ def search_with_tags(
445449
"""
446450
new_all_names: List[str] = []
447451
for name in all_names:
448-
if not (canonical_tags.tags("*") in name or
449-
canonical_tags.tags("DATETO") in name or
450-
canonical_tags.tags("TIMETO") in name or
451-
canonical_tags.tags("DATETIMETO") in name):
452+
if not (
453+
canonical_tags.tags("*") in name
454+
or canonical_tags.tags("DATETO") in name
455+
or canonical_tags.tags("TIMETO") in name
456+
or canonical_tags.tags("DATETIMETO") in name
457+
):
452458
new_all_names.append(name)
453459
continue
454460

@@ -473,7 +479,9 @@ def search_with_tags(
473479

474480
if format_type is not None:
475481
assert tag is not None, "format and tag should be set together"
476-
search_str = validation.format_and_validate_datetime_search_str(search_str, format_type, tag)
482+
search_str = validation.format_and_validate_datetime_search_str(
483+
search_str, format_type, tag
484+
)
477485

478486
# Use the helper function to perform the glob search
479487
if sub:
@@ -491,13 +499,21 @@ def search_with_tags(
491499

492500
# Filter results by datetime range if one was present
493501
if format_type is not None and tag is not None:
494-
expected_values = validation.get_expected_num_datetime_values(format_type)
495-
full_tag_regex = fr"(\d{{{expected_values}}}){re.escape(tag)}(\d{{{expected_values}}})"
502+
expected_values = validation.get_expected_num_datetime_values(
503+
format_type
504+
)
505+
full_tag_regex = rf"(\d{{{expected_values}}}){re.escape(tag)}(\d{{{expected_values}}})"
496506
match = re.search(full_tag_regex, name)
497-
if match: # We know this is true because format_and_validate_datetime_search_str succeeded
507+
if (
508+
match
509+
): # We know this is true because format_and_validate_datetime_search_str succeeded
498510
start_str, end_str = match.groups()
499-
start_timepoint = datetime.strptime(start_str, canonical_tags.get_datetime_format(format_type))
500-
end_timepoint = datetime.strptime(end_str, canonical_tags.get_datetime_format(format_type))
511+
start_timepoint = datetime.strptime(
512+
start_str, canonical_tags.get_datetime_format(format_type)
513+
)
514+
end_timepoint = datetime.strptime(
515+
end_str, canonical_tags.get_datetime_format(format_type)
516+
)
501517
matching_names = filter_names_by_datetime_range(
502518
matching_names, format_type, start_timepoint, end_timepoint
503519
)

datashuttle/utils/validation.py

Lines changed: 36 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,11 @@
2424
from itertools import chain
2525
from pathlib import Path
2626

27-
from datashuttle.configs import canonical_configs, canonical_folders, canonical_tags
27+
from datashuttle.configs import (
28+
canonical_configs,
29+
canonical_folders,
30+
canonical_tags,
31+
)
2832
from datashuttle.utils import formatting, getters, utils
2933
from datashuttle.utils.custom_exceptions import NeuroBlueprintError
3034

@@ -432,7 +436,9 @@ def datetime_are_iso_format(
432436
"""
433437
Check formatting for date-, time-, or datetime- tags.
434438
"""
435-
key = next((key for key in ["datetime", "time", "date"] if key in name), None)
439+
key = next(
440+
(key for key in ["datetime", "time", "date"] if key in name), None
441+
)
436442

437443
error_message: List[str]
438444
if not key:
@@ -447,15 +453,22 @@ def datetime_are_iso_format(
447453

448454
try:
449455
if not validate_datetime(format_to_check, key):
450-
error_message = [get_datetime_error(
451-
key, name, canonical_tags.get_datetime_format(key), path_
452-
)]
456+
error_message = [
457+
get_datetime_error(
458+
key,
459+
name,
460+
canonical_tags.get_datetime_format(key),
461+
path_,
462+
)
463+
]
453464
else:
454465
error_message = []
455466
except ValueError:
456-
error_message = [get_datetime_error(
457-
key, name, canonical_tags.get_datetime_format(key), path_
458-
)]
467+
error_message = [
468+
get_datetime_error(
469+
key, name, canonical_tags.get_datetime_format(key), path_
470+
)
471+
]
459472

460473
return error_message
461474

@@ -477,7 +490,9 @@ def validate_datetime(datetime_str: str, format_type: str) -> bool:
477490
True if valid, False otherwise
478491
"""
479492
try:
480-
datetime.strptime(datetime_str, canonical_tags.get_datetime_format(format_type))
493+
datetime.strptime(
494+
datetime_str, canonical_tags.get_datetime_format(format_type)
495+
)
481496
return True
482497
except ValueError:
483498
return False
@@ -502,7 +517,9 @@ def get_expected_num_datetime_values(format_type: str) -> int:
502517
return len(today.strftime(format_str))
503518

504519

505-
def format_and_validate_datetime_search_str(search_str: str, format_type: str, tag: str) -> str:
520+
def format_and_validate_datetime_search_str(
521+
search_str: str, format_type: str, tag: str
522+
) -> str:
506523
"""
507524
Validate and format a search string containing a datetime range.
508525
@@ -526,7 +543,9 @@ def format_and_validate_datetime_search_str(search_str: str, format_type: str, t
526543
If the datetime format is invalid or the range is malformed
527544
"""
528545
expected_values = get_expected_num_datetime_values(format_type)
529-
full_tag_regex = fr"(\d{{{expected_values}}}){re.escape(tag)}(\d{{{expected_values}}})"
546+
full_tag_regex = (
547+
rf"(\d{{{expected_values}}}){re.escape(tag)}(\d{{{expected_values}}})"
548+
)
530549
match = re.search(full_tag_regex, search_str)
531550

532551
if not match:
@@ -549,8 +568,12 @@ def format_and_validate_datetime_search_str(search_str: str, format_type: str, t
549568
NeuroBlueprintError,
550569
)
551570

552-
start_timepoint = datetime.strptime(start_str, canonical_tags.get_datetime_format(format_type))
553-
end_timepoint = datetime.strptime(end_str, canonical_tags.get_datetime_format(format_type))
571+
start_timepoint = datetime.strptime(
572+
start_str, canonical_tags.get_datetime_format(format_type)
573+
)
574+
end_timepoint = datetime.strptime(
575+
end_str, canonical_tags.get_datetime_format(format_type)
576+
)
554577

555578
if end_timepoint < start_timepoint:
556579
utils.log_and_raise_error(
@@ -1079,5 +1102,3 @@ def check_datatypes_are_valid(
10791102
return message
10801103

10811104
return None
1082-
1083-

tests/test_date_search_range.py

Lines changed: 24 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,6 @@
33
import re
44
import shutil
55
import tempfile
6-
from datetime import datetime
76
from pathlib import Path
87
from typing import List
98

@@ -20,7 +19,7 @@ def tags(x: str) -> str:
2019
"*": "@*@",
2120
"DATETO": "@DATETO@",
2221
"TIMETO": "@TIMETO@",
23-
"DATETIMETO": "@DATETIMETO@"
22+
"DATETIMETO": "@DATETIMETO@",
2423
}
2524
return tags_dict.get(x, x)
2625

@@ -40,8 +39,13 @@ def get_datetime_format(format_type: str) -> str:
4039
@pytest.fixture(autouse=True)
4140
def patch_canonical_tags(monkeypatch):
4241
from datashuttle.configs import canonical_tags
42+
4343
monkeypatch.setattr(canonical_tags, "tags", DummyCanonicalTags.tags)
44-
monkeypatch.setattr(canonical_tags, "get_datetime_format", DummyCanonicalTags.get_datetime_format)
44+
monkeypatch.setattr(
45+
canonical_tags,
46+
"get_datetime_format",
47+
DummyCanonicalTags.get_datetime_format,
48+
)
4549

4650

4751
# Dummy implementation for search_sub_or_ses_level that simply performs globbing.
@@ -57,11 +61,16 @@ def dummy_search_sub_or_ses_level(
5761
@pytest.fixture(autouse=True)
5862
def patch_search_sub_or_ses_level(monkeypatch):
5963
from datashuttle.utils import folders
60-
monkeypatch.setattr(folders, "search_sub_or_ses_level", dummy_search_sub_or_ses_level)
64+
65+
monkeypatch.setattr(
66+
folders, "search_sub_or_ses_level", dummy_search_sub_or_ses_level
67+
)
6168

6269

6370
# Dummy implementation for get_values_from_bids_formatted_name
64-
def dummy_get_values_from_bids_formatted_name(names: List[str], key: str, return_as_int: bool = False) -> List[str]:
71+
def dummy_get_values_from_bids_formatted_name(
72+
names: List[str], key: str, return_as_int: bool = False
73+
) -> List[str]:
6574
results = []
6675
for name in names:
6776
if key == "date":
@@ -75,7 +84,12 @@ def dummy_get_values_from_bids_formatted_name(names: List[str], key: str, return
7584
@pytest.fixture(autouse=True)
7685
def patch_get_values_from_bids(monkeypatch):
7786
from datashuttle.utils import utils
78-
monkeypatch.setattr(utils, "get_values_from_bids_formatted_name", dummy_get_values_from_bids_formatted_name)
87+
88+
monkeypatch.setattr(
89+
utils,
90+
"get_values_from_bids_formatted_name",
91+
dummy_get_values_from_bids_formatted_name,
92+
)
7993

8094

8195
# Fixture to create a temporary directory with a simulated folder structure.
@@ -104,6 +118,7 @@ def test_date_range_wildcard(temp_project_dir: Path):
104118
only folders whose embedded date falls between 20250306 and 20250309 (inclusive)
105119
should be returned.
106120
"""
121+
107122
class Configs:
108123
pass
109124

@@ -130,6 +145,7 @@ def test_simple_wildcard(temp_project_dir: Path):
130145
When given a simple wildcard pattern like "sub-01_@*@",
131146
all folders should be returned.
132147
"""
148+
133149
class Configs:
134150
pass
135151

@@ -146,6 +162,7 @@ def test_invalid_date_range(temp_project_dir: Path):
146162
"""
147163
Test that invalid date ranges raise appropriate errors.
148164
"""
165+
149166
class Configs:
150167
pass
151168

@@ -170,6 +187,7 @@ def test_combined_wildcards(temp_project_dir: Path):
170187
"""
171188
Test that wildcard and date range can be combined in the same pattern.
172189
"""
190+
173191
class Configs:
174192
pass
175193

@@ -197,4 +215,3 @@ class Configs:
197215
"sub-03_date-20250308",
198216
}
199217
assert matched_folders == expected_folders
200-

0 commit comments

Comments
 (0)