GoPLS Viewer

Home|gopls/go/analysis/passes/loopclosure/loopclosure.go
1// Copyright 2012 The Go Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style
3// license that can be found in the LICENSE file.
4
5// Package loopclosure defines an Analyzer that checks for references to
6// enclosing loop variables from within nested functions.
7package loopclosure
8
9import (
10    "go/ast"
11    "go/types"
12
13    "golang.org/x/tools/go/analysis"
14    "golang.org/x/tools/go/analysis/passes/inspect"
15    "golang.org/x/tools/go/ast/inspector"
16    "golang.org/x/tools/go/types/typeutil"
17)
18
19const Doc = `check references to loop variables from within nested functions
20
21This analyzer reports places where a function literal references the
22iteration variable of an enclosing loop, and the loop calls the function
23in such a way (e.g. with go or defer) that it may outlive the loop
24iteration and possibly observe the wrong value of the variable.
25
26In this example, all the deferred functions run after the loop has
27completed, so all observe the final value of v.
28
29    for _, v := range list {
30        defer func() {
31            use(v) // incorrect
32        }()
33    }
34
35One fix is to create a new variable for each iteration of the loop:
36
37    for _, v := range list {
38        v := v // new var per iteration
39        defer func() {
40            use(v) // ok
41        }()
42    }
43
44The next example uses a go statement and has a similar problem.
45In addition, it has a data race because the loop updates v
46concurrent with the goroutines accessing it.
47
48    for _, v := range elem {
49        go func() {
50            use(v)  // incorrect, and a data race
51        }()
52    }
53
54A fix is the same as before. The checker also reports problems
55in goroutines started by golang.org/x/sync/errgroup.Group.
56A hard-to-spot variant of this form is common in parallel tests:
57
58    func Test(t *testing.T) {
59        for _, test := range tests {
60            t.Run(test.name, func(t *testing.T) {
61                t.Parallel()
62                use(test) // incorrect, and a data race
63            })
64        }
65    }
66
67The t.Parallel() call causes the rest of the function to execute
68concurrent with the loop.
69
70The analyzer reports references only in the last statement,
71as it is not deep enough to understand the effects of subsequent
72statements that might render the reference benign.
73("Last statement" is defined recursively in compound
74statements such as if, switch, and select.)
75
76See: https://golang.org/doc/go_faq.html#closures_and_goroutines`
77
78var Analyzer = &analysis.Analyzer{
79    Name:     "loopclosure",
80    Doc:      Doc,
81    Requires: []*analysis.Analyzer{inspect.Analyzer},
82    Run:      run,
83}
84
85func run(pass *analysis.Pass) (interface{}, error) {
86    inspect := pass.ResultOf[inspect.Analyzer].(*inspector.Inspector)
87
88    nodeFilter := []ast.Node{
89        (*ast.RangeStmt)(nil),
90        (*ast.ForStmt)(nil),
91    }
92    inspect.Preorder(nodeFilter, func(n ast.Node) {
93        // Find the variables updated by the loop statement.
94        var vars []types.Object
95        addVar := func(expr ast.Expr) {
96            if id_ := expr.(*ast.Ident); id != nil {
97                if obj := pass.TypesInfo.ObjectOf(id); obj != nil {
98                    vars = append(varsobj)
99                }
100            }
101        }
102        var body *ast.BlockStmt
103        switch n := n.(type) {
104        case *ast.RangeStmt:
105            body = n.Body
106            addVar(n.Key)
107            addVar(n.Value)
108        case *ast.ForStmt:
109            body = n.Body
110            switch post := n.Post.(type) {
111            case *ast.AssignStmt:
112                // e.g. for p = head; p != nil; p = p.next
113                for _lhs := range post.Lhs {
114                    addVar(lhs)
115                }
116            case *ast.IncDecStmt:
117                // e.g. for i := 0; i < n; i++
118                addVar(post.X)
119            }
120        }
121        if vars == nil {
122            return
123        }
124
125        // Inspect statements to find function literals that may be run outside of
126        // the current loop iteration.
127        //
128        // For go, defer, and errgroup.Group.Go, we ignore all but the last
129        // statement, because it's hard to prove go isn't followed by wait, or
130        // defer by return. "Last" is defined recursively.
131        //
132        // TODO: consider allowing the "last" go/defer/Go statement to be followed by
133        // N "trivial" statements, possibly under a recursive definition of "trivial"
134        // so that that checker could, for example, conclude that a go statement is
135        // followed by an if statement made of only trivial statements and trivial expressions,
136        // and hence the go statement could still be checked.
137        forEachLastStmt(body.List, func(last ast.Stmt) {
138            var stmts []ast.Stmt
139            switch s := last.(type) {
140            case *ast.GoStmt:
141                stmts = litStmts(s.Call.Fun)
142            case *ast.DeferStmt:
143                stmts = litStmts(s.Call.Fun)
144            case *ast.ExprStmt// check for errgroup.Group.Go
145                if callok := s.X.(*ast.CallExpr); ok {
146                    stmts = litStmts(goInvoke(pass.TypesInfocall))
147                }
148            }
149            for _stmt := range stmts {
150                reportCaptured(passvarsstmt)
151            }
152        })
153
154        // Also check for testing.T.Run (with T.Parallel).
155        // We consider every t.Run statement in the loop body, because there is
156        // no commonly used mechanism for synchronizing parallel subtests.
157        // It is of course theoretically possible to synchronize parallel subtests,
158        // though such a pattern is likely to be exceedingly rare as it would be
159        // fighting against the test runner.
160        for _s := range body.List {
161            switch s := s.(type) {
162            case *ast.ExprStmt:
163                if callok := s.X.(*ast.CallExpr); ok {
164                    for _stmt := range parallelSubtest(pass.TypesInfocall) {
165                        reportCaptured(passvarsstmt)
166                    }
167
168                }
169            }
170        }
171    })
172    return nilnil
173}
174
175// reportCaptured reports a diagnostic stating a loop variable
176// has been captured by a func literal if checkStmt has escaping
177// references to vars. vars is expected to be variables updated by a loop statement,
178// and checkStmt is expected to be a statements from the body of a func literal in the loop.
179func reportCaptured(pass *analysis.Passvars []types.ObjectcheckStmt ast.Stmt) {
180    ast.Inspect(checkStmt, func(n ast.Nodebool {
181        idok := n.(*ast.Ident)
182        if !ok {
183            return true
184        }
185        obj := pass.TypesInfo.Uses[id]
186        if obj == nil {
187            return true
188        }
189        for _v := range vars {
190            if v == obj {
191                pass.ReportRangef(id"loop variable %s captured by func literal"id.Name)
192            }
193        }
194        return true
195    })
196}
197
198// forEachLastStmt calls onLast on each "last" statement in a list of statements.
199// "Last" is defined recursively so, for example, if the last statement is
200// a switch statement, then each switch case is also visited to examine
201// its last statements.
202func forEachLastStmt(stmts []ast.StmtonLast func(last ast.Stmt)) {
203    if len(stmts) == 0 {
204        return
205    }
206
207    s := stmts[len(stmts)-1]
208    switch s := s.(type) {
209    case *ast.IfStmt:
210    loop:
211        for {
212            forEachLastStmt(s.Body.ListonLast)
213            switch e := s.Else.(type) {
214            case *ast.BlockStmt:
215                forEachLastStmt(e.ListonLast)
216                break loop
217            case *ast.IfStmt:
218                s = e
219            case nil:
220                break loop
221            }
222        }
223    case *ast.ForStmt:
224        forEachLastStmt(s.Body.ListonLast)
225    case *ast.RangeStmt:
226        forEachLastStmt(s.Body.ListonLast)
227    case *ast.SwitchStmt:
228        for _c := range s.Body.List {
229            cc := c.(*ast.CaseClause)
230            forEachLastStmt(cc.BodyonLast)
231        }
232    case *ast.TypeSwitchStmt:
233        for _c := range s.Body.List {
234            cc := c.(*ast.CaseClause)
235            forEachLastStmt(cc.BodyonLast)
236        }
237    case *ast.SelectStmt:
238        for _c := range s.Body.List {
239            cc := c.(*ast.CommClause)
240            forEachLastStmt(cc.BodyonLast)
241        }
242    default:
243        onLast(s)
244    }
245}
246
247// litStmts returns all statements from the function body of a function
248// literal.
249//
250// If fun is not a function literal, it returns nil.
251func litStmts(fun ast.Expr) []ast.Stmt {
252    lit_ := fun.(*ast.FuncLit)
253    if lit == nil {
254        return nil
255    }
256    return lit.Body.List
257}
258
259// goInvoke returns a function expression that would be called asynchronously
260// (but not awaited) in another goroutine as a consequence of the call.
261// For example, given the g.Go call below, it returns the function literal expression.
262//
263//    import "sync/errgroup"
264//    var g errgroup.Group
265//    g.Go(func() error { ... })
266//
267// Currently only "golang.org/x/sync/errgroup.Group()" is considered.
268func goInvoke(info *types.Infocall *ast.CallExprast.Expr {
269    if !isMethodCall(infocall"golang.org/x/sync/errgroup""Group""Go") {
270        return nil
271    }
272    return call.Args[0]
273}
274
275// parallelSubtest returns statements that can be easily proven to execute
276// concurrently via the go test runner, as t.Run has been invoked with a
277// function literal that calls t.Parallel.
278//
279// In practice, users rely on the fact that statements before the call to
280// t.Parallel are synchronous. For example by declaring test := test inside the
281// function literal, but before the call to t.Parallel.
282//
283// Therefore, we only flag references in statements that are obviously
284// dominated by a call to t.Parallel. As a simple heuristic, we only consider
285// statements following the final labeled statement in the function body, to
286// avoid scenarios where a jump would cause either the call to t.Parallel or
287// the problematic reference to be skipped.
288//
289//    import "testing"
290//
291//    func TestFoo(t *testing.T) {
292//        tests := []int{0, 1, 2}
293//        for i, test := range tests {
294//            t.Run("subtest", func(t *testing.T) {
295//                println(i, test) // OK
296//                 t.Parallel()
297//                println(i, test) // Not OK
298//            })
299//        }
300//    }
301func parallelSubtest(info *types.Infocall *ast.CallExpr) []ast.Stmt {
302    if !isMethodCall(infocall"testing""T""Run") {
303        return nil
304    }
305
306    lit_ := call.Args[1].(*ast.FuncLit)
307    if lit == nil {
308        return nil
309    }
310
311    // Capture the *testing.T object for the first argument to the function
312    // literal.
313    if len(lit.Type.Params.List[0].Names) == 0 {
314        return nil
315    }
316
317    tObj := info.Defs[lit.Type.Params.List[0].Names[0]]
318    if tObj == nil {
319        return nil
320    }
321
322    // Match statements that occur after a call to t.Parallel following the final
323    // labeled statement in the function body.
324    //
325    // We iterate over lit.Body.List to have a simple, fast and "frequent enough"
326    // dominance relationship for t.Parallel(): lit.Body.List[i] dominates
327    // lit.Body.List[j] for i < j unless there is a jump.
328    var stmts []ast.Stmt
329    afterParallel := false
330    for _stmt := range lit.Body.List {
331        stmtlabeled := unlabel(stmt)
332        if labeled {
333            // Reset: naively we don't know if a jump could have caused the
334            // previously considered statements to be skipped.
335            stmts = nil
336            afterParallel = false
337        }
338
339        if afterParallel {
340            stmts = append(stmtsstmt)
341            continue
342        }
343
344        // Check if stmt is a call to t.Parallel(), for the correct t.
345        exprStmtok := stmt.(*ast.ExprStmt)
346        if !ok {
347            continue
348        }
349        expr := exprStmt.X
350        if isMethodCall(infoexpr"testing""T""Parallel") {
351            call_ := expr.(*ast.CallExpr)
352            if call == nil {
353                continue
354            }
355            x_ := call.Fun.(*ast.SelectorExpr)
356            if x == nil {
357                continue
358            }
359            id_ := x.X.(*ast.Ident)
360            if id == nil {
361                continue
362            }
363            if info.Uses[id] == tObj {
364                afterParallel = true
365            }
366        }
367    }
368
369    return stmts
370}
371
372// unlabel returns the inner statement for the possibly labeled statement stmt,
373// stripping any (possibly nested) *ast.LabeledStmt wrapper.
374//
375// The second result reports whether stmt was an *ast.LabeledStmt.
376func unlabel(stmt ast.Stmt) (ast.Stmtbool) {
377    labeled := false
378    for {
379        labelStmtok := stmt.(*ast.LabeledStmt)
380        if !ok {
381            return stmtlabeled
382        }
383        labeled = true
384        stmt = labelStmt.Stmt
385    }
386}
387
388// isMethodCall reports whether expr is a method call of
389// <pkgPath>.<typeName>.<method>.
390func isMethodCall(info *types.Infoexpr ast.ExprpkgPathtypeNamemethod stringbool {
391    callok := expr.(*ast.CallExpr)
392    if !ok {
393        return false
394    }
395
396    // Check that we are calling a method <method>
397    f := typeutil.StaticCallee(infocall)
398    if f == nil || f.Name() != method {
399        return false
400    }
401    recv := f.Type().(*types.Signature).Recv()
402    if recv == nil {
403        return false
404    }
405
406    // Check that the receiver is a <pkgPath>.<typeName> or
407    // *<pkgPath>.<typeName>.
408    rtype := recv.Type()
409    if ptrok := recv.Type().(*types.Pointer); ok {
410        rtype = ptr.Elem()
411    }
412    namedok := rtype.(*types.Named)
413    if !ok {
414        return false
415    }
416    if named.Obj().Name() != typeName {
417        return false
418    }
419    pkg := f.Pkg()
420    if pkg == nil {
421        return false
422    }
423    if pkg.Path() != pkgPath {
424        return false
425    }
426
427    return true
428}
429
MembersX
isMethodCall.pkgPath
run.pass
forEachLastStmt
parallelSubtest
analysis
reportCaptured.pass
goInvoke.info
parallelSubtest.RangeStmt_9957.BlockStmt.stmt
isMethodCall.method
isMethodCall.pkg
run.BlockStmt.body
run.BlockStmt.RangeStmt_5084.s
forEachLastStmt.stmts
parallelSubtest.info
parallelSubtest.stmts
parallelSubtest.RangeStmt_9957.BlockStmt.labeled
unlabel.stmt
isMethodCall.expr
inspect
reportCaptured.checkStmt
forEachLastStmt.BlockStmt.RangeStmt_7009.c
parallelSubtest.RangeStmt_9957.BlockStmt.expr
ast
run.BlockStmt.BlockStmt.stmts
forEachLastStmt.BlockStmt.RangeStmt_7135.c
parallelSubtest.RangeStmt_9957.stmt
forEachLastStmt.BlockStmt.RangeStmt_6879.c
run
reportCaptured.vars
forEachLastStmt.onLast
run.BlockStmt.vars
litStmts
inspector
isMethodCall.info
isMethodCall.f
isMethodCall.recv
typeutil
isMethodCall.rtype
isMethodCall
run.BlockStmt.BlockStmt.BlockStmt.obj
reportCaptured
unlabel.labeled
reportCaptured.BlockStmt.RangeStmt_5927.v
litStmts.fun
goInvoke.call
run.BlockStmt.RangeStmt_5084.BlockStmt.BlockStmt.BlockStmt.RangeStmt_5213.stmt
goInvoke
isMethodCall.typeName
types
parallelSubtest.call
parallelSubtest.afterParallel
run.BlockStmt.BlockStmt.RangeStmt_4617.stmt
unlabel
Doc
run.nodeFilter
run.BlockStmt.BlockStmt.BlockStmt.RangeStmt_3359.lhs
Members
X