1 | // Copyright 2016 The Go Authors. All rights reserved. |
---|---|
2 | // Use of this source code is governed by a BSD-style |
3 | // license that can be found in the LICENSE file. |
4 | |
5 | // Package lostcancel defines an Analyzer that checks for failure to |
6 | // call a context cancellation function. |
7 | package lostcancel |
8 | |
9 | import ( |
10 | "fmt" |
11 | "go/ast" |
12 | "go/types" |
13 | |
14 | "golang.org/x/tools/go/analysis" |
15 | "golang.org/x/tools/go/analysis/passes/ctrlflow" |
16 | "golang.org/x/tools/go/analysis/passes/inspect" |
17 | "golang.org/x/tools/go/ast/inspector" |
18 | "golang.org/x/tools/go/cfg" |
19 | ) |
20 | |
21 | const Doc = `check cancel func returned by context.WithCancel is called |
22 | |
23 | The cancellation function returned by context.WithCancel, WithTimeout, |
24 | and WithDeadline must be called or the new context will remain live |
25 | until its parent context is cancelled. |
26 | (The background context is never cancelled.)` |
27 | |
28 | var Analyzer = &analysis.Analyzer{ |
29 | Name: "lostcancel", |
30 | Doc: Doc, |
31 | Run: run, |
32 | Requires: []*analysis.Analyzer{ |
33 | inspect.Analyzer, |
34 | ctrlflow.Analyzer, |
35 | }, |
36 | } |
37 | |
38 | const debug = false |
39 | |
40 | var contextPackage = "context" |
41 | |
42 | // checkLostCancel reports a failure to the call the cancel function |
43 | // returned by context.WithCancel, either because the variable was |
44 | // assigned to the blank identifier, or because there exists a |
45 | // control-flow path from the call to a return statement and that path |
46 | // does not "use" the cancel function. Any reference to the variable |
47 | // counts as a use, even within a nested function literal. |
48 | // If the variable's scope is larger than the function |
49 | // containing the assignment, we assume that other uses exist. |
50 | // |
51 | // checkLostCancel analyzes a single named or literal function. |
52 | func run(pass *analysis.Pass) (interface{}, error) { |
53 | // Fast path: bypass check if file doesn't use context.WithCancel. |
54 | if !hasImport(pass.Pkg, contextPackage) { |
55 | return nil, nil |
56 | } |
57 | |
58 | // Call runFunc for each Func{Decl,Lit}. |
59 | inspect := pass.ResultOf[inspect.Analyzer].(*inspector.Inspector) |
60 | nodeTypes := []ast.Node{ |
61 | (*ast.FuncLit)(nil), |
62 | (*ast.FuncDecl)(nil), |
63 | } |
64 | inspect.Preorder(nodeTypes, func(n ast.Node) { |
65 | runFunc(pass, n) |
66 | }) |
67 | return nil, nil |
68 | } |
69 | |
70 | func runFunc(pass *analysis.Pass, node ast.Node) { |
71 | // Find scope of function node |
72 | var funcScope *types.Scope |
73 | switch v := node.(type) { |
74 | case *ast.FuncLit: |
75 | funcScope = pass.TypesInfo.Scopes[v.Type] |
76 | case *ast.FuncDecl: |
77 | funcScope = pass.TypesInfo.Scopes[v.Type] |
78 | } |
79 | |
80 | // Maps each cancel variable to its defining ValueSpec/AssignStmt. |
81 | cancelvars := make(map[*types.Var]ast.Node) |
82 | |
83 | // TODO(adonovan): opt: refactor to make a single pass |
84 | // over the AST using inspect.WithStack and node types |
85 | // {FuncDecl,FuncLit,CallExpr,SelectorExpr}. |
86 | |
87 | // Find the set of cancel vars to analyze. |
88 | stack := make([]ast.Node, 0, 32) |
89 | ast.Inspect(node, func(n ast.Node) bool { |
90 | switch n.(type) { |
91 | case *ast.FuncLit: |
92 | if len(stack) > 0 { |
93 | return false // don't stray into nested functions |
94 | } |
95 | case nil: |
96 | stack = stack[:len(stack)-1] // pop |
97 | return true |
98 | } |
99 | stack = append(stack, n) // push |
100 | |
101 | // Look for [{AssignStmt,ValueSpec} CallExpr SelectorExpr]: |
102 | // |
103 | // ctx, cancel := context.WithCancel(...) |
104 | // ctx, cancel = context.WithCancel(...) |
105 | // var ctx, cancel = context.WithCancel(...) |
106 | // |
107 | if !isContextWithCancel(pass.TypesInfo, n) || !isCall(stack[len(stack)-2]) { |
108 | return true |
109 | } |
110 | var id *ast.Ident // id of cancel var |
111 | stmt := stack[len(stack)-3] |
112 | switch stmt := stmt.(type) { |
113 | case *ast.ValueSpec: |
114 | if len(stmt.Names) > 1 { |
115 | id = stmt.Names[1] |
116 | } |
117 | case *ast.AssignStmt: |
118 | if len(stmt.Lhs) > 1 { |
119 | id, _ = stmt.Lhs[1].(*ast.Ident) |
120 | } |
121 | } |
122 | if id != nil { |
123 | if id.Name == "_" { |
124 | pass.ReportRangef(id, |
125 | "the cancel function returned by context.%s should be called, not discarded, to avoid a context leak", |
126 | n.(*ast.SelectorExpr).Sel.Name) |
127 | } else if v, ok := pass.TypesInfo.Uses[id].(*types.Var); ok { |
128 | // If the cancel variable is defined outside function scope, |
129 | // do not analyze it. |
130 | if funcScope.Contains(v.Pos()) { |
131 | cancelvars[v] = stmt |
132 | } |
133 | } else if v, ok := pass.TypesInfo.Defs[id].(*types.Var); ok { |
134 | cancelvars[v] = stmt |
135 | } |
136 | } |
137 | return true |
138 | }) |
139 | |
140 | if len(cancelvars) == 0 { |
141 | return // no need to inspect CFG |
142 | } |
143 | |
144 | // Obtain the CFG. |
145 | cfgs := pass.ResultOf[ctrlflow.Analyzer].(*ctrlflow.CFGs) |
146 | var g *cfg.CFG |
147 | var sig *types.Signature |
148 | switch node := node.(type) { |
149 | case *ast.FuncDecl: |
150 | sig, _ = pass.TypesInfo.Defs[node.Name].Type().(*types.Signature) |
151 | if node.Name.Name == "main" && sig.Recv() == nil && pass.Pkg.Name() == "main" { |
152 | // Returning from main.main terminates the process, |
153 | // so there's no need to cancel contexts. |
154 | return |
155 | } |
156 | g = cfgs.FuncDecl(node) |
157 | |
158 | case *ast.FuncLit: |
159 | sig, _ = pass.TypesInfo.Types[node.Type].Type.(*types.Signature) |
160 | g = cfgs.FuncLit(node) |
161 | } |
162 | if sig == nil { |
163 | return // missing type information |
164 | } |
165 | |
166 | // Print CFG. |
167 | if debug { |
168 | fmt.Println(g.Format(pass.Fset)) |
169 | } |
170 | |
171 | // Examine the CFG for each variable in turn. |
172 | // (It would be more efficient to analyze all cancelvars in a |
173 | // single pass over the AST, but seldom is there more than one.) |
174 | for v, stmt := range cancelvars { |
175 | if ret := lostCancelPath(pass, g, v, stmt, sig); ret != nil { |
176 | lineno := pass.Fset.Position(stmt.Pos()).Line |
177 | pass.ReportRangef(stmt, "the %s function is not used on all paths (possible context leak)", v.Name()) |
178 | pass.ReportRangef(ret, "this return statement may be reached without using the %s var defined on line %d", v.Name(), lineno) |
179 | } |
180 | } |
181 | } |
182 | |
183 | func isCall(n ast.Node) bool { _, ok := n.(*ast.CallExpr); return ok } |
184 | |
185 | func hasImport(pkg *types.Package, path string) bool { |
186 | for _, imp := range pkg.Imports() { |
187 | if imp.Path() == path { |
188 | return true |
189 | } |
190 | } |
191 | return false |
192 | } |
193 | |
194 | // isContextWithCancel reports whether n is one of the qualified identifiers |
195 | // context.With{Cancel,Timeout,Deadline}. |
196 | func isContextWithCancel(info *types.Info, n ast.Node) bool { |
197 | sel, ok := n.(*ast.SelectorExpr) |
198 | if !ok { |
199 | return false |
200 | } |
201 | switch sel.Sel.Name { |
202 | case "WithCancel", "WithTimeout", "WithDeadline": |
203 | default: |
204 | return false |
205 | } |
206 | if x, ok := sel.X.(*ast.Ident); ok { |
207 | if pkgname, ok := info.Uses[x].(*types.PkgName); ok { |
208 | return pkgname.Imported().Path() == contextPackage |
209 | } |
210 | // Import failed, so we can't check package path. |
211 | // Just check the local package name (heuristic). |
212 | return x.Name == "context" |
213 | } |
214 | return false |
215 | } |
216 | |
217 | // lostCancelPath finds a path through the CFG, from stmt (which defines |
218 | // the 'cancel' variable v) to a return statement, that doesn't "use" v. |
219 | // If it finds one, it returns the return statement (which may be synthetic). |
220 | // sig is the function's type, if known. |
221 | func lostCancelPath(pass *analysis.Pass, g *cfg.CFG, v *types.Var, stmt ast.Node, sig *types.Signature) *ast.ReturnStmt { |
222 | vIsNamedResult := sig != nil && tupleContains(sig.Results(), v) |
223 | |
224 | // uses reports whether stmts contain a "use" of variable v. |
225 | uses := func(pass *analysis.Pass, v *types.Var, stmts []ast.Node) bool { |
226 | found := false |
227 | for _, stmt := range stmts { |
228 | ast.Inspect(stmt, func(n ast.Node) bool { |
229 | switch n := n.(type) { |
230 | case *ast.Ident: |
231 | if pass.TypesInfo.Uses[n] == v { |
232 | found = true |
233 | } |
234 | case *ast.ReturnStmt: |
235 | // A naked return statement counts as a use |
236 | // of the named result variables. |
237 | if n.Results == nil && vIsNamedResult { |
238 | found = true |
239 | } |
240 | } |
241 | return !found |
242 | }) |
243 | } |
244 | return found |
245 | } |
246 | |
247 | // blockUses computes "uses" for each block, caching the result. |
248 | memo := make(map[*cfg.Block]bool) |
249 | blockUses := func(pass *analysis.Pass, v *types.Var, b *cfg.Block) bool { |
250 | res, ok := memo[b] |
251 | if !ok { |
252 | res = uses(pass, v, b.Nodes) |
253 | memo[b] = res |
254 | } |
255 | return res |
256 | } |
257 | |
258 | // Find the var's defining block in the CFG, |
259 | // plus the rest of the statements of that block. |
260 | var defblock *cfg.Block |
261 | var rest []ast.Node |
262 | outer: |
263 | for _, b := range g.Blocks { |
264 | for i, n := range b.Nodes { |
265 | if n == stmt { |
266 | defblock = b |
267 | rest = b.Nodes[i+1:] |
268 | break outer |
269 | } |
270 | } |
271 | } |
272 | if defblock == nil { |
273 | panic("internal error: can't find defining block for cancel var") |
274 | } |
275 | |
276 | // Is v "used" in the remainder of its defining block? |
277 | if uses(pass, v, rest) { |
278 | return nil |
279 | } |
280 | |
281 | // Does the defining block return without using v? |
282 | if ret := defblock.Return(); ret != nil { |
283 | return ret |
284 | } |
285 | |
286 | // Search the CFG depth-first for a path, from defblock to a |
287 | // return block, in which v is never "used". |
288 | seen := make(map[*cfg.Block]bool) |
289 | var search func(blocks []*cfg.Block) *ast.ReturnStmt |
290 | search = func(blocks []*cfg.Block) *ast.ReturnStmt { |
291 | for _, b := range blocks { |
292 | if seen[b] { |
293 | continue |
294 | } |
295 | seen[b] = true |
296 | |
297 | // Prune the search if the block uses v. |
298 | if blockUses(pass, v, b) { |
299 | continue |
300 | } |
301 | |
302 | // Found path to return statement? |
303 | if ret := b.Return(); ret != nil { |
304 | if debug { |
305 | fmt.Printf("found path to return in block %s\n", b) |
306 | } |
307 | return ret // found |
308 | } |
309 | |
310 | // Recur |
311 | if ret := search(b.Succs); ret != nil { |
312 | if debug { |
313 | fmt.Printf(" from block %s\n", b) |
314 | } |
315 | return ret |
316 | } |
317 | } |
318 | return nil |
319 | } |
320 | return search(defblock.Succs) |
321 | } |
322 | |
323 | func tupleContains(tuple *types.Tuple, v *types.Var) bool { |
324 | for i := 0; i < tuple.Len(); i++ { |
325 | if tuple.At(i) == v { |
326 | return true |
327 | } |
328 | } |
329 | return false |
330 | } |
331 |
Members