diff --git a/pkg/crypto/encryption/encryption.go b/pkg/crypto/encryption/encryption.go new file mode 100644 index 0000000000..0780435592 --- /dev/null +++ b/pkg/crypto/encryption/encryption.go @@ -0,0 +1,269 @@ +/* +Copyright 2023 The Radius Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package encryption + +import ( + "crypto/cipher" + "crypto/rand" + "crypto/sha256" + "encoding/base64" + "encoding/json" + "errors" + "fmt" + "io" + + "golang.org/x/crypto/chacha20poly1305" +) + +const ( + // KeySize is the required size for ChaCha20-Poly1305 keys (256 bits). + KeySize = chacha20poly1305.KeySize + + // NonceSize is the size of the nonce for ChaCha20-Poly1305. + NonceSize = chacha20poly1305.NonceSize +) + +var ( + // ErrInvalidKeySize is returned when the encryption key is not the correct size. + ErrInvalidKeySize = errors.New("encryption key must be 32 bytes (256 bits)") + + // ErrEncryptionFailed is returned when encryption fails. + ErrEncryptionFailed = errors.New("encryption failed") + + // ErrDecryptionFailed is returned when decryption fails. + ErrDecryptionFailed = errors.New("decryption failed") + + // ErrInvalidEncryptedData is returned when the encrypted data format is invalid. + ErrInvalidEncryptedData = errors.New("invalid encrypted data format") + + // ErrEmptyPlaintext is returned when attempting to encrypt empty data. + ErrEmptyPlaintext = errors.New("plaintext cannot be empty") + + // ErrAssociatedDataMismatch is returned when the associated data provided during + // decryption does not match what was used during encryption. + ErrAssociatedDataMismatch = errors.New("associated data mismatch") +) + +// EncryptedData represents the structure for storing encrypted data. +// It contains the base64-encoded ciphertext and nonce, plus optional associated data hash. +type EncryptedData struct { + // Encrypted contains the base64-encoded ciphertext. + Encrypted string `json:"encrypted"` + // Nonce contains the base64-encoded nonce used for encryption. + Nonce string `json:"nonce"` + // AD contains a hash of the associated data used during encryption (optional). + // This is stored for verification purposes - the actual AD must be provided during decryption. + // The hash allows detection of AD mismatches without exposing the AD value. + AD string `json:"ad,omitempty"` +} + +// Encryptor provides methods for encrypting and decrypting data using ChaCha20-Poly1305. +type Encryptor struct { + aead cipher.AEAD +} + +// NewEncryptor creates a new Encryptor with the provided 256-bit key. +// Returns an error if the key is not exactly 32 bytes. +func NewEncryptor(key []byte) (*Encryptor, error) { + if len(key) != KeySize { + return nil, ErrInvalidKeySize + } + + aead, err := chacha20poly1305.New(key) + if err != nil { + return nil, fmt.Errorf("%w: %v", ErrEncryptionFailed, err) + } + + return &Encryptor{aead: aead}, nil +} + +// Encrypt encrypts the plaintext using ChaCha20-Poly1305 with Associated Data (AD). +// The AD provides authentication for contextual data (like resource ID or field path) without +// encrypting it. This binds the ciphertext to its context, preventing an attacker from +// moving encrypted values between different resources or fields. +// +// The AD is authenticated but NOT encrypted - it must be provided again during decryption. +// A hash of the AD is stored in the encrypted data structure to allow early detection of mismatches. +// +// Example AD values: +// - Resource ID: "/planes/radius/local/resourceGroups/test/providers/Foo.Bar/myResources/test" +// - Field path: "credentials.password" +// - Combined: resourceID + ":" + fieldPath +// +// Pass nil for associatedData if no context binding is needed (not recommended for sensitive data). +func (e *Encryptor) Encrypt(plaintext []byte, associatedData []byte) ([]byte, error) { + if len(plaintext) == 0 { + return nil, ErrEmptyPlaintext + } + + // Generate a unique nonce for this encryption operation + nonce, err := generateNonce(e.aead.NonceSize()) + if err != nil { + return nil, fmt.Errorf("%w: failed to generate nonce: %v", ErrEncryptionFailed, err) + } + + // Encrypt the plaintext with associated data + // The AD is authenticated (included in the auth tag) but not encrypted + ciphertext := e.aead.Seal(nil, nonce, plaintext, associatedData) + + // Create the encrypted data structure + encryptedData := EncryptedData{ + Encrypted: base64.StdEncoding.EncodeToString(ciphertext), + Nonce: base64.StdEncoding.EncodeToString(nonce), + } + + // Store a hash of the AD if provided (for verification during decryption) + if len(associatedData) > 0 { + encryptedData.AD = hashAD(associatedData) + } + + // Marshal to JSON + result, err := json.Marshal(encryptedData) + if err != nil { + return nil, fmt.Errorf("%w: failed to marshal encrypted data: %v", ErrEncryptionFailed, err) + } + + return result, nil +} + +// Decrypt decrypts the data that was encrypted using the Encrypt method. +// The associatedData must match what was provided during encryption; if the AD +// was used during encryption, it must be provided here for successful decryption. +// The input should be JSON-encoded EncryptedData. +func (e *Encryptor) Decrypt(data []byte, associatedData []byte) ([]byte, error) { + if len(data) == 0 { + return nil, ErrInvalidEncryptedData + } + + // Parse the encrypted data structure + var encryptedData EncryptedData + if err := json.Unmarshal(data, &encryptedData); err != nil { + return nil, fmt.Errorf("%w: failed to parse encrypted data: %v", ErrInvalidEncryptedData, err) + } + + // Verify AD hash matches if AD was used during encryption + if encryptedData.AD != "" { + if len(associatedData) == 0 { + return nil, fmt.Errorf("%w: encrypted data requires associated data but none provided", ErrAssociatedDataMismatch) + } + if hashAD(associatedData) != encryptedData.AD { + return nil, fmt.Errorf("%w: provided associated data does not match", ErrAssociatedDataMismatch) + } + } + + // Decode the base64-encoded ciphertext + ciphertext, err := base64.StdEncoding.DecodeString(encryptedData.Encrypted) + if err != nil { + return nil, fmt.Errorf("%w: failed to decode ciphertext: %v", ErrInvalidEncryptedData, err) + } + + // Decode the base64-encoded nonce + nonce, err := base64.StdEncoding.DecodeString(encryptedData.Nonce) + if err != nil { + return nil, fmt.Errorf("%w: failed to decode nonce: %v", ErrInvalidEncryptedData, err) + } + + // Validate nonce size + if len(nonce) != e.aead.NonceSize() { + return nil, fmt.Errorf("%w: invalid nonce size", ErrInvalidEncryptedData) + } + + // Decrypt the ciphertext with the same associated data + plaintext, err := e.aead.Open(nil, nonce, ciphertext, associatedData) + if err != nil { + return nil, fmt.Errorf("%w: %v", ErrDecryptionFailed, err) + } + + return plaintext, nil +} + +// EncryptString encrypts a string with associated data and returns the JSON-encoded encrypted data as a string. +func (e *Encryptor) EncryptString(plaintext string, associatedData []byte) (string, error) { + encrypted, err := e.Encrypt([]byte(plaintext), associatedData) + if err != nil { + return "", err + } + return string(encrypted), nil +} + +// DecryptString decrypts the JSON-encoded encrypted data with associated data and returns the original string. +func (e *Encryptor) DecryptString(data string, associatedData []byte) (string, error) { + decrypted, err := e.Decrypt([]byte(data), associatedData) + if err != nil { + return "", err + } + return string(decrypted), nil +} + +// hashAD creates a truncated SHA-256 hash of the associated data for storage. +// This allows verification that the correct AD is provided during decryption +// without storing the actual AD value. +func hashAD(ad []byte) string { + hash := sha256.Sum256(ad) + // Use first 16 bytes (128 bits) - sufficient for verification, saves storage + return base64.StdEncoding.EncodeToString(hash[:16]) +} + +// generateNonce generates a cryptographically secure random nonce. +func generateNonce(size int) ([]byte, error) { + nonce := make([]byte, size) + if _, err := io.ReadFull(rand.Reader, nonce); err != nil { + return nil, err + } + return nonce, nil +} + +// IsEncryptedData checks if the given data appears to be in the encrypted data format. +// It validates that the data is valid JSON with non-empty encrypted and nonce fields, +// and that both fields contain valid base64-encoded data with appropriate nonce size. +func IsEncryptedData(data []byte) bool { + var encryptedData EncryptedData + if err := json.Unmarshal(data, &encryptedData); err != nil { + return false + } + + if encryptedData.Encrypted == "" || encryptedData.Nonce == "" { + return false + } + + // Validate base64 encoding of ciphertext + if _, err := base64.StdEncoding.DecodeString(encryptedData.Encrypted); err != nil { + return false + } + + // Validate base64 encoding and size of nonce + nonce, err := base64.StdEncoding.DecodeString(encryptedData.Nonce) + if err != nil { + return false + } + + // ChaCha20-Poly1305 nonce must be 12 bytes + if len(nonce) != NonceSize { + return false + } + + return true +} + +// GenerateKey generates a new random 256-bit encryption key. +func GenerateKey() ([]byte, error) { + key := make([]byte, KeySize) + if _, err := io.ReadFull(rand.Reader, key); err != nil { + return nil, fmt.Errorf("failed to generate encryption key: %w", err) + } + return key, nil +} diff --git a/pkg/crypto/encryption/encryption_test.go b/pkg/crypto/encryption/encryption_test.go new file mode 100644 index 0000000000..863e89c1e6 --- /dev/null +++ b/pkg/crypto/encryption/encryption_test.go @@ -0,0 +1,594 @@ +/* +Copyright 2023 The Radius Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package encryption + +import ( + "encoding/json" + "testing" + + "github.com/stretchr/testify/require" +) + +func TestNewEncryptor(t *testing.T) { + tests := []struct { + name string + key []byte + wantErr error + }{ + { + name: "valid-32-byte-key", + key: make([]byte, 32), + wantErr: nil, + }, + { + name: "invalid-key-too-short", + key: make([]byte, 16), + wantErr: ErrInvalidKeySize, + }, + { + name: "invalid-key-too-long", + key: make([]byte, 64), + wantErr: ErrInvalidKeySize, + }, + { + name: "invalid-empty-key", + key: []byte{}, + wantErr: ErrInvalidKeySize, + }, + { + name: "invalid-nil-key", + key: nil, + wantErr: ErrInvalidKeySize, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + enc, err := NewEncryptor(tt.key) + if tt.wantErr != nil { + require.ErrorIs(t, err, tt.wantErr) + require.Nil(t, enc) + } else { + require.NoError(t, err) + require.NotNil(t, enc) + } + }) + } +} + +func TestEncrypt(t *testing.T) { + key := make([]byte, 32) + for i := range key { + key[i] = byte(i) + } + enc, err := NewEncryptor(key) + require.NoError(t, err) + + tests := []struct { + name string + plaintext []byte + wantErr error + }{ + { + name: "encrypt-simple-text", + plaintext: []byte("hello world"), + wantErr: nil, + }, + { + name: "encrypt-json-data", + plaintext: []byte(`{"password": "secret123", "token": "abc-xyz"}`), + wantErr: nil, + }, + { + name: "encrypt-binary-data", + plaintext: []byte{0x00, 0x01, 0x02, 0xff, 0xfe, 0xfd}, + wantErr: nil, + }, + { + name: "encrypt-long-text", + plaintext: make([]byte, 10000), + wantErr: nil, + }, + { + name: "encrypt-empty-plaintext", + plaintext: []byte{}, + wantErr: ErrEmptyPlaintext, + }, + { + name: "encrypt-nil-plaintext", + plaintext: nil, + wantErr: ErrEmptyPlaintext, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + encrypted, err := enc.Encrypt(tt.plaintext, nil) + if tt.wantErr != nil { + require.ErrorIs(t, err, tt.wantErr) + require.Nil(t, encrypted) + } else { + require.NoError(t, err) + require.NotNil(t, encrypted) + + // Verify the encrypted data is valid JSON with expected structure + var encData EncryptedData + err = json.Unmarshal(encrypted, &encData) + require.NoError(t, err) + require.NotEmpty(t, encData.Encrypted) + require.NotEmpty(t, encData.Nonce) + } + }) + } +} + +func TestDecrypt(t *testing.T) { + key := make([]byte, 32) + for i := range key { + key[i] = byte(i) + } + enc, err := NewEncryptor(key) + require.NoError(t, err) + + // Create valid encrypted data for testing + validPlaintext := []byte("secret data") + validEncrypted, err := enc.Encrypt(validPlaintext, nil) + require.NoError(t, err) + + tests := []struct { + name string + data []byte + wantErr error + }{ + { + name: "decrypt-valid-data", + data: validEncrypted, + wantErr: nil, + }, + { + name: "decrypt-empty-data", + data: []byte{}, + wantErr: ErrInvalidEncryptedData, + }, + { + name: "decrypt-nil-data", + data: nil, + wantErr: ErrInvalidEncryptedData, + }, + { + name: "decrypt-invalid-json", + data: []byte("not json"), + wantErr: ErrInvalidEncryptedData, + }, + { + name: "decrypt-missing-encrypted-field", + data: []byte(`{"nonce": "dGVzdA=="}`), + wantErr: ErrInvalidEncryptedData, + }, + { + name: "decrypt-missing-nonce-field", + data: []byte(`{"encrypted": "dGVzdA=="}`), + wantErr: ErrInvalidEncryptedData, + }, + { + name: "decrypt-invalid-base64-ciphertext", + data: []byte(`{"encrypted": "not-valid-base64!!!", "nonce": "dGVzdA=="}`), + wantErr: ErrInvalidEncryptedData, + }, + { + name: "decrypt-invalid-base64-nonce", + data: []byte(`{"encrypted": "dGVzdA==", "nonce": "not-valid-base64!!!"}`), + wantErr: ErrInvalidEncryptedData, + }, + { + name: "decrypt-wrong-nonce-size", + data: []byte(`{"encrypted": "dGVzdA==", "nonce": "dGVzdA=="}`), + wantErr: ErrInvalidEncryptedData, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + decrypted, err := enc.Decrypt(tt.data, nil) + if tt.wantErr != nil { + require.ErrorIs(t, err, tt.wantErr) + } else { + require.NoError(t, err) + require.Equal(t, validPlaintext, decrypted) + } + }) + } +} + +func TestEncryptDecryptRoundTrip(t *testing.T) { + key, err := GenerateKey() + require.NoError(t, err) + + enc, err := NewEncryptor(key) + require.NoError(t, err) + + testCases := []struct { + name string + plaintext []byte + }{ + { + name: "simple-text", + plaintext: []byte("hello world"), + }, + { + name: "json-secret", + plaintext: []byte(`{"password": "super-secret-password", "apiKey": "xyz-123-abc"}`), + }, + { + name: "unicode-text", + plaintext: []byte("Hello δΈ–η•Œ! πŸ”"), + }, + { + name: "binary-data", + plaintext: []byte{0x00, 0x01, 0x02, 0x03, 0xff, 0xfe, 0xfd, 0xfc}, + }, + { + name: "large-data", + plaintext: make([]byte, 65536), + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + // Encrypt + encrypted, err := enc.Encrypt(tc.plaintext, nil) + require.NoError(t, err) + require.NotEqual(t, tc.plaintext, encrypted, "encrypted data should differ from plaintext") + + // Decrypt + decrypted, err := enc.Decrypt(encrypted, nil) + require.NoError(t, err) + require.Equal(t, tc.plaintext, decrypted, "decrypted data should match original plaintext") + }) + } +} + +func TestEncryptStringDecryptString(t *testing.T) { + key, err := GenerateKey() + require.NoError(t, err) + + enc, err := NewEncryptor(key) + require.NoError(t, err) + + testCases := []string{ + "simple password", + "API-KEY-12345", + `{"token": "secret"}`, + "Unicode: ζ—₯本θͺž πŸ”‘", + } + + for _, plaintext := range testCases { + t.Run(plaintext, func(t *testing.T) { + encrypted, err := enc.EncryptString(plaintext, nil) + require.NoError(t, err) + require.NotEqual(t, plaintext, encrypted) + + decrypted, err := enc.DecryptString(encrypted, nil) + require.NoError(t, err) + require.Equal(t, plaintext, decrypted) + }) + } +} + +func TestUniqueNoncesPerEncryption(t *testing.T) { + key, err := GenerateKey() + require.NoError(t, err) + + enc, err := NewEncryptor(key) + require.NoError(t, err) + + plaintext := []byte("same plaintext") + nonces := make(map[string]bool) + + // Encrypt the same plaintext multiple times + for i := 0; i < 100; i++ { + encrypted, err := enc.Encrypt(plaintext, nil) + require.NoError(t, err) + + var encData EncryptedData + err = json.Unmarshal(encrypted, &encData) + require.NoError(t, err) + + // Each nonce should be unique + require.False(t, nonces[encData.Nonce], "nonce should be unique for each encryption") + nonces[encData.Nonce] = true + } +} + +func TestDifferentKeysCannotDecrypt(t *testing.T) { + key1, err := GenerateKey() + require.NoError(t, err) + + key2, err := GenerateKey() + require.NoError(t, err) + + enc1, err := NewEncryptor(key1) + require.NoError(t, err) + + enc2, err := NewEncryptor(key2) + require.NoError(t, err) + + plaintext := []byte("secret message") + + // Encrypt with key1 + encrypted, err := enc1.Encrypt(plaintext, nil) + require.NoError(t, err) + + // Try to decrypt with key2 - should fail + _, err = enc2.Decrypt(encrypted, nil) + require.ErrorIs(t, err, ErrDecryptionFailed) +} + +func TestGenerateKey(t *testing.T) { + key1, err := GenerateKey() + require.NoError(t, err) + require.Len(t, key1, KeySize) + + key2, err := GenerateKey() + require.NoError(t, err) + require.Len(t, key2, KeySize) + + // Keys should be different + require.NotEqual(t, key1, key2) +} + +func TestIsEncryptedData(t *testing.T) { + key, err := GenerateKey() + require.NoError(t, err) + + enc, err := NewEncryptor(key) + require.NoError(t, err) + + encrypted, err := enc.Encrypt([]byte("test"), nil) + require.NoError(t, err) + + // Valid 12-byte nonce encoded in base64 for manual test cases + validNonceBase64 := "AAAAAAAAAAAAAAAA" // 12 bytes of zeros in base64 + + tests := []struct { + name string + data []byte + want bool + }{ + { + name: "valid-encrypted-data", + data: encrypted, + want: true, + }, + { + name: "valid-format-manual", + data: []byte(`{"encrypted": "YWJjZGVm", "nonce": "` + validNonceBase64 + `"}`), + want: true, + }, + { + name: "invalid-base64-encrypted", + data: []byte(`{"encrypted": "not-valid-base64!!!", "nonce": "` + validNonceBase64 + `"}`), + want: false, + }, + { + name: "invalid-base64-nonce", + data: []byte(`{"encrypted": "YWJjZGVm", "nonce": "not-valid-base64!!!"}`), + want: false, + }, + { + name: "invalid-nonce-size", + data: []byte(`{"encrypted": "YWJjZGVm", "nonce": "YWJj"}`), // "abc" = 3 bytes, not 12 + want: false, + }, + { + name: "missing-encrypted-field", + data: []byte(`{"nonce": "` + validNonceBase64 + `"}`), + want: false, + }, + { + name: "missing-nonce-field", + data: []byte(`{"encrypted": "YWJjZGVm"}`), + want: false, + }, + { + name: "empty-fields", + data: []byte(`{"encrypted": "", "nonce": ""}`), + want: false, + }, + { + name: "invalid-json", + data: []byte("not json"), + want: false, + }, + { + name: "empty-data", + data: []byte{}, + want: false, + }, + { + name: "nil-data", + data: nil, + want: false, + }, + { + name: "plain-text", + data: []byte("just plain text"), + want: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := IsEncryptedData(tt.data) + require.Equal(t, tt.want, got) + }) + } +} + +func TestTamperDetection(t *testing.T) { + key, err := GenerateKey() + require.NoError(t, err) + + enc, err := NewEncryptor(key) + require.NoError(t, err) + + plaintext := []byte("sensitive data") + encrypted, err := enc.Encrypt(plaintext, nil) + require.NoError(t, err) + + // Parse the encrypted data + var encData EncryptedData + err = json.Unmarshal(encrypted, &encData) + require.NoError(t, err) + + // Tamper with the ciphertext by modifying the base64 string + // We'll change the first character + tampered := encData.Encrypted + if tampered[0] == 'A' { + tampered = "B" + tampered[1:] + } else { + tampered = "A" + tampered[1:] + } + encData.Encrypted = tampered + + tamperedJSON, err := json.Marshal(encData) + require.NoError(t, err) + + // Decryption should fail due to authentication tag mismatch + _, err = enc.Decrypt(tamperedJSON, nil) + require.Error(t, err) +} + +// Tests for Associated Data (AD) functionality + +func TestEncryptWithAssociatedData(t *testing.T) { + key, err := GenerateKey() + require.NoError(t, err) + + enc, err := NewEncryptor(key) + require.NoError(t, err) + + plaintext := []byte("secret data") + resourceID := "/planes/radius/local/resourceGroups/test/providers/Foo.Bar/myResources/test" + fieldPath := "credentials.password" + ad := []byte(resourceID + ":" + fieldPath) + + // Encrypt with AD + encrypted, err := enc.Encrypt(plaintext, ad) + require.NoError(t, err) + + // Verify AD hash is stored + var encData EncryptedData + err = json.Unmarshal(encrypted, &encData) + require.NoError(t, err) + require.NotEmpty(t, encData.AD, "AD hash should be stored") + + // Decrypt with same AD should succeed + decrypted, err := enc.Decrypt(encrypted, ad) + require.NoError(t, err) + require.Equal(t, plaintext, decrypted) +} + +func TestDecryptWithWrongAssociatedData(t *testing.T) { + key, err := GenerateKey() + require.NoError(t, err) + + enc, err := NewEncryptor(key) + require.NoError(t, err) + + plaintext := []byte("secret data") + ad1 := []byte("/resource/1:password") + ad2 := []byte("/resource/2:password") + + // Encrypt with AD1 + encrypted, err := enc.Encrypt(plaintext, ad1) + require.NoError(t, err) + + // Decrypt with different AD should fail with mismatch error + _, err = enc.Decrypt(encrypted, ad2) + require.ErrorIs(t, err, ErrAssociatedDataMismatch) +} + +func TestDecryptWithMissingAssociatedData(t *testing.T) { + key, err := GenerateKey() + require.NoError(t, err) + + enc, err := NewEncryptor(key) + require.NoError(t, err) + + plaintext := []byte("secret data") + ad := []byte("/resource/1:password") + + // Encrypt with AD + encrypted, err := enc.Encrypt(plaintext, ad) + require.NoError(t, err) + + // Decrypt without AD when AD was used should fail + _, err = enc.Decrypt(encrypted, nil) + require.ErrorIs(t, err, ErrAssociatedDataMismatch) +} + +func TestEncryptWithoutAssociatedDataDecryptWithAD(t *testing.T) { + key, err := GenerateKey() + require.NoError(t, err) + + enc, err := NewEncryptor(key) + require.NoError(t, err) + + plaintext := []byte("secret data") + + // Encrypt without AD + encrypted, err := enc.Encrypt(plaintext, nil) + require.NoError(t, err) + + // Verify no AD hash is stored + var encData EncryptedData + err = json.Unmarshal(encrypted, &encData) + require.NoError(t, err) + require.Empty(t, encData.AD, "AD hash should not be stored when no AD provided") + + // Decrypt without AD should succeed + decrypted, err := enc.Decrypt(encrypted, nil) + require.NoError(t, err) + require.Equal(t, plaintext, decrypted) + + // Decrypt with AD when no AD was used - AEAD will fail because the auth tag won't match + _, err = enc.Decrypt(encrypted, []byte("unexpected-ad")) + require.ErrorIs(t, err, ErrDecryptionFailed) +} + +func TestAssociatedDataPreventsContextSwitch(t *testing.T) { + key, err := GenerateKey() + require.NoError(t, err) + + enc, err := NewEncryptor(key) + require.NoError(t, err) + + // Simulate encrypting a password for resource1 + password := []byte("super-secret-password") + resource1AD := []byte("/resource/1:password") + resource2AD := []byte("/resource/2:password") + + // Encrypt password for resource1 + encryptedForResource1, err := enc.Encrypt(password, resource1AD) + require.NoError(t, err) + + // Attacker tries to use this encrypted value for resource2 + // This should fail because the AD is different + _, err = enc.Decrypt(encryptedForResource1, resource2AD) + require.Error(t, err, "should not be able to decrypt with different resource context") +} diff --git a/pkg/crypto/encryption/keyprovider.go b/pkg/crypto/encryption/keyprovider.go new file mode 100644 index 0000000000..ed302ca7e2 --- /dev/null +++ b/pkg/crypto/encryption/keyprovider.go @@ -0,0 +1,159 @@ +/* +Copyright 2023 The Radius Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package encryption + +import ( + "context" + "errors" + "fmt" + + corev1 "k8s.io/api/core/v1" + k8s_error "k8s.io/apimachinery/pkg/api/errors" + controller_runtime "sigs.k8s.io/controller-runtime/pkg/client" +) + +const ( + // DefaultEncryptionKeySecretName is the default name of the Kubernetes Secret containing the encryption key. + + // This is the Secret's name, not actual credentials. + DefaultEncryptionKeySecretName = "radius-encryption-key" //nolint:gosec // This is a Secret name, not credentials + + // DefaultEncryptionKeySecretKey is the key within the Secret that contains the encryption key. + DefaultEncryptionKeySecretKey = "key" + + // RadiusNamespace is the namespace where Radius secrets are stored. + RadiusNamespace = "radius-system" +) + +var ( + // ErrKeyNotFound is returned when the encryption key is not found. + ErrKeyNotFound = errors.New("encryption key not found") + + // ErrKeyLoadFailed is returned when loading the encryption key fails. + ErrKeyLoadFailed = errors.New("failed to load encryption key") +) + +// KeyProvider defines the interface for retrieving encryption keys. +// +//go:generate mockgen -typed -destination=./mock_keyprovider.go -package=encryption -self_package github.com/radius-project/radius/pkg/crypto/encryption github.com/radius-project/radius/pkg/crypto/encryption KeyProvider +type KeyProvider interface { + // GetKey retrieves the encryption key. + // Returns ErrKeyNotFound if the key does not exist. + GetKey(ctx context.Context) ([]byte, error) +} + +// KubernetesKeyProvider implements KeyProvider by loading the encryption key from a Kubernetes Secret. +type KubernetesKeyProvider struct { + client controller_runtime.Client + secretName string + secretKey string + namespace string +} + +// KubernetesKeyProviderOptions contains options for creating a KubernetesKeyProvider. +type KubernetesKeyProviderOptions struct { + // SecretName is the name of the Kubernetes Secret containing the encryption key. + // Defaults to DefaultEncryptionKeySecretName if not specified. + SecretName string + + // SecretKey is the key within the Secret that contains the encryption key. + // Defaults to DefaultEncryptionKeySecretKey if not specified. + SecretKey string + + // Namespace is the namespace where the Secret is located. + // Defaults to RadiusNamespace if not specified. + Namespace string +} + +// NewKubernetesKeyProvider creates a new KubernetesKeyProvider with the given Kubernetes client and options. +func NewKubernetesKeyProvider(client controller_runtime.Client, opts *KubernetesKeyProviderOptions) *KubernetesKeyProvider { + secretName := DefaultEncryptionKeySecretName + secretKey := DefaultEncryptionKeySecretKey + namespace := RadiusNamespace + + if opts != nil { + if opts.SecretName != "" { + secretName = opts.SecretName + } + if opts.SecretKey != "" { + secretKey = opts.SecretKey + } + if opts.Namespace != "" { + namespace = opts.Namespace + } + } + + return &KubernetesKeyProvider{ + client: client, + secretName: secretName, + secretKey: secretKey, + namespace: namespace, + } +} + +// GetKey retrieves the encryption key from the Kubernetes Secret. +func (p *KubernetesKeyProvider) GetKey(ctx context.Context) ([]byte, error) { + secret := &corev1.Secret{} + objectKey := controller_runtime.ObjectKey{ + Name: p.secretName, + Namespace: p.namespace, + } + + if err := p.client.Get(ctx, objectKey, secret); err != nil { + if k8s_error.IsNotFound(err) { + return nil, fmt.Errorf("%w: secret %s/%s not found", ErrKeyNotFound, p.namespace, p.secretName) + } + return nil, fmt.Errorf("%w: %v", ErrKeyLoadFailed, err) + } + + key, ok := secret.Data[p.secretKey] + if !ok { + return nil, fmt.Errorf("%w: key %q not found in secret %s/%s", ErrKeyNotFound, p.secretKey, p.namespace, p.secretName) + } + + if len(key) != KeySize { + return nil, fmt.Errorf("%w: key in secret %s/%s has invalid size (expected %d bytes, got %d)", ErrKeyLoadFailed, p.namespace, p.secretName, KeySize, len(key)) + } + + return key, nil +} + +// InMemoryKeyProvider implements KeyProvider with an in-memory key. +// This is useful for testing or development environments. +type InMemoryKeyProvider struct { + key []byte +} + +// NewInMemoryKeyProvider creates a new InMemoryKeyProvider with the given key. +func NewInMemoryKeyProvider(key []byte) (*InMemoryKeyProvider, error) { + if len(key) != KeySize { + return nil, ErrInvalidKeySize + } + keyCopy := make([]byte, KeySize) + copy(keyCopy, key) + return &InMemoryKeyProvider{key: keyCopy}, nil +} + +// GetKey returns a copy of the in-memory encryption key. +// A copy is returned to prevent callers from mutating the provider's internal state. +func (p *InMemoryKeyProvider) GetKey(ctx context.Context) ([]byte, error) { + if p.key == nil { + return nil, ErrKeyNotFound + } + // Return a copy to prevent mutation of the internal key + return append([]byte(nil), p.key...), nil +} diff --git a/pkg/crypto/encryption/keyprovider_test.go b/pkg/crypto/encryption/keyprovider_test.go new file mode 100644 index 0000000000..7398c66ff4 --- /dev/null +++ b/pkg/crypto/encryption/keyprovider_test.go @@ -0,0 +1,239 @@ +/* +Copyright 2023 The Radius Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package encryption + +import ( + "context" + "testing" + + "github.com/radius-project/radius/test/k8sutil" + "github.com/stretchr/testify/require" + corev1 "k8s.io/api/core/v1" + metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" + "k8s.io/kubectl/pkg/scheme" + controller_runtime "sigs.k8s.io/controller-runtime/pkg/client" +) + +func TestKubernetesKeyProvider_GetKey(t *testing.T) { + ctx := context.Background() + validKey := make([]byte, KeySize) + for i := range validKey { + validKey[i] = byte(i) + } + + tests := []struct { + name string + setupFunc func(k8sClient controller_runtime.Client) + opts *KubernetesKeyProviderOptions + wantErr error + wantKey []byte + wantErrMsg string + }{ + { + name: "success-with-default-options", + setupFunc: func(k8sClient controller_runtime.Client) { + secret := &corev1.Secret{ + ObjectMeta: metav1.ObjectMeta{ + Name: DefaultEncryptionKeySecretName, + Namespace: RadiusNamespace, + }, + Data: map[string][]byte{ + DefaultEncryptionKeySecretKey: validKey, + }, + } + err := k8sClient.Create(ctx, secret) + require.NoError(t, err) + }, + opts: nil, + wantKey: validKey, + }, + { + name: "success-with-custom-options", + setupFunc: func(k8sClient controller_runtime.Client) { + secret := &corev1.Secret{ + ObjectMeta: metav1.ObjectMeta{ + Name: "custom-secret", + Namespace: "custom-namespace", + }, + Data: map[string][]byte{ + "custom-key": validKey, + }, + } + err := k8sClient.Create(ctx, secret) + require.NoError(t, err) + }, + opts: &KubernetesKeyProviderOptions{ + SecretName: "custom-secret", + SecretKey: "custom-key", + Namespace: "custom-namespace", + }, + wantKey: validKey, + }, + { + name: "error-secret-not-found", + setupFunc: func(k8sClient controller_runtime.Client) {}, + opts: nil, + wantErr: ErrKeyNotFound, + wantErrMsg: "not found", + }, + { + name: "error-key-not-in-secret", + setupFunc: func(k8sClient controller_runtime.Client) { + secret := &corev1.Secret{ + ObjectMeta: metav1.ObjectMeta{ + Name: DefaultEncryptionKeySecretName, + Namespace: RadiusNamespace, + }, + Data: map[string][]byte{ + "wrong-key": validKey, + }, + } + err := k8sClient.Create(ctx, secret) + require.NoError(t, err) + }, + opts: nil, + wantErr: ErrKeyNotFound, + wantErrMsg: "not found in secret", + }, + { + name: "error-invalid-key-size", + setupFunc: func(k8sClient controller_runtime.Client) { + secret := &corev1.Secret{ + ObjectMeta: metav1.ObjectMeta{ + Name: DefaultEncryptionKeySecretName, + Namespace: RadiusNamespace, + }, + Data: map[string][]byte{ + DefaultEncryptionKeySecretKey: make([]byte, 16), // Too short + }, + } + err := k8sClient.Create(ctx, secret) + require.NoError(t, err) + }, + opts: nil, + wantErr: ErrKeyLoadFailed, + wantErrMsg: "invalid size", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + k8sClient := k8sutil.NewFakeKubeClient(scheme.Scheme) + tt.setupFunc(k8sClient) + + provider := NewKubernetesKeyProvider(k8sClient, tt.opts) + key, err := provider.GetKey(ctx) + + if tt.wantErr != nil { + require.ErrorIs(t, err, tt.wantErr) + if tt.wantErrMsg != "" { + require.Contains(t, err.Error(), tt.wantErrMsg) + } + require.Nil(t, key) + } else { + require.NoError(t, err) + require.Equal(t, tt.wantKey, key) + } + }) + } +} + +func TestNewKubernetesKeyProvider_DefaultOptions(t *testing.T) { + k8sClient := k8sutil.NewFakeKubeClient(scheme.Scheme) + + // Test with nil options + provider := NewKubernetesKeyProvider(k8sClient, nil) + require.Equal(t, DefaultEncryptionKeySecretName, provider.secretName) + require.Equal(t, DefaultEncryptionKeySecretKey, provider.secretKey) + require.Equal(t, RadiusNamespace, provider.namespace) + + // Test with empty options + provider = NewKubernetesKeyProvider(k8sClient, &KubernetesKeyProviderOptions{}) + require.Equal(t, DefaultEncryptionKeySecretName, provider.secretName) + require.Equal(t, DefaultEncryptionKeySecretKey, provider.secretKey) + require.Equal(t, RadiusNamespace, provider.namespace) +} + +func TestInMemoryKeyProvider(t *testing.T) { + ctx := context.Background() + validKey := make([]byte, KeySize) + for i := range validKey { + validKey[i] = byte(i) + } + + t.Run("success", func(t *testing.T) { + provider, err := NewInMemoryKeyProvider(validKey) + require.NoError(t, err) + + key, err := provider.GetKey(ctx) + require.NoError(t, err) + require.Equal(t, validKey, key) + }) + + t.Run("error-invalid-key-size", func(t *testing.T) { + _, err := NewInMemoryKeyProvider(make([]byte, 16)) + require.ErrorIs(t, err, ErrInvalidKeySize) + }) + + t.Run("key-is-copied", func(t *testing.T) { + originalKey := make([]byte, KeySize) + for i := range originalKey { + originalKey[i] = byte(i) + } + + provider, err := NewInMemoryKeyProvider(originalKey) + require.NoError(t, err) + + // Modify the original key + originalKey[0] = 0xff + + // The provider's key should not be affected + key, err := provider.GetKey(ctx) + require.NoError(t, err) + require.NotEqual(t, originalKey[0], key[0]) + require.Equal(t, byte(0), key[0]) + }) +} + +func TestKeyProviderIntegration(t *testing.T) { + ctx := context.Background() + + // Generate a key + key, err := GenerateKey() + require.NoError(t, err) + + // Create an in-memory provider + provider, err := NewInMemoryKeyProvider(key) + require.NoError(t, err) + + // Get the key from the provider + retrievedKey, err := provider.GetKey(ctx) + require.NoError(t, err) + + // Create an encryptor with the retrieved key + enc, err := NewEncryptor(retrievedKey) + require.NoError(t, err) + + // Test encryption/decryption + plaintext := []byte("secret data from key provider") + encrypted, err := enc.Encrypt(plaintext, nil) + require.NoError(t, err) + + decrypted, err := enc.Decrypt(encrypted, nil) + require.NoError(t, err) + require.Equal(t, plaintext, decrypted) +} diff --git a/pkg/crypto/encryption/sensitive.go b/pkg/crypto/encryption/sensitive.go new file mode 100644 index 0000000000..ade7c6f8ff --- /dev/null +++ b/pkg/crypto/encryption/sensitive.go @@ -0,0 +1,611 @@ +/* +Copyright 2023 The Radius Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package encryption + +import ( + "context" + "encoding/json" + "errors" + "fmt" + "strconv" + "strings" +) + +var ( + // ErrFieldNotFound is returned when a field path cannot be found in the data. + ErrFieldNotFound = errors.New("field not found") + + // ErrInvalidFieldPath is returned when a field path is invalid. + ErrInvalidFieldPath = errors.New("invalid field path") + + // ErrFieldEncryptionFailed is returned when encryption of a field fails. + ErrFieldEncryptionFailed = errors.New("field encryption failed") + + // ErrFieldDecryptionFailed is returned when decryption of a field fails. + ErrFieldDecryptionFailed = errors.New("field decryption failed") +) + +// SensitiveDataHandler provides methods for encrypting and decrypting sensitive fields +// in data structures based on field paths marked with x-radius-sensitive annotation. +type SensitiveDataHandler struct { + encryptor *Encryptor +} + +// NewSensitiveDataHandler creates a new SensitiveDataHandler with the provided encryptor. +func NewSensitiveDataHandler(encryptor *Encryptor) *SensitiveDataHandler { + return &SensitiveDataHandler{encryptor: encryptor} +} + +// NewSensitiveDataHandlerFromKey creates a new SensitiveDataHandler from a raw encryption key. +func NewSensitiveDataHandlerFromKey(key []byte) (*SensitiveDataHandler, error) { + encryptor, err := NewEncryptor(key) + if err != nil { + return nil, err + } + return &SensitiveDataHandler{encryptor: encryptor}, nil +} + +// NewSensitiveDataHandlerFromProvider creates a new SensitiveDataHandler using a key provider. +func NewSensitiveDataHandlerFromProvider(ctx context.Context, provider KeyProvider) (*SensitiveDataHandler, error) { + key, err := provider.GetKey(ctx) + if err != nil { + return nil, err + } + return NewSensitiveDataHandlerFromKey(key) +} + +// EncryptSensitiveFields encrypts all sensitive fields in the data based on the provided field paths. +// The data is modified in place. Field paths support dot notation and [*] for arrays/maps. +// Examples: "credentials.password", "secrets[*].value", "config[*]" +// +// The resourceID is used as Associated Data (AD) for context binding. This prevents encrypted +// values from being moved between different resources. The resourceID should be the full +// resource ID (e.g., "/planes/radius/local/resourceGroups/test/providers/Foo.Bar/myResources/test"). +// +// Returns an error if any field encryption fails. In case of error, partial encryption may have occurred. +func (h *SensitiveDataHandler) EncryptSensitiveFields(data map[string]any, sensitiveFieldPaths []string, resourceID string) error { + for _, path := range sensitiveFieldPaths { + // Build associated data from resource ID and field path + ad := buildAssociatedData(resourceID, path) + if err := h.encryptFieldAtPath(data, path, ad); err != nil { + return fmt.Errorf("%w: path %q: %v", ErrFieldEncryptionFailed, path, err) + } + } + return nil +} + +// DecryptSensitiveFields decrypts all sensitive fields in the data based on the provided field paths. +// The data is modified in place. Field paths support dot notation and [*] for arrays/maps. +// +// The resourceID must match what was provided during encryption for successful decryption. +// +// Note: This method does not use schema information for type restoration. Numbers in decrypted +// objects will be returned as float64 (standard Go JSON behavior). For accurate type restoration, +// use DecryptSensitiveFieldsWithSchema instead. +// +// Returns an error if any field decryption fails. In case of error, partial decryption may have occurred. +func (h *SensitiveDataHandler) DecryptSensitiveFields(data map[string]any, sensitiveFieldPaths []string, resourceID string) error { + for _, path := range sensitiveFieldPaths { + ad := buildAssociatedData(resourceID, path) + if err := h.decryptFieldAtPath(data, path, nil, ad); err != nil { + // Skip fields that are not found - they may not exist in this resource instance + if errors.Is(err, ErrFieldNotFound) { + continue + } + return fmt.Errorf("%w: path %q: %v", ErrFieldDecryptionFailed, path, err) + } + } + return nil +} + +// DecryptSensitiveFieldsWithSchema decrypts all sensitive fields in the data using schema information +// for accurate type restoration. The schema should be the OpenAPI schema for the resource type. +// The data is modified in place. Field paths support dot notation and [*] for arrays/maps. +// +// The resourceID must match what was provided during encryption for successful decryption. +// The schema is used to restore the correct types for fields within encrypted objects (e.g., integers +// that would otherwise be decoded as float64). +// +// Returns an error if any field decryption fails. In case of error, partial decryption may have occurred. +func (h *SensitiveDataHandler) DecryptSensitiveFieldsWithSchema(data map[string]any, sensitiveFieldPaths []string, resourceID string, schema map[string]any) error { + for _, path := range sensitiveFieldPaths { + // Get the schema for this specific field path + fieldSchema := getSchemaForPath(schema, path) + ad := buildAssociatedData(resourceID, path) + if err := h.decryptFieldAtPath(data, path, fieldSchema, ad); err != nil { + // Skip fields that are not found - they may not exist in this resource instance + if errors.Is(err, ErrFieldNotFound) { + continue + } + return fmt.Errorf("%w: path %q: %v", ErrFieldDecryptionFailed, path, err) + } + } + return nil +} + +// encryptFieldAtPath encrypts the value at the given field path in the data. +func (h *SensitiveDataHandler) encryptFieldAtPath(data map[string]any, path string, associatedData []byte) error { + processor := func(value any) (any, error) { + return h.encryptValue(value, associatedData) + } + return h.processFieldAtPath(data, path, processor) +} + +// decryptFieldAtPath decrypts the value at the given field path in the data. +// If fieldSchema is provided, it will be used for type restoration. +func (h *SensitiveDataHandler) decryptFieldAtPath(data map[string]any, path string, fieldSchema map[string]any, associatedData []byte) error { + processor := func(value any) (any, error) { + return h.decryptValue(value, fieldSchema, associatedData) + } + return h.processFieldAtPath(data, path, processor) +} + +// processFieldAtPath traverses the data structure and applies the processor function to the field at the path. +func (h *SensitiveDataHandler) processFieldAtPath(data map[string]any, path string, processor func(any) (any, error)) error { + if path == "" { + return ErrInvalidFieldPath + } + + segments := parseFieldPath(path) + if len(segments) == 0 { + return ErrInvalidFieldPath + } + + return h.processPathSegments(data, segments, processor) +} + +// processPathSegments recursively processes path segments to find and transform the target field. +func (h *SensitiveDataHandler) processPathSegments(current any, segments []pathSegment, processor func(any) (any, error)) error { + if len(segments) == 0 { + return nil + } + + segment := segments[0] + remainingSegments := segments[1:] + + switch segment.segmentType { + case segmentTypeField: + return h.processFieldSegment(current, segment.value, remainingSegments, processor) + case segmentTypeWildcard: + return h.processWildcardSegment(current, remainingSegments, processor) + case segmentTypeIndex: + return h.processIndexSegment(current, segment.value, remainingSegments, processor) + default: + return ErrInvalidFieldPath + } +} + +// processFieldSegment handles a regular field name segment in the path. +func (h *SensitiveDataHandler) processFieldSegment(current any, fieldName string, remainingSegments []pathSegment, processor func(any) (any, error)) error { + dataMap, ok := current.(map[string]any) + if !ok { + return ErrFieldNotFound + } + + value, exists := dataMap[fieldName] + if !exists { + return ErrFieldNotFound + } + + // If this is the last segment, process the value + if len(remainingSegments) == 0 { + processed, err := processor(value) + if err != nil { + return err + } + dataMap[fieldName] = processed + return nil + } + + // Continue traversing + return h.processPathSegments(value, remainingSegments, processor) +} + +// processWildcardSegment handles [*] segments for arrays and maps. +func (h *SensitiveDataHandler) processWildcardSegment(current any, remainingSegments []pathSegment, processor func(any) (any, error)) error { + // Handle array + if arr, ok := current.([]any); ok { + for i := range arr { + if len(remainingSegments) == 0 { + // Process each array element + processed, err := processor(arr[i]) + if err != nil { + return fmt.Errorf("index %d: %w", i, err) + } + arr[i] = processed + } else { + // Continue traversing into each element + if err := h.processPathSegments(arr[i], remainingSegments, processor); err != nil { + // Skip elements that don't have the field + if !errors.Is(err, ErrFieldNotFound) { + return fmt.Errorf("index %d: %w", i, err) + } + } + } + } + return nil + } + + // Handle map + if dataMap, ok := current.(map[string]any); ok { + for key := range dataMap { + if len(remainingSegments) == 0 { + // Process each map value + processed, err := processor(dataMap[key]) + if err != nil { + return fmt.Errorf("key %q: %w", key, err) + } + dataMap[key] = processed + } else { + // Continue traversing into each value + if err := h.processPathSegments(dataMap[key], remainingSegments, processor); err != nil { + // Skip elements that don't have the field + if !errors.Is(err, ErrFieldNotFound) { + return fmt.Errorf("key %q: %w", key, err) + } + } + } + } + return nil + } + + return ErrFieldNotFound +} + +// processIndexSegment handles specific index segments like [0], [1], etc. +func (h *SensitiveDataHandler) processIndexSegment(current any, indexStr string, remainingSegments []pathSegment, processor func(any) (any, error)) error { + arr, ok := current.([]any) + if !ok { + return ErrFieldNotFound + } + + index, err := strconv.Atoi(indexStr) + if err != nil { + return fmt.Errorf("%w: invalid index %q", ErrInvalidFieldPath, indexStr) + } + + if index < 0 || index >= len(arr) { + return ErrFieldNotFound + } + + if len(remainingSegments) == 0 { + processed, err := processor(arr[index]) + if err != nil { + return err + } + arr[index] = processed + return nil + } + + return h.processPathSegments(arr[index], remainingSegments, processor) +} + +// encryptValue encrypts a single value, handling different types appropriately. +func (h *SensitiveDataHandler) encryptValue(value any, associatedData []byte) (any, error) { + if value == nil { + return nil, nil + } + + var dataToEncrypt []byte + var err error + + switch v := value.(type) { + case string: + if v == "" { + return v, nil + } + dataToEncrypt = []byte(v) + case map[string]any, []any: + dataToEncrypt, err = json.Marshal(v) + if err != nil { + return nil, fmt.Errorf("failed to marshal value: %w", err) + } + default: + // For other types, convert to JSON + dataToEncrypt, err = json.Marshal(v) + if err != nil { + return nil, fmt.Errorf("failed to marshal value: %w", err) + } + } + + encrypted, err := h.encryptor.Encrypt(dataToEncrypt, associatedData) + if err != nil { + return nil, err + } + + // Return as a structured object + var result map[string]any + if err := json.Unmarshal(encrypted, &result); err != nil { + return nil, err + } + + return result, nil +} + +// decryptValue decrypts a single value, restoring the original type using schema information if provided. +func (h *SensitiveDataHandler) decryptValue(value any, fieldSchema map[string]any, associatedData []byte) (any, error) { + if value == nil { + return nil, nil + } + + // Check if this looks like encrypted data + encMap, ok := value.(map[string]any) + if !ok { + // Not encrypted data, return as-is + return value, nil + } + + _, hasEncrypted := encMap["encrypted"].(string) + _, hasNonce := encMap["nonce"].(string) + + if !hasEncrypted || !hasNonce { + // Not our encrypted format, return as-is + return value, nil + } + + // Convert back to JSON for decryption + encryptedJSON, err := json.Marshal(encMap) + if err != nil { + return nil, err + } + + decrypted, err := h.encryptor.Decrypt(encryptedJSON, associatedData) + if err != nil { + return nil, err + } + + // Determine the expected type from schema + expectedType := getSchemaType(fieldSchema) + + // For string type, return as string directly + if expectedType == "string" { + return string(decrypted), nil + } + + // For object/array types, unmarshal and apply type coercion based on schema + var result any + if err := json.Unmarshal(decrypted, &result); err != nil { + // If not valid JSON, return as string + return string(decrypted), nil + } + + // Apply schema-based type coercion if schema is available + if fieldSchema != nil && expectedType == "object" { + if resultMap, ok := result.(map[string]any); ok { + coerceTypesFromSchema(resultMap, fieldSchema) + } + } + + return result, nil +} + +// buildAssociatedData constructs the associated data for AEAD encryption from the resource ID and field path. +// This binds the ciphertext to its context, preventing encrypted values from being moved between +// different resources or fields. +func buildAssociatedData(resourceID, fieldPath string) []byte { + if resourceID == "" && fieldPath == "" { + return nil + } + // Combine resource ID and field path with a separator + // Format: "resourceID:fieldPath" + return []byte(resourceID + ":" + fieldPath) +} + +// getSchemaType returns the type from a schema, or empty string if not specified. +func getSchemaType(schema map[string]any) string { + if schema == nil { + return "" + } + if t, ok := schema["type"].(string); ok { + return t + } + return "" +} + +// getSchemaForPath retrieves the schema definition for a specific field path. +// It navigates through the schema following the path segments (supporting nested properties, +// array items via [*], and additionalProperties for maps). +func getSchemaForPath(schema map[string]any, path string) map[string]any { + if schema == nil || path == "" { + return nil + } + + segments := parseFieldPath(path) + current := schema + + for _, segment := range segments { + switch segment.segmentType { + case segmentTypeField: + // Navigate to properties -> fieldName + properties, ok := current["properties"].(map[string]any) + if !ok { + return nil + } + fieldSchema, ok := properties[segment.value].(map[string]any) + if !ok { + return nil + } + current = fieldSchema + + case segmentTypeWildcard: + // Could be array items or additionalProperties + if items, ok := current["items"].(map[string]any); ok { + current = items + } else if addProps, ok := current["additionalProperties"].(map[string]any); ok { + current = addProps + } else { + return nil + } + + case segmentTypeIndex: + // Specific array index - use items schema + if items, ok := current["items"].(map[string]any); ok { + current = items + } else { + return nil + } + } + } + + return current +} + +// coerceTypesFromSchema recursively walks through a data map and coerces types +// to match the schema definition. This primarily handles converting float64 to int64 +// for integer fields. +func coerceTypesFromSchema(data map[string]any, schema map[string]any) { + if schema == nil { + return + } + + properties, ok := schema["properties"].(map[string]any) + if !ok { + return + } + + for fieldName, fieldValue := range data { + fieldSchema, ok := properties[fieldName].(map[string]any) + if !ok { + continue + } + + fieldType := getSchemaType(fieldSchema) + + switch fieldType { + case "integer": + // Coerce float64 to int64 + if f, ok := fieldValue.(float64); ok { + data[fieldName] = int64(f) + } + + case "object": + // Recursively coerce nested objects + if nestedMap, ok := fieldValue.(map[string]any); ok { + coerceTypesFromSchema(nestedMap, fieldSchema) + } + + case "array": + // Coerce array items if they have a schema + if arr, ok := fieldValue.([]any); ok { + itemSchema, _ := fieldSchema["items"].(map[string]any) + if itemSchema != nil { + itemType := getSchemaType(itemSchema) + for i, item := range arr { + if itemType == "integer" { + if f, ok := item.(float64); ok { + arr[i] = int64(f) + } + } else if itemType == "object" { + if itemMap, ok := item.(map[string]any); ok { + coerceTypesFromSchema(itemMap, itemSchema) + } + } + } + } + } + } + + // Handle additionalProperties for map types + if addPropsSchema, ok := fieldSchema["additionalProperties"].(map[string]any); ok { + if nestedMap, ok := fieldValue.(map[string]any); ok { + addPropsType := getSchemaType(addPropsSchema) + for key, val := range nestedMap { + if addPropsType == "integer" { + if f, ok := val.(float64); ok { + nestedMap[key] = int64(f) + } + } else if addPropsType == "object" { + if valMap, ok := val.(map[string]any); ok { + coerceTypesFromSchema(valMap, addPropsSchema) + } + } + } + } + } + } +} + +// pathSegment represents a segment of a field path. +type pathSegment struct { + segmentType segmentType + value string +} + +type segmentType int + +const ( + segmentTypeField segmentType = iota + segmentTypeWildcard + segmentTypeIndex +) + +// parseFieldPath parses a field path into segments. +// Examples: +// - "credentials.password" -> [field:credentials, field:password] +// - "secrets[*].value" -> [field:secrets, wildcard, field:value] +// - "config[*]" -> [field:config, wildcard] +// - "items[0].name" -> [field:items, index:0, field:name] +func parseFieldPath(path string) []pathSegment { + var segments []pathSegment + var current strings.Builder + + i := 0 + for i < len(path) { + ch := path[i] + + switch ch { + case '.': + if current.Len() > 0 { + segments = append(segments, pathSegment{segmentType: segmentTypeField, value: current.String()}) + current.Reset() + } + i++ + + case '[': + if current.Len() > 0 { + segments = append(segments, pathSegment{segmentType: segmentTypeField, value: current.String()}) + current.Reset() + } + + // Find the closing bracket + end := strings.Index(path[i:], "]") + if end == -1 { + // Invalid path - unterminated bracket, return nil to signal error + return nil + } + + bracketContent := path[i+1 : i+end] + if bracketContent == "*" { + segments = append(segments, pathSegment{segmentType: segmentTypeWildcard}) + } else { + segments = append(segments, pathSegment{segmentType: segmentTypeIndex, value: bracketContent}) + } + i += end + 1 + + default: + current.WriteByte(ch) + i++ + } + } + + // Don't forget the last segment + if current.Len() > 0 { + segments = append(segments, pathSegment{segmentType: segmentTypeField, value: current.String()}) + } + + return segments +} diff --git a/pkg/crypto/encryption/sensitive_test.go b/pkg/crypto/encryption/sensitive_test.go new file mode 100644 index 0000000000..75efbbb3fb --- /dev/null +++ b/pkg/crypto/encryption/sensitive_test.go @@ -0,0 +1,844 @@ +/* +Copyright 2023 The Radius Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package encryption + +import ( + "context" + "testing" + + "github.com/stretchr/testify/require" +) + +const ( + // testResourceID is a sample resource ID used for testing associated data + testResourceID = "/planes/radius/local/resourceGroups/test/providers/Test.Resource/testResources/myResource" +) + +func TestParseFieldPath(t *testing.T) { + tests := []struct { + name string + path string + expected []pathSegment + }{ + { + name: "simple-field", + path: "password", + expected: []pathSegment{ + {segmentType: segmentTypeField, value: "password"}, + }, + }, + { + name: "nested-field", + path: "credentials.password", + expected: []pathSegment{ + {segmentType: segmentTypeField, value: "credentials"}, + {segmentType: segmentTypeField, value: "password"}, + }, + }, + { + name: "deeply-nested-field", + path: "config.database.connection.password", + expected: []pathSegment{ + {segmentType: segmentTypeField, value: "config"}, + {segmentType: segmentTypeField, value: "database"}, + {segmentType: segmentTypeField, value: "connection"}, + {segmentType: segmentTypeField, value: "password"}, + }, + }, + { + name: "array-wildcard", + path: "secrets[*].value", + expected: []pathSegment{ + {segmentType: segmentTypeField, value: "secrets"}, + {segmentType: segmentTypeWildcard}, + {segmentType: segmentTypeField, value: "value"}, + }, + }, + { + name: "map-wildcard", + path: "config[*]", + expected: []pathSegment{ + {segmentType: segmentTypeField, value: "config"}, + {segmentType: segmentTypeWildcard}, + }, + }, + { + name: "specific-index", + path: "items[0].name", + expected: []pathSegment{ + {segmentType: segmentTypeField, value: "items"}, + {segmentType: segmentTypeIndex, value: "0"}, + {segmentType: segmentTypeField, value: "name"}, + }, + }, + { + name: "multiple-wildcards", + path: "data[*].secrets[*].value", + expected: []pathSegment{ + {segmentType: segmentTypeField, value: "data"}, + {segmentType: segmentTypeWildcard}, + {segmentType: segmentTypeField, value: "secrets"}, + {segmentType: segmentTypeWildcard}, + {segmentType: segmentTypeField, value: "value"}, + }, + }, + { + name: "unterminated-bracket", + path: "secrets[*", + expected: nil, + }, + { + name: "unterminated-bracket-with-index", + path: "items[0", + expected: nil, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := parseFieldPath(tt.path) + require.Equal(t, tt.expected, result) + }) + } +} + +func TestSensitiveDataHandler_EncryptDecrypt_SimpleField(t *testing.T) { + key, err := GenerateKey() + require.NoError(t, err) + + handler, err := NewSensitiveDataHandlerFromKey(key) + require.NoError(t, err) + + data := map[string]any{ + "username": "admin", + "password": "super-secret-password", + } + + // Encrypt + err = handler.EncryptSensitiveFields(data, []string{"password"}, testResourceID) + require.NoError(t, err) + + // Verify password is encrypted + password := data["password"] + encMap, ok := password.(map[string]any) + require.True(t, ok, "password should be encrypted map") + require.NotEmpty(t, encMap["encrypted"]) + require.NotEmpty(t, encMap["nonce"]) + + // Username should be unchanged + require.Equal(t, "admin", data["username"]) + + // Decrypt + err = handler.DecryptSensitiveFields(data, []string{"password"}, testResourceID) + require.NoError(t, err) + + // Verify password is decrypted + require.Equal(t, "super-secret-password", data["password"]) + require.Equal(t, "admin", data["username"]) +} + +func TestSensitiveDataHandler_EncryptDecrypt_NestedField(t *testing.T) { + key, err := GenerateKey() + require.NoError(t, err) + + handler, err := NewSensitiveDataHandlerFromKey(key) + require.NoError(t, err) + + data := map[string]any{ + "name": "test-resource", + "credentials": map[string]any{ + "username": "admin", + "password": "nested-secret", + "apiKey": "key-12345", + }, + } + + // Encrypt password and apiKey + err = handler.EncryptSensitiveFields(data, []string{"credentials.password", "credentials.apiKey"}, testResourceID) + require.NoError(t, err) + + // Verify encrypted fields + creds := data["credentials"].(map[string]any) + require.Equal(t, "admin", creds["username"]) + + _, passwordIsEncrypted := creds["password"].(map[string]any) + require.True(t, passwordIsEncrypted) + + _, apiKeyIsEncrypted := creds["apiKey"].(map[string]any) + require.True(t, apiKeyIsEncrypted) + + // Decrypt + err = handler.DecryptSensitiveFields(data, []string{"credentials.password", "credentials.apiKey"}, testResourceID) + require.NoError(t, err) + + creds = data["credentials"].(map[string]any) + require.Equal(t, "nested-secret", creds["password"]) + require.Equal(t, "key-12345", creds["apiKey"]) +} + +func TestSensitiveDataHandler_EncryptDecrypt_ArrayWildcard(t *testing.T) { + key, err := GenerateKey() + require.NoError(t, err) + + handler, err := NewSensitiveDataHandlerFromKey(key) + require.NoError(t, err) + + data := map[string]any{ + "name": "test-resource", + "secrets": []any{ + map[string]any{"name": "secret1", "value": "value1"}, + map[string]any{"name": "secret2", "value": "value2"}, + map[string]any{"name": "secret3", "value": "value3"}, + }, + } + + // Encrypt all secret values + err = handler.EncryptSensitiveFields(data, []string{"secrets[*].value"}, testResourceID) + require.NoError(t, err) + + // Verify all values are encrypted + secrets := data["secrets"].([]any) + for i, s := range secrets { + secret := s.(map[string]any) + require.Equal(t, "secret"+string(rune('1'+i)), secret["name"]) + + _, valueIsEncrypted := secret["value"].(map[string]any) + require.True(t, valueIsEncrypted, "secret[%d].value should be encrypted", i) + } + + // Decrypt + err = handler.DecryptSensitiveFields(data, []string{"secrets[*].value"}, testResourceID) + require.NoError(t, err) + + secrets = data["secrets"].([]any) + for i, s := range secrets { + secret := s.(map[string]any) + require.Equal(t, "value"+string(rune('1'+i)), secret["value"]) + } +} + +func TestSensitiveDataHandler_EncryptDecrypt_MapWildcard(t *testing.T) { + key, err := GenerateKey() + require.NoError(t, err) + + handler, err := NewSensitiveDataHandlerFromKey(key) + require.NoError(t, err) + + data := map[string]any{ + "name": "test-resource", + "config": map[string]any{ + "database_password": "db-secret", + "api_key": "api-secret", + "token": "token-secret", + }, + } + + // Encrypt all config values + err = handler.EncryptSensitiveFields(data, []string{"config[*]"}, testResourceID) + require.NoError(t, err) + + // Verify all config values are encrypted + config := data["config"].(map[string]any) + for key, value := range config { + _, isEncrypted := value.(map[string]any) + require.True(t, isEncrypted, "config[%s] should be encrypted", key) + } + + // Decrypt + err = handler.DecryptSensitiveFields(data, []string{"config[*]"}, testResourceID) + require.NoError(t, err) + + config = data["config"].(map[string]any) + require.Equal(t, "db-secret", config["database_password"]) + require.Equal(t, "api-secret", config["api_key"]) + require.Equal(t, "token-secret", config["token"]) +} + +func TestSensitiveDataHandler_EncryptDecrypt_ObjectValue(t *testing.T) { + key, err := GenerateKey() + require.NoError(t, err) + + handler, err := NewSensitiveDataHandlerFromKey(key) + require.NoError(t, err) + + data := map[string]any{ + "name": "test-resource", + "sensitiveConfig": map[string]any{ + "password": "secret-pass", + "privateKey": "-----BEGIN PRIVATE KEY-----\nMIIE...", + "nested": map[string]any{ + "deep": "value", + }, + }, + } + + // Encrypt entire object + err = handler.EncryptSensitiveFields(data, []string{"sensitiveConfig"}, testResourceID) + require.NoError(t, err) + + // Verify the entire object is encrypted + _, isEncrypted := data["sensitiveConfig"].(map[string]any) + require.True(t, isEncrypted) + + encData := data["sensitiveConfig"].(map[string]any) + require.NotEmpty(t, encData["encrypted"]) + require.NotEmpty(t, encData["nonce"]) + + // Decrypt + err = handler.DecryptSensitiveFields(data, []string{"sensitiveConfig"}, testResourceID) + require.NoError(t, err) + + // Verify decrypted object + config := data["sensitiveConfig"].(map[string]any) + require.Equal(t, "secret-pass", config["password"]) + require.Equal(t, "-----BEGIN PRIVATE KEY-----\nMIIE...", config["privateKey"]) + + nested := config["nested"].(map[string]any) + require.Equal(t, "value", nested["deep"]) +} + +func TestSensitiveDataHandler_FieldNotFound(t *testing.T) { + key, err := GenerateKey() + require.NoError(t, err) + + handler, err := NewSensitiveDataHandlerFromKey(key) + require.NoError(t, err) + + data := map[string]any{ + "username": "admin", + } + + // Encrypting non-existent field should return error + err = handler.EncryptSensitiveFields(data, []string{"password"}, testResourceID) + require.Error(t, err) + require.ErrorIs(t, err, ErrFieldEncryptionFailed) + + // Decrypting non-existent field should be skipped (no error) + err = handler.DecryptSensitiveFields(data, []string{"password"}, testResourceID) + require.NoError(t, err) +} + +func TestSensitiveDataHandler_EmptyValue(t *testing.T) { + key, err := GenerateKey() + require.NoError(t, err) + + handler, err := NewSensitiveDataHandlerFromKey(key) + require.NoError(t, err) + + data := map[string]any{ + "password": "", + } + + // Empty string should remain empty + err = handler.EncryptSensitiveFields(data, []string{"password"}, testResourceID) + require.NoError(t, err) + require.Equal(t, "", data["password"]) +} + +func TestSensitiveDataHandler_NilValue(t *testing.T) { + key, err := GenerateKey() + require.NoError(t, err) + + handler, err := NewSensitiveDataHandlerFromKey(key) + require.NoError(t, err) + + data := map[string]any{ + "password": nil, + } + + // Nil should remain nil + err = handler.EncryptSensitiveFields(data, []string{"password"}, testResourceID) + require.NoError(t, err) + require.Nil(t, data["password"]) +} + +func TestSensitiveDataHandler_InvalidFieldPath(t *testing.T) { + key, err := GenerateKey() + require.NoError(t, err) + + handler, err := NewSensitiveDataHandlerFromKey(key) + require.NoError(t, err) + + data := map[string]any{ + "password": "secret", + } + + // Empty path should return error + err = handler.EncryptSensitiveFields(data, []string{""}, testResourceID) + require.Error(t, err) +} + +func TestSensitiveDataHandler_RoundTrip_ComplexStructure(t *testing.T) { + key, err := GenerateKey() + require.NoError(t, err) + + handler, err := NewSensitiveDataHandlerFromKey(key) + require.NoError(t, err) + + original := map[string]any{ + "name": "my-application", + "database": map[string]any{ + "host": "localhost", + "port": 5432, + "password": "db-password-123", + }, + "secrets": []any{ + map[string]any{"key": "API_KEY", "value": "api-secret-value"}, + map[string]any{"key": "AUTH_TOKEN", "value": "auth-token-value"}, + }, + "config": map[string]any{ + "public_setting": "visible", + "private_key": "secret-key-data", + }, + } + + sensitivePaths := []string{ + "database.password", + "secrets[*].value", + "config.private_key", + } + + // Make a copy to encrypt + data := deepCopyMap(original) + + // Encrypt + err = handler.EncryptSensitiveFields(data, sensitivePaths, testResourceID) + require.NoError(t, err) + + // Verify sensitive fields are encrypted + dbPassword := data["database"].(map[string]any)["password"] + _, isEncrypted := dbPassword.(map[string]any) + require.True(t, isEncrypted, "database.password should be encrypted") + + secrets := data["secrets"].([]any) + for i, s := range secrets { + secret := s.(map[string]any) + _, valueEncrypted := secret["value"].(map[string]any) + require.True(t, valueEncrypted, "secrets[%d].value should be encrypted", i) + } + + configPrivateKey := data["config"].(map[string]any)["private_key"] + _, isEncrypted = configPrivateKey.(map[string]any) + require.True(t, isEncrypted, "config.private_key should be encrypted") + + // Verify non-sensitive fields are unchanged + require.Equal(t, "my-application", data["name"]) + require.Equal(t, "localhost", data["database"].(map[string]any)["host"]) + require.Equal(t, 5432, data["database"].(map[string]any)["port"]) + require.Equal(t, "visible", data["config"].(map[string]any)["public_setting"]) + + // Decrypt + err = handler.DecryptSensitiveFields(data, sensitivePaths, testResourceID) + require.NoError(t, err) + + // Verify values are restored + require.Equal(t, "db-password-123", data["database"].(map[string]any)["password"]) + require.Equal(t, "api-secret-value", data["secrets"].([]any)[0].(map[string]any)["value"]) + require.Equal(t, "auth-token-value", data["secrets"].([]any)[1].(map[string]any)["value"]) + require.Equal(t, "secret-key-data", data["config"].(map[string]any)["private_key"]) +} + +func TestSensitiveDataHandler_FromProvider(t *testing.T) { + ctx := context.Background() + + key, err := GenerateKey() + require.NoError(t, err) + + provider, err := NewInMemoryKeyProvider(key) + require.NoError(t, err) + + handler, err := NewSensitiveDataHandlerFromProvider(ctx, provider) + require.NoError(t, err) + require.NotNil(t, handler) + + // Test basic functionality + data := map[string]any{ + "secret": "my-secret", + } + + err = handler.EncryptSensitiveFields(data, []string{"secret"}, testResourceID) + require.NoError(t, err) + + _, isEncrypted := data["secret"].(map[string]any) + require.True(t, isEncrypted) + + err = handler.DecryptSensitiveFields(data, []string{"secret"}, testResourceID) + require.NoError(t, err) + require.Equal(t, "my-secret", data["secret"]) +} + +func TestSensitiveDataHandler_SpecificIndex(t *testing.T) { + key, err := GenerateKey() + require.NoError(t, err) + + handler, err := NewSensitiveDataHandlerFromKey(key) + require.NoError(t, err) + + data := map[string]any{ + "items": []any{ + map[string]any{"value": "public"}, + map[string]any{"value": "secret"}, + map[string]any{"value": "public2"}, + }, + } + + // Only encrypt second item + err = handler.EncryptSensitiveFields(data, []string{"items[1].value"}, testResourceID) + require.NoError(t, err) + + items := data["items"].([]any) + + // First and third should remain strings + require.Equal(t, "public", items[0].(map[string]any)["value"]) + require.Equal(t, "public2", items[2].(map[string]any)["value"]) + + // Second should be encrypted + _, isEncrypted := items[1].(map[string]any)["value"].(map[string]any) + require.True(t, isEncrypted) + + // Decrypt + err = handler.DecryptSensitiveFields(data, []string{"items[1].value"}, testResourceID) + require.NoError(t, err) + + require.Equal(t, "secret", items[1].(map[string]any)["value"]) +} + +func TestSensitiveDataHandler_DecryptWithSchema_IntegerRestoration(t *testing.T) { + key, err := GenerateKey() + require.NoError(t, err) + + handler, err := NewSensitiveDataHandlerFromKey(key) + require.NoError(t, err) + + // Data with integer fields inside a sensitive object + data := map[string]any{ + "name": "test-resource", + "sensitiveConfig": map[string]any{ + "port": 5432, + "timeout": 30, + "password": "secret", + "enabled": true, + }, + } + + // Schema that describes the sensitive field + schema := map[string]any{ + "properties": map[string]any{ + "sensitiveConfig": map[string]any{ + "type": "object", + "x-radius-sensitive": true, + "properties": map[string]any{ + "port": map[string]any{ + "type": "integer", + }, + "timeout": map[string]any{ + "type": "integer", + }, + "password": map[string]any{ + "type": "string", + }, + "enabled": map[string]any{ + "type": "boolean", + }, + }, + }, + }, + } + + sensitivePaths := []string{"sensitiveConfig"} + + // Encrypt + err = handler.EncryptSensitiveFields(data, sensitivePaths, testResourceID) + require.NoError(t, err) + + // Verify it's encrypted + _, isEncrypted := data["sensitiveConfig"].(map[string]any)["encrypted"] + require.True(t, isEncrypted) + + // Decrypt WITH schema + err = handler.DecryptSensitiveFieldsWithSchema(data, sensitivePaths, testResourceID, schema) + require.NoError(t, err) + + // Verify types are correctly restored + config := data["sensitiveConfig"].(map[string]any) + + // Integers should be int64, not float64 + port, ok := config["port"].(int64) + require.True(t, ok, "port should be int64, got %T", config["port"]) + require.Equal(t, int64(5432), port) + + timeout, ok := config["timeout"].(int64) + require.True(t, ok, "timeout should be int64, got %T", config["timeout"]) + require.Equal(t, int64(30), timeout) + + // String should remain string + password, ok := config["password"].(string) + require.True(t, ok, "password should be string") + require.Equal(t, "secret", password) + + // Boolean should remain boolean + enabled, ok := config["enabled"].(bool) + require.True(t, ok, "enabled should be bool") + require.True(t, enabled) +} + +func TestSensitiveDataHandler_DecryptWithSchema_NestedObjects(t *testing.T) { + key, err := GenerateKey() + require.NoError(t, err) + + handler, err := NewSensitiveDataHandlerFromKey(key) + require.NoError(t, err) + + // Data with nested objects containing integers + data := map[string]any{ + "credentials": map[string]any{ + "database": map[string]any{ + "host": "localhost", + "port": 5432, + "maxConns": 100, + }, + "apiKey": "secret-key", + }, + } + + schema := map[string]any{ + "properties": map[string]any{ + "credentials": map[string]any{ + "type": "object", + "x-radius-sensitive": true, + "properties": map[string]any{ + "database": map[string]any{ + "type": "object", + "properties": map[string]any{ + "host": map[string]any{ + "type": "string", + }, + "port": map[string]any{ + "type": "integer", + }, + "maxConns": map[string]any{ + "type": "integer", + }, + }, + }, + "apiKey": map[string]any{ + "type": "string", + }, + }, + }, + }, + } + + sensitivePaths := []string{"credentials"} + + // Encrypt and decrypt with schema + err = handler.EncryptSensitiveFields(data, sensitivePaths, testResourceID) + require.NoError(t, err) + + err = handler.DecryptSensitiveFieldsWithSchema(data, sensitivePaths, testResourceID, schema) + require.NoError(t, err) + + // Verify nested integers are restored + creds := data["credentials"].(map[string]any) + db := creds["database"].(map[string]any) + + require.Equal(t, "localhost", db["host"]) + require.Equal(t, int64(5432), db["port"]) + require.Equal(t, int64(100), db["maxConns"]) + require.Equal(t, "secret-key", creds["apiKey"]) +} + +func TestSensitiveDataHandler_DecryptWithSchema_ArrayWithIntegers(t *testing.T) { + key, err := GenerateKey() + require.NoError(t, err) + + handler, err := NewSensitiveDataHandlerFromKey(key) + require.NoError(t, err) + + // Sensitive object containing an array with integers + data := map[string]any{ + "config": map[string]any{ + "ports": []any{80, 443, 8080}, + "name": "my-config", + }, + } + + schema := map[string]any{ + "properties": map[string]any{ + "config": map[string]any{ + "type": "object", + "x-radius-sensitive": true, + "properties": map[string]any{ + "ports": map[string]any{ + "type": "array", + "items": map[string]any{ + "type": "integer", + }, + }, + "name": map[string]any{ + "type": "string", + }, + }, + }, + }, + } + + sensitivePaths := []string{"config"} + + err = handler.EncryptSensitiveFields(data, sensitivePaths, testResourceID) + require.NoError(t, err) + + err = handler.DecryptSensitiveFieldsWithSchema(data, sensitivePaths, testResourceID, schema) + require.NoError(t, err) + + config := data["config"].(map[string]any) + ports := config["ports"].([]any) + + require.Len(t, ports, 3) + require.Equal(t, int64(80), ports[0]) + require.Equal(t, int64(443), ports[1]) + require.Equal(t, int64(8080), ports[2]) + require.Equal(t, "my-config", config["name"]) +} + +func TestSensitiveDataHandler_DecryptWithoutSchema_NoTypeCoercion(t *testing.T) { + key, err := GenerateKey() + require.NoError(t, err) + + handler, err := NewSensitiveDataHandlerFromKey(key) + require.NoError(t, err) + + // Data with integer fields + data := map[string]any{ + "sensitiveConfig": map[string]any{ + "port": 5432, + }, + } + + sensitivePaths := []string{"sensitiveConfig"} + + err = handler.EncryptSensitiveFields(data, sensitivePaths, testResourceID) + require.NoError(t, err) + + // Decrypt WITHOUT schema + err = handler.DecryptSensitiveFields(data, sensitivePaths, testResourceID) + require.NoError(t, err) + + config := data["sensitiveConfig"].(map[string]any) + + // Without schema, integer should be float64 (standard JSON behavior) + _, isFloat := config["port"].(float64) + require.True(t, isFloat, "without schema, port should be float64, got %T", config["port"]) +} + +func TestGetSchemaForPath(t *testing.T) { + schema := map[string]any{ + "properties": map[string]any{ + "credentials": map[string]any{ + "type": "object", + "properties": map[string]any{ + "password": map[string]any{ + "type": "string", + }, + }, + }, + "secrets": map[string]any{ + "type": "array", + "items": map[string]any{ + "type": "object", + "properties": map[string]any{ + "value": map[string]any{ + "type": "string", + }, + }, + }, + }, + "config": map[string]any{ + "type": "object", + "additionalProperties": map[string]any{ + "type": "string", + }, + }, + }, + } + + tests := []struct { + name string + path string + expectedType string + }{ + { + name: "simple-nested-field", + path: "credentials.password", + expectedType: "string", + }, + { + name: "array-wildcard-nested", + path: "secrets[*].value", + expectedType: "string", + }, + { + name: "map-wildcard", + path: "config[*]", + expectedType: "string", + }, + { + name: "object-field", + path: "credentials", + expectedType: "object", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := getSchemaForPath(schema, tt.path) + require.NotNil(t, result) + require.Equal(t, tt.expectedType, result["type"]) + }) + } +} + +// Helper function to deep copy a map for testing +func deepCopyMap(original map[string]any) map[string]any { + result := make(map[string]any) + for key, value := range original { + switch v := value.(type) { + case map[string]any: + result[key] = deepCopyMap(v) + case []any: + result[key] = deepCopySlice(v) + default: + result[key] = value + } + } + return result +} + +func deepCopySlice(original []any) []any { + result := make([]any, len(original)) + for i, value := range original { + switch v := value.(type) { + case map[string]any: + result[i] = deepCopyMap(v) + case []any: + result[i] = deepCopySlice(v) + default: + result[i] = value + } + } + return result +}