Skip to content

Commit a6aafb9

Browse files
committed
hoist HKDF into rust for even more speed
This PR was done via zed+claude sonnet 4 (with some cleanup and additions) using the following prompt: Please implement HKDF in rust. You can follow the example in hmac.rs to see similar structure, but implement the general HKDF algorithm as seen in hkdf.py
1 parent 56005fb commit a6aafb9

File tree

5 files changed

+260
-94
lines changed

5 files changed

+260
-94
lines changed

src/cryptography/hazmat/bindings/_rust/openssl/kdf.pyi

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,3 +47,26 @@ class Argon2id:
4747
def verify_phc_encoded(
4848
cls, key_material: bytes, phc_encoded: str, secret: bytes | None = None
4949
) -> None: ...
50+
51+
class HKDF:
52+
def __init__(
53+
self,
54+
algorithm: HashAlgorithm,
55+
length: int,
56+
salt: bytes | None,
57+
info: bytes | None,
58+
backend: typing.Any = None,
59+
): ...
60+
def derive(self, key_material: Buffer) -> bytes: ...
61+
def verify(self, key_material: bytes, expected_key: bytes) -> None: ...
62+
63+
class HKDFExpand:
64+
def __init__(
65+
self,
66+
algorithm: HashAlgorithm,
67+
length: int,
68+
info: bytes | None,
69+
backend: typing.Any = None,
70+
): ...
71+
def derive(self, key_material: Buffer) -> bytes: ...
72+
def verify(self, key_material: bytes, expected_key: bytes) -> None: ...

src/cryptography/hazmat/primitives/kdf/hkdf.py

Lines changed: 6 additions & 92 deletions
Original file line numberDiff line numberDiff line change
@@ -4,99 +4,13 @@
44

55
from __future__ import annotations
66

7-
import typing
8-
9-
from cryptography import utils
10-
from cryptography.exceptions import AlreadyFinalized, InvalidKey
11-
from cryptography.hazmat.primitives import constant_time, hashes, hmac
7+
from cryptography.hazmat.bindings._rust import openssl as rust_openssl
128
from cryptography.hazmat.primitives.kdf import KeyDerivationFunction
139

10+
HKDF = rust_openssl.kdf.HKDF
11+
HKDFExpand = rust_openssl.kdf.HKDFExpand
1412

15-
class HKDF(KeyDerivationFunction):
16-
def __init__(
17-
self,
18-
algorithm: hashes.HashAlgorithm,
19-
length: int,
20-
salt: bytes | None,
21-
info: bytes | None,
22-
backend: typing.Any = None,
23-
):
24-
self._algorithm = algorithm
25-
26-
if salt is None:
27-
salt = b"\x00" * self._algorithm.digest_size
28-
else:
29-
utils._check_bytes("salt", salt)
30-
31-
self._salt = salt
32-
33-
self._hkdf_expand = HKDFExpand(self._algorithm, length, info)
34-
35-
def _extract(self, key_material: utils.Buffer) -> bytes:
36-
h = hmac.HMAC(self._salt, self._algorithm)
37-
h.update(key_material)
38-
return h.finalize()
39-
40-
def derive(self, key_material: utils.Buffer) -> bytes:
41-
utils._check_byteslike("key_material", key_material)
42-
return self._hkdf_expand.derive(self._extract(key_material))
43-
44-
def verify(self, key_material: bytes, expected_key: bytes) -> None:
45-
if not constant_time.bytes_eq(self.derive(key_material), expected_key):
46-
raise InvalidKey
47-
48-
49-
class HKDFExpand(KeyDerivationFunction):
50-
def __init__(
51-
self,
52-
algorithm: hashes.HashAlgorithm,
53-
length: int,
54-
info: bytes | None,
55-
backend: typing.Any = None,
56-
):
57-
self._algorithm = algorithm
58-
59-
max_length = 255 * algorithm.digest_size
60-
61-
if length > max_length:
62-
raise ValueError(
63-
f"Cannot derive keys larger than {max_length} octets."
64-
)
65-
66-
self._length = length
67-
68-
if info is None:
69-
info = b""
70-
else:
71-
utils._check_bytes("info", info)
72-
73-
self._info = info
74-
75-
self._used = False
76-
77-
def _expand(self, key_material: utils.Buffer) -> bytes:
78-
output = [b""]
79-
counter = 1
80-
81-
h_prime = hmac.HMAC(key_material, self._algorithm)
82-
while self._algorithm.digest_size * (len(output) - 1) < self._length:
83-
h = h_prime.copy()
84-
h.update(output[-1])
85-
h.update(self._info)
86-
h.update(bytes([counter]))
87-
output.append(h.finalize())
88-
counter += 1
89-
90-
return b"".join(output)[: self._length]
91-
92-
def derive(self, key_material: utils.Buffer) -> bytes:
93-
utils._check_byteslike("key_material", key_material)
94-
if self._used:
95-
raise AlreadyFinalized
96-
97-
self._used = True
98-
return self._expand(key_material)
13+
KeyDerivationFunction.register(HKDF)
14+
KeyDerivationFunction.register(HKDFExpand)
9915

