1515
1616import importlib
1717import logging
18+ from collections import OrderedDict
1819from enum import Enum
1920from typing import Any , List , Optional
2021
2526
2627__all__ = [
2728 "Framework" ,
29+ "detect_frameworks" ,
2830 "detect_framework" ,
2931 "execute_in_sparseml_framework" ,
3032 "get_version" ,
@@ -48,61 +50,109 @@ class Framework(Enum):
4850 tensorflow_v1 = "tensorflow_v1"
4951
5052
51- def detect_framework (item : Any ) -> Framework :
53+ def _execute_sparseml_package_function (
54+ framework : Framework , function_name : str , * args , ** kwargs
55+ ):
56+ try :
57+ module = importlib .import_module (f"sparseml.{ framework .value } " )
58+ function = getattr (module , function_name )
59+ except Exception as err :
60+ raise ValueError (
61+ f"unknown or unsupported framework { framework } , "
62+ f"cannot call function { function_name } : { err } "
63+ )
64+
65+ return function (* args , ** kwargs )
66+
67+
68+ def detect_frameworks (item : Any ) -> List [Framework ]:
5269 """
53- Detect the supported ML framework for a given item.
70+ Detects the supported ML frameworks for a given item.
5471 Supported input types are the following:
5572 - A Framework enum
5673 - A string of any case representing the name of the framework
5774 (deepsparse, onnx, keras, pytorch, tensorflow_v1)
5875 - A supported file type within the framework such as model files:
5976 (onnx, pth, h5, pb)
6077 - An object from a supported ML framework such as a model instance
61- If the framework cannot be determined, will return Framework.unknown
78+ If the framework cannot be determined, an empty list will be returned
79+
6280 :param item: The item to detect the ML framework for
6381 :type item: Any
64- :return: The detected framework from the given item
65- :rtype: Framework
82+ :return: The detected ML frameworks from the given item
83+ :rtype: List[ Framework]
6684 """
67- _LOGGER .debug ("detecting framework for %s" , item )
68- framework = Framework .unknown
85+ _LOGGER .debug ("detecting frameworks for %s" , item )
86+ frameworks = []
87+
88+ if isinstance (item , str ) and item .lower ().strip () in Framework .__members__ :
89+ _LOGGER .debug ("framework detected from Framework string instance" )
90+ item = Framework [item .lower ().strip ()]
6991
7092 if isinstance (item , Framework ):
7193 _LOGGER .debug ("framework detected from Framework instance" )
72- framework = item
73- elif isinstance (item , str ) and item .lower ().strip () in Framework .__members__ :
74- _LOGGER .debug ("framework detected from Framework string instance" )
75- framework = Framework [item .lower ().strip ()]
94+
95+ if item != Framework .unknown :
96+ frameworks .append (item )
7697 else :
77- _LOGGER .debug ("detecting framework by calling into supported frameworks" )
98+ _LOGGER .debug ("detecting frameworks by calling into supported frameworks" )
99+ frameworks = []
78100
79101 for test in Framework :
102+ if test == Framework .unknown :
103+ continue
104+
80105 try :
81- framework = execute_in_sparseml_framework (
106+ detected = _execute_sparseml_package_function (
82107 test , "detect_framework" , item
83108 )
109+ frameworks .append (detected )
84110 except Exception as err :
85111 # errors are expected if the framework is not installed, log as debug
86- _LOGGER .debug (f"error while calling detect_framework for { test } : { err } " )
112+ _LOGGER .debug (
113+ "error while calling detect_framework for %s: %s" , test , err
114+ )
115+
116+ _LOGGER .info ("detected frameworks of %s from %s" , frameworks , item )
117+
118+ return frameworks
87119
88- if framework != Framework .unknown :
89- break
90120
91- _LOGGER .info ("detected framework of %s from %s" , framework , item )
121+ def detect_framework (item : Any ) -> Framework :
122+ """
123+ Detect the supported ML framework for a given item.
124+ Supported input types are the following:
125+ - A Framework enum
126+ - A string of any case representing the name of the framework
127+ (deepsparse, onnx, keras, pytorch, tensorflow_v1)
128+ - A supported file type within the framework such as model files:
129+ (onnx, pth, h5, pb)
130+ - An object from a supported ML framework such as a model instance
131+ If the framework cannot be determined, will return Framework.unknown
132+
133+ :param item: The item to detect the ML framework for
134+ :type item: Any
135+ :return: The detected framework from the given item
136+ :rtype: Framework
137+ """
138+ _LOGGER .debug ("detecting framework for %s" , item )
139+ frameworks = detect_frameworks (item )
92140
93- return framework
141+ return frameworks [ 0 ] if len ( frameworks ) > 0 else Framework . unknown
94142
95143
96144def execute_in_sparseml_framework (
97- framework : Framework , function_name : str , * args , ** kwargs
145+ framework : Any , function_name : str , * args , ** kwargs
98146) -> Any :
99147 """
100148 Execute a general function that is callable from the root of the frameworks
101149 package under SparseML such as sparseml.pytorch.
102150 Useful for benchmarking, analyzing, etc.
103151 Will pass the args and kwargs to the callable function.
104- :param framework: The ML framework to run the function under in SparseML.
105- :type framework: Framework
152+
153+ :param framework: The item to detect the ML framework for to run the function under,
154+ see detect_frameworks for more details on acceptible inputs
155+ :type framework: Any
106156 :param function_name: The name of the function in SparseML that should be run
107157 with the given args and kwargs.
108158 :type function_name: str
@@ -119,25 +169,28 @@ def execute_in_sparseml_framework(
119169 kwargs ,
120170 )
121171
122- if not isinstance ( framework , Framework ):
123- framework = detect_framework (framework )
172+ framework_errs = OrderedDict ()
173+ test_frameworks = detect_frameworks (framework )
124174
125- if framework == Framework .unknown :
126- raise ValueError (
127- f"unknown or unsupported framework { framework } , "
128- f"cannot call function { function_name } "
129- )
175+ for test_framework in test_frameworks :
176+ try :
177+ module = importlib .import_module (f"sparseml.{ test_framework .value } " )
178+ function = getattr (module , function_name )
130179
131- try :
132- module = importlib .import_module (f"sparseml.{ framework .value } " )
133- function = getattr (module , function_name )
134- except Exception as err :
135- raise ValueError (
136- f"could not find function_name { function_name } in framework { framework } : "
137- f"{ err } "
138- )
180+ return function (* args , ** kwargs )
181+ except Exception as err :
182+ framework_errs [framework ] = err
139183
140- return function (* args , ** kwargs )
184+ if len (framework_errs ) == 1 :
185+ raise list (framework_errs .values ())[0 ]
186+
187+ if len (framework_errs ) > 1 :
188+ raise RuntimeError (str (framework_errs ))
189+
190+ raise ValueError (
191+ f"unknown or unsupported framework { framework } , "
192+ f"cannot call function { function_name } "
193+ )
141194
142195
143196def get_version (
0 commit comments