From d062a5a127368c8833eb18b963375adc229ee448 Mon Sep 17 00:00:00 2001 From: Liang Yuxuan Date: Tue, 17 Dec 2024 15:00:01 +0800 Subject: [PATCH] Fix issue https://github.com/yuin/gopher-lua/issues/214: replace golang native map with swiss map for table hash parts --- bits.go | 46 +++++++ bits_amd64.go | 38 ++++++ bits_amd64.s | 19 +++ swiss_map.go | 371 ++++++++++++++++++++++++++++++++++++++++++++++++++ table.go | 158 ++++++++++++--------- table_test.go | 7 +- value.go | 6 +- 7 files changed, 575 insertions(+), 70 deletions(-) create mode 100644 bits.go create mode 100644 bits_amd64.go create mode 100644 bits_amd64.s create mode 100644 swiss_map.go diff --git a/bits.go b/bits.go new file mode 100644 index 00000000..8bc35281 --- /dev/null +++ b/bits.go @@ -0,0 +1,46 @@ +// This is a port of the bits.go file from the dolthub/swiss repository. +// The original source code is licensed under the Apache License, Version 2.0. +// The original source code can be found at: +// https://github.com/dolthub/swiss + +//go:build !amd64 || nosimd + +package lua + +import ( + "math/bits" + "unsafe" +) + +const ( + groupSize = 8 + maxAvgGroupLoad = 7 + + loBits uint64 = 0x0101010101010101 + hiBits uint64 = 0x8080808080808080 +) + +type bitset uint64 + +func metaMatchH2(m *metadata, h h2) bitset { + // https://graphics.stanford.edu/~seander/bithacks.html##ValueInWord + return hasZeroByte(castUint64(m) ^ (loBits * uint64(h))) +} + +func metaMatchEmpty(m *metadata) bitset { + return hasZeroByte(castUint64(m) ^ hiBits) +} + +func nextMatch(b *bitset) uint32 { + s := uint32(bits.TrailingZeros64(uint64(*b))) + *b &= ^(1 << s) // clear bit |s| + return s >> 3 // div by 8 +} + +func hasZeroByte(x uint64) bitset { + return bitset(((x - loBits) & ^(x)) & hiBits) +} + +func castUint64(m *metadata) uint64 { + return *(*uint64)((unsafe.Pointer)(m)) +} diff --git a/bits_amd64.go b/bits_amd64.go new file mode 100644 index 00000000..53891cc7 --- /dev/null +++ b/bits_amd64.go @@ -0,0 +1,38 @@ +// Code generated by command: go run asm.go -out match.s -stubs match_amd64.go. DO NOT EDIT. + +//go:build amd64 + +package lua + +import ( + "math/bits" + _ "unsafe" +) + +const ( + groupSize = 16 + maxAvgGroupLoad = 14 +) + +type bitset uint16 + +func metaMatchH2(m *metadata, h h2) bitset { + b := matchMetadata((*[16]int8)(m), int8(h)) + return bitset(b) +} + +func metaMatchEmpty(m *metadata) bitset { + b := matchMetadata((*[16]int8)(m), empty) + return bitset(b) +} + +func nextMatch(b *bitset) (s uint32) { + s = uint32(bits.TrailingZeros16(uint16(*b))) + *b &= ^(1 << s) // clear bit |s| + return +} + + +// matchMetadata performs a 16-way probe of |metadata| using SSE instructions +// nb: |metadata| must be an aligned pointer +func matchMetadata(metadata *[16]int8, hash int8) uint16 diff --git a/bits_amd64.s b/bits_amd64.s new file mode 100644 index 00000000..c8a67866 --- /dev/null +++ b/bits_amd64.s @@ -0,0 +1,19 @@ +// Code generated by command: go run asm.go -out match.s -stubs match_amd64.go. DO NOT EDIT. + +//go:build amd64 + +#include "textflag.h" + +// func matchMetadata(metadata *[16]int8, hash int8) uint16 +// Requires: SSE2, SSSE3 +TEXT ·matchMetadata(SB), NOSPLIT, $0-18 + MOVQ metadata+0(FP), AX + MOVBLSX hash+8(FP), CX + MOVD CX, X0 + PXOR X1, X1 + PSHUFB X1, X0 + MOVOU (AX), X1 + PCMPEQB X1, X0 + PMOVMSKB X0, AX + MOVW AX, ret+16(FP) + RET diff --git a/swiss_map.go b/swiss_map.go new file mode 100644 index 00000000..ec01d30b --- /dev/null +++ b/swiss_map.go @@ -0,0 +1,371 @@ +// This is a port of the swiss_map.go file from the dolthub/swiss repository. +// The original source code is licensed under the Apache License, Version 2.0. +// The original source code can be found at: +// https://github.com/dolthub/swiss +// remove generic functions and types to adapt to the project +// remove unused functions and types + +package lua + +import ( + "math/rand" + "unsafe" +) + +type keyKind int + +// currently only two key kinds are supported +const ( + KeyKindStr keyKind = iota + keyKindIntr +) + +type hashFunc func(unsafe.Pointer, uintptr) uintptr + +//go:linkname strhash runtime.strhash +func strhash(p unsafe.Pointer, s uintptr) uintptr + +//go:linkname interhash runtime.interhash +func interhash(p unsafe.Pointer, s uintptr) uintptr + +// baseMap is an open-addressing hash map +// based on Abseil's flat_hash_map. +type baseMap struct { + kind keyKind + ctrl []metadata + groups []group + randSeed uint64 + hashFn hashFunc + resident uint32 + dead uint32 + limit uint32 +} + +// metadata is the h2 metadata array for a group. +// find operations first probe the controls bytes +// to filter candidates before matching keys +type metadata [groupSize]int8 + +type groupKey [2]uintptr + +// group is a group of 16 key-value pairs +type group struct { + keys [groupSize]groupKey + values [groupSize]LValue +} + +const ( + h1Mask uint64 = 0xffff_ffff_ffff_ff80 + h2Mask uint64 = 0x0000_0000_0000_007f + empty int8 = -128 // 0b1000_0000 + tombstone int8 = -2 // 0b1111_1110 +) + +// h1 is a 57 bit hash prefix +type h1 uint64 + +// h2 is a 7 bit hash suffix +type h2 int8 + +// newMap constructs a Map. +func newMap(kind keyKind, sz uint32) (m *baseMap) { + var hashFn hashFunc + if kind == KeyKindStr { + hashFn = strhash + } else { + hashFn = interhash + } + groups := numGroups(sz) + m = &baseMap{ + kind: kind, + ctrl: make([]metadata, groups), + groups: make([]group, groups), + randSeed: rand.Uint64(), + hashFn: hashFn, + limit: groups * maxAvgGroupLoad, + } + for i := range m.ctrl { + m.ctrl[i] = newEmptyMetadata() + } + return +} + +func (m *baseMap) hash(key unsafe.Pointer) uint64 { + return uint64(m.hashFn(key, uintptr(m.randSeed))) +} + +func (m *baseMap) keyEqualGroupKey(key unsafe.Pointer, k *groupKey) bool { + if m.kind == KeyKindStr { + keyStrPtr := (*string)(key) + kStrPtr := (*string)(unsafe.Pointer(k)) + return *keyStrPtr == *kStrPtr + } + keyIntPtr := (*LValue)(key) + kIntPtr := (*LValue)(unsafe.Pointer(k)) + return *keyIntPtr == *kIntPtr +} + +// get returns the |value| mapped by |key| if one exists. +func (m *baseMap) get(key unsafe.Pointer) (value LValue, ok bool) { + if m == nil { + return + } + hi, lo := splitHash(m.hash(key)) + g := probeStart(hi, len(m.groups)) + for { // inlined find loop + matches := metaMatchH2(&m.ctrl[g], lo) + for matches != 0 { + s := nextMatch(&matches) + if m.keyEqualGroupKey(key, &m.groups[g].keys[s]) { + value, ok = m.groups[g].values[s], true + return + } + } + // |key| is not in group |g|, + // stop probing if we see an empty slot + matches = metaMatchEmpty(&m.ctrl[g]) + if matches != 0 { + ok = false + return + } + g += 1 // linear probing + if g >= uint32(len(m.groups)) { + g = 0 + } + } +} + +// first returns the first key-value pair in the Map. +func (m *baseMap) first() (key unsafe.Pointer, value LValue, ok bool) { + for g, c := range m.ctrl { + for s := range c { + if c[s] != empty && c[s] != tombstone { + return (unsafe.Pointer)(&m.groups[g].keys[s]), m.groups[g].values[s], true + } + } + } + return +} + +// put attempts to insert |key| and |value| +func (m *baseMap) put(key unsafe.Pointer, value LValue) { + if m.resident >= m.limit { + m.rehash(m.nextSize()) + } + + hi, lo := splitHash(m.hash(key)) + g := probeStart(hi, len(m.groups)) + for { // inlined find loop + matches := metaMatchH2(&m.ctrl[g], lo) + for matches != 0 { + s := nextMatch(&matches) + if m.keyEqualGroupKey(key, &m.groups[g].keys[s]) { // update + m.groups[g].keys[s] = *(*groupKey)(key) + m.groups[g].values[s] = value + return + } + } + // |key| is not in group |g|, + // stop probing if we see an empty slot + matches = metaMatchEmpty(&m.ctrl[g]) + if matches != 0 { // insert + s := nextMatch(&matches) + m.groups[g].keys[s] = *(*groupKey)(key) + m.groups[g].values[s] = value + m.ctrl[g][s] = int8(lo) + m.resident++ + return + } + g += 1 // linear probing + if g >= uint32(len(m.groups)) { + g = 0 + } + } +} + +// delete attempts to remove |key|, returns true successful. +func (m *baseMap) delete(key unsafe.Pointer) (ok bool) { + hi, lo := splitHash(m.hash(key)) + g := probeStart(hi, len(m.groups)) + for { + matches := metaMatchH2(&m.ctrl[g], lo) + for matches != 0 { + s := nextMatch(&matches) + if m.keyEqualGroupKey(key, &m.groups[g].keys[s]) { + ok = true + // optimization: if |m.ctrl[g]| contains any empty + // metadata bytes, we can physically delete |key| + // rather than placing a tombstone. + // The observation is that any probes into group |g| + // would already be terminated by the existing empty + // slot, and therefore reclaiming slot |s| will not + // cause premature termination of probes into |g|. + if metaMatchEmpty(&m.ctrl[g]) != 0 { + m.ctrl[g][s] = empty + m.resident-- + } else { + m.ctrl[g][s] = tombstone + m.dead++ + } + var k groupKey + var v LValue + m.groups[g].keys[s] = k + m.groups[g].values[s] = v + return + } + } + // |key| is not in group |g|, + // stop probing if we see an empty slot + matches = metaMatchEmpty(&m.ctrl[g]) + if matches != 0 { // |key| absent + ok = false + return + } + g += 1 // linear probing + if g >= uint32(len(m.groups)) { + g = 0 + } + } +} + +// iter iterates the elements of the Map, passing them to the callback. +// It guarantees that any key in the Map will be visited only once, and +// for un-mutated Maps, every key will be visited once. If the Map is +// Mutated during iteration, mutations will be reflected on return from +// iter, but the set of keys visited by iter is non-deterministic. +func (m *baseMap) iter(cb func(k unsafe.Pointer, v LValue) (stop bool)) { + // take a consistent view of the table in case + // we rehash during iteration + ctrl, groups := m.ctrl, m.groups + // pick a random starting group + g := rand.Intn(len(groups)) + for n := 0; n < len(groups); n++ { + for s, c := range ctrl[g] { + if c == empty || c == tombstone { + continue + } + k, v := groups[g].keys[s], groups[g].values[s] + if stop := cb((unsafe.Pointer)(&k), v); stop { + return + } + } + g++ + if g >= len(groups) { + g = 0 + } + } +} + +// count returns the number of elements in the Map. +func (m *baseMap) count() int { + return int(m.resident - m.dead) +} + +// // capacity returns the number of additional elements +// // the can be added to the Map before resizing. +// func (m *baseMap) capacity() int { +// return int(m.limit - m.resident) +// } + +// findNext returns the next key-value pair in the Map after |key|. +func (m *baseMap) findNext(key unsafe.Pointer) (retKey unsafe.Pointer, retValue LValue, ok bool) { + hi, lo := splitHash(m.hash(key)) + startG := probeStart(hi, len(m.groups)) + g := startG + for { // inlined find loop + matches := metaMatchH2(&m.ctrl[g], lo) + for matches != 0 { + s := nextMatch(&matches) + if m.keyEqualGroupKey(key, &m.groups[g].keys[s]) { + // move to the next key + for { + s++ + if s >= groupSize { + s = 0 + g++ + if g >= uint32(len(m.groups)) { + // end of the table + ok = false + return + } + } + if m.ctrl[g][s] != empty && m.ctrl[g][s] != tombstone { + retKey = (unsafe.Pointer)(&m.groups[g].keys[s]) + retValue = m.groups[g].values[s] + ok = true + return + } + } + } + } + // |key| is not in group |g|, + // stop probing if we see an empty slot + matches = metaMatchEmpty(&m.ctrl[g]) + if matches != 0 { + ok = false + return + } + g += 1 // linear probing + if g >= uint32(len(m.groups)) { + g = 0 + } + } +} + +func (m *baseMap) nextSize() (n uint32) { + n = uint32(len(m.groups)) * 2 + if m.dead >= (m.resident / 2) { + n = uint32(len(m.groups)) + } + return +} + +func (m *baseMap) rehash(n uint32) { + groups, ctrl := m.groups, m.ctrl + m.groups = make([]group, n) + m.ctrl = make([]metadata, n) + for i := range m.ctrl { + m.ctrl[i] = newEmptyMetadata() + } + m.randSeed = rand.Uint64() + m.limit = n * maxAvgGroupLoad + m.resident, m.dead = 0, 0 + for g := range ctrl { + for s := range ctrl[g] { + c := ctrl[g][s] + if c == empty || c == tombstone { + continue + } + m.put((unsafe.Pointer)(&groups[g].keys[s]), groups[g].values[s]) + } + } +} + +// numGroups returns the minimum number of groups needed to store |n| elems. +func numGroups(n uint32) (groups uint32) { + groups = (n + maxAvgGroupLoad - 1) / maxAvgGroupLoad + if groups == 0 { + groups = 1 + } + return +} + +func newEmptyMetadata() (meta metadata) { + for i := range meta { + meta[i] = empty + } + return +} + +func splitHash(h uint64) (h1, h2) { + return h1((h & h1Mask) >> 7), h2(h & h2Mask) +} + +func probeStart(hi h1, groups int) uint32 { + return fastModN(uint32(hi), uint32(groups)) +} + +// lemire.me/blog/2016/06/27/a-fast-alternative-to-the-modulo-reduction/ +func fastModN(x, n uint32) uint32 { + return uint32((uint64(x) * uint64(n)) >> 32) +} diff --git a/table.go b/table.go index ddf14dd8..b300457a 100644 --- a/table.go +++ b/table.go @@ -1,5 +1,7 @@ package lua +import "unsafe" + const defaultArrayCap = 32 const defaultHashCap = 32 @@ -41,7 +43,7 @@ func newLTable(acap int, hcap int) *LTable { tb.array = make([]LValue, 0, acap) } if hcap != 0 { - tb.strdict = make(map[string]LValue, hcap) + tb.strdict = newMap(KeyKindStr, uint32(hcap)) } return tb } @@ -157,10 +159,26 @@ func (tb *LTable) RawSet(key LValue, value LValue) { case index == alen: tb.array = append(tb.array, value) case index > alen: - for i := 0; i < (index - alen); i++ { - tb.array = append(tb.array, LNil) + if int(v) > cap(tb.array) { + // Optimize the capacity expansion process to avoid generating a + // large amount of GC during the append expansion process after large jumps. + capSize := int(float64(v) * 1.25) + if capSize > MaxArrayIndex { + capSize = MaxArrayIndex + } + temp := make([]LValue, int(v), capSize) + copy(temp, tb.array) + for i := alen; i < index; i++ { + temp[i] = LNil + } + temp[index] = value + tb.array = temp + } else { + for i := 0; i < (index - alen); i++ { + tb.array = append(tb.array, LNil) + } + tb.array = append(tb.array, value) } - tb.array = append(tb.array, value) case index < alen: tb.array[index] = value } @@ -189,10 +207,26 @@ func (tb *LTable) RawSetInt(key int, value LValue) { case index == alen: tb.array = append(tb.array, value) case index > alen: - for i := 0; i < (index - alen); i++ { - tb.array = append(tb.array, LNil) + if key > cap(tb.array) { + // Optimize the capacity expansion process to avoid generating a + // large amount of GC during the append expansion process after large jumps. + capSize := int(float64(key) * 1.25) + if capSize > MaxArrayIndex { + capSize = MaxArrayIndex + } + temp := make([]LValue, key, capSize) + copy(temp, tb.array) + for i := alen; i < index; i++ { + temp[i] = LNil + } + temp[index] = value + tb.array = temp + } else { + for i := 0; i < (index - alen); i++ { + tb.array = append(tb.array, LNil) + } + tb.array = append(tb.array, value) } - tb.array = append(tb.array, value) case index < alen: tb.array[index] = value } @@ -201,23 +235,13 @@ func (tb *LTable) RawSetInt(key int, value LValue) { // RawSetString sets a given LValue to a given string index without the __newindex metamethod. func (tb *LTable) RawSetString(key string, value LValue) { if tb.strdict == nil { - tb.strdict = make(map[string]LValue, defaultHashCap) - } - if tb.keys == nil { - tb.keys = []LValue{} - tb.k2i = map[LValue]int{} + tb.strdict = newMap(KeyKindStr, defaultHashCap) } if value == LNil { - // TODO tb.keys and tb.k2i should also be removed - delete(tb.strdict, key) + tb.strdict.delete(unsafe.Pointer(&key)) } else { - tb.strdict[key] = value - lkey := LString(key) - if _, ok := tb.k2i[lkey]; !ok { - tb.k2i[lkey] = len(tb.keys) - tb.keys = append(tb.keys, lkey) - } + tb.strdict.put(unsafe.Pointer(&key), value) } } @@ -228,22 +252,13 @@ func (tb *LTable) RawSetH(key LValue, value LValue) { return } if tb.dict == nil { - tb.dict = make(map[LValue]LValue, len(tb.strdict)) - } - if tb.keys == nil { - tb.keys = []LValue{} - tb.k2i = map[LValue]int{} + tb.dict = newMap(keyKindIntr, defaultHashCap) } if value == LNil { - // TODO tb.keys and tb.k2i should also be removed - delete(tb.dict, key) + tb.dict.delete(unsafe.Pointer(&key)) } else { - tb.dict[key] = value - if _, ok := tb.k2i[key]; !ok { - tb.k2i[key] = len(tb.keys) - tb.keys = append(tb.keys, key) - } + tb.dict.put(unsafe.Pointer(&key), value) } } @@ -265,7 +280,7 @@ func (tb *LTable) RawGet(key LValue) LValue { if tb.strdict == nil { return LNil } - if ret, ok := tb.strdict[string(v)]; ok { + if ret, ok := tb.strdict.get((unsafe.Pointer)(&v)); ok { return ret } return LNil @@ -273,7 +288,7 @@ func (tb *LTable) RawGet(key LValue) LValue { if tb.dict == nil { return LNil } - if v, ok := tb.dict[key]; ok { + if v, ok := tb.dict.get((unsafe.Pointer)(&key)); ok { return v } return LNil @@ -297,7 +312,7 @@ func (tb *LTable) RawGetH(key LValue) LValue { if tb.strdict == nil { return LNil } - if v, vok := tb.strdict[string(s)]; vok { + if v, vok := tb.strdict.get((unsafe.Pointer)(&s)); vok { return v } return LNil @@ -305,7 +320,7 @@ func (tb *LTable) RawGetH(key LValue) LValue { if tb.dict == nil { return LNil } - if v, ok := tb.dict[key]; ok { + if v, ok := tb.dict.get((unsafe.Pointer)(&key)); ok { return v } return LNil @@ -316,7 +331,7 @@ func (tb *LTable) RawGetString(key string) LValue { if tb.strdict == nil { return LNil } - if v, vok := tb.strdict[string(key)]; vok { + if v, vok := tb.strdict.get((unsafe.Pointer)(&key)); vok { return v } return LNil @@ -332,18 +347,20 @@ func (tb *LTable) ForEach(cb func(LValue, LValue)) { } } if tb.strdict != nil { - for k, v := range tb.strdict { + tb.strdict.iter(func(k unsafe.Pointer, v LValue) bool { if v != LNil { - cb(LString(k), v) + cb(LString(*(*string)(k)), v) } - } + return false + }) } if tb.dict != nil { - for k, v := range tb.dict { + tb.dict.iter(func(k unsafe.Pointer, v LValue) bool { if v != LNil { - cb(k, v) + cb(*(*LValue)(k), v) } - } + return false + }) } } @@ -351,36 +368,49 @@ func (tb *LTable) ForEach(cb func(LValue, LValue)) { func (tb *LTable) Next(key LValue) (LValue, LValue) { init := false if key == LNil { - key = LNumber(0) init = true } - if init || key != LNumber(0) { - if kv, ok := key.(LNumber); ok && isInteger(kv) && int(kv) >= 0 && kv < LNumber(MaxArrayIndex) { - index := int(kv) - if tb.array != nil { - for ; index < len(tb.array); index++ { - if v := tb.array[index]; v != LNil { - return LNumber(index + 1), v - } + if kv, ok := key.(LNumber); (ok && isInteger(kv) && int(kv) > 0 && kv < LNumber(MaxArrayIndex)) || init { + index := int(kv) + if tb.array != nil { + for ; index < len(tb.array); index++ { + if v := tb.array[index]; v != LNil { + return LNumber(index + 1), v } } - if tb.array == nil || index == len(tb.array) { - if (tb.dict == nil || len(tb.dict) == 0) && (tb.strdict == nil || len(tb.strdict) == 0) { - return LNil, LNil - } - key = tb.keys[0] - if v := tb.RawGetH(key); v != LNil { - return key, v - } + } + + if tb.strdict != nil && tb.strdict.count() > 0 { + key, val, _ := tb.strdict.first() + return LString(*(*string)(key)), val + } else if tb.dict != nil && tb.dict.count() > 0 { + key, val, _ := tb.dict.first() + return *(*LValue)(key), val + } else { + return LNil, LNil + } + } + + if kstr, ok := key.(LString); ok { + if tb.strdict != nil { + nextKey, nextVal, ok2 := tb.strdict.findNext((unsafe.Pointer)(&kstr)) + if ok2 { + return LString(*(*string)(nextKey)), nextVal } } + if tb.dict != nil && tb.dict.count() > 0 { + key, val, _ := tb.dict.first() + return *(*LValue)(key), val + } else { + return LNil, LNil + } } - for i := tb.k2i[key] + 1; i < len(tb.keys); i++ { - key := tb.keys[i] - if v := tb.RawGetH(key); v != LNil { - return key, v + if tb.dict != nil { + nextKey, nextVal, ok := tb.dict.findNext((unsafe.Pointer)(&key)) + if ok { + return *(*LValue)(nextKey), nextVal } } return LNil, LNil diff --git a/table_test.go b/table_test.go index 6acbbb2c..8e8db793 100644 --- a/table_test.go +++ b/table_test.go @@ -2,6 +2,7 @@ package lua import ( "testing" + "unsafe" ) func TestTableNewLTable(t *testing.T) { @@ -156,12 +157,14 @@ func TestTableRawSetH(t *testing.T) { tbl := newLTable(0, 0) tbl.RawSetH(LString("key"), LTrue) tbl.RawSetH(LString("key"), LNil) - _, found := tbl.dict[LString("key")] + var val LValue = LString("key") + _, found := tbl.dict.get(unsafe.Pointer(&val)) errorIfNotEqual(t, false, found) tbl.RawSetH(LTrue, LTrue) tbl.RawSetH(LTrue, LNil) - _, foundb := tbl.dict[LTrue] + val = LTrue + _, foundb := tbl.dict.get((unsafe.Pointer(&val))) errorIfNotEqual(t, false, foundb) } diff --git a/value.go b/value.go index 4156e9d5..066e8903 100644 --- a/value.go +++ b/value.go @@ -144,10 +144,8 @@ type LTable struct { Metatable LValue array []LValue - dict map[LValue]LValue - strdict map[string]LValue - keys []LValue - k2i map[LValue]int + dict *baseMap + strdict *baseMap } func (tb *LTable) String() string { return fmt.Sprintf("table: %p", tb) }