Skip to content

Commit

Permalink
feat: add additional restrictions for login server
Browse files Browse the repository at this point in the history
Signed-off-by: Alan Tang <[email protected]>
  • Loading branch information
Standing-Man committed Jan 1, 2025
1 parent c901e34 commit d6daef3
Showing 1 changed file with 82 additions and 16 deletions.
98 changes: 82 additions & 16 deletions pkg/utils/helper.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import (
"errors"
"fmt"
"net"
"net/url"
"regexp"
"strconv"
"strings"
Expand Down Expand Up @@ -42,37 +43,102 @@ func FormatUrl(url string) string {

// ValidateDomain validates subdomain, IP, or top-level domain formats
func ValidateDomain(domain string) error {
if strings.Contains(domain, ":") {
parts := strings.Split(domain, ":")
if len(parts) != 2 {
return errors.New("invalid domain format: too many colons")
}
port, err := strconv.Atoi(parts[1])
if err != nil || port < 0 || port > 65535 {
return errors.New("invalid port number")
}
domain = parts[0]
url := FormatUrl(domain)

err := isValidURL(url)
if err != nil {
return err
}
return nil
}

func isValidURL(rawURL string) error {
parsedURL, err := url.Parse(rawURL)
if err != nil {
return errors.New("invaild url scheme")
}

if net.ParseIP(domain) != nil {
return nil
if parsedURL.Host == "" {
return errors.New("missing domain")
}

host, port, err := net.SplitHostPort(parsedURL.Host)
if err != nil {
host = parsedURL.Host
}

if err := isValidHost(host); err != nil {
return err
}

if port != "" && !isValidPort(port) {
return errors.New("invalid port")
}

return nil
}

// Validate whether the port is valid
func isValidPort(port string) bool {
portRegex := `^([1-9][0-9]{0,4})$`
match, _ := regexp.MatchString(portRegex, port)
if !match {
return false
}
portNumber := 0
fmt.Sscanf(port, "%d", &portNumber)
return portNumber >= 1 && portNumber <= 65535
}

func extractIP(host string) string {
ipv4Regex := `(?:\d{1,3}\.){3}\d{1,3}`

re := regexp.MustCompile(ipv4Regex)
ipv4Matches := re.FindString(host)

if ipv4Matches != "" {
return ipv4Matches
}

parts := strings.Split(domain, ".")
return ""
}

// Validate whether the host is valid
func isValidHost(host string) error {
ip := extractIP(host)
if ip != "" {
return isValidIPv4(host)
}

parts := strings.Split(host, ".")
if len(parts) < 2 {
return errors.New("invalid domain: must have at least one dot")
return errors.New("invalid host: must have at least one dot")
}

for _, part := range parts {
if !isValidLabel(part) {
return fmt.Errorf("invalid domain label: %s", part)
return fmt.Errorf("invalid host label: %s", part)
}
}

if len(parts[len(parts)-1]) < 2 {
return errors.New("invalid top-level domain: must be at least 2 characters")
return errors.New("invalid top-level host: must be at least 2 characters")
}
return nil
}

func isValidIPv4(ip string) error {
octets := strings.Split(ip, ".")
if len(octets) != 4 {
return errors.New("IP: consists of more than three octets")
}

for _, octet := range octets {
num, err := strconv.Atoi(octet)
if err != nil || num < 0 || num > 255 {
return errors.New("IP: octet exceeds range")
}
}
return nil
}

Expand Down

0 comments on commit d6daef3

Please sign in to comment.