Skip to content

Commit 2ba8d67

Browse files
committed
Add input validation and write errors to stderr
1 parent 3d8db5e commit 2ba8d67

3 files changed

Lines changed: 52 additions & 25 deletions

File tree

cmd/mcpsnag/main.go

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -113,7 +113,7 @@ func main() {
113113
}
114114

115115
url := flag.Arg(0)
116-
printer := output.NewPrinter(os.Stdout, compact, verbose)
116+
printer := output.NewPrinter(os.Stdout, os.Stderr, compact, verbose)
117117

118118
if !initOnly && data == "" {
119119
fmt.Fprintln(os.Stderr, "error: -d/--data is required (or use --init-only)")
@@ -124,9 +124,11 @@ func main() {
124124
headerMap := make(map[string]string)
125125
for _, h := range headers {
126126
parts := strings.SplitN(h, ":", 2)
127-
if len(parts) == 2 {
128-
headerMap[strings.TrimSpace(parts[0])] = strings.TrimSpace(parts[1])
127+
if len(parts) != 2 {
128+
fmt.Fprintf(os.Stderr, "warning: invalid header format %q (expected 'Key: Value')\n", h)
129+
continue
129130
}
131+
headerMap[strings.TrimSpace(parts[0])] = strings.TrimSpace(parts[1])
130132
}
131133

132134
c := client.New(client.Options{
@@ -192,6 +194,11 @@ func runRequest(c *client.Client, printer *output.Printer, data string) {
192194
os.Exit(1)
193195
}
194196

197+
if userReq.Method == "" {
198+
printer.PrintError(fmt.Errorf("missing 'method' field in request"))
199+
os.Exit(1)
200+
}
201+
195202
resp, err := c.Request(userReq.Method, userReq.Params, func(r protocol.Response) error {
196203
if r.Result != nil {
197204
return printer.PrintRawJSON(r.Result)

internal/output/printer.go

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -10,13 +10,15 @@ import (
1010

1111
type Printer struct {
1212
out io.Writer
13+
errOut io.Writer
1314
compact bool
1415
verbose bool
1516
}
1617

17-
func NewPrinter(out io.Writer, compact, verbose bool) *Printer {
18+
func NewPrinter(out, errOut io.Writer, compact, verbose bool) *Printer {
1819
return &Printer{
1920
out: out,
21+
errOut: errOut,
2022
compact: compact,
2123
verbose: verbose,
2224
}
@@ -101,11 +103,11 @@ func (p *Printer) PrintVerbose(format string, args ...any) {
101103
if !p.verbose {
102104
return
103105
}
104-
fmt.Fprintf(p.out, format+"\n", args...)
106+
fmt.Fprintf(p.errOut, format+"\n", args...)
105107
}
106108

107109
func (p *Printer) PrintError(err error) {
108-
fmt.Fprintf(p.out, "error: %v\n", err)
110+
fmt.Fprintf(p.errOut, "error: %v\n", err)
109111
}
110112

111113
func (p *Printer) PrintSessionInfo(sessionID string) {

internal/output/printer_test.go

Lines changed: 37 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -3,13 +3,14 @@ package output
33
import (
44
"bytes"
55
"encoding/json"
6+
"errors"
67
"strings"
78
"testing"
89
)
910

1011
func TestPrinterPrintJSON(t *testing.T) {
1112
var buf bytes.Buffer
12-
p := NewPrinter(&buf, false, false)
13+
p := NewPrinter(&buf, &bytes.Buffer{}, false, false)
1314

1415
data := map[string]string{"key": "value"}
1516
err := p.PrintJSON(data)
@@ -28,7 +29,7 @@ func TestPrinterPrintJSON(t *testing.T) {
2829

2930
func TestPrinterPrintJSONCompact(t *testing.T) {
3031
var buf bytes.Buffer
31-
p := NewPrinter(&buf, true, false)
32+
p := NewPrinter(&buf, &bytes.Buffer{}, true, false)
3233

3334
data := map[string]string{"key": "value"}
3435
err := p.PrintJSON(data)
@@ -45,7 +46,7 @@ func TestPrinterPrintJSONCompact(t *testing.T) {
4546

4647
func TestPrinterPrintJSONPretty(t *testing.T) {
4748
var buf bytes.Buffer
48-
p := NewPrinter(&buf, false, false)
49+
p := NewPrinter(&buf, &bytes.Buffer{}, false, false)
4950

5051
data := map[string]string{"key": "value"}
5152
err := p.PrintJSON(data)
@@ -61,7 +62,7 @@ func TestPrinterPrintJSONPretty(t *testing.T) {
6162

6263
func TestPrinterPrintRawJSON(t *testing.T) {
6364
var buf bytes.Buffer
64-
p := NewPrinter(&buf, false, false)
65+
p := NewPrinter(&buf, &bytes.Buffer{}, false, false)
6566

6667
raw := json.RawMessage(`{"tools":[{"name":"test"}]}`)
6768
err := p.PrintRawJSON(raw)
@@ -77,7 +78,7 @@ func TestPrinterPrintRawJSON(t *testing.T) {
7778

7879
func TestPrinterPrintRawJSONCompact(t *testing.T) {
7980
var buf bytes.Buffer
80-
p := NewPrinter(&buf, true, false)
81+
p := NewPrinter(&buf, &bytes.Buffer{}, true, false)
8182

8283
raw := json.RawMessage(`{"key":"value"}`)
8384
err := p.PrintRawJSON(raw)
@@ -94,7 +95,7 @@ func TestPrinterPrintRawJSONCompact(t *testing.T) {
9495

9596
func TestPrinterPrintRawJSONInvalid(t *testing.T) {
9697
var buf bytes.Buffer
97-
p := NewPrinter(&buf, false, false)
98+
p := NewPrinter(&buf, &bytes.Buffer{}, false, false)
9899

99100
raw := json.RawMessage(`not valid json`)
100101
err := p.PrintRawJSON(raw)
@@ -109,43 +110,60 @@ func TestPrinterPrintRawJSONInvalid(t *testing.T) {
109110
}
110111

111112
func TestPrinterPrintVerbose(t *testing.T) {
112-
var buf bytes.Buffer
113-
p := NewPrinter(&buf, false, true)
113+
var errBuf bytes.Buffer
114+
p := NewPrinter(&bytes.Buffer{}, &errBuf, false, true)
114115

115116
p.PrintVerbose("test message %s", "arg")
116117

117-
output := buf.String()
118+
output := errBuf.String()
118119
if !strings.Contains(output, "test message arg") {
119120
t.Errorf("expected verbose message, got %s", output)
120121
}
121122
}
122123

123124
func TestPrinterPrintVerboseDisabled(t *testing.T) {
124-
var buf bytes.Buffer
125-
p := NewPrinter(&buf, false, false)
125+
var errBuf bytes.Buffer
126+
p := NewPrinter(&bytes.Buffer{}, &errBuf, false, false)
126127

127128
p.PrintVerbose("test message")
128129

129-
output := buf.String()
130+
output := errBuf.String()
130131
if output != "" {
131132
t.Errorf("expected no output when verbose disabled, got %s", output)
132133
}
133134
}
134135

135136
func TestPrinterPrintError(t *testing.T) {
136-
var buf bytes.Buffer
137-
p := NewPrinter(&buf, false, false)
137+
var errBuf bytes.Buffer
138+
p := NewPrinter(&bytes.Buffer{}, &errBuf, false, false)
138139

139-
p.PrintError(nil)
140-
output := buf.String()
140+
p.PrintError(errors.New("test error"))
141+
output := errBuf.String()
141142
if !strings.Contains(output, "error:") {
142143
t.Errorf("expected error prefix, got %s", output)
143144
}
145+
if !strings.Contains(output, "test error") {
146+
t.Errorf("expected error message, got %s", output)
147+
}
148+
}
149+
150+
func TestPrinterPrintErrorToStderr(t *testing.T) {
151+
var outBuf, errBuf bytes.Buffer
152+
p := NewPrinter(&outBuf, &errBuf, false, false)
153+
154+
p.PrintError(errors.New("test error"))
155+
156+
if outBuf.String() != "" {
157+
t.Errorf("expected no output to stdout, got %s", outBuf.String())
158+
}
159+
if !strings.Contains(errBuf.String(), "error:") {
160+
t.Errorf("expected error to stderr, got %s", errBuf.String())
161+
}
144162
}
145163

146164
func TestPrinterPrintSessionInfo(t *testing.T) {
147165
var buf bytes.Buffer
148-
p := NewPrinter(&buf, false, false)
166+
p := NewPrinter(&buf, &bytes.Buffer{}, false, false)
149167

150168
p.PrintSessionInfo("test-session-123")
151169

@@ -160,7 +178,7 @@ func TestPrinterPrintSessionInfo(t *testing.T) {
160178

161179
func TestPrinterPrintRequestVerbose(t *testing.T) {
162180
var buf bytes.Buffer
163-
p := NewPrinter(&buf, false, true)
181+
p := NewPrinter(&buf, &bytes.Buffer{}, false, true)
164182

165183
headers := map[string]string{"Authorization": "Bearer token"}
166184
body := []byte(`{"method":"test"}`)
@@ -178,7 +196,7 @@ func TestPrinterPrintRequestVerbose(t *testing.T) {
178196

179197
func TestPrinterPrintRequestNotVerbose(t *testing.T) {
180198
var buf bytes.Buffer
181-
p := NewPrinter(&buf, false, false)
199+
p := NewPrinter(&buf, &bytes.Buffer{}, false, false)
182200

183201
p.PrintRequest("POST", "http://localhost/mcp", nil, nil)
184202

0 commit comments

Comments
 (0)