@@ -5,6 +5,8 @@ package azure
55
66import (
77 "context"
8+ "crypto/rand"
9+ "crypto/rsa"
810 "encoding/base64"
911 "errors"
1012 "fmt"
@@ -22,6 +24,7 @@ import (
2224 provider "github.com/confidential-containers/cloud-api-adaptor/src/cloud-providers"
2325 "github.com/confidential-containers/cloud-api-adaptor/src/cloud-providers/util"
2426 "github.com/confidential-containers/cloud-api-adaptor/src/cloud-providers/util/cloudinit"
27+ "golang.org/x/crypto/ssh"
2528)
2629
2730var logger = log .New (log .Writer (), "[adaptor/cloud/azure] " , log .LstdFlags | log .Lmsgprefix )
@@ -42,7 +45,9 @@ func NewProvider(config *Config) (provider.Provider, error) {
4245 logger .Printf ("azure config %+v" , config .Redact ())
4346
4447 // Clean the config.SSHKeyPath to avoid bad paths
45- config .SSHKeyPath = filepath .Clean (config .SSHKeyPath )
48+ if config .SSHKeyPath != "" {
49+ config .SSHKeyPath = filepath .Clean (config .SSHKeyPath )
50+ }
4651
4752 azureClient , err := NewAzureClient (* config )
4853 if err != nil {
@@ -74,6 +79,37 @@ func parseIP(addr string) (*netip.Addr, error) {
7479 return & ip , nil
7580}
7681
82+ // generateSSHPublicKey generates a new RSA SSH key pair,
83+ // but doesn't save anything in the filesystem
84+ func generateSSHPublicKey () ([]byte , error ) {
85+ logger .Printf ("Generating a new in-memory SSH public key" )
86+
87+ // Generate RSA private key
88+ bitSize := 4096
89+ privateKey , err := rsa .GenerateKey (rand .Reader , bitSize )
90+ if err != nil {
91+ return nil , fmt .Errorf ("failed to generate RSA private key: %w" , err )
92+ }
93+
94+ // Validate the private key
95+ err = privateKey .Validate ()
96+ if err != nil {
97+ return nil , fmt .Errorf ("failed to validate private key: %w" , err )
98+ }
99+
100+ // Generate public key
101+ publicKey , err := ssh .NewPublicKey (& privateKey .PublicKey )
102+ if err != nil {
103+ return nil , fmt .Errorf ("failed to generate public key: %w" , err )
104+ }
105+
106+ // Marshal public key in authorized_keys format
107+ publicKeyBytes := ssh .MarshalAuthorizedKey (publicKey )
108+
109+ logger .Printf ("Successfully generated a new in-memory SSH public key" )
110+ return publicKeyBytes , nil
111+ }
112+
77113func (p * azureProvider ) getIPs (ctx context.Context , vm * armcompute.VirtualMachine ) ([]netip.Addr , error ) {
78114 nicClient , err := armnetwork .NewInterfacesClient (p .serviceConfig .SubscriptionId , p .azureClient , nil )
79115 if err != nil {
@@ -218,20 +254,26 @@ func (p *azureProvider) CreateInstance(ctx context.Context, podName, sandboxID s
218254 diskName := fmt .Sprintf ("%s-disk" , instanceName )
219255 nicName := fmt .Sprintf ("%s-net" , instanceName )
220256
221- // require ssh key for authentication on linux
222257 sshPublicKeyPath := os .ExpandEnv (p .serviceConfig .SSHKeyPath )
223258 var sshBytes []byte
224- if _ , err := os .Stat (sshPublicKeyPath ); err == nil {
259+ if sshPublicKeyPath != "" {
260+ // SSH key path provided, read the key
261+ logger .Printf ("Using existing SSH public key from %s" , sshPublicKeyPath )
225262 sshBytes , err = os .ReadFile (sshPublicKeyPath )
226263 if err != nil {
227264 err = fmt .Errorf ("reading ssh public key file: %w" , err )
228265 logger .Printf ("%v" , err )
229266 return nil , err
230267 }
231268 } else {
232- err = fmt .Errorf ("ssh public key: %w" , err )
233- logger .Printf ("%v" , err )
234- return nil , err
269+ // SSH key path is empty, generate a new key automatically in memory
270+ logger .Printf ("SSH public key path is empty, generating new public key" )
271+ sshBytes , err = generateSSHPublicKey ()
272+ if err != nil {
273+ err = fmt .Errorf ("failed to generate SSH public key: %w" , err )
274+ logger .Printf ("%v" , err )
275+ return nil , err
276+ }
235277 }
236278
237279 imageId := p .serviceConfig .ImageId
@@ -307,9 +349,12 @@ func (p *azureProvider) ConfigVerifier() error {
307349 return fmt .Errorf ("ImageId is empty" )
308350 }
309351
310- // Verify it's an SSH key file with the right permissions
311- if err := provider .VerifySSHKeyFile (p .serviceConfig .SSHKeyPath ); err != nil {
312- return fmt .Errorf ("SSH key is invalid: %s" , err )
352+ // If defined, verify it's an SSH key file with the right permissions
353+ // If empty, it means the SSH key is generated in memory
354+ if p .serviceConfig .SSHKeyPath != "" {
355+ if err := provider .VerifySSHKeyFile (p .serviceConfig .SSHKeyPath ); err != nil {
356+ return fmt .Errorf ("SSH key is invalid: %s" , err )
357+ }
313358 }
314359 return nil
315360}
0 commit comments