diff --git a/pkg/utils/helper.go b/pkg/utils/helper.go index 7919e2fd..a17cd987 100644 --- a/pkg/utils/helper.go +++ b/pkg/utils/helper.go @@ -4,6 +4,7 @@ import ( "errors" "fmt" "net" + "net/url" "regexp" "strconv" "strings" @@ -42,37 +43,102 @@ func FormatUrl(url string) string { // ValidateDomain validates subdomain, IP, or top-level domain formats func ValidateDomain(domain string) error { - if strings.Contains(domain, ":") { - parts := strings.Split(domain, ":") - if len(parts) != 2 { - return errors.New("invalid domain format: too many colons") - } - port, err := strconv.Atoi(parts[1]) - if err != nil || port < 0 || port > 65535 { - return errors.New("invalid port number") - } - domain = parts[0] + url := FormatUrl(domain) + + err := isValidURL(url) + if err != nil { + return err + } + return nil +} + +func isValidURL(rawURL string) error { + parsedURL, err := url.Parse(rawURL) + if err != nil { + return errors.New("invaild url scheme") } - if net.ParseIP(domain) != nil { - return nil + if parsedURL.Host == "" { + return errors.New("missing domain") + } + + host, port, err := net.SplitHostPort(parsedURL.Host) + if err != nil { + host = parsedURL.Host + } + + if err := isValidHost(host); err != nil { + return err + } + + if port != "" && !isValidPort(port) { + return errors.New("invalid port") + } + + return nil +} + +// Validate whether the port is valid +func isValidPort(port string) bool { + portRegex := `^([1-9][0-9]{0,4})$` + match, _ := regexp.MatchString(portRegex, port) + if !match { + return false + } + portNumber := 0 + fmt.Sscanf(port, "%d", &portNumber) + return portNumber >= 1 && portNumber <= 65535 +} + +func extractIP(host string) string { + ipv4Regex := `(?:\d{1,3}\.){3}\d{1,3}` + + re := regexp.MustCompile(ipv4Regex) + ipv4Matches := re.FindString(host) + + if ipv4Matches != "" { + return ipv4Matches } - parts := strings.Split(domain, ".") + return "" +} + +// Validate whether the host is valid +func isValidHost(host string) error { + ip := extractIP(host) + if ip != "" { + return isValidIPv4(host) + } + + parts := strings.Split(host, ".") if len(parts) < 2 { - return errors.New("invalid domain: must have at least one dot") + return errors.New("invalid host: must have at least one dot") } for _, part := range parts { if !isValidLabel(part) { - return fmt.Errorf("invalid domain label: %s", part) + return fmt.Errorf("invalid host label: %s", part) } } if len(parts[len(parts)-1]) < 2 { - return errors.New("invalid top-level domain: must be at least 2 characters") + return errors.New("invalid top-level host: must be at least 2 characters") } + return nil +} +func isValidIPv4(ip string) error { + octets := strings.Split(ip, ".") + if len(octets) != 4 { + return errors.New("IP: consists of more than three octets") + } + + for _, octet := range octets { + num, err := strconv.Atoi(octet) + if err != nil || num < 0 || num > 255 { + return errors.New("IP: octet exceeds range") + } + } return nil }