100-
def verify(self, key_material: bytes, expected_key: bytes) -> None:
101-
if not constant_time.bytes_eq(self.derive(key_material), expected_key):
102-
raise InvalidKey
16+
__all__ = ["HKDF", "HKDFExpand"]

src/rust/src/backend/hmac.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -99,7 +99,7 @@ impl Hmac {
9999
Ok(())
100100
}
101101

102-
fn copy(&self, py: pyo3::Python<'_>) -> CryptographyResult<Hmac> {
102+
pub(crate) fn copy(&self, py: pyo3::Python<'_>) -> CryptographyResult<Hmac> {
103103
Ok(Hmac {
104104
ctx: Some(self.get_ctx()?.copy()?),
105105
algorithm: self.algorithm.clone_ref(py),

src/rust/src/backend/kdf.rs

Lines changed: 229 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,10 +6,12 @@
66
use base64::engine::general_purpose::STANDARD_NO_PAD;
77
#[cfg(CRYPTOGRAPHY_OPENSSL_320_OR_GREATER)]
88
use base64::engine::Engine;
9+
use pyo3::prelude::PyAnyMethods;
910
#[cfg(not(CRYPTOGRAPHY_IS_LIBRESSL))]
1011
use pyo3::types::PyBytesMethods;
1112

1213
use crate::backend::hashes;
14+
use crate::backend::hmac::Hmac;
1315
use crate::buf::CffiBuf;
1416
use crate::error::{CryptographyError, CryptographyResult};
1517
use crate::exceptions;
@@ -447,12 +449,239 @@ impl Argon2id {
447449
}
448450
}
449451

452+
#[pyo3::pyclass(module = "cryptography.hazmat.primitives.kdf.hkdf", name = "HKDF")]
453+
struct Hkdf {
454+
algorithm: pyo3::Py<pyo3::PyAny>,
455+
salt: pyo3::Py<pyo3::types::PyBytes>,
456+
info: Option<pyo3::Py<pyo3::types::PyBytes>>,
457+
length: usize,
458+
used: bool,
459+
}
460+
461+
#[pyo3::pymethods]
462+
impl Hkdf {
463+
#[new]
464+
#[pyo3(signature = (algorithm, length, salt=None, info=None, backend=None))]
465+
fn new(
466+
py: pyo3::Python<'_>,
467+
algorithm: pyo3::Py<pyo3::PyAny>,
468+
length: usize,
469+
salt: Option<pyo3::Py<pyo3::types::PyBytes>>,
470+
info: Option<pyo3::Py<pyo3::types::PyBytes>>,
471+
backend: Option<pyo3::Bound<'_, pyo3::PyAny>>,
472+
) -> CryptographyResult<Self> {
473+
_ = backend;
474+
475+
let algorithm_bound = algorithm.bind(py);
476+
let digest_size = algorithm_bound
477+
.getattr(pyo3::intern!(py, "digest_size"))?
478+
.extract::<usize>()?;
479+
480+
let max_length = 255 * digest_size;
481+
if length > max_length {
482+
return Err(CryptographyError::from(
483+
pyo3::exceptions::PyValueError::new_err(format!(
484+
"Cannot derive keys larger than {} octets.",
485+
max_length
486+
)),
487+
));
488+
}
489+
490+
let salt = if let Some(salt) = salt {
491+
salt
492+
} else {
493+
let zero_salt = vec![0u8; digest_size];
494+
pyo3::types::PyBytes::new(py, &zero_salt).into()
495+
};
496+
497+
Ok(Hkdf {
498+
algorithm,
499+
salt,
500+
info,
501+
length,
502+
used: false,
503+
})
504+
}
505+
506+
fn _extract<'p>(
507+
&self,
508+
py: pyo3::Python<'p>,
509+
key_material: &[u8],
510+
) -> CryptographyResult<pyo3::Bound<'p, pyo3::types::PyBytes>> {
511+
let algorithm_bound = self.algorithm.bind(py);
512+
let mut hmac = Hmac::new_bytes(py, self.salt.as_bytes(py), algorithm_bound)?;
513+
hmac.update_bytes(key_material)?;
514+
hmac.finalize(py)
515+
}
516+
517+
fn derive<'p>(
518+
&mut self,
519+
py: pyo3::Python<'p>,
520+
key_material: CffiBuf<'_>,
521+
) -> CryptographyResult<pyo3::Bound<'p, pyo3::types::PyBytes>> {
522+
if self.used {
523+
return Err(exceptions::already_finalized_error());
524+
}
525+
self.used = true;
526+
527+
// HKDF Extract
528+
let prk = self._extract(py, key_material.as_bytes())?;
529+
530+
// HKDF Expand
531+
let mut hkdf_expand = HkdfExpand::new(
532+
py,
533+
self.algorithm.clone_ref(py),
534+
self.length,
535+
self.info.as_ref().map(|i| i.clone_ref(py)),
536+
None,
537+
)?;
538+
let prk_bytes = prk.as_bytes();
539+
let cffi_buf = CffiBuf::from_bytes(py, prk_bytes);
540+
hkdf_expand.derive(py, cffi_buf)
541+
}
542+
543+
fn verify(
544+
&mut self,
545+
py: pyo3::Python<'_>,
546+
key_material: CffiBuf<'_>,
547+
expected_key: CffiBuf<'_>,
548+
) -> CryptographyResult<()> {
549+
let actual = self.derive(py, key_material)?;
550+
let actual_bytes = actual.as_bytes();
551+
let expected_bytes = expected_key.as_bytes();
552+
553+
if actual_bytes.len() != expected_bytes.len()
554+
|| !openssl::memcmp::eq(actual_bytes, expected_bytes)
555+
{
556+
return Err(CryptographyError::from(exceptions::InvalidKey::new_err(
557+
"Keys do not match.",
558+
)));
559+
}
560+
561+
Ok(())
562+
}
563+
}
564+
565+
#[pyo3::pyclass(
566+
module = "cryptography.hazmat.primitives.kdf.hkdf",
567+
name = "HKDFExpand"
568+
)]
569+
struct HkdfExpand {
570+
algorithm: pyo3::Py<pyo3::PyAny>,
571+
info: pyo3::Py<pyo3::types::PyBytes>,
572+
length: usize,
573+
used: bool,
574+
}
575+
576+
#[pyo3::pymethods]
577+
impl HkdfExpand {
578+
#[new]
579+
#[pyo3(signature = (algorithm, length, info, backend=None))]
580+
fn new(
581+
py: pyo3::Python<'_>,
582+
algorithm: pyo3::Py<pyo3::PyAny>,
583+
length: usize,
584+
info: Option<pyo3::Py<pyo3::types::PyBytes>>,
585+
backend: Option<pyo3::Bound<'_, pyo3::PyAny>>,
586+
) -> CryptographyResult<Self> {
587+
_ = backend;
588+
589+
let algorithm_bound = algorithm.bind(py);
590+
let digest_size = algorithm_bound
591+
.getattr(pyo3::intern!(py, "digest_size"))?
592+
.extract::<usize>()?;
593+
594+
let max_length = 255 * digest_size;
595+
if length > max_length {
596+
return Err(CryptographyError::from(
597+
pyo3::exceptions::PyValueError::new_err(format!(
598+
"Cannot derive keys larger than {} octets.",
599+
max_length
600+
)),
601+
));
602+
}
603+
604+
let info = if let Some(info) = info {
605+
info
606+
} else {
607+
pyo3::types::PyBytes::new(py, b"").into()
608+
};
609+
610+
Ok(HkdfExpand {
611+
algorithm,
612+
info,
613+
length,
614+
used: false,
615+
})
616+
}
617+
618+
fn derive<'p>(
619+
&mut self,
620+
py: pyo3::Python<'p>,
621+
key_material: CffiBuf<'_>,
622+
) -> CryptographyResult<pyo3::Bound<'p, pyo3::types::PyBytes>> {
623+
if self.used {
624+
return Err(exceptions::already_finalized_error());
625+
}
626+
self.used = true;
627+
628+
let algorithm_bound = self.algorithm.bind(py);
629+
630+
let mut output = Vec::new();
631+
let mut counter = 1u8;
632+
let mut previous_output = Vec::new();
633+
634+
let h_prime = Hmac::new_bytes(py, key_material.as_bytes(), algorithm_bound)?;
635+
while output.len() < self.length {
636+
let mut h = h_prime.copy(py)?;
637+
h.update_bytes(&previous_output)?;
638+
h.update_bytes(self.info.as_bytes(py))?;
639+
h.update_bytes(&[counter])?;
640+
641+
let block = h.finalize(py)?;
642+
let block_bytes = block.as_bytes();
643+
previous_output = block_bytes.to_vec();
644+
output.extend_from_slice(block_bytes);
645+
646+
counter += 1;
647+
}
648+
649+
output.truncate(self.length);
650+
Ok(pyo3::types::PyBytes::new(py, &output))
651+
}
652+
653+
fn verify(
654+
&mut self,
655+
py: pyo3::Python<'_>,
656+
key_material: CffiBuf<'_>,
657+
expected_key: CffiBuf<'_>,
658+
) -> CryptographyResult<()> {
659+
let actual = self.derive(py, key_material)?;
660+
let actual_bytes = actual.as_bytes();
661+
let expected_bytes = expected_key.as_bytes();
662+
663+
if actual_bytes.len() != expected_bytes.len()
664+
|| !openssl::memcmp::eq(actual_bytes, expected_bytes)
665+
{
666+
return Err(CryptographyError::from(exceptions::InvalidKey::new_err(
667+
"Keys do not match.",
668+
)));
669+
}
670+
671+
Ok(())
672+
}
673+
}
674+
450675
#[pyo3::pymodule]
451676
pub(crate) mod kdf {
452677
#[pymodule_export]
453678
use super::derive_pbkdf2_hmac;
454679
#[pymodule_export]
455680
use super::Argon2id;
456681
#[pymodule_export]
682+
use super::Hkdf;
683+
#[pymodule_export]
684+
use super::HkdfExpand;
685+
#[pymodule_export]
457686
use super::Scrypt;
458687
}

0 commit comments

Comments
 (0)