Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Azure] Private Endpoint support #465

Draft
wants to merge 2 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Binary file added dump.rdb
Binary file not shown.
1 change: 1 addition & 0 deletions go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ require (
cloud.google.com/go/iam v1.1.8 // indirect
cloud.google.com/go/longrunning v0.5.7 // indirect
github.com/Azure/azure-sdk-for-go/sdk/internal v1.5.2 // indirect
github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/privatedns/armprivatedns v1.2.0 // indirect
github.com/AzureAD/microsoft-authentication-library-for-go v1.2.2 // indirect
github.com/bytedance/sonic/loader v0.1.1 // indirect
github.com/cloudwego/base64x v0.1.4 // indirect
Expand Down
2 changes: 2 additions & 0 deletions go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,8 @@ github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/network/armnetwork v1.0.0
github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/network/armnetwork v1.0.0/go.mod h1:243D9iHbcQXoFUtgHJwL7gl2zx1aDuDMjvBZVGr2uW0=
github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/network/armnetwork/v4 v4.3.0 h1:bXwSugBiSbgtz7rOtbfGf+woewp4f06orW9OP5BjHLA=
github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/network/armnetwork/v4 v4.3.0/go.mod h1:Y/HgrePTmGy9HjdSGTqZNa+apUpTVIEVKXJyARP2lrk=
github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/privatedns/armprivatedns v1.2.0 h1:9Eih8XcEeQnFD0ntMlUDleKMzfeCeUfa+VbnDCI4AZs=
github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/privatedns/armprivatedns v1.2.0/go.mod h1:wGPyTi+aURdqPAGMZDQqnNs9IrShADF8w2WZb6bKeq0=
github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/resources/armresources v1.2.0 h1:Dd+RhdJn0OTtVGaeDLZpcumkIVCtA/3/Fo42+eoYvVM=
github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/resources/armresources v1.2.0/go.mod h1:5kakwfW5CjC9KK+Q4wjXAg+ShuIm2mBMua0ZFj2C8PE=
github.com/AzureAD/microsoft-authentication-library-for-go v1.2.2 h1:XHOnouVk1mxXfQidrMEnLlPk9UMeRtyBTnEFtxkV0kU=
Expand Down
8 changes: 8 additions & 0 deletions pkg/azure/naming.go
Original file line number Diff line number Diff line change
Expand Up @@ -128,3 +128,11 @@ func getLocalNetworkGatewayName(namespace string, cloud string, idx int) string
func getVirtualNetworkGatewayConnectionName(namespace string, cloud string, idx int) string {
return getParagliderNamespacePrefix(namespace) + "-" + cloud + "-conn-" + strconv.Itoa(idx)
}

func getPrivateEndpointName(namespace string, resourceName string) string {
return getParagliderNamespacePrefix(namespace) + "-" + resourceName + "-pe"
}

func getPrivateLinkConnectionName(namespace string, resourceName string) string {
return getParagliderNamespacePrefix(namespace) + "-" + resourceName + "-link-connection"
}
108 changes: 108 additions & 0 deletions pkg/azure/plugin_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ package azure
import (
"context"
"encoding/json"
"fmt"
"testing"

"github.com/Azure/azure-sdk-for-go/sdk/azcore"
Expand Down Expand Up @@ -744,3 +745,110 @@ func TestAttachResource(t *testing.T) {
require.Nil(t, resp)
})
}

