1 | // Copyright 2019 The Go Authors. All rights reserved. |
---|---|
2 | // Use of this source code is governed by a BSD-style |
3 | // license that can be found in the LICENSE file. |
4 | |
5 | // Package memoize defines a "promise" abstraction that enables |
6 | // memoization of the result of calling an expensive but idempotent |
7 | // function. |
8 | // |
9 | // Call p = NewPromise(f) to obtain a promise for the future result of |
10 | // calling f(), and call p.Get() to obtain that result. All calls to |
11 | // p.Get return the result of a single call of f(). |
12 | // Get blocks if the function has not finished (or started). |
13 | // |
14 | // A Store is a map of arbitrary keys to promises. Use Store.Promise |
15 | // to create a promise in the store. All calls to Handle(k) return the |
16 | // same promise as long as it is in the store. These promises are |
17 | // reference-counted and must be explicitly released. Once the last |
18 | // reference is released, the promise is removed from the store. |
19 | package memoize |
20 | |
21 | import ( |
22 | "context" |
23 | "fmt" |
24 | "reflect" |
25 | "runtime/trace" |
26 | "sync" |
27 | "sync/atomic" |
28 | |
29 | "golang.org/x/tools/internal/xcontext" |
30 | ) |
31 | |
32 | // Function is the type of a function that can be memoized. |
33 | // |
34 | // If the arg is a RefCounted, its Acquire/Release operations are called. |
35 | // |
36 | // The argument must not materially affect the result of the function |
37 | // in ways that are not captured by the promise's key, since if |
38 | // Promise.Get is called twice concurrently, with the same (implicit) |
39 | // key but different arguments, the Function is called only once but |
40 | // its result must be suitable for both callers. |
41 | // |
42 | // The main purpose of the argument is to avoid the Function closure |
43 | // needing to retain large objects (in practice: the snapshot) in |
44 | // memory that can be supplied at call time by any caller. |
45 | type Function func(ctx context.Context, arg interface{}) interface{} |
46 | |
47 | // A RefCounted is a value whose functional lifetime is determined by |
48 | // reference counting. |
49 | // |
50 | // Its Acquire method is called before the Function is invoked, and |
51 | // the corresponding release is called when the Function returns. |
52 | // Usually both events happen within a single call to Get, so Get |
53 | // would be fine with a "borrowed" reference, but if the context is |
54 | // cancelled, Get may return before the Function is complete, causing |
55 | // the argument to escape, and potential premature destruction of the |
56 | // value. For a reference-counted type, this requires a pair of |
57 | // increment/decrement operations to extend its life. |
58 | type RefCounted interface { |
59 | // Acquire prevents the value from being destroyed until the |
60 | // returned function is called. |
61 | Acquire() func() |
62 | } |
63 | |
64 | // A Promise represents the future result of a call to a function. |
65 | type Promise struct { |
66 | debug string // for observability |
67 | |
68 | // refcount is the reference count in the containing Store, used by |
69 | // Store.Promise. It is guarded by Store.promisesMu on the containing Store. |
70 | refcount int32 |
71 | |
72 | mu sync.Mutex |
73 | |
74 | // A Promise starts out IDLE, waiting for something to demand |
75 | // its evaluation. It then transitions into RUNNING state. |
76 | // |
77 | // While RUNNING, waiters tracks the number of Get calls |
78 | // waiting for a result, and the done channel is used to |
79 | // notify waiters of the next state transition. Once |
80 | // evaluation finishes, value is set, state changes to |
81 | // COMPLETED, and done is closed, unblocking waiters. |
82 | // |
83 | // Alternatively, as Get calls are cancelled, they decrement |
84 | // waiters. If it drops to zero, the inner context is |
85 | // cancelled, computation is abandoned, and state resets to |
86 | // IDLE to start the process over again. |
87 | state state |
88 | // done is set in running state, and closed when exiting it. |
89 | done chan struct{} |
90 | // cancel is set in running state. It cancels computation. |
91 | cancel context.CancelFunc |
92 | // waiters is the number of Gets outstanding. |
93 | waiters uint |
94 | // the function that will be used to populate the value |
95 | function Function |
96 | // value is set in completed state. |
97 | value interface{} |
98 | } |
99 | |
100 | // NewPromise returns a promise for the future result of calling the |
101 | // specified function. |
102 | // |
103 | // The debug string is used to classify promises in logs and metrics. |
104 | // It should be drawn from a small set. |
105 | func NewPromise(debug string, function Function) *Promise { |
106 | if function == nil { |
107 | panic("nil function") |
108 | } |
109 | return &Promise{ |
110 | debug: debug, |
111 | function: function, |
112 | } |
113 | } |
114 | |
115 | type state int |
116 | |
117 | const ( |
118 | stateIdle = iota // newly constructed, or last waiter was cancelled |
119 | stateRunning // start was called and not cancelled |
120 | stateCompleted // function call ran to completion |
121 | ) |
122 | |
123 | // Cached returns the value associated with a promise. |
124 | // |
125 | // It will never cause the value to be generated. |
126 | // It will return the cached value, if present. |
127 | func (p *Promise) Cached() interface{} { |
128 | p.mu.Lock() |
129 | defer p.mu.Unlock() |
130 | if p.state == stateCompleted { |
131 | return p.value |
132 | } |
133 | return nil |
134 | } |
135 | |
136 | // Get returns the value associated with a promise. |
137 | // |
138 | // All calls to Promise.Get on a given promise return the |
139 | // same result but the function is called (to completion) at most once. |
140 | // |
141 | // If the value is not yet ready, the underlying function will be invoked. |
142 | // |
143 | // If ctx is cancelled, Get returns (nil, Canceled). |
144 | // If all concurrent calls to Get are cancelled, the context provided |
145 | // to the function is cancelled. A later call to Get may attempt to |
146 | // call the function again. |
147 | func (p *Promise) Get(ctx context.Context, arg interface{}) (interface{}, error) { |
148 | if ctx.Err() != nil { |
149 | return nil, ctx.Err() |
150 | } |
151 | p.mu.Lock() |
152 | switch p.state { |
153 | case stateIdle: |
154 | return p.run(ctx, arg) |
155 | case stateRunning: |
156 | return p.wait(ctx) |
157 | case stateCompleted: |
158 | defer p.mu.Unlock() |
159 | return p.value, nil |
160 | default: |
161 | panic("unknown state") |
162 | } |
163 | } |
164 | |
165 | // run starts p.function and returns the result. p.mu must be locked. |
166 | func (p *Promise) run(ctx context.Context, arg interface{}) (interface{}, error) { |
167 | childCtx, cancel := context.WithCancel(xcontext.Detach(ctx)) |
168 | p.cancel = cancel |
169 | p.state = stateRunning |
170 | p.done = make(chan struct{}) |
171 | function := p.function // Read under the lock |
172 | |
173 | // Make sure that the argument isn't destroyed while we're running in it. |
174 | release := func() {} |
175 | if rc, ok := arg.(RefCounted); ok { |
176 | release = rc.Acquire() |
177 | } |
178 | |
179 | go func() { |
180 | trace.WithRegion(childCtx, fmt.Sprintf("Promise.run %s", p.debug), func() { |
181 | defer release() |
182 | // Just in case the function does something expensive without checking |
183 | // the context, double-check we're still alive. |
184 | if childCtx.Err() != nil { |
185 | return |
186 | } |
187 | v := function(childCtx, arg) |
188 | if childCtx.Err() != nil { |
189 | return |
190 | } |
191 | |
192 | p.mu.Lock() |
193 | defer p.mu.Unlock() |
194 | // It's theoretically possible that the promise has been cancelled out |
195 | // of the run that started us, and then started running again since we |
196 | // checked childCtx above. Even so, that should be harmless, since each |
197 | // run should produce the same results. |
198 | if p.state != stateRunning { |
199 | return |
200 | } |
201 | |
202 | p.value = v |
203 | p.function = nil // aid GC |
204 | p.state = stateCompleted |
205 | close(p.done) |
206 | }) |
207 | }() |
208 | |
209 | return p.wait(ctx) |
210 | } |
211 | |
212 | // wait waits for the value to be computed, or ctx to be cancelled. p.mu must be locked. |
213 | func (p *Promise) wait(ctx context.Context) (interface{}, error) { |
214 | p.waiters++ |
215 | done := p.done |
216 | p.mu.Unlock() |
217 | |
218 | select { |
219 | case <-done: |
220 | p.mu.Lock() |
221 | defer p.mu.Unlock() |
222 | if p.state == stateCompleted { |
223 | return p.value, nil |
224 | } |
225 | return nil, nil |
226 | case <-ctx.Done(): |
227 | p.mu.Lock() |
228 | defer p.mu.Unlock() |
229 | p.waiters-- |
230 | if p.waiters == 0 && p.state == stateRunning { |
231 | p.cancel() |
232 | close(p.done) |
233 | p.state = stateIdle |
234 | p.done = nil |
235 | p.cancel = nil |
236 | } |
237 | return nil, ctx.Err() |
238 | } |
239 | } |
240 | |
241 | // An EvictionPolicy controls the eviction behavior of keys in a Store when |
242 | // they no longer have any references. |
243 | type EvictionPolicy int |
244 | |
245 | const ( |
246 | // ImmediatelyEvict evicts keys as soon as they no longer have references. |
247 | ImmediatelyEvict EvictionPolicy = iota |
248 | |
249 | // NeverEvict does not evict keys. |
250 | NeverEvict |
251 | ) |
252 | |
253 | // A Store maps arbitrary keys to reference-counted promises. |
254 | // |
255 | // The zero value is a valid Store, though a store may also be created via |
256 | // NewStore if a custom EvictionPolicy is required. |
257 | type Store struct { |
258 | evictionPolicy EvictionPolicy |
259 | |
260 | promisesMu sync.Mutex |
261 | promises map[interface{}]*Promise |
262 | } |
263 | |
264 | // NewStore creates a new store with the given eviction policy. |
265 | func NewStore(policy EvictionPolicy) *Store { |
266 | return &Store{evictionPolicy: policy} |
267 | } |
268 | |
269 | // Promise returns a reference-counted promise for the future result of |
270 | // calling the specified function. |
271 | // |
272 | // Calls to Promise with the same key return the same promise, incrementing its |
273 | // reference count. The caller must call the returned function to decrement |
274 | // the promise's reference count when it is no longer needed. The returned |
275 | // function must not be called more than once. |
276 | // |
277 | // Once the last reference has been released, the promise is removed from the |
278 | // store. |
279 | func (store *Store) Promise(key interface{}, function Function) (*Promise, func()) { |
280 | store.promisesMu.Lock() |
281 | p, ok := store.promises[key] |
282 | if !ok { |
283 | p = NewPromise(reflect.TypeOf(key).String(), function) |
284 | if store.promises == nil { |
285 | store.promises = map[interface{}]*Promise{} |
286 | } |
287 | store.promises[key] = p |
288 | } |
289 | p.refcount++ |
290 | store.promisesMu.Unlock() |
291 | |
292 | var released int32 |
293 | release := func() { |
294 | if !atomic.CompareAndSwapInt32(&released, 0, 1) { |
295 | panic("release called more than once") |
296 | } |
297 | store.promisesMu.Lock() |
298 | |
299 | p.refcount-- |
300 | if p.refcount == 0 && store.evictionPolicy != NeverEvict { |
301 | // Inv: if p.refcount > 0, then store.promises[key] == p. |
302 | delete(store.promises, key) |
303 | } |
304 | store.promisesMu.Unlock() |
305 | } |
306 | |
307 | return p, release |
308 | } |
309 | |
310 | // Stats returns the number of each type of key in the store. |
311 | func (s *Store) Stats() map[reflect.Type]int { |
312 | result := map[reflect.Type]int{} |
313 | |
314 | s.promisesMu.Lock() |
315 | defer s.promisesMu.Unlock() |
316 | |
317 | for k := range s.promises { |
318 | result[reflect.TypeOf(k)]++ |
319 | } |
320 | return result |
321 | } |
322 | |
323 | // DebugOnlyIterate iterates through the store and, for each completed |
324 | // promise, calls f(k, v) for the map key k and function result v. It |
325 | // should only be used for debugging purposes. |
326 | func (s *Store) DebugOnlyIterate(f func(k, v interface{})) { |
327 | s.promisesMu.Lock() |
328 | defer s.promisesMu.Unlock() |
329 | |
330 | for k, p := range s.promises { |
331 | if v := p.Cached(); v != nil { |
332 | f(k, v) |
333 | } |
334 | } |
335 | } |
336 |
Members