Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 5 additions & 5 deletions main.go
Original file line number Diff line number Diff line change
Expand Up @@ -291,7 +291,7 @@ func main() {
fmt.Println("Newt version " + newtVersion)
os.Exit(0)
} else {
logger.Info("Newt version " + newtVersion)
logger.Info("Newt version %s", newtVersion)
}

if err := updates.CheckForUpdate("fosrl", "newt", newtVersion); err != nil {
Expand Down Expand Up @@ -1138,9 +1138,9 @@ persistent_keepalive_interval=5`, fixKey(privateKey.String()), fixKey(wgData.Pub
}

if err := healthMonitor.EnableTarget(requestData.ID); err != nil {
logger.Error("Failed to enable health check target %s: %v", requestData.ID, err)
logger.Error("Failed to enable health check target %d: %v", requestData.ID, err)
} else {
logger.Info("Enabled health check target: %s", requestData.ID)
logger.Info("Enabled health check target: %d", requestData.ID)
}
})

Expand All @@ -1163,9 +1163,9 @@ persistent_keepalive_interval=5`, fixKey(privateKey.String()), fixKey(wgData.Pub
}

if err := healthMonitor.DisableTarget(requestData.ID); err != nil {
logger.Error("Failed to disable health check target %s: %v", requestData.ID, err)
logger.Error("Failed to disable health check target %d: %v", requestData.ID, err)
} else {
logger.Info("Disabled health check target: %s", requestData.ID)
logger.Info("Disabled health check target: %d", requestData.ID)
}
})

Expand Down
174 changes: 50 additions & 124 deletions network/network.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,24 +8,17 @@ import (
"net"
"time"

"github.com/google/gopacket"
"github.com/google/gopacket/layers"
"github.com/vishvananda/netlink"
"golang.org/x/net/bpf"
"golang.org/x/net/ipv4"
)

const (
udpProtocol = 17
// EmptyUDPSize is the size of an empty UDP packet
EmptyUDPSize = 28
timeout = time.Second * 10
timeout = time.Second * 10
)

// Server stores data relating to the server
type Server struct {
Hostname string
Addr *net.IPAddr
Addr net.IP
Port uint16
}

Expand All @@ -37,159 +30,92 @@ type PeerNet struct {
NewtID string
}

