1 | // Copyright 2012 The Go Authors. All rights reserved. |
---|---|
2 | // Use of this source code is governed by a BSD-style |
3 | // license that can be found in the LICENSE file. |
4 | |
5 | // Package loopclosure defines an Analyzer that checks for references to |
6 | // enclosing loop variables from within nested functions. |
7 | package loopclosure |
8 | |
9 | import ( |
10 | "go/ast" |
11 | "go/types" |
12 | |
13 | "golang.org/x/tools/go/analysis" |
14 | "golang.org/x/tools/go/analysis/passes/inspect" |
15 | "golang.org/x/tools/go/ast/inspector" |
16 | "golang.org/x/tools/go/types/typeutil" |
17 | ) |
18 | |
19 | const Doc = `check references to loop variables from within nested functions |
20 | |
21 | This analyzer reports places where a function literal references the |
22 | iteration variable of an enclosing loop, and the loop calls the function |
23 | in such a way (e.g. with go or defer) that it may outlive the loop |
24 | iteration and possibly observe the wrong value of the variable. |
25 | |
26 | In this example, all the deferred functions run after the loop has |
27 | completed, so all observe the final value of v. |
28 | |
29 | for _, v := range list { |
30 | defer func() { |
31 | use(v) // incorrect |
32 | }() |
33 | } |
34 | |
35 | One fix is to create a new variable for each iteration of the loop: |
36 | |
37 | for _, v := range list { |
38 | v := v // new var per iteration |
39 | defer func() { |
40 | use(v) // ok |
41 | }() |
42 | } |
43 | |
44 | The next example uses a go statement and has a similar problem. |
45 | In addition, it has a data race because the loop updates v |
46 | concurrent with the goroutines accessing it. |
47 | |
48 | for _, v := range elem { |
49 | go func() { |
50 | use(v) // incorrect, and a data race |
51 | }() |
52 | } |
53 | |
54 | A fix is the same as before. The checker also reports problems |
55 | in goroutines started by golang.org/x/sync/errgroup.Group. |
56 | A hard-to-spot variant of this form is common in parallel tests: |
57 | |
58 | func Test(t *testing.T) { |
59 | for _, test := range tests { |
60 | t.Run(test.name, func(t *testing.T) { |
61 | t.Parallel() |
62 | use(test) // incorrect, and a data race |
63 | }) |
64 | } |
65 | } |
66 | |
67 | The t.Parallel() call causes the rest of the function to execute |
68 | concurrent with the loop. |
69 | |
70 | The analyzer reports references only in the last statement, |
71 | as it is not deep enough to understand the effects of subsequent |
72 | statements that might render the reference benign. |
73 | ("Last statement" is defined recursively in compound |
74 | statements such as if, switch, and select.) |
75 | |
76 | See: https://golang.org/doc/go_faq.html#closures_and_goroutines` |
77 | |
78 | var Analyzer = &analysis.Analyzer{ |
79 | Name: "loopclosure", |
80 | Doc: Doc, |
81 | Requires: []*analysis.Analyzer{inspect.Analyzer}, |
82 | Run: run, |
83 | } |
84 | |
85 | func run(pass *analysis.Pass) (interface{}, error) { |
86 | inspect := pass.ResultOf[inspect.Analyzer].(*inspector.Inspector) |
87 | |
88 | nodeFilter := []ast.Node{ |
89 | (*ast.RangeStmt)(nil), |
90 | (*ast.ForStmt)(nil), |
91 | } |
92 | inspect.Preorder(nodeFilter, func(n ast.Node) { |
93 | // Find the variables updated by the loop statement. |
94 | var vars []types.Object |
95 | addVar := func(expr ast.Expr) { |
96 | if id, _ := expr.(*ast.Ident); id != nil { |
97 | if obj := pass.TypesInfo.ObjectOf(id); obj != nil { |
98 | vars = append(vars, obj) |
99 | } |
100 | } |
101 | } |
102 | var body *ast.BlockStmt |
103 | switch n := n.(type) { |
104 | case *ast.RangeStmt: |
105 | body = n.Body |
106 | addVar(n.Key) |
107 | addVar(n.Value) |
108 | case *ast.ForStmt: |
109 | body = n.Body |
110 | switch post := n.Post.(type) { |
111 | case *ast.AssignStmt: |
112 | // e.g. for p = head; p != nil; p = p.next |
113 | for _, lhs := range post.Lhs { |
114 | addVar(lhs) |
115 | } |
116 | case *ast.IncDecStmt: |
117 | // e.g. for i := 0; i < n; i++ |
118 | addVar(post.X) |
119 | } |
120 | } |
121 | if vars == nil { |
122 | return |
123 | } |
124 | |
125 | // Inspect statements to find function literals that may be run outside of |
126 | // the current loop iteration. |
127 | // |
128 | // For go, defer, and errgroup.Group.Go, we ignore all but the last |
129 | // statement, because it's hard to prove go isn't followed by wait, or |
130 | // defer by return. "Last" is defined recursively. |
131 | // |
132 | // TODO: consider allowing the "last" go/defer/Go statement to be followed by |
133 | // N "trivial" statements, possibly under a recursive definition of "trivial" |
134 | // so that that checker could, for example, conclude that a go statement is |
135 | // followed by an if statement made of only trivial statements and trivial expressions, |
136 | // and hence the go statement could still be checked. |
137 | forEachLastStmt(body.List, func(last ast.Stmt) { |
138 | var stmts []ast.Stmt |
139 | switch s := last.(type) { |
140 | case *ast.GoStmt: |
141 | stmts = litStmts(s.Call.Fun) |
142 | case *ast.DeferStmt: |
143 | stmts = litStmts(s.Call.Fun) |
144 | case *ast.ExprStmt: // check for errgroup.Group.Go |
145 | if call, ok := s.X.(*ast.CallExpr); ok { |
146 | stmts = litStmts(goInvoke(pass.TypesInfo, call)) |
147 | } |
148 | } |
149 | for _, stmt := range stmts { |
150 | reportCaptured(pass, vars, stmt) |
151 | } |
152 | }) |
153 | |
154 | // Also check for testing.T.Run (with T.Parallel). |
155 | // We consider every t.Run statement in the loop body, because there is |
156 | // no commonly used mechanism for synchronizing parallel subtests. |
157 | // It is of course theoretically possible to synchronize parallel subtests, |
158 | // though such a pattern is likely to be exceedingly rare as it would be |
159 | // fighting against the test runner. |
160 | for _, s := range body.List { |
161 | switch s := s.(type) { |
162 | case *ast.ExprStmt: |
163 | if call, ok := s.X.(*ast.CallExpr); ok { |
164 | for _, stmt := range parallelSubtest(pass.TypesInfo, call) { |
165 | reportCaptured(pass, vars, stmt) |
166 | } |
167 | |
168 | } |
169 | } |
170 | } |
171 | }) |
172 | return nil, nil |
173 | } |
174 | |
175 | // reportCaptured reports a diagnostic stating a loop variable |
176 | // has been captured by a func literal if checkStmt has escaping |
177 | // references to vars. vars is expected to be variables updated by a loop statement, |
178 | // and checkStmt is expected to be a statements from the body of a func literal in the loop. |
179 | func reportCaptured(pass *analysis.Pass, vars []types.Object, checkStmt ast.Stmt) { |
180 | ast.Inspect(checkStmt, func(n ast.Node) bool { |
181 | id, ok := n.(*ast.Ident) |
182 | if !ok { |
183 | return true |
184 | } |
185 | obj := pass.TypesInfo.Uses[id] |
186 | if obj == nil { |
187 | return true |
188 | } |
189 | for _, v := range vars { |
190 | if v == obj { |
191 | pass.ReportRangef(id, "loop variable %s captured by func literal", id.Name) |
192 | } |
193 | } |
194 | return true |
195 | }) |
196 | } |
197 | |
198 | // forEachLastStmt calls onLast on each "last" statement in a list of statements. |
199 | // "Last" is defined recursively so, for example, if the last statement is |
200 | // a switch statement, then each switch case is also visited to examine |
201 | // its last statements. |
202 | func forEachLastStmt(stmts []ast.Stmt, onLast func(last ast.Stmt)) { |
203 | if len(stmts) == 0 { |
204 | return |
205 | } |
206 | |
207 | s := stmts[len(stmts)-1] |
208 | switch s := s.(type) { |
209 | case *ast.IfStmt: |
210 | loop: |
211 | for { |
212 | forEachLastStmt(s.Body.List, onLast) |
213 | switch e := s.Else.(type) { |
214 | case *ast.BlockStmt: |
215 | forEachLastStmt(e.List, onLast) |
216 | break loop |
217 | case *ast.IfStmt: |
218 | s = e |
219 | case nil: |
220 | break loop |
221 | } |
222 | } |
223 | case *ast.ForStmt: |
224 | forEachLastStmt(s.Body.List, onLast) |
225 | case *ast.RangeStmt: |
226 | forEachLastStmt(s.Body.List, onLast) |
227 | case *ast.SwitchStmt: |
228 | for _, c := range s.Body.List { |
229 | cc := c.(*ast.CaseClause) |
230 | forEachLastStmt(cc.Body, onLast) |
231 | } |
232 | case *ast.TypeSwitchStmt: |
233 | for _, c := range s.Body.List { |
234 | cc := c.(*ast.CaseClause) |
235 | forEachLastStmt(cc.Body, onLast) |
236 | } |
237 | case *ast.SelectStmt: |
238 | for _, c := range s.Body.List { |
239 | cc := c.(*ast.CommClause) |
240 | forEachLastStmt(cc.Body, onLast) |
241 | } |
242 | default: |
243 | onLast(s) |
244 | } |
245 | } |
246 | |
247 | // litStmts returns all statements from the function body of a function |
248 | // literal. |
249 | // |
250 | // If fun is not a function literal, it returns nil. |
251 | func litStmts(fun ast.Expr) []ast.Stmt { |
252 | lit, _ := fun.(*ast.FuncLit) |
253 | if lit == nil { |
254 | return nil |
255 | } |
256 | return lit.Body.List |
257 | } |
258 | |
259 | // goInvoke returns a function expression that would be called asynchronously |
260 | // (but not awaited) in another goroutine as a consequence of the call. |
261 | // For example, given the g.Go call below, it returns the function literal expression. |
262 | // |
263 | // import "sync/errgroup" |
264 | // var g errgroup.Group |
265 | // g.Go(func() error { ... }) |
266 | // |
267 | // Currently only "golang.org/x/sync/errgroup.Group()" is considered. |
268 | func goInvoke(info *types.Info, call *ast.CallExpr) ast.Expr { |
269 | if !isMethodCall(info, call, "golang.org/x/sync/errgroup", "Group", "Go") { |
270 | return nil |
271 | } |
272 | return call.Args[0] |
273 | } |
274 | |
275 | // parallelSubtest returns statements that can be easily proven to execute |
276 | // concurrently via the go test runner, as t.Run has been invoked with a |
277 | // function literal that calls t.Parallel. |
278 | // |
279 | // In practice, users rely on the fact that statements before the call to |
280 | // t.Parallel are synchronous. For example by declaring test := test inside the |
281 | // function literal, but before the call to t.Parallel. |
282 | // |
283 | // Therefore, we only flag references in statements that are obviously |
284 | // dominated by a call to t.Parallel. As a simple heuristic, we only consider |
285 | // statements following the final labeled statement in the function body, to |
286 | // avoid scenarios where a jump would cause either the call to t.Parallel or |
287 | // the problematic reference to be skipped. |
288 | // |
289 | // import "testing" |
290 | // |
291 | // func TestFoo(t *testing.T) { |
292 | // tests := []int{0, 1, 2} |
293 | // for i, test := range tests { |
294 | // t.Run("subtest", func(t *testing.T) { |
295 | // println(i, test) // OK |
296 | // t.Parallel() |
297 | // println(i, test) // Not OK |
298 | // }) |
299 | // } |
300 | // } |
301 | func parallelSubtest(info *types.Info, call *ast.CallExpr) []ast.Stmt { |
302 | if !isMethodCall(info, call, "testing", "T", "Run") { |
303 | return nil |
304 | } |
305 | |
306 | lit, _ := call.Args[1].(*ast.FuncLit) |
307 | if lit == nil { |
308 | return nil |
309 | } |
310 | |
311 | // Capture the *testing.T object for the first argument to the function |
312 | // literal. |
313 | if len(lit.Type.Params.List[0].Names) == 0 { |
314 | return nil |
315 | } |
316 | |
317 | tObj := info.Defs[lit.Type.Params.List[0].Names[0]] |
318 | if tObj == nil { |
319 | return nil |
320 | } |
321 | |
322 | // Match statements that occur after a call to t.Parallel following the final |
323 | // labeled statement in the function body. |
324 | // |
325 | // We iterate over lit.Body.List to have a simple, fast and "frequent enough" |
326 | // dominance relationship for t.Parallel(): lit.Body.List[i] dominates |
327 | // lit.Body.List[j] for i < j unless there is a jump. |
328 | var stmts []ast.Stmt |
329 | afterParallel := false |
330 | for _, stmt := range lit.Body.List { |
331 | stmt, labeled := unlabel(stmt) |
332 | if labeled { |
333 | // Reset: naively we don't know if a jump could have caused the |
334 | // previously considered statements to be skipped. |
335 | stmts = nil |
336 | afterParallel = false |
337 | } |
338 | |
339 | if afterParallel { |
340 | stmts = append(stmts, stmt) |
341 | continue |
342 | } |
343 | |
344 | // Check if stmt is a call to t.Parallel(), for the correct t. |
345 | exprStmt, ok := stmt.(*ast.ExprStmt) |
346 | if !ok { |
347 | continue |
348 | } |
349 | expr := exprStmt.X |
350 | if isMethodCall(info, expr, "testing", "T", "Parallel") { |
351 | call, _ := expr.(*ast.CallExpr) |
352 | if call == nil { |
353 | continue |
354 | } |
355 | x, _ := call.Fun.(*ast.SelectorExpr) |
356 | if x == nil { |
357 | continue |
358 | } |
359 | id, _ := x.X.(*ast.Ident) |
360 | if id == nil { |
361 | continue |
362 | } |
363 | if info.Uses[id] == tObj { |
364 | afterParallel = true |
365 | } |
366 | } |
367 | } |
368 | |
369 | return stmts |
370 | } |
371 | |
372 | // unlabel returns the inner statement for the possibly labeled statement stmt, |
373 | // stripping any (possibly nested) *ast.LabeledStmt wrapper. |
374 | // |
375 | // The second result reports whether stmt was an *ast.LabeledStmt. |
376 | func unlabel(stmt ast.Stmt) (ast.Stmt, bool) { |
377 | labeled := false |
378 | for { |
379 | labelStmt, ok := stmt.(*ast.LabeledStmt) |
380 | if !ok { |
381 | return stmt, labeled |
382 | } |
383 | labeled = true |
384 | stmt = labelStmt.Stmt |
385 | } |
386 | } |
387 | |
388 | // isMethodCall reports whether expr is a method call of |
389 | // <pkgPath>.<typeName>.<method>. |
390 | func isMethodCall(info *types.Info, expr ast.Expr, pkgPath, typeName, method string) bool { |
391 | call, ok := expr.(*ast.CallExpr) |
392 | if !ok { |
393 | return false |
394 | } |
395 | |
396 | // Check that we are calling a method <method> |
397 | f := typeutil.StaticCallee(info, call) |
398 | if f == nil || f.Name() != method { |
399 | return false |
400 | } |
401 | recv := f.Type().(*types.Signature).Recv() |
402 | if recv == nil { |
403 | return false |
404 | } |
405 | |
406 | // Check that the receiver is a <pkgPath>.<typeName> or |
407 | // *<pkgPath>.<typeName>. |
408 | rtype := recv.Type() |
409 | if ptr, ok := recv.Type().(*types.Pointer); ok { |
410 | rtype = ptr.Elem() |
411 | } |
412 | named, ok := rtype.(*types.Named) |
413 | if !ok { |
414 | return false |
415 | } |
416 | if named.Obj().Name() != typeName { |
417 | return false |
418 | } |
419 | pkg := f.Pkg() |
420 | if pkg == nil { |
421 | return false |
422 | } |
423 | if pkg.Path() != pkgPath { |
424 | return false |
425 | } |
426 | |
427 | return true |
428 | } |
429 |
Members