|
| 1 | +package taskgraph |
| 2 | + |
| 3 | +import ( |
| 4 | + "errors" |
| 5 | + "fmt" |
| 6 | + "sync" |
| 7 | + |
| 8 | + set "github.com/deckarep/golang-set/v2" |
| 9 | +) |
| 10 | + |
| 11 | +var ( |
| 12 | + // ErrDuplicateBinding is returned when Binder.Store is called with a binding whose ID has already |
| 13 | + // been stored (which implies that the graph being executed contains multiple tasks producing |
| 14 | + // bindings for the same Key). |
| 15 | + ErrDuplicateBinding = errors.New("duplicate binding") |
| 16 | +) |
| 17 | + |
| 18 | +// BindStatus represents the tristate of a Binding. |
| 19 | +type BindStatus int |
| 20 | + |
| 21 | +const ( |
| 22 | + // Pending represents where the key is unbound (i.e. that no task has yet provided a binding for |
| 23 | + // the key, and no input binding was provided). |
| 24 | + Pending BindStatus = iota |
| 25 | + |
| 26 | + // Absent represents where the key is explicitly unbound (i.e. that the task which provides it was |
| 27 | + // unable to provide a value). The binding will contain an error, which is ErrIsAbsent by default |
| 28 | + // but can be another error to propagate information between tasks. This allows for errors which |
| 29 | + // do not terminate the execution of a graph. |
| 30 | + Absent |
| 31 | + |
| 32 | + // Present represents where the key is bound to a valid value (i.e. the task was able to provide a |
| 33 | + // value, or the value was bound as an input). |
| 34 | + Present |
| 35 | +) |
| 36 | + |
| 37 | +func (bs BindStatus) String() string { |
| 38 | + return map[BindStatus]string{ |
| 39 | + Pending: "PENDING", |
| 40 | + Absent: "ABSENT", |
| 41 | + Present: "PRESENT", |
| 42 | + }[bs] |
| 43 | +} |
| 44 | + |
| 45 | +// A Binding is a tristate wrapper around a key ID and an optional value or error. See the |
| 46 | +// documentation for BindStatus for details of the 3 states. Bindings are produced by calling the |
| 47 | +// Bind, BindAbsent, or BindError methods on a Key. |
| 48 | +type Binding interface { |
| 49 | + // ID returns the ID of the key which is bound by this Binding. |
| 50 | + ID() ID |
| 51 | + |
| 52 | + // Status returns the status of this binding. |
| 53 | + Status() BindStatus |
| 54 | + |
| 55 | + // Value returns the value bound to the key. This should only be called if Status() returns Present. |
| 56 | + Value() any |
| 57 | + |
| 58 | + // Value returns the error bound to the key. This should only be called if Status() returns Absent. |
| 59 | + Error() error |
| 60 | +} |
| 61 | + |
| 62 | +type binding struct { |
| 63 | + id ID |
| 64 | + status BindStatus |
| 65 | + value any |
| 66 | + err error |
| 67 | +} |
| 68 | + |
| 69 | +func (b *binding) ID() ID { |
| 70 | + return b.id |
| 71 | +} |
| 72 | + |
| 73 | +func (b *binding) Status() BindStatus { |
| 74 | + return b.status |
| 75 | +} |
| 76 | + |
| 77 | +func (b *binding) Value() any { |
| 78 | + return b.value |
| 79 | +} |
| 80 | + |
| 81 | +func (b *binding) Error() error { |
| 82 | + return b.err |
| 83 | +} |
| 84 | + |
| 85 | +func (b *binding) String() string { |
| 86 | + if b.status == Present { |
| 87 | + return fmt.Sprintf("%s(%s -> %v)", b.status, b.id, b.value) |
| 88 | + } |
| 89 | + return fmt.Sprintf("%s(%s)", b.status, b.id) |
| 90 | +} |
| 91 | + |
| 92 | +// bind a value to a key ID. |
| 93 | +func bind(id ID, value any) Binding { |
| 94 | + return &binding{ |
| 95 | + id: id, |
| 96 | + status: Present, |
| 97 | + value: value, |
| 98 | + } |
| 99 | +} |
| 100 | + |
| 101 | +// bindAbsent produces an absent binding for the given Key ID. |
| 102 | +func bindAbsent(id ID) Binding { |
| 103 | + return bindAbsentWithError(id, ErrIsAbsent) |
| 104 | +} |
| 105 | + |
| 106 | +// bindAbsent produces an absent binding for the given Key ID. |
| 107 | +func bindAbsentWithError(id ID, err error) Binding { |
| 108 | + return &binding{ |
| 109 | + id: id, |
| 110 | + status: Absent, |
| 111 | + err: err, |
| 112 | + } |
| 113 | +} |
| 114 | + |
| 115 | +// bindPending produces a pending binding; this is only ever used when calling Get on a Binder. |
| 116 | +func bindPending(id ID) Binding { |
| 117 | + return &binding{ |
| 118 | + id: id, |
| 119 | + status: Pending, |
| 120 | + } |
| 121 | +} |
| 122 | + |
| 123 | +// A Binder is the state store for tasks in a graph. |
| 124 | +type Binder interface { |
| 125 | + // Store adds bindings to the binder that can be retrieved with Get(). |
| 126 | + Store(...Binding) error |
| 127 | + |
| 128 | + // Returns whether the given IDs have all been bound (as Present or Absent). |
| 129 | + Has(...ID) bool |
| 130 | + |
| 131 | + // Get a previously stored binding. If no binding with the given ID has yet been stored, a binding with Status() = Pending is generated. |
| 132 | + Get(ID) Binding |
| 133 | + |
| 134 | + // GetAll returns all stored bindings. This is typically used only for tests. |
| 135 | + GetAll() []Binding |
| 136 | +} |
| 137 | + |
| 138 | +type binder struct { |
| 139 | + // Protects against concurrent access to the map |
| 140 | + sync.RWMutex |
| 141 | + |
| 142 | + bindings map[ID]Binding |
| 143 | +} |
| 144 | + |
| 145 | +func (b *binder) Store(bs ...Binding) error { |
| 146 | + b.Lock() |
| 147 | + defer b.Unlock() |
| 148 | + |
| 149 | + for _, binding := range bs { |
| 150 | + if _, ok := b.bindings[binding.ID()]; ok { |
| 151 | + return wrapStackErrorf("%w: %q", ErrDuplicateBinding, binding.ID()) |
| 152 | + } |
| 153 | + b.bindings[binding.ID()] = binding |
| 154 | + } |
| 155 | + |
| 156 | + return nil |
| 157 | +} |
| 158 | + |
| 159 | +func (b *binder) Has(ids ...ID) bool { |
| 160 | + b.RLock() |
| 161 | + defer b.RUnlock() |
| 162 | + |
| 163 | + for _, id := range ids { |
| 164 | + if _, ok := b.bindings[id]; !ok { |
| 165 | + return false |
| 166 | + } |
| 167 | + } |
| 168 | + return true |
| 169 | +} |
| 170 | + |
| 171 | +func (b *binder) Get(id ID) Binding { |
| 172 | + b.RLock() |
| 173 | + defer b.RUnlock() |
| 174 | + |
| 175 | + if binding, ok := b.bindings[id]; ok { |
| 176 | + return binding |
| 177 | + } |
| 178 | + return bindPending(id) |
| 179 | +} |
| 180 | + |
| 181 | +func (b *binder) GetAll() []Binding { |
| 182 | + b.RLock() |
| 183 | + defer b.RUnlock() |
| 184 | + |
| 185 | + res := make([]Binding, 0, len(b.bindings)) |
| 186 | + for _, binding := range b.bindings { |
| 187 | + res = append(res, binding) |
| 188 | + } |
| 189 | + return res |
| 190 | +} |
| 191 | + |
| 192 | +// NewBinder returns a new binder. |
| 193 | +func NewBinder() Binder { |
| 194 | + return &binder{ |
| 195 | + bindings: map[ID]Binding{}, |
| 196 | + } |
| 197 | +} |
| 198 | + |
| 199 | +// overlayBinder implements Binder to provide an overlay over an existing binder, such that newly |
| 200 | +// stored keys are added to the overlay only, but bindings can still be read from the base. This |
| 201 | +// still does not allow duplicate bindings (attempting to Store() a binding already present in the |
| 202 | +// base will return an error) |
| 203 | +type overlayBinder struct { |
| 204 | + base, overlay Binder |
| 205 | +} |
| 206 | + |
| 207 | +func (ob *overlayBinder) Store(bindings ...Binding) error { |
| 208 | + for _, b := range bindings { |
| 209 | + if ob.base.Has(b.ID()) { |
| 210 | + return wrapStackErrorf("%w: %q", ErrDuplicateBinding, b.ID()) |
| 211 | + } |
| 212 | + } |
| 213 | + return ob.overlay.Store(bindings...) |
| 214 | +} |
| 215 | + |
| 216 | +func (ob *overlayBinder) Has(ids ...ID) bool { |
| 217 | + for _, id := range ids { |
| 218 | + if !ob.overlay.Has(id) && !ob.base.Has(id) { |
| 219 | + return false |
| 220 | + } |
| 221 | + } |
| 222 | + return true |
| 223 | +} |
| 224 | + |
| 225 | +func (ob *overlayBinder) Get(id ID) Binding { |
| 226 | + if b := ob.overlay.Get(id); b.Status() != Pending { |
| 227 | + return b |
| 228 | + } |
| 229 | + return ob.base.Get(id) |
| 230 | +} |
| 231 | + |
| 232 | +func (ob *overlayBinder) GetAll() []Binding { |
| 233 | + return append(ob.base.GetAll(), ob.overlay.GetAll()...) |
| 234 | +} |
| 235 | + |
| 236 | +// NewOverlayBinder creates a new overlay binder. |
| 237 | +func NewOverlayBinder(base, overlay Binder) Binder { |
| 238 | + return &overlayBinder{ |
| 239 | + base: base, |
| 240 | + overlay: overlay, |
| 241 | + } |
| 242 | +} |
| 243 | + |
| 244 | +// graphTaskBinder implements Binder to run a Graph as a task. Any bindings for keys that should be |
| 245 | +// exposed are immediately added to the Binder of the parent graph, so that dependent tasks outside |
| 246 | +// this graph do not have to wait for every task in this graph to complete. |
| 247 | +type graphTaskBinder struct { |
| 248 | + internal, external Binder |
| 249 | + exposeKeys set.Set[ID] |
| 250 | +} |
| 251 | + |
| 252 | +func (gtb *graphTaskBinder) Store(bindings ...Binding) error { |
| 253 | + for _, binding := range bindings { |
| 254 | + if gtb.exposeKeys.Contains(binding.ID()) { |
| 255 | + if err := gtb.external.Store(binding); err != nil { |
| 256 | + return err |
| 257 | + } |
| 258 | + } else { |
| 259 | + if err := gtb.internal.Store(binding); err != nil { |
| 260 | + return err |
| 261 | + } |
| 262 | + } |
| 263 | + } |
| 264 | + return nil |
| 265 | +} |
| 266 | + |
| 267 | +func (gtb *graphTaskBinder) Has(ids ...ID) bool { |
| 268 | + for _, id := range ids { |
| 269 | + if !gtb.internal.Has(id) && !gtb.external.Has(id) { |
| 270 | + return false |
| 271 | + } |
| 272 | + } |
| 273 | + return true |
| 274 | +} |
| 275 | + |
| 276 | +func (gtb *graphTaskBinder) Get(id ID) Binding { |
| 277 | + if ib := gtb.internal.Get(id); ib.Status() != Pending { |
| 278 | + return ib |
| 279 | + } |
| 280 | + return gtb.external.Get(id) |
| 281 | +} |
| 282 | + |
| 283 | +func (gtb *graphTaskBinder) GetAll() []Binding { |
| 284 | + return append(gtb.internal.GetAll(), gtb.external.GetAll()...) |
| 285 | +} |
| 286 | + |
| 287 | +// TestOnlyNewGraphTaskBinder creates a new graph task binder. This is exported for testing, and |
| 288 | +// should not be called in production code. |
| 289 | +func TestOnlyNewGraphTaskBinder(internal, external Binder, exposeKeys set.Set[ID]) Binder { |
| 290 | + return &graphTaskBinder{ |
| 291 | + internal: internal, |
| 292 | + external: external, |
| 293 | + exposeKeys: exposeKeys, |
| 294 | + } |
| 295 | +} |
0 commit comments