1919#include " llvm/ADT/PostOrderIterator.h"
2020#include " llvm/IR/Function.h"
2121#include " llvm/IR/Instructions.h"
22- #include " llvm/IR/IntrinsicInst.h"
2322#include " llvm/IR/Module.h"
2423#include " llvm/Support/Debug.h"
2524
@@ -44,50 +43,76 @@ void VisitorTemplate::setStrategy(VisitorStrategy strategy) {
4443 m_strategy = strategy;
4544}
4645
47- void VisitorTemplate::add (VisitorKey key, VisitorCallback *fn,
48- VisitorCallbackData data,
49- VisitorHandler::Projection projection) {
50- VisitorHandler handler;
51- handler.callback = fn;
52- handler.data = data;
53- handler.projection = projection;
54-
55- m_handlers.emplace_back (handler);
46+ void VisitorTemplate::storeHandlersInOpMap (
47+ const VisitorKey &key, unsigned handlerIdx,
48+ VisitorCallbackType visitorCallbackTy) {
49+ const auto HandlerList =
50+ [&](const OpDescription &opDescription) -> llvm::SmallVector<unsigned > & {
51+ if (visitorCallbackTy == VisitorCallbackType::PreVisit)
52+ return m_opMap[opDescription].PreVisitHandlers ;
5653
57- const unsigned handlerIdx = m_handlers.size () - 1 ;
54+ return m_opMap[opDescription].VisitHandlers ;
55+ };
5856
5957 if (key.m_kind == VisitorKey::Kind::Intrinsic) {
60- m_opMap[ OpDescription::fromIntrinsic (key.m_intrinsicId )]. push_back (
61- handlerIdx);
58+ HandlerList ( OpDescription::fromIntrinsic (key.m_intrinsicId ))
59+ . push_back ( handlerIdx);
6260 } else if (key.m_kind == VisitorKey::Kind::OpDescription) {
6361 const OpDescription *opDesc = key.m_description ;
6462
6563 if (opDesc->isCoreOp ()) {
6664 for (const unsigned op : opDesc->getOpcodes ())
67- m_opMap[ OpDescription::fromCoreOp (op)] .push_back (handlerIdx);
65+ HandlerList ( OpDescription::fromCoreOp (op)) .push_back (handlerIdx);
6866 } else if (opDesc->isIntrinsic ()) {
6967 for (const unsigned op : opDesc->getOpcodes ())
70- m_opMap[ OpDescription::fromIntrinsic (op)] .push_back (handlerIdx);
68+ HandlerList ( OpDescription::fromIntrinsic (op)) .push_back (handlerIdx);
7169 } else {
72- m_opMap[ *opDesc] .push_back (handlerIdx);
70+ HandlerList ( *opDesc) .push_back (handlerIdx);
7371 }
7472 } else if (key.m_kind == VisitorKey::Kind::OpSet) {
7573 const OpSet *opSet = key.m_set ;
7674
75+ if (visitorCallbackTy == VisitorCallbackType::PreVisit && opSet->empty ()) {
76+ // This adds a handler for every stored op.
77+ // Note: should be used with caution.
78+ for (auto it : m_opMap)
79+ it.second .PreVisitHandlers .push_back (handlerIdx);
80+
81+ return ;
82+ }
83+
7784 for (unsigned opcode : opSet->getCoreOpcodes ())
78- m_opMap[ OpDescription::fromCoreOp (opcode)] .push_back (handlerIdx);
85+ HandlerList ( OpDescription::fromCoreOp (opcode)) .push_back (handlerIdx);
7986
8087 for (unsigned intrinsicID : opSet->getIntrinsicIDs ())
81- m_opMap[OpDescription::fromIntrinsic (intrinsicID)].push_back (handlerIdx);
88+ HandlerList (OpDescription::fromIntrinsic (intrinsicID))
89+ .push_back (handlerIdx);
8290
83- for (const auto &dialectOpPair : opSet->getDialectOps ()) {
84- m_opMap[ OpDescription::fromDialectOp (dialectOpPair.isOverload ,
85- dialectOpPair.mnemonic )]
91+ for (const auto &dialectOpPair : opSet->getDialectOps ())
92+ HandlerList ( OpDescription::fromDialectOp (dialectOpPair.isOverload ,
93+ dialectOpPair.mnemonic ))
8694 .push_back (handlerIdx);
87- }
8895 }
8996}
9097
98+ void VisitorTemplate::add (VisitorKey key, VisitorCallback *fn,
99+ VisitorCallbackData data,
100+ VisitorHandler::Projection projection,
101+ VisitorCallbackType visitorCallbackTy) {
102+ assert (visitorCallbackTy != VisitorCallbackType::PreVisit || key.m_set );
103+
104+ VisitorHandler handler;
105+ handler.callback = fn;
106+ handler.data = data;
107+ handler.projection = projection;
108+
109+ m_handlers.emplace_back (handler);
110+
111+ const unsigned handlerIdx = m_handlers.size () - 1 ;
112+
113+ storeHandlersInOpMap (key, handlerIdx, visitorCallbackTy);
114+ }
115+
91116VisitorBuilderBase::VisitorBuilderBase () : m_template(&m_ownedTemplate) {}
92117
93118VisitorBuilderBase::VisitorBuilderBase (VisitorBuilderBase *parent,
@@ -144,6 +169,13 @@ void VisitorBuilderBase::setStrategy(VisitorStrategy strategy) {
144169 m_template->setStrategy (strategy);
145170}
146171
172+ void VisitorBuilderBase::addPreVisitCallback (VisitorKey key,
173+ VisitorCallback *fn,
174+ VisitorCallbackData data) {
175+ m_template->add (key, fn, data, m_projection,
176+ VisitorTemplate::VisitorCallbackType::PreVisit);
177+ }
178+
147179void VisitorBuilderBase::add (VisitorKey key, VisitorCallback *fn,
148180 VisitorCallbackData data) {
149181 m_template->add (key, fn, data, m_projection);
@@ -192,9 +224,12 @@ VisitorBase::VisitorBase(VisitorTemplate &&templ)
192224 BuildHelper helper (*this , templ.m_handlers );
193225
194226 m_opMap.reserve (templ.m_opMap );
195-
196- for (auto it : templ.m_opMap )
197- m_opMap[it.first ] = helper.mapHandlers (it.second );
227+ for (auto it : templ.m_opMap ) {
228+ m_opMap[it.first ].PreVisitCallbacks =
229+ helper.mapHandlers (it.second .PreVisitHandlers );
230+ m_opMap[it.first ].VisitCallbacks =
231+ helper.mapHandlers (it.second .VisitHandlers );
232+ }
198233}
199234
200235void VisitorBase::call (HandlerRange handlers, void *payload,
@@ -223,11 +258,14 @@ VisitorResult VisitorBase::call(const VisitorHandler &handler, void *payload,
223258}
224259
225260void VisitorBase::visit (void *payload, Instruction &inst) const {
226- auto handlers = m_opMap.find (inst);
227- if (!handlers )
261+ auto mappedHandlers = m_opMap.find (inst);
262+ if (!mappedHandlers )
228263 return ;
229264
230- call (*handlers.val (), payload, inst);
265+ auto &callbacks = *mappedHandlers.val ();
266+
267+ call (callbacks.PreVisitCallbacks , payload, inst);
268+ call (callbacks.VisitCallbacks , payload, inst);
231269}
232270
233271template <typename FilterT>
@@ -241,19 +279,23 @@ void VisitorBase::visitByDeclarations(void *payload, llvm::Module &module,
241279
242280 LLVM_DEBUG (dbgs () << " visit " << decl.getName () << ' \n ' );
243281
244- auto handlers = m_opMap.find (decl);
245- if (!handlers ) {
282+ auto mappedHandlers = m_opMap.find (decl);
283+ if (!mappedHandlers ) {
246284 // Neither a matched intrinsic nor a matched dialect op; skip.
247285 continue ;
248286 }
249287
288+ auto &callbacks = *mappedHandlers.val ();
289+
250290 for (Use &use : make_early_inc_range (decl.uses ())) {
251291 if (auto *inst = dyn_cast<Instruction>(use.getUser ())) {
252292 if (!filter (*inst))
253293 continue ;
254294 if (auto *callInst = dyn_cast<CallInst>(inst)) {
255- if (&use == &callInst->getCalledOperandUse ())
256- call (*handlers.val (), payload, *callInst);
295+ if (&use == &callInst->getCalledOperandUse ()) {
296+ call (callbacks.PreVisitCallbacks , payload, *callInst);
297+ call (callbacks.VisitCallbacks , payload, *callInst);
298+ }
257299 }
258300 }
259301 }
0 commit comments