Skip to content

Commit 26bff6b

Browse files
authored
Merge pull request #248 from snir911/disable_ssh_key
backport: automatically generate ssh keys if missing
2 parents 2071ba3 + d4d598c commit 26bff6b

File tree

3 files changed

+73
-10
lines changed

3 files changed

+73
-10
lines changed

src/cloud-providers/azure/manager.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@ func (_ *Manager) ParseCmd(flags *flag.FlagSet) {
2929
flags.StringVar(&azurecfg.Size, "instance-size", "Standard_DC2as_v5", "Instance size")
3030
flags.StringVar(&azurecfg.ImageId, "imageid", "", "Image Id")
3131
flags.StringVar(&azurecfg.SubscriptionId, "subscriptionid", "", "Subscription ID")
32-
flags.StringVar(&azurecfg.SSHKeyPath, "ssh-key-path", "$HOME/.ssh/id_rsa.pub", "Path to SSH public key")
32+
flags.StringVar(&azurecfg.SSHKeyPath, "ssh-key-path", "", "Path to SSH public key")
3333
flags.StringVar(&azurecfg.SSHUserName, "ssh-username", "peerpod", "SSH User Name")
3434
flags.BoolVar(&azurecfg.DisableCVM, "disable-cvm", false, "Use non-CVMs for peer pods")
3535
// Add a List parameter to indicate different types of instance sizes to be used for the Pod VMs

src/cloud-providers/azure/provider.go

Lines changed: 54 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,8 @@ package azure
55

66
import (
77
"context"
8+
"crypto/rand"
9+
"crypto/rsa"
810
"encoding/base64"
911
"errors"
1012
"fmt"
@@ -22,6 +24,7 @@ import (
2224
provider "github.com/confidential-containers/cloud-api-adaptor/src/cloud-providers"
2325
"github.com/confidential-containers/cloud-api-adaptor/src/cloud-providers/util"
2426
"github.com/confidential-containers/cloud-api-adaptor/src/cloud-providers/util/cloudinit"
27+
"golang.org/x/crypto/ssh"
2528
)
2629

2730
var logger = log.New(log.Writer(), "[adaptor/cloud/azure] ", log.LstdFlags|log.Lmsgprefix)
@@ -42,7 +45,9 @@ func NewProvider(config *Config) (provider.Provider, error) {
4245
logger.Printf("azure config %+v", config.Redact())
4346

4447
// Clean the config.SSHKeyPath to avoid bad paths
45-
config.SSHKeyPath = filepath.Clean(config.SSHKeyPath)
48+
if config.SSHKeyPath != "" {
49+
config.SSHKeyPath = filepath.Clean(config.SSHKeyPath)
50+
}
4651

4752
azureClient, err := NewAzureClient(*config)
4853
if err != nil {
@@ -74,6 +79,37 @@ func parseIP(addr string) (*netip.Addr, error) {
7479
return &ip, nil
7580
}
7681

82+
// generateSSHPublicKey generates a new RSA SSH key pair,
83+
// but doesn't save anything in the filesystem
84+
func generateSSHPublicKey() ([]byte, error) {
85+
logger.Printf("Generating a new in-memory SSH public key")
86+
87+
// Generate RSA private key
88+
bitSize := 4096
89+
privateKey, err := rsa.GenerateKey(rand.Reader, bitSize)
90+
if err != nil {
91+
return nil, fmt.Errorf("failed to generate RSA private key: %w", err)
92+
}
93+
94+
// Validate the private key
95+
err = privateKey.Validate()
96+
if err != nil {
97+
return nil, fmt.Errorf("failed to validate private key: %w", err)
98+
}
99+
100+
// Generate public key
101+
publicKey, err := ssh.NewPublicKey(&privateKey.PublicKey)
102+
if err != nil {
103+
return nil, fmt.Errorf("failed to generate public key: %w", err)
104+
}
105+
106+
// Marshal public key in authorized_keys format
107+
publicKeyBytes := ssh.MarshalAuthorizedKey(publicKey)
108+
109+
logger.Printf("Successfully generated a new in-memory SSH public key")
110+
return publicKeyBytes, nil
111+
}
112+
77113
func (p *azureProvider) getIPs(ctx context.Context, vm *armcompute.VirtualMachine) ([]netip.Addr, error) {
78114
nicClient, err := armnetwork.NewInterfacesClient(p.serviceConfig.SubscriptionId, p.azureClient, nil)
79115
if err != nil {
@@ -218,20 +254,26 @@ func (p *azureProvider) CreateInstance(ctx context.Context, podName, sandboxID s
218254
diskName := fmt.Sprintf("%s-disk", instanceName)
219255
nicName := fmt.Sprintf("%s-net", instanceName)
220256

221-
// require ssh key for authentication on linux
222257
sshPublicKeyPath := os.ExpandEnv(p.serviceConfig.SSHKeyPath)
223258
var sshBytes []byte
224-
if _, err := os.Stat(sshPublicKeyPath); err == nil {
259+
if sshPublicKeyPath != "" {
260+
// SSH key path provided, read the key
261+
logger.Printf("Using existing SSH public key from %s", sshPublicKeyPath)
225262
sshBytes, err = os.ReadFile(sshPublicKeyPath)
226263
if err != nil {
227264
err = fmt.Errorf("reading ssh public key file: %w", err)
228265
logger.Printf("%v", err)
229266
return nil, err
230267
}
231268
} else {
232-
err = fmt.Errorf("ssh public key: %w", err)
233-
logger.Printf("%v", err)
234-
return nil, err
269+
// SSH key path is empty, generate a new key automatically in memory
270+
logger.Printf("SSH public key path is empty, generating new public key")
271+
sshBytes, err = generateSSHPublicKey()
272+
if err != nil {
273+
err = fmt.Errorf("failed to generate SSH public key: %w", err)
274+
logger.Printf("%v", err)
275+
return nil, err
276+
}
235277
}
236278

237279
imageId := p.serviceConfig.ImageId
@@ -307,9 +349,12 @@ func (p *azureProvider) ConfigVerifier() error {
307349
return fmt.Errorf("ImageId is empty")
308350
}
309351

310-
// Verify it's an SSH key file with the right permissions
311-
if err := provider.VerifySSHKeyFile(p.serviceConfig.SSHKeyPath); err != nil {
312-
return fmt.Errorf("SSH key is invalid: %s", err)
352+
// If defined, verify it's an SSH key file with the right permissions
353+
// If empty, it means the SSH key is generated in memory
354+
if p.serviceConfig.SSHKeyPath != "" {
355+
if err := provider.VerifySSHKeyFile(p.serviceConfig.SSHKeyPath); err != nil {
356+
return fmt.Errorf("SSH key is invalid: %s", err)
357+
}
313358
}
314359
return nil
315360
}

src/cloud-providers/azure/types_test.go

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,8 @@ import (
77
"fmt"
88
"strings"
99
"testing"
10+
11+
"golang.org/x/crypto/ssh"
1012
)
1113

1214
func TestAzureMasking(t *testing.T) {
@@ -39,3 +41,19 @@ func TestAzureMasking(t *testing.T) {
3941
checkLine("%v")
4042
checkLine("%s")
4143
}
44+
45+
func TestGenerateSSHKeyPair(t *testing.T) {
46+
publicKeyBytes, err := generateSSHPublicKey()
47+
if err != nil {
48+
t.Fatalf("Failed to generate SSH key pair: %v", err)
49+
}
50+
51+
if len(publicKeyBytes) == 0 {
52+
t.Error("Generated public key bytes are empty")
53+
}
54+
55+
_, _, _, _, err = ssh.ParseAuthorizedKey(publicKeyBytes)
56+
if err != nil {
57+
t.Errorf("Failed to parse generated public key: %v", err)
58+
}
59+
}

0 commit comments

Comments
 (0)