Skip to content

Commit 95b6f6f

Browse files
committed
Fix race condition, input validation, and code quality issues
- Add sync.RWMutex to protect global searchCache from concurrent access - Replace O(n²) duplicate skill check with map-based O(1) lookup - Add git ref validation to prevent malformed references in Checkout - Add path traversal protection in buildRepoURL/buildRepoName - Replace io.ReadAtLeast with io.ReadAll+LimitReader for safer reads - Consolidate hardcoded HTTP timeouts into named constants - Add OS signal handling for graceful sync cancellation - Remove redundant helper functions, use strings stdlib instead
1 parent d4cd0f9 commit 95b6f6f

4 files changed

Lines changed: 121 additions & 83 deletions

File tree

cmd/install.go

Lines changed: 6 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -139,16 +139,14 @@ func runInstall(cmd *cobra.Command, args []string) {
139139
count++
140140
}
141141
// Add from legacy skills list if not duplicate
142+
seen := make(map[string]bool, len(skillArgs))
143+
for _, existing := range skillArgs {
144+
seen[existing] = true
145+
}
142146
for _, s := range cfg.Skills {
143-
exists := false
144-
for _, existing := range skillArgs {
145-
if existing == s {
146-
exists = true
147-
break
148-
}
149-
}
150-
if !exists {
147+
if !seen[s] {
151148
skillArgs = append(skillArgs, s)
149+
seen[s] = true
152150
count++
153151
}
154152
}

cmd/sync.go

Lines changed: 36 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ import (
44
"context"
55
"fmt"
66
"os"
7+
"os/signal"
78
"strings"
89
"sync"
910
"time"
@@ -80,7 +81,9 @@ If no repo name is specified, syncs all configured repositories.`,
8081
bar := ui.NewProgressBar(len(targetRepos), "Syncing repositories")
8182

8283
// Use errgroup for parallel syncing with limit
83-
ctx := context.Background()
84+
// Support cancellation via OS signals (Ctrl+C)
85+
ctx, stop := signal.NotifyContext(context.Background(), os.Interrupt)
86+
defer stop()
8487
g, ctx := errgroup.WithContext(ctx)
8588
g.SetLimit(5) // Limit concurrency to 5
8689

@@ -169,28 +172,48 @@ If no repo name is specified, syncs all configured repositories.`,
169172
},
170173
}
171174

