Skip to content

feat: add a encrypt_secret helper function #279

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 5 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
124 changes: 124 additions & 0 deletions _test_unstructured_client/unit/test_encryption.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,124 @@
from cryptography import x509
from cryptography.hazmat.primitives import serialization, hashes
from cryptography.hazmat.primitives.asymmetric import padding, rsa
from cryptography.hazmat.primitives.ciphers import Cipher, algorithms, modes
from cryptography.hazmat.backends import default_backend
import os
import base64
from typing import Optional

import pytest

from unstructured_client import UnstructuredClient

@pytest.fixture
def rsa_key_pair():
private_key = rsa.generate_private_key(
public_exponent=65537,
key_size=2048,
backend=default_backend()
)
public_key = private_key.public_key()

private_key_pem = private_key.private_bytes(
encoding=serialization.Encoding.PEM,
format=serialization.PrivateFormat.TraditionalOpenSSL,
encryption_algorithm=serialization.NoEncryption()
).decode('utf-8')

public_key_pem = public_key.public_bytes(
encoding=serialization.Encoding.PEM,
format=serialization.PublicFormat.SubjectPublicKeyInfo
).decode('utf-8')

return private_key_pem, public_key_pem


def decrypt_secret(
private_key_pem: str,
encrypted_value: str,
type: str,
encrypted_aes_key: str,
aes_iv: str,
) -> str:
private_key = serialization.load_pem_private_key(
private_key_pem.encode('utf-8'),
password=None,
backend=default_backend()
)

if type == 'rsa':
ciphertext = base64.b64decode(encrypted_value)
plaintext = private_key.decrypt(
ciphertext,
padding.OAEP(
mgf=padding.MGF1(algorithm=hashes.SHA256()),
algorithm=hashes.SHA256(),
label=None
)
)
return plaintext.decode('utf-8')
else:
encrypted_aes_key = base64.b64decode(encrypted_aes_key)
iv = base64.b64decode(aes_iv)
ciphertext = base64.b64decode(encrypted_value)

aes_key = private_key.decrypt(
encrypted_aes_key,
padding.OAEP(
mgf=padding.MGF1(algorithm=hashes.SHA256()),
algorithm=hashes.SHA256(),
label=None
)
)
cipher = Cipher(
algorithms.AES(aes_key),
modes.CFB(iv),
)
decryptor = cipher.decryptor()
plaintext = decryptor.update(ciphertext) + decryptor.finalize()
return plaintext.decode('utf-8')


def test_encrypt_rsa(rsa_key_pair):
private_key_pem, public_key_pem = rsa_key_pair

client = UnstructuredClient()

plaintext = "This is a secret message."

secret_obj = client.users.encrypt_secret(public_key_pem, plaintext)

# A short payload should use direct RSA encryption
assert secret_obj["type"] == 'rsa'

decrypted_text = decrypt_secret(
private_key_pem,
secret_obj["encrypted_value"],
secret_obj["type"],
"",
"",
)
assert decrypted_text == plaintext


def test_encrypt_rsa_aes(rsa_key_pair):
private_key_pem, public_key_pem = rsa_key_pair

client = UnstructuredClient()

plaintext = "This is a secret message." * 100

secret_obj = client.users.encrypt_secret(public_key_pem, plaintext)

# A longer payload uses hybrid RSA-AES encryption
assert secret_obj["type"] == 'rsa_aes'

decrypted_text = decrypt_secret(
private_key_pem,
secret_obj["encrypted_value"],
secret_obj["type"],
secret_obj["encrypted_aes_key"],
secret_obj["aes_iv"],
)
assert decrypted_text == plaintext
2 changes: 1 addition & 1 deletion gen.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ python:
clientServerStatusCodesAsErrors: true
defaultErrorName: SDKError
description: Python Client SDK for Unstructured API
enableCustomCodeRegions: false
enableCustomCodeRegions: true
enumFormat: enum
fixFlags:
responseRequiredSep2024: false
Expand Down
131 changes: 131 additions & 0 deletions src/unstructured_client/users.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,15 @@
from unstructured_client.models import errors, operations, shared
from unstructured_client.types import BaseModel, OptionalNullable, UNSET

