Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 12 additions & 0 deletions osv/osv.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ const (
type Source struct {
httpClient *http.Client
baseURL string
userAgent string
}

// Option configures a Source.
Expand All @@ -42,11 +43,19 @@ func WithBaseURL(url string) Option {
}
}

// WithUserAgent sets the User-Agent header for API requests.
func WithUserAgent(ua string) Option {
return func(s *Source) {
s.userAgent = ua
}
}

// New creates a new OSV source.
func New(opts ...Option) *Source {
s := &Source{
httpClient: &http.Client{Timeout: DefaultTimeout},
baseURL: DefaultAPIURL,
userAgent: "vulns",
}
for _, opt := range opts {
opt(s)
Expand Down Expand Up @@ -82,6 +91,7 @@ func (s *Source) Query(ctx context.Context, p *purl.PURL) ([]vulns.Vulnerability
return nil, fmt.Errorf("creating request: %w", err)
}
httpReq.Header.Set("Content-Type", "application/json")
httpReq.Header.Set("User-Agent", s.userAgent)

resp, err := s.httpClient.Do(httpReq)
if err != nil {
Expand Down Expand Up @@ -141,6 +151,7 @@ func (s *Source) QueryBatch(ctx context.Context, purls []*purl.PURL) ([][]vulns.
return nil, fmt.Errorf("creating request: %w", err)
}
httpReq.Header.Set("Content-Type", "application/json")
httpReq.Header.Set("User-Agent", s.userAgent)

resp, err := s.httpClient.Do(httpReq)
if err != nil {
Expand Down Expand Up @@ -174,6 +185,7 @@ func (s *Source) Get(ctx context.Context, id string) (*vulns.Vulnerability, erro
if err != nil {
return nil, fmt.Errorf("creating request: %w", err)
}
httpReq.Header.Set("User-Agent", s.userAgent)

resp, err := s.httpClient.Do(httpReq)
if err != nil {
Expand Down
52 changes: 52 additions & 0 deletions osv/osv_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -168,6 +168,58 @@ func TestName(t *testing.T) {
}
}

func TestDefaultUserAgent(t *testing.T) {
var gotUA string
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
gotUA = r.Header.Get("User-Agent")
w.Header().Set("Content-Type", "application/json")
_, _ = w.Write([]byte(`{"vulns": []}`))
}))
defer server.Close()

source := New(WithBaseURL(server.URL))
p := purl.MakePURL("npm", "test", "1.0.0")
_, _ = source.Query(context.Background(), p)

if gotUA != "vulns" {
t.Errorf("default User-Agent = %q, want %q", gotUA, "vulns")
}
}

func TestCustomUserAgent(t *testing.T) {
var gotUA string
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
gotUA = r.Header.Get("User-Agent")
w.Header().Set("Content-Type", "application/json")
_, _ = w.Write([]byte(`{"vulns": []}`))
}))
defer server.Close()

source := New(WithBaseURL(server.URL), WithUserAgent("git-pkgs/1.0"))
p := purl.MakePURL("npm", "test", "1.0.0")
_, _ = source.Query(context.Background(), p)

if gotUA != "git-pkgs/1.0" {
t.Errorf("User-Agent = %q, want %q", gotUA, "git-pkgs/1.0")
}
}

func TestUserAgentOnGet(t *testing.T) {
var gotUA string
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
gotUA = r.Header.Get("User-Agent")
w.WriteHeader(http.StatusNotFound)
}))
defer server.Close()

source := New(WithBaseURL(server.URL), WithUserAgent("test-agent"))
_, _ = source.Get(context.Background(), "GHSA-test")

if gotUA != "test-agent" {
t.Errorf("Get User-Agent = %q, want %q", gotUA, "test-agent")
}
}

func loadFixture(t *testing.T, name string) []byte {
t.Helper()
path := filepath.Join("testdata", name)
Expand Down