diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index 81ecaac..080acfb 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -518,7 +518,7 @@ They'll help us keep things organized and make your contribution process as effi ##### C.b Lint, Format and Test your Code > [!TIP] -> For VS Code users, we recommend installing: +> For VSCode users, we recommend installing: > - [Ruff extension](https://marketplace.visualstudio.com/items?itemName=charliermarsh.ruff) > - [Mypy extension](https://marketplace.visualstudio.com/items?itemName=ms-python.mypy-type-checker) @@ -540,9 +540,9 @@ To ensure your code meets `torchmeter`'s standards, please complete these 3 crit ``` 2. Linting and Formatting - - `torchmeter` uses `ruff` for linting/formatting (already installed in step [B.e.3](#be-configure-python-environment)). + - `torchmeter` uses [`ruff`](https://docs.astral.sh/ruff) for linting/formatting (already installed in step [B.e.3](#be-configure-python-environment)). - Our style rules are defined in `ruff.toml`, please respect these configurations. If you find any rules unreasonable, please start a [Discussions](#πŸ’¬-discussions--lets-collaborate--innovate) - - Ensure your changes comply with the project's code style by running the following linting commands: + - Ensure the code format of your changes meets the project requirements by running the following formating commands: ```bash # pwd: path/to/your/working/directory/torchmeter-yourname @@ -550,15 +550,14 @@ To ensure your code meets `torchmeter`'s standards, please complete these 3 crit # Replace `torchmeter-dev` with your virtual environment name conda activate torchmeter-dev - ruff check \ + ruff format \ --preview \ - --target-version=py38 \ - --output-format=grouped - - # You should promise output is `All checks passed!` + --target-version=py38 + + # You should promise the command ends successfully ``` - - If the code analysis passes, then ensure that the code format meets the project requirements. + - After that, ensure your changes comply with the project's code style with the following commands: ```bash # pwd: path/to/your/working/directory/torchmeter-yourname @@ -566,15 +565,16 @@ To ensure your code meets `torchmeter`'s standards, please complete these 3 crit # Replace `torchmeter-dev` with your virtual environment name conda activate torchmeter-dev - ruff format \ - --diff \ + ruff check \ + --fix \ + --unsafe-fixes \ --preview \ --target-version=py38 - # You should promise output is only one line showing` files already formatted` + # You should promise output is `All checks passed!` and no errors are reported. ``` - - If any step fails, please modify the code according to the terminal output and re-execute the above steps until both steps are successful. If you are a VSCode user, we recommend using the `ruff` plugin to automatically perform code linting and formatting. This plugin uses underlines to highlight code snippets that do not conform to the predefined rules and allows you to automatically fix some common errors. + - If any step fails, please modify the code according to the terminal output and re-execute the above steps until both steps are successful. If you are a VSCode user, we recommend using the `ruff` [plugin](https://marketplace.visualstudio.com/items?itemName=charliermarsh.ruff) to automatically perform code linting and formatting. This plugin uses underlines to highlight code snippets that do not conform to the predefined rules and allows you to automatically fix some common errors. > [!TIP] > If you have a way to run the `shell` script (on Unix-like systems or `cygwin` on `windows`)), then: @@ -585,7 +585,7 @@ To ensure your code meets `torchmeter`'s standards, please complete these 3 crit > ``` > This runs all linting and formatting in one command. -3. Testing +1. Testing - `torchmeter` uses `pytest` for testing code. Yes, `pytest` and the related plugins have also been installed in step [B.e.3](#be-configure-python-environment). - `torchmeter` has written the `pytest` running configuration in the `pytest.ini` file at the root directory of the project. This file defines how the tests are run, including the test directory, test filters, test configuration, etc. Specifically, `pytest` will only discover tests in the `tests` directory at the root of the project, and requires a test coverage rate of **> 90%**. diff --git a/misc/lint_format.sh b/misc/lint_format.sh index dbad6a5..dc79c7d 100644 --- a/misc/lint_format.sh +++ b/misc/lint_format.sh @@ -1,5 +1,13 @@ #!/usr/bin/env bash +green_output() { + echo -e "\033[32m$1\033[0m" +} + +cyan_output() { + echo -e "\033[36m\033[0m" +} + find_dir() { local target_path=$1 local current_path=$(realpath $(dirname $0)) @@ -46,43 +54,45 @@ do if [ -n "$env" ]; then cyan_output "$env selected." conda activate "$env" - green_output "$env activated." + green_output "$env activated.\n" break else red_output "Invalid selection. Please try again." fi done -# ---------------------------------------------- Lint ----------------------------------------------- +# --------------------------------------------- Format ----------------------------------------------- set +e -ruff check \ +ruff format \ --preview \ - --target-version=py38 \ - --output-format=grouped + --target-version=py38 exit_code=$? set -e if [[ $exit_code -eq 0 ]]; then - echo -e "\nβœ… Linting passed! Code quality check successful! πŸŽ‰" + echo -e "βœ… Formatting finish! πŸŽ‰\n" else - echo -e "\n❌ Linting failed! Some code does not meet the linting rules!" >&2 + echo -e "❌ Formatting failed! Some code does not meet the format requirements!s" >&2 + echo -e "❌ Ruff terminates abnormally due to invalid configuration, invalid CLI options, or an internal error" >&2 exit 1 fi -# --------------------------------------------- Format ----------------------------------------------- +# ---------------------------------------------- Lint ----------------------------------------------- set +e -ruff format \ - --diff \ +ruff check \ --preview \ - --target-version=py38 + --fix \ + --unsafe-fixes \ + --target-version=py38 \ + --output-format=grouped exit_code=$? - set -e + if [[ $exit_code -eq 0 ]]; then - echo -e "\nβœ… Formatting passed! All code is well-formated! πŸŽ‰" + echo -e "βœ… Linting passed! Code quality check successful! πŸŽ‰\n" else - echo -e "\n❌ Formatting failed! Some code does not meet the format requirements!" >&2 + echo -e "❌ Linting failed! Some code does not meet the linting rules!\n" >&2 exit 1 fi \ No newline at end of file diff --git a/mypy.ini b/mypy.ini new file mode 100644 index 0000000..8e6f2a9 --- /dev/null +++ b/mypy.ini @@ -0,0 +1,7 @@ +[mypy] +packages = torchmeter + +exclude = ( + tests + | refers + ) \ No newline at end of file diff --git a/ruff.toml b/ruff.toml index f4aba2a..555a081 100644 --- a/ruff.toml +++ b/ruff.toml @@ -1,48 +1,125 @@ -# See https://docs.astral.sh/ruff/settings for help +# TorchMeter, MIT license +# Author: Ahzyuan +# Repo: https://github.com/TorchMeter/torchmeter -target-version = "py37" # Always generate Python 3.6-compatible code. -line-length = 250 # Allow lines to be as long as 120. +include = [ + "torchmeter/**/*.py", + "tests/**/*.py", + "examples/**/*.py", + "examples/*/*.ipynb" +] +extend-exclude = [ + "refers/**/*" +] -[format] -docstring-code-format = true # Enable reformatting of code snippets in docstrings. +src=[".", "torchmeter/*"] +preview = true +line-length = 120 # Allow lines to be as long as 120. +target-version = "py38" # Always generate Python 3.8-compatible code. +output-format = "grouped" +required-version = ">=0.6.0" # ruff version +# =========================================== Linter =========================================== [lint] +select = [ + # flake8-builtins + "A", + # flake8-annotations + "ANN", + # flake8-unused-arguments + "ARG", + # mccabe + "C90", + # pycodestyle + "E", + # Pyflakes + "F", + # isort + "I", + # flake8-no-pep420 + "INP", + # flake8-implicit-str-concat + # "ISC", + # flake8-pie + "PIE", + # flake8-pytest-style + "PT", + # Error + "PLE", + # ruff-specific rules + "RUF", + # flake8-simplify + "SIM", + # flake8-2020 + "YTT" +] +extend-select = [ + "ISC001", # Implicitly concatenated string literals on one line + "Q004", # Unnecessary escape on inner quote character + "DOC201", # return is not documented in docstring + "DOC402", # yield is not documented in docstring + "DOC403", # Docstring has a "Yields" section but the function doesn't yield anything + "DOC501", # Raised exception {id} missing from docstring +] + # Skip unused variable rules ignore = [ - "ANN101", # Missing type annotation for `self` in method - "ANN102", # Missing type annotation for `cls` in classmethod - "ANN401", # Dynamically typed expressions (typing.Any) are disallowed - "C901", # function is too complex (12 > 10) - "COM812", # Trailing comma missing - "D", # Docstring rules - "EM101", # Exception must not use a string literal, assign to variable first - "EM102", # Exception must not use an f-string literal, assign to variable first - "ERA001", # Found commented-out code - "FBT001", # Boolean positional arg in function definition - "FBT002", # Boolean default value in function definition - "FBT003", # Boolean positional value in function call - "FIX002", # Line contains TODO - "ISC001", # Isort - "PLR0911", # Too many return statements (11 > 6) - "PLR2004", # Magic value used in comparison, consider replacing 2 with a constant variable - "PLR0912", # Too many branches - "PLR0913", # Too many arguments to function call - "PLR0915", # Too many statements - "S101", # Use of `assert` detected - "S311", # Standard pseudo-random generators are not suitable for cryptographic purposes - "T201", # print() found - "T203", # pprint() found - "TD002", # Missing author in TODO; try: `# TODO(): ...` - "TD003", # Missing issue link on the line following this TODO - "TD005", # Missing issue description after `TODO` - "TRY003", # Avoid specifying long messages outside the exception class - "PLW2901", # `for` loop variable `name` overwritten by assignment target - "SLF001", # Private member accessed: `_modules` + "ANN002", # Missing type annotation for *{name} + "ANN003", # Missing type annotation for **{name} + "ANN401", # Dynamically typed expressions (typing.Any) are disallowed in {name} + "E111", # Indentation is not a multiple of {indent_width} + "E114", # Indentation is not a multiple of {indent_width} (comment) + "E117", # Over-indented (comment) + "E261", # Insert at least two spaces before an inline comment + "E731", # Allow lambda expressions + "FA102", # Missing from __future__ import annotations, but uses {reason} + "PIE790", # Unnecessary pass statement + "PIE794", # Class field {name} is defined multiple times + "PIE810", # Call {attr} once with a tuple + "PLE0101", # Explicit return in __init__ + "PT007", # Wrong values type in pytest.mark.parametrize expected {values} of {row} + "PT008", # Use return_value= instead of patching with lambda + "PT009", # Use a regular assert instead of unittest-style {assertion} + "PT011", # pytest.raises({exception}) is too broad, set the match parameter or use a more specific exception + "PT012", # pytest.raises() block should contain a single simple statement + "PT019", # Fixture {name} without value is injected as parameter, use @pytest.mark.usefixtures instead + "PT021", # Use yield instead of request.addfinalizer + "PT023", # Use @pytest.mark.{mark_name}{expected_parens} over @pytest.mark.{mark_name}{actual_parens} + "RUF022", # __all__ is not sorted + "RUF023", # {}.__slots__ is not sorted + "RUF031", # Use parentheses for tuples in subscripts + "RUF034", # Useless if-else condition + "RUF052", # Local dummy variable {} is accessed + "SIM105", # Use contextlib.suppress({exception}) instead of try-except-pass + "SIM107", # Don't use return in try-except and finally + "SIM910", # Use {expected} instead of {actual} + "W191", # Indentation contains tabs ] +unfixable = ["E501"] # long lines should be wrapped manually + +[lint.per-file-ignores] +"*.ipynb" = ["E402"] # Module level import not at top of cell +"tests/**/*.py" = [ + "ANN001", # Missing type annotation for function argument {name} + "ANN201", # Missing return type annotation for public function {name} + "ANN202", # Missing return type annotation for private function {name} + "ARG002", # Unused method argument: {name} + "ARG005", # Unused lambda argument: {name} + "PT030", # pytest.warns({warning}) is too broad, set the match parameter or use a more specific warning + "PT031", # pytest.warns() block should contain a single simple statement + "DOC" # pydoclint +] + +[lint.flake8-implicit-str-concat] +allow-multiline = true + [lint.isort] length-sort = true # sort imports by their string length -combine-as-imports = true known-first-party = ["torchmeter"] -lines-after-imports = 1 # Use a single line after each import block. -single-line-exclusions = ["os", "json", "re"] + +# =========================================== Formatter =========================================== + +[format] +quote-style = "double" +docstring-code-format = true # Enable reformatting of code snippets in docstrings. \ No newline at end of file diff --git a/tests/__init__.py b/tests/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/test_config.py b/tests/test_config.py index aaa17c4..300f0d3 100644 --- a/tests/test_config.py +++ b/tests/test_config.py @@ -1,97 +1,112 @@ import os from enum import Enum -from unittest.mock import patch, Mock, PropertyMock +from unittest.mock import Mock, PropertyMock, patch import yaml import pytest from torchmeter.config import ( - UNSAFE_KV, DEFAULT_CFG, DEFAULT_FIELDS, - dict_to_namespace, namespace_to_dict, list_to_callbacklist, - FlagNameSpace, CallbackList, CallbackSet, - get_config, Config + UNSAFE_KV, + DEFAULT_CFG, + DEFAULT_FIELDS, + Config, + CallbackSet, + CallbackList, + FlagNameSpace, + get_config, + dict_to_namespace, + namespace_to_dict, + list_to_callbacklist, ) -@pytest.fixture() + +@pytest.fixture def default_cfg_path(tmpdir): temp_cfg_path = tmpdir.join("default_cfg.yaml") - with open(temp_cfg_path, 'w') as f: + with open(temp_cfg_path, "w") as f: f.write(DEFAULT_CFG) yield temp_cfg_path.strpath if tmpdir.exists(): tmpdir.remove(rec=1) - -@pytest.fixture() + + +@pytest.fixture def invalid_cfg_path(tmpdir): temp_cfg_path = tmpdir.join("invalid_cfg.txt") - with open(temp_cfg_path, 'w') as f: + with open(temp_cfg_path, "w") as f: f.write(DEFAULT_CFG) yield temp_cfg_path.strpath if tmpdir.exists(): tmpdir.remove(rec=1) -@pytest.fixture() + +@pytest.fixture def custom_cfg_path(tmpdir): temp_cfg_path = tmpdir.join("custom_cfg.yaml") - + fake_interval = 0.34 fake_content = f"render_interval: {fake_interval}" with open(temp_cfg_path, "w") as f: f.write(fake_content) - + yield temp_cfg_path.strpath - + if tmpdir.exists(): tmpdir.remove(rec=1) + @pytest.fixture def clist(): - obj = CallbackList([1,2,3,2]) + obj = CallbackList([1, 2, 3, 2]) obj._callback_func = Mock() return obj + @pytest.fixture def cset(): - obj = CallbackSet([1,3,2]) + obj = CallbackSet([1, 3, 2]) obj._callback_func = Mock() return obj -def pytest_generate_tests(metafunc): + +def pytest_generate_tests(metafunc) -> None: if "all_type_data" in metafunc.fixturenames: metafunc.parametrize( - argnames="all_type_data", + argnames="all_type_data", argvalues=[ "string", False, 123, 1.23, None, - (1,2,3), [4,5,6], {7,8,9}, {"A":1, "B":2}, + (1, 2, 3), [4, 5, 6], {7, 8, 9}, {"A": 1, "B": 2}, lambda : None, (_ for _ in range(5)) ], - ids=map(lambda x:f"val({x})", [ + ids=map(lambda x: f"val({x})", [ "string", "bool", "int", "float", "None", - "tuple", "list", "set","dict", + "tuple", "list", "set", "dict", "function", "iterable_obj" ] ) - ) + ) # fmt: skip + @pytest.mark.vital -def test_unsafe_kv(): +def test_unsafe_kv() -> None: """Test whether the UNSAFE_KV is correctly defined.""" - + for key, val in UNSAFE_KV.items(): assert isinstance(key, str) - assert issubclass(val,Enum) - + assert issubclass(val, Enum) + # test whether the val'repr not equal to its corresponding key for k, v in val.__members__.items(): assert str(v) != k + @pytest.mark.vital -def test_default_fields_in_default_setting(): +def test_default_fields_in_default_setting() -> None: """Test whether all default fields are defined in default setting""" - + setting_lines_generator = (line for line in DEFAULT_CFG.split('\n') - if len(line) and not line.isspace()) - + if len(line) and not line.isspace()) # fmt: skip + assure_fields = [] for valid_line in setting_lines_generator: for field in DEFAULT_FIELDS: @@ -100,63 +115,63 @@ def test_default_fields_in_default_setting(): assure_fields.append(field) if len(assure_fields) == len(DEFAULT_FIELDS): return - pytest.fail(f"These fields are missing in default setting: {set(DEFAULT_FIELDS)-set(assure_fields)}") + pytest.fail(f"These fields are missing in default setting: {set(DEFAULT_FIELDS) - set(assure_fields)}") + -def test_list_to_callbacklist(): +def test_list_to_callbacklist() -> None: """Test the logic of list_to_callbacklist function""" - - ls = [1, "2", 3., None, (6,), - {7}, {"eight":8}, [9]] + + ls = [1, "2", 3.0, None, (6,), {7}, {"eight": 8}, [9]] res = list_to_callbacklist(ls) - assert res[:6] == [1, "2", 3., None, (6,), {7}] + assert res[:6] == [1, "2", 3.0, None, (6,), {7}] assert isinstance(res[6], FlagNameSpace) assert res[6].eight == 8 assert isinstance(res[7], CallbackList) assert res[7] == [9] - -class TestListToCallbackList: - ... + + +class TestListToCallbackList: ... + class TestDictToNamespace: - @pytest.mark.parametrize( argnames=("key", "is_error"), argvalues=[ - ("string", False), - (False, True), - (123, True), - (1.23, True), - (None, True), - ((1,2,3), True) + ("string", False), + (False, True), + (123, True), + (1.23, True), + (None, True), + ((1, 2, 3), True), ], - ids=map(lambda x:f"key({x})", ["string", "bool", "int", "float", "None", "tuple"]) + ids=map(lambda x: f"key({x})", ["string", "bool", "int", "float", "None", "tuple"]), ) - def test_valid_input(self, key, is_error, all_type_data): + def test_valid_input(self, key, is_error, all_type_data) -> None: """Test normal dictionary conversion""" input_dict = {key: all_type_data} - + if is_error: with pytest.raises(TypeError): dict_to_namespace(input_dict) else: result = dict_to_namespace(input_dict) assert isinstance(result, FlagNameSpace) - + key_res = getattr(result, key) if isinstance(all_type_data, dict): assert isinstance(key_res, FlagNameSpace) else: assert key_res == all_type_data - def test_invalid_input(self, all_type_data): + def test_invalid_input(self, all_type_data) -> None: """Test non-dictionary input""" if not isinstance(all_type_data, dict): with pytest.raises(TypeError): dict_to_namespace(all_type_data) - def test_nested_dict(self): + def test_nested_dict(self) -> None: """Test the conversion of nested dictionary""" - + nested_dict = { "nested_one": {"key": "value"}, @@ -170,50 +185,51 @@ def test_nested_dict(self): } } } - } + } # fmt: skip result = dict_to_namespace(nested_dict) - - def dfs_assert(namespace, depth=0): + + def dfs_assert(namespace, depth=0) -> None: for k, v in namespace.data_dict.items(): - if isinstance(v, FlagNameSpace): - dfs_assert(v, depth+1) + dfs_assert(v, depth + 1) else: assert k == "key" assert v == "value" - + assert isinstance(result, FlagNameSpace) dfs_assert(result) - def test_list(self): + def test_list(self) -> None: """Test the conversion of dictionary containing list""" input_dict = {"list": [{"key1": "value1"}, "item2"]} result = dict_to_namespace(input_dict) - + assert isinstance(result, FlagNameSpace) - assert isinstance(result.list, CallbackList) + assert isinstance(result.list, CallbackList) assert isinstance(result.list[0], FlagNameSpace) - + assert result.list[0].key1 == "value1" assert result.list[1] == "item2" - - def test_set(self): + + def test_set(self) -> None: """Test the conversion of dictionary containing set""" input_dict = {"set": {"item1", "item2"}} result = dict_to_namespace(input_dict) - + assert isinstance(result, FlagNameSpace) - assert isinstance(result.set, CallbackSet) - + assert isinstance(result.set, CallbackSet) + assert result.set == {"item1", "item2"} - @pytest.mark.parametrize(argnames="unsafe_key", - argvalues=UNSAFE_KV.keys(), - ids=map(lambda x:f"unsafe_key({x})", UNSAFE_KV.keys())) - def test_unsafe_key(self, unsafe_key): - """""Test the conversion of dict containing unsafe key""" + @pytest.mark.parametrize( + argnames="unsafe_key", + argvalues=UNSAFE_KV.keys(), + ids=map(lambda x: f"unsafe_key({x})", UNSAFE_KV.keys()), + ) + def test_unsafe_key(self, unsafe_key) -> None: + """ ""Test the conversion of dict containing unsafe key""" vals_enum = UNSAFE_KV[unsafe_key] - + valid_safevals = [] for member in vals_enum: input_dict = {unsafe_key: member.name} @@ -221,87 +237,92 @@ def test_unsafe_key(self, unsafe_key): assert isinstance(result, FlagNameSpace) assert getattr(result, unsafe_key) is member.value valid_safevals.append(member.name) - + # verify the invalid value error with pytest.raises(AttributeError): - invalid_safeval = 'invalid_safeval' + invalid_safeval = "invalid_safeval" while invalid_safeval in valid_safevals: invalid_safeval *= 2 result = dict_to_namespace({unsafe_key: invalid_safeval}) - def test_invalid_key(self): + def test_invalid_key(self) -> None: """Test the conversion of dictionary containing invalid key""" with pytest.raises(AttributeError): dict_to_namespace({"__FLAG": "value"}) - + with pytest.raises(AttributeError): dict_to_namespace({"__flag_key": 123}) + class TestNamespaceToDict: - def test_valid_input(self, all_type_data): + def test_valid_input(self, all_type_data) -> None: """Test normal namespace conversion""" ns = FlagNameSpace(key1=all_type_data) result = namespace_to_dict(ns) assert isinstance(result, dict) assert result["key1"] == all_type_data - def test_invalid_input(self, all_type_data): + def test_invalid_input(self, all_type_data) -> None: """Test non-FlagNameSpace input""" with pytest.raises(TypeError): namespace_to_dict(all_type_data) - @pytest.mark.parametrize(argnames="unsafe_key", - argvalues=UNSAFE_KV.keys(), - ids=map(lambda x:f"unsafe_key({x})", UNSAFE_KV.keys())) - @pytest.mark.parametrize(argnames="safe_resolve", - argvalues=(True, False), - ids=lambda x:f"safe_resolve={x}") - def test_unsafe_key(self, unsafe_key, safe_resolve): + @pytest.mark.parametrize( + argnames="unsafe_key", + argvalues=UNSAFE_KV.keys(), + ids=map(lambda x: f"unsafe_key({x})", UNSAFE_KV.keys()), + ) + @pytest.mark.parametrize( + argnames="safe_resolve", + argvalues=(True, False), + ids=lambda x: f"safe_resolve={x}", + ) + def test_unsafe_key(self, unsafe_key, safe_resolve) -> None: """Test the conversion of FlagNameSpace containing unsafe key""" vals_enum = UNSAFE_KV[unsafe_key] - + valid_vals = [] for member in list(vals_enum): ns = FlagNameSpace() setattr(ns, unsafe_key, member.value) res_dict = namespace_to_dict(ns, safe_resolve=safe_resolve) assert isinstance(res_dict, dict) - + if safe_resolve: assert res_dict[unsafe_key] == member.name else: assert res_dict[unsafe_key] == member.value valid_vals.append(res_dict[unsafe_key]) - invalid_safeval = 'invalid_val' + invalid_safeval = "invalid_val" while invalid_safeval in valid_vals: invalid_safeval *= 2 - + ns = FlagNameSpace() setattr(ns, unsafe_key, invalid_safeval) - + if not safe_resolve: namespace_to_dict(ns, safe_resolve=safe_resolve) - + invalid_unsafeval = lambda x: "invalid_unsafeval" ns = FlagNameSpace() setattr(ns, unsafe_key, invalid_unsafeval) - + namespace_to_dict(ns, safe_resolve=safe_resolve) - + else: with pytest.raises(Exception): namespace_to_dict(ns, safe_resolve=safe_resolve) - + invalid_unsafeval = lambda x: "invalid_unsafeval" ns = FlagNameSpace() setattr(ns, unsafe_key, invalid_unsafeval) - + with pytest.raises(Exception): namespace_to_dict(ns, safe_resolve=safe_resolve) - - def test_nested_namespace(self): + + def test_nested_namespace(self) -> None: """Test the conversion of nested FlagNameSpace""" ns = FlagNameSpace( nested_one=FlagNameSpace( @@ -311,15 +332,15 @@ def test_nested_namespace(self): ) ) ) - ) + ) # fmt: skip - def dfs_assert(res_dict, depth=0): + def dfs_assert(res_dict, depth=0) -> None: for k, v in res_dict.items(): if "__FLAG" in k: continue - + if isinstance(v, dict): - dfs_assert(v, depth+1) + dfs_assert(v, depth + 1) assert k.startswith("nested") else: assert k == "key" @@ -327,10 +348,10 @@ def dfs_assert(res_dict, depth=0): result = namespace_to_dict(ns) assert isinstance(result, dict) - + dfs_assert(result) - def test_list(self): + def test_list(self) -> None: """Test the conversion of FlagNameSpace containing list""" nested_ns = FlagNameSpace(key1="value1") ns = FlagNameSpace(list=[nested_ns, "item2"]) @@ -340,308 +361,311 @@ def test_list(self): assert result["list"][0]["key1"] == "value1" assert result["list"][1] == "item2" - def test_invalid_key(self): + def test_invalid_key(self) -> None: """Test the conversion of FlagNameSpace containing invalid key""" - + with pytest.raises(AttributeError): FlagNameSpace(__FLAG="value") - + with pytest.raises(AttributeError): FlagNameSpace(__flag_key="value") + class TestCallbackList: - def test_init(self): + def test_init(self) -> None: """Test the initialization of callback list""" - + # verify default callback function clist = CallbackList((1, 2, 3)) assert isinstance(clist, list) assert clist == [1, 2, 3] assert clist._callback_func() is None - + # verify callback funtion specification clist = CallbackList({4, 5, 6}, callback_func=lambda: 42) assert isinstance(clist, list) assert clist == [4, 5, 6] assert clist._callback_func() == 42 - - def test_inheritance(self): + + def test_inheritance(self) -> None: """Test whether the list type is maintained.""" assert issubclass(CallbackList, list) assert isinstance(CallbackList(), list) - - def test_append(self, clist): + + def test_append(self, clist) -> None: """Test the append method of callback list""" clist.append(42) - assert clist == [1,2,3,2,42] + assert clist == [1, 2, 3, 2, 42] clist._callback_func.assert_called_once() - def test_extend(self, clist): + def test_extend(self, clist) -> None: """Test the extend method of callback list""" clist.extend([1, 2, 3]) - assert clist == [1,2,3,2,1, 2, 3] + assert clist == [1, 2, 3, 2, 1, 2, 3] clist._callback_func.assert_called_once() - def test_insert(self, clist): + def test_insert(self, clist) -> None: """Test the insert method of callback list""" - + clist.insert(0, 10) - assert clist == [10,1,2,3,2] + assert clist == [10, 1, 2, 3, 2] assert clist._callback_func.call_count == 1 - def test_pop(self, clist): + def test_pop(self, clist) -> None: """Test the pop method of callback list""" - + clist.pop() - assert clist == [1,2,3] + assert clist == [1, 2, 3] assert clist._callback_func.call_count == 1 - def test_remove(self, clist): + def test_remove(self, clist) -> None: """Test the remove method of callback list""" - + clist.remove(2) assert clist == [1, 3, 2] assert clist._callback_func.call_count == 1 - - def test_clear(self, clist): + + def test_clear(self, clist) -> None: """Test the clear method of callback list""" - + clist.clear() assert not len(clist) - assert clist._callback_func.call_count == 1 + assert clist._callback_func.call_count == 1 - def test_reverse(self, clist): + def test_reverse(self, clist) -> None: """Test the reverse method of callback list""" - + clist.reverse() assert clist == [2, 3, 2, 1] assert clist._callback_func.call_count == 1 - - def test_sort(self, clist): + + def test_sort(self, clist) -> None: """Test the sort method of callback list""" - + clist.sort() assert clist == [1, 2, 2, 3] assert clist._callback_func.call_count == 1 - - def test_setitem(self, clist): + + def test_setitem(self, clist) -> None: """Test the setitem method of callback list""" - + clist[0] = 10 assert clist == [10, 2, 3, 2] assert clist._callback_func.call_count == 1 - def test_delitem(self, clist): + def test_delitem(self, clist) -> None: """Test the delitem method of callback list""" - + del clist[0] assert clist == [2, 3, 2] assert clist._callback_func.call_count == 1 - - def test_iadd(self, clist): + + def test_iadd(self, clist) -> None: """Test the iadd method of callback list""" - + clist += [10, 20, 30] assert clist == [1, 2, 3, 2, 10, 20, 30] assert clist._callback_func.call_count == 1 - - def test_imul(self, clist): + + def test_imul(self, clist) -> None: """Test the imul method of callback list""" - + clist *= 2 assert clist == [1, 2, 3, 2, 1, 2, 3, 2] assert clist._callback_func.call_count == 1 - - def test_multi_calls(self, clist): + + def test_multi_calls(self, clist) -> None: """Test whether the callback function is called correctly in multiple calls""" - + clist.append(10) clist.extend([20, 30]) clist.append(40) assert clist == [1, 2, 3, 2, 10, 20, 30, 40] assert clist._callback_func.call_count == 3 - def test_callback_trigger_order(self): + def test_callback_trigger_order(self) -> None: """Test callback function is triggered after origin api ends""" - + result = [] cl = CallbackList() cl._callback_func = lambda: result.append(len(cl)) - + cl.append(10) cl.extend([20, 30]) - assert result == [1, 3] - - def test_edge_cases(self, clist): + assert result == [1, 3] + + def test_edge_cases(self, clist) -> None: """Test some edge usage cases""" - + # empty operation clist.append(None) clist.extend([]) - assert clist == [1,2,3,2,None] + assert clist == [1, 2, 3, 2, None] assert clist._callback_func.call_count == 2 clist._callback_func.reset_mock() - + # invalid usage of origin api with pytest.raises(TypeError): - clist.append(1, 2, 3) + clist.append(1, 2, 3) assert clist._callback_func.call_count == 0 with pytest.raises(TypeError): - clist.extend(1) + clist.extend(1) assert clist._callback_func.call_count == 0 + class TestCallbackSet: - def test_init(self): + def test_init(self) -> None: """Test the initialization of callback set""" - + # verify default callback function cset = CallbackSet((1, 2, 3)) assert isinstance(cset, set) assert cset == {1, 2, 3} assert cset._callback_func() is None - + # verify callback funtion specification cset = CallbackSet([4, 4, 6], callback_func=lambda: 42) assert isinstance(cset, set) assert cset == {4, 6} assert cset._callback_func() == 42 - - def test_inheritance(self): + + def test_inheritance(self) -> None: """Test whether the set type is maintained.""" assert issubclass(CallbackSet, set) assert isinstance(CallbackSet(), set) - - def test_add(self, cset): + + def test_add(self, cset) -> None: """Test the add method of callback set""" cset.add(42) - assert cset == {1,2,3,42} + assert cset == {1, 2, 3, 42} cset._callback_func.assert_called_once() - def test_update(self, cset): + def test_update(self, cset) -> None: """Test the update method of callback set""" - cset.update({10,6,8}) + cset.update({10, 6, 8}) assert cset == {1, 2, 3, 6, 8, 10} cset._callback_func.assert_called_once() - def test_difference_update(self, cset): + def test_difference_update(self, cset) -> None: """Test the difference_update method of callback set""" - + cset.difference_update({2, 3}) assert cset == {1} assert cset._callback_func.call_count == 1 - def test_intersection_update(self, cset): + def test_intersection_update(self, cset) -> None: """Test the intersection_update method of callback set""" - + cset.intersection_update({2}) assert cset == {2} assert cset._callback_func.call_count == 1 - def test_symmetric_difference_update(self, cset): + def test_symmetric_difference_update(self, cset) -> None: """Test the symmetric_difference_update method of callback set""" - + cset.symmetric_difference_update({2, 3, 4}) assert cset == {1, 4} assert cset._callback_func.call_count == 1 - - def test_discard(self, cset): + + def test_discard(self, cset) -> None: """Test the discard method of callback set""" - + cset.discard(1) - assert cset == {2,3} - assert cset._callback_func.call_count == 1 + assert cset == {2, 3} + assert cset._callback_func.call_count == 1 - def test_pop(self, cset): + def test_pop(self, cset) -> None: """Test the pop method of callback set""" - + cset.pop() - assert cset == {2,3} + assert cset == {2, 3} assert cset._callback_func.call_count == 1 - - def test_remove(self, cset): + + def test_remove(self, cset) -> None: """Test the remove method of callback set""" - + cset.remove(2) assert cset == {1, 3} assert cset._callback_func.call_count == 1 - - def test_clear(self, cset): + + def test_clear(self, cset) -> None: """Test the clear method of callback set""" - + cset.clear() assert not len(cset) assert cset._callback_func.call_count == 1 - def test_isub(self, cset): + def test_isub(self, cset) -> None: """Test the isub method of callback set""" - + cset -= {2} assert cset == {1, 3} assert cset._callback_func.call_count == 1 - - def test_iand(self, cset): + + def test_iand(self, cset) -> None: """Test the iand method of callback set""" - + cset &= {3} assert cset == {3} assert cset._callback_func.call_count == 1 - - def test_ixor(self, cset): + + def test_ixor(self, cset) -> None: """Test the iadd method of callback set""" - - cset ^= {2,4} - assert cset == {1,3,4} + + cset ^= {2, 4} + assert cset == {1, 3, 4} assert cset._callback_func.call_count == 1 - - def test_ior(self, cset): + + def test_ior(self, cset) -> None: """Test the ior method of callback set""" - + cset |= {4, 5} - assert cset == {1,2,3,4,5} + assert cset == {1, 2, 3, 4, 5} assert cset._callback_func.call_count == 1 - - def test_multi_calls(self, cset): + + def test_multi_calls(self, cset) -> None: """Test whether the callback function is called correctly in multiple calls""" - + cset.add(10) cset.update({20, 30}) cset.add(40) - assert cset == {1,2,3,10,20,30,40} + assert cset == {1, 2, 3, 10, 20, 30, 40} assert cset._callback_func.call_count == 3 - def test_callback_trigger_order(self): + def test_callback_trigger_order(self) -> None: """Test callback function is triggered after origin api ends""" - + result = [] cl = CallbackSet() cl._callback_func = lambda: result.append(len(cl)) - + cl.add(10) cl.update({20, 30}) - assert result == [1, 3] - - def test_edge_cases(self, cset): + assert result == [1, 3] + + def test_edge_cases(self, cset) -> None: """Test the edge cases of callback set""" - + # empty operation cset.add(None) cset.update(set()) - assert cset == {1,3,2,None} + assert cset == {1, 3, 2, None} assert cset._callback_func.call_count == 2 cset._callback_func.reset_mock() - + # invalid usage of origin api with pytest.raises(TypeError): - cset.add(1, 2, 3) + cset.add(1, 2, 3) assert cset._callback_func.call_count == 0 with pytest.raises(KeyError): - cset.remove(100) + cset.remove(100) assert cset._callback_func.call_count == 0 + class TestFlagNameSpace: - def test_init(self): + def test_init(self) -> None: flagns = FlagNameSpace(key1="value1", key2=123) assert hasattr(flagns, "key1") assert hasattr(flagns, "key2") @@ -650,16 +674,16 @@ def test_init(self): assert hasattr(flagns, "_FlagNameSpace__flag_key") assert not flagns.is_change() - def test_setattr(self, all_type_data): + def test_setattr(self, all_type_data) -> None: flagns = FlagNameSpace() flagns.key1 = all_type_data - + # verify the flag is toggled assert flagns.is_change() - + # verify the dict, list and set value is transformed to corresponding format if isinstance(all_type_data, dict): - assert isinstance(flagns.key1, FlagNameSpace) + assert isinstance(flagns.key1, FlagNameSpace) for k, v in all_type_data.items(): assert getattr(flagns.key1, k) == v elif isinstance(all_type_data, list): @@ -669,190 +693,192 @@ def test_setattr(self, all_type_data): assert isinstance(flagns.key1, CallbackSet) assert flagns.key1 == all_type_data else: - assert flagns.key1 == all_type_data - + assert flagns.key1 == all_type_data + # invalid key with pytest.raises(AttributeError): setattr(flagns, "__FLAG", "new_value") - + with pytest.raises(AttributeError): setattr(flagns, "__flag_key", "new_value") - def test_delattr(self): + def test_delattr(self) -> None: flagns = FlagNameSpace(key1="value1") del flagns.key1 assert not hasattr(flagns, "key1") - + # verify the flag is toggled - assert flagns.is_change() + assert flagns.is_change() with pytest.raises(AttributeError): del flagns.__FLAG - + with pytest.raises(AttributeError): del flagns._FlagNameSpace__FLAG - + with pytest.raises(AttributeError): del flagns.__flag_key - + with pytest.raises(AttributeError): del flagns._FlagNameSpace__flag_key - def test_data_dict(self): + def test_data_dict(self) -> None: """Test the data_dict property is set and retrieved correctly""" - + flagns = FlagNameSpace(key1="value1", key2=123) - + assert hasattr(flagns, "data_dict") - + # verify content data_dict = flagns.data_dict assert isinstance(data_dict, dict) assert "__FLAG" not in data_dict assert data_dict["key1"] == "value1" assert data_dict["key2"] == 123 - + # verify memory independence data_dict["key1"] = "value2" assert flagns.key1 == "value1" - def test_update(self): + def test_update(self) -> None: """Test the logic of update method""" - + flagns = FlagNameSpace(key1="value1", key2=123) - + # invalid input type with pytest.raises(TypeError): flagns.update(123) - + # verify flag toggle assert flagns.is_change() is False - + # add new key - ## with dict + # with dict assert "key3" not in flagns.__dict__ flagns.update({"key3": 456}) assert flagns.key3 == 456 assert "key2" in flagns.__dict__ assert flagns.is_change() - - ## with FlagNameSpace + + # with FlagNameSpace assert "key4" not in flagns.__dict__ flagns.update(FlagNameSpace(key4=789)) assert flagns.key4 == 789 assert "key3" in flagns.__dict__ - + # update existing key - ## with dict + # with dict assert flagns.key1 != "000" flagns.update({"key1": "000"}) assert flagns.key1 == "000" assert "key4" in flagns.__dict__ - - ## with FlagNameSpace + + # with FlagNameSpace assert flagns.key2 != "123" flagns.update(FlagNameSpace(key2="123")) assert flagns.key2 == "123" assert "key3" in flagns.__dict__ - - # dict to FlagNameSpace + + # dict to FlagNameSpace flagns.update({"key5": {"subkey1": 901}}) assert isinstance(flagns.key5, FlagNameSpace) assert flagns.key5.data_dict == {"subkey1": 901} - + # verify origin structure keeping with pytest.raises(RuntimeError): flagns.update({"key5": 901}) - + # verify replace option flagns.mark_unchange() flagns.update({"key5": 901}, replace=True) assert flagns.data_dict == {"key5": 901} assert flagns.is_change() - def test_is_change(self): - flagns = FlagNameSpace(key1='1', - key2=[2,[3, 3]], - key3=FlagNameSpace(val3=4)) - assert not flagns.is_change() + def test_is_change(self) -> None: + flagns = FlagNameSpace( + key1="1", + key2=[2, [3, 3]], + key3=FlagNameSpace(val3=4), + ) + assert not flagns.is_change() # common case flagns.key1 = "value1" assert flagns.is_change() flagns.mark_unchange() - assert not flagns.is_change() - + assert not flagns.is_change() + flagns.key1 += "value2" assert flagns.is_change() flagns.mark_unchange() - assert not flagns.is_change() - + assert not flagns.is_change() + # modify list - ## modify common element + # modify common element flagns.key2[0] = 5 assert flagns.is_change() flagns.mark_unchange() assert not flagns.is_change() - + flagns.key2.append(6) assert flagns.is_change() flagns.mark_unchange() assert not flagns.is_change() - + del flagns.key2[0] assert flagns.is_change() flagns.mark_unchange() assert not flagns.is_change() - - ## modify nested list + + # modify nested list flagns.key2[0][0] = 7 assert flagns.is_change() flagns.mark_unchange() assert not flagns.is_change() - + # modify namespace flagns.key3.val3 = 6 assert flagns.is_change() flagns.mark_unchange() assert not flagns.is_change() - + flagns.key3.val4 = 7 assert flagns.is_change() flagns.mark_unchange() assert not flagns.is_change() - + flagns.key3.val4 += 8 assert flagns.is_change() flagns.mark_unchange() assert not flagns.is_change() - + # add new key flagns.key4 = {1, 2} assert flagns.is_change() assert flagns.is_change() flagns.mark_unchange() - + # del key del flagns.key1 assert flagns.is_change() flagns.mark_unchange() assert not flagns.is_change() - + # modify set flagns.key4.add(3) assert flagns.is_change() flagns.mark_unchange() assert not flagns.is_change() - + flagns.key4.remove(2) assert flagns.is_change() flagns.mark_unchange() - - def test_mark_change_and_unchange(self): + + def test_mark_change_and_unchange(self) -> None: flagns = FlagNameSpace(key=FlagNameSpace(subkey=1)) assert not flagns.is_change() assert not flagns.key.is_change() - + # parent change, child no need to change flagns.mark_change() assert flagns.is_change() @@ -864,48 +890,49 @@ def test_mark_change_and_unchange(self): assert flagns.is_change() assert flagns.key.is_change() flagns.key.mark_unchange() - + # parent reset to not change, childs do it too flagns.key.mark_change() flagns.mark_unchange() assert not flagns.is_change() assert not flagns.key.is_change() + @pytest.mark.vital class TestGetConfig: - def teardown_method(self, method): + def teardown_method(self, method) -> None: cfg = Config() cfg.config_file = None - - def test_get_default(self): - with patch.dict(os.environ, {}, clear=True): + + def test_get_default(self) -> None: + with patch.dict(os.environ, {}, clear=True): config = get_config() assert isinstance(config, Config) - assert config.config_file is None + assert config.config_file is None for field in DEFAULT_FIELDS: assert hasattr(config, field) - def test_get_from_env(self, default_cfg_path): + def test_get_from_env(self, default_cfg_path) -> None: with patch.dict(os.environ, {"TORCHMETER_CONFIG": default_cfg_path}): config = get_config() assert isinstance(config, Config) assert config.config_file == default_cfg_path - def test_get_from_path(self, default_cfg_path): + def test_get_from_path(self, default_cfg_path) -> None: config = get_config(default_cfg_path) assert isinstance(config, Config) assert config.config_file == default_cfg_path - def test_config_file_not_exist(self): + def test_config_file_not_exist(self) -> None: fake_config_path = "/fake/path/to/nonexistent.yaml" with pytest.raises(FileNotFoundError): get_config(fake_config_path) - def test_get_invalid_file(self, invalid_cfg_path): + def test_get_invalid_file(self, invalid_cfg_path) -> None: with pytest.raises(ValueError): get_config(invalid_cfg_path) - def test_get_custom_file(self, custom_cfg_path): + def test_get_custom_file(self, custom_cfg_path) -> None: with pytest.warns(UserWarning) as w: config = get_config(custom_cfg_path) assert isinstance(config, Config) @@ -913,44 +940,45 @@ def test_get_custom_file(self, custom_cfg_path): assert config.render_interval == 0.34 assert len(w) == len(DEFAULT_FIELDS) - 1 + @pytest.mark.vital class TestConfig: - def teardown_method(self, method): + def teardown_method(self, method) -> None: cfg = Config() cfg.restore() - - def test_init(self, custom_cfg_path): + + def test_init(self, custom_cfg_path) -> None: # init with no config file config = Config() assert config.config_file is None - + default_settings_dict = yaml.safe_load(DEFAULT_CFG) default_ns = dict_to_namespace(default_settings_dict) for field in DEFAULT_FIELDS: assert getattr(config, field) == getattr(default_ns, field) - + # init with config file with pytest.warns(UserWarning): config = Config(custom_cfg_path) assert config.config_file == custom_cfg_path assert config.render_interval == 0.34 - - def test_ban_delete_or_new_field(self): + + def test_ban_delete_or_new_field(self) -> None: config = Config() with pytest.raises(AttributeError): config.new_attr = 123 - - for field in DEFAULT_FIELDS + ["config_file"]: + + for field in [*DEFAULT_FIELDS, "config_file"]: with pytest.raises(RuntimeError): delattr(config, field) - def test_config_file_property(self, invalid_cfg_path, custom_cfg_path): + def test_config_file_property(self, invalid_cfg_path, custom_cfg_path) -> None: """Test the property `config_file` getter and setter""" config = Config() assert config.config_file is None - + with pytest.raises(TypeError): - config.config_file = 123 + config.config_file = 123 with pytest.raises(FileNotFoundError): config.config_file = "/fake/path/to/nonexistent.yaml" @@ -960,108 +988,108 @@ def test_config_file_property(self, invalid_cfg_path, custom_cfg_path): # custom config file specified is tested in TestGetConfig::test_get_custom_file - def test_setattr(self): + def test_setattr(self) -> None: """Test the logic of setattr""" - + config = Config() - + # set attribute that is not in the DEFAULT_FIELDS with pytest.raises(AttributeError): config.invalid_attr = 1 - + # set attribute whose value is not a FlagNameSpace assert config.render_interval != 0.6 config.render_interval = 0.6 assert config.render_interval == 0.6 - + assert config.tree_fold_repeat is True config.tree_fold_repeat = False assert config.tree_fold_repeat is False - + # set attribute whose value is a FlagNameSpace # verify the action is actually a update - ## with dict + # with dict assert config.tree_repeat_block_args.title_align != "left" config.tree_repeat_block_args = {"title_align": "left"} assert config.tree_repeat_block_args.title_align == "left" - + assert config.tree_levels_args.default.guide_style != "red" config.tree_levels_args = {"default": {"guide_style": "red"}} assert config.tree_levels_args.default.guide_style == "red" assert "label" in config.tree_levels_args.default.__dict__ - + assert "new_field" not in config.table_column_args.__dict__ config.table_column_args = {"new_field": 1} assert "new_field" in config.table_column_args.__dict__ assert config.table_column_args.new_field == 1 - - with pytest.raises(TypeError): + + with pytest.raises(TypeError): # typerror trigger in `FlagNameSpace.update`, # cause the value of `table_display_args` is a FlagNameSpace instance config.table_display_args = 1 - + config.combine = {"new_field": {"sub_field1": 1, "sub_field2": 2}} assert "new_field" in config.combine.__dict__ assert isinstance(config.combine.new_field, FlagNameSpace) assert config.combine.new_field.data_dict == {"sub_field1": 1, "sub_field2": 2} - ## with another FlagNameSpace + # with another FlagNameSpace assert config.tree_repeat_block_args.title_align != "right" config.tree_repeat_block_args = FlagNameSpace(title_align="right") assert config.tree_repeat_block_args.title_align == "right" - + assert config.tree_levels_args.default.guide_style != "blue" config.tree_levels_args = FlagNameSpace(default={"guide_style": "blue"}) assert config.tree_levels_args.default.guide_style == "blue" assert "label" in config.tree_levels_args.default.__dict__ - + assert "new_field2" not in config.table_column_args.__dict__ config.table_column_args = FlagNameSpace(new_field2=2) assert "new_field2" in config.table_column_args.__dict__ assert config.table_column_args.new_field2 == 2 - + config.combine = FlagNameSpace(new_field3={"sub_field1": 1, "sub_field2": 2}) assert "new_field3" in config.combine.__dict__ assert isinstance(config.combine.new_field, FlagNameSpace) assert config.combine.new_field3.data_dict == {"sub_field1": 1, "sub_field2": 2} - def test_delattr(self): + def test_delattr(self) -> None: config = Config() - for field in DEFAULT_FIELDS + ["config_file"]: + for field in [*DEFAULT_FIELDS, "config_file"]: with pytest.raises(RuntimeError): delattr(config, field) - def test_restore(self): + def test_restore(self) -> None: default_settings = yaml.safe_load(DEFAULT_CFG) config = Config() - + config.render_interval = 0.45 config.restore() - assert config.render_interval == default_settings['render_interval'] - + assert config.render_interval == default_settings["render_interval"] + config.tree_levels_args = {"2": {"label": "1"}} config.restore() assert "2" not in config.tree_levels_args.__dict__ - def test_check_integrity(self, custom_cfg_path): + def test_check_integrity(self, custom_cfg_path) -> None: config = Config() assert config.check_integrity() is None - + _ = TestGetConfig() - _.test_get_custom_file(custom_cfg_path) + _.test_get_custom_file(custom_cfg_path) - def test_asdict(self): + def test_asdict(self) -> None: config = Config() - + safe_dict = config.asdict(safe_resolve=True) default_safe_dict = yaml.safe_load(DEFAULT_CFG) assert isinstance(safe_dict, dict) assert set(safe_dict.keys()) == set(DEFAULT_FIELDS) assert safe_dict == default_safe_dict - + unsafe_dict = config.asdict(safe_resolve=False) default_unsafe_dict = yaml.safe_load(DEFAULT_CFG) - + def dfs_replace_unsafe_value(d): for k, v in d.items(): if isinstance(v, dict): @@ -1069,20 +1097,20 @@ def dfs_replace_unsafe_value(d): elif k in UNSAFE_KV: d[k] = getattr(UNSAFE_KV[k], v).value return d - + assert isinstance(safe_dict, dict) assert set(safe_dict.keys()) == set(DEFAULT_FIELDS) assert unsafe_dict == dfs_replace_unsafe_value(default_unsafe_dict) - def test_dump(self, custom_cfg_path): + def test_dump(self, custom_cfg_path) -> None: config = Config() config.dump(custom_cfg_path) assert os.path.exists(custom_cfg_path) @patch("torchmeter.config.Config.asdict") - def test_repr(self, mock_asdict): + def test_repr(self, mock_asdict) -> None: """Test the logic of `__repr__` method.""" - + expected = ( "β€’ Config file: test_config.yaml\n\n" "β€’ field A: 0.45 | \n\n" @@ -1108,41 +1136,43 @@ def test_repr(self, mock_asdict): "β”‚ └─ )\n" "└─ }" ) - - cfg = Config() - - mock_asdict.return_value = {"field A": 0.45, - "field B": 123, - "field C": [4, 5, 6], - "field D": (7, 8, 9), - "field E": {"subfield A": None, - "subfield B": True, - "subfield C": "test", - "subfield D": ("1", "2", "3")}} - - with patch.object(Config, "config_file", - new_callable=PropertyMock, - return_value="test_config.yaml"): - + + cfg = Config() + + mock_asdict.return_value = { + "field A": 0.45, + "field B": 123, + "field C": [4, 5, 6], + "field D": (7, 8, 9), + "field E": { + "subfield A": None, + "subfield B": True, + "subfield C": "test", + "subfield D": ("1", "2", "3"), + }, + } + + with patch.object(Config, "config_file", + new_callable=PropertyMock, + return_value="test_config.yaml"): # fmt: skip assert str(cfg).strip() == expected - def test_singleton(self, custom_cfg_path): + def test_singleton(self, custom_cfg_path) -> None: # verify same instance config1 = Config() config2 = Config() assert id(config1) == id(config2) - + # verify synchronization of changes config2.render_interval = 0.45 assert config1.render_interval == 0.45 - + # verify change is kept config3 = Config() assert config3.render_interval == 0.45 - + # verify reload when a new config_path is specified with pytest.warns(UserWarning): config4 = Config(custom_cfg_path) assert id(config1) == id(config4) assert config4.render_interval == 0.34 - \ No newline at end of file diff --git a/tests/test_core.py b/tests/test_core.py index 461a864..50d6fdf 100644 --- a/tests/test_core.py +++ b/tests/test_core.py @@ -1,90 +1,90 @@ -from unittest.mock import patch, ANY -from unittest.mock import MagicMock, PropertyMock +from unittest.mock import ANY, MagicMock, PropertyMock, patch import pytest import torch.nn as nn from rich import get_console -from rich.text import Text -from rich.layout import Layout -from torch import float16, float32 from torch import equal as torch_equal from torch import randn as torch_randn -from torch.utils.hooks import RemovableHandle +from torch import float16, float32 +from rich.text import Text from torch.cuda import is_available as is_cuda +from rich.layout import Layout +from torch.utils.hooks import RemovableHandle -from torchmeter.core import Meter, tc_device, __cfg__ +from torchmeter.core import Meter, __cfg__, tc_device from torchmeter.core import __cfg__ as core_cfg -from torchmeter.utils import indent_str, data_repr +from torchmeter.utils import data_repr, indent_str +from torchmeter.engine import CalMeter, MemMeter, IttpMeter, ParamsMeter, OperationNode, OperationTree from torchmeter.display import TreeRenderer, TabularRenderer -from torchmeter.engine import ( - OperationNode, OperationTree, - ParamsMeter, CalMeter, MemMeter, IttpMeter -) pytestmark = pytest.mark.vital + class ExampleModel(nn.Module): - def __init__(self): + def __init__(self) -> None: super(ExampleModel, self).__init__() self.layer0 = nn.Linear(10, 10) self.layer1 = nn.Sequential( nn.Linear(10, 10), nn.ReLU(), nn.Linear(10, 10), - nn.ReLU() + nn.ReLU(), ) - + def forward(self, ipt): return self.layer1(self.layer0(ipt)) + class EmptyModel(nn.Module): - def __init__(self): + def __init__(self) -> None: super(EmptyModel, self).__init__() - def forward(self): + + def forward(self) -> None: pass + class RepeatModel(nn.Module): - def __init__(self, repeat_nodes=1): + def __init__(self, repeat_nodes=1) -> None: super(RepeatModel, self).__init__() - + self.layer_ls = nn.ModuleList([nn.Identity() for _ in range(repeat_nodes)]) + class TestMeter: - model_getter = lambda _, metered_model: metered_model.optree.root.operation - - def test_valid_init(self): + + def test_valid_init(self) -> None: """Test valid initialization and basic functionality""" model = ExampleModel() - + # init a gpu model if is_cuda(): - gpu_model = Meter(model, device="cuda:0") + gpu_model = Meter(model, device="cuda:0") assert gpu_model._Meter__device.type == "cuda" assert self.model_getter(gpu_model).layer0.weight.device.type == "cuda" - + # init a cpu model - cpu_model = Meter(model, device="cpu") + cpu_model = Meter(model, device="cpu") assert cpu_model._Meter__device.type == "cpu" assert self.model_getter(cpu_model).layer0.weight.device.type == "cpu" - - assert cpu_model._ipt == {'args':tuple(), 'kwargs':dict()} + + assert cpu_model._ipt == {"args": tuple(), "kwargs": dict()} assert isinstance(cpu_model.optree, OperationTree) assert isinstance(cpu_model.tree_renderer, TreeRenderer) assert isinstance(cpu_model.table_renderer, TabularRenderer) - + assert cpu_model._Meter__measure_param is False assert cpu_model._Meter__measure_cal is False assert cpu_model._Meter__measure_mem is False assert cpu_model._Meter__has_nocall_nodes is None assert cpu_model._Meter__has_not_support_nodes is None - + assert hasattr(cpu_model, "ittp_warmup") assert hasattr(cpu_model, "ittp_benchmark_time") # set ittp_warmup and ittp_benchmark_time to a lower value to save time cpu_model.ittp_warmup = 2 cpu_model.ittp_benchmark_time = 2 - + cpu_model(torch_randn(1, 10)) assert hasattr(cpu_model, "ipt") assert hasattr(cpu_model, "device") @@ -101,246 +101,244 @@ def test_valid_init(self): assert hasattr(cpu_model, "model_info") assert hasattr(cpu_model, "subnodes") - def test_invalid_init(self): + def test_invalid_init(self) -> None: """Test invalid initialization""" with pytest.raises(TypeError): Meter("not a model") - def test_call(self): + def test_call(self) -> None: """Test the logic of __call__ method""" - input = torch_randn(1, 10) + ipt = torch_randn(1, 10) model = ExampleModel() - + # cpu_model - ## call with positional argument + # call with positional argument cpu_model = Meter(model, device="cpu") - output = cpu_model(input) + output = cpu_model(ipt) assert output.shape == (1, 10) assert output.device.type == "cpu" - assert cpu_model.ipt["args"][0] is input + assert cpu_model.ipt["args"][0] is ipt assert not len(cpu_model.ipt["kwargs"]) - - ## call with keyword argument - output2 = cpu_model(ipt=input) + + # call with keyword argument + output2 = cpu_model(ipt=ipt) assert output2.shape == (1, 10) assert output2.device.type == "cpu" - assert cpu_model.ipt["kwargs"]["ipt"] is input + assert cpu_model.ipt["kwargs"]["ipt"] is ipt assert not len(cpu_model.ipt["args"]) - - ## new input reset statistics measured flags + + # new input reset statistics measured flags cpu_model._Meter__measure_param = True cpu_model._Meter__measure_cal = True cpu_model._Meter__measure_mem = True - cpu_model(torch_randn(2, 10)) # different input triggers reset - assert not cpu_model._Meter__measure_param - assert not cpu_model._Meter__measure_cal + cpu_model(torch_randn(2, 10)) # different input triggers reset + assert not cpu_model._Meter__measure_param + assert not cpu_model._Meter__measure_cal assert not cpu_model._Meter__measure_mem - + cpu_model._Meter__measure_param = True cpu_model._Meter__measure_cal = True cpu_model._Meter__measure_mem = True - cpu_model(torch_randn(2, 10)) # same input not triggers reset - assert cpu_model._Meter__measure_param - assert cpu_model._Meter__measure_cal + cpu_model(torch_randn(2, 10)) # same input not triggers reset + assert cpu_model._Meter__measure_param + assert cpu_model._Meter__measure_cal assert cpu_model._Meter__measure_mem - + if is_cuda(): # gpu_model - ## call with positional argument + # call with positional argument gpu_model = Meter(model, device="cuda:0") - output = gpu_model(input) + output = gpu_model(ipt) assert output.shape == (1, 10) assert output.device.type == "cuda" - assert torch_equal(gpu_model.ipt["args"][0], input.to("cuda:0")) + assert torch_equal(gpu_model.ipt["args"][0], ipt.to("cuda:0")) assert not len(gpu_model.ipt["kwargs"]) - - ## call with keyword argument - output = gpu_model(ipt=input) + + # call with keyword argument + output = gpu_model(ipt=ipt) assert output.shape == (1, 10) assert output.device.type == "cuda" - assert torch_equal(gpu_model.ipt["kwargs"]['ipt'], input.to("cuda:0")) + assert torch_equal(gpu_model.ipt["kwargs"]["ipt"], ipt.to("cuda:0")) assert not len(gpu_model.ipt["args"]) - - def test_attr_operation(self): + + def test_attr_operation(self) -> None: """Test the logic of overwritten __get(del)attr__ method""" - + model = ExampleModel() model.test_attr = "ATTR" model.param = "PARAM" - model.test_method = lambda : "enter test method" - + model.test_method = lambda: "enter test method" + metered_model = Meter(model) - + # getter - ## get self attr + # get self attr assert hasattr(metered_model, "param") assert isinstance(metered_model.optree, OperationTree) - - ## get origin model's attr + + # get origin model's attr assert hasattr(metered_model, "test_attr") assert metered_model.test_attr == "ATTR" - - ## get attr with same name defined in origin model + + # get attr with same name defined in origin model assert isinstance(metered_model.param, ParamsMeter) assert metered_model.ORIGIN_param == "PARAM" - - ## get not exist attr + + # get not exist attr with pytest.raises(AttributeError): getattr(metered_model, "not_exist_attr") - - ## call origin model's method + + # call origin model's method assert metered_model.test_method() == "enter test method" - + # setter - ## set self attr - ### common attr + # set self attr + # common attr origin_val = metered_model.ittp_warmup setattr(metered_model, "ittp_warmup", "10") assert origin_val != metered_model.ittp_warmup assert metered_model.ittp_warmup == "10" - - ### class property that can be set + + # class property that can be set setattr(metered_model, "device", "cpu") - - ### class property that can not be set + + # class property that can not be set with pytest.raises(AttributeError): setattr(metered_model, "param", "Param") - ## set origin model's attr + # set origin model's attr model.test_attr = "NEW_ATTR" assert metered_model.test_attr == "NEW_ATTR" - - ## set attr with same name defined in origin model + + # set attr with same name defined in origin model model.param = "NEW_PARAM" assert isinstance(metered_model.param, ParamsMeter) assert metered_model.ORIGIN_param == "NEW_PARAM" - - ## set not exist attr + + # set not exist attr setattr(metered_model, "not_exist_attr", "NOW_EXIST") assert metered_model.not_exist_attr == "NOW_EXIST" - + setattr(model, "not_exist_attr_2", "NOW_EXIST_2") assert metered_model.not_exist_attr_2 == "NOW_EXIST_2" - - ## set origin model's method - model.test_method = lambda : "enter test method 2" + + # set origin model's method + model.test_method = lambda: "enter test method 2" assert metered_model.test_method() == "enter test method 2" - + # delttr - ## del self attr + # del self attr del metered_model.ittp_warmup assert not hasattr(metered_model, "ittp_warmup") with pytest.raises(AttributeError): del metered_model.param - - ## del origin model's attr + + # del origin model's attr del model.test_attr - assert not hasattr(metered_model.model,"test_attr") - - ## del attr with same name defined in origin model + assert not hasattr(metered_model.model, "test_attr") + + # del attr with same name defined in origin model del model.param assert not hasattr(metered_model, "ORIGIN_param") - ## del not exist attr + # del not exist attr with pytest.raises(AttributeError): del metered_model.not_exist_attr_3 - - ## del origin model's method + + # del origin model's method del model.test_method with pytest.raises(AttributeError): metered_model.test_method() - - def test_ipt(self): + + def test_ipt(self) -> None: """Test the ipt property is set and retrieved correctly""" - + metered_model = Meter(ExampleModel()) metered_model._ipt = "test_ipt" - + # verify whether ipt property is linked to _ipt assert metered_model.ipt == "test_ipt" - - def test_device(self): + + def test_device(self) -> None: """Test the device property is set and retrieved correctly""" - + model = ExampleModel() metered_model = Meter(model, device="cpu") - metered_model.ipt["args"] = (torch_randn(1,10, device=tc_device("cpu")), ) - metered_model.ipt["kwargs"] = {'ipt':torch_randn(1,10, device=tc_device("cpu"))} - + metered_model.ipt["args"] = (torch_randn(1, 10, device=tc_device("cpu")),) + metered_model.ipt["kwargs"] = {"ipt": torch_randn(1, 10, device=tc_device("cpu"))} + # retrieve assert hasattr(metered_model, "device") assert metered_model.device is metered_model._Meter__device assert metered_model.device.type == "cpu" - - # set + + # set if is_cuda(): metered_model.device = "cuda:0" assert metered_model.device.type == "cuda" assert metered_model.ipt["args"][0].device.type == "cuda" - assert metered_model.ipt["kwargs"]['ipt'].device.type == "cuda" + assert metered_model.ipt["kwargs"]["ipt"].device.type == "cuda" assert self.model_getter(metered_model).layer0.weight.device.type == "cuda" - - def test_to(self): + + def test_to(self) -> None: """Test the logic of changing model's device through `to` method""" - + metered_model = Meter(ExampleModel()) - + # verify link to `device` method - with patch.object(Meter, "device", - new_callable=PropertyMock, - return_value=Meter.device) as mock_device_property: + with patch.object( + Meter, "device", new_callable=PropertyMock, return_value=Meter.device + ) as mock_device_property: metered_model.to("cpu") mock_device_property.assert_called_once_with("cpu") - + # invalid input type with pytest.raises(RuntimeError): metered_model.to("not a device") - - def test_auto_detect_device(self): + + def test_auto_detect_device(self) -> None: """Test whether the device is auto detected when no device is specified""" - + # verify auto detect the device of cpu model - autodevice_model = Meter(ExampleModel().to("cpu")) + autodevice_model = Meter(ExampleModel().to("cpu")) assert autodevice_model.device.type == "cpu" - + # verify auto detect the device of gpu model if is_cuda(): - autodevice_model = Meter(ExampleModel().to("cuda:0")) + autodevice_model = Meter(ExampleModel().to("cuda:0")) assert autodevice_model.device.type == "cuda" - + # verify auto move no parameter model to cpu with pytest.warns(UserWarning): empty_model = Meter(EmptyModel()) assert empty_model.device.type == "cpu" - - def test_is_ipt_empty(self): + + def test_is_ipt_empty(self) -> None: """Test the logic of _is_ipt_empty method""" - + # empty metered_model = Meter(ExampleModel()) assert metered_model._is_ipt_empty() - + # only positional argument - metered_model._ipt = {"args":(torch_randn(1,10), ), - "kwargs":{}} + metered_model._ipt = {"args": (torch_randn(1, 10),), "kwargs": {}} assert not metered_model._is_ipt_empty() - + # only keyword argument - metered_model._ipt = {"args":(), - "kwargs":{'ipt':torch_randn(1,10)}} + metered_model._ipt = {"args": (), "kwargs": {"ipt": torch_randn(1, 10)}} assert not metered_model._is_ipt_empty() - + # both - metered_model.ipt["args"] = (torch_randn(1,10), ) + metered_model.ipt["args"] = (torch_randn(1, 10),) assert not metered_model._is_ipt_empty() @pytest.mark.skipif(not is_cuda(), reason="requires gpu") - def test_ipt2device(self): + def test_ipt2device(self) -> None: """Test the logic of _ipt2device method""" - + metered_model = Meter(ExampleModel()) to_method = torch_randn(1).to - + # empty ipt (no need) empty_metered_model = Meter(EmptyModel(), device="cpu") empty_metered_model._ipt2device() @@ -348,170 +346,169 @@ def test_ipt2device(self): # empty ipt(needed) with pytest.raises(RuntimeError): metered_model._ipt2device() - + # non Tensor ipt with patch("torch.Tensor.to", side_effect=to_method) as mock_to: - non_tensor_ipt = {"args":(1,"2", 3., None, lambda x:x, {1,2}), - "kwargs":{"A":1, "B":"2", "C":3., "D":None, "E":lambda x:x, "F":{1,2}}} + non_tensor_ipt = { + "args": (1, "2", 3.0, None, lambda x: x, {1, 2}), + "kwargs": {"A": 1, "B": "2", "C": 3.0, "D": None, "E": lambda x: x, "F": {1, 2}}, + } metered_model._ipt = non_tensor_ipt metered_model._ipt2device() mock_to.assert_not_called() - + # single tensor ipt - ## same device with model, will not move + # same device with model, will not move with patch("torch.Tensor.to", side_effect=to_method) as mock_to: - single_tensor_ipt = {"args":(torch_randn(1,10), ), - "kwargs": {"A":1, "B":"2", "C":3., "D":None, "E":lambda x:x, "F":{1,2}}} + single_tensor_ipt = { + "args": (torch_randn(1, 10),), + "kwargs": {"A": 1, "B": "2", "C": 3.0, "D": None, "E": lambda x: x, "F": {1, 2}}, + } metered_model._ipt = single_tensor_ipt metered_model._ipt2device() mock_to.assert_not_called() - - ## different device with model, will move to model's device + + # different device with model, will move to model's device with patch("torch.Tensor.to", side_effect=to_method) as mock_to: - single_tensor_ipt = {"args":(1,"2", 3., None, lambda x:x, {1,2}), - "kwargs": {"A":torch_randn(1,10, device=tc_device("cuda:0"))}} + single_tensor_ipt = { + "args": (1, "2", 3.0, None, lambda x: x, {1, 2}), + "kwargs": {"A": torch_randn(1, 10, device=tc_device("cuda:0"))}, + } metered_model._ipt = single_tensor_ipt metered_model._ipt2device() mock_to.assert_called_once() assert metered_model.ipt["kwargs"]["A"].device.type == "cpu" - + # multiple tensor ipt - ## same device with model, will not move + # same device with model, will not move with patch("torch.Tensor.to", side_effect=to_method) as mock_to: - multiple_tensor_ipt = {"args":(torch_randn(1,10), ), - "kwargs": {"A":torch_randn(1,10), "B":torch_randn(1,10)}} + multiple_tensor_ipt = { + "args": (torch_randn(1, 10),), + "kwargs": {"A": torch_randn(1, 10), "B": torch_randn(1, 10)}, + } metered_model._ipt = multiple_tensor_ipt metered_model._ipt2device() mock_to.assert_not_called() - - ## mixed device, will move all tensor to model's device + + # mixed device, will move all tensor to model's device with patch("torch.Tensor.to", side_effect=to_method) as mock_to: - multiple_tensor_ipt = {"args":(torch_randn(1,10), ), - "kwargs": {"A":torch_randn(1,10), "B":torch_randn(1,10, device=tc_device("cuda:0"))}} + multiple_tensor_ipt = { + "args": (torch_randn(1, 10),), + "kwargs": {"A": torch_randn(1, 10), "B": torch_randn(1, 10, device=tc_device("cuda:0"))}, + } metered_model._ipt = multiple_tensor_ipt metered_model._ipt2device() assert mock_to.call_count == 3 assert metered_model.ipt["kwargs"]["B"].device.type == "cpu" - - ## different device with model, will move all tensor to model's device + + # different device with model, will move all tensor to model's device with patch("torch.Tensor.to", side_effect=to_method) as mock_to: - multiple_tensor_ipt = {"args":(torch_randn(1,10, device=tc_device("cuda:0")), ), - "kwargs": {"A":torch_randn(1,10, device=tc_device("cuda:0"))}} + multiple_tensor_ipt = { + "args": (torch_randn(1, 10, device=tc_device("cuda:0")),), + "kwargs": {"A": torch_randn(1, 10, device=tc_device("cuda:0"))}, + } metered_model._ipt = multiple_tensor_ipt metered_model._ipt2device() assert mock_to.call_count == 2 assert metered_model.ipt["args"][0].device.type == "cpu" assert metered_model.ipt["kwargs"]["A"].device.type == "cpu" - + @pytest.mark.parametrize( - argnames=["origin", "new", "expected"], + argnames=("origin", "new", "expected"), argvalues=[ # empty - ({"args":tuple(), "kwargs":{}}, - {"args":(1,), "kwargs":{}}, True), - + ({"args": tuple(), "kwargs": {}}, {"args": (1,), "kwargs": {}}, True), # different number of anonymous args - ({"args":(1,)}, {"args":(1,2)}, True), - + ({"args": (1,)}, {"args": (1, 2)}, True), # same number but different inner value - ## different type in same position - ({"args":(1,2)}, {"args":(1.,2)}, True), - - ## different shape of tensor in same position - ({"args":(torch_randn(1, 10), )}, {"args":(torch_randn(1, 20), )}, True), - - ## different dtype of tensor in same position - ({"args":(torch_randn(1, 10, dtype=float16), )}, - {"args":(torch_randn(1, 10, dtype=float32), )}, True), - - ## different value in same position - ({"args":(1, 2)}, {"args":(2, 2)}, True), - - ## all same input without tensor data - ({"args":(1, 2), "kwargs":{}}, - {"args":(1, 2),"kwargs":{}}, False), - - ## all same input with tensor data of same shape and same dtype - ({"args":(torch_randn(1, 10, dtype=float32), ), "kwargs":{}}, - {"args":(torch_randn(1, 10, dtype=float32), ), "kwargs":{}}, False), - + # different type in same position + ({"args": (1, 2)}, {"args": (1.0, 2)}, True), + # different shape of tensor in same position + ({"args": (torch_randn(1, 10),)}, {"args": (torch_randn(1, 20),)}, True), + # different dtype of tensor in same position + ({"args": (torch_randn(1, 10, dtype=float16),)}, {"args": (torch_randn(1, 10, dtype=float32),)}, True), + # different value in same position + ({"args": (1, 2)}, {"args": (2, 2)}, True), + # all same input without tensor data + ({"args": (1, 2), "kwargs": {}}, {"args": (1, 2), "kwargs": {}}, False), + # all same input with tensor data of same shape and same dtype + ( + {"args": (torch_randn(1, 10, dtype=float32),), "kwargs": {}}, + {"args": (torch_randn(1, 10, dtype=float32),), "kwargs": {}}, + False, + ), # different number of keyword args - ({"args":tuple(), "kwargs":{"a":1, "b":2}}, - {"args":tuple(), "kwargs":{"a":1}}, True), - + ({"args": tuple(), "kwargs": {"a": 1, "b": 2}}, {"args": tuple(), "kwargs": {"a": 1}}, True), # same number but different keys - ({"args":tuple(), "kwargs":{"a":1, "b":2}}, - {"args":tuple(), "kwargs":{"a":1, "c":2}}, True), - + ({"args": tuple(), "kwargs": {"a": 1, "b": 2}}, {"args": tuple(), "kwargs": {"a": 1, "c": 2}}, True), # same number but different values - ## different value type of same key - ({"args":tuple(), "kwargs":{"a":1}}, - {"args":tuple(), "kwargs":{"a":1.}}, True), - - ## different shape of tensor in value of same key - ({"args":tuple(), "kwargs":{"a":torch_randn(1, 10)}}, - {"args":tuple(), "kwargs":{"a":torch_randn(1, 20)}}, True), - - ## different dtype of tensor in value of same key - ({"args":tuple(), "kwargs":{"a":torch_randn(1, 10, dtype=float16)}}, - {"args":tuple(), "kwargs":{"a":torch_randn(1, 10, dtype=float32)}}, True), - - ## different value of same key - ({"args":tuple(), "kwargs":{"a":1}}, - {"args":tuple(), "kwargs":{"a":2}}, True), - - ## all same input without tensor data - ({"args":tuple(), "kwargs":{"b":2, "c":3}}, - {"args":tuple(), "kwargs":{"b":2, "c":3}}, False), - - ## all same input with tensor data of same shape and same dtype - ({"args":tuple(), "kwargs":{"d":torch_randn(1, 10, dtype=float32)}}, - {"args":tuple(), "kwargs":{"d":torch_randn(1, 10, dtype=float32)}}, False) - - ] + # different value type of same key + ({"args": tuple(), "kwargs": {"a": 1}}, {"args": tuple(), "kwargs": {"a": 1.0}}, True), + # different shape of tensor in value of same key + ( + {"args": tuple(), "kwargs": {"a": torch_randn(1, 10)}}, + {"args": tuple(), "kwargs": {"a": torch_randn(1, 20)}}, + True, + ), + # different dtype of tensor in value of same key + ( + {"args": tuple(), "kwargs": {"a": torch_randn(1, 10, dtype=float16)}}, + {"args": tuple(), "kwargs": {"a": torch_randn(1, 10, dtype=float32)}}, + True, + ), + # different value of same key + ({"args": tuple(), "kwargs": {"a": 1}}, {"args": tuple(), "kwargs": {"a": 2}}, True), + # all same input without tensor data + ({"args": tuple(), "kwargs": {"b": 2, "c": 3}}, {"args": tuple(), "kwargs": {"b": 2, "c": 3}}, False), + # all same input with tensor data of same shape and same dtype + ( + {"args": tuple(), "kwargs": {"d": torch_randn(1, 10, dtype=float32)}}, + {"args": tuple(), "kwargs": {"d": torch_randn(1, 10, dtype=float32)}}, + False, + ), + ], ) - def test_is_ipt_changed(self, origin, new, expected, monkeypatch): + def test_is_ipt_changed(self, origin, new, expected, monkeypatch) -> None: """Test the logic of __ipt_is_changed method""" - + metered_model = Meter(ExampleModel()) target_method = metered_model._Meter__is_ipt_changed - + monkeypatch.setattr(metered_model, "_ipt", origin) - assert target_method(new) is expected - - def test_repr(self): + assert target_method(new) is expected + + def test_repr(self) -> None: """Test correct representation of Meter object""" - + metered_model = Meter(ExampleModel()) - + mock_optree = MagicMock() - mock_optree.__repr__ = lambda _ : "model_info" - + mock_optree.__repr__ = lambda _: "model_info" + mock_device = MagicMock() - mock_device.__repr__ = lambda _ : "device_info" - + mock_device.__repr__ = lambda _: "device_info" + with patch.object(metered_model, "optree", new=mock_optree), \ - patch("torchmeter.core.Meter.device", new=mock_device): - + patch("torchmeter.core.Meter.device", new=mock_device): # fmt: skip res = str(metered_model) - + assert res == "Meter(model=model_info, device=device_info)" - - def test_tree_fold_repeat(self): + + def test_tree_fold_repeat(self) -> None: """Test whether the `tree_fold_repeat` property is set and retrieved correctly.""" - + metered_model = Meter(ExampleModel()) - + assert hasattr(metered_model, "tree_fold_repeat") - + # retrieve assert metered_model.tree_fold_repeat == __cfg__.tree_fold_repeat - + # valid set metered_model.tree_fold_repeat = False assert metered_model.tree_fold_repeat is False assert __cfg__.tree_fold_repeat is False - + # invalid set with pytest.raises(TypeError): metered_model.tree_fold_repeat = 1 @@ -522,147 +519,145 @@ def test_tree_fold_repeat(self): ("tree_levels_args", "tree_renderer.tree_levels_args"), ("tree_repeat_block_args", "tree_renderer.repeat_block_args"), ("table_display_args", "table_renderer.tb_args"), - ("table_column_args", "table_renderer.col_args") - ] + ("table_column_args", "table_renderer.col_args"), + ], ) - def test_setting_related_property(self, attr_name, upper_bound): - """Test the setting related properties are set and retrieved correctly""" + def test_setting_related_property(self, attr_name, upper_bound) -> None: + """Test the setting related properties are set and retrieved correctly""" from operator import attrgetter - + metered_model = Meter(ExampleModel()) - + upper_getter = attrgetter(upper_bound) - + assert getattr(metered_model, attr_name) is upper_getter(metered_model) - + @patch("torchmeter.config.FlagNameSpace.mark_unchange") @patch("torchmeter.core.__cfg__.tree_levels_args.is_change") @patch("torchmeter.core.__cfg__.tree_repeat_block_args.is_change") - def test_structure(self, - mock_rpbk_change, mock_level_change, - mock_mark_unchange, monkeypatch): + def test_structure(self, mock_rpbk_change, mock_level_change, mock_mark_unchange, monkeypatch) -> None: """Test the rendered tree is correctly cached until some settings are changed""" metered_model = Meter(ExampleModel()) - + mock_tree_renderer = MagicMock(spec=TreeRenderer) # overwrite tree_renderer() - mock_tree_renderer.return_value = "re-render_tree" + mock_tree_renderer.return_value = "re-render_tree" # overwrite tree_renderer property with mock object monkeypatch.setattr(metered_model, "tree_renderer", mock_tree_renderer) - + # verify if the fold tree is cached with monkeypatch.context() as m: m.setattr(core_cfg, "tree_fold_repeat", True) # overwrite render result mock_tree_renderer.render_fold_tree = "rendered_fold_tree" - + # no change, no need to re-render mock_rpbk_change.return_value = False mock_level_change.return_value = False assert metered_model.structure == "rendered_fold_tree" mock_mark_unchange.assert_not_called() - + # only repeat block args changed, need to re-render mock_rpbk_change.return_value = True mock_level_change.return_value = False assert metered_model.structure == "re-render_tree" mock_mark_unchange.assert_called_once() mock_mark_unchange.reset_mock() - + # only level args changed, need to re-render mock_rpbk_change.return_value = False mock_level_change.return_value = True assert metered_model.structure == "re-render_tree" mock_mark_unchange.assert_called_once() mock_mark_unchange.reset_mock() - + # both settings changed, need to re-render mock_rpbk_change.return_value = True mock_level_change.return_value = True assert metered_model.structure == "re-render_tree" assert mock_mark_unchange.call_count == 2 mock_mark_unchange.reset_mock() - + # verify if the unfold tree is cached with monkeypatch.context() as m: m.setattr(core_cfg, "tree_fold_repeat", False) # overwrite render result mock_tree_renderer.render_unfold_tree = "rendered_unfold_tree" - + # no change, no need to re-render mock_rpbk_change.return_value = False mock_level_change.return_value = False assert metered_model.structure == "rendered_unfold_tree" mock_mark_unchange.assert_not_called() - + # only repeat block args changed, no need to re-render mock_rpbk_change.return_value = True mock_level_change.return_value = False assert metered_model.structure == "rendered_unfold_tree" mock_mark_unchange.assert_not_called() - + # only level args changed, need to re-render mock_rpbk_change.return_value = False mock_level_change.return_value = True assert metered_model.structure == "re-render_tree" mock_mark_unchange.assert_called_once() mock_mark_unchange.reset_mock() - + # both settings changed, need to re-render mock_rpbk_change.return_value = True mock_level_change.return_value = True assert metered_model.structure == "re-render_tree" mock_mark_unchange.assert_called_once() mock_mark_unchange.reset_mock() - + @patch("torchmeter.statistic.ParamsMeter.measure") - def test_param_property(self, mock_measure): + def test_param_property(self, mock_measure) -> None: """Test the parameter measurement result is retrieved and cached correctly""" - + metered_model = Meter(ExampleModel()) assert metered_model._Meter__measure_param is False - + # verify the measurement is triggered for all operationnode res = metered_model.param assert isinstance(res, ParamsMeter) assert mock_measure.call_count == len(metered_model.subnodes) assert metered_model._Meter__measure_param is True - + # verify the result is cached mock_measure.reset_mock() res2 = metered_model.param assert res2 is res mock_measure.assert_not_called() - + @patch("torchmeter.core.Meter._ipt2device") @patch("torchmeter.statistic.CalMeter.measure") - def test_cal_property(self, mock_measure, mock_ipt2device): + def test_cal_property(self, mock_measure, mock_ipt2device) -> None: """Test the calculation measurement result is retrieved and cached correctly""" - + metered_model = Meter(ExampleModel()) assert metered_model._Meter__measure_cal is False - + # mock a RemovableHandle object mock_handle = MagicMock(spec=RemovableHandle) mock_handle.remove.return_value = "removed" mock_measure.return_value = mock_handle - + # verify access the property when the input is unknown with pytest.raises(RuntimeError): metered_model.cal - + # verify auto move input the model's device # verify the measurement is triggered for all operationnode - metered_model._ipt = {"args":tuple(),"kwargs":{"ipt":torch_randn(1,10)}} + metered_model._ipt = {"args": tuple(), "kwargs": {"ipt": torch_randn(1, 10)}} res = metered_model.cal mock_ipt2device.assert_called_once() assert isinstance(res, CalMeter) assert metered_model._Meter__measure_cal is True assert mock_measure.call_count == len(metered_model.subnodes) assert mock_handle.remove.call_count == len(metered_model.subnodes) - + # verify the result is cached mock_measure.reset_mock() mock_handle.reset_mock() @@ -673,31 +668,31 @@ def test_cal_property(self, mock_measure, mock_ipt2device): @patch("torchmeter.core.Meter._ipt2device") @patch("torchmeter.statistic.MemMeter.measure") - def test_mem_property(self, mock_measure, mock_ipt2device): + def test_mem_property(self, mock_measure, mock_ipt2device) -> None: """Test the memory-access measurement result is retrieved and cached correctly""" - + metered_model = Meter(ExampleModel()) assert metered_model._Meter__measure_mem is False - + # mock a RemovableHandle object mock_handle = MagicMock(spec=RemovableHandle) mock_handle.remove.return_value = "removed" mock_measure.return_value = mock_handle - + # verify access the property when the input is unknown with pytest.raises(RuntimeError): metered_model.mem - + # verify auto move input the model's device # verify the measurement is triggered for all operationnode - metered_model._ipt = {"args":tuple(torch_randn(1,10),),"kwargs":{}} + metered_model._ipt = {"args": tuple(torch_randn(1, 10),), "kwargs": {}} # fmt: skip res = metered_model.mem mock_ipt2device.assert_called_once() assert isinstance(res, MemMeter) assert metered_model._Meter__measure_mem is True assert mock_measure.call_count == len(metered_model.subnodes) assert mock_handle.remove.call_count == len(metered_model.subnodes) - + # verify the result is cached mock_measure.reset_mock() mock_handle.reset_mock() @@ -708,34 +703,34 @@ def test_mem_property(self, mock_measure, mock_ipt2device): @patch("torchmeter.core.Meter._ipt2device") @patch("torchmeter.statistic.IttpMeter.measure") - def test_ittp_property(self, mock_measure, mock_ipt2device, monkeypatch): + def test_ittp_property(self, mock_measure, mock_ipt2device, monkeypatch) -> None: """Test the inference time & throughput measurement result is retrieved correctly""" - + metered_model = Meter(ExampleModel()) - + # mock a RemovableHandle object mock_handle = MagicMock(spec=RemovableHandle) mock_handle.remove.return_value = "removed" mock_measure.return_value = mock_handle - + # verify access the property when the input is unknown with pytest.raises(RuntimeError): metered_model.ittp - - metered_model._ipt = {"args":tuple(torch_randn(1,10),),"kwargs":{}} - + + metered_model._ipt = {"args": tuple(torch_randn(1, 10),), "kwargs": {}} # fmt: skip + # invalid warmup type with pytest.raises(TypeError): monkeypatch.setattr(metered_model, "ittp_warmup", "invalid") metered_model.ittp - + # invalid warmup value with pytest.raises(ValueError): monkeypatch.setattr(metered_model, "ittp_warmup", -1) metered_model.ittp - + monkeypatch.undo() - + # normal usage with patch.object(metered_model.model, "forward", wraps=metered_model.model.forward) as mock_call: monkeypatch.setattr(metered_model, "ittp_warmup", 10) @@ -745,13 +740,13 @@ def test_ittp_property(self, mock_measure, mock_ipt2device, monkeypatch): mock_ipt2device.assert_called_once() # verify the model is warmup for specified times before measurement - assert mock_call.call_count == 10 + 1 # warmup + 1 feed-forward + assert mock_call.call_count == 10 + 1 # warmup + 1 feed-forward # verify the measurement is triggered for all operationnode assert isinstance(res, IttpMeter) assert mock_measure.call_count == len(metered_model.subnodes) assert mock_handle.remove.call_count == len(metered_model.subnodes) - + # verify the result is not cached mock_ipt2device.reset_mock() mock_call.reset_mock() @@ -766,314 +761,300 @@ def test_ittp_property(self, mock_measure, mock_ipt2device, monkeypatch): @patch("torchmeter.utils.data_repr", wraps=data_repr) @patch("torchmeter.utils.indent_str", wraps=indent_str) - def test_model_info_property(self, mock_indent_str, mock_data_repr, monkeypatch): + def test_model_info_property(self, mock_indent_str, mock_data_repr, monkeypatch) -> None: """Test the model_info property is set and retrieved correctly""" - + from dataclasses import dataclass + from numpy import ones as np_ones - + @dataclass - class TextData(): - model: str # type: ignore - device: str # type: ignore - forward_sig: str # type: ignore - input: str # type: ignore - - def text_resolve(info:Text) -> TextData: + class TextData: + model: str # type: ignore + device: str # type: ignore + forward_sig: str # type: ignore + input: str # type: ignore + + def text_resolve(info: Text) -> TextData: plain_str = info.plain infos = plain_str.split("\n") - return TextData(model=infos[0], device=infos[1], - forward_sig=infos[2], input="\n".join((infos[3:]))) + return TextData(model=infos[0], device=infos[1], forward_sig=infos[2], input="\n".join((infos[3:]))) - def reset_all_mock(): + def reset_all_mock() -> None: mock_data_repr.reset_mock() mock_indent_str.reset_mock() metered_model = Meter(ExampleModel()) - + # verify output type direct_res = metered_model.model_info assert isinstance(direct_res, Text) - - # verify model name + + # verify model name monkeypatch.setattr(metered_model.optree.root, "name", "test_name") res = text_resolve(metered_model.model_info) assert "test_name" in res.model - + # verify device - ## cpu + # cpu with patch("torchmeter.core.Meter.device", new=tc_device("cpu")): res = text_resolve(metered_model.model_info) assert "cpu" in res.device - - ## gpu + + # gpu with patch("torchmeter.core.Meter.device", new=tc_device("cuda:20")): res = text_resolve(metered_model.model_info) assert "cuda:20" in res.device - + # verify forward args representation - ## without args - monkeypatch.setattr(metered_model.model, "forward", lambda : None) + # without args + monkeypatch.setattr(metered_model.model, "forward", lambda: None) res = text_resolve(metered_model.model_info) assert "forward(self)" in res.forward_sig - ## with multiple args - monkeypatch.setattr(metered_model.model, "forward", lambda a,b,c=2: None) + # with multiple args + monkeypatch.setattr(metered_model.model, "forward", lambda a, b, c=2: None) res = text_resolve(metered_model.model_info) assert "forward(self, a, b, c)" in res.forward_sig - - ## with variable args - monkeypatch.setattr(metered_model.model, "forward", lambda a,*var_position, **var_kw: None) + + # with variable args + monkeypatch.setattr(metered_model.model, "forward", lambda a, *var_position, **var_kw: None) res = text_resolve(metered_model.model_info) assert "forward(self, a, var_position, var_kw)" in res.forward_sig - + # verify input representation - ## empty input + # empty input reset_all_mock() - metered_model._ipt = {"args":tuple(),"kwargs":{}} + metered_model._ipt = {"args": tuple(), "kwargs": {}} res = text_resolve(metered_model.model_info) mock_data_repr.assert_not_called() mock_indent_str.assert_called_once() assert "Not Provided" in res.input - - ## any input (all have pass-in value) + + # any input (all have pass-in value) reset_all_mock() - metered_model._ipt = {"args":(torch_randn(1,10),20,3),"kwargs":{}} - monkeypatch.setattr(metered_model.model, "forward", lambda a,b,c=2: None) + metered_model._ipt = {"args": (torch_randn(1, 10), 20, 3), "kwargs": {}} + monkeypatch.setattr(metered_model.model, "forward", lambda a, b, c=2: None) res = text_resolve(metered_model.model_info) assert mock_data_repr.call_count == 3 mock_indent_str.assert_called_once() - assert all(t in res.input for t in ["a = Shape([1, 10])", - "b = 20", - "c = 3"]) - - ## any input (part of them have pass-in value) + assert all(t in res.input for t in ["a = Shape([1, 10])", "b = 20", "c = 3"]) + + # any input (part of them have pass-in value) reset_all_mock() - metered_model._ipt = {"args":(torch_randn(1,10),20),"kwargs":{}} - monkeypatch.setattr(metered_model.model, "forward", lambda a,b,c=2: None) + metered_model._ipt = {"args": (torch_randn(1, 10), 20), "kwargs": {}} + monkeypatch.setattr(metered_model.model, "forward", lambda a, b, c=2: None) res = text_resolve(metered_model.model_info) assert mock_data_repr.call_count == 2 mock_indent_str.assert_called_once() - assert all(t in res.input for t in ["a = Shape([1, 10])", - "b = 20"]) + assert all(t in res.input for t in ["a = Shape([1, 10])", "b = 20"]) assert "c = 2" not in res.input - - ## any input (pass-in value through keyword argument) + + # any input (pass-in value through keyword argument) reset_all_mock() - metered_model._ipt = {"args":tuple(),"kwargs":{"a":torch_randn(1,10), "b":20, "c":20}} - monkeypatch.setattr(metered_model.model, "forward", lambda a,b,c=2: None) + metered_model._ipt = {"args": tuple(), "kwargs": {"a": torch_randn(1, 10), "b": 20, "c": 20}} + monkeypatch.setattr(metered_model.model, "forward", lambda a, b, c=2: None) res = text_resolve(metered_model.model_info) assert mock_data_repr.call_count == 3 mock_indent_str.assert_called_once() - assert all(t in res.input for t in ["a = Shape([1, 10])", - "b = 20", - "c = 20"]) + assert all(t in res.input for t in ["a = Shape([1, 10])", "b = 20", "c = 20"]) - ## any input (pass-in value through keyword argument and positional argument) + # any input (pass-in value through keyword argument and positional argument) reset_all_mock() - metered_model._ipt = {"args":(torch_randn(1,10),np_ones([3,4,5])),"kwargs":{"c":40}} - monkeypatch.setattr(metered_model.model, "forward", lambda a,b,c=2: None) + metered_model._ipt = {"args": (torch_randn(1, 10), np_ones([3, 4, 5])), "kwargs": {"c": 40}} + monkeypatch.setattr(metered_model.model, "forward", lambda a, b, c=2: None) res = text_resolve(metered_model.model_info) assert mock_data_repr.call_count == 3 mock_indent_str.assert_called_once() - assert all(t in res.input for t in ["a = Shape([1, 10])", - "b = Shape([3, 4, 5])", - "c = 40"]) + assert all(t in res.input for t in ["a = Shape([1, 10])", "b = Shape([3, 4, 5])", "c = 40"]) @pytest.mark.parametrize( - argnames="repeat_num", - argvalues=[2, 2**2, 2**4, 2**8] + argnames="repeat_num", + argvalues=[2, 2**2, 2**4, 2**8], ) - def test_subnodes(self, repeat_num): + def test_subnodes(self, repeat_num) -> None: """Test the subnodes property is set and retrieved correctly""" - + metered_model = Meter(RepeatModel(repeat_nodes=repeat_num), device="cpu") - assert len(metered_model.subnodes) == repeat_num + 2 # 2: root + modulelist - - def test_rebase(self): - """Test the logic of rebase method""" - + assert len(metered_model.subnodes) == repeat_num + 2 # 2: root + modulelist + + def test_rebase(self) -> None: + """Test the logic of rebase method""" + metered_model = Meter(ExampleModel()) - + # rebase to root itself, return self directly - with patch.object(Meter, "__init__", wraps=Meter.__init__) as mock_new: + with patch.object(Meter, "__init__", wraps=Meter.__init__) as mock_new: rebase_model = metered_model.rebase("0") mock_new.assert_not_called() assert rebase_model is metered_model - + # rebase to child node rebase_model = metered_model.rebase("2.1") assert rebase_model.optree.root.type == "Linear" - + # invalid argument type with pytest.raises(TypeError): metered_model.rebase(0) - + # invalid argument value with pytest.raises(ValueError): metered_model.rebase("10") - - def test_stat_info(self): + + def test_stat_info(self) -> None: """Test the logic of stat_info method""" - + metered_model = Meter(ExampleModel()) - metered_model(torch_randn(1,10)) + metered_model(torch_randn(1, 10)) metered_model._Meter__has_nocall_nodes = False - + # verify output type # input a stat name - with patch.object(ParamsMeter, "crucial_data", - new_callable=PropertyMock) as mock_crucial_data: + with patch.object(ParamsMeter, "crucial_data", new_callable=PropertyMock) as mock_crucial_data: direct_res = metered_model.stat_info("param") mock_crucial_data.assert_called_once() assert isinstance(direct_res, Text) - + # input a stat obj - with patch.object(CalMeter, "crucial_data", - new_callable=PropertyMock) as mock_crucial_data: + with patch.object(CalMeter, "crucial_data", new_callable=PropertyMock) as mock_crucial_data: direct_res = metered_model.stat_info(metered_model.cal) mock_crucial_data.assert_called_once() assert isinstance(direct_res, Text) - + # invalid input type with pytest.raises(TypeError): metered_model.stat_info(["param", "cal"]) - + # verify ittp special field - with patch.object(IttpMeter, "crucial_data", - new_callable=PropertyMock) as mock_crucial_data: + with patch.object(IttpMeter, "crucial_data", new_callable=PropertyMock) as mock_crucial_data: direct_res = metered_model.stat_info("ittp") mock_crucial_data.assert_called_once() assert "Benchmark Times" in direct_res.plain - + # verify content is the crucial data of the specified stat - with patch.object(MemMeter, "crucial_data", - new_callable=PropertyMock) as mock_crucial_data: + with patch.object(MemMeter, "crucial_data", new_callable=PropertyMock) as mock_crucial_data: direct_res = metered_model.stat_info("mem") mock_crucial_data.assert_called_once() - - def test_stat_info_warning(self, monkeypatch): + + def test_stat_info_warning(self, monkeypatch) -> None: """Test the logic of generating warning info""" - + class NotSupportModel(nn.Module): - def __init__(self): + def __init__(self) -> None: super(NotSupportModel, self).__init__() - self.layer0 = nn.Linear(10,5) + self.layer0 = nn.Linear(10, 5) self.layer1 = nn.AdaptiveAvgPool1d(1) + def forward(self, x): return self.layer1(x) - + metered_model = Meter(NotSupportModel(), device="cpu") # set ittp_warmup and ittp_benchmark_time to a lower value to save time metered_model.ittp_warmup = 2 metered_model.ittp_benchmark_time = 2 - metered_model(torch_randn(1,10,5)) - - nocall_flag = lambda :metered_model._Meter__has_nocall_nodes - notsupport_flag = lambda :metered_model._Meter__has_not_support_nodes - + metered_model(torch_randn(1, 10, 5)) + + nocall_flag = lambda: metered_model._Meter__has_nocall_nodes + notsupport_flag = lambda: metered_model._Meter__has_not_support_nodes + assert nocall_flag() is None assert notsupport_flag() is None - + # only take effect when the stat is cal or mem metered_model.stat_info("param", show_warning=True) assert nocall_flag() is None assert notsupport_flag() is None - + metered_model.stat_info("ittp", show_warning=True) assert nocall_flag() is None assert notsupport_flag() is None - + metered_model.stat_info("mem", show_warning=True) assert nocall_flag() is True assert notsupport_flag() is None - + metered_model._Meter__has_nocall_nodes = None metered_model.stat_info("cal", show_warning=True) assert nocall_flag() is True assert notsupport_flag() is True - + # verify show_warning option metered_model._Meter__has_nocall_nodes = None metered_model._Meter__has_not_support_nodes = None metered_model.stat_info("cal", show_warning=False) assert nocall_flag() is None assert notsupport_flag() is None - + # verify __has_nocall_nodes flag is properly set & - # verify cache of __has_nocall_nodes - with patch.object(MemMeter, "crucial_data", - new_callable=PropertyMock) as mock_crucial_data: - ## when there is no nocalled nodes (crucial_data is mocked and will raise error) + # verify cache of __has_nocall_nodes + with patch.object(MemMeter, "crucial_data", new_callable=PropertyMock) as mock_crucial_data: + # when there is no nocalled nodes (crucial_data is mocked and will raise error) metered_model.stat_info("mem", show_warning=True) - assert mock_crucial_data.call_count == 4 # 1: provide info to info_ls; 3: traverse all 3 nodes + assert mock_crucial_data.call_count == 4 # 1: provide info to info_ls; 3: traverse all 3 nodes assert nocall_flag() is False - - ## when the second traversal node is no called + + # when the second traversal node is no called mock_crucial_data.reset_mock() metered_model._Meter__has_nocall_nodes = None mock_crucial_data.side_effect = [{}, True, RuntimeError] metered_model.stat_info("mem", show_warning=True) - assert mock_crucial_data.call_count == 3 # 1: provide info to info_ls; 2: traverse the leading 2 nodes + # 1: provide info to info_ls; 2: traverse the leading 2 nodes + assert mock_crucial_data.call_count == 3 assert nocall_flag() is True - + # verify cache of __has_nocall_nodes mock_crucial_data.reset_mock() mock_crucial_data.side_effect = [{}] metered_model.stat_info("mem", show_warning=True) - mock_crucial_data.assert_called_once() # 1: provide info to info_ls; 0: no need to traverse due to the cache + # 1: provide info to info_ls; 0: no need to traverse due to the cache + mock_crucial_data.assert_called_once() # verify __has_not_support_nodes flag is properly set & # verify cache of __has_not_support_nodes - with patch.object(OperationNode, "cal", - new_callable=PropertyMock) as mock_cal: - + with patch.object(OperationNode, "cal", new_callable=PropertyMock) as mock_cal: mock_cal_instance = MagicMock(spec=CalMeter) type(mock_cal_instance).name = PropertyMock(return_value="cal") mock_cal.return_value = mock_cal_instance - - ## when all nodes are supported + + # when all nodes are supported mock_cal_instance.is_not_supported = False metered_model.stat_info("cal", show_warning=True) assert notsupport_flag() is False - ## when any node is not supported + # when any node is not supported mock_cal_instance.is_not_supported = True metered_model._Meter__has_not_support_nodes = None metered_model.stat_info("cal", show_warning=True) assert notsupport_flag() is True - - ## verify cache of __has_not_support_nodes + + # verify cache of __has_not_support_nodes mock_cal.reset_mock() metered_model._Meter__has_not_support_nodes = None metered_model.stat_info("cal", show_warning=True) - assert mock_cal.call_count == 2 # 1: provide info to info_ls; 1: check each node.cal.is_not_supported - + assert mock_cal.call_count == 2 # 1: provide info to info_ls; 1: check each node.cal.is_not_supported + mock_cal.reset_mock() metered_model.stat_info("cal", show_warning=True) - assert mock_cal.call_count == 1 # 1: provide info to info_ls - + assert mock_cal.call_count == 1 # 1: provide info to info_ls + # verify warning info is added to info_ls correctly - with patch.object(OperationNode, "cal", - new_callable=PropertyMock) as mock_cal, \ - patch.object(CalMeter, "crucial_data", - new_callable=PropertyMock) as mock_crucial_data: - + with patch.object(OperationNode, "cal", new_callable=PropertyMock) as mock_cal, \ + patch.object(CalMeter, "crucial_data", new_callable=PropertyMock) as mock_crucial_data: # fmt: skip mock_cal_instance = MagicMock(spec=CalMeter) type(mock_cal_instance).name = PropertyMock(return_value="cal") type(mock_cal_instance).crucial_data = mock_crucial_data mock_cal.return_value = mock_cal_instance - - ## when there is no warning info, i.e. no nocalled nodes and not-supported nodes - mock_crucial_data.side_effect=[dict()] + + # when there is no warning info, i.e. no nocalled nodes and not-supported nodes + mock_crucial_data.side_effect = [dict()] mock_cal_instance.is_not_supported = False metered_model._Meter__has_nocall_nodes = None metered_model._Meter__has_not_support_nodes = None res = metered_model.stat_info("cal", show_warning=True).plain assert "Warning" not in res - - ## when there are only nocall nodes, but all nodes are supported - mock_crucial_data.side_effect=[dict(), RuntimeError] + + # when there are only nocall nodes, but all nodes are supported + mock_crucial_data.side_effect = [dict(), RuntimeError] mock_cal_instance.is_not_supported = False metered_model._Meter__has_nocall_nodes = None metered_model._Meter__has_not_support_nodes = None @@ -1081,9 +1062,9 @@ def forward(self, x): assert "Warning" in res assert "not called" in res assert "don't support" not in res - - ## when there are only not-supported nodes, but all nodes are called - mock_crucial_data.side_effect=[dict()]*10 + + # when there are only not-supported nodes, but all nodes are called + mock_crucial_data.side_effect = [dict()] * 10 mock_cal_instance.is_not_supported = True metered_model._Meter__has_nocall_nodes = None metered_model._Meter__has_not_support_nodes = None @@ -1091,9 +1072,9 @@ def forward(self, x): assert "Warning" in res assert "not called" not in res assert "don't support" in res - - ## when nocall nodes and not-supported nodes exist in the same time - mock_crucial_data.side_effect=[dict(), RuntimeError] + + # when nocall nodes and not-supported nodes exist in the same time + mock_crucial_data.side_effect = [dict(), RuntimeError] mock_cal_instance.is_not_supported = True metered_model._Meter__has_nocall_nodes = None metered_model._Meter__has_not_support_nodes = None @@ -1101,100 +1082,99 @@ def forward(self, x): assert "Warning" in res assert "not called" in res assert "don't support" in res - - def test_overview(self): + + def test_overview(self) -> None: """Test the logic of overview method""" - + from rich.panel import Panel from rich.columns import Columns - + metered_model = Meter(ExampleModel()) # set ittp_warmup and ittp_benchmark_time to a lower value to save time metered_model.ittp_warmup = 2 metered_model.ittp_benchmark_time = 2 - metered_model(torch_randn(1,10)) - + metered_model(torch_randn(1, 10)) + order_getter = lambda res: [p._title.plain.split(" INFO")[0].strip().lower() - for p in res.renderables] - + for p in res.renderables] # fmt: skip + # verify output type res = metered_model.overview() assert isinstance(res, Columns) - + # verify default order is the order defined in OperationNode.statistics res_order = order_getter(res) - assert res_order == ["model"] + list(OperationNode.statistics) - + assert res_order == ["model", *list(OperationNode.statistics)] + # verify custom order res = metered_model.overview("param", "mem") res_order = order_getter(res) assert len(res.renderables) == 3 assert res_order == ["model", "param", "mem"] - + # invalid stat name with pytest.raises(ValueError): metered_model.overview("invalid_stat") - + # verify content - with patch.object(metered_model, "stat_info", - wraps=metered_model.stat_info) as mock_stat_info: - ## whether model info always exists, see the second section above - - ## each item is a panel of stat_info + with patch.object(metered_model, "stat_info", wraps=metered_model.stat_info) as mock_stat_info: + # whether model info always exists, see the second section above + + # each item is a panel of stat_info res = metered_model.overview("mem", "cal") assert all(isinstance(p, Panel) for p in res.renderables) assert mock_stat_info.call_count == 2 mock_stat_info.assert_any_call("mem", show_warning=True) mock_stat_info.assert_any_call("cal", show_warning=True) - - ## default setting is True + + # default setting is True metered_model.overview("param", "cal") call_args_ls = mock_stat_info.call_args_list assert all(call_args.kwargs["show_warning"] for call_args in call_args_ls) - - ## custom setting + + # custom setting mock_stat_info.reset_mock() metered_model.overview("param", "cal", show_warning=False) call_args_ls = mock_stat_info.call_args_list assert all(not call_args.kwargs["show_warning"] for call_args in call_args_ls) - - def test_table_cols(self): + + def test_table_cols(self) -> None: """Test the logic of table_cols method""" - + from polars import DataFrame - + metered_model = Meter(ExampleModel()) - + # invalid stat name with pytest.raises(KeyError): metered_model.table_cols("invalid_stat") - + # invalid input type with pytest.raises(TypeError): metered_model.table_cols(["param", "cal"]) - + # verify the result got from a empty dict of dataframe metered_model.table_renderer.stats_data["param"] = DataFrame() cols = metered_model.table_cols("param") assert cols == ParamsMeter.detail_val_container._fields - + # verify the result got from a non-empty dict of dataframe - metered_model.table_renderer.stats_data["cal"] = DataFrame({"test_A":[1,2], "test_B":[1,2]}) + metered_model.table_renderer.stats_data["cal"] = DataFrame({"test_A": [1, 2], "test_B": [1, 2]}) cols = metered_model.table_cols("cal") assert cols == ("test_A", "test_B") - def test_profile_iopt(self): + def test_profile_iopt(self) -> None: """Test the type and content of input and output.""" - - from rich.table import Table + from polars import DataFrame - + from rich.table import Table + metered_model = Meter(ExampleModel()) - + # invalid input type with pytest.raises(TypeError): metered_model.profile(stat_name={"cal", "param"}) - + # invalid input stat name with pytest.raises(AttributeError): metered_model.profile(stat_name="invalid_stat") @@ -1203,221 +1183,223 @@ def test_profile_iopt(self): output_a, output_b = metered_model.profile("param", show=False) assert isinstance(output_a, Table) assert isinstance(output_b, DataFrame) - + @patch("torchmeter.core.render_perline") - def test_profile_option(self, mock_render, capsys): + def test_profile_option(self, mock_render, capsys) -> None: """Test the logic of different profile option.""" - terminal_output_strls = lambda: capsys.readouterr().out.strip().split("\n") + lambda: capsys.readouterr().out.strip().split("\n") metered_model = Meter(ExampleModel()) model_structure = metered_model.structure - + # show = False, no_tree = False mock_render.reset_mock() metered_model.profile("param", show=False, no_tree=False) mock_render.assert_not_called() - + # show = False, no_tree = True mock_render.reset_mock() metered_model.profile("param", show=False, no_tree=True) mock_render.assert_not_called() - - with patch.object(Meter, "structure", - new_callable=PropertyMock, - return_value=model_structure) as mock_structure: + + with patch.object( + Meter, "structure", new_callable=PropertyMock, return_value=model_structure + ) as mock_structure: # show = True, no_tree = False mock_render.reset_mock() mock_structure.reset_mock() metered_model.profile("param", show=True, no_tree=False) mock_structure.assert_called_once() mock_render.assert_called_once() - + # show = True, no_tree = True mock_render.reset_mock() mock_structure.reset_mock() metered_model.profile("param", show=True, no_tree=True) mock_structure.assert_not_called() mock_render.assert_called_once() - + @patch("torchmeter.core.render_perline") - def test_profile_horizon_gap(self, mock_render): + def test_profile_horizon_gap(self, mock_render) -> None: """Test the gap between tree and table is as same as the one in config""" - + origin_init = Layout.__init__ - def layout_init_wrapper(self, *args, **kwargs): - origin_init(self, *args, **kwargs) + + def layout_init_wrapper(self, *args, **kwargs) -> None: + origin_init(self, *args, **kwargs) metered_model = Meter(ExampleModel()) model_structure = metered_model.structure - + console = get_console() tree_width = console.measure(model_structure).maximum - + # negative gap with pytest.raises(ValueError): __cfg__.combine.horizon_gap = -10 metered_model.profile("param", show=True, no_tree=False) - # verify custom gap + # verify custom gap __cfg__.combine.horizon_gap = 10 with patch.object(Layout, "__init__", autospec=True, - side_effect=layout_init_wrapper) as mock_init_layout: + side_effect=layout_init_wrapper) as mock_init_layout: # fmt: skip metered_model.profile("param", show=True, no_tree=False) kwargs_dict = mock_init_layout.call_args_list[1].kwargs assert kwargs_dict["size"] == tree_width + 10 - + @patch("torchmeter.core.render_perline") - def test_profile_content(self, mock_render): + def test_profile_content(self, mock_render) -> None: """Test the logic of generating the profile content.""" - + from rich.rule import Rule from rich.table import Table from rich.columns import Columns - + origin_init = Layout.__init__ - def layout_init_wrapper(self, *args, **kwargs): - origin_init(self, *args, **kwargs) - + + def layout_init_wrapper(self, *args, **kwargs) -> None: + origin_init(self, *args, **kwargs) + console = get_console() metered_model = Meter(ExampleModel()) tree = metered_model.structure tree_width = console.measure(tree).maximum - + renderable_getter = lambda c: c.args[0]._renderable - + # verify main content generation with patch.object(Layout, "__init__", autospec=True, side_effect=layout_init_wrapper) as mock_init_layout,\ patch.object(metered_model, "stat_info", - wraps=metered_model.stat_info) as mock_stat_info: - ## no tree, main content is a Table + wraps=metered_model.stat_info) as mock_stat_info: # fmt: skip + # no tree, main content is a Table metered_model.profile("param", show=True, no_tree=True) main_content = renderable_getter(mock_init_layout.call_args_list[-2]) - + assert mock_init_layout.call_count == 3 assert isinstance(main_content, Table) - - ## with tree, main content is a Layout with two columns + + # with tree, main content is a Layout with two columns mock_init_layout.reset_mock() - tb,_ = metered_model.profile("param", show=True, no_tree=False) + tb, _ = metered_model.profile("param", show=True, no_tree=False) main_content = renderable_getter(mock_init_layout.call_args_list[-2]) - + assert mock_init_layout.call_count == 6 assert isinstance(main_content, Layout) assert len(main_content._children) == 2 - + assert main_content["left"]._renderable is tree assert main_content["right"]._renderable is tb assert main_content["left"].size == tree_width + __cfg__.combine.horizon_gap - + # verify footer generation footer = renderable_getter(mock_init_layout.call_args_list[-1]) assert isinstance(footer, Columns) assert len(footer.renderables) == 2 assert all(isinstance(ctt, Text) for ctt in footer.renderables) mock_stat_info.assert_called_with(stat_or_statname=ANY, show_warning=False) - + assert isinstance(footer.title, Rule) - + @patch("torchmeter.core.render_perline") - def test_profile_console_management(self, mock_render, monkeypatch): + def test_profile_console_management(self, mock_render, monkeypatch) -> None: """Test the rendering interactive logic related to console size""" - - + from rich.console import Console - + console = get_console() metered_model = Meter(ExampleModel()) - + # verify auto show_lines when there is no enough space for table - ## when the table width is smaller than the terminal width + # when the table width is smaller than the terminal width monkeypatch.setattr(console, "width", float("inf")) tb, _ = metered_model.profile("param", show=True, no_tree=True) assert tb.show_lines is False - - ## when the table width exceeds the terminal width + + # when the table width exceeds the terminal width monkeypatch.setattr(console, "width", console.measure(tb).maximum - 10) tb, _ = metered_model.profile("param", show=True, no_tree=False) assert tb.show_lines is True - - + # verify minimal console width error monkeypatch.setattr(console, "width", 5) with pytest.raises(RuntimeError): metered_model.profile("param", show=True, no_tree=False) - + monkeypatch.setattr(console, "width", 1) with pytest.raises(RuntimeError): metered_model.profile("param", show=True, no_tree=False) - # verify console size change and restore - ## init a large enough console with size 120x60 + # init a large enough console with size 120x60 monkeypatch.setattr(console, "width", 120) monkeypatch.setattr(console, "height", 60) origin_width, origin_height = console.width, console.height - - with patch.object(Console, "width", new_callable=PropertyMock, + + # fmt: off + with patch.object(Console, "width", new_callable=PropertyMock, return_value=origin_width) as mock_console_width, \ patch.object(Console, "height", new_callable=PropertyMock, return_value=origin_height) as mock_console_height, \ patch("rich.layout.Layout.split_column", wraps=Layout().split_column) as mock_split_col: - + # fmt: on + metered_model.profile("param", show=True, no_tree=True) upper_layout, down_layout = mock_split_col.call_args.args content_width = console.measure(upper_layout._renderable).maximum content_height = upper_layout.size + down_layout.size - + is_empty_call = lambda c: not len(c.args) and not len(c.kwargs) - + for mock_size_attr, origin_val, content_val in zip( [mock_console_width, mock_console_height], [origin_width, origin_height], - [content_width, content_height] + [content_width, content_height], ): call_ls = mock_size_attr.call_args_list setter_calls = [(c_idx, c) for c_idx, c in enumerate(call_ls) - if not is_empty_call(c)] - ## 1: set to canvas size; 1: restore - assert len(setter_calls) == 2 - - ## the restore happen right after it is set - assert setter_calls[1][0] == setter_calls[0][0] + 1 - - ## set to the canvas corresponding size, here the canvas only contains a table + if not is_empty_call(c)] # fmt: skip + # 1: set to canvas size; 1: restore + assert len(setter_calls) == 2 + + # the restore happen right after it is set + assert setter_calls[1][0] == setter_calls[0][0] + 1 + + # set to the canvas corresponding size, here the canvas only contains a table assert setter_calls[0][1].args[0] == content_val - ## verify if restore to the original size - assert setter_calls[1][1].args[0] == origin_val + # verify if restore to the original size + assert setter_calls[1][1].args[0] == origin_val - # verify console size restore when the rendering is interrupted + # verify console size restore when the rendering is interrupted mock_render.side_effect = KeyboardInterrupt("Simulated interrupt") - + + # fmt: off with patch.object(Console, "width", new_callable=PropertyMock, return_value=origin_width) as mock_console_width, \ patch.object(Console, "height", new_callable=PropertyMock, return_value=origin_height) as mock_console_height, \ patch("rich.layout.Layout.split_column", wraps=Layout().split_column) as mock_split_col: - + # fmt: on + with pytest.raises(KeyboardInterrupt): metered_model.profile("param", show=True, no_tree=True) upper_layout, down_layout = mock_split_col.call_args.args content_width = console.measure(upper_layout._renderable).maximum content_height = upper_layout.size + down_layout.size - + is_empty_call = lambda c: not len(c.args) and not len(c.kwargs) - + for mock_size_attr, origin_val, content_val in zip( [mock_console_width, mock_console_height], [origin_width, origin_height], - [content_width, content_height] + [content_width, content_height], ): call_ls = mock_size_attr.call_args_list setter_calls = [(c_idx, c) for c_idx, c in enumerate(call_ls) - if not is_empty_call(c)] - assert len(setter_calls) == 2 - assert setter_calls[1][0] == setter_calls[0][0] + 1 + if not is_empty_call(c)] # fmt: skip + assert len(setter_calls) == 2 + assert setter_calls[1][0] == setter_calls[0][0] + 1 assert setter_calls[0][1].args[0] == content_val assert setter_calls[1][1].args[0] == origin_val - \ No newline at end of file diff --git a/tests/test_dfstask.py b/tests/test_dfstask.py index 35ce0a0..0d89039 100644 --- a/tests/test_dfstask.py +++ b/tests/test_dfstask.py @@ -2,17 +2,20 @@ from torchmeter.utils import dfs_task + class TreeNode: - def __init__(self, val, left=None, right=None): + def __init__(self, val, left=None, right=None) -> None: self.val = val self.left = left self.right = right + class GraphNode: - def __init__(self, val): + def __init__(self, val) -> None: self.val = val self.children = [] + @pytest.fixture def binary_tree_root(): """Build a binary tree for test: @@ -28,134 +31,139 @@ def binary_tree_root(): n3 = TreeNode(3) return TreeNode(1, n2, n3) + @pytest.fixture def cyclic_graph_stnode(): """Create a cyclic graph: A β†’ B β†’ C ↑ ↓ - D ← ← ← + D ← ← ← """ - node_a = GraphNode('A') - node_b = GraphNode('B') - node_c = GraphNode('C') - node_d = GraphNode('D') - + node_a = GraphNode("A") + node_b = GraphNode("B") + node_c = GraphNode("C") + node_d = GraphNode("D") + node_a.children = [node_b] node_b.children = [node_c] node_c.children = [node_d] node_d.children = [node_a] return node_a + @pytest.mark.vital class TestDfsTask: # basic funtion test - def test_binary_tree_traversal(self, binary_tree_root): + def test_binary_tree_traversal(self, binary_tree_root) -> None: """Test standard binary tree preorder traversal using DFS""" traversal_order = [] - + # task function: preorder traversal def record_node(subject, pre_res=[]): if subject is None: return pre_res traversal_order.append(subject.val) - return pre_res + [subject.val] - + return [*pre_res, subject.val] + dfs_task( dfs_subject=binary_tree_root, adj_func=lambda n: [child for child in (n.left, n.right) if child is not None], task_func=record_node, visited_signal_func=lambda x: id(x), - visited=[] + visited=[], ) - + assert traversal_order == [1, 2, 4, 5, 3] - def test_cyclic_graph_traversal(self, cyclic_graph_stnode): + def test_cyclic_graph_traversal(self, cyclic_graph_stnode) -> None: """Test the traversal of a cyclic graph""" visited_nodes = [] - + def track_nodes(subject, pre_res=[]): visited_nodes.append(subject.val) - return pre_res + [subject.val] - + return [*pre_res, subject.val] + dfs_task( dfs_subject=cyclic_graph_stnode, adj_func=lambda n: n.children, task_func=track_nodes, visited_signal_func=lambda x: x.val, - visited=[] + visited=[], ) - - assert visited_nodes == ['A', 'B', 'C', 'D'] + + assert visited_nodes == ["A", "B", "C", "D"] # boundary condition test - def test_single_node_traversal(self, binary_tree_root): + def test_single_node_traversal(self, binary_tree_root) -> None: """Test single node traversal""" + def identity_task(subject, pre_res=[]): - return pre_res + [subject.val] - + return [*pre_res, subject.val] + result = dfs_task( dfs_subject=binary_tree_root, adj_func=lambda _: [], task_func=identity_task, - visited=[] + visited=[], ) - + assert result == [1] - def test_custom_visit_signal(self): + def test_custom_visit_signal(self) -> None: """Test custom visit signal function""" visited_signals = [] - + def custom_signal(x): sig = f"CUSTOM_{x}" visited_signals.append(sig) return sig - + dfs_task( dfs_subject=1, - adj_func=lambda x: [x+1] if x < 3 else [], + adj_func=lambda x: [x + 1] if x < 3 else [], task_func=lambda subject, pre_res=[]: None, visited_signal_func=custom_signal, - visited=[] + visited=[], ) - + assert visited_signals == ["CUSTOM_1", "CUSTOM_2", "CUSTOM_3"] # Error handling test - def test_invalid_task_function(self): + def test_invalid_task_function(self) -> None: """Test invalid task function signature""" + def invalid_task(missing_arg): return missing_arg - + with pytest.raises(RuntimeError) as excinfo: dfs_task( dfs_subject=1, adj_func=lambda x: [], task_func=invalid_task, - visited=[] + visited=[], ) - + assert "missing following required args: ['subject', 'pre_res']" in str(excinfo.value).lower() # Special scenario testing - def test_mutable_default_visited(self): + def test_mutable_default_visited(self) -> None: """Test whether the default visited argument is isolated""" + def safe_task(subject, pre_res=[]): - return pre_res + [subject] - + return [*pre_res, subject] + result1 = dfs_task( dfs_subject="test", adj_func=lambda x: [], - task_func=safe_task + task_func=safe_task, ) - + result2 = dfs_task( dfs_subject="test", adj_func=lambda x: [], - task_func=safe_task + task_func=safe_task, ) - + assert result1 == ["test"] assert result2 == ["test"] # fail if the default value is shared diff --git a/tests/test_display.py b/tests/test_display.py index f45369b..1c34b3f 100644 --- a/tests/test_display.py +++ b/tests/test_display.py @@ -1,28 +1,23 @@ import os -from unittest.mock import ANY, Mock -from unittest.mock import call, patch +from unittest.mock import ANY, Mock, call, patch import pytest import torch.nn as nn from torch import randn as torch_randn -from rich.text import Text +from polars import Series, DataFrame from rich.rule import Rule +from rich.text import Text from rich.tree import Tree -from rich.table import Table from rich.panel import Panel +from rich.table import Table from rich.columns import Columns +from rich.console import Group, Console from rich.segment import Segment -from rich.console import Console, Group -from polars import DataFrame, Series from torchmeter.config import FlagNameSpace from torchmeter.engine import OperationNode, OperationTree -from torchmeter._stat_numeric import UpperLinkData, MetricsData, CountUnit -from torchmeter.display import ( - __cfg__, - dfs_task, render_perline, apply_setting, - TreeRenderer, TabularRenderer -) +from torchmeter.display import TreeRenderer, TabularRenderer, __cfg__, dfs_task, apply_setting, render_perline +from torchmeter._stat_numeric import CountUnit, MetricsData, UpperLinkData pytestmark = pytest.mark.vital @@ -34,49 +29,60 @@ TREE_CHILD2.add("2.1") TREE_CHILD2.add("2.2") -EXAMPLE_TABLE = Table("A","B") +EXAMPLE_TABLE = Table("A", "B") EXAMPLE_TABLE.add_row("1", "2") EXAMPLE_TABLE.add_row("3", "4") + class NoVPAObj: """No Variable Positional Arguments""" - def __init__(self, a, b, c=3): + + def __init__(self, a, b, c=3) -> None: self._a = a self._b = b self._c = c - + self._all = a + b + c + class VPAFObj: """Variable Positional Arguments at Front""" - def __init__(self, *a, b=2, c=3): + + def __init__(self, *a, b=2, c=3) -> None: self._a = a self._b = b self._c = c + class VPAMObj: """Variable Positional Arguments at Middle""" - def __init__(self, a, *b, c=3): + + def __init__(self, a, *b, c=3) -> None: self._a = a self._b = b self._c = c + class VPALObj: """Variable Positional Arguments as Last""" - def __init__(self, a, b, *c): + + def __init__(self, a, b, *c) -> None: self._a = a self._b = b self._c = c + class MixedArgsObj: """All types of arguments""" - def __init__(self, a, b=2, *c, d=4, **e): + + def __init__(self, a, b=2, *c, d=4, **e) -> None: self._a = a self._b = b self._c = c self._d = d self._e = e + @pytest.fixture def mock_console(): """Fixture providing a mocked console object""" @@ -84,63 +90,71 @@ def mock_console(): console.render_lines.return_value = [[Segment("line1")], [Segment("line2")]] return console + @pytest.fixture -def mock_config(monkeypatch): +def mock_config(monkeypatch) -> None: """Fixture to mock the __cfg__ object""" + class MockConfig: render_interval = 0.1 + monkeypatch.setattr("torchmeter.display.__cfg__", MockConfig()) + @pytest.fixture def simple_tree_renderer(): opnode = OperationNode(nn.Identity()) yield TreeRenderer(opnode) __cfg__.restore() + @pytest.fixture def repeat_tree_renderer(): class RepeatModel(nn.Module): - def __init__(self): + def __init__(self) -> None: super(RepeatModel, self).__init__() self.layer0 = nn.Linear(10, 10) self.layer1 = nn.Sequential( nn.Linear(10, 10), nn.ReLU(), nn.Linear(10, 10), - nn.ReLU() + nn.ReLU(), ) - + optree = OperationTree(RepeatModel()) - + yield TreeRenderer(optree.root) - + __cfg__.restore() + @pytest.fixture def simple_tabular_renderer(): opnode = OperationNode(nn.Identity()) yield TabularRenderer(opnode) __cfg__.restore() + @pytest.fixture def universal_tabular_renderer(): class UnuseModel(nn.Module): - def __init__(self): + def __init__(self) -> None: super(UnuseModel, self).__init__() - + self.conv = nn.Conv2d(3, 3, 3) self.unuse = nn.Identity() - + def forward(self, x): return self.conv(x) - + optree = OperationTree(UnuseModel()) tabular_renderer = TabularRenderer(optree.root) - + yield tabular_renderer - + __cfg__.restore() - + + @pytest.fixture def example_df(): """ @@ -154,311 +168,345 @@ def example_df(): β”‚ null ┆ c ┆ null ┆ null ┆ 0.00 Β± 0.00 β”‚ β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”΄β”€β”€β”€β”€β”€β”€β”΄β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”΄β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”΄β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜ """ - + from polars import Object as pl_obj df = DataFrame({ "numeric": [1, 2, None], "text": ["a", None, "c"], - "list_col": [[1,2], [3], None], + "list_col": [[1, 2], [3], None], "nomal_obj": [ - Text("example"), - Text("dataframe"), - None - ]}) - + Text("example"), + Text("dataframe"), + None, + ], + }) + self_obj_col = Series( name="self_obj", values=[ UpperLinkData(1e5, unit_sys=CountUnit, none_str="test none_str"), None, - MetricsData() + MetricsData(), ], - dtype = pl_obj + dtype=pl_obj, ) df.insert_column(len(df.columns), self_obj_col) - + return df + @pytest.fixture def export_dir(tmpdir): yield tmpdir.strpath if tmpdir.exists(): tmpdir.remove(rec=1) + class TestApplySetting: + def test_valid_usage(self) -> None: + """ "Test basic functionality and common usage cases""" - def test_valid_usage(self): - """"Test basic functionality and common usage cases""" - # all settings are changed obj = NoVPAObj(1, 2, 3) - apply_setting(obj, setting={"a":10, "b":20, "c":30}) + apply_setting(obj, setting={"a": 10, "b": 20, "c": 30}) assert obj._a == 10 assert obj._b == 20 assert obj._c == 30 # partial settings are changed obj = NoVPAObj(1, 2, 3) - apply_setting(obj, setting={"a":10, "b":20}) + apply_setting(obj, setting={"a": 10, "b": 20}) assert obj._a == 10 assert obj._b == 20 assert obj._c == 3 - - def test_invalid_usage(self): + + def test_invalid_usage(self) -> None: """Test invalid usage cases""" - + # invlaid setting type with pytest.raises(TypeError): - apply_setting(NoVPAObj(1,2,3), setting=10) - + apply_setting(NoVPAObj(1, 2, 3), setting=10) + # invalid omit type, see `test_omit_type` - + # required initializatio argument absent with pytest.raises(RuntimeError) as e: - apply_setting(NoVPAObj(1,2,3), setting={'a':10}) + apply_setting(NoVPAObj(1, 2, 3), setting={"a": 10}) assert "`b` unknown" in str(e.value) - def test_setting(self): + def test_setting(self) -> None: """Test the logic of getting setting_dict""" - + # use FlagNameSpace to store the setting obj = NoVPAObj(1, 2, 3) apply_setting(obj, setting=FlagNameSpace(a=100, b=200, c=300)) assert obj._a == 100 assert obj._b == 200 assert obj._c == 300 - + # use dict to store the setting obj = NoVPAObj(1, 2, 3) - apply_setting(obj, setting={"a":1000, "b":2000, "c":3000}) + apply_setting(obj, setting={"a": 1000, "b": 2000, "c": 3000}) assert obj._a == 1000 assert obj._b == 2000 assert obj._c == 3000 - + # update setting with extra_settings obj = NoVPAObj(1, 2, 3) - apply_setting(obj, setting={"a":1000, "b":2000, "c":3000}, - c=30) + apply_setting(obj, setting={"a": 1000, "b": 2000, "c": 3000}, c=30) assert obj._a == 1000 assert obj._b == 2000 assert obj._c == 30 - + # invalid setting type with pytest.raises(TypeError): - apply_setting(NoVPAObj(1,2,3), setting=10) + apply_setting(NoVPAObj(1, 2, 3), setting=10) @pytest.mark.parametrize( argnames=("omit_args", "is_error", "key_error_info"), argvalues=[ - (None, False, None), # None - ("a", False, None), # str - (["a", "b"], False, None), # list - (("a", "c"), False, None), # list - ({"a", "b", "c"}, False, None), # set - ({"a":1, "b":2, "c":3}, True, "but got `dict`"), # dict - (123, True, "but got `int`"), # int - ([1,2,3], True, "`list` of `int`") # container of non-str - ] + (None, False, None), # None + ("a", False, None), # str + (["a", "b"], False, None), # list + (("a", "c"), False, None), # list + ({"a", "b", "c"}, False, None), # set + ({"a": 1, "b": 2, "c": 3}, True, "but got `dict`"), # dict + (123, True, "but got `int`"), # int + ([1, 2, 3], True, "`list` of `int`"), # container of non-str + ], ) - def test_omit_type(self, omit_args, is_error, key_error_info): + def test_omit_type(self, omit_args, is_error, key_error_info) -> None: """Test the pass-in type limitation""" if is_error: with pytest.raises(TypeError) as e: - apply_setting(NoVPAObj(1, 2, c=10), - setting={"a":10, "b":20, "c":30}, - omit=omit_args) + apply_setting( + NoVPAObj(1, 2, c=10), + setting={"a": 10, "b": 20, "c": 30}, + omit=omit_args, + ) assert key_error_info in str(e.value) else: - apply_setting(NoVPAObj(1, 2, c=10), - setting={"a":10, "b":20, "c":30}, - omit=omit_args) - + apply_setting( + NoVPAObj(1, 2, c=10), + setting={"a": 10, "b": 20, "c": 30}, + omit=omit_args, + ) + @pytest.mark.parametrize( argnames=("obj", "setting", "omit_args", "expected_state"), argvalues=[ # omit one argument - (NoVPAObj(1, 2, 10), {"a":10, "b":20, "c":30}, "_c", {"_a":10, "_b":20, "_c":10, "_all":60}), - + (NoVPAObj(1, 2, 10), {"a": 10, "b": 20, "c": 30}, "_c", {"_a": 10, "_b": 20, "_c": 10, "_all": 60}), # not omit one argument - (NoVPAObj(1, 2, 10), {"a":10, "b":20, "c":30}, "", {"_a":10, "_b":20, "_c":30, "_all":60}), - + (NoVPAObj(1, 2, 10), {"a": 10, "b": 20, "c": 30}, "", {"_a": 10, "_b": 20, "_c": 30, "_all": 60}), # omit multiple arguments - (NoVPAObj(1, 2, 10), {"a":10, "b":20, "c":30}, ["_a", "_b"], {"_a":1, "_b":2, "_c":30, "_all":60}), - (NoVPAObj(1, 2, 10), {"a":10, "b":20, "c":30}, ("_a", "_c"), {"_a":1, "_b":20, "_c":10, "_all":60}), - (NoVPAObj(1, 2, 10), {"a":10, "b":20, "c":30}, {"_a", "_b", "_c"}, {"_a":1, "_b":2, "_c":10, "_all":60}), - - # not omit multiple arguments - (NoVPAObj(1, 2, 10), {"a":10, "b":20, "c":30}, "", {"_a":10, "_b":20, "_c":30, "_all":60}), - + (NoVPAObj(1, 2, 10), {"a": 10, "b": 20, "c": 30}, ["_a", "_b"], {"_a": 1, "_b": 2, "_c": 30, "_all": 60}), + (NoVPAObj(1, 2, 10), {"a": 10, "b": 20, "c": 30}, ("_a", "_c"), {"_a": 1, "_b": 20, "_c": 10, "_all": 60}), + ( + NoVPAObj(1, 2, 10), + {"a": 10, "b": 20, "c": 30}, + {"_a", "_b", "_c"}, + {"_a": 1, "_b": 2, "_c": 10, "_all": 60}, + ), # omit variable positional arguments - (VPAFObj(1, 2, 3, 4, 5), {"a":[7,8,9], "b":40, "c":50}, "_a", {"_a":(1,2,3,4,5), "_b":40, "_c":50}), - (VPAMObj(1, 2, 3, 4, 5), {"a":10, "b":[7,8,9], "c":50}, "_b", {"_a":10, "_b":(2,3,4,5), "_c":50}), - (VPALObj(1, 2, 3, 4, 5), {"a":10, "b":20, "c":[7,8,9]}, "_c", {"_a":10, "_b":20, "_c":(3,4,5)}), - + ( + VPAFObj(1, 2, 3, 4, 5), + {"a": [7, 8, 9], "b": 40, "c": 50}, + "_a", + {"_a": (1, 2, 3, 4, 5), "_b": 40, "_c": 50}, + ), + ( + VPAMObj(1, 2, 3, 4, 5), + {"a": 10, "b": [7, 8, 9], "c": 50}, + "_b", + {"_a": 10, "_b": (2, 3, 4, 5), "_c": 50}, + ), + (VPALObj(1, 2, 3, 4, 5), {"a": 10, "b": 20, "c": [7, 8, 9]}, "_c", {"_a": 10, "_b": 20, "_c": (3, 4, 5)}), # not omit variable positional arguments - (VPAFObj(1, 2, 3, 4, 5), {"a":[7,8,9], "b":40, "c":50}, "", {"_a":(7,8,9), "_b":40, "_c":50}), - (VPAMObj(1, 2, 3, 4, 5), {"a":10, "b":[7,8,9], "c":50}, "", {"_a":10, "_b":(7,8,9), "_c":50}), - (VPALObj(1, 2, 3, 4, 5), {"a":10, "b":20, "c":[7,8,9]}, "", {"_a":10, "_b":20, "_c":(7,8,9)}), - + (VPAFObj(1, 2, 3, 4, 5), {"a": [7, 8, 9], "b": 40, "c": 50}, "", {"_a": (7, 8, 9), "_b": 40, "_c": 50}), + (VPAMObj(1, 2, 3, 4, 5), {"a": 10, "b": [7, 8, 9], "c": 50}, "", {"_a": 10, "_b": (7, 8, 9), "_c": 50}), + (VPALObj(1, 2, 3, 4, 5), {"a": 10, "b": 20, "c": [7, 8, 9]}, "", {"_a": 10, "_b": 20, "_c": (7, 8, 9)}), # omit mixed arguments - (MixedArgsObj(2, 4, 6, 8, d=10, f=12), {"a":10, "b":20, "c":[1,2,3], "d":40, "g":"G"}, - "_a", {"_a":2, "_b":20, "_c":(1,2,3), "_d":40, "_e":{"g":"G"}}), - (MixedArgsObj(2, 4, 6, 8, d=10, f=12), {"a":10, "b":20, "c":[1,2,3], "d":40, "g":"G"}, - "_c", {"_a":10, "_b":20, "_c":(6,8), "_d":40, "_e":{"g":"G"}}), - (MixedArgsObj(2, 4, 6, 8, d=10, f=12), {"a":10, "b":20, "c":[1,2,3], "d":40, "g":"G"}, - "_d", {"_a":10, "_b":20, "_c":(1,2,3), "_d":10, "_e":{"g":"G"}}), - (MixedArgsObj(2, 4, 6, 8, d=10, f=12), {"a":10, "b":20, "c":[1,2,3], "d":40, "g":"G"}, - "_e", {"_a":10, "_b":20, "_c":(1,2,3), "_d":40, "_e":{"f":12}}), - + ( + MixedArgsObj(2, 4, 6, 8, d=10, f=12), + {"a": 10, "b": 20, "c": [1, 2, 3], "d": 40, "g": "G"}, + "_a", + {"_a": 2, "_b": 20, "_c": (1, 2, 3), "_d": 40, "_e": {"g": "G"}}, + ), + ( + MixedArgsObj(2, 4, 6, 8, d=10, f=12), + {"a": 10, "b": 20, "c": [1, 2, 3], "d": 40, "g": "G"}, + "_c", + {"_a": 10, "_b": 20, "_c": (6, 8), "_d": 40, "_e": {"g": "G"}}, + ), + ( + MixedArgsObj(2, 4, 6, 8, d=10, f=12), + {"a": 10, "b": 20, "c": [1, 2, 3], "d": 40, "g": "G"}, + "_d", + {"_a": 10, "_b": 20, "_c": (1, 2, 3), "_d": 10, "_e": {"g": "G"}}, + ), + ( + MixedArgsObj(2, 4, 6, 8, d=10, f=12), + {"a": 10, "b": 20, "c": [1, 2, 3], "d": 40, "g": "G"}, + "_e", + {"_a": 10, "_b": 20, "_c": (1, 2, 3), "_d": 40, "_e": {"f": 12}}, + ), # not omit mixed arguments - (MixedArgsObj(2, 4, 6, 8, d=10, f=12), {"a":10, "b":20, "c":[1,2,3], "d":40, "g":"G"}, - "", {"_a":10, "_b":20, "_c":(1,2,3), "_d":40, "_e":{"g":"G"}}), - ] + ( + MixedArgsObj(2, 4, 6, 8, d=10, f=12), + {"a": 10, "b": 20, "c": [1, 2, 3], "d": 40, "g": "G"}, + "", + {"_a": 10, "_b": 20, "_c": (1, 2, 3), "_d": 40, "_e": {"g": "G"}}, + ), + ], ) - def test_omit_logic(self, obj, setting, omit_args, expected_state): + def test_omit_logic(self, obj, setting, omit_args, expected_state) -> None: """Test the logic of omitting the update of specified arguments""" apply_setting(obj, setting, omit=omit_args) assert obj.__dict__ == expected_state - def test_slots_object(self): + def test_slots_object(self) -> None: """Test the logic of dealing with slots object""" + class OneSlotObj: __slots__ = "_a" - def __init__(self, a): + + def __init__(self, a) -> None: self._a = a class MultiSlotObj: __slots__ = ["_a", "_b", "_c"] - def __init__(self, a, b, c=3): + + def __init__(self, a, b, c=3) -> None: self._a = a self._b = b self._c = c # slots just have one attribute obj = OneSlotObj(1) - apply_setting(obj, {'a': 10}) + apply_setting(obj, {"a": 10}) assert obj._a == 10 # slots just multi attributes - obj = MultiSlotObj(1,2,3) - apply_setting(obj, {"a": 10, "b":20}) + obj = MultiSlotObj(1, 2, 3) + apply_setting(obj, {"a": 10, "b": 20}) assert obj._a == 10 assert obj._b == 20 assert obj._c == 3 - def test_private_property(self): + def test_private_property(self) -> None: """Test whether the function works well when the inner attribute is private""" + class PrivateObj: - def __init__(self, a): + def __init__(self, a) -> None: self.__a = a - + @property def a_val(self): return self.__a - + class PrivateSlotObj: __slots__ = "__a" - def __init__(self, a): + + def __init__(self, a) -> None: self.__a = a - + @property def a_val(self): return self.__a - + obj = PrivateObj(1) - apply_setting(obj, setting={'a': 10}) + apply_setting(obj, setting={"a": 10}) assert obj.a_val == 10 - + obj = PrivateSlotObj(1) - apply_setting(obj, setting={'a': 10}) + apply_setting(obj, setting={"a": 10}) assert obj.a_val == 10 - def test_indirect_property(self): + def test_indirect_property(self) -> None: """Test whether indirect initialization properties will change synchronously.""" + class IndirectObj: - def __init__(self, a): + def __init__(self, a) -> None: self.a = a - self.computed = a * 2 + self.computed = a * 2 obj = IndirectObj(2) - apply_setting(obj, setting={'a': 5}) + apply_setting(obj, setting={"a": 5}) assert obj.a == 5 - assert obj.computed == 10 + assert obj.computed == 10 - def test_inplace_update(self): + def test_inplace_update(self) -> None: """Test whether the settings are updated inplace""" + class Child: - def __init__(self, value): + def __init__(self, value) -> None: self.value = value class Parent: - def __init__(self, child: Child): + def __init__(self, child: Child) -> None: self.child = child child = Child(1) parent = Parent(child) - + apply_setting(child, setting={"value": 10}) assert child.value == 10 assert parent.child.value == 10 - def test_edge_cases(self): + def test_edge_cases(self) -> None: # omit list is empty - obj = NoVPAObj(1, 2, 3) - apply_setting(obj, - setting={"a":10, "b":20, "c":30}, - omit=[]) + obj = NoVPAObj(1, 2, 3) + apply_setting( + obj, + setting={"a": 10, "b": 20, "c": 30}, + omit=[], + ) assert obj._a == 10 assert obj._b == 20 assert obj._c == 30 + class TestRenderPerline: - def test_negative_interval(self, mock_config, monkeypatch): + def test_negative_interval(self, mock_config, monkeypatch) -> None: """Test ValueError when render_interval is negative""" monkeypatch.setattr("torchmeter.display.__cfg__.render_interval", -0.5) with pytest.raises(ValueError) as excinfo: render_perline("test") assert "non-negative" in str(excinfo.value) - def test_instant_render(self, mock_config, mock_console, monkeypatch): + def test_instant_render(self, mock_config, mock_console, monkeypatch) -> None: """Test immediate rendering when time_sep is 0""" - with patch("rich.get_console", return_value=mock_console), \ - patch("time.sleep") as mock_sleep: - + with patch("rich.get_console", return_value=mock_console), patch("time.sleep") as mock_sleep: monkeypatch.setattr("torchmeter.display.__cfg__.render_interval", 0) - + render_perline("test_content") - + # Verify console.print called once mock_console.print.assert_called_once_with("test_content") # and no sleep calls mock_sleep.assert_not_called() - def test_render_line_by_line(self, mock_config, mock_console, monkeypatch): + def test_render_line_by_line(self, mock_config, mock_console, monkeypatch) -> None: """Test line-by-line rendering with time interval""" - with patch("rich.get_console", return_value=mock_console), \ - patch("time.sleep") as mock_sleep: - + with patch("rich.get_console", return_value=mock_console), patch("time.sleep") as mock_sleep: monkeypatch.setattr("torchmeter.display.__cfg__.render_interval", 0.1) - + render_perline("multi\nline\ncontent") - + # Verify render_lines called - mock_console.render_lines.assert_called_once_with( - "multi\nline\ncontent", - new_lines=True - ) - + mock_console.render_lines.assert_called_once_with("multi\nline\ncontent", new_lines=True) + # Verify buffer operations assert mock_console._buffer_index == 0 mock_console._buffer.extend.assert_has_calls([ - call([Segment("line1")]), # define in mock_console - call([Segment("line2")]) + call([Segment("line1")]), # define in mock_console + call([Segment("line2")]), ]) - + # Verify sleep calls between lines mock_sleep.assert_has_calls([call(0.1), call(0.1)]) @@ -477,31 +525,29 @@ def test_render_line_by_line(self, mock_config, mock_console, monkeypatch): (EXAMPLE_TREE, 7), (Table(), 3), (EXAMPLE_TABLE, 7), - ] + ], ) - def test_various_content(self, content, render_lines_num, - mock_console, monkeypatch): + def test_various_content(self, content, render_lines_num, mock_console, monkeypatch) -> None: """Test handling empty renderable content""" - - + with patch("rich.get_console", return_value=mock_console), \ - patch("time.sleep") as mock_sleep: - + patch("time.sleep") as mock_sleep: # fmt: skip # no time interval monkeypatch.setattr("torchmeter.display.__cfg__.render_interval", 0) render_perline(content) mock_console.render_lines.assert_not_called() - + # with time interval monkeypatch.setattr("torchmeter.display.__cfg__.render_interval", 0.15) render_perline(content) mock_sleep.call_count == render_lines_num + class TestTreeRenderer: - def teardown_method(self, method): + def teardown_method(self, method) -> None: __cfg__.restore() - - def test_valid_init(self, simple_tree_renderer): + + def test_valid_init(self, simple_tree_renderer) -> None: """Test valid initialization""" assert isinstance(simple_tree_renderer.opnode, OperationNode) assert simple_tree_renderer.render_unfold_tree is None @@ -509,29 +555,37 @@ def test_valid_init(self, simple_tree_renderer): assert isinstance(simple_tree_renderer.loop_algebras, str) assert len(simple_tree_renderer.loop_algebras) >= 10 - def test_invalid_init(self): + def test_invalid_init(self) -> None: """Test invalid initialization""" with pytest.raises(TypeError): TreeRenderer(1) - def test_default_level_args(self, simple_tree_renderer): + def test_default_level_args(self, simple_tree_renderer) -> None: """Test if default_level_args is set and retrieved correctly""" - + # retrieve - ## when default settings is defined + # when default settings is defined default_args = simple_tree_renderer.default_level_args assert isinstance(default_args, FlagNameSpace) assert default_args.is_change() # newly created, mark as changed assert hasattr(default_args, "label") - - ## when default settings is not defined + + # when default settings is not defined delattr(simple_tree_renderer.tree_levels_args, "default") default_args = simple_tree_renderer.default_level_args assert isinstance(default_args, FlagNameSpace) assert default_args.is_change() - assert all(hasattr(default_args, f) - for f in ['label', 'style', 'guide_style', # define in display.py::TreeRenderer::default_level_args - 'highlight', 'hide_root', 'expanded']) + assert all( + hasattr(default_args, f) + for f in [ + "label", + "style", + "guide_style", # define in display.py::TreeRenderer::default_level_args + "highlight", + "hide_root", + "expanded", + ] + ) # new a single field default_args.mark_unchange() @@ -546,38 +600,44 @@ def test_default_level_args(self, simple_tree_renderer): assert default_args.style == "cyan" # overwrite - ## with invalid type + # with invalid type with pytest.raises(TypeError): simple_tree_renderer.default_level_args = 1 - - ## with invalid field + + # with invalid field with pytest.raises(KeyError): - simple_tree_renderer.default_level_args = {'invalid_field': 'value'} - - ## with combination of invalid field and valid field + simple_tree_renderer.default_level_args = {"invalid_field": "value"} + + # with combination of invalid field and valid field with pytest.raises(KeyError): - simple_tree_renderer.default_level_args = {'invalid_field': 'value', - 'label': "test"} - - ## update nothing + simple_tree_renderer.default_level_args = {"invalid_field": "value", "label": "test"} + + # update nothing default_args.mark_unchange() simple_tree_renderer.default_level_args = {} assert default_args.is_change() - assert all(hasattr(default_args, f) - for f in ['label', 'style', 'guide_style', # define in display.py::TreeRenderer::default_level_args - 'highlight', 'hide_root', 'expanded']) + assert all( + hasattr(default_args, f) + for f in [ + "label", + "style", + "guide_style", # define in display.py::TreeRenderer::default_level_args + "highlight", + "hide_root", + "expanded", + ] + ) - ## with parts of valid fields + # with parts of valid fields default_args.mark_unchange() - simple_tree_renderer.default_level_args = {'label': "test", - 'style': "magenta"} + simple_tree_renderer.default_level_args = {"label": "test", "style": "magenta"} assert default_args.is_change() assert default_args.label == "test" assert default_args.style == "magenta" - - def test_tree_levels_args(self, simple_tree_renderer): + + def test_tree_levels_args(self, simple_tree_renderer) -> None: """Test if tree_levels_args is set and retrieved correctly""" - + # retrieve levels_args = simple_tree_renderer.tree_levels_args assert isinstance(levels_args, FlagNameSpace) @@ -600,74 +660,72 @@ def test_tree_levels_args(self, simple_tree_renderer): assert level_0_settings.label == "level zero" # overwrite - ## with invalid type + # with invalid type with pytest.raises(TypeError): simple_tree_renderer.tree_levels_args = 1 - - ## with invalid field + + # with invalid field with pytest.raises(KeyError): - simple_tree_renderer.tree_levels_args = {'default':{'invalid_field': 'value'}} - - ## with combination of invalid field and valid field + simple_tree_renderer.tree_levels_args = {"default": {"invalid_field": "value"}} + + # with combination of invalid field and valid field with pytest.raises(KeyError): - simple_tree_renderer.tree_levels_args = {"default":{'invalid_field': 'value', - 'label': "test"}} - - ## with invalid level + simple_tree_renderer.tree_levels_args = {"default": {"invalid_field": "value", "label": "test"}} + + # with invalid level levels_args.mark_unchange() with pytest.warns(UserWarning): - simple_tree_renderer.tree_levels_args = {"invalid_level": {'label': "test"}} + simple_tree_renderer.tree_levels_args = {"invalid_level": {"label": "test"}} assert levels_args.is_change() assert not hasattr(levels_args, "invalid_level") - - ## assign `default` settings + # assign `default` settings levels_args.mark_unchange() - simple_tree_renderer.tree_levels_args = {"default": {'label': "test"}} + simple_tree_renderer.tree_levels_args = {"default": {"label": "test"}} assert levels_args.default.label == "test" assert levels_args.is_change() - ## assign `all` settings + # assign `all` settings levels_args.mark_unchange() - simple_tree_renderer.tree_levels_args = {"1": {'label': "test"}} - simple_tree_renderer.tree_levels_args = {"all": {'label': "all label"}} + simple_tree_renderer.tree_levels_args = {"1": {"label": "test"}} + simple_tree_renderer.tree_levels_args = {"all": {"label": "all label"}} assert levels_args.is_change() assert levels_args.default.label == "all label" assert not hasattr(levels_args, "0") assert not hasattr(levels_args, "1") assert not hasattr(levels_args, "all") - - ## update nothing + + # update nothing levels_args.mark_unchange() simple_tree_renderer.tree_levels_args = {} assert levels_args.is_change() assert hasattr(levels_args, "default") - - ## update and new level settings + + # update and new level settings levels_args.mark_unchange() - simple_tree_renderer.tree_levels_args = {"default":{"guide_style": "blue"}, - "3": {"label": "label 3", - "style": "green", - "guide_style": "red"}} + simple_tree_renderer.tree_levels_args = { + "default": {"guide_style": "blue"}, + "3": {"label": "label 3", "style": "green", "guide_style": "red"}, + } assert levels_args.is_change() - assert hasattr(simple_tree_renderer.default_level_args, "label") # other fields will not be deleted + assert hasattr(simple_tree_renderer.default_level_args, "label") # other fields will not be deleted assert hasattr(levels_args, "3") assert simple_tree_renderer.default_level_args.guide_style == "blue" - + level_3_settings = getattr(levels_args, "3") assert level_3_settings.label == "label 3" assert level_3_settings.style == "green" assert level_3_settings.guide_style == "red" - - ## verify level case insensitive + + # verify level case insensitive levels_args.mark_unchange() simple_tree_renderer.tree_levels_args = {"DeFaulT": {"highlight": False}} assert levels_args.is_change() assert simple_tree_renderer.default_level_args.highlight is False - def test_repeat_block_args(self, simple_tree_renderer): + def test_repeat_block_args(self, simple_tree_renderer) -> None: """Test if repeat_block_args is set and retrieved correctly""" - + # retrieve rpbk_args = simple_tree_renderer.repeat_block_args assert isinstance(rpbk_args, FlagNameSpace) @@ -686,50 +744,53 @@ def test_repeat_block_args(self, simple_tree_renderer): assert rpbk_args.title_align == "left" # overwrite - ## with invalid type + # with invalid type with pytest.raises(TypeError): simple_tree_renderer.repeat_block_args = 1 - - ## with invalid field + + # with invalid field with pytest.raises(KeyError): - simple_tree_renderer.repeat_block_args = {'invalid_field': 'value'} - - ## with combination of invalid field and valid field + simple_tree_renderer.repeat_block_args = {"invalid_field": "value"} + + # with combination of invalid field and valid field with pytest.raises(KeyError): - simple_tree_renderer.repeat_block_args = {'invalid_field': 'value', - 'border_style': 'yellow'} - - ## update nothing + simple_tree_renderer.repeat_block_args = {"invalid_field": "value", "border_style": "yellow"} + + # update nothing rpbk_args.mark_unchange() simple_tree_renderer.repeat_block_args = {} assert rpbk_args.is_change() assert hasattr(rpbk_args, "title") - - ## update several settings without repeat_footer + + # update several settings without repeat_footer rpbk_args.mark_unchange() - simple_tree_renderer.repeat_block_args = {"subtitle": "this is a subtitle", - "subtitle_align": "left", - "style": "cyan"} + simple_tree_renderer.repeat_block_args = { + "subtitle": "this is a subtitle", + "subtitle_align": "left", + "style": "cyan", + } assert rpbk_args.is_change() assert rpbk_args.subtitle == "this is a subtitle" assert rpbk_args.subtitle_align == "left" assert rpbk_args.style == "cyan" - - ## update several settings with repeat_footer + + # update several settings with repeat_footer rpbk_args.mark_unchange() - simple_tree_renderer.repeat_block_args = {"style": "red", - "repeat_footer": lambda :"Footer"} + simple_tree_renderer.repeat_block_args = { + "style": "red", + "repeat_footer": lambda: "Footer", + } assert rpbk_args.is_change() assert rpbk_args.style == "red" assert not hasattr(rpbk_args, "repeat_footer") assert simple_tree_renderer.repeat_footer == "Footer" - def test_repeat_footer(self, simple_tree_renderer): + def test_repeat_footer(self, simple_tree_renderer) -> None: """Test if repeat_footer is set and retrieved correctly""" - + from inspect import signature - - ## retrieve + + # retrieve repeat_footer = simple_tree_renderer.repeat_footer rpbk_args = simple_tree_renderer.repeat_block_args assert rpbk_args.is_change() @@ -737,59 +798,59 @@ def test_repeat_footer(self, simple_tree_renderer): if callable(repeat_footer): args_num = len(signature(repeat_footer).parameters) assert args_num <= 1 - + if not args_num: res = repeat_footer() assert isinstance(res, (type(None), str)) - - ## set with None + + # set with None rpbk_args.mark_unchange() simple_tree_renderer.repeat_footer = None assert simple_tree_renderer.repeat_footer is None assert rpbk_args.is_change() - - ## set with str + + # set with str rpbk_args.mark_unchange() simple_tree_renderer.repeat_footer = "Footer" assert simple_tree_renderer.repeat_footer == "Footer" assert rpbk_args.is_change() - ## set with no arg function + # set with no arg function rpbk_args.mark_unchange() - simple_tree_renderer.repeat_footer = lambda :"footer" + simple_tree_renderer.repeat_footer = lambda: "footer" assert simple_tree_renderer.repeat_footer == "footer" assert rpbk_args.is_change() - + with pytest.raises(RuntimeError): - simple_tree_renderer.repeat_footer = lambda :2 - - ## set with one arg function + simple_tree_renderer.repeat_footer = lambda: 2 + + # set with one arg function rpbk_args.mark_unchange() simple_tree_renderer.repeat_footer = lambda x: f"Footer {x}" assert simple_tree_renderer.repeat_footer("test") == "Footer test" assert rpbk_args.is_change() - - ## set with more than one args function + + # set with more than one args function with pytest.raises(RuntimeError): simple_tree_renderer.repeat_footer = lambda x, y: f"Footer {x} {y}" - - ## set with invalid input + + # set with invalid input with pytest.raises(RuntimeError): simple_tree_renderer.repeat_footer = 33 - def test_default_footer(self, monkeypatch): + def test_default_footer(self, monkeypatch) -> None: """Test the default_rpft method logic""" from random import sample monkeypatch.setattr("torchmeter.display.TreeRenderer.loop_algebras", "xy") class RepeatWinszModel(nn.Module): - def __init__(self, repeat_winsz=1, repeat_time=3): + def __init__(self, repeat_winsz=1, repeat_time=3) -> None: super(RepeatWinszModel, self).__init__() - + candidate_layers = ( - nn.Linear(1,10), - nn.Conv2d(3,10,1), + nn.Linear(1, 10), + nn.Conv2d(3, 10, 1), nn.MaxPool2d(3), nn.AvgPool2d(3), nn.BatchNorm2d(10), @@ -798,9 +859,8 @@ def __init__(self, repeat_winsz=1, repeat_time=3): nn.Identity(), ) - self.layer = nn.ModuleList(repeat_time * - sample(candidate_layers, repeat_winsz)) - + self.layer = nn.ModuleList(repeat_time * sample(candidate_layers, repeat_winsz)) + model = RepeatWinszModel(repeat_winsz=1, repeat_time=3) optree = OperationTree(model) tree_renderer = TreeRenderer(optree.root) @@ -809,51 +869,51 @@ def __init__(self, repeat_winsz=1, repeat_time=3): footer = res.children[0].children[0].label.renderable.renderables[-1] footer_str = Text.from_markup(footer).plain assert footer_str == "Where x ∈ [1, 3]" - + model = RepeatWinszModel(repeat_winsz=2, repeat_time=3) optree = OperationTree(model) tree_renderer = TreeRenderer(optree.root) - + res = tree_renderer() footer = res.children[0].children[0].label.renderable.renderables[-1] footer_str = Text.from_markup(footer).plain - assert footer_str == "Where x = 1, 3, 5" + assert footer_str == "Where x = 1, 3, 5" - def test_resolve_attr(self, simple_tree_renderer): + def test_resolve_attr(self, simple_tree_renderer) -> None: """Test whether the resolve_attr method works well""" simple_tree_renderer.resolve_attr = lambda x: str(x) mock_node = Mock(node_id="1.2") result = simple_tree_renderer._TreeRenderer__resolve_argtext( - text="Node ", - attr_owner=mock_node + text="Node ", + attr_owner=mock_node, ) assert result == "Node 1.2" - + simple_tree_renderer.resolve_attr = lambda x: f"Custom_{x}" mock_node = Mock(node_id="1.2") result = simple_tree_renderer._TreeRenderer__resolve_argtext( - text="Node ", - attr_owner=mock_node + text="Node ", + attr_owner=mock_node, ) assert result == "Node Custom_1.2" - + simple_tree_renderer.resolve_attr = lambda x: 12345 mock_node = Mock(node_id="1.2") result = simple_tree_renderer._TreeRenderer__resolve_argtext( - text="Node ", - attr_owner=mock_node + text="Node ", + attr_owner=mock_node, ) assert result == "Node 12345" - def test_resolve_argtext(self, simple_tree_renderer): + def test_resolve_argtext(self, simple_tree_renderer) -> None: """Test whether argtext is resolved correctly""" - + simple_tree_renderer.resolve_attr = lambda x: str(x) # resolve placeholders mock_node = Mock(node_id="1.2") result = simple_tree_renderer._TreeRenderer__resolve_argtext( - text="Node ", - attr_owner=mock_node + text="Node ", + attr_owner=mock_node, ) assert result == "Node 1.2" @@ -862,66 +922,74 @@ def test_resolve_argtext(self, simple_tree_renderer): mock_node = Mock(node_name="test", node_id="1") result = simple_tree_renderer._TreeRenderer__resolve_argtext( text=text, - attr_owner=mock_node + attr_owner=mock_node, ) assert result == " test" - + # resolve with extra args mock_node = Mock(node_id="1.2") result = simple_tree_renderer._TreeRenderer__resolve_argtext( - text="Node ", + text="Node ", attr_owner=mock_node, - node_name="test" + node_name="test", ) assert result == "Node 1.2 test" - - def test_loop_algebra_rotation(self): + + def test_loop_algebra_rotation(self) -> None: """Test algebraic symbol cyclic rotation logic""" - + # no use algebras optree = OperationTree(nn.Identity()) tree_renderer = TreeRenderer(optree.root) - tree_renderer.loop_algebras = 'ab' + tree_renderer.loop_algebras = "ab" tree_renderer() - assert tree_renderer.loop_algebras == 'ab' - + assert tree_renderer.loop_algebras == "ab" + # use one - optree = OperationTree(nn.Sequential(nn.Identity(), - nn.Identity(), - nn.Identity())) + optree = OperationTree( + nn.Sequential(nn.Identity(), nn.Identity(), nn.Identity()), + ) tree_renderer = TreeRenderer(optree.root) - tree_renderer.loop_algebras = 'ab' + tree_renderer.loop_algebras = "ab" tree_renderer() - assert tree_renderer.loop_algebras == 'ba' - + assert tree_renderer.loop_algebras == "ba" + # use twice - optree = OperationTree(nn.Sequential(nn.Identity(), - nn.Identity(), - nn.Identity(), - nn.ReLU(), - nn.Identity(), - nn.Identity())) + optree = OperationTree( + nn.Sequential( + nn.Identity(), + nn.Identity(), + nn.Identity(), + nn.ReLU(), + nn.Identity(), + nn.Identity(), + ) + ) tree_renderer = TreeRenderer(optree.root) - tree_renderer.loop_algebras = 'ab' + tree_renderer.loop_algebras = "ab" tree_renderer() - assert tree_renderer.loop_algebras == 'ab' - + assert tree_renderer.loop_algebras == "ab" + # use over preset len - optree = OperationTree(nn.Sequential(nn.Identity(), - nn.Identity(), - nn.Identity(), - nn.ReLU(), - nn.Identity(), - nn.Identity(), - nn.BatchNorm1d(10), - nn.Identity(), - nn.Identity())) + optree = OperationTree( + nn.Sequential( + nn.Identity(), + nn.Identity(), + nn.Identity(), + nn.ReLU(), + nn.Identity(), + nn.Identity(), + nn.BatchNorm1d(10), + nn.Identity(), + nn.Identity(), + ) + ) tree_renderer = TreeRenderer(optree.root) - tree_renderer.loop_algebras = 'ab' + tree_renderer.loop_algebras = "ab" tree_renderer() - assert tree_renderer.loop_algebras == 'ba' - - def test_fold_repeat(self, repeat_tree_renderer, monkeypatch): + assert tree_renderer.loop_algebras == "ba" + + def test_fold_repeat(self, repeat_tree_renderer, monkeypatch) -> None: """Test whether the fold_repeat option works well""" monkeypatch.setattr("torchmeter.display.__cfg__.tree_fold_repeat", True) @@ -938,12 +1006,12 @@ def test_fold_repeat(self, repeat_tree_renderer, monkeypatch): assert isinstance(res.children[1].children[0].label, str) assert isinstance(res.children[1].children[0], Tree) - def test_isolated_rendering(self, repeat_tree_renderer, monkeypatch): + def test_isolated_rendering(self, repeat_tree_renderer, monkeypatch) -> None: """Test whether the rendering is performed in a deepcopy tree""" - + monkeypatch.setattr("torchmeter.display.__cfg__.tree_fold_repeat", True) origin_tree = repeat_tree_renderer.opnode.display_root - + res = repeat_tree_renderer() # not pollute the original display tree @@ -955,15 +1023,15 @@ def test_isolated_rendering(self, repeat_tree_renderer, monkeypatch): # not pollute the original operation tree oproot = repeat_tree_renderer.opnode - assert all(c.node_id == f"2.{c_idx+1}" - for c_idx, c in enumerate(oproot.childs["2"].childs.values())) + assert all(c.node_id == f"2.{c_idx + 1}" + for c_idx, c in enumerate(oproot.childs["2"].childs.values())) # fmt: skip - def test_node_id_generation(self, repeat_tree_renderer, monkeypatch): + def test_node_id_generation(self, repeat_tree_renderer, monkeypatch) -> None: """Test the generation logic of tree label""" repeat_tree_renderer.loop_algebras = "xx" repeat_tree_renderer.repeat_footer = None - repeat_tree_renderer.tree_levels_args = {"all":{"label": ""}} + repeat_tree_renderer.tree_levels_args = {"all": {"label": ""}} # fold_repeat = True monkeypatch.setattr("torchmeter.display.__cfg__.tree_fold_repeat", True) @@ -977,8 +1045,8 @@ def test_node_id_generation(self, repeat_tree_renderer, monkeypatch): child_2_2 = repeat_block_inner_tree.children[1] # no child_2_3 and child_2_4, cause they are folded and skipped in rendering - assert len(repeat_block_inner_tree.children) == 2 - + assert len(repeat_block_inner_tree.children) == 2 + assert all(isinstance(c, Tree) for c in [child_1, child_2, child_2_1, child_2_2]) assert child_1.label == "1" assert child_2.label == "2" @@ -995,8 +1063,7 @@ def test_node_id_generation(self, repeat_tree_renderer, monkeypatch): child_2_2 = child_2.children[1] child_2_3 = child_2.children[2] child_2_4 = child_2.children[3] - assert all(isinstance(c, Tree) for c in [child_1, child_2, - child_2_1, child_2_2, child_2_3, child_2_4]) + assert all(isinstance(c, Tree) for c in [child_1, child_2, child_2_1, child_2_2, child_2_3, child_2_4]) assert child_1.label == "1" assert child_2.label == "2" assert child_2_1.label == "2.1" @@ -1004,7 +1071,7 @@ def test_node_id_generation(self, repeat_tree_renderer, monkeypatch): assert child_2_3.label == "2.3" assert child_2_4.label == "2.4" - def test_skip_rendering(self, repeat_tree_renderer, monkeypatch): + def test_skip_rendering(self, repeat_tree_renderer, monkeypatch) -> None: """Test whether the skip logic when fold_repeat = True is correct""" monkeypatch.setattr("torchmeter.display.__cfg__.tree_fold_repeat", True) @@ -1021,7 +1088,7 @@ def test_skip_rendering(self, repeat_tree_renderer, monkeypatch): res = repeat_tree_renderer() - ## display tree is not change + # display tree is not change assert res.label == "0" assert res.children[0].label == "1" assert res.children[1].label == "1" @@ -1037,36 +1104,36 @@ def test_skip_rendering(self, repeat_tree_renderer, monkeypatch): oproot.childs["2"].childs["2.3"]._is_folded = True oproot.childs["2"].childs["2.4"]._is_folded = True - ## display tree is not change + # display tree is not change assert res.label == "0" assert res.children[0].label == "1" assert res.children[1].label == "1" assert all(c.label == "2" for c in res.children[1].children) - def test_repeat_body_generation(self, repeat_tree_renderer, monkeypatch): + def test_repeat_body_generation(self, repeat_tree_renderer, monkeypatch) -> None: """Test whether the repeat body tree is generated correctly""" - + monkeypatch.setattr("torchmeter.display.__cfg__.tree_fold_repeat", True) repeat_tree_renderer.repeat_footer = None repeat_tree_renderer.tree_levels_args = {"2": {"label": ""}} res = repeat_tree_renderer() repeat_body_tree = res.children[1].children[0].label.renderable - + # repeat body tree structure - assert isinstance(repeat_body_tree, Tree) + assert isinstance(repeat_body_tree, Tree) assert repeat_body_tree.hide_root is True assert len(repeat_body_tree.children) == 2 assert all(isinstance(c, Tree) for c in repeat_body_tree.children) - + # repeat body tree content assert "Linear" in repeat_body_tree.children[0].label assert "ReLU" in repeat_body_tree.children[1].label - def test_repeat_block_rendering(self, repeat_tree_renderer, monkeypatch): + def test_repeat_block_rendering(self, repeat_tree_renderer, monkeypatch) -> None: """Test whether the repeat block(panel) can be rendered correctly""" - + monkeypatch.setattr("torchmeter.display.__cfg__.tree_fold_repeat", True) repeat_tree_renderer.repeat_footer = "Footer" repeat_tree_renderer.repeat_block_args = {"title": "test title"} @@ -1077,7 +1144,7 @@ def test_repeat_block_rendering(self, repeat_tree_renderer, monkeypatch): # repeat panel content assert isinstance(repeat_panel, Panel) assert repeat_panel.title == "test title" - + # repeat panel structure assert isinstance(repeat_panel.renderable, Group) assert len(repeat_panel.renderable.renderables) == 3 @@ -1088,102 +1155,94 @@ def test_repeat_block_rendering(self, repeat_tree_renderer, monkeypatch): assert isinstance(footer, str) assert "Footer" in footer - def test_style_application(self, repeat_tree_renderer, monkeypatch): + def test_style_application(self, repeat_tree_renderer, monkeypatch) -> None: """Test the levels styles and repeat block styles are applied correctly""" - + monkeypatch.setattr("torchmeter.display.__cfg__.tree_fold_repeat", True) repeat_tree_renderer.repeat_footer = None - + repeat_tree_renderer.tree_levels_args = { - "default": {"label": "[] -", - "guide_style": "red"}, - - "0": {"label": "", - "style": "magenta"}, - - "1": {"label": "", - "guide_style": "blue"}, - } - - repeat_tree_renderer.repeat_block_args = { - "title": "test title", - "style": "cyan" + "default": {"label": "[] -", "guide_style": "red"}, + "0": {"label": "", "style": "magenta"}, + "1": {"label": "", "guide_style": "blue"}, } - + + repeat_tree_renderer.repeat_block_args = {"title": "test title", "style": "cyan"} + # level 0 res = repeat_tree_renderer() assert res.label == "RepeatModel" assert res.style == "magenta" - + # level 1 child_1, child_2 = res.children assert child_1.label == "1" assert child_2.label == "2" assert all(c.guide_style == "blue" for c in res.children) - + # repeat block (panel) repeat_panel = child_2.children[0].label assert repeat_panel.title == "test title" assert repeat_panel.style == "cyan" - + # level 2 (use default setting) child_2_1, child_2_2 = repeat_panel.renderable.children assert child_2_1.label == "[2.x] 0-Linear" assert child_2_2.label == "[2.(x+1)] 1-ReLU" assert all(c.guide_style == "red" for c in [child_2_1, child_2_2]) - def test_edge_cases(self): + def test_edge_cases(self) -> None: """Test the edge cases in rendering""" - + class EdgeModel(nn.Module): - def __init__(self): + def __init__(self) -> None: super(EdgeModel, self).__init__() - + optree = OperationTree(EdgeModel()) oproot = optree.root tree_renderer = TreeRenderer(oproot) - + # no child nodes res = tree_renderer() assert not res.children - + # repeat_time is modified to an invalid value oproot.repeat_time = 0 with pytest.raises(RuntimeError): tree_renderer() - + # repeat_winsz is modified to an invalid value oproot.repeat_winsz = 0 with pytest.raises(RuntimeError): tree_renderer() - -class TestTabularRenderer: + +class TestTabularRenderer: tbval_getter = lambda _, row_idx, col_idx, tb: tb.columns[col_idx]._cells[row_idx] - def test_valid_init(self, simple_tabular_renderer): + def test_valid_init(self, simple_tabular_renderer) -> None: """Test valid initialization""" assert isinstance(simple_tabular_renderer.opnode, OperationNode) - + stats_data = simple_tabular_renderer._TabularRenderer__stats_data assert len(stats_data) == len(OperationNode.statistics) - def test_invalid_init(self): + def test_invalid_init(self) -> None: """Test invalid initialization""" with pytest.raises(TypeError): TabularRenderer(1) - def test_stats_data(self, simple_tabular_renderer): + def test_stats_data(self, simple_tabular_renderer) -> None: """Test if the stats data property is set and retrieved correctly""" - + stats_data = simple_tabular_renderer.stats_data assert stats_data is simple_tabular_renderer._TabularRenderer__stats_data assert tuple(stats_data.keys()) == OperationNode.statistics assert all(df.is_empty() for df in stats_data.values()) - def test_tb_args(self, simple_tabular_renderer): + def test_tb_args(self, simple_tabular_renderer) -> None: """Test if tb_args is set and retrieved correctly""" - + # retrieve tb_args = simple_tabular_renderer.tb_args assert tb_args is __cfg__.table_display_args @@ -1202,36 +1261,34 @@ def test_tb_args(self, simple_tabular_renderer): assert tb_args.highlight is False # overwrite - ## with invalid type + # with invalid type with pytest.raises(TypeError): simple_tabular_renderer.tb_args = 1 - - ## with invalid field + + # with invalid field with pytest.raises(KeyError): - simple_tabular_renderer.tb_args = {'invalid_field': 'value'} - - ## with combination of invalid field and valid field + simple_tabular_renderer.tb_args = {"invalid_field": "value"} + + # with combination of invalid field and valid field with pytest.raises(KeyError): - simple_tabular_renderer.tb_args = {'invalid_field': 'value', - 'border_style': 'yellow'} - - ## update nothing + simple_tabular_renderer.tb_args = {"invalid_field": "value", "border_style": "yellow"} + + # update nothing tb_args.mark_unchange() simple_tabular_renderer.tb_args = {} assert tb_args.is_change() assert hasattr(tb_args, "box") - - ## update several settings + + # update several settings tb_args.mark_unchange() - simple_tabular_renderer.tb_args = {"style": "red", - "expand": True} + simple_tabular_renderer.tb_args = {"style": "red", "expand": True} assert tb_args.is_change() assert tb_args.style == "red" assert tb_args.expand is True - - def test_col_args(self, simple_tabular_renderer): + + def test_col_args(self, simple_tabular_renderer) -> None: """Test if tb_args is set and retrieved correctly""" - + # retrieve col_args = simple_tabular_renderer.col_args assert col_args is __cfg__.table_column_args @@ -1250,89 +1307,91 @@ def test_col_args(self, simple_tabular_renderer): assert col_args.justify == "left" # overwrite - ## with invalid type + # with invalid type with pytest.raises(TypeError): simple_tabular_renderer.col_args = 1 - - ## with invalid field + + # with invalid field with pytest.raises(KeyError): - simple_tabular_renderer.col_args = {'invalid_field': 'value'} - - ## with combination of invalid field and valid field + simple_tabular_renderer.col_args = {"invalid_field": "value"} + + # with combination of invalid field and valid field with pytest.raises(KeyError): - simple_tabular_renderer.col_args = {'invalid_field': 'value', - 'no_wrap': True} - - ## update nothing + simple_tabular_renderer.col_args = {"invalid_field": "value", "no_wrap": True} + + # update nothing col_args.mark_unchange() simple_tabular_renderer.col_args = {} assert col_args.is_change() assert hasattr(col_args, "style") - - ## update several settings + + # update several settings col_args.mark_unchange() - simple_tabular_renderer.col_args = {"style": "red", - "vertical": "top"} + simple_tabular_renderer.col_args = {"style": "red", "vertical": "top"} assert col_args.is_change() assert col_args.style == "red" assert col_args.vertical == "top" - - def test_df2tb_structure(self, simple_tabular_renderer, example_df): + + def test_df2tb_structure(self, simple_tabular_renderer, example_df) -> None: """Test the rendering table structure""" with patch("torchmeter.display.apply_setting", side_effect=apply_setting) as mock_apply: res = simple_tabular_renderer.df2tb(example_df, show_raw=False) - + assert isinstance(res, Table) assert len(res.columns) == 5 assert res.row_count == 3 tb_headers = [col_obj.header for col_obj in res.columns] - assert tb_headers == ["numeric" , "text" , "list_col" , "nomal_obj" , "self_obj"] + assert tb_headers == ["numeric", "text", "list_col", "nomal_obj", "self_obj"] + + assert str(example_df[0, 0]) == self.tbval_getter(0, 0, res) + assert str(example_df[2, 1]) == self.tbval_getter(2, 1, res) + assert str(example_df[1, 2].to_list()) == self.tbval_getter(1, 2, res) + assert str(example_df[0, 3]) == self.tbval_getter(0, 3, res) - assert str(example_df[0,0]) == self.tbval_getter(0, 0, res) - assert str(example_df[2,1]) == self.tbval_getter(2, 1, res) - assert str(example_df[1,2].to_list()) == self.tbval_getter(1, 2, res) - assert str(example_df[0,3]) == self.tbval_getter(0, 3, res) - # ιͺŒθ―ζ ·εΌεΊ”用调用 - mock_apply.assert_any_call(obj=ANY, - setting=simple_tabular_renderer.tb_args, - omit="columns", - headers=example_df.columns) - mock_apply.assert_any_call(obj=ANY, - omit="header", - setting=simple_tabular_renderer.col_args, - highlight=simple_tabular_renderer.tb_args.highlight) - - def test_df2tb_none_handling(self, simple_tabular_renderer, example_df): + mock_apply.assert_any_call( + obj=ANY, + setting=simple_tabular_renderer.tb_args, + omit="columns", + headers=example_df.columns, + ) + mock_apply.assert_any_call( + obj=ANY, + omit="header", + setting=simple_tabular_renderer.col_args, + highlight=simple_tabular_renderer.tb_args.highlight, + ) + + def test_df2tb_none_handling(self, simple_tabular_renderer, example_df) -> None: """Test none replacement in rendering table""" res = simple_tabular_renderer.df2tb(example_df) - + # int none assert self.tbval_getter(0, 0, res) == "1" assert self.tbval_getter(2, 0, res) == "-" - + # str none assert self.tbval_getter(1, 1, res) == "-" # list none assert self.tbval_getter(2, 2, res) == "-" - + # normal object none assert self.tbval_getter(2, 3, res) == "-" # self object none assert self.tbval_getter(1, 4, res) == "test none_str" - def test_df2tb_show_raw(self, simple_tabular_renderer, example_df): + def test_df2tb_show_raw(self, simple_tabular_renderer, example_df) -> None: """Test whether the show_raw argument works well""" noraml_res = simple_tabular_renderer.df2tb(example_df, show_raw=False) raw_res = simple_tabular_renderer.df2tb(example_df, show_raw=True) - + # verify not raw display - assert self.tbval_getter(0, 0, noraml_res) == "1" + assert self.tbval_getter(0, 0, noraml_res) == "1" assert self.tbval_getter(0, 1, noraml_res) == "a" assert self.tbval_getter(1, 2, noraml_res) == "[3]" assert self.tbval_getter(1, 3, noraml_res) == "dataframe" @@ -1340,487 +1399,553 @@ def test_df2tb_show_raw(self, simple_tabular_renderer, example_df): assert self.tbval_getter(2, 4, noraml_res) == "0.00 Β± 0.00" # verify raw display - assert self.tbval_getter(0, 0, raw_res) == "1" + assert self.tbval_getter(0, 0, raw_res) == "1" assert self.tbval_getter(0, 1, raw_res) == "a" assert self.tbval_getter(1, 2, raw_res) == "[3]" assert self.tbval_getter(1, 3, raw_res) == "dataframe" assert self.tbval_getter(0, 4, raw_res) == "100000.0" assert self.tbval_getter(2, 4, raw_res) == "0.0" - - def test_clear(self, simple_tabular_renderer, example_df, monkeypatch): + + def test_clear(self, simple_tabular_renderer, example_df, monkeypatch) -> None: """Test the stat dataframe clearing logic""" - monkeypatch.setattr(simple_tabular_renderer, - "_TabularRenderer__stats_data", - {"param": example_df, - "cal": example_df, - "mem": example_df, - "ittp": example_df} + monkeypatch.setattr( + simple_tabular_renderer, + "_TabularRenderer__stats_data", + { + "param": example_df, + "cal": example_df, + "mem": example_df, + "ittp": example_df, + }, ) # clear one stat - assert not simple_tabular_renderer.stats_data['param'].is_empty() + assert not simple_tabular_renderer.stats_data["param"].is_empty() simple_tabular_renderer.clear("param") - assert simple_tabular_renderer.stats_data['param'].is_empty() + assert simple_tabular_renderer.stats_data["param"].is_empty() # clear all data - assert not simple_tabular_renderer.stats_data['cal'].is_empty() - assert not simple_tabular_renderer.stats_data['mem'].is_empty() - assert not simple_tabular_renderer.stats_data['ittp'].is_empty() + assert not simple_tabular_renderer.stats_data["cal"].is_empty() + assert not simple_tabular_renderer.stats_data["mem"].is_empty() + assert not simple_tabular_renderer.stats_data["ittp"].is_empty() simple_tabular_renderer.clear() assert all(df.is_empty() for df in simple_tabular_renderer.stats_data.values()) # clear invalid stat with pytest.raises(ValueError): simple_tabular_renderer.clear("invalid_stat") - + # invalid type of pass-in stat_name with pytest.raises(TypeError): simple_tabular_renderer.clear(1) - def test_export(self, simple_tabular_renderer, - example_df, export_dir): + def test_export(self, simple_tabular_renderer, example_df, export_dir) -> None: """Test whether the export method works well""" - + from polars import read_csv - - # format is not specified + + # extension is not specified with pytest.raises(ValueError): - simple_tabular_renderer.export(df=example_df, - save_path=export_dir) - - # format is unsupported + simple_tabular_renderer.export(df=example_df, save_path=export_dir) + + # extension is unsupported with pytest.raises(ValueError): - simple_tabular_renderer.export(df=example_df, - save_path=export_dir, - format="png") - - # without format specified - ## file path specified + simple_tabular_renderer.export( + df=example_df, + save_path=export_dir, + ext="png", + ) + + # without extension specified + # file path specified expected_file = os.path.join(export_dir, "Data.csv") assert not os.path.exists(expected_file) - simple_tabular_renderer.export(df=example_df, - save_path=expected_file) + simple_tabular_renderer.export( + df=example_df, + save_path=expected_file, + ) assert os.path.exists(expected_file) os.remove(expected_file) - - # with format specified - ## dir path specified - ## format without dot + + # with extension specified + # dir path specified + # extension without dot expected_file = os.path.join(export_dir, "Identity.xlsx") assert not os.path.exists(expected_file) - simple_tabular_renderer.export(df=example_df, - save_path=export_dir, - format="xlsx") + simple_tabular_renderer.export( + df=example_df, + save_path=export_dir, + ext="xlsx", + ) assert os.path.exists(expected_file) os.remove(expected_file) - + # without file suffix - ## format with dot + # extension with dot expected_file = os.path.join(export_dir, "Identity.csv") assert not os.path.exists(expected_file) - simple_tabular_renderer.export(df=example_df, - save_path=export_dir, - format=".csv") + simple_tabular_renderer.export( + df=example_df, + save_path=export_dir, + ext=".csv", + ) assert os.path.exists(expected_file) os.remove(expected_file) - + # with file suffix expected_file = os.path.join(export_dir, "Identity_suffix.csv") assert not os.path.exists(expected_file) - simple_tabular_renderer.export(df=example_df, - save_path=export_dir, - format=".csv", - file_suffix="suffix") + simple_tabular_renderer.export( + df=example_df, + save_path=export_dir, + ext=".csv", + file_suffix="suffix", + ) assert os.path.exists(expected_file) os.remove(expected_file) - + # enable raw_data expected_normal_file = os.path.join(export_dir, "Normal.csv") assert not os.path.exists(expected_normal_file) - simple_tabular_renderer.export(df=example_df, - save_path=expected_normal_file) + simple_tabular_renderer.export( + df=example_df, + save_path=expected_normal_file, + ) assert os.path.exists(expected_normal_file) - + expected_raw_file = os.path.join(export_dir, "Raw.csv") assert not os.path.exists(expected_raw_file) - simple_tabular_renderer.export(df=example_df, - save_path=expected_raw_file) + simple_tabular_renderer.export( + df=example_df, + save_path=expected_raw_file, + ) assert os.path.exists(expected_raw_file) - + normal_data = read_csv(expected_normal_file) raw_data = read_csv(expected_raw_file) - + assert not all(normal_data["self_obj"] == raw_data["self_obj"]) - + # list data is converted to str when exporting to csv file assert normal_data["list_col"][1] == "[3]" - + # object data is converted to str assert normal_data["self_obj"][0] == "100 K" - + os.remove(expected_normal_file) os.remove(expected_raw_file) - - def test_new_col(self, simple_tabular_renderer, example_df): + + def test_new_col(self, simple_tabular_renderer, example_df) -> None: """Test whether the __new_col method works well""" - + from polars import Float64 - + new_col = simple_tabular_renderer._TabularRenderer__new_col - + # invalid column name type with pytest.raises(TypeError): - new_col(df=example_df, - col_name=1, - col_func=lambda x: "test") - + new_col( + df=example_df, + col_name=1, + col_func=lambda x: "test", + ) + # duplicated column name with pytest.raises(ValueError): - new_col(df=example_df, - col_name="self_obj", - col_func=lambda x: "test") - + new_col( + df=example_df, + col_name="self_obj", + col_func=lambda x: "test", + ) + # invalid new col index type with pytest.raises(TypeError): - simple_tabular_renderer(stat_name="cal", newcol_idx=1.5) - + simple_tabular_renderer(stat_name="cal", newcol_idx=1.5) + # invalid column function type with pytest.raises(TypeError): - new_col(df=example_df, - col_name="new_col", - col_func="test") - + new_col( + df=example_df, + col_name="new_col", + col_func="test", + ) + # invalid column function argument num - ## lack + # lack with pytest.raises(TypeError): - new_col(df=example_df, - col_name="new_col", - col_func=lambda :...) - - ## exceed + new_col( + df=example_df, + col_name="new_col", + col_func=lambda: ..., + ) + + # exceed with pytest.raises(TypeError): - new_col(df=example_df, - col_name="new_col", - col_func=lambda x, y:...) - + new_col( + df=example_df, + col_name="new_col", + col_func=lambda x, y: ..., + ) + # invalid column function return - ## invalid return type + # invalid return type with pytest.raises(TypeError): - new_col(df=example_df, - col_name="new_col", - col_func=lambda x: 1) - - ## invalid return len + new_col( + df=example_df, + col_name="new_col", + col_func=lambda x: 1, + ) + + # invalid return len with pytest.raises(RuntimeError): - new_col(df=example_df, - col_name="new_col", - col_func=lambda x: [1]) - + new_col( + df=example_df, + col_name="new_col", + col_func=lambda x: [1], + ) + # verify function is applied correctly - new_df = new_col(df=example_df, - col_name="new_col", - col_func=lambda x: ["test"]*len(x), - col_idx=0) + new_df = new_col( + df=example_df, + col_name="new_col", + col_func=lambda x: ["test"] * len(x), + col_idx=0, + ) assert new_df.shape == (3, 6) assert new_df.columns[0] == "new_col" - assert new_df["new_col"].to_list() == ["test"]*3 - + assert new_df["new_col"].to_list() == ["test"] * 3 + # verify funtion operation will not influence the original dataframe - new_df = new_col(df=example_df, - col_name="origin_numeric", - col_func=lambda df: df.drop_in_place(name="numeric"), - col_idx=0) + new_df = new_col( + df=example_df, + col_name="origin_numeric", + col_func=lambda df: df.drop_in_place(name="numeric"), + col_idx=0, + ) assert new_df.shape == (3, 6) assert example_df.shape == (3, 5) assert example_df.columns == ["numeric", "text", "list_col", "nomal_obj", "self_obj"] assert new_df["origin_numeric"].to_list() == example_df["numeric"].to_list() # verify col_idx - ## non-negative and in range - new_df = new_col(df=example_df, - col_name="new_col", - col_func=lambda x: ["test"]*len(x), - col_idx=1) + # non-negative and in range + new_df = new_col( + df=example_df, + col_name="new_col", + col_func=lambda x: ["test"] * len(x), + col_idx=1, + ) assert new_df.shape == (3, 6) assert new_df.columns[1] == "new_col" - - ## non-negative and out of range (add last) - new_df = new_col(df=example_df, - col_name="new_col", - col_func=lambda x: ["test"]*len(x), - col_idx=8) + + # non-negative and out of range (add last) + new_df = new_col( + df=example_df, + col_name="new_col", + col_func=lambda x: ["test"] * len(x), + col_idx=8, + ) assert new_df.shape == (3, 6) assert new_df.columns[5] == "new_col" - ## negative and in range - new_df = new_col(df=example_df, - col_name="new_col", - col_func=lambda x: ["test"]*len(x), - col_idx=-1) + # negative and in range + new_df = new_col( + df=example_df, + col_name="new_col", + col_func=lambda x: ["test"] * len(x), + col_idx=-1, + ) assert new_df.shape == (3, 6) assert new_df.columns[-1] == "new_col" - - ## negative and out of range (add first) - new_df = new_col(df=example_df, - col_name="new_col", - col_func=lambda x: ["test"]*len(x), - col_idx=-9) + + # negative and out of range (add first) + new_df = new_col( + df=example_df, + col_name="new_col", + col_func=lambda x: ["test"] * len(x), + col_idx=-9, + ) assert new_df.shape == (3, 6) assert new_df.columns[0] == "new_col" - - ## verify return_type is correctly applied - new_df = new_col(df=example_df, - col_name="new_col", - col_func=lambda x: [1]*len(x), - return_type=float) + + # verify return_type is correctly applied + new_df = new_col( + df=example_df, + col_name="new_col", + col_func=lambda x: [1] * len(x), + return_type=float, + ) assert new_df["new_col"].dtype == Float64 - - def test_call_valid_use(self, universal_tabular_renderer): + + def test_call_valid_use(self, universal_tabular_renderer) -> None: """Test the valid usage cases of TabularRenderer.__call__""" - + tb, data = universal_tabular_renderer(stat_name="param") assert isinstance(tb, Table) assert isinstance(data, DataFrame) - assert data.shape == (3, 6) # 1 (unuse) + 2 (conv: weight + bias) - - def test_call_invalid_use(self, simple_tabular_renderer): + assert data.shape == (3, 6) # 1 (unuse) + 2 (conv: weight + bias) + + def test_call_invalid_use(self, simple_tabular_renderer) -> None: """Test the invalid usage cases of TabularRenderer.__call__""" - + # invalid stat name with pytest.raises(ValueError): simple_tabular_renderer(stat_name="invalid stat") - + # invalid pick_col type with pytest.raises(TypeError): simple_tabular_renderer(stat_name="cal", pick_col=1) - + # invalid exclude_cols type with pytest.raises(TypeError): simple_tabular_renderer(stat_name="cal", exclude_cols=1) - + # invalid custom_cols type with pytest.raises(TypeError): simple_tabular_renderer(stat_name="cal", custom_cols=1) - + @patch("torchmeter.display.dfs_task", side_effect=dfs_task) - def test_call_data_acquisition(self, mock_dfs_task, - simple_tabular_renderer, universal_tabular_renderer): + def test_call_data_acquisition(self, mock_dfs_task, simple_tabular_renderer, universal_tabular_renderer) -> None: """Test the stat data acquisition logic""" - + class EasyModel(nn.Module): - def __init__(self): + def __init__(self) -> None: super(EasyModel, self).__init__() - self.conv = nn.Conv2d(3,10,3) + self.conv = nn.Conv2d(3, 10, 3) + def forward(self, x): return self.conv(x) - + # skipping when using a layer as a model simple_tabular_renderer.clear() with pytest.raises(RuntimeError): simple_tabular_renderer("param") - + # fill dataframe through dfs in the first call optree = OperationTree(EasyModel()) tabular_renderer = TabularRenderer(optree.root) tabular_renderer.clear() - tb,data1 = tabular_renderer(stat_name="param") - assert data1.shape == (2, 6) # 2: weight + bias - assert mock_dfs_task.call_count == 2 # root + conv - + _tb, data1 = tabular_renderer(stat_name="param") + assert data1.shape == (2, 6) # 2: weight + bias + assert mock_dfs_task.call_count == 2 # root + conv + stat_data = tabular_renderer.stats_data assert not stat_data["param"].is_empty() - + # reuse data in latter call - tb, data2 = tabular_renderer(stat_name="param") - assert mock_dfs_task.call_count == 2 # stay no change - + _tb, _data2 = tabular_renderer(stat_name="param") + assert mock_dfs_task.call_count == 2 # stay no change + # verify no-called module warning oproot = universal_tabular_renderer.opnode universal_tabular_renderer.clear() - list(map(lambda n:n.cal.measure(), list(oproot.childs.values()) + [oproot])) + list(map(lambda n: n.cal.measure(), [*list(oproot.childs.values()), oproot])) with pytest.warns(RuntimeWarning) as w: oproot.operation(torch_randn(1, 3, 20, 20)) - tb, data = universal_tabular_renderer(stat_name="cal") - assert "not explicitly called" in str(w[0].message) - - def test_call_pick_col(self, universal_tabular_renderer): + _tb, _data = universal_tabular_renderer(stat_name="cal") + assert "not explicitly called" in str(w[0].message) + + def test_call_pick_col(self, universal_tabular_renderer) -> None: """Test the column selection logic""" - + universal_tabular_renderer.clear() - + # valid usage - _, data1 = universal_tabular_renderer(stat_name="param", - pick_cols=["Operation_Id", - "Operation_Name", - "Param_Name", - "Numeric_Num"]) + _, data1 = universal_tabular_renderer( + stat_name="param", + pick_cols=[ + "Operation_Id", + "Operation_Name", + "Param_Name", + "Numeric_Num", + ], + ) assert data1.columns == ["Operation_Id", "Operation_Name", "Param_Name", "Numeric_Num"] - + # pick and reorder columns - _, data2 = universal_tabular_renderer(stat_name="param", - pick_cols=["Numeric_Num", - "Param_Name", - "Operation_Name", - "Operation_Id"]) + _, data2 = universal_tabular_renderer( + stat_name="param", + pick_cols=[ + "Numeric_Num", + "Param_Name", + "Operation_Name", + "Operation_Id", + ], + ) assert data2.columns == ["Numeric_Num", "Param_Name", "Operation_Name", "Operation_Id"] - + # invalid column name with pytest.raises(ValueError): - universal_tabular_renderer(stat_name="param", - pick_cols=["invalid_col"]) - - def test_call_exclude_col(self, universal_tabular_renderer): + universal_tabular_renderer(stat_name="param", pick_cols=["invalid_col"]) + + def test_call_exclude_col(self, universal_tabular_renderer) -> None: """Test the column exclusion logic""" - + universal_tabular_renderer.clear() - - _, data = universal_tabular_renderer("param", - exclude_cols=["Operation_Type", - "Operation_Name"]) + + _, data = universal_tabular_renderer("param", exclude_cols=["Operation_Type", "Operation_Name"]) assert "Operation_Type" not in data.columns assert "Operation_Name" not in data.columns - def test_call_custom_col(self, universal_tabular_renderer): + def test_call_custom_col(self, universal_tabular_renderer) -> None: """Test column name customization logic""" - + universal_tabular_renderer.clear() - + # basic usage - _, data = universal_tabular_renderer(stat_name="param", - keep_custom_name=False, - custom_cols={"Operation_Id": "Operation Id", - "Operation_Name": "Operation Name", - "Numeric_Num": "Numeric Value"}) - assert all(col_name not in data.columns for col_name in ["Operation_Id", - "Operation_Name", - "Numeric_Num"]) - assert all(col_name in data.columns for col_name in ["Operation Id", - "Operation Name", - "Numeric Value"]) - + _, data = universal_tabular_renderer( + stat_name="param", + keep_custom_name=False, + custom_cols={ + "Operation_Id": "Operation Id", + "Operation_Name": "Operation Name", + "Numeric_Num": "Numeric Value", + }, + ) + assert all(col_name not in data.columns + for col_name in ["Operation_Id", "Operation_Name", "Numeric_Num"]) # fmt: skip + assert all(col_name in data.columns + for col_name in ["Operation Id", "Operation Name", "Numeric Value"]) # fmt: skip + # invalid column name - _, data = universal_tabular_renderer(stat_name="param", - keep_custom_name=False, - custom_cols={"invalid_col": "invalid col"}) + _, data = universal_tabular_renderer( + stat_name="param", + keep_custom_name=False, + custom_cols={"invalid_col": "invalid col"}, + ) assert "invalid_col" not in data.columns - + # verify option: whether keep customized column name - _, data = universal_tabular_renderer(stat_name="param", - keep_custom_name=True, - custom_cols={"Operation_Id": "Operation Id"}) + _, data = universal_tabular_renderer( + stat_name="param", + keep_custom_name=True, + custom_cols={"Operation_Id": "Operation Id"}, + ) assert "Operation_Id" not in data.columns assert "Operation Id" in data.columns - - ## the same column name's recustomization should base on the new name - _, data = universal_tabular_renderer(stat_name="param", - keep_custom_name=True, - custom_cols={"Operation_Id": "Operation ID"}) + + # the same column name's recustomization should base on the new name + _, data = universal_tabular_renderer( + stat_name="param", + keep_custom_name=True, + custom_cols={"Operation_Id": "Operation ID"}, + ) assert "Operation ID" not in data.columns - - _, data = universal_tabular_renderer(stat_name="param", - keep_custom_name=True, - custom_cols={"Operation Id": "Operation ID"}) + + _, data = universal_tabular_renderer( + stat_name="param", + keep_custom_name=True, + custom_cols={"Operation Id": "Operation ID"}, + ) assert "Operation ID" in data.columns - - def test_call_pick_exclude_cooperation(self, universal_tabular_renderer): + + def test_call_pick_exclude_cooperation(self, universal_tabular_renderer) -> None: """Test whether exclude logic and selection logic work well together""" - + universal_tabular_renderer.clear() - + # pick_col and exclude_col are mutually exclusive # then exclude_col does not take effect - _, data = universal_tabular_renderer(stat_name="param", - pick_cols=["Operation_Id", - "Param_Name", - "Numeric_Num"], - exclude_cols=["Operation_Name"]) + _, data = universal_tabular_renderer( + stat_name="param", + pick_cols=["Operation_Id", "Param_Name", "Numeric_Num"], + exclude_cols=["Operation_Name"], + ) assert data.columns == ["Operation_Id", "Param_Name", "Numeric_Num"] - + # pick_col and exclude_col have intersection # then exclude_col takes effect - _, data = universal_tabular_renderer(stat_name="param", - pick_cols=["Operation_Id", - "Param_Name", - "Numeric_Num"], - exclude_cols=["Operation_Id"]) + _, data = universal_tabular_renderer( + stat_name="param", + pick_cols=["Operation_Id", "Param_Name", "Numeric_Num"], + exclude_cols=["Operation_Id"], + ) assert data.columns == ["Param_Name", "Numeric_Num"] - def test_call_pick_custom_cooperation(self, universal_tabular_renderer): + def test_call_pick_custom_cooperation(self, universal_tabular_renderer) -> None: """Test whether custom_col logic works after selection logic""" - + universal_tabular_renderer.clear() - + # pick_col and custom_col are mutually exclusive # then custom_col does not take effect - _, data = universal_tabular_renderer(stat_name="param", - pick_cols=["Operation_Id", - "Param_Name"], - custom_cols={"Operation_Type": "Operation Type"}) + _, data = universal_tabular_renderer( + stat_name="param", + pick_cols=["Operation_Id", "Param_Name"], + custom_cols={"Operation_Type": "Operation Type"}, + ) assert data.columns == ["Operation_Id", "Param_Name"] - + # pick_col and custom_col have intersection # then use the origin name to pick columns and then customization takes effect - _, data = universal_tabular_renderer(stat_name="param", - pick_cols=["Operation_Id", - "Param_Name"], - custom_cols={"Operation_Id": "Operation ID"}) + _, data = universal_tabular_renderer( + stat_name="param", + pick_cols=["Operation_Id", "Param_Name"], + custom_cols={"Operation_Id": "Operation ID"}, + ) assert data.columns == ["Operation ID", "Param_Name"] - + # selection does happen before customization with pytest.raises(ValueError): - universal_tabular_renderer(stat_name="param", - pick_cols=["Operation ID2"], - custom_cols={"Operation ID": "Operation ID2"}) - - def test_call_keep_new_col(self, universal_tabular_renderer): + universal_tabular_renderer( + stat_name="param", + pick_cols=["Operation ID2"], + custom_cols={"Operation ID": "Operation ID2"}, + ) + + def test_call_keep_new_col(self, universal_tabular_renderer) -> None: """Test whether the keep_new_col option works well""" - + universal_tabular_renderer.clear() - + # not to keep - _, data = universal_tabular_renderer(stat_name="param", - newcol_name="test", - newcol_func=lambda x: ["test"]*len(x), - keep_new_col=False) + _, data = universal_tabular_renderer( + stat_name="param", + newcol_name="test", + newcol_func=lambda x: ["test"] * len(x), + keep_new_col=False, + ) assert "test" in data.columns assert "test" not in universal_tabular_renderer.stats_data["param"].columns - + # keep - _, data = universal_tabular_renderer(stat_name="param", - newcol_name="test", - newcol_func=lambda x: ["test"]*len(x), - keep_new_col=True) + _, data = universal_tabular_renderer( + stat_name="param", + newcol_name="test", + newcol_func=lambda x: ["test"] * len(x), + keep_new_col=True, + ) assert "test" in data.columns assert "test" in universal_tabular_renderer.stats_data["param"].columns - @patch('torchmeter.display.TabularRenderer.export') - def test_export_trigger(self, mock_export, universal_tabular_renderer): + @patch("torchmeter.display.TabularRenderer.export") + def test_export_trigger(self, mock_export, universal_tabular_renderer) -> None: """Test whether the save_to argument can trigger export""" # not trigger universal_tabular_renderer(stat_name="param") mock_export.assert_not_called() - + # trigger universal_tabular_renderer(stat_name="param", save_to="test.csv") mock_export.assert_called_once() - def test_style_application(self, universal_tabular_renderer): + def test_style_application(self, universal_tabular_renderer) -> None: """Test the levels styles and repeat block styles are applied correctly""" - + universal_tabular_renderer.tb_args = { "style": "red", "highlight": False, "caption": "test caption", - "show_lines": True + "show_lines": True, } - + universal_tabular_renderer.col_args = { "style": "blue", "justify": "left", - "no_wrap": True + "no_wrap": True, } tb, _ = universal_tabular_renderer(stat_name="param") @@ -1835,14 +1960,14 @@ def test_style_application(self, universal_tabular_renderer): assert col.justify == "left" assert col.no_wrap is True - def test_edge_cases(self, simple_tabular_renderer): + def test_edge_cases(self, simple_tabular_renderer) -> None: """Test the edge cases in rendering""" - + from polars import Float64 as pl_float64 # rebder empty dataframe empty_df = DataFrame(schema={"col1": pl_float64}) res = simple_tabular_renderer.df2tb(empty_df) - + assert res.row_count == 0 assert len(res.columns) == 1 diff --git a/tests/test_engine.py b/tests/test_engine.py index 7ae18c4..fa61df5 100644 --- a/tests/test_engine.py +++ b/tests/test_engine.py @@ -4,32 +4,31 @@ import pytest import torch.nn as nn -from torchmeter.engine import ( - OperationNode, OperationTree, - ParamsMeter, CalMeter, - MemMeter, IttpMeter -) +from torchmeter.engine import CalMeter, MemMeter, IttpMeter, ParamsMeter, OperationNode, OperationTree + @pytest.fixture def linear_model(): return nn.Linear(10, 5) + @pytest.fixture def sequential_model(): """A sequential model with a repeat structure""" - return nn.Sequential(OrderedDict([ - ("first_conv", nn.Conv2d(3, 6, 3)), - ("first_relu", nn.ReLU()), - ("second_conv", nn.Conv2d(3, 6, 3)), - ("second_relu", nn.ReLU()) - ])) + return nn.Sequential( + OrderedDict([ + ("first_conv", nn.Conv2d(3, 6, 3)), + ("first_relu", nn.ReLU()), + ("second_conv", nn.Conv2d(3, 6, 3)), + ("second_relu", nn.ReLU()), + ]) + ) + @pytest.fixture def nested_model(linear_model, sequential_model): - return nn.Sequential(OrderedDict([ - ("s", sequential_model), - ("l", linear_model) - ])) + return nn.Sequential(OrderedDict([("s", sequential_model), ("l", linear_model)])) + @pytest.fixture def check_scanning_process(capsys): @@ -37,24 +36,25 @@ def check_scanning_process(capsys): captured = capsys.readouterr() assert "Finish Scanning model in" in captured.out + @pytest.mark.vital class TestOPN: - def test_invalid_init(self): + def test_invalid_init(self) -> None: """Test whether non-module object cannot be initialized""" with pytest.raises(TypeError): OperationNode(module="not_a_module") - - def test_valid_init(self, linear_model): + + def test_valid_init(self, linear_model) -> None: """Test basic attributes""" - - assert OperationNode.statistics == ('param', 'cal', 'mem', 'ittp') - + + assert OperationNode.statistics == ("param", "cal", "mem", "ittp") + node = OperationNode( module=linear_model, name="TestLinear", - node_id="1.2.3" + node_id="1.2.3", ) - + assert node.operation == linear_model assert node.type == "Linear" assert node.name == "TestLinear" @@ -68,7 +68,7 @@ def test_valid_init(self, linear_model): assert node._render_when_repeat is False assert node._is_folded is False - def test_module_repr(self, linear_model, sequential_model): + def test_module_repr(self, linear_model, sequential_model) -> None: """Test the module_repr attribute""" node = OperationNode(module=linear_model) @@ -77,31 +77,31 @@ def test_module_repr(self, linear_model, sequential_model): node = OperationNode(module=sequential_model) assert node.module_repr == str(sequential_model.__class__.__name__) - def test_default_name(self, linear_model): + def test_default_name(self, linear_model) -> None: """Test the default name is the type name when no name is provided""" node = OperationNode(module=linear_model) assert node.name == "Linear" - def test_hierarchical_attrs(self, linear_model): + def test_hierarchical_attrs(self, linear_model) -> None: """Test parent-child relationship""" parent = OperationNode(module=nn.Module(), node_id="1") child = OperationNode( module=linear_model, parent=parent, - node_id="1.1" + node_id="1.1", ) - + parent.childs["1.1"] = child - + assert child.parent is parent assert "1.1" in parent.childs assert parent.childs["1.1"] is child - def test_is_leaf(self, linear_model, sequential_model): + def test_is_leaf(self, linear_model, sequential_model) -> None: """Test the is_leaf property is correcttly set""" leaf_node = OperationNode(module=linear_model) non_leaf_node = OperationNode(module=sequential_model) - + assert leaf_node.is_leaf is True assert non_leaf_node.is_leaf is False @@ -111,47 +111,48 @@ def test_is_leaf(self, linear_model, sequential_model): ("param", ParamsMeter), ("cal", CalMeter), ("mem", MemMeter), - ("ittp", IttpMeter) - ] + ("ittp", IttpMeter), + ], ) - def test_statistic_attrs(self, linear_model, stat_name, stat_cls): + def test_statistic_attrs(self, linear_model, stat_name, stat_cls) -> None: """Test whether all the statistic attributes are created correctly and are all read-only""" node = OperationNode(module=linear_model) - + stat = getattr(node, stat_name) assert isinstance(stat, stat_cls) assert stat._opnode is node - + with pytest.raises(AttributeError): node.param = None - + with pytest.raises(AttributeError): delattr(node, stat_name) - def test_repr(self, linear_model, sequential_model): + def test_repr(self, linear_model, sequential_model) -> None: """Test repr""" leaf_node = OperationNode(module=linear_model) non_leaf_node = OperationNode(module=sequential_model) - - assert repr(leaf_node) == f"0 Linear: {str(linear_model)}" + + assert repr(leaf_node) == f"0 Linear: {linear_model!s}" assert repr(non_leaf_node) == "0 Sequential: Sequential" + @pytest.mark.vital @pytest.mark.usefixtures("check_scanning_process") -class TestOPT: - def test_single_layer_model(self, linear_model): +class TestOPT: + def test_single_layer_model(self, linear_model) -> None: """Test building operation tree for a single-layer model""" tree = OperationTree(linear_model) - - assert len(tree.all_nodes) == 1 - + + assert len(tree.all_nodes) == 1 + # basic attributes root = tree.root assert root.operation is linear_model assert root.type == "Linear" assert root.name == "Linear" assert root.node_id == "0" - + # hierarchical attributes assert root.parent is None assert not root.childs @@ -167,118 +168,116 @@ def test_single_layer_model(self, linear_model): assert root._render_when_repeat is True assert root._is_folded is False - def test_sequential_model(self, sequential_model): + def test_sequential_model(self, sequential_model) -> None: """Test building operation tree for a sequential model""" tree = OperationTree(sequential_model) - + assert len(tree.all_nodes) == 5 - + # basic attributes root = tree.root assert root.operation is sequential_model assert root.type == "Sequential" assert root.name == "Sequential" assert root.node_id == "0" - + # hierarchical attributes assert root.parent is None assert len(root.childs) == 4 - assert list(root.childs.keys()) == ['1', '2', '3', '4'] # it is node_id also + assert list(root.childs.keys()) == ["1", "2", "3", "4"] # it is node_id also assert all(c.parent is root for c in root.childs.values()) assert all(not c.childs for c in root.childs.values()) - - assert root.childs['1'].type == "Conv2d" - assert root.childs['2'].type == "ReLU" - assert root.childs['3'].type == "Conv2d" - assert root.childs['4'].type == "ReLU" - - assert root.childs['1'].name == "first_conv" - assert root.childs['2'].name == "first_relu" - assert root.childs['3'].name == "second_conv" - assert root.childs['4'].name == "second_relu" - + + assert root.childs["1"].type == "Conv2d" + assert root.childs["2"].type == "ReLU" + assert root.childs["3"].type == "Conv2d" + assert root.childs["4"].type == "ReLU" + + assert root.childs["1"].name == "first_conv" + assert root.childs["2"].name == "first_relu" + assert root.childs["3"].name == "second_conv" + assert root.childs["4"].name == "second_relu" + # repeat-related attributes assert root.repeat_winsz == 1 assert root.repeat_time == 1 assert root._repeat_body == [] - - assert root.childs['1'].repeat_winsz == 2 - assert root.childs['1'].repeat_time == 2 - assert root.childs['1']._repeat_body == [("1", "first_conv"), - ("2", "first_relu")] + + assert root.childs["1"].repeat_winsz == 2 + assert root.childs["1"].repeat_time == 2 + assert root.childs["1"]._repeat_body == [("1", "first_conv"), ("2", "first_relu")] assert all(c.repeat_winsz == 1 for c in root.childs.values() if c.node_id != "1") assert all(c.repeat_time == 1 for c in root.childs.values() if c.node_id != "1") assert all(not c._repeat_body for c in root.childs.values() if c.node_id != "1") - + # display-related attributes assert all(hasattr(n, "display_root") for n in tree.all_nodes) assert root.display_root.label == "0" assert all(c.display_root.label == "1" for c in root.childs.values()) assert root._render_when_repeat is True - assert root.childs['1']._render_when_repeat is True - assert root.childs['2']._render_when_repeat is True - assert root.childs['3']._render_when_repeat is False - assert root.childs['4']._render_when_repeat is False + assert root.childs["1"]._render_when_repeat is True + assert root.childs["2"]._render_when_repeat is True + assert root.childs["3"]._render_when_repeat is False + assert root.childs["4"]._render_when_repeat is False assert root._is_folded is False - assert root.childs['1']._is_folded is False - assert root.childs['2']._is_folded is True # True because in the repeat_body - assert root.childs['3']._is_folded is False # False because skip the visit - assert root.childs['4']._is_folded is False # False because skip the visit - - def test_nested_model(self, nested_model): + assert root.childs["1"]._is_folded is False + assert root.childs["2"]._is_folded is True # True because in the repeat_body + assert root.childs["3"]._is_folded is False # False because skip the visit + assert root.childs["4"]._is_folded is False # False because skip the visit + + def test_nested_model(self, nested_model) -> None: """Test building operation tree for a sequential model""" tree = OperationTree(nested_model) - + assert len(tree.all_nodes) == 7 - + # basic attributes root = tree.root assert root.operation is nested_model assert root.type == "Sequential" assert root.name == "Sequential" assert root.node_id == "0" - + # hierarchical attributes assert root.parent is None assert len(root.childs) == 2 assert sum(len(c.childs) for c in root.childs.values()) == 4 - assert list(root.childs.keys()) == ['1', '2'] # it is node_id also - - child_1 = root.childs['1'] - child_2 = root.childs['2'] - assert list(child_1.childs.keys()) == ['1.1', '1.2', '1.3', '1.4'] # it is node_id also + assert list(root.childs.keys()) == ["1", "2"] # it is node_id also + + child_1 = root.childs["1"] + child_2 = root.childs["2"] + assert list(child_1.childs.keys()) == ["1.1", "1.2", "1.3", "1.4"] # it is node_id also assert not child_2.childs assert all(c.parent is root for c in root.childs.values()) assert all(c.parent is child_1 for c in child_1.childs.values()) - - assert root.childs['1'].type == "Sequential" - assert root.childs['2'].type == "Linear" - assert child_1.childs['1.1'].type == "Conv2d" - assert child_1.childs['1.2'].type == "ReLU" - assert child_1.childs['1.3'].type == "Conv2d" - assert child_1.childs['1.4'].type == "ReLU" - - assert root.childs['1'].name == "s" - assert root.childs['2'].name == "l" - assert child_1.childs['1.1'].name == "first_conv" - assert child_1.childs['1.2'].name == "first_relu" - assert child_1.childs['1.3'].name == "second_conv" - assert child_1.childs['1.4'].name == "second_relu" - + + assert root.childs["1"].type == "Sequential" + assert root.childs["2"].type == "Linear" + assert child_1.childs["1.1"].type == "Conv2d" + assert child_1.childs["1.2"].type == "ReLU" + assert child_1.childs["1.3"].type == "Conv2d" + assert child_1.childs["1.4"].type == "ReLU" + + assert root.childs["1"].name == "s" + assert root.childs["2"].name == "l" + assert child_1.childs["1.1"].name == "first_conv" + assert child_1.childs["1.2"].name == "first_relu" + assert child_1.childs["1.3"].name == "second_conv" + assert child_1.childs["1.4"].name == "second_relu" + # repeat-related attributes assert root.repeat_winsz == 1 assert root.repeat_time == 1 assert root._repeat_body == [] - + assert all(c.repeat_winsz * c.repeat_time == 1 for c in root.childs.values()) assert all(not c._repeat_body for c in root.childs.values()) - assert child_1.childs['1.1'].repeat_winsz == 2 - assert child_1.childs['1.1'].repeat_time == 2 - assert child_1.childs['1.1']._repeat_body == [("1.1", "first_conv"), - ("1.2", "first_relu")] + assert child_1.childs["1.1"].repeat_winsz == 2 + assert child_1.childs["1.1"].repeat_time == 2 + assert child_1.childs["1.1"]._repeat_body == [("1.1", "first_conv"), ("1.2", "first_relu")] assert all(c.repeat_winsz * c.repeat_time == 1 for c in child_1.childs.values() if c.node_id != "1.1") assert all(not c._repeat_body for c in child_1.childs.values() if c.node_id != "1.1") - + # display-related attributes assert all(hasattr(n, "display_root") for n in tree.all_nodes) assert root.display_root.label == "0" @@ -286,61 +285,58 @@ def test_nested_model(self, nested_model): assert all(c.display_root.label == "2" for c in child_1.childs.values()) assert root._render_when_repeat is True assert all(c._render_when_repeat is True for c in root.childs.values()) - assert child_1.childs['1.1']._render_when_repeat is True - assert child_1.childs['1.2']._render_when_repeat is True - assert child_1.childs['1.3']._render_when_repeat is False - assert child_1.childs['1.4']._render_when_repeat is False + assert child_1.childs["1.1"]._render_when_repeat is True + assert child_1.childs["1.2"]._render_when_repeat is True + assert child_1.childs["1.3"]._render_when_repeat is False + assert child_1.childs["1.4"]._render_when_repeat is False assert root._is_folded is False assert all(c._is_folded is False for c in root.childs.values()) - assert child_1.childs['1.1']._is_folded is False - assert child_1.childs['1.2']._is_folded is True # True because in the repeat_body - assert child_1.childs['1.3']._is_folded is False # False because skip the visit - assert child_1.childs['1.4']._is_folded is False # False because skip the visit + assert child_1.childs["1.1"]._is_folded is False + assert child_1.childs["1.2"]._is_folded is True # True because in the repeat_body + assert child_1.childs["1.3"]._is_folded is False # False because skip the visit + assert child_1.childs["1.4"]._is_folded is False # False because skip the visit - def test_repeat_detection(self): + def test_repeat_detection(self) -> None: """Test repeat detection""" - model = nn.Sequential( - *[nn.Conv2d(3, 6, 3) for _ in range(4)], - nn.ReLU() - ) - + model = nn.Sequential(*[nn.Conv2d(3, 6, 3) for _ in range(4)], nn.ReLU()) + tree = OperationTree(model) root = tree.root assert root.repeat_winsz == 1 assert root.repeat_time == 1 assert root._repeat_body == [] - + assert all(c.repeat_winsz * c.repeat_time == 1 for c in root.childs.values() if c.name != "0") assert all(not c._repeat_body for c in root.childs.values() if c.name != "0") - - assert root.childs['1'].repeat_winsz == 1 # not 4 because all the layers in the repeat block are the same - assert root.childs['1'].repeat_time == 4 - assert root.childs['1']._repeat_body == [("1", "0")] - + + assert root.childs["1"].repeat_winsz == 1 # not 4 because all the layers in the repeat block are the same + assert root.childs["1"].repeat_time == 4 + assert root.childs["1"]._repeat_body == [("1", "0")] + assert root._render_when_repeat is True - assert root.childs['1']._render_when_repeat is True - assert root.childs['2']._render_when_repeat is False - assert root.childs['3']._render_when_repeat is False - assert root.childs['4']._render_when_repeat is False - assert root.childs['5']._render_when_repeat is True + assert root.childs["1"]._render_when_repeat is True + assert root.childs["2"]._render_when_repeat is False + assert root.childs["3"]._render_when_repeat is False + assert root.childs["4"]._render_when_repeat is False + assert root.childs["5"]._render_when_repeat is True assert root._is_folded is False - assert root.childs['1']._is_folded is False - assert root.childs['2']._is_folded is False # True because in the repeat_body - assert root.childs['3']._is_folded is False # False because skip the visit - assert root.childs['4']._is_folded is False # False because skip the visit - assert root.childs['5']._is_folded is False + assert root.childs["1"]._is_folded is False + assert root.childs["2"]._is_folded is False # True because in the repeat_body + assert root.childs["3"]._is_folded is False # False because skip the visit + assert root.childs["4"]._is_folded is False # False because skip the visit + assert root.childs["5"]._is_folded is False - def test_display_tree_construction(self, nested_model): + def test_display_tree_construction(self, nested_model) -> None: """Test building display tree""" tree = OperationTree(nested_model) - + root = tree.root - child_1 = root.childs['1'] - child_2 = root.childs['2'] - + child_1 = root.childs["1"] + child_2 = root.childs["2"] + # verify display tree label, i.e. the level of the node in the display tree - assert root.display_root.label == '0' + assert root.display_root.label == "0" assert all(c.display_root.label == "1" for c in root.childs.values()) assert all(c.display_root.label == "2" for c in child_1.childs.values()) @@ -350,41 +346,37 @@ def test_display_tree_construction(self, nested_model): assert not len(child_2.display_root.children) assert child_2.display_root in root.display_root.children - def test_large_scale_model_construction(self): - model = nn.Sequential( - *[nn.Sequential(nn.Linear(100, 100), nn.ReLU()) for _ in range(100)] - ) + def test_large_scale_model_construction(self) -> None: + model = nn.Sequential(*[nn.Sequential(nn.Linear(100, 100), nn.ReLU()) for _ in range(100)]) tree = OperationTree(model) - + assert len(tree.all_nodes) == 301 - assert tree.root.childs['100'].childs['100.2'].type == "ReLU" + assert tree.root.childs["100"].childs["100.2"].type == "ReLU" - def test_custom_module(self): + def test_custom_module(self) -> None: """Test model made up of custom module and standard module""" + class CustomLayer(nn.Module): - def __init__(self): + def __init__(self) -> None: super().__init__() self.layer = nn.Linear(10, 10) - - model = nn.Sequential( - CustomLayer(), - nn.Sequential(nn.ReLU(), nn.Tanh()) - ) - + + model = nn.Sequential(CustomLayer(), nn.Sequential(nn.ReLU(), nn.Tanh())) + tree = OperationTree(model) root = tree.root - assert root.childs['1'].type == "CustomLayer" - assert root.childs['2'].type == "Sequential" + assert root.childs["1"].type == "CustomLayer" + assert root.childs["2"].type == "Sequential" - assert root.childs['1'].childs['1.1'].name == "layer" - assert root.childs['1'].childs['1.1'].type == "Linear" - assert root.childs['2'].childs['2.1'].name == "0" - assert root.childs['2'].childs['2.1'].type == "ReLU" - assert root.childs['2'].childs['2.2'].name == "1" - assert root.childs['2'].childs['2.2'].type == "Tanh" + assert root.childs["1"].childs["1.1"].name == "layer" + assert root.childs["1"].childs["1.1"].type == "Linear" + assert root.childs["2"].childs["2.1"].name == "0" + assert root.childs["2"].childs["2.1"].type == "ReLU" + assert root.childs["2"].childs["2.2"].name == "1" + assert root.childs["2"].childs["2.2"].type == "Tanh" - def test_repr(self, nested_model): + def test_repr(self, nested_model) -> None: """Test __repr__ logic""" tree = OperationTree(nested_model) @@ -394,8 +386,9 @@ def test_repr(self, nested_model): str(tree) mock_opn_repr.assert_called_once() + @pytest.mark.vital -def test_invalid_init(): +def test_invalid_init() -> None: """Test whether non-module object cannot be initialized""" with pytest.raises(TypeError): - OperationTree(model="not_a_module") \ No newline at end of file + OperationTree(model="not_a_module") diff --git a/tests/test_stat_numeric.py b/tests/test_stat_numeric.py index 388707b..f6834ab 100644 --- a/tests/test_stat_numeric.py +++ b/tests/test_stat_numeric.py @@ -1,28 +1,28 @@ from typing import Union from decimal import Decimal -import pytest import numpy as np +import pytest -from torchmeter._stat_numeric import ( - CountUnit, BinaryUnit, - NumericData, UpperLinkData, MetricsData -) +from torchmeter._stat_numeric import CountUnit, BinaryUnit, MetricsData, NumericData, UpperLinkData pytestmark = pytest.mark.vital + class SimpleNumeric(NumericData): - def __init__(self, value: Union[int, float]): + def __init__(self, value: Union[int, float]) -> None: self._value = value @property def raw_data(self) -> float: return float(self._value) + @pytest.fixture def base_upperlink_data(): return UpperLinkData(val=100) + @pytest.fixture def linked_upperlink_data(): parent = UpperLinkData(val=200) @@ -31,36 +31,35 @@ def linked_upperlink_data(): class TestNumericData: - @pytest.mark.parametrize( - argnames=["a", "b", "expected"], + argnames=("a", "b", "expected"), argvalues=[ (5.0, 5.0, True), (5.0, 5, True), (5.0, 4.9, False), (-3.5, -3.5, True), - (0.0, 0, True) - ] + (0.0, 0, True), + ], ) - def test_equality(self, a, b, expected): + def test_equality(self, a, b, expected) -> None: """Test the logic of __eq__ and __ne__""" - + num_a = SimpleNumeric(a) assert (num_a == b) == expected assert (num_a != b) != expected @pytest.mark.parametrize( - argnames=["a", "b", "latter_larger"], + argnames=("a", "b", "latter_larger"), argvalues=[ (5.0, 3.0, False), (2.5, 3.0, True), (-4.0, -3.0, True), - (0.0, 0.0, False) - ] + (0.0, 0.0, False), + ], ) - def test_ordering(self, a, b, latter_larger): + def test_ordering(self, a, b, latter_larger) -> None: """Test the logic of __lt__, __le__, __gt__, __ge__""" - + num_a = SimpleNumeric(a) assert (num_a < b) == (latter_larger and a != b) assert (num_a <= b) == (latter_larger or (a == b)) @@ -68,92 +67,92 @@ def test_ordering(self, a, b, latter_larger): assert (num_a >= b) == (not latter_larger or a == b) @pytest.mark.parametrize( - argnames=["op", "a", "b", "expected"], + argnames=("op", "a", "b", "expected"), argvalues=[ - ('+', 5.0, 3.0, 8.0), - ('+', -2.5, 3.0, 0.5), - ('-', 10.0, 4.5, 5.5), - ('*', 2.5, 4.0, 10.0), - ('/', 9.0, 2.0, 4.5), - ('+', 5.0, SimpleNumeric(3.0), 8.0), - ('*', SimpleNumeric(2.0), 3, 6.0), - ('*', SimpleNumeric(3), 4.0, 12.0), - ('/', 9.0, SimpleNumeric(3), 3.0) - ] + ("+", 5.0, 3.0, 8.0), + ("+", -2.5, 3.0, 0.5), + ("-", 10.0, 4.5, 5.5), + ("*", 2.5, 4.0, 10.0), + ("/", 9.0, 2.0, 4.5), + ("+", 5.0, SimpleNumeric(3.0), 8.0), + ("*", SimpleNumeric(2.0), 3, 6.0), + ("*", SimpleNumeric(3), 4.0, 12.0), + ("/", 9.0, SimpleNumeric(3), 3.0), + ], ) - def test_arithmetic_operations(self, op, a, b, expected): + def test_arithmetic_operations(self, op, a, b, expected) -> None: """Test the logic of arithmetic operations""" - + if isinstance(a, float): a = SimpleNumeric(a) if isinstance(b, float): b = SimpleNumeric(b) - + result = { - '+': a + b, - '-': a - b, - '*': a * b, - '/': a / b + "+": a + b, + "-": a - b, + "*": a * b, + "/": a / b, }[op] - + assert expected == pytest.approx(result) - def test_reverse_operations(self): + def test_reverse_operations(self) -> None: """Test the logic of reverse arithmetic operations""" - + num = SimpleNumeric(3.0) - + # __radd__ assert pytest.approx(2 + num) == 5.0 - + # __rsub__ assert pytest.approx(5 - num) == 2.0 - + # __rmul__ assert pytest.approx(2 * num) == 6.0 - + # __rtruediv__ assert pytest.approx(6 / num) == 2.0 @pytest.mark.parametrize( - argnames=["value", "expected"], + argnames=("value", "expected"), argvalues=[ (5.5, 5.5), (-3.2, -3.2), - (0.0, 0.0) - ] + (0.0, 0.0), + ], ) - def test_type_conversion(self, value, expected): + def test_type_conversion(self, value, expected) -> None: """Test type conversion""" - + num = SimpleNumeric(value) assert float(num) == expected assert int(num) == int(expected) assert round(num) == round(expected) - def test_hash_behavior(self): + def test_hash_behavior(self) -> None: """Test unhashable""" - + with pytest.raises(TypeError): hash(SimpleNumeric(5.0)) - + with pytest.raises(TypeError): _ = {SimpleNumeric(5.0), SimpleNumeric(5.0)} - - def test_numpy_compatability(self): + + def test_numpy_compatability(self) -> None: """Test the compatibility with numpy""" - + num = SimpleNumeric(3.5) arr = np.array([num, 1.5, 2.0]) assert np.sum(arr) == pytest.approx(7.0) assert np.mean(arr) == pytest.approx((3.5 + 1.5 + 2.0) / 3) assert np.sum(np.sort(arr) == [1.5, 2.0, 3.5]) == len(arr) - - def test_polars_compatability(self): + + def test_polars_compatability(self) -> None: """Test the compatibility with polars""" - + from polars import Series - + num = SimpleNumeric(5.5) arr = Series(values=[num, 1.5, 2.0]) assert arr.sum() == pytest.approx(9.0) @@ -162,83 +161,87 @@ def test_polars_compatability(self): assert arr.mean() == pytest.approx((1.5 + 2.0 + 5.5) / 3) assert sum(arr.sort() == [1.5, 2.0, 5.5]) == len(arr) - @pytest.mark.parametrize("invalid_input", [ - "string", - Decimal('10.5'), - {'key': 'value'} - ]) - def test_invalid_operations(self, invalid_input): + @pytest.mark.parametrize( + argnames="invalid_input", + argvalues=[ + "string", + Decimal("10.5"), + {"key": "value"}, + ], + ) + def test_invalid_operations(self, invalid_input) -> None: """Test exception thrown in invalid operations""" - + num = SimpleNumeric(5.0) with pytest.raises(TypeError): - _ = num + invalid_input + _ = num + invalid_input + class TestUpperLinkData: - def test_valid_init(self): + def test_valid_init(self) -> None: """Test initialization with different arguments""" data = UpperLinkData() assert data.val == 0 assert data._UpperLinkData__parent_data is None assert data._UpperLinkData__unit_sys is None assert data._UpperLinkData__access_cnt == 1 - assert data.none_str == '-' + assert data.none_str == "-" parent = UpperLinkData() data = UpperLinkData( val=100, parent_data=parent, unit_sys=BinaryUnit, - none_str="N/A" + none_str="N/A", ) assert data.val == 100 assert data._UpperLinkData__parent_data is parent assert data._UpperLinkData__unit_sys is BinaryUnit assert data.none_str == "N/A" - def test_invalid_init(self): + def test_invalid_init(self) -> None: """Test invalid initialization""" with pytest.raises(TypeError): UpperLinkData(val={}) - + with pytest.raises(TypeError): UpperLinkData(val=100, parent_data=100) - + with pytest.raises(TypeError): - UpperLinkData(unit_sys="KB") - + UpperLinkData(unit_sys="KB") + with pytest.raises(TypeError): UpperLinkData(none_str=22) - def test_slots(self, base_upperlink_data): + def test_slots(self, base_upperlink_data) -> None: """Test __slots__ restriction""" assert not hasattr(base_upperlink_data, "__dict__") - + with pytest.raises(AttributeError): base_upperlink_data.invalid_attribute = 42 - def test_raw_data(self, base_upperlink_data): + def test_raw_data(self, base_upperlink_data) -> None: """Test whether the raw_data property is correcttly calculated""" assert base_upperlink_data.raw_data == 100.0 base_upperlink_data.val = 150 assert base_upperlink_data.raw_data == 150.0 - def test_mark_access(self, base_upperlink_data): + def test_mark_access(self, base_upperlink_data) -> None: """Test whether the mark_access method is correct""" assert base_upperlink_data._UpperLinkData__access_cnt == 1 base_upperlink_data.mark_access() base_upperlink_data.mark_access() assert base_upperlink_data._UpperLinkData__access_cnt == 3 - def test_inplace_addition(self, base_upperlink_data): + def test_inplace_addition(self, base_upperlink_data) -> None: """Test inplace addition""" base_upperlink_data += 50 assert base_upperlink_data.val == 150 - def test_linked_update(self, linked_upperlink_data): + def test_linked_update(self, linked_upperlink_data) -> None: """Test whether the inplace addition will trigger the update of parent data""" parent, child = linked_upperlink_data - + # single linked data update child += 50 assert child.val == 100 @@ -249,53 +252,54 @@ def test_linked_update(self, linked_upperlink_data): parent._UpperLinkData__parent_data = grandparent child += 100 assert child.val == 200 - assert parent.val == 350 + assert parent.val == 350 assert grandparent.val == 600 - - # verify the common arithmetic operations + + # verify the common arithmetic operations # will not influence the linked update feature assert child + 100 == 300 assert child.val == 200 assert parent.val == 350 assert grandparent.val == 600 - + child += 100 assert child.val == 300 assert parent.val == 450 assert grandparent.val == 700 - def test_repr(self, linked_upperlink_data): + def test_repr(self, linked_upperlink_data) -> None: """Test correct representation""" - + # no unit_sys parent, child = linked_upperlink_data - assert repr(parent) == "200.0" + assert repr(parent) == "200.0" assert repr(child) == "50.0" # with unit_sys data = UpperLinkData(val=1500, unit_sys=BinaryUnit) - assert repr(data) == "1.46 KiB" # 1500/1024 + assert repr(data) == "1.46 KiB" # 1500/1024 # re-access representation data = UpperLinkData(val=300) data.mark_access() - assert repr(data) == "150.0 [dim](Γ—2)[/]" + assert repr(data) == "150.0 [dim](Γ—2)[/]" # noqa: RUF001 - def test_edge_cases(self, base_upperlink_data): + def test_edge_cases(self, base_upperlink_data) -> None: """Test some edge cases""" # add invalid data with pytest.raises(TypeError): base_upperlink_data += "invalid_type" - + + class TestMetricsData: - def test_valid_init(self): + def test_valid_init(self) -> None: """Test initialization with different arguments""" m = MetricsData() assert isinstance(m.vals, np.ndarray) assert not len(m.vals) assert m._MetricsData__reduce_func is np.mean assert m._MetricsData__unit_sys is CountUnit - assert m.none_str == '-' + assert m.none_str == "-" custom_func = np.median m = MetricsData(reduce_func=custom_func, unit_sys=BinaryUnit, none_str="N/A") @@ -303,93 +307,93 @@ def test_valid_init(self): assert m._MetricsData__unit_sys is BinaryUnit assert m.none_str == "N/A" - def test_invalid_init(self): + def test_invalid_init(self) -> None: """Test invalid initialization""" with pytest.raises(TypeError): MetricsData(reduce_func=100) - + with pytest.raises(RuntimeError): MetricsData(reduce_func=str) - + with pytest.raises(TypeError): MetricsData(unit_sys=100) - + with pytest.raises(TypeError): MetricsData(none_str=22) - def test_slots(self): - """Test __slots__ restriction""" + def test_slots(self) -> None: + """Test __slots__ restriction""" m = MetricsData() with pytest.raises(AttributeError): m.invalid_attr = 100 - def test_empty_data_properties(self): + def test_empty_data_properties(self) -> None: empty_m = MetricsData() assert empty_m.metrics == 0.0 assert empty_m.iqr == 0.0 assert empty_m.val == (0.0, 0.0) assert empty_m.raw_data == 0.0 - def test_single_value_properties(self): + def test_single_value_properties(self) -> None: m = MetricsData() m.append(5.0) assert m.metrics == 5.0 assert not m.iqr assert m.val == (5.0, 0.0) - def test_multi_value_properties(self): + def test_multi_value_properties(self) -> None: m = MetricsData() m.vals = np.array([2.0, 4.0, 6.0, 10.0]) assert m.metrics == 5.5 assert m.iqr == 3.5 # Q3=7.0, Q1=3.5 - def test_reduce_func(self): + def test_reduce_func(self) -> None: m = MetricsData() m.vals = np.array([1.0, 2.0, 6.0]) assert m.metrics == 3.0 # mean m._MetricsData__reduce_func = np.median assert m.metrics == 2.0 # median - + m._MetricsData__reduce_func = np.sum assert m.metrics == 9.0 # sum - def test_data_management(self): + def test_data_management(self) -> None: m = MetricsData() with pytest.raises(TypeError): m.append([1.0, 2.0, 3.0, 4.0, 5.0]) - + m.append(1.0) m.append(2.0) m.append(4.0) assert m.vals.tolist() == [1.0, 2.0, 4.0] - + m.clear() assert not len(m.vals) - def test_repr(self): + def test_repr(self) -> None: """Test correct representation""" # no unit m = MetricsData(unit_sys=None) m.append(1.5) m.append(2.5) - assert repr(m) == "2.00 Β± 0.50" # mean 2.0,IQR=0.5(Q3=2.25, Q1=1.75οΌ‰ + assert repr(m) == "2.00 Β± 0.50" # mean 2.0, IQR=0.5(Q3=2.25, Q1=1.75) # default unit: CountUnit m = MetricsData() m.append(1000) m.append(2000) - assert repr(m) == "1.50 K Β± 500.00" # mean 1500,IQR=500 - + assert repr(m) == "1.50 K Β± 500.00" # mean 1500, IQR=500 + # custom unit m = MetricsData(unit_sys=BinaryUnit) m.append(1000) m.append(2000) - assert repr(m) == "1.46 KiB Β± 500 B" # (mean 1500)/1024,IQR=500 - - def test_edge_cases(self): - """Test some edge cases""" + assert repr(m) == "1.46 KiB Β± 500 B" # (mean 1500)/1024, IQR=500 + + def test_edge_cases(self) -> None: + """Test some edge cases""" # all zero m = MetricsData() m.append(0) diff --git a/tests/test_statistic.py b/tests/test_statistic.py index 69cf713..5ee3a2e 100644 --- a/tests/test_statistic.py +++ b/tests/test_statistic.py @@ -3,51 +3,56 @@ from collections import namedtuple from unittest.mock import MagicMock, PropertyMock, patch -import pytest import numpy as np +import pytest import torch.nn as nn -from pympler.asizeof import asizeof +from torch import int8 as torch_int8 from torch import ones as torch_ones +from torch import int16 as torch_int16 +from torch import int64 as torch_int64 from torch import randn as torch_randn from torch import device as torch_device +from torch import float16 as torch_float16 +from torch import float64 as torch_float64 from torch.cuda import is_available as is_cuda -from torch import ( - float16 as torch_float16, - float64 as torch_float64, - int16 as torch_int16, - int64 as torch_int64, - int8 as torch_int8 -) +from pympler.asizeof import asizeof -from torchmeter.engine import ( - OperationNode, - OperationTree -) +from torchmeter.engine import OperationNode, OperationTree from torchmeter.statistic import ( - UpperLinkData, MetricsData, - BinaryUnit, CountUnit, TimeUnit, SpeedUnit, - Statistics, ParamsMeter, CalMeter, MemMeter, IttpMeter + CalMeter, + MemMeter, + TimeUnit, + CountUnit, + IttpMeter, + SpeedUnit, + BinaryUnit, + Statistics, + MetricsData, + ParamsMeter, + UpperLinkData, ) pytestmark = pytest.mark.vital STAT_TESTED_NOW = "" + @pytest.fixture def empty_model_root(): model = nn.Sequential() optree = OperationTree(model) return model, optree.root + @pytest.fixture def simple_model_root(): from torch import mean as torch_mean - + class SimpleModel(nn.Module): - def __init__(self): + def __init__(self) -> None: super(SimpleModel, self).__init__() self.conv = nn.Conv2d(3, 16, 3) self.linear = nn.Linear(16, 10) - + def forward(self, x): conv_res = self.conv(x) pool_res = torch_mean(conv_res, dim=(2, 3)) @@ -57,10 +62,11 @@ def forward(self, x): for param in model.parameters(): param.requires_grad = False model.conv.weight.requires_grad = True - + optree = OperationTree(model) return model, optree.root + @pytest.fixture def measured_simple_model(simple_model_root): simple_model, simple_oproot = simple_model_root @@ -77,48 +83,51 @@ def measured_simple_model(simple_model_root): getattr(child, STAT_TESTED_NOW).measure() stat.measure() - simple_model(torch_randn(1,3,64,64)) + simple_model(torch_randn(1, 3, 64, 64)) return simple_model, simple_oproot, stat -@pytest.fixture() + +@pytest.fixture def toggle_to_cal(): global STAT_TESTED_NOW STAT_TESTED_NOW = "cal" yield STAT_TESTED_NOW = "" -@pytest.fixture() + +@pytest.fixture def toggle_to_mem(): global STAT_TESTED_NOW STAT_TESTED_NOW = "mem" yield STAT_TESTED_NOW = "" -@pytest.fixture() + +@pytest.fixture def toggle_to_ittp(): global STAT_TESTED_NOW STAT_TESTED_NOW = "ittp" yield STAT_TESTED_NOW = "" + class ConcreteStat(Statistics): detail_val_container = namedtuple('Detail', # type: ignore - ['field1', 'field2']) + ['field1', 'field2']) # fmt: skip overview_val_container = namedtuple('Overview', # type: ignore - ['summary', 'other_field']) + ['summary', 'other_field']) # fmt: skip def __init__(self) -> None: self.StatVal = UpperLinkData(val=50) - + @property def name(self) -> str: return "con_stat" @property def val(self): - return self.overview_val_container(summary=100, - other_field=200) + return self.overview_val_container(summary=100, other_field=200) @property def detail_val(self): @@ -128,87 +137,90 @@ def detail_val(self): def crucial_data(self): return {"key": "value"} - def measure(self): ... + def measure(self) -> None: ... class TestStatistics: - def test_mandatory_attributes_check(self): + def test_mandatory_attributes_check(self) -> None: """Test whether the necessary class properties are implemented.""" - class MissingAll(Statistics):... + + class MissingAll(Statistics): ... + with pytest.raises(AttributeError) as e1: MissingAll() assert "detail_val_container" in str(e1.value) - + class MissingOverviewContainer(Statistics): - detail_val_container = namedtuple('Detail', ['a']) + detail_val_container = namedtuple("Detail", ["a"]) + with pytest.raises(AttributeError) as e2: MissingOverviewContainer() assert "overview_val_container" in str(e2.value) - + class MissingDetailContainer(Statistics): - overview_val_val_container = namedtuple('Detail', ['a']) + overview_val_val_container = namedtuple("Detail", ["a"]) + with pytest.raises(AttributeError) as e3: MissingDetailContainer() assert "detail_val_container" in str(e3.value) - - def test_required_method_property(self): + + def test_required_method_property(self) -> None: """Test whether all abstract methods and property are implemented""" + class InvalidSubclass(Statistics): - detail_val_container = namedtuple('Detail', ['a']) - overview_val_container = namedtuple('Overview', ['b']) + detail_val_container = namedtuple("Detail", ["a"]) + overview_val_container = namedtuple("Overview", ["b"]) with pytest.raises(TypeError) as e: InvalidSubclass() - + required = ["name", "val", "detail_val", "crucial_data", "measure"] assert all(method in str(e.value) for method in required) - def test_init_linkdata(self): + def test_init_linkdata(self) -> None: """Test init_linkdata method""" - # init a upperlinkdata without parent + # init a upperlinkdata without parent linked_data = ConcreteStat().init_linkdata("StatVal", init_val=100) assert linked_data.val == 100 assert linked_data._UpperLinkData__parent_data is None # init a upperlinkdata with parent - mock_opnode = MagicMock(con_stat=ConcreteStat()) # val=50 - mock_opnode.parent = MagicMock(con_stat=ConcreteStat()) # val=50 - linked_data = mock_opnode.con_stat.init_linkdata("StatVal", init_val=100, - opparent=mock_opnode.parent) + mock_opnode = MagicMock(con_stat=ConcreteStat()) # val=50 + mock_opnode.parent = MagicMock(con_stat=ConcreteStat()) # val=50 + linked_data = mock_opnode.con_stat.init_linkdata("StatVal", init_val=100, opparent=mock_opnode.parent) + assert linked_data.val == 100 assert linked_data._UpperLinkData__parent_data is mock_opnode.parent.con_stat.StatVal - + linked_data += 50 - assert mock_opnode.parent.con_stat.StatVal.val == 100 # 50 + 50 + assert mock_opnode.parent.con_stat.StatVal.val == 100 # 50 + 50 - def test_repr(self): + def test_repr(self) -> None: """Test correct representation""" - # without upperlinkdata + # without upperlinkdata stat = ConcreteStat() output = repr(stat) assert output == ( "Overview\n" "β€’ summary = 100\n" "β€’ other_field = 200\n" - ) - + ) # fmt: skip + # with upperlinkdata - with patch.object(ConcreteStat, 'val', - new_callable=PropertyMock) as mock_val: - mock_val.return_value = stat.overview_val_container(summary=100, - other_field=stat.StatVal) + with patch.object(ConcreteStat, "val", new_callable=PropertyMock) as mock_val: + mock_val.return_value = stat.overview_val_container(summary=100, other_field=stat.StatVal) + output = repr(stat) assert output == ( "Overview\n" "β€’ summary = 100\n" "β€’ other_field = 50.00 = 50.0\n" - ) + ) # fmt: skip + + assert mock_val.call_count == 3 # title + 2 fields - assert mock_val.call_count == 3 # title + 2 fields - # with invalid field - with patch.object(ConcreteStat, 'ov_fields', - new_callable=PropertyMock) as mock_val: + with patch.object(ConcreteStat, "ov_fields", new_callable=PropertyMock) as mock_val: mock_val.return_value = ("summary", "invalid_field") output = repr(stat) @@ -216,69 +228,70 @@ def test_repr(self): "Overview\n" "β€’ summary = 100\n" "β€’ invalid_field = N/A\n" - ) + ) # fmt: skip - assert mock_val.call_count == 2 # max_len + for loop + assert mock_val.call_count == 2 # max_len + for loop - def test_tbov_fields(self): + def test_tbov_fields(self) -> None: """Test tb_fields and ov_fields are set correctly""" stat = ConcreteStat() assert stat.tb_fields == ("field1", "field2") - assert stat.ov_fields == ("summary","other_field") + assert stat.ov_fields == ("summary", "other_field") - def test_crucial_data(self): + def test_crucial_data(self) -> None: """Test crucial_data is set correctly""" stat = ConcreteStat() data = stat.crucial_data - assert data == {"key":"value"} + assert data == {"key": "value"} + class TestParamsMeter: - def test_cls_variable(self): + def test_cls_variable(self) -> None: """Test detail_val_container and overview_val_container settings""" assert hasattr(ParamsMeter, "detail_val_container") dc = ParamsMeter.detail_val_container assert all(v is None for v in dc._field_defaults.values()) - + assert hasattr(ParamsMeter, "overview_val_container") oc = ParamsMeter.overview_val_container assert all(v is None for v in oc._field_defaults.values()) - - def test_valid_init(self, simple_model_root): + + def test_valid_init(self, simple_model_root) -> None: """Test valid initialization""" model, oproot = simple_model_root - + param_meter = oproot.param assert param_meter._opnode == oproot assert param_meter._model is model assert not param_meter.is_measured assert not param_meter._ParamsMeter__stat_ls - - assert param_meter.name == "param" + + assert param_meter.name == "param" assert hasattr(param_meter, "RegNum") assert isinstance(param_meter.RegNum, UpperLinkData) assert param_meter.RegNum.val == 0 assert param_meter.RegNum._UpperLinkData__parent_data is None assert param_meter.RegNum._UpperLinkData__unit_sys is CountUnit - + assert hasattr(param_meter, "TotalNum") assert isinstance(param_meter.TotalNum, UpperLinkData) assert param_meter.TotalNum.val == 0 assert param_meter.TotalNum._UpperLinkData__parent_data is None assert param_meter.TotalNum._UpperLinkData__unit_sys is CountUnit - def test_invalid_init(self): + def test_invalid_init(self) -> None: """Test invalid initialization""" with pytest.raises(TypeError): ParamsMeter(opnode="0") - def test_val_property(self, simple_model_root): + def test_val_property(self, simple_model_root) -> None: """Test whether the val property is properly set""" - simple_model, simple_oproot = simple_model_root + _simple_model, simple_oproot = simple_model_root param_meter = simple_oproot.param for child in simple_oproot.childs.values(): child.param.measure() param_meter.measure() - + overview = param_meter.val assert isinstance(overview, ParamsMeter.overview_val_container) assert overview.Operation_Id == "0" @@ -287,59 +300,61 @@ def test_val_property(self, simple_model_root): assert overview.Total_Params is param_meter.TotalNum assert overview.Learnable_Params is param_meter.RegNum - def test_crucial_data_format(self, simple_model_root): + def test_crucial_data_format(self, simple_model_root) -> None: """Test whether the crucial_data is return in correct format""" - simple_model, simple_oproot = simple_model_root + _simple_model, simple_oproot = simple_model_root param_meter = simple_oproot.param crucial_data = param_meter.crucial_data assert isinstance(crucial_data, dict) - + # verify align keys = list(crucial_data.keys()) - assert all(isinstance(k, str) for k in crucial_data.keys()) + assert all(isinstance(k, str) for k in crucial_data) assert all(len(k) == len(keys[0]) for k in keys[1:]) - + # verify value - assert all(isinstance(v,str) for v in crucial_data.values()) + assert all(isinstance(v, str) for v in crucial_data.values()) - def test_param_measure(self, empty_model_root, simple_model_root): + def test_param_measure(self, empty_model_root, simple_model_root) -> None: """Test whether the measure method works well""" # model without parameters - empty_model, empty_oproot = empty_model_root + _empty_model, empty_oproot = empty_model_root empty_pm = empty_oproot.param empty_pm.measure() - + assert empty_pm.is_measured assert empty_pm.RegNum.val == 0 assert empty_pm.TotalNum.val == 0 - - assert len(empty_pm.detail_val) == 1 + + assert len(empty_pm.detail_val) == 1 record = empty_pm.detail_val[0] assert record.Operation_Id == "0" assert record.Operation_Name == "Sequential" assert record.Operation_Type == "Sequential" assert record.Numeric_Num.val == 0 - + # model with parameters simple_model, simple_oproot = simple_model_root param_meter = simple_oproot.param for child in simple_oproot.childs.values(): child.param.measure() param_meter.measure() - + assert param_meter.is_measured assert all(c.param.is_measured for c in simple_oproot.childs.values()) - + assert param_meter.RegNum.val == simple_model.conv.weight.numel() - assert param_meter.TotalNum.val == sum([simple_model.conv.weight.numel(), - simple_model.conv.bias.numel(), - simple_model.linear.weight.numel(), - simple_model.linear.bias.numel()]) - + assert param_meter.TotalNum.val == sum([ + simple_model.conv.weight.numel(), + simple_model.conv.bias.numel(), + simple_model.linear.weight.numel(), + simple_model.linear.bias.numel(), + ]) + assert len(param_meter.detail_val) == 1 # empty record assert len(simple_oproot.childs["1"].param.detail_val) == 2 # Conv2d: weight+bias assert len(simple_oproot.childs["2"].param.detail_val) == 2 # Linear: weight+bias - + records = param_meter.detail_val records.extend(simple_oproot.childs["1"].param.detail_val) records.extend(simple_oproot.childs["2"].param.detail_val) @@ -350,73 +365,74 @@ def test_param_measure(self, empty_model_root, simple_model_root): else: assert record.Requires_Grad is False - def test_measure_cache(self): - """Test whether the measure method will be revisited after the first call""" + def test_measure_cache(self) -> None: + """Test whether the measure method will be revisited after the first call""" model = nn.Linear(10, 5) opnode = OperationNode(module=model) pm = ParamsMeter(opnode) - + pm.measure() initial_total = pm.TotalNum.val - - model.weight = nn.Parameter(torch_randn(5, 20)) + + model.weight = nn.Parameter(torch_randn(5, 20)) pm._model = model pm.measure() - + assert pm.TotalNum.val == initial_total - def test_none_parameter_handling(self): + def test_none_parameter_handling(self) -> None: """Test whether the None parameter is skipped correctly""" - + class BadModule(nn.Module): - def __init__(self): + def __init__(self) -> None: super().__init__() self.weight = nn.Parameter(torch_randn(10, 10)) self.bias = None - + mock_opnode = OperationNode(module=BadModule()) pm = ParamsMeter(mock_opnode) - + pm.measure() assert len(pm.detail_val) == 1 assert pm.detail_val[0].Param_Name == "weight" + @pytest.mark.usefixtures("toggle_to_cal") class TestCalMeter: - def test_cls_variable(self): + def test_cls_variable(self) -> None: """Test detail_val_container and overview_val_container settings""" assert hasattr(CalMeter, "detail_val_container") dc = CalMeter.detail_val_container assert all(v is None for v in dc._field_defaults.values()) - + assert hasattr(CalMeter, "overview_val_container") oc = CalMeter.overview_val_container assert all(v is None for v in oc._field_defaults.values()) - - def test_valid_init(self, simple_model_root): + + def test_valid_init(self, simple_model_root) -> None: """Test valid initialization""" model, oproot = simple_model_root - + cal_meter = oproot.cal assert cal_meter._opnode == oproot assert cal_meter._model is model assert not cal_meter.is_measured assert not cal_meter._CalMeter__is_not_supported assert not cal_meter._CalMeter__stat_ls - - assert cal_meter.name == "cal" - + + assert cal_meter.name == "cal" + assert hasattr(cal_meter, "is_not_supported") - assert not cal_meter.is_not_supported - + assert not cal_meter.is_not_supported + assert hasattr(cal_meter, "Macs") assert isinstance(cal_meter.Macs, UpperLinkData) assert cal_meter.Macs.val == 0 assert cal_meter.Macs._UpperLinkData__parent_data is None assert cal_meter.Macs._UpperLinkData__unit_sys is CountUnit assert cal_meter.Macs.none_str == "Not Supported" - + assert hasattr(cal_meter, "Flops") assert isinstance(cal_meter.Flops, UpperLinkData) assert cal_meter.Flops.val == 0 @@ -424,15 +440,15 @@ def test_valid_init(self, simple_model_root): assert cal_meter.Flops._UpperLinkData__unit_sys is CountUnit assert cal_meter.Flops.none_str == "Not Supported" - def test_invalid_init(self): + def test_invalid_init(self) -> None: """Test invalid initialization""" with pytest.raises(TypeError): CalMeter(opnode="0") - def test_val_property(self, measured_simple_model): + def test_val_property(self, measured_simple_model) -> None: """Test whether the val property is properly set""" *_, cal_meter = measured_simple_model - + overview = cal_meter.val assert isinstance(overview, CalMeter.overview_val_container) assert overview.Operation_Id == "0" @@ -441,45 +457,40 @@ def test_val_property(self, measured_simple_model): assert overview.MACs is cal_meter.Macs assert overview.FLOPs is cal_meter.Flops - def test_crucial_data_format(self, measured_simple_model): + def test_crucial_data_format(self, measured_simple_model) -> None: """Test whether the crucial_data is return in correct format""" *_, cal_meter = measured_simple_model crucial_data = cal_meter.crucial_data assert isinstance(crucial_data, dict) - + # verify align keys = list(crucial_data.keys()) - assert all(isinstance(k,str) for k in crucial_data.keys()) + assert all(isinstance(k, str) for k in crucial_data) assert all(len(k) == len(keys[0]) for k in keys[1:]) - + # verify value - assert all(isinstance(v,str) for v in crucial_data.values()) + assert all(isinstance(v, str) for v in crucial_data.values()) @pytest.mark.parametrize( - argnames=("module", "target_hook"), + argnames=("module", "target_hook"), argvalues=[ (nn.Sequential(nn.Identity()), "__container_hook"), - (nn.ModuleList([nn.Identity()]), "__container_hook"), - (nn.ModuleDict({"example":nn.Identity()}), "__container_hook"), - - (nn.Conv1d(10, 5, 3), "__conv_hook"), - (nn.Conv2d(10, 5, 3), "__conv_hook"), - (nn.Conv3d(10, 5, 3), "__conv_hook"), - - (nn.Linear(10, 5), "__linear_hook"), - - (nn.BatchNorm1d(10), "__BN_hook"), - (nn.BatchNorm2d(10), "__BN_hook"), - (nn.BatchNorm3d(10), "__BN_hook"), - - (nn.MaxPool1d(3), "__pool_hook"), - (nn.MaxPool2d(3), "__pool_hook"), + (nn.ModuleList([nn.Identity()]), "__container_hook"), + (nn.ModuleDict({"example": nn.Identity()}), "__container_hook"), + (nn.Conv1d(10, 5, 3), "__conv_hook"), + (nn.Conv2d(10, 5, 3), "__conv_hook"), + (nn.Conv3d(10, 5, 3), "__conv_hook"), + (nn.Linear(10, 5), "__linear_hook"), + (nn.BatchNorm1d(10), "__BN_hook"), + (nn.BatchNorm2d(10), "__BN_hook"), + (nn.BatchNorm3d(10), "__BN_hook"), + (nn.MaxPool1d(3), "__pool_hook"), + (nn.MaxPool2d(3), "__pool_hook"), (nn.MaxPool3d(3), "__pool_hook"), - (nn.AvgPool1d(3), "__pool_hook"), - (nn.AvgPool2d(3), "__pool_hook"), - (nn.AvgPool3d(3), "__pool_hook"), - - (nn.Sigmoid(), "__activate_hook"), + (nn.AvgPool1d(3), "__pool_hook"), + (nn.AvgPool2d(3), "__pool_hook"), + (nn.AvgPool3d(3), "__pool_hook"), + (nn.Sigmoid(), "__activate_hook"), (nn.Tanh(), "__activate_hook"), (nn.ReLU(), "__activate_hook"), (nn.ReLU6(), "__activate_hook"), @@ -487,43 +498,42 @@ def test_crucial_data_format(self, measured_simple_model): (nn.PReLU(), "__activate_hook"), (nn.RReLU(), "__activate_hook"), (nn.LeakyReLU(), "__activate_hook"), - - (nn.Dropout(0.5), "__not_support_hook"), - (nn.AdaptiveAvgPool1d(1), "__not_support_hook"), - (nn.Identity(), "__not_support_hook"), - ] + (nn.Dropout(0.5), "__not_support_hook"), + (nn.AdaptiveAvgPool1d(1), "__not_support_hook"), + (nn.Identity(), "__not_support_hook"), + ], ) - def test_cal_measure(self, module, target_hook): + def test_cal_measure(self, module, target_hook) -> None: """Test whether the measure method works well""" opnode = OperationNode(module=module) cal_meter = opnode.cal cal_meter.measure() - + assert cal_meter.is_measured assert len(module._forward_hooks) == 1 assert next(iter(module._forward_hooks.values())).__name__ == target_hook - - def test_measure_cache(self, simple_model_root): + + def test_measure_cache(self, simple_model_root) -> None: """Test whether the measure method will be revisited after the first call""" - model, oproot = simple_model_root + _model, oproot = simple_model_root cal_meter = oproot.cal - + res = cal_meter.measure() assert res is not None - + res = cal_meter.measure() assert res is None - def test_valid_access(self, simple_model_root): + def test_valid_access(self, simple_model_root) -> None: """Test whether the invalid access will be blocked""" - model, oproot = simple_model_root + _model, oproot = simple_model_root cal_meter = oproot.cal - + # access property before measure with pytest.raises(AttributeError) as e: cal_meter.detail_val assert "cal" in str(e.value) - + with pytest.raises(AttributeError) as e: cal_meter.val assert "cal" in str(e.value) @@ -531,76 +541,63 @@ def test_valid_access(self, simple_model_root): with pytest.raises(AttributeError) as e: cal_meter.crucial_data assert "cal" in str(e.value) - + # access skipped module after measure cal_meter.measure() with pytest.raises(RuntimeError): cal_meter.detail_val - + with pytest.raises(RuntimeError): cal_meter.val - + with pytest.raises(RuntimeError): cal_meter.crucial_data - + @pytest.mark.parametrize( argnames=("iopt", "expected"), argvalues=[ - (torch_randn(3,4,5), "[3, 4, 5]"), - + (torch_randn(3, 4, 5), "[3, 4, 5]"), (None, "None"), (123, "int"), (1.5, "float"), - (np.array([1,2,3]), "ndarray"), - - ((torch_randn(1,2,3),), "[1, 2, 3]"), - ([torch_randn(4,5,6)], "[4, 5, 6]"), - ({torch_randn(7,8,9)}, "[7, 8, 9]"), - ({"k":torch_randn(2,4,6)}, "{str: [2, 4, 6]}"), - - ((torch_randn(2,3),)*3, ("([2, 3],\n" - " [2, 3],\n" - " [2, 3])")), - ([torch_randn(3,4)]*3, ("([3, 4],\n" - " [3, 4],\n" - " [3, 4])")), - ({"k":torch_randn(2,3), - "l":torch_randn(4,5), - "m":torch_randn(6,7)}, ("{str: [2, 3],\n" - " str: [4, 5],\n" - " str: [6, 7]}")), - ] - ) - def test_iopt_repr(self, iopt, expected): + (np.array([1, 2, 3]), "ndarray"), + ((torch_randn(1, 2, 3),), "[1, 2, 3]"), + ([torch_randn(4, 5, 6)], "[4, 5, 6]"), + ({torch_randn(7, 8, 9)}, "[7, 8, 9]"), + ({"k": torch_randn(2, 4, 6)}, "{str: [2, 4, 6]}"), + ((torch_randn(2, 3),) * 3, ("([2, 3],\n [2, 3],\n [2, 3])")), + ([torch_randn(3, 4)] * 3, ("([3, 4],\n [3, 4],\n [3, 4])")), + ( + {"k": torch_randn(2, 3), "l": torch_randn(4, 5), "m": torch_randn(6, 7)}, + ("{str: [2, 3],\n str: [4, 5],\n str: [6, 7]}"), + ), + ], + ) + def test_iopt_repr(self, iopt, expected) -> None: """Test whether the __iopt_repr method works well""" oproot = OperationNode(module=nn.Identity()) cal_meter = oproot.cal iopt_repr = cal_meter._CalMeter__iopt_repr - - assert iopt_repr(iopt) == expected + + assert iopt_repr(iopt) == expected @pytest.mark.parametrize( argnames=("module", "ipt_shape"), argvalues=[ (nn.Sequential(nn.Identity()), (1, 10)), - (nn.Linear(10, 5), (1, 10)), - (nn.Conv1d(10, 5, 3), (1, 10, 32)), (nn.Conv2d(10, 5, 3), (1, 10, 32, 32)), (nn.Conv3d(10, 5, 3), (1, 10, 32, 32, 32)), - (nn.MaxPool1d(3), (1, 10, 32)), (nn.MaxPool2d(3), (1, 10, 32, 32)), (nn.MaxPool3d(3), (1, 10, 32, 32, 32)), (nn.AvgPool1d(3), (1, 10, 32)), (nn.AvgPool2d(3), (1, 10, 32, 32)), (nn.AvgPool3d(3), (1, 10, 32, 32, 32)), - (nn.BatchNorm1d(10), (1, 10, 32)), (nn.BatchNorm2d(10), (1, 10, 32, 32)), (nn.BatchNorm3d(10), (1, 10, 32, 32, 32)), - (nn.Sigmoid(), (1, 10)), (nn.Tanh(), (1, 10)), (nn.ReLU(), (1, 10)), @@ -609,25 +606,24 @@ def test_iopt_repr(self, iopt, expected): (nn.PReLU(), (1, 10)), (nn.RReLU(), (1, 10)), (nn.LeakyReLU(), (1, 10)), - - (nn.Dropout(0.5), (1, 10)), - (nn.AdaptiveAvgPool1d(1), (1, 32, 8)), - (nn.Identity(), (1, 10)), - ] + (nn.Dropout(0.5), (1, 10)), + (nn.AdaptiveAvgPool1d(1), (1, 32, 8)), + (nn.Identity(), (1, 10)), + ], ) - def test_reaccess_module(self, module, ipt_shape): + def test_reaccess_module(self, module, ipt_shape) -> None: """Test reaccess handling""" oproot = OperationTree(module).root cal_meter = oproot.cal - + cal_meter.measure() if not oproot.is_leaf: - list(map(lambda x:x.cal.measure(), oproot.childs.values())) + list(map(lambda x: x.cal.measure(), oproot.childs.values())) module(torch_randn(*ipt_shape)) - + assert cal_meter.Macs._UpperLinkData__access_cnt == 1 assert cal_meter.Flops._UpperLinkData__access_cnt == 1 - + hook_func = next(iter(module._forward_hooks.values())).__name__ if "not_support_hook" not in hook_func: module(torch_randn(*ipt_shape)) @@ -638,40 +634,51 @@ def test_reaccess_module(self, module, ipt_shape): assert cal_meter.Flops._UpperLinkData__access_cnt == 1 @pytest.mark.parametrize( - argnames=("module", "ipt_shape", "expected_opt_shape", - "expected_macs", "expected_flops"), + argnames=( + "module", + "ipt_shape", + "expected_opt_shape", + "expected_macs", + "expected_flops", + ), argvalues=[ (nn.Sequential(nn.Identity()), (1, 10), (1, 10), 0, 0), - (nn.Sequential(nn.Conv2d(3,10,3), - nn.Conv2d(10,30,1)), (1, 3, 32, 32), (1, 30, 30, 30), 30**2*10*27+30**3*10, 30**2*20*27+30**3*20), - - (nn.Linear(10, 5, bias=True), (1, 10), (1, 5), 5*10, 5*10*2), - (nn.Linear(10, 5, bias=False), (1, 10), (1, 5), 5*10, 5*10*2-5), - - (nn.Conv1d(10, 5, 3, bias=True), (1, 10, 32), (1, 5, 30), 150*30, 150*30*2), - (nn.Conv1d(10, 5, 3, bias=False), (1, 10, 32), (1, 5, 30), 150*30, 150*30*2-150), - (nn.Conv2d(10, 5, 3, bias=True), (1, 10, 32, 32), (1, 5, 30, 30), 4500*90, 4500*2*90), - (nn.Conv2d(10, 5, 3, bias=False), (1, 10, 32, 32), (1, 5, 30, 30), 4500*90, 4500*2*90-4500), - (nn.Conv3d(10, 5, 3, bias=True), (1, 10, 32, 32, 32), (1, 5, 30, 30, 30), 135000*270, 135000*270*2), - (nn.Conv3d(10, 5, 3, bias=False), (1, 10, 32, 32, 32), (1, 5, 30, 30, 30), 135000*270, 135000*270*2-135000), - - (nn.MaxPool1d(3, ceil_mode=False), (1, 10, 32), (1, 10, 10), 2*100, 2*100), - (nn.MaxPool2d(3, ceil_mode=False), (1, 10, 32, 32), (1, 10, 10, 10), 8*10**3, 8*10**3), - (nn.MaxPool3d(3, ceil_mode=False), (1, 10, 32, 32, 32), (1, 10, 10, 10, 10), 26*10**4, 26*10**4), - (nn.MaxPool1d(3, ceil_mode=True), (1, 10, 32), (1, 10, 11), 2*110, 2*110), - (nn.MaxPool2d(3, ceil_mode=True), (1, 10, 32, 32), (1, 10, 11, 11), 80*11**2, 80*11**2), - (nn.MaxPool3d(3, ceil_mode=True), (1, 10, 32, 32, 32), (1, 10, 11, 11, 11), 260*11**3, 260*11**3), - (nn.AvgPool1d(3, ceil_mode=False), (1, 10, 32), (1, 10, 10), 2*100, 5*100), - (nn.AvgPool2d(3, ceil_mode=False), (1, 10, 32, 32), (1, 10, 10, 10), 8*10**3, 17*10**3), - (nn.AvgPool3d(3, ceil_mode=False), (1, 10, 32, 32, 32), (1, 10, 10, 10, 10), 26*10**4, 53*10**4), - (nn.AvgPool1d(3, ceil_mode=True), (1, 10, 32), (1, 10, 11), 2*110, 5*110), - (nn.AvgPool2d(3, ceil_mode=True), (1, 10, 32, 32), (1, 10, 11, 11), 80*11**2, 170*11**2), - (nn.AvgPool3d(3, ceil_mode=True), (1, 10, 32, 32, 32), (1, 10, 11, 11, 11), 260*11**3, 530*11**3), - - (nn.BatchNorm1d(10), (1, 10, 32), (1, 10, 32), 320*2, 320*4), - (nn.BatchNorm2d(10), (1, 10, 32, 32), (1, 10, 32, 32), 32*32*20, 32*32*40), - (nn.BatchNorm3d(10), (1, 10, 32, 32, 32), (1, 10, 32, 32, 32), 32**3*20, 32**3*40), - + ( + nn.Sequential(nn.Conv2d(3, 10, 3), nn.Conv2d(10, 30, 1)), + (1, 3, 32, 32), + (1, 30, 30, 30), + 30**2 * 10 * 27 + 30**3 * 10, + 30**2 * 20 * 27 + 30**3 * 20, + ), + (nn.Linear(10, 5, bias=True), (1, 10), (1, 5), 5 * 10, 5 * 10 * 2), + (nn.Linear(10, 5, bias=False), (1, 10), (1, 5), 5 * 10, 5 * 10 * 2 - 5), + (nn.Conv1d(10, 5, 3, bias=True), (1, 10, 32), (1, 5, 30), 150 * 30, 150 * 30 * 2), + (nn.Conv1d(10, 5, 3, bias=False), (1, 10, 32), (1, 5, 30), 150 * 30, 150 * 30 * 2 - 150), + (nn.Conv2d(10, 5, 3, bias=True), (1, 10, 32, 32), (1, 5, 30, 30), 4500 * 90, 4500 * 2 * 90), + (nn.Conv2d(10, 5, 3, bias=False), (1, 10, 32, 32), (1, 5, 30, 30), 4500 * 90, 4500 * 2 * 90 - 4500), + (nn.Conv3d(10, 5, 3, bias=True), (1, 10, 32, 32, 32), (1, 5, 30, 30, 30), 135000 * 270, 135000 * 270 * 2), + ( + nn.Conv3d(10, 5, 3, bias=False), + (1, 10, 32, 32, 32), + (1, 5, 30, 30, 30), + 135000 * 270, + 135000 * 270 * 2 - 135000, + ), + (nn.MaxPool1d(3, ceil_mode=False), (1, 10, 32), (1, 10, 10), 2 * 100, 2 * 100), + (nn.MaxPool2d(3, ceil_mode=False), (1, 10, 32, 32), (1, 10, 10, 10), 8 * 10**3, 8 * 10**3), + (nn.MaxPool3d(3, ceil_mode=False), (1, 10, 32, 32, 32), (1, 10, 10, 10, 10), 26 * 10**4, 26 * 10**4), + (nn.MaxPool1d(3, ceil_mode=True), (1, 10, 32), (1, 10, 11), 2 * 110, 2 * 110), + (nn.MaxPool2d(3, ceil_mode=True), (1, 10, 32, 32), (1, 10, 11, 11), 80 * 11**2, 80 * 11**2), + (nn.MaxPool3d(3, ceil_mode=True), (1, 10, 32, 32, 32), (1, 10, 11, 11, 11), 260 * 11**3, 260 * 11**3), + (nn.AvgPool1d(3, ceil_mode=False), (1, 10, 32), (1, 10, 10), 2 * 100, 5 * 100), + (nn.AvgPool2d(3, ceil_mode=False), (1, 10, 32, 32), (1, 10, 10, 10), 8 * 10**3, 17 * 10**3), + (nn.AvgPool3d(3, ceil_mode=False), (1, 10, 32, 32, 32), (1, 10, 10, 10, 10), 26 * 10**4, 53 * 10**4), + (nn.AvgPool1d(3, ceil_mode=True), (1, 10, 32), (1, 10, 11), 2 * 110, 5 * 110), + (nn.AvgPool2d(3, ceil_mode=True), (1, 10, 32, 32), (1, 10, 11, 11), 80 * 11**2, 170 * 11**2), + (nn.AvgPool3d(3, ceil_mode=True), (1, 10, 32, 32, 32), (1, 10, 11, 11, 11), 260 * 11**3, 530 * 11**3), + (nn.BatchNorm1d(10), (1, 10, 32), (1, 10, 32), 320 * 2, 320 * 4), + (nn.BatchNorm2d(10), (1, 10, 32, 32), (1, 10, 32, 32), 32 * 32 * 20, 32 * 32 * 40), + (nn.BatchNorm3d(10), (1, 10, 32, 32, 32), (1, 10, 32, 32, 32), 32**3 * 20, 32**3 * 40), (nn.Sigmoid(), (1, 10), (1, 10), 20, 40), (nn.Tanh(), (1, 10), (1, 10), 50, 90), (nn.ReLU(), (1, 10), (1, 10), 10, 10), @@ -680,105 +687,106 @@ def test_reaccess_module(self, module, ipt_shape): (nn.PReLU(), (1, 10), (1, 10), 20, 40), (nn.RReLU(), (1, 10), (1, 10), 20, 40), (nn.LeakyReLU(), (1, 10), (1, 10), 20, 40), - (nn.Dropout(0.5), (1, 10), (1, 10), 0, 0), - (nn.AdaptiveAvgPool1d(1), (1, 32, 8), (1, 32, 1), 0, 0), - (nn.Identity(), (1, 10), (1, 10), 0, 0), - ] + (nn.AdaptiveAvgPool1d(1), (1, 32, 8), (1, 32, 1), 0, 0), + (nn.Identity(), (1, 10), (1, 10), 0, 0), + ], ) - def test_module_measurement_logic(self, module, ipt_shape, expected_opt_shape, - expected_macs, expected_flops): + def test_module_measurement_logic( + self, module, ipt_shape, expected_opt_shape, expected_macs, expected_flops + ) -> None: """Test whether the measurement logic is true""" oproot = OperationTree(module).root cal_meter = oproot.cal - + assert not cal_meter._CalMeter__stat_ls cal_meter.measure() if not oproot.is_leaf: - list(map(lambda x:x.cal.measure(), oproot.childs.values())) + list(map(lambda x: x.cal.measure(), oproot.childs.values())) opt = module(torch_randn(*ipt_shape)) assert tuple(opt.shape) == expected_opt_shape assert len(cal_meter._CalMeter__stat_ls) == 1 - - assert cal_meter.Macs.val == expected_macs + + assert cal_meter.Macs.val == expected_macs assert cal_meter.Flops.val == expected_flops - def test_not_supported_flag(self): + def test_not_supported_flag(self) -> None: """Test the is_not_supported property is set and retrieved correctly""" - + model = nn.Identity() opnode = OperationNode(module=model) cal_meter = opnode.cal - + # retrieve assert not cal_meter.is_not_supported - + # valid set model.register_forward_hook(cal_meter._CalMeter__not_support_hook) model(torch_randn(1, 10)) - + assert cal_meter.is_not_supported - + # invalid set with pytest.raises(AttributeError): del cal_meter.is_not_supported + @pytest.mark.usefixtures("toggle_to_mem") class TestMemMeter: - def test_cls_variable(self): + def test_cls_variable(self) -> None: """Test detail_val_container and overview_val_container settings""" assert hasattr(MemMeter, "detail_val_container") dc = MemMeter.detail_val_container assert all(v is None for v in dc._field_defaults.values()) - + assert hasattr(MemMeter, "overview_val_container") oc = MemMeter.overview_val_container assert all(v is None for v in oc._field_defaults.values()) - - def test_valid_init(self, simple_model_root): + + def test_valid_init(self, simple_model_root) -> None: """Test valid initialization""" model, oproot = simple_model_root - + mem_meter = oproot.mem assert mem_meter._opnode == oproot assert mem_meter._model is model assert not mem_meter.is_measured assert not mem_meter._MemMeter__stat_ls - - assert mem_meter.name == "mem" + + assert mem_meter.name == "mem" assert hasattr(mem_meter, "ParamCost") assert isinstance(mem_meter.ParamCost, UpperLinkData) assert mem_meter.ParamCost.val == 0 assert mem_meter.ParamCost._UpperLinkData__parent_data is None assert mem_meter.ParamCost._UpperLinkData__unit_sys is BinaryUnit - + assert hasattr(mem_meter, "BufferCost") assert isinstance(mem_meter.BufferCost, UpperLinkData) assert mem_meter.BufferCost.val == 0 assert mem_meter.BufferCost._UpperLinkData__parent_data is None assert mem_meter.BufferCost._UpperLinkData__unit_sys is BinaryUnit - + assert hasattr(mem_meter, "OutputCost") assert isinstance(mem_meter.OutputCost, UpperLinkData) assert mem_meter.OutputCost.val == 0 assert mem_meter.OutputCost._UpperLinkData__parent_data is None assert mem_meter.OutputCost._UpperLinkData__unit_sys is BinaryUnit - + assert hasattr(mem_meter, "TotalCost") assert isinstance(mem_meter.TotalCost, UpperLinkData) assert mem_meter.TotalCost.val == 0 assert mem_meter.TotalCost._UpperLinkData__parent_data is None assert mem_meter.TotalCost._UpperLinkData__unit_sys is BinaryUnit - def test_invalid_init(self): + def test_invalid_init(self) -> None: """Test invalid initialization""" with pytest.raises(TypeError): MemMeter(opnode="0") - def test_val_property(self, measured_simple_model): + def test_val_property(self, measured_simple_model) -> None: """Test whether the val property is properly set""" *_, mem_meter = measured_simple_model - + overview = mem_meter.val assert isinstance(overview, MemMeter.overview_val_container) assert overview.Operation_Id == "0" @@ -789,52 +797,52 @@ def test_val_property(self, measured_simple_model): assert overview.Output_Cost is mem_meter.OutputCost assert overview.Total is mem_meter.TotalCost - def test_crucial_data_format(self, measured_simple_model): + def test_crucial_data_format(self, measured_simple_model) -> None: """Test whether the crucial_data is return in correct format""" *_, mem_meter = measured_simple_model crucial_data = mem_meter.crucial_data assert isinstance(crucial_data, dict) - + # verify align keys = list(crucial_data.keys()) - assert all(isinstance(k,str) for k in crucial_data.keys()) + assert all(isinstance(k, str) for k in crucial_data) assert all(len(k) == len(keys[0]) for k in keys[1:]) - + # verify value - assert all(isinstance(v,str) for v in crucial_data.values()) + assert all(isinstance(v, str) for v in crucial_data.values()) - def test_mem_measure(self): + def test_mem_measure(self) -> None: """Test whether the measure method works well""" module = nn.Identity() opnode = OperationNode(module) mem_meter = opnode.mem mem_meter.measure() - + assert mem_meter.is_measured assert len(module._forward_hooks) == 1 assert next(iter(module._forward_hooks.values())).__name__ == "__hook_func" - - def test_measure_cache(self, simple_model_root): + + def test_measure_cache(self, simple_model_root) -> None: """Test whether the measure method will be revisited after the first call""" - model, oproot = simple_model_root + _model, oproot = simple_model_root mem_meter = oproot.mem - + res = mem_meter.measure() assert res is not None - + res = mem_meter.measure() assert res is None - def test_valid_access(self, simple_model_root): + def test_valid_access(self, simple_model_root) -> None: """Test whether the invalid access will be blocked""" - model, oproot = simple_model_root + _model, oproot = simple_model_root mem_meter = oproot.mem - + # access property before measure with pytest.raises(AttributeError) as e: mem_meter.detail_val assert "mem" in str(e.value) - + with pytest.raises(AttributeError) as e: mem_meter.val assert "mem" in str(e.value) @@ -842,18 +850,18 @@ def test_valid_access(self, simple_model_root): with pytest.raises(AttributeError) as e: mem_meter.crucial_data assert "mem" in str(e.value) - + # access skipped module after measure mem_meter.measure() with pytest.raises(RuntimeError): mem_meter.detail_val - + with pytest.raises(RuntimeError): mem_meter.val - + with pytest.raises(RuntimeError): mem_meter.crucial_data - + @pytest.mark.parametrize( argnames=("module", "ipt_shape", "is_inplace"), argvalues=[ @@ -865,7 +873,6 @@ def test_valid_access(self, simple_model_root): (nn.SELU(), (1, 10), False), (nn.Dropout(0.5), (1, 10), False), (nn.Threshold(0.1, 20), (1, 10), False), - (nn.ReLU(inplace=True), (1, 10), True), (nn.ReLU6(inplace=True), (1, 10), True), (nn.SiLU(inplace=True), (1, 10), True), @@ -874,12 +881,11 @@ def test_valid_access(self, simple_model_root): (nn.SELU(inplace=True), (1, 10), True), (nn.Dropout(0.5, inplace=True), (1, 10), True), (nn.Threshold(0.1, 20, inplace=True), (1, 10), True), - (nn.GELU(), (1, 10), False), (nn.PReLU(), (1, 10), False), (nn.Sigmoid(), (1, 10), False), (nn.Tanh(), (1, 10), False), - (nn.Conv1d(10,5,3), (1, 10, 32), False), + (nn.Conv1d(10, 5, 3), (1, 10, 32), False), (nn.Linear(10, 5), (1, 10), False), (nn.BatchNorm1d(10), (1, 10, 32), False), (nn.AvgPool1d(3), (1, 10, 32), False), @@ -887,21 +893,21 @@ def test_valid_access(self, simple_model_root): (nn.Identity(), (1, 10), False), (nn.Sequential(), (1, 10), False), (nn.Sequential(nn.Identity()), (1, 10), False), - ] + ], ) - def test_inplace_module_handling(self, module, ipt_shape, is_inplace): + def test_inplace_module_handling(self, module, ipt_shape, is_inplace) -> None: """Test whether the inplace module will be handled properly""" opnode = OperationNode(module) mem_meter = opnode.mem assert mem_meter.is_inplace is is_inplace mem_meter.measure() - + module(torch_randn(*ipt_shape)) - + record = mem_meter.detail_val[0] if is_inplace: assert record.Operation_Type.endswith("(inplace)") - assert mem_meter.OutputCost.val == 0 + assert mem_meter.OutputCost.val == 0 else: assert not record.Operation_Type.endswith("(inplace)") if opnode.is_leaf: @@ -910,52 +916,45 @@ def test_inplace_module_handling(self, module, ipt_shape, is_inplace): @pytest.mark.parametrize( argnames=("opts", "expected_opt_cost"), argvalues=[ - (1, 32), - (1., 24), # python default size for float - + (1, 32), + (1.0, 24), # python default size for float ("1", 1 + (49 if sys.version_info < (3, 12) else 41)), - ("-"*50, 50 + (49 if sys.version_info < (3, 12) else 41)), - - (None, 16), # python default size for None - + ("-" * 50, 50 + (49 if sys.version_info < (3, 12) else 41)), + (None, 16), # python default size for None (tuple(), 0), - ((1,2,3), 32*3), - + ((1, 2, 3), 32 * 3), # value change between python version (list(), asizeof([])), - ([1,2,3], asizeof([1,2,3])), - + ([1, 2, 3], asizeof([1, 2, 3])), (set(), 216), - ({1,2,3}, 216 + 32*3), - + ({1, 2, 3}, 216 + 32 * 3), # hard to resolve the component - (dict(), asizeof(dict())), - ({"a":1, "b":2}, asizeof({"a":1, "b":2})), - ({"a":1., "b":2.}, asizeof({"a":1., "b":2.})), - - (np.array([1,2,3], dtype=np.int8), 1*3), - (np.array([1,2,3], dtype=np.int16), 2*3), - (np.array([1,2,3], dtype=np.int64), 8*3), - (np.array([1,2,3], dtype=np.float16), 2*3), - (np.array([1,2,3], dtype=np.float64), 8*3), - - (torch_randn(1,2,3), 6*4), - (torch_randn(1,2,3, dtype=torch_float16), 6*2), - (torch_randn(1,2,3, dtype=torch_float64), 6*8), - (torch_ones(1,2,3, dtype=torch_int8), 6*1), - (torch_ones(1,2,3, dtype=torch_int16), 6*2), - (torch_ones(1,2,3, dtype=torch_int64), 6*8) - ] + (dict(), asizeof(dict())), + ({"a": 1, "b": 2}, asizeof({"a": 1, "b": 2})), + ({"a": 1.0, "b": 2.0}, asizeof({"a": 1.0, "b": 2.0})), + (np.array([1, 2, 3], dtype=np.int8), 1 * 3), + (np.array([1, 2, 3], dtype=np.int16), 2 * 3), + (np.array([1, 2, 3], dtype=np.int64), 8 * 3), + (np.array([1, 2, 3], dtype=np.float16), 2 * 3), + (np.array([1, 2, 3], dtype=np.float64), 8 * 3), + (torch_randn(1, 2, 3), 6 * 4), + (torch_randn(1, 2, 3, dtype=torch_float16), 6 * 2), + (torch_randn(1, 2, 3, dtype=torch_float64), 6 * 8), + (torch_ones(1, 2, 3, dtype=torch_int8), 6 * 1), + (torch_ones(1, 2, 3, dtype=torch_int16), 6 * 2), + (torch_ones(1, 2, 3, dtype=torch_int64), 6 * 8), + ], ) - def test_multitype_output_handling(self, opts, expected_opt_cost): + def test_multitype_output_handling(self, opts, expected_opt_cost) -> None: """Test whether the different types' output will be handled properly""" + class MultiOutputModel(nn.Module): - def __init__(self): + def __init__(self) -> None: super(MultiOutputModel, self).__init__() def forward(self): return opts - + model = MultiOutputModel() opnode = OperationNode(model) mem_meter = opnode.mem @@ -968,42 +967,36 @@ def forward(self): @pytest.mark.parametrize( argnames=("opts", "expected_opt_cost"), argvalues=[ - ((1, 1.), 32 + 24), + ((1, 1.0), 32 + 24), ((1, "1"), 32 + 1 + (49 if sys.version_info < (3, 12) else 41)), - ((1, None), 32 + 16), + ((1, None), 32 + 16), ((1, ()), 32 + 40), - ((1, (1,2,3)), 32 + 40*4), - ((1, [1,2,3]), 32 + asizeof([1,2,3])), - ((1, {1,2,3}), 32 + 216 + 32*3), - ((1, {"a":1, "b":2}), 32 + asizeof({"a":1, "b":2})), - - (("1", "2."), 3 + 2*(49 if sys.version_info < (3, 12) else 41)), - (("1", 2.), 1 + 24 + (49 if sys.version_info < (3, 12) else 41)), + ((1, (1, 2, 3)), 32 + 40 * 4), + ((1, [1, 2, 3]), 32 + asizeof([1, 2, 3])), + ((1, {1, 2, 3}), 32 + 216 + 32 * 3), + ((1, {"a": 1, "b": 2}), 32 + asizeof({"a": 1, "b": 2})), + (("1", "2."), 3 + 2 * (49 if sys.version_info < (3, 12) else 41)), + (("1", 2.0), 1 + 24 + (49 if sys.version_info < (3, 12) else 41)), (("1", None), 1 + 16 + (49 if sys.version_info < (3, 12) else 41)), - ((None, None), 16 + 16), - ((None, 2.), 16 + 24), - - ((1, np.array([1,2,3], dtype=np.int8)), 32 + 1*3), - ((1, torch_ones(1,2,3, dtype=torch_int8)), 32 + 1*6), - ((torch_randn(1,2,3, dtype=torch_float64), None), 6*8 + 16), - - ((torch_randn(1,2,3, dtype=torch_float16), - np.array([1,2,3], dtype=np.int8)), 6*2 + 1*3), - - ((torch_randn(1,2,3, dtype=torch_float64), - torch_ones(1,2,3, dtype=torch_int64)), 6*8 + 6*8) - ] + ((None, 2.0), 16 + 24), + ((1, np.array([1, 2, 3], dtype=np.int8)), 32 + 1 * 3), + ((1, torch_ones(1, 2, 3, dtype=torch_int8)), 32 + 1 * 6), + ((torch_randn(1, 2, 3, dtype=torch_float64), None), 6 * 8 + 16), + ((torch_randn(1, 2, 3, dtype=torch_float16), np.array([1, 2, 3], dtype=np.int8)), 6 * 2 + 1 * 3), + ((torch_randn(1, 2, 3, dtype=torch_float64), torch_ones(1, 2, 3, dtype=torch_int64)), 6 * 8 + 6 * 8), + ], ) - def test_multi_output_handling(self, opts, expected_opt_cost): + def test_multi_output_handling(self, opts, expected_opt_cost) -> None: """Test whether the multi output module will be handled properly""" + class MultiOutputModel(nn.Module): - def __init__(self): + def __init__(self) -> None: super(MultiOutputModel, self).__init__() def forward(self): return opts - + model = MultiOutputModel() opnode = OperationNode(model) mem_meter = opnode.mem @@ -1017,24 +1010,19 @@ def forward(self): argnames=("module", "ipt_shape"), argvalues=[ (nn.Sequential(nn.Identity()), (1, 10)), - (nn.Linear(10, 5), (1, 10)), - (nn.Conv1d(10, 5, 3), (1, 10, 32)), (nn.Conv2d(10, 5, 3), (1, 10, 32, 32)), (nn.Conv3d(10, 5, 3), (1, 10, 32, 32, 32)), - (nn.MaxPool1d(3), (1, 10, 32)), (nn.MaxPool2d(3), (1, 10, 32, 32)), (nn.MaxPool3d(3), (1, 10, 32, 32, 32)), (nn.AvgPool1d(3), (1, 10, 32)), (nn.AvgPool2d(3), (1, 10, 32, 32)), (nn.AvgPool3d(3), (1, 10, 32, 32, 32)), - (nn.BatchNorm1d(10), (1, 10, 32)), (nn.BatchNorm2d(10), (1, 10, 32, 32)), (nn.BatchNorm3d(10), (1, 10, 32, 32, 32)), - (nn.Sigmoid(), (1, 10)), (nn.Tanh(), (1, 10)), (nn.ReLU(), (1, 10)), @@ -1043,20 +1031,19 @@ def forward(self): (nn.PReLU(), (1, 10)), (nn.RReLU(), (1, 10)), (nn.LeakyReLU(), (1, 10)), - - (nn.Dropout(0.5), (1, 10)), - (nn.AdaptiveAvgPool1d(1), (1, 32, 8)), - (nn.Identity(), (1, 10)), - ] + (nn.Dropout(0.5), (1, 10)), + (nn.AdaptiveAvgPool1d(1), (1, 32, 8)), + (nn.Identity(), (1, 10)), + ], ) - def test_reaccess_module(self, module, ipt_shape): + def test_reaccess_module(self, module, ipt_shape) -> None: """Test reaccess handling""" opnode = OperationNode(module) mem_meter = opnode.mem - + mem_meter.measure() module(torch_randn(*ipt_shape)) - + assert mem_meter.ParamCost._UpperLinkData__access_cnt == 1 assert mem_meter.BufferCost._UpperLinkData__access_cnt == 1 assert mem_meter.OutputCost._UpperLinkData__access_cnt == 1 @@ -1070,7 +1057,7 @@ def test_reaccess_module(self, module, ipt_shape): module(torch_randn(*ipt_shape)) assert mem_meter.ParamCost._UpperLinkData__access_cnt == 1 assert mem_meter.BufferCost._UpperLinkData__access_cnt == 1 - assert mem_meter.OutputCost._UpperLinkData__access_cnt == 2 # revisit will only take output into account + assert mem_meter.OutputCost._UpperLinkData__access_cnt == 2 # revisit will only take output into account assert mem_meter.TotalCost._UpperLinkData__access_cnt == 1 assert mem_meter.ParamCost.val == origin_paramcost assert mem_meter.BufferCost.val == origin_buffercost @@ -1078,34 +1065,33 @@ def test_reaccess_module(self, module, ipt_shape): assert mem_meter.TotalCost.val == origin_totalcost + origin_outputcost @pytest.mark.parametrize( - argnames=("module", "ipt_shape", "expected_param_cost", - "expected_buffer_cost", "expected_output_cost"), + argnames=("module", "ipt_shape", "expected_param_cost", "expected_buffer_cost", "expected_output_cost"), argvalues=[ - (nn.Sequential(nn.Identity()), (1, 10), 0, 0, 10*4), - (nn.Sequential(nn.Conv2d(3,10,3), - nn.Conv2d(10,30,1)), (1, 3, 32, 32), 610*4, 0, 9000*4+27000*4), - - (nn.Linear(10, 5, bias=True), (1, 10), 55*4, 0, 5*4), - (nn.Linear(10, 5, bias=False), (1, 10), 50*4, 0, 5*4), - - (nn.Conv1d(10, 5, 3, bias=True), (1, 10, 32), 155*4, 0, 150*4), - (nn.Conv1d(10, 5, 3, bias=False), (1, 10, 32), 150*4, 0, 150*4), - (nn.Conv2d(10, 5, 3, bias=True), (1, 10, 32, 32), 455*4, 0, 4500*4), - (nn.Conv2d(10, 5, 3, bias=False), (1, 10, 32, 32), 450*4, 0, 4500*4), - (nn.Conv3d(10, 5, 3, bias=True), (1, 10, 32, 32, 32), 1355*4, 0, 135000*4), - (nn.Conv3d(10, 5, 3, bias=False), (1, 10, 32, 32, 32), 1350*4, 0, 135000*4), - - (nn.MaxPool1d(3), (1, 10, 32), 0, 0, 4*1e2), - (nn.MaxPool2d(3), (1, 10, 32, 32), 0, 0, 4*1e3), - (nn.MaxPool3d(3), (1, 10, 32, 32, 32), 0, 0, 4*1e4), - (nn.AvgPool1d(3), (1, 10, 32), 0, 0, 4*1e2), - (nn.AvgPool2d(3), (1, 10, 32, 32), 0, 0, 4*1e3), - (nn.AvgPool3d(3), (1, 10, 32, 32, 32), 0, 0, 4*1e4), - - (nn.BatchNorm1d(10), (1, 10, 32), 80, 88, 32*40), - (nn.BatchNorm2d(10), (1, 10, 32, 32), 80, 88, 32*32*40), - (nn.BatchNorm3d(10), (1, 10, 32, 32, 32), 80, 88, 32**3*40), - + (nn.Sequential(nn.Identity()), (1, 10), 0, 0, 10 * 4), + ( + nn.Sequential(nn.Conv2d(3, 10, 3), nn.Conv2d(10, 30, 1)), + (1, 3, 32, 32), + 610 * 4, + 0, + 9000 * 4 + 27000 * 4, + ), + (nn.Linear(10, 5, bias=True), (1, 10), 55 * 4, 0, 5 * 4), + (nn.Linear(10, 5, bias=False), (1, 10), 50 * 4, 0, 5 * 4), + (nn.Conv1d(10, 5, 3, bias=True), (1, 10, 32), 155 * 4, 0, 150 * 4), + (nn.Conv1d(10, 5, 3, bias=False), (1, 10, 32), 150 * 4, 0, 150 * 4), + (nn.Conv2d(10, 5, 3, bias=True), (1, 10, 32, 32), 455 * 4, 0, 4500 * 4), + (nn.Conv2d(10, 5, 3, bias=False), (1, 10, 32, 32), 450 * 4, 0, 4500 * 4), + (nn.Conv3d(10, 5, 3, bias=True), (1, 10, 32, 32, 32), 1355 * 4, 0, 135000 * 4), + (nn.Conv3d(10, 5, 3, bias=False), (1, 10, 32, 32, 32), 1350 * 4, 0, 135000 * 4), + (nn.MaxPool1d(3), (1, 10, 32), 0, 0, 4 * 1e2), + (nn.MaxPool2d(3), (1, 10, 32, 32), 0, 0, 4 * 1e3), + (nn.MaxPool3d(3), (1, 10, 32, 32, 32), 0, 0, 4 * 1e4), + (nn.AvgPool1d(3), (1, 10, 32), 0, 0, 4 * 1e2), + (nn.AvgPool2d(3), (1, 10, 32, 32), 0, 0, 4 * 1e3), + (nn.AvgPool3d(3), (1, 10, 32, 32, 32), 0, 0, 4 * 1e4), + (nn.BatchNorm1d(10), (1, 10, 32), 80, 88, 32 * 40), + (nn.BatchNorm2d(10), (1, 10, 32, 32), 80, 88, 32 * 32 * 40), + (nn.BatchNorm3d(10), (1, 10, 32, 32, 32), 80, 88, 32**3 * 40), (nn.Sigmoid(), (1, 10), 0, 0, 40), (nn.Tanh(), (1, 10), 0, 0, 40), (nn.ReLU(), (1, 10), 0, 0, 40), @@ -1114,7 +1100,6 @@ def test_reaccess_module(self, module, ipt_shape): (nn.PReLU(), (1, 10), 4, 0, 40), (nn.RReLU(), (1, 10), 0, 0, 40), (nn.LeakyReLU(), (1, 10), 0, 0, 40), - (nn.ReLU(inplace=True), (1, 10), 0, 0, 0), (nn.ReLU6(inplace=True), (1, 10), 0, 0, 0), (nn.SiLU(inplace=True), (1, 10), 0, 0, 0), @@ -1123,91 +1108,98 @@ def test_reaccess_module(self, module, ipt_shape): (nn.SELU(inplace=True), (1, 10), 0, 0, 0), (nn.Dropout(0.5, inplace=True), (1, 10), 0, 0, 0), (nn.Threshold(0.1, 20, inplace=True), (1, 10), 0, 0, 0), - (nn.Dropout(0.5), (1, 10), 0, 0, 40), - (nn.AdaptiveAvgPool1d(1), (1, 32, 8), 0, 0, 32*4), - (nn.Identity(), (1, 10), 0, 0, 40), - ] + (nn.AdaptiveAvgPool1d(1), (1, 32, 8), 0, 0, 32 * 4), + (nn.Identity(), (1, 10), 0, 0, 40), + ], ) - def test_module_measurement_logic(self, module, ipt_shape, - expected_param_cost, expected_buffer_cost, expected_output_cost): + def test_module_measurement_logic( + self, + module, + ipt_shape, + expected_param_cost, + expected_buffer_cost, + expected_output_cost, + ) -> None: """Test whether the measurement logic is true""" oproot = OperationTree(module).root mem_meter = oproot.mem - + assert not mem_meter._MemMeter__stat_ls mem_meter.measure() if not oproot.is_leaf: list(map(lambda x: x.mem.measure(), oproot.childs.values())) module(torch_randn(*ipt_shape)) assert len(mem_meter._MemMeter__stat_ls) == 1 - - assert mem_meter.ParamCost.val == expected_param_cost + + assert mem_meter.ParamCost.val == expected_param_cost assert mem_meter.BufferCost.val == expected_buffer_cost - assert mem_meter.OutputCost.val == expected_output_cost - assert mem_meter.TotalCost.val == expected_param_cost + \ - expected_buffer_cost + \ - expected_output_cost + assert mem_meter.OutputCost.val == expected_output_cost + assert mem_meter.TotalCost.val == expected_param_cost + expected_buffer_cost + expected_output_cost record = mem_meter._MemMeter__stat_ls[0] if not oproot.is_leaf: - assert all(isinstance(getattr(record, field), UpperLinkData) - for field in ["Param_Cost", "Buffer_Cost", "Output_Cost", "Total"]) + assert all( + isinstance(getattr(record, field), UpperLinkData) + for field in ["Param_Cost", "Buffer_Cost", "Output_Cost", "Total"] + ) else: - for expected_val, field_name in zip([expected_param_cost, expected_buffer_cost, - expected_output_cost, mem_meter.TotalCost.val], - ["Param_Cost", "Buffer_Cost", "Output_Cost", "Total"]): + for expected_val, field_name in zip( + [expected_param_cost, expected_buffer_cost, expected_output_cost, mem_meter.TotalCost.val], + ["Param_Cost", "Buffer_Cost", "Output_Cost", "Total"], + ): field_val = getattr(record, field_name) if not expected_val: assert field_val is None else: assert isinstance(field_val, UpperLinkData) + @pytest.mark.usefixtures("toggle_to_ittp") class TestIttpMeter: - def test_cls_variable(self): + def test_cls_variable(self) -> None: """Test detail_val_container and overview_val_container settings""" assert hasattr(IttpMeter, "detail_val_container") dc = MemMeter.detail_val_container assert all(v is None for v in dc._field_defaults.values()) - + assert hasattr(IttpMeter, "overview_val_container") oc = MemMeter.overview_val_container assert all(v is None for v in oc._field_defaults.values()) - - def test_valid_init(self, simple_model_root): + + def test_valid_init(self, simple_model_root) -> None: """Test valid initialization""" model, oproot = simple_model_root - + ittp_meter = oproot.ittp assert ittp_meter._opnode == oproot assert ittp_meter._model is model assert not ittp_meter.is_measured assert not ittp_meter._IttpMeter__stat_ls - - assert ittp_meter.name == "ittp" + + assert ittp_meter.name == "ittp" assert hasattr(ittp_meter, "InferTime") assert isinstance(ittp_meter.InferTime, MetricsData) assert not len(ittp_meter.InferTime.vals) assert ittp_meter.InferTime._MetricsData__reduce_func is np.median assert ittp_meter.InferTime._MetricsData__unit_sys is TimeUnit - - assert ittp_meter.name == "ittp" + + assert ittp_meter.name == "ittp" assert hasattr(ittp_meter, "Throughput") assert isinstance(ittp_meter.Throughput, MetricsData) assert not len(ittp_meter.Throughput.vals) assert ittp_meter.Throughput._MetricsData__reduce_func is np.median assert ittp_meter.Throughput._MetricsData__unit_sys is SpeedUnit - def test_invalid_init(self): + def test_invalid_init(self) -> None: """Test invalid initialization""" with pytest.raises(TypeError): IttpMeter(opnode="0") - def test_val_property(self, measured_simple_model): + def test_val_property(self, measured_simple_model) -> None: """Test whether the val property is properly set""" *_, ittp_meter = measured_simple_model - + overview = ittp_meter.val assert isinstance(overview, IttpMeter.overview_val_container) assert overview.Operation_Id == "0" @@ -1216,21 +1208,21 @@ def test_val_property(self, measured_simple_model): assert overview.Infer_Time is ittp_meter.InferTime assert overview.Throughput is ittp_meter.Throughput - def test_crucial_data_format(self, measured_simple_model): + def test_crucial_data_format(self, measured_simple_model) -> None: """Test whether the crucial_data is return in correct format""" *_, ittp_meter = measured_simple_model crucial_data = ittp_meter.crucial_data assert isinstance(crucial_data, dict) - + # verify align keys = list(crucial_data.keys()) - assert all(isinstance(k,str) for k in crucial_data.keys()) + assert all(isinstance(k, str) for k in crucial_data) assert all(len(k) == len(keys[0]) for k in keys[1:]) - + # verify value - assert all(isinstance(v,str) for v in crucial_data.values()) + assert all(isinstance(v, str) for v in crucial_data.values()) - def test_ittp_measure(self): + def test_ittp_measure(self) -> None: """Test whether the measure method works well""" module = nn.Identity() opnode = OperationNode(module) @@ -1240,22 +1232,22 @@ def test_ittp_measure(self): assert ittp_meter.is_measured assert len(module._forward_hooks) == 1 assert next(iter(module._forward_hooks.values())).func.__name__ == "__hook_func" - - def test_no_measure_cache(self, simple_model_root): + + def test_no_measure_cache(self, simple_model_root) -> None: """Test whether the measure method will be revisited after the first call""" - model, oproot = simple_model_root + _model, oproot = simple_model_root ittp_meter = oproot.ittp - + res = ittp_meter.measure(device=torch_device("cpu")) assert res is not None - + res = ittp_meter.measure(device=torch_device("cpu")) assert res is not None @pytest.mark.skipif(not is_cuda(), reason="No GPUs detected") - def test_model_device_dismatch(self): - """Test whether the measure method works well - when model's device is the same with given argument""" + def test_model_device_dismatch(self) -> None: + """Test whether the measure method works well + when model's device is the same with given argument""" model = nn.Linear(10, 5) opnode = OperationNode(model) ittp_meter = opnode.ittp @@ -1264,7 +1256,7 @@ def test_model_device_dismatch(self): ittp_meter.measure(device=torch_device("cuda:0"), repeat=1) assert len(model._forward_hooks) == 1 - def test_measure_on_different_device(self): + def test_measure_on_different_device(self) -> None: """Test whether the measure method works well for model on different device""" model = nn.Linear(10, 5) opnode = OperationNode(model) @@ -1273,33 +1265,36 @@ def test_measure_on_different_device(self): # cpu ittp_meter.measure(device=torch_device("cpu"), repeat=1) with patch("torchmeter.statistic.perf_counter") as cpu_timer, \ - patch("torchmeter.statistic.cuda_event.elapsed_time") as gpu_timer: + patch("torchmeter.statistic.cuda_event.elapsed_time") as gpu_timer: # fmt: skip cpu_timer.side_effect = [1, 2] gpu_timer.side_effect = [1, 2] model(torch_randn(1, 10, device=torch_device("cpu"))) assert cpu_timer.call_count == 2 assert gpu_timer.call_count == 0 - + # gpu if is_cuda(): ittp_meter.measure(device=torch_device("cuda:0"), repeat=1) with patch("torchmeter.statistic.perf_counter") as cpu_timer, \ - patch("torchmeter.statistic.cuda_event.elapsed_time") as gpu_timer: + patch("torchmeter.statistic.cuda_event.elapsed_time") as gpu_timer: # fmt: skip cpu_timer.side_effect = [1, 2] gpu_timer.side_effect = [1, 2] model(torch_randn(1, 10, device=torch_device("cuda:0"))) assert cpu_timer.call_count == 0 assert gpu_timer.call_count == 1 else: - warnings.warn(message="No Nvidia GPU detected on this device, the test of measuring ittp of model on GPU will be skipped.", - category=UserWarning) + warnings.warn( + category=UserWarning, + message="No Nvidia GPU detected on this device, " + + "the test of measuring ittp of model on GPU will be skipped.", + ) @pytest.mark.parametrize( argnames="repeat_time", - argvalues=range(10,101,10), - ids=lambda x: f"repeat measurement {x} times" + argvalues=range(10, 101, 10), + ids=lambda x: f"repeat measurement {x} times", ) - def test_repeat_measure(self, repeat_time): + def test_repeat_measure(self, repeat_time) -> None: """Test whether the repeat setting works well""" model = nn.Linear(10, 5) opnode = OperationNode(model) @@ -1309,18 +1304,18 @@ def test_repeat_measure(self, repeat_time): model(torch_randn(1, 10)) assert len(ittp_meter.InferTime.vals) == repeat_time - assert len(ittp_meter.Throughput.vals) == repeat_time + assert len(ittp_meter.Throughput.vals) == repeat_time - def test_valid_access(self, simple_model_root): + def test_valid_access(self, simple_model_root) -> None: """Test whether the invalid access will be blocked""" - model, oproot = simple_model_root + _model, oproot = simple_model_root ittp_meter = oproot.ittp - + # access property before measure with pytest.raises(AttributeError) as e: ittp_meter.detail_val assert "ittp" in str(e.value) - + with pytest.raises(AttributeError) as e: ittp_meter.val assert "ittp" in str(e.value) @@ -1328,78 +1323,78 @@ def test_valid_access(self, simple_model_root): with pytest.raises(AttributeError) as e: ittp_meter.crucial_data assert "ittp" in str(e.value) - + # access skipped module after measure ittp_meter.measure(device=torch_device("cpu")) with pytest.raises(RuntimeError): ittp_meter.detail_val - + with pytest.raises(RuntimeError): ittp_meter.val - + with pytest.raises(RuntimeError): ittp_meter.crucial_data - def test_reaccess_module(self): + def test_reaccess_module(self) -> None: """Test reaccess handling""" model = nn.Linear(10, 5) opnode = OperationNode(model) ittp_meter = opnode.ittp - + assert not len(model._forward_hooks) ittp_meter.measure(device=torch_device("cpu")) assert len(model._forward_hooks) == 1 - + model(torch_randn(1, 10)) assert not len(model._forward_hooks) - + # reaccess it_val = ittp_meter.InferTime.metrics tp_val = ittp_meter.Throughput.metrics model(torch_randn(1, 10)) - assert it_val == ittp_meter.InferTime.metrics + assert it_val == ittp_meter.InferTime.metrics assert tp_val == ittp_meter.Throughput.metrics @pytest.mark.parametrize( argnames=("repeat_time", "expected_it", "expected_tp"), argvalues=[ - (11, 6, 1/6), - (21, 11, 1/11), - (31, 16, 1/16), - (41, 21, 1/21), - (51, 26, 1/26) + (11, 6, 1 / 6), + (21, 11, 1 / 11), + (31, 16, 1 / 16), + (41, 21, 1 / 21), + (51, 26, 1 / 26), ], - ids=lambda x: f"{x}" if isinstance(x, int) else f"{x:g}" + ids=lambda x: f"{x}" if isinstance(x, int) else f"{x:g}", ) - def test_module_measurement_logic(self, repeat_time, - expected_it, expected_tp): + def test_module_measurement_logic(self, repeat_time, expected_it, expected_tp) -> None: """Test whether the measurement logic is true""" model = nn.Linear(10, 5) opnode = OperationNode(model) ittp_meter = opnode.ittp ittp_meter._IttpMeter__reduce_func = np.median - + # cpu ittp_meter.measure(device=torch_device("cpu"), repeat=repeat_time) with patch("torchmeter.statistic.perf_counter") as cpu_timer: se_time_vals = [] - for sub_res in range(1, repeat_time+1): + for sub_res in range(1, repeat_time + 1): se_time_vals.extend([0, sub_res]) cpu_timer.side_effect = se_time_vals model(torch_randn(1, 10, device=torch_device("cpu"))) assert ittp_meter.InferTime.metrics == expected_it assert ittp_meter.Throughput.metrics == pytest.approx(expected_tp) - + # gpu if is_cuda(): ittp_meter.measure(device=torch_device("cuda:0"), repeat=repeat_time) with patch("torchmeter.statistic.cuda_event.elapsed_time") as gpu_timer: - gpu_timer.side_effect = map(lambda x:int(x*1e3), range(1, repeat_time+1)) + gpu_timer.side_effect = map(lambda x: int(x * 1e3), range(1, repeat_time + 1)) model(torch_randn(1, 10, device=torch_device("cuda:0"))) assert ittp_meter.InferTime.metrics == expected_it assert ittp_meter.Throughput.metrics == pytest.approx(expected_tp) else: - warnings.warn(message="No Nvidia GPU detected on this device, the test of ittp measuring logic on GPU will be skipped.", - category=UserWarning) - - + warnings.warn( + category=UserWarning, + message="No Nvidia GPU detected on this device, " + + "the test of ittp measuring logic on GPU will be skipped.", + ) diff --git a/tests/test_unit.py b/tests/test_unit.py index b47b46b..8b75773 100644 --- a/tests/test_unit.py +++ b/tests/test_unit.py @@ -2,74 +2,75 @@ import pytest -from torchmeter.unit import ( - CountUnit, - BinaryUnit, - TimeUnit, - SpeedUnit, - auto_unit -) - -def is_unitsys_valid(unit_sys): +from torchmeter.unit import TimeUnit, CountUnit, SpeedUnit, BinaryUnit, auto_unit + + +def is_unitsys_valid(unit_sys) -> None: assert issubclass(unit_sys, Enum) - assert all(unit_val > 0 for unit_val in unit_sys._value2member_map_.keys()) + assert all(unit_val > 0 for unit_val in unit_sys._value2member_map_) + @pytest.fixture(params=[CountUnit, BinaryUnit, TimeUnit, SpeedUnit]) -def all_type_unit(request): +def all_type_unit(request): return request.param -def test_count_unit(): + +def test_count_unit() -> None: is_unitsys_valid(CountUnit) - + assert CountUnit.T.value == 1e12 assert CountUnit.G.value == 1e9 assert CountUnit.M.value == 1e6 assert CountUnit.K.value == 1e3 - + assert len(list(CountUnit)) == 4 -def test_binary_unit(): + +def test_binary_unit() -> None: is_unitsys_valid(BinaryUnit) - + assert BinaryUnit.TiB.value == 2**40 assert BinaryUnit.GiB.value == 2**30 assert BinaryUnit.MiB.value == 2**20 assert BinaryUnit.KiB.value == 2**10 assert BinaryUnit.B.value == 2**0 - + assert len(list(BinaryUnit)) == 5 -def test_time_unit(): + +def test_time_unit() -> None: is_unitsys_valid(TimeUnit) - + assert TimeUnit.h.value == 60**2 assert TimeUnit.min.value == 60**1 assert TimeUnit.s.value == 60**0 assert TimeUnit.ms.value == 1e-3 assert TimeUnit.us.value == 1e-6 assert TimeUnit.ns.value == 1e-9 - + assert len(list(TimeUnit)) == 6 -def test_speed_unit(): + +def test_speed_unit() -> None: is_unitsys_valid(SpeedUnit) - + assert SpeedUnit.TIPS.value == 1e12 assert SpeedUnit.GIPS.value == 1e9 assert SpeedUnit.MIPS.value == 1e6 assert SpeedUnit.KIPS.value == 1e3 assert SpeedUnit.IPS.value == 1e0 - + assert len(list(SpeedUnit)) == 5 + @pytest.mark.vital -def test_auto_unit(all_type_unit): +def test_auto_unit(all_type_unit) -> None: stage_vals = list(all_type_unit._value2member_map_.keys()) stage_vals.sort() - + # in range - for i in range(len(stage_vals)-1): - low_stage, high_stage = stage_vals[i:i+2] + for i in range(len(stage_vals) - 1): + low_stage, high_stage = stage_vals[i : i + 2] unit = all_type_unit(low_stage).name if 2 * low_stage < high_stage: @@ -78,30 +79,24 @@ def test_auto_unit(all_type_unit): else: integral_multiple_val = low_stage int_multiple_time = 1 - assert f"{int_multiple_time} {unit}" == auto_unit(integral_multiple_val, - unit_system=all_type_unit) + assert f"{int_multiple_time} {unit}" == auto_unit(integral_multiple_val, unit_system=all_type_unit) float_multiple_val = 1.9 * low_stage - while not float_multiple_val % low_stage and float_multiple_val < 2*low_stage-1: + while not float_multiple_val % low_stage and float_multiple_val < 2 * low_stage - 1: float_multiple_val += 1 - assert f"{1.9:.2f} {unit}" == auto_unit(float_multiple_val, - unit_system=all_type_unit) - + assert f"{1.9:.2f} {unit}" == auto_unit(float_multiple_val, unit_system=all_type_unit) + # out of range(smaller) underflow_float_val = stage_vals[0] / 2 - assert f"{underflow_float_val:.2f}" == auto_unit(underflow_float_val, - unit_system=all_type_unit) - + assert f"{underflow_float_val:.2f}" == auto_unit(underflow_float_val, unit_system=all_type_unit) + underflow_int_val = int(underflow_float_val) - assert f"{underflow_int_val}" == auto_unit(underflow_int_val, - unit_system=all_type_unit) - + assert f"{underflow_int_val}" == auto_unit(underflow_int_val, unit_system=all_type_unit) + # out of range(bigger) unit = all_type_unit(stage_vals[-1]).name - + overflow_integral_multiple_val = stage_vals[-1] * 2 - assert f"2 {unit}" == auto_unit(overflow_integral_multiple_val, - unit_system=all_type_unit) + assert f"2 {unit}" == auto_unit(overflow_integral_multiple_val, unit_system=all_type_unit) overflow_float_multiple_val = stage_vals[-1] * 1.5 - assert f"{1.5:.2f} {unit}" == auto_unit(overflow_float_multiple_val, - unit_system=all_type_unit) \ No newline at end of file + assert f"{1.5:.2f} {unit}" == auto_unit(overflow_float_multiple_val, unit_system=all_type_unit) diff --git a/tests/test_utils.py b/tests/test_utils.py index 0420d47..1768eda 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -1,43 +1,37 @@ import os +from typing import NoReturn from decimal import Decimal -from datetime import datetime, date +from datetime import date, datetime from unittest.mock import Mock, patch -import pytest import numpy as np import polars as pl -from rich.text import Text +import pytest from torch import rand as torch_rand +from rich.text import Text from numpy.random import rand as np_rand -from torchmeter._stat_numeric import ( - UpperLinkData, - MetricsData -) +from torchmeter.utils import Timer, Status, hasargs, data_repr, indent_str, resolve_savepath, match_polars_type +from torchmeter._stat_numeric import MetricsData, UpperLinkData -from torchmeter.utils import ( - resolve_savepath, - hasargs, - indent_str, data_repr, - match_polars_type, - Status, Timer -) -def func_no_args(): ... -def func_one_arg(a): ... -def func_multi_args(a, b, c): ... +def func_no_args() -> None: ... +def func_one_arg(a) -> None: ... +def func_multi_args(a, b, c) -> None: ... + @pytest.fixture def chang_to_temp_dir(tmpdir): origin_work_dir = os.getcwd() os.chdir(tmpdir) - + yield tmpdir.strpath - + os.chdir(origin_work_dir) if tmpdir.exists(): tmpdir.remove(rec=1) + @pytest.fixture def mock_status(monkeypatch): """To mock rich.status.Status.__enter__ and .__exit__""" @@ -47,17 +41,16 @@ def mock_status(monkeypatch): monkeypatch.setattr(Status, "__exit__", mock_exit) return {"enter": mock_enter, "exit": mock_exit} + @pytest.mark.parametrize( - argnames="func, given_args, will_error", + argnames=("func", "given_args", "will_error"), argvalues=[ (func_no_args, [], False), (func_no_args, ["a"], True), - (func_one_arg, [], False), (func_one_arg, ["a"], False), (func_one_arg, ["b"], True), (func_one_arg, ["a", "b"], True), - (func_multi_args, [], False), (func_multi_args, ["a"], False), (func_multi_args, ["b"], False), @@ -66,151 +59,154 @@ def mock_status(monkeypatch): (func_multi_args, ["a", "d"], True), (func_multi_args, ["a", "b", "c"], False), (func_multi_args, ["a", "b", "d"], True), - ] + ], ) -def test_hasargs(func, given_args, will_error): +def test_hasargs(func, given_args, will_error) -> None: if will_error: - with pytest.raises(RuntimeError) as e: hasargs(func, *given_args) assert func.__name__ in str(e.value) - + else: - hasargs(func, *given_args) + hasargs(func, *given_args) + @pytest.mark.usefixtures("chang_to_temp_dir") class TestResolveSavePath: - def test_relative_filepath_input(self): + def test_relative_filepath_input(self) -> None: temp_dir = os.getcwd() - + file_dir = "relative_dir" file_name = "relative_file.txt" file_path = os.path.join(temp_dir, file_dir, file_name) - - save_dir, save_file = resolve_savepath(os.path.join(file_dir, file_name), - target_ext="txt") + + save_dir, save_file = resolve_savepath(os.path.join(file_dir, file_name), target_ext="txt") assert save_dir == os.path.join(temp_dir, file_dir) assert save_file == file_path assert os.path.exists(save_dir) - - def test_absolute_filepath_input(self): + + def test_absolute_filepath_input(self) -> None: temp_dir = os.getcwd() file_dir = "absolute_dir" file_name = "absolute_file.txt" file_path = os.path.join(temp_dir, file_dir, file_name) - - save_dir, save_file = resolve_savepath(file_path, - target_ext="txt") + + save_dir, save_file = resolve_savepath(file_path, target_ext="txt") assert save_dir == os.path.join(temp_dir, file_dir) assert save_file == file_path assert os.path.exists(save_dir) - - def test_relative_dirpath_input(self): + + def test_relative_dirpath_input(self) -> None: temp_dir = os.getcwd() dir_name = "relative_dir" default_file_name = "TestData" defaule_file_ext = "txt" dir_path = os.path.join(temp_dir, dir_name) - - save_dir, save_file = resolve_savepath(dir_path, - target_ext=defaule_file_ext, - default_filename=default_file_name) + + save_dir, save_file = resolve_savepath( + dir_path, + target_ext=defaule_file_ext, + default_filename=default_file_name, + ) assert save_dir == os.path.join(temp_dir, dir_name) - assert save_file == os.path.join(temp_dir, dir_name, - default_file_name + "." + defaule_file_ext) + assert save_file == os.path.join( + temp_dir, + dir_name, + default_file_name + "." + defaule_file_ext, + ) assert os.path.exists(save_dir) - - def test_absolute_dirpath_input(self): + + def test_absolute_dirpath_input(self) -> None: temp_dir = os.getcwd() - + dir_name = "absolute_dir" default_file_name = "TestData" defaule_file_ext = "txt" dir_path = os.path.join(temp_dir, dir_name) - - save_dir, save_file = resolve_savepath(dir_path, - target_ext=defaule_file_ext, - default_filename=default_file_name) - + + save_dir, save_file = resolve_savepath( + dir_path, target_ext=defaule_file_ext, default_filename=default_file_name + ) + assert save_dir == os.path.join(temp_dir, dir_name) - assert save_file == os.path.join(temp_dir, dir_name, - default_file_name + "." + defaule_file_ext) + assert save_file == os.path.join(temp_dir, dir_name, default_file_name + "." + defaule_file_ext) assert os.path.exists(save_dir) + @pytest.mark.vital class TestIndentStr: # basic function test - def test_single_line(self): + def test_single_line(self) -> None: """Test single-line string with default args""" result = indent_str("hello") - assert result == " "*4 + "hello" + assert result == " " * 4 + "hello" - def test_multi_line(self): + def test_multi_line(self) -> None: """Test multi-line strings with default args""" input_str = "first\nsecond\nthird" expected = ( "β”‚ first\n" "β”‚ second\n" "└─ third" - ) + ) # fmt: skip assert indent_str(input_str) == expected - def test_string_list_input(self): + def test_string_list_input(self) -> None: """Test string list input""" ipt_list = ["line1", "line2"] expected = ( "β”‚ line1\n" "└─ line2" - ) + ) # fmt: skip assert indent_str(ipt_list) == expected # arguments combination test @pytest.mark.parametrize( - argnames=("indent", "expected"), + argnames=("indent", "expected"), argvalues=[ (-1, "first\nsecond"), (0, "first\nsecond"), (1, "β”‚first\nβ””second"), (2, "β”‚ first\n└─second"), (3, "β”‚ first\n└─ second"), - (4, "β”‚ first\n└─ second") - ] + (4, "β”‚ first\n└─ second"), + ], ) - def test_indent_variations(self, indent, expected): + def test_indent_variations(self, indent, expected) -> None: """Test different indentation levels""" assert indent_str("first\nsecond", indent=indent) == expected @pytest.mark.parametrize( - argnames=("guideline", "expected"), + argnames=("guideline", "expected"), argvalues=[ (True, "β”‚ line1\n└─ line2"), - (False, " line1\n line2") - ] + (False, " line1\n line2"), + ], ) - def test_guideline_toggle(self, guideline, expected): + def test_guideline_toggle(self, guideline, expected) -> None: """Test guideline activation""" assert indent_str("line1\nline2", guideline=guideline) == expected - def test_not_process_first(self): + def test_not_process_first(self) -> None: """Test not to indent first line""" result = indent_str("a\nb", process_first=False) assert result == "a\n└─ b" # boundary condition test - def test_empty_input(self): + def test_empty_input(self) -> None: """桋试空字符串输ε…₯""" assert indent_str("") == " " # single line assert indent_str("\n") == "β”‚ \n└─ " # multi-line - def test_no_guideline_for_single_line(self): + def test_no_guideline_for_single_line(self) -> None: """Test single-line automatic disabling of guide lines""" assert indent_str("hello", guideline=True) == " hello" # Special scenario testing - def test_mixed_lengths_input(self): + def test_mixed_lengths_input(self) -> None: """Test mixed input with different line lengths""" input_lines = "short\nvery long line\nmedium" result = indent_str(input_lines) @@ -218,71 +214,70 @@ def test_mixed_lengths_input(self): "β”‚ short\n" "β”‚ very long line\n" "└─ medium" - ) + ) # fmt: skip assert result == expected # Error handling test - def test_invalid_indent_type(self): + def test_invalid_indent_type(self) -> None: """Test non-integer indent""" with pytest.raises(TypeError): indent_str("test", indent="4") # type: ignore - def test_invalid_input_type(self): + def test_invalid_input_type(self) -> None: """Test non-str input""" with pytest.raises(TypeError): indent_str(123) - + with pytest.raises(TypeError): - indent_str(["a", "b", 123]) - + indent_str(["a", "b", 123]) + + @pytest.mark.vital class TestDataRepr: @pytest.mark.parametrize( - argnames=("val", "type"), + argnames=("val", "type_repr"), argvalues=[ (42, "int"), (3.14, "float"), ("hello", "str"), (True, "bool"), (None, "NoneType"), - ] + ], ) - def test_simple_data(self, val, type): + def test_simple_data(self, val, type_repr) -> None: """Test repr of basic data types""" - assert data_repr(val) == f"[b green]{val}[/] [dim]<{type}>[/]" + assert data_repr(val) == f"[b green]{val}[/] [dim]<{type_repr}>[/]" @pytest.mark.parametrize( - argnames=("val", "type"), + argnames=("val", "type_repr"), argvalues=[ (np_rand(2, 3, 4), "ndarray"), (torch_rand(3, 224, 224), "Tensor"), - (Mock(shape=(5, 5)), "Mock") - ] + (Mock(shape=(5, 5)), "Mock"), + ], ) - def test_shape_objects(self, val, type): + def test_shape_objects(self, val, type_repr) -> None: """Test objects with shape attributes""" - assert data_repr(val) == f"[dim]Shape[/]([b green]{list(val.shape)}[/]) [dim]<{type}>[/]" + assert data_repr(val) == f"[dim]Shape[/]([b green]{list(val.shape)}[/]) [dim]<{type_repr}>[/]" @pytest.mark.parametrize( - argnames=("val", "type", "inner_type"), + argnames=("val", "type_repr", "inner_type"), argvalues=[ ([1, 2, 3], "list", "int"), - ([np_rand(2,3), np_rand(3,4)], "list", "ndarray"), - + ([np_rand(2, 3), np_rand(3, 4)], "list", "ndarray"), (("1", "2", "3"), "tuple", "str"), - ((torch_rand(1,2), torch_rand(3,4)), "tuple", "Tensor"), - - ({1., 2., 3.}, "set", "float"), - ({Mock(shape=(1,2)), Mock(shape=(3,4))}, "set", "Mock"), - ] + ((torch_rand(1, 2), torch_rand(3, 4)), "tuple", "Tensor"), + ({1.0, 2.0, 3.0}, "set", "float"), + ({Mock(shape=(1, 2)), Mock(shape=(3, 4))}, "set", "Mock"), + ], ) - def test_container_data(self, val, type, inner_type): + def test_container_data(self, val, type_repr, inner_type) -> None: actual = data_repr(val) - + # verify the overall structure - assert actual.startswith(f"[dim]{type}[/](") + assert actual.startswith(f"[dim]{type_repr}[/](") assert actual.endswith(")") - + # verify the repr of each item for v in val: if hasattr(v, "shape"): @@ -290,99 +285,99 @@ def test_container_data(self, val, type, inner_type): else: item_segment = f"[b green]{v}[/] [dim]<{inner_type}>[/]" assert item_segment in actual - + # verify indentation lines = actual.split("\n") - assert all(line.startswith("β”‚" + " "*(len(f"{type}"))) for line in lines[1:-1]) - + assert all(line.startswith("β”‚" + " " * (len(f"{type_repr}"))) for line in lines[1:-1]) + @pytest.mark.parametrize( argnames=("ipt", "key_type", "value_type"), argvalues=[ ({"a": "1", "b": "2"}, "str", "str"), ({1: 1, 2: 2}, "int", "int"), - ({1.0: 1., 2.0: 2.}, "float", "float"), + ({1.0: 1.0, 2.0: 2.0}, "float", "float"), ({True: True, False: False}, "bool", "bool"), - ({None: None}, "NoneType", "NoneType") - ] + ({None: None}, "NoneType", "NoneType"), + ], ) - def test_simple_data_dict(self, ipt, key_type, value_type): + def test_simple_data_dict(self, ipt, key_type, value_type) -> None: actual = data_repr(ipt) - + # verify the overall structure assert actual.startswith("[dim]dict[/](") assert actual.endswith(")") - + # verify the repr of each key-value pair for k, v in ipt.items(): key_segment = f"[b green]{k}[/] [dim]<{key_type}>[/]" value_segment = f"[b green]{v}[/] [dim]<{value_type}>[/]" assert key_segment in actual assert value_segment in actual - + # verify indentation lines = actual.split("\n") - assert all(line.startswith("β”‚" + " "*4) for line in lines[1:-1]) + assert all(line.startswith("β”‚" + " " * 4) for line in lines[1:-1]) @pytest.mark.parametrize( argnames=("ipt", "key_type", "value_type"), argvalues=[ - ({"a": np_rand(2,3,4), "b": np_rand(5,6,7)}, "str", "ndarray"), - ({"a": torch_rand(2,3,4), "b": torch_rand(5,6,7)}, "str", "Tensor"), - ({"a": Mock(shape=(1,2,3,4)), "b": Mock(shape=(5,6,7,8))}, "str", "Mock") - ] + ({"a": np_rand(2, 3, 4), "b": np_rand(5, 6, 7)}, "str", "ndarray"), + ({"a": torch_rand(2, 3, 4), "b": torch_rand(5, 6, 7)}, "str", "Tensor"), + ({"a": Mock(shape=(1, 2, 3, 4)), "b": Mock(shape=(5, 6, 7, 8))}, "str", "Mock"), + ], ) - def test_shape_objects_dict(self, ipt, key_type, value_type): + def test_shape_objects_dict(self, ipt, key_type, value_type) -> None: actual = data_repr(ipt) - + # verify the overall structure assert actual.startswith("[dim]dict[/](") assert actual.endswith(")") - + # verify the repr of each key-value pair for k, v in ipt.items(): if hasattr(k, "shape"): key_segment = f"[dim]Shape[/]([b green]{list(k.shape)}[/]) [dim]<{key_type}>[/]" else: key_segment = f"[b green]{k}[/] [dim]<{key_type}>[/]" - + value_segment = f"[dim]Shape[/]([b green]{list(v.shape)}[/]) [dim]<{value_type}>[/]" assert key_segment in actual assert value_segment in actual - + # verify indentation lines = actual.split("\n") - assert all(line.startswith("β”‚" + " "*4) for line in lines[1:-1]) + assert all(line.startswith("β”‚" + " " * 4) for line in lines[1:-1]) @pytest.mark.parametrize( argnames=("ipt", "key_type", "container_type", "container_key_type", "container_val_type"), argvalues=[ ({"a": {"b": 1, "c": 2}, "d": {"e": 3, "f": 4}}, "str", "dict", "str", "int"), - ({"a": [1., 2., 3.], "b": [4., 5., 6.]}, "str", "list", None, "float"), + ({"a": [1.0, 2.0, 3.0], "b": [4.0, 5.0, 6.0]}, "str", "list", None, "float"), ({"a": (True, False, True), "b": (False, True, False)}, "str", "tuple", None, "bool"), ({"a": {"1", "2", "3"}, "b": {"4", "5", "6"}}, "str", "set", None, "str"), - ] + ], ) - def test_container_data_dict(self, ipt, key_type, container_type, container_key_type, container_val_type): + def test_container_data_dict(self, ipt, key_type, container_type, container_key_type, container_val_type) -> None: """Test dict made up of container, i.e. the nested situation""" actual = data_repr(ipt) - + # verify the overall structure assert actual.startswith("[dim]dict[/](") assert actual.endswith(")") - + # verify the repr of each key-value pair for k, v in ipt.items(): if hasattr(k, "shape"): key_segment = f"[dim]Shape[/]([b green]{list(k.shape)}[/]) [dim]<{key_type}>[/]" else: key_segment = f"[b green]{k}[/] [dim]<{key_type}>[/]" - + value_segment = f"[dim]{container_type}[/](" assert key_segment in actual assert value_segment in actual - + if isinstance(v, dict): - for ck,cv in v.items(): + for ck, cv in v.items(): if hasattr(ck, "shape"): ck_segment = f"[dim]Shape[/]([b green]{list(ck.shape)}[/]) [dim]<{container_key_type}>[/]" else: @@ -390,10 +385,10 @@ def test_container_data_dict(self, ipt, key_type, container_type, container_key_ # simplify logic here # test case one must not have object with shape attribute cv_segment = f"[b green]{cv}[/] [dim]<{container_val_type}>[/]" - + assert ck_segment in actual assert cv_segment in actual - + else: for cv in v: if hasattr(cv, "shape"): @@ -402,24 +397,25 @@ def test_container_data_dict(self, ipt, key_type, container_type, container_key_ cv_segment = f"[b green]{cv}[/] [dim]<{container_val_type}>[/]" assert cv_segment in actual - + # verify indentation lines = actual.split("\n")[1:] for idx, container in enumerate(ipt.values()): section_len = len(container) - assert all(line.startswith("β”‚" + " "*len("dict(a :") + "β”‚") - for line in lines[:section_len-2]) - if idx < len(ipt)-1: - assert lines[section_len-2].startswith("β”‚" + " "*len("dict(a :") + "└─") - - lines = lines[section_len-1:] - assert lines[0].startswith("β”‚" + " "*len("dict")) - + assert all(line.startswith("β”‚" + " " * len("dict(a :") + "β”‚") + for line in lines[:section_len - 2]) # fmt: skip + + if idx < len(ipt) - 1: + assert lines[section_len - 2].startswith("β”‚" + " " * len("dict(a :") + "└─") + + lines = lines[section_len - 1 :] + assert lines[0].startswith("β”‚" + " " * len("dict")) + lines.pop(0) else: - assert lines[section_len-2].startswith("└─" + " "*len("dict(a ") + "└─") + assert lines[section_len - 2].startswith("└─" + " " * len("dict(a ") + "└─") - def test_empty_container(self): + def test_empty_container(self) -> None: """Test empty container input""" assert data_repr([]) == "[b green][][/] [dim][/]" @@ -427,26 +423,28 @@ def test_empty_container(self): assert data_repr([[], {}]) == ( "[dim]list[/]([b green][][/] [dim][/],\n" "└─ [b green]{}[/] [dim][/])" - ) + ) # fmt: skip - def test_uncommon_input(self): + def test_uncommon_input(self) -> None: """Test uncommon input""" + class CustomType: ... + assert data_repr(CustomType()) == f"[b green]obj[/] [dim]<{CustomType.__module__}.CustomType>[/]" func = func_no_args assert data_repr(func) == "[b green]func_no_args[/] [dim][/]" - - mock_obj = Mock(shape="invalid") # invalid shape + + mock_obj = Mock(shape="invalid") # invalid shape assert data_repr(mock_obj) == "[b green]obj[/] [dim][/]" + @pytest.mark.vital class TestMatchPolarsType: - is_same_type = lambda _, val, pl_type: match_polars_type(val).is_(pl_type) @pytest.mark.parametrize( - argnames=("input_value", "expected_type"), + argnames=("input_value", "expected_type"), argvalues=[ # basic types (42, pl.Int64), @@ -463,10 +461,8 @@ class TestMatchPolarsType: (np.uint64(5), pl.UInt64), (np.float32(1.2), pl.Float32), (np.float64(1.2), pl.Float64), - - # decimal + # decimal (Decimal("1.2"), pl.Decimal), - # time related types (date.today(), pl.Date), (datetime.now(), pl.Datetime("us")), @@ -474,149 +470,147 @@ class TestMatchPolarsType: (np.datetime64("2023-01-01T12"), pl.Object), (np.timedelta64(1, "us"), pl.Duration("us")), (np.timedelta64(1, "D"), pl.Object), - # container types ([1, 2, 3], pl.List(pl.Int64)), - ((1., 2., 3.), pl.List(pl.Float64)), - ((1, 2., 3.), pl.List(pl.Int64)), + ((1.0, 2.0, 3.0), pl.List(pl.Float64)), + ((1, 2.0, 3.0), pl.List(pl.Int64)), ({"A": 1, "B": "b"}, pl.Struct({"A": pl.Int64, "B": pl.Utf8})), (set(), pl.Object), - # ndarray - (np.array([1, 2, 3], dtype=np.int64), pl.Int64), # 1D int - (np.array([1.1, 2.2], dtype=np.float64), pl.Float64), # 1D float + (np.array([1, 2, 3], dtype=np.int64), pl.Int64), # 1D int + (np.array([1.1, 2.2], dtype=np.float64), pl.Float64), # 1D float (np.array([[1, 2], [3, 4]], dtype=np.int64), pl.Array(pl.Int64, 2)), # 2D - (np.array([1, "a"], dtype=object), pl.Object), # structed ndarray - + (np.array([1, "a"], dtype=object), pl.Object), # structed ndarray # class instance (Timer(task_desc="test"), pl.Object), (Status(status="test"), pl.Object), (UpperLinkData(2), pl.Object), - (MetricsData(), pl.Object) - ] + (MetricsData(), pl.Object), + ], ) - def test_type_inference(self, input_value, expected_type): + def test_type_inference(self, input_value, expected_type) -> None: """Test basic functionality""" - - assert self.is_same_type(input_value, expected_type) + + assert self.is_same_type(input_value, expected_type) @pytest.mark.parametrize( - argnames="pre_res_value", + argnames="pre_res_value", argvalues=[ pl.Int64, pl.Float64, pl.Date, pl.Object, pl.List(pl.Float64), - pl.Array(pl.Float64, 2) - ] + pl.Array(pl.Float64, 2), + ], ) - def test_pre_res_option(self, pre_res_value): + def test_pre_res_option(self, pre_res_value) -> None: """Test the logic of pre_res option(early return)""" - + result = match_polars_type( - "return pre_res no matter what this is", + "return pre_res no matter what this is", recheck=False, - pre_res=pre_res_value + pre_res=pre_res_value, ) assert result.is_(pre_res_value) - def test_recheck_option(self): + def test_recheck_option(self) -> None: """Test the logic of recheck option(force recheck the type)""" result = match_polars_type( - 42, - recheck=True, - pre_res=pl.Utf8 + 42, + recheck=True, + pre_res=pl.Utf8, ) - + assert result.is_(pl.Int64) - def test_edge_cases(self): + def test_edge_cases(self) -> None: """Test the edge use of the function""" - + # structured array structured_array = np.array( - [('Rex', 9, 81.0), ('Fido', 3, 27.0)], - dtype=[('name', 'U10'), ('age', 'i4'), ('weight', 'f4')] - ) + [("Rex", 9, 81.0), ("Fido", 3, 27.0)], + dtype=[("name", "U10"), ("age", "i4"), ("weight", "f4")], + ) # fmt: skip assert self.is_same_type(structured_array, pl.Object) # nested container - nested_ls = [[1, 2], (3., 4.)] - + nested_ls = [[1, 2], (3.0, 4.0)] + assert self.is_same_type(nested_ls, pl.List(pl.List(pl.Int64))) + @pytest.mark.vital class TestTimer: # basic function test @pytest.mark.parametrize( - argnames="task", - argvalues=[ - "basic task", # normal string - "", # empty string - "任劑描述", # Non-ASCII characters - "a" * 200, # long string - "special_!@#$%^&*" # special characters - ] - ) - def test_basic_use(self, task, mock_status, capsys): + argnames="task", + argvalues=[ + "basic task", # normal string + "", # empty string + "任劑描述", # Non-ASCII characters + "a" * 200, # long string + "special_!@#$%^&*", # special characters + ], + ) + def test_basic_use(self, task, mock_status, capsys) -> None: from time import sleep + from rich import get_console console = get_console() console_width = console.width if len(task) > console_width: - task = [task[i:i+console_width] for i in range(0, len(task), console_width)] + task = [task[i : i + console_width] for i in range(0, len(task), console_width)] task = "\n" + "\n".join(task) with Timer(task_desc=task): - sleep(1) - + sleep(1) + mock_status["enter"].assert_called_once() mock_status["exit"].assert_called_once() - + captured = capsys.readouterr() plain_text = Text.from_ansi(captured.out).plain assert f"Finish {task} in" in plain_text assert "seconds" in plain_text # boundary condition test - def test_short_time(self, capsys): + def test_short_time(self, capsys) -> None: with Timer("short task"): # quit immediately pass - + captured = capsys.readouterr() assert "0.0000" in captured.out - def test_long_time(self, capsys): + def test_long_time(self, capsys) -> None: with patch("torchmeter.utils.perf_counter") as mock_time: # Simulate a one-year time difference - mock_time.side_effect = [0.0, 31536000.0] + mock_time.side_effect = [0.0, 31536000.0] with Timer("long task"): pass - + captured = capsys.readouterr() assert "31536000.0000" in captured.out # Error handling test - def test_exception_handling(self, mock_status): + def test_exception_handling(self, mock_status) -> NoReturn: class CustomError(Exception): ... - with pytest.raises(CustomError): - with Timer("error task"): - raise CustomError("test error") - + with pytest.raises(CustomError), Timer("error task"): + raise CustomError("test error") + assert mock_status["enter"].assert_called_once assert mock_status["exit"].assert_called_once # verify time accuracy - def test_time_accuracy(self, capsys): + def test_time_accuracy(self, capsys) -> None: with patch("torchmeter.utils.perf_counter") as mock_time: mock_time.side_effect = [100.0, 100.1234] with Timer("precision test"): pass - + captured = capsys.readouterr() - assert "0.1234" in captured.out \ No newline at end of file + assert "0.1234" in captured.out diff --git a/torchmeter/__cli__.py b/torchmeter/__cli__.py index 2e7fe9d..ce80358 100644 --- a/torchmeter/__cli__.py +++ b/torchmeter/__cli__.py @@ -1,9 +1,12 @@ from torchmeter import __version__ -def main(): + +def main() -> None: print(f"Sorry, TorchMeter {__version__} does not support command line interface yet.") - print("Please use it as a library, " - "or update to the newest version using `pip install -U torchmeter` and try again.") + print( + "Please use it as a library, or update to the newest version using `pip install -U torchmeter` and try again." + ) + -if __name__ == '__main__': - main() \ No newline at end of file +if __name__ == "__main__": + main() diff --git a/torchmeter/__init__.py b/torchmeter/__init__.py index 0fcae1f..9165df4 100644 --- a/torchmeter/__init__.py +++ b/torchmeter/__init__.py @@ -15,25 +15,25 @@ ------------------------------------------------------------------------------------------------------------ Core Functionality: - 1. Parameter Analysis + 1. Parameter Analysis - Total/trainable parameter quantification - Layer-wise parameter distribution analysis - Gradient state tracking (requires_grad flags) - + 2. Computational Profiling - FLOPs/MACs precision calculation - Operation-wise calculation distribution analysis - Dynamic input/output detection (number, type, shape, ...) - - 3. Memory Diagnostics + + 3. Memory Diagnostics - Input/output tensor memory awareness - Hierarchical memory consumption analysis - + 4. Performance Benchmarking - Auto warm-up phase execution (eliminates cold-start bias) - Device-specific high-precision timing - Inference latency & Throughput Benchmarking - + 5. Visualization Engine - Centralized configuration management - Programmable tabular report @@ -44,7 +44,7 @@ - Rich-text hierarchical structure tree rendering 1. Style customization and real-time rendering 2. Smart module folding based on structural equivalence detection - + 6. Cross-Platform Support - Automatic model-data co-location - Seamless device transition (CPU/CUDA) @@ -58,4 +58,4 @@ from torchmeter.core import Meter from torchmeter.config import get_config -__all__ = ["Meter", "get_config"] \ No newline at end of file +__all__ = ["Meter", "get_config"] diff --git a/torchmeter/_stat_numeric.py b/torchmeter/_stat_numeric.py index ea3d69d..39f0014 100644 --- a/torchmeter/_stat_numeric.py +++ b/torchmeter/_stat_numeric.py @@ -1,181 +1,213 @@ from __future__ import annotations -from typing import TYPE_CHECKING from abc import ABC, abstractmethod +from typing import TYPE_CHECKING from functools import total_ordering import numpy as np -from torchmeter.unit import ( - auto_unit, - CountUnit, BinaryUnit, TimeUnit, SpeedUnit -) +from torchmeter.unit import TimeUnit, CountUnit, SpeedUnit, BinaryUnit, auto_unit if TYPE_CHECKING: - from typing import Optional, Tuple, Union - from typing import Type, Callable, Sequence + import sys + from typing import Type, Tuple, Union, Callable, Optional, Sequence + + if sys.version_info >= (3, 11): + from typing import Self + else: + from typing_extensions import Self from numpy.typing import NDArray UNIT_TYPE = Optional[Union[Type[CountUnit], Type[BinaryUnit], Type[TimeUnit], - Type[SpeedUnit]]] + Type[SpeedUnit]]] # fmt: skip SEQ_DATA = NDArray[Union[np.int_, np.float_]] FLOAT = Union[float, np.float_] + NUMERIC_DATA_TYPE = Union[int, float] SEQ_FUNC = Callable[[SEQ_DATA], FLOAT] + @total_ordering class NumericData(ABC): - - __slots__:Sequence[str] = [] - + __slots__: Sequence[str] = [] + @property @abstractmethod - def raw_data(self) -> FLOAT: - ... + def raw_data(self) -> FLOAT: ... - def _numeric_op(self, other, op): + def _numeric_op(self, other: Union[object, int, float], op: Callable): # noqa: ANN202 other_data = other.raw_data if isinstance(other, NumericData) else other return op(self.raw_data, other_data) - + # required for generating other comparison operators - def __eq__(self, other): - return self._numeric_op(other, lambda s,o: s == o) + def __eq__(self, other: object) -> bool: + return self._numeric_op(other, lambda s, o: s == o) - def __lt__(self, other): - return self._numeric_op(other, lambda s,o: s < o) + def __lt__(self, other: NUMERIC_DATA_TYPE) -> bool: + return self._numeric_op(other, lambda s, o: s < o) # arithmetic operations - def __add__(self, other): return self._numeric_op(other, lambda s,o: s + o) + def __add__(self, other: NUMERIC_DATA_TYPE) -> Union[NUMERIC_DATA_TYPE, Self]: + return self._numeric_op(other, lambda s, o: s + o) + __radd__ = __add__ - - def __sub__(self, other): return self._numeric_op(other, lambda s,o: s - o) - __rsub__ = lambda self, other: self._numeric_op(other, lambda s,o: o - s) - - def __mul__(self, other): return self._numeric_op(other, lambda s,o: s * o) + + def __sub__(self, other: NUMERIC_DATA_TYPE) -> Union[NUMERIC_DATA_TYPE, Self]: + return self._numeric_op(other, lambda s, o: s - o) + + __rsub__ = lambda self, other: self._numeric_op(other, lambda s, o: o - s) + + def __mul__(self, other: NUMERIC_DATA_TYPE) -> Union[NUMERIC_DATA_TYPE, Self]: + return self._numeric_op(other, lambda s, o: s * o) + __rmul__ = __mul__ - - def __truediv__(self, other): return self._numeric_op(other, lambda s,o: s / o) - __rtruediv__ = lambda self, other: self._numeric_op(other, lambda s,o: o / s) - + + def __truediv__(self, other: NUMERIC_DATA_TYPE) -> Union[float, Self]: + return self._numeric_op(other, lambda s, o: s / o) + + __rtruediv__ = lambda self, other: self._numeric_op(other, lambda s, o: o / s) + # type conversion - def __float__(self): return float(self.raw_data) - def __int__(self): return int(self.raw_data) - def __round__(self, ndigits=None): return round(self.raw_data, ndigits) + def __float__(self) -> float: + return float(self.raw_data) -class UpperLinkData(NumericData): + def __int__(self) -> int: + return int(self.raw_data) + + def __round__(self, ndigits: Optional[int] = None) -> Union[int, FLOAT]: + return round(self.raw_data, ndigits) - __slots__ = ['val', 'none_str', - '__access_cnt', '__parent_data', '__unit_sys'] - def __init__(self, - val:Union[int, float]=0, parent_data:Optional[UpperLinkData]=None, - unit_sys:UNIT_TYPE=None, - none_str:str='-') -> None: - +class UpperLinkData(NumericData): + __slots__ = ["val", "none_str", "__access_cnt", "__parent_data", "__unit_sys"] + + def __init__( + self, + val: Union[int, float] = 0, + parent_data: Optional[UpperLinkData] = None, + unit_sys: UNIT_TYPE = None, + none_str: str = "-", + ) -> None: if not isinstance(val, (int, float)): raise TypeError(f"`val` must be `int` or `float`, but got `{type(val).__name__}`.") - + if not isinstance(parent_data, (UpperLinkData, type(None))): - raise TypeError("`parent_data` must be an instance of `UpperLinkData` or `None`, " + \ - f"but got `{type(parent_data).__name__}`.") - + raise TypeError( + "`parent_data` must be an instance of `UpperLinkData` or `None`, " + + f"but got `{type(parent_data).__name__}`." + ) + if unit_sys not in (None, CountUnit, BinaryUnit, TimeUnit, SpeedUnit): - raise TypeError("`unit_sys` must be `None` or one of `(CountUnit, BinaryUnit, TimeUnit, SpeedUnit)`, " + \ - f"but got `{type(unit_sys).__name__}`.") - + raise TypeError( + "`unit_sys` must be `None` or one of `(CountUnit, BinaryUnit, TimeUnit, SpeedUnit)`, " + + f"but got `{type(unit_sys).__name__}`." + ) + if not isinstance(none_str, str): raise TypeError(f"`none_str` must be a string, but got `{type(none_str).__name__}`.") - + self.val = val - self.__parent_data = parent_data + self.__parent_data = parent_data self.__unit_sys = unit_sys self.__access_cnt = 1 - self.none_str = none_str # Use when there is a "None" in the column where this data is located while rendering the table. - + + # Use when there is a "None" in the column where this data is located while rendering the table + self.none_str = none_str + @property def raw_data(self) -> float: return float(self.val) - + def mark_access(self) -> None: self.__access_cnt += 1 - - def __iadd__(self, other) -> UpperLinkData: + + def __iadd__(self, other: NUMERIC_DATA_TYPE) -> UpperLinkData: if not isinstance(other, (int, float)): - raise TypeError(f"Instances of {self.__class__.__name__} can only be added in place with " + \ - f"`int` or `float` data, but provided `{type(other).__name__}`.") + raise TypeError( + f"Instances of {self.__class__.__name__} can only be added in place with " + + f"`int` or `float` data, but provided `{type(other).__name__}`." + ) self.val += other self.__upper_update(other) # self.__access_cnt += 1 return self - - def __upper_update(self, other:Union[int, float]) -> None: + + def __upper_update(self, other: Union[int, float]) -> None: if self.__parent_data is not None: self.__parent_data += other - + def __repr__(self) -> str: if self.__unit_sys is not None: - base = auto_unit(self.val/self.__access_cnt, self.__unit_sys) + base = auto_unit(self.val / self.__access_cnt, self.__unit_sys) else: - base = str(self.val/self.__access_cnt) - return base + (f" [dim](Γ—{self.__access_cnt})[/]" if self.__access_cnt > 1 else "") + base = str(self.val / self.__access_cnt) + return base + (f" [dim](Γ—{self.__access_cnt})[/]" if self.__access_cnt > 1 else "") # noqa: RUF001 -class MetricsData(NumericData): - __slots__ = ['vals', 'none_str', - '__reduce_func', '__unit_sys', ] - - def __init__(self, - reduce_func:Optional[SEQ_FUNC]=np.mean, - unit_sys:UNIT_TYPE=CountUnit, - none_str:str='-') -> None: +class MetricsData(NumericData): + __slots__ = [ + "vals", + "none_str", + "__reduce_func", + "__unit_sys", + ] + + def __init__( + self, reduce_func: Optional[SEQ_FUNC] = np.mean, unit_sys: UNIT_TYPE = CountUnit, none_str: str = "-" + ) -> None: if reduce_func is not None and not callable(reduce_func): - raise TypeError("`reduce_func` must be a callable object, " + \ - f"but got `{type(reduce_func).__name__}`.") + raise TypeError("`reduce_func` must be a callable object, " + f"but got `{type(reduce_func).__name__}`.") elif reduce_func is not None: _ = reduce_func(np.array([1, 2, 3])) if not isinstance(_, (int, float)): - raise RuntimeError("The return type of `reduce_func` must be `int` or `float`, " + \ - f"but got `{type(_).__name__}`.") + raise RuntimeError( + "The return type of `reduce_func` must be `int` or `float`, " + f"but got `{type(_).__name__}`." + ) if unit_sys not in (None, CountUnit, BinaryUnit, TimeUnit, SpeedUnit): - raise TypeError("`unit_sys` must be `None` or one of `(CountUnit, BinaryUnit, TimeUnit, SpeedUnit)`, " + \ - f"but got `{type(unit_sys).__name__}`.") - + raise TypeError( + "`unit_sys` must be `None` or one of `(CountUnit, BinaryUnit, TimeUnit, SpeedUnit)`, " + + f"but got `{type(unit_sys).__name__}`." + ) + if not isinstance(none_str, str): raise TypeError(f"`none_str` must be a string, but got `{type(none_str).__name__}`.") - - self.vals:SEQ_DATA = np.array([]) + + self.vals: SEQ_DATA = np.array([]) self.__reduce_func = reduce_func if reduce_func is not None else np.mean self.__unit_sys = unit_sys self.none_str = none_str @property def metrics(self) -> FLOAT: - return self.__reduce_func(self.vals) if self.vals.any() else 0. - + return self.__reduce_func(self.vals) if self.vals.any() else 0.0 + @property def iqr(self) -> FLOAT: if self.vals.any(): return np.percentile(self.vals, 75) - np.percentile(self.vals, 25) else: - return 0. - + return 0.0 + @property def val(self) -> Tuple[FLOAT, FLOAT]: return self.metrics, self.iqr - + @property def raw_data(self) -> FLOAT: return self.metrics - - def append(self, new_val:Union[int, FLOAT]) -> None: + + def append(self, new_val: Union[int, FLOAT]) -> None: if not isinstance(new_val, (int, float)): - raise TypeError(f"Instances of {self.__class__.__name__} can only be appended with `int` or `float` data, " + \ - f"but got `{type(new_val).__name__}`.") + raise TypeError( + f"Instances of {self.__class__.__name__} can only be appended with `int` or `float` data, " + + f"but got `{type(new_val).__name__}`." + ) + self.vals = np.append(self.vals, new_val) def clear(self) -> None: @@ -183,7 +215,6 @@ def clear(self) -> None: def __repr__(self) -> str: if self.__unit_sys is not None: - return f"{auto_unit(self.metrics, self.__unit_sys)}" + ' Β± ' + \ - f"{auto_unit(self.iqr, self.__unit_sys)}" + return f"{auto_unit(self.metrics, self.__unit_sys)}" + " Β± " + f"{auto_unit(self.iqr, self.__unit_sys)}" else: return f"{self.metrics:.2f} Β± {self.iqr:.2f}" diff --git a/torchmeter/config.py b/torchmeter/config.py index 6a0fa0a..d8f89c2 100644 --- a/torchmeter/config.py +++ b/torchmeter/config.py @@ -1,11 +1,11 @@ from __future__ import annotations -from typing import TYPE_CHECKING import os import warnings -from threading import Lock from enum import Enum, unique from types import SimpleNamespace +from typing import TYPE_CHECKING +from threading import Lock import yaml from rich import box @@ -20,21 +20,24 @@ else: from typing_extensions import TypeAlias - from typing import Any, Dict, Sequence, Union, List - from typing import Callable, Optional + from typing import Any, Dict, List, Union, Callable, Optional, Sequence CFG_CONTENT_TYPE: TypeAlias = Union[ - int, float, str, bool, None, - Sequence["CFG_CONTENT_TYPE"], - Dict[str, "CFG_CONTENT_TYPE"] + int, float, str, bool, Sequence["CFG_CONTENT_TYPE"], Dict[str, "CFG_CONTENT_TYPE"], None ] __all__ = ["get_config", "Config"] -DEFAULT_FIELDS = ['render_interval', - 'tree_fold_repeat', 'tree_repeat_block_args', 'tree_levels_args', - 'table_column_args','table_display_args', - 'combine'] +DEFAULT_FIELDS = [ + "render_interval", + "tree_fold_repeat", + "tree_repeat_block_args", + "tree_levels_args", + "table_column_args", + "table_display_args", + "combine", +] + DEFAULT_CFG = """\ render_interval: 0.15 @@ -123,6 +126,7 @@ horizon_gap: 2 """ + @unique class BOX(Enum): ASCII = box.ASCII @@ -145,14 +149,13 @@ class BOX(Enum): SQUARE = box.SQUARE SQUARE_DOUBLE_HEAD = box.SQUARE_DOUBLE_HEAD + # all the keys should be str, while all the value should be enum # int each value, all the member's name should not be the same with its value's repr -UNSAFE_KV = { - 'box': BOX -} +UNSAFE_KV = {"box": BOX} + -def list_to_callbacklist(ls: List[Any], - callback_func: Callable[[], Any]=lambda: None) -> CallbackList: +def list_to_callbacklist(ls: List[Any], callback_func: Callable[[], Any] = lambda: None) -> CallbackList: _list: List[Any] = [] for item in ls: if isinstance(item, dict): @@ -163,15 +166,16 @@ def list_to_callbacklist(ls: List[Any], _list.append(item) return CallbackList(_list, callback_func=callback_func) + def dict_to_namespace(d: Dict[str, Any]) -> FlagNameSpace: """ Recursively converts a dictionary to a FlagNameSpace object. - """ + """ # noqa: DOC201, DOC501 if not isinstance(d, dict): raise TypeError(f"Input must be a dictionary, but got `{type(d).__name__}`") - + ns = FlagNameSpace() - + for k, v in d.items(): # overwrite the value of unsafe key to get the unrepresent value if k in UNSAFE_KV and isinstance(v, str): @@ -179,199 +183,229 @@ def dict_to_namespace(d: Dict[str, Any]) -> FlagNameSpace: if isinstance(v, dict): setattr(ns, k, dict_to_namespace(v)) - + elif isinstance(v, list): setattr(ns, k, list_to_callbacklist(v, callback_func=ns.mark_change)) - + elif isinstance(v, set): setattr(ns, k, CallbackSet(v, callback_func=ns.mark_change)) - + else: if not isinstance(k, str): raise TypeError(f"Attribute name must be a string, but got `{type(k).__name__}`") - + setattr(ns, k, v) - + return ns -def namespace_to_dict(ns, safe_resolve=False) -> Dict[str, CFG_CONTENT_TYPE]: + +def namespace_to_dict(ns: FlagNameSpace, safe_resolve: bool = False) -> Dict[str, CFG_CONTENT_TYPE]: """ Recursively converts a FlagNameSpace object to a dictionary. - """ - if not isinstance(ns, SimpleNamespace): - raise TypeError(f"Input must be an instance of SimpleNamespace, but got `{type(ns).__name__}`") + """ # noqa: DOC201, DOC501 + if not isinstance(ns, FlagNameSpace): + raise TypeError(f"Input must be an instance of FlagNameSpace, but got `{type(ns).__name__}`") - d:Dict[str, CFG_CONTENT_TYPE] = {} + d: Dict[str, CFG_CONTENT_TYPE] = {} for k, v in ns.data_dict.items(): # transform the unrepresent value to its name defined in corresponding Enum if k in UNSAFE_KV and safe_resolve: v = UNSAFE_KV[k](v).name - if isinstance(v, SimpleNamespace): + if isinstance(v, FlagNameSpace): d[k] = namespace_to_dict(v, safe_resolve=safe_resolve) - + elif isinstance(v, list): _list: List[CFG_CONTENT_TYPE] = [] for item in v: - if isinstance(item, SimpleNamespace): + if isinstance(item, FlagNameSpace): _list.append(namespace_to_dict(item, safe_resolve=safe_resolve)) else: _list.append(item) d[k] = _list - + else: d[k] = v - + return d -def get_config(config_path:Optional[str]=None) -> Config: - cfg_path = os.environ.get('TORCHMETER_CONFIG', config_path) - cfg = Config(config_path=cfg_path) # always exist an instance cause display.py and core.py depend on it + +def get_config(config_path: Optional[str] = None) -> Config: + cfg_path = os.environ.get("TORCHMETER_CONFIG", config_path) + cfg = Config(config_path=cfg_path) # always exist an instance cause display.py and core.py depend on it return cfg + class CallbackList(list): - def __init__(self, *args, callback_func=lambda:None, **kwargs) -> None: + def __init__(self, *args, callback_func: Callable = lambda: None, **kwargs) -> None: super().__init__(*args, **kwargs) - + self._callback_func = callback_func - self._register_callback("append", "extend", "insert", - "pop", "remove", "clear", - "reverse", "sort", - "__setitem__", "__delitem__", - "__iadd__", "__imul__") - + self._register_callback( + "append", + "extend", + "insert", + "pop", + "remove", + "clear", + "reverse", + "sort", + "__setitem__", + "__delitem__", + "__iadd__", + "__imul__", + ) + def _register_callback(self, *methods) -> None: for method_name in methods: orig_method = getattr(self.__class__, method_name) - def wrapped_method(*args, _method=orig_method, **kwargs): + + def wrapped_method(*args, _method: Callable = orig_method, **kwargs) -> Any: result = _method(*args, **kwargs) self._callback_func() return result - setattr(self.__class__, method_name, wrapped_method) + + setattr(self.__class__, method_name, wrapped_method) + class CallbackSet(set): - def __init__(self, *args, callback_func=lambda:None, **kwargs): + def __init__(self, *args, callback_func: Callable = lambda: None, **kwargs) -> None: super().__init__(*args, **kwargs) - + self._callback_func = callback_func - self._register_callback("add", "update", "difference_update", - "intersection_update", "symmetric_difference_update", - "discard", "pop", "remove", "clear", - "__isub__", "__iand__", - "__ixor__", "__ior__") - - def _register_callback(self, *methods): + self._register_callback( + "add", + "update", + "difference_update", + "intersection_update", + "symmetric_difference_update", + "discard", + "pop", + "remove", + "clear", + "__isub__", + "__iand__", + "__ixor__", + "__ior__", + ) + + def _register_callback(self, *methods) -> None: for method_name in methods: orig_method = getattr(self.__class__, method_name) - def wrapped_method(*args, _method=orig_method, **kwargs): + + def wrapped_method(*args, _method: Callable = orig_method, **kwargs) -> Any: result = _method(*args, **kwargs) self._callback_func() return result - setattr(self.__class__, method_name, wrapped_method) + + setattr(self.__class__, method_name, wrapped_method) + class FlagNameSpace(SimpleNamespace): - - __flag_key = '__FLAG' - + __flag_key = "__FLAG" + def __init__(self, **kwargs) -> None: - list(map(lambda x: setattr(self, x, kwargs[x]), kwargs)) + list(map(lambda x: setattr(self, x, kwargs[x]), kwargs)) self.mark_unchange() - - def __setattr__(self, key: str, value: Any) -> None: + + def __setattr__(self, key: str, value: Any) -> None: if key in ("__flag_key", self.__flag_key): - raise AttributeError(f"`{key}` is preserved for internal use, " + \ - "you should never try to set it to a new value.") - + raise AttributeError( + f"`{key}` is preserved for internal use, " + "you should never try to set it to a new value." + ) + if isinstance(value, dict): value = dict_to_namespace(value) elif isinstance(value, list): value = list_to_callbacklist(value, callback_func=self.mark_change) elif isinstance(value, set): value = CallbackSet(value, callback_func=self.mark_change) - + super().__setattr__(key, value) - + self.mark_change() - - def __delattr__(self, key): + + def __delattr__(self, key: str) -> None: if key in ("__flag_key", self.__flag_key): - raise AttributeError(f"`{key}` is preserved for internal use, " + \ - "you should never try to delete it.") - + raise AttributeError(f"`{key}` is preserved for internal use, " + "you should never try to delete it.") + super().__delattr__(key) - + self.mark_change() - + @property - def data_dict(self): + def data_dict(self) -> Dict[str, Any]: full_dict = self.__dict__.copy() del full_dict[self.__flag_key] return full_dict - - def update(self, other:Union[dict, FlagNameSpace], *, replace:bool=False) -> None: - """`other` should keep a same hierarchy structure with `self`""" - + + def update(self, other: Union[dict, FlagNameSpace], *, replace: bool = False) -> None: + """`other` should keep a same hierarchy structure with `self`""" # noqa: DOC501 + if not isinstance(other, (dict, FlagNameSpace)): - raise TypeError(f"Instance of `{self.__class__.__name__}` can only be updated with a dict or " + \ - f"another instance of `{self.__class__.__name__}`, but got `{type(other).__name__}`.") - + raise TypeError( + f"Instance of `{self.__class__.__name__}` can only be updated with a dict or " + + f"another instance of `{self.__class__.__name__}`, but got `{type(other).__name__}`." + ) + if isinstance(other, dict): other = dict_to_namespace(other) - + if replace: replace_data = other.data_dict self.__dict__.update(replace_data) - + del_keys = set(self.data_dict.keys()) - set(replace_data.keys()) list(map(lambda k: delattr(self, k), del_keys)) - + self.mark_change() - return - - for k, v in other.data_dict.items(): - + return + + for k, v in other.data_dict.items(): if k not in self.__dict__: if isinstance(v, dict): v = dict_to_namespace(v) setattr(self, k, v) continue - - # if the value is a dict or a FlagNameSpace, + + # if the value is a dict or a FlagNameSpace, # update the orgin namespace (origin must be a FlagNameSpace) origin_val_type = type(self.__dict__[k]).__name__ new_val_type = type(v).__name__ if origin_val_type == "FlagNameSpace": if new_val_type not in ["dict", "FlagNameSpace"]: - raise RuntimeError(f"Operation aborted: the origin value of `{k}` is of type " + \ - "`FlagNameSpace` which has a inner structure, " + \ - f"set to `{new_val_type}` will destroy it.") + raise RuntimeError( + f"Operation aborted: the origin value of `{k}` is of type " + + "`FlagNameSpace` which has a inner structure, " + + f"set to `{new_val_type}` will destroy it." + ) else: self.__dict__[k].update(v) else: setattr(self, k, v) - + def is_change(self) -> bool: - res = getattr(self, self.__flag_key) or \ - any(args.is_change() for args in self.__dict__.values() - if isinstance(args, self.__class__)) + res = getattr(self, self.__flag_key) or any( + args.is_change() for args in self.__dict__.values() if isinstance(args, self.__class__) + ) self.__dict__[self.__flag_key] = res return res - + def mark_change(self) -> None: self.__dict__[self.__flag_key] = True - + def mark_unchange(self) -> None: self.__dict__[self.__flag_key] = False - list(map(lambda x: x.mark_unchange() if isinstance(x, self.__class__) else None, - self.__dict__.values())) - + list(map(lambda x: x.mark_unchange() if isinstance(x, self.__class__) else None, self.__dict__.values())) + + class ConfigMeta(type): """To achieve sigleton pattern""" - + __instances = None __thread_lock = Lock() - def __call__(cls, config_path:Optional[str]=None) -> Config: + def __call__(cls, config_path: Optional[str] = None) -> Config: with cls.__thread_lock: if cls.__instances is None: instance = super().__call__(config_path) @@ -381,10 +415,10 @@ def __call__(cls, config_path:Optional[str]=None) -> Config: cls.__instances.config_file = config_path return cls.__instances + class Config(metaclass=ConfigMeta): - """You can only read or write the predefined fields in the instance""" - + render_interval: float tree_fold_repeat: bool tree_repeat_block_args: FlagNameSpace @@ -393,52 +427,54 @@ class Config(metaclass=ConfigMeta): table_display_args: FlagNameSpace combine: FlagNameSpace - __slots__ = DEFAULT_FIELDS + ['__cfg_file'] - - def __init__(self, config_path:Optional[str]=None) -> None: + __slots__ = [*DEFAULT_FIELDS, "__cfg_file"] + + def __init__(self, config_path: Optional[str] = None) -> None: """Load default settings by default""" # init __cfg_file self.__cfg_file = None - + # set __cfg_file, load config file and check its integrity self.config_file = config_path - + @property def config_file(self) -> Optional[str]: return self.__cfg_file - + @config_file.setter - def config_file(self, file_path:Optional[str]=None) -> None: + def config_file(self, file_path: Optional[str] = None) -> None: if file_path is not None and not isinstance(file_path, str): - raise TypeError("You must pass in a string or None to change config or use the default config, " + \ - f"but got `{type(file_path).__name__}`.") - + raise TypeError( + "You must pass in a string or None to change config or use the default config, " + + f"but got `{type(file_path).__name__}`." + ) + if file_path: file_path = os.path.abspath(file_path) if not os.path.isfile(file_path): raise FileNotFoundError(f"Config file {file_path} does not exist.") - if not file_path.endswith('.yaml'): + if not file_path.endswith(".yaml"): raise ValueError(f"Config file must be a yaml file, but got `{file_path}`") - + self.__cfg_file = file_path self.__load() self.check_integrity() - + def __load(self) -> None: if self.config_file is None: raw_data = yaml.safe_load(DEFAULT_CFG) else: - with open(self.config_file, 'r') as f: + with open(self.config_file, "r") as f: raw_data = yaml.safe_load(f) - - ns:FlagNameSpace = dict_to_namespace(raw_data) - for k,v in ns.data_dict.items(): + + ns: FlagNameSpace = dict_to_namespace(raw_data) + for k, v in ns.data_dict.items(): is_reload = hasattr(self, k) - + if is_reload and isinstance(v, FlagNameSpace): getattr(self, k).update(v, replace=True) else: - setattr(self, k, v) + setattr(self, k, v) def restore(self) -> None: self.__load() @@ -448,39 +484,49 @@ def check_integrity(self) -> None: # no need to check integrity when loading default settings if self.config_file is None: return None - - with open(self.config_file, 'r') as f: + + with open(self.config_file, "r") as f: custom_cfg = yaml.safe_load(f) - + for field in DEFAULT_FIELDS: if field not in custom_cfg: - warnings.warn(message=f"Config file {self.config_file} does not contain '{field}' field, " + \ - "using default settings instead.", - category=UserWarning) - - def asdict(self, safe_resolve=False) -> Dict[str, CFG_CONTENT_TYPE]: - d:Dict[str, CFG_CONTENT_TYPE] = {} + warnings.warn( + category=UserWarning, + message=f"Config file {self.config_file} does not contain '{field}' field, " + + "using default settings instead.", + ) + + def asdict(self, safe_resolve: bool = False) -> Dict[str, CFG_CONTENT_TYPE]: + d: Dict[str, CFG_CONTENT_TYPE] = {} + for field in DEFAULT_FIELDS: field_val = getattr(self, field) - if isinstance(field_val, SimpleNamespace): + + if isinstance(field_val, FlagNameSpace): d[field] = namespace_to_dict(field_val, safe_resolve=safe_resolve) + elif isinstance(field_val, list): - d[field] = [namespace_to_dict(v, safe_resolve=safe_resolve) if isinstance(v, SimpleNamespace) else v - for v in field_val] + d[field] = [ + namespace_to_dict(v, safe_resolve=safe_resolve) if isinstance(v, FlagNameSpace) else v + for v in field_val + ] + elif isinstance(field_val, dict): - d[field] = {k:namespace_to_dict(v, safe_resolve=safe_resolve) if isinstance(v, SimpleNamespace) else v - for k,v in field_val.items()} + d[field] = { + k: namespace_to_dict(v, safe_resolve=safe_resolve) if isinstance(v, FlagNameSpace) else v + for k, v in field_val.items() + } + else: d[field] = field_val + return d - - def dump(self, save_path:str) -> None: + + def dump(self, save_path: str) -> None: d = self.asdict(safe_resolve=True) - with open(save_path, 'w') as f: - yaml.safe_dump(d, f, - indent=2, sort_keys=False, - encoding='utf-8', allow_unicode=True) + with open(save_path, "w") as f: + yaml.safe_dump(d, f, indent=2, sort_keys=False, encoding="utf-8", allow_unicode=True) def __setattr__(self, name: str, value: Any) -> None: # the attribute is already exist @@ -491,6 +537,7 @@ def __setattr__(self, name: str, value: Any) -> None: origin_val.update(value) else: super().__setattr__(name, value) + # the first time to set the attribute except AttributeError: super().__setattr__(name, value) @@ -502,35 +549,36 @@ def __delattr__(self, name: str) -> None: def __repr__(self) -> str: d = self.asdict(safe_resolve=True) - def simple_data_repr(val:Any) -> str: - val_repr = [] - + def simple_data_repr(val: Any) -> str: + val_repr = [] + if isinstance(val, dict): val_repr.append("namespace{") - val_repr.extend([f"{k} = {simple_data_repr(v)}" for k,v in val.items()]) + val_repr.extend([f"{k} = {simple_data_repr(v)}" for k, v in val.items()]) # val_repr[-1] += "}" val_repr.append("}") - + elif isinstance(val, (tuple, list, set)): val_repr.append(f"{type(val).__name__}(") val_repr.extend([f"- {simple_data_repr(v)}" for v in val]) # val_repr[-1] += ")" val_repr.append(")") - + else: return f"{val} | <{type(val).__name__}>" return indent_str(val_repr, indent=4, process_first=False) - - s = 'β€’ Config file: ' + (self.config_file if self.config_file else 'None(default setting below)') + "\n"*2 - for field_name, field_vals in d.items(): - s += f"β€’ {field_name}: " + simple_data_repr(field_vals) + "\n"*2 + + s = "β€’ Config file: " + (self.config_file if self.config_file else "None(default setting below)") + "\n" * 2 + for field_name, field_vals in d.items(): + s += f"β€’ {field_name}: " + simple_data_repr(field_vals) + "\n" * 2 return s -if __name__ == '__main__': + +if __name__ == "__main__": default_cfg = Config() print(default_cfg) cfg1 = Config() cfg2 = Config() if id(cfg1) == id(cfg2): - print('Singleton Pattern Success.') \ No newline at end of file + print("Singleton Pattern Success.") diff --git a/torchmeter/core.py b/torchmeter/core.py index 48cee06..8b319e5 100644 --- a/torchmeter/core.py +++ b/torchmeter/core.py @@ -1,28 +1,29 @@ from __future__ import annotations + from typing import TYPE_CHECKING import torch.nn as nn from rich import get_console -from rich.columns import Columns from torch import Tensor from torch import device as tc_device +from rich.columns import Columns from torchmeter.config import get_config -from torchmeter.statistic import Statistics from torchmeter.display import render_perline +from torchmeter.statistic import Statistics if TYPE_CHECKING: import sys from typing import Any, Dict, List, Tuple, Union, Optional - - from rich.tree import Tree + + from polars import DataFrame from rich.text import Text + from rich.tree import Tree from rich.table import Table - from polars import DataFrame from torchmeter.config import FlagNameSpace - from torchmeter.statistic import ParamsMeter, CalMeter, MemMeter, IttpMeter - + from torchmeter.statistic import CalMeter, MemMeter, IttpMeter, ParamsMeter + if sys.version_info >= (3, 8): from typing import TypedDict else: @@ -32,64 +33,66 @@ class IPT_TYPE(TypedDict): args: Tuple[Any, ...] kwargs: Dict[str, Any] + __all__ = ["Meter"] __cfg__ = get_config() + class Meter: """A comprehensive instrumentation tool for PyTorch model performance analysis and visualization. - - The `Meter` class provides end-to-end measurement capabilities for neural networks, including - parameter statistics, computational cost analysis, memory usage tracking, inference time and - throughput analysis. It serves as a wrapper around PyTorch modules while maintaining full compatibility + + The `Meter` class provides end-to-end measurement capabilities for neural networks, including + parameter statistics, computational cost analysis, memory usage tracking, inference time and + throughput analysis. It serves as a wrapper around PyTorch modules while maintaining full compatibility with native model operations. Key Features: `easy-to-use`, `comprehensive`, and `flexible` - 1. **Zero-Intrusion Proxy** + 1. **Zero-Intrusion Proxy** - acts as drop-in decorator without any changes of the underlying model - Seamlessly integrates with PyTorch modules while preserving full compatibility (attributes and methods) - 2. **Full-Stack Model Analytics**: Holistic performance analytics across 5 dimensions: + 2. **Full-Stack Model Analytics**: Holistic performance analytics across 5 dimensions: - parameter distribution - calculation cost: FLOPs/MACs - memory access assessment - inference latency - throughput benchmarking - - 3. **Rich visualization** - - Programmable tabular reports with real-time rendering + + 3. **Rich visualization** + - Programmable tabular reports with real-time rendering - Hierarchical operation tree with smart folding of repeated blocks for model structure insights - 4. **Fine-Grained Customization** - - Real-time hot-reload rendering: Dynamic adjustment of rendering configuration for operation trees, + 4. **Fine-Grained Customization** + - Real-time hot-reload rendering: Dynamic adjustment of rendering configuration for operation trees, report tables and their nested components - Progressive update: Namespace assignment + dictionary batch update - 5. **Config-Driven Runtime Management** + 5. **Config-Driven Runtime Management** - Centralized control: Singleton-managed global configuration for dynamic behavior adjustment - Portable presets: Export/import YAML profiles for runtime behaviors, eliminating repetitive setup - 6. **Portability and Practicality** + 6. **Portability and Practicality** - Decoupled pipeline: Separation of data collection and visualization - Automatic device synchronization: Maintains production-ready status by keeping model and data co-located - - Dual-mode reporting with export flexibility: + - Dual-mode reporting with export flexibility: * Measurement units mode vs. raw data mode * Multi-format export (`CSV`/`Excel`) for analysis integration Core Functionality - 1. Parameter Analysis + 1. Parameter Analysis - Total/trainable parameter quantification - Layer-wise parameter distribution analysis - Gradient state tracking (requires_grad flags) - + 2. Computational Profiling - FLOPs/MACs precision calculation - Operation-wise calculation distribution analysis - Dynamic input/output detection (number, type, shape, ...) - - 3. Memory Diagnostics + + 3. Memory Diagnostics - Input/output tensor memory awareness - Hierarchical memory consumption analysis - + 4. Performance Benchmarking - Auto warm-up phase execution (eliminates cold-start bias) - Device-specific high-precision timing @@ -105,7 +108,7 @@ class Meter: - Rich-text hierarchical structure tree rendering 1. Style customization and real-time rendering 2. Smart module folding based on structural equivalence detection - + 6. Cross-Platform Support - Automatic model-data co-location - Seamless device transition (CPU/CUDA) @@ -143,10 +146,10 @@ class Meter: stat_info: Generates a formatted summary of the specified statistics. overview: Generates an overview of all statistics in a formatted layout. rebase: Rebases the Meter instance to a specific node in the operation tree. - + Note: - Requires at least one forward pass before most measurements become available. - - Implements lazy evaluation and cache for most statistics (i.e. `param`, `cal`, `mem`). + - Implements lazy evaluation and cache for most statistics (i.e. `param`, `cal`, `mem`). Example: ```python @@ -157,33 +160,31 @@ class Meter: model = models.resnet152() # wrap the model with Meter class - metered_model = Meter(model, device='cuda') + metered_model = Meter(model, device="cuda") # Basic usage input_tensor = torch.randn(1, 3, 224, 224).cuda() output = meter(input_tensor) # Standard model execution - + # Performance analysis print(meter.structure) # Visualize model hierarchy - print(meter.param) # Show parameter statistics - meter.profile('cal') # Display computational cost table - + print(meter.param) # Show parameter statistics + meter.profile("cal") # Display computational cost table + # Device management - meter.to('cpu') + meter.to("cpu") meter(input_tensor.cpu()) ``` """ - - def __init__(self, - model: nn.Module, - device:Optional[Union[str, tc_device]]=None) -> None: + + def __init__(self, model: nn.Module, device: Optional[Union[str, tc_device]] = None) -> None: """Initialize a Meter instance for model performance measurement and visualization. Args: model (nn.Module): PyTorch model to be instrumented for measurement device (Optional[Union[str, torch.device]]): Target device for model execution and measurement. - Accepts either device string (e.g., 'cuda:0') or - torch.device object. If None, automatically detects + Accepts either device string (e.g., 'cuda:0') or + torch.device object. If None, automatically detects model's current device via its parameters. Raises: @@ -205,34 +206,34 @@ def __init__(self, - Resets measurement flags (`param`/`cal`/`mem`) - Sets default benchmark parameters (`ittp_warmup`=50, `ittp_benchmark_time`=100) - Initializes accuracy warning trackers (`_has_nocall_nodes`, `_has_not_support_nodes`) - + Example: ```python from torchmeter import Meter from torchvision import models - + model = models.resnet18() # auto detect device metered_model = Meter(model) # init a gpu model - metered_model = Meter(model, device='cuda') - metered_model = Meter(model, device='cuda:1') + metered_model = Meter(model, device="cuda") + metered_model = Meter(model, device="cuda:1") ``` """ - + from torchmeter.engine import OperationTree from torchmeter.display import TreeRenderer, TabularRenderer if not isinstance(model, nn.Module): raise TypeError(f"model must be a nn.Module, but got `{type(model).__name__}`.") - + device = device or self.__device_detect(model) self.__device = tc_device(device) if isinstance(device, str) else device self.model = model.to(self.__device) - self._ipt:IPT_TYPE = {'args':tuple(), 'kwargs':dict()} # TODO: self.ipt_infer() + self._ipt: IPT_TYPE = {"args": tuple(), "kwargs": dict()} # TODO: self.ipt_infer() self.optree = OperationTree(self.model) @@ -245,8 +246,8 @@ def __init__(self, self.ittp_warmup = 50 self.ittp_benchmark_time = 100 - self.__has_nocall_nodes:Optional[bool] = None - self.__has_not_support_nodes:Optional[bool] = None + self.__has_nocall_nodes: Optional[bool] = None + self.__has_not_support_nodes: Optional[bool] = None def __call__(self, *args, **kwargs) -> Any: """Execute model inference while maintaining input and model device synchronization. @@ -268,9 +269,9 @@ def __call__(self, *args, **kwargs) -> Any: (triggered by `_ipt2device()` method). Notes: - - From a macroscopic perspective, this is equivalent to direct model invocation: + - From a macroscopic perspective, this is equivalent to direct model invocation: `meter_instance(input)` == `model(input)` - + - You can safely input tensors from different devices; automatic synchronization is handled: - Moves all tensors in the input to current device via `_ipt2device()` - Ensures model is on current device before execution @@ -278,51 +279,54 @@ def __call__(self, *args, **kwargs) -> Any: - Subsequent calls perform two key operations: 1. Overwrite captured inputs, enabling `ipt` updates through normal model invocation 2. Clear cached measurements when input differs (determined by `Meter.__is_ipt_changed()` rules) - - - If there exists tensor data, its dimensions might directly impact the measurement results - of multiple statistics (e.g. `cal`, `mem`, `ittp`). For consistent and comparable results, - we recommend using **a single sample** for measuring all statistics. This can be achieved - by passing in a single batch of sample data whenever you want. + + - If there exists tensor data, its dimensions might directly impact the measurement results + of multiple statistics (e.g. `cal`, `mem`, `ittp`). For consistent and comparable results, + we recommend using **a single sample** for measuring all statistics. This can be achieved + by passing in a single batch of sample data whenever you want. Example: ```python import torch import torch.nn as nn from torchmeter import Meter - + + class MyModel(nn.Module): def __init__(self): super(MyModel, self).__init__() self.conv = nn.Conv2d(3, 10, 3) + def forward(self, x, y=1): - return self.conv(x) + y + return self.conv(x) + y + model = MyModel() - metered_model = Meter(model, device='cuda:0') - + metered_model = Meter(model, device="cuda:0") + # Standard invocation output = metered_model(torch.randn(1, 3, 224, 224)) - + # Mixed argument types output = metered_model(torch.randn(1, 3, 224, 224), y=2) ``` """ - new_ipt:IPT_TYPE = {"args": args, "kwargs": kwargs} + new_ipt: IPT_TYPE = {"args": args, "kwargs": kwargs} if self.__is_ipt_changed(new_ipt): self.__measure_param = False self.__measure_cal = False self.__measure_mem = False - + self._ipt = new_ipt self._ipt2device() self.model.to(self.device) - return self.model(*self._ipt['args'], **self._ipt['kwargs']) - + return self.model(*self._ipt["args"], **self._ipt["kwargs"]) + def __getattr__(self, name: str) -> Any: """Transparently proxy attribute access to the underlying model when not found in Meter instance - This method enables seamless attribute access to the wrapped model while maintaining Meter's + This method enables seamless attribute access to the wrapped model while maintaining Meter's own attributes. It follows these resolution rules: 1. Directly returns Meter's own attributes if they exist 2. For attributes prefixed with "ORIGIN_", returns the underlying model's attribute with the prefix removed @@ -335,7 +339,7 @@ def __getattr__(self, name: str) -> Any: Any: The value of the requested attribute from either Meter instance or underlying model Raises: - AttributeError: + AttributeError: - When the attribute does not exist in both Meter instance and underlying model - When using "ORIGIN_" prefix with non-existent attribute in underlying model @@ -346,7 +350,7 @@ def __getattr__(self, name: str) -> Any: - To bypass Meter's attributes and directly access model's attributes with same name: Use "ORIGIN_" prefix (e.g., `meter.ORIGIN_param` maps to `model.param`) - - This implementation ensures the Meter instance can be seamlessly used as a drop-in replacement + - This implementation ensures the Meter instance can be seamlessly used as a drop-in replacement for the underlying model without requiring code modifications """ @@ -356,10 +360,10 @@ def __getattr__(self, name: str) -> Any: name = name[7:] raise AttributeError return super().__getattribute__(name) - + except AttributeError: return getattr(self.model, name) - + def __setattr__(self, name: str, value: Any) -> None: """Prioritize setting attributes on Meter instance first, falling back to the underlying model. @@ -377,16 +381,16 @@ def __setattr__(self, name: str, value: Any) -> None: - When attempting to set non-modifiable Meter class attributes. - When attribute assignment fails for both Meter instance and the underlying model - For example, the attributes + For example, the attributes . All these attributes are - + Notes: - When encountering conflicting attribute names between Meter instance and the model: The Meter instance's attribute will be prioritized for assignment by default. To assign the underlying model's attribute with same name, prepend "ORIGIN_" prefix. Example: `meter_instance.ORIGIN_param = 1` will set model's `param` attribute to 1 - - This implementation ensures the Meter instance can be seamlessly used as a drop-in replacement + - This implementation ensures the Meter instance can be seamlessly used as a drop-in replacement for the underlying model without requiring code modifications - Non-modifiable Meter class attributes are attributes defined by `@property` but without a setter. @@ -400,25 +404,25 @@ def __setattr__(self, name: str, value: Any) -> None: 7. `model_info` 8. `subnodes` """ - - cls_attrs:Dict[str, bool] = self.__get_clsattr_with_settable_flag() - notchange_cls_attrs = [k for k,v in cls_attrs.items() if not v] - + + cls_attrs: Dict[str, bool] = self.__get_clsattr_with_settable_flag() + notchange_cls_attrs = [k for k, v in cls_attrs.items() if not v] + if name in notchange_cls_attrs: raise AttributeError(f"`{name}` could never be set.") - + try: # set the property with same name defined in Meter from origin model if name.startswith("ORIGIN_"): name = name[7:] raise AttributeError - + super().__setattr__(name, value) - + except AttributeError: setattr(self.model, name, value) - - def __delattr__(self, name:str) -> None: + + def __delattr__(self, name: str) -> None: """Try to delete attributes from Meter instance first, fall back to underlying model if needed. This method ensures: @@ -430,37 +434,37 @@ def __delattr__(self, name:str) -> None: name (str): Name of the attribute to delete Raises: - AttributeError: + AttributeError: - When trying to delete Meter's class attributes - When attempting to delete non-existent attributes - When failed to delete attribute from both Meter instance and the underlying model - + Notes: - When encountering conflicting attribute names between Meter instance and the model: The Meter instance's attribute will be prioritized for deletion by default. To delete the underlying model's attribute with same name, prepend "ORIGIN_" prefix. Example: `del meter_instance.ORIGIN_param` will delete model's `param` attribute - - This implementation ensures the Meter instance can be seamlessly used as a drop-in replacement + - This implementation ensures the Meter instance can be seamlessly used as a drop-in replacement for the underlying model without requiring code modifications """ - - cls_attrs:Dict[str, bool] = self.__get_clsattr_with_settable_flag() - + + cls_attrs: Dict[str, bool] = self.__get_clsattr_with_settable_flag() + if name in cls_attrs: raise AttributeError(f"`{name}` could never be deleted.") - + try: # delete the property with same name defined in Meter from origin model if name.startswith("ORIGIN_"): name = name[7:] raise AttributeError - + super().__delattr__(name) - + except AttributeError: delattr(self.model, name) - + @property def ipt(self) -> IPT_TYPE: """Captured underlying model input dictionary. @@ -475,17 +479,17 @@ def ipt(self) -> IPT_TYPE: - 'args' (tuple): Positional arguments passed to the `forward()` of the underlying model. - 'kwargs' (dict): Keyword arguments passed to the `forward()` of the underlying model - - Input can only be set/updated through `Meter` instance calls + - Input can only be set/updated through `Meter` instance calls (i.e., feed-forward inference of the origin model) - - - If there exists tensor data, its dimensions might directly impact the measurement results - of multiple statistics (e.g. `cal`, `mem`, `ittp`). For consistent and comparable results, - we recommend using **a single sample** for measuring all statistics. This can be achieved - by providing a single-sample forward pass to the meter instance whenever you want. + + - If there exists tensor data, its dimensions might directly impact the measurement results + of multiple statistics (e.g. `cal`, `mem`, `ittp`). For consistent and comparable results, + we recommend using **a single sample** for measuring all statistics. This can be achieved + by providing a single-sample forward pass to the meter instance whenever you want. """ return self._ipt - + @property def device(self) -> tc_device: """The device where the model and all input tensors are currently located. @@ -495,15 +499,15 @@ def device(self) -> tc_device: """ return self.__device - + @device.setter - def device(self, new_device:Union[str, tc_device]) -> None: + def device(self, new_device: Union[str, tc_device]) -> None: """Moves the model and all tensors in captured input to the specified device. This setter updates the device for both the model and its input tensors (if available). Args: - new_device (Union[str, torch.device]): The target device, which can be a string (e.g., "cpu" or "cuda:0") + new_device (Union[str, torch.device]): The target device, which can be a string (e.g., "cpu" or "cuda:0") or a torch.device object. Notes: @@ -511,7 +515,7 @@ def device(self, new_device:Union[str, tc_device]) -> None: - If any tensors are present in `self._ipt`, they will also be moved to the new device. - Moves the model to the new device using `model.to()` in PyTorch. """ - + self.__device = tc_device(new_device) self.model.to(self.__device) if not self._is_ipt_empty(): @@ -521,30 +525,30 @@ def device(self, new_device:Union[str, tc_device]) -> None: def tree_fold_repeat(self) -> bool: """Controls whether repeated tree blocks are rendered as collapsed panels. - This property directly binds to the `tree_fold_repeat` property in the global configuration. - When enabled, repeated operation blocks are collapsed into a single panel during tree rendering + This property directly binds to the `tree_fold_repeat` property in the global configuration. + When enabled, repeated operation blocks are collapsed into a single panel during tree rendering via `Meter.structure`. Returns: bool: True to collapse repeated blocks, False to expand them. - + Note: - Repeated blocks are identified only when two operations exhibit structural equivalence in: - 1. Their own parameter signatures + 1. Their own parameter signatures 2. Their child operations' hierarchical parameters 3. The execution order within the operation if it is a container. - - - The folding feature activates exclusively for such validated repetitive patterns. + + - The folding feature activates exclusively for such validated repetitive patterns. All other structures render sequentially following their topological order. - - If your model doesn't have the repeated blocks mentioned above (like `AlexNet`), + - If your model doesn't have the repeated blocks mentioned above (like `AlexNet`), setting this property True or False won't affect the output. """ - + return __cfg__.tree_fold_repeat - + @tree_fold_repeat.setter - def tree_fold_repeat(self, enable:bool) -> None: + def tree_fold_repeat(self, enable: bool) -> None: """Control rendering of repeated tree blocks as a single collapsed panel. Args: @@ -553,10 +557,10 @@ def tree_fold_repeat(self, enable:bool) -> None: Raises: TypeError: If value is not a boolean. - Notes: - This property is directly bound to the `tree_fold_repeat` property in the global configuration, + Notes: + This property is directly bound to the `tree_fold_repeat` property in the global configuration, so any change will be directly synchronized to the global settings. - + Example: ```python from rich import print @@ -577,66 +581,69 @@ def tree_fold_repeat(self, enable:bool) -> None: """ if not isinstance(enable, bool): - raise TypeError("The `tree_fold_repeat` property can only be rewritten with a boolean, " + \ - f"but got `{type(enable).__name__}`.") + raise TypeError( + "The `tree_fold_repeat` property can only be rewritten with a boolean, " + + f"but got `{type(enable).__name__}`." + ) + __cfg__.tree_fold_repeat = enable @property def tree_levels_args(self) -> FlagNameSpace: """Gets rendering configuration for various levels of rendered tree structure. - This property directly binds to `torchmeter.display.TreeRenderer.tree_levels_args` - to get rendering configuration (e.g., label, guide_style) for various levels of rendered - tree structure generated via `Meter.structure` property. The configuration persists across + This property directly binds to `torchmeter.display.TreeRenderer.tree_levels_args` + to get rendering configuration (e.g., label, guide_style) for various levels of rendered + tree structure generated via `Meter.structure` property. The configuration persists across all subsequent tree renderings until explicitly modified. Returns: - FlagNameSpace: A nested namespace where the outer-layer keys are the specific tree levels, - and the values are the configuration namespaces for the corresponding levels. - In each configuration namespace, the keys contain the specific configuration names, + FlagNameSpace: A nested namespace where the outer-layer keys are the specific tree levels, + and the values are the configuration namespaces for the corresponding levels. + In each configuration namespace, the keys contain the specific configuration names, which match the valid parameters of `rich.tree.Tree`. """ return self.tree_renderer.tree_levels_args - + @tree_levels_args.setter - def tree_levels_args(self, custom_args:Dict[str, Dict[str, Any]]) -> None: + def tree_levels_args(self, custom_args: Dict[str, Dict[str, Any]]) -> None: """Sets rendering configuration for various levels of rendered tree structure via a dictionary. - This property is bound to the `tree_levels_args` attribute of the internal `TreeRenderer` instance. - It allows users to batch configure the rendering configuration (e.g., label, guide_style) for tree - structure generated through the `Meter.structure` property. The provided dictionary maps configuration + This property is bound to the `tree_levels_args` attribute of the internal `TreeRenderer` instance. + It allows users to batch configure the rendering configuration (e.g., label, guide_style) for tree + structure generated through the `Meter.structure` property. The provided dictionary maps configuration names to their values for fine-grained control over table rendering. Args: - custom_args (Dict[str, Dict[str, Any]]): A nested dictionary where the keys of the outer dictionary - are tree level names (such as 0, 1, default), and the values - are the inner configuration dictionaries for the corresponding - levels. In the inner dictionary, the keys are the configuration + custom_args (Dict[str, Dict[str, Any]]): A nested dictionary where the keys of the outer dictionary + are tree level names (such as 0, 1, default), and the values + are the inner configuration dictionaries for the corresponding + levels. In the inner dictionary, the keys are the configuration names and the values are the corresponding configuration values. Raises: - UserWarning: If the input dictionary contains keys that are not valid level names, then the corresponding + UserWarning: If the input dictionary contains keys that are not valid level names, then the corresponding configuration will be ignored. - TypeError: If the input is not a dictionary type. + TypeError: If the input is not a dictionary type. KeyError: If the input dictionary contains keys that are not valid arguments for `rich.tree.Tree`. - + Notes: - - Specified configurations will be updated, while unspecified ones remain unchanged. Therefore, users + - Specified configurations will be updated, while unspecified ones remain unchanged. Therefore, users can pass a partially configured dictionary. - - - If the keys in outer dictionary are invalid, then the configuration in its value will be ignored. Valid + + - If the keys in outer dictionary are invalid, then the configuration in its value will be ignored. Valid level names include (all are strings): 1. non-negative integer: "0", "1", ... Used to specify the configuration for the corresponding level - 2. "default": The configuration applied when encountering a level with unspecified configuration during + 2. "default": The configuration applied when encountering a level with unspecified configuration during the rendering process. 3. "all": The configuration will be used for all levels. - + - Supported configurations of inner configuration dictionary include: 1. `label` (str): Node representation string, accept rich styling 2. `guide_style` (str): Guide style of the node, execute `python -m rich.theme` to see more 3. ... see more at https://rich.readthedocs.io/en/latest/reference/tree.html#rich.tree.Tree - + Example: ```python from torchmeter import Meter @@ -649,51 +656,48 @@ def tree_levels_args(self, custom_args:Dict[str, Dict[str, Any]]) -> None: print(metered_model.tree_levels_args) # only update two configuration, other configuration remain unchanged. - metered_model.tree_levels_args = { - "default": {"guide_style": "red"}, - "1": {"guide_style": "yellow"} - } + metered_model.tree_levels_args = {"default": {"guide_style": "red"}, "1": {"guide_style": "yellow"}} print(metered_model.tree_levels_args) ``` """ - self.tree_renderer.tree_levels_args = custom_args # type: ignore + self.tree_renderer.tree_levels_args = custom_args # type: ignore @property def tree_repeat_block_args(self) -> FlagNameSpace: """Gets rendering configuration for repeated blocks of rendered tree structure. - This property directly binds to `torchmeter.display.TreeRenderer.repeat_block_args` - to get rendering configuration (e.g., style, highlight) for repeated blocks of rendered - tree structure generated via `Meter.structure` property. The configuration persists across + This property directly binds to `torchmeter.display.TreeRenderer.repeat_block_args` + to get rendering configuration (e.g., style, highlight) for repeated blocks of rendered + tree structure generated via `Meter.structure` property. The configuration persists across all subsequent tree renderings until explicitly modified. Returns: - FlagNameSpace: A namespace containing concrete configuration names. + FlagNameSpace: A namespace containing concrete configuration names. Accessible keys match valid arguments of `rich.panel.Panel`. """ return self.tree_renderer.repeat_block_args - + @tree_repeat_block_args.setter - def tree_repeat_block_args(self, custom_args:Dict[str, Any]) -> None: + def tree_repeat_block_args(self, custom_args: Dict[str, Any]) -> None: """Sets rendering configuration for repeated blocks of rendered tree structure via a dictionary. - This property is bound to the `repeat_block_args` attribute of the internal `TreeRenderer` instance. - It allows users to batch configure the rendering configuration (e.g., style, highlight) for tree - structure generated through the `Meter.structure` property. The provided dictionary maps configuration + This property is bound to the `repeat_block_args` attribute of the internal `TreeRenderer` instance. + It allows users to batch configure the rendering configuration (e.g., style, highlight) for tree + structure generated through the `Meter.structure` property. The provided dictionary maps configuration names to their values for fine-grained control over table rendering. Args: - custom_args (Dict[str, Any]): A dictionary where keys are configuration names and values are - the corresponding values to be set. + custom_args (Dict[str, Any]): A dictionary where keys are configuration names and values are + the corresponding values to be set. Raises: TypeError: If the input is not a dictionary type. KeyError: If the input dictionary contains keys that are not valid arguments for `rich.panel.Panel`. Notes: - - Specified configurations will be updated, while unspecified ones remain unchanged. Therefore, users + - Specified configurations will be updated, while unspecified ones remain unchanged. Therefore, users can pass a partially configured dictionary. - Supported configuration include: @@ -716,49 +720,49 @@ def tree_repeat_block_args(self, custom_args:Dict[str, Any]) -> None: # only update two configuration, other configuration remain unchanged. metered_model.tree_repeat_block_args = { "title": "This block repeats for [[b][/b]] Times", - "title_align": "right" + "title_align": "right", } print(metered_model.tree_repeat_block_args) ``` """ - self.tree_renderer.repeat_block_args = custom_args # type: ignore + self.tree_renderer.repeat_block_args = custom_args # type: ignore @property def table_display_args(self) -> FlagNameSpace: """Gets comprehensive rendering configuration for rendered tables. - This property directly binds to `torchmeter.display.TabularRenderer.tb_args` - to get rendering configuration (e.g., style, highlight) for tables generated - via `Meter.profile()`. The configuration persists across all subsequent table + This property directly binds to `torchmeter.display.TabularRenderer.tb_args` + to get rendering configuration (e.g., style, highlight) for tables generated + via `Meter.profile()`. The configuration persists across all subsequent table renderings until explicitly modified. Returns: - FlagNameSpace: A namespace containing concrete configuration names. + FlagNameSpace: A namespace containing concrete configuration names. Accessible keys match valid arguments of `rich.table.Table`. """ return self.table_renderer.tb_args @table_display_args.setter - def table_display_args(self, custom_args:Dict[str, Any]) -> None: + def table_display_args(self, custom_args: Dict[str, Any]) -> None: """Sets comprehensive rendering configuration for rendered tables via a dictionary. - This property is bound to the `tb_args` attribute of the internal `TabularRenderer` instance. - It allows users to batch configure the comprehensive rendering configuration (e.g., style, highlight) - for tables generated through the `Meter.profile()` method with a dictionary. The provided dictionary maps + This property is bound to the `tb_args` attribute of the internal `TabularRenderer` instance. + It allows users to batch configure the comprehensive rendering configuration (e.g., style, highlight) + for tables generated through the `Meter.profile()` method with a dictionary. The provided dictionary maps configuration names to their values for fine-grained control over table rendering. Args: - custom_args (Dict[str, Any]): A dictionary where keys are configuration names and values are - the corresponding values to be set. + custom_args (Dict[str, Any]): A dictionary where keys are configuration names and values are + the corresponding values to be set. Raises: TypeError: If the input is not a dictionary type. KeyError: If the input dictionary contains keys that are not valid arguments for `rich.table.Table`. Notes: - - Specified configurations will be updated, while unspecified ones remain unchanged. Therefore, users + - Specified configurations will be updated, while unspecified ones remain unchanged. Therefore, users can pass a partially configured dictionary. - Supported configuration include: @@ -780,44 +784,41 @@ def table_display_args(self, custom_args:Dict[str, Any]) -> None: print(metered_model.table_display_args) # only update two configuration, other configuration remain unchanged. - metered_model.table_display_args = { - "style": "red", - "show_lines": True - } + metered_model.table_display_args = {"style": "red", "show_lines": True} print(metered_model.table_display_args) ``` """ - - self.table_renderer.tb_args = custom_args # type: ignore + + self.table_renderer.tb_args = custom_args # type: ignore @property def table_column_args(self) -> FlagNameSpace: """Gets column rendering configuration for rendered tables. - This property directly binds to `torchmeter.display.TabularRenderer.col_args` - to get column-level rendering configuration (e.g., style, justify) for tables - generated via `Meter.profile()`. The configuration persists across all subsequent + This property directly binds to `torchmeter.display.TabularRenderer.col_args` + to get column-level rendering configuration (e.g., style, justify) for tables + generated via `Meter.profile()`. The configuration persists across all subsequent table renderings until explicitly modified. Returns: - FlagNameSpace: A namespace containing concrete configuration names. + FlagNameSpace: A namespace containing concrete configuration names. Accessible keys match valid arguments of `rich.table.Column`. """ return self.table_renderer.col_args @table_column_args.setter - def table_column_args(self, custom_args:Dict[str, Any]) -> None: + def table_column_args(self, custom_args: Dict[str, Any]) -> None: """Sets column-level rendering configuration for rendered tables via a dictionary. - This property is bound to the `col_args` attribute of the internal `TabularRenderer` instance. - It allows users to batch configure column-specific rendering configuration (e.g., style, justify) - for tables generated through the `Meter.profile()` method. The provided dictionary maps configuration + This property is bound to the `col_args` attribute of the internal `TabularRenderer` instance. + It allows users to batch configure column-specific rendering configuration (e.g., style, justify) + for tables generated through the `Meter.profile()` method. The provided dictionary maps configuration names to their values for fine-grained control over table rendering. Args: - custom_args (Dict[str, Any]): A dictionary where keys are configuration names and values are - the corresponding values to be set. + custom_args (Dict[str, Any]): A dictionary where keys are configuration names and values are + the corresponding values to be set. Raises: TypeError: If the input is not a dictionary type. @@ -826,7 +827,7 @@ def table_column_args(self, custom_args:Dict[str, Any]) -> None: Notes: - Configuration changes will be applied to **all** columns of the rendered table. - - Specified configurations will be updated, while unspecified ones remain unchanged. Therefore, users + - Specified configurations will be updated, while unspecified ones remain unchanged. Therefore, users can pass a partially configured dictionary. - Supported configuration include: @@ -847,40 +848,37 @@ def table_column_args(self, custom_args:Dict[str, Any]) -> None: print(metered_model.table_column_args) # only update two configuration, other configuration remain unchanged. - metered_model.table_column_args = { - "style": "bold green", - "justify": "left" - } + metered_model.table_column_args = {"style": "bold green", "justify": "left"} print(metered_model.table_column_args) ``` """ - - self.table_renderer.col_args = custom_args # type: ignore + + self.table_renderer.col_args = custom_args # type: ignore @property def structure(self) -> Tree: """Generate a stylized tree representation of the model's operation hierarchy. - This property renders the operation tree based on current configuration settings. - The rendering strategy (folded/unfolded) and customization options are determined - by the active configuration parameters. Caching is applied to optimize rendering + This property renders the operation tree based on current configuration settings. + The rendering strategy (folded/unfolded) and customization options are determined + by the active configuration parameters. Caching is applied to optimize rendering performance when configuration remains unchanged. Returns: Tree: A `rich.tree.Tree` object representing the hierarchical structure of model operations. - + Notes: - - Configuration parameters influence rendering behavior, you can access them directly + - Configuration parameters influence rendering behavior, you can access them directly by `metered_model.` - * `tree_fold_repeat`: Controls whether a repeated block is rendered as a single block. + * `tree_fold_repeat`: Controls whether a repeated block is rendered as a single block. Default to True. * `tree_levels_args`: Customizes rendering at different tree levels. * `tree_repeat_block_args`: Detailed parameters to control the rendering of the repeat blocks. - - - Caching mechanism: Reuses cached render result if all configuration parameters remain + + - Caching mechanism: Reuses cached render result if all configuration parameters remain unchanged since last render. - - - For information on repeated blocks identification and rendering, please refer to the + + - For information on repeated blocks identification and rendering, please refer to the description of `Meter.tree_fold_repeat` property. Example: @@ -888,44 +886,42 @@ def structure(self) -> Tree: from rich import print from torchmeter import Meter from torchvision import models - + model = models.vit_b_16() metered_model = Meter(model) - + # use the default configuration print(metered_model.structure) - + # reaccess the structure, will be quickly returned print(metered_model.structure) - + # use a custom configuration metered_model.tree_fold_repeat = False - metered_model.tree_levels_args = { - "default": {"guide_style": "red"} - } + metered_model.tree_levels_args = {"default": {"guide_style": "red"}} print(metered_model.structure) ``` """ fold_repeat = __cfg__.tree_fold_repeat - + is_rpbk_change = __cfg__.tree_repeat_block_args.is_change() - + is_level_change = __cfg__.tree_levels_args.is_change() - + if fold_repeat: cache_res = self.tree_renderer.render_fold_tree if not is_rpbk_change else None else: cache_res = self.tree_renderer.render_unfold_tree cache_res = cache_res if not is_level_change else None - + rendered_tree = self.tree_renderer() if cache_res is None else cache_res - + if is_rpbk_change and fold_repeat: __cfg__.tree_repeat_block_args.mark_unchange() if is_level_change: __cfg__.tree_levels_args.mark_unchange() - + # render_perline(renderable=rendered_tree) return rendered_tree @@ -933,28 +929,28 @@ def structure(self) -> Tree: def param(self) -> ParamsMeter: """Measures the number of model parameters. - This property calculates the parameter-related metrics (e.g., number of parameters, - trainable parameters) for each node in the operation tree. + This property calculates the parameter-related metrics (e.g., number of parameters, + trainable parameters) for each node in the operation tree. Returns: ParamsMeter: A ParamsMeter instance containing the measured parameter-related statistics. Notes: - The measurement is performed only once for each Meter instance. Subsequent accesses + The measurement is performed only once for each Meter instance. Subsequent accesses will return the cached result. """ - + if not self.__measure_param: list(map(lambda node: node.param.measure(), self.optree.all_nodes)) self.__measure_param = True return self.optree.root.param - + @property def cal(self) -> CalMeter: """Measures the calculation cost of the model during inference. - This property calculates the computational cost (i.e., FLOPs and MACs) for each node in the + This property calculates the computational cost (i.e., FLOPs and MACs) for each node in the operation tree during a feed-forward inference pass. Returns: @@ -965,39 +961,42 @@ def cal(self) -> CalMeter: Notes: - You must first invoke the Meter instance (via a forward pass) before accessing this property. - - - The measurement is performed only once for each Meter instance. Subsequent accesses + + - The measurement is performed only once for each Meter instance. Subsequent accesses will return the cached result. - - The measurement results depend on the model input, and different input tensor sizes - will lead to varying calculation costs, which is **normal**. For consistent and comparable + - The measurement results depend on the model input, and different input tensor sizes + will lead to varying calculation costs, which is **normal**. For consistent and comparable results, we recommend using **a single sample** for measuring all statistics including `cal`. - This can be achieved by providing a single-sample forward pass to the meter instance whenever + This can be achieved by providing a single-sample forward pass to the meter instance whenever you want. """ - + if not self.__measure_cal: if self._is_ipt_empty(): - raise RuntimeError("Input unknown! You should perform at least one feed-forward inference before measuring calculation!") + raise RuntimeError( + "Input unknown! " + + "You should perform at least one feed-forward inference before measuring calculation!" + ) hook_ls = [node.cal.measure() for node in self.optree.all_nodes] # feed forwad self._ipt2device() - self.model(*self.ipt['args'], **self.ipt['kwargs']) + self.model(*self.ipt["args"], **self.ipt["kwargs"]) # remove hooks after measurement - list(map(lambda x:x.remove() if x is not None else None, hook_ls)) + list(map(lambda x: x.remove() if x is not None else None, hook_ls)) self.__measure_cal = True - + return self.optree.root.cal @property def mem(self) -> MemMeter: """Measures the memory cost of the model during inference. - This property calculates the memory usage for each node in the operation tree during a + This property calculates the memory usage for each node in the operation tree during a feed-forward inference pass. Returns: @@ -1008,30 +1007,32 @@ def mem(self) -> MemMeter: Notes: - You must first invoke the Meter instance (via a forward pass) before accessing this property. - - - The measurement is performed only once for each Meter instance. Subsequent accesses + + - The measurement is performed only once for each Meter instance. Subsequent accesses will return the cached result. - - - The measurement results depend on the model input, and different input tensor sizes - will lead to varying memory costs, which is **normal**. For consistent and comparable - results, we recommend using **a single sample** for measuring all statistics including - `mem`. This can be achieved by providing a single-sample forward pass to the meter instance + + - The measurement results depend on the model input, and different input tensor sizes + will lead to varying memory costs, which is **normal**. For consistent and comparable + results, we recommend using **a single sample** for measuring all statistics including + `mem`. This can be achieved by providing a single-sample forward pass to the meter instance whenever you want. """ - + if not self.__measure_mem: if self._is_ipt_empty(): - raise RuntimeError("Input unknown! You should perform at least one feed-forward inference " + \ - "before measuring the memory cost!") + raise RuntimeError( + "Input unknown! You should perform at least one feed-forward inference " + + "before measuring the memory cost!" + ) hook_ls = [node.mem.measure() for node in self.optree.all_nodes] # feed forward self._ipt2device() - self.model(*self.ipt['args'], **self.ipt['kwargs']) + self.model(*self.ipt["args"], **self.ipt["kwargs"]) # remove hooks after measurement - list(map(lambda x:x.remove() if x is not None else None, hook_ls)) + list(map(lambda x: x.remove() if x is not None else None, hook_ls)) self.__measure_mem = True @@ -1041,8 +1042,8 @@ def mem(self) -> MemMeter: def ittp(self) -> IttpMeter: """Measures the inference time and throughput of the model. - This property calculates the inference time and throughput for each node in the operation tree. - It performs a warm-up phase followed by a benchmark phase to ensure accurate measurements. + This property calculates the inference time and throughput for each node in the operation tree. + It performs a warm-up phase followed by a benchmark phase to ensure accurate measurements. The results are returned as an `IttpMeter` object. Returns: @@ -1055,53 +1056,58 @@ def ittp(self) -> IttpMeter: Notes: - You must first invoke the Meter instance (via a forward pass) before accessing this property. - + - The measurements are performed on the device specified by `meter_instance.device` !!! - + - The unit `IPS` means **Input Per Second**, which is the number of inferences with given input - per second. - + per second. + - Unlike other statistics, the measured result is **not** cached, so it will be re-measured every time `ittp` attribute is accessed. - + - The warm-up phase runs for `meter_instance.ittp_warmup` iterations to stabilize the measurements. - + - The benchmark phase runs for `meter_instance.ittp_benchmark_time` iterations per operation. - - - The measurement results depend on the model input, and different input tensor sizes will lead to - varying latencies and throughput, which is **normal**. For consistent and comparable results, we - recommend using **a single sample** for measuring all statistics including `ittp`. This can be + + - The measurement results depend on the model input, and different input tensor sizes will lead to + varying latencies and throughput, which is **normal**. For consistent and comparable results, we + recommend using **a single sample** for measuring all statistics including `ittp`. This can be achieved by providing a single-sample forward pass to the meter instance whenever you want. """ - + from tqdm import tqdm if self._is_ipt_empty(): - raise RuntimeError("Input unknown! " + \ - "You should perform at least one feed-forward inference before measuring the inference time or throughput!") + raise RuntimeError( + "Input unknown! " + + "You should perform at least one feed-forward inference " + + "before measuring the inference time or throughput!" + ) if not isinstance(self.ittp_warmup, int): raise TypeError(f"ittp_warmup must be an integer, but got `{type(self.ittp_warmup).__name__}`") if self.ittp_warmup < 0: raise ValueError(f"ittp_warmup must be greater than or equal to 0, but got `{self.ittp_warmup}`.") - + self._ipt2device() - for i in tqdm(range(self.ittp_warmup), desc='Warming Up'): - self.model(*self.ipt['args'], **self.ipt['kwargs']) + for i in tqdm(range(self.ittp_warmup), desc="Warming Up"): + self.model(*self.ipt["args"], **self.ipt["kwargs"]) - pb = tqdm(total=self.ittp_benchmark_time*len(self.optree.all_nodes), - desc='Benchmark Inference Time & Throughput', - unit='module') - hook_ls = [node.ittp.measure(device=self.device, - repeat=self.ittp_benchmark_time, - global_process=pb) - for node in self.optree.all_nodes] + pb = tqdm( + total=self.ittp_benchmark_time * len(self.optree.all_nodes), + desc="Benchmark Inference Time & Throughput", + unit="module", + ) + hook_ls = [ + node.ittp.measure(device=self.device, repeat=self.ittp_benchmark_time, global_process=pb) + for node in self.optree.all_nodes + ] # feed forwad - self.model(*self.ipt['args'], **self.ipt['kwargs']) + self.model(*self.ipt["args"], **self.ipt["kwargs"]) # remove hooks after measurement - list(map(lambda x:x.remove() if x is not None else None, hook_ls)) + list(map(lambda x: x.remove() if x is not None else None, hook_ls)) del pb @@ -1111,41 +1117,42 @@ def ittp(self) -> IttpMeter: def model_info(self) -> Text: """Generates a formatted summary of the model's basic information. - This property provides a detailed summary of the model, including its name, device, + This property provides a detailed summary of the model, including its name, device, forward method signature, and structured input representation. Returns: Text: A `rich.Text` object containing the formatted model information. Notes: - - If no input has been provided (i.e., `self._ipt` is empty), the input representation will + - If no input has been provided (i.e., `self._ipt` is empty), the input representation will indicate that it is not provided. - - - Otherwise, all the values in `self._ipt` will correspond to the formal arguments of the - `forward` method, and a structured input representation with type prompts will be generated + + - Otherwise, all the values in `self._ipt` will correspond to the formal arguments of the + `forward` method, and a structured input representation with type prompts will be generated through the `torchmeter.utils.data_repr` function. """ - + from inspect import signature - from torchmeter.utils import indent_str, data_repr - forward_args:List[str] = list(signature(self.model.forward).parameters.keys()) + from torchmeter.utils import data_repr, indent_str + + forward_args: List[str] = list(signature(self.model.forward).parameters.keys()) if self._is_ipt_empty(): ipt_repr = "[dim]Not Provided\n(give an inference first)[/]" else: - ipt_dict = {forward_args[args_idx]: anony_ipt for args_idx, anony_ipt in enumerate(self.ipt['args'])} - ipt_dict.update(self.ipt['kwargs']) - ipt_repr_ls = [f"{args_name} = {data_repr(args_val)}" for args_name, args_val in ipt_dict.items()] - ipt_repr = ',\n'.join(ipt_repr_ls) + ipt_dict = {forward_args[args_idx]: anony_ipt for args_idx, anony_ipt in enumerate(self.ipt["args"])} + ipt_dict.update(self.ipt["kwargs"]) + ipt_repr_ls = [f"{args_name} = {data_repr(args_val)}" for args_name, args_val in ipt_dict.items()] + ipt_repr = ",\n".join(ipt_repr_ls) - forward_args = ["self"] + forward_args - infos = '\n'.join([ + forward_args = ["self", *forward_args] + infos = "\n".join([ f"β€’ [b]Model :[/b] {self.optree.root.name}", f"β€’ [b]Device :[/b] {self.device}", f"β€’ [b]Signature:[/b] forward({', '.join(forward_args)})", - f"β€’ [b]Input :[/b] \n{indent_str(ipt_repr, indent=3, guideline=False)}" + f"β€’ [b]Input :[/b] \n{indent_str(ipt_repr, indent=3, guideline=False)}", ]) - + console = get_console() return console.render_str(infos) @@ -1153,44 +1160,44 @@ def model_info(self) -> Text: def subnodes(self) -> List[str]: """Retrieves a list of all nodes in the operation tree with their IDs and names. - This property returns a formatted list of all nodes in the operation tree, where each node is - represented by its ID and name. This is useful for identifying specific nodes when rebasing or + This property returns a formatted list of all nodes in the operation tree, where each node is + represented by its ID and name. This is useful for identifying specific nodes when rebasing or inspecting the tree structure. Returns: - List[str]: A list of strings, each formatted as `(node_id) node_name`, representing all nodes + List[str]: A list of strings, each formatted as `(node_id) node_name`, representing all nodes in the operation tree. """ return [f"({node.node_id}) {node.name}" for node in self.optree.all_nodes] - def to(self, new_device:Union[str, tc_device]) -> None: + def to(self, new_device: Union[str, tc_device]) -> None: """Move the model to the specified device while keeping input and model device synchronization. - Simulate the `to` method of pytorch model and use it to move model and all tensor data in + Simulate the `to` method of pytorch model and use it to move model and all tensor data in `self._ipt` to the specified device. Args: new_device (Union[str, torch.device]): Target device name or its corresponding torch.device object. - + Example: ```python import torch from torchmeter import Meter from torchvision import models - + model = models.resnet18() metered_model = Meter(model) - + # move to cuda:0 metered_model.to("cuda:0") - + # move to cpu metered_model.to(torch.device("cpu")) ``` """ - self.device = new_device # type: ignore + self.device = new_device # type: ignore - def rebase(self, node_id:str) -> Meter: + def rebase(self, node_id: str) -> Meter: """Rebases the Meter instance to a specific node in the operation tree. This method allows the Meter instance to focus on a specific node in the operation tree, @@ -1210,28 +1217,28 @@ def rebase(self, node_id:str) -> Meter: Notes: - Use `Meter(your_model).subnodes` to retrieve a list of valid node IDs. - If `node_id` is "0", the original Meter instance is returned without modification. - + Example: ```python from torchmeter import Meter from torchvision import models - + model = models.resnet18() metered_model = Meter(model) rebased_model = metered_model.rebase("5") - - print(metered_model) # Meter(model=0 ResNet: ResNet, device=cpu) - print(rebased_model) # Meter(model=0 Sequential: Sequential, device=cpu) + + print(metered_model) # Meter(model=0 ResNet: ResNet, device=cpu) + print(rebased_model) # Meter(model=0 Sequential: Sequential, device=cpu) ``` """ - + if not isinstance(node_id, str): raise TypeError(f"node_id must be a string, but got `{type(node_id).__name__}`.") - + if node_id == "0": return self - - id_generator = ( (node_idx, node.node_id) for node_idx, node in enumerate(self.optree.all_nodes) ) + + id_generator = ((node_idx, node.node_id) for node_idx, node in enumerate(self.optree.all_nodes)) for idx, valid_id in id_generator: if node_id == valid_id: @@ -1240,13 +1247,13 @@ def rebase(self, node_id:str) -> Meter: else: raise ValueError(f"Invalid node_id: {node_id}. Use `Meter(your_model).subnodes` to check valid ones.") - def stat_info(self, stat_or_statname:Union[str, Statistics], *, show_warning:bool=True) -> Text: + def stat_info(self, stat_or_statname: Union[str, Statistics], *, show_warning: bool = True) -> Text: # noqa: C901 """Generates a formatted summary of the specified statistics. This method provides a summary of the given statistics, including its name and the crucial data - about this statistics. However, sometimes there may exist some modules which is defined but not - explicitly called, or some modules that its calculation measurement logic is not defined in this - version. To prevent confusing user, we will show inaccuracies warnings in the summary. If you don't + about this statistics. However, sometimes there may exist some modules which is defined but not + explicitly called, or some modules that its calculation measurement logic is not defined in this + version. To prevent confusing user, we will show inaccuracies warnings in the summary. If you don't want to see these warnings, you can set `show_warning` to `False` manually. Args: @@ -1260,14 +1267,14 @@ def stat_info(self, stat_or_statname:Union[str, Statistics], *, show_warning:boo TypeError: If `stat_or_statname` is neither a string nor a `Statistics` object. Notes: - - The main content will be obtained from the `crucial_data` property of the statistics object, which is defined - in the corresponding statistics class. - - - For `ittp`, the number of repeated measurements, namely `Benchmark Times`, will be additionally displayed. - This value can be accessed or modified through the `ittp_benchmark_time' attribute. - + - The main content will be obtained from the `crucial_data` property of the statistics object, which is + defined in the corresponding statistics class. + + - For `ittp`, the number of repeated measurements, namely `Benchmark Times`, will be additionally + displayed. This value can be accessed or modified through the `ittp_benchmark_time' attribute. + - `show_warning` option is keyword-only argument, so you should use it through its keyword name. - + - Warnings are only shown for the following two statistics: calculation (`cal`) and memory (`mem`). Because only these two statistics are affected by the no called modules or the not supported mudules. @@ -1276,83 +1283,94 @@ def stat_info(self, stat_or_statname:Union[str, Statistics], *, show_warning:boo from torch import randn from torchmeter import Meter from torchvision import models - + from rich import print - + model = models.vit_b_16() metered_model = Meter(model) metered_model(randn(1, 3, 224, 224)) - + # using statistics name print(metered_model.stat_info("param")) - + # using statistics object cal = metered_model.cal print(metered_model.stat_info(cal)) - - # not show warnings + + # not show warnings print(metered_model.stat_info("mem", show_warning=False)) ``` """ - + if isinstance(stat_or_statname, str): stat = getattr(self, stat_or_statname) elif isinstance(stat_or_statname, Statistics): stat = stat_or_statname else: - raise TypeError(f"Invalid type for stat_or_statname: `{type(stat_or_statname).__name__}`. " + \ - "Please pass in the statistics name or the statistics object itself.") + raise TypeError( + f"Invalid type for stat_or_statname: `{type(stat_or_statname).__name__}`. " + + "Please pass in the statistics name or the statistics object itself." + ) stat_name = stat.name - infos_ls:List[str] = [f"β€’ [b]Statistics:[/b] {stat_name}"] - - if stat_name == 'ittp': + infos_ls: List[str] = [f"β€’ [b]Statistics:[/b] {stat_name}"] + + if stat_name == "ittp": infos_ls.append(f"β€’ [b]Benchmark Times:[/b] {self.ittp_benchmark_time}") - - infos_ls.extend([ - f"β€’ [b]{k}:[/b] {v}" for k, v in stat.crucial_data.items() - ]) - - ## warning field, only works when stat is "cal" or "mem" - if show_warning and stat_name not in ("param", "ittp"): + + infos_ls.extend([f"β€’ [b]{k}:[/b] {v}" for k, v in stat.crucial_data.items()]) + + # warning field, only works when stat is "cal" or "mem" + if show_warning and stat_name in ("cal", "mem"): # cache for __has_nocall_nodes if self.__has_nocall_nodes is None: from operator import attrgetter - + crucial_data_getter = attrgetter(f"{stat_name}.crucial_data") try: list(map(crucial_data_getter, self.optree.all_nodes)) self.__has_nocall_nodes = False except RuntimeError: - self.__has_nocall_nodes = True - + self.__has_nocall_nodes = True + # cache for __has_not_support_nodes if stat_name == "cal" and self.__has_not_support_nodes is None: - self.__has_not_support_nodes = any(n.cal.is_not_supported - for n in self.optree.all_nodes) - + self.__has_not_support_nodes = any(n.cal.is_not_supported for n in self.optree.all_nodes) + warns_ls = [] if self.__has_nocall_nodes: - warns_ls.append(" "*2 + "[dim yellow]:arrow_forward: Some nodes are defined but not called explicitly.[/]") + warns_ls.append( + " " * 2 + "[dim yellow]:arrow_forward: " + "Some nodes are defined but not called explicitly.[/]" + ) + if stat_name == "cal" and self.__has_not_support_nodes: - warns_ls.append(" "*2 + "[dim yellow]:arrow_forward: Some modules don't support calculation measurement yet.[/]") + warns_ls.append( + " " * 2 + + "[dim yellow]:arrow_forward: " + + "Some modules don't support calculation measurement yet.[/]" + ) + if warns_ls: warns_ls.insert(0, "[dim yellow]:warning: Warning: the result may be inaccurate, cause:[/]") - warns_ls.append(" "*2 + f"[dim cyan]:ballot_box_with_check: use `Meter(your_model).profile('{stat_name}')` to see more.[/]") - + warns_ls.append( + " " * 2 + + "[dim cyan]:ballot_box_with_check: " + + f"use `Meter(your_model).profile('{stat_name}')` to see more.[/]" + ) + infos_ls.extend(warns_ls) - - infos = '\n'.join(infos_ls) - + + infos = "\n".join(infos_ls) + console = get_console() return console.render_str(infos) - def overview(self, *order:str, show_warning:bool=True) -> Columns: + def overview(self, *order: str, show_warning: bool = True) -> Columns: """Generates an overview of all statistics in a formatted layout. - This method creates a visual overview of model statistics, including basic model - information and core data of each specified statistic. You can customize the statistics - contained in the rendering results and their order by passing in the statistics you want + This method creates a visual overview of model statistics, including basic model + information and core data of each specified statistic. You can customize the statistics + contained in the rendering results and their order by passing in the statistics you want in the order you prefer. Args: @@ -1366,51 +1384,56 @@ def overview(self, *order:str, show_warning:bool=True) -> Columns: Raises: ValueError: If any of the provided statistics names are invalid. - + Example: ```python from torch import randn from torchmeter import Meter from torchvision import models - + model = models.resnet18() metered_model = Meter(model) metered_model(randn(1, 3, 224, 224)) # overview all statistics (i.e. param, cal, mem, ittp) - metered_model.overview() - + metered_model.overview() + # only overview `cal` and `param` # and the order is `cal` then `param` - metered_model.overview("cal", "param") + metered_model.overview("cal", "param") ``` """ - + from functools import partial - from rich.panel import Panel + from rich.box import HORIZONTALS + from rich.panel import Panel order = order or self.optree.root.statistics - + invalid_stat = tuple(filter(lambda x: x not in self.optree.root.statistics, order)) if len(invalid_stat) > 0: raise ValueError(f"Invalid statistics: {invalid_stat}") - - container = Columns(expand=True, align='center') + + container = Columns(expand=True, align="center") format_cell = partial(Panel, safe_box=True, expand=False, highlight=True, box=HORIZONTALS) - - container.add_renderable(format_cell(self.model_info, title='[b]Model INFO[/]', border_style='orange1')) - container.renderables.extend([format_cell(self.stat_info(stat_name, show_warning=show_warning), - title=f"[b]{stat_name.capitalize()} INFO[/]", - border_style='cyan') - for stat_name in order]) - + + container.add_renderable(format_cell(self.model_info, title="[b]Model INFO[/]", border_style="orange1")) + container.renderables.extend([ + format_cell( + self.stat_info(stat_name, show_warning=show_warning), + title=f"[b]{stat_name.capitalize()} INFO[/]", + border_style="cyan", + ) + for stat_name in order + ]) + return container - def table_cols(self, stat_name:str) -> Tuple[str, ...]: + def table_cols(self, stat_name: str) -> Tuple[str, ...]: """Get all column names of the backend dataframe for the specified statistics. - - This method returns the column names of the backend dataframe associated with the given statistics. + + This method returns the column names of the backend dataframe associated with the given statistics. If the dataframe is empty(i.e. the `profile` is not called yet), it falls back to the values of the `tb_fields` property of corresponding statistics class. @@ -1423,29 +1446,29 @@ def table_cols(self, stat_name:str) -> Tuple[str, ...]: Raises: TypeError: If `stat_name` is not a string. KeyError: If `stat_name` is not found in the available statistics (i.e. `param`, `cal`, `mem`, `ittp`). - + Notes: default column names for each statistics: - - param: ("Operation_Id", "Operation_Name", "Operation_Type", + - param: ("Operation_Id", "Operation_Name", "Operation_Type", "Param_Name", "Requires_Grad", "Numeric_Num") - - - cal: ("Operation_Id", "Operation_Name", "Operation_Type", + + - cal: ("Operation_Id", "Operation_Name", "Operation_Type", "Kernel_Size", "Bias", "Input", "Output", "MACs", "FLOPs") - - - mem: ("Operation_Id", "Operation_Name", "Operation_Type", + + - mem: ("Operation_Id", "Operation_Name", "Operation_Type", "Param_Cost", "Buffer_Cost", "Output_Cost", "Total") - - - ittp: ("Operation_Id", "Operation_Name", "Operation_Type", + + - ittp: ("Operation_Id", "Operation_Name", "Operation_Type", "Infer_Time", "Throughput") - + Example: ```python from torchmeter import Meter from torchvision import models - + model = models.resnet18() metered_model = Meter(model) - + metered_model.table_cols("param") # ('Operation_Id', # 'Operation_Name', @@ -1453,7 +1476,7 @@ def table_cols(self, stat_name:str) -> Tuple[str, ...]: # 'Param_Name', # 'Requires_Grad', # 'Numeric_Num') - + metered_model.table_cols("cal") # ('Operation_Id', # 'Operation_Name', @@ -1466,28 +1489,25 @@ def table_cols(self, stat_name:str) -> Tuple[str, ...]: # 'FLOPs') ``` """ - + if not isinstance(stat_name, str): raise TypeError(f"stat_name must be a string, but got `{type(stat_name).__name__}`.") - - stats_data_dict:Dict[str, DataFrame] = self.table_renderer.stats_data - + + stats_data_dict: Dict[str, DataFrame] = self.table_renderer.stats_data + if stat_name not in stats_data_dict: raise KeyError(f"Statistics `{stat_name}` not in {tuple(stats_data_dict.keys())}.") - - stat_data:DataFrame = stats_data_dict[stat_name] - + + stat_data: DataFrame = stats_data_dict[stat_name] + if stat_data.is_empty(): - cols:Tuple[str, ...] = getattr(self.optree.root, stat_name).tb_fields + cols: Tuple[str, ...] = getattr(self.optree.root, stat_name).tb_fields else: cols = tuple(stat_data.columns) - + return cols - - def profile(self, - stat_name:str, - show:bool=True, no_tree:bool=False, - **tb_kwargs) -> Tuple[Table, DataFrame]: + + def profile(self, stat_name: str, show: bool = True, no_tree: bool = False, **tb_kwargs) -> Tuple[Table, DataFrame]: """Render a tabular report of the specified statistics with rich visualization. This method generates an interactive table visualization for the given statistical data, @@ -1496,63 +1516,63 @@ def profile(self, Args: stat_name (str): Name of the statistics to profile (i.e., 'param', 'cal', 'mem', 'ittp'). - + show (bool, optional): Whether to immediately render the visualization and display in terminal. Defaults to True. - + no_tree (bool, optional): Not to display the rendered tree when set to True. Defaults to False. - + **tb_kwargs: Additional table customization options: - - raw_data (bool): Use raw numerical data instead of formatted values with unit. + - raw_data (bool): Use raw numerical data instead of formatted values with unit. Defaults to False. - pick_cols (Sequence[str]): Whitelist of columns to display. Defaults to []. - exclude_cols (Sequence[str]): Blacklist of columns to hide. Defaults to []. - custom_cols (Dict[str, str]): Column rename mappings (original: new). Defaults to {}. - keep_custom_name (bool): Whether to keep custom names after this call. Defaults to False. - newcol_name (str): Name for new computed column. Defaults to ''. - - newcol_func (Callable[[DataFrame], ArrayLike]): Function to compute new column values. + - newcol_func (Callable[[DataFrame], ArrayLike]): Function to compute new column values. Defaults to lambda df: [None]*len(df). - newcol_type (Optional[PolarsDataType]): Explicit data type for new column. Defaults to None. - newcol_idx (int): Insertion position for new column (-1=append). Defaults to -1. - - keep_new_col (bool): Retain new columns in backend dataframe and subsequent renders. + - keep_new_col (bool): Retain new columns in backend dataframe and subsequent renders. Defaults to False. - save_to (Optional[str]): File path for data export, not None to trigger export. Defaults to None. - - save_format (Optional[str]): Export format, None to use the value in `save_to`. + - save_format (Optional[str]): Export format, None to use the value in `save_to`. Now we only support 'csv' or 'xlsx' file. Defaults to None. Returns: - Tuple[rich.table.Table, polars.DataFrame]: The rendered `rich.table.Table` object and + Tuple[rich.table.Table, polars.DataFrame]: The rendered `rich.table.Table` object and underlying polars DataFrame. Raises: RuntimeWarning: If your model has some modules defined but not explicitly called. - + AttributeError: If `stat_name` is not a valid statistics name. - - ValueError: + + ValueError: - If horizontal gap defined in global config is negative when disable `no_tree`. - If you specify any not existing column name to `pick_cols` when enable `show` and `pick_cols`. - If you pass in a directory path as `save_to` but not specify `save_format`. - If you pass in a non csv or xlsx file path as `save_to`. - If you pass in a non-supported export format as `save_format`. - If `newcol_name` already exists in the backend dataframe. - - RuntimeError: + + RuntimeError: - If terminal width is insufficient for display when enable `show`. - - If no input data has been provided (i.e., `ipt` property is empty) and `stat_name` is + - If no input data has been provided (i.e., `ipt` property is empty) and `stat_name` is one of `cal`, `mem`, or `ittp`. - If no module is called (e.g. the underlying model's `forward` method is not empty). - If the whole model is empty and has no sublayers. - If using a single layer as a model - If `newcol_func` returns values with length mismatch to the underlying dataframe's row count - - TypeError: + + TypeError: - If `stat_name` is not a string - If `pick_cols` is not a list, tuple or set. - If `exclude_cols` is not a list, tuple or set. - If `custom_cols` is not a dict. - If `newcol_name` is not a string. - - If `newcol_func` is uncallable + - If `newcol_func` is uncallable - If `newcol_func` doesn't have exactly **1** formal parameter - If return value of `newcol_func` is not array-like - If `newcol_idx` is not an integer. @@ -1561,7 +1581,7 @@ def profile(self, - If `save_format` is not a string, neither None. Notes: - 1. Ensure at least one forward pass has been executed before accessing `cal`/`mem`/`ittp` statistics to + 1. Ensure at least one forward pass has been executed before accessing `cal`/`mem`/`ittp` statistics to guarantee valid input capture. 2. Table and tree rendering styles can be preconfigured through the properties: @@ -1570,15 +1590,16 @@ def profile(self, - `tree_fold_args` - `tree_levels_args` - `tree_repeat_block_args` - - 3. The rendering result will be progressively displayed line-by-line with a time interval. + + 3. The rendering result will be progressively displayed line-by-line with a time interval. You can configure this interval through the following steps (must be non-negative): ```python from torchmeter import get_config + cfg = get_config() - cfg.render_interval = 0.5 # unit second, should be non-negative + cfg.render_interval = 0.5 # unit second, should be non-negative ``` - + 4. Disable rendering (`show=False`) when only exporting data to reduce computational overhead. 5. Enable `no_tree` to: @@ -1591,28 +1612,28 @@ def profile(self, 7. When `raw_data=True` displays unformatted values: - `param`: Parameter counts - - `cal`: FLOPs/MACs counts + - `cal`: FLOPs/MACs counts - `mem`: Bytes consumed - `ittp`: Median inference time (seconds) and inferences per second per module - + 8. Column management: - - Use `pick_cols` to reorder columns (validate column names via + - Use `pick_cols` to reorder columns (validate column names via `metered_instance.table_cols(stat_name)`) - Processing order: `pick_cols` -> `exclude_cols` -> `custom_cols` -> `newcol` - - Conflicts: + - Conflicts: - picked columns override custom/newcol names - exclusions override picks - + 9. About `newcol_func`: - - must have exactly **1** formal parameter (name irrelevant) that will receive the + - must have exactly **1** formal parameter (name irrelevant) that will receive the underlying `polars.DataFrame` of specified statistics. - Implement logic using the incoming dataframe to return new column values (must be 1D array-like data - such as `Series`, `lists`, `ndarrays`, etc.). Note that you can use `val` property to access the - raw data for all statistics (for `ittp`, the return will be a tuple made up of the median and iqr of + such as `Series`, `lists`, `ndarrays`, etc.). Note that you can use `val` property to access the + raw data for all statistics (for `ittp`, the return will be a tuple made up of the median and iqr of the measurement data sequence). - - The example below demonstrates adding a percentage column of the `cal` statistics. Refer to + - The example below demonstrates adding a percentage column of the `cal` statistics. Refer to https://docs.pola.rs/api/python/stable/reference/dataframe/index.html for using `polars.Dataframe`. - + 10. The `newcol_idx` parameter mostly follows Python list insertion semantics: - Negative values count backward from end (`-1`=`append`) - `0` inserts at beginning @@ -1623,17 +1644,17 @@ def profile(self, 12. Session persistence: - `keep_new_col` retains created columns - `keep_custom_name` preserves renamed columns - + 13. Export paths: - Directory paths require explicit `save_format` - File paths auto-detect format from extension unless `save_format` overrides - + Example: ```python import torch from torchmeter import Meter from torchvision import models - + # wrap your model with Meter model = models.alexnet() metered_model = Meter(model) @@ -1643,36 +1664,33 @@ def profile(self, metered_model(input) # check column names of cal tabel - print(metered_model.table_cols('cal')) - # ('Operation_Id', 'Operation_Name', 'Operation_Type', 'Kernel_Size', 'Bias', + print(metered_model.table_cols("cal")) + # ('Operation_Id', 'Operation_Name', 'Operation_Type', 'Kernel_Size', 'Bias', # 'Input', 'Output', 'MACs', 'FLOPs') + def newcol_logic(df): - flops_col = df['FLOPs'] - return flops_col.map_elements( - lambda x: f"{100 * x / metered_model.cal.Flops:.4f} %" - ) + flops_col = df["FLOPs"] + return flops_col.map_elements(lambda x: f"{100 * x / metered_model.cal.Flops:.4f} %") + # Customized profile with column operations metered_model.profile( - 'cal', - + "cal", # render and display immediately - show=True, + show=True, no_tree=True, - raw_data=False, - + raw_data=False, # columns management - exclude_cols=['Kernel_Size', 'Bias'], - custom_cols={'Operation_Id': 'ID', 'Operation_Name': 'Module Name', 'Operation_Type': 'Module Type'}, - newcol_name='Percentage', + exclude_cols=["Kernel_Size", "Bias"], + custom_cols={"Operation_Id": "ID", "Operation_Name": "Module Name", "Operation_Type": "Module Type"}, + newcol_name="Percentage", newcol_func=newcol_logic, newcol_type=str, newcol_idx=-1, - # export - save_to='./cal_profile.xlsx', - save_format='xlsx', + save_to="./cal_profile.xlsx", + save_format="xlsx", ) ``` """ @@ -1684,82 +1702,91 @@ def newcol_logic(df): TREE_TABLE_GAP = __cfg__.combine.horizon_gap if not isinstance(stat_name, str): - raise TypeError(f"stat_name must be a string, but got `{type(stat_name).__name__}`.") + raise TypeError(f"stat_name must be a string, but got `{type(stat_name).__name__}`.") if TREE_TABLE_GAP < 0: - raise ValueError("The gap between the rendered tree and the rendered table should be non-negative, " + \ - f"but got `{TREE_TABLE_GAP}`.") - + raise ValueError( + "The gap between the rendered tree and the rendered table should be non-negative, " + + f"but got `{TREE_TABLE_GAP}`." + ) + stat = getattr(self, stat_name) tb, data = self.table_renderer(stat_name=stat_name, **tb_kwargs) - + if not show: return tb, data - + tree = None if no_tree else self.structure - + console = get_console() - tree_width = console.measure(tree).maximum if not no_tree else 0 # type: ignore + tree_width = console.measure(tree).maximum if not no_tree else 0 # type: ignore desirable_tb_width = console.measure(tb).maximum actual_tb_width = min(desirable_tb_width, console.width - tree_width - TREE_TABLE_GAP) - - if actual_tb_width <= 5: # 5 is the minimum width of table - raise RuntimeError("The width of the terminal is too small, try to maximize the window or " + \ - "set a smaller `horizon_gap` value in config and try again.") - + + if actual_tb_width <= 5: # 5 is the minimum width of table + raise RuntimeError( + "The width of the terminal is too small, try to maximize the window or " + + "set a smaller `horizon_gap` value in config and try again." + ) + # when some cells in the table is overflown, we need to show a line between rows if actual_tb_width < desirable_tb_width: - tb.show_lines = True - + tb.show_lines = True + # get main content(i.e. tree & statistics table) if no_tree: - main_content = tb + main_content: Union[Table, Layout] = tb tree_height = 0 else: main_content = Layout() - main_content.split_row(Layout(tree, name='left', size=tree_width + TREE_TABLE_GAP), - Layout(tb, name='right', size=actual_tb_width)) - tree_height = len(console.render_lines(tree)) # type: ignore - - temp_options = console.options.update_width(actual_tb_width) + main_content.split_row( + Layout(tree, name="left", size=tree_width + TREE_TABLE_GAP), + Layout(tb, name="right", size=actual_tb_width), + ) + tree_height = len(console.render_lines(tree)) # type: ignore + + temp_options = console.options.update_width(actual_tb_width) tb_height = len(console.render_lines(tb, options=temp_options)) main_content_height = max(tree_height, tb_height) main_content_width = tree_width + actual_tb_width + (0 if no_tree else TREE_TABLE_GAP) # get footer content - footer = Columns(title=Rule('[gray54]s u m m a r y[/]', characters='-', style='gray54'), # type: ignore - padding=(1,1), - equal=True, - expand=True) - + footer = Columns( + title=Rule("[gray54]s u m m a r y[/]", characters="-", style="gray54"), # type: ignore + padding=(1, 1), + equal=True, + expand=True, + ) + model_info = self.model_info stat_info = self.stat_info(stat_or_statname=stat, show_warning=False) - model_info.style = 'dim' - stat_info.style = 'dim' + model_info.style = "dim" + stat_info.style = "dim" footer.add_renderable(model_info) footer.add_renderable(stat_info) temp_options = console.options.update_width(main_content_width) footer_height = len(console.render_lines(footer, options=temp_options)) - + # render profile canvas = Layout() - canvas.split_column(Layout(main_content, name='top', size=main_content_height), - Layout(footer, name='down', size=footer_height)) - + canvas.split_column( + Layout(main_content, name="top", size=main_content_height), Layout(footer, name="down", size=footer_height) + ) + origin_width = console.width origin_height = console.height console.width = main_content_width console.height = main_content_height + footer_height - - try: + + try: render_perline(renderable=canvas) finally: # if user interupts the rendering when render_interval > 0 # still restore the console size console.width = origin_width console.height = origin_height - + return tb, data def _is_ipt_empty(self) -> bool: @@ -1768,12 +1795,12 @@ def _is_ipt_empty(self) -> bool: Returns: bool: whether the input required for a feed-forward is clear """ - return not self._ipt['args'] and not self._ipt['kwargs'] - + return not self._ipt["args"] and not self._ipt["kwargs"] + def _ipt2device(self) -> None: """Moves all input tensors to the specified device. - This method checks if the input tensors are already on the specified device. + This method checks if the input tensors are already on the specified device. If not, it moves them to the device set in the Meter instance. Raises: @@ -1785,61 +1812,64 @@ def _ipt2device(self) -> None: """ from inspect import signature + forward_args = signature(self.model.forward).parameters if len(forward_args) and self._is_ipt_empty(): raise RuntimeError("No input data provided.") - devices = set(arg.device for arg in self._ipt['args'] if isinstance(arg, Tensor)) - devices.update(kwargs.device for kwargs in self._ipt['kwargs'].values() if isinstance(kwargs, Tensor)) + devices = set(arg.device for arg in self._ipt["args"] if isinstance(arg, Tensor)) + devices.update(kwargs.device for kwargs in self._ipt["kwargs"].values() if isinstance(kwargs, Tensor)) if not len(devices): return - + if len(devices) == 1 and next(iter(devices)) == self.device: return self._ipt = { - 'args': tuple(x.to(self.device) if isinstance(x, Tensor) else x - for x in self._ipt['args']), - 'kwargs': {k: (v.to(self.device) if isinstance(v, Tensor) else v) - for k, v in self._ipt['kwargs'].items()} - } + "args": tuple(x.to(self.device) if isinstance(x, Tensor) else x + for x in self._ipt["args"]), + "kwargs": {k: (v.to(self.device) if isinstance(v, Tensor) else v) + for k, v in self._ipt["kwargs"].items()} + } # fmt: skip - def __device_detect(self, model) -> Union[str, tc_device]: + def __device_detect(self, model: nn.Module) -> Union[str, tc_device]: """Detects the device where the model are located via model's parameters. - This method detects the model's device by checking its parameters' location. + This method detects the model's device by checking its parameters' location. If no parameters are found, it will raise a warning and move the model to CPU by default. Args: model (nn.Module): The model whose device is to be detected. Returns: - Union[str, torch.device]: The device where the model's parameters are located. If no parameters are found, + Union[str, torch.device]: The device where the model's parameters are located. If no parameters are found, returns 'cpu' as the default device. Raises: - UserWarning: If the model has no parameters, a warning is issued indicating that the model will be moved + UserWarning: If the model has no parameters, a warning is issued indicating that the model will be moved to CPU for subsequent analysis. """ - + import warnings - + try: model_first_param = next(model.parameters()) return model_first_param.device - + except StopIteration: - warnings.warn(category=UserWarning, message=\ - "We can't detect the device where your model is located because no parameter was found in your model. " + \ - "We'll move your model to CPU and do all subsequent analysis based on this CPU version. " + \ - "If this isn't what you want, set a specific device when initializing the `Meter` class, " + \ - "e.g. `Meter(your_model, device='cuda:0')`.") - + warnings.warn( + category=UserWarning, + message="We can't detect the device where your model is located because no parameter was found " + + "in your model. We'll move your model to CPU and do all subsequent analysis based on this CPU " + + "version. If this isn't what you want, change the device with `to` method, " + + "e.g. `metered_model.to('cuda')`.", + ) + return "cpu" - - def __is_ipt_changed(self, new_ipt:IPT_TYPE) -> bool: + + def __is_ipt_changed(self, new_ipt: IPT_TYPE) -> bool: # noqa: C901 """Determines if the new input differs from the current captured input. Compares both positional arguments (args) and keyword arguments (kwargs) between current and new input: @@ -1848,7 +1878,7 @@ def __is_ipt_changed(self, new_ipt:IPT_TYPE) -> bool: - Verifies argument structure consistency (same length for args, same keys for kwargs) Args: - new_ipt: New input arguments to compare against currently stored input. + new_ipt: New input arguments to compare against currently stored input. It is a dictionary with two keys: - `args`: A tuple containing all positional arguments. - `kwargs`: A dictionary containing all keyword arguments. @@ -1860,44 +1890,46 @@ def __is_ipt_changed(self, new_ipt:IPT_TYPE) -> bool: 3. Keyword arguments have different keys or values 4. Any argument value differs (non-Tensor) or tensor properties differ (Tensor) """ - + if self._is_ipt_empty(): return True - + is_changed = False - - if len(self._ipt["args"]) == len(new_ipt["args"]): - for origin, new in zip(self._ipt["args"], new_ipt["args"]): - if type(origin) is not type(new): - is_changed = True - elif isinstance(origin, Tensor): - is_changed = origin.shape != new.shape or origin.dtype != new.dtype - else: - is_changed = origin != new - - if is_changed: - return True - else: + + # check anonymous arguments + if len(self._ipt["args"]) != len(new_ipt["args"]): return True - if set(self._ipt["kwargs"].keys()) == set(new_ipt["kwargs"].keys()): - for k, origin in self._ipt["kwargs"].items(): - new = new_ipt["kwargs"][k] - - if type(origin) is not type(new): - is_changed = True - elif isinstance(origin, Tensor): - is_changed = origin.shape != new.shape or origin.dtype != new.dtype - else: - is_changed = origin != new - - if is_changed: - return True - else: + for origin, new in zip(self._ipt["args"], new_ipt["args"]): + if type(origin) is not type(new): + is_changed = True + elif isinstance(origin, Tensor): + is_changed = origin.shape != new.shape or origin.dtype != new.dtype + else: + is_changed = origin != new + + if is_changed: + return True + + # check named arguments + if set(self._ipt["kwargs"].keys()) != set(new_ipt["kwargs"].keys()): return True - + + for k, origin in self._ipt["kwargs"].items(): + new = new_ipt["kwargs"][k] + + if type(origin) is not type(new): + is_changed = True + elif isinstance(origin, Tensor): + is_changed = origin.shape != new.shape or origin.dtype != new.dtype + else: + is_changed = origin != new + + if is_changed: + return True + return False - + def __repr__(self) -> str: return f"Meter(model={self.optree}, device={self.device})" @@ -1913,7 +1945,6 @@ def __get_clsattr_with_settable_flag(cls) -> Dict[str, bool]: Dict[str, bool]: A dictionary where keys are attribute names and values indicate whether the attribute has a setter method (True if settable, False otherwise). """ - - return {k:v.fset is not None for k,v in cls.__dict__.items() - if isinstance(v, property)} - \ No newline at end of file + + return {k: v.fset is not None for k, v in cls.__dict__.items() + if isinstance(v, property)} # fmt: skip diff --git a/torchmeter/display.py b/torchmeter/display.py index fa28430..060baa2 100644 --- a/torchmeter/display.py +++ b/torchmeter/display.py @@ -1,172 +1,177 @@ from __future__ import annotations -from typing import TYPE_CHECKING -import re import os +import re import warnings from copy import copy, deepcopy +from typing import TYPE_CHECKING +from inspect import _empty, signature from collections import OrderedDict -from inspect import signature, _empty -from rich import print +from rich import print # noqa: A004 +from polars import Series, DataFrame from rich.tree import Tree from rich.panel import Panel from rich.table import Table, Column -from polars import DataFrame, Series from torchmeter.utils import dfs_task, resolve_savepath, match_polars_type -from torchmeter.config import get_config, dict_to_namespace, FlagNameSpace +from torchmeter.config import FlagNameSpace, get_config, dict_to_namespace if TYPE_CHECKING: - from typing import Any, Dict, List, Union, Sequence - from typing import Callable, Optional, NamedTuple - - from rich.segment import Segment + from typing import Any, Dict, List, Union, Callable, Optional, Sequence, NamedTuple + from rich.console import Console, RenderableType + from rich.segment import Segment from polars._typing import PolarsDataType from polars.series.series import ArrayLike from torchmeter.engine import OperationNode - - LAZY_STR_TYPE = Optional[Union[str, - Callable[[Dict[str, Any]], str]]] + + LAZY_STR_TYPE = Optional[Union[str, Callable[[Dict[str, Any]], str]]] __all__ = ["render_perline", "TreeRenderer", "TabularRenderer"] __cfg__ = get_config() -def apply_setting(obj:Any, - setting:FlagNameSpace, - omit:Optional[Union[str, Sequence[str]]]=None, - **extra_settings) -> Any: + +def apply_setting( # noqa: C901 + obj: Any, + setting: FlagNameSpace, + omit: Optional[Union[str, Sequence[str]]] = None, + **extra_settings, +) -> Any: """ This funtion is to adapt to the third-party library api change, and to apply the settings to the given object in place. - - we can not directly use `obj.__dict__.update(settings)`, cause sometime the + + we can not directly use `obj.__dict__.update(settings)`, cause sometime the obj does not have a `__dict__` attribute. Additionally, the obj's properties names may have no relationship with its initialization parameters. - - Note: + + Note: - important property(such as Tree.children) should be in omit, or it will be reset. - - although all of this, if we want to omit some arguments, we still need to set omit to + - although all of this, if we want to omit some arguments, we still need to set omit to the inner related attribute name insteat of the initialization parameter name. - """ - + """ # noqa: DOC201, DOC501 + # prepare the setting dict if isinstance(setting, FlagNameSpace): setting_dict = setting.data_dict.copy() elif isinstance(setting, dict): setting_dict = setting.copy() else: - raise TypeError("The `setting` argument should be a `FlagNameSpace` or a `dict`, " + \ - f"but got `{type(setting).__name__}`.") + raise TypeError( + f"The `setting` argument should be a `FlagNameSpace` or a `dict`, but got `{type(setting).__name__}`." + ) setting_dict.update(extra_settings) - + # prepare all the initialization arguments obj_cls = obj.__class__ variable_position_idx = None variable_keyword_argname = None - init_args:Dict[str, Any] = OrderedDict() + init_args: Dict[str, Any] = OrderedDict() for arg_idx, (arg_name, arg) in enumerate(signature(obj_cls).parameters.items()): arg_type = arg.kind.name - + if arg_type == "VAR_POSITIONAL": variable_position_idx = arg_idx variable_position_argname = arg_name init_args[arg_name] = list(setting_dict.get(arg_name, [])) - + elif arg_type == "VAR_KEYWORD": variable_keyword_argname = arg_name if arg_name in setting_dict: init_args[arg_name] = setting_dict[arg_name] else: - init_args[arg_name] = {k:v for k,v in setting_dict.items() - if k not in init_args} - + init_args[arg_name] = {k: v for k, v in setting_dict.items() if k not in init_args} + else: if arg_name not in setting_dict and arg.default is _empty: - try: # try to find the property with same name in the object + try: # try to find the property with same name in the object init_args[arg_name] = getattr(obj, arg_name) - except AttributeError: # if not, this argment is required but absent - raise RuntimeError(f"A required argument `{arg_name}` unknown, " + \ - f"consider providing it via `{arg_name}=xxx` or adding it to config.") + except AttributeError: # if not, this argment is required but absent + raise RuntimeError( + f"A required argument `{arg_name}` unknown, " + + f"consider providing it via `{arg_name}=xxx` or adding it to config." + ) init_args[arg_name] = setting_dict.get(arg_name, arg.default) - + # divide the arguments into position and keyword if variable_position_idx is not None: all_args = tuple(init_args.keys()) position_args = [init_args[k] for k in all_args[:variable_position_idx]] position_args.extend(init_args[variable_position_argname]) - keyword_args = {k:init_args[k] for k in all_args[variable_position_idx+1:]} + keyword_args = {k: init_args[k] for k in all_args[variable_position_idx + 1 :]} else: position_args = [] keyword_args = init_args - + if variable_keyword_argname is not None: keyword_args.update(keyword_args[variable_keyword_argname]) del keyword_args[variable_keyword_argname] - + # initialize a mirror object temp_obj = obj_cls(*position_args, **keyword_args) - + # prepare the state dict - if hasattr(temp_obj, '__dict__'): + if hasattr(temp_obj, "__dict__"): target_state = temp_obj.__dict__ else: all_states = temp_obj.__slots__ - + if not isinstance(all_states, (list, tuple, set)): all_states = [all_states] - all_states = [f"_{obj_cls.__name__}{p}" if p.startswith('__') else p - for p in all_states] - - target_state = {property: getattr(temp_obj, property) - for property in all_states} - + all_states = [f"_{obj_cls.__name__}{p}" if p.startswith("__") else p for p in all_states] + + target_state = {p: getattr(temp_obj, p) for p in all_states} + # filt out the omit items from the state dict if omit is not None: if not isinstance(omit, (str, list, tuple, set)): - raise TypeError("The `omit` argument should be a string, a list, a tuple or a set, " + \ - f"but got `{type(omit).__name__}`.") - + raise TypeError( + f"The `omit` argument should be a string, a list, a tuple or a set, but got `{type(omit).__name__}`." + ) + if isinstance(omit, str): omit_items = set([omit]) else: - if any(not isinstance((inner:=_), str) for _ in omit): - raise TypeError(f"The `omit` argument receives a `{type(omit).__name__}` of `{type(inner).__name__}`, " + \ - "but expect the inner type to be str.") + if any(not isinstance((inner := _), str) for _ in omit): + raise TypeError( + f"The `omit` argument receives a `{type(omit).__name__}` of `{type(inner).__name__}`, " + + "but expect the inner type to be str." + ) omit_items = set(omit) else: omit_items = set() - - target_state = {k:v for k,v in target_state.items() if k not in omit_items} - + + target_state = {k: v for k, v in target_state.items() if k not in omit_items} + # update obj's state with the setting dict - if hasattr(obj, '__dict__'): + if hasattr(obj, "__dict__"): obj.__dict__.update(target_state) else: - list(map(lambda kv: setattr(obj,kv[0],kv[1]), target_state.items())) # type: ignore - + list(map(lambda kv: setattr(obj, kv[0], kv[1]), target_state.items())) # type: ignore + # return the origin object return obj + def render_perline(renderable: RenderableType) -> None: - from time import sleep + from rich import get_console - - time_sep:float = __cfg__.render_interval + + time_sep: float = __cfg__.render_interval if time_sep < 0: raise ValueError(f"The `render_interval` value defined in config must be non-negative, but got `{time_sep}`") - console:Console = get_console() + console: Console = get_console() if not time_sep: console.print(renderable) else: lines: List[List[Segment]] = console.render_lines(renderable, new_lines=True) - + # a fake implementation of `rich.print` console._buffer_index = 0 for line in lines: @@ -174,49 +179,48 @@ def render_perline(renderable: RenderableType) -> None: console._check_buffer() sleep(time_sep) + class TreeRenderer: - - loop_algebras:str = "xyijkabcdefghlmnopqrstuvwz" + "XYIJKABCDEFGHLMNOPQRSTUVWZ" - - def __init__(self, node:OperationNode) -> None: - if node.__class__.__name__ != 'OperationNode': - raise TypeError("Expected `node` to be an instance of `OperationNode`, " + \ - f"but got `{type(node).__name__}`.") - + loop_algebras: str = "xyijkabcdefghlmnopqrstuvwz" + "XYIJKABCDEFGHLMNOPQRSTUVWZ" + + def __init__(self, node: OperationNode) -> None: + if node.__class__.__name__ != "OperationNode": + raise TypeError(f"Expected `node` to be an instance of `OperationNode`, but got `{type(node).__name__}`.") + self.opnode = node - self.render_unfold_tree:Optional[Tree] = None - self.render_fold_tree:Optional[Tree] = None - - def default_rpft(attr_dict:Dict[str, Any]) -> str: - """Must have only one args which accept an attribute dict""" - # basic format of footer in each repeat block - start_idx = attr_dict['node_id'].split('.')[-1] - - repeat_winsz = attr_dict['repeat_winsz'] + self.render_unfold_tree: Optional[Tree] = None + self.render_fold_tree: Optional[Tree] = None + + def default_rpft(attr_dict: Dict[str, Any]) -> str: + """Must have only one args which accept an attribute dict""" # noqa: DOC201 + # basic ext of footer in each repeat block + start_idx = attr_dict["node_id"].split(".")[-1] + + repeat_winsz = attr_dict["repeat_winsz"] if repeat_winsz == 1: - end_idx = int(start_idx) + attr_dict['repeat_time'] -1 + end_idx = int(start_idx) + attr_dict["repeat_time"] - 1 return f"Where ∈ [{start_idx}, {end_idx}]" else: - end_idx = int(start_idx) + attr_dict['repeat_time']*repeat_winsz -1 + end_idx = int(start_idx) + attr_dict["repeat_time"] * repeat_winsz - 1 valid_vals = list(map(str, range(int(start_idx), end_idx, repeat_winsz))) return f"Where = {', '.join(valid_vals)}" - + self.__rpft: LAZY_STR_TYPE = default_rpft - + @property def default_level_args(self) -> FlagNameSpace: - if not hasattr(self.tree_levels_args, 'default'): + if not hasattr(self.tree_levels_args, "default"): self.tree_levels_args.default = dict_to_namespace({ - 'label':'[b gray35]() [green][/green] [cyan][/]', # str | Callable - 'style':'tree', - 'guide_style':'light_coral', - 'highlight':True, - 'hide_root':False, - 'expanded':True + "label": "[b gray35]() [green][/green] [cyan][/]", # str | Callable + "style": "tree", + "guide_style": "light_coral", + "highlight": True, + "hide_root": False, + "expanded": True, }) return self.tree_levels_args.default - + @property def tree_levels_args(self) -> FlagNameSpace: return __cfg__.tree_levels_args @@ -230,114 +234,133 @@ def repeat_footer(self) -> LAZY_STR_TYPE: return self.__rpft @default_level_args.setter # type: ignore - def default_level_args(self, custom_args:Dict[str, Any]) -> None: + def default_level_args(self, custom_args: Dict[str, Any]) -> None: if not isinstance(custom_args, dict): - raise TypeError(f"You can only overwrite `{self.__class__.__name__}.default_level_args` with a dict, " + \ - f"but got `{type(custom_args).__name__}`.") - + raise TypeError( + f"You can only overwrite `{self.__class__.__name__}.default_level_args` with a dict, " + + f"but got `{type(custom_args).__name__}`." + ) + valid_setting_keys = set(signature(Tree).parameters.keys()) passin_keys = set(custom_args.keys()) invalid_keys = passin_keys - valid_setting_keys if invalid_keys: - raise KeyError(f"Keys {invalid_keys} is/are not accepted by `rich.tree.Tree`, refer to " + \ - "https://rich.readthedocs.io/en/latest/reference/tree.html#rich.tree.Tree " + \ - "for valid args.") + raise KeyError( + f"Keys {invalid_keys} is/are not accepted by `rich.tree.Tree`, refer to " + + "https://rich.readthedocs.io/en/latest/reference/tree.html#rich.tree.Tree " + + "for valid args." + ) self.default_level_args.update(custom_args) - + self.default_level_args.mark_change() - @tree_levels_args.setter # type: ignore - def tree_levels_args(self, custom_args:Dict[Any, Dict[str, Any]]) -> None: + @tree_levels_args.setter # type: ignore + def tree_levels_args(self, custom_args: Dict[Any, Dict[str, Any]]) -> None: if not isinstance(custom_args, dict): - raise TypeError(f"You can only overwrite `{self.__class__.__name__}.tree_levels_args` with a dict, " + \ - f"but got `{type(custom_args).__name__}`.") - + raise TypeError( + f"You can only overwrite `{self.__class__.__name__}.tree_levels_args` with a dict, " + + f"but got `{type(custom_args).__name__}`." + ) + # filt out invalid level definations and invalid display settings valid_setting_keys = set(signature(Tree).parameters.keys()) for level, level_args_dict in custom_args.items(): # assure level is a non-negative integer, 'default' or 'all' level = level.lower() - if not level.isnumeric() and level not in ('default', 'all'): - warnings.warn(message="The `level` key should be numeric, `default` or `all`, " + \ - f"but got `{level}`.This setting will be ignored.", - category=UserWarning) + if not level.isnumeric() and level not in ("default", "all"): + warnings.warn( + category=UserWarning, + message="The `level` key should be numeric, `default` or `all`, " + + f"but got `{level}`.This setting will be ignored.", + ) continue - + passin_keys = set(level_args_dict.keys()) invalid_keys = passin_keys - valid_setting_keys if invalid_keys: - raise KeyError(f"Keys {invalid_keys} is/are not accepted by `rich.tree.Tree`, refer to " + \ - "https://rich.readthedocs.io/en/latest/reference/tree.html#rich.tree.Tree " + \ - "for valid args.") - - if level == 'default': - self.default_level_args = level_args_dict # type: ignore - elif level == 'all': - self.default_level_args = level_args_dict # type: ignore - # delete all levels settings - levels = [level for level in self.tree_levels_args.__dict__.keys() - if level.isnumeric()] - list(map(lambda level:delattr(self.tree_levels_args, level), levels)) # type: ignore + raise KeyError( + f"Keys {invalid_keys} is/are not accepted by `rich.tree.Tree`, refer to " + + "https://rich.readthedocs.io/en/latest/reference/tree.html#rich.tree.Tree " + + "for valid args." + ) + + if level == "default": + self.default_level_args = level_args_dict # type: ignore + elif level == "all": + self.default_level_args = level_args_dict # type: ignore + # delete all levels settings + levels = [level for level in self.tree_levels_args.__dict__ if level.isnumeric()] + list(map(lambda level: delattr(self.tree_levels_args, level), levels)) # type: ignore break else: - self.tree_levels_args.update({level:level_args_dict}) - + self.tree_levels_args.update({level: level_args_dict}) + self.tree_levels_args.mark_change() - @repeat_block_args.setter # type: ignore - def repeat_block_args(self, custom_args:Dict[str, Any]) -> None: + @repeat_block_args.setter # type: ignore + def repeat_block_args(self, custom_args: Dict[str, Any]) -> None: if not isinstance(custom_args, dict): - raise TypeError(f"You can only overwrite `{self.__class__.__name__}.repeat_block_args` with a dict, " + \ - f"but got `{type(custom_args).__name__}`.") - - footer_key = list(filter(lambda x: x.lower() == 'repeat_footer', custom_args.keys())) + raise TypeError( + f"You can only overwrite `{self.__class__.__name__}.repeat_block_args` with a dict, " + + f"but got `{type(custom_args).__name__}`." + ) + + footer_key = list(filter(lambda x: x.lower() == "repeat_footer", custom_args.keys())) if footer_key: - self.repeat_footer = custom_args[footer_key[-1]] # type: ignore - del custom_args[footer_key[-1]] - + self.repeat_footer = custom_args[footer_key[-1]] # type: ignore + del custom_args[footer_key[-1]] + valid_setting_keys = set(signature(Panel).parameters.keys()) passin_keys = set(custom_args.keys()) invalid_keys = passin_keys - valid_setting_keys if invalid_keys: - raise KeyError(f"Keys {invalid_keys} is/are not accepted by `rich.panel.Panel`, refer to " + \ - "https://rich.readthedocs.io/en/latest/reference/panel.html#rich.panel.Panel " + \ - "for valid args.") + raise KeyError( + f"Keys {invalid_keys} is/are not accepted by `rich.panel.Panel`, refer to " + + "https://rich.readthedocs.io/en/latest/reference/panel.html#rich.panel.Panel " + + "for valid args." + ) self.repeat_block_args.update(custom_args) - + self.repeat_block_args.mark_change() - @repeat_footer.setter # type: ignore - def repeat_footer(self, custom_footer:LAZY_STR_TYPE) -> None: + @repeat_footer.setter # type: ignore + def repeat_footer(self, custom_footer: LAZY_STR_TYPE) -> None: from inspect import signature - + if callable(custom_footer): - func_args = signature(custom_footer).parameters - + if not len(func_args): - res = custom_footer() # type: ignore + res = custom_footer() # type: ignore if not isinstance(res, (type(None), str)): - raise RuntimeError("If `repeat_foot` is a parameterless function, its return value must be `str` or `None`, " + \ - f"but got a result of type `{type(res).__name__}`.") + raise RuntimeError( + "If `repeat_foot` is a parameterless function, its return value must be `str` or `None`, " + + f"but got a result of type `{type(res).__name__}`." + ) self.__rpft = res - + elif len(func_args) == 1: self.__rpft = custom_footer - + else: - raise RuntimeError("If `repeat_footer` is a parameterized function, it must have exactly one parameter and will accept a `dict` as input, " + \ - f"but there are {len(func_args)} arguments in the passed-in function.") - + raise RuntimeError( + "If `repeat_footer` is a parameterized function, " + + "it must have exactly one parameter and will accept a `dict` as input, " + + f"but there are {len(func_args)} arguments in the passed-in function." + ) + elif isinstance(custom_footer, (type(None), str)): self.__rpft = custom_footer - + else: - raise RuntimeError("The `repeat_footer` can be None, string, a parameterless function, or a function with one argument, " + \ - f"but got `{type(custom_footer).__name__}`.") + raise RuntimeError( + "The `repeat_footer` can be None, string, a parameterless function, or a function with one argument, " + + f"but got `{type(custom_footer).__name__}`." + ) self.repeat_block_args.mark_change() - def resolve_attr(self, attr_val:Any) -> str: + def resolve_attr(self, attr_val: Any) -> str: """ Function to process the attribute value resolved by regex. @@ -348,29 +371,26 @@ def resolve_attr(self, attr_val:Any) -> str: str: the processed result. Must be a string! """ return str(attr_val) - - def __call__(self) -> Tree: - + + def __call__(self) -> Tree: # noqa: C901 from rich.rule import Rule from rich.console import Group - fold_repeat:bool = __cfg__.tree_fold_repeat - - copy_tree:OperationNode = deepcopy(self.opnode) - - # task_func for `dfs_task` - def __render_per_node(subject:OperationNode, - pre_res=None) -> None: + fold_repeat: bool = __cfg__.tree_fold_repeat + copy_tree: OperationNode = deepcopy(self.opnode) + + # task_func for `dfs_task` + def __render_per_node(subject: OperationNode, pre_res=None) -> None: # noqa: ANN001, ARG001, C901 # skip repeat nodes and folded nodes when enable `fold_repeat` - if fold_repeat and subject._is_folded: + if fold_repeat and subject._is_folded: return None if fold_repeat and not subject._render_when_repeat: return None - - display_root:Tree = subject.display_root - level = str(display_root.label) + display_root: Tree = subject.display_root + + level = str(display_root.label) # update display setting for the currently traversed node target_level_args = getattr(self.tree_levels_args, level, self.default_level_args) @@ -380,15 +400,15 @@ def __render_per_node(subject:OperationNode, if fold_repeat and int(level) > 1: subject.node_id = subject.parent.node_id + f".{subject.node_id.split('.')[-1]}" # type: ignore label = self.__resolve_argtext(text=getattr(target_level_args, 'label', self.default_level_args.label), - attr_owner=subject) + attr_owner=subject) # fmt: skip # apply display setting apply_setting(obj=display_root, setting=target_level_args, omit="children", - label=label) - - if fold_repeat: + label=label) # fmt: skip + + if fold_repeat: algebra = self.loop_algebras[0] use_algebra = False @@ -399,91 +419,108 @@ def __render_per_node(subject:OperationNode, # if the repeat body contains more than one operations # get a complete copy of the repeat body, so as to render repeat block more conveniently later. - if subject.repeat_winsz > 1: + if subject.repeat_winsz > 1: use_algebra = True - repeat_body_tree = Tree('.', hide_root=True) - + repeat_body_tree = Tree(".", hide_root=True) + for loop_idx, (node_id, node_name) in enumerate(subject._repeat_body): - repeat_op_node:OperationNode = subject.parent.childs[node_id] # type: ignore - + repeat_op_node: OperationNode = subject.parent.childs[node_id] # type: ignore + # update node_id with a algebraic expression which indicates the loop - if level != '1': - if loop_idx == 0: - repeat_op_node.node_id = repeat_op_node.parent.node_id + f".{algebra}" # type: ignore - else: - repeat_op_node.node_id = repeat_op_node.parent.node_id + f".({algebra}+{loop_idx})" # type: ignore + if level != "1": + if loop_idx == 0: + repeat_op_node.node_id = repeat_op_node.parent.node_id + f".{algebra}" # type: ignore + else: + repeat_op_node.node_id = repeat_op_node.parent.node_id + f".({algebra}+{loop_idx})" # type: ignore else: if loop_idx == 0: repeat_op_node.node_id = algebra else: repeat_op_node.node_id = f"{algebra}+{loop_idx}" - + # resolve label field for the `rich.Tree` object of the currently traversed node - label = self.__resolve_argtext(text=getattr(target_level_args, 'label', self.default_level_args.label), - attr_owner=repeat_op_node) - + label = self.__resolve_argtext( + text=getattr(target_level_args, "label", self.default_level_args.label), + attr_owner=repeat_op_node, + ) + # update display setting for the `rich.Tree` object of the currently traversed node - repeat_display_node:Tree = copy(repeat_op_node.display_root) + repeat_display_node: Tree = copy(repeat_op_node.display_root) apply_setting(obj=repeat_display_node, setting=target_level_args, omit="children", - label=label) - + label=label) # fmt: skip + # Delete repeat nodes and folded nodes (Note: operate in a copied tree) - repeat_display_node.children = [child.display_root for child in repeat_op_node.childs.values() - if child._render_when_repeat and not child._is_folded] - - repeat_body_tree.children.append(repeat_display_node) - + repeat_display_node.children = [ + child.display_root + for child in repeat_op_node.childs.values() + if child._render_when_repeat and not child._is_folded + ] + + repeat_body_tree.children.append(repeat_display_node) + display_root = repeat_body_tree else: - # for the case that the repeat body is only a single operation or the current node is just not a repeat node, - # just delete its repeat childs or the folded childs and need to do nothing more - display_root.children = [child.display_root for child in subject.childs.values() - if child._render_when_repeat and not child._is_folded] - + # for the case that the repeat body is only a single operation or the current node is not a + # repeat node, just delete its repeat childs or the folded childs and need to do nothing more + display_root.children = [ + child.display_root + for child in subject.childs.values() + if child._render_when_repeat and not child._is_folded + ] + # render the repeat body as a panel if subject.repeat_time > 1: use_algebra = True # update node_id with a algebraic expression which indicates the loop - if level != '1': - subject.node_id = subject.parent.node_id + f".{algebra}" # type: ignore + if level != "1": + subject.node_id = subject.parent.node_id + f".{algebra}" # type: ignore else: subject.node_id = algebra - display_root.label = self.__resolve_argtext(text=getattr(target_level_args, 'label', self.default_level_args.label), - attr_owner=subject) - - block_footer = self.__resolve_argtext(text=self.repeat_footer, attr_owner=subject, - loop_algebra=algebra, node_id=origin_node_id) + display_root.label = self.__resolve_argtext( + text=getattr(target_level_args, "label", self.default_level_args.label), + attr_owner=subject + ) # fmt: skip + + block_footer = self.__resolve_argtext( + text=self.repeat_footer, attr_owner=subject, + loop_algebra=algebra, node_id=origin_node_id + ) # fmt: skip if block_footer: - repeat_block_content:Union[Tree, Group] = Group( - copy(display_root), # the tree structure of the circulating body - Rule(characters='-', style='dim ' + getattr(self.repeat_block_args, 'style','')), # a separator made up of '-' + repeat_block_content: Union[Tree, Group] = Group( + # the tree structure of the circulating body + copy(display_root), + # a separator made up of '-' + Rule(characters="-", style="dim " + getattr(self.repeat_block_args, "style", "")), + # footer "[dim]" + block_footer + "[/]", - fit=True + fit=True, ) else: repeat_block_content = copy(display_root) - + # make a pannel to show repeat information - title = self.__resolve_argtext(text=getattr(self.repeat_block_args, 'title', ''), - attr_owner=subject, - loop_algebra=algebra) - + title = self.__resolve_argtext( + text=getattr(self.repeat_block_args, "title", ""), + attr_owner=subject, + loop_algebra=algebra + ) # fmt: skip + repeat_block = apply_setting( obj=Panel(renderable=repeat_block_content), setting=self.repeat_block_args, omit="renderable", title=title, - border_style=self.repeat_block_args.border_style + ' ' + self.repeat_block_args.style + border_style=self.repeat_block_args.border_style + " " + self.repeat_block_args.style, ) - - # overwrite the label of the first node in repeat block + + # overwrite the label of the first node in repeat block subject.display_root.label = repeat_block - # remove all children nodes of the first repeat item, + # remove all children nodes of the first repeat item, # so that only the rendered panel will be displayed subject.display_root.children = [] @@ -491,28 +528,30 @@ def __render_per_node(subject:OperationNode, self.loop_algebras = self.loop_algebras[1:] + algebra return None - + # apply display setting for each node by dfs traversal dfs_task(dfs_subject=copy_tree, - adj_func=lambda x:x.childs.values(), + adj_func=lambda x: x.childs.values(), task_func=__render_per_node, - visited=[]) - + visited=[]) # fmt: skip + # cache the rendered result if fold_repeat: self.render_fold_tree = copy_tree.display_root else: self.render_unfold_tree = copy_tree.display_root - - return copy_tree.display_root - def __resolve_argtext(self, - text:LAZY_STR_TYPE, - attr_owner:"OperationNode", # noqa # type: ignore - **kwargs) -> str: + return copy_tree.display_root + + def __resolve_argtext( + self, + text: LAZY_STR_TYPE, + attr_owner: "OperationNode", # type: ignore + **kwargs, + ) -> str: """ - Disolve all placeholders in form of `<Β·>` in `text`. If you do not want the content in `<Β·>` to - be resolved, you can use `\\<` or `\\>` to escape it. For example, `` will be replaced by + Disolve all placeholders in form of `<Β·>` in `text`. If you do not want the content in `<Β·>` to + be resolved, you can use `\\<` or `\\>` to escape it. For example, `` will be replaced by the value of `attr_owner.name`, while `\\` will not be resolved. Args: @@ -525,83 +564,90 @@ def __resolve_argtext(self, """ attr_dict = copy(attr_owner.__dict__) attr_dict.update(kwargs) - + if callable(text): text = text(attr_dict) elif text is None: return "" - res_str = re.sub(pattern=r'(?', - repl=lambda match: str(self.resolve_attr(attr_dict.get(match.group(1), None))), - string=str(text)) - res_str = re.sub(pattern=r'\\<|\\>', - repl=lambda x: x.group().replace('\\', ''), - string=res_str) + res_str = re.sub( + pattern=r"(?", + repl=lambda match: str(self.resolve_attr(attr_dict.get(match.group(1), None))), + string=str(text), + ) + res_str = re.sub(pattern=r"\\<|\\>", repl=lambda x: x.group().replace("\\", ""), string=res_str) return res_str - + + class TabularRenderer: + def __init__(self, node: OperationNode) -> None: + if node.__class__.__name__ != "OperationNode": + raise TypeError(f"Expected `node` to be an instance of `OperationNode`, but got `{type(node).__name__}`.") - def __init__(self, node:OperationNode) -> None: - if node.__class__.__name__ != 'OperationNode': - raise TypeError("Expected `node` to be an instance of `OperationNode`, " + \ - f"but got `{type(node).__name__}`.") - self.opnode = node # underlying data - self.__stats_data:Dict[str, DataFrame] = {stat_name:DataFrame() for stat_name in node.statistics} + self.__stats_data: Dict[str, DataFrame] = {stat_name: DataFrame() for stat_name in node.statistics} @property - def stats_data(self): + def stats_data(self) -> Dict[str, DataFrame]: return self.__stats_data @property def tb_args(self) -> FlagNameSpace: return __cfg__.table_display_args - + @property def col_args(self) -> FlagNameSpace: return __cfg__.table_column_args @property def valid_export_format(self) -> List[str]: - return ['csv', 'xlsx'] + return ["csv", "xlsx"] - @tb_args.setter # type: ignore - def tb_args(self, custom_args:Dict[str, Any]): + @tb_args.setter # type: ignore + def tb_args(self, custom_args: Dict[str, Any]) -> None: if not isinstance(custom_args, dict): - raise TypeError(f"You can only overwrite `{self.__class__.__name__}.tb_args` with a dict, " + \ - f"but got `{type(custom_args).__name__}`.") - + raise TypeError( + f"You can only overwrite `{self.__class__.__name__}.tb_args` with a dict, " + + f"but got `{type(custom_args).__name__}`." + ) + valid_setting_keys = set(signature(Table).parameters.keys()) passin_keys = set(custom_args.keys()) invalid_keys = passin_keys - valid_setting_keys if invalid_keys: - raise KeyError(f"Keys {invalid_keys} is/are not accepted by `rich.table.Table`, refer to " + \ - "https://rich.readthedocs.io/en/latest/reference/table.html#rich.table.Table " + \ - "for valid args.") + raise KeyError( + f"Keys {invalid_keys} is/are not accepted by `rich.table.Table`, refer to " + + "https://rich.readthedocs.io/en/latest/reference/table.html#rich.table.Table " + + "for valid args." + ) self.tb_args.update(custom_args) - + self.tb_args.mark_change() - - @col_args.setter # type: ignore - def col_args(self, custom_args:Dict[str, Any]): + + @col_args.setter # type: ignore + def col_args(self, custom_args: Dict[str, Any]) -> None: if not isinstance(custom_args, dict): - raise TypeError(f"You can only overwrite `{self.__class__.__name__}.col_args` with a dict, " + \ - f"but got `{type(custom_args).__name__}`.") - + raise TypeError( + f"You can only overwrite `{self.__class__.__name__}.col_args` with a dict, " + + f"but got `{type(custom_args).__name__}`." + ) + valid_setting_keys = set(signature(Column).parameters.keys()) passin_keys = set(custom_args.keys()) invalid_keys = passin_keys - valid_setting_keys if invalid_keys: - raise KeyError(f"Keys {invalid_keys} is/are not accepted by `rich.table.Column`, refer to " + \ - "https://rich.readthedocs.io/en/latest/reference/table.html#rich.table.Column " + \ - "for valid args.") + raise KeyError( + f"Keys {invalid_keys} is/are not accepted by `rich.table.Column`, refer to " + + "https://rich.readthedocs.io/en/latest/reference/table.html#rich.table.Column " + + "for valid args." + ) self.col_args.update(custom_args) - + self.col_args.mark_change() - def df2tb(self, df:DataFrame, show_raw:bool = False) -> Table: + def df2tb(self, df: DataFrame, show_raw: bool = False) -> Table: # create rich table tb_fields = df.columns tb = apply_setting( @@ -609,125 +655,139 @@ def df2tb(self, df:DataFrame, show_raw:bool = False) -> Table: setting=self.tb_args, omit="columns", headers=tb_fields - ) + ) # fmt: skip # apply column settings to all columns list(map(lambda tb_col: apply_setting(obj=tb_col, omit="header", setting=self.col_args, - highlight = self.tb_args.highlight), tb.columns)) + highlight=self.tb_args.highlight), + tb.columns + ) + ) # fmt: skip + if df.is_empty(): return tb - + # collect each column's none replacing string - col_none_str = {col_name:getattr(df[col_name].drop_nulls()[0], 'none_str', '-') - for col_name in df.schema.keys()} - + col_none_str = {col_name: getattr(df[col_name].drop_nulls()[0], "none_str", "-") + for col_name in df.schema} # fmt: skip + # fill table for vals_dict in df.iter_rows(named=True): str_vals = [] - for col_name,col_val in vals_dict.items(): + for col_name, col_val in vals_dict.items(): if col_val is None: str_vals.append(col_none_str[col_name]) elif show_raw: - str_vals.append(str(getattr(col_val, 'raw_data', col_val))) + str_vals.append(str(getattr(col_val, "raw_data", col_val))) else: str_vals.append(str(col_val)) - + tb.add_row(*str_vals) - + return tb - def clear(self, stat_name:Optional[str]=None) -> None: + def clear(self, stat_name: Optional[str] = None) -> None: if not isinstance(stat_name, (str, type(None))): raise TypeError(f"`stat_name` must be a string or None, but got `{type(stat_name).__name__}`.") - + valid_stat_name = self.opnode.statistics - if isinstance(stat_name,str): + if isinstance(stat_name, str): if stat_name not in valid_stat_name: raise ValueError(f"`{stat_name}` not in the supported statistics {valid_stat_name}.") self.__stats_data[stat_name] = DataFrame() else: - self.__stats_data = {stat_name:DataFrame() for stat_name in valid_stat_name} + self.__stats_data = {stat_name: DataFrame() for stat_name in valid_stat_name} def export(self, - df:DataFrame, - save_path:str, - format:Optional[str]=None, - file_suffix:str='', - raw_data:bool=False) -> None: - - from polars import col - from polars import Float64 as pl_float - from polars import String as pl_str + df: DataFrame, + save_path: str, + file_suffix: str = '', + ext: Optional[str] = None, + raw_data: bool = False) -> None: # fmt: skip from polars import List as pl_list from polars import Object as pl_object + from polars import String as pl_str + from polars import Float64 as pl_float + from polars import col save_path = os.path.abspath(save_path) - + # get save path - if format is None: - format = os.path.splitext(save_path)[-1] - if '.' not in format: - raise ValueError("File format unknown! Please specify a path to a file.\n" + \ - f"Or you can specify a file format using `format=xxx`, now we support exporting to {self.valid_export_format} file.") - - format = format.strip('.') - if format not in self.valid_export_format: - raise ValueError(f"`{format}` file is not supported, now we only support exporting to {self.valid_export_format} file.") - + if ext is None: + ext = os.path.splitext(save_path)[-1] + if "." not in ext: + raise ValueError( + "File ext unknown! Please specify a path to a file. " + + "Or you can specify a file extension using `ext=xxx`, " + + f"now we support exporting to {self.valid_export_format} file." + ) + + ext = ext.strip(".") + if ext not in self.valid_export_format: + raise ValueError( + f"`{ext}` file is not supported, now we only support exporting to {self.valid_export_format} file." + ) + default_filename = f"{self.opnode.name}_{file_suffix}" if file_suffix else self.opnode.name - _, file_path = resolve_savepath(origin_path=save_path, - target_ext=format, - default_filename=default_filename) - + _, file_path = resolve_savepath(origin_path=save_path, target_ext=ext, default_filename=default_filename) + # deal with invalid data df = deepcopy(df) - - obj_cols:Dict[str, Any] = {col_name:df[col_name].drop_nulls()[0].__class__ - for col_name, col_type in df.schema.items() if col_type == pl_object} + + obj_cols: Dict[str, Any] = { + col_name: df[col_name].drop_nulls()[0].__class__ + for col_name, col_type in df.schema.items() + if col_type == pl_object + } df = df.with_columns([ - col(col_name).map_elements(lambda s: getattr(s,'raw_data',s.val) if raw_data else str(s), - return_dtype=pl_float if raw_data else pl_str) - for col_name in obj_cols.keys() - ]) - - # export - if format == 'csv': + col(col_name).map_elements( + lambda s: getattr(s, "raw_data", s.val) if raw_data else str(s), + return_dtype=pl_float if raw_data else pl_str, + ) + for col_name in obj_cols + ]) + + # export + if ext == "csv": # list column -> str ls_cols = [col_name for col_name, col_type in df.schema.items() if col_type == pl_list] df = df.with_columns([ - col(col_name).map_elements(lambda s: str(s.to_list()), return_dtype=pl_str) + col(col_name).map_elements(lambda s: str(s.to_list()), return_dtype=pl_str) for col_name in ls_cols - ]) + ]) # fmt: skip df.write_csv(file=file_path) - elif format == 'xlsx': + + elif ext == "xlsx": df.write_excel(workbook=file_path, autofit=True) - + # output saving message if file_suffix: print(f"{file_suffix.capitalize()} data saved to [b magenta]{file_path}[/]") else: print(f"Data saved to [b magenta]{file_path}[/]") - - def __call__(self, - stat_name:str, - *, - raw_data:bool=False, - pick_cols:Sequence[str]=[], - exclude_cols:Sequence[str]=[], - custom_cols:Dict[str, str]={}, - keep_custom_name:bool=False, - newcol_name:str='', - newcol_func:Callable[[DataFrame], ArrayLike]=lambda df: [None]*len(df), - newcol_type:Optional[PolarsDataType]=None, - newcol_idx:int=-1, - keep_new_col:bool=False, - save_to:Optional[str]=None, - save_format:Optional[str]=None): + + def __call__( # noqa: C901 + self, + stat_name: str, + *, + raw_data: bool = False, + pick_cols: Sequence[str] = [], + exclude_cols: Sequence[str] = [], + custom_cols: Dict[str, str] = {}, + keep_custom_name: bool = False, + newcol_name: str = "", + newcol_func: Callable[[DataFrame], ArrayLike] = lambda df: [None] * len(df), + newcol_type: Optional[PolarsDataType] = None, + newcol_idx: int = -1, + keep_new_col: bool = False, + save_to: Optional[str] = None, + save_format: Optional[str] = None, + ) -> tuple[Table, DataFrame]: """render rich tabel according to the statistics dataframe. Note that `pick_cols` work before `custom_col` - """ + """ # noqa: DOC201, DOC501 from collections import defaultdict @@ -741,113 +801,126 @@ def __call__(self, raise TypeError(f"`exclude_cols` must be a list, tuple or set, but got `{type(exclude_cols).__name__}`.") if not isinstance(custom_cols, dict): raise TypeError(f"`custom_cols` must be a dict, but got `{type(custom_cols).__name__}`.") - - data:DataFrame = self.__stats_data[stat_name] + + data: DataFrame = self.__stats_data[stat_name] valid_fields = data.columns or getattr(self.opnode, stat_name).tb_fields - - def __fill_cell(subject:OperationNode, pre_res=None): - nonlocal val_collector, nocall_nodes, col_sample_data # type: ignore - if subject.node_id == '0': + def __fill_cell(subject: OperationNode, pre_res: None = None) -> None: # noqa: ARG001 + nonlocal val_collector, nocall_nodes, col_sample_data # type: ignore + + if subject.node_id == "0": return node_stat = getattr(subject, stat_name) - + try: - stat_infos:List[NamedTuple] = node_stat.detail_val + stat_infos: List[NamedTuple] = node_stat.detail_val for info_nametuple in stat_infos: info_dict = info_nametuple._asdict() - val_collector = {k:val_collector[k] + [v] - for k,v in info_dict.items()} - + val_collector = {k: val_collector[k] + [v] for k, v in info_dict.items()} + if None in col_sample_data.values(): - col_sample_data = {k:col_sample_data[k] or v - for k,v in info_dict.items()} + col_sample_data = {k: col_sample_data[k] or v + for k, v in info_dict.items()} # fmt: skip except RuntimeError: nocall_nodes.append(f"({subject.node_id}){subject.name}") # only when the table is empty, then explore the data using dfs - if data.is_empty(): - nocall_nodes:List[str] = [] - val_collector:Dict[str, List[Any]] = defaultdict(list) - col_sample_data:Dict[str, Any] = {col_name:None for col_name in valid_fields} + if data.is_empty(): + nocall_nodes: List[str] = [] + val_collector: Dict[str, List[Any]] = defaultdict(list) + col_sample_data: Dict[str, Any] = {col_name: None for col_name in valid_fields} dfs_task(dfs_subject=self.opnode, adj_func=lambda x: x.childs.values(), task_func=__fill_cell, - visited=[]) - + visited=[]) # fmt: skip + if not val_collector: raise RuntimeError( - f"No {stat_name} data collected, the reasons are three-folds:\n" + \ - "1. No module is called, make sure that your model's `forward` method is not empty.\n" + \ - "2. The whole model is empty and has no sublayers.\n" + \ - "3. You use a single layer as a model, consider putting it in a class and try again.\n") - - col_data:Dict[str, Series] = {col_name: Series(name=col_name, values=col_val, - dtype=match_polars_type(col_sample_data[col_name])) - for col_name, col_val in val_collector.items()} - + f"No {stat_name} data collected, the reasons are three-folds:\n" + + "1. No module is called, make sure that your model's `forward` method is not empty.\n" + + "2. The whole model is empty and has no sublayers.\n" + + "3. You use a single layer as a model, consider putting it in a class and try again.\n" + ) + + col_data: Dict[str, Series] = { + col_name: Series(name=col_name, values=col_val, dtype=match_polars_type(col_sample_data[col_name])) + for col_name, col_val in val_collector.items() + } + data = DataFrame(data=col_data) self.__stats_data[stat_name] = data - + if nocall_nodes: - warnings.warn(message=f"{', '.join(nocall_nodes)}\nThe modules above might be defined but not explicitly called. " + \ - "They will be ignored in the measuring, so will not appear in the table below.", - category=RuntimeWarning) - + warnings.warn( + category=RuntimeWarning, + message=f"{', '.join(nocall_nodes)}\n" + + "The modules above might be defined but not explicitly called. " + + "They will be ignored in the measuring, so will not appear in the table below.", + ) + # pick columns, order defined by `pick_cols` if pick_cols: - invalid_cols = tuple(filter(lambda col_name:col_name not in valid_fields, pick_cols)) + invalid_cols = tuple(filter(lambda col_name: col_name not in valid_fields, pick_cols)) if invalid_cols: raise ValueError(f"Column names {invalid_cols} not found in supported columns {data.columns}.") else: pick_cols = valid_fields # not use set is to keep order - final_cols = [col_name for col_name in pick_cols if col_name not in exclude_cols] + final_cols = [col_name for col_name in pick_cols + if col_name not in exclude_cols] # fmt: skip data = data.select(final_cols) - + # custom columns name, order defined by `custom_col` if custom_cols: - custom_cols = {k:v for k,v in custom_cols.items() if k in final_cols} + custom_cols = {k: v for k, v in custom_cols.items() if k in final_cols} data = data.rename(custom_cols) if keep_custom_name: self.__stats_data[stat_name] = data - + # add new column if newcol_name: - data = self.__new_col(df=data, - col_name=newcol_name, - col_func=newcol_func, - return_type=newcol_type, - col_idx=newcol_idx) + data = self.__new_col( + df=data, + col_name=newcol_name, + col_func=newcol_func, + return_type=newcol_type, + col_idx=newcol_idx, + ) if keep_new_col: self.__stats_data[stat_name] = data tb = self.df2tb(df=data, show_raw=raw_data) if save_to: - save_to = os.path.abspath(save_to) - if '.' not in os.path.basename(save_to): - if save_format not in self.valid_export_format: - raise ValueError(f"Argument `save_format` must be one in {self.valid_export_format}, but got `{save_format}`. " + \ - "Alternatively, you can set `save_to` to a concrete file path, like `path/to/file.xlsx`") - - self.export(df=data, - save_path=save_to, - format=save_format, - file_suffix=stat_name, - raw_data=raw_data) + save_to = os.path.abspath(save_to) + + # when a dir path is received + if "." not in os.path.basename(save_to) and save_format not in self.valid_export_format: + raise ValueError( + f"Argument `save_format` must be one in {self.valid_export_format}, but got `{save_format}`. " + + "Alternatively, you can set `save_to` to a concrete file path, like `path/to/file.xlsx`" + ) + + self.export( + df=data, + save_path=save_to, + file_suffix=stat_name, + ext=save_format, + raw_data=raw_data, + ) return tb, data - - def __new_col(self, - df:DataFrame, - col_name:str, - col_func:Callable[[DataFrame], ArrayLike], - return_type=None, - col_idx:int = -1) -> DataFrame: + def __new_col( + self, + df: DataFrame, + col_name: str, + col_func: Callable[[DataFrame], ArrayLike], + return_type: Optional[PolarsDataType] = None, + col_idx: int = -1, + ) -> DataFrame: from inspect import signature # validate col_name @@ -856,35 +929,38 @@ def __new_col(self, if col_name in df.columns: raise ValueError(f"Column name `{col_name}` already exists in the table.") - + # validate col_func if not callable(col_func): raise TypeError(f"`col_func` must be a callable object, but got `{type(col_func).__name__}`.") else: col_func_args_num = len(signature(col_func).parameters) if col_func_args_num != 1: - raise TypeError("`col_func` must take exactly only one argument to receive " + \ - f"the backend dataframe, but got {col_func_args_num} instead.") + raise TypeError( + "`col_func` must take exactly only one argument to receive " + + f"the backend dataframe, but got {col_func_args_num} instead." + ) else: func_ret = col_func(df.clone()) try: col_data = Series(values=func_ret, dtype=return_type) except TypeError: - raise TypeError("`col_func` must return an array-like object, " + \ - f"but got `{type(func_ret).__name__}`.") - + raise TypeError( + f"`col_func` must return an array-like object, but got `{type(func_ret).__name__}`." + ) + if len(col_data) != len(df): - raise RuntimeError(f"The result length of `col_func` is {len(col_data)}, " + \ - f"not matchs the backend dataframe's length {len(df)}.") - + raise RuntimeError( + f"The result length of `col_func` is {len(col_data)}, " + + f"not matchs the backend dataframe's length {len(df)}." + ) + # get new column position if col_idx < 0: col_idx = len(df.columns) + col_idx + 1 if abs(col_idx) <= len(df.columns) else 0 - + final_cols = df.columns[:] final_cols.insert(col_idx, col_name) - + # create new column - return df.with_columns( - col_data.alias(col_name) - ).select(final_cols) \ No newline at end of file + return df.with_columns(col_data.alias(col_name)).select(final_cols) diff --git a/torchmeter/engine.py b/torchmeter/engine.py index 620a4db..0d0b800 100644 --- a/torchmeter/engine.py +++ b/torchmeter/engine.py @@ -1,56 +1,61 @@ from __future__ import annotations -from typing import TYPE_CHECKING import re +from typing import TYPE_CHECKING from collections import OrderedDict import torch.nn as nn from rich.tree import Tree -from torchmeter.utils import dfs_task, Timer -from torchmeter.statistic import ParamsMeter, CalMeter, MemMeter, IttpMeter +from torchmeter.utils import Timer, dfs_task +from torchmeter.statistic import CalMeter, MemMeter, IttpMeter, ParamsMeter if TYPE_CHECKING: - from typing import List, Optional, Tuple + from typing import List, Tuple, Optional OPNODE_LIST = List["OperationNode"] __all__ = ["OperationNode", "OperationTree"] -class OperationNode: - - statistics:Tuple[str, ...] = ('param', 'cal', 'mem', 'ittp') # all statistics stored as attributes - - def __init__(self, - module:nn.Module, - name:Optional[str]=None, - node_id:str='0', - parent:Optional[OperationNode]=None): +class OperationNode: + statistics: Tuple[str, ...] = ("param", "cal", "mem", "ittp") # all statistics stored as attributes + + def __init__( + self, + module: nn.Module, + name: Optional[str] = None, + node_id: str = "0", + parent: Optional[OperationNode] = None, + ) -> None: if not isinstance(module, nn.Module): - raise TypeError(f"You must use an `nn.Module` instance to instantiate `{self.__class__.__name__}`, " + \ - f"but got `{type(module).__name__}`.") - + raise TypeError( + f"You must use an `nn.Module` instance to instantiate `{self.__class__.__name__}`, " + + f"but got `{type(module).__name__}`." + ) + # basic info self.operation = module - self.type:str = module.__class__.__name__ - self.name:str = name if name else self.type - self.node_id:str = node_id # index in the model tree, e.g. '1.2.1' - + self.type: str = module.__class__.__name__ + self.name: str = name if name else self.type + self.node_id: str = node_id # index in the model tree, e.g. '1.2.1' + # hierarchical info - self.parent:Optional[OperationNode] = parent - self.childs:OrderedDict[str, "OperationNode"] = OrderedDict() # e.g. {'1.2.1': OperationNode, ...} - self.is_leaf:bool = len(module._modules) == 0 - + self.parent: Optional[OperationNode] = parent + self.childs: OrderedDict[str, "OperationNode"] = OrderedDict() # e.g. {'1.2.1': OperationNode, ...} + self.is_leaf: bool = len(module._modules) == 0 + # repeat info - self.repeat_winsz:int = 1 # size of repeat block - self.repeat_time:int = 1 - self._repeat_body:List[Tuple[str, str]] = [] # the ids and names of the nodes in the same repeat block - - # display info - self.display_root:Tree # set in `OperationTree.__build()` - self._render_when_repeat:bool = False # whether to render when enable `fold_repeat`, set in `OperationTree.__build()` - self._is_folded = False # whether the node is folded in a repeat block, set in `OperationTree.__build()` + self.repeat_winsz: int = 1 # size of repeat block + self.repeat_time: int = 1 + self._repeat_body: List[Tuple[str, str]] = [] # the ids and names of the nodes in the same repeat block + + # display info + self.display_root: Tree # set in `OperationTree.__build()` + # whether to render when the node in the repeat window, set in `OperationTree.__build()` + self._render_when_repeat: bool = False + # whether the node is folded in a repeat block, set in `OperationTree.__build()` + self._is_folded = False self.module_repr = str(self.type) if not self.is_leaf else str(self.operation) # statistic info (all read-only) @@ -70,145 +75,152 @@ def cal(self) -> CalMeter: @property def mem(self) -> MemMeter: return self.__mem - + @property def ittp(self) -> IttpMeter: return self.__ittp - + def __repr__(self) -> str: return f"{self.node_id} {self.name}: {self.module_repr}" - -class OperationTree: - def __init__(self, model:nn.Module) -> None: - + +class OperationTree: + def __init__(self, model: nn.Module) -> None: if not isinstance(model, nn.Module): - raise TypeError(f"You must use an `nn.Module` instance to instantiate `{self.__class__.__name__}`, " + \ - f"but got `{type(model).__name__}`.") - + raise TypeError( + f"You must use an `nn.Module` instance to instantiate `{self.__class__.__name__}`, " + + f"but got `{type(model).__name__}`." + ) + self.root = OperationNode(module=model) self.root._render_when_repeat = True - + with Timer(task_desc="Scanning model"): - nonroot_nodes, *_ = dfs_task(dfs_subject=self.root, - adj_func=lambda x:x.childs.values(), - task_func=OperationTree.__build, - visited=[]) + nonroot_nodes, *_ = dfs_task( + dfs_subject=self.root, + adj_func=lambda x: x.childs.values(), + task_func=OperationTree.__build, + visited=[], + ) + + self.all_nodes: OPNODE_LIST = [self.root, *nonroot_nodes] - self.all_nodes:OPNODE_LIST = [self.root] + nonroot_nodes - @staticmethod - def __build(subject:OperationNode, - pre_res:Tuple[Optional[OPNODE_LIST], Optional[Tree]]=(None, None)) \ - -> Tuple[OPNODE_LIST, Optional[Tree]]: + def __build( + subject: OperationNode, + pre_res: Tuple[Optional[OPNODE_LIST], Optional[Tree]] = (None, None), + ) -> Tuple[OPNODE_LIST, Optional[Tree]]: """ Private method. - This function will explore the model structure, unfold all multi-layers modules, and organize them - into a tree structure. Finally, it will build a display tree and a operation tree in one DFS recursion, + This function will explore the model structure, unfold all multi-layers modules, and organize them + into a tree structure. Finally, it will build a display tree and a operation tree in one DFS recursion, simutaneously. With the built display tree, the terminal display of the model structure can be achieved. With the built operation tree, the model statistic can be easily and quickly accessed. Args: subject (OperationNode): the node to be traversed in the operation tree. pre_res (Tuple[Optional[OPNODE_LIST], Optional[Tree]]): the container of all nodes and the father node - of the display tree - + of the display tree + Returns: - Tuple[Optional[OPNODE_LIST], Optional[Tree]]: the current traversed node of the operation tree and + Tuple[Optional[OPNODE_LIST], Optional[Tree]]: the current traversed node of the operation tree and display tree Note: This function serves as the `task_func` in `torchmeter.utils.dfs_task` """ - + all_nodes, *display_parent = pre_res all_nodes = all_nodes if all_nodes else [] - + # build display tree if display_parent and display_parent[0] is not None: display_parent_node = display_parent[0] - + # create a tree node of rich.Tree, and record the level of the node in attribute 'label' - display_node = Tree(label=str(int(display_parent_node.label)+1)) # type: ignore - + display_node = Tree(label=str(int(display_parent_node.label) + 1)) # type: ignore + # add the current node to the father node - display_parent_node.children.append(display_node) # type: ignore + display_parent_node.children.append(display_node) # type: ignore else: # situation of root node - display_node = Tree(label='0') - + display_node = Tree(label="0") + # link the display node to the attribute `display_root` of operation node subject.display_root = display_node - + # build operation tree copy_childs, str_childs = [], [] for access_idx, (module_name, module) in enumerate(subject.operation._modules.items()): - module_idx = (subject.node_id if subject.node_id != '0' else '') + \ - ('.' if subject.node_id != '0' else '') + \ - str(access_idx+1) - - child = OperationNode(module=module, # type: ignore - name=module_name, - parent=subject, - node_id=module_idx) - + module_idx = ( + (subject.node_id if subject.node_id != "0" else "") + + ("." if subject.node_id != "0" else "") + + str(access_idx + 1) + ) + + child = OperationNode( + module=module, # type: ignore + name=module_name, + parent=subject, + node_id=module_idx, + ) + all_nodes.append(child) - + subject.childs[module_idx] = child copy_childs.append(child) - str_childs.append(re.sub(r'\(.*?\):\s', repl='', string=str(module))) - + str_childs.append(re.sub(r"\(.*?\):\s", repl="", string=str(module))) + # find all potantial maximum repeat block in currently traversed level using greedy strategy slide_start_idx = 0 while slide_start_idx < len(str_childs): now_node = copy_childs[slide_start_idx] now_node._render_when_repeat = True & subject._render_when_repeat - + # find the maximum window size `m` that satifies `str_childs[0:m] == str_childs[m:2*m]` exist_repeat = False - for win_size in range((len(str_childs)-slide_start_idx)//2, 0, -1): - win1_start_end = [slide_start_idx, slide_start_idx+win_size] - win2_start_end = [slide_start_idx+win_size, slide_start_idx+win_size*2] - + for win_size in range((len(str_childs) - slide_start_idx) // 2, 0, -1): + win1_start_end = [slide_start_idx, slide_start_idx + win_size] + win2_start_end = [slide_start_idx + win_size, slide_start_idx + win_size * 2] + if str_childs[slice(*win1_start_end)] == str_childs[slice(*win2_start_end)]: - exist_repeat =True + exist_repeat = True break - + # if the maximum window size `m` does exist, then try to explore the repeat time of the window if exist_repeat: # whether all the modules in the window are the same - inner_repeat = len(set(str_childs[win1_start_end[0]:win2_start_end[1]])) == 1 - + inner_repeat = len(set(str_childs[win1_start_end[0] : win2_start_end[1]])) == 1 + # multiply the window size `m` by 2 if all the modules in the window are the same - repeat_time = win_size*2 if inner_repeat else 2 + repeat_time = win_size * 2 if inner_repeat else 2 win_size = 1 if inner_repeat else win_size - + win2_start_end[0] += win_size win2_start_end[1] += win_size - + while str_childs[slice(*win1_start_end)] == str_childs[slice(*win2_start_end)]: repeat_time += 1 - + win2_start_end[0] += win_size win2_start_end[1] += win_size - + now_node.repeat_winsz = win_size now_node.repeat_time = repeat_time for idx in range(slide_start_idx, slide_start_idx + win_size): inwin_node = copy_childs[idx] inwin_node._render_when_repeat = True & subject._render_when_repeat - inwin_node._is_folded = True if idx-slide_start_idx else False + inwin_node._is_folded = bool(idx - slide_start_idx) now_node._repeat_body.append((inwin_node.node_id, inwin_node.name)) # skip the modules that is repeated slide_start_idx += win_size * repeat_time - + else: - # if there isn't such a window, then the current module is unique, + # if there isn't such a window, then the current module is unique, # then jump to the adjacent module and repeat such a procedure slide_start_idx += 1 - + return (all_nodes, display_node) - + def __repr__(self) -> str: return self.root.__repr__() - \ No newline at end of file diff --git a/torchmeter/statistic.py b/torchmeter/statistic.py index aadb7b7..4802e9c 100644 --- a/torchmeter/statistic.py +++ b/torchmeter/statistic.py @@ -1,49 +1,51 @@ from __future__ import annotations -from typing import TYPE_CHECKING import re -from time import perf_counter -from collections import namedtuple from abc import ABC, abstractmethod -from operator import attrgetter, mul +from time import perf_counter +from typing import TYPE_CHECKING +from operator import mul, attrgetter from functools import reduce, partial +from collections import namedtuple import numpy as np import torch.nn as nn -from pympler.asizeof import asizeof -from torch import no_grad, Tensor +from torch import Tensor, no_grad from torch.cuda import Event as cuda_event from torch.cuda import synchronize as cuda_sync +from pympler.asizeof import asizeof -from torchmeter._stat_numeric import ( - UpperLinkData, MetricsData, - CountUnit, BinaryUnit, TimeUnit, SpeedUnit -) +from torchmeter._stat_numeric import TimeUnit, CountUnit, SpeedUnit, BinaryUnit, MetricsData, UpperLinkData if TYPE_CHECKING: - from typing import Dict, List - from typing import Optional, Tuple, NamedTuple, Sequence + from typing import Any, Dict, List, Tuple, Optional, Sequence, NamedTuple from tqdm import tqdm - from torch.cuda import Event from torch import device as tc_device + from torch.cuda import Event from torch.utils.hooks import RemovableHandle from torchmeter.engine import OperationNode __all__ = ["ParamsMeter", "CalMeter", "MemMeter", "IttpMeter"] -class Statistics(ABC): +class Statistics(ABC): detail_val_container: NamedTuple overview_val_container: NamedTuple - def __new__(cls, *args, **kwargs): - if not hasattr(cls, 'detail_val_container'): - raise AttributeError(f"Class '{cls.__name__}' must have the class attribute 'detail_val_container', which should be a NamedTuple") - if not hasattr(cls, 'overview_val_container'): - raise AttributeError(f"Class '{cls.__name__}' must have the class attribute 'overview_val_container', which should be a NamedTuple") - return super().__new__(cls) + def __new__(cls, *args, **kwargs) -> Statistics: # noqa: ARG004 + if not hasattr(cls, "detail_val_container"): + raise AttributeError( + f"Class '{cls.__name__}' must have the class attribute 'detail_val_container', " + + "which should be a NamedTuple" + ) + if not hasattr(cls, "overview_val_container"): + raise AttributeError( + f"Class '{cls.__name__}' must have the class attribute 'overview_val_container', " + + "which should be a NamedTuple" + ) + return super().__new__(cls) @property @abstractmethod @@ -56,7 +58,7 @@ def name(self) -> str: def val(self) -> NamedTuple: """A namedtuple which contains all the necessary information of the statistics""" ... - + @property @abstractmethod def detail_val(self) -> List[NamedTuple]: @@ -70,85 +72,90 @@ def crucial_data(self) -> Dict[str, str]: ... @abstractmethod - def measure(self, *args, **kwargs): + def measure(self, *args, **kwargs): # noqa: ANN202 """To measure the statistics""" ... @property - def tb_fields(self) -> Tuple[str,...]: + def tb_fields(self) -> Tuple[str, ...]: return self.detail_val_container._fields - + @property - def ov_fields(self) -> Tuple[str,...]: + def ov_fields(self) -> Tuple[str, ...]: return self.overview_val_container._fields - - def init_linkdata(self, - attr_name:str, - init_val:int=0, - opparent:Optional[OperationNode]=None, - **kwargs) -> UpperLinkData: + + def init_linkdata( + self, + attr_name: str, + init_val: int = 0, + opparent: Optional[OperationNode] = None, + **kwargs, + ) -> UpperLinkData: if opparent is None: link_data = UpperLinkData(val=init_val, **kwargs) else: upper_getter = attrgetter(f"{self.name}.{attr_name}") - link_data = UpperLinkData(val=init_val, - parent_data=upper_getter(opparent), - **kwargs) + link_data = UpperLinkData( + val=init_val, + parent_data=upper_getter(opparent), + **kwargs, + ) return link_data def __repr__(self) -> str: - repr_str = self.val.__class__.__name__ + '\n' + repr_str = self.val.__class__.__name__ + "\n" max_len = max(len(f) for f in self.ov_fields) for field in self.ov_fields: - field_val = getattr(self.val, field, 'N/A') + field_val = getattr(self.val, field, "N/A") if isinstance(field_val, UpperLinkData): - repr_str += 'β€’ ' + f"{field.rjust(max_len)} = {field_val.raw_data:.2f} = {field_val}\n" + repr_str += "β€’ " + f"{field.rjust(max_len)} = {field_val.raw_data:.2f} = {field_val}\n" else: - repr_str += 'β€’ ' + f"{field.rjust(max_len)} = {field_val}\n" + repr_str += "β€’ " + f"{field.rjust(max_len)} = {field_val}\n" return repr_str -class ParamsMeter(Statistics): - detail_val_container = namedtuple( # type: ignore - typename='Params_INFO', - field_names=['Operation_Id', 'Operation_Name', 'Operation_Type', - 'Param_Name', 'Requires_Grad', 'Numeric_Num'], - defaults=[None]*6 # type: ignore - ) - - overview_val_container = namedtuple( # type: ignore - typename='Params_INFO', - field_names=['Operation_Id', 'Operation_Name', 'Operation_Type', - 'Total_Params', 'Learnable_Params'], - defaults=[None]*5 # type: ignore - ) +class ParamsMeter(Statistics): + detail_val_container = namedtuple( # type: ignore + typename="Params_INFO", + field_names=["Operation_Id", "Operation_Name", "Operation_Type", + "Param_Name", "Requires_Grad", "Numeric_Num"], + defaults=[None] * 6, # type: ignore + ) # fmt: skip + + overview_val_container = namedtuple( # type: ignore + typename="Params_INFO", + field_names=["Operation_Id", "Operation_Name", "Operation_Type", + "Total_Params", "Learnable_Params"], + defaults=[None] * 5, # type: ignore + ) # fmt: skip def __init__(self, opnode: OperationNode) -> None: - if opnode.__class__.__name__ != 'OperationNode': - raise TypeError("Expected `opnode` to be an instance of `OperationNode`, " + \ - f"but got `{type(opnode).__name__}`.") - + if opnode.__class__.__name__ != "OperationNode": + raise TypeError( + f"Expected `opnode` to be an instance of `OperationNode`, but got `{type(opnode).__name__}`." + ) + self._opnode = opnode - self._model:nn.Module = opnode.operation - - self.__stat_ls:List[NamedTuple] = [] # record all parameters' information - self.is_measured = False # used for cache + self._model: nn.Module = opnode.operation + + self.__stat_ls: List[NamedTuple] = [] # record all parameters' information + self.is_measured = False # used for cache - _opparent:Optional[OperationNode] = opnode.parent - self.__RegNum = self.init_linkdata(attr_name='RegNum', init_val=0, opparent=_opparent, unit_sys=CountUnit) - self.__TotalNum = self.init_linkdata(attr_name='TotalNum', init_val=0, opparent=_opparent, unit_sys=CountUnit) + _opparent: Optional[OperationNode] = opnode.parent + self.__RegNum = self.init_linkdata(attr_name="RegNum", init_val=0, opparent=_opparent, unit_sys=CountUnit) + self.__TotalNum = self.init_linkdata(attr_name="TotalNum", init_val=0, opparent=_opparent, unit_sys=CountUnit) @property def name(self) -> str: - return 'param' + return "param" @property - def RegNum(self) -> UpperLinkData : + def RegNum(self) -> UpperLinkData: return self.__RegNum - + @property def TotalNum(self) -> UpperLinkData: return self.__TotalNum @@ -157,107 +164,125 @@ def TotalNum(self) -> UpperLinkData: def detail_val(self) -> List[NamedTuple]: self.measure() return self.__stat_ls - + @property def val(self) -> NamedTuple: self.measure() return self.overview_val_container( - Operation_Id=self._opnode.node_id, # type: ignore - Operation_Name=self._opnode.name, # type: ignore - Operation_Type=self._opnode.type, # type: ignore - Total_Params=self.TotalNum, # type: ignore - Learnable_Params=self.RegNum) # type: ignore + Operation_Id=self._opnode.node_id, # type: ignore + Operation_Name=self._opnode.name, # type: ignore + Operation_Type=self._opnode.type, # type: ignore + Total_Params=self.TotalNum, # type: ignore + Learnable_Params=self.RegNum, # type: ignore + ) @property def crucial_data(self) -> Dict[str, str]: self.measure() - - res_dict = {'Learnable Parameters Num': str(self.RegNum), - 'Total Parameters Num': str(self.TotalNum)} - max_keylen = max([len(key) for key in res_dict.keys()]) + + res_dict = { + "Learnable Parameters Num": str(self.RegNum), + "Total Parameters Num": str(self.TotalNum), + } + max_keylen = max([len(key) for key in res_dict]) res_dict = {key.ljust(max_keylen): value for key, value in res_dict.items()} return res_dict def measure(self) -> None: if self.is_measured: return - + if not self._model._parameters: - self.__stat_ls.append(self.detail_val_container( # type: ignore - Operation_Id=self._opnode.node_id, # type: ignore - Operation_Name=self._opnode.name, # type: ignore - Operation_Type=self._opnode.type, # type: ignore - Numeric_Num=UpperLinkData(val=0, unit_sys=CountUnit)) # type: ignore + self.__stat_ls.append( + self.detail_val_container( # type: ignore + Operation_Id=self._opnode.node_id, # type: ignore + Operation_Name=self._opnode.name, # type: ignore + Operation_Type=self._opnode.type, # type: ignore + Numeric_Num=UpperLinkData(val=0, unit_sys=CountUnit), # type: ignore + ) ) else: - for param_name, param_val in self._model._parameters.items(): + for param_name, param_val in self._model._parameters.items(): if param_val is None: continue p_num = param_val.numel() - + p_reg = False if param_val.requires_grad: p_reg = True self.__RegNum += p_num - self.__stat_ls.append(self.detail_val_container( # type: ignore - Operation_Id=self._opnode.node_id, # type: ignore - Operation_Name=self._opnode.name, # type: ignore - Operation_Type=self._opnode.type, # type: ignore - Param_Name=param_name, # type: ignore - Requires_Grad=p_reg, # type: ignore - Numeric_Num=UpperLinkData(val=p_num, unit_sys=CountUnit)) # type: ignore + self.__stat_ls.append( + self.detail_val_container( # type: ignore + Operation_Id=self._opnode.node_id, # type: ignore + Operation_Name=self._opnode.name, # type: ignore + Operation_Type=self._opnode.type, # type: ignore + Param_Name=param_name, # type: ignore + Requires_Grad=p_reg, # type: ignore + Numeric_Num=UpperLinkData(val=p_num, unit_sys=CountUnit), # type: ignore + ) ) - + self.__TotalNum += p_num - + self.is_measured = True -class CalMeter(Statistics): - detail_val_container:NamedTuple = namedtuple( # type: ignore - typename='Calculation_INFO', - field_names=['Operation_Id', 'Operation_Name', 'Operation_Type', - 'Kernel_Size', 'Bias', - 'Input', 'Output', - 'MACs', 'FLOPs'], - defaults=(None,)*9) # type: ignore - - overview_val_container:NamedTuple = namedtuple( # type: ignore - typename='Calculation_INFO', - field_names=['Operation_Id', 'Operation_Type', 'Operation_Name', - 'MACs', 'FLOPs'], - defaults=(None,)*5) # type: ignore +class CalMeter(Statistics): + detail_val_container: NamedTuple = namedtuple( # type: ignore + typename="Calculation_INFO", + field_names=[ + "Operation_Id", "Operation_Name", "Operation_Type", + "Kernel_Size", "Bias", + "Input", "Output", + "MACs", "FLOPs", + ], + defaults=(None,) * 9, # type: ignore + ) # fmt: skip + + overview_val_container: NamedTuple = namedtuple( # type: ignore + typename="Calculation_INFO", + field_names=[ + "Operation_Id", "Operation_Type", "Operation_Name", + "MACs", "FLOPs" + ], + defaults=(None,) * 5, # type: ignore + ) # fmt: skip def __init__(self, opnode: OperationNode) -> None: - if opnode.__class__.__name__ != 'OperationNode': - raise TypeError("Expected `opnode` to be an instance of `OperationNode`, " + \ - f"but got `{type(opnode).__name__}`.") + if opnode.__class__.__name__ != "OperationNode": + raise TypeError( + f"Expected `opnode` to be an instance of `OperationNode`, but got `{type(opnode).__name__}`." + ) self._opnode = opnode - self._model:nn.Module = opnode.operation - - self.__stat_ls:List[NamedTuple] = [] # record the flops and macs information of each operation - self.is_measured = False + self._model: nn.Module = opnode.operation + + self.__stat_ls: List[NamedTuple] = [] # record the flops and macs information of each operation + self.is_measured = False self.__is_not_supported = False - _opparent:Optional[OperationNode] = opnode.parent - self.__Macs = self.init_linkdata(attr_name='Macs', init_val=0, opparent=_opparent, - unit_sys=CountUnit, none_str='Not Supported') - self.__Flops = self.init_linkdata(attr_name='Flops', init_val=0, opparent=_opparent, - unit_sys=CountUnit, none_str='Not Supported') + _opparent: Optional[OperationNode] = opnode.parent + self.__Macs = self.init_linkdata( + attr_name="Macs", init_val=0, opparent=_opparent, + unit_sys=CountUnit, none_str="Not Supported" + ) # fmt: skip + self.__Flops = self.init_linkdata( + attr_name="Flops", init_val=0, opparent=_opparent, + unit_sys=CountUnit, none_str="Not Supported" + ) # fmt: skip @property def name(self) -> str: - return 'cal' + return "cal" @property - def is_not_supported(self): + def is_not_supported(self) -> bool: return self.__is_not_supported @property - def Macs(self) -> UpperLinkData : + def Macs(self) -> UpperLinkData: return self.__Macs - + @property def Flops(self) -> UpperLinkData: return self.__Flops @@ -266,32 +291,34 @@ def Flops(self) -> UpperLinkData: def detail_val(self) -> List[NamedTuple]: self.__is_valid_access() return self.__stat_ls - + @property def val(self) -> NamedTuple: self.__is_valid_access() - return self.overview_val_container( # type: ignore - Operation_Id=self._opnode.node_id, # type: ignore - Operation_Type=self._opnode.type, # type: ignore - Operation_Name=self._opnode.name, # type: ignore - MACs=self.Macs, # type: ignore - FLOPs=self.Flops # type: ignore + return self.overview_val_container( # type: ignore + Operation_Id=self._opnode.node_id, # type: ignore + Operation_Type=self._opnode.type, # type: ignore + Operation_Name=self._opnode.name, # type: ignore + MACs=self.Macs, # type: ignore + FLOPs=self.Flops, # type: ignore ) @property def crucial_data(self) -> Dict[str, str]: self.__is_valid_access() - res_dict = {'FLOPs': str(self.Flops), - 'MACs(aka MACC, MADD)': str(self.Macs)} - max_keylen = max([len(key) for key in res_dict.keys()]) + res_dict = { + "FLOPs": str(self.Flops), + "MACs(aka MACC, MADD)": str(self.Macs), + } + max_keylen = max([len(key) for key in res_dict]) res_dict = {key.ljust(max_keylen): value for key, value in res_dict.items()} return res_dict def measure(self) -> Optional[RemovableHandle]: if self.is_measured: return None - - hook = self.__regist_hook(self._model) # torch.utils.hooks.RemovableHandle + + hook = self.__regist_hook(self._model) # torch.utils.hooks.RemovableHandle self.is_measured = True @@ -299,40 +326,44 @@ def measure(self) -> Optional[RemovableHandle]: def __is_valid_access(self) -> bool: if self.is_measured: - if not (self.Flops.val + self.Macs.val) and \ - not self.__stat_ls and \ - not isinstance(self._model, (nn.ModuleDict, nn.ModuleList)): + if ( + not (self.Flops.val + self.Macs.val) + and not self.__stat_ls + and not isinstance(self._model, (nn.ModuleDict, nn.ModuleList)) + ): raise RuntimeError("This module might be defined but not explicitly called, so no data is collected.") else: - raise AttributeError("You should never access this property on your own " + \ - "before accessing `Meter(your_model).cal`.") + raise AttributeError( + "You should never access this property on your own before accessing `Meter(your_model).cal`." + ) return True - def __iopt_repr(self, iopt) -> str: - repr: Sequence[str] + def __iopt_repr(self, iopt: Any) -> str: + item_repr: Sequence[str] if isinstance(iopt, Tensor): return str(list(iopt.shape)) - + elif iopt is None: - return 'None' + return "None" elif isinstance(iopt, (tuple, list, set)): - repr = tuple(map(self.__iopt_repr, iopt)) - return '(' + ',\n '.join(repr) + ')' if len(repr) > 1 else repr[0] - + item_repr = tuple(map(self.__iopt_repr, iopt)) + return "(" + ",\n ".join(item_repr) + ")" if len(item_repr) > 1 else item_repr[0] + elif isinstance(iopt, dict): - repr = ["{}: {}".format(self.__iopt_repr(k), self.__iopt_repr(v)) - for k, v in iopt.items()] - return '{' + ',\n '.join(repr) + '}' - + item_repr = ["{}: {}".format(self.__iopt_repr(k), self.__iopt_repr(v)) + for k, v in iopt.items()] # fmt: skip + + return "{" + ",\n ".join(item_repr) + "}" + else: return type(iopt).__name__ - def __regist_hook(self, module): + def __regist_hook(self, module: nn.Module) -> RemovableHandle: if not self._opnode.is_leaf: h = module.register_forward_hook(self.__container_hook) - + elif isinstance(module, (nn.Conv1d, nn.Conv2d, nn.Conv3d)): h = module.register_forward_hook(self.__conv_hook) @@ -352,42 +383,44 @@ def __regist_hook(self, module): h = module.register_forward_hook(self.__not_support_hook) return h - - def __conv_hook(self, module, input, output): - c_in = input[0].shape[1] + + def __conv_hook(self, module: nn.Module, ipt: Tuple[Tensor], opt: Tensor) -> None: + c_in = ipt[0].shape[1] n = c_in * reduce(mul, module.kernel_size) - m = output.numel() + m = opt.numel() is_bias = 1 if module.bias is not None else 0 - FLOPs = m*(2*n-1+is_bias) - MACs = m*n + FLOPs = m * (2 * n - 1 + is_bias) + MACs = m * n self.__Macs += MACs self.__Flops += FLOPs - + if len(self.__stat_ls): self.Macs.mark_access() self.Flops.mark_access() else: - self.__stat_ls.append(self.detail_val_container( - Operation_Id=self._opnode.node_id, - Operation_Name=self._opnode.name, - Operation_Type=self._opnode.type, - Kernel_Size=list(module.kernel_size), - Bias=bool(is_bias), - Input=self.__iopt_repr(input), - Output=self.__iopt_repr(output), - MACs=self.Macs, - FLOPs=self.Flops) + self.__stat_ls.append( + self.detail_val_container( # type: ignore + Operation_Id=self._opnode.node_id, # type: ignore + Operation_Name=self._opnode.name, # type: ignore + Operation_Type=self._opnode.type, # type: ignore + Kernel_Size=list(module.kernel_size), # type: ignore + Bias=bool(is_bias), # type: ignore + Input=self.__iopt_repr(ipt), # type: ignore + Output=self.__iopt_repr(opt), # type: ignore + MACs=self.Macs, # type: ignore + FLOPs=self.Flops, # type: ignore + ) ) - - def __linear_hook(self, module, input, output): + + def __linear_hook(self, module: nn.Module, ipt: Tuple[Tensor], opt: Tensor) -> None: k = module.in_features - l = module.out_features # noqa + l = module.out_features # noqa is_bias = 1 if module.bias is not None else 0 n = k - FLOPs = l*(2*n-1 + is_bias) - MACs = l*n + FLOPs = l * (2 * n - 1 + is_bias) + MACs = l * n self.__Macs += MACs self.__Flops += FLOPs @@ -395,20 +428,22 @@ def __linear_hook(self, module, input, output): self.Macs.mark_access() self.Flops.mark_access() else: - self.__stat_ls.append(self.detail_val_container( - Operation_Id=self._opnode.node_id, - Operation_Name=self._opnode.name, - Operation_Type=self._opnode.type, - Bias=bool(is_bias), - Input=self.__iopt_repr(input), - Output=self.__iopt_repr(output), - MACs=self.Macs, - FLOPs=self.Flops) + self.__stat_ls.append( + self.detail_val_container( # type: ignore + Operation_Id=self._opnode.node_id, # type: ignore + Operation_Name=self._opnode.name, # type: ignore + Operation_Type=self._opnode.type, # type: ignore + Bias=bool(is_bias), # type: ignore + Input=self.__iopt_repr(ipt), # type: ignore + Output=self.__iopt_repr(opt), # type: ignore + MACs=self.Macs, # type: ignore + FLOPs=self.Flops, # type: ignore + ) ) - def __BN_hook(self, module, input, output): - FLOPs = 4*input[0].numel() - MACs = 0.5*FLOPs + def __BN_hook(self, module: nn.Module, ipt: Tuple[Tensor], opt: Tensor) -> None: # noqa: ARG002 + FLOPs = 4 * ipt[0].numel() + MACs = 0.5 * FLOPs self.__Macs += MACs self.__Flops += FLOPs @@ -416,33 +451,35 @@ def __BN_hook(self, module, input, output): self.Macs.mark_access() self.Flops.mark_access() else: - self.__stat_ls.append(self.detail_val_container( - Operation_Id=self._opnode.node_id, - Operation_Name=self._opnode.name, - Operation_Type=self._opnode.type, - Input=self.__iopt_repr(input), - Output=self.__iopt_repr(output), - MACs=self.Macs, - FLOPs=self.Flops) + self.__stat_ls.append( + self.detail_val_container( # type: ignore + Operation_Id=self._opnode.node_id, # type: ignore + Operation_Name=self._opnode.name, # type: ignore + Operation_Type=self._opnode.type, # type: ignore + Input=self.__iopt_repr(ipt), # type: ignore + Output=self.__iopt_repr(opt), # type: ignore + MACs=self.Macs, # type: ignore + FLOPs=self.Flops, # type: ignore + ) ) - def __activate_hook(self, module, input, output): - k = input[0].numel() + def __activate_hook(self, module: nn.Module, ipt: Tuple[Tensor], opt: Tensor) -> None: + k = ipt[0].numel() if isinstance(module, (nn.Sigmoid, nn.PReLU, nn.RReLU, nn.LeakyReLU)): - FLOPs = 4*k - MACs = 2*k + FLOPs = 4 * k + MACs = 2 * k elif isinstance(module, nn.Tanh): - FLOPs = 9*k - MACs = 5*k + FLOPs = 9 * k + MACs = 5 * k elif isinstance(module, (nn.ReLU, nn.ReLU6)): FLOPs = k MACs = k - else: # SiLU - FLOPs = 5*k - MACs = 3*k + else: # SiLU + FLOPs = 5 * k + MACs = 3 * k self.__Macs += MACs self.__Flops += FLOPs @@ -451,30 +488,32 @@ def __activate_hook(self, module, input, output): self.Macs.mark_access() self.Flops.mark_access() else: - self.__stat_ls.append(self.detail_val_container( - Operation_Id=self._opnode.node_id, - Operation_Name=self._opnode.name, - Operation_Type=self._opnode.type, - Input=self.__iopt_repr(input), - Output=self.__iopt_repr(output), - MACs=self.Macs, - FLOPs=self.Flops) + self.__stat_ls.append( + self.detail_val_container( # type: ignore + Operation_Id=self._opnode.node_id, # type: ignore + Operation_Name=self._opnode.name, # type: ignore + Operation_Type=self._opnode.type, # type: ignore + Input=self.__iopt_repr(ipt), # type: ignore + Output=self.__iopt_repr(opt), # type: ignore + MACs=self.Macs, # type: ignore + FLOPs=self.Flops, # type: ignore + ) ) - def __pool_hook(self, module, input, output): + def __pool_hook(self, module: nn.Module, ipt: Tuple[Tensor], opt: Tensor) -> None: k = module.kernel_size if isinstance(k, int): - dimension = int(re.findall(r'\d+', module.__class__.__name__)[0]) - k = (k,)*dimension + dimension = int(re.findall(r"\d+", module.__class__.__name__)[0]) + k = (k,) * dimension - n = reduce(mul, k)-1 - m = output.numel() + n = reduce(mul, k) - 1 + m = opt.numel() - if isinstance(module, (nn.MaxPool1d, nn.MaxPool2d, nn.MaxPool3d)): - FLOPs = n*m - else: # avgpool - FLOPs = (2*n+1)*m - MACs = n*m + if isinstance(module, (nn.MaxPool1d, nn.MaxPool2d, nn.MaxPool3d)): # noqa: SIM108 + FLOPs = n * m + else: # avgpool + FLOPs = (2 * n + 1) * m + MACs = n * m self.__Macs += MACs self.__Flops += FLOPs @@ -483,86 +522,107 @@ def __pool_hook(self, module, input, output): self.Macs.mark_access() self.Flops.mark_access() else: - self.__stat_ls.append(self.detail_val_container( - Operation_Id=self._opnode.node_id, - Operation_Name=self._opnode.name, - Operation_Type=self._opnode.type, - Kernel_Size=list(k), - Input=self.__iopt_repr(input), - Output=self.__iopt_repr(output), - MACs=self.Macs, - FLOPs=self.Flops) + self.__stat_ls.append( + self.detail_val_container( # type: ignore + Operation_Id=self._opnode.node_id, # type: ignore + Operation_Name=self._opnode.name, # type: ignore + Operation_Type=self._opnode.type, # type: ignore + Kernel_Size=list(k), # type: ignore + Input=self.__iopt_repr(ipt), # type: ignore + Output=self.__iopt_repr(opt), # type: ignore + MACs=self.Macs, # type: ignore + FLOPs=self.Flops, # type: ignore + ) ) - def __container_hook(self, module, input, output): + def __container_hook(self, module: nn.Module, ipt: Any, opt: Any) -> None: # noqa: ARG002 if len(self.__stat_ls): self.Macs.mark_access() self.Flops.mark_access() else: - self.__stat_ls.append(self.detail_val_container( - Operation_Id=self._opnode.node_id, - Operation_Name=self._opnode.name, - Operation_Type=self._opnode.type, - Input=self.__iopt_repr(input), - Output=self.__iopt_repr(output), - MACs=self.Macs, - FLOPs=self.Flops) + self.__stat_ls.append( + self.detail_val_container( # type: ignore + Operation_Id=self._opnode.node_id, # type: ignore + Operation_Name=self._opnode.name, # type: ignore + Operation_Type=self._opnode.type, # type: ignore + Input=self.__iopt_repr(ipt), # type: ignore + Output=self.__iopt_repr(opt), # type: ignore + MACs=self.Macs, # type: ignore + FLOPs=self.Flops, # type: ignore + ) ) - def __not_support_hook(self, module, input, output): + def __not_support_hook(self, module: nn.Module, ipt: Any, opt: Any) -> None: # noqa: ARG002 self.__is_not_supported = True - + if not len(self.__stat_ls): - self.__stat_ls.append(self.detail_val_container( - Operation_Id=self._opnode.node_id, - Operation_Name=self._opnode.name, - Operation_Type=self._opnode.type, - Input=self.__iopt_repr(input), - Output=self.__iopt_repr(output)) + self.__stat_ls.append( + self.detail_val_container( # type: ignore + Operation_Id=self._opnode.node_id, # type: ignore + Operation_Name=self._opnode.name, # type: ignore + Operation_Type=self._opnode.type, # type: ignore + Input=self.__iopt_repr(ipt), # type: ignore + Output=self.__iopt_repr(opt), # type: ignore + ) ) -class MemMeter(Statistics): - detail_val_container:NamedTuple = namedtuple( # type: ignore - typename='Memory_INFO', - field_names=['Operation_Id', 'Operation_Name', 'Operation_Type', - 'Param_Cost', 'Buffer_Cost', 'Output_Cost', - 'Total'], - defaults=(None,)*7) # type: ignore - - overview_val_container:NamedTuple = namedtuple( # type: ignore - typename='Memory_INFO', - field_names=['Operation_Id', 'Operation_Type', 'Operation_Name', - 'Param_Cost', 'Buffer_Cost', 'Output_Cost', - 'Total'], - defaults=(None,)*7) # type: ignore +class MemMeter(Statistics): + detail_val_container: NamedTuple = namedtuple( # type: ignore + typename="Memory_INFO", + field_names=[ + "Operation_Id", "Operation_Name", "Operation_Type", + "Param_Cost", "Buffer_Cost", "Output_Cost", + "Total", + ], + defaults=(None,) * 7, # type: ignore + ) # fmt: skip + + overview_val_container: NamedTuple = namedtuple( # type: ignore + typename="Memory_INFO", + field_names=[ + "Operation_Id", "Operation_Type", "Operation_Name", + "Param_Cost", "Buffer_Cost", "Output_Cost", + "Total", + ], + defaults=(None,) * 7, # type: ignore + ) # fmt: skip def __init__(self, opnode: OperationNode) -> None: - if opnode.__class__.__name__ != 'OperationNode': - raise TypeError("Expected `opnode` to be an instance of `OperationNode`, " + \ - f"but got `{type(opnode).__name__}`.") - + if opnode.__class__.__name__ != "OperationNode": + raise TypeError( + f"Expected `opnode` to be an instance of `OperationNode`, but got `{type(opnode).__name__}`." + ) + self._opnode = opnode - self._model:nn.Module = opnode.operation - self.is_inplace:bool = getattr(self._model, 'inplace', False) - - self.__stat_ls:List[NamedTuple] = [] # record the flops and macs information of each operation - self.is_measured = False # used for cache + self._model: nn.Module = opnode.operation + self.is_inplace: bool = getattr(self._model, "inplace", False) - _opparent:Optional[OperationNode] = opnode.parent - self.__ParamCost = self.init_linkdata(attr_name='ParamCost', init_val=0, opparent=_opparent, unit_sys=BinaryUnit) - self.__BufferCost = self.init_linkdata(attr_name='BufferCost', init_val=0, opparent=_opparent, unit_sys=BinaryUnit) - self.__OutputCost = self.init_linkdata(attr_name='OutputCost', init_val=0, opparent=_opparent, unit_sys=BinaryUnit) - self.__TotalCost = self.init_linkdata(attr_name='TotalCost', init_val=0, opparent=_opparent, unit_sys=BinaryUnit) + self.__stat_ls: List[NamedTuple] = [] # record the flops and macs information of each operation + self.is_measured = False # used for cache + + _opparent: Optional[OperationNode] = opnode.parent + self.__ParamCost = self.init_linkdata( + attr_name="ParamCost", init_val=0, opparent=_opparent, unit_sys=BinaryUnit + ) + self.__BufferCost = self.init_linkdata( + attr_name="BufferCost", init_val=0, opparent=_opparent, unit_sys=BinaryUnit + ) + self.__OutputCost = self.init_linkdata( + attr_name="OutputCost", init_val=0, opparent=_opparent, unit_sys=BinaryUnit + ) + self.__TotalCost = self.init_linkdata( + attr_name="TotalCost", init_val=0, opparent=_opparent, unit_sys=BinaryUnit + ) @property def name(self) -> str: - return 'mem' + return "mem" @property - def ParamCost(self) -> UpperLinkData : + def ParamCost(self) -> UpperLinkData: return self.__ParamCost - + @property def BufferCost(self) -> UpperLinkData: return self.__BufferCost @@ -570,95 +630,103 @@ def BufferCost(self) -> UpperLinkData: @property def OutputCost(self) -> UpperLinkData: return self.__OutputCost - + @property def TotalCost(self) -> UpperLinkData: return self.__TotalCost - + @property def detail_val(self) -> List[NamedTuple]: self.__is_valid_access() return self.__stat_ls - + @property def val(self) -> NamedTuple: self.__is_valid_access() - return self.overview_val_container( # type: ignore - Operation_Id=self._opnode.node_id, # type: ignore - Operation_Type=self._opnode.type, # type: ignore - Operation_Name=self._opnode.name, # type: ignore - Param_Cost=self.ParamCost, # type: ignore - Buffer_Cost=self.BufferCost, # type: ignore - Output_Cost=self.OutputCost, # type: ignore - Total=self.TotalCost # type: ignore + return self.overview_val_container( # type: ignore + Operation_Id=self._opnode.node_id, # type: ignore + Operation_Type=self._opnode.type, # type: ignore + Operation_Name=self._opnode.name, # type: ignore + Param_Cost=self.ParamCost, # type: ignore + Buffer_Cost=self.BufferCost, # type: ignore + Output_Cost=self.OutputCost, # type: ignore + Total=self.TotalCost, # type: ignore ) @property def crucial_data(self) -> Dict[str, str]: self.__is_valid_access() - res_dict = {'[b]Parameters[/] Memory Cost': f"{self.ParamCost}, {self.ParamCost.val*100/self.TotalCost.val:.2f} %", - '[b]Buffers[/] Memory Cost': f"{self.BufferCost}, {self.BufferCost.val*100/self.TotalCost.val:.2f} %", - '[b]FeatureMap[/] Memory Cost': f"{self.OutputCost}, {self.OutputCost.val*100/self.TotalCost.val:.2f} %", - '[b]Total Memory Cost[/]': str(self.TotalCost)} - max_keylen = max([len(key) for key in res_dict.keys()]) + + total_cost = self.TotalCost.val + res_dict = { + "[b]Parameters[/] Memory Cost": f"{self.ParamCost}, {self.ParamCost.val * 100 / total_cost:.2f} %", + "[b]Buffers[/] Memory Cost": f"{self.BufferCost}, {self.BufferCost.val * 100 / total_cost:.2f} %", + "[b]FeatureMap[/] Memory Cost": f"{self.OutputCost}, {self.OutputCost.val * 100 / total_cost:.2f} %", + "[b]Total Memory Cost[/]": str(self.TotalCost), + } + + max_keylen = max([len(key) for key in res_dict]) res_dict = {key.ljust(max_keylen): value for key, value in res_dict.items()} return res_dict def measure(self) -> Optional[RemovableHandle]: if self.is_measured: return None - + hook = self._model.register_forward_hook(self.__hook_func) - + self.is_measured = True return hook - def __hook_func(self, module, input, output): + def __hook_func(self, module: nn.Module, ipt: Any, opt: Any) -> None: # noqa: ARG002, C901 opt_cost = 0 if self._opnode.is_leaf and not self.is_inplace: - output = output if isinstance(output, tuple) else (output,) - for opt in output: + outs = opt if isinstance(opt, tuple) else (opt,) + for opt in outs: if isinstance(opt, Tensor): - opt_cost += opt.numel() * opt.element_size() # byte + opt_cost += opt.numel() * opt.element_size() # byte elif isinstance(opt, np.ndarray): opt_cost += opt.nbytes elif isinstance(opt, str): - # Note: string storage is optimized after python 3.12, so the value changes with the version, + # Note: string storage is optimized after python 3.12, so the value changes with the version, # but it is not significantly changed, which does not affect the macro measurement results. opt_cost += opt.__sizeof__() else: opt_cost += asizeof(opt) self.__OutputCost += opt_cost - + if len(self.__stat_ls): # duplicated access self.OutputCost.mark_access() total_cost = opt_cost else: - param_cost = 0 # byte + param_cost = 0 # byte for param in module._parameters.values(): if param is None: continue param_cost += param.numel() * param.element_size() - self.__ParamCost += param_cost - - buffer_cost = 0 # byte + self.__ParamCost += param_cost + + buffer_cost = 0 # byte for buffer in module._buffers.values(): - buffer_cost += buffer.numel() * buffer.element_size() + if buffer is not None: + buffer_cost += buffer.numel() * buffer.element_size() self.__BufferCost += buffer_cost total_cost = param_cost + buffer_cost + opt_cost - self.__stat_ls.append(self.detail_val_container( - Operation_Id=self._opnode.node_id, - Operation_Name=self._opnode.name, - Operation_Type=self._opnode.type + ('(inplace)' if self.is_inplace else ''), - Param_Cost=None if self._opnode.is_leaf and not param_cost else self.ParamCost, - Buffer_Cost=None if self._opnode.is_leaf and not buffer_cost else self.BufferCost, - Output_Cost=None if self._opnode.is_leaf and not opt_cost else self.OutputCost, - Total=None if self._opnode.is_leaf and not total_cost else self.TotalCost) + self.__stat_ls.append( + self.detail_val_container( # type: ignore + Operation_Id=self._opnode.node_id, # type: ignore + Operation_Name=self._opnode.name, # type: ignore + Operation_Type=self._opnode.type + ("(inplace)" if self.is_inplace else ""), # type: ignore + Param_Cost=None if self._opnode.is_leaf and not param_cost else self.ParamCost, # type: ignore + Buffer_Cost=None if self._opnode.is_leaf and not buffer_cost else self.BufferCost, # type: ignore + Output_Cost=None if self._opnode.is_leaf and not opt_cost else self.OutputCost, # type: ignore + Total=None if self._opnode.is_leaf and not total_cost else self.TotalCost, # type: ignore + ) ) - + self.__TotalCost += total_cost def __is_valid_access(self) -> bool: @@ -666,33 +734,41 @@ def __is_valid_access(self) -> bool: if not self.__stat_ls and not isinstance(self._model, (nn.ModuleDict, nn.ModuleList)): raise RuntimeError("This module might be defined but not explicitly called, so no data is collected.") else: - raise AttributeError("You should never access this property on your own " + \ - "before accessing `Meter(your_model).mem`.") + raise AttributeError( + "You should never access this property on your own before accessing `Meter(your_model).mem`." + ) return True + class IttpMeter(Statistics): + detail_val_container: NamedTuple = namedtuple( # type: ignore + typename="InferTime_Throughput_INFO", + field_names=[ + "Operation_Id", "Operation_Name", "Operation_Type", + "Infer_Time", "Throughput" + ], + defaults=(None,) * 5, # type: ignore + ) # fmt: skip + + overview_val_container: NamedTuple = namedtuple( # type: ignore + typename="InferTime_Throughput_INFO", + field_names=[ + "Operation_Id", "Operation_Name", "Operation_Type", + "Infer_Time", "Throughput" + ], + defaults=(None,) * 5, # type: ignore + ) # fmt: skip - detail_val_container:NamedTuple = namedtuple( # type: ignore - typename='InferTime_Throughput_INFO', - field_names=['Operation_Id', 'Operation_Name', 'Operation_Type', - 'Infer_Time', 'Throughput'], - defaults=(None,)*5) # type: ignore - - overview_val_container:NamedTuple = namedtuple( # type: ignore - typename='InferTime_Throughput_INFO', - field_names=['Operation_Id', 'Operation_Name','Operation_Type', - 'Infer_Time', 'Throughput'], - defaults=(None,)*5,) # type: ignore - def __init__(self, opnode: OperationNode) -> None: - if opnode.__class__.__name__ != 'OperationNode': - raise TypeError("Expected `opnode` to be an instance of `OperationNode`, " + \ - f"but got `{type(opnode).__name__}`.") - + if opnode.__class__.__name__ != "OperationNode": + raise TypeError( + f"Expected `opnode` to be an instance of `OperationNode`, but got `{type(opnode).__name__}`." + ) + self._opnode = opnode - self._model:nn.Module = opnode.operation - - self.__stat_ls:List[NamedTuple] = [] # record the inference time and throughput of each operation + self._model: nn.Module = opnode.operation + + self.__stat_ls: List[NamedTuple] = [] # record the inference time and throughput of each operation self.is_measured = False self.__InferTime = MetricsData(reduce_func=np.median, unit_sys=TimeUnit) @@ -700,12 +776,12 @@ def __init__(self, opnode: OperationNode) -> None: @property def name(self) -> str: - return 'ittp' + return "ittp" @property - def InferTime(self) -> MetricsData : + def InferTime(self) -> MetricsData: return self.__InferTime - + @property def Throughput(self) -> MetricsData: return self.__Throughput @@ -714,86 +790,99 @@ def Throughput(self) -> MetricsData: def detail_val(self) -> List[NamedTuple]: self.__is_valid_access() return self.__stat_ls - + @property def val(self) -> NamedTuple: self.__is_valid_access() - return self.overview_val_container( # type: ignore - Operation_Id=self._opnode.node_id, # type: ignore - Operation_Type=self._opnode.type, # type: ignore - Operation_Name=self._opnode.name, # type: ignore - Infer_Time=self.__InferTime, # type: ignore - Throughput=self.__Throughput # type: ignore + return self.overview_val_container( # type: ignore + Operation_Id=self._opnode.node_id, # type: ignore + Operation_Type=self._opnode.type, # type: ignore + Operation_Name=self._opnode.name, # type: ignore + Infer_Time=self.__InferTime, # type: ignore + Throughput=self.__Throughput, # type: ignore ) @property def crucial_data(self) -> Dict[str, str]: self.__is_valid_access() - res_dict = {'Inference Elapse': str(self.InferTime), - 'Throughput': str(self.Throughput)} - max_keylen = max([len(key) for key in res_dict.keys()]) + res_dict = { + "Inference Elapse": str(self.InferTime), + "Throughput": str(self.Throughput), + } + max_keylen = max([len(key) for key in res_dict]) res_dict = {key.ljust(max_keylen): value for key, value in res_dict.items()} return res_dict - def measure(self, device:tc_device, repeat:int=50, global_process:Optional[tqdm]=None) -> RemovableHandle: - + def measure(self, device: tc_device, repeat: int = 50, global_process: Optional[tqdm] = None) -> RemovableHandle: self._model.to(device, non_blocking=True) hook = self._model.register_forward_hook( - partial(self.__hook_func, - device=device, - repeat=repeat, - global_process=global_process) + partial( + self.__hook_func, + device=device, + repeat=repeat, + global_process=global_process, + ) ) - + self.is_measured = True - + return hook - def __hook_func(self, module, input, output, device:tc_device, repeat:int=50, global_process:Optional[tqdm]=None): + def __hook_func( + self, + module: nn.Module, + ipt: Any, + opt: Any, # noqa: ARG002 + device: tc_device, + repeat: int = 50, + global_process: Optional[tqdm] = None, + ) -> None: self.__InferTime.clear() self.__Throughput.clear() self.__stat_ls.clear() - + module._forward_hooks.clear() module.eval() - if device.type == 'cpu': + if device.type == "cpu": cpu_timer = perf_counter - elif device.type == 'cuda': - start_event:Event = cuda_event(enable_timing=True) - end_event:Event = cuda_event(enable_timing=True) + elif device.type == "cuda": + start_event: Event = cuda_event(enable_timing=True) + end_event: Event = cuda_event(enable_timing=True) gpu_start_timer = start_event.record gpu_end_timer = end_event.record - cuda_sync() # WAIT FOR GPU SYNC - + cuda_sync() # WAIT FOR GPU SYNC + with no_grad(): for _ in range(repeat): - start_time = cpu_timer() if device.type == 'cpu' else gpu_start_timer() + start_time = cpu_timer() if device.type == "cpu" else gpu_start_timer() - module(*input) + module(*ipt) - end_time = cpu_timer() if device.type == 'cpu' else gpu_end_timer() + end_time = cpu_timer() if device.type == "cpu" else gpu_end_timer() - if device.type == 'cpu': - it = end_time-start_time + if device.type == "cpu": + it = end_time - start_time else: cuda_sync() # WAIT FOR GPU SYNC - it = start_event.elapsed_time(end_event)*1e-3 # ms -> s # type: ignore + it = start_event.elapsed_time(end_event) * 1e-3 # ms -> s # type: ignore - tp = input[0].shape[0]/it # TODO: batch infer - self.__InferTime.append(it) + tp = ipt[0].shape[0] / it # TODO: batch infer + self.__InferTime.append(it) self.__Throughput.append(tp) if global_process is not None: global_process.update(1) self.__stat_ls.append( - self.detail_val_container(Operation_Id=self._opnode.node_id, # type: ignore - Operation_Name=self._opnode.name, # type: ignore - Operation_Type=self._opnode.type, # type: ignore - Infer_Time=self.InferTime, # type: ignore - Throughput=self.Throughput) # type: ignore + self.detail_val_container( + Operation_Id=self._opnode.node_id, # type: ignore + Operation_Name=self._opnode.name, # type: ignore + Operation_Type=self._opnode.type, # type: ignore + Infer_Time=self.InferTime, # type: ignore + Throughput=self.Throughput, # type: ignore + ) ) def __is_valid_access(self) -> bool: @@ -801,6 +890,7 @@ def __is_valid_access(self) -> bool: if not self.__stat_ls and not isinstance(self._model, (nn.ModuleDict, nn.ModuleList)): raise RuntimeError("This module might be defined but not explicitly called, so no data is collected.") else: - raise AttributeError("You should never access this property on your own " + \ - "before accessing `Meter(your_model).ittp`.") + raise AttributeError( + "You should never access this property on your own before accessing `Meter(your_model).ittp`." + ) return True diff --git a/torchmeter/unit.py b/torchmeter/unit.py index c34893d..fefd37f 100644 --- a/torchmeter/unit.py +++ b/torchmeter/unit.py @@ -1,17 +1,24 @@ from __future__ import annotations -from typing import TYPE_CHECKING from enum import Enum, IntFlag, unique +from typing import TYPE_CHECKING if TYPE_CHECKING: - from typing import Union + from typing import Type, Union import numpy as np FLOAT = Union[float, np.float_] -__all__ = ["CountUnit", "BinaryUnit", "TimeUnit", "SpeedUnit", - "auto_unit"] + UNITS = Union[ + Type["CountUnit"], + Type["BinaryUnit"], + Type["TimeUnit"], + Type["SpeedUnit"], + ] + +__all__ = ["CountUnit", "BinaryUnit", "TimeUnit", "SpeedUnit", "auto_unit"] + @unique class CountUnit(Enum): @@ -20,6 +27,7 @@ class CountUnit(Enum): M = 1e6 K = 1e3 + @unique class BinaryUnit(IntFlag): TiB = 2**40 @@ -28,15 +36,17 @@ class BinaryUnit(IntFlag): KiB = 2**10 B = 2**0 + @unique class TimeUnit(Enum): - h = 60**2 + h = 60**2 min = 60**1 - s = 60**0 + s = 60**0 ms = 1e-3 us = 1e-6 ns = 1e-9 + @unique class SpeedUnit(Enum): TIPS = 1e12 @@ -45,8 +55,11 @@ class SpeedUnit(Enum): KIPS = 1e3 IPS = 1e0 -def auto_unit(val:Union[int, FLOAT], unit_system=CountUnit) -> str: - for unit in list(unit_system): + +def auto_unit(val: Union[int, FLOAT], unit_system: UNITS = CountUnit) -> str: + unit: Enum + + for unit in list(unit_system): # type: ignore if val >= unit.value: if val % unit.value: return f"{val / unit.value:.2f} {unit.name}" diff --git a/torchmeter/utils.py b/torchmeter/utils.py index cf51468..f209049 100644 --- a/torchmeter/utils.py +++ b/torchmeter/utils.py @@ -1,8 +1,8 @@ from __future__ import annotations -from typing import TYPE_CHECKING import os from time import perf_counter +from typing import TYPE_CHECKING from inspect import signature from functools import partial @@ -10,105 +10,114 @@ from rich.status import Status if TYPE_CHECKING: - from typing import Any, List, Tuple - from typing import Union, Optional, Callable, Iterable + from types import TracebackType + from typing import Any, List, Type, Tuple, Union, Callable, Iterable, Optional from polars import PolarsDataType __all__ = ["dfs_task", "data_repr", "Timer"] -def resolve_savepath(origin_path:str, - target_ext:str, - default_filename:str='Data') -> Tuple[str, str]: + +def resolve_savepath(origin_path: str, target_ext: str, default_filename: str = "Data") -> Tuple[str, str]: origin_path = os.path.abspath(origin_path) - dir, file = os.path.split(origin_path) - + directory, file = os.path.split(origin_path) + # origin_path is a file path - if '.' in file: - os.makedirs(dir, exist_ok=True) - save_dir = dir - save_file = os.path.join(dir, os.path.splitext(file)[0]+f".{target_ext}") - + if "." in file: + os.makedirs(directory, exist_ok=True) + save_dir = directory + save_file = os.path.join(directory, os.path.splitext(file)[0] + f".{target_ext}") + # origin_path is a dir path - else: + else: os.makedirs(origin_path, exist_ok=True) save_dir = origin_path save_file = os.path.join(origin_path, f"{default_filename}.{target_ext}") - + return save_dir, save_file -def hasargs(func:Callable, *required_args:str) -> None: + +def hasargs(func: Callable, *required_args: str) -> None: if not required_args: - return - + return + missing_args = [arg for arg in required_args - if arg not in signature(func).parameters] - + if arg not in signature(func).parameters] # fmt: skip + if missing_args: raise RuntimeError(f"Function `{func.__name__}()` is missing following required args: {missing_args}.") -def dfs_task(dfs_subject:Any, - adj_func:Callable[[Any], Iterable], - task_func:Callable[[Any, Any], Any], - visited_signal_func:Callable[[Any], Any]=lambda x:id(x), - *, - visited:Optional[List]=None) -> Any: - hasargs(task_func, 'subject', 'pre_res') + +def dfs_task( + dfs_subject: Any, + adj_func: Callable[[Any], Iterable], + task_func: Callable[[Any, Any], Any], + visited_signal_func: Callable[[Any], Any] = lambda x: id(x), + *, + visited: Optional[List] = None, +) -> Any: + hasargs(task_func, "subject", "pre_res") visited_signal = visited_signal_func(dfs_subject) - + visited = visited or [] - + if visited_signal not in visited: visited.append(visited_signal) try: - task_res = task_func(subject=dfs_subject) # type: ignore + task_res = task_func(subject=dfs_subject) # type: ignore except TypeError: # use empty list when no default value for `pre_res` - task_res = task_func(subject=dfs_subject, pre_res=[]) # type: ignore - - for adj in adj_func(dfs_subject): - dfs_task(dfs_subject=adj, - adj_func=adj_func, - task_func=partial(task_func, pre_res=task_res), # type: ignore - visited_signal_func=visited_signal_func, - visited=visited) - + task_res = task_func(subject=dfs_subject, pre_res=[]) # type: ignore + + for adj in adj_func(dfs_subject): + dfs_task( + dfs_subject=adj, + adj_func=adj_func, + task_func=partial(task_func, pre_res=task_res), # type: ignore + visited_signal_func=visited_signal_func, + visited=visited, + ) + try: return task_res - except UnboundLocalError: # revisit visited node + except UnboundLocalError: # revisit visited node return None -def indent_str(s:Union[str, Iterable[str]], - indent:int=4, - guideline:bool=True, - process_first:bool=True) -> str: + +def indent_str( + s: Union[str, Iterable[str]], + indent: int = 4, + guideline: bool = True, + process_first: bool = True, +) -> str: if isinstance(s, str): - split_lines:List[str] = s.split("\n") - - elif hasattr(s, '__iter__'): + split_lines: List[str] = s.split("\n") + + elif hasattr(s, "__iter__"): split_lines = [] for i in s: - if not isinstance(i,str): - raise TypeError("The input should be a string or an iterable object of strings, " + \ - f"but got `{type(i).__name__}` when travering input.") + if not isinstance(i, str): + raise TypeError( + "The input should be a string or an iterable object of strings, " + + f"but got `{type(i).__name__}` when travering input." + ) split_lines.extend(i.split("\n")) - + else: - raise TypeError("The input should be a string or a sequence of strings, " + \ - f"but got `{type(s).__name__}`.") - + raise TypeError(f"The input should be a string or a sequence of strings, but got `{type(s).__name__}`.") + if not isinstance(indent, int): raise TypeError(f"The indent should be an integer, but got `{type(indent).__name__}`") indent = max(indent, 0) - + res = [] - guideline = not len(split_lines) == 1 and guideline - + guideline = len(split_lines) != 1 and guideline + if indent: for line in split_lines: - indent_line = "β”‚" if guideline else " " - indent_line += " "*(indent-1) + line + indent_line = "β”‚" if guideline else " " + indent_line += " " * (indent - 1) + line res.append(indent_line) if not process_first: @@ -118,53 +127,67 @@ def indent_str(s:Union[str, Iterable[str]], res[-1] = "└─"[:indent].ljust(indent) + res[-1][indent:] else: res = split_lines - - return '\n'.join(res) -def data_repr(val:Any) -> str: + return "\n".join(res) + + +def data_repr(val: Any) -> str: get_type = lambda val: type(val).__name__ - item_repr = lambda val_type, val: (f"[dim]Shape[/]([b green]{list(val.shape)}[/])" if hasattr(val, 'shape') else f"[b green]{val}[/]") + f" [dim]<{val_type}>[/]" + item_repr = ( + lambda val_type, val: ( + f"[dim]Shape[/]([b green]{list(val.shape)}[/])" if hasattr(val, "shape") else f"[b green]{val}[/]" + ) + + f" [dim]<{val_type}>[/]" + ) val_type = get_type(val) if isinstance(val, (list, tuple, set, dict)) and len(val) > 0: if isinstance(val, dict): - inner_repr_parts = [(item_repr(get_type(k),k), data_repr(v)) for k, v in val.items()] - inner_repr:List[str] = [indent_str(f"{record[0]}: {record[1]}", - indent=2+Text.from_markup(record[0]).cell_len, - guideline=False, process_first=False) - for record in inner_repr_parts] + inner_repr_parts = [(item_repr(get_type(k), k), data_repr(v)) for k, v in val.items()] + inner_repr: List[str] = [ + indent_str( + f"{record[0]}: {record[1]}", + indent=2 + Text.from_markup(record[0]).cell_len, + guideline=False, + process_first=False, + ) + for record in inner_repr_parts + ] else: inner_repr = [data_repr(i) for i in val] - + res_repr = f"[dim]{val_type}[/](" - res_repr += ',\n'.join(inner_repr) - res_repr += ')' + res_repr += ",\n".join(inner_repr) + res_repr += ")" return indent_str(res_repr, indent=len(f"{val_type}("), process_first=False) - + elif "function" in val_type and callable(val): return f"[b green]{val.__name__}[/] [dim][/]" - - elif hasattr(val, 'shape'): + + elif hasattr(val, "shape"): if any(not isinstance(i, int) for i in list(val.shape)): return f"[b green]obj[/] [dim]<{val.__class__.__module__}.{val_type}>[/]" return item_repr(val_type, val) - + elif val.__class__.__module__ != "builtins": return f"[b green]obj[/] [dim]<{val.__class__.__module__}.{val_type}>[/]" else: return item_repr(val_type, val) -def match_polars_type(ipt:Any, *, - recheck:bool=False, - pre_res:Optional[PolarsDataType]=None)-> PolarsDataType: - + +def match_polars_type( + ipt: Any, + *, + recheck: bool = False, + pre_res: Optional[PolarsDataType] = None, +) -> PolarsDataType: import numpy as np import polars as pl - from polars.datatypes._parse import parse_into_dtype from polars.series.series import _resolve_temporal_dtype + from polars.datatypes._parse import parse_into_dtype if not recheck and pre_res is not None: return pre_res @@ -172,21 +195,21 @@ def match_polars_type(ipt:Any, *, try: pl_type = parse_into_dtype(type(ipt)) if isinstance(ipt, (list, tuple)): - # TODO: inner type awareness (following type priority) + # TODO: inner type awareness (following type priority) inner_type = match_polars_type(ipt[0]) - return pl_type(inner_type) # type: ignore - + return pl_type(inner_type) # type: ignore + return pl_type - + except TypeError: if isinstance(ipt, dict): - fields = {k:match_polars_type(v) for k,v in ipt.items()} + fields = {k: match_polars_type(v) for k, v in ipt.items()} return pl.Struct(fields=fields) - + elif isinstance(ipt, (np.datetime64, np.timedelta64)): - pl_type = _resolve_temporal_dtype(None, np.dtype(ipt)) # type: ignore + pl_type = _resolve_temporal_dtype(None, np.dtype(ipt)) # type: ignore return pl_type or pl.Object - + elif isinstance(ipt, (np.integer, np.floating)): return { np.int8: pl.Int8, @@ -198,28 +221,33 @@ def match_polars_type(ipt:Any, *, np.uint32: pl.UInt32, np.uint64: pl.UInt64, np.float32: pl.Float32, - np.float64: pl.Float64 + np.float64: pl.Float64, }[type(ipt)] - + elif isinstance(ipt, np.ndarray): return pl.Series(ipt).dtype - + else: # class instance return pl.Object + class Timer(Status): - def __init__(self, task_desc:str, - *args, **kwargs) -> None: - super(Timer, self).__init__(status=task_desc, *args, **kwargs) # type: ignore + def __init__(self, task_desc: str, *args, **kwargs) -> None: + super(Timer, self).__init__(status=task_desc, *args, **kwargs) # type: ignore self.task_desc = task_desc - - def __enter__(self): + + def __enter__(self) -> Timer: super().__enter__() self.__start_time = perf_counter() return self - def __exit__(self, exc_type, exc_val, exc_tb): + def __exit__( + self, + exc_type: Optional[Type[BaseException]], + exc_val: Optional[BaseException], + exc_tb: Optional[TracebackType], + ) -> None: elapsed_time = perf_counter() - self.__start_time super().__exit__(exc_type, exc_val, exc_tb) - self.console.print(f"[b blue]Finish {self.task_desc} in [green]{elapsed_time:.4f}[/green] seconds[/]") \ No newline at end of file + self.console.print(f"[b blue]Finish {self.task_desc} in [green]{elapsed_time:.4f}[/green] seconds[/]")