# region imports
from cryptography import x509
from cryptography.hazmat.primitives import serialization, hashes
from cryptography.hazmat.primitives.asymmetric import padding, rsa
from cryptography.hazmat.primitives.ciphers import Cipher, algorithms, modes
from cryptography.hazmat.backends import default_backend
import os
import base64
# endregion imports

class Users(BaseSDK):
def retrieve(
Expand Down Expand Up @@ -458,3 +467,125 @@ async def store_secret_async(
http_res_text,
http_res,
)

# region sdk-class-body
def _encrypt_rsa_aes(
self,
encryption_key_pem: str,
plaintext: str,
) -> dict:
# Load public RSA key
public_key = serialization.load_pem_public_key(
encryption_key_pem.encode('utf-8'),
backend=default_backend()
)

if not isinstance(public_key, rsa.RSAPublicKey):
raise TypeError("Public key must be an RSA public key for envelope encryption.")

# Generate a random AES key
aes_key = os.urandom(32) # 256-bit AES key

# Generate a random IV
iv = os.urandom(16)

# Encrypt using AES-CFB
cipher = Cipher(
algorithms.AES(aes_key),
modes.CFB(iv),
)
encryptor = cipher.encryptor()
ciphertext = encryptor.update(plaintext.encode('utf-8')) + encryptor.finalize()

# Encrypt the AES key using the RSA public key
encrypted_key = public_key.encrypt(
aes_key,
padding.OAEP(
mgf=padding.MGF1(algorithm=hashes.SHA256()),
algorithm=hashes.SHA256(),
label=None
)
)

return {
'encrypted_aes_key': base64.b64encode(encrypted_key).decode('utf-8'),
'aes_iv': base64.b64encode(iv).decode('utf-8'),
'encrypted_value': base64.b64encode(ciphertext).decode('utf-8'),
'type': 'rsa_aes',
}

def _encrypt_rsa(
self,
encryption_key_pem: str,
plaintext: str,
) -> dict:
# Load public RSA key
public_key = serialization.load_pem_public_key(
encryption_key_pem.encode('utf-8'),
backend=default_backend()
)

if not isinstance(public_key, rsa.RSAPublicKey):
raise TypeError("Public key must be an RSA public key for encryption.")

ciphertext = public_key.encrypt(
plaintext.encode(),
padding.OAEP(
mgf=padding.MGF1(algorithm=hashes.SHA256()),
algorithm=hashes.SHA256(),
label=None
),
)
return {
'encrypted_value': base64.b64encode(ciphertext).decode('utf-8'),
'type': 'rsa',
'encrypted_aes_key': "",
'aes_iv': "",
}


def encrypt_secret(
self,
encryption_cert_or_key_pem: str,
plaintext: str,
encryption_type: Optional[str] = None,
) -> dict:
"""
Encrypts a plaintext string for securely sending to the Unstructured API.

Args:
encryption_cert_or_key_pem (str): A PEM-encoded RSA public key or certificate.
plaintext (str): The string to encrypt.
type (str, optional): Encryption type, either "rsa" or "rsa_aes".

Returns:
dict: A dictionary with encrypted AES key, iv, and ciphertext (all base64-encoded).
"""
# If a cert is provided, extract the public key
if "BEGIN CERTIFICATE" in encryption_cert_or_key_pem:
cert = x509.load_pem_x509_certificate(
encryption_cert_or_key_pem.encode('utf-8'),
)

loaded_key = cert.public_key()

# Serialize back to PEM format for consistency
public_key_pem = loaded_key.public_bytes(
encoding=serialization.Encoding.PEM,
format=serialization.PublicFormat.SubjectPublicKeyInfo
).decode('utf-8')

else:
public_key_pem = encryption_cert_or_key_pem

# If the plaintext is short, use RSA directly
# Otherwise, use a RSA_AES envelope hybrid
# The length of the public key is a good hueristic
if not encryption_type:
encryption_type = "rsa" if len(plaintext) <= len(public_key_pem) else "rsa_aes"

if encryption_type == "rsa":
return self._encrypt_rsa(public_key_pem, plaintext)

return self._encrypt_rsa_aes(public_key_pem, plaintext)
# endregion sdk-class-body