From 0b9ab6b70788b683a0e9c59ac17f694331acb673 Mon Sep 17 00:00:00 2001 From: lakshmimsft Date: Wed, 28 Jan 2026 00:13:09 -0800 Subject: [PATCH 1/3] adding encryption/decryption logic Signed-off-by: lakshmimsft --- pkg/crypto/encryption/encryption.go | 269 +++++++ pkg/crypto/encryption/encryption_test.go | 594 +++++++++++++++ pkg/crypto/encryption/keyprovider.go | 159 ++++ pkg/crypto/encryption/keyprovider_test.go | 239 ++++++ pkg/crypto/encryption/sensitive.go | 611 ++++++++++++++++ pkg/crypto/encryption/sensitive_test.go | 844 ++++++++++++++++++++++ 6 files changed, 2716 insertions(+) create mode 100644 pkg/crypto/encryption/encryption.go create mode 100644 pkg/crypto/encryption/encryption_test.go create mode 100644 pkg/crypto/encryption/keyprovider.go create mode 100644 pkg/crypto/encryption/keyprovider_test.go create mode 100644 pkg/crypto/encryption/sensitive.go create mode 100644 pkg/crypto/encryption/sensitive_test.go diff --git a/pkg/crypto/encryption/encryption.go b/pkg/crypto/encryption/encryption.go new file mode 100644 index 0000000000..0780435592 --- /dev/null +++ b/pkg/crypto/encryption/encryption.go @@ -0,0 +1,269 @@ +/* +Copyright 2023 The Radius Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package encryption + +import ( + "crypto/cipher" + "crypto/rand" + "crypto/sha256" + "encoding/base64" + "encoding/json" + "errors" + "fmt" + "io" + + "golang.org/x/crypto/chacha20poly1305" +) + +const ( + // KeySize is the required size for ChaCha20-Poly1305 keys (256 bits). + KeySize = chacha20poly1305.KeySize + + // NonceSize is the size of the nonce for ChaCha20-Poly1305. + NonceSize = chacha20poly1305.NonceSize +) + +var ( + // ErrInvalidKeySize is returned when the encryption key is not the correct size. + ErrInvalidKeySize = errors.New("encryption key must be 32 bytes (256 bits)") + + // ErrEncryptionFailed is returned when encryption fails. + ErrEncryptionFailed = errors.New("encryption failed") + + // ErrDecryptionFailed is returned when decryption fails. + ErrDecryptionFailed = errors.New("decryption failed") + + // ErrInvalidEncryptedData is returned when the encrypted data format is invalid. + ErrInvalidEncryptedData = errors.New("invalid encrypted data format") + + // ErrEmptyPlaintext is returned when attempting to encrypt empty data. + ErrEmptyPlaintext = errors.New("plaintext cannot be empty") + + // ErrAssociatedDataMismatch is returned when the associated data provided during + // decryption does not match what was used during encryption. + ErrAssociatedDataMismatch = errors.New("associated data mismatch") +) + +// EncryptedData represents the structure for storing encrypted data. +// It contains the base64-encoded ciphertext and nonce, plus optional associated data hash. +type EncryptedData struct { + // Encrypted contains the base64-encoded ciphertext. + Encrypted string `json:"encrypted"` + // Nonce contains the base64-encoded nonce used for encryption. + Nonce string `json:"nonce"` + // AD contains a hash of the associated data used during encryption (optional). + // This is stored for verification purposes - the actual AD must be provided during decryption. + // The hash allows detection of AD mismatches without exposing the AD value. + AD string `json:"ad,omitempty"` +} + +// Encryptor provides methods for encrypting and decrypting data using ChaCha20-Poly1305. +type Encryptor struct { + aead cipher.AEAD +} + +// NewEncryptor creates a new Encryptor with the provided 256-bit key. +// Returns an error if the key is not exactly 32 bytes. +func NewEncryptor(key []byte) (*Encryptor, error) { + if len(key) != KeySize { + return nil, ErrInvalidKeySize + } + + aead, err := chacha20poly1305.New(key) + if err != nil { + return nil, fmt.Errorf("%w: %v", ErrEncryptionFailed, err) + } + + return &Encryptor{aead: aead}, nil +} + +// Encrypt encrypts the plaintext using ChaCha20-Poly1305 with Associated Data (AD). +// The AD provides authentication for contextual data (like resource ID or field path) without +// encrypting it. This binds the ciphertext to its context, preventing an attacker from +// moving encrypted values between different resources or fields. +// +// The AD is authenticated but NOT encrypted - it must be provided again during decryption. +// A hash of the AD is stored in the encrypted data structure to allow early detection of mismatches. +// +// Example AD values: +// - Resource ID: "/planes/radius/local/resourceGroups/test/providers/Foo.Bar/myResources/test" +// - Field path: "credentials.password" +// - Combined: resourceID + ":" + fieldPath +// +// Pass nil for associatedData if no context binding is needed (not recommended for sensitive data). +func (e *Encryptor) Encrypt(plaintext []byte, associatedData []byte) ([]byte, error) { + if len(plaintext) == 0 { + return nil, ErrEmptyPlaintext + } + + // Generate a unique nonce for this encryption operation + nonce, err := generateNonce(e.aead.NonceSize()) + if err != nil { + return nil, fmt.Errorf("%w: failed to generate nonce: %v", ErrEncryptionFailed, err) + } + + // Encrypt the plaintext with associated data + // The AD is authenticated (included in the auth tag) but not encrypted + ciphertext := e.aead.Seal(nil, nonce, plaintext, associatedData) + + // Create the encrypted data structure + encryptedData := EncryptedData{ + Encrypted: base64.StdEncoding.EncodeToString(ciphertext), + Nonce: base64.StdEncoding.EncodeToString(nonce), + } + + // Store a hash of the AD if provided (for verification during decryption) + if len(associatedData) > 0 { + encryptedData.AD = hashAD(associatedData) + } + + // Marshal to JSON + result, err := json.Marshal(encryptedData) + if err != nil { + return nil, fmt.Errorf("%w: failed to marshal encrypted data: %v", ErrEncryptionFailed, err) + } + + return result, nil +} + +// Decrypt decrypts the data that was encrypted using the Encrypt method. +// The associatedData must match what was provided during encryption; if the AD +// was used during encryption, it must be provided here for successful decryption. +// The input should be JSON-encoded EncryptedData. +func (e *Encryptor) Decrypt(data []byte, associatedData []byte) ([]byte, error) { + if len(data) == 0 { + return nil, ErrInvalidEncryptedData + } + + // Parse the encrypted data structure + var encryptedData EncryptedData + if err := json.Unmarshal(data, &encryptedData); err != nil { + return nil, fmt.Errorf("%w: failed to parse encrypted data: %v", ErrInvalidEncryptedData, err) + } + + // Verify AD hash matches if AD was used during encryption + if encryptedData.AD != "" { + if len(associatedData) == 0 { + return nil, fmt.Errorf("%w: encrypted data requires associated data but none provided", ErrAssociatedDataMismatch) + } + if hashAD(associatedData) != encryptedData.AD { + return nil, fmt.Errorf("%w: provided associated data does not match", ErrAssociatedDataMismatch) + } + } + + // Decode the base64-encoded ciphertext + ciphertext, err := base64.StdEncoding.DecodeString(encryptedData.Encrypted) + if err != nil { + return nil, fmt.Errorf("%w: failed to decode ciphertext: %v", ErrInvalidEncryptedData, err) + } + + // Decode the base64-encoded nonce + nonce, err := base64.StdEncoding.DecodeString(encryptedData.Nonce) + if err != nil { + return nil, fmt.Errorf("%w: failed to decode nonce: %v", ErrInvalidEncryptedData, err) + } + + // Validate nonce size + if len(nonce) != e.aead.NonceSize() { + return nil, fmt.Errorf("%w: invalid nonce size", ErrInvalidEncryptedData) + } + + // Decrypt the ciphertext with the same associated data + plaintext, err := e.aead.Open(nil, nonce, ciphertext, associatedData) + if err != nil { + return nil, fmt.Errorf("%w: %v", ErrDecryptionFailed, err) + } + + return plaintext, nil +} + +// EncryptString encrypts a string with associated data and returns the JSON-encoded encrypted data as a string. +func (e *Encryptor) EncryptString(plaintext string, associatedData []byte) (string, error) { + encrypted, err := e.Encrypt([]byte(plaintext), associatedData) + if err != nil { + return "", err + } + return string(encrypted), nil +} + +// DecryptString decrypts the JSON-encoded encrypted data with associated data and returns the original string. +func (e *Encryptor) DecryptString(data string, associatedData []byte) (string, error) { + decrypted, err := e.Decrypt([]byte(data), associatedData) + if err != nil { + return "", err + } + return string(decrypted), nil +} + +// hashAD creates a truncated SHA-256 hash of the associated data for storage. +// This allows verification that the correct AD is provided during decryption +// without storing the actual AD value. +func hashAD(ad []byte) string { + hash := sha256.Sum256(ad) + // Use first 16 bytes (128 bits) - sufficient for verification, saves storage + return base64.StdEncoding.EncodeToString(hash[:16]) +} + +// generateNonce generates a cryptographically secure random nonce. +func generateNonce(size int) ([]byte, error) { + nonce := make([]byte, size) + if _, err := io.ReadFull(rand.Reader, nonce); err != nil { + return nil, err + } + return nonce, nil +} + +// IsEncryptedData checks if the given data appears to be in the encrypted data format. +// It validates that the data is valid JSON with non-empty encrypted and nonce fields, +// and that both fields contain valid base64-encoded data with appropriate nonce size. +func IsEncryptedData(data []byte) bool { + var encryptedData EncryptedData + if err := json.Unmarshal(data, &encryptedData); err != nil { + return false + } + + if encryptedData.Encrypted == "" || encryptedData.Nonce == "" { + return false + } + + // Validate base64 encoding of ciphertext + if _, err := base64.StdEncoding.DecodeString(encryptedData.Encrypted); err != nil { + return false + } + + // Validate base64 encoding and size of nonce + nonce, err := base64.StdEncoding.DecodeString(encryptedData.Nonce) + if err != nil { + return false + } + + // ChaCha20-Poly1305 nonce must be 12 bytes + if len(nonce) != NonceSize { + return false + } + + return true +} + +// GenerateKey generates a new random 256-bit encryption key. +func GenerateKey() ([]byte, error) { + key := make([]byte, KeySize) + if _, err := io.ReadFull(rand.Reader, key); err != nil { + return nil, fmt.Errorf("failed to generate encryption key: %w", err) + } + return key, nil +} diff --git a/pkg/crypto/encryption/encryption_test.go b/pkg/crypto/encryption/encryption_test.go new file mode 100644 index 0000000000..863e89c1e6 --- /dev/null +++ b/pkg/crypto/encryption/encryption_test.go @@ -0,0 +1,594 @@ +/* +Copyright 2023 The Radius Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package encryption + +import ( + "encoding/json" + "testing" + + "github.com/stretchr/testify/require" +) + +func TestNewEncryptor(t *testing.T) { + tests := []struct { + name string + key []byte + wantErr error + }{ + { + name: "valid-32-byte-key", + key: make([]byte, 32), + wantErr: nil, + }, + { + name: "invalid-key-too-short", + key: make([]byte, 16), + wantErr: ErrInvalidKeySize, + }, + { + name: "invalid-key-too-long", + key: make([]byte, 64), + wantErr: ErrInvalidKeySize, + }, + { + name: "invalid-empty-key", + key: []byte{}, + wantErr: ErrInvalidKeySize, + }, + { + name: "invalid-nil-key", + key: nil, + wantErr: ErrInvalidKeySize, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + enc, err := NewEncryptor(tt.key) + if tt.wantErr != nil { + require.ErrorIs(t, err, tt.wantErr) + require.Nil(t, enc) + } else { + require.NoError(t, err) + require.NotNil(t, enc) + } + }) + } +} + +func TestEncrypt(t *testing.T) { + key := make([]byte, 32) + for i := range key { + key[i] = byte(i) + } + enc, err := NewEncryptor(key) + require.NoError(t, err) + + tests := []struct { + name string + plaintext []byte + wantErr error + }{ + { + name: "encrypt-simple-text", + plaintext: []byte("hello world"), + wantErr: nil, + }, + { + name: "encrypt-json-data", + plaintext: []byte(`{"password": "secret123", "token": "abc-xyz"}`), + wantErr: nil, + }, + { + name: "encrypt-binary-data", + plaintext: []byte{0x00, 0x01, 0x02, 0xff, 0xfe, 0xfd}, + wantErr: nil, + }, + { + name: "encrypt-long-text", + plaintext: make([]byte, 10000), + wantErr: nil, + }, + { + name: "encrypt-empty-plaintext", + plaintext: []byte{}, + wantErr: ErrEmptyPlaintext, + }, + { + name: "encrypt-nil-plaintext", + plaintext: nil, + wantErr: ErrEmptyPlaintext, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + encrypted, err := enc.Encrypt(tt.plaintext, nil) + if tt.wantErr != nil { + require.ErrorIs(t, err, tt.wantErr) + require.Nil(t, encrypted) + } else { + require.NoError(t, err) + require.NotNil(t, encrypted) + + // Verify the encrypted data is valid JSON with expected structure + var encData EncryptedData + err = json.Unmarshal(encrypted, &encData) + require.NoError(t, err) + require.NotEmpty(t, encData.Encrypted) + require.NotEmpty(t, encData.Nonce) + } + }) + } +} + +func TestDecrypt(t *testing.T) { + key := make([]byte, 32) + for i := range key { + key[i] = byte(i) + } + enc, err := NewEncryptor(key) + require.NoError(t, err) + + // Create valid encrypted data for testing + validPlaintext := []byte("secret data") + validEncrypted, err := enc.Encrypt(validPlaintext, nil) + require.NoError(t, err) + + tests := []struct { + name string + data []byte + wantErr error + }{ + { + name: "decrypt-valid-data", + data: validEncrypted, + wantErr: nil, + }, + { + name: "decrypt-empty-data", + data: []byte{}, + wantErr: ErrInvalidEncryptedData, + }, + { + name: "decrypt-nil-data", + data: nil, + wantErr: ErrInvalidEncryptedData, + }, + { + name: "decrypt-invalid-json", + data: []byte("not json"), + wantErr: ErrInvalidEncryptedData, + }, + { + name: "decrypt-missing-encrypted-field", + data: []byte(`{"nonce": "dGVzdA=="}`), + wantErr: ErrInvalidEncryptedData, + }, + { + name: "decrypt-missing-nonce-field", + data: []byte(`{"encrypted": "dGVzdA=="}`), + wantErr: ErrInvalidEncryptedData, + }, + { + name: "decrypt-invalid-base64-ciphertext", + data: []byte(`{"encrypted": "not-valid-base64!!!", "nonce": "dGVzdA=="}`), + wantErr: ErrInvalidEncryptedData, + }, + { + name: "decrypt-invalid-base64-nonce", + data: []byte(`{"encrypted": "dGVzdA==", "nonce": "not-valid-base64!!!"}`), + wantErr: ErrInvalidEncryptedData, + }, + { + name: "decrypt-wrong-nonce-size", + data: []byte(`{"encrypted": "dGVzdA==", "nonce": "dGVzdA=="}`), + wantErr: ErrInvalidEncryptedData, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + decrypted, err := enc.Decrypt(tt.data, nil) + if tt.wantErr != nil { + require.ErrorIs(t, err, tt.wantErr) + } else { + require.NoError(t, err) + require.Equal(t, validPlaintext, decrypted) + } + }) + } +} + +func TestEncryptDecryptRoundTrip(t *testing.T) { + key, err := GenerateKey() + require.NoError(t, err) + + enc, err := NewEncryptor(key) + require.NoError(t, err) + + testCases := []struct { + name string + plaintext []byte + }{ + { + name: "simple-text", + plaintext: []byte("hello world"), + }, + { + name: "json-secret", + plaintext: []byte(`{"password": "super-secret-password", "apiKey": "xyz-123-abc"}`), + }, + { + name: "unicode-text", + plaintext: []byte("Hello δΈ–η•Œ! πŸ”"), + }, + { + name: "binary-data", + plaintext: []byte{0x00, 0x01, 0x02, 0x03, 0xff, 0xfe, 0xfd, 0xfc}, + }, + { + name: "large-data", + plaintext: make([]byte, 65536), + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + // Encrypt + encrypted, err := enc.Encrypt(tc.plaintext, nil) + require.NoError(t, err) + require.NotEqual(t, tc.plaintext, encrypted, "encrypted data should differ from plaintext") + + // Decrypt + decrypted, err := enc.Decrypt(encrypted, nil) + require.NoError(t, err) + require.Equal(t, tc.plaintext, decrypted, "decrypted data should match original plaintext") + }) + } +} + +func TestEncryptStringDecryptString(t *testing.T) { + key, err := GenerateKey() + require.NoError(t, err) + + enc, err := NewEncryptor(key) + require.NoError(t, err) + + testCases := []string{ + "simple password", + "API-KEY-12345", + `{"token": "secret"}`, + "Unicode: ζ—₯本θͺž πŸ”‘", + } + + for _, plaintext := range testCases { + t.Run(plaintext, func(t *testing.T) { + encrypted, err := enc.EncryptString(plaintext, nil) + require.NoError(t, err) + require.NotEqual(t, plaintext, encrypted) + + decrypted, err := enc.DecryptString(encrypted, nil) + require.NoError(t, err) + require.Equal(t, plaintext, decrypted) + }) + } +} + +func TestUniqueNoncesPerEncryption(t *testing.T) { + key, err := GenerateKey() + require.NoError(t, err) + + enc, err := NewEncryptor(key) + require.NoError(t, err) + + plaintext := []byte("same plaintext") + nonces := make(map[string]bool) + + // Encrypt the same plaintext multiple times + for i := 0; i < 100; i++ { + encrypted, err := enc.Encrypt(plaintext, nil) + require.NoError(t, err) + + var encData EncryptedData + err = json.Unmarshal(encrypted, &encData) + require.NoError(t, err) + + // Each nonce should be unique + require.False(t, nonces[encData.Nonce], "nonce should be unique for each encryption") + nonces[encData.Nonce] = true + } +} + +func TestDifferentKeysCannotDecrypt(t *testing.T) { + key1, err := GenerateKey() + require.NoError(t, err) + + key2, err := GenerateKey() + require.NoError(t, err) + + enc1, err := NewEncryptor(key1) + require.NoError(t, err) + + enc2, err := NewEncryptor(key2) + require.NoError(t, err) + + plaintext := []byte("secret message") + + // Encrypt with key1 + encrypted, err := enc1.Encrypt(plaintext, nil) + require.NoError(t, err) + + // Try to decrypt with key2 - should fail + _, err = enc2.Decrypt(encrypted, nil) + require.ErrorIs(t, err, ErrDecryptionFailed) +} + +func TestGenerateKey(t *testing.T) { + key1, err := GenerateKey() + require.NoError(t, err) + require.Len(t, key1, KeySize) + + key2, err := GenerateKey() + require.NoError(t, err) + require.Len(t, key2, KeySize) + + // Keys should be different + require.NotEqual(t, key1, key2) +} + +func TestIsEncryptedData(t *testing.T) { + key, err := GenerateKey() + require.NoError(t, err) + + enc, err := NewEncryptor(key) + require.NoError(t, err) + + encrypted, err := enc.Encrypt([]byte("test"), nil) + require.NoError(t, err) + + // Valid 12-byte nonce encoded in base64 for manual test cases + validNonceBase64 := "AAAAAAAAAAAAAAAA" // 12 bytes of zeros in base64 + + tests := []struct { + name string + data []byte + want bool + }{ + { + name: "valid-encrypted-data", + data: encrypted, + want: true, + }, + { + name: "valid-format-manual", + data: []byte(`{"encrypted": "YWJjZGVm", "nonce": "` + validNonceBase64 + `"}`), + want: true, + }, + { + name: "invalid-base64-encrypted", + data: []byte(`{"encrypted": "not-valid-base64!!!", "nonce": "` + validNonceBase64 + `"}`), + want: false, + }, + { + name: "invalid-base64-nonce", + data: []byte(`{"encrypted": "YWJjZGVm", "nonce": "not-valid-base64!!!"}`), + want: false, + }, + { + name: "invalid-nonce-size", + data: []byte(`{"encrypted": "YWJjZGVm", "nonce": "YWJj"}`), // "abc" = 3 bytes, not 12 + want: false, + }, + { + name: "missing-encrypted-field", + data: []byte(`{"nonce": "` + validNonceBase64 + `"}`), + want: false, + }, + { + name: "missing-nonce-field", + data: []byte(`{"encrypted": "YWJjZGVm"}`), + want: false, + }, + { + name: "empty-fields", + data: []byte(`{"encrypted": "", "nonce": ""}`), + want: false, + }, + { + name: "invalid-json", + data: []byte("not json"), + want: false, + }, + { + name: "empty-data", + data: []byte{}, + want: false, + }, + { + name: "nil-data", + data: nil, + want: false, + }, + { + name: "plain-text", + data: []byte("just plain text"), + want: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := IsEncryptedData(tt.data) + require.Equal(t, tt.want, got) + }) + } +} + +func TestTamperDetection(t *testing.T) { + key, err := GenerateKey() + require.NoError(t, err) + + enc, err := NewEncryptor(key) + require.NoError(t, err) + + plaintext := []byte("sensitive data") + encrypted, err := enc.Encrypt(plaintext, nil) + require.NoError(t, err) + + // Parse the encrypted data + var encData EncryptedData + err = json.Unmarshal(encrypted, &encData) + require.NoError(t, err) + + // Tamper with the ciphertext by modifying the base64 string + // We'll change the first character + tampered := encData.Encrypted + if tampered[0] == 'A' { + tampered = "B" + tampered[1:] + } else { + tampered = "A" + tampered[1:] + } + encData.Encrypted = tampered + + tamperedJSON, err := json.Marshal(encData) + require.NoError(t, err) + + // Decryption should fail due to authentication tag mismatch + _, err = enc.Decrypt(tamperedJSON, nil) + require.Error(t, err) +} + +// Tests for Associated Data (AD) functionality + +func TestEncryptWithAssociatedData(t *testing.T) { + key, err := GenerateKey() + require.NoError(t, err) + + enc, err := NewEncryptor(key) + require.NoError(t, err) + + plaintext := []byte("secret data") + resourceID := "/planes/radius/local/resourceGroups/test/providers/Foo.Bar/myResources/test" + fieldPath := "credentials.password" + ad := []byte(resourceID + ":" + fieldPath) + + // Encrypt with AD + encrypted, err := enc.Encrypt(plaintext, ad) + require.NoError(t, err) + + // Verify AD hash is stored + var encData EncryptedData + err = json.Unmarshal(encrypted, &encData) + require.NoError(t, err) + require.NotEmpty(t, encData.AD, "AD hash should be stored") + + // Decrypt with same AD should succeed + decrypted, err := enc.Decrypt(encrypted, ad) + require.NoError(t, err) + require.Equal(t, plaintext, decrypted) +} + +func TestDecryptWithWrongAssociatedData(t *testing.T) { + key, err := GenerateKey() + require.NoError(t, err) + + enc, err := NewEncryptor(key) + require.NoError(t, err) + + plaintext := []byte("secret data") + ad1 := []byte("/resource/1:password") + ad2 := []byte("/resource/2:password") + + // Encrypt with AD1 + encrypted, err := enc.Encrypt(plaintext, ad1) + require.NoError(t, err) + + // Decrypt with different AD should fail with mismatch error + _, err = enc.Decrypt(encrypted, ad2) + require.ErrorIs(t, err, ErrAssociatedDataMismatch) +} + +func TestDecryptWithMissingAssociatedData(t *testing.T) { + key, err := GenerateKey() + require.NoError(t, err) + + enc, err := NewEncryptor(key) + require.NoError(t, err) + + plaintext := []byte("secret data") + ad := []byte("/resource/1:password") + + // Encrypt with AD + encrypted, err := enc.Encrypt(plaintext, ad) + require.NoError(t, err) + + // Decrypt without AD when AD was used should fail + _, err = enc.Decrypt(encrypted, nil) + require.ErrorIs(t, err, ErrAssociatedDataMismatch) +} + +func TestEncryptWithoutAssociatedDataDecryptWithAD(t *testing.T) { + key, err := GenerateKey() + require.NoError(t, err) + + enc, err := NewEncryptor(key) + require.NoError(t, err) + + plaintext := []byte("secret data") + + // Encrypt without AD + encrypted, err := enc.Encrypt(plaintext, nil) + require.NoError(t, err) + + // Verify no AD hash is stored + var encData EncryptedData + err = json.Unmarshal(encrypted, &encData) + require.NoError(t, err) + require.Empty(t, encData.AD, "AD hash should not be stored when no AD provided") + + // Decrypt without AD should succeed + decrypted, err := enc.Decrypt(encrypted, nil) + require.NoError(t, err) + require.Equal(t, plaintext, decrypted) + + // Decrypt with AD when no AD was used - AEAD will fail because the auth tag won't match + _, err = enc.Decrypt(encrypted, []byte("unexpected-ad")) + require.ErrorIs(t, err, ErrDecryptionFailed) +} + +func TestAssociatedDataPreventsContextSwitch(t *testing.T) { + key, err := GenerateKey() + require.NoError(t, err) + + enc, err := NewEncryptor(key) + require.NoError(t, err) + + // Simulate encrypting a password for resource1 + password := []byte("super-secret-password") + resource1AD := []byte("/resource/1:password") + resource2AD := []byte("/resource/2:password") + + // Encrypt password for resource1 + encryptedForResource1, err := enc.Encrypt(password, resource1AD) + require.NoError(t, err) + + // Attacker tries to use this encrypted value for resource2 + // This should fail because the AD is different + _, err = enc.Decrypt(encryptedForResource1, resource2AD) + require.Error(t, err, "should not be able to decrypt with different resource context") +} diff --git a/pkg/crypto/encryption/keyprovider.go b/pkg/crypto/encryption/keyprovider.go new file mode 100644 index 0000000000..ed302ca7e2 --- /dev/null +++ b/pkg/crypto/encryption/keyprovider.go @@ -0,0 +1,159 @@ +/* +Copyright 2023 The Radius Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package encryption + +import ( + "context" + "errors" + "fmt" + + corev1 "k8s.io/api/core/v1" + k8s_error "k8s.io/apimachinery/pkg/api/errors" + controller_runtime "sigs.k8s.io/controller-runtime/pkg/client" +) + +const ( + // DefaultEncryptionKeySecretName is the default name of the Kubernetes Secret containing the encryption key. + + // This is the Secret's name, not actual credentials. + DefaultEncryptionKeySecretName = "radius-encryption-key" //nolint:gosec // This is a Secret name, not credentials + + // DefaultEncryptionKeySecretKey is the key within the Secret that contains the encryption key. + DefaultEncryptionKeySecretKey = "key" + + // RadiusNamespace is the namespace where Radius secrets are stored. + RadiusNamespace = "radius-system" +) + +var ( + // ErrKeyNotFound is returned when the encryption key is not found. + ErrKeyNotFound = errors.New("encryption key not found") + + // ErrKeyLoadFailed is returned when loading the encryption key fails. + ErrKeyLoadFailed = errors.New("failed to load encryption key") +) + +// KeyProvider defines the interface for retrieving encryption keys. +// +//go:generate mockgen -typed -destination=./mock_keyprovider.go -package=encryption -self_package github.com/radius-project/radius/pkg/crypto/encryption github.com/radius-project/radius/pkg/crypto/encryption KeyProvider +type KeyProvider interface { + // GetKey retrieves the encryption key. + // Returns ErrKeyNotFound if the key does not exist. + GetKey(ctx context.Context) ([]byte, error) +} + +// KubernetesKeyProvider implements KeyProvider by loading the encryption key from a Kubernetes Secret. +type KubernetesKeyProvider struct { + client controller_runtime.Client + secretName string + secretKey string + namespace string +} + +// KubernetesKeyProviderOptions contains options for creating a KubernetesKeyProvider. +type KubernetesKeyProviderOptions struct { + // SecretName is the name of the Kubernetes Secret containing the encryption key. + // Defaults to DefaultEncryptionKeySecretName if not specified. + SecretName string + + // SecretKey is the key within the Secret that contains the encryption key. + // Defaults to DefaultEncryptionKeySecretKey if not specified. + SecretKey string + + // Namespace is the namespace where the Secret is located. + // Defaults to RadiusNamespace if not specified. + Namespace string +} + +// NewKubernetesKeyProvider creates a new KubernetesKeyProvider with the given Kubernetes client and options. +func NewKubernetesKeyProvider(client controller_runtime.Client, opts *KubernetesKeyProviderOptions) *KubernetesKeyProvider { + secretName := DefaultEncryptionKeySecretName + secretKey := DefaultEncryptionKeySecretKey + namespace := RadiusNamespace + + if opts != nil { + if opts.SecretName != "" { + secretName = opts.SecretName + } + if opts.SecretKey != "" { + secretKey = opts.SecretKey + } + if opts.Namespace != "" { + namespace = opts.Namespace + } + } + + return &KubernetesKeyProvider{ + client: client, + secretName: secretName, + secretKey: secretKey, + namespace: namespace, + } +} + +// GetKey retrieves the encryption key from the Kubernetes Secret. +func (p *KubernetesKeyProvider) GetKey(ctx context.Context) ([]byte, error) { + secret := &corev1.Secret{} + objectKey := controller_runtime.ObjectKey{ + Name: p.secretName, + Namespace: p.namespace, + } + + if err := p.client.Get(ctx, objectKey, secret); err != nil { + if k8s_error.IsNotFound(err) { + return nil, fmt.Errorf("%w: secret %s/%s not found", ErrKeyNotFound, p.namespace, p.secretName) + } + return nil, fmt.Errorf("%w: %v", ErrKeyLoadFailed, err) + } + + key, ok := secret.Data[p.secretKey] + if !ok { + return nil, fmt.Errorf("%w: key %q not found in secret %s/%s", ErrKeyNotFound, p.secretKey, p.namespace, p.secretName) + } + + if len(key) != KeySize { + return nil, fmt.Errorf("%w: key in secret %s/%s has invalid size (expected %d bytes, got %d)", ErrKeyLoadFailed, p.namespace, p.secretName, KeySize, len(key)) + } + + return key, nil +} + +// InMemoryKeyProvider implements KeyProvider with an in-memory key. +// This is useful for testing or development environments. +type InMemoryKeyProvider struct { + key []byte +} + +// NewInMemoryKeyProvider creates a new InMemoryKeyProvider with the given key. +func NewInMemoryKeyProvider(key []byte) (*InMemoryKeyProvider, error) { + if len(key) != KeySize { + return nil, ErrInvalidKeySize + } + keyCopy := make([]byte, KeySize) + copy(keyCopy, key) + return &InMemoryKeyProvider{key: keyCopy}, nil +} + +// GetKey returns a copy of the in-memory encryption key. +// A copy is returned to prevent callers from mutating the provider's internal state. +func (p *InMemoryKeyProvider) GetKey(ctx context.Context) ([]byte, error) { + if p.key == nil { + return nil, ErrKeyNotFound + } + // Return a copy to prevent mutation of the internal key + return append([]byte(nil), p.key...), nil +} diff --git a/pkg/crypto/encryption/keyprovider_test.go b/pkg/crypto/encryption/keyprovider_test.go new file mode 100644 index 0000000000..7398c66ff4 --- /dev/null +++ b/pkg/crypto/encryption/keyprovider_test.go @@ -0,0 +1,239 @@ +/* +Copyright 2023 The Radius Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package encryption + +import ( + "context" + "testing" + + "github.com/radius-project/radius/test/k8sutil" + "github.com/stretchr/testify/require" + corev1 "k8s.io/api/core/v1" + metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" + "k8s.io/kubectl/pkg/scheme" + controller_runtime "sigs.k8s.io/controller-runtime/pkg/client" +) + +func TestKubernetesKeyProvider_GetKey(t *testing.T) { + ctx := context.Background() + validKey := make([]byte, KeySize) + for i := range validKey { + validKey[i] = byte(i) + } + + tests := []struct { + name string + setupFunc func(k8sClient controller_runtime.Client) + opts *KubernetesKeyProviderOptions + wantErr error + wantKey []byte + wantErrMsg string + }{ + { + name: "success-with-default-options", + setupFunc: func(k8sClient controller_runtime.Client) { + secret := &corev1.Secret{ + ObjectMeta: metav1.ObjectMeta{ + Name: DefaultEncryptionKeySecretName, + Namespace: RadiusNamespace, + }, + Data: map[string][]byte{ + DefaultEncryptionKeySecretKey: validKey, + }, + } + err := k8sClient.Create(ctx, secret) + require.NoError(t, err) + }, + opts: nil, + wantKey: validKey, + }, + { + name: "success-with-custom-options", + setupFunc: func(k8sClient controller_runtime.Client) { + secret := &corev1.Secret{ + ObjectMeta: metav1.ObjectMeta{ + Name: "custom-secret", + Namespace: "custom-namespace", + }, + Data: map[string][]byte{ + "custom-key": validKey, + }, + } + err := k8sClient.Create(ctx, secret) + require.NoError(t, err) + }, + opts: &KubernetesKeyProviderOptions{ + SecretName: "custom-secret", + SecretKey: "custom-key", + Namespace: "custom-namespace", + }, + wantKey: validKey, + }, + { + name: "error-secret-not-found", + setupFunc: func(k8sClient controller_runtime.Client) {}, + opts: nil, + wantErr: ErrKeyNotFound, + wantErrMsg: "not found", + }, + { + name: "error-key-not-in-secret", + setupFunc: func(k8sClient controller_runtime.Client) { + secret := &corev1.Secret{ + ObjectMeta: metav1.ObjectMeta{ + Name: DefaultEncryptionKeySecretName, + Namespace: RadiusNamespace, + }, + Data: map[string][]byte{ + "wrong-key": validKey, + }, + } + err := k8sClient.Create(ctx, secret) + require.NoError(t, err) + }, + opts: nil, + wantErr: ErrKeyNotFound, + wantErrMsg: "not found in secret", + }, + { + name: "error-invalid-key-size", + setupFunc: func(k8sClient controller_runtime.Client) { + secret := &corev1.Secret{ + ObjectMeta: metav1.ObjectMeta{ + Name: DefaultEncryptionKeySecretName, + Namespace: RadiusNamespace, + }, + Data: map[string][]byte{ + DefaultEncryptionKeySecretKey: make([]byte, 16), // Too short + }, + } + err := k8sClient.Create(ctx, secret) + require.NoError(t, err) + }, + opts: nil, + wantErr: ErrKeyLoadFailed, + wantErrMsg: "invalid size", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + k8sClient := k8sutil.NewFakeKubeClient(scheme.Scheme) + tt.setupFunc(k8sClient) + + provider := NewKubernetesKeyProvider(k8sClient, tt.opts) + key, err := provider.GetKey(ctx) + + if tt.wantErr != nil { + require.ErrorIs(t, err, tt.wantErr) + if tt.wantErrMsg != "" { + require.Contains(t, err.Error(), tt.wantErrMsg) + } + require.Nil(t, key) + } else { + require.NoError(t, err) + require.Equal(t, tt.wantKey, key) + } + }) + } +} + +func TestNewKubernetesKeyProvider_DefaultOptions(t *testing.T) { + k8sClient := k8sutil.NewFakeKubeClient(scheme.Scheme) + + // Test with nil options + provider := NewKubernetesKeyProvider(k8sClient, nil) + require.Equal(t, DefaultEncryptionKeySecretName, provider.secretName) + require.Equal(t, DefaultEncryptionKeySecretKey, provider.secretKey) + require.Equal(t, RadiusNamespace, provider.namespace) + + // Test with empty options + provider = NewKubernetesKeyProvider(k8sClient, &KubernetesKeyProviderOptions{}) + require.Equal(t, DefaultEncryptionKeySecretName, provider.secretName) + require.Equal(t, DefaultEncryptionKeySecretKey, provider.secretKey) + require.Equal(t, RadiusNamespace, provider.namespace) +} + +func TestInMemoryKeyProvider(t *testing.T) { + ctx := context.Background() + validKey := make([]byte, KeySize) + for i := range validKey { + validKey[i] = byte(i) + } + + t.Run("success", func(t *testing.T) { + provider, err := NewInMemoryKeyProvider(validKey) + require.NoError(t, err) + + key, err := provider.GetKey(ctx) + require.NoError(t, err) + require.Equal(t, validKey, key) + }) + + t.Run("error-invalid-key-size", func(t *testing.T) { + _, err := NewInMemoryKeyProvider(make([]byte, 16)) + require.ErrorIs(t, err, ErrInvalidKeySize) + }) + + t.Run("key-is-copied", func(t *testing.T) { + originalKey := make([]byte, KeySize) + for i := range originalKey { + originalKey[i] = byte(i) + } + + provider, err := NewInMemoryKeyProvider(originalKey) + require.NoError(t, err) + + // Modify the original key + originalKey[0] = 0xff + + // The provider's key should not be affected + key, err := provider.GetKey(ctx) + require.NoError(t, err) + require.NotEqual(t, originalKey[0], key[0]) + require.Equal(t, byte(0), key[0]) + }) +} + +func TestKeyProviderIntegration(t *testing.T) { + ctx := context.Background() + + // Generate a key + key, err := GenerateKey() + require.NoError(t, err) + + // Create an in-memory provider + provider, err := NewInMemoryKeyProvider(key) + require.NoError(t, err) + + // Get the key from the provider + retrievedKey, err := provider.GetKey(ctx) + require.NoError(t, err) + + // Create an encryptor with the retrieved key + enc, err := NewEncryptor(retrievedKey) + require.NoError(t, err) + + // Test encryption/decryption + plaintext := []byte("secret data from key provider") + encrypted, err := enc.Encrypt(plaintext, nil) + require.NoError(t, err) + + decrypted, err := enc.Decrypt(encrypted, nil) + require.NoError(t, err) + require.Equal(t, plaintext, decrypted) +} diff --git a/pkg/crypto/encryption/sensitive.go b/pkg/crypto/encryption/sensitive.go new file mode 100644 index 0000000000..ade7c6f8ff --- /dev/null +++ b/pkg/crypto/encryption/sensitive.go @@ -0,0 +1,611 @@ +/* +Copyright 2023 The Radius Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package encryption + +import ( + "context" + "encoding/json" + "errors" + "fmt" + "strconv" + "strings" +) + +var ( + // ErrFieldNotFound is returned when a field path cannot be found in the data. + ErrFieldNotFound = errors.New("field not found") + + // ErrInvalidFieldPath is returned when a field path is invalid. + ErrInvalidFieldPath = errors.New("invalid field path") + + // ErrFieldEncryptionFailed is returned when encryption of a field fails. + ErrFieldEncryptionFailed = errors.New("field encryption failed") + + // ErrFieldDecryptionFailed is returned when decryption of a field fails. + ErrFieldDecryptionFailed = errors.New("field decryption failed") +) + +// SensitiveDataHandler provides methods for encrypting and decrypting sensitive fields +// in data structures based on field paths marked with x-radius-sensitive annotation. +type SensitiveDataHandler struct { + encryptor *Encryptor +} + +// NewSensitiveDataHandler creates a new SensitiveDataHandler with the provided encryptor. +func NewSensitiveDataHandler(encryptor *Encryptor) *SensitiveDataHandler { + return &SensitiveDataHandler{encryptor: encryptor} +} + +// NewSensitiveDataHandlerFromKey creates a new SensitiveDataHandler from a raw encryption key. +func NewSensitiveDataHandlerFromKey(key []byte) (*SensitiveDataHandler, error) { + encryptor, err := NewEncryptor(key) + if err != nil { + return nil, err + } + return &SensitiveDataHandler{encryptor: encryptor}, nil +} + +// NewSensitiveDataHandlerFromProvider creates a new SensitiveDataHandler using a key provider. +func NewSensitiveDataHandlerFromProvider(ctx context.Context, provider KeyProvider) (*SensitiveDataHandler, error) { + key, err := provider.GetKey(ctx) + if err != nil { + return nil, err + } + return NewSensitiveDataHandlerFromKey(key) +} + +// EncryptSensitiveFields encrypts all sensitive fields in the data based on the provided field paths. +// The data is modified in place. Field paths support dot notation and [*] for arrays/maps. +// Examples: "credentials.password", "secrets[*].value", "config[*]" +// +// The resourceID is used as Associated Data (AD) for context binding. This prevents encrypted +// values from being moved between different resources. The resourceID should be the full +// resource ID (e.g., "/planes/radius/local/resourceGroups/test/providers/Foo.Bar/myResources/test"). +// +// Returns an error if any field encryption fails. In case of error, partial encryption may have occurred. +func (h *SensitiveDataHandler) EncryptSensitiveFields(data map[string]any, sensitiveFieldPaths []string, resourceID string) error { + for _, path := range sensitiveFieldPaths { + // Build associated data from resource ID and field path + ad := buildAssociatedData(resourceID, path) + if err := h.encryptFieldAtPath(data, path, ad); err != nil { + return fmt.Errorf("%w: path %q: %v", ErrFieldEncryptionFailed, path, err) + } + } + return nil +} + +// DecryptSensitiveFields decrypts all sensitive fields in the data based on the provided field paths. +// The data is modified in place. Field paths support dot notation and [*] for arrays/maps. +// +// The resourceID must match what was provided during encryption for successful decryption. +// +// Note: This method does not use schema information for type restoration. Numbers in decrypted +// objects will be returned as float64 (standard Go JSON behavior). For accurate type restoration, +// use DecryptSensitiveFieldsWithSchema instead. +// +// Returns an error if any field decryption fails. In case of error, partial decryption may have occurred. +func (h *SensitiveDataHandler) DecryptSensitiveFields(data map[string]any, sensitiveFieldPaths []string, resourceID string) error { + for _, path := range sensitiveFieldPaths { + ad := buildAssociatedData(resourceID, path) + if err := h.decryptFieldAtPath(data, path, nil, ad); err != nil { + // Skip fields that are not found - they may not exist in this resource instance + if errors.Is(err, ErrFieldNotFound) { + continue + } + return fmt.Errorf("%w: path %q: %v", ErrFieldDecryptionFailed, path, err) + } + } + return nil +} + +// DecryptSensitiveFieldsWithSchema decrypts all sensitive fields in the data using schema information +// for accurate type restoration. The schema should be the OpenAPI schema for the resource type. +// The data is modified in place. Field paths support dot notation and [*] for arrays/maps. +// +// The resourceID must match what was provided during encryption for successful decryption. +// The schema is used to restore the correct types for fields within encrypted objects (e.g., integers +// that would otherwise be decoded as float64). +// +// Returns an error if any field decryption fails. In case of error, partial decryption may have occurred. +func (h *SensitiveDataHandler) DecryptSensitiveFieldsWithSchema(data map[string]any, sensitiveFieldPaths []string, resourceID string, schema map[string]any) error { + for _, path := range sensitiveFieldPaths { + // Get the schema for this specific field path + fieldSchema := getSchemaForPath(schema, path) + ad := buildAssociatedData(resourceID, path) + if err := h.decryptFieldAtPath(data, path, fieldSchema, ad); err != nil { + // Skip fields that are not found - they may not exist in this resource instance + if errors.Is(err, ErrFieldNotFound) { + continue + } + return fmt.Errorf("%w: path %q: %v", ErrFieldDecryptionFailed, path, err) + } + } + return nil +} + +// encryptFieldAtPath encrypts the value at the given field path in the data. +func (h *SensitiveDataHandler) encryptFieldAtPath(data map[string]any, path string, associatedData []byte) error { + processor := func(value any) (any, error) { + return h.encryptValue(value, associatedData) + } + return h.processFieldAtPath(data, path, processor) +} + +// decryptFieldAtPath decrypts the value at the given field path in the data. +// If fieldSchema is provided, it will be used for type restoration. +func (h *SensitiveDataHandler) decryptFieldAtPath(data map[string]any, path string, fieldSchema map[string]any, associatedData []byte) error { + processor := func(value any) (any, error) { + return h.decryptValue(value, fieldSchema, associatedData) + } + return h.processFieldAtPath(data, path, processor) +} + +// processFieldAtPath traverses the data structure and applies the processor function to the field at the path. +func (h *SensitiveDataHandler) processFieldAtPath(data map[string]any, path string, processor func(any) (any, error)) error { + if path == "" { + return ErrInvalidFieldPath + } + + segments := parseFieldPath(path) + if len(segments) == 0 { + return ErrInvalidFieldPath + } + + return h.processPathSegments(data, segments, processor) +} + +// processPathSegments recursively processes path segments to find and transform the target field. +func (h *SensitiveDataHandler) processPathSegments(current any, segments []pathSegment, processor func(any) (any, error)) error { + if len(segments) == 0 { + return nil + } + + segment := segments[0] + remainingSegments := segments[1:] + + switch segment.segmentType { + case segmentTypeField: + return h.processFieldSegment(current, segment.value, remainingSegments, processor) + case segmentTypeWildcard: + return h.processWildcardSegment(current, remainingSegments, processor) + case segmentTypeIndex: + return h.processIndexSegment(current, segment.value, remainingSegments, processor) + default: + return ErrInvalidFieldPath + } +} + +// processFieldSegment handles a regular field name segment in the path. +func (h *SensitiveDataHandler) processFieldSegment(current any, fieldName string, remainingSegments []pathSegment, processor func(any) (any, error)) error { + dataMap, ok := current.(map[string]any) + if !ok { + return ErrFieldNotFound + } + + value, exists := dataMap[fieldName] + if !exists { + return ErrFieldNotFound + } + + // If this is the last segment, process the value + if len(remainingSegments) == 0 { + processed, err := processor(value) + if err != nil { + return err + } + dataMap[fieldName] = processed + return nil + } + + // Continue traversing + return h.processPathSegments(value, remainingSegments, processor) +} + +// processWildcardSegment handles [*] segments for arrays and maps. +func (h *SensitiveDataHandler) processWildcardSegment(current any, remainingSegments []pathSegment, processor func(any) (any, error)) error { + // Handle array + if arr, ok := current.([]any); ok { + for i := range arr { + if len(remainingSegments) == 0 { + // Process each array element + processed, err := processor(arr[i]) + if err != nil { + return fmt.Errorf("index %d: %w", i, err) + } + arr[i] = processed + } else { + // Continue traversing into each element + if err := h.processPathSegments(arr[i], remainingSegments, processor); err != nil { + // Skip elements that don't have the field + if !errors.Is(err, ErrFieldNotFound) { + return fmt.Errorf("index %d: %w", i, err) + } + } + } + } + return nil + } + + // Handle map + if dataMap, ok := current.(map[string]any); ok { + for key := range dataMap { + if len(remainingSegments) == 0 { + // Process each map value + processed, err := processor(dataMap[key]) + if err != nil { + return fmt.Errorf("key %q: %w", key, err) + } + dataMap[key] = processed + } else { + // Continue traversing into each value + if err := h.processPathSegments(dataMap[key], remainingSegments, processor); err != nil { + // Skip elements that don't have the field + if !errors.Is(err, ErrFieldNotFound) { + return fmt.Errorf("key %q: %w", key, err) + } + } + } + } + return nil + } + + return ErrFieldNotFound +} + +// processIndexSegment handles specific index segments like [0], [1], etc. +func (h *SensitiveDataHandler) processIndexSegment(current any, indexStr string, remainingSegments []pathSegment, processor func(any) (any, error)) error { + arr, ok := current.([]any) + if !ok { + return ErrFieldNotFound + } + + index, err := strconv.Atoi(indexStr) + if err != nil { + return fmt.Errorf("%w: invalid index %q", ErrInvalidFieldPath, indexStr) + } + + if index < 0 || index >= len(arr) { + return ErrFieldNotFound + } + + if len(remainingSegments) == 0 { + processed, err := processor(arr[index]) + if err != nil { + return err + } + arr[index] = processed + return nil + } + + return h.processPathSegments(arr[index], remainingSegments, processor) +} + +// encryptValue encrypts a single value, handling different types appropriately. +func (h *SensitiveDataHandler) encryptValue(value any, associatedData []byte) (any, error) { + if value == nil { + return nil, nil + } + + var dataToEncrypt []byte + var err error + + switch v := value.(type) { + case string: + if v == "" { + return v, nil + } + dataToEncrypt = []byte(v) + case map[string]any, []any: + dataToEncrypt, err = json.Marshal(v) + if err != nil { + return nil, fmt.Errorf("failed to marshal value: %w", err) + } + default: + // For other types, convert to JSON + dataToEncrypt, err = json.Marshal(v) + if err != nil { + return nil, fmt.Errorf("failed to marshal value: %w", err) + } + } + + encrypted, err := h.encryptor.Encrypt(dataToEncrypt, associatedData) + if err != nil { + return nil, err + } + + // Return as a structured object + var result map[string]any + if err := json.Unmarshal(encrypted, &result); err != nil { + return nil, err + } + + return result, nil +} + +// decryptValue decrypts a single value, restoring the original type using schema information if provided. +func (h *SensitiveDataHandler) decryptValue(value any, fieldSchema map[string]any, associatedData []byte) (any, error) { + if value == nil { + return nil, nil + } + + // Check if this looks like encrypted data + encMap, ok := value.(map[string]any) + if !ok { + // Not encrypted data, return as-is + return value, nil + } + + _, hasEncrypted := encMap["encrypted"].(string) + _, hasNonce := encMap["nonce"].(string) + + if !hasEncrypted || !hasNonce { + // Not our encrypted format, return as-is + return value, nil + } + + // Convert back to JSON for decryption + encryptedJSON, err := json.Marshal(encMap) + if err != nil { + return nil, err + } + + decrypted, err := h.encryptor.Decrypt(encryptedJSON, associatedData) + if err != nil { + return nil, err + } + + // Determine the expected type from schema + expectedType := getSchemaType(fieldSchema) + + // For string type, return as string directly + if expectedType == "string" { + return string(decrypted), nil + } + + // For object/array types, unmarshal and apply type coercion based on schema + var result any + if err := json.Unmarshal(decrypted, &result); err != nil { + // If not valid JSON, return as string + return string(decrypted), nil + } + + // Apply schema-based type coercion if schema is available + if fieldSchema != nil && expectedType == "object" { + if resultMap, ok := result.(map[string]any); ok { + coerceTypesFromSchema(resultMap, fieldSchema) + } + } + + return result, nil +} + +// buildAssociatedData constructs the associated data for AEAD encryption from the resource ID and field path. +// This binds the ciphertext to its context, preventing encrypted values from being moved between +// different resources or fields. +func buildAssociatedData(resourceID, fieldPath string) []byte { + if resourceID == "" && fieldPath == "" { + return nil + } + // Combine resource ID and field path with a separator + // Format: "resourceID:fieldPath" + return []byte(resourceID + ":" + fieldPath) +} + +// getSchemaType returns the type from a schema, or empty string if not specified. +func getSchemaType(schema map[string]any) string { + if schema == nil { + return "" + } + if t, ok := schema["type"].(string); ok { + return t + } + return "" +} + +// getSchemaForPath retrieves the schema definition for a specific field path. +// It navigates through the schema following the path segments (supporting nested properties, +// array items via [*], and additionalProperties for maps). +func getSchemaForPath(schema map[string]any, path string) map[string]any { + if schema == nil || path == "" { + return nil + } + + segments := parseFieldPath(path) + current := schema + + for _, segment := range segments { + switch segment.segmentType { + case segmentTypeField: + // Navigate to properties -> fieldName + properties, ok := current["properties"].(map[string]any) + if !ok { + return nil + } + fieldSchema, ok := properties[segment.value].(map[string]any) + if !ok { + return nil + } + current = fieldSchema + + case segmentTypeWildcard: + // Could be array items or additionalProperties + if items, ok := current["items"].(map[string]any); ok { + current = items + } else if addProps, ok := current["additionalProperties"].(map[string]any); ok { + current = addProps + } else { + return nil + } + + case segmentTypeIndex: + // Specific array index - use items schema + if items, ok := current["items"].(map[string]any); ok { + current = items + } else { + return nil + } + } + } + + return current +} + +// coerceTypesFromSchema recursively walks through a data map and coerces types +// to match the schema definition. This primarily handles converting float64 to int64 +// for integer fields. +func coerceTypesFromSchema(data map[string]any, schema map[string]any) { + if schema == nil { + return + } + + properties, ok := schema["properties"].(map[string]any) + if !ok { + return + } + + for fieldName, fieldValue := range data { + fieldSchema, ok := properties[fieldName].(map[string]any) + if !ok { + continue + } + + fieldType := getSchemaType(fieldSchema) + + switch fieldType { + case "integer": + // Coerce float64 to int64 + if f, ok := fieldValue.(float64); ok { + data[fieldName] = int64(f) + } + + case "object": + // Recursively coerce nested objects + if nestedMap, ok := fieldValue.(map[string]any); ok { + coerceTypesFromSchema(nestedMap, fieldSchema) + } + + case "array": + // Coerce array items if they have a schema + if arr, ok := fieldValue.([]any); ok { + itemSchema, _ := fieldSchema["items"].(map[string]any) + if itemSchema != nil { + itemType := getSchemaType(itemSchema) + for i, item := range arr { + if itemType == "integer" { + if f, ok := item.(float64); ok { + arr[i] = int64(f) + } + } else if itemType == "object" { + if itemMap, ok := item.(map[string]any); ok { + coerceTypesFromSchema(itemMap, itemSchema) + } + } + } + } + } + } + + // Handle additionalProperties for map types + if addPropsSchema, ok := fieldSchema["additionalProperties"].(map[string]any); ok { + if nestedMap, ok := fieldValue.(map[string]any); ok { + addPropsType := getSchemaType(addPropsSchema) + for key, val := range nestedMap { + if addPropsType == "integer" { + if f, ok := val.(float64); ok { + nestedMap[key] = int64(f) + } + } else if addPropsType == "object" { + if valMap, ok := val.(map[string]any); ok { + coerceTypesFromSchema(valMap, addPropsSchema) + } + } + } + } + } + } +} + +// pathSegment represents a segment of a field path. +type pathSegment struct { + segmentType segmentType + value string +} + +type segmentType int + +const ( + segmentTypeField segmentType = iota + segmentTypeWildcard + segmentTypeIndex +) + +// parseFieldPath parses a field path into segments. +// Examples: +// - "credentials.password" -> [field:credentials, field:password] +// - "secrets[*].value" -> [field:secrets, wildcard, field:value] +// - "config[*]" -> [field:config, wildcard] +// - "items[0].name" -> [field:items, index:0, field:name] +func parseFieldPath(path string) []pathSegment { + var segments []pathSegment + var current strings.Builder + + i := 0 + for i < len(path) { + ch := path[i] + + switch ch { + case '.': + if current.Len() > 0 { + segments = append(segments, pathSegment{segmentType: segmentTypeField, value: current.String()}) + current.Reset() + } + i++ + + case '[': + if current.Len() > 0 { + segments = append(segments, pathSegment{segmentType: segmentTypeField, value: current.String()}) + current.Reset() + } + + // Find the closing bracket + end := strings.Index(path[i:], "]") + if end == -1 { + // Invalid path - unterminated bracket, return nil to signal error + return nil + } + + bracketContent := path[i+1 : i+end] + if bracketContent == "*" { + segments = append(segments, pathSegment{segmentType: segmentTypeWildcard}) + } else { + segments = append(segments, pathSegment{segmentType: segmentTypeIndex, value: bracketContent}) + } + i += end + 1 + + default: + current.WriteByte(ch) + i++ + } + } + + // Don't forget the last segment + if current.Len() > 0 { + segments = append(segments, pathSegment{segmentType: segmentTypeField, value: current.String()}) + } + + return segments +} diff --git a/pkg/crypto/encryption/sensitive_test.go b/pkg/crypto/encryption/sensitive_test.go new file mode 100644 index 0000000000..75efbbb3fb --- /dev/null +++ b/pkg/crypto/encryption/sensitive_test.go @@ -0,0 +1,844 @@ +/* +Copyright 2023 The Radius Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package encryption + +import ( + "context" + "testing" + + "github.com/stretchr/testify/require" +) + +const ( + // testResourceID is a sample resource ID used for testing associated data + testResourceID = "/planes/radius/local/resourceGroups/test/providers/Test.Resource/testResources/myResource" +) + +func TestParseFieldPath(t *testing.T) { + tests := []struct { + name string + path string + expected []pathSegment + }{ + { + name: "simple-field", + path: "password", + expected: []pathSegment{ + {segmentType: segmentTypeField, value: "password"}, + }, + }, + { + name: "nested-field", + path: "credentials.password", + expected: []pathSegment{ + {segmentType: segmentTypeField, value: "credentials"}, + {segmentType: segmentTypeField, value: "password"}, + }, + }, + { + name: "deeply-nested-field", + path: "config.database.connection.password", + expected: []pathSegment{ + {segmentType: segmentTypeField, value: "config"}, + {segmentType: segmentTypeField, value: "database"}, + {segmentType: segmentTypeField, value: "connection"}, + {segmentType: segmentTypeField, value: "password"}, + }, + }, + { + name: "array-wildcard", + path: "secrets[*].value", + expected: []pathSegment{ + {segmentType: segmentTypeField, value: "secrets"}, + {segmentType: segmentTypeWildcard}, + {segmentType: segmentTypeField, value: "value"}, + }, + }, + { + name: "map-wildcard", + path: "config[*]", + expected: []pathSegment{ + {segmentType: segmentTypeField, value: "config"}, + {segmentType: segmentTypeWildcard}, + }, + }, + { + name: "specific-index", + path: "items[0].name", + expected: []pathSegment{ + {segmentType: segmentTypeField, value: "items"}, + {segmentType: segmentTypeIndex, value: "0"}, + {segmentType: segmentTypeField, value: "name"}, + }, + }, + { + name: "multiple-wildcards", + path: "data[*].secrets[*].value", + expected: []pathSegment{ + {segmentType: segmentTypeField, value: "data"}, + {segmentType: segmentTypeWildcard}, + {segmentType: segmentTypeField, value: "secrets"}, + {segmentType: segmentTypeWildcard}, + {segmentType: segmentTypeField, value: "value"}, + }, + }, + { + name: "unterminated-bracket", + path: "secrets[*", + expected: nil, + }, + { + name: "unterminated-bracket-with-index", + path: "items[0", + expected: nil, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := parseFieldPath(tt.path) + require.Equal(t, tt.expected, result) + }) + } +} + +func TestSensitiveDataHandler_EncryptDecrypt_SimpleField(t *testing.T) { + key, err := GenerateKey() + require.NoError(t, err) + + handler, err := NewSensitiveDataHandlerFromKey(key) + require.NoError(t, err) + + data := map[string]any{ + "username": "admin", + "password": "super-secret-password", + } + + // Encrypt + err = handler.EncryptSensitiveFields(data, []string{"password"}, testResourceID) + require.NoError(t, err) + + // Verify password is encrypted + password := data["password"] + encMap, ok := password.(map[string]any) + require.True(t, ok, "password should be encrypted map") + require.NotEmpty(t, encMap["encrypted"]) + require.NotEmpty(t, encMap["nonce"]) + + // Username should be unchanged + require.Equal(t, "admin", data["username"]) + + // Decrypt + err = handler.DecryptSensitiveFields(data, []string{"password"}, testResourceID) + require.NoError(t, err) + + // Verify password is decrypted + require.Equal(t, "super-secret-password", data["password"]) + require.Equal(t, "admin", data["username"]) +} + +func TestSensitiveDataHandler_EncryptDecrypt_NestedField(t *testing.T) { + key, err := GenerateKey() + require.NoError(t, err) + + handler, err := NewSensitiveDataHandlerFromKey(key) + require.NoError(t, err) + + data := map[string]any{ + "name": "test-resource", + "credentials": map[string]any{ + "username": "admin", + "password": "nested-secret", + "apiKey": "key-12345", + }, + } + + // Encrypt password and apiKey + err = handler.EncryptSensitiveFields(data, []string{"credentials.password", "credentials.apiKey"}, testResourceID) + require.NoError(t, err) + + // Verify encrypted fields + creds := data["credentials"].(map[string]any) + require.Equal(t, "admin", creds["username"]) + + _, passwordIsEncrypted := creds["password"].(map[string]any) + require.True(t, passwordIsEncrypted) + + _, apiKeyIsEncrypted := creds["apiKey"].(map[string]any) + require.True(t, apiKeyIsEncrypted) + + // Decrypt + err = handler.DecryptSensitiveFields(data, []string{"credentials.password", "credentials.apiKey"}, testResourceID) + require.NoError(t, err) + + creds = data["credentials"].(map[string]any) + require.Equal(t, "nested-secret", creds["password"]) + require.Equal(t, "key-12345", creds["apiKey"]) +} + +func TestSensitiveDataHandler_EncryptDecrypt_ArrayWildcard(t *testing.T) { + key, err := GenerateKey() + require.NoError(t, err) + + handler, err := NewSensitiveDataHandlerFromKey(key) + require.NoError(t, err) + + data := map[string]any{ + "name": "test-resource", + "secrets": []any{ + map[string]any{"name": "secret1", "value": "value1"}, + map[string]any{"name": "secret2", "value": "value2"}, + map[string]any{"name": "secret3", "value": "value3"}, + }, + } + + // Encrypt all secret values + err = handler.EncryptSensitiveFields(data, []string{"secrets[*].value"}, testResourceID) + require.NoError(t, err) + + // Verify all values are encrypted + secrets := data["secrets"].([]any) + for i, s := range secrets { + secret := s.(map[string]any) + require.Equal(t, "secret"+string(rune('1'+i)), secret["name"]) + + _, valueIsEncrypted := secret["value"].(map[string]any) + require.True(t, valueIsEncrypted, "secret[%d].value should be encrypted", i) + } + + // Decrypt + err = handler.DecryptSensitiveFields(data, []string{"secrets[*].value"}, testResourceID) + require.NoError(t, err) + + secrets = data["secrets"].([]any) + for i, s := range secrets { + secret := s.(map[string]any) + require.Equal(t, "value"+string(rune('1'+i)), secret["value"]) + } +} + +func TestSensitiveDataHandler_EncryptDecrypt_MapWildcard(t *testing.T) { + key, err := GenerateKey() + require.NoError(t, err) + + handler, err := NewSensitiveDataHandlerFromKey(key) + require.NoError(t, err) + + data := map[string]any{ + "name": "test-resource", + "config": map[string]any{ + "database_password": "db-secret", + "api_key": "api-secret", + "token": "token-secret", + }, + } + + // Encrypt all config values + err = handler.EncryptSensitiveFields(data, []string{"config[*]"}, testResourceID) + require.NoError(t, err) + + // Verify all config values are encrypted + config := data["config"].(map[string]any) + for key, value := range config { + _, isEncrypted := value.(map[string]any) + require.True(t, isEncrypted, "config[%s] should be encrypted", key) + } + + // Decrypt + err = handler.DecryptSensitiveFields(data, []string{"config[*]"}, testResourceID) + require.NoError(t, err) + + config = data["config"].(map[string]any) + require.Equal(t, "db-secret", config["database_password"]) + require.Equal(t, "api-secret", config["api_key"]) + require.Equal(t, "token-secret", config["token"]) +} + +func TestSensitiveDataHandler_EncryptDecrypt_ObjectValue(t *testing.T) { + key, err := GenerateKey() + require.NoError(t, err) + + handler, err := NewSensitiveDataHandlerFromKey(key) + require.NoError(t, err) + + data := map[string]any{ + "name": "test-resource", + "sensitiveConfig": map[string]any{ + "password": "secret-pass", + "privateKey": "-----BEGIN PRIVATE KEY-----\nMIIE...", + "nested": map[string]any{ + "deep": "value", + }, + }, + } + + // Encrypt entire object + err = handler.EncryptSensitiveFields(data, []string{"sensitiveConfig"}, testResourceID) + require.NoError(t, err) + + // Verify the entire object is encrypted + _, isEncrypted := data["sensitiveConfig"].(map[string]any) + require.True(t, isEncrypted) + + encData := data["sensitiveConfig"].(map[string]any) + require.NotEmpty(t, encData["encrypted"]) + require.NotEmpty(t, encData["nonce"]) + + // Decrypt + err = handler.DecryptSensitiveFields(data, []string{"sensitiveConfig"}, testResourceID) + require.NoError(t, err) + + // Verify decrypted object + config := data["sensitiveConfig"].(map[string]any) + require.Equal(t, "secret-pass", config["password"]) + require.Equal(t, "-----BEGIN PRIVATE KEY-----\nMIIE...", config["privateKey"]) + + nested := config["nested"].(map[string]any) + require.Equal(t, "value", nested["deep"]) +} + +func TestSensitiveDataHandler_FieldNotFound(t *testing.T) { + key, err := GenerateKey() + require.NoError(t, err) + + handler, err := NewSensitiveDataHandlerFromKey(key) + require.NoError(t, err) + + data := map[string]any{ + "username": "admin", + } + + // Encrypting non-existent field should return error + err = handler.EncryptSensitiveFields(data, []string{"password"}, testResourceID) + require.Error(t, err) + require.ErrorIs(t, err, ErrFieldEncryptionFailed) + + // Decrypting non-existent field should be skipped (no error) + err = handler.DecryptSensitiveFields(data, []string{"password"}, testResourceID) + require.NoError(t, err) +} + +func TestSensitiveDataHandler_EmptyValue(t *testing.T) { + key, err := GenerateKey() + require.NoError(t, err) + + handler, err := NewSensitiveDataHandlerFromKey(key) + require.NoError(t, err) + + data := map[string]any{ + "password": "", + } + + // Empty string should remain empty + err = handler.EncryptSensitiveFields(data, []string{"password"}, testResourceID) + require.NoError(t, err) + require.Equal(t, "", data["password"]) +} + +func TestSensitiveDataHandler_NilValue(t *testing.T) { + key, err := GenerateKey() + require.NoError(t, err) + + handler, err := NewSensitiveDataHandlerFromKey(key) + require.NoError(t, err) + + data := map[string]any{ + "password": nil, + } + + // Nil should remain nil + err = handler.EncryptSensitiveFields(data, []string{"password"}, testResourceID) + require.NoError(t, err) + require.Nil(t, data["password"]) +} + +func TestSensitiveDataHandler_InvalidFieldPath(t *testing.T) { + key, err := GenerateKey() + require.NoError(t, err) + + handler, err := NewSensitiveDataHandlerFromKey(key) + require.NoError(t, err) + + data := map[string]any{ + "password": "secret", + } + + // Empty path should return error + err = handler.EncryptSensitiveFields(data, []string{""}, testResourceID) + require.Error(t, err) +} + +func TestSensitiveDataHandler_RoundTrip_ComplexStructure(t *testing.T) { + key, err := GenerateKey() + require.NoError(t, err) + + handler, err := NewSensitiveDataHandlerFromKey(key) + require.NoError(t, err) + + original := map[string]any{ + "name": "my-application", + "database": map[string]any{ + "host": "localhost", + "port": 5432, + "password": "db-password-123", + }, + "secrets": []any{ + map[string]any{"key": "API_KEY", "value": "api-secret-value"}, + map[string]any{"key": "AUTH_TOKEN", "value": "auth-token-value"}, + }, + "config": map[string]any{ + "public_setting": "visible", + "private_key": "secret-key-data", + }, + } + + sensitivePaths := []string{ + "database.password", + "secrets[*].value", + "config.private_key", + } + + // Make a copy to encrypt + data := deepCopyMap(original) + + // Encrypt + err = handler.EncryptSensitiveFields(data, sensitivePaths, testResourceID) + require.NoError(t, err) + + // Verify sensitive fields are encrypted + dbPassword := data["database"].(map[string]any)["password"] + _, isEncrypted := dbPassword.(map[string]any) + require.True(t, isEncrypted, "database.password should be encrypted") + + secrets := data["secrets"].([]any) + for i, s := range secrets { + secret := s.(map[string]any) + _, valueEncrypted := secret["value"].(map[string]any) + require.True(t, valueEncrypted, "secrets[%d].value should be encrypted", i) + } + + configPrivateKey := data["config"].(map[string]any)["private_key"] + _, isEncrypted = configPrivateKey.(map[string]any) + require.True(t, isEncrypted, "config.private_key should be encrypted") + + // Verify non-sensitive fields are unchanged + require.Equal(t, "my-application", data["name"]) + require.Equal(t, "localhost", data["database"].(map[string]any)["host"]) + require.Equal(t, 5432, data["database"].(map[string]any)["port"]) + require.Equal(t, "visible", data["config"].(map[string]any)["public_setting"]) + + // Decrypt + err = handler.DecryptSensitiveFields(data, sensitivePaths, testResourceID) + require.NoError(t, err) + + // Verify values are restored + require.Equal(t, "db-password-123", data["database"].(map[string]any)["password"]) + require.Equal(t, "api-secret-value", data["secrets"].([]any)[0].(map[string]any)["value"]) + require.Equal(t, "auth-token-value", data["secrets"].([]any)[1].(map[string]any)["value"]) + require.Equal(t, "secret-key-data", data["config"].(map[string]any)["private_key"]) +} + +func TestSensitiveDataHandler_FromProvider(t *testing.T) { + ctx := context.Background() + + key, err := GenerateKey() + require.NoError(t, err) + + provider, err := NewInMemoryKeyProvider(key) + require.NoError(t, err) + + handler, err := NewSensitiveDataHandlerFromProvider(ctx, provider) + require.NoError(t, err) + require.NotNil(t, handler) + + // Test basic functionality + data := map[string]any{ + "secret": "my-secret", + } + + err = handler.EncryptSensitiveFields(data, []string{"secret"}, testResourceID) + require.NoError(t, err) + + _, isEncrypted := data["secret"].(map[string]any) + require.True(t, isEncrypted) + + err = handler.DecryptSensitiveFields(data, []string{"secret"}, testResourceID) + require.NoError(t, err) + require.Equal(t, "my-secret", data["secret"]) +} + +func TestSensitiveDataHandler_SpecificIndex(t *testing.T) { + key, err := GenerateKey() + require.NoError(t, err) + + handler, err := NewSensitiveDataHandlerFromKey(key) + require.NoError(t, err) + + data := map[string]any{ + "items": []any{ + map[string]any{"value": "public"}, + map[string]any{"value": "secret"}, + map[string]any{"value": "public2"}, + }, + } + + // Only encrypt second item + err = handler.EncryptSensitiveFields(data, []string{"items[1].value"}, testResourceID) + require.NoError(t, err) + + items := data["items"].([]any) + + // First and third should remain strings + require.Equal(t, "public", items[0].(map[string]any)["value"]) + require.Equal(t, "public2", items[2].(map[string]any)["value"]) + + // Second should be encrypted + _, isEncrypted := items[1].(map[string]any)["value"].(map[string]any) + require.True(t, isEncrypted) + + // Decrypt + err = handler.DecryptSensitiveFields(data, []string{"items[1].value"}, testResourceID) + require.NoError(t, err) + + require.Equal(t, "secret", items[1].(map[string]any)["value"]) +} + +func TestSensitiveDataHandler_DecryptWithSchema_IntegerRestoration(t *testing.T) { + key, err := GenerateKey() + require.NoError(t, err) + + handler, err := NewSensitiveDataHandlerFromKey(key) + require.NoError(t, err) + + // Data with integer fields inside a sensitive object + data := map[string]any{ + "name": "test-resource", + "sensitiveConfig": map[string]any{ + "port": 5432, + "timeout": 30, + "password": "secret", + "enabled": true, + }, + } + + // Schema that describes the sensitive field + schema := map[string]any{ + "properties": map[string]any{ + "sensitiveConfig": map[string]any{ + "type": "object", + "x-radius-sensitive": true, + "properties": map[string]any{ + "port": map[string]any{ + "type": "integer", + }, + "timeout": map[string]any{ + "type": "integer", + }, + "password": map[string]any{ + "type": "string", + }, + "enabled": map[string]any{ + "type": "boolean", + }, + }, + }, + }, + } + + sensitivePaths := []string{"sensitiveConfig"} + + // Encrypt + err = handler.EncryptSensitiveFields(data, sensitivePaths, testResourceID) + require.NoError(t, err) + + // Verify it's encrypted + _, isEncrypted := data["sensitiveConfig"].(map[string]any)["encrypted"] + require.True(t, isEncrypted) + + // Decrypt WITH schema + err = handler.DecryptSensitiveFieldsWithSchema(data, sensitivePaths, testResourceID, schema) + require.NoError(t, err) + + // Verify types are correctly restored + config := data["sensitiveConfig"].(map[string]any) + + // Integers should be int64, not float64 + port, ok := config["port"].(int64) + require.True(t, ok, "port should be int64, got %T", config["port"]) + require.Equal(t, int64(5432), port) + + timeout, ok := config["timeout"].(int64) + require.True(t, ok, "timeout should be int64, got %T", config["timeout"]) + require.Equal(t, int64(30), timeout) + + // String should remain string + password, ok := config["password"].(string) + require.True(t, ok, "password should be string") + require.Equal(t, "secret", password) + + // Boolean should remain boolean + enabled, ok := config["enabled"].(bool) + require.True(t, ok, "enabled should be bool") + require.True(t, enabled) +} + +func TestSensitiveDataHandler_DecryptWithSchema_NestedObjects(t *testing.T) { + key, err := GenerateKey() + require.NoError(t, err) + + handler, err := NewSensitiveDataHandlerFromKey(key) + require.NoError(t, err) + + // Data with nested objects containing integers + data := map[string]any{ + "credentials": map[string]any{ + "database": map[string]any{ + "host": "localhost", + "port": 5432, + "maxConns": 100, + }, + "apiKey": "secret-key", + }, + } + + schema := map[string]any{ + "properties": map[string]any{ + "credentials": map[string]any{ + "type": "object", + "x-radius-sensitive": true, + "properties": map[string]any{ + "database": map[string]any{ + "type": "object", + "properties": map[string]any{ + "host": map[string]any{ + "type": "string", + }, + "port": map[string]any{ + "type": "integer", + }, + "maxConns": map[string]any{ + "type": "integer", + }, + }, + }, + "apiKey": map[string]any{ + "type": "string", + }, + }, + }, + }, + } + + sensitivePaths := []string{"credentials"} + + // Encrypt and decrypt with schema + err = handler.EncryptSensitiveFields(data, sensitivePaths, testResourceID) + require.NoError(t, err) + + err = handler.DecryptSensitiveFieldsWithSchema(data, sensitivePaths, testResourceID, schema) + require.NoError(t, err) + + // Verify nested integers are restored + creds := data["credentials"].(map[string]any) + db := creds["database"].(map[string]any) + + require.Equal(t, "localhost", db["host"]) + require.Equal(t, int64(5432), db["port"]) + require.Equal(t, int64(100), db["maxConns"]) + require.Equal(t, "secret-key", creds["apiKey"]) +} + +func TestSensitiveDataHandler_DecryptWithSchema_ArrayWithIntegers(t *testing.T) { + key, err := GenerateKey() + require.NoError(t, err) + + handler, err := NewSensitiveDataHandlerFromKey(key) + require.NoError(t, err) + + // Sensitive object containing an array with integers + data := map[string]any{ + "config": map[string]any{ + "ports": []any{80, 443, 8080}, + "name": "my-config", + }, + } + + schema := map[string]any{ + "properties": map[string]any{ + "config": map[string]any{ + "type": "object", + "x-radius-sensitive": true, + "properties": map[string]any{ + "ports": map[string]any{ + "type": "array", + "items": map[string]any{ + "type": "integer", + }, + }, + "name": map[string]any{ + "type": "string", + }, + }, + }, + }, + } + + sensitivePaths := []string{"config"} + + err = handler.EncryptSensitiveFields(data, sensitivePaths, testResourceID) + require.NoError(t, err) + + err = handler.DecryptSensitiveFieldsWithSchema(data, sensitivePaths, testResourceID, schema) + require.NoError(t, err) + + config := data["config"].(map[string]any) + ports := config["ports"].([]any) + + require.Len(t, ports, 3) + require.Equal(t, int64(80), ports[0]) + require.Equal(t, int64(443), ports[1]) + require.Equal(t, int64(8080), ports[2]) + require.Equal(t, "my-config", config["name"]) +} + +func TestSensitiveDataHandler_DecryptWithoutSchema_NoTypeCoercion(t *testing.T) { + key, err := GenerateKey() + require.NoError(t, err) + + handler, err := NewSensitiveDataHandlerFromKey(key) + require.NoError(t, err) + + // Data with integer fields + data := map[string]any{ + "sensitiveConfig": map[string]any{ + "port": 5432, + }, + } + + sensitivePaths := []string{"sensitiveConfig"} + + err = handler.EncryptSensitiveFields(data, sensitivePaths, testResourceID) + require.NoError(t, err) + + // Decrypt WITHOUT schema + err = handler.DecryptSensitiveFields(data, sensitivePaths, testResourceID) + require.NoError(t, err) + + config := data["sensitiveConfig"].(map[string]any) + + // Without schema, integer should be float64 (standard JSON behavior) + _, isFloat := config["port"].(float64) + require.True(t, isFloat, "without schema, port should be float64, got %T", config["port"]) +} + +func TestGetSchemaForPath(t *testing.T) { + schema := map[string]any{ + "properties": map[string]any{ + "credentials": map[string]any{ + "type": "object", + "properties": map[string]any{ + "password": map[string]any{ + "type": "string", + }, + }, + }, + "secrets": map[string]any{ + "type": "array", + "items": map[string]any{ + "type": "object", + "properties": map[string]any{ + "value": map[string]any{ + "type": "string", + }, + }, + }, + }, + "config": map[string]any{ + "type": "object", + "additionalProperties": map[string]any{ + "type": "string", + }, + }, + }, + } + + tests := []struct { + name string + path string + expectedType string + }{ + { + name: "simple-nested-field", + path: "credentials.password", + expectedType: "string", + }, + { + name: "array-wildcard-nested", + path: "secrets[*].value", + expectedType: "string", + }, + { + name: "map-wildcard", + path: "config[*]", + expectedType: "string", + }, + { + name: "object-field", + path: "credentials", + expectedType: "object", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := getSchemaForPath(schema, tt.path) + require.NotNil(t, result) + require.Equal(t, tt.expectedType, result["type"]) + }) + } +} + +// Helper function to deep copy a map for testing +func deepCopyMap(original map[string]any) map[string]any { + result := make(map[string]any) + for key, value := range original { + switch v := value.(type) { + case map[string]any: + result[key] = deepCopyMap(v) + case []any: + result[key] = deepCopySlice(v) + default: + result[key] = value + } + } + return result +} + +func deepCopySlice(original []any) []any { + result := make([]any, len(original)) + for i, value := range original { + switch v := value.(type) { + case map[string]any: + result[i] = deepCopyMap(v) + case []any: + result[i] = deepCopySlice(v) + default: + result[i] = value + } + } + return result +} From 5abb5f48dc0a74da9b5a9d1c379bbc428109d47a Mon Sep 17 00:00:00 2001 From: lakshmimsft Date: Thu, 29 Jan 2026 10:53:41 -0800 Subject: [PATCH 2/3] updating code to use version in keys Signed-off-by: lakshmimsft --- pkg/crypto/encryption/encryption.go | 31 +- pkg/crypto/encryption/encryption_test.go | 218 ++++++++++++ pkg/crypto/encryption/keyprovider.go | 196 +++++++++-- pkg/crypto/encryption/keyprovider_test.go | 398 ++++++++++++++++++++-- pkg/crypto/encryption/sensitive.go | 83 ++++- pkg/crypto/encryption/sensitive_test.go | 223 +++++++++++- 6 files changed, 1073 insertions(+), 76 deletions(-) diff --git a/pkg/crypto/encryption/encryption.go b/pkg/crypto/encryption/encryption.go index 0780435592..119323742a 100644 --- a/pkg/crypto/encryption/encryption.go +++ b/pkg/crypto/encryption/encryption.go @@ -61,6 +61,9 @@ var ( // EncryptedData represents the structure for storing encrypted data. // It contains the base64-encoded ciphertext and nonce, plus optional associated data hash. type EncryptedData struct { + // Version is the key version used for encryption. + // This allows decryption to use the correct key when multiple versions exist. + Version int `json:"version,omitempty"` // Encrypted contains the base64-encoded ciphertext. Encrypted string `json:"encrypted"` // Nonce contains the base64-encoded nonce used for encryption. @@ -73,12 +76,20 @@ type EncryptedData struct { // Encryptor provides methods for encrypting and decrypting data using ChaCha20-Poly1305. type Encryptor struct { - aead cipher.AEAD + aead cipher.AEAD + keyVersion int } // NewEncryptor creates a new Encryptor with the provided 256-bit key. // Returns an error if the key is not exactly 32 bytes. +// The key version defaults to 0 (unversioned). Use NewEncryptorWithVersion for versioned keys. func NewEncryptor(key []byte) (*Encryptor, error) { + return NewEncryptorWithVersion(key, 0) +} + +// NewEncryptorWithVersion creates a new Encryptor with the provided key and version. +// The version is stored in encrypted data to enable decryption with the correct key. +func NewEncryptorWithVersion(key []byte, version int) (*Encryptor, error) { if len(key) != KeySize { return nil, ErrInvalidKeySize } @@ -88,7 +99,7 @@ func NewEncryptor(key []byte) (*Encryptor, error) { return nil, fmt.Errorf("%w: %v", ErrEncryptionFailed, err) } - return &Encryptor{aead: aead}, nil + return &Encryptor{aead: aead, keyVersion: version}, nil } // Encrypt encrypts the plaintext using ChaCha20-Poly1305 with Associated Data (AD). @@ -122,6 +133,7 @@ func (e *Encryptor) Encrypt(plaintext []byte, associatedData []byte) ([]byte, er // Create the encrypted data structure encryptedData := EncryptedData{ + Version: e.keyVersion, Encrypted: base64.StdEncoding.EncodeToString(ciphertext), Nonce: base64.StdEncoding.EncodeToString(nonce), } @@ -267,3 +279,18 @@ func GenerateKey() ([]byte, error) { } return key, nil } + +// GetEncryptedDataVersion extracts the key version from encrypted data without decrypting. +// Returns 0 if the version is not present (for backwards compatibility with unversioned data). +func GetEncryptedDataVersion(data []byte) (int, error) { + if len(data) == 0 { + return 0, ErrInvalidEncryptedData + } + + var encryptedData EncryptedData + if err := json.Unmarshal(data, &encryptedData); err != nil { + return 0, fmt.Errorf("%w: failed to parse encrypted data: %v", ErrInvalidEncryptedData, err) + } + + return encryptedData.Version, nil +} diff --git a/pkg/crypto/encryption/encryption_test.go b/pkg/crypto/encryption/encryption_test.go index 863e89c1e6..2631bd31d4 100644 --- a/pkg/crypto/encryption/encryption_test.go +++ b/pkg/crypto/encryption/encryption_test.go @@ -592,3 +592,221 @@ func TestAssociatedDataPreventsContextSwitch(t *testing.T) { _, err = enc.Decrypt(encryptedForResource1, resource2AD) require.Error(t, err, "should not be able to decrypt with different resource context") } + +// Tests for versioned key support + +func TestNewEncryptorWithVersion(t *testing.T) { + key := make([]byte, 32) + for i := range key { + key[i] = byte(i) + } + + tests := []struct { + name string + key []byte + version int + wantErr error + }{ + { + name: "valid-key-with-version-1", + key: key, + version: 1, + wantErr: nil, + }, + { + name: "valid-key-with-version-2", + key: key, + version: 2, + wantErr: nil, + }, + { + name: "valid-key-with-version-0", + key: key, + version: 0, + wantErr: nil, + }, + { + name: "invalid-key-size", + key: make([]byte, 16), + version: 1, + wantErr: ErrInvalidKeySize, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + enc, err := NewEncryptorWithVersion(tt.key, tt.version) + if tt.wantErr != nil { + require.ErrorIs(t, err, tt.wantErr) + require.Nil(t, enc) + } else { + require.NoError(t, err) + require.NotNil(t, enc) + require.Equal(t, tt.version, enc.keyVersion) + } + }) + } +} + +func TestEncryptedDataContainsVersion(t *testing.T) { + key, err := GenerateKey() + require.NoError(t, err) + + testCases := []int{0, 1, 2, 99} + + for _, version := range testCases { + t.Run("version-"+string(rune('0'+version)), func(t *testing.T) { + enc, err := NewEncryptorWithVersion(key, version) + require.NoError(t, err) + + encrypted, err := enc.Encrypt([]byte("secret"), nil) + require.NoError(t, err) + + // Parse and verify version + var encData EncryptedData + err = json.Unmarshal(encrypted, &encData) + require.NoError(t, err) + require.Equal(t, version, encData.Version) + }) + } +} + +func TestGetEncryptedDataVersion(t *testing.T) { + key, err := GenerateKey() + require.NoError(t, err) + + tests := []struct { + name string + data []byte + wantVersion int + wantErr error + }{ + { + name: "valid-version-1", + data: func() []byte { + enc, _ := NewEncryptorWithVersion(key, 1) + data, _ := enc.Encrypt([]byte("test"), nil) + return data + }(), + wantVersion: 1, + }, + { + name: "valid-version-2", + data: func() []byte { + enc, _ := NewEncryptorWithVersion(key, 2) + data, _ := enc.Encrypt([]byte("test"), nil) + return data + }(), + wantVersion: 2, + }, + { + name: "unversioned-data-version-0", + data: func() []byte { + enc, _ := NewEncryptor(key) // Default encryptor has version 0 + data, _ := enc.Encrypt([]byte("test"), nil) + return data + }(), + wantVersion: 0, + }, + { + name: "empty-data", + data: []byte{}, + wantErr: ErrInvalidEncryptedData, + }, + { + name: "invalid-json", + data: []byte("not json"), + wantErr: ErrInvalidEncryptedData, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + version, err := GetEncryptedDataVersion(tt.data) + if tt.wantErr != nil { + require.ErrorIs(t, err, tt.wantErr) + } else { + require.NoError(t, err) + require.Equal(t, tt.wantVersion, version) + } + }) + } +} + +func TestVersionedEncryptionDecryption(t *testing.T) { + // Generate two different keys + key1, err := GenerateKey() + require.NoError(t, err) + key2, err := GenerateKey() + require.NoError(t, err) + + // Create encryptors with different versions + enc1, err := NewEncryptorWithVersion(key1, 1) + require.NoError(t, err) + enc2, err := NewEncryptorWithVersion(key2, 2) + require.NoError(t, err) + + // Encrypt with version 1 + plaintext1 := []byte("secret encrypted with key v1") + encrypted1, err := enc1.Encrypt(plaintext1, nil) + require.NoError(t, err) + + // Encrypt with version 2 + plaintext2 := []byte("secret encrypted with key v2") + encrypted2, err := enc2.Encrypt(plaintext2, nil) + require.NoError(t, err) + + // Verify versions are stored correctly + v1, err := GetEncryptedDataVersion(encrypted1) + require.NoError(t, err) + require.Equal(t, 1, v1) + + v2, err := GetEncryptedDataVersion(encrypted2) + require.NoError(t, err) + require.Equal(t, 2, v2) + + // Decrypt with correct keys + decrypted1, err := enc1.Decrypt(encrypted1, nil) + require.NoError(t, err) + require.Equal(t, plaintext1, decrypted1) + + decrypted2, err := enc2.Decrypt(encrypted2, nil) + require.NoError(t, err) + require.Equal(t, plaintext2, decrypted2) + + // Cross-decryption should fail (wrong key) + _, err = enc1.Decrypt(encrypted2, nil) + require.ErrorIs(t, err, ErrDecryptionFailed) + + _, err = enc2.Decrypt(encrypted1, nil) + require.ErrorIs(t, err, ErrDecryptionFailed) +} + +func TestBackwardsCompatibilityWithUnversionedData(t *testing.T) { + key, err := GenerateKey() + require.NoError(t, err) + + // Create unversioned encryptor (simulates old code) + unversionedEnc, err := NewEncryptor(key) + require.NoError(t, err) + + // Create versioned encryptor with same key + versionedEnc, err := NewEncryptorWithVersion(key, 0) + require.NoError(t, err) + + plaintext := []byte("data from old system") + + // Encrypt with unversioned encryptor + encrypted, err := unversionedEnc.Encrypt(plaintext, nil) + require.NoError(t, err) + + // Should have version 0 + version, err := GetEncryptedDataVersion(encrypted) + require.NoError(t, err) + require.Equal(t, 0, version) + + // Should be decryptable by versioned encryptor with same key + decrypted, err := versionedEnc.Decrypt(encrypted, nil) + require.NoError(t, err) + require.Equal(t, plaintext, decrypted) +} diff --git a/pkg/crypto/encryption/keyprovider.go b/pkg/crypto/encryption/keyprovider.go index ed302ca7e2..607cd8dd6c 100644 --- a/pkg/crypto/encryption/keyprovider.go +++ b/pkg/crypto/encryption/keyprovider.go @@ -18,8 +18,11 @@ package encryption import ( "context" + "encoding/base64" + "encoding/json" "errors" "fmt" + "strconv" corev1 "k8s.io/api/core/v1" k8s_error "k8s.io/apimachinery/pkg/api/errors" @@ -28,32 +31,63 @@ import ( const ( // DefaultEncryptionKeySecretName is the default name of the Kubernetes Secret containing the encryption key. - + // This is the Secret's name, not actual credentials. DefaultEncryptionKeySecretName = "radius-encryption-key" //nolint:gosec // This is a Secret name, not credentials - // DefaultEncryptionKeySecretKey is the key within the Secret that contains the encryption key. - DefaultEncryptionKeySecretKey = "key" + // DefaultEncryptionKeySecretKey is the key within the Secret that contains the versioned key store JSON. + DefaultEncryptionKeySecretKey = "keys.json" // RadiusNamespace is the namespace where Radius secrets are stored. RadiusNamespace = "radius-system" ) +// KeyStore represents a versioned key store containing multiple encryption keys. +// This structure matches the format used by the key rotation CronJob. +type KeyStore struct { + // CurrentVersion is the version number of the key to use for encryption. + CurrentVersion int `json:"currentVersion"` + // Keys is a map of version number (as string) to key data. + Keys map[string]KeyData `json:"keys"` +} + +// KeyData represents a single encryption key with its metadata. +type KeyData struct { + // Key is the base64-encoded encryption key. + Key string `json:"key"` + // Version is the version number of this key. + Version int `json:"version"` + // CreatedAt is the timestamp when this key was created (RFC3339 format). + CreatedAt string `json:"createdAt"` + // ExpiresAt is the timestamp when this key expires (RFC3339 format). + ExpiresAt string `json:"expiresAt"` +} + var ( // ErrKeyNotFound is returned when the encryption key is not found. ErrKeyNotFound = errors.New("encryption key not found") // ErrKeyLoadFailed is returned when loading the encryption key fails. ErrKeyLoadFailed = errors.New("failed to load encryption key") + + // ErrKeyVersionNotFound is returned when a specific key version is not found. + ErrKeyVersionNotFound = errors.New("key version not found") ) // KeyProvider defines the interface for retrieving encryption keys. +// It supports versioned keys to enable key rotation without data loss. // //go:generate mockgen -typed -destination=./mock_keyprovider.go -package=encryption -self_package github.com/radius-project/radius/pkg/crypto/encryption github.com/radius-project/radius/pkg/crypto/encryption KeyProvider type KeyProvider interface { - // GetKey retrieves the encryption key. - // Returns ErrKeyNotFound if the key does not exist. - GetKey(ctx context.Context) ([]byte, error) + // GetCurrentKey retrieves the current (latest) encryption key for encrypting new data. + // Returns the key bytes, the version number, and any error. + // Returns ErrKeyNotFound if no key exists. + GetCurrentKey(ctx context.Context) (key []byte, version int, err error) + + // GetKeyByVersion retrieves a specific key version for decryption. + // This is used when decrypting data that was encrypted with an older key. + // Returns ErrKeyVersionNotFound if the specified version does not exist. + GetKeyByVersion(ctx context.Context, version int) ([]byte, error) } // KubernetesKeyProvider implements KeyProvider by loading the encryption key from a Kubernetes Secret. @@ -105,8 +139,8 @@ func NewKubernetesKeyProvider(client controller_runtime.Client, opts *Kubernetes } } -// GetKey retrieves the encryption key from the Kubernetes Secret. -func (p *KubernetesKeyProvider) GetKey(ctx context.Context) ([]byte, error) { +// loadKeyStore loads and parses the key store from the Kubernetes Secret. +func (p *KubernetesKeyProvider) loadKeyStore(ctx context.Context) (*KeyStore, error) { secret := &corev1.Secret{} objectKey := controller_runtime.ObjectKey{ Name: p.secretName, @@ -120,40 +154,162 @@ func (p *KubernetesKeyProvider) GetKey(ctx context.Context) ([]byte, error) { return nil, fmt.Errorf("%w: %v", ErrKeyLoadFailed, err) } - key, ok := secret.Data[p.secretKey] + keysJSON, ok := secret.Data[p.secretKey] if !ok { return nil, fmt.Errorf("%w: key %q not found in secret %s/%s", ErrKeyNotFound, p.secretKey, p.namespace, p.secretName) } + var keyStore KeyStore + if err := json.Unmarshal(keysJSON, &keyStore); err != nil { + return nil, fmt.Errorf("%w: failed to parse key store JSON: %v", ErrKeyLoadFailed, err) + } + + return &keyStore, nil +} + +// GetCurrentKey retrieves the current encryption key from the Kubernetes Secret. +// Returns the key bytes, version number, and any error. +func (p *KubernetesKeyProvider) GetCurrentKey(ctx context.Context) ([]byte, int, error) { + keyStore, err := p.loadKeyStore(ctx) + if err != nil { + return nil, 0, err + } + + versionStr := strconv.Itoa(keyStore.CurrentVersion) + keyData, ok := keyStore.Keys[versionStr] + if !ok { + return nil, 0, fmt.Errorf("%w: current version %d not found in key store", ErrKeyVersionNotFound, keyStore.CurrentVersion) + } + + key, err := base64.StdEncoding.DecodeString(keyData.Key) + if err != nil { + return nil, 0, fmt.Errorf("%w: failed to decode key: %v", ErrKeyLoadFailed, err) + } + if len(key) != KeySize { - return nil, fmt.Errorf("%w: key in secret %s/%s has invalid size (expected %d bytes, got %d)", ErrKeyLoadFailed, p.namespace, p.secretName, KeySize, len(key)) + return nil, 0, fmt.Errorf("%w: key version %d has invalid size (expected %d bytes, got %d)", ErrKeyLoadFailed, keyStore.CurrentVersion, KeySize, len(key)) + } + + return key, keyStore.CurrentVersion, nil +} + +// GetKeyByVersion retrieves a specific key version from the Kubernetes Secret. +func (p *KubernetesKeyProvider) GetKeyByVersion(ctx context.Context, version int) ([]byte, error) { + keyStore, err := p.loadKeyStore(ctx) + if err != nil { + return nil, err + } + + versionStr := strconv.Itoa(version) + keyData, ok := keyStore.Keys[versionStr] + if !ok { + return nil, fmt.Errorf("%w: version %d not found in key store", ErrKeyVersionNotFound, version) + } + + key, err := base64.StdEncoding.DecodeString(keyData.Key) + if err != nil { + return nil, fmt.Errorf("%w: failed to decode key version %d: %v", ErrKeyLoadFailed, version, err) + } + + if len(key) != KeySize { + return nil, fmt.Errorf("%w: key version %d has invalid size (expected %d bytes, got %d)", ErrKeyLoadFailed, version, KeySize, len(key)) } return key, nil } -// InMemoryKeyProvider implements KeyProvider with an in-memory key. -// This is useful for testing or development environments. +// InMemoryKeyProvider implements KeyProvider with in-memory versioned keys. +// This is useful for testing environments. type InMemoryKeyProvider struct { - key []byte + keys map[int][]byte + currentVersion int } -// NewInMemoryKeyProvider creates a new InMemoryKeyProvider with the given key. +// NewInMemoryKeyProvider creates a new InMemoryKeyProvider with a single key at version 1. func NewInMemoryKeyProvider(key []byte) (*InMemoryKeyProvider, error) { if len(key) != KeySize { return nil, ErrInvalidKeySize } keyCopy := make([]byte, KeySize) copy(keyCopy, key) - return &InMemoryKeyProvider{key: keyCopy}, nil + return &InMemoryKeyProvider{ + keys: map[int][]byte{1: keyCopy}, + currentVersion: 1, + }, nil } -// GetKey returns a copy of the in-memory encryption key. -// A copy is returned to prevent callers from mutating the provider's internal state. -func (p *InMemoryKeyProvider) GetKey(ctx context.Context) ([]byte, error) { - if p.key == nil { +// NewInMemoryKeyProviderWithVersions creates a new InMemoryKeyProvider with multiple versioned keys. +func NewInMemoryKeyProviderWithVersions(keys map[int][]byte, currentVersion int) (*InMemoryKeyProvider, error) { + if len(keys) == 0 { return nil, ErrKeyNotFound } + + keysCopy := make(map[int][]byte, len(keys)) + for version, key := range keys { + if len(key) != KeySize { + return nil, fmt.Errorf("%w: key version %d", ErrInvalidKeySize, version) + } + keyCopy := make([]byte, KeySize) + copy(keyCopy, key) + keysCopy[version] = keyCopy + } + + if _, ok := keysCopy[currentVersion]; !ok { + return nil, fmt.Errorf("%w: current version %d", ErrKeyVersionNotFound, currentVersion) + } + + return &InMemoryKeyProvider{ + keys: keysCopy, + currentVersion: currentVersion, + }, nil +} + +// GetCurrentKey returns the current encryption key and its version. +func (p *InMemoryKeyProvider) GetCurrentKey(ctx context.Context) ([]byte, int, error) { + if len(p.keys) == 0 { + return nil, 0, ErrKeyNotFound + } + + key, ok := p.keys[p.currentVersion] + if !ok { + return nil, 0, fmt.Errorf("%w: current version %d", ErrKeyVersionNotFound, p.currentVersion) + } + + // Return a copy to prevent mutation of the internal key + return append([]byte(nil), key...), p.currentVersion, nil +} + +// GetKeyByVersion returns the key for a specific version. +func (p *InMemoryKeyProvider) GetKeyByVersion(ctx context.Context, version int) ([]byte, error) { + if p.keys == nil { + return nil, ErrKeyNotFound + } + + key, ok := p.keys[version] + if !ok { + return nil, fmt.Errorf("%w: version %d", ErrKeyVersionNotFound, version) + } + // Return a copy to prevent mutation of the internal key - return append([]byte(nil), p.key...), nil + return append([]byte(nil), key...), nil +} + +// AddKey adds a new key version to the provider (useful for testing rotation). +func (p *InMemoryKeyProvider) AddKey(version int, key []byte) error { + if len(key) != KeySize { + return ErrInvalidKeySize + } + keyCopy := make([]byte, KeySize) + copy(keyCopy, key) + p.keys[version] = keyCopy + return nil +} + +// SetCurrentVersion sets the current version (useful for testing rotation). +func (p *InMemoryKeyProvider) SetCurrentVersion(version int) error { + if _, ok := p.keys[version]; !ok { + return fmt.Errorf("%w: version %d", ErrKeyVersionNotFound, version) + } + p.currentVersion = version + return nil } diff --git a/pkg/crypto/encryption/keyprovider_test.go b/pkg/crypto/encryption/keyprovider_test.go index 7398c66ff4..9597a91758 100644 --- a/pkg/crypto/encryption/keyprovider_test.go +++ b/pkg/crypto/encryption/keyprovider_test.go @@ -18,6 +18,8 @@ package encryption import ( "context" + "encoding/base64" + "encoding/json" "testing" "github.com/radius-project/radius/test/k8sutil" @@ -28,7 +30,30 @@ import ( controller_runtime "sigs.k8s.io/controller-runtime/pkg/client" ) -func TestKubernetesKeyProvider_GetKey(t *testing.T) { +// createTestKeyStore creates a KeyStore JSON for testing +func createTestKeyStore(t *testing.T, keys map[int][]byte, currentVersion int) []byte { + keyStore := KeyStore{ + CurrentVersion: currentVersion, + Keys: make(map[string]KeyData), + } + for version, key := range keys { + versionStr := string(rune('0' + version)) + if version >= 10 { + versionStr = string(rune('0'+version/10)) + string(rune('0'+version%10)) + } + keyStore.Keys[versionStr] = KeyData{ + Key: base64.StdEncoding.EncodeToString(key), + Version: version, + CreatedAt: "2024-01-01T00:00:00Z", + ExpiresAt: "2024-04-01T00:00:00Z", + } + } + data, err := json.Marshal(keyStore) + require.NoError(t, err) + return data +} + +func TestKubernetesKeyProvider_GetCurrentKey(t *testing.T) { ctx := context.Background() validKey := make([]byte, KeySize) for i := range validKey { @@ -36,41 +61,77 @@ func TestKubernetesKeyProvider_GetKey(t *testing.T) { } tests := []struct { - name string - setupFunc func(k8sClient controller_runtime.Client) - opts *KubernetesKeyProviderOptions - wantErr error - wantKey []byte - wantErrMsg string + name string + setupFunc func(k8sClient controller_runtime.Client) + opts *KubernetesKeyProviderOptions + wantErr error + wantKey []byte + wantVersion int + wantErrMsg string }{ { name: "success-with-default-options", setupFunc: func(k8sClient controller_runtime.Client) { + keyStoreJSON := createTestKeyStore(t, map[int][]byte{1: validKey}, 1) + secret := &corev1.Secret{ + ObjectMeta: metav1.ObjectMeta{ + Name: DefaultEncryptionKeySecretName, + Namespace: RadiusNamespace, + }, + Data: map[string][]byte{ + DefaultEncryptionKeySecretKey: keyStoreJSON, + }, + } + err := k8sClient.Create(ctx, secret) + require.NoError(t, err) + }, + opts: nil, + wantKey: validKey, + wantVersion: 1, + }, + { + name: "success-with-multiple-versions", + setupFunc: func(k8sClient controller_runtime.Client) { + key1 := make([]byte, KeySize) + key2 := make([]byte, KeySize) + for i := range key1 { + key1[i] = byte(i) + key2[i] = byte(i + 100) + } + keyStoreJSON := createTestKeyStore(t, map[int][]byte{1: key1, 2: key2}, 2) secret := &corev1.Secret{ ObjectMeta: metav1.ObjectMeta{ Name: DefaultEncryptionKeySecretName, Namespace: RadiusNamespace, }, Data: map[string][]byte{ - DefaultEncryptionKeySecretKey: validKey, + DefaultEncryptionKeySecretKey: keyStoreJSON, }, } err := k8sClient.Create(ctx, secret) require.NoError(t, err) }, - opts: nil, - wantKey: validKey, + opts: nil, + wantKey: func() []byte { + key := make([]byte, KeySize) + for i := range key { + key[i] = byte(i + 100) + } + return key + }(), + wantVersion: 2, }, { name: "success-with-custom-options", setupFunc: func(k8sClient controller_runtime.Client) { + keyStoreJSON := createTestKeyStore(t, map[int][]byte{1: validKey}, 1) secret := &corev1.Secret{ ObjectMeta: metav1.ObjectMeta{ Name: "custom-secret", Namespace: "custom-namespace", }, Data: map[string][]byte{ - "custom-key": validKey, + "custom-key": keyStoreJSON, }, } err := k8sClient.Create(ctx, secret) @@ -81,7 +142,8 @@ func TestKubernetesKeyProvider_GetKey(t *testing.T) { SecretKey: "custom-key", Namespace: "custom-namespace", }, - wantKey: validKey, + wantKey: validKey, + wantVersion: 1, }, { name: "error-secret-not-found", @@ -99,7 +161,7 @@ func TestKubernetesKeyProvider_GetKey(t *testing.T) { Namespace: RadiusNamespace, }, Data: map[string][]byte{ - "wrong-key": validKey, + "wrong-key": []byte("{}"), }, } err := k8sClient.Create(ctx, secret) @@ -109,16 +171,57 @@ func TestKubernetesKeyProvider_GetKey(t *testing.T) { wantErr: ErrKeyNotFound, wantErrMsg: "not found in secret", }, + { + name: "error-invalid-json", + setupFunc: func(k8sClient controller_runtime.Client) { + secret := &corev1.Secret{ + ObjectMeta: metav1.ObjectMeta{ + Name: DefaultEncryptionKeySecretName, + Namespace: RadiusNamespace, + }, + Data: map[string][]byte{ + DefaultEncryptionKeySecretKey: []byte("not-valid-json"), + }, + } + err := k8sClient.Create(ctx, secret) + require.NoError(t, err) + }, + opts: nil, + wantErr: ErrKeyLoadFailed, + wantErrMsg: "failed to parse key store JSON", + }, + { + name: "error-current-version-not-found", + setupFunc: func(k8sClient controller_runtime.Client) { + keyStoreJSON := createTestKeyStore(t, map[int][]byte{1: validKey}, 99) + secret := &corev1.Secret{ + ObjectMeta: metav1.ObjectMeta{ + Name: DefaultEncryptionKeySecretName, + Namespace: RadiusNamespace, + }, + Data: map[string][]byte{ + DefaultEncryptionKeySecretKey: keyStoreJSON, + }, + } + err := k8sClient.Create(ctx, secret) + require.NoError(t, err) + }, + opts: nil, + wantErr: ErrKeyVersionNotFound, + wantErrMsg: "current version 99 not found", + }, { name: "error-invalid-key-size", setupFunc: func(k8sClient controller_runtime.Client) { + shortKey := make([]byte, 16) // Too short + keyStoreJSON := createTestKeyStore(t, map[int][]byte{1: shortKey}, 1) secret := &corev1.Secret{ ObjectMeta: metav1.ObjectMeta{ Name: DefaultEncryptionKeySecretName, Namespace: RadiusNamespace, }, Data: map[string][]byte{ - DefaultEncryptionKeySecretKey: make([]byte, 16), // Too short + DefaultEncryptionKeySecretKey: keyStoreJSON, }, } err := k8sClient.Create(ctx, secret) @@ -136,7 +239,107 @@ func TestKubernetesKeyProvider_GetKey(t *testing.T) { tt.setupFunc(k8sClient) provider := NewKubernetesKeyProvider(k8sClient, tt.opts) - key, err := provider.GetKey(ctx) + key, version, err := provider.GetCurrentKey(ctx) + + if tt.wantErr != nil { + require.ErrorIs(t, err, tt.wantErr) + if tt.wantErrMsg != "" { + require.Contains(t, err.Error(), tt.wantErrMsg) + } + require.Nil(t, key) + } else { + require.NoError(t, err) + require.Equal(t, tt.wantKey, key) + require.Equal(t, tt.wantVersion, version) + } + }) + } +} + +func TestKubernetesKeyProvider_GetKeyByVersion(t *testing.T) { + ctx := context.Background() + key1 := make([]byte, KeySize) + key2 := make([]byte, KeySize) + for i := range key1 { + key1[i] = byte(i) + key2[i] = byte(i + 100) + } + + tests := []struct { + name string + setupFunc func(k8sClient controller_runtime.Client) + version int + wantErr error + wantKey []byte + wantErrMsg string + }{ + { + name: "success-get-version-1", + setupFunc: func(k8sClient controller_runtime.Client) { + keyStoreJSON := createTestKeyStore(t, map[int][]byte{1: key1, 2: key2}, 2) + secret := &corev1.Secret{ + ObjectMeta: metav1.ObjectMeta{ + Name: DefaultEncryptionKeySecretName, + Namespace: RadiusNamespace, + }, + Data: map[string][]byte{ + DefaultEncryptionKeySecretKey: keyStoreJSON, + }, + } + err := k8sClient.Create(ctx, secret) + require.NoError(t, err) + }, + version: 1, + wantKey: key1, + }, + { + name: "success-get-version-2", + setupFunc: func(k8sClient controller_runtime.Client) { + keyStoreJSON := createTestKeyStore(t, map[int][]byte{1: key1, 2: key2}, 2) + secret := &corev1.Secret{ + ObjectMeta: metav1.ObjectMeta{ + Name: DefaultEncryptionKeySecretName, + Namespace: RadiusNamespace, + }, + Data: map[string][]byte{ + DefaultEncryptionKeySecretKey: keyStoreJSON, + }, + } + err := k8sClient.Create(ctx, secret) + require.NoError(t, err) + }, + version: 2, + wantKey: key2, + }, + { + name: "error-version-not-found", + setupFunc: func(k8sClient controller_runtime.Client) { + keyStoreJSON := createTestKeyStore(t, map[int][]byte{1: key1}, 1) + secret := &corev1.Secret{ + ObjectMeta: metav1.ObjectMeta{ + Name: DefaultEncryptionKeySecretName, + Namespace: RadiusNamespace, + }, + Data: map[string][]byte{ + DefaultEncryptionKeySecretKey: keyStoreJSON, + }, + } + err := k8sClient.Create(ctx, secret) + require.NoError(t, err) + }, + version: 99, + wantErr: ErrKeyVersionNotFound, + wantErrMsg: "version 99 not found", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + k8sClient := k8sutil.NewFakeKubeClient(scheme.Scheme) + tt.setupFunc(k8sClient) + + provider := NewKubernetesKeyProvider(k8sClient, nil) + key, err := provider.GetKeyByVersion(ctx, tt.version) if tt.wantErr != nil { require.ErrorIs(t, err, tt.wantErr) @@ -175,11 +378,21 @@ func TestInMemoryKeyProvider(t *testing.T) { validKey[i] = byte(i) } - t.Run("success", func(t *testing.T) { + t.Run("success-get-current-key", func(t *testing.T) { + provider, err := NewInMemoryKeyProvider(validKey) + require.NoError(t, err) + + key, version, err := provider.GetCurrentKey(ctx) + require.NoError(t, err) + require.Equal(t, validKey, key) + require.Equal(t, 1, version) + }) + + t.Run("success-get-key-by-version", func(t *testing.T) { provider, err := NewInMemoryKeyProvider(validKey) require.NoError(t, err) - key, err := provider.GetKey(ctx) + key, err := provider.GetKeyByVersion(ctx, 1) require.NoError(t, err) require.Equal(t, validKey, key) }) @@ -189,6 +402,14 @@ func TestInMemoryKeyProvider(t *testing.T) { require.ErrorIs(t, err, ErrInvalidKeySize) }) + t.Run("error-version-not-found", func(t *testing.T) { + provider, err := NewInMemoryKeyProvider(validKey) + require.NoError(t, err) + + _, err = provider.GetKeyByVersion(ctx, 99) + require.ErrorIs(t, err, ErrKeyVersionNotFound) + }) + t.Run("key-is-copied", func(t *testing.T) { originalKey := make([]byte, KeySize) for i := range originalKey { @@ -202,30 +423,128 @@ func TestInMemoryKeyProvider(t *testing.T) { originalKey[0] = 0xff // The provider's key should not be affected - key, err := provider.GetKey(ctx) + key, _, err := provider.GetCurrentKey(ctx) require.NoError(t, err) require.NotEqual(t, originalKey[0], key[0]) require.Equal(t, byte(0), key[0]) }) } +func TestInMemoryKeyProviderWithVersions(t *testing.T) { + ctx := context.Background() + key1 := make([]byte, KeySize) + key2 := make([]byte, KeySize) + for i := range key1 { + key1[i] = byte(i) + key2[i] = byte(i + 100) + } + + t.Run("success-multiple-versions", func(t *testing.T) { + provider, err := NewInMemoryKeyProviderWithVersions(map[int][]byte{1: key1, 2: key2}, 2) + require.NoError(t, err) + + // Current key should be version 2 + key, version, err := provider.GetCurrentKey(ctx) + require.NoError(t, err) + require.Equal(t, key2, key) + require.Equal(t, 2, version) + + // Should be able to get version 1 + key, err = provider.GetKeyByVersion(ctx, 1) + require.NoError(t, err) + require.Equal(t, key1, key) + + // Should be able to get version 2 + key, err = provider.GetKeyByVersion(ctx, 2) + require.NoError(t, err) + require.Equal(t, key2, key) + }) + + t.Run("error-empty-keys", func(t *testing.T) { + _, err := NewInMemoryKeyProviderWithVersions(map[int][]byte{}, 1) + require.ErrorIs(t, err, ErrKeyNotFound) + }) + + t.Run("error-current-version-not-in-keys", func(t *testing.T) { + _, err := NewInMemoryKeyProviderWithVersions(map[int][]byte{1: key1}, 99) + require.ErrorIs(t, err, ErrKeyVersionNotFound) + }) + + t.Run("error-invalid-key-size-in-map", func(t *testing.T) { + _, err := NewInMemoryKeyProviderWithVersions(map[int][]byte{1: make([]byte, 16)}, 1) + require.ErrorIs(t, err, ErrInvalidKeySize) + }) +} + +func TestInMemoryKeyProvider_AddKeyAndSetVersion(t *testing.T) { + ctx := context.Background() + key1 := make([]byte, KeySize) + key2 := make([]byte, KeySize) + for i := range key1 { + key1[i] = byte(i) + key2[i] = byte(i + 100) + } + + provider, err := NewInMemoryKeyProvider(key1) + require.NoError(t, err) + + // Initial state: version 1 + key, version, err := provider.GetCurrentKey(ctx) + require.NoError(t, err) + require.Equal(t, key1, key) + require.Equal(t, 1, version) + + // Add version 2 + err = provider.AddKey(2, key2) + require.NoError(t, err) + + // Version 2 should be accessible but not current + key, err = provider.GetKeyByVersion(ctx, 2) + require.NoError(t, err) + require.Equal(t, key2, key) + + _, version, err = provider.GetCurrentKey(ctx) + require.NoError(t, err) + require.Equal(t, 1, version) // Still version 1 + + // Set current to version 2 + err = provider.SetCurrentVersion(2) + require.NoError(t, err) + + key, version, err = provider.GetCurrentKey(ctx) + require.NoError(t, err) + require.Equal(t, key2, key) + require.Equal(t, 2, version) + + // Error: set version that doesn't exist + err = provider.SetCurrentVersion(99) + require.ErrorIs(t, err, ErrKeyVersionNotFound) + + // Error: add key with invalid size + err = provider.AddKey(3, make([]byte, 16)) + require.ErrorIs(t, err, ErrInvalidKeySize) +} + func TestKeyProviderIntegration(t *testing.T) { ctx := context.Background() - // Generate a key - key, err := GenerateKey() + // Generate keys + key1, err := GenerateKey() + require.NoError(t, err) + key2, err := GenerateKey() require.NoError(t, err) - // Create an in-memory provider - provider, err := NewInMemoryKeyProvider(key) + // Create an in-memory provider with multiple versions + provider, err := NewInMemoryKeyProviderWithVersions(map[int][]byte{1: key1, 2: key2}, 2) require.NoError(t, err) - // Get the key from the provider - retrievedKey, err := provider.GetKey(ctx) + // Get the current key from the provider + retrievedKey, version, err := provider.GetCurrentKey(ctx) require.NoError(t, err) + require.Equal(t, 2, version) - // Create an encryptor with the retrieved key - enc, err := NewEncryptor(retrievedKey) + // Create an encryptor with the retrieved key and version + enc, err := NewEncryptorWithVersion(retrievedKey, version) require.NoError(t, err) // Test encryption/decryption @@ -233,7 +552,34 @@ func TestKeyProviderIntegration(t *testing.T) { encrypted, err := enc.Encrypt(plaintext, nil) require.NoError(t, err) + // Verify the encrypted data contains the version + encVersion, err := GetEncryptedDataVersion(encrypted) + require.NoError(t, err) + require.Equal(t, 2, encVersion) + + // Decrypt with the same encryptor decrypted, err := enc.Decrypt(encrypted, nil) require.NoError(t, err) require.Equal(t, plaintext, decrypted) + + // Simulate decryption with old key version + // Get the old key + oldKey, err := provider.GetKeyByVersion(ctx, 1) + require.NoError(t, err) + oldEnc, err := NewEncryptorWithVersion(oldKey, 1) + require.NoError(t, err) + + // Encrypt with old key + oldEncrypted, err := oldEnc.Encrypt([]byte("old secret"), nil) + require.NoError(t, err) + + // Verify version is 1 + oldVersion, err := GetEncryptedDataVersion(oldEncrypted) + require.NoError(t, err) + require.Equal(t, 1, oldVersion) + + // Decrypt with old key + oldDecrypted, err := oldEnc.Decrypt(oldEncrypted, nil) + require.NoError(t, err) + require.Equal(t, []byte("old secret"), oldDecrypted) } diff --git a/pkg/crypto/encryption/sensitive.go b/pkg/crypto/encryption/sensitive.go index ade7c6f8ff..9e8016f161 100644 --- a/pkg/crypto/encryption/sensitive.go +++ b/pkg/crypto/encryption/sensitive.go @@ -42,15 +42,20 @@ var ( // SensitiveDataHandler provides methods for encrypting and decrypting sensitive fields // in data structures based on field paths marked with x-radius-sensitive annotation. type SensitiveDataHandler struct { - encryptor *Encryptor + encryptor *Encryptor + keyProvider KeyProvider } // NewSensitiveDataHandler creates a new SensitiveDataHandler with the provided encryptor. +// Note: This constructor does not support versioned key rotation for decryption. +// Use NewSensitiveDataHandlerFromProvider for full versioned key support. func NewSensitiveDataHandler(encryptor *Encryptor) *SensitiveDataHandler { return &SensitiveDataHandler{encryptor: encryptor} } // NewSensitiveDataHandlerFromKey creates a new SensitiveDataHandler from a raw encryption key. +// Note: This constructor does not support versioned key rotation for decryption. +// Use NewSensitiveDataHandlerFromProvider for full versioned key support. func NewSensitiveDataHandlerFromKey(key []byte) (*SensitiveDataHandler, error) { encryptor, err := NewEncryptor(key) if err != nil { @@ -59,13 +64,23 @@ func NewSensitiveDataHandlerFromKey(key []byte) (*SensitiveDataHandler, error) { return &SensitiveDataHandler{encryptor: encryptor}, nil } -// NewSensitiveDataHandlerFromProvider creates a new SensitiveDataHandler using a key provider. +// NewSensitiveDataHandlerFromProvider creates a new SensitiveDataHandler using a versioned key provider. +// This is the recommended constructor as it supports key rotation: +// - Encryption uses the current key version +// - Decryption reads the version from encrypted data and fetches the appropriate key func NewSensitiveDataHandlerFromProvider(ctx context.Context, provider KeyProvider) (*SensitiveDataHandler, error) { - key, err := provider.GetKey(ctx) + key, version, err := provider.GetCurrentKey(ctx) if err != nil { return nil, err } - return NewSensitiveDataHandlerFromKey(key) + encryptor, err := NewEncryptorWithVersion(key, version) + if err != nil { + return nil, err + } + return &SensitiveDataHandler{ + encryptor: encryptor, + keyProvider: provider, + }, nil } // EncryptSensitiveFields encrypts all sensitive fields in the data based on the provided field paths. @@ -77,11 +92,17 @@ func NewSensitiveDataHandlerFromProvider(ctx context.Context, provider KeyProvid // resource ID (e.g., "/planes/radius/local/resourceGroups/test/providers/Foo.Bar/myResources/test"). // // Returns an error if any field encryption fails. In case of error, partial encryption may have occurred. +// Fields that are not found are skipped - this allows optional sensitive fields to be absent. func (h *SensitiveDataHandler) EncryptSensitiveFields(data map[string]any, sensitiveFieldPaths []string, resourceID string) error { for _, path := range sensitiveFieldPaths { // Build associated data from resource ID and field path ad := buildAssociatedData(resourceID, path) if err := h.encryptFieldAtPath(data, path, ad); err != nil { + // Skip fields that are not found - they may not exist in this resource instance + // (e.g., optional sensitive properties) + if errors.Is(err, ErrFieldNotFound) { + continue + } return fmt.Errorf("%w: path %q: %v", ErrFieldEncryptionFailed, path, err) } } @@ -92,16 +113,17 @@ func (h *SensitiveDataHandler) EncryptSensitiveFields(data map[string]any, sensi // The data is modified in place. Field paths support dot notation and [*] for arrays/maps. // // The resourceID must match what was provided during encryption for successful decryption. +// The context is used to fetch versioned keys from the key provider when needed. // // Note: This method does not use schema information for type restoration. Numbers in decrypted // objects will be returned as float64 (standard Go JSON behavior). For accurate type restoration, // use DecryptSensitiveFieldsWithSchema instead. // // Returns an error if any field decryption fails. In case of error, partial decryption may have occurred. -func (h *SensitiveDataHandler) DecryptSensitiveFields(data map[string]any, sensitiveFieldPaths []string, resourceID string) error { +func (h *SensitiveDataHandler) DecryptSensitiveFields(ctx context.Context, data map[string]any, sensitiveFieldPaths []string, resourceID string) error { for _, path := range sensitiveFieldPaths { ad := buildAssociatedData(resourceID, path) - if err := h.decryptFieldAtPath(data, path, nil, ad); err != nil { + if err := h.decryptFieldAtPath(ctx, data, path, nil, ad); err != nil { // Skip fields that are not found - they may not exist in this resource instance if errors.Is(err, ErrFieldNotFound) { continue @@ -117,16 +139,17 @@ func (h *SensitiveDataHandler) DecryptSensitiveFields(data map[string]any, sensi // The data is modified in place. Field paths support dot notation and [*] for arrays/maps. // // The resourceID must match what was provided during encryption for successful decryption. +// The context is used to fetch versioned keys from the key provider when needed. // The schema is used to restore the correct types for fields within encrypted objects (e.g., integers // that would otherwise be decoded as float64). // // Returns an error if any field decryption fails. In case of error, partial decryption may have occurred. -func (h *SensitiveDataHandler) DecryptSensitiveFieldsWithSchema(data map[string]any, sensitiveFieldPaths []string, resourceID string, schema map[string]any) error { +func (h *SensitiveDataHandler) DecryptSensitiveFieldsWithSchema(ctx context.Context, data map[string]any, sensitiveFieldPaths []string, resourceID string, schema map[string]any) error { for _, path := range sensitiveFieldPaths { // Get the schema for this specific field path fieldSchema := getSchemaForPath(schema, path) ad := buildAssociatedData(resourceID, path) - if err := h.decryptFieldAtPath(data, path, fieldSchema, ad); err != nil { + if err := h.decryptFieldAtPath(ctx, data, path, fieldSchema, ad); err != nil { // Skip fields that are not found - they may not exist in this resource instance if errors.Is(err, ErrFieldNotFound) { continue @@ -137,6 +160,35 @@ func (h *SensitiveDataHandler) DecryptSensitiveFieldsWithSchema(data map[string] return nil } +// getEncryptorForDecryption returns the appropriate encryptor for decrypting data. +// If a keyProvider is available and the data contains a version, it fetches the versioned key. +// Otherwise, it falls back to the default encryptor. +func (h *SensitiveDataHandler) getEncryptorForDecryption(ctx context.Context, encryptedJSON []byte) (*Encryptor, error) { + // If no key provider, use the default encryptor + if h.keyProvider == nil { + return h.encryptor, nil + } + + // Extract the version from the encrypted data + version, err := GetEncryptedDataVersion(encryptedJSON) + if err != nil { + return nil, err + } + + // If version is 0 (unversioned/legacy data), use the default encryptor + if version == 0 { + return h.encryptor, nil + } + + // Fetch the key for this specific version + key, err := h.keyProvider.GetKeyByVersion(ctx, version) + if err != nil { + return nil, fmt.Errorf("failed to get key for version %d: %w", version, err) + } + + return NewEncryptorWithVersion(key, version) +} + // encryptFieldAtPath encrypts the value at the given field path in the data. func (h *SensitiveDataHandler) encryptFieldAtPath(data map[string]any, path string, associatedData []byte) error { processor := func(value any) (any, error) { @@ -147,9 +199,9 @@ func (h *SensitiveDataHandler) encryptFieldAtPath(data map[string]any, path stri // decryptFieldAtPath decrypts the value at the given field path in the data. // If fieldSchema is provided, it will be used for type restoration. -func (h *SensitiveDataHandler) decryptFieldAtPath(data map[string]any, path string, fieldSchema map[string]any, associatedData []byte) error { +func (h *SensitiveDataHandler) decryptFieldAtPath(ctx context.Context, data map[string]any, path string, fieldSchema map[string]any, associatedData []byte) error { processor := func(value any) (any, error) { - return h.decryptValue(value, fieldSchema, associatedData) + return h.decryptValue(ctx, value, fieldSchema, associatedData) } return h.processFieldAtPath(data, path, processor) } @@ -337,7 +389,8 @@ func (h *SensitiveDataHandler) encryptValue(value any, associatedData []byte) (a } // decryptValue decrypts a single value, restoring the original type using schema information if provided. -func (h *SensitiveDataHandler) decryptValue(value any, fieldSchema map[string]any, associatedData []byte) (any, error) { +// If a keyProvider is available and the encrypted data contains a version, it will fetch the appropriate key. +func (h *SensitiveDataHandler) decryptValue(ctx context.Context, value any, fieldSchema map[string]any, associatedData []byte) (any, error) { if value == nil { return nil, nil } @@ -363,7 +416,13 @@ func (h *SensitiveDataHandler) decryptValue(value any, fieldSchema map[string]an return nil, err } - decrypted, err := h.encryptor.Decrypt(encryptedJSON, associatedData) + // Get the appropriate encryptor based on the key version in the encrypted data + encryptor, err := h.getEncryptorForDecryption(ctx, encryptedJSON) + if err != nil { + return nil, err + } + + decrypted, err := encryptor.Decrypt(encryptedJSON, associatedData) if err != nil { return nil, err } diff --git a/pkg/crypto/encryption/sensitive_test.go b/pkg/crypto/encryption/sensitive_test.go index 75efbbb3fb..990656cb93 100644 --- a/pkg/crypto/encryption/sensitive_test.go +++ b/pkg/crypto/encryption/sensitive_test.go @@ -143,7 +143,7 @@ func TestSensitiveDataHandler_EncryptDecrypt_SimpleField(t *testing.T) { require.Equal(t, "admin", data["username"]) // Decrypt - err = handler.DecryptSensitiveFields(data, []string{"password"}, testResourceID) + err = handler.DecryptSensitiveFields(context.Background(), data, []string{"password"}, testResourceID) require.NoError(t, err) // Verify password is decrypted @@ -182,7 +182,7 @@ func TestSensitiveDataHandler_EncryptDecrypt_NestedField(t *testing.T) { require.True(t, apiKeyIsEncrypted) // Decrypt - err = handler.DecryptSensitiveFields(data, []string{"credentials.password", "credentials.apiKey"}, testResourceID) + err = handler.DecryptSensitiveFields(context.Background(), data, []string{"credentials.password", "credentials.apiKey"}, testResourceID) require.NoError(t, err) creds = data["credentials"].(map[string]any) @@ -221,7 +221,7 @@ func TestSensitiveDataHandler_EncryptDecrypt_ArrayWildcard(t *testing.T) { } // Decrypt - err = handler.DecryptSensitiveFields(data, []string{"secrets[*].value"}, testResourceID) + err = handler.DecryptSensitiveFields(context.Background(), data, []string{"secrets[*].value"}, testResourceID) require.NoError(t, err) secrets = data["secrets"].([]any) @@ -259,7 +259,7 @@ func TestSensitiveDataHandler_EncryptDecrypt_MapWildcard(t *testing.T) { } // Decrypt - err = handler.DecryptSensitiveFields(data, []string{"config[*]"}, testResourceID) + err = handler.DecryptSensitiveFields(context.Background(), data, []string{"config[*]"}, testResourceID) require.NoError(t, err) config = data["config"].(map[string]any) @@ -299,7 +299,7 @@ func TestSensitiveDataHandler_EncryptDecrypt_ObjectValue(t *testing.T) { require.NotEmpty(t, encData["nonce"]) // Decrypt - err = handler.DecryptSensitiveFields(data, []string{"sensitiveConfig"}, testResourceID) + err = handler.DecryptSensitiveFields(context.Background(), data, []string{"sensitiveConfig"}, testResourceID) require.NoError(t, err) // Verify decrypted object @@ -322,16 +322,58 @@ func TestSensitiveDataHandler_FieldNotFound(t *testing.T) { "username": "admin", } - // Encrypting non-existent field should return error + // Encrypting non-existent field should be skipped (no error) - supports optional sensitive fields err = handler.EncryptSensitiveFields(data, []string{"password"}, testResourceID) - require.Error(t, err) - require.ErrorIs(t, err, ErrFieldEncryptionFailed) + require.NoError(t, err) // Decrypting non-existent field should be skipped (no error) - err = handler.DecryptSensitiveFields(data, []string{"password"}, testResourceID) + err = handler.DecryptSensitiveFields(context.Background(), data, []string{"password"}, testResourceID) require.NoError(t, err) } +func TestSensitiveDataHandler_OptionalSensitiveFields(t *testing.T) { + key, err := GenerateKey() + require.NoError(t, err) + + handler, err := NewSensitiveDataHandlerFromKey(key) + require.NoError(t, err) + + // Simulate a resource with optional sensitive fields - only some are present + data := map[string]any{ + "name": "my-resource", + "password": "secret-password", // present + // "apiKey" is absent (optional) + // "credentials.token" is absent (optional nested) + } + + sensitivePaths := []string{"password", "apiKey", "credentials.token"} + + // Encrypt should succeed even though some fields are missing + err = handler.EncryptSensitiveFields(data, sensitivePaths, testResourceID) + require.NoError(t, err) + + // The present field should be encrypted + _, isEncrypted := data["password"].(map[string]any) + require.True(t, isEncrypted, "password should be encrypted") + + // Name should be unchanged + require.Equal(t, "my-resource", data["name"]) + + // Missing fields should not be added + _, hasAPIKey := data["apiKey"] + require.False(t, hasAPIKey, "apiKey should not be added") + + _, hasCredentials := data["credentials"] + require.False(t, hasCredentials, "credentials should not be added") + + // Decrypt should also succeed + err = handler.DecryptSensitiveFields(context.Background(), data, sensitivePaths, testResourceID) + require.NoError(t, err) + + // Verify decryption worked for the present field + require.Equal(t, "secret-password", data["password"]) +} + func TestSensitiveDataHandler_EmptyValue(t *testing.T) { key, err := GenerateKey() require.NoError(t, err) @@ -442,7 +484,7 @@ func TestSensitiveDataHandler_RoundTrip_ComplexStructure(t *testing.T) { require.Equal(t, "visible", data["config"].(map[string]any)["public_setting"]) // Decrypt - err = handler.DecryptSensitiveFields(data, sensitivePaths, testResourceID) + err = handler.DecryptSensitiveFields(context.Background(), data, sensitivePaths, testResourceID) require.NoError(t, err) // Verify values are restored @@ -476,7 +518,7 @@ func TestSensitiveDataHandler_FromProvider(t *testing.T) { _, isEncrypted := data["secret"].(map[string]any) require.True(t, isEncrypted) - err = handler.DecryptSensitiveFields(data, []string{"secret"}, testResourceID) + err = handler.DecryptSensitiveFields(ctx, data, []string{"secret"}, testResourceID) require.NoError(t, err) require.Equal(t, "my-secret", data["secret"]) } @@ -511,7 +553,7 @@ func TestSensitiveDataHandler_SpecificIndex(t *testing.T) { require.True(t, isEncrypted) // Decrypt - err = handler.DecryptSensitiveFields(data, []string{"items[1].value"}, testResourceID) + err = handler.DecryptSensitiveFields(context.Background(), data, []string{"items[1].value"}, testResourceID) require.NoError(t, err) require.Equal(t, "secret", items[1].(map[string]any)["value"]) @@ -570,7 +612,7 @@ func TestSensitiveDataHandler_DecryptWithSchema_IntegerRestoration(t *testing.T) require.True(t, isEncrypted) // Decrypt WITH schema - err = handler.DecryptSensitiveFieldsWithSchema(data, sensitivePaths, testResourceID, schema) + err = handler.DecryptSensitiveFieldsWithSchema(context.Background(), data, sensitivePaths, testResourceID, schema) require.NoError(t, err) // Verify types are correctly restored @@ -649,7 +691,7 @@ func TestSensitiveDataHandler_DecryptWithSchema_NestedObjects(t *testing.T) { err = handler.EncryptSensitiveFields(data, sensitivePaths, testResourceID) require.NoError(t, err) - err = handler.DecryptSensitiveFieldsWithSchema(data, sensitivePaths, testResourceID, schema) + err = handler.DecryptSensitiveFieldsWithSchema(context.Background(), data, sensitivePaths, testResourceID, schema) require.NoError(t, err) // Verify nested integers are restored @@ -702,7 +744,7 @@ func TestSensitiveDataHandler_DecryptWithSchema_ArrayWithIntegers(t *testing.T) err = handler.EncryptSensitiveFields(data, sensitivePaths, testResourceID) require.NoError(t, err) - err = handler.DecryptSensitiveFieldsWithSchema(data, sensitivePaths, testResourceID, schema) + err = handler.DecryptSensitiveFieldsWithSchema(context.Background(), data, sensitivePaths, testResourceID, schema) require.NoError(t, err) config := data["config"].(map[string]any) @@ -735,7 +777,7 @@ func TestSensitiveDataHandler_DecryptWithoutSchema_NoTypeCoercion(t *testing.T) require.NoError(t, err) // Decrypt WITHOUT schema - err = handler.DecryptSensitiveFields(data, sensitivePaths, testResourceID) + err = handler.DecryptSensitiveFields(context.Background(), data, sensitivePaths, testResourceID) require.NoError(t, err) config := data["sensitiveConfig"].(map[string]any) @@ -812,6 +854,155 @@ func TestGetSchemaForPath(t *testing.T) { } } +// Test versioned key rotation support +func TestSensitiveDataHandler_VersionedKeyRotation(t *testing.T) { + ctx := context.Background() + + // Generate two different keys + key1, err := GenerateKey() + require.NoError(t, err) + key2, err := GenerateKey() + require.NoError(t, err) + + // Create a provider with both keys, version 1 is current + provider, err := NewInMemoryKeyProviderWithVersions(map[int][]byte{1: key1, 2: key2}, 1) + require.NoError(t, err) + + // Create handler with version 1 as current + handler1, err := NewSensitiveDataHandlerFromProvider(ctx, provider) + require.NoError(t, err) + + // Encrypt data with version 1 + data := map[string]any{ + "password": "secret-v1", + } + err = handler1.EncryptSensitiveFields(data, []string{"password"}, testResourceID) + require.NoError(t, err) + + // Verify it's encrypted + encData := data["password"].(map[string]any) + require.NotEmpty(t, encData["encrypted"]) + require.Equal(t, float64(1), encData["version"]) // JSON unmarshals to float64 + + // Simulate key rotation: set version 2 as current + err = provider.SetCurrentVersion(2) + require.NoError(t, err) + + // Create new handler (would happen on pod restart after rotation) + handler2, err := NewSensitiveDataHandlerFromProvider(ctx, provider) + require.NoError(t, err) + + // Handler2 should be able to decrypt data encrypted with version 1 + // because it has access to all keys via the provider + err = handler2.DecryptSensitiveFields(ctx, data, []string{"password"}, testResourceID) + require.NoError(t, err) + require.Equal(t, "secret-v1", data["password"]) + + // Now encrypt new data with handler2 (should use version 2) + data2 := map[string]any{ + "password": "secret-v2", + } + err = handler2.EncryptSensitiveFields(data2, []string{"password"}, testResourceID) + require.NoError(t, err) + + encData2 := data2["password"].(map[string]any) + require.Equal(t, float64(2), encData2["version"]) // Should be version 2 + + // Decrypt the version 2 data + err = handler2.DecryptSensitiveFields(ctx, data2, []string{"password"}, testResourceID) + require.NoError(t, err) + require.Equal(t, "secret-v2", data2["password"]) +} + +func TestSensitiveDataHandler_DecryptWithOldKeyVersion(t *testing.T) { + ctx := context.Background() + + // Generate keys + key1, err := GenerateKey() + require.NoError(t, err) + key2, err := GenerateKey() + require.NoError(t, err) + key3, err := GenerateKey() + require.NoError(t, err) + + // Create provider with all three keys + provider, err := NewInMemoryKeyProviderWithVersions(map[int][]byte{1: key1, 2: key2, 3: key3}, 3) + require.NoError(t, err) + + // Create handlers for each version to encrypt data + err = provider.SetCurrentVersion(1) + require.NoError(t, err) + handler1, err := NewSensitiveDataHandlerFromProvider(ctx, provider) + require.NoError(t, err) + + err = provider.SetCurrentVersion(2) + require.NoError(t, err) + handler2, err := NewSensitiveDataHandlerFromProvider(ctx, provider) + require.NoError(t, err) + + err = provider.SetCurrentVersion(3) + require.NoError(t, err) + handler3, err := NewSensitiveDataHandlerFromProvider(ctx, provider) + require.NoError(t, err) + + // Encrypt data with each version + dataV1 := map[string]any{"secret": "encrypted-with-v1"} + dataV2 := map[string]any{"secret": "encrypted-with-v2"} + dataV3 := map[string]any{"secret": "encrypted-with-v3"} + + err = handler1.EncryptSensitiveFields(dataV1, []string{"secret"}, testResourceID) + require.NoError(t, err) + err = handler2.EncryptSensitiveFields(dataV2, []string{"secret"}, testResourceID) + require.NoError(t, err) + err = handler3.EncryptSensitiveFields(dataV3, []string{"secret"}, testResourceID) + require.NoError(t, err) + + // Using handler3 (current), should be able to decrypt all versions + err = handler3.DecryptSensitiveFields(ctx, dataV1, []string{"secret"}, testResourceID) + require.NoError(t, err) + require.Equal(t, "encrypted-with-v1", dataV1["secret"]) + + err = handler3.DecryptSensitiveFields(ctx, dataV2, []string{"secret"}, testResourceID) + require.NoError(t, err) + require.Equal(t, "encrypted-with-v2", dataV2["secret"]) + + err = handler3.DecryptSensitiveFields(ctx, dataV3, []string{"secret"}, testResourceID) + require.NoError(t, err) + require.Equal(t, "encrypted-with-v3", dataV3["secret"]) +} + +func TestSensitiveDataHandler_DecryptWithMissingKeyVersion(t *testing.T) { + ctx := context.Background() + + // Generate keys + key1, err := GenerateKey() + require.NoError(t, err) + key2, err := GenerateKey() + require.NoError(t, err) + + // Create provider with only key version 2 (simulates key 1 was removed after grace period) + provider, err := NewInMemoryKeyProviderWithVersions(map[int][]byte{2: key2}, 2) + require.NoError(t, err) + + // Create handler with only version 2 + handler, err := NewSensitiveDataHandlerFromProvider(ctx, provider) + require.NoError(t, err) + + // Create data that was encrypted with version 1 (using a separate handler) + tempProvider, _ := NewInMemoryKeyProviderWithVersions(map[int][]byte{1: key1}, 1) + tempHandler, _ := NewSensitiveDataHandlerFromProvider(ctx, tempProvider) + + dataV1 := map[string]any{"secret": "old-secret"} + err = tempHandler.EncryptSensitiveFields(dataV1, []string{"secret"}, testResourceID) + require.NoError(t, err) + + // Trying to decrypt with handler that doesn't have version 1 should fail + err = handler.DecryptSensitiveFields(ctx, dataV1, []string{"secret"}, testResourceID) + require.Error(t, err) + require.ErrorIs(t, err, ErrFieldDecryptionFailed) + require.Contains(t, err.Error(), "key version not found") +} + // Helper function to deep copy a map for testing func deepCopyMap(original map[string]any) map[string]any { result := make(map[string]any) From d3e8f3380555e4a03a94221b27fa2c84c6d8d9f8 Mon Sep 17 00:00:00 2001 From: lakshmimsft Date: Thu, 29 Jan 2026 12:26:36 -0800 Subject: [PATCH 3/3] updating per comment Signed-off-by: lakshmimsft --- pkg/crypto/encryption/keyprovider_test.go | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/pkg/crypto/encryption/keyprovider_test.go b/pkg/crypto/encryption/keyprovider_test.go index 9597a91758..7f83c8381a 100644 --- a/pkg/crypto/encryption/keyprovider_test.go +++ b/pkg/crypto/encryption/keyprovider_test.go @@ -20,6 +20,7 @@ import ( "context" "encoding/base64" "encoding/json" + "strconv" "testing" "github.com/radius-project/radius/test/k8sutil" @@ -37,10 +38,7 @@ func createTestKeyStore(t *testing.T, keys map[int][]byte, currentVersion int) [ Keys: make(map[string]KeyData), } for version, key := range keys { - versionStr := string(rune('0' + version)) - if version >= 10 { - versionStr = string(rune('0'+version/10)) + string(rune('0'+version%10)) - } + versionStr := strconv.Itoa(version) keyStore.Keys[versionStr] = KeyData{ Key: base64.StdEncoding.EncodeToString(key), Version: version,