func TestRun(t *testing.T) {
subscriptionId := GetAzureSubscriptionId()
resourceGroupName := "rg-2"
resourceName := "vm6"
namespace := "default"
location := "global"
server := InitializeServer("localhost:50051")
ctx := context.Background()
// storageAccountID := "/subscriptions/2051cdb8-80a4-48a4-840c-ac989eb486a2/resourceGroups/julian-rg/providers/Microsoft.Storage/storageAccounts/julianstorage1"
storageAccountID := "/subscriptions/2051cdb8-80a4-48a4-840c-ac989eb486a2/resourceGroups/julian-rg/providers/Microsoft.Storage/storageAccounts/julianteststorage1"
storageAccountName := "julianteststorage1"
fmt.Println("Storage Account ID: ", storageAccountID)
resourceInfo := ResourceIDInfo{SubscriptionID: subscriptionId, ResourceGroupName: resourceGroupName, ResourceName: resourceName}
handler, err := server.setupAzureHandler(resourceInfo, namespace)

// Call the function you want to test
if err != nil {
// Handle errors
fmt.Println("Error in setupAzureHandler ")
}

fmt.Println("Starting server")
fmt.Println("Subscription ID: ", subscriptionId)
resourceID := getVmUri(subscriptionId, resourceGroupName, resourceName)
fmt.Println("Resource ID: ", resourceID)

// // attach a resource
// attachReq := &paragliderpb.AttachResourceRequest{
// Namespace: namespace,
// Resource: getVmUri(subscriptionId, resourceGroupName, resourceName),
// }

// _, err = server.AttachResource(ctx, attachReq)
// if err != nil {
// // Handle errors
// fmt.Println("Error in AttachResource, ", err)
// return
// }

// Create dns zone
dnsZoneName := namespace + ".vnet.paraglider.com"
dnsZoneParams := getPrivateDNSZoneParams(location)
dnsZone, err := handler.CreatePrivateDNSZone(ctx, dnsZoneName, dnsZoneParams)
if err != nil {
// Handle errors
fmt.Println("Error in CreatePrivateDNSZone: ", err)
return
}
fmt.Println("Private DNS Zone Name: ", *dnsZone.Name)
fmt.Println("Private DNS Zone ID: ", *dnsZone.ID)

netInfo, err := GetNetworkInfoFromResource(ctx, handler, resourceID)
if err != nil {
// Handle errors
fmt.Println("Error in GetNetworkInfoFromResource")
return
}

vnetName := getVnetFromSubnetId(netInfo.SubnetID)
vnet, err := handler.GetVirtualNetwork(ctx, vnetName)
if err != nil {
// Handle errors
fmt.Println("Error in GetVirtualNetwork")
return
}

// todo: assert that the state of the link created is completed
// Create virtual Network Link
vnetLinkParams := getVirtualNetworkLinkParams(*vnet.ID, location)
_, err = handler.CreateVirtualNetworkLink(ctx, dnsZoneName, "vnet-link", vnetLinkParams)
if err != nil {
fmt.Println("Error in CreateVirtualNetworkLink: ", err)
return
}

peName := getPrivateEndpointName(namespace, storageAccountName)
plName := getPrivateLinkConnectionName(namespace, storageAccountName)
// create a private endpoint
params := getPrivateEndpointParams(plName, netInfo.SubnetID, storageAccountID, "eastus", "blob")
privEndpoint, err := handler.CreatePrivateEndpoint(ctx, peName, params)
if err != nil {
fmt.Println("Error in CreatePrivateEndpoint: ", err)
return
}
fmt.Println("Private Endpoint ID: ", *privEndpoint.ID)
privEndpointIp := *privEndpoint.Properties.CustomDNSConfigs[0].IPAddresses[0]
if len(privEndpoint.Properties.IPConfigurations) > 0 {
fmt.Println("Private Endpoint IP: ", *privEndpoint.Properties.IPConfigurations[0].Properties.PrivateIPAddress)
}
if (len(privEndpoint.Properties.PrivateLinkServiceConnections)) > 0 {
fmt.Println("Private Link Service Connection ID: ", *privEndpoint.Properties.PrivateLinkServiceConnections[0].ID)
}
if (len(privEndpoint.Properties.CustomDNSConfigs)) > 0 {
fmt.Println("Custom DNS Config: ", *privEndpoint.Properties.CustomDNSConfigs[0].IPAddresses[0])
}
fqdn := *privEndpoint.Properties.CustomDNSConfigs[0].Fqdn
privateIps := []string{privEndpointIp}
recordSetParams := getDnsRecordSetParams(privateIps)
_, err = handler.CreateDnsRecordSet(ctx, dnsZoneName, fqdn, recordSetParams)
if err != nil {
fmt.Println("Error in CreateDnsRecordSet: ", err)
return
}

fmt.Println("DONE")
}
11 changes: 11 additions & 0 deletions pkg/azure/resources.go
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,8 @@ func getResourceHandler(resourceID string) (AzureResourceHandler, error) {
return &azureResourceHandlerVM{}, nil
} else if strings.Contains(resourceID, managedClusterTypeName) {
return &azureResourceHandlerAKS{}, nil
} else if strings.Contains(resourceID, privateEndpointTypeName) {
return &azureResourceHandlePrivateEndpoint{}, nil
} else {
return nil, fmt.Errorf("resource type %s is not supported", resourceID)
}
Expand Down Expand Up @@ -488,3 +490,12 @@ func (r *azureResourceHandlerAKS) fromResourceDecription(resourceDesc []byte) (*

return aks, nil
}

