diff --git a/main.go b/main.go index 12849b1..064fdcb 100644 --- a/main.go +++ b/main.go @@ -291,7 +291,7 @@ func main() { fmt.Println("Newt version " + newtVersion) os.Exit(0) } else { - logger.Info("Newt version " + newtVersion) + logger.Info("Newt version %s", newtVersion) } if err := updates.CheckForUpdate("fosrl", "newt", newtVersion); err != nil { @@ -1138,9 +1138,9 @@ persistent_keepalive_interval=5`, fixKey(privateKey.String()), fixKey(wgData.Pub } if err := healthMonitor.EnableTarget(requestData.ID); err != nil { - logger.Error("Failed to enable health check target %s: %v", requestData.ID, err) + logger.Error("Failed to enable health check target %d: %v", requestData.ID, err) } else { - logger.Info("Enabled health check target: %s", requestData.ID) + logger.Info("Enabled health check target: %d", requestData.ID) } }) @@ -1163,9 +1163,9 @@ persistent_keepalive_interval=5`, fixKey(privateKey.String()), fixKey(wgData.Pub } if err := healthMonitor.DisableTarget(requestData.ID); err != nil { - logger.Error("Failed to disable health check target %s: %v", requestData.ID, err) + logger.Error("Failed to disable health check target %d: %v", requestData.ID, err) } else { - logger.Info("Disabled health check target: %s", requestData.ID) + logger.Info("Disabled health check target: %d", requestData.ID) } }) diff --git a/network/network.go b/network/network.go index e359219..1ea47bf 100644 --- a/network/network.go +++ b/network/network.go @@ -8,24 +8,17 @@ import ( "net" "time" - "github.com/google/gopacket" - "github.com/google/gopacket/layers" "github.com/vishvananda/netlink" - "golang.org/x/net/bpf" - "golang.org/x/net/ipv4" ) const ( - udpProtocol = 17 - // EmptyUDPSize is the size of an empty UDP packet - EmptyUDPSize = 28 - timeout = time.Second * 10 + timeout = time.Second * 10 ) // Server stores data relating to the server type Server struct { Hostname string - Addr *net.IPAddr + Addr net.IP Port uint16 } @@ -37,159 +30,92 @@ type PeerNet struct { NewtID string } -// GetClientIP gets source ip address that will be used when sending data to dstIP +// GetClientIP gets the source IP address for a destination. func GetClientIP(dstIP net.IP) net.IP { routes, err := netlink.RouteGet(dstIP) - if err != nil { + if err != nil || len(routes) == 0 { log.Fatalln("Error getting route:", err) } return routes[0].Src } -// HostToAddr resolves a hostname, whether DNS or IP to a valid net.IPAddr -func HostToAddr(hostStr string) *net.IPAddr { - remoteAddrs, err := net.LookupHost(hostStr) +// HostToAddr resolves a hostname, preferring IPv4. +func HostToAddr(hostStr string) net.IP { + ips, err := net.LookupIP(hostStr) if err != nil { - log.Fatalln("Error parsing remote address:", err) + log.Fatalf("Error looking up host %s: %v", hostStr, err) } - - for _, addrStr := range remoteAddrs { - if remoteAddr, err := net.ResolveIPAddr("ip4", addrStr); err == nil { - return remoteAddr + for _, ip := range ips { + if ip.To4() != nil { + return ip } } + if len(ips) > 0 { + return ips[0] + } + log.Fatalf("No IP address found for host: %s", hostStr) return nil } -// SetupRawConn creates an ipv4 and udp only RawConn and applies packet filtering -func SetupRawConn(server *Server, client *PeerNet) *ipv4.RawConn { - packetConn, err := net.ListenPacket("ip4:udp", client.IP.String()) - if err != nil { - log.Fatalln("Error creating packetConn:", err) +// SetupConn creates a standard UDP connection for the appropriate IP family. +// The BPF and raw socket logic has been removed for compatibility. +func SetupConn(client *PeerNet) net.PacketConn { + var networkType string + var localAddr string + + if client.IP.To4() != nil { + networkType = "udp4" + localAddr = fmt.Sprintf("%s:%d", client.IP.String(), client.Port) + } else if client.IP.To16() != nil { + networkType = "udp6" + localAddr = fmt.Sprintf("[%s]:%d", client.IP.String(), client.Port) + } else { + log.Fatalln("Client IP is not a valid IPv4 or IPv6 address") } - rawConn, err := ipv4.NewRawConn(packetConn) + conn, err := net.ListenPacket(networkType, localAddr) if err != nil { - log.Fatalln("Error creating rawConn:", err) + log.Fatalf("Error creating packetConn for %s: %v", localAddr, err) } - - ApplyBPF(rawConn, server, client) - - return rawConn + return conn } -// ApplyBPF constructs a BPF program and applies it to the RawConn -func ApplyBPF(rawConn *ipv4.RawConn, server *Server, client *PeerNet) { - const ipv4HeaderLen = 20 - const srcIPOffset = 12 - const srcPortOffset = ipv4HeaderLen + 0 - const dstPortOffset = ipv4HeaderLen + 2 - - ipArr := []byte(server.Addr.IP.To4()) - ipInt := uint32(ipArr[0])<<(3*8) + uint32(ipArr[1])<<(2*8) + uint32(ipArr[2])<<8 + uint32(ipArr[3]) - - bpfRaw, err := bpf.Assemble([]bpf.Instruction{ - bpf.LoadAbsolute{Off: srcIPOffset, Size: 4}, - bpf.JumpIf{Cond: bpf.JumpEqual, Val: ipInt, SkipFalse: 5, SkipTrue: 0}, - - bpf.LoadAbsolute{Off: srcPortOffset, Size: 2}, - bpf.JumpIf{Cond: bpf.JumpEqual, Val: uint32(server.Port), SkipFalse: 3, SkipTrue: 0}, - - bpf.LoadAbsolute{Off: dstPortOffset, Size: 2}, - bpf.JumpIf{Cond: bpf.JumpEqual, Val: uint32(client.Port), SkipFalse: 1, SkipTrue: 0}, - - bpf.RetConstant{Val: 1<<(8*4) - 1}, - bpf.RetConstant{Val: 0}, - }) - - if err != nil { - log.Fatalln("Error assembling BPF:", err) - } - - err = rawConn.SetBPF(bpfRaw) +// SendDataPacket sends a JSON payload to the Server using a standard UDP socket. +func SendDataPacket(data interface{}, conn net.PacketConn, server *Server, client *PeerNet) error { + jsonData, err := json.Marshal(data) if err != nil { - log.Fatalln("Error setting BPF:", err) - } -} - -// MakePacket constructs a request packet to send to the server -func MakePacket(payload []byte, server *Server, client *PeerNet) []byte { - buf := gopacket.NewSerializeBuffer() - - opts := gopacket.SerializeOptions{ - FixLengths: true, - ComputeChecksums: true, - } - - ipHeader := layers.IPv4{ - SrcIP: client.IP, - DstIP: server.Addr.IP, - Version: 4, - TTL: 64, - Protocol: layers.IPProtocolUDP, + return fmt.Errorf("failed to marshal payload: %v", err) } - udpHeader := layers.UDP{ - SrcPort: layers.UDPPort(client.Port), - DstPort: layers.UDPPort(server.Port), + destAddr := &net.UDPAddr{ + IP: server.Addr, + Port: int(server.Port), } - payloadLayer := gopacket.Payload(payload) - - udpHeader.SetNetworkLayerForChecksum(&ipHeader) - - gopacket.SerializeLayers(buf, opts, &ipHeader, &udpHeader, &payloadLayer) - - return buf.Bytes() -} - -// SendPacket sends packet to the Server -func SendPacket(packet []byte, conn *ipv4.RawConn, server *Server, client *PeerNet) error { - fullPacket := MakePacket(packet, server, client) - _, err := conn.WriteToIP(fullPacket, server.Addr) + _, err = conn.WriteTo(jsonData, destAddr) return err } -// SendDataPacket sends a JSON payload to the Server -func SendDataPacket(data interface{}, conn *ipv4.RawConn, server *Server, client *PeerNet) error { - jsonData, err := json.Marshal(data) - if err != nil { - return fmt.Errorf("failed to marshal payload: %v", err) - } - - return SendPacket(jsonData, conn, server, client) -} - -// RecvPacket receives a UDP packet from server -func RecvPacket(conn *ipv4.RawConn, server *Server, client *PeerNet) ([]byte, int, error) { +// RecvDataPacket receives a JSON packet from the server. +func RecvDataPacket(conn net.PacketConn) ([]byte, error) { err := conn.SetReadDeadline(time.Now().Add(timeout)) if err != nil { - return nil, 0, err + return nil, err } - response := make([]byte, 4096) - n, err := conn.Read(response) - if err != nil { - return nil, n, err - } - return response, n, nil -} - -// RecvDataPacket receives and unmarshals a JSON packet from server -func RecvDataPacket(conn *ipv4.RawConn, server *Server, client *PeerNet) ([]byte, error) { - response, n, err := RecvPacket(conn, server, client) + n, _, err := conn.ReadFrom(response) if err != nil { return nil, err } - - // Extract payload from UDP packet - payload := response[EmptyUDPSize:n] - return payload, nil + return response[:n], nil } -// ParseResponse takes a response packet and parses it into an IP and port +// ParseResponse takes a response packet and parses it into an IP and port. func ParseResponse(response []byte) (net.IP, uint16) { + if len(response) < 6 { + return nil, 0 + } ip := net.IP(response[:4]) port := binary.BigEndian.Uint16(response[4:6]) return ip, port -} +} \ No newline at end of file diff --git a/wg/wg.go b/wg/wg.go index 3cee1a9..2eaeea8 100644 --- a/wg/wg.go +++ b/wg/wg.go @@ -517,9 +517,32 @@ func (s *WireGuardService) addPeer(peer Peer) error { var peerConfig wgtypes.PeerConfig if peer.Endpoint != "" { - endpoint, err := net.ResolveUDPAddr("udp", peer.Endpoint) + // This logic correctly handles IPv4, IPv6, and hostnames. + formattedEndpoint := peer.Endpoint + host, _, err := net.SplitHostPort(formattedEndpoint) + if err == nil { + // It's a host:port string, check if the host is a literal IPv6 + ip := net.ParseIP(host) + if ip != nil && ip.To4() == nil { // It is a literal IPv6 + // Already correctly formatted by SplitHostPort logic, do nothing + } + } else { + // Not a standard host:port string, could be IPv6 without brackets. + // Let's try to parse it as such. + lastColon := strings.LastIndex(formattedEndpoint, ":") + if lastColon != -1 { + host := formattedEndpoint[:lastColon] + port := formattedEndpoint[lastColon+1:] + ip := net.ParseIP(host) + if ip != nil && ip.To4() == nil { // It is a literal IPv6 + formattedEndpoint = fmt.Sprintf("[%s]:%s", host, port) + } + } + } + + endpoint, err := net.ResolveUDPAddr("udp", formattedEndpoint) if err != nil { - return fmt.Errorf("failed to resolve endpoint address: %w", err) + return fmt.Errorf("failed to resolve endpoint address '%s': %w", formattedEndpoint, err) } peerConfig = wgtypes.PeerConfig{ @@ -539,6 +562,7 @@ func (s *WireGuardService) addPeer(peer Peer) error { config := wgtypes.Config{ Peers: []wgtypes.PeerConfig{peerConfig}, + ReplacePeers: false, } if err := s.wgClient.ConfigureDevice(s.interfaceName, config); err != nil { @@ -550,6 +574,7 @@ func (s *WireGuardService) addPeer(peer Peer) error { return nil } + func (s *WireGuardService) handleRemovePeer(msg websocket.WSMessage) { logger.Debug("Received message: %v", msg.Data) // parse the publicKey from the message which is json { "publicKey": "asdfasdfl;akjsdf" } @@ -599,114 +624,114 @@ func (s *WireGuardService) removePeer(publicKey string) error { } func (s *WireGuardService) handleUpdatePeer(msg websocket.WSMessage) { - logger.Debug("Received message: %v", msg.Data) - // Define a struct to match the incoming message structure with optional fields - type UpdatePeerRequest struct { - PublicKey string `json:"publicKey"` - AllowedIPs []string `json:"allowedIps,omitempty"` - Endpoint string `json:"endpoint,omitempty"` - } - jsonData, err := json.Marshal(msg.Data) - if err != nil { - logger.Info("Error marshaling data: %v", err) - return - } - var request UpdatePeerRequest - if err := json.Unmarshal(jsonData, &request); err != nil { - logger.Info("Error unmarshaling peer data: %v", err) - return - } - // First, get the current peer configuration to preserve any unmodified fields - device, err := s.wgClient.Device(s.interfaceName) - if err != nil { - logger.Info("Error getting WireGuard device: %v", err) - return - } - pubKey, err := wgtypes.ParseKey(request.PublicKey) - if err != nil { - logger.Info("Error parsing public key: %v", err) - return - } - // Find the existing peer configuration - var currentPeer *wgtypes.Peer - for _, p := range device.Peers { - if p.PublicKey == pubKey { - currentPeer = &p - break - } - } - if currentPeer == nil { - logger.Info("Peer %s not found, cannot update", request.PublicKey) - return - } - // Create the update peer config - peerConfig := wgtypes.PeerConfig{ - PublicKey: pubKey, - UpdateOnly: true, - } - // Keep the default persistent keepalive of 1 second - keepalive := time.Second - peerConfig.PersistentKeepaliveInterval = &keepalive - - // Handle Endpoint field special case - // If Endpoint is included in the request but empty, we want to remove the endpoint - // If Endpoint is not included, we don't modify it - endpointSpecified := false - for key := range msg.Data.(map[string]interface{}) { - if key == "endpoint" { - endpointSpecified = true - break - } - } - - // Only update AllowedIPs if provided in the request - if len(request.AllowedIPs) > 0 { - var allowedIPs []net.IPNet - for _, ipStr := range request.AllowedIPs { - _, ipNet, err := net.ParseCIDR(ipStr) - if err != nil { - logger.Info("Error parsing allowed IP %s: %v", ipStr, err) - return - } - allowedIPs = append(allowedIPs, *ipNet) - } - peerConfig.AllowedIPs = allowedIPs - peerConfig.ReplaceAllowedIPs = true - logger.Info("Updating AllowedIPs for peer %s", request.PublicKey) - } else if endpointSpecified && request.Endpoint == "" { - peerConfig.ReplaceAllowedIPs = false - } - - if endpointSpecified { - if request.Endpoint != "" { - // Update to new endpoint - endpoint, err := net.ResolveUDPAddr("udp", request.Endpoint) - if err != nil { - logger.Info("Error resolving endpoint address %s: %v", request.Endpoint, err) - return - } - peerConfig.Endpoint = endpoint - logger.Info("Updating Endpoint for peer %s to %s", request.PublicKey, request.Endpoint) - } else { - // specify any address to listen for any incoming packets - peerConfig.Endpoint = &net.UDPAddr{ - IP: net.IPv4(127, 0, 0, 1), - } - logger.Info("Removing Endpoint for peer %s", request.PublicKey) - } - } - - // Apply the configuration update - config := wgtypes.Config{ - Peers: []wgtypes.PeerConfig{peerConfig}, - } - if err := s.wgClient.ConfigureDevice(s.interfaceName, config); err != nil { - logger.Info("Error updating peer configuration: %v", err) - return - } - logger.Info("Peer %s updated successfully", request.PublicKey) + logger.Debug("Received message: %v", msg.Data) + // Define a struct to match the incoming message structure with optional fields + type UpdatePeerRequest struct { + PublicKey string `json:"publicKey"` + AllowedIPs []string `json:"allowedIps,omitempty"` + Endpoint string `json:"endpoint,omitempty"` + } + jsonData, err := json.Marshal(msg.Data) + if err != nil { + logger.Info("Error marshaling data: %v", err) + return + } + var request UpdatePeerRequest + if err := json.Unmarshal(jsonData, &request); err != nil { + logger.Info("Error unmarshaling peer data: %v", err) + return + } + // First, get the current peer configuration to preserve any unmodified fields + device, err := s.wgClient.Device(s.interfaceName) + if err != nil { + logger.Info("Error getting WireGuard device: %v", err) + return + } + pubKey, err := wgtypes.ParseKey(request.PublicKey) + if err != nil { + logger.Info("Error parsing public key: %v", err) + return + } + // Find the existing peer configuration + var currentPeer *wgtypes.Peer + for i := range device.Peers { + if device.Peers[i].PublicKey == pubKey { + currentPeer = &device.Peers[i] + break + } + } + if currentPeer == nil { + logger.Info("Peer %s not found, cannot update", request.PublicKey) + return + } + // Create the update peer config + peerConfig := wgtypes.PeerConfig{ + PublicKey: pubKey, + UpdateOnly: true, + } + // Keep the default persistent keepalive + keepalive := 25 * time.Second + peerConfig.PersistentKeepaliveInterval = &keepalive + + // Handle Endpoint field special case + endpointSpecified := false + if rawData, ok := msg.Data.(map[string]interface{}); ok { + _, endpointSpecified = rawData["endpoint"] + } + + // Only update AllowedIPs if provided in the request + if len(request.AllowedIPs) > 0 { + var allowedIPs []net.IPNet + for _, ipStr := range request.AllowedIPs { + _, ipNet, err := net.ParseCIDR(ipStr) + if err != nil { + logger.Info("Error parsing allowed IP %s: %v", ipStr, err) + return + } + allowedIPs = append(allowedIPs, *ipNet) + } + peerConfig.AllowedIPs = allowedIPs + peerConfig.ReplaceAllowedIPs = true + logger.Info("Updating AllowedIPs for peer %s", request.PublicKey) + } + + if endpointSpecified { + if request.Endpoint != "" { + // Update to new endpoint using the robust formatting logic + formattedEndpoint := request.Endpoint + host, port, err := net.SplitHostPort(request.Endpoint) + if err == nil { + ip := net.ParseIP(host) + if ip != nil && ip.To4() == nil { + formattedEndpoint = fmt.Sprintf("[%s]:%s", host, port) + } + } + endpoint, err := net.ResolveUDPAddr("udp", formattedEndpoint) + if err != nil { + logger.Info("Error resolving endpoint address %s: %v", formattedEndpoint, err) + return + } + peerConfig.Endpoint = endpoint + logger.Info("Updating Endpoint for peer %s to %s", request.PublicKey, formattedEndpoint) + } else { + // Set a valid "any" IP address instead of a nil one. + logger.Info("Removing Endpoint for peer %s", request.PublicKey) + peerConfig.Endpoint = &net.UDPAddr{IP: net.IPv4zero, Port: 0} + } + } + + // Apply the configuration update + config := wgtypes.Config{ + Peers: []wgtypes.PeerConfig{peerConfig}, + } + if err := s.wgClient.ConfigureDevice(s.interfaceName, config); err != nil { + logger.Info("Error updating peer configuration: %v", err) + return + } + logger.Info("Peer %s updated successfully", request.PublicKey) } + func (s *WireGuardService) periodicBandwidthCheck() { ticker := time.NewTicker(10 * time.Second) defer ticker.Stop() @@ -738,15 +763,13 @@ func (s *WireGuardService) calculatePeerBandwidth() ([]PeerBandwidth, error) { LastChecked: now, } - var bytesInDiff, bytesOutDiff float64 lastReading, exists := s.lastReadings[publicKey] if exists { timeDiff := currentReading.LastChecked.Sub(lastReading.LastChecked).Seconds() if timeDiff > 0 { - // Calculate bytes transferred since last reading - bytesInDiff = float64(currentReading.BytesReceived - lastReading.BytesReceived) - bytesOutDiff = float64(currentReading.BytesTransmitted - lastReading.BytesTransmitted) + bytesInDiff := float64(currentReading.BytesReceived - lastReading.BytesReceived) + bytesOutDiff := float64(currentReading.BytesTransmitted - lastReading.BytesTransmitted) // Handle counter wraparound (if the counter resets or overflows) if bytesInDiff < 0 { @@ -765,37 +788,17 @@ func (s *WireGuardService) calculatePeerBandwidth() ([]PeerBandwidth, error) { BytesIn: bytesInMB, BytesOut: bytesOutMB, }) - } else { - // If readings are too close together or time hasn't passed, report 0 - peerBandwidths = append(peerBandwidths, PeerBandwidth{ - PublicKey: publicKey, - BytesIn: 0, - BytesOut: 0, - }) } - } else { - // For first reading of a peer, report 0 to establish baseline - peerBandwidths = append(peerBandwidths, PeerBandwidth{ - PublicKey: publicKey, - BytesIn: 0, - BytesOut: 0, - }) } - - // Update the last reading s.lastReadings[publicKey] = currentReading } - // Clean up old peers + activePeers := make(map[string]struct{}) + for _, peer := range device.Peers { + activePeers[peer.PublicKey.String()] = struct{}{} + } for publicKey := range s.lastReadings { - found := false - for _, peer := range device.Peers { - if peer.PublicKey.String() == publicKey { - found = true - break - } - } - if !found { + if _, found := activePeers[publicKey]; !found { delete(s.lastReadings, publicKey) } } @@ -809,6 +812,10 @@ func (s *WireGuardService) reportPeerBandwidth() error { return fmt.Errorf("failed to calculate peer bandwidth: %v", err) } + if len(bandwidths) == 0 { + return nil + } + err = s.client.SendMessage("newt/receive-bandwidth", map[string]interface{}{ "bandwidthData": bandwidths, }) @@ -826,26 +833,29 @@ func (s *WireGuardService) sendUDPHolePunch(serverAddr string) error { return nil } - // Parse server address - serverSplit := strings.Split(serverAddr, ":") - if len(serverSplit) < 2 { - return fmt.Errorf("invalid server address format, expected hostname:port") + serverHostname, serverPortStr, err := net.SplitHostPort(serverAddr) + if err != nil { + return fmt.Errorf("failed to parse server address '%s': %v", serverAddr, err) } - serverHostname := serverSplit[0] - serverPort, err := strconv.ParseUint(serverSplit[1], 10, 16) + serverPort, err := strconv.ParseUint(serverPortStr, 10, 16) if err != nil { - return fmt.Errorf("failed to parse server port: %v", err) + return fmt.Errorf("failed to parse server port from '%s': %v", serverPortStr, err) } - // Resolve server hostname to IP - serverIPAddr := network.HostToAddr(serverHostname) - if serverIPAddr == nil { - return fmt.Errorf("failed to resolve server hostname") + var serverIPAddr net.IP + ip := net.ParseIP(serverHostname) + + if ip != nil { + serverIPAddr = ip + } else { + serverIPAddr = network.HostToAddr(serverHostname) + if serverIPAddr == nil { + return fmt.Errorf("failed to resolve server hostname: %s", serverHostname) + } } - // Get client IP based on route to server - clientIP := network.GetClientIP(serverIPAddr.IP) + clientIP := network.GetClientIP(serverIPAddr) // Create server and client configs server := &network.Server{ @@ -860,9 +870,8 @@ func (s *WireGuardService) sendUDPHolePunch(serverAddr string) error { NewtID: s.newtId, } - // Setup raw connection with BPF filtering - rawConn := network.SetupRawConn(server, client) - defer rawConn.Close() + conn := network.SetupConn(client) + defer conn.Close() // Create JSON payload payload := struct { @@ -886,7 +895,7 @@ func (s *WireGuardService) sendUDPHolePunch(serverAddr string) error { } // Send the encrypted packet using the raw connection - err = network.SendDataPacket(encryptedPayload, rawConn, server, client) + err = network.SendDataPacket(encryptedPayload, conn, server, client) if err != nil { return fmt.Errorf("failed to send UDP packet: %v", err) } @@ -973,6 +982,11 @@ func (s *WireGuardService) removeInterface() error { // Remove the WireGuard interface link, err := netlink.LinkByName(s.interfaceName) if err != nil { + // If the link is not found, we can consider it as successfully removed + if _, ok := err.(netlink.LinkNotFoundError); ok { + logger.Info("WireGuard interface %s already removed", s.interfaceName) + return nil + } return fmt.Errorf("failed to get interface: %v", err) } diff --git a/wgnetstack/wgnetstack.go b/wgnetstack/wgnetstack.go index 6684c40..797f5fd 100644 --- a/wgnetstack/wgnetstack.go +++ b/wgnetstack/wgnetstack.go @@ -605,10 +605,19 @@ func (s *WireGuardService) addPeerToDevice(peer Peer) error { // Add endpoint if specified if peer.Endpoint != "" { - config += fmt.Sprintf("\nendpoint=%s", peer.Endpoint) + // Reformat the endpoint if it's a non-bracketed IPv6 address + formattedEndpoint := peer.Endpoint + if strings.Contains(formattedEndpoint, ":") && !strings.HasPrefix(formattedEndpoint, "[") { + lastColon := strings.LastIndex(formattedEndpoint, ":") + if strings.Count(formattedEndpoint, ":") > 1 { + host := formattedEndpoint[:lastColon] + port := formattedEndpoint[lastColon+1:] + formattedEndpoint = fmt.Sprintf("[%s]:%s", host, port) + } + } + config += fmt.Sprintf("\nendpoint=%s", formattedEndpoint) } - // Add persistent keepalive config += "\npersistent_keepalive_interval=25" // Apply the configuration @@ -935,33 +944,37 @@ func (s *WireGuardService) sendUDPHolePunch(serverAddr string) error { return nil } - // Parse server address - serverSplit := strings.Split(serverAddr, ":") - if len(serverSplit) < 2 { - return fmt.Errorf("invalid server address format, expected hostname:port") + serverHostname, serverPortStr, err := net.SplitHostPort(serverAddr) + if err != nil { + return fmt.Errorf("failed to parse server address '%s': %v", serverAddr, err) } - serverHostname := serverSplit[0] - serverPort, err := strconv.ParseUint(serverSplit[1], 10, 16) + serverPort, err := strconv.ParseUint(serverPortStr, 10, 16) if err != nil { - return fmt.Errorf("failed to parse server port: %v", err) + return fmt.Errorf("failed to parse server port from '%s': %v", serverPortStr, err) } - // Resolve server hostname to IP - serverIPAddr := network.HostToAddr(serverHostname) - if serverIPAddr == nil { - return fmt.Errorf("failed to resolve server hostname") + var serverIPAddr net.IP + ip := net.ParseIP(serverHostname) + + if ip != nil { + serverIPAddr = ip + } else { + serverIPAddr = network.HostToAddr(serverHostname) + if serverIPAddr == nil { + return fmt.Errorf("failed to resolve server hostname: %s", serverHostname) + } } - // Create local UDP address using the same port as WireGuard + // Create a local UDP address that is protocol-agnostic localAddr := &net.UDPAddr{ - IP: net.IPv4zero, Port: int(s.Port), + // IP is left nil (unspecified), letting the OS choose correctly (0.0.0.0 for v4, :: for v6) } // Create remote server address remoteAddr := &net.UDPAddr{ - IP: serverIPAddr.IP, + IP: serverIPAddr, Port: int(serverPort), }