@@ -23,6 +23,7 @@ import (
2323 "go/constant"
2424 "go/token"
2525 "go/types"
26+ "iter"
2627
2728 "golang.org/x/tools/go/analysis"
2829 "golang.org/x/tools/go/analysis/passes/buildssa"
@@ -89,6 +90,10 @@ func run(pass *analysis.Pass) (any, error) {
8990 }
9091 }
9192 }
93+ isYieldCall := func (v ssa.Value ) bool {
94+ call , ok := v .(* ssa.Call )
95+ return ok && ssaYieldCalls [call ] != nil
96+ }
9297
9398 // Now search for a control path from the instruction after a
9499 // yield call to another yield call--possible the same one,
@@ -115,96 +120,68 @@ func run(pass *analysis.Pass) (any, error) {
115120 // visit visits the instructions of a block (or a suffix if start > 0).
116121 var visit func (b * ssa.BasicBlock , start int )
117122 visit = func (b * ssa.BasicBlock , start int ) {
118- if ! visited [b .Index ] {
119- if start == 0 {
120- visited [b .Index ] = true
121- }
122- for _ , instr := range b .Instrs [start :] {
123- switch instr := instr .(type ) {
124- case * ssa.Call :
125-
126- // Precondition: v has a pos within a CallExpr.
127- enclosingCall := func (v ssa.Value ) ast.Node {
128- pos := v .Pos ()
129- cur , ok := inspector .Root ().FindByPos (pos , pos )
130- if ! ok {
131- panic (fmt .Sprintf ("can't find node at %v" , safetoken .StartPosition (pass .Fset , pos )))
132- }
133- call , ok := moreiters .First (cur .Enclosing ((* ast .CallExpr )(nil )))
134- if ! ok {
135- panic (fmt .Sprintf ("no call enclosing %v" , safetoken .StartPosition (pass .Fset , pos )))
136- }
137- return call .Node ()
138- }
123+ if visited [b .Index ] {
124+ return
125+ }
126+ if start == 0 {
127+ visited [b .Index ] = true
128+ }
129+ for _ , instr := range b .Instrs [start :] {
130+ switch instr := instr .(type ) {
131+ case * ssa.Call :
139132
140- if ! info .reported && ssaYieldCalls [instr ] != nil {
141- info .reported = true
142- var (
143- where = "" // "" => same yield call (a loop)
144- related []analysis.RelatedInformation
145- )
146- // Also report location of reached yield call, if distinct.
147- if instr != call {
148- otherLine := safetoken .StartPosition (pass .Fset , instr .Pos ()).Line
149- where = fmt .Sprintf ("(on L%d) " , otherLine )
150- otherCallExpr := enclosingCall (instr )
151- related = []analysis.RelatedInformation {{
152- Pos : otherCallExpr .Pos (),
153- End : otherCallExpr .End (),
154- Message : "other call here" ,
155- }}
156- }
157- callExpr := enclosingCall (call )
158- pass .Report (analysis.Diagnostic {
159- Pos : callExpr .Pos (),
160- End : callExpr .End (),
161- Message : fmt .Sprintf ("yield may be called again %safter returning false" , where ),
162- Related : related ,
163- })
133+ // Precondition: v has a pos within a CallExpr.
134+ enclosingCall := func (v ssa.Value ) ast.Node {
135+ pos := v .Pos ()
136+ cur , ok := inspector .Root ().FindByPos (pos , pos )
137+ if ! ok {
138+ panic (fmt .Sprintf ("can't find node at %v" , safetoken .StartPosition (pass .Fset , pos )))
164139 }
165- case * ssa.If :
166- // Visit both successors, unless cond is yield() or its negation.
167- // In that case visit only the "if !yield()" block.
168- cond := instr .Cond
169- t , f := b .Succs [0 ], b .Succs [1 ]
170-
171- // Strip off any NOT operator.
172- cond , t , f = unnegate (cond , t , f )
173-
174- // As a peephole optimization for this special case:
175- // ok := yield()
176- // ok = ok && yield()
177- // ok = ok && yield()
178- // which in SSA becomes:
179- // yield()
180- // phi(false, yield())
181- // phi(false, yield())
182- // we reduce a cond of phi(false, x) to just x.
183- if phi , ok := cond .(* ssa.Phi ); ok {
184- var nonFalse []ssa.Value
185- for _ , v := range phi .Edges {
186- if c , ok := v .(* ssa.Const ); ok &&
187- ! constant .BoolVal (c .Value ) {
188- continue // constant false
189- }
190- nonFalse = append (nonFalse , v )
191- }
192- if len (nonFalse ) == 1 {
193- cond = nonFalse [0 ]
194- cond , t , f = unnegate (cond , t , f )
195- }
140+ call , ok := moreiters .First (cur .Enclosing ((* ast .CallExpr )(nil )))
141+ if ! ok {
142+ panic (fmt .Sprintf ("no call enclosing %v" , safetoken .StartPosition (pass .Fset , pos )))
196143 }
144+ return call .Node ()
145+ }
197146
198- if cond , ok := cond .(* ssa.Call ); ok && ssaYieldCalls [cond ] != nil {
199- // Skip the successor reached by "if yield() { ... }".
200- } else {
201- visit (t , 0 )
147+ if ! info .reported && ssaYieldCalls [instr ] != nil {
148+ info .reported = true
149+ var (
150+ where = "" // "" => same yield call (a loop)
151+ related []analysis.RelatedInformation
152+ )
153+ // Also report location of reached yield call, if distinct.
154+ if instr != call {
155+ otherLine := safetoken .StartPosition (pass .Fset , instr .Pos ()).Line
156+ where = fmt .Sprintf ("(on L%d) " , otherLine )
157+ otherCallExpr := enclosingCall (instr )
158+ related = []analysis.RelatedInformation {{
159+ Pos : otherCallExpr .Pos (),
160+ End : otherCallExpr .End (),
161+ Message : "other call here" ,
162+ }}
202163 }
203- visit (f , 0 )
204-
205- case * ssa.Jump :
164+ callExpr := enclosingCall (call )
165+ pass .Report (analysis.Diagnostic {
166+ Pos : callExpr .Pos (),
167+ End : callExpr .End (),
168+ Message : fmt .Sprintf ("yield may be called again %safter returning false" , where ),
169+ Related : related ,
170+ })
171+ }
172+ case * ssa.If :
173+ // Visit both successors, unless cond is yield() or its negation.
174+ // In that case visit only the "if !yield()" block.
175+ t , f := reachableSuccs (instr .Cond , isYieldCall )
176+ if t {
206177 visit (b .Succs [0 ], 0 )
207178 }
179+ if f {
180+ visit (b .Succs [1 ], 0 )
181+ }
182+
183+ case * ssa.Jump :
184+ visit (b .Succs [0 ], 0 )
208185 }
209186 }
210187 }
@@ -217,9 +194,129 @@ func run(pass *analysis.Pass) (any, error) {
217194 return nil , nil
218195}
219196
220- func unnegate (cond ssa.Value , t , f * ssa.BasicBlock ) (_ ssa.Value , _ , _ * ssa.BasicBlock ) {
221- if unop , ok := cond .(* ssa.UnOp ); ok && unop .Op == token .NOT {
222- return unop .X , f , t
197+ // reachableSuccs reports whether the (true, false) outcomes of the
198+ // condition are possible.
199+ func reachableSuccs (cond ssa.Value , isYieldCall func (ssa.Value ) bool ) (_t , _f bool ) {
200+ // If the condition is...
201+ //
202+ // ...a constant, we know only one successor is reachable.
203+ //
204+ // ...a yield call, we assume that it returned false,
205+ // and treat it like a constant.
206+ //
207+ // ...a negation !v, we strip the negation and flip the sense
208+ // of the result.
209+ //
210+ // ...a phi node, we recursively find all non-phi leaves
211+ // of the phi graph and treat them like a conjunction,
212+ // e.g. if false || true || yield || yield { ... }.
213+ //
214+ // (We don't actually analyze || and && in this way,
215+ // but we could do them too.)
216+
217+ // This logic addresses cases where conditions are
218+ // materialized as booleans such as this
219+ //
220+ // ok := yield()
221+ // ok = ok && yield()
222+ // ok = ok && yield()
223+ //
224+ // which in SSA becomes:
225+ //
226+ // yield()
227+ // phi(false, yield())
228+ // phi(false, yield())
229+ //
230+ // and we can reduce each phi(false, x) to just x.
231+ //
232+ // Similarly this case:
233+ //
234+ // var ok bool
235+ // if foo { ok = yield() }
236+ // else { ok = yield() }
237+ // if ok { ... }
238+ //
239+ // can be analyzed as "if yield || yield".
240+
241+ // all[false] => all cases are false
242+ // all[true] => all cases are true
243+ all := [2 ]bool {true , true }
244+ for v := range unphi (cond ) {
245+ sense := 1 // 0=false 1=true
246+
247+ // Strip off any NOT operators.
248+ for {
249+ unop , ok := v .(* ssa.UnOp )
250+ if ! (ok && unop .Op == token .NOT ) {
251+ break
252+ }
253+ v = unop .X
254+ sense = 1 - sense
255+ }
256+
257+ switch {
258+ case is [* ssa.Const ](v ):
259+ // "if false" means not all cases are true,
260+ // and vice versa.
261+ if constant .BoolVal (v .(* ssa.Const ).Value ) {
262+ sense = 1 - sense
263+ }
264+ all [sense ] = false
265+
266+ case isYieldCall (v ):
267+ // "if yield" is assumed to be false.
268+ all [sense ] = false // ¬ all cases are true
269+
270+ default :
271+ // Unknown condition:
272+ // ¬ all cases are false
273+ // ¬ all cases are true
274+ return true , true
275+ }
276+ }
277+ if all [0 ] && all [1 ] {
278+ panic ("unphi returned empty sequence" )
279+ }
280+ return ! all [0 ], ! all [1 ]
281+ }
282+
283+ func is [T any ](x any ) bool {
284+ _ , ok := x .(T )
285+ return ok
286+ }
287+
288+ // -- SSA helpers --
289+
290+ // unphi returns the sequence of values formed by recursively
291+ // replacing phi nodes in v by their non-phi operands.
292+ func unphi (v ssa.Value ) iter.Seq [ssa.Value ] {
293+ return func (yield func (ssa.Value ) bool ) {
294+ _ = every (v , yield )
295+ }
296+ }
297+
298+ // every reports whether predicate f is true of each value in the
299+ // sequence formed by recursively replacing phi nodes in v by their
300+ // operands.
301+ func every (v ssa.Value , f func (ssa.Value ) bool ) bool {
302+ var seen map [* ssa.Phi ]bool
303+ var visit func (v ssa.Value ) bool
304+ visit = func (v ssa.Value ) bool {
305+ if phi , ok := v .(* ssa.Phi ); ok {
306+ if ! seen [phi ] {
307+ if seen == nil {
308+ seen = make (map [* ssa.Phi ]bool )
309+ }
310+ seen [phi ] = true
311+ for _ , edge := range phi .Edges {
312+ if ! visit (edge ) {
313+ return false
314+ }
315+ }
316+ }
317+ return true
318+ }
319+ return f (v )
223320 }
224- return cond , t , f
321+ return visit ( v )
225322}
0 commit comments