@@ -132,7 +132,7 @@ class InlineGlobalSlotsAnalysis : public DataFlowAnalysis {
132
132
public:
133
133
InlineGlobalSlotsAnalysis (DataFlowSolver &solver);
134
134
LogicalResult initialize (Operation *top) override ;
135
- LogicalResult visit (ProgramPoint point) override ;
135
+ LogicalResult visit (ProgramPoint * point) override ;
136
136
137
137
private:
138
138
// / The local transfer function determining the safety of `value`.
@@ -170,7 +170,7 @@ LogicalResult InlineGlobalSlotsAnalysis::initialize(Operation *top) {
170
170
if (auto initialize = dyn_cast<Torch::InitializeGlobalSlotsOp>(op)) {
171
171
initializeGlobalSlotsOp = initialize;
172
172
}
173
- if (failed (visit (op )))
173
+ if (failed (visit (getProgramPointAfter (op) )))
174
174
return WalkResult::interrupt ();
175
175
176
176
return WalkResult::advance ();
@@ -180,8 +180,11 @@ LogicalResult InlineGlobalSlotsAnalysis::initialize(Operation *top) {
180
180
return success ();
181
181
}
182
182
183
- LogicalResult InlineGlobalSlotsAnalysis::visit (ProgramPoint point) {
184
- if (auto op = dyn_cast<Operation *>(point)) {
183
+ LogicalResult InlineGlobalSlotsAnalysis::visit (ProgramPoint *point) {
184
+ if (point->isBlockStart ())
185
+ return success ();
186
+
187
+ if (auto op = point->getPrevOp ()) {
185
188
for (auto value : op->getResults ()) {
186
189
bool isSafe = isValueSafeTransferFunction (value);
187
190
auto *state = getOrCreate<InlineGlobalSlotsAnalysisState>(value);
@@ -196,7 +199,7 @@ LogicalResult InlineGlobalSlotsAnalysis::visit(ProgramPoint point) {
196
199
auto *flatSymbolRefPoint =
197
200
getLatticeAnchor<FlatSymbolRefLatticeAnchor>(globalSlot);
198
201
auto *valueState = getOrCreateFor<InlineGlobalSlotsAnalysisState>(
199
- globalSlot, globalSlotGet.getResult ());
202
+ getProgramPointAfter ( globalSlot) , globalSlotGet.getResult ());
200
203
auto *globalState =
201
204
getOrCreate<InlineGlobalSlotsAnalysisState>(flatSymbolRefPoint);
202
205
propagateIfChanged (globalState,
@@ -223,7 +226,7 @@ bool InlineGlobalSlotsAnalysis::isValueSafeTransferFunction(Value value) {
223
226
if ((op->hasTrait <Torch::OpTrait::ReadOnly>() || isMemoryEffectFree (op)) &&
224
227
llvm::all_of (op->getResults (), [&](Value result) {
225
228
auto *state = getOrCreateFor<InlineGlobalSlotsAnalysisState>(
226
- value.getDefiningOp (), result);
229
+ getProgramPointAfter ( value.getDefiningOp () ), result);
227
230
return state->isSafe ;
228
231
}))
229
232
continue ;
@@ -234,7 +237,7 @@ bool InlineGlobalSlotsAnalysis::isValueSafeTransferFunction(Value value) {
234
237
SymbolTable::lookupNearestSymbolFrom<GlobalSlotOp>(op, symName);
235
238
236
239
auto *state = getOrCreateFor<InlineGlobalSlotsAnalysisState>(
237
- value.getDefiningOp (),
240
+ getProgramPointAfter ( value.getDefiningOp () ),
238
241
getLatticeAnchor<FlatSymbolRefLatticeAnchor>(globalSlot));
239
242
if (state->isSafe )
240
243
continue ;
0 commit comments