GoPLS Viewer

Home|gopls/go/ssa/subst.go
1// Copyright 2022 The Go Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style
3// license that can be found in the LICENSE file.
4package ssa
5
6import (
7    "fmt"
8    "go/types"
9
10    "golang.org/x/tools/internal/typeparams"
11)
12
13// Type substituter for a fixed set of replacement types.
14//
15// A nil *subster is an valid, empty substitution map. It always acts as
16// the identity function. This allows for treating parameterized and
17// non-parameterized functions identically while compiling to ssa.
18//
19// Not concurrency-safe.
20type subster struct {
21    // TODO(zpavlinovic): replacements can contain type params
22    // when generating instances inside of a generic function body.
23    replacements map[*typeparams.TypeParam]types.Type // values should contain no type params
24    cache        map[types.Type]types.Type            // cache of subst results
25    ctxt         *typeparams.Context
26    debug        bool // perform extra debugging checks
27    // TODO(taking): consider adding Pos
28}
29
30// Returns a subster that replaces tparams[i] with targs[i]. Uses ctxt as a cache.
31// targs should not contain any types in tparams.
32func makeSubster(ctxt *typeparams.Contexttparams *typeparams.TypeParamListtargs []types.Typedebug bool) *subster {
33    assert(tparams.Len() == len(targs), "makeSubster argument count must match")
34
35    subst := &subster{
36        replacementsmake(map[*typeparams.TypeParam]types.Typetparams.Len()),
37        cache:        make(map[types.Type]types.Type),
38        ctxt:         ctxt,
39        debug:        debug,
40    }
41    for i := 0i < tparams.Len(); i++ {
42        subst.replacements[tparams.At(i)] = targs[i]
43    }
44    if subst.debug {
45        if err := subst.wellFormed(); err != nil {
46            panic(err)
47        }
48    }
49    return subst
50}
51
52// wellFormed returns an error if subst was not properly initialized.
53func (subst *substerwellFormed() error {
54    if subst == nil || len(subst.replacements) == 0 {
55        return nil
56    }
57    // Check that all of the type params do not appear in the arguments.
58    s := make(map[types.Type]boollen(subst.replacements))
59    for tparam := range subst.replacements {
60        s[tparam] = true
61    }
62    for _r := range subst.replacements {
63        if reaches(rs) {
64            return fmt.Errorf("\n‰r %s s %v replacements %v\n"rssubst.replacements)
65        }
66    }
67    return nil
68}
69
70// typ returns the type of t with the type parameter tparams[i] substituted
71// for the type targs[i] where subst was created using tparams and targs.
72func (subst *substertyp(t types.Type) (res types.Type) {
73    if subst == nil {
74        return t // A nil subst is type preserving.
75    }
76    if rok := subst.cache[t]; ok {
77        return r
78    }
79    defer func() {
80        subst.cache[t] = res
81    }()
82
83    // fall through if result r will be identical to t, types.Identical(r, t).
84    switch t := t.(type) {
85    case *typeparams.TypeParam:
86        r := subst.replacements[t]
87        assert(r != nil"type param without replacement encountered")
88        return r
89
90    case *types.Basic:
91        return t
92
93    case *types.Array:
94        if r := subst.typ(t.Elem()); r != t.Elem() {
95            return types.NewArray(rt.Len())
96        }
97        return t
98
99    case *types.Slice:
100        if r := subst.typ(t.Elem()); r != t.Elem() {
101            return types.NewSlice(r)
102        }
103        return t
104
105    case *types.Pointer:
106        if r := subst.typ(t.Elem()); r != t.Elem() {
107            return types.NewPointer(r)
108        }
109        return t
110
111    case *types.Tuple:
112        return subst.tuple(t)
113
114    case *types.Struct:
115        return subst.struct_(t)
116
117    case *types.Map:
118        key := subst.typ(t.Key())
119        elem := subst.typ(t.Elem())
120        if key != t.Key() || elem != t.Elem() {
121            return types.NewMap(keyelem)
122        }
123        return t
124
125    case *types.Chan:
126        if elem := subst.typ(t.Elem()); elem != t.Elem() {
127            return types.NewChan(t.Dir(), elem)
128        }
129        return t
130
131    case *types.Signature:
132        return subst.signature(t)
133
134    case *typeparams.Union:
135        return subst.union(t)
136
137    case *types.Interface:
138        return subst.interface_(t)
139
140    case *types.Named:
141        return subst.named(t)
142
143    default:
144        panic("unreachable")
145    }
146}
147
148// types returns the result of {subst.typ(ts[i])}.
149func (subst *substertypes(ts []types.Type) []types.Type {
150    res := make([]types.Typelen(ts))
151    for i := range ts {
152        res[i] = subst.typ(ts[i])
153    }
154    return res
155}
156
157func (subst *substertuple(t *types.Tuple) *types.Tuple {
158    if t != nil {
159        if vars := subst.varlist(t); vars != nil {
160            return types.NewTuple(vars...)
161        }
162    }
163    return t
164}
165
166type varlist interface {
167    At(i int) *types.Var
168    Len() int
169}
170
171// fieldlist is an adapter for structs for the varlist interface.
172type fieldlist struct {
173    str *types.Struct
174}
175
176func (fl fieldlistAt(i int) *types.Var { return fl.str.Field(i) }
177func (fl fieldlistLen() int            { return fl.str.NumFields() }
178
179func (subst *substerstruct_(t *types.Struct) *types.Struct {
180    if t != nil {
181        if fields := subst.varlist(fieldlist{t}); fields != nil {
182            tags := make([]stringt.NumFields())
183            for in := 0t.NumFields(); i < ni++ {
184                tags[i] = t.Tag(i)
185            }
186            return types.NewStruct(fieldstags)
187        }
188    }
189    return t
190}
191
192// varlist reutrns subst(in[i]) or return nils if subst(v[i]) == v[i] for all i.
193func (subst *substervarlist(in varlist) []*types.Var {
194    var out []*types.Var // nil => no updates
195    for in := 0in.Len(); i < ni++ {
196        v := in.At(i)
197        w := subst.var_(v)
198        if v != w && out == nil {
199            out = make([]*types.Varn)
200            for j := 0j < ij++ {
201                out[j] = in.At(j)
202            }
203        }
204        if out != nil {
205            out[i] = w
206        }
207    }
208    return out
209}
210
211func (subst *substervar_(v *types.Var) *types.Var {
212    if v != nil {
213        if typ := subst.typ(v.Type()); typ != v.Type() {
214            if v.IsField() {
215                return types.NewField(v.Pos(), v.Pkg(), v.Name(), typv.Embedded())
216            }
217            return types.NewVar(v.Pos(), v.Pkg(), v.Name(), typ)
218        }
219    }
220    return v
221}
222
223func (subst *substerunion(u *typeparams.Union) *typeparams.Union {
224    var out []*typeparams.Term // nil => no updates
225
226    for in := 0u.Len(); i < ni++ {
227        t := u.Term(i)
228        r := subst.typ(t.Type())
229        if r != t.Type() && out == nil {
230            out = make([]*typeparams.Termn)
231            for j := 0j < ij++ {
232                out[j] = u.Term(j)
233            }
234        }
235        if out != nil {
236            out[i] = typeparams.NewTerm(t.Tilde(), r)
237        }
238    }
239
240    if out != nil {
241        return typeparams.NewUnion(out)
242    }
243    return u
244}
245
246func (subst *substerinterface_(iface *types.Interface) *types.Interface {
247    if iface == nil {
248        return nil
249    }
250
251    // methods for the interface. Initially nil if there is no known change needed.
252    // Signatures for the method where recv is nil. NewInterfaceType fills in the recievers.
253    var methods []*types.Func
254    initMethods := func(n int) { // copy first n explicit methods
255        methods = make([]*types.Funciface.NumExplicitMethods())
256        for i := 0i < ni++ {
257            f := iface.ExplicitMethod(i)
258            norecv := changeRecv(f.Type().(*types.Signature), nil)
259            methods[i] = types.NewFunc(f.Pos(), f.Pkg(), f.Name(), norecv)
260        }
261    }
262    for i := 0i < iface.NumExplicitMethods(); i++ {
263        f := iface.ExplicitMethod(i)
264        // On interfaces, we need to cycle break on anonymous interface types
265        // being in a cycle with their signatures being in cycles with their recievers
266        // that do not go through a Named.
267        norecv := changeRecv(f.Type().(*types.Signature), nil)
268        sig := subst.typ(norecv)
269        if sig != norecv && methods == nil {
270            initMethods(i)
271        }
272        if methods != nil {
273            methods[i] = types.NewFunc(f.Pos(), f.Pkg(), f.Name(), sig.(*types.Signature))
274        }
275    }
276
277    var embeds []types.Type
278    initEmbeds := func(n int) { // copy first n embedded types
279        embeds = make([]types.Typeiface.NumEmbeddeds())
280        for i := 0i < ni++ {
281            embeds[i] = iface.EmbeddedType(i)
282        }
283    }
284    for i := 0i < iface.NumEmbeddeds(); i++ {
285        e := iface.EmbeddedType(i)
286        r := subst.typ(e)
287        if e != r && embeds == nil {
288            initEmbeds(i)
289        }
290        if embeds != nil {
291            embeds[i] = r
292        }
293    }
294
295    if methods == nil && embeds == nil {
296        return iface
297    }
298    if methods == nil {
299        initMethods(iface.NumExplicitMethods())
300    }
301    if embeds == nil {
302        initEmbeds(iface.NumEmbeddeds())
303    }
304    return types.NewInterfaceType(methodsembeds).Complete()
305}
306
307func (subst *substernamed(t *types.Namedtypes.Type {
308    // A name type may be:
309    // (1) ordinary (no type parameters, no type arguments),
310    // (2) generic (type parameters but no type arguments), or
311    // (3) instantiated (type parameters and type arguments).
312    tparams := typeparams.ForNamed(t)
313    if tparams.Len() == 0 {
314        // case (1) ordinary
315
316        // Note: If Go allows for local type declarations in generic
317        // functions we may need to descend into underlying as well.
318        return t
319    }
320    targs := typeparams.NamedTypeArgs(t)
321
322    // insts are arguments to instantiate using.
323    insts := make([]types.Typetparams.Len())
324
325    // case (2) generic ==> targs.Len() == 0
326    // Instantiating a generic with no type arguments should be unreachable.
327    // Please report a bug if you encounter this.
328    assert(targs.Len() != 0"substition into a generic Named type is currently unsupported")
329
330    // case (3) instantiated.
331    // Substitute into the type arguments and instantiate the replacements/
332    // Example:
333    //    type N[A any] func() A
334    //    func Foo[T](g N[T]) {}
335    //  To instantiate Foo[string], one goes through {T->string}. To get the type of g
336    //  one subsitutes T with string in {N with typeargs == {T} and typeparams == {A} }
337    //  to get {N with TypeArgs == {string} and typeparams == {A} }.
338    assert(targs.Len() == tparams.Len(), "typeargs.Len() must match typeparams.Len() if present")
339    for in := 0targs.Len(); i < ni++ {
340        inst := subst.typ(targs.At(i)) // TODO(generic): Check with rfindley for mutual recursion
341        insts[i] = inst
342    }
343    rerr := typeparams.Instantiate(subst.ctxttypeparams.NamedTypeOrigin(t), instsfalse)
344    assert(err == nil"failed to Instantiate Named type")
345    return r
346}
347
348func (subst *substersignature(t *types.Signaturetypes.Type {
349    tparams := typeparams.ForSignature(t)
350
351    // We are choosing not to support tparams.Len() > 0 until a need has been observed in practice.
352    //
353    // There are some known usages for types.Types coming from types.{Eval,CheckExpr}.
354    // To support tparams.Len() > 0, we just need to do the following [psuedocode]:
355    //   targs := {subst.replacements[tparams[i]]]}; Instantiate(ctxt, t, targs, false)
356
357    assert(tparams.Len() == 0"Substituting types.Signatures with generic functions are currently unsupported.")
358
359    // Either:
360    // (1)non-generic function.
361    //    no type params to substitute
362    // (2)generic method and recv needs to be substituted.
363
364    // Recievers can be either:
365    // named
366    // pointer to named
367    // interface
368    // nil
369    // interface is the problematic case. We need to cycle break there!
370    recv := subst.var_(t.Recv())
371    params := subst.tuple(t.Params())
372    results := subst.tuple(t.Results())
373    if recv != t.Recv() || params != t.Params() || results != t.Results() {
374        return typeparams.NewSignatureType(recvnilnilparamsresultst.Variadic())
375    }
376    return t
377}
378
379// reaches returns true if a type t reaches any type t' s.t. c[t'] == true.
380// Updates c to cache results.
381func reaches(t types.Typec map[types.Type]bool) (res bool) {
382    if cok := c[t]; ok {
383        return c
384    }
385    c[t] = false // prevent cycles
386    defer func() {
387        c[t] = res
388    }()
389
390    switch t := t.(type) {
391    case *typeparams.TypeParam, *types.Basic:
392        // no-op => c == false
393    case *types.Array:
394        return reaches(t.Elem(), c)
395    case *types.Slice:
396        return reaches(t.Elem(), c)
397    case *types.Pointer:
398        return reaches(t.Elem(), c)
399    case *types.Tuple:
400        for i := 0i < t.Len(); i++ {
401            if reaches(t.At(i).Type(), c) {
402                return true
403            }
404        }
405    case *types.Struct:
406        for i := 0i < t.NumFields(); i++ {
407            if reaches(t.Field(i).Type(), c) {
408                return true
409            }
410        }
411    case *types.Map:
412        return reaches(t.Key(), c) || reaches(t.Elem(), c)
413    case *types.Chan:
414        return reaches(t.Elem(), c)
415    case *types.Signature:
416        if t.Recv() != nil && reaches(t.Recv().Type(), c) {
417            return true
418        }
419        return reaches(t.Params(), c) || reaches(t.Results(), c)
420    case *typeparams.Union:
421        for i := 0i < t.Len(); i++ {
422            if reaches(t.Term(i).Type(), c) {
423                return true
424            }
425        }
426    case *types.Interface:
427        for i := 0i < t.NumEmbeddeds(); i++ {
428            if reaches(t.Embedded(i), c) {
429                return true
430            }
431        }
432        for i := 0i < t.NumExplicitMethods(); i++ {
433            if reaches(t.ExplicitMethod(i).Type(), c) {
434                return true
435            }
436        }
437    case *types.Named:
438        return reaches(t.Underlying(), c)
439    default:
440        panic("unreachable")
441    }
442    return false
443}
444
MembersX
subster.tuple.BlockStmt.vars
makeSubster.targs
subster.typ.BlockStmt.r
subster.interface_.BlockStmt.r
subster.types.RangeStmt_3999.i
subster.tuple
subster.var_.v
subster.named.t
reaches
subster.interface_.iface
subster.tuple.subst
subster.varlist
subster.union.n
subster.signature.t
subster
makeSubster.debug
subster.struct_.subst
subster.varlist.i
subster.interface_.BlockStmt.f
subster.named
subster.signature.subst
makeSubster
fieldlist.str
subster.varlist.subst
subster.types.subst
subster.struct_.BlockStmt.BlockStmt.i
subster.union.subst
subster.named.i
subster.typ.subst
subster.typ.t
subster.union.BlockStmt.t
subster.varlist.in
subster.typ.BlockStmt.elem
reaches.BlockStmt.i
subster.interface_.methods
subster.interface_.embeds
subster.named.targs
subster.cache
subster.wellFormed
subster.named.err
subster.types.res
subster.varlist.n
varlist
fieldlist.Len.fl
subster.struct_
subster.debug
subster.named.r
subster.interface_.BlockStmt.norecv
subster.interface_.BlockStmt.e
subster.signature
subster.struct_.BlockStmt.BlockStmt.n
subster.var_.BlockStmt.typ
subster.var_.subst
subster.union.i
subster.struct_.BlockStmt.BlockStmt.tags
subster.union.out
subster.named.insts
subster.named.n
subster.varlist.BlockStmt.v
subster.wellFormed.s
fieldlist.At.i
subster.struct_.t
makeSubster.tparams
makeSubster.subst
subster.interface_.i
subster.union
subster.interface_
fieldlist
fieldlist.At
subster.signature.params
subster.typ
subster.typ.BlockStmt.key
subster.varlist.BlockStmt.BlockStmt.j
subster.var_
makeSubster.ctxt
subster.types.ts
subster.interface_.BlockStmt.sig
subster.wellFormed.RangeStmt_2074.r
subster.union.u
subster.interface_.BlockStmt.i
subster.ctxt
fieldlist.At.fl
subster.struct_.BlockStmt.fields
subster.union.BlockStmt.BlockStmt.j
reaches.t
subster.wellFormed.RangeStmt_2010.tparam
subster.varlist.out
subster.named.subst
subster.interface_.BlockStmt.BlockStmt.norecv
subster.types
subster.signature.results
reaches.c
fieldlist.Len
subster.varlist.BlockStmt.w
makeSubster.i
makeSubster.BlockStmt.err
subster.wellFormed.subst
subster.typ.res
subster.tuple.t
subster.replacements
subster.named.tparams
subster.union.BlockStmt.r
subster.interface_.subst
subster.signature.tparams
subster.signature.recv
reaches.res
subster.interface_.BlockStmt.BlockStmt.f
subster.named.BlockStmt.inst
Members
X