|
| 1 | +// Copyright 2025 The Go Authors. All rights reserved. |
| 2 | +// Use of this source code is governed by a BSD-style |
| 3 | +// license that can be found in the LICENSE file. |
| 4 | + |
| 5 | +package unify |
| 6 | + |
| 7 | +import ( |
| 8 | + "fmt" |
| 9 | + "iter" |
| 10 | + "maps" |
| 11 | + "slices" |
| 12 | +) |
| 13 | + |
| 14 | +type Closure struct { |
| 15 | + val *Value |
| 16 | + env nonDetEnv |
| 17 | +} |
| 18 | + |
| 19 | +func NewSum(vs ...*Value) Closure { |
| 20 | + id := &ident{name: "sum"} |
| 21 | + return Closure{NewValue(Var{id}), topEnv.bind(id, vs...)} |
| 22 | +} |
| 23 | + |
| 24 | +// IsBottom returns whether c consists of no values. |
| 25 | +func (c Closure) IsBottom() bool { |
| 26 | + return c.val.Domain == nil |
| 27 | +} |
| 28 | + |
| 29 | +// Summands returns the top-level Values of c. This assumes the top-level of c |
| 30 | +// was constructed as a sum, and is mostly useful for debugging. |
| 31 | +func (c Closure) Summands() iter.Seq[*Value] { |
| 32 | + if v, ok := c.val.Domain.(Var); ok { |
| 33 | + parts := c.env.partitionBy(v.id) |
| 34 | + return func(yield func(*Value) bool) { |
| 35 | + for _, part := range parts { |
| 36 | + if !yield(part.value) { |
| 37 | + return |
| 38 | + } |
| 39 | + } |
| 40 | + } |
| 41 | + } |
| 42 | + return func(yield func(*Value) bool) { |
| 43 | + yield(c.val) |
| 44 | + } |
| 45 | +} |
| 46 | + |
| 47 | +// All enumerates all possible concrete values of c by substituting variables |
| 48 | +// from the environment. |
| 49 | +// |
| 50 | +// E.g., enumerating this Value |
| 51 | +// |
| 52 | +// a: !sum [1, 2] |
| 53 | +// b: !sum [3, 4] |
| 54 | +// |
| 55 | +// results in |
| 56 | +// |
| 57 | +// - {a: 1, b: 3} |
| 58 | +// - {a: 1, b: 4} |
| 59 | +// - {a: 2, b: 3} |
| 60 | +// - {a: 2, b: 4} |
| 61 | +func (c Closure) All() iter.Seq[*Value] { |
| 62 | + // In order to enumerate all concrete values under all possible variable |
| 63 | + // bindings, we use a "non-deterministic continuation passing style" to |
| 64 | + // implement this. We use CPS to traverse the Value tree, threading the |
| 65 | + // (possibly narrowing) environment through that CPS following an Euler |
| 66 | + // tour. Where the environment permits multiple choices, we invoke the same |
| 67 | + // continuation for each choice. Similar to a yield function, the |
| 68 | + // continuation can return false to stop the non-deterministic walk. |
| 69 | + return func(yield func(*Value) bool) { |
| 70 | + c.val.all1(c.env, func(v *Value, e nonDetEnv) bool { |
| 71 | + return yield(v) |
| 72 | + }) |
| 73 | + } |
| 74 | +} |
| 75 | + |
| 76 | +func (v *Value) all1(e nonDetEnv, cont func(*Value, nonDetEnv) bool) bool { |
| 77 | + switch d := v.Domain.(type) { |
| 78 | + default: |
| 79 | + panic(fmt.Sprintf("unknown domain type %T", d)) |
| 80 | + |
| 81 | + case nil: |
| 82 | + return true |
| 83 | + |
| 84 | + case Top, String: |
| 85 | + return cont(v, e) |
| 86 | + |
| 87 | + case Def: |
| 88 | + fields := d.keys() |
| 89 | + // We can reuse this parts slice because we're doing a DFS through the |
| 90 | + // state space. (Otherwise, we'd have to do some messy threading of an |
| 91 | + // immutable slice-like value through allElt.) |
| 92 | + parts := make(map[string]*Value, len(fields)) |
| 93 | + |
| 94 | + // TODO: If there are no Vars or Sums under this Def, then nothing can |
| 95 | + // change the Value or env, so we could just cont(v, e). |
| 96 | + var allElt func(elt int, e nonDetEnv) bool |
| 97 | + allElt = func(elt int, e nonDetEnv) bool { |
| 98 | + if elt == len(fields) { |
| 99 | + // Build a new Def from the concrete parts. Clone parts because |
| 100 | + // we may reuse it on other non-deterministic branches. |
| 101 | + nVal := newValueFrom(Def{maps.Clone(parts)}, v) |
| 102 | + return cont(nVal, e) |
| 103 | + } |
| 104 | + |
| 105 | + return d.fields[fields[elt]].all1(e, func(v *Value, e nonDetEnv) bool { |
| 106 | + parts[fields[elt]] = v |
| 107 | + return allElt(elt+1, e) |
| 108 | + }) |
| 109 | + } |
| 110 | + return allElt(0, e) |
| 111 | + |
| 112 | + case Tuple: |
| 113 | + // Essentially the same as Def. |
| 114 | + if d.repeat != nil { |
| 115 | + // There's nothing we can do with this. |
| 116 | + return cont(v, e) |
| 117 | + } |
| 118 | + parts := make([]*Value, len(d.vs)) |
| 119 | + var allElt func(elt int, e nonDetEnv) bool |
| 120 | + allElt = func(elt int, e nonDetEnv) bool { |
| 121 | + if elt == len(d.vs) { |
| 122 | + // Build a new tuple from the concrete parts. Clone parts because |
| 123 | + // we may reuse it on other non-deterministic branches. |
| 124 | + nVal := newValueFrom(Tuple{vs: slices.Clone(parts)}, v) |
| 125 | + return cont(nVal, e) |
| 126 | + } |
| 127 | + |
| 128 | + return d.vs[elt].all1(e, func(v *Value, e nonDetEnv) bool { |
| 129 | + parts[elt] = v |
| 130 | + return allElt(elt+1, e) |
| 131 | + }) |
| 132 | + } |
| 133 | + return allElt(0, e) |
| 134 | + |
| 135 | + case Var: |
| 136 | + // Go each way this variable can be bound. |
| 137 | + for _, ePart := range e.partitionBy(d.id) { |
| 138 | + // d.id is no longer bound in this environment partition. We'll may |
| 139 | + // need it later in the Euler tour, so bind it back to this single |
| 140 | + // value. |
| 141 | + env := ePart.env.bind(d.id, ePart.value) |
| 142 | + if !ePart.value.all1(env, cont) { |
| 143 | + return false |
| 144 | + } |
| 145 | + } |
| 146 | + return true |
| 147 | + } |
| 148 | +} |
0 commit comments