Skip to content

Commit e2b989d

Browse files
Copilotsawka
andcommitted
fix: harden multi-arg adapter handling and centralize data type fallback
Co-authored-by: sawka <2722291+sawka@users.noreply.github.com>
1 parent 7d26197 commit e2b989d

5 files changed

Lines changed: 21 additions & 13 deletions

File tree

pkg/gogen/gogen.go

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -107,10 +107,7 @@ func GenMethod_ResponseStream(buf *strings.Builder, methodDecl *wshrpc.WshRpcMet
107107
}
108108

109109
func getWshMethodDataParamsAndExpr(methodDecl *wshrpc.WshRpcMethodDecl) (string, string) {
110-
dataTypes := methodDecl.CommandDataTypes
111-
if len(dataTypes) == 0 && methodDecl.CommandDataType != nil {
112-
dataTypes = []reflect.Type{methodDecl.CommandDataType}
113-
}
110+
dataTypes := methodDecl.GetCommandDataTypes()
114111
if len(dataTypes) == 0 {
115112
return "", "nil"
116113
}

pkg/tsgen/tsgen.go

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -497,10 +497,7 @@ func generateWshClientApiMethod_Call(methodDecl *wshrpc.WshRpcMethodDecl, tsType
497497
}
498498

499499
func getTsWshMethodDataParamsAndExpr(methodDecl *wshrpc.WshRpcMethodDecl, tsTypesMap map[reflect.Type]string) (string, string) {
500-
dataTypes := methodDecl.CommandDataTypes
501-
if len(dataTypes) == 0 && methodDecl.CommandDataType != nil {
502-
dataTypes = []reflect.Type{methodDecl.CommandDataType}
503-
}
500+
dataTypes := methodDecl.GetCommandDataTypes()
504501
if len(dataTypes) == 0 {
505502
return "", "null"
506503
}

pkg/wshrpc/wshrpcmeta.go

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,16 @@ type WshRpcMethodDecl struct {
2020
DefaultResponseDataType reflect.Type
2121
}
2222

23+
func (decl *WshRpcMethodDecl) GetCommandDataTypes() []reflect.Type {
24+
if len(decl.CommandDataTypes) > 0 {
25+
return decl.CommandDataTypes
26+
}
27+
if decl.CommandDataType != nil {
28+
return []reflect.Type{decl.CommandDataType}
29+
}
30+
return nil
31+
}
32+
2333
var contextRType = reflect.TypeOf((*context.Context)(nil)).Elem()
2434
var wshRpcInterfaceRType = reflect.TypeOf((*WshRpcInterface)(nil)).Elem()
2535

pkg/wshrpc/wshrpcmeta_test.go

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,4 +34,7 @@ func TestGenerateWshCommandDecl_MultiArgs(t *testing.T) {
3434
if decl.CommandDataType == nil || decl.CommandDataType.Kind() != reflect.String {
3535
t.Fatalf("expected legacy single data type to remain first arg type, got %v", decl.CommandDataType)
3636
}
37+
if len(decl.GetCommandDataTypes()) != 2 {
38+
t.Fatalf("expected helper to return two command data types")
39+
}
3740
}

pkg/wshutil/wshadapter.go

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -100,10 +100,7 @@ func serverImplAdapter(impl any) func(*RpcResponseHandler) bool {
100100
implMethod := reflect.ValueOf(impl).MethodByName(rmethod.Name)
101101
var callParams []reflect.Value
102102
callParams = append(callParams, reflect.ValueOf(handler.Context()))
103-
commandDataTypes := methodDecl.CommandDataTypes
104-
if len(commandDataTypes) == 0 && methodDecl.CommandDataType != nil {
105-
commandDataTypes = []reflect.Type{methodDecl.CommandDataType}
106-
}
103+
commandDataTypes := methodDecl.GetCommandDataTypes()
107104
if len(commandDataTypes) == 1 {
108105
cmdData, err := recodeCommandData(cmd, handler.GetCommandRawData(), commandDataTypes[0])
109106
if err != nil {
@@ -117,7 +114,11 @@ func serverImplAdapter(impl any) func(*RpcResponseHandler) bool {
117114
handler.SendResponseError(err)
118115
return true
119116
}
120-
multiArg := multiArgAny.(wshrpc.MultiArg)
117+
multiArg, ok := multiArgAny.(wshrpc.MultiArg)
118+
if !ok {
119+
handler.SendResponseError(fmt.Errorf("command %q invalid multi arg payload", cmd))
120+
return true
121+
}
121122
if len(multiArg.Args) != len(commandDataTypes) {
122123
handler.SendResponseError(fmt.Errorf("command %q expected %d args, got %d", cmd, len(commandDataTypes), len(multiArg.Args)))
123124
return true

0 commit comments

Comments
 (0)