1 | // Copyright 2022 The Go Authors. All rights reserved. |
---|---|
2 | // Use of this source code is governed by a BSD-style |
3 | // license that can be found in the LICENSE file. |
4 | package ssa |
5 | |
6 | import ( |
7 | "fmt" |
8 | "go/types" |
9 | |
10 | "golang.org/x/tools/internal/typeparams" |
11 | ) |
12 | |
13 | // Type substituter for a fixed set of replacement types. |
14 | // |
15 | // A nil *subster is an valid, empty substitution map. It always acts as |
16 | // the identity function. This allows for treating parameterized and |
17 | // non-parameterized functions identically while compiling to ssa. |
18 | // |
19 | // Not concurrency-safe. |
20 | type subster struct { |
21 | // TODO(zpavlinovic): replacements can contain type params |
22 | // when generating instances inside of a generic function body. |
23 | replacements map[*typeparams.TypeParam]types.Type // values should contain no type params |
24 | cache map[types.Type]types.Type // cache of subst results |
25 | ctxt *typeparams.Context |
26 | debug bool // perform extra debugging checks |
27 | // TODO(taking): consider adding Pos |
28 | } |
29 | |
30 | // Returns a subster that replaces tparams[i] with targs[i]. Uses ctxt as a cache. |
31 | // targs should not contain any types in tparams. |
32 | func makeSubster(ctxt *typeparams.Context, tparams *typeparams.TypeParamList, targs []types.Type, debug bool) *subster { |
33 | assert(tparams.Len() == len(targs), "makeSubster argument count must match") |
34 | |
35 | subst := &subster{ |
36 | replacements: make(map[*typeparams.TypeParam]types.Type, tparams.Len()), |
37 | cache: make(map[types.Type]types.Type), |
38 | ctxt: ctxt, |
39 | debug: debug, |
40 | } |
41 | for i := 0; i < tparams.Len(); i++ { |
42 | subst.replacements[tparams.At(i)] = targs[i] |
43 | } |
44 | if subst.debug { |
45 | if err := subst.wellFormed(); err != nil { |
46 | panic(err) |
47 | } |
48 | } |
49 | return subst |
50 | } |
51 | |
52 | // wellFormed returns an error if subst was not properly initialized. |
53 | func (subst *subster) wellFormed() error { |
54 | if subst == nil || len(subst.replacements) == 0 { |
55 | return nil |
56 | } |
57 | // Check that all of the type params do not appear in the arguments. |
58 | s := make(map[types.Type]bool, len(subst.replacements)) |
59 | for tparam := range subst.replacements { |
60 | s[tparam] = true |
61 | } |
62 | for _, r := range subst.replacements { |
63 | if reaches(r, s) { |
64 | return fmt.Errorf("\n‰r %s s %v replacements %v\n", r, s, subst.replacements) |
65 | } |
66 | } |
67 | return nil |
68 | } |
69 | |
70 | // typ returns the type of t with the type parameter tparams[i] substituted |
71 | // for the type targs[i] where subst was created using tparams and targs. |
72 | func (subst *subster) typ(t types.Type) (res types.Type) { |
73 | if subst == nil { |
74 | return t // A nil subst is type preserving. |
75 | } |
76 | if r, ok := subst.cache[t]; ok { |
77 | return r |
78 | } |
79 | defer func() { |
80 | subst.cache[t] = res |
81 | }() |
82 | |
83 | // fall through if result r will be identical to t, types.Identical(r, t). |
84 | switch t := t.(type) { |
85 | case *typeparams.TypeParam: |
86 | r := subst.replacements[t] |
87 | assert(r != nil, "type param without replacement encountered") |
88 | return r |
89 | |
90 | case *types.Basic: |
91 | return t |
92 | |
93 | case *types.Array: |
94 | if r := subst.typ(t.Elem()); r != t.Elem() { |
95 | return types.NewArray(r, t.Len()) |
96 | } |
97 | return t |
98 | |
99 | case *types.Slice: |
100 | if r := subst.typ(t.Elem()); r != t.Elem() { |
101 | return types.NewSlice(r) |
102 | } |
103 | return t |
104 | |
105 | case *types.Pointer: |
106 | if r := subst.typ(t.Elem()); r != t.Elem() { |
107 | return types.NewPointer(r) |
108 | } |
109 | return t |
110 | |
111 | case *types.Tuple: |
112 | return subst.tuple(t) |
113 | |
114 | case *types.Struct: |
115 | return subst.struct_(t) |
116 | |
117 | case *types.Map: |
118 | key := subst.typ(t.Key()) |
119 | elem := subst.typ(t.Elem()) |
120 | if key != t.Key() || elem != t.Elem() { |
121 | return types.NewMap(key, elem) |
122 | } |
123 | return t |
124 | |
125 | case *types.Chan: |
126 | if elem := subst.typ(t.Elem()); elem != t.Elem() { |
127 | return types.NewChan(t.Dir(), elem) |
128 | } |
129 | return t |
130 | |
131 | case *types.Signature: |
132 | return subst.signature(t) |
133 | |
134 | case *typeparams.Union: |
135 | return subst.union(t) |
136 | |
137 | case *types.Interface: |
138 | return subst.interface_(t) |
139 | |
140 | case *types.Named: |
141 | return subst.named(t) |
142 | |
143 | default: |
144 | panic("unreachable") |
145 | } |
146 | } |
147 | |
148 | // types returns the result of {subst.typ(ts[i])}. |
149 | func (subst *subster) types(ts []types.Type) []types.Type { |
150 | res := make([]types.Type, len(ts)) |
151 | for i := range ts { |
152 | res[i] = subst.typ(ts[i]) |
153 | } |
154 | return res |
155 | } |
156 | |
157 | func (subst *subster) tuple(t *types.Tuple) *types.Tuple { |
158 | if t != nil { |
159 | if vars := subst.varlist(t); vars != nil { |
160 | return types.NewTuple(vars...) |
161 | } |
162 | } |
163 | return t |
164 | } |
165 | |
166 | type varlist interface { |
167 | At(i int) *types.Var |
168 | Len() int |
169 | } |
170 | |
171 | // fieldlist is an adapter for structs for the varlist interface. |
172 | type fieldlist struct { |
173 | str *types.Struct |
174 | } |
175 | |
176 | func (fl fieldlist) At(i int) *types.Var { return fl.str.Field(i) } |
177 | func (fl fieldlist) Len() int { return fl.str.NumFields() } |
178 | |
179 | func (subst *subster) struct_(t *types.Struct) *types.Struct { |
180 | if t != nil { |
181 | if fields := subst.varlist(fieldlist{t}); fields != nil { |
182 | tags := make([]string, t.NumFields()) |
183 | for i, n := 0, t.NumFields(); i < n; i++ { |
184 | tags[i] = t.Tag(i) |
185 | } |
186 | return types.NewStruct(fields, tags) |
187 | } |
188 | } |
189 | return t |
190 | } |
191 | |
192 | // varlist reutrns subst(in[i]) or return nils if subst(v[i]) == v[i] for all i. |
193 | func (subst *subster) varlist(in varlist) []*types.Var { |
194 | var out []*types.Var // nil => no updates |
195 | for i, n := 0, in.Len(); i < n; i++ { |
196 | v := in.At(i) |
197 | w := subst.var_(v) |
198 | if v != w && out == nil { |
199 | out = make([]*types.Var, n) |
200 | for j := 0; j < i; j++ { |
201 | out[j] = in.At(j) |
202 | } |
203 | } |
204 | if out != nil { |
205 | out[i] = w |
206 | } |
207 | } |
208 | return out |
209 | } |
210 | |
211 | func (subst *subster) var_(v *types.Var) *types.Var { |
212 | if v != nil { |
213 | if typ := subst.typ(v.Type()); typ != v.Type() { |
214 | if v.IsField() { |
215 | return types.NewField(v.Pos(), v.Pkg(), v.Name(), typ, v.Embedded()) |
216 | } |
217 | return types.NewVar(v.Pos(), v.Pkg(), v.Name(), typ) |
218 | } |
219 | } |
220 | return v |
221 | } |
222 | |
223 | func (subst *subster) union(u *typeparams.Union) *typeparams.Union { |
224 | var out []*typeparams.Term // nil => no updates |
225 | |
226 | for i, n := 0, u.Len(); i < n; i++ { |
227 | t := u.Term(i) |
228 | r := subst.typ(t.Type()) |
229 | if r != t.Type() && out == nil { |
230 | out = make([]*typeparams.Term, n) |
231 | for j := 0; j < i; j++ { |
232 | out[j] = u.Term(j) |
233 | } |
234 | } |
235 | if out != nil { |
236 | out[i] = typeparams.NewTerm(t.Tilde(), r) |
237 | } |
238 | } |
239 | |
240 | if out != nil { |
241 | return typeparams.NewUnion(out) |
242 | } |
243 | return u |
244 | } |
245 | |
246 | func (subst *subster) interface_(iface *types.Interface) *types.Interface { |
247 | if iface == nil { |
248 | return nil |
249 | } |
250 | |
251 | // methods for the interface. Initially nil if there is no known change needed. |
252 | // Signatures for the method where recv is nil. NewInterfaceType fills in the recievers. |
253 | var methods []*types.Func |
254 | initMethods := func(n int) { // copy first n explicit methods |
255 | methods = make([]*types.Func, iface.NumExplicitMethods()) |
256 | for i := 0; i < n; i++ { |
257 | f := iface.ExplicitMethod(i) |
258 | norecv := changeRecv(f.Type().(*types.Signature), nil) |
259 | methods[i] = types.NewFunc(f.Pos(), f.Pkg(), f.Name(), norecv) |
260 | } |
261 | } |
262 | for i := 0; i < iface.NumExplicitMethods(); i++ { |
263 | f := iface.ExplicitMethod(i) |
264 | // On interfaces, we need to cycle break on anonymous interface types |
265 | // being in a cycle with their signatures being in cycles with their recievers |
266 | // that do not go through a Named. |
267 | norecv := changeRecv(f.Type().(*types.Signature), nil) |
268 | sig := subst.typ(norecv) |
269 | if sig != norecv && methods == nil { |
270 | initMethods(i) |
271 | } |
272 | if methods != nil { |
273 | methods[i] = types.NewFunc(f.Pos(), f.Pkg(), f.Name(), sig.(*types.Signature)) |
274 | } |
275 | } |
276 | |
277 | var embeds []types.Type |
278 | initEmbeds := func(n int) { // copy first n embedded types |
279 | embeds = make([]types.Type, iface.NumEmbeddeds()) |
280 | for i := 0; i < n; i++ { |
281 | embeds[i] = iface.EmbeddedType(i) |
282 | } |
283 | } |
284 | for i := 0; i < iface.NumEmbeddeds(); i++ { |
285 | e := iface.EmbeddedType(i) |
286 | r := subst.typ(e) |
287 | if e != r && embeds == nil { |
288 | initEmbeds(i) |
289 | } |
290 | if embeds != nil { |
291 | embeds[i] = r |
292 | } |
293 | } |
294 | |
295 | if methods == nil && embeds == nil { |
296 | return iface |
297 | } |
298 | if methods == nil { |
299 | initMethods(iface.NumExplicitMethods()) |
300 | } |
301 | if embeds == nil { |
302 | initEmbeds(iface.NumEmbeddeds()) |
303 | } |
304 | return types.NewInterfaceType(methods, embeds).Complete() |
305 | } |
306 | |
307 | func (subst *subster) named(t *types.Named) types.Type { |
308 | // A name type may be: |
309 | // (1) ordinary (no type parameters, no type arguments), |
310 | // (2) generic (type parameters but no type arguments), or |
311 | // (3) instantiated (type parameters and type arguments). |
312 | tparams := typeparams.ForNamed(t) |
313 | if tparams.Len() == 0 { |
314 | // case (1) ordinary |
315 | |
316 | // Note: If Go allows for local type declarations in generic |
317 | // functions we may need to descend into underlying as well. |
318 | return t |
319 | } |
320 | targs := typeparams.NamedTypeArgs(t) |
321 | |
322 | // insts are arguments to instantiate using. |
323 | insts := make([]types.Type, tparams.Len()) |
324 | |
325 | // case (2) generic ==> targs.Len() == 0 |
326 | // Instantiating a generic with no type arguments should be unreachable. |
327 | // Please report a bug if you encounter this. |
328 | assert(targs.Len() != 0, "substition into a generic Named type is currently unsupported") |
329 | |
330 | // case (3) instantiated. |
331 | // Substitute into the type arguments and instantiate the replacements/ |
332 | // Example: |
333 | // type N[A any] func() A |
334 | // func Foo[T](g N[T]) {} |
335 | // To instantiate Foo[string], one goes through {T->string}. To get the type of g |
336 | // one subsitutes T with string in {N with typeargs == {T} and typeparams == {A} } |
337 | // to get {N with TypeArgs == {string} and typeparams == {A} }. |
338 | assert(targs.Len() == tparams.Len(), "typeargs.Len() must match typeparams.Len() if present") |
339 | for i, n := 0, targs.Len(); i < n; i++ { |
340 | inst := subst.typ(targs.At(i)) // TODO(generic): Check with rfindley for mutual recursion |
341 | insts[i] = inst |
342 | } |
343 | r, err := typeparams.Instantiate(subst.ctxt, typeparams.NamedTypeOrigin(t), insts, false) |
344 | assert(err == nil, "failed to Instantiate Named type") |
345 | return r |
346 | } |
347 | |
348 | func (subst *subster) signature(t *types.Signature) types.Type { |
349 | tparams := typeparams.ForSignature(t) |
350 | |
351 | // We are choosing not to support tparams.Len() > 0 until a need has been observed in practice. |
352 | // |
353 | // There are some known usages for types.Types coming from types.{Eval,CheckExpr}. |
354 | // To support tparams.Len() > 0, we just need to do the following [psuedocode]: |
355 | // targs := {subst.replacements[tparams[i]]]}; Instantiate(ctxt, t, targs, false) |
356 | |
357 | assert(tparams.Len() == 0, "Substituting types.Signatures with generic functions are currently unsupported.") |
358 | |
359 | // Either: |
360 | // (1)non-generic function. |
361 | // no type params to substitute |
362 | // (2)generic method and recv needs to be substituted. |
363 | |
364 | // Recievers can be either: |
365 | // named |
366 | // pointer to named |
367 | // interface |
368 | // nil |
369 | // interface is the problematic case. We need to cycle break there! |
370 | recv := subst.var_(t.Recv()) |
371 | params := subst.tuple(t.Params()) |
372 | results := subst.tuple(t.Results()) |
373 | if recv != t.Recv() || params != t.Params() || results != t.Results() { |
374 | return typeparams.NewSignatureType(recv, nil, nil, params, results, t.Variadic()) |
375 | } |
376 | return t |
377 | } |
378 | |
379 | // reaches returns true if a type t reaches any type t' s.t. c[t'] == true. |
380 | // Updates c to cache results. |
381 | func reaches(t types.Type, c map[types.Type]bool) (res bool) { |
382 | if c, ok := c[t]; ok { |
383 | return c |
384 | } |
385 | c[t] = false // prevent cycles |
386 | defer func() { |
387 | c[t] = res |
388 | }() |
389 | |
390 | switch t := t.(type) { |
391 | case *typeparams.TypeParam, *types.Basic: |
392 | // no-op => c == false |
393 | case *types.Array: |
394 | return reaches(t.Elem(), c) |
395 | case *types.Slice: |
396 | return reaches(t.Elem(), c) |
397 | case *types.Pointer: |
398 | return reaches(t.Elem(), c) |
399 | case *types.Tuple: |
400 | for i := 0; i < t.Len(); i++ { |
401 | if reaches(t.At(i).Type(), c) { |
402 | return true |
403 | } |
404 | } |
405 | case *types.Struct: |
406 | for i := 0; i < t.NumFields(); i++ { |
407 | if reaches(t.Field(i).Type(), c) { |
408 | return true |
409 | } |
410 | } |
411 | case *types.Map: |
412 | return reaches(t.Key(), c) || reaches(t.Elem(), c) |
413 | case *types.Chan: |
414 | return reaches(t.Elem(), c) |
415 | case *types.Signature: |
416 | if t.Recv() != nil && reaches(t.Recv().Type(), c) { |
417 | return true |
418 | } |
419 | return reaches(t.Params(), c) || reaches(t.Results(), c) |
420 | case *typeparams.Union: |
421 | for i := 0; i < t.Len(); i++ { |
422 | if reaches(t.Term(i).Type(), c) { |
423 | return true |
424 | } |
425 | } |
426 | case *types.Interface: |
427 | for i := 0; i < t.NumEmbeddeds(); i++ { |
428 | if reaches(t.Embedded(i), c) { |
429 | return true |
430 | } |
431 | } |
432 | for i := 0; i < t.NumExplicitMethods(); i++ { |
433 | if reaches(t.ExplicitMethod(i).Type(), c) { |
434 | return true |
435 | } |
436 | } |
437 | case *types.Named: |
438 | return reaches(t.Underlying(), c) |
439 | default: |
440 | panic("unreachable") |
441 | } |
442 | return false |
443 | } |
444 |
Members