diff --git a/osv/osv.go b/osv/osv.go index 4802c67..86d3464 100644 --- a/osv/osv.go +++ b/osv/osv.go @@ -23,6 +23,7 @@ const ( type Source struct { httpClient *http.Client baseURL string + userAgent string } // Option configures a Source. @@ -42,11 +43,19 @@ func WithBaseURL(url string) Option { } } +// WithUserAgent sets the User-Agent header for API requests. +func WithUserAgent(ua string) Option { + return func(s *Source) { + s.userAgent = ua + } +} + // New creates a new OSV source. func New(opts ...Option) *Source { s := &Source{ httpClient: &http.Client{Timeout: DefaultTimeout}, baseURL: DefaultAPIURL, + userAgent: "vulns", } for _, opt := range opts { opt(s) @@ -82,6 +91,7 @@ func (s *Source) Query(ctx context.Context, p *purl.PURL) ([]vulns.Vulnerability return nil, fmt.Errorf("creating request: %w", err) } httpReq.Header.Set("Content-Type", "application/json") + httpReq.Header.Set("User-Agent", s.userAgent) resp, err := s.httpClient.Do(httpReq) if err != nil { @@ -141,6 +151,7 @@ func (s *Source) QueryBatch(ctx context.Context, purls []*purl.PURL) ([][]vulns. return nil, fmt.Errorf("creating request: %w", err) } httpReq.Header.Set("Content-Type", "application/json") + httpReq.Header.Set("User-Agent", s.userAgent) resp, err := s.httpClient.Do(httpReq) if err != nil { @@ -174,6 +185,7 @@ func (s *Source) Get(ctx context.Context, id string) (*vulns.Vulnerability, erro if err != nil { return nil, fmt.Errorf("creating request: %w", err) } + httpReq.Header.Set("User-Agent", s.userAgent) resp, err := s.httpClient.Do(httpReq) if err != nil { diff --git a/osv/osv_test.go b/osv/osv_test.go index fd79bf6..68d7d4a 100644 --- a/osv/osv_test.go +++ b/osv/osv_test.go @@ -168,6 +168,58 @@ func TestName(t *testing.T) { } } +func TestDefaultUserAgent(t *testing.T) { + var gotUA string + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + gotUA = r.Header.Get("User-Agent") + w.Header().Set("Content-Type", "application/json") + _, _ = w.Write([]byte(`{"vulns": []}`)) + })) + defer server.Close() + + source := New(WithBaseURL(server.URL)) + p := purl.MakePURL("npm", "test", "1.0.0") + _, _ = source.Query(context.Background(), p) + + if gotUA != "vulns" { + t.Errorf("default User-Agent = %q, want %q", gotUA, "vulns") + } +} + +func TestCustomUserAgent(t *testing.T) { + var gotUA string + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + gotUA = r.Header.Get("User-Agent") + w.Header().Set("Content-Type", "application/json") + _, _ = w.Write([]byte(`{"vulns": []}`)) + })) + defer server.Close() + + source := New(WithBaseURL(server.URL), WithUserAgent("git-pkgs/1.0")) + p := purl.MakePURL("npm", "test", "1.0.0") + _, _ = source.Query(context.Background(), p) + + if gotUA != "git-pkgs/1.0" { + t.Errorf("User-Agent = %q, want %q", gotUA, "git-pkgs/1.0") + } +} + +func TestUserAgentOnGet(t *testing.T) { + var gotUA string + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + gotUA = r.Header.Get("User-Agent") + w.WriteHeader(http.StatusNotFound) + })) + defer server.Close() + + source := New(WithBaseURL(server.URL), WithUserAgent("test-agent")) + _, _ = source.Get(context.Background(), "GHSA-test") + + if gotUA != "test-agent" { + t.Errorf("Get User-Agent = %q, want %q", gotUA, "test-agent") + } +} + func loadFixture(t *testing.T, name string) []byte { t.Helper() path := filepath.Join("testdata", name)