From 3568e87aba57e12aeea094ee7ef16ec652f47647 Mon Sep 17 00:00:00 2001 From: NimaSarajpoor Date: Tue, 28 Jan 2025 00:24:51 -0500 Subject: [PATCH 01/24] empty commit From 00d8ac99cd17dff9e6bc7c34185776e49d89b732 Mon Sep 17 00:00:00 2001 From: NimaSarajpoor Date: Tue, 28 Jan 2025 01:05:36 -0500 Subject: [PATCH 02/24] fix fastmath for aamp._compute_diagonal --- stumpy/aamp.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/stumpy/aamp.py b/stumpy/aamp.py index 1e4879bcc..ff4a672b4 100644 --- a/stumpy/aamp.py +++ b/stumpy/aamp.py @@ -13,7 +13,7 @@ @njit( # "(f8[:], f8[:], i8, b1[:], b1[:], f8, i8[:], i8, i8, i8, f8[:, :, :]," # "f8[:, :], f8[:, :], i8[:, :, :], i8[:, :], i8[:, :], b1)", - fastmath=config.STUMPY_FASTMATH_TRUE, + fastmath=config.STUMPY_FASTMATH_FLAGS, ) def _compute_diagonal( T_A, From df9381b614fc475c7224d50095d73e473bb28b1c Mon Sep 17 00:00:00 2001 From: NimaSarajpoor Date: Tue, 28 Jan 2025 01:06:35 -0500 Subject: [PATCH 03/24] fix fastmath for aamp._aamp --- stumpy/aamp.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/stumpy/aamp.py b/stumpy/aamp.py index ff4a672b4..74236a7bf 100644 --- a/stumpy/aamp.py +++ b/stumpy/aamp.py @@ -186,7 +186,7 @@ def _compute_diagonal( @njit( # "(f8[:], f8[:], i8, b1[:], b1[:], i8[:], b1, i8)", parallel=True, - fastmath=config.STUMPY_FASTMATH_TRUE, + fastmath=config.STUMPY_FASTMATH_FLAGS, ) def _aamp( T_A, From 6abe8de27527a5cb8e2fee8db692410b2902b814 Mon Sep 17 00:00:00 2001 From: NimaSarajpoor Date: Tue, 28 Jan 2025 01:29:32 -0500 Subject: [PATCH 04/24] fix fastmath for core._calculate_squared_distance_profile --- stumpy/core.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/stumpy/core.py b/stumpy/core.py index a7758c2fd..4bd546d18 100644 --- a/stumpy/core.py +++ b/stumpy/core.py @@ -1111,7 +1111,7 @@ def _calculate_squared_distance( @njit( # "f8[:](i8, f8[:], f8, f8, f8[:], f8[:])", - fastmath=config.STUMPY_FASTMATH_TRUE, + fastmath=config.STUMPY_FASTMATH_FLAGS, ) def _calculate_squared_distance_profile( m, QT, μ_Q, σ_Q, M_T, Σ_T, Q_subseq_isconstant, T_subseq_isconstant From a29e57601fa3331f8d8fb68e46e7bf071773cd35 Mon Sep 17 00:00:00 2001 From: NimaSarajpoor Date: Tue, 28 Jan 2025 01:37:27 -0500 Subject: [PATCH 05/24] fix fastmath for core.calculate_distance_profile --- stumpy/core.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/stumpy/core.py b/stumpy/core.py index 4bd546d18..2787647cb 100644 --- a/stumpy/core.py +++ b/stumpy/core.py @@ -1177,7 +1177,7 @@ def _calculate_squared_distance_profile( @njit( # "f8[:](i8, f8[:], f8, f8, f8[:], f8[:])", - fastmath=config.STUMPY_FASTMATH_TRUE, + fastmath=config.STUMPY_FASTMATH_FLAGS, ) def calculate_distance_profile( m, QT, μ_Q, σ_Q, M_T, Σ_T, Q_subseq_isconstant, T_subseq_isconstant From 2581e5bc6f4bddb43c6de17faa53145ca5a2957e Mon Sep 17 00:00:00 2001 From: NimaSarajpoor Date: Tue, 28 Jan 2025 01:47:54 -0500 Subject: [PATCH 06/24] fix fastmath for core._apply_exclusion_zone --- stumpy/core.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/stumpy/core.py b/stumpy/core.py index 2787647cb..98c15d0c9 100644 --- a/stumpy/core.py +++ b/stumpy/core.py @@ -1979,7 +1979,7 @@ def _get_QT(start, T_A, T_B, m): @njit( # ["(f8[:], i8, i8)", "(f8[:, :], i8, i8)"], - fastmath=config.STUMPY_FASTMATH_TRUE + fastmath=config.STUMPY_FASTMATH_FLAGS ) def _apply_exclusion_zone(a, idx, excl_zone, val): """ From 2ab8b328f34c9de99586f101afc2276d0e868406 Mon Sep 17 00:00:00 2001 From: NimaSarajpoor Date: Thu, 30 Jan 2025 11:12:42 -0500 Subject: [PATCH 07/24] fix fastmath for mstump._compute_multi_D --- stumpy/mstump.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/stumpy/mstump.py b/stumpy/mstump.py index c4b7ed2c9..35d58b130 100644 --- a/stumpy/mstump.py +++ b/stumpy/mstump.py @@ -811,7 +811,7 @@ def _get_multi_QT(start, T, m): # "(i8, i8, i8, f8[:, :], f8[:, :], i8, i8, f8[:, :], f8[:, :], f8[:, :]," # "f8[:, :], f8[:, :], f8[:, :], f8[:, :])", parallel=True, - fastmath=config.STUMPY_FASTMATH_TRUE, + fastmath=config.STUMPY_FASTMATH_FLAGS, ) def _compute_multi_D( d, From 7b4f69478b2bec8331ad36ea26a2a613fc0886e8 Mon Sep 17 00:00:00 2001 From: NimaSarajpoor Date: Thu, 30 Jan 2025 18:45:48 -0500 Subject: [PATCH 08/24] fix fastmath for scraamp._compute_PI --- stumpy/scraamp.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/stumpy/scraamp.py b/stumpy/scraamp.py index 56d83f6b6..e2475579a 100644 --- a/stumpy/scraamp.py +++ b/stumpy/scraamp.py @@ -83,7 +83,7 @@ def _preprocess_prescraamp(T_A, m, T_B=None, s=None): return (T_A, T_B, T_A_subseq_isfinite, T_B_subseq_isfinite, indices, s, excl_zone) -@njit(fastmath=config.STUMPY_FASTMATH_TRUE) +@njit(fastmath=config.STUMPY_FASTMATH_FLAGS) def _compute_PI( T_A, T_B, From 0d67a859bd32a0b07036bbacd19a918ef07cb559 Mon Sep 17 00:00:00 2001 From: NimaSarajpoor Date: Thu, 30 Jan 2025 20:03:52 -0500 Subject: [PATCH 09/24] fix fastmath for scraamp._prescraamp --- stumpy/scraamp.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/stumpy/scraamp.py b/stumpy/scraamp.py index e2475579a..682a83405 100644 --- a/stumpy/scraamp.py +++ b/stumpy/scraamp.py @@ -286,7 +286,7 @@ def _compute_PI( # "(f8[:], f8[:], i8, b1[:], b1[:], f8, i8, i8, f8[:], f8[:]," # "i8[:], optional(i8))", parallel=True, - fastmath=config.STUMPY_FASTMATH_TRUE, + fastmath=config.STUMPY_FASTMATH_FLAGS, ) def _prescraamp( T_A, From e09a766ba86a59a44169f59b130186045b2558eb Mon Sep 17 00:00:00 2001 From: NimaSarajpoor Date: Thu, 30 Jan 2025 20:11:36 -0500 Subject: [PATCH 10/24] fix fastmath for scrump._compute_PI --- stumpy/scrump.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/stumpy/scrump.py b/stumpy/scrump.py index dd5617480..5cc12f3b1 100644 --- a/stumpy/scrump.py +++ b/stumpy/scrump.py @@ -133,7 +133,7 @@ def _preprocess_prescrump( ) -@njit(fastmath=config.STUMPY_FASTMATH_TRUE) +@njit(fastmath=config.STUMPY_FASTMATH_FLAGS) def _compute_PI( T_A, T_B, From c895270fe1fee90f1b60372f441cc2b71061233c Mon Sep 17 00:00:00 2001 From: NimaSarajpoor Date: Thu, 30 Jan 2025 20:13:22 -0500 Subject: [PATCH 11/24] fix fastmath for scrump._prescrump --- stumpy/scrump.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/stumpy/scrump.py b/stumpy/scrump.py index 5cc12f3b1..b9894770d 100644 --- a/stumpy/scrump.py +++ b/stumpy/scrump.py @@ -384,7 +384,7 @@ def _compute_PI( # "(f8[:], f8[:], i8, f8[:], f8[:], f8[:], f8[:], f8[:], i8, i8, f8[:], f8[:]," # "i8[:], optional(i8))", parallel=True, - fastmath=config.STUMPY_FASTMATH_TRUE, + fastmath=config.STUMPY_FASTMATH_FLAGS, ) def _prescrump( T_A, From 2d3a6ac14a628ed87fe9db08c16ad9c2a07b3c4b Mon Sep 17 00:00:00 2001 From: NimaSarajpoor Date: Thu, 30 Jan 2025 20:16:10 -0500 Subject: [PATCH 12/24] fix fastmath for stump._compute_diagonal --- stumpy/stump.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/stumpy/stump.py b/stumpy/stump.py index 18409c6e1..9d4ade6a3 100644 --- a/stumpy/stump.py +++ b/stumpy/stump.py @@ -15,7 +15,7 @@ # "(f8[:], f8[:], i8, f8[:], f8[:], f8[:], f8[:], f8[:], f8[:], f8[:], f8[:]," # "b1[:], b1[:], b1[:], b1[:], i8[:], i8, i8, i8, f8[:, :, :], f8[:, :]," # "f8[:, :], i8[:, :, :], i8[:, :], i8[:, :], b1)", - fastmath=config.STUMPY_FASTMATH_TRUE, + fastmath=config.STUMPY_FASTMATH_FLAGS, ) def _compute_diagonal( T_A, From c24f0a9effff36b6080bf9ee7fc72504ed8b679f Mon Sep 17 00:00:00 2001 From: NimaSarajpoor Date: Thu, 30 Jan 2025 20:17:28 -0500 Subject: [PATCH 13/24] fix fastmath for stump._stump --- stumpy/stump.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/stumpy/stump.py b/stumpy/stump.py index 9d4ade6a3..10e4d0e3f 100644 --- a/stumpy/stump.py +++ b/stumpy/stump.py @@ -247,7 +247,7 @@ def _compute_diagonal( # "(f8[:], f8[:], i8, f8[:], f8[:], f8[:], f8[:], f8[:], f8[:], b1[:], b1[:]," # "b1[:], b1[:], i8[:], b1, i8)", parallel=True, - fastmath=config.STUMPY_FASTMATH_TRUE, + fastmath=config.STUMPY_FASTMATH_FLAGS, ) def _stump( T_A, From ad2c4f23658ea52d3f721a9c2582c89d5793681b Mon Sep 17 00:00:00 2001 From: NimaSarajpoor Date: Fri, 31 Jan 2025 22:45:21 -0500 Subject: [PATCH 14/24] temp commit --- stumpy/utils.py | 225 ++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 225 insertions(+) create mode 100644 stumpy/utils.py diff --git a/stumpy/utils.py b/stumpy/utils.py new file mode 100644 index 000000000..0366921a1 --- /dev/null +++ b/stumpy/utils.py @@ -0,0 +1,225 @@ +import ast + +import pathlib + +from stumpy import cache + +def check_fastmath(decorator): + """ + For the given `decorator` node with type `ast.Call`, + return the value of the `fastmath` argument if it exists. + Otherwise, return `None`. + """ + fastmath_value = None + for n in ast.iter_child_nodes(decorator): + if isinstance(n, ast.keyword) and n.arg == "fastmath": + if isinstance(n.value, ast.Constant): + fastmath_value = n.value.value + elif isinstance(n.value, ast.Set): + fastmath_value = set(item.value for item in n.value.elts) + else: + pass + break + + return fastmath_value + + +def check_njit(fd): + """ + For the given `fd` node with type `ast.FunctionDef`, + return the node of the `njit` decorator if it exists. + Otherwise, return `None`. + """ + decorator_node = None + for decorator in fd.decorator_list: + if not isinstance(decorator, ast.Call): + continue + + obj = decorator.func + if isinstance(obj, ast.Attribute): + name = obj.attr + elif isinstance(obj, ast.Subscript): + name = obj.value.id + elif isinstance(obj, ast.Name): + name = obj.id + else: + msg = f"The type {type(obj)} is not supported." + raise ValueError(msg) + + if name == "njit": + decorator_node = decorator + break + + return decorator_node + + +def check_functions(filepath): + """ + For the given `filepath`, return the function names, + whether the function is decorated with `@njit`, + and the value of the `fastmath` argument if it exists + + Parameters + ---------- + filepath : str + The path to the file + + Returns + ------- + func_names : list + List of function names + + is_njit : list + List of boolean values indicating whether the function is decorated with `@njit` + + fastmath_value : list + List of values of the `fastmath` argument if it exists + """ + file_contents = "" + with open(filepath, encoding="utf8") as f: + file_contents = f.read() + module = ast.parse(file_contents) + + function_definitions = [ + node for node in module.body if isinstance(node, ast.FunctionDef) + ] + + func_names = [fd.name for fd in function_definitions] + + njit_nodes = [check_njit(fd) for fd in function_definitions] + is_njit = [node is not None for node in njit_nodes] + + fastmath_values = [None] * len(njit_nodes) + for i, node in enumerate(njit_nodes): + if node is not None: + fastmath_values[i] = check_fastmath(node) + + return func_names, is_njit, fastmath_values + + +def _get_callees(node, all_functions): + for n in ast.iter_child_nodes(node): + if isinstance(n, ast.Call): + obj = n.func + if isinstance(obj, ast.Attribute): + name = obj.attr + elif isinstance(obj, ast.Subscript): + name = obj.value.id + elif isinstance(obj, ast.Name): + name = obj.id + else: + msg = f"The type {type(obj)} is not supported" + raise ValueError(msg) + + all_functions.append(name) + + _get_callees(n, all_functions) + + +def get_all_callees(fd): + """ + For a given node of type ast.FunctionDef, visit all of its child nodes, + and return a list of all of its callees + """ + all_functions = [] + _get_callees(fd, all_functions) + + return all_functions + + +def check_callees(filepath): + """ + For the given `filepath`, return a dictionary with the key + being the function name and the value being a set of function names + that are called by the function + """ + file_contents = "" + with open(filepath, encoding="utf8") as f: + file_contents = f.read() + module = ast.parse(file_contents) + + function_definitions = [ + node for node in module.body if isinstance(node, ast.FunctionDef) + ] + + callees = {} + for fd in function_definitions: + callees[fd.name] = set(get_all_callees(fd)) + + return callees + + +stumpy_path = pathlib.Path(__file__).parent # / "stumpy" +filepaths = sorted(f for f in pathlib.Path(stumpy_path).iterdir() if f.is_file()) + +all_functions = {} + +ignore = ["__init__.py", "__pycache__"] +for filepath in filepaths: + file_name = filepath.name + if file_name not in ignore and str(filepath).endswith(".py"): + prefix = file_name.replace(".py", "") + + func_names, is_njit, fastmath_values = check_functions(filepath) + func_names = [f"{prefix}.{fn}" for fn in func_names] + + all_functions[file_name] = { + "func_names": func_names, + "is_njit": is_njit, + "fastmath_values": fastmath_values, + } + +all_stumpy_functions = set() +for file_name, file_functions_metadata in all_functions.items(): + all_stumpy_functions.update(file_functions_metadata["func_names"]) + +all_stumpy_functions = list(all_stumpy_functions) +all_stumpy_functions_no_prefix = [f.split(".")[-1] for f in all_stumpy_functions] + + +# output 1: func_metadata +func_metadata = {} +for file_name, file_functions_metadata in all_functions.items(): + for i, f in enumerate(file_functions_metadata["func_names"]): + is_njit = file_functions_metadata["is_njit"][i] + fastmath_value = file_functions_metadata["fastmath_values"][i] + func_metadata[f] = [is_njit, fastmath_value] + + +# output 2: func_callers +func_callers = {} +for f in func_metadata.keys(): + func_callers[f] = [] + +for filepath in filepaths: + file_name = filepath.name + if file_name in ignore or not str(filepath).endswith(".py"): + continue + + prefix = file_name.replace(".py", "") + callees = check_callees(filepath) + + current_callers = set(callees.keys()) + for caller, callee_set in callees.items(): + s = list(callee_set.intersection(all_stumpy_functions_no_prefix)) + if len(s) == 0: + continue + + for c in s: + if c in current_callers: + c_name = prefix + "." + c + else: + idx = all_stumpy_functions_no_prefix.index(c) + c_name = all_stumpy_functions[idx] + + func_callers[c_name].append(f"{prefix}.{caller}") + + +for f, callers in func_callers.items(): + func_callers[f] = list(set(callers)) + + + +for modue_name, func_name in cache.get_njit_funcs(): + f = f"{modue_name}.{func_name}" + print(f, func_callers[f]) \ No newline at end of file From 3c79603fa921cd2c0028c77087a2196420d721ac Mon Sep 17 00:00:00 2001 From: NimaSarajpoor Date: Sat, 1 Feb 2025 22:51:07 -0500 Subject: [PATCH 15/24] fix fastmath for maamp._compute_multi_p_norm --- stumpy/maamp.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/stumpy/maamp.py b/stumpy/maamp.py index dad6748c3..a216f9fc4 100644 --- a/stumpy/maamp.py +++ b/stumpy/maamp.py @@ -592,7 +592,7 @@ def _get_multi_p_norm(start, T, m, p=2.0): # "(i8, i8, i8, f8[:, :], f8[:, :], i8, i8, b1[:, :], b1[:, :], f8," # "f8[:, :], f8[:, :], f8[:, :])", parallel=True, - fastmath=config.STUMPY_FASTMATH_TRUE, + fastmath=config.STUMPY_FASTMATH_FLAGS, ) def _compute_multi_p_norm( d, From b4c56a0edf57a0ee58d8ead56cae79666b70193f Mon Sep 17 00:00:00 2001 From: NimaSarajpoor Date: Sat, 1 Feb 2025 22:51:34 -0500 Subject: [PATCH 16/24] Add note to docstring for case p=np.inf --- stumpy/core.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/stumpy/core.py b/stumpy/core.py index 98c15d0c9..e5e6912a2 100644 --- a/stumpy/core.py +++ b/stumpy/core.py @@ -1251,6 +1251,10 @@ def _p_norm_distance_profile(Q, T, p=2.0): ------- output : numpy.ndarray p-normalized distance profile between `Q` and `T` + + Notes + ----- + The special case `p==inf` is not supported. """ m = Q.shape[0] l = T.shape[0] - m + 1 From da39fa3dead04a8578d665115850506c4bdab462 Mon Sep 17 00:00:00 2001 From: NimaSarajpoor Date: Sat, 1 Feb 2025 22:53:27 -0500 Subject: [PATCH 17/24] deleted wrong file --- stumpy/utils.py | 225 ------------------------------------------------ 1 file changed, 225 deletions(-) delete mode 100644 stumpy/utils.py diff --git a/stumpy/utils.py b/stumpy/utils.py deleted file mode 100644 index 0366921a1..000000000 --- a/stumpy/utils.py +++ /dev/null @@ -1,225 +0,0 @@ -import ast - -import pathlib - -from stumpy import cache - -def check_fastmath(decorator): - """ - For the given `decorator` node with type `ast.Call`, - return the value of the `fastmath` argument if it exists. - Otherwise, return `None`. - """ - fastmath_value = None - for n in ast.iter_child_nodes(decorator): - if isinstance(n, ast.keyword) and n.arg == "fastmath": - if isinstance(n.value, ast.Constant): - fastmath_value = n.value.value - elif isinstance(n.value, ast.Set): - fastmath_value = set(item.value for item in n.value.elts) - else: - pass - break - - return fastmath_value - - -def check_njit(fd): - """ - For the given `fd` node with type `ast.FunctionDef`, - return the node of the `njit` decorator if it exists. - Otherwise, return `None`. - """ - decorator_node = None - for decorator in fd.decorator_list: - if not isinstance(decorator, ast.Call): - continue - - obj = decorator.func - if isinstance(obj, ast.Attribute): - name = obj.attr - elif isinstance(obj, ast.Subscript): - name = obj.value.id - elif isinstance(obj, ast.Name): - name = obj.id - else: - msg = f"The type {type(obj)} is not supported." - raise ValueError(msg) - - if name == "njit": - decorator_node = decorator - break - - return decorator_node - - -def check_functions(filepath): - """ - For the given `filepath`, return the function names, - whether the function is decorated with `@njit`, - and the value of the `fastmath` argument if it exists - - Parameters - ---------- - filepath : str - The path to the file - - Returns - ------- - func_names : list - List of function names - - is_njit : list - List of boolean values indicating whether the function is decorated with `@njit` - - fastmath_value : list - List of values of the `fastmath` argument if it exists - """ - file_contents = "" - with open(filepath, encoding="utf8") as f: - file_contents = f.read() - module = ast.parse(file_contents) - - function_definitions = [ - node for node in module.body if isinstance(node, ast.FunctionDef) - ] - - func_names = [fd.name for fd in function_definitions] - - njit_nodes = [check_njit(fd) for fd in function_definitions] - is_njit = [node is not None for node in njit_nodes] - - fastmath_values = [None] * len(njit_nodes) - for i, node in enumerate(njit_nodes): - if node is not None: - fastmath_values[i] = check_fastmath(node) - - return func_names, is_njit, fastmath_values - - -def _get_callees(node, all_functions): - for n in ast.iter_child_nodes(node): - if isinstance(n, ast.Call): - obj = n.func - if isinstance(obj, ast.Attribute): - name = obj.attr - elif isinstance(obj, ast.Subscript): - name = obj.value.id - elif isinstance(obj, ast.Name): - name = obj.id - else: - msg = f"The type {type(obj)} is not supported" - raise ValueError(msg) - - all_functions.append(name) - - _get_callees(n, all_functions) - - -def get_all_callees(fd): - """ - For a given node of type ast.FunctionDef, visit all of its child nodes, - and return a list of all of its callees - """ - all_functions = [] - _get_callees(fd, all_functions) - - return all_functions - - -def check_callees(filepath): - """ - For the given `filepath`, return a dictionary with the key - being the function name and the value being a set of function names - that are called by the function - """ - file_contents = "" - with open(filepath, encoding="utf8") as f: - file_contents = f.read() - module = ast.parse(file_contents) - - function_definitions = [ - node for node in module.body if isinstance(node, ast.FunctionDef) - ] - - callees = {} - for fd in function_definitions: - callees[fd.name] = set(get_all_callees(fd)) - - return callees - - -stumpy_path = pathlib.Path(__file__).parent # / "stumpy" -filepaths = sorted(f for f in pathlib.Path(stumpy_path).iterdir() if f.is_file()) - -all_functions = {} - -ignore = ["__init__.py", "__pycache__"] -for filepath in filepaths: - file_name = filepath.name - if file_name not in ignore and str(filepath).endswith(".py"): - prefix = file_name.replace(".py", "") - - func_names, is_njit, fastmath_values = check_functions(filepath) - func_names = [f"{prefix}.{fn}" for fn in func_names] - - all_functions[file_name] = { - "func_names": func_names, - "is_njit": is_njit, - "fastmath_values": fastmath_values, - } - -all_stumpy_functions = set() -for file_name, file_functions_metadata in all_functions.items(): - all_stumpy_functions.update(file_functions_metadata["func_names"]) - -all_stumpy_functions = list(all_stumpy_functions) -all_stumpy_functions_no_prefix = [f.split(".")[-1] for f in all_stumpy_functions] - - -# output 1: func_metadata -func_metadata = {} -for file_name, file_functions_metadata in all_functions.items(): - for i, f in enumerate(file_functions_metadata["func_names"]): - is_njit = file_functions_metadata["is_njit"][i] - fastmath_value = file_functions_metadata["fastmath_values"][i] - func_metadata[f] = [is_njit, fastmath_value] - - -# output 2: func_callers -func_callers = {} -for f in func_metadata.keys(): - func_callers[f] = [] - -for filepath in filepaths: - file_name = filepath.name - if file_name in ignore or not str(filepath).endswith(".py"): - continue - - prefix = file_name.replace(".py", "") - callees = check_callees(filepath) - - current_callers = set(callees.keys()) - for caller, callee_set in callees.items(): - s = list(callee_set.intersection(all_stumpy_functions_no_prefix)) - if len(s) == 0: - continue - - for c in s: - if c in current_callers: - c_name = prefix + "." + c - else: - idx = all_stumpy_functions_no_prefix.index(c) - c_name = all_stumpy_functions[idx] - - func_callers[c_name].append(f"{prefix}.{caller}") - - -for f, callers in func_callers.items(): - func_callers[f] = list(set(callers)) - - - -for modue_name, func_name in cache.get_njit_funcs(): - f = f"{modue_name}.{func_name}" - print(f, func_callers[f]) \ No newline at end of file From 69995ca1fb74733c8764f60dd540e0b2cc0218b6 Mon Sep 17 00:00:00 2001 From: NimaSarajpoor Date: Mon, 31 Mar 2025 10:37:16 -0400 Subject: [PATCH 18/24] Add check for fastmath flags of callstacks --- fastmath.py | 379 ++++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 379 insertions(+) diff --git a/fastmath.py b/fastmath.py index fe7dc0b56..ff7c371c6 100755 --- a/fastmath.py +++ b/fastmath.py @@ -89,6 +89,384 @@ def check_fastmath(pkg_dir, pkg_name): return +class FunctionCallVisitor(ast.NodeVisitor): + """ + A class to traverse the AST of modules of a package to collect the call stacks + of njit functions. + + Parameters + ---------- + pkg_dir : str + The path to the package directory containing some .py files. + + pkg_name : str + The name of the package. + + Attributes + ---------- + module_names : list + A list of module names to track the modules as the visitor traverses their AST + + call_stack : list + A list of function calls made in the current module + + out : list + A list of unique function call stacks. + + njit_funcs : list + A list of njit functions in STUMPY. Each element is a tuple of the form + (module_name, func_name). + + njit_modules : set + A set of module names, where each contains at least one njit function. + + njit_nodes : dict + A dictionary mapping njit function names to their corresponding AST nodes. + A key is of the form "module_name.func_name", and its corresponding value + is the AST node- with type ast.FunctionDef- of that njit function + + ast_modules : dict + A dictionary mapping module names to their corresponding AST objects. A key + is of the form "module_name", and its corresponding value is the content of + the module as an AST object. + + Methods + ------- + push_module(module_name) + Push a module name onto the stack of module names. + + pop_module() + Pop the last module name from the stack of module names. + + push_call_stack(module_name, func_name) + Push a function call onto the stack of function calls. + + pop_call_stack() + Pop the last function call from the stack of function calls. + + goto_deeper_func(node) + Calls the visit method from class `ast.NodeVisitor` on all children of the node. + + goto_next_func(node) + Calls the visit method from class `ast.NodeVisitor` on all children of the node. + + push_out() + Push the current function call stack onto the output list if it is not + included in one of the existing call stacks in `self.out`. + + visit_Call(node) + Visit an AST node of type `ast.Call`. This method is called when the visitor + encounters a function call in the AST. It checks if the called function is + a njit function and, if so, traverses its AST to collect its call stack. + """ + + def __init__(self, pkg_dir, pkg_name): + """ + Initialize the FunctionCallVisitor class. This method sets up the necessary + attributes and prepares the visitor for traversing the AST of STUMPY's modules. + + Parameters + ---------- + pkg_dir : str + The path to the package directory containing some .py files. + + pkg_name : str + The name of the package. + + Returns + ------- + None + """ + super().__init__() + self.module_names = [] + self.call_stack = [] + self.out = [] + + # Setup lists, dicts, and ast objects + self.njit_funcs = get_njit_funcs(pkg_dir) + self.njit_modules = set(mod_name for mod_name, func_name in self.njit_funcs) + self.njit_nodes = {} + self.ast_modules = {} + + filepaths = sorted(f for f in pathlib.Path(pkg_dir).iterdir() if f.is_file()) + ignore = ["__init__.py", "__pycache__"] + + for filepath in filepaths: + file_name = filepath.name + if ( + file_name not in ignore + and not file_name.startswith("gpu") + and str(filepath).endswith(".py") + ): + module_name = file_name.replace(".py", "") + file_contents = "" + with open(filepath, encoding="utf8") as f: + file_contents = f.read() + self.ast_modules[module_name] = ast.parse(file_contents) + + for node in self.ast_modules[module_name].body: + if isinstance(node, ast.FunctionDef): + func_name = node.name + if (module_name, func_name) in self.njit_funcs: + self.njit_nodes[f"{module_name}.{func_name}"] = node + + def push_module(self, module_name): + """ + Push a module name onto the stack of module names. + + Parameters + ---------- + module_name : str + The name of the module to be pushed onto the stack. + + Returns + ------- + None + """ + self.module_names.append(module_name) + + return + + def pop_module(self): + """ + Pop the last module name from the stack of module names. + + Parameters + ---------- + None + + Returns + ------- + None + """ + if self.module_names: + self.module_names.pop() + + return + + def push_call_stack(self, module_name, func_name): + """ + Push a function call onto the stack of function calls. + + Parameters + ---------- + module_name : str + The name of the module containing the function being called. + + func_name : str + The name of the function being called. + + Returns + ------- + None + """ + self.call_stack.append(f"{module_name}.{func_name}") + + return + + def pop_call_stack(self): + """ + Pop the last function call from the stack of function calls. + + Parameters + ---------- + None + + Returns + ------- + None + """ + if self.call_stack: + self.call_stack.pop() + + return + + def goto_deeper_func(self, node): + """ + Calls the visit method from class `ast.NodeVisitor` on + all children of the node. + + Parameters + ---------- + node : ast.AST + The AST node to be visited. + + Returns + ------- + None + """ + self.generic_visit(node) + + return + + def goto_next_func(self, node): + """ + Calls the visit method from class `ast.NodeVisitor` on + all children of the node. + + Parameters + ---------- + node : ast.AST + The AST node to be visited. + + Returns + ------- + None + """ + self.generic_visit(node) + + return + + def push_out(self): + """ + Push the current function call stack onto the output list if it is not + included in one of the existing call stacks in `self.out`. + + Parameters + ---------- + None + + Returns + ------- + None + """ + unique = True + for cs in self.out: + if " ".join(self.call_stack) in " ".join(cs): + unique = False + break + + if unique: + self.out.append(self.call_stack.copy()) + + return + + def visit_Call(self, node): + """ + Visit an AST node of type `ast.Call`. + + Parameters + ---------- + node : ast.Call + The AST node representing a function call. + + Returns + ------- + None + """ + callee_name = ast.unparse(node.func) + + module_changed = False + if "." in callee_name: + new_module_name, new_func_name = callee_name.split(".")[:2] + + if new_module_name in self.njit_modules: + self.push_module(new_module_name) + module_changed = True + else: + if self.module_names: + new_module_name = self.module_names[-1] + new_func_name = callee_name + callee_name = f"{new_module_name}.{new_func_name}" + + if callee_name in self.njit_nodes.keys(): + callee_node = self.njit_nodes[callee_name] + self.push_call_stack(new_module_name, new_func_name) + self.goto_deeper_func(callee_node) + if module_changed: + self.pop_module() + self.push_out() + self.pop_call_stack() + + self.goto_next_func(node) + + return + + +def get_njit_call_stacks(pkg_dir, pkg_name): + """ + Get the call stacks of all njit functions in STUMPY. + This function traverses the AST of each module in STUMPY and returns + a list of unique function call stacks. + + Parameters + ---------- + pkg_dir : str + The path to the package directory containing some .py files + + pkg_name : str + The name of the package + + Returns + ------- + out : list + A list of unique function call stacks. Each element is a list of strings, + where each string represents a function call in the stack. + """ + visitor = FunctionCallVisitor(pkg_dir, pkg_name) + + for module_name in visitor.njit_modules: + visitor.push_module(module_name) + + for node in visitor.ast_modules[module_name].body: + if isinstance(node, ast.FunctionDef): + func_name = node.name + if (module_name, func_name) in visitor.njit_funcs: + visitor.push_call_stack(module_name, func_name) + visitor.visit(node) + visitor.pop_call_stack() + + visitor.pop_module() + + return visitor.out + + +def check_fastmath_callstack(pkg_dir, pkg_name): + """ + Check if all njit functions in a callstack have the same `fastmath` flag. + This function raises a ValueError if it finds any inconsistencies in the + `fastmath` flags across the call stacks of njit functions. + + Parameters + ---------- + pkg_dir : str + The path to the package directory containing some .py files + + pkg_name : str + The name of the package + + Returns + ------- + None + """ + out = get_njit_call_stacks(pkg_dir, pkg_name) + + fastmath_is_inconsistent = [] + for cs in out: + module_name, func_name = cs[0].split(".") + module = importlib.import_module(f".{module_name}", package="stumpy") + func = getattr(module, func_name) + flag = func.targetoptions["fastmath"] + + for item in cs[1:]: + module_name, func_name = cs[0].split(".") + module = importlib.import_module(f".{module_name}", package="stumpy") + func = getattr(module, func_name) + func_flag = func.targetoptions["fastmath"] + if func_flag != flag: + fastmath_is_inconsistent.append(cs) + break + + if len(fastmath_is_inconsistent) > 0: + msg = ( + "Found at least one callstack that have inconsistent `fastmath` flags. " + + f"The functions are:\n {fastmath_is_inconsistent}\n" + ) + raise ValueError(msg) + + return + + if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument("--check", dest="pkg_dir") @@ -98,3 +476,4 @@ def check_fastmath(pkg_dir, pkg_name): pkg_dir = pathlib.Path(args.pkg_dir) pkg_name = pkg_dir.name check_fastmath(str(pkg_dir), pkg_name) + check_fastmath_callstack(str(pkg_dir), pkg_name) From 08b46c65d784df3ab83e4e712eec6bc4fb779e8d Mon Sep 17 00:00:00 2001 From: NimaSarajpoor Date: Mon, 31 Mar 2025 11:17:53 -0400 Subject: [PATCH 19/24] minor changes --- fastmath.py | 27 ++++++++++++++------------- 1 file changed, 14 insertions(+), 13 deletions(-) diff --git a/fastmath.py b/fastmath.py index ff7c371c6..c955e2a39 100755 --- a/fastmath.py +++ b/fastmath.py @@ -152,7 +152,7 @@ class FunctionCallVisitor(ast.NodeVisitor): push_out() Push the current function call stack onto the output list if it is not - included in one of the existing call stacks in `self.out`. + included in one of the so-far-collected call stacks. visit_Call(node) Visit an AST node of type `ast.Call`. This method is called when the visitor @@ -421,16 +421,16 @@ def get_njit_call_stacks(pkg_dir, pkg_name): return visitor.out -def check_fastmath_callstack(pkg_dir, pkg_name): +def check_call_stack_fastmath(pkg_dir, pkg_name): """ - Check if all njit functions in a callstack have the same `fastmath` flag. + Check if all njit functions in a call stack have the same `fastmath` flag. This function raises a ValueError if it finds any inconsistencies in the - `fastmath` flags across the call stacks of njit functions. + `fastmath` flags in any call stack of njit functions. Parameters ---------- pkg_dir : str - The path to the package directory containing some .py files + The path to the directory containing some .py files pkg_name : str The name of the package @@ -441,26 +441,27 @@ def check_fastmath_callstack(pkg_dir, pkg_name): """ out = get_njit_call_stacks(pkg_dir, pkg_name) - fastmath_is_inconsistent = [] + inconsitent_call_stacks = [] for cs in out: + # Set the fastmath flag of the first function in the call stack + # as the reference flag module_name, func_name = cs[0].split(".") module = importlib.import_module(f".{module_name}", package="stumpy") func = getattr(module, func_name) - flag = func.targetoptions["fastmath"] + flag_ref = func.targetoptions["fastmath"] for item in cs[1:]: module_name, func_name = cs[0].split(".") module = importlib.import_module(f".{module_name}", package="stumpy") func = getattr(module, func_name) - func_flag = func.targetoptions["fastmath"] - if func_flag != flag: - fastmath_is_inconsistent.append(cs) + if func.targetoptions["fastmath"] != flag_ref: + inconsitent_call_stacks.append(cs) break - if len(fastmath_is_inconsistent) > 0: + if len(inconsitent_call_stacks) > 0: msg = ( "Found at least one callstack that have inconsistent `fastmath` flags. " - + f"The functions are:\n {fastmath_is_inconsistent}\n" + + f"The functions are:\n {inconsitent_call_stacks}\n" ) raise ValueError(msg) @@ -476,4 +477,4 @@ def check_fastmath_callstack(pkg_dir, pkg_name): pkg_dir = pathlib.Path(args.pkg_dir) pkg_name = pkg_dir.name check_fastmath(str(pkg_dir), pkg_name) - check_fastmath_callstack(str(pkg_dir), pkg_name) + check_call_stack_fastmath(str(pkg_dir), pkg_name) From dd3490b08443dd0faa8077cf095c3d7b8c0611a0 Mon Sep 17 00:00:00 2001 From: NimaSarajpoor Date: Sat, 5 Apr 2025 01:23:45 -0400 Subject: [PATCH 20/24] minor changes and fixes --- fastmath.py | 41 ++++++++++++++++++++++------------------- 1 file changed, 22 insertions(+), 19 deletions(-) diff --git a/fastmath.py b/fastmath.py index c955e2a39..e9784a3e9 100755 --- a/fastmath.py +++ b/fastmath.py @@ -91,8 +91,8 @@ def check_fastmath(pkg_dir, pkg_name): class FunctionCallVisitor(ast.NodeVisitor): """ - A class to traverse the AST of modules of a package to collect the call stacks - of njit functions. + A class to traverse the AST of the modules of a package to collect + the call stacks of njit functions. Parameters ---------- @@ -108,51 +108,54 @@ class FunctionCallVisitor(ast.NodeVisitor): A list of module names to track the modules as the visitor traverses their AST call_stack : list - A list of function calls made in the current module + A list of njit functions, representing a chain of function calls, + where each element is a string of the form "module_name.func_name". out : list - A list of unique function call stacks. + A list of unique function `call_stack`s. njit_funcs : list - A list of njit functions in STUMPY. Each element is a tuple of the form - (module_name, func_name). + A list of all njit functions in `pkg_dir`'s modules. Each element is a tuple + of the form `(module_name, func_name)`. njit_modules : set - A set of module names, where each contains at least one njit function. + A set that contains the names of all modules, each of which contains at least + one njit function. njit_nodes : dict A dictionary mapping njit function names to their corresponding AST nodes. - A key is of the form "module_name.func_name", and its corresponding value - is the AST node- with type ast.FunctionDef- of that njit function + A key is a string, and it is of the form "module_name.func_name", and its + corresponding value is the AST node- with type ast.FunctionDef- of that + function. ast_modules : dict A dictionary mapping module names to their corresponding AST objects. A key - is of the form "module_name", and its corresponding value is the content of - the module as an AST object. + is the name of a module, and its corresponding value is the content of that + module as an AST object. Methods ------- push_module(module_name) - Push a module name onto the stack of module names. + Push the name of a module onto the stack `module_names`. pop_module() - Pop the last module name from the stack of module names. + Pop the last module name from the stack `module_names`. push_call_stack(module_name, func_name) - Push a function call onto the stack of function calls. + Push a function call onto the stack of function calls, `call_stack`. pop_call_stack() - Pop the last function call from the stack of function calls. + Pop the last function call from the stack of function calls, `call_stack` goto_deeper_func(node) - Calls the visit method from class `ast.NodeVisitor` on all children of the node. + Calls the visit method from class `ast.NodeVisitor` on all children of the `node`. goto_next_func(node) - Calls the visit method from class `ast.NodeVisitor` on all children of the node. + Calls the visit method from class `ast.NodeVisitor` on all children of the `node`. push_out() - Push the current function call stack onto the output list if it is not - included in one of the so-far-collected call stacks. + Push the current function call stack, `call_stack`, onto the output list, `out`, + unless it is already included in one of the so-far-collected call stacks. visit_Call(node) Visit an AST node of type `ast.Call`. This method is called when the visitor From d6740ed223c9e308def83c067b467187f55c4d53 Mon Sep 17 00:00:00 2001 From: NimaSarajpoor Date: Sat, 5 Apr 2025 01:24:36 -0400 Subject: [PATCH 21/24] fix black and flake8 --- fastmath.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/fastmath.py b/fastmath.py index e9784a3e9..d11d9822e 100755 --- a/fastmath.py +++ b/fastmath.py @@ -108,7 +108,7 @@ class FunctionCallVisitor(ast.NodeVisitor): A list of module names to track the modules as the visitor traverses their AST call_stack : list - A list of njit functions, representing a chain of function calls, + A list of njit functions, representing a chain of function calls, where each element is a string of the form "module_name.func_name". out : list @@ -148,10 +148,12 @@ class FunctionCallVisitor(ast.NodeVisitor): Pop the last function call from the stack of function calls, `call_stack` goto_deeper_func(node) - Calls the visit method from class `ast.NodeVisitor` on all children of the `node`. + Calls the visit method from class `ast.NodeVisitor` on all children of + the `node`. goto_next_func(node) - Calls the visit method from class `ast.NodeVisitor` on all children of the `node`. + Calls the visit method from class `ast.NodeVisitor` on all children of + the `node`. push_out() Push the current function call stack, `call_stack`, onto the output list, `out`, From 5fa5e1fd5499cdf2960c77606104f4a4f9e18548 Mon Sep 17 00:00:00 2001 From: NimaSarajpoor Date: Sat, 5 Apr 2025 02:44:50 -0400 Subject: [PATCH 22/24] minor changes --- fastmath.py | 38 ++++++++++++++++++++------------------ 1 file changed, 20 insertions(+), 18 deletions(-) diff --git a/fastmath.py b/fastmath.py index d11d9822e..ce5424d3f 100755 --- a/fastmath.py +++ b/fastmath.py @@ -105,7 +105,7 @@ class FunctionCallVisitor(ast.NodeVisitor): Attributes ---------- module_names : list - A list of module names to track the modules as the visitor traverses their AST + A list of module names to track the modules as the visitor traverses their AST. call_stack : list A list of njit functions, representing a chain of function calls, @@ -289,7 +289,7 @@ def pop_call_stack(self): def goto_deeper_func(self, node): """ Calls the visit method from class `ast.NodeVisitor` on - all children of the node. + all children of the `node`. Parameters ---------- @@ -324,8 +324,9 @@ def goto_next_func(self, node): def push_out(self): """ - Push the current function call stack onto the output list if it is not - included in one of the existing call stacks in `self.out`. + Push the current function call stack onto the output list unless it + is already included in one of the so-far-collected call stacks. + Parameters ---------- @@ -348,7 +349,7 @@ def push_out(self): def visit_Call(self, node): """ - Visit an AST node of type `ast.Call`. + Called when visiting an AST node of type `ast.Call`. Parameters ---------- @@ -378,10 +379,10 @@ def visit_Call(self, node): callee_node = self.njit_nodes[callee_name] self.push_call_stack(new_module_name, new_func_name) self.goto_deeper_func(callee_node) - if module_changed: - self.pop_module() self.push_out() self.pop_call_stack() + if module_changed: + self.pop_module() self.goto_next_func(node) @@ -390,9 +391,9 @@ def visit_Call(self, node): def get_njit_call_stacks(pkg_dir, pkg_name): """ - Get the call stacks of all njit functions in STUMPY. - This function traverses the AST of each module in STUMPY and returns - a list of unique function call stacks. + Get the call stacks of all njit functions in `pkg_dir`. + This function traverses the AST of each module in `pkg_dir` + and returns a list of unique function call stacks. Parameters ---------- @@ -405,8 +406,8 @@ def get_njit_call_stacks(pkg_dir, pkg_name): Returns ------- out : list - A list of unique function call stacks. Each element is a list of strings, - where each string represents a function call in the stack. + A list of unique function call stacks. Each item is of type list, + representing a chain of function calls. """ visitor = FunctionCallVisitor(pkg_dir, pkg_name) @@ -430,7 +431,7 @@ def check_call_stack_fastmath(pkg_dir, pkg_name): """ Check if all njit functions in a call stack have the same `fastmath` flag. This function raises a ValueError if it finds any inconsistencies in the - `fastmath` flags in any call stack of njit functions. + `fastmath` flags in at lease one call stack of njit functions. Parameters ---------- @@ -444,10 +445,10 @@ def check_call_stack_fastmath(pkg_dir, pkg_name): ------- None """ - out = get_njit_call_stacks(pkg_dir, pkg_name) - inconsitent_call_stacks = [] - for cs in out: + + njit_call_stacks = get_njit_call_stacks(pkg_dir, pkg_name) + for cs in njit_call_stacks: # Set the fastmath flag of the first function in the call stack # as the reference flag module_name, func_name = cs[0].split(".") @@ -459,14 +460,15 @@ def check_call_stack_fastmath(pkg_dir, pkg_name): module_name, func_name = cs[0].split(".") module = importlib.import_module(f".{module_name}", package="stumpy") func = getattr(module, func_name) - if func.targetoptions["fastmath"] != flag_ref: + flag = func.targetoptions["fastmath"] + if flag != flag_ref: inconsitent_call_stacks.append(cs) break if len(inconsitent_call_stacks) > 0: msg = ( "Found at least one callstack that have inconsistent `fastmath` flags. " - + f"The functions are:\n {inconsitent_call_stacks}\n" + + f"Those call stacks are:\n {inconsitent_call_stacks}\n" ) raise ValueError(msg) From 0908600f6268073546c88831b0ee89157b5d765b Mon Sep 17 00:00:00 2001 From: NimaSarajpoor Date: Mon, 7 Apr 2025 00:30:45 -0400 Subject: [PATCH 23/24] fixed typo and add comment --- fastmath.py | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/fastmath.py b/fastmath.py index ce5424d3f..986eadf59 100755 --- a/fastmath.py +++ b/fastmath.py @@ -445,7 +445,8 @@ def check_call_stack_fastmath(pkg_dir, pkg_name): ------- None """ - inconsitent_call_stacks = [] + # List of call stacks with inconsistent fastmath flags + inconsistent_call_stacks = [] njit_call_stacks = get_njit_call_stacks(pkg_dir, pkg_name) for cs in njit_call_stacks: @@ -462,13 +463,13 @@ def check_call_stack_fastmath(pkg_dir, pkg_name): func = getattr(module, func_name) flag = func.targetoptions["fastmath"] if flag != flag_ref: - inconsitent_call_stacks.append(cs) + inconsistent_call_stacks.append(cs) break - if len(inconsitent_call_stacks) > 0: + if len(inconsistent_call_stacks) > 0: msg = ( - "Found at least one callstack that have inconsistent `fastmath` flags. " - + f"Those call stacks are:\n {inconsitent_call_stacks}\n" + "Found at least one call stack that has inconsistent `fastmath` flags. " + + f"Those call stacks are:\n {inconsistent_call_stacks}\n" ) raise ValueError(msg) From 8520288cfcdda980c1253ae1c6de6e47442e012e Mon Sep 17 00:00:00 2001 From: NimaSarajpoor Date: Mon, 7 Apr 2025 00:56:03 -0400 Subject: [PATCH 24/24] minor changes --- fastmath.py | 18 ++++++++---------- 1 file changed, 8 insertions(+), 10 deletions(-) diff --git a/fastmath.py b/fastmath.py index 986eadf59..b46fd5e37 100755 --- a/fastmath.py +++ b/fastmath.py @@ -105,14 +105,14 @@ class FunctionCallVisitor(ast.NodeVisitor): Attributes ---------- module_names : list - A list of module names to track the modules as the visitor traverses their AST. + A list of module names to track the modules as the visitor traverses them. call_stack : list A list of njit functions, representing a chain of function calls, where each element is a string of the form "module_name.func_name". out : list - A list of unique function `call_stack`s. + A list of unique `call_stack`s. njit_funcs : list A list of all njit functions in `pkg_dir`'s modules. Each element is a tuple @@ -160,9 +160,9 @@ class FunctionCallVisitor(ast.NodeVisitor): unless it is already included in one of the so-far-collected call stacks. visit_Call(node) - Visit an AST node of type `ast.Call`. This method is called when the visitor - encounters a function call in the AST. It checks if the called function is - a njit function and, if so, traverses its AST to collect its call stack. + This method is called when the visitor encounters a function call in the AST. It + checks if the called function is a njit function and, if so, traverses its AST + to collect its call stack. """ def __init__(self, pkg_dir, pkg_name): @@ -256,10 +256,10 @@ def push_call_stack(self, module_name, func_name): Parameters ---------- module_name : str - The name of the module containing the function being called. + A module's name func_name : str - The name of the function being called. + A function's name Returns ------- @@ -391,9 +391,7 @@ def visit_Call(self, node): def get_njit_call_stacks(pkg_dir, pkg_name): """ - Get the call stacks of all njit functions in `pkg_dir`. - This function traverses the AST of each module in `pkg_dir` - and returns a list of unique function call stacks. + Get the call stacks of all njit functions in `pkg_dir` Parameters ----------