-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtape.go
162 lines (142 loc) · 3.88 KB
/
tape.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
package router
import (
"encoding/json"
"errors"
"reflect"
"github.com/Postcord/rest"
"github.com/stretchr/testify/require"
)
type tapeItem struct {
// Input
FuncName string `json:"func_name"`
Params []json.RawMessage `json:"params"`
// Output
Results []json.RawMessage `json:"results"`
GenericError string `json:"generic_error,omitempty"`
RESTError *rest.ErrorREST `json:"rest_error,omitempty"`
}
func mustMarshal(t TestingT, indent bool, item any) []byte {
var b []byte
var err error
if indent {
b, err = json.MarshalIndent(item, "", " ")
} else {
b, err = json.Marshal(item)
}
require.NoError(t, err)
return b
}
func (i *tapeItem) match(t TestingT, funcName string, isVard bool, inCount int, items ...any) {
// Check the right function is called.
if funcName != i.FuncName {
t.Fatalf("wrong function called: expected %s, got %s", i.FuncName, funcName)
return // Here for unit tests - in production this will never be hit.
}
// If the function is variadic, the params check is special.
if isVard {
// The error will be different.
if inCount-1 > len(i.Params) {
t.Fatalf("wrong number of inputs: got %d", (inCount-2)+len(i.Params))
return // Here for unit tests - in production this will never be hit.
}
} else if inCount != len(i.Params) {
t.Fatalf("wrong number of inputs: expected %d, got %d", len(i.Params), inCount)
return // Here for unit tests - in production this will never be hit.
}
// Check all the params are equal.
for x, p := range i.Params {
end := len(items) - 1
if x >= end && isVard {
require.JSONEq(t, string(p), string(mustMarshal(t, false, reflect.ValueOf(items[x]).Field(x-end).Interface())))
}
require.JSONEq(t, string(p), string(mustMarshal(t, false, items[x])))
}
// Get the count of outputs.
outCount := len(items) - inCount
// Check if there is an error on the end.
if outCount > 0 {
ptr, _ := items[len(items)-1].(*error)
if ptr != nil {
if i.GenericError != "" {
*ptr = errors.New(i.GenericError)
} else if i.RESTError != nil {
*ptr = i.RESTError
}
outCount--
}
}
// Check the output count is equal to the number of outputs.
if outCount != len(i.Results) {
t.Fatalf("wrong number of outputs: expected %d, got %d", len(i.Results), outCount)
return // Here for unit tests - in production this will never be hit.
}
// Handle the remainder of the params.
for j, item := range i.Results {
require.NoError(t, json.Unmarshal(item, items[inCount+j]))
}
}
type tape []*tapeItem
func (t *tape) write(funcName string, isVard bool, params ...any) *tapeItem {
undynamicLen := len(params)
if isVard {
undynamicLen--
}
p := make([]json.RawMessage, undynamicLen)
for i, x := range params {
if i == len(params)-1 && isVard {
// Get the item from reflect.
r := reflect.ValueOf(x)
// Get each item from the slice and turn it into JSON.
for j := 0; j < r.Len(); j++ {
b, err := json.Marshal(r.Index(j).Interface())
if err != nil {
panic(err)
}
p = append(p, b)
}
// Break here.
break
}
// Otherwise just handle it as standard.
b, err := json.Marshal(x)
if err != nil {
panic(err)
}
p[i] = b
}
x := &tapeItem{
FuncName: funcName,
Params: p,
}
*t = append(*t, x)
return x
}
func (i *tapeItem) end(items ...any) {
// Check if the last type is an error and if so split it from the items.
var err error
var ok bool
if len(items) > 0 {
err, ok = items[len(items)-1].(error)
if ok || items[len(items)-1] == nil {
items = items[:len(items)-1]
}
}
// Marshal the rest of the results.
p := make([]json.RawMessage, len(items))
for i, x := range items {
b, err := json.Marshal(x)
if err != nil {
panic(err)
}
p[i] = b
}
i.Results = p
// Figure out how to process the error.
if err != nil {
if e, ok := err.(*rest.ErrorREST); ok {
i.RESTError = e
} else {
i.GenericError = err.Error()
}
}
}