diff --git a/intra/netstack/icmp.go b/intra/netstack/icmp.go index 7da06220..9e5f43e0 100644 --- a/intra/netstack/icmp.go +++ b/intra/netstack/icmp.go @@ -7,6 +7,7 @@ package netstack import ( + "encoding/binary" "math" "github.com/celzero/firestack/intra/core" @@ -19,6 +20,79 @@ import ( "gvisor.dev/gvisor/pkg/tcpip/transport/icmp" ) +type ICMPHackTarget struct{} + +func (t *ICMPHackTarget) Action(pkt *stack.PacketBuffer, hook stack.Hook, r *stack.Route, _ stack.AddressableEndpoint) (stack.RuleVerdict, int) { + transportHdr := pkt.TransportHeader() + if len(transportHdr.Slice()) < 8 { + return stack.RuleDrop, 0 + } + switch pkt.TransportProtocolNumber { + case header.ICMPv6ProtocolNumber: + icmp6Hdr := header.ICMPv6(transportHdr.Slice()) + if icmp6Hdr.Type() == header.ICMPv6EchoRequest { + //https://www.rfc-editor.org/rfc/rfc4443.html#section-2.1 + //200 Private experimentation + icmp6Hdr.SetType(200) + icmp6Hdr.SetChecksum(0) + + ipv6Hdr := pkt.NetworkHeader().Slice() + icmp6Msg := stack.PayloadSince(transportHdr).AsSlice() + t1 := checksum.Checksumer{} + t1.Add(ipv6Hdr[8:40]) + var t2 [8]byte + binary.BigEndian.PutUint32(t2[:4], uint32(len(icmp6Msg))) + t2[7] = uint8(header.ICMPv6ProtocolNumber) + t1.Add(t2[:]) + t1.Add(icmp6Msg) + + icmp6Hdr.SetChecksum(t1.Checksum()) + } + case header.ICMPv4ProtocolNumber: + icmp4Hdr := header.ICMPv4(transportHdr.Slice()) + if icmp4Hdr.Type() == header.ICMPv4Echo { + //https://www.rfc-editor.org/rfc/rfc4727.html#section-4 + //253 RFC3692-style Experiment 1 + icmp4Hdr.SetType(253) + icmp4Hdr.SetChecksum(0) + icmp4Msg := stack.PayloadSince(transportHdr).AsSlice() + icmp4Hdr.SetChecksum(checksum.Checksum(icmp4Msg, 0)) + } + } + return stack.RuleAccept, 0 +} + +func restoreICMPv6Type(h header.ICMPv6) { + if h.Type() == 200 { + h.SetType(header.ICMPv6EchoRequest) + } +} +func restoreICMPv4Type(h header.ICMPv4) { + if h.Type() == 253 { + h.SetType(header.ICMPv4Echo) + } +} + +func ICMPHack(s *stack.Stack, target stack.Target) { + ipt := s.IPTables() + + table := ipt.GetTable(stack.MangleID, true) + index := table.BuiltinChains[stack.Prerouting] + rules := table.Rules + rules[index].Filter.Protocol = header.ICMPv6ProtocolNumber + rules[index].Filter.CheckProtocol = true + rules[index].Target = target + ipt.ReplaceTable(stack.MangleID, table, true) + + table = ipt.GetTable(stack.MangleID, false) + index = table.BuiltinChains[stack.Prerouting] + rules = table.Rules + rules[index].Filter.Protocol = header.ICMPv4ProtocolNumber + rules[index].Filter.CheckProtocol = true + rules[index].Target = target + ipt.ReplaceTable(stack.MangleID, table, false) +} + type GICMPHandler interface { GBaseConnHandler GEchoConnHandler @@ -33,6 +107,8 @@ type icmpForwarder struct { // github.com/google/gvisor/blob/738e1d995f/pkg/tcpip/network/ipv4/icmp.go // github.com/google/gvisor/blob/738e1d995f/pkg/tcpip/network/ipv6/icmp.go func OutboundICMP(id string, s *stack.Stack, hdl GICMPHandler) { + ICMPHack(s, &ICMPHackTarget{}) + // remove default handlers s.SetTransportProtocolHandler(icmp.ProtocolNumber4, nil) s.SetTransportProtocolHandler(icmp.ProtocolNumber6, nil) @@ -69,6 +145,7 @@ func (f *icmpForwarder) reply4(id stack.TransportEndpointID, pkt *stack.PacketBu // ref: github.com/google/gvisor/blob/acf460d0d735/pkg/tcpip/stack/conntrack.go#L933 hdr := header.ICMPv4(pkt.TransportHeader().Slice()) + restoreICMPv4Type(hdr) if hdr.Type() != header.ICMPv4Echo { // netstack handles other msgs except echo / ping log.D("icmp: v4: %s: type %v passthrough", f.o, hdr.Type()) @@ -166,6 +243,7 @@ func (f *icmpForwarder) reply6(id stack.TransportEndpointID, pkt *stack.PacketBu } hdr := header.ICMPv6(pkt.TransportHeader().Slice()) + restoreICMPv6Type(hdr) if hdr.Type() != header.ICMPv6EchoRequest { log.D("icmp: v6: %s: type %v/%v passthrough", f.o, hdr.Type(), hdr.Code()) return // netstack to handle other msgs except echo / ping