// Private Endpoint implementation of the NewAzureResourceHandler interface
type azureResourceHandlePrivateEndpoint struct {
AzureResourceHandler
}

func getNetworkInfo(ctx context.Context, endpointUri string) {

}
109 changes: 109 additions & 0 deletions pkg/azure/sdk_handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ import (
"github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/compute/armcompute/v4"
"github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/containerservice/armcontainerservice/v4"
"github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/network/armnetwork/v4"
"github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/privatedns/armprivatedns"
"github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/resources/armresources"
)

Expand All @@ -45,6 +46,7 @@ type AzureSDKHandler struct {
computeClientFactory *armcompute.ClientFactory
networkClientFactory *armnetwork.ClientFactory
containerServiceClientFactory *armcontainerservice.ClientFactory
dnsClientFactory *armprivatedns.ClientFactory
securityGroupsClient *armnetwork.SecurityGroupsClient
interfacesClient *armnetwork.InterfacesClient
securityRulesClient *armnetwork.SecurityRulesClient
Expand All @@ -59,6 +61,10 @@ type AzureSDKHandler struct {
subnetsClient *armnetwork.SubnetsClient
virtualNetworkGatewayConnectionsClient *armnetwork.VirtualNetworkGatewayConnectionsClient
localNetworkGatewaysClient *armnetwork.LocalNetworkGatewaysClient
privateEndpointClient *armnetwork.PrivateEndpointsClient
privateDNSZoneClient *armprivatedns.PrivateZonesClient
virtualNetworkLinkClient *armprivatedns.VirtualNetworkLinksClient
recordSetClient *armprivatedns.RecordSetsClient
subscriptionID string
resourceGroupName string
paragliderNamespace string
Expand Down Expand Up @@ -133,6 +139,11 @@ func (h *AzureSDKHandler) InitializeClients(cred azcore.TokenCredential) error {
return err
}

h.dnsClientFactory, err = armprivatedns.NewClientFactory(h.subscriptionID, cred, nil)
if err != nil {
return err
}

h.securityGroupsClient = h.networkClientFactory.NewSecurityGroupsClient()
h.interfacesClient = h.networkClientFactory.NewInterfacesClient()
h.networkPeeringClient = h.networkClientFactory.NewVirtualNetworkPeeringsClient()
Expand All @@ -147,6 +158,10 @@ func (h *AzureSDKHandler) InitializeClients(cred azcore.TokenCredential) error {
h.resourcesClient = h.resourcesClientFactory.NewClient()
h.virtualMachinesClient = h.computeClientFactory.NewVirtualMachinesClient()
h.managedClustersClient = h.containerServiceClientFactory.NewManagedClustersClient()
h.privateEndpointClient = h.networkClientFactory.NewPrivateEndpointsClient()
h.privateDNSZoneClient = h.dnsClientFactory.NewPrivateZonesClient()
h.virtualNetworkLinkClient = h.dnsClientFactory.NewVirtualNetworkLinksClient()
h.recordSetClient = h.dnsClientFactory.NewRecordSetsClient()

return nil
}
Expand Down Expand Up @@ -607,6 +622,100 @@ func (h *AzureSDKHandler) GetVirtualNetwork(ctx context.Context, name string) (*
return &resp.VirtualNetwork, nil
}

func (h *AzureSDKHandler) CreatePrivateEndpoint(ctx context.Context, privateEndpointName string, parameters armnetwork.PrivateEndpoint) (*armnetwork.PrivateEndpoint, error) {
pollerResponse, err := h.privateEndpointClient.BeginCreateOrUpdate(ctx, h.resourceGroupName, privateEndpointName, parameters, nil)
if err != nil {
return nil, err
}
resp, err := pollerResponse.PollUntilDone(ctx, nil)
if err != nil {
return nil, err
}
return &resp.PrivateEndpoint, nil
}

func (h *AzureSDKHandler) GetAllPrivateEndpoints(ctx context.Context) map[string]*armnetwork.PrivateEndpoint {
endpoints := make(map[string]*armnetwork.PrivateEndpoint)
pager := h.privateEndpointClient.NewListPager(h.resourceGroupName, nil)
for pager.More() {
page, err := pager.NextPage(ctx)
if err != nil {
utils.Log.Printf("Failed to get private endpoints: %v", err)
return nil
}
for _, v := range page.Value {
endpoints[*v.Name] = v
}
}

return endpoints
}

func (h *AzureSDKHandler) GetPrivateEndpoint(ctx context.Context, privateEndpointName string) (*armnetwork.PrivateEndpoint, error) {
resp, err := h.privateEndpointClient.Get(ctx, h.resourceGroupName, privateEndpointName, nil)
if err != nil {
return nil, err
}
return &resp.PrivateEndpoint, nil
}

func (h *AzureSDKHandler) CreatePrivateDNSZone(ctx context.Context, privateDNSZoneName string, parameters armprivatedns.PrivateZone) (*armprivatedns.PrivateZone, error) {
pollerResponse, err := h.privateDNSZoneClient.BeginCreateOrUpdate(ctx, h.resourceGroupName, privateDNSZoneName, parameters, nil)
if err != nil {
return nil, err
}
resp, err := pollerResponse.PollUntilDone(ctx, nil)
if err != nil {
return nil, err
}
return &resp.PrivateZone, nil
}

func (h AzureSDKHandler) GetPrivateDNSZone(ctx context.Context, privateDNSZoneName string) (*armprivatedns.PrivateZone, error) {
resp, err := h.privateDNSZoneClient.Get(ctx, h.resourceGroupName, privateDNSZoneName, nil)
if err != nil {
return nil, err
}
return &resp.PrivateZone, nil
}

func (h *AzureSDKHandler) CreateVirtualNetworkLink(ctx context.Context, privateZoneName, virtualNetworkLinkName string, parameters armprivatedns.VirtualNetworkLink) (*armprivatedns.VirtualNetworkLink, error) {
pollerResponse, err := h.virtualNetworkLinkClient.BeginCreateOrUpdate(ctx, h.resourceGroupName, privateZoneName, virtualNetworkLinkName, parameters, nil)
if err != nil {
return nil, err
}
resp, err := pollerResponse.PollUntilDone(ctx, nil)
if err != nil {
return nil, err
}
return &resp.VirtualNetworkLink, nil
}

func (h *AzureSDKHandler) GetVirtualNetworkLink(ctx context.Context, privateZoneName, virtualNetworkLinkName string) (*armprivatedns.VirtualNetworkLink, error) {
resp, err := h.virtualNetworkLinkClient.Get(ctx, h.resourceGroupName, privateZoneName, virtualNetworkLinkName, nil)
if err != nil {
return nil, err
}
return &resp.VirtualNetworkLink, nil
}

func (h *AzureSDKHandler) CreateDnsRecordSet(ctx context.Context, privateZoneName string, recordSetName string, parameters armprivatedns.RecordSet) (*armprivatedns.RecordSet, error) {
// Record type "A" maps a name to an IPv4 address
resp, err := h.recordSetClient.CreateOrUpdate(ctx, h.resourceGroupName, privateZoneName, armprivatedns.RecordTypeA, recordSetName, parameters, nil)
if err != nil {
return nil, err
}
return &resp.RecordSet, nil
}

func (h *AzureSDKHandler) GetDnsRecordSet(ctx context.Context, privateZoneName string, recordType armprivatedns.RecordType, recordSetName string) (*armprivatedns.RecordSet, error) {
resp, err := h.recordSetClient.Get(ctx, h.resourceGroupName, privateZoneName, recordType, recordSetName, nil)
if err != nil {
return nil, err
}
return &resp.RecordSet, nil
}

func (h *AzureSDKHandler) CreateSecurityGroup(ctx context.Context, resourceName string, location string, allowedCIDRs map[string]string) (*armnetwork.SecurityGroup, error) {
nsgParameters := armnetwork.SecurityGroup{
Location: to.Ptr(location),
Expand Down
59 changes: 59 additions & 0 deletions pkg/azure/utils.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ import (
"github.com/Azure/azure-sdk-for-go/sdk/azidentity"
"github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/compute/armcompute/v4"
"github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/network/armnetwork/v4"
"github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/privatedns/armprivatedns"
"github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/resources/armresources"
"github.com/paraglider-project/paraglider/pkg/paragliderpb"
utils "github.com/paraglider-project/paraglider/pkg/utils"
Expand All @@ -44,6 +45,7 @@ const (
virtualNetworkGatewayTypeName = "Microsoft.Network/virtualNetworkGateways"
virtualNetworkTypeName = "Microsoft.Network/virtualNetworks"
networkWatcherTypeName = "Microsoft.Network/networkWatchers"
privateEndpointTypeName = "Microsoft.Network/" // todo @J-467: update this
)

// Gets subscription ID defined in environment variable
Expand Down Expand Up @@ -554,3 +556,60 @@ func getVirtualNetworkParameters(location string, addressSpace string) armnetwor
},
}
}

func getPrivateEndpointParams(connectionName string, subnetID string, privateLinkServiceID string, location string, subresource string) armnetwork.PrivateEndpoint {
return armnetwork.PrivateEndpoint{
Location: to.Ptr(location),
Properties: &armnetwork.PrivateEndpointProperties{
PrivateLinkServiceConnections: []*armnetwork.PrivateLinkServiceConnection{
{
Name: to.Ptr(connectionName),
Properties: &armnetwork.PrivateLinkServiceConnectionProperties{
PrivateLinkServiceID: to.Ptr(privateLinkServiceID),
GroupIDs: []*string{to.Ptr(subresource)},
RequestMessage: to.Ptr("Connection request from Paraglider"),
},
},
},
Subnet: &armnetwork.Subnet{
ID: to.Ptr(subnetID),
},
},
}
}

func getPrivateDNSZoneParams(location string) armprivatedns.PrivateZone {
return armprivatedns.PrivateZone{
Location: to.Ptr(location),
}
}

func getVirtualNetworkLinkParams(virtualNetworkID string, location string) armprivatedns.VirtualNetworkLink {
return armprivatedns.VirtualNetworkLink{
Location: to.Ptr(location),
Properties: &armprivatedns.VirtualNetworkLinkProperties{
RegistrationEnabled: to.Ptr(true),
VirtualNetwork: &armprivatedns.SubResource{
ID: to.Ptr(virtualNetworkID),
},
},
}
}


// Returns parameters for A record set with the specified IP address
func getDnsRecordSetParams(ipAddresses []string) armprivatedns.RecordSet {
aRecordAddresses := make([]*armprivatedns.ARecord, len(ipAddresses))
for i, ipAddress := range ipAddresses {
aRecordAddresses[i] = &armprivatedns.ARecord{
IPv4Address: to.Ptr(ipAddress),
}
}

return armprivatedns.RecordSet{
Properties: &armprivatedns.RecordSetProperties{
ARecords: aRecordAddresses,
TTL: to.Ptr[int64](3600), // 1 hour cache by clients
},
}
}