Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
217 changes: 113 additions & 104 deletions cmd/allowlist/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,12 @@ import (
"bytes"
"flag"
"fmt"
yaml "gopkg.in/yaml.v3"
"os"
"slices"
"strings"

yaml "gopkg.in/yaml.v3"

"github.com/winhowes/AuthTranslator/cmd/allowlist/plugins"
)

Expand Down Expand Up @@ -69,86 +71,31 @@ func addEntry(args []string) {
fs.Usage()
return
}
var params map[string]interface{}
if *paramList != "" {
params = make(map[string]interface{})
for _, kv := range strings.Split(*paramList, ",") {
kv = strings.TrimSpace(kv)
if kv == "" {
continue
}
parts := strings.SplitN(kv, "=", 2)
if len(parts) == 2 {
k := strings.TrimSpace(parts[0])
v := strings.TrimSpace(parts[1])
params[k] = v
}
}
}
params := parseParams(*paramList)

// load file
data, err := os.ReadFile(*file)
if err != nil && !os.IsNotExist(err) {
entries, err := loadAllowlist(true)
if err != nil {
fmt.Fprintln(os.Stderr, err)
return
}
var entries []plugins.AllowlistEntry
if len(data) > 0 {
if err := yaml.Unmarshal(data, &entries); err != nil {
fmt.Fprintln(os.Stderr, err)
return
}
}
// find integration

wantName := strings.ToLower(*integ)
var entry *plugins.AllowlistEntry
for i := range entries {
if strings.ToLower(entries[i].Integration) == wantName {
entry = &entries[i]
break
}
}
if entry == nil {
entryIdx := findIntegration(entries, wantName)
if entryIdx == -1 {
entries = append(entries, plugins.AllowlistEntry{Integration: wantName})
entry = &entries[len(entries)-1]
entryIdx = len(entries) - 1
} else {
entry.Integration = wantName
}
// find caller
var callerCfg *plugins.CallerConfig
for i := range entry.Callers {
if entry.Callers[i].ID == *caller {
callerCfg = &entry.Callers[i]
break
}
}
if callerCfg == nil {
entry.Callers = append(entry.Callers, plugins.CallerConfig{ID: *caller})
callerCfg = &entry.Callers[len(entry.Callers)-1]
}
replaced := false
for i := range callerCfg.Capabilities {
if callerCfg.Capabilities[i].Name == *capName {
callerCfg.Capabilities[i].Params = params
replaced = true
break
}
}
if !replaced {
callerCfg.Capabilities = append(callerCfg.Capabilities, plugins.CapabilityConfig{Name: *capName, Params: params})
entries[entryIdx].Integration = wantName
}

out, err := yamlMarshal(entries)
if err != nil {
fmt.Fprintln(os.Stderr, err)
exitFunc(1)
callerIdx := findCaller(entries[entryIdx].Callers, *caller)
if callerIdx == -1 {
entries[entryIdx].Callers = append(entries[entryIdx].Callers, plugins.CallerConfig{ID: *caller})
callerIdx = len(entries[entryIdx].Callers) - 1
}
out = bytes.ReplaceAll(out, []byte("params: {}"), []byte("params: null"))

if err := writeFile(*file, out, 0644); err != nil {
fmt.Fprintln(os.Stderr, err)
exitFunc(1)
}
entries[entryIdx].Callers[callerIdx].Capabilities = upsertCapability(entries[entryIdx].Callers[callerIdx].Capabilities, *capName, params)
saveAllowlist(entries)
}

func removeEntry(args []string) {
Expand All @@ -168,59 +115,121 @@ func removeEntry(args []string) {
return
}

data, err := os.ReadFile(*file)
entries, err := loadAllowlist(false)
if err != nil {
fmt.Fprintln(os.Stderr, err)
return
}

wantName := strings.ToLower(*integ)
entryIdx := findIntegration(entries, wantName)
if entryIdx != -1 {
entries[entryIdx].Integration = wantName
callerIdx := findCaller(entries[entryIdx].Callers, *caller)
if callerIdx != -1 {
entries[entryIdx].Callers[callerIdx].Capabilities = trimCapabilities(entries[entryIdx].Callers[callerIdx].Capabilities, *capName)
if len(entries[entryIdx].Callers[callerIdx].Capabilities) == 0 {
entries[entryIdx].Callers = slices.Delete(entries[entryIdx].Callers, callerIdx, callerIdx+1)
}
if len(entries[entryIdx].Callers) == 0 {
entries = slices.Delete(entries, entryIdx, entryIdx+1)
}
}
}

saveAllowlist(entries)
}

func parseParams(paramList string) map[string]interface{} {
params := make(map[string]interface{})
for _, kv := range strings.Split(paramList, ",") {
kv = strings.TrimSpace(kv)
if kv == "" {
continue
}
key, value, ok := strings.Cut(kv, "=")
if !ok {
continue
}
key = strings.TrimSpace(key)
value = strings.TrimSpace(value)
if key == "" {
continue
}
params[key] = value
}
if len(params) == 0 {
return nil
}
return params
}

func loadAllowlist(allowMissing bool) ([]plugins.AllowlistEntry, error) {
data, err := os.ReadFile(*file)
if err != nil {
if allowMissing && os.IsNotExist(err) {
return nil, nil
}
return nil, err
}
if len(data) == 0 {
return nil, nil
}
var entries []plugins.AllowlistEntry
if err := yaml.Unmarshal(data, &entries); err != nil {
fmt.Fprintln(os.Stderr, err)
return
return nil, err
}
return entries, nil
}

wantName := strings.ToLower(*integ)
for ei := range entries {
if strings.ToLower(entries[ei].Integration) != wantName {
continue
func findIntegration(entries []plugins.AllowlistEntry, name string) int {
for i := range entries {
if strings.ToLower(entries[i].Integration) == name {
return i
}
entries[ei].Integration = wantName
for ci := range entries[ei].Callers {
if entries[ei].Callers[ci].ID != *caller {
continue
}
caps := entries[ei].Callers[ci].Capabilities
for i := 0; i < len(caps); i++ {
if caps[i].Name == *capName {
caps = append(caps[:i], caps[i+1:]...)
i--
continue
}
}
if len(caps) == 0 {
entries[ei].Callers = append(entries[ei].Callers[:ci], entries[ei].Callers[ci+1:]...)
} else {
for i := range caps {
if len(caps[i].Params) == 0 {
caps[i].Params = nil
}
}
entries[ei].Callers[ci].Capabilities = caps
}
break
}
return -1
}

func findCaller(callers []plugins.CallerConfig, id string) int {
for i := range callers {
if callers[i].ID == id {
return i
}
if len(entries[ei].Callers) == 0 {
entries = append(entries[:ei], entries[ei+1:]...)
}
return -1
}

func upsertCapability(caps []plugins.CapabilityConfig, name string, params map[string]interface{}) []plugins.CapabilityConfig {
for i := range caps {
if caps[i].Name == name {
caps[i].Params = params
return caps
}
break
}
return append(caps, plugins.CapabilityConfig{Name: name, Params: params})
}

func trimCapabilities(caps []plugins.CapabilityConfig, name string) []plugins.CapabilityConfig {
trimmed := slices.DeleteFunc(caps, func(cap plugins.CapabilityConfig) bool {
return cap.Name == name
})
for i := range trimmed {
if len(trimmed[i].Params) == 0 {
trimmed[i].Params = nil
}
}
return trimmed
}

func saveAllowlist(entries []plugins.AllowlistEntry) {
out, err := yamlMarshal(entries)
if err != nil {
fmt.Fprintln(os.Stderr, err)
exitFunc(1)
}
out = bytes.ReplaceAll(out, []byte("params: {}"), []byte("params: null"))

if err := writeFile(*file, out, 0644); err != nil {
fmt.Fprintln(os.Stderr, err)
exitFunc(1)
Expand Down
Loading