172-
// buildRepoURL constructs the git clone URL from repo config
173-
func buildRepoURL(url string) string {
175+
// buildRepoURL constructs the git clone URL from repo config.
176+
// Validates that owner/repo parts do not contain path traversal patterns.
177+
func buildRepoURL(repoURL string) string {
174178
// Handle owner/repo format
175-
if !strings.HasPrefix(url, "http") && !strings.HasPrefix(url, "git@") {
179+
if !strings.HasPrefix(repoURL, "http") && !strings.HasPrefix(repoURL, "git@") {
176180
// Extract owner/repo from path like "anthropics/skills/skills"
177-
parts := strings.Split(url, "/")
181+
parts := strings.Split(repoURL, "/")
178182
if len(parts) >= 2 {
179-
return fmt.Sprintf("https://github.com/%s/%s.git", parts[0], parts[1])
183+
owner, repo := parts[0], parts[1]
184+
// Reject path traversal or empty segments
185+
if owner == ".." || repo == ".." || owner == "." || repo == "." || owner == "" || repo == "" {
186+
return ""
187+
}
188+
return fmt.Sprintf("https://github.com/%s/%s.git", owner, repo)
189+
}
190+
if repoURL == ".." || repoURL == "." || repoURL == "" {
191+
return ""
180192
}
181-
return "https://github.com/" + url + ".git"
193+
return "https://github.com/" + repoURL + ".git"
182194
}
183-
return url
195+
return repoURL
184196
}
185197

186-
// buildRepoName constructs a filesystem-safe name from repo URL
187-
func buildRepoName(url string) string {
198+
// buildRepoName constructs a filesystem-safe name from repo URL.
199+
// Strips path traversal patterns to prevent directory escape.
200+
func buildRepoName(repoURL string) string {
188201
// Handle owner/repo/path format
189-
parts := strings.Split(url, "/")
202+
parts := strings.Split(repoURL, "/")
190203
if len(parts) >= 2 {
191-
return parts[0] + "-" + parts[1]
204+
owner := strings.ReplaceAll(parts[0], "..", "")
205+
repo := strings.ReplaceAll(parts[1], "..", "")
206+
if owner == "" || repo == "" {
207+
return "unknown-repo"
208+
}
209+
return owner + "-" + repo
210+
}
211+
name := strings.ReplaceAll(repoURL, "/", "-")
212+
name = strings.ReplaceAll(name, "..", "")
213+
if name == "" {
214+
return "unknown-repo"
192215
}
193-
return strings.ReplaceAll(url, "/", "-")
216+
return name
194217
}
195218

196219
func init() {

internal/git/git.go

Lines changed: 22 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -172,15 +172,36 @@ func GetLatestTag(repoPath string) (string, error) {
172172
return strings.TrimSpace(string(output)), nil
173173
}
174174

175-
// Checkout checks out a specific tag or branch
175+
// Checkout checks out a specific tag or branch.
176+
// The ref is validated to prevent unexpected git behavior from malformed references.
176177
func Checkout(repoPath, ref string) error {
178+
if err := validateGitRef(ref); err != nil {
179+
return fmt.Errorf("invalid git ref: %w", err)
180+
}
177181
cmd := exec.Command("git", "checkout", ref)
178182
cmd.Dir = repoPath
179183
cmd.Stdout = os.Stdout
180184
cmd.Stderr = os.Stderr
181185
return cmd.Run()
182186
}
183187

188+
// validateGitRef checks that a git reference string is safe to use.
189+
func validateGitRef(ref string) error {
190+
if ref == "" {
191+
return fmt.Errorf("ref cannot be empty")
192+
}
193+
if strings.Contains(ref, "..") {
194+
return fmt.Errorf("ref cannot contain '..'")
195+
}
196+
if strings.ContainsAny(ref, " \t\n\r~^:?*[\\") {
197+
return fmt.Errorf("ref contains invalid characters")
198+
}
199+
if strings.HasPrefix(ref, "-") {
200+
return fmt.Errorf("ref cannot start with '-'")
201+
}
202+
return nil
203+
}
204+
184205
// GetCurrentCommit returns the current commit hash of the repository
185206
func GetCurrentCommit(repoPath string) (string, error) {
186207
cmd := exec.Command("git", "rev-parse", "HEAD")

internal/github/client.go

Lines changed: 57 additions & 61 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@ import (
99
"net/url"
1010
"os"
1111
"strings"
12+
"sync"
1213
"time"
1314

1415
"github.com/yeasy/ask/internal/cache"
@@ -20,10 +21,20 @@ const (
2021
SkillTopic = "agent-skill"
2122
// APIURL is the GitHub API endpoint for searching repositories
2223
APIURL = "https://api.github.com/search/repositories"
24+
25+
// httpTimeoutDefault is the default timeout for GitHub API requests
26+
httpTimeoutDefault = 10 * time.Second
27+
// httpTimeoutShort is a shorter timeout for non-critical requests like fetching descriptions
28+
httpTimeoutShort = 5 * time.Second
29+
// maxDescriptionReadBytes limits how much of SKILL.md we read for description extraction
30+
maxDescriptionReadBytes = 4096
2331
)
2432

25-
// Global cache instance
26-
var searchCache *cache.Cache
33+
// Global cache instance, protected by cacheMu for concurrent access
34+
var (
35+
searchCache *cache.Cache
36+
cacheMu sync.RWMutex
37+
)
2738

2839
// OfflineMode returns whether the application is in offline mode.
2940
// Delegates to config.OfflineMode as the single source of truth.
@@ -41,6 +52,25 @@ func init() {
4152
}
4253
}
4354

55+
// cacheGet safely reads from the global cache under a read lock.
56+
func cacheGet(key string, dest interface{}) bool {
57+
cacheMu.RLock()
58+
defer cacheMu.RUnlock()
59+
if searchCache == nil {
60+
return false
61+
}
62+
return searchCache.Get(key, dest)
63+
}
64+
65+
// cacheSet safely writes to the global cache under a write lock.
66+
func cacheSet(key string, value interface{}) {
67+
cacheMu.Lock()
68+
defer cacheMu.Unlock()
69+
if searchCache != nil {
70+
_ = searchCache.Set(key, value)
71+
}
72+
}
73+
4474
// SearchResult represents the response from GitHub search API
4575
type SearchResult struct {
4676
TotalCount int `json:"total_count"`
@@ -97,11 +127,9 @@ func SearchTopic(topic, keyword string) ([]Repository, error) {
97127

98128
// Try cache first
99129
// In offline mode, we MUST find it in cache or return error
100-
if searchCache != nil {
101-
var cached []Repository
102-
if searchCache.Get(cacheKey, &cached) {
103-
return cached, nil
104-
}
130+
var cached []Repository
131+
if cacheGet(cacheKey, &cached) {
132+
return cached, nil
105133
}
106134

107135
if isOffline() {
@@ -129,7 +157,7 @@ func SearchTopic(topic, keyword string) ([]Repository, error) {
129157
req.Header.Set("Accept", "application/vnd.github.v3+json")
130158
req.Header.Set("User-Agent", "ask-cli")
131159

132-
client := &http.Client{Timeout: 10 * time.Second}
160+
client := &http.Client{Timeout: httpTimeoutDefault}
133161
resp, err := client.Do(req)
134162
if err != nil {
135163
return nil, err
@@ -146,9 +174,7 @@ func SearchTopic(topic, keyword string) ([]Repository, error) {
146174
}
147175

148176
// Cache the result
149-
if searchCache != nil {
150-
_ = searchCache.Set(cacheKey, result.Items)
151-
}
177+
cacheSet(cacheKey, result.Items)
152178

153179
return result.Items, nil
154180
}
@@ -165,11 +191,9 @@ func SearchDir(owner, repo, path string) ([]Repository, error) {
165191
cacheKey := fmt.Sprintf("dir:%s/%s/%s", owner, repo, path)
166192

167193
// Try cache first
168-
if searchCache != nil {
169-
var cached []Repository
170-
if searchCache.Get(cacheKey, &cached) {
171-
return cached, nil
172-
}
194+
var cached []Repository
195+
if cacheGet(cacheKey, &cached) {
196+
return cached, nil
173197
}
174198

175199
if isOffline() {
@@ -190,7 +214,7 @@ func SearchDir(owner, repo, path string) ([]Repository, error) {
190214
req.Header.Set("Accept", "application/vnd.github.v3+json")
191215
req.Header.Set("User-Agent", "ask-cli")
192216

193-
client := &http.Client{Timeout: 10 * time.Second}
217+
client := &http.Client{Timeout: httpTimeoutDefault}
194218
resp, err := client.Do(req)
195219
if err != nil {
196220
return nil, err
@@ -240,9 +264,7 @@ func SearchDir(owner, repo, path string) ([]Repository, error) {
240264
}
241265

242266
// Cache the result
243-
if searchCache != nil {
244-
_ = searchCache.Set(cacheKey, skills)
245-
}
267+
cacheSet(cacheKey, skills)
246268

247269
return skills, nil
248270
}
@@ -251,11 +273,9 @@ func SearchDir(owner, repo, path string) ([]Repository, error) {
251273
func fetchSkillDescription(owner, repo, skillPath string) string {
252274
// Check cache first
253275
cacheKey := fmt.Sprintf("skill-desc:%s/%s/%s", owner, repo, skillPath)
254-
if searchCache != nil {
255-
var cached string
256-
if searchCache.Get(cacheKey, &cached) {
257-
return cached
258-
}
276+
var cached string
277+
if cacheGet(cacheKey, &cached) {
278+
return cached
259279
}
260280

261281
// Fetch SKILL.md content
@@ -273,7 +293,7 @@ func fetchSkillDescription(owner, repo, skillPath string) string {
273293
req.Header.Set("Accept", "application/vnd.github.v3.raw") // Get raw file content
274294
req.Header.Set("User-Agent", "ask-cli")
275295

276-
client := &http.Client{Timeout: 5 * time.Second}
296+
client := &http.Client{Timeout: httpTimeoutShort}
277297
resp, err := client.Do(req)
278298
if err != nil {
279299
return ""
@@ -284,20 +304,19 @@ func fetchSkillDescription(owner, repo, skillPath string) string {
284304
return ""
285305
}
286306

287-
// Read the content (limit to 4KB to avoid huge files)
288-
buf := make([]byte, 4096)
289-
n, err := io.ReadAtLeast(resp.Body, buf, 1)
290-
if err != nil && n == 0 {
307+
// Read the content (limit to maxDescriptionReadBytes to avoid huge files)
308+
data, err := io.ReadAll(io.LimitReader(resp.Body, maxDescriptionReadBytes))
309+
if err != nil || len(data) == 0 {
291310
return ""
292311
}
293-
content := string(buf[:n])
312+
content := string(data)
294313

295314
// Parse description from SKILL.md (check both frontmatter and first paragraph)
296315
desc := parseDescriptionFromSkillMD(content)
297316

298317
// Cache the description
299-
if searchCache != nil && desc != "" {
300-
_ = searchCache.Set(cacheKey, desc)
318+
if desc != "" {
319+
cacheSet(cacheKey, desc)
301320
}
302321

303322
return desc
@@ -328,7 +347,7 @@ func parseDescriptionFromSkillMD(content string) string {
328347

329348
// If no frontmatter description, look for first non-empty non-heading line
330349
for _, line := range lines {
331-
line = trimSpace(line)
350+
line = strings.TrimSpace(line)
332351
if line == "" || line == "---" {
333352
continue
334353
}
@@ -342,36 +361,13 @@ func parseDescriptionFromSkillMD(content string) string {
342361
return ""
343362
}
344363

345-
// Helper functions to avoid importing strings package
364+
// splitLines splits a string into lines by newline character.
346365
func splitLines(s string) []string {
347-
var lines []string
348-
start := 0
349-
for i := 0; i < len(s); i++ {
350-
if s[i] == '\n' {
351-
lines = append(lines, s[start:i])
352-
start = i + 1
353-
}
354-
}
355-
if start < len(s) {
356-
lines = append(lines, s[start:])
357-
}
358-
return lines
359-
}
360-
361-
func trimSpace(s string) string {
362-
start := 0
363-
end := len(s)
364-
for start < end && (s[start] == ' ' || s[start] == '\t' || s[start] == '\r') {
365-
start++
366-
}
367-
for end > start && (s[end-1] == ' ' || s[end-1] == '\t' || s[end-1] == '\r') {
368-
end--
369-
}
370-
return s[start:end]
366+
return strings.Split(s, "\n")
371367
}
372368

373369
func trimQuotes(s string) string {
374-
s = trimSpace(s)
370+
s = strings.TrimSpace(s)
375371
if len(s) >= 2 && ((s[0] == '"' && s[len(s)-1] == '"') || (s[0] == '\'' && s[len(s)-1] == '\'')) {
376372
return s[1 : len(s)-1]
377373
}
@@ -511,7 +507,7 @@ func FetchRepoDetails(owner, repo string) (*Repository, error) {
511507
req.Header.Set("Accept", "application/vnd.github.v3+json")
512508
req.Header.Set("User-Agent", "ask-cli")
513509

514-
client := &http.Client{Timeout: 10 * time.Second}
510+
client := &http.Client{Timeout: httpTimeoutDefault}
515511
resp, err := client.Do(req)
516512
if err != nil {
517513
return nil, err

0 commit comments

Comments
 (0)