diff --git a/products/ddr/client_test.go b/products/ddr/client_test.go index 8a9c5fa..a0f1a00 100644 --- a/products/ddr/client_test.go +++ b/products/ddr/client_test.go @@ -2,29 +2,40 @@ package ddr import ( "context" + "io" "net/http" - "net/http/httptest" + "strings" "testing" ) -func TestClientInjectsAuthHeaders(t *testing.T) { - server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - if got := r.Header.Get("Authorization"); got != "test-token" { - t.Fatalf("Authorization = %q, want %q", got, "test-token") - } - if got := r.Header.Get("X-CS-Header-Company"); got != "company-1" { - t.Fatalf("X-CS-Header-Company = %q, want %q", got, "company-1") - } - w.Header().Set("Content-Type", "application/json") - _, _ = w.Write([]byte(`{"code":0,"msg":"ok","data":{"items":[]}}`)) - })) - defer server.Close() +type roundTripFunc func(*http.Request) (*http.Response, error) + +func (f roundTripFunc) RoundTrip(req *http.Request) (*http.Response, error) { + return f(req) +} +func TestClientInjectsAuthHeaders(t *testing.T) { client := NewClient(&Config{ - URL: server.URL, + URL: "https://example.test", APIKey: "test-token", CompanyID: "company-1", }, nil, false) + client.httpClient = &http.Client{ + Transport: roundTripFunc(func(req *http.Request) (*http.Response, error) { + if got := req.Header.Get("Authorization"); got != "Serval test-token" { + t.Fatalf("Authorization = %q, want %q", got, "Serval test-token") + } + if got := req.Header.Get("X-CS-Header-Company"); got != "company-1" { + t.Fatalf("X-CS-Header-Company = %q, want %q", got, "company-1") + } + return &http.Response{ + StatusCode: http.StatusOK, + Header: http.Header{"Content-Type": []string{"application/json"}}, + Body: io.NopCloser(strings.NewReader(`{"code":0,"msg":"ok","data":{"items":[]}}`)), + Request: req, + }, nil + }), + } var result map[string]interface{} if err := client.Do(context.Background(), http.MethodGet, "/health", nil, nil, nil, &result); err != nil { diff --git a/products/ddr/get_api_token_test.go b/products/ddr/get_api_token_test.go index 620c60e..5c7498e 100644 --- a/products/ddr/get_api_token_test.go +++ b/products/ddr/get_api_token_test.go @@ -3,8 +3,8 @@ package ddr import ( "context" "encoding/base64" + "io" "net/http" - "net/http/httptest" "path/filepath" "strings" "testing" @@ -24,7 +24,7 @@ func TestParseJWTClaims(t *testing.T) { func TestBuildServalToken(t *testing.T) { got := buildServalToken("69e35caa09f496d33065033a") - want := "Serval " + base64.StdEncoding.EncodeToString([]byte("serval:69e35caa09f496d33065033a")) + want := base64.StdEncoding.EncodeToString([]byte("serval:69e35caa09f496d33065033a")) if got != want { t.Fatalf("buildServalToken() = %q, want %q", got, want) } @@ -61,29 +61,34 @@ func TestCreateAndPersistAPIToken(t *testing.T) { var seenCreatePath string var seenAttrPath string - server := httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - switch r.URL.Path { - case "/qzh/api/auth/v1/access_key/batch": - seenCreatePath = r.URL.Path - seenCreateAuth = r.Header.Get("Authorization") - seenCreateUser = r.Header.Get("User") - w.Header().Set("Content-Type", "application/json") - _, _ = w.Write([]byte(`{"data":{"data":[{"access_key":{"access_key":"ak-1","secret_key":"sk-1"}}]}}`)) - case "/qzh/api/auth/v1/ns/attributes": - seenAttrPath = r.URL.Path - seenAttrAuth = r.Header.Get("Authorization") - w.Header().Set("Content-Type", "application/json") - _, _ = w.Write([]byte(`{"data":{"attributes":[{"k":"corp_name","v":"corp-1"}]}}`)) - default: - t.Fatalf("unexpected path: %s", r.URL.Path) - } - })) - defer server.Close() - + baseURL := "https://example.test" client := NewClient(&Config{ - URL: server.URL, + URL: baseURL, }, nil, false) - client.httpClient = server.Client() + client.httpClient = &http.Client{ + Transport: roundTripFunc(func(req *http.Request) (*http.Response, error) { + var body string + switch req.URL.Path { + case "/qzh/api/auth/v1/access_key/batch": + seenCreatePath = req.URL.Path + seenCreateAuth = req.Header.Get("Authorization") + seenCreateUser = req.Header.Get("User") + body = `{"data":{"data":[{"access_key":{"access_key":"ak-1","secret_key":"sk-1"}}]}}` + case "/qzh/api/auth/v1/ns/attributes": + seenAttrPath = req.URL.Path + seenAttrAuth = req.Header.Get("Authorization") + body = `{"data":{"attributes":[{"k":"corp_name","v":"corp-1"}]}}` + default: + t.Fatalf("unexpected path: %s", req.URL.Path) + } + return &http.Response{ + StatusCode: http.StatusOK, + Header: http.Header{"Content-Type": []string{"application/json"}}, + Body: io.NopCloser(strings.NewReader(body)), + Request: req, + }, nil + }), + } result, err := createAndPersistAPIToken(context.Background(), client, configPath, "Bearer e30.eyJVc2VySUQiOiIxIiwiVXNlck5TIjoiZGVmYXVsdCJ9.sig") if err != nil { @@ -101,8 +106,10 @@ func TestCreateAndPersistAPIToken(t *testing.T) { } wantToken := base64.StdEncoding.EncodeToString([]byte("serval:ak-1")) - if seenAttrAuth != wantToken { - t.Fatalf("ns attributes Authorization = %q, want %q", seenAttrAuth, wantToken) + // injectHeaders automatically adds "Serval " prefix + wantAuth := "Serval " + wantToken + if seenAttrAuth != wantAuth { + t.Fatalf("ns attributes Authorization = %q, want %q", seenAttrAuth, wantAuth) } if seenAttrPath != "/qzh/api/auth/v1/ns/attributes" { t.Fatalf("attributes path = %q", seenAttrPath) @@ -121,7 +128,7 @@ func TestCreateAndPersistAPIToken(t *testing.T) { if err := node.Decode(&saved); err != nil { t.Fatalf("Decode() error = %v", err) } - if saved.URL != server.URL+"/qzh/api/v1" || saved.APIKey != wantToken || saved.CompanyID != "corp-1" { + if saved.URL != baseURL+"/qzh/api/v1" || saved.APIKey != wantToken || saved.CompanyID != "corp-1" { t.Fatalf("unexpected saved config: %+v", saved) } }