Don't pass in MLKEM key factory on Bob side, not required

Fix NPE in MLKEMDHState.destroy()
This commit is contained in:
zzz
2025-03-27 12:45:40 -04:00
parent 6afd938c53
commit d06b776189
3 changed files with 29 additions and 14 deletions

View File

@ -271,6 +271,7 @@ public class HandshakeState implements Destroyable, Cloneable {
/**
* Creates a new Noise handshake.
* Noise protocol name is hardcoded.
* Not for PQ Alice side.
*
* @param patternId XK, IK, or N
* @param role The role, HandshakeState.INITIATOR or HandshakeState.RESPONDER.
@ -294,7 +295,7 @@ public class HandshakeState implements Destroyable, Cloneable {
* @param patternId XK, IK, or N
* @param role The role, HandshakeState.INITIATOR or HandshakeState.RESPONDER.
* @param xdh The key pair factory for ephemeral keys
* @param hdh The key pair factory for hybrid keys, or null for non-hybrid
* @param hdh The key pair factory for hybrid keys, Alice side only, or null for Bob or non-hybrid
*
* @throws IllegalArgumentException The protocolName is not
* formatted correctly, or the role is not recognized.
@ -349,16 +350,14 @@ public class HandshakeState implements Destroyable, Cloneable {
if ((flags & Pattern.FLAG_LOCAL_HYBRID) != 0) {
if (isInitiator && hdh == null)
throw new IllegalArgumentException("Hybrid patterns require hybrid key generator");
localHybrid = isInitiator ? new MLKEMDHState(hdh, patternId) : new MLKEMDHState(patternId);
localHybrid = isInitiator ? new MLKEMDHState(hdh, patternId) : new MLKEMDHState(false, patternId);
}
if ((flags & Pattern.FLAG_REMOTE_STATIC) != 0)
remotePublicKey = new Curve25519DHState(xdh);
if ((flags & Pattern.FLAG_REMOTE_EPHEMERAL) != 0)
remoteEphemeral = new Curve25519DHState(xdh);
if ((flags & Pattern.FLAG_REMOTE_HYBRID) != 0) {
if (!isInitiator && hdh == null)
throw new IllegalArgumentException("Hybrid patterns require hybrid key generator");
remoteHybrid = isInitiator ? new MLKEMDHState(patternId) : new MLKEMDHState(hdh, patternId);
remoteHybrid = new MLKEMDHState(!isInitiator, patternId);
}
}
@ -803,6 +802,8 @@ public class HandshakeState implements Destroyable, Cloneable {
if (localHybrid == null)
throw new IllegalStateException("Pattern definition error");
byte[] shared = null;
//System.out.println("State before writing F");
//System.out.println(toString());
if (isInitiator) {
// Only Alice generates a keypair
localHybrid.generateKeyPair();
@ -1035,6 +1036,9 @@ public class HandshakeState implements Destroyable, Cloneable {
if (remoteHybrid == null)
throw new IllegalStateException("Pattern definition error");
len = remoteHybrid.getPublicKeyLength();
//System.out.println("State before reading F");
//System.out.println(toString());
//System.out.println("State: F, reading remote eph. key len=" + len);
macLen = symmetric.getMACLength();
if (space < (len + macLen))
throw new ShortBufferException();

View File

@ -44,19 +44,27 @@ class MLKEMDHState implements DHState, Cloneable {
private final KeyFactory _hdh;
/**
* Bob side, do not call generateKeyPair()
* Bob local/remote or Alice remote side, do not call generateKeyPair()
* @param isAlice true for Bob remote side, false for Bob local side and Alice remote side
*/
public MLKEMDHState(String patternId)
public MLKEMDHState(boolean isAlice, String patternId)
{
this(null, patternId);
this(isAlice, null, patternId);
}
/**
* Alice side
* Alice local side
*/
public MLKEMDHState(KeyFactory hdh, String patternId)
{
boolean isAlice = hdh != null;
this(true, hdh, patternId);
}
/**
* Internal
*/
private MLKEMDHState(boolean isAlice, KeyFactory hdh, String patternId)
{
if (patternId.equals(HandshakeState.PATTERN_ID_IKHFS_512)) {
type = isAlice ? EncType.MLKEM512_X25519_INT : EncType.MLKEM512_X25519_CT;
} else if (patternId.equals(HandshakeState.PATTERN_ID_IKHFS_768)) {
@ -107,7 +115,7 @@ class MLKEMDHState implements DHState, Cloneable {
}
/**
* Alice side ONLY
* Alice local side ONLY
*/
@Override
public void generateKeyPair() {
@ -162,14 +170,16 @@ class MLKEMDHState implements DHState, Cloneable {
@Override
public void setToNullPublicKey() {
Arrays.fill(publicKey, (byte)0);
Arrays.fill(privateKey, (byte)0);
if (privateKey != null)
Arrays.fill(privateKey, (byte)0);
mode = 0x01;
}
@Override
public void clearKey() {
Noise.destroy(publicKey);
Noise.destroy(privateKey);
if (privateKey != null)
Noise.destroy(privateKey);
mode = 0;
}

View File

@ -481,7 +481,8 @@ public final class ECIESAEADEngine {
EncType type = targetPrivateKey.getType();
try {
String pattern = getNoisePattern(type);
state = new HandshakeState(pattern, HandshakeState.RESPONDER, _edhThread, getHybridKeyFactory(type));
// Bob does not need a hybrid key factory
state = new HandshakeState(pattern, HandshakeState.RESPONDER, _edhThread);
} catch (GeneralSecurityException gse) {
throw new IllegalStateException("bad proto", gse);
}