@@ -26,23 +26,6 @@ type Stats struct {
2626 Total time.Duration
2727}
2828
29- // httpEndpointConfig represents the configuration for an HTTP endpoint.
30- type httpEndpointConfig struct {
31- client * http.Client
32- url string
33- }
34-
35- // sgAuthTransport is an http.RoundTripper that adds an Authorization header to requests.
36- // It is used to add the Sourcegraph access token to requests to Sourcegraph endpoints.
37- type sgAuthTransport struct {
38- token string
39- base http.RoundTripper
40- }
41- func (t * sgAuthTransport ) RoundTrip (req * http.Request ) (* http.Response , error ) {
42- req .Header .Add ("Authorization" , "token " + t .token )
43- return t .base .RoundTrip (req )
44- }
45-
4629func init () {
4730 usage := `
4831'src gateway benchmark' runs performance benchmarks against Cody Gateway endpoints.
@@ -53,10 +36,10 @@ Usage:
5336
5437Examples:
5538
56- $ src gateway benchmark
57- $ src gateway benchmark --requests 50
58- $ src gateway benchmark --gateway http://localhost:9992 --sourcegraph http://localhost:3082 --sgp sgp_***** --requests 50
59- $ src gateway benchmark --requests 50 --csv results.csv
39+ $ src gateway benchmark --sgp <token>
40+ $ src gateway benchmark --requests 50 --sgp <token>
41+ $ src gateway benchmark --gateway http://localhost:9992 --sourcegraph http://localhost:3082 --sgp <token>
42+ $ src gateway benchmark --requests 50 --csv results.csv --sgp <token>
6043`
6144
6245 flagSet := flag .NewFlagSet ("benchmark" , flag .ExitOnError )
@@ -66,7 +49,7 @@ Examples:
6649 csvOutput = flagSet .String ("csv" , "" , "Export results to CSV file (provide filename)" )
6750 gatewayEndpoint = flagSet .String ("gateway" , "https://cody-gateway.sourcegraph.com" , "Cody Gateway endpoint" )
6851 sgEndpoint = flagSet .String ("sourcegraph" , "https://sourcegraph.com" , "Sourcegraph endpoint" )
69- sgpToken = flagSet .String ("sgp" , "sgp_***** " , "Sourcegraph personal access token for the called instance" )
52+ sgpToken = flagSet .String ("sgp" , "" , "Sourcegraph personal access token for the called instance" )
7053 )
7154
7255 handler := func (args []string ) error {
@@ -79,78 +62,49 @@ Examples:
7962 }
8063
8164 var (
82- gatewayWebsocket , sourcegraphWebsocket * websocket.Conn
83- err error
84- gatewayClient = & http.Client {}
85- sourcegraphClient = & http.Client {}
86- endpoints = map [string ]any {} // Values: URL `string`s or `*websocket.Conn`s
65+ httpClient = & http.Client {}
66+ endpoints = map [string ]any {} // Values: URL `string`s or `*webSocketClient`s
8767 )
88-
89- // Connect to endpoints
9068 if * gatewayEndpoint != "" {
9169 fmt .Println ("Benchmarking Cody Gateway instance:" , * gatewayEndpoint )
92- wsURL := strings .Replace (fmt .Sprint (* gatewayEndpoint , "/v2/websocket" ), "http" , "ws" , 1 )
93- fmt .Println ("Connecting to Cody Gateway via WebSocket.." , wsURL )
94- gatewayWebsocket , _ , err = websocket .DefaultDialer .Dial (wsURL , nil )
95- if err != nil {
96- return fmt .Errorf ("WebSocket dial(%s): %v" , wsURL , err )
97- }
98- fmt .Println ("Connected!" )
99- endpoints ["ws(s): gateway" ] = gatewayWebsocket
100- endpoints ["http(s): gateway" ] = & httpEndpointConfig {
101- client : gatewayClient ,
102- url : fmt .Sprint (* gatewayEndpoint , "/v2/http" ),
70+ endpoints ["ws(s): gateway" ] = & webSocketClient {
71+ conn : nil ,
72+ URL : strings .Replace (fmt .Sprint (* gatewayEndpoint , "/v2/websocket" ), "http" , "ws" , 1 ),
10373 }
74+ endpoints ["http(s): gateway" ] = fmt .Sprint (* gatewayEndpoint , "/v2/http" )
10475 } else {
10576 fmt .Println ("warning: not benchmarking Cody Gateway (-gateway endpoint not provided)" )
10677 }
10778 if * sgEndpoint != "" {
108- // Add auth header to sourcegraphClient transport
109- if * sgpToken != "" {
110- sourcegraphClient .Transport = & sgAuthTransport {
111- token : * sgpToken ,
112- base : http .DefaultTransport ,
113- }
79+ if * sgpToken == "" {
80+ return cmderrors .Usage ("must specify --sgp <Sourcegraph personal access token>" )
11481 }
11582 fmt .Println ("Benchmarking Sourcegraph instance:" , * sgEndpoint )
116- wsURL := strings .Replace (fmt .Sprint (* sgEndpoint , "/.api/gateway/websocket" ), "http" , "ws" , 1 )
117- header := http.Header {}
118- header .Add ("Authorization" , "token " + * sgpToken )
119- fmt .Println ("Connecting to Sourcegraph instance via WebSocket.." , wsURL )
120- sourcegraphWebsocket , _ , err = websocket .DefaultDialer .Dial (wsURL , header )
121- if err != nil {
122- return fmt .Errorf ("WebSocket dial(%s): %v" , wsURL , err )
123- }
124- fmt .Println ("Connected!" )
125-
126- endpoints ["ws(s): sourcegraph" ] = sourcegraphWebsocket
127- endpoints ["http(s): sourcegraph" ] = & httpEndpointConfig {
128- client : sourcegraphClient ,
129- url : fmt .Sprint (* sgEndpoint , "/.api/gateway/http" ),
130- }
131- endpoints ["http(s): http-then-ws" ] = & httpEndpointConfig {
132- client : sourcegraphClient ,
133- url : fmt .Sprint (* sgEndpoint , "/.api/gateway/http-then-websocket" ),
83+ endpoints ["ws(s): sourcegraph" ] = & webSocketClient {
84+ conn : nil ,
85+ URL : strings .Replace (fmt .Sprint (* sgEndpoint , "/.api/gateway/websocket" ), "http" , "ws" , 1 ),
13486 }
87+ endpoints ["http(s): sourcegraph" ] = fmt .Sprint (* sgEndpoint , "/.api/gateway/http" )
88+ endpoints ["http(s): http-then-ws" ] = fmt .Sprint (* sgEndpoint , "/.api/gateway/http-then-websocket" )
13589 } else {
13690 fmt .Println ("warning: not benchmarking Sourcegraph instance (-sourcegraph endpoint not provided)" )
13791 }
13892
13993 fmt .Printf ("Starting benchmark with %d requests per endpoint...\n " , * requestCount )
14094
14195 var results []endpointResult
142- for name , clientOrEndpointConfig := range endpoints {
96+ for name , clientOrURL := range endpoints {
14397 durations := make ([]time.Duration , 0 , * requestCount )
14498 fmt .Printf ("\n Testing %s..." , name )
14599
146100 for i := 0 ; i < * requestCount ; i ++ {
147- if ws , ok := clientOrEndpointConfig .(* websocket. Conn ); ok {
101+ if ws , ok := clientOrURL .(* webSocketClient ); ok {
148102 duration := benchmarkEndpointWebSocket (ws )
149103 if duration > 0 {
150104 durations = append (durations , duration )
151105 }
152- } else if epConf , ok := clientOrEndpointConfig .( * httpEndpointConfig ); ok {
153- duration := benchmarkEndpointHTTP (epConf )
106+ } else if url , ok := clientOrURL .( string ); ok {
107+ duration := benchmarkEndpointHTTP (httpClient , url , * sgpToken )
154108 if duration > 0 {
155109 durations = append (durations , duration )
156110 }
@@ -200,6 +154,26 @@ Examples:
200154 })
201155}
202156
157+ type webSocketClient struct {
158+ conn * websocket.Conn
159+ URL string
160+ }
161+
162+ func (c * webSocketClient ) reconnect () error {
163+ if c .conn != nil {
164+ c .conn .Close () // don't leak connections
165+ }
166+ fmt .Println ("Connecting to WebSocket.." , c .URL )
167+ var err error
168+ c .conn , _ , err = websocket .DefaultDialer .Dial (c .URL , nil )
169+ if err != nil {
170+ c .conn = nil // retry again later
171+ return fmt .Errorf ("WebSocket dial(%s): %v" , c .URL , err )
172+ }
173+ fmt .Println ("Connected!" )
174+ return nil
175+ }
176+
203177type endpointResult struct {
204178 name string
205179 avg time.Duration
@@ -212,11 +186,18 @@ type endpointResult struct {
212186 successful int
213187}
214188
215- func benchmarkEndpointHTTP (epConfig * httpEndpointConfig ) time.Duration {
189+ func benchmarkEndpointHTTP (client * http. Client , url , accessToken string ) time.Duration {
216190 start := time .Now ()
217- resp , err := epConfig .client .Post (epConfig .url , "application/json" , strings .NewReader ("ping" ))
191+ req , err := http .NewRequest ("POST" , url , strings .NewReader ("ping" ))
192+ if err != nil {
193+ fmt .Printf ("Error creating request: %v\n " , err )
194+ return 0
195+ }
196+ req .Header .Set ("Content-Type" , "application/json" )
197+ req .Header .Set ("Authorization" , "token " + accessToken )
198+ resp , err := client .Do (req )
218199 if err != nil {
219- fmt .Printf ("Error calling %s: %v\n " , epConfig . url , err )
200+ fmt .Printf ("Error calling %s: %v\n " , url , err )
220201 return 0
221202 }
222203 defer func () {
@@ -242,20 +223,39 @@ func benchmarkEndpointHTTP(epConfig *httpEndpointConfig) time.Duration {
242223 return time .Since (start )
243224}
244225
245- func benchmarkEndpointWebSocket (conn * websocket.Conn ) time.Duration {
226+ func benchmarkEndpointWebSocket (client * webSocketClient ) time.Duration {
227+ // Perform initial websocket connection, if needed.
228+ if client .conn == nil {
229+ if err := client .reconnect (); err != nil {
230+ fmt .Printf ("Error reconnecting: %v\n " , err )
231+ return 0
232+ }
233+ }
234+
235+ // Perform the benchmarked request using the websocket.
246236 start := time .Now ()
247- err := conn .WriteMessage (websocket .TextMessage , []byte ("ping" ))
237+ err := client . conn .WriteMessage (websocket .TextMessage , []byte ("ping" ))
248238 if err != nil {
249239 fmt .Printf ("WebSocket write error: %v\n " , err )
240+ if err := client .reconnect (); err != nil {
241+ fmt .Printf ("Error reconnecting: %v\n " , err )
242+ }
250243 return 0
251244 }
252- _ , message , err := conn .ReadMessage ()
245+ _ , message , err := client .conn .ReadMessage ()
246+
253247 if err != nil {
254248 fmt .Printf ("WebSocket read error: %v\n " , err )
249+ if err := client .reconnect (); err != nil {
250+ fmt .Printf ("Error reconnecting: %v\n " , err )
251+ }
255252 return 0
256253 }
257254 if string (message ) != "pong" {
258255 fmt .Printf ("Expected 'pong' response, got: %q\n " , string (message ))
256+ if err := client .reconnect (); err != nil {
257+ fmt .Printf ("Error reconnecting: %v\n " , err )
258+ }
259259 return 0
260260 }
261261 return time .Since (start )
0 commit comments