diff --git a/cmd/allowlist/main.go b/cmd/allowlist/main.go index 7391f5a..11d3f30 100644 --- a/cmd/allowlist/main.go +++ b/cmd/allowlist/main.go @@ -4,10 +4,12 @@ import ( "bytes" "flag" "fmt" - yaml "gopkg.in/yaml.v3" "os" + "slices" "strings" + yaml "gopkg.in/yaml.v3" + "github.com/winhowes/AuthTranslator/cmd/allowlist/plugins" ) @@ -69,86 +71,31 @@ func addEntry(args []string) { fs.Usage() return } - var params map[string]interface{} - if *paramList != "" { - params = make(map[string]interface{}) - for _, kv := range strings.Split(*paramList, ",") { - kv = strings.TrimSpace(kv) - if kv == "" { - continue - } - parts := strings.SplitN(kv, "=", 2) - if len(parts) == 2 { - k := strings.TrimSpace(parts[0]) - v := strings.TrimSpace(parts[1]) - params[k] = v - } - } - } + params := parseParams(*paramList) - // load file - data, err := os.ReadFile(*file) - if err != nil && !os.IsNotExist(err) { + entries, err := loadAllowlist(true) + if err != nil { fmt.Fprintln(os.Stderr, err) return } - var entries []plugins.AllowlistEntry - if len(data) > 0 { - if err := yaml.Unmarshal(data, &entries); err != nil { - fmt.Fprintln(os.Stderr, err) - return - } - } - // find integration + wantName := strings.ToLower(*integ) - var entry *plugins.AllowlistEntry - for i := range entries { - if strings.ToLower(entries[i].Integration) == wantName { - entry = &entries[i] - break - } - } - if entry == nil { + entryIdx := findIntegration(entries, wantName) + if entryIdx == -1 { entries = append(entries, plugins.AllowlistEntry{Integration: wantName}) - entry = &entries[len(entries)-1] + entryIdx = len(entries) - 1 } else { - entry.Integration = wantName - } - // find caller - var callerCfg *plugins.CallerConfig - for i := range entry.Callers { - if entry.Callers[i].ID == *caller { - callerCfg = &entry.Callers[i] - break - } - } - if callerCfg == nil { - entry.Callers = append(entry.Callers, plugins.CallerConfig{ID: *caller}) - callerCfg = &entry.Callers[len(entry.Callers)-1] - } - replaced := false - for i := range callerCfg.Capabilities { - if callerCfg.Capabilities[i].Name == *capName { - callerCfg.Capabilities[i].Params = params - replaced = true - break - } - } - if !replaced { - callerCfg.Capabilities = append(callerCfg.Capabilities, plugins.CapabilityConfig{Name: *capName, Params: params}) + entries[entryIdx].Integration = wantName } - out, err := yamlMarshal(entries) - if err != nil { - fmt.Fprintln(os.Stderr, err) - exitFunc(1) + callerIdx := findCaller(entries[entryIdx].Callers, *caller) + if callerIdx == -1 { + entries[entryIdx].Callers = append(entries[entryIdx].Callers, plugins.CallerConfig{ID: *caller}) + callerIdx = len(entries[entryIdx].Callers) - 1 } - out = bytes.ReplaceAll(out, []byte("params: {}"), []byte("params: null")) - if err := writeFile(*file, out, 0644); err != nil { - fmt.Fprintln(os.Stderr, err) - exitFunc(1) - } + entries[entryIdx].Callers[callerIdx].Capabilities = upsertCapability(entries[entryIdx].Callers[callerIdx].Capabilities, *capName, params) + saveAllowlist(entries) } func removeEntry(args []string) { @@ -168,59 +115,121 @@ func removeEntry(args []string) { return } - data, err := os.ReadFile(*file) + entries, err := loadAllowlist(false) if err != nil { fmt.Fprintln(os.Stderr, err) return } + + wantName := strings.ToLower(*integ) + entryIdx := findIntegration(entries, wantName) + if entryIdx != -1 { + entries[entryIdx].Integration = wantName + callerIdx := findCaller(entries[entryIdx].Callers, *caller) + if callerIdx != -1 { + entries[entryIdx].Callers[callerIdx].Capabilities = trimCapabilities(entries[entryIdx].Callers[callerIdx].Capabilities, *capName) + if len(entries[entryIdx].Callers[callerIdx].Capabilities) == 0 { + entries[entryIdx].Callers = slices.Delete(entries[entryIdx].Callers, callerIdx, callerIdx+1) + } + if len(entries[entryIdx].Callers) == 0 { + entries = slices.Delete(entries, entryIdx, entryIdx+1) + } + } + } + + saveAllowlist(entries) +} + +func parseParams(paramList string) map[string]interface{} { + params := make(map[string]interface{}) + for _, kv := range strings.Split(paramList, ",") { + kv = strings.TrimSpace(kv) + if kv == "" { + continue + } + key, value, ok := strings.Cut(kv, "=") + if !ok { + continue + } + key = strings.TrimSpace(key) + value = strings.TrimSpace(value) + if key == "" { + continue + } + params[key] = value + } + if len(params) == 0 { + return nil + } + return params +} + +func loadAllowlist(allowMissing bool) ([]plugins.AllowlistEntry, error) { + data, err := os.ReadFile(*file) + if err != nil { + if allowMissing && os.IsNotExist(err) { + return nil, nil + } + return nil, err + } + if len(data) == 0 { + return nil, nil + } var entries []plugins.AllowlistEntry if err := yaml.Unmarshal(data, &entries); err != nil { - fmt.Fprintln(os.Stderr, err) - return + return nil, err } + return entries, nil +} - wantName := strings.ToLower(*integ) - for ei := range entries { - if strings.ToLower(entries[ei].Integration) != wantName { - continue +func findIntegration(entries []plugins.AllowlistEntry, name string) int { + for i := range entries { + if strings.ToLower(entries[i].Integration) == name { + return i } - entries[ei].Integration = wantName - for ci := range entries[ei].Callers { - if entries[ei].Callers[ci].ID != *caller { - continue - } - caps := entries[ei].Callers[ci].Capabilities - for i := 0; i < len(caps); i++ { - if caps[i].Name == *capName { - caps = append(caps[:i], caps[i+1:]...) - i-- - continue - } - } - if len(caps) == 0 { - entries[ei].Callers = append(entries[ei].Callers[:ci], entries[ei].Callers[ci+1:]...) - } else { - for i := range caps { - if len(caps[i].Params) == 0 { - caps[i].Params = nil - } - } - entries[ei].Callers[ci].Capabilities = caps - } - break + } + return -1 +} + +func findCaller(callers []plugins.CallerConfig, id string) int { + for i := range callers { + if callers[i].ID == id { + return i } - if len(entries[ei].Callers) == 0 { - entries = append(entries[:ei], entries[ei+1:]...) + } + return -1 +} + +func upsertCapability(caps []plugins.CapabilityConfig, name string, params map[string]interface{}) []plugins.CapabilityConfig { + for i := range caps { + if caps[i].Name == name { + caps[i].Params = params + return caps } - break } + return append(caps, plugins.CapabilityConfig{Name: name, Params: params}) +} +func trimCapabilities(caps []plugins.CapabilityConfig, name string) []plugins.CapabilityConfig { + trimmed := slices.DeleteFunc(caps, func(cap plugins.CapabilityConfig) bool { + return cap.Name == name + }) + for i := range trimmed { + if len(trimmed[i].Params) == 0 { + trimmed[i].Params = nil + } + } + return trimmed +} + +func saveAllowlist(entries []plugins.AllowlistEntry) { out, err := yamlMarshal(entries) if err != nil { fmt.Fprintln(os.Stderr, err) exitFunc(1) } out = bytes.ReplaceAll(out, []byte("params: {}"), []byte("params: null")) + if err := writeFile(*file, out, 0644); err != nil { fmt.Fprintln(os.Stderr, err) exitFunc(1)