Skip to content
This repository was archived by the owner on Jun 3, 2025. It is now read-only.

Commit 6e9cc4f

Browse files
authored
Add framework fallback ability to execute in sparseml (#413) (#422)
* Add framework fallback ability to execute in sparseml * remove unused variable * decrease complexity of falling back on framework execution * quality and test fixes * update docs * fix tests
1 parent 3f165e6 commit 6e9cc4f

File tree

3 files changed

+94
-55
lines changed

3 files changed

+94
-55
lines changed

src/sparseml/base.py

Lines changed: 90 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515

1616
import importlib
1717
import logging
18+
from collections import OrderedDict
1819
from enum import Enum
1920
from typing import Any, List, Optional
2021

@@ -25,6 +26,7 @@
2526

2627
__all__ = [
2728
"Framework",
29+
"detect_frameworks",
2830
"detect_framework",
2931
"execute_in_sparseml_framework",
3032
"get_version",
@@ -48,61 +50,109 @@ class Framework(Enum):
4850
tensorflow_v1 = "tensorflow_v1"
4951

5052

51-
def detect_framework(item: Any) -> Framework:
53+
def _execute_sparseml_package_function(
54+
framework: Framework, function_name: str, *args, **kwargs
55+
):
56+
try:
57+
module = importlib.import_module(f"sparseml.{framework.value}")
58+
function = getattr(module, function_name)
59+
except Exception as err:
60+
raise ValueError(
61+
f"unknown or unsupported framework {framework}, "
62+
f"cannot call function {function_name}: {err}"
63+
)
64+
65+
return function(*args, **kwargs)
66+
67+
68+
def detect_frameworks(item: Any) -> List[Framework]:
5269
"""
53-
Detect the supported ML framework for a given item.
70+
Detects the supported ML frameworks for a given item.
5471
Supported input types are the following:
5572
- A Framework enum
5673
- A string of any case representing the name of the framework
5774
(deepsparse, onnx, keras, pytorch, tensorflow_v1)
5875
- A supported file type within the framework such as model files:
5976
(onnx, pth, h5, pb)
6077
- An object from a supported ML framework such as a model instance
61-
If the framework cannot be determined, will return Framework.unknown
78+
If the framework cannot be determined, an empty list will be returned
79+
6280
:param item: The item to detect the ML framework for
6381
:type item: Any
64-
:return: The detected framework from the given item
65-
:rtype: Framework
82+
:return: The detected ML frameworks from the given item
83+
:rtype: List[Framework]
6684
"""
67-
_LOGGER.debug("detecting framework for %s", item)
68-
framework = Framework.unknown
85+
_LOGGER.debug("detecting frameworks for %s", item)
86+
frameworks = []
87+
88+
if isinstance(item, str) and item.lower().strip() in Framework.__members__:
89+
_LOGGER.debug("framework detected from Framework string instance")
90+
item = Framework[item.lower().strip()]
6991

7092
if isinstance(item, Framework):
7193
_LOGGER.debug("framework detected from Framework instance")
72-
framework = item
73-
elif isinstance(item, str) and item.lower().strip() in Framework.__members__:
74-
_LOGGER.debug("framework detected from Framework string instance")
75-
framework = Framework[item.lower().strip()]
94+
95+
if item != Framework.unknown:
96+
frameworks.append(item)
7697
else:
77-
_LOGGER.debug("detecting framework by calling into supported frameworks")
98+
_LOGGER.debug("detecting frameworks by calling into supported frameworks")
99+
frameworks = []
78100

79101
for test in Framework:
102+
if test == Framework.unknown:
103+
continue
104+
80105
try:
81-
framework = execute_in_sparseml_framework(
106+
detected = _execute_sparseml_package_function(
82107
test, "detect_framework", item
83108
)
109+
frameworks.append(detected)
84110
except Exception as err:
85111
# errors are expected if the framework is not installed, log as debug
86-
_LOGGER.debug(f"error while calling detect_framework for {test}: {err}")
112+
_LOGGER.debug(
113+
"error while calling detect_framework for %s: %s", test, err
114+
)
115+
116+
_LOGGER.info("detected frameworks of %s from %s", frameworks, item)
117+
118+
return frameworks
87119

88-
if framework != Framework.unknown:
89-
break
90120

91-
_LOGGER.info("detected framework of %s from %s", framework, item)
121+
def detect_framework(item: Any) -> Framework:
122+
"""
123+
Detect the supported ML framework for a given item.
124+
Supported input types are the following:
125+
- A Framework enum
126+
- A string of any case representing the name of the framework
127+
(deepsparse, onnx, keras, pytorch, tensorflow_v1)
128+
- A supported file type within the framework such as model files:
129+
(onnx, pth, h5, pb)
130+
- An object from a supported ML framework such as a model instance
131+
If the framework cannot be determined, will return Framework.unknown
132+
133+
:param item: The item to detect the ML framework for
134+
:type item: Any
135+
:return: The detected framework from the given item
136+
:rtype: Framework
137+
"""
138+
_LOGGER.debug("detecting framework for %s", item)
139+
frameworks = detect_frameworks(item)
92140

93-
return framework
141+
return frameworks[0] if len(frameworks) > 0 else Framework.unknown
94142

95143

96144
def execute_in_sparseml_framework(
97-
framework: Framework, function_name: str, *args, **kwargs
145+
framework: Any, function_name: str, *args, **kwargs
98146
) -> Any:
99147
"""
100148
Execute a general function that is callable from the root of the frameworks
101149
package under SparseML such as sparseml.pytorch.
102150
Useful for benchmarking, analyzing, etc.
103151
Will pass the args and kwargs to the callable function.
104-
:param framework: The ML framework to run the function under in SparseML.
105-
:type framework: Framework
152+
153+
:param framework: The item to detect the ML framework for to run the function under,
154+
see detect_frameworks for more details on acceptible inputs
155+
:type framework: Any
106156
:param function_name: The name of the function in SparseML that should be run
107157
with the given args and kwargs.
108158
:type function_name: str
@@ -119,25 +169,28 @@ def execute_in_sparseml_framework(
119169
kwargs,
120170
)
121171

122-
if not isinstance(framework, Framework):
123-
framework = detect_framework(framework)
172+
framework_errs = OrderedDict()
173+
test_frameworks = detect_frameworks(framework)
124174

125-
if framework == Framework.unknown:
126-
raise ValueError(
127-
f"unknown or unsupported framework {framework}, "
128-
f"cannot call function {function_name}"
129-
)
175+
for test_framework in test_frameworks:
176+
try:
177+
module = importlib.import_module(f"sparseml.{test_framework.value}")
178+
function = getattr(module, function_name)
130179

131-
try:
132-
module = importlib.import_module(f"sparseml.{framework.value}")
133-
function = getattr(module, function_name)
134-
except Exception as err:
135-
raise ValueError(
136-
f"could not find function_name {function_name} in framework {framework}: "
137-
f"{err}"
138-
)
180+
return function(*args, **kwargs)
181+
except Exception as err:
182+
framework_errs[framework] = err
139183

140-
return function(*args, **kwargs)
184+
if len(framework_errs) == 1:
185+
raise list(framework_errs.values())[0]
186+
187+
if len(framework_errs) > 1:
188+
raise RuntimeError(str(framework_errs))
189+
190+
raise ValueError(
191+
f"unknown or unsupported framework {framework}, "
192+
f"cannot call function {function_name}"
193+
)
141194

142195

143196
def get_version(

src/sparseml/benchmark/info.py

Lines changed: 3 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -96,7 +96,7 @@
9696

9797
from tqdm import auto
9898

99-
from sparseml.base import Framework, detect_framework, execute_in_sparseml_framework
99+
from sparseml.base import Framework, execute_in_sparseml_framework
100100
from sparseml.benchmark.serialization import (
101101
BatchBenchmarkResult,
102102
BenchmarkConfig,
@@ -369,13 +369,8 @@ def save_benchmark_results(
369369
pass to the runner
370370
:param show_progress: True to show a tqdm bar when running, False otherwise
371371
"""
372-
if framework is None:
373-
framework = detect_framework(model)
374-
else:
375-
framework = detect_framework(framework)
376-
377372
results = execute_in_sparseml_framework(
378-
framework,
373+
framework if framework is not None else model,
379374
"run_benchmark",
380375
model,
381376
data,
@@ -442,18 +437,9 @@ def load_and_run_benchmark(
442437
:param save_path: path to save the new benchmark results
443438
"""
444439
_LOGGER.info(f"rerunning benchmark {load}")
445-
446440
info = load_benchmark_info(load)
447-
448-
framework = info.framework
449-
450-
if framework is None:
451-
framework = detect_framework(model)
452-
else:
453-
framework = detect_framework(framework)
454-
455441
save_benchmark_results(
456-
model,
442+
info.framework if info.framework is not None else model,
457443
data,
458444
batch_size=info.config.batch_size,
459445
iterations=info.config.iterations,

tests/sparseml/test_base.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,7 @@ def test_execute_in_sparseml_framework():
6262
with pytest.raises(ValueError):
6363
execute_in_sparseml_framework(Framework.unknown, "unknown")
6464

65-
with pytest.raises(ValueError):
65+
with pytest.raises(Exception):
6666
execute_in_sparseml_framework(Framework.onnx, "unknown")
6767

6868
# TODO: fill in with sample functions to execute in frameworks once available

0 commit comments

Comments
 (0)