diff --git a/pyproject.toml b/pyproject.toml index 7bc3ec19..f40b295c 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -61,6 +61,13 @@ extend-select = [ "UP035", # Missing function argument type-annotation "ANN001", + "ANN002", + "ANN003", + "ANN201", + "ANN202", + "ANN204", + "ANN205", + "ANN206", # Using except without specifying an exception type to catch "BLE001", ] diff --git a/tests/test_cli.py b/tests/test_cli.py index 9d866a74..4715a7ec 100644 --- a/tests/test_cli.py +++ b/tests/test_cli.py @@ -48,7 +48,7 @@ realisation_to_srf, ], ) -def test_invocation_of_script(script: Callable): +def test_invocation_of_script(script: Callable) -> None: """Basic check that the scripts can be invoked.""" runner = CliRunner() result = runner.invoke(script.app, ["--help"]) diff --git a/tests/test_log_utils.py b/tests/test_log_utils.py index 5d599ae9..db385373 100644 --- a/tests/test_log_utils.py +++ b/tests/test_log_utils.py @@ -11,11 +11,11 @@ @log_utils.log_call() -def foo(a: int, b: int): +def foo(a: int, b: int) -> int: return a + b -def test_basic_log(): +def test_basic_log() -> None: log_capture = structlog.testing.LogCapture() structlog.configure(processors=[log_capture]) @@ -35,11 +35,11 @@ def test_basic_log(): @log_utils.log_call(exclude_args={"b"}) -def foo_less_b(a: int, b: int): +def foo_less_b(a: int, b: int) -> int: return a + b -def test_excluded_log(): +def test_excluded_log() -> None: log_capture = structlog.testing.LogCapture() structlog.configure(processors=[log_capture]) @@ -59,11 +59,11 @@ def test_excluded_log(): @log_utils.log_call(action_name="FOOBAR") -def bar(a: Any): +def bar(a: Any) -> None: pass -def test_renamed_bar(): +def test_renamed_bar() -> None: log_capture = structlog.testing.LogCapture() structlog.configure(processors=[log_capture]) @@ -81,11 +81,11 @@ def test_renamed_bar(): @log_utils.log_call(include_result=False) -def baz(a: Any): +def baz(a: Any) -> int: return 1 -def test_no_result(): +def test_no_result() -> None: log_capture = structlog.testing.LogCapture() structlog.configure(processors=[log_capture]) @@ -104,11 +104,11 @@ def test_no_result(): @log_utils.log_call() -def failing_function(): +def failing_function() -> None: raise ValueError("This function should fail!") -def test_failing_function(): +def test_failing_function() -> None: log_capture = structlog.testing.LogCapture() structlog.configure(processors=[log_capture]) @@ -122,7 +122,7 @@ def test_failing_function(): assert "error" in return_log -def test_successful_check_call_log(tmp_path: Path): +def test_successful_check_call_log(tmp_path: Path) -> None: log_capture = structlog.testing.LogCapture() structlog.configure(processors=[log_capture]) @@ -140,7 +140,7 @@ def test_successful_check_call_log(tmp_path: Path): assert "stdout" in completion_message and "test.txt" in completion_message["stdout"] -def test_failing_check_call_log(): +def test_failing_check_call_log() -> None: log_capture = structlog.testing.LogCapture() structlog.configure(processors=[log_capture]) @@ -156,7 +156,7 @@ def test_failing_check_call_log(): ) -def test_repeated_logs(): +def test_repeated_logs() -> None: log_capture = structlog.testing.LogCapture() structlog.configure(processors=[log_capture]) @@ -177,12 +177,12 @@ def test_repeated_logs(): ) -def _thread_worker(logger_name: str): +def _thread_worker(logger_name: str) -> None: logger = log_utils.get_logger(logger_name) logger.info("Threaded log message") -def test_thread_safety(): +def test_thread_safety() -> None: log_capture = structlog.testing.LogCapture() structlog.configure(processors=[log_capture]) diff --git a/tests/test_realisation.py b/tests/test_realisation.py index e8e5289b..14296d96 100644 --- a/tests/test_realisation.py +++ b/tests/test_realisation.py @@ -18,7 +18,7 @@ from workflow import defaults, realisations -def test_bounding_box_example(tmp_path: Path): +def test_bounding_box_example(tmp_path: Path) -> None: domain_parameters = realisations.DomainParameters( resolution=0.1, # a 0.1km resolution domain=bounding_box.BoundingBox.from_centroid_bearing_extents( @@ -60,7 +60,7 @@ def test_bounding_box_example(tmp_path: Path): ).all() -def test_domain_parameters_properties(): +def test_domain_parameters_properties() -> None: domain_parameters = realisations.DomainParameters( resolution=0.1, # a 0.1km resolution domain=bounding_box.BoundingBox.from_centroid_bearing_extents( @@ -78,7 +78,7 @@ def test_domain_parameters_properties(): assert domain_parameters.nz == 400 -def test_srf_config_example(tmp_path: Path): +def test_srf_config_example(tmp_path: Path) -> None: domain_parameters = realisations.DomainParameters( resolution=0.1, # a 0.1km resolution domain=bounding_box.BoundingBox.from_centroid_bearing_extents( @@ -124,7 +124,7 @@ def test_srf_config_example(tmp_path: Path): assert realisations.SRFConfig.read_from_realisation(realisation_ffp) == srf_config -def test_bad_domain_parameters(tmp_path: Path): +def test_bad_domain_parameters(tmp_path: Path) -> None: bad_json = tmp_path / "bad_domain_parameters.json" bad_json.write_text( json.dumps( @@ -160,7 +160,7 @@ def test_bad_domain_parameters(tmp_path: Path): realisations.DomainParameters.read_from_realisation(bad_json) -def test_bad_config_key(tmp_path: Path): +def test_bad_config_key(tmp_path: Path) -> None: bad_json = tmp_path / "bad_domain_parameters.json" bad_json.write_text( json.dumps( @@ -196,7 +196,7 @@ def test_bad_config_key(tmp_path: Path): realisations.DomainParameters.read_from_realisation(bad_json) -def test_metadata(tmp_path: Path): +def test_metadata(tmp_path: Path) -> None: metadata = realisations.RealisationMetadata( name="consecutive write test", version="1", @@ -220,7 +220,7 @@ def test_metadata(tmp_path: Path): ) -def test_velocity_model(tmp_path: Path): +def test_velocity_model(tmp_path: Path) -> None: velocity_model = realisations.VelocityModelParameters( min_vs=1.0, version="2.06", @@ -257,7 +257,7 @@ def test_velocity_model(tmp_path: Path): ) -def test_rupture_prop_config(tmp_path: Path): +def test_rupture_prop_config(tmp_path: Path) -> None: rup_prop = realisations.RupturePropagationConfig( rupture_causality_tree={"A": None, "B": "A", "C": "B"}, jump_points={ @@ -307,7 +307,7 @@ def test_rupture_prop_config(tmp_path: Path): assert rupture_prop_config.hypocentre.tolist() == [0.0, 0.6] -def test_rupture_prop_properties(): +def test_rupture_prop_properties() -> None: rup_prop = realisations.RupturePropagationConfig( rupture_causality_tree={"A": None, "B": "A", "C": "B"}, jump_points={ @@ -325,7 +325,7 @@ def test_rupture_prop_properties(): assert rup_prop.initial_fault == "A" -def test_hf_config(tmp_path: Path): +def test_hf_config(tmp_path: Path) -> None: test_realisation = tmp_path / "realisation.json" test_realisation.write_text("{}") hf_config = realisations.HFConfig.read_from_realisation_or_defaults( @@ -344,7 +344,7 @@ def test_hf_config(tmp_path: Path): ) -def test_emod3d(tmp_path: Path): +def test_emod3d(tmp_path: Path) -> None: test_realisation = tmp_path / "realisation.json" test_realisation.write_text("{}") emod3d = realisations.EMOD3DParameters.read_from_realisation_or_defaults( @@ -363,7 +363,7 @@ def test_emod3d(tmp_path: Path): ) -def test_broadband_parameters(tmp_path: Path): +def test_broadband_parameters(tmp_path: Path) -> None: test_realisation = tmp_path / "realisation.json" broadband_parameters = realisations.BroadbandParameters( flo=0.5, dt=0.005, fmidbot=0.5, fmin=0.25, site_amp_version="2014" @@ -385,14 +385,14 @@ def test_broadband_parameters(tmp_path: Path): ) -def test_logtrail_init_empty(): +def test_logtrail_init_empty() -> None: """Test LogTrail initialization with no log provided.""" trail = realisations.LogTrail([]) assert trail.log == [] assert trail._config_key == "log_trail" -def test_logtrail_init_with_log_entries(): +def test_logtrail_init_with_log_entries() -> None: """Test LogTrail initialization with a list of LogEntry objects.""" entry1 = realisations.LogEntry( utility="util1", args=["a"], version="1", timestamp=datetime.now() @@ -404,7 +404,7 @@ def test_logtrail_init_with_log_entries(): assert trail.log == [entry1, entry2] -def test_logtrail_init_with_dicts_post_init(): +def test_logtrail_init_with_dicts_post_init() -> None: """Test LogTrail post_init conversion of dicts to LogEntry objects.""" log_data = [ { @@ -431,7 +431,7 @@ def test_logtrail_init_with_dicts_post_init(): assert trail.log[1].args == ["b"] -def test_logtrail_log_entry_method(): +def test_logtrail_log_entry_method() -> None: """Test adding an entry using the log_entry method.""" trail = realisations.LogTrail([]) trail.log_entry("my_util", ["--flag", "value"]) @@ -442,7 +442,7 @@ def test_logtrail_log_entry_method(): assert isinstance(trail.log[0].timestamp, datetime) -def test_logtrail_to_dict(): +def test_logtrail_to_dict() -> None: """Test converting LogTrail to a dictionary.""" ts = datetime.now() entry1 = realisations.LogEntry( @@ -480,7 +480,7 @@ def test_logtrail_to_dict(): def test_append_log_entry_file_exists_no_key( tmp_path: Path, -): +) -> None: """Test append_log_entry when file exists but lacks the 'log_trail' key.""" realisation_file = tmp_path / "test_realisation.json" # Create a file with unrelated content @@ -504,7 +504,7 @@ def test_append_log_entry_file_exists_no_key( assert data["log_trail"]["log"][0]["utility"] == "script_name.py" -def test_seeds(): +def test_seeds() -> None: seeds = realisations.Seeds.random_seeds() assert all( 0 <= seed <= 2 ** (struct.Struct("i").size * 8 - 1) - 1 @@ -512,7 +512,7 @@ def test_seeds(): ) -def test_velocity_model_1d(tmp_path: Path): +def test_velocity_model_1d(tmp_path: Path) -> None: velocity_model_1d = realisations.VelocityModel1D( model=pd.DataFrame( { @@ -564,7 +564,7 @@ def test_velocity_model_1d(tmp_path: Path): ) -def test_intensity_measure_calculation_parameters(tmp_path: Path): +def test_intensity_measure_calculation_parameters(tmp_path: Path) -> None: im_calc_params = realisations.IntensityMeasureCalculationParameters( ims=[im_calculation.IM("PGA"), im_calculation.IM("PGV")], valid_periods=np.array([0.1, 0.2, 0.3]), @@ -605,5 +605,5 @@ def test_defaults_are_loadable( tmp_path: Path, realisation_config: realisations.RealisationConfiguration, defaults_version: defaults.DefaultsVersion, -): +) -> None: realisation_config.read_from_defaults(defaults_version) diff --git a/tests/test_utils.py b/tests/test_utils.py index 51936f0b..de823b6f 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -4,14 +4,16 @@ from workflow import utils -def test_get_available_cores_slurm_cpus_on_node(): +def test_get_available_cores_slurm_cpus_on_node() -> None: with patch.dict(os.environ, {"SLURM_CPUS_ON_NODE": "4"}): assert utils.get_available_cores() == 4 -def get_available_cores_slurm_nprocs(): + +def get_available_cores_slurm_nprocs() -> None: with patch.dict(os.environ, {"SLURM_NPROCS": "8"}): assert utils.get_available_cores() == 8 -def get_available_cores_no_slurm(): + +def get_available_cores_no_slurm() -> None: with patch("multiprocessing.cpu_count", return_value=16): assert utils.get_available_cores() == 16 diff --git a/workflow/log_utils.py b/workflow/log_utils.py index 073735c3..bfde980e 100644 --- a/workflow/log_utils.py +++ b/workflow/log_utils.py @@ -84,7 +84,9 @@ def log_call( def decorator(f: Callable) -> Callable: # numpydoc ignore=GL08 @functools.wraps(f) - def wrapper(*args, **kwargs): # numpydoc ignore=GL08 + def wrapper( + *args: list[Any], **kwargs: dict[str, Any] + ) -> Any: # numpydoc ignore=GL08 nonlocal exclude_args signature = inspect.signature(f) function_id = str(uuid.uuid4()) diff --git a/workflow/realisations.py b/workflow/realisations.py index 6a8c7959..d679ea69 100644 --- a/workflow/realisations.py +++ b/workflow/realisations.py @@ -243,7 +243,7 @@ class Seeds(RealisationConfiguration): @classmethod def read_from_realisation_or_defaults( - cls, realisation_ffp: Path, *args + cls, realisation_ffp: Path, *args: list[Any] ) -> Self: # *args is to maintain compat with superclass (remove this and see the error in mypy). """Read seeds configuration from a realisation file or generate random seeds if not present. @@ -256,7 +256,7 @@ def read_from_realisation_or_defaults( ---------- realisation_ffp : Path The realisation filepath to read from. - *args : Any + *args : list Ignored arguments. Returns @@ -315,7 +315,7 @@ class SourceConfig(RealisationConfiguration): source_geometries: dict[str, IsSource] """Dictionary mapping source names to their definitions.""" - def to_dict(self): + def to_dict(self) -> dict[str, Any]: """ Convert the object to a dictionary representation. @@ -522,7 +522,7 @@ class VelocityModel1D(RealisationConfiguration): model: pd.DataFrame - def write_velocity_model(self, velocity_model_path: Path): + def write_velocity_model(self, velocity_model_path: Path) -> None: """Write a 1D velocity model to the specified path. Parameters diff --git a/workflow/scripts/bb_sim.py b/workflow/scripts/bb_sim.py index 689fea2f..3771f648 100644 --- a/workflow/scripts/bb_sim.py +++ b/workflow/scripts/bb_sim.py @@ -81,7 +81,7 @@ def bb_simulate_station( n2: float, station_name: str, station: pd.Series, -): +) -> None: """Simulate broadband seismic for a single station. Combines the low frequency and high frequency waveforms together @@ -185,7 +185,7 @@ def combine_hf_and_lf( Path, typer.Argument(file_okay=False, exists=True) ], output_ffp: Annotated[Path, typer.Argument(dir_okay=False, writable=True)], -): +) -> None: """Combine low-frequency and high-frequency seismic waveforms. Parameters diff --git a/workflow/scripts/check_domain.py b/workflow/scripts/check_domain.py index 6e3e8551..5340ce1e 100644 --- a/workflow/scripts/check_domain.py +++ b/workflow/scripts/check_domain.py @@ -47,7 +47,7 @@ def check_domain( realisation_ffp: Annotated[Path, typer.Argument()], srf_ffp: Annotated[Path, typer.Argument()], velocity_model: Annotated[Path, typer.Argument()], -): +) -> None: """Check an SRF's contents for viability. Parameters diff --git a/workflow/scripts/check_srf.py b/workflow/scripts/check_srf.py index 358cb57d..10de9d94 100644 --- a/workflow/scripts/check_srf.py +++ b/workflow/scripts/check_srf.py @@ -54,7 +54,7 @@ def check_srf( realisation_ffp: Annotated[Path, typer.Argument()], srf_ffp: Annotated[Path, typer.Argument()], -): +) -> None: """Check an SRF's contents for viability. Parameters diff --git a/workflow/scripts/create_e3d_par.py b/workflow/scripts/create_e3d_par.py index ed25b693..488f2ad7 100644 --- a/workflow/scripts/create_e3d_par.py +++ b/workflow/scripts/create_e3d_par.py @@ -254,7 +254,7 @@ def create_e3d_par( "/EMOD3D/tools/emod3d-mpi_v3.0.8" ), emod3d_version: Annotated[str, typer.Option()] = "3.0.8", -): +) -> None: """Create EMOD3D parameter file from provided inputs. Parameters diff --git a/workflow/scripts/gcmt_auto_simulate.py b/workflow/scripts/gcmt_auto_simulate.py index c7d366e2..7744dfcc 100755 --- a/workflow/scripts/gcmt_auto_simulate.py +++ b/workflow/scripts/gcmt_auto_simulate.py @@ -77,7 +77,7 @@ def get_nz_outline_polygon() -> Polygon: def gcmt_auto_simulate( gcmt_solutions_url: Annotated[str, typer.Argument()], old_gcmt_solutions_path: Annotated[Path, typer.Argument()], -): +) -> None: """Automatically simulate GCMT solutions that are new, large, and within 30 km of New Zealand. Parameters diff --git a/workflow/scripts/gcmt_to_realisation.py b/workflow/scripts/gcmt_to_realisation.py index e9dc9aaa..994a6fa2 100644 --- a/workflow/scripts/gcmt_to_realisation.py +++ b/workflow/scripts/gcmt_to_realisation.py @@ -103,7 +103,7 @@ def gcmt_to_realisation( nodal_plane: Annotated[ NodalPlaneChoice, typer.Option() ] = NodalPlaneChoice.MOST_LIKELY, -): +) -> None: """Generate a realisation from a GCMT solution. Parameters diff --git a/workflow/scripts/generate_rupture_propagation.py b/workflow/scripts/generate_rupture_propagation.py index 49f14795..36617220 100644 --- a/workflow/scripts/generate_rupture_propagation.py +++ b/workflow/scripts/generate_rupture_propagation.py @@ -137,7 +137,7 @@ def generate_rupture_propagation( ], shypo: Annotated[Optional[float], typer.Option(min=0, max=1)] = None, dhypo: Annotated[Optional[float], typer.Option(min=0, max=1)] = None, -): +) -> None: """Generate a likely rupture propagation for a given set of sources. Parameters diff --git a/workflow/scripts/generate_stoch.py b/workflow/scripts/generate_stoch.py index 3704492e..f59e4a97 100644 --- a/workflow/scripts/generate_stoch.py +++ b/workflow/scripts/generate_stoch.py @@ -46,7 +46,7 @@ def generate_stoch( srf2stoch_path: Annotated[Path, typer.Option(exists=True)] = Path( "/EMOD3D/tools/srf2stoch" ), -): +) -> None: """Generate a stoch file from an SRF file. This function uses the `srf2stoch` binary to generate a stoch file from the provided SRF file. diff --git a/workflow/scripts/generate_velocity_model_parameters.py b/workflow/scripts/generate_velocity_model_parameters.py index 3eda0ddf..3ffed444 100644 --- a/workflow/scripts/generate_velocity_model_parameters.py +++ b/workflow/scripts/generate_velocity_model_parameters.py @@ -405,7 +405,7 @@ def pgv_target( @log_utils.log_call() def generate_velocity_model_parameters( realisation_ffp: Annotated[Path, typer.Argument()], -): +) -> None: """Generate velocity model parameters for a given realisation file. This function reads the source and rupture propagation information and computes: diff --git a/workflow/scripts/hf_sim.py b/workflow/scripts/hf_sim.py index c438e264..0ded2fd9 100644 --- a/workflow/scripts/hf_sim.py +++ b/workflow/scripts/hf_sim.py @@ -193,7 +193,7 @@ def run_hf( Path, typer.Option(exists=True, writable=True, file_okay=False), ] = Path("/out"), -): +) -> None: """Run the HF (High-Frequency) simulation and generate the HF output file. This function performs the following steps: diff --git a/workflow/scripts/nshm2022_to_realisation.py b/workflow/scripts/nshm2022_to_realisation.py index 23fd970d..06e1288e 100755 --- a/workflow/scripts/nshm2022_to_realisation.py +++ b/workflow/scripts/nshm2022_to_realisation.py @@ -158,7 +158,7 @@ def generate_realisation( max=1, ), ] = None, -): +) -> None: """Generate realisation stub files from ruptures in the NSHM 2022 database. This function initializes a connection to the NSHM database, retrieves diff --git a/workflow/scripts/plan_workflow.py b/workflow/scripts/plan_workflow.py index 93ae3d0d..e13bc09f 100644 --- a/workflow/scripts/plan_workflow.py +++ b/workflow/scripts/plan_workflow.py @@ -708,7 +708,7 @@ def plan_workflow( ] = None, container: Annotated[Optional[Path], typer.Option()] = None, emod3d_path: Annotated[Optional[Path], typer.Option()] = None, -): +) -> None: """Plan and generate a Cylc workflow file for a number of realisations. Parameters diff --git a/workflow/scripts/realisation_to_srf.py b/workflow/scripts/realisation_to_srf.py index 0cd3af52..a470074c 100644 --- a/workflow/scripts/realisation_to_srf.py +++ b/workflow/scripts/realisation_to_srf.py @@ -130,7 +130,7 @@ def generate_fault_srf( srf_config: SRFConfig, seeds: Seeds, genslip_path: Path, -): +) -> None: """Generate an SRF file for a given fault. Parameters @@ -450,7 +450,7 @@ def generate_fault_srfs_parallel( seeds: Seeds, velocity_model_1d: VelocityModel1D, genslip_path: Path, -): +) -> None: """Generate fault SRF files in parallel. Parameters @@ -524,7 +524,7 @@ def generate_srf( genslip_path: Annotated[Path, typer.Option(readable=True, dir_okay=False)] = Path( "/EMOD3D/tools/genslip_v5.4.2" ), -): +) -> None: """Generate an SRF file from a given realisation specification. This function reads the realisation metadata and configurations from the specified YAML file. It then generates