|
11 | 11 |
|
12 | 12 | import cgen as c |
13 | 13 | from sympy import IndexedBase |
| 14 | +from sympy.core.function import Application |
14 | 15 |
|
15 | 16 | from devito.exceptions import VisitorException |
16 | 17 | from devito.ir.iet.nodes import (Node, Iteration, Expression, ExpressionBundle, |
|
24 | 25 | IndexedData, DeviceMap) |
25 | 26 |
|
26 | 27 |
|
27 | | -__all__ = ['FindNodes', 'FindSections', 'FindSymbols', 'MapExprStmts', 'MapNodes', |
28 | | - 'IsPerfectIteration', 'printAST', 'CGen', 'CInterface', 'Transformer', |
29 | | - 'Uxreplace'] |
| 28 | +__all__ = ['FindApplications', 'FindNodes', 'FindSections', 'FindSymbols', |
| 29 | + 'MapExprStmts', 'MapNodes', 'IsPerfectIteration', 'printAST', 'CGen', |
| 30 | + 'CInterface', 'Transformer', 'Uxreplace'] |
30 | 31 |
|
31 | 32 |
|
32 | 33 | class Visitor(GenericVisitor): |
@@ -953,6 +954,55 @@ def visit_Node(self, o, ret=None): |
953 | 954 | return ret |
954 | 955 |
|
955 | 956 |
|
| 957 | +class FindApplications(Visitor): |
| 958 | + |
| 959 | + """ |
| 960 | + Find all SymPy applied functions (aka, `Application`s). The user may refine |
| 961 | + the search by supplying a different target class. |
| 962 | + """ |
| 963 | + |
| 964 | + def __init__(self, cls=Application): |
| 965 | + super().__init__() |
| 966 | + self.match = lambda i: isinstance(i, cls) and not isinstance(i, Basic) |
| 967 | + |
| 968 | + @classmethod |
| 969 | + def default_retval(cls): |
| 970 | + return set() |
| 971 | + |
| 972 | + def visit_object(self, o, **kwargs): |
| 973 | + return self.default_retval() |
| 974 | + |
| 975 | + def visit_tuple(self, o, ret=None): |
| 976 | + ret = ret or self.default_retval() |
| 977 | + for i in o: |
| 978 | + ret.update(self._visit(i, ret=ret)) |
| 979 | + return ret |
| 980 | + |
| 981 | + def visit_Node(self, o, ret=None): |
| 982 | + ret = ret or self.default_retval() |
| 983 | + for i in o.children: |
| 984 | + ret.update(self._visit(i, ret=ret)) |
| 985 | + return ret |
| 986 | + |
| 987 | + def visit_Expression(self, o, **kwargs): |
| 988 | + return o.expr.find(self.match) |
| 989 | + |
| 990 | + def visit_Iteration(self, o, **kwargs): |
| 991 | + ret = self._visit(o.children) or self.default_retval() |
| 992 | + ret.update(o.symbolic_min.find(self.match)) |
| 993 | + ret.update(o.symbolic_max.find(self.match)) |
| 994 | + return ret |
| 995 | + |
| 996 | + def visit_Call(self, o, **kwargs): |
| 997 | + ret = self.default_retval() |
| 998 | + for i in o.arguments: |
| 999 | + try: |
| 1000 | + ret.update(i.find(self.match)) |
| 1001 | + except (AttributeError, TypeError): |
| 1002 | + continue |
| 1003 | + return ret |
| 1004 | + |
| 1005 | + |
956 | 1006 | class IsPerfectIteration(Visitor): |
957 | 1007 |
|
958 | 1008 | """ |
|
0 commit comments