Skip to content

Commit c8ad850

Browse files
authored
feat: add a encrypt_secret helper function (#279)
Add a `users.encrypt_secret` that conforms to Unstructured's secret formats. This will allow users to encrypt credentials locally before using them in connector config.
1 parent 7d6d402 commit c8ad850

File tree

3 files changed

+275
-1
lines changed

3 files changed

+275
-1
lines changed
Lines changed: 108 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,108 @@
1+
from cryptography import x509
2+
from cryptography.hazmat.primitives import serialization, hashes
3+
from cryptography.hazmat.primitives.asymmetric import padding, rsa
4+
from cryptography.hazmat.primitives.ciphers import Cipher, algorithms, modes
5+
from cryptography.hazmat.backends import default_backend
6+
import os
7+
import base64
8+
from typing import Optional
9+
10+
import pytest
11+
12+
from unstructured_client import UnstructuredClient
13+
14+
@pytest.fixture
15+
def rsa_key_pair():
16+
private_key = rsa.generate_private_key(
17+
public_exponent=65537,
18+
key_size=2048,
19+
backend=default_backend()
20+
)
21+
public_key = private_key.public_key()
22+
23+
private_key_pem = private_key.private_bytes(
24+
encoding=serialization.Encoding.PEM,
25+
format=serialization.PrivateFormat.TraditionalOpenSSL,
26+
encryption_algorithm=serialization.NoEncryption()
27+
).decode('utf-8')
28+
29+
public_key_pem = public_key.public_bytes(
30+
encoding=serialization.Encoding.PEM,
31+
format=serialization.PublicFormat.SubjectPublicKeyInfo
32+
).decode('utf-8')
33+
34+
return private_key_pem, public_key_pem
35+
36+
def test_encrypt_rsa(rsa_key_pair):
37+
private_key_pem, public_key_pem = rsa_key_pair
38+
39+
client = UnstructuredClient()
40+
41+
plaintext = "This is a secret message."
42+
43+
secret_obj = client.users.encrypt_secret(public_key_pem, plaintext)
44+
45+
# A short payload should use direct RSA encryption
46+
assert secret_obj["type"] == 'rsa'
47+
48+
decrypted_text = client.users.decrypt_secret(
49+
private_key_pem,
50+
secret_obj["encrypted_value"],
51+
secret_obj["type"],
52+
"",
53+
"",
54+
)
55+
assert decrypted_text == plaintext
56+
57+
58+
def test_encrypt_rsa_aes(rsa_key_pair):
59+
private_key_pem, public_key_pem = rsa_key_pair
60+
61+
client = UnstructuredClient()
62+
63+
plaintext = "This is a secret message." * 100
64+
65+
secret_obj = client.users.encrypt_secret(public_key_pem, plaintext)
66+
67+
# A longer payload uses hybrid RSA-AES encryption
68+
assert secret_obj["type"] == 'rsa_aes'
69+
70+
decrypted_text = client.users.decrypt_secret(
71+
private_key_pem,
72+
secret_obj["encrypted_value"],
73+
secret_obj["type"],
74+
secret_obj["encrypted_aes_key"],
75+
secret_obj["aes_iv"],
76+
)
77+
assert decrypted_text == plaintext
78+
79+
80+
rsa_key_size_bytes = 2048 // 8
81+
max_payload_size = rsa_key_size_bytes - 66 # OAEP SHA256 overhead
82+
83+
@pytest.mark.parametrize(("plaintext", "secret_type"), [
84+
("Short message", "rsa"),
85+
("A" * (max_payload_size), "rsa"), # Just at the RSA limit
86+
("A" * (max_payload_size + 1), "rsa_aes"), # Just over the RSA limit
87+
("A" * 500, "rsa_aes"), # Well over the RSA limit
88+
])
89+
def test_encrypt_around_rsa_size_limit(rsa_key_pair, plaintext, secret_type):
90+
"""
91+
Test that payloads around the RSA size limit choose the correct algorithm.
92+
"""
93+
_, public_key_pem = rsa_key_pair
94+
95+
print(f"Testing plaintext of length {len(plaintext)} with expected type {secret_type}")
96+
97+
# Load the public key
98+
public_key = serialization.load_pem_public_key(
99+
public_key_pem.encode('utf-8'),
100+
backend=default_backend()
101+
)
102+
103+
client = UnstructuredClient()
104+
105+
secret_obj = client.users.encrypt_secret(public_key_pem, plaintext)
106+
107+
assert secret_obj["type"] == secret_type
108+
assert secret_obj["encrypted_value"] is not None

gen.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@ python:
3939
clientServerStatusCodesAsErrors: true
4040
defaultErrorName: SDKError
4141
description: Python Client SDK for Unstructured API
42-
enableCustomCodeRegions: false
42+
enableCustomCodeRegions: true
4343
enumFormat: enum
4444
fixFlags:
4545
responseRequiredSep2024: false

src/unstructured_client/users.py

Lines changed: 166 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,15 @@
77
from unstructured_client.models import errors, operations, shared
88
from unstructured_client.types import BaseModel, OptionalNullable, UNSET
99

10+
# region imports
11+
from cryptography import x509
12+
from cryptography.hazmat.primitives import serialization, hashes
13+
from cryptography.hazmat.primitives.asymmetric import padding, rsa
14+
from cryptography.hazmat.primitives.ciphers import Cipher, algorithms, modes
15+
from cryptography.hazmat.backends import default_backend
16+
import os
17+
import base64
18+
# endregion imports
1019

1120
class Users(BaseSDK):
1221
def retrieve(
@@ -458,3 +467,160 @@ async def store_secret_async(
458467
http_res_text,
459468
http_res,
460469
)
470+
471+
# region sdk-class-body
472+
def _encrypt_rsa_aes(
473+
self,
474+
public_key: rsa.RSAPublicKey,
475+
plaintext: str,
476+
) -> dict:
477+
# Generate a random AES key
478+
aes_key = os.urandom(32) # 256-bit AES key
479+
480+
# Generate a random IV
481+
iv = os.urandom(16)
482+
483+
# Encrypt using AES-CFB
484+
cipher = Cipher(
485+
algorithms.AES(aes_key),
486+
modes.CFB(iv),
487+
)
488+
encryptor = cipher.encryptor()
489+
ciphertext = encryptor.update(plaintext.encode('utf-8')) + encryptor.finalize()
490+
491+
# Encrypt the AES key using the RSA public key
492+
encrypted_key = public_key.encrypt(
493+
aes_key,
494+
padding.OAEP(
495+
mgf=padding.MGF1(algorithm=hashes.SHA256()),
496+
algorithm=hashes.SHA256(),
497+
label=None
498+
)
499+
)
500+
501+
return {
502+
'encrypted_aes_key': base64.b64encode(encrypted_key).decode('utf-8'),
503+
'aes_iv': base64.b64encode(iv).decode('utf-8'),
504+
'encrypted_value': base64.b64encode(ciphertext).decode('utf-8'),
505+
'type': 'rsa_aes',
506+
}
507+
508+
def _encrypt_rsa(
509+
self,
510+
public_key: rsa.RSAPublicKey,
511+
plaintext: str,
512+
) -> dict:
513+
# Load public RSA key
514+
ciphertext = public_key.encrypt(
515+
plaintext.encode(),
516+
padding.OAEP(
517+
mgf=padding.MGF1(algorithm=hashes.SHA256()),
518+
algorithm=hashes.SHA256(),
519+
label=None
520+
),
521+
)
522+
return {
523+
'encrypted_value': base64.b64encode(ciphertext).decode('utf-8'),
524+
'type': 'rsa',
525+
'encrypted_aes_key': "",
526+
'aes_iv': "",
527+
}
528+
529+
def decrypt_secret(
530+
self,
531+
private_key_pem: str,
532+
encrypted_value: str,
533+
secret_type: str,
534+
encrypted_aes_key: str,
535+
aes_iv: str,
536+
) -> str:
537+
private_key = serialization.load_pem_private_key(
538+
private_key_pem.encode('utf-8'),
539+
password=None,
540+
backend=default_backend()
541+
)
542+
543+
if not isinstance(private_key, rsa.RSAPrivateKey):
544+
raise TypeError("Private key must be a RSA private key for decryption.")
545+
546+
if secret_type == 'rsa':
547+
ciphertext = base64.b64decode(encrypted_value)
548+
plaintext = private_key.decrypt(
549+
ciphertext,
550+
padding.OAEP(
551+
mgf=padding.MGF1(algorithm=hashes.SHA256()),
552+
algorithm=hashes.SHA256(),
553+
label=None
554+
)
555+
)
556+
return plaintext.decode('utf-8')
557+
558+
# aes_rsa
559+
encrypted_aes_key_decoded = base64.b64decode(encrypted_aes_key)
560+
iv = base64.b64decode(aes_iv)
561+
ciphertext = base64.b64decode(encrypted_value)
562+
563+
aes_key = private_key.decrypt(
564+
encrypted_aes_key_decoded,
565+
padding.OAEP(
566+
mgf=padding.MGF1(algorithm=hashes.SHA256()),
567+
algorithm=hashes.SHA256(),
568+
label=None
569+
)
570+
)
571+
cipher = Cipher(
572+
algorithms.AES(aes_key),
573+
modes.CFB(iv),
574+
)
575+
decryptor = cipher.decryptor()
576+
plaintext = decryptor.update(ciphertext) + decryptor.finalize()
577+
return plaintext.decode('utf-8')
578+
579+
def encrypt_secret(
580+
self,
581+
encryption_cert_or_key_pem: str,
582+
plaintext: str,
583+
encryption_type: Optional[str] = None,
584+
) -> dict:
585+
"""
586+
Encrypts a plaintext string for securely sending to the Unstructured API.
587+
588+
Args:
589+
encryption_cert_or_key_pem (str): A PEM-encoded RSA public key or certificate.
590+
plaintext (str): The string to encrypt.
591+
type (str, optional): Encryption type, either "rsa" or "rsa_aes".
592+
593+
Returns:
594+
dict: A dictionary with encrypted AES key, iv, and ciphertext (all base64-encoded).
595+
"""
596+
# If a cert is provided, extract the public key
597+
if "BEGIN CERTIFICATE" in encryption_cert_or_key_pem:
598+
cert = x509.load_pem_x509_certificate(
599+
encryption_cert_or_key_pem.encode('utf-8'),
600+
)
601+
602+
public_key = cert.public_key() # type: ignore[assignment]
603+
else:
604+
public_key = serialization.load_pem_public_key(
605+
encryption_cert_or_key_pem.encode('utf-8'),
606+
backend=default_backend()
607+
) # type: ignore[assignment]
608+
609+
if not isinstance(public_key, rsa.RSAPublicKey):
610+
raise TypeError("Public key must be a RSA public key for encryption.")
611+
612+
# If the plaintext is short, use RSA directly
613+
# Otherwise, use a RSA_AES envelope hybrid
614+
# Use the length of the public key to determine the encryption type
615+
key_size_bytes = public_key.key_size // 8
616+
max_rsa_length = key_size_bytes - 66 # OAEP SHA256 overhead
617+
print(max_rsa_length)
618+
619+
if not encryption_type:
620+
encryption_type = "rsa" if len(plaintext) <= max_rsa_length else "rsa_aes"
621+
622+
if encryption_type == "rsa":
623+
return self._encrypt_rsa(public_key, plaintext)
624+
625+
return self._encrypt_rsa_aes(public_key, plaintext)
626+
# endregion sdk-class-body

0 commit comments

Comments
 (0)