// GetClientIP gets source ip address that will be used when sending data to dstIP
// GetClientIP gets the source IP address for a destination.
func GetClientIP(dstIP net.IP) net.IP {
routes, err := netlink.RouteGet(dstIP)
if err != nil {
if err != nil || len(routes) == 0 {
log.Fatalln("Error getting route:", err)
}
return routes[0].Src
}

// HostToAddr resolves a hostname, whether DNS or IP to a valid net.IPAddr
func HostToAddr(hostStr string) *net.IPAddr {
remoteAddrs, err := net.LookupHost(hostStr)
// HostToAddr resolves a hostname, preferring IPv4.
func HostToAddr(hostStr string) net.IP {
ips, err := net.LookupIP(hostStr)
if err != nil {
log.Fatalln("Error parsing remote address:", err)
log.Fatalf("Error looking up host %s: %v", hostStr, err)
}

for _, addrStr := range remoteAddrs {
if remoteAddr, err := net.ResolveIPAddr("ip4", addrStr); err == nil {
return remoteAddr
for _, ip := range ips {
if ip.To4() != nil {
return ip
}
}
if len(ips) > 0 {
return ips[0]
}
log.Fatalf("No IP address found for host: %s", hostStr)
return nil
}

// SetupRawConn creates an ipv4 and udp only RawConn and applies packet filtering
func SetupRawConn(server *Server, client *PeerNet) *ipv4.RawConn {
packetConn, err := net.ListenPacket("ip4:udp", client.IP.String())
if err != nil {
log.Fatalln("Error creating packetConn:", err)
// SetupConn creates a standard UDP connection for the appropriate IP family.
// The BPF and raw socket logic has been removed for compatibility.
func SetupConn(client *PeerNet) net.PacketConn {
var networkType string
var localAddr string

if client.IP.To4() != nil {
networkType = "udp4"
localAddr = fmt.Sprintf("%s:%d", client.IP.String(), client.Port)
} else if client.IP.To16() != nil {
networkType = "udp6"
localAddr = fmt.Sprintf("[%s]:%d", client.IP.String(), client.Port)
} else {
log.Fatalln("Client IP is not a valid IPv4 or IPv6 address")
}

rawConn, err := ipv4.NewRawConn(packetConn)
conn, err := net.ListenPacket(networkType, localAddr)
if err != nil {
log.Fatalln("Error creating rawConn:", err)
log.Fatalf("Error creating packetConn for %s: %v", localAddr, err)
}

ApplyBPF(rawConn, server, client)

return rawConn
return conn
}

// ApplyBPF constructs a BPF program and applies it to the RawConn
func ApplyBPF(rawConn *ipv4.RawConn, server *Server, client *PeerNet) {
const ipv4HeaderLen = 20
const srcIPOffset = 12
const srcPortOffset = ipv4HeaderLen + 0
const dstPortOffset = ipv4HeaderLen + 2

ipArr := []byte(server.Addr.IP.To4())
ipInt := uint32(ipArr[0])<<(3*8) + uint32(ipArr[1])<<(2*8) + uint32(ipArr[2])<<8 + uint32(ipArr[3])

bpfRaw, err := bpf.Assemble([]bpf.Instruction{
bpf.LoadAbsolute{Off: srcIPOffset, Size: 4},
bpf.JumpIf{Cond: bpf.JumpEqual, Val: ipInt, SkipFalse: 5, SkipTrue: 0},

bpf.LoadAbsolute{Off: srcPortOffset, Size: 2},
bpf.JumpIf{Cond: bpf.JumpEqual, Val: uint32(server.Port), SkipFalse: 3, SkipTrue: 0},

bpf.LoadAbsolute{Off: dstPortOffset, Size: 2},
bpf.JumpIf{Cond: bpf.JumpEqual, Val: uint32(client.Port), SkipFalse: 1, SkipTrue: 0},

bpf.RetConstant{Val: 1<<(8*4) - 1},
bpf.RetConstant{Val: 0},
})

if err != nil {
log.Fatalln("Error assembling BPF:", err)
}

err = rawConn.SetBPF(bpfRaw)
// SendDataPacket sends a JSON payload to the Server using a standard UDP socket.
func SendDataPacket(data interface{}, conn net.PacketConn, server *Server, client *PeerNet) error {
jsonData, err := json.Marshal(data)
if err != nil {
log.Fatalln("Error setting BPF:", err)
}
}

// MakePacket constructs a request packet to send to the server
func MakePacket(payload []byte, server *Server, client *PeerNet) []byte {
buf := gopacket.NewSerializeBuffer()

opts := gopacket.SerializeOptions{
FixLengths: true,
ComputeChecksums: true,
}

ipHeader := layers.IPv4{
SrcIP: client.IP,
DstIP: server.Addr.IP,
Version: 4,
TTL: 64,
Protocol: layers.IPProtocolUDP,
return fmt.Errorf("failed to marshal payload: %v", err)
}

udpHeader := layers.UDP{
SrcPort: layers.UDPPort(client.Port),
DstPort: layers.UDPPort(server.Port),
destAddr := &net.UDPAddr{
IP: server.Addr,
Port: int(server.Port),
}

payloadLayer := gopacket.Payload(payload)

udpHeader.SetNetworkLayerForChecksum(&ipHeader)

gopacket.SerializeLayers(buf, opts, &ipHeader, &udpHeader, &payloadLayer)

return buf.Bytes()
}

// SendPacket sends packet to the Server
func SendPacket(packet []byte, conn *ipv4.RawConn, server *Server, client *PeerNet) error {
fullPacket := MakePacket(packet, server, client)
_, err := conn.WriteToIP(fullPacket, server.Addr)
_, err = conn.WriteTo(jsonData, destAddr)
return err
}

// SendDataPacket sends a JSON payload to the Server
func SendDataPacket(data interface{}, conn *ipv4.RawConn, server *Server, client *PeerNet) error {
jsonData, err := json.Marshal(data)
if err != nil {
return fmt.Errorf("failed to marshal payload: %v", err)
}

return SendPacket(jsonData, conn, server, client)
}

// RecvPacket receives a UDP packet from server
func RecvPacket(conn *ipv4.RawConn, server *Server, client *PeerNet) ([]byte, int, error) {
// RecvDataPacket receives a JSON packet from the server.
func RecvDataPacket(conn net.PacketConn) ([]byte, error) {
err := conn.SetReadDeadline(time.Now().Add(timeout))
if err != nil {
return nil, 0, err
return nil, err
}

response := make([]byte, 4096)
n, err := conn.Read(response)
if err != nil {
return nil, n, err
}
return response, n, nil
}

// RecvDataPacket receives and unmarshals a JSON packet from server
func RecvDataPacket(conn *ipv4.RawConn, server *Server, client *PeerNet) ([]byte, error) {
response, n, err := RecvPacket(conn, server, client)
n, _, err := conn.ReadFrom(response)
if err != nil {
return nil, err
}

// Extract payload from UDP packet
payload := response[EmptyUDPSize:n]
return payload, nil
return response[:n], nil
}

// ParseResponse takes a response packet and parses it into an IP and port
// ParseResponse takes a response packet and parses it into an IP and port.
func ParseResponse(response []byte) (net.IP, uint16) {
if len(response) < 6 {
return nil, 0
}
ip := net.IP(response[:4])
port := binary.BigEndian.Uint16(response[4:6])
return ip, port
}
}
Loading