Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -147,9 +147,10 @@ class DefaultProviderAttrs {
+ " # PQC key factories\n"
+ " # =======================================================================\n"
+ " #\n"
+ "Service.KeyFactory.ML-KEM = com.ibm.crypto.plus.provider.PQCKeyFactory$MLKEM\n"
+ "KeyFactory.ML-KEM-512.alias.add = ML_KEM_512, MLKEM512, OID.2.16.840.1.101.3.4.4.1, 2.16.840.1.101.3.4.4.1\n"
+ "Service.KeyFactory.ML-KEM-512 = com.ibm.crypto.plus.provider.PQCKeyFactory$MLKEM512\n"
+ "KeyFactory.ML-KEM-768.alias.add = ML-KEM, ML_KEM_768, MLKEM768, OID.2.16.840.1.101.3.4.4.2, 2.16.840.1.101.3.4.4.2\n"
+ "KeyFactory.ML-KEM-768.alias.add = ML_KEM_768, MLKEM768, OID.2.16.840.1.101.3.4.4.2, 2.16.840.1.101.3.4.4.2\n"
+ "Service.KeyFactory.ML-KEM-768 = com.ibm.crypto.plus.provider.PQCKeyFactory$MLKEM768\n"
+ "KeyFactory.ML-KEM-1024.alias.add = ML_KEM_1024, MLKEM1024, OID.2.16.840.1.101.3.4.4.3, 2.16.840.1.101.3.4.4.3\n"
+ "Service.KeyFactory.ML-KEM-1024 = com.ibm.crypto.plus.provider.PQCKeyFactory$MLKEM1024\n"
Expand Down Expand Up @@ -315,9 +316,10 @@ class DefaultProviderAttrs {
+ " # PQC key encapsulation mechanisms\n"
+ " # =======================================================================\n"
+ " #\n"
+ "Service.KEM.ML-KEM = com.ibm.crypto.plus.provider.MLKEMImpl$MLKEM\n"
+ "KEM.ML-KEM-512.alias.add = ML_KEM_512, MLKEM512, OID.2.16.840.1.101.3.4.4.1, 2.16.840.1.101.3.4.4.1\n"
+ "Service.KEM.ML-KEM-512 = com.ibm.crypto.plus.provider.MLKEMImpl$MLKEM512\n"
+ "KEM.ML-KEM-768.alias.add = ML-KEM, ML_KEM_768, MLKEM768, OID.2.16.840.1.101.3.4.4.2, 2.16.840.1.101.3.4.4.2\n"
+ "KEM.ML-KEM-768.alias.add = ML_KEM_768, MLKEM768, OID.2.16.840.1.101.3.4.4.2, 2.16.840.1.101.3.4.4.2\n"
+ "Service.KEM.ML-KEM-768 = com.ibm.crypto.plus.provider.MLKEMImpl$MLKEM768\n"

+ "KEM.ML-KEM-1024.alias.add = ML_KEM_1024, MLKEM1024, OID.2.16.840.1.101.3.4.4.3, 2.16.840.1.101.3.4.4.3\n"
Expand Down
92 changes: 83 additions & 9 deletions src/main/java/com/ibm/crypto/plus/provider/MLKEMImpl.java
Original file line number Diff line number Diff line change
Expand Up @@ -38,18 +38,45 @@ public MLKEMImpl(OpenJCEPlusProvider provider, String alg) {
this.alg = alg;
}

private int getEncapsulationLength() {
/**
* Validates that the key's algorithm matches this KEM instance's algorithm.
* The generic "ML-KEM" instance accepts keys from any ML-KEM variant.
* Specific instances (ML-KEM-512, ML-KEM-768, ML-KEM-1024) accept:
* - Keys with matching specific algorithm (e.g., ML-KEM-512)
* - Keys with generic "ML-KEM" algorithm (for interop with providers that use generic naming)
*
* @param keyAlgorithm the algorithm from the key
* @throws InvalidKeyException if the key algorithm doesn't match the instance algorithm
*/
private void validateKeyAlgorithm(String keyAlgorithm) throws InvalidKeyException {
// Generic ML-KEM instance accepts any ML-KEM variant key algorithm
if (this.alg.equals("ML-KEM")) {
return;
}

// Specific instance accepts exact match or generic "ML-KEM"
if (!this.alg.equals(keyAlgorithm) && !keyAlgorithm.equals("ML-KEM")) {
throw new InvalidKeyException("Key algorithm " + keyAlgorithm +
" does not match KEM instance algorithm " + this.alg);
}
}

private int getEncapsulationLength(String algorithm) {
int size = 0;

switch (this.alg) {
switch (algorithm) {
case "ML-KEM-512":
size = 768;
break;
case "ML-KEM-768":
size = 1088;
break;
default:
case "ML-KEM-1024":
size = 1568;
break;
default:
// If algorithm is generic "ML-KEM", default to ML-KEM-768
size = 1088;
}
return size;
}
Expand All @@ -72,14 +99,27 @@ public KEMSpi.EncapsulatorSpi engineNewEncapsulator(PublicKey publicKey,

if (!(pubKey instanceof PQCPublicKey)) {
// Try and convert this key to a usage PQCPublicKey
// First verify it's an ML-KEM key
String keyAlgorithm = publicKey.getAlgorithm();
if (keyAlgorithm == null || !keyAlgorithm.startsWith("ML-KEM")) {
throw new InvalidKeyException("unsupported key");
}

// Validate algorithm match (unless this is the generic ML-KEM instance)
validateKeyAlgorithm(keyAlgorithm);

// Use the key's actual algorithm, not the generic "ML-KEM"
try {
KeyFactory kf = KeyFactory.getInstance(this.alg, this.provider.getName());
KeyFactory kf = KeyFactory.getInstance(keyAlgorithm, this.provider.getName());
EncodedKeySpec publicKeySpec = new X509EncodedKeySpec(publicKey.getEncoded());
pubKey = kf.generatePublic(publicKeySpec);

} catch (Exception e) {
throw new InvalidKeyException("unsupported key", e);
}
} else {
// Key is already a PQCPublicKey, validate algorithm match
validateKeyAlgorithm(pubKey.getAlgorithm());
}

if (spec != null) {
Expand All @@ -105,7 +145,9 @@ class MLKEMEncapsulator implements KEMSpi.EncapsulatorSpi {

@Override
public KEM.Encapsulated engineEncapsulate(int from, int to, String algorithm) {
int encapLen = getEncapsulationLength();
// Get the actual algorithm from the public key
String keyAlgorithm = publicKey.getAlgorithm();
int encapLen = getEncapsulationLength(keyAlgorithm);
byte[] encapsulation = new byte[encapLen];
byte[] secret = new byte[SECRETSIZE];

Expand All @@ -130,7 +172,8 @@ public KEM.Encapsulated engineEncapsulate(int from, int to, String algorithm) {

@Override
public int engineEncapsulationSize() {
return getEncapsulationLength();
String keyAlgorithm = publicKey.getAlgorithm();
return getEncapsulationLength(keyAlgorithm);
}

@Override
Expand All @@ -155,9 +198,19 @@ public KEMSpi.DecapsulatorSpi engineNewDecapsulator(PrivateKey privateKey,

if (!(privKey instanceof PQCPrivateKey)) {
// Try and convert this key to a usage PQCPrivateKey
// First verify it's an ML-KEM key
String keyAlgorithm = privateKey.getAlgorithm();
if (keyAlgorithm == null || !keyAlgorithm.startsWith("ML-KEM")) {
throw new InvalidKeyException("unsupported key");
}

// Validate algorithm match (unless this is the generic ML-KEM instance)
validateKeyAlgorithm(keyAlgorithm);

// Use the key's actual algorithm, not the generic "ML-KEM"
byte[] encoding = null;
try {
KeyFactory kf = KeyFactory.getInstance(this.alg, this.provider.getName());
KeyFactory kf = KeyFactory.getInstance(keyAlgorithm, this.provider.getName());
encoding = privateKey.getEncoded();
PKCS8EncodedKeySpec privateKeySpec = new PKCS8EncodedKeySpec(encoding);
privKey = kf.generatePrivate(privateKeySpec);
Expand All @@ -167,6 +220,9 @@ public KEMSpi.DecapsulatorSpi engineNewDecapsulator(PrivateKey privateKey,
Arrays.fill(encoding, (byte) 0);
}

} else {
// Key is already a PQCPrivateKey, validate algorithm match
validateKeyAlgorithm(privKey.getAlgorithm());
}

if (spec != null) {
Expand Down Expand Up @@ -197,6 +253,17 @@ public SecretKey engineDecapsulate(byte[] cipherText, int from, int to, String a
if (algorithm == null || cipherText == null) {
throw new NullPointerException();
}

// Validate encapsulation length matches the key's algorithm
String keyAlgorithm = privateKey.getAlgorithm();
int expectedEncapLen = getEncapsulationLength(keyAlgorithm);
if (cipherText.length != expectedEncapLen) {
throw new DecapsulateException(
"Invalid key encapsulation message length: expected " +
expectedEncapLen + " bytes for " + keyAlgorithm +
", but got " + cipherText.length + " bytes");
}

try {
secret = OJPKEM.KEM_decapsulate(((PQCPrivateKey) this.privateKey).getPQCKey().getPKeyId(),
cipherText, provider);
Expand All @@ -210,8 +277,8 @@ public SecretKey engineDecapsulate(byte[] cipherText, int from, int to, String a

@Override
public int engineEncapsulationSize() {

return getEncapsulationLength();
String keyAlgorithm = privateKey.getAlgorithm();
return getEncapsulationLength(keyAlgorithm);
}

@Override
Expand All @@ -222,6 +289,13 @@ public int engineSecretSize() {

}

public static final class MLKEM extends MLKEMImpl {

public MLKEM(OpenJCEPlusProvider provider) {
super(provider, "ML-KEM");
}
}

public static final class MLKEM512 extends MLKEMImpl {

public MLKEM512(OpenJCEPlusProvider provider) {
Expand Down
23 changes: 21 additions & 2 deletions src/main/java/com/ibm/crypto/plus/provider/PQCKeyFactory.java
Original file line number Diff line number Diff line change
Expand Up @@ -188,8 +188,20 @@ private void checkKeyAlgo(Key key) throws InvalidKeyException {
String keyAlg = key.getAlgorithm();
if (keyAlg == null) {
throw new InvalidKeyException("Algorithm associate with key is null.");
} else if (!(key.getAlgorithm().equalsIgnoreCase(this.algName) ||
(PQCKnownOIDs.findMatch(key.getAlgorithm()).stdName().equalsIgnoreCase(this.algName)))) {
}

// Check if algorithms match exactly or via OID lookup
boolean matches = key.getAlgorithm().equalsIgnoreCase(this.algName) ||
(PQCKnownOIDs.findMatch(key.getAlgorithm()).stdName().equalsIgnoreCase(this.algName));

// Special case for generic ML-KEM: Allow any ML-KEM parameter set variant
// (ML-KEM-512, ML-KEM-768, ML-KEM-1024) when using the generic "ML-KEM" KeyFactory.
// This enables interoperability with KEM.getInstance("ML-KEM", ...).
if (!matches && "ML-KEM".equals(this.algName) && keyAlg.startsWith("ML-KEM")) {
matches = true;
}

if (!matches) {
throw new InvalidKeyException("Expected a " + this.algName + " key, but got " + keyAlg);
}

Expand Down Expand Up @@ -217,6 +229,13 @@ private boolean checkEncoded(byte[] key, boolean pub) {
}
}

public static final class MLKEM extends PQCKeyFactory {

public MLKEM(OpenJCEPlusProvider provider) {
super(provider, "ML-KEM");
}
}

public static final class MLKEM512 extends PQCKeyFactory {

public MLKEM512(OpenJCEPlusProvider provider) {
Expand Down
8 changes: 6 additions & 2 deletions src/test/ProviderDefAttrs.config
Original file line number Diff line number Diff line change
Expand Up @@ -228,10 +228,12 @@ Service.KeyFactory.RSAPSS = com.ibm.crypto.plus.provider.RSAKeyFactory$PSS
# PQC key factories
# =======================================================================
#
Service.KeyFactory.ML-KEM = com.ibm.crypto.plus.provider.PQCKeyFactory$MLKEM

KeyFactory.ML-KEM-512.alias.add = ML_KEM_512, MLKEM512, OID.2.16.840.1.101.3.4.4.1, 2.16.840.1.101.3.4.4.1
Service.KeyFactory.ML-KEM-512 = com.ibm.crypto.plus.provider.PQCKeyFactory$MLKEM512

KeyFactory.ML-KEM-768.alias.add = ML-KEM, ML_KEM_768, MLKEM768, OID.2.16.840.1.101.3.4.4.2, 2.16.840.1.101.3.4.4.2
KeyFactory.ML-KEM-768.alias.add = ML_KEM_768, MLKEM768, OID.2.16.840.1.101.3.4.4.2, 2.16.840.1.101.3.4.4.2
Service.KeyFactory.ML-KEM-768 = com.ibm.crypto.plus.provider.PQCKeyFactory$MLKEM768

KeyFactory.ML-KEM-1024.alias.add = ML_KEM_1024, MLKEM1024, OID.2.16.840.1.101.3.4.4.3, 2.16.840.1.101.3.4.4.3
Expand Down Expand Up @@ -464,10 +466,12 @@ Service.MessageDigest.SHA3-512 = com.ibm.crypto.plus.provider.MessageDigest$SHA3
# PQC key encapsulation mechanisms
# =======================================================================
#
Service.KEM.ML-KEM = com.ibm.crypto.plus.provider.MLKEMImpl$MLKEM

KEM.ML-KEM-512.alias.add = ML_KEM_512, MLKEM512, OID.2.16.840.1.101.3.4.4.1, 2.16.840.1.101.3.4.4.1
Service.KEM.ML-KEM-512 = com.ibm.crypto.plus.provider.MLKEMImpl$MLKEM512

KEM.ML-KEM-768.alias.add = ML-KEM, ML_KEM_768, MLKEM768, OID.2.16.840.1.101.3.4.4.2, 2.16.840.1.101.3.4.4.2
KEM.ML-KEM-768.alias.add = ML_KEM_768, MLKEM768, OID.2.16.840.1.101.3.4.4.2, 2.16.840.1.101.3.4.4.2
Service.KEM.ML-KEM-768 = com.ibm.crypto.plus.provider.MLKEMImpl$MLKEM768

KEM.ML-KEM-1024.alias.add = ML_KEM_1024, MLKEM1024, OID.2.16.840.1.101.3.4.4.3, 2.16.840.1.101.3.4.4.3
Expand Down
Loading
Loading