Skip to content

Commit

Permalink
Merge pull request #52 from OpenCHAMI/41-bug-wireguard-endpoints-and-…
Browse files Browse the repository at this point in the history
…tunnels-are-not-cleaned-up-properly

41 bug wireguard endpoints and tunnels are not cleaned up properly
travisbcotton authored Jan 23, 2025
2 parents 4d68fcd + 7019900 commit f8f34b5
Showing 5 changed files with 67 additions and 18 deletions.
21 changes: 18 additions & 3 deletions cmd/cloud-init-server/handlers.go
Original file line number Diff line number Diff line change
@@ -7,6 +7,7 @@ import (

"github.com/OpenCHAMI/cloud-init/internal/smdclient"
"github.com/OpenCHAMI/cloud-init/pkg/cistore"
"github.com/OpenCHAMI/cloud-init/pkg/wgtunnel"
"github.com/go-chi/chi/v5"
"github.com/rs/zerolog/log"
)
@@ -129,7 +130,7 @@ func InstanceInfoHandler(sm smdclient.SMDClientInterface, store cistore.Store) h
}

// Phone home should be a POST request x-www-form-urlencoded like this: pub_key_rsa=rsa_contents&pub_key_ecdsa=ecdsa_contents&pub_key_ed25519=ed25519_contents&instance_id=i-87018aed&hostname=myhost&fqdn=myhost.internal
func PhoneHomeHandler(store cistore.Store) http.HandlerFunc {
func PhoneHomeHandler(wg *wgtunnel.InterfaceManager, sm smdclient.SMDClientInterface) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodPost {
w.WriteHeader(http.StatusMethodNotAllowed)
@@ -139,7 +140,15 @@ func PhoneHomeHandler(store cistore.Store) http.HandlerFunc {
log.Info().Msgf("Phone home request from %s", ip)
// TODO: validate the request IP against the SMD client and reject if needed

err := r.ParseForm()
id, err := sm.IDfromIP(ip)
if err != nil {
log.Error().Msgf("Error getting ID from IP: %v", err)
}
peerName, err := sm.IPfromID(id)
if err != nil {
log.Error().Msgf("Error getting IP from ID: %v", err)
}
err = r.ParseForm()
if err != nil {
log.Error().Msgf("Error parsing form data: %v", err)
w.WriteHeader(http.StatusBadRequest)
@@ -163,6 +172,12 @@ func PhoneHomeHandler(store cistore.Store) http.HandlerFunc {
Msgf("Received phone home data: pub_key_rsa=%s, pub_key_ecdsa=%s, pub_key_ed25519=%s, instance_id=%s, hostname=%s, fqdn=%s",
pubKeyRsa, pubKeyEcdsa, pubKeyEd25519, instanceId, hostname, fqdn)

w.WriteHeader(http.StatusOK)
if wg != nil {
go func() {
wg.RemovePeer(peerName)
}()

w.WriteHeader(http.StatusOK)
}
}
}
15 changes: 7 additions & 8 deletions cmd/cloud-init-server/main.go
Original file line number Diff line number Diff line change
@@ -168,7 +168,7 @@ func main() {

router_client := chi.NewRouter()
initCiClientRouter(router_client, ciHandler, wgInterfaceManager)
router.Mount("/cloud-init", router_client)
router.Mount("/cloud-init", router_client)

router_admin := chi.NewRouter()
if secureRouteEnable {
@@ -185,15 +185,14 @@ func main() {
log.Fatal().Err(http.ListenAndServe(ciEndpoint, router)).Msg("Server closed")
}


func initCiClientRouter(router chi.Router, handler *CiHandler, wgInterfaceManager *wgtunnel.InterfaceManager) {
// Add cloud-init endpoints to router
router.With(wireGuardMiddleware).Get("/user-data", UserDataHandler)
router.With(wireGuardMiddleware).Get("/meta-data", MetaDataHandler(handler.sm, handler.store))
router.With(wireGuardMiddleware).Get("/vendor-data", VendorDataHandler(handler.sm, handler.store))
router.With(wireGuardMiddleware).Get("/{group}.yaml", GroupUserDataHandler(handler.sm, handler.store))
router.Post("/phone-home/{id}", PhoneHomeHandler(handler.store))
router.Post("/wg-init", wgtunnel.AddClientHandler(wgInterfaceManager, handler.sm))
router.With(wireGuardMiddleware).Get("/user-data", UserDataHandler)
router.With(wireGuardMiddleware).Get("/meta-data", MetaDataHandler(handler.sm, handler.store))
router.With(wireGuardMiddleware).Get("/vendor-data", VendorDataHandler(handler.sm, handler.store))
router.With(wireGuardMiddleware).Get("/{group}.yaml", GroupUserDataHandler(handler.sm, handler.store))
router.Post("/phone-home/{id}", PhoneHomeHandler(wgInterfaceManager, handler.sm))
router.Post("/wg-init", wgtunnel.AddClientHandler(wgInterfaceManager, handler.sm))
}

func initCiAdminRouter(router chi.Router, handler *CiHandler) {
12 changes: 9 additions & 3 deletions cmd/cloud-init-server/metadata_test.go
Original file line number Diff line number Diff line change
@@ -9,6 +9,12 @@ import (
)

func TestGenerateHostname(t *testing.T) {
clusterDefaults := cistore.ClusterDefaults{
ClusterName: "cluster",
ShortName: "cl",
NidLength: 4,
}

tests := []struct {
clusterName string
component cistore.OpenCHAMIComponent
@@ -32,7 +38,7 @@ func TestGenerateHostname(t *testing.T) {
NID: json.Number("12"),
},
},
expected: "cl-io12",
expected: "cl0012",
},
{
clusterName: "cluster",
@@ -42,7 +48,7 @@ func TestGenerateHostname(t *testing.T) {
NID: json.Number("34"),
},
},
expected: "cl-fe34",
expected: "cl0034",
},
{
clusterName: "cluster",
@@ -67,7 +73,7 @@ func TestGenerateHostname(t *testing.T) {

for _, tt := range tests {
t.Run(tt.expected, func(t *testing.T) {
got := generateHostname(tt.clusterName, tt.component)
got := generateHostname(tt.clusterName, clusterDefaults.ShortName, clusterDefaults.NidLength, tt.component)
if got != tt.expected {
t.Errorf("generateHostname() = %v, want %v", got, tt.expected)
}
8 changes: 4 additions & 4 deletions pkg/wgtunnel/handlers.go
Original file line number Diff line number Diff line change
@@ -17,7 +17,7 @@ type PublicKeyRequest struct {
}

// addClientHandler handles adding a WireGuard client.
func AddClientHandler(st Store, smdClient smdclient.SMDClientInterface) http.HandlerFunc {
func AddClientHandler(im *InterfaceManager, smdClient smdclient.SMDClientInterface) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodPost {
http.Error(w, "Only POST requests are allowed", http.StatusMethodNotAllowed)
@@ -57,7 +57,7 @@ func AddClientHandler(st Store, smdClient smdclient.SMDClientInterface) http.Han
log.Info().Msgf("Received request: PublicKey=%s, ClientIP=%s\n", publicKey, clientIP)

// Assign a unique IP for the client.
clientVPNIP := st.IpForPeer(clientIP, publicKey)
clientVPNIP := im.IpForPeer(clientIP, publicKey)
if clientVPNIP == "" {
http.Error(w, "Failed to allocate client IP", http.StatusInternalServerError)
return
@@ -76,12 +76,12 @@ func AddClientHandler(st Store, smdClient smdclient.SMDClientInterface) http.Han

// Add the client to the WireGuard configuration.
log.Info().Msgf("Adding WireGuard peer: PublicKey=%s, ClientVPNIP=%s, ClientIP=%s\n", publicKey, clientVPNIP, clientIP)
if err := AddWireGuardPeer(st.GetInterfaceName(), publicKey, clientVPNIP, clientIP); err != nil {
if err := im.AddPeer(im.GetInterfaceName(), publicKey, clientVPNIP, clientIP); err != nil {
http.Error(w, "Failed to configure WireGuard tunnel: "+err.Error(), http.StatusInternalServerError)
return
}

serverConfig, err := st.GetServerConfig()
serverConfig, err := im.GetServerConfig()
if err != nil {
http.Error(w, "Failed to get server configuration: "+err.Error(), http.StatusInternalServerError)
return
29 changes: 29 additions & 0 deletions pkg/wgtunnel/tunnels.go
Original file line number Diff line number Diff line change
@@ -230,6 +230,35 @@ func (m *InterfaceManager) StartServer() error {

}

func (m *InterfaceManager) StopServer() error {
// Step 1: Bring the interface down
if err := exec.Command("ip", "link", "set", "down", "dev", m.interfaceName).Run(); err != nil {
return fmt.Errorf("failed to bring down the WireGuard interface: %v", err)
}

// Step 2: Delete the WireGuard interface
if err := exec.Command("ip", "link", "delete", "dev", m.interfaceName).Run(); err != nil {
return fmt.Errorf("failed to delete the WireGuard interface: %v", err)
}

return nil
}

func (m *InterfaceManager) AddPeer(peerName, publicKey, vpnIP, clientIP string) error {
m.peersMutex.RLock()
defer m.peersMutex.RUnlock()

// Add the peer to the WireGuard configuration
if err := AddWireGuardPeer(m.interfaceName, publicKey, vpnIP, clientIP); err != nil {
return err
}
m.peers[peerName] = PeerConfig{
PublicKey: publicKey,
IP: net.IPAddr{IP: net.ParseIP(vpnIP), Zone: ""},
}
return nil
}

// AddWireGuardPeer adds a peer to the WireGuard configuration.
func AddWireGuardPeer(interfaceID, publicKey, vpnIP, clientIP string) error {
allowedIPs := vpnIP + "/32"

0 comments on commit f8f34b5

Please sign in to comment.