org.bouncycastle
bcpkix-jdk18on
diff --git a/sdk/src/main/java/io/opentdf/platform/sdk/KeyType.java b/sdk/src/main/java/io/opentdf/platform/sdk/KeyType.java
index 06be4cf6..80291584 100644
--- a/sdk/src/main/java/io/opentdf/platform/sdk/KeyType.java
+++ b/sdk/src/main/java/io/opentdf/platform/sdk/KeyType.java
@@ -14,7 +14,9 @@ public enum KeyType {
RSA4096Key("rsa:4096"),
EC256Key("ec:secp256r1", SECP256R1),
EC384Key("ec:secp384r1", SECP384R1),
- EC521Key("ec:secp521r1", SECP521R1);
+ EC521Key("ec:secp521r1", SECP521R1),
+ MLKEM768Key("mlkem:768"),
+ MLKEM1024Key("mlkem:1024");
private final String keyType;
private final ECCurve curve;
@@ -93,4 +95,14 @@ public static KeyType fromPublicKeyAlgorithm(KasPublicKeyAlgEnum algorithm) {
public boolean isEc() {
return this.curve != null;
}
+
+ public boolean isMLKEM() {
+ switch (this) {
+ case MLKEM768Key:
+ case MLKEM1024Key:
+ return true;
+ default:
+ return false;
+ }
+ }
}
diff --git a/sdk/src/main/java/io/opentdf/platform/sdk/MLKEMKeyPair.java b/sdk/src/main/java/io/opentdf/platform/sdk/MLKEMKeyPair.java
new file mode 100644
index 00000000..04d95548
--- /dev/null
+++ b/sdk/src/main/java/io/opentdf/platform/sdk/MLKEMKeyPair.java
@@ -0,0 +1,214 @@
+package io.opentdf.platform.sdk;
+
+import org.bouncycastle.crypto.AsymmetricCipherKeyPair;
+import org.bouncycastle.crypto.SecretWithEncapsulation;
+import org.bouncycastle.pqc.crypto.mlkem.MLKEMExtractor;
+import org.bouncycastle.pqc.crypto.mlkem.MLKEMGenerator;
+import org.bouncycastle.pqc.crypto.mlkem.MLKEMKeyGenerationParameters;
+import org.bouncycastle.pqc.crypto.mlkem.MLKEMKeyPairGenerator;
+import org.bouncycastle.pqc.crypto.mlkem.MLKEMParameters;
+import org.bouncycastle.pqc.crypto.mlkem.MLKEMPrivateKeyParameters;
+import org.bouncycastle.pqc.crypto.mlkem.MLKEMPublicKeyParameters;
+
+import java.security.SecureRandom;
+import java.util.Arrays;
+import java.util.Base64;
+
+/**
+ * Pure ML-KEM (FIPS 203) key encapsulation for DEK wrapping.
+ *
+ * Wire format of the {@code wrappedKey} field (after base64 decode):
+ *
+ * mlkem_ciphertext (1088 bytes for 768; 1568 for 1024)
+ * || aes_gcm_nonce (12 bytes)
+ * || aes_gcm_ciphertext_and_tag
+ *
+ *
+ * Wrap key derivation: {@code HKDF-SHA256(ikm = mlkem_shared_secret,
+ * salt = SHA-256("TDF"), info = empty, L = 32)} via
+ * {@link ECKeyPair#calculateHKDF(byte[], byte[])}.
+ *
+ * The KAO uses {@code type == "wrapped"} (the default; no override needed)
+ * and leaves {@code ephemeralPublicKey} empty — the KAS disambiguates from
+ * RSA-wrapped by looking up the registered key's algorithm.
+ *
+ * BouncyCastle is required for the ML-KEM primitives because JDK 11 stdlib
+ * has no KEM API (added in JDK 21).
+ */
+final class MLKEMKeyPair {
+
+ static final MLKEMKeyPair MLKEM_768 = new MLKEMKeyPair(
+ MLKEMParameters.ml_kem_768,
+ "ML-KEM-768 PUBLIC KEY",
+ "ML-KEM-768 PRIVATE KEY",
+ /* publicKeySize */ 1184,
+ /* ciphertextSize */ 1088,
+ KeyType.MLKEM768Key);
+
+ static final MLKEMKeyPair MLKEM_1024 = new MLKEMKeyPair(
+ MLKEMParameters.ml_kem_1024,
+ "ML-KEM-1024 PUBLIC KEY",
+ "ML-KEM-1024 PRIVATE KEY",
+ /* publicKeySize */ 1568,
+ /* ciphertextSize */ 1568,
+ KeyType.MLKEM1024Key);
+
+ /** FIPS 203 seed (d || z) — same 64 bytes for both 768 and 1024. */
+ static final int SEED_SIZE = 64;
+ static final int SHARED_SECRET_SIZE = 32;
+
+ private final MLKEMParameters mlkemParams;
+ private final String pubPemBlock;
+ private final String privPemBlock;
+ private final int publicKeySize;
+ private final int ciphertextSize;
+ private final KeyType keyType;
+
+ private final byte[] publicKey;
+ private final byte[] privateKey;
+
+ private MLKEMKeyPair(MLKEMParameters mlkemParams, String pubPemBlock, String privPemBlock,
+ int publicKeySize, int ciphertextSize, KeyType keyType) {
+ this.mlkemParams = mlkemParams;
+ this.pubPemBlock = pubPemBlock;
+ this.privPemBlock = privPemBlock;
+ this.publicKeySize = publicKeySize;
+ this.ciphertextSize = ciphertextSize;
+ this.keyType = keyType;
+ this.publicKey = null;
+ this.privateKey = null;
+ }
+
+ private MLKEMKeyPair(MLKEMKeyPair params, byte[] publicKey, byte[] privateKey) {
+ this.mlkemParams = params.mlkemParams;
+ this.pubPemBlock = params.pubPemBlock;
+ this.privPemBlock = params.privPemBlock;
+ this.publicKeySize = params.publicKeySize;
+ this.ciphertextSize = params.ciphertextSize;
+ this.keyType = params.keyType;
+ this.publicKey = publicKey;
+ this.privateKey = privateKey;
+ }
+
+ static MLKEMKeyPair forKeyType(KeyType kt) {
+ switch (kt) {
+ case MLKEM768Key: return MLKEM_768;
+ case MLKEM1024Key: return MLKEM_1024;
+ default: throw new SDKException("not an ML-KEM key type: " + kt);
+ }
+ }
+
+ int publicKeySize() { return publicKeySize; }
+ int ciphertextSize() { return ciphertextSize; }
+ KeyType keyType() { return keyType; }
+
+ MLKEMKeyPair generate() {
+ SecureRandom random = new SecureRandom();
+ MLKEMKeyPairGenerator gen = new MLKEMKeyPairGenerator();
+ gen.init(new MLKEMKeyGenerationParameters(random, mlkemParams));
+ AsymmetricCipherKeyPair kp = gen.generateKeyPair();
+ byte[] pub = ((MLKEMPublicKeyParameters) kp.getPublic()).getEncoded();
+ byte[] seed = ((MLKEMPrivateKeyParameters) kp.getPrivate()).getSeed();
+ if (pub.length != publicKeySize) {
+ throw new SDKException("ML-KEM public key size " + pub.length + " != expected " + publicKeySize);
+ }
+ if (seed.length != SEED_SIZE) {
+ throw new SDKException("ML-KEM seed size " + seed.length + " != expected " + SEED_SIZE);
+ }
+ return new MLKEMKeyPair(this, pub, seed);
+ }
+
+ String publicKeyInPemFormat() {
+ return rawToPem(pubPemBlock, publicKey, publicKeySize);
+ }
+
+ String privateKeyInPemFormat() {
+ return rawToPem(privPemBlock, privateKey, SEED_SIZE);
+ }
+
+ byte[] getPublicKey() { return publicKey == null ? null : publicKey.clone(); }
+ byte[] getPrivateKey() { return privateKey == null ? null : privateKey.clone(); }
+
+ byte[] pubKeyFromPem(String pem) {
+ return decodeSizedPemBlock(pem, pubPemBlock, publicKeySize);
+ }
+
+ byte[] privateKeyFromPem(String pem) {
+ return decodeSizedPemBlock(pem, privPemBlock, SEED_SIZE);
+ }
+
+ /**
+ * Encapsulate against {@code rawPub} (an ML-KEM encapsulation key) and AES-256-GCM
+ * wrap the {@code dek}. Returns the raw, un-base64'd blob: ciphertext || AES-GCM(nonce||ct||tag).
+ */
+ byte[] wrapDEK(byte[] rawPub, byte[] dek) {
+ if (rawPub.length != publicKeySize) {
+ throw new SDKException("invalid " + keyType + " public key size: got " + rawPub.length + " want " + publicKeySize);
+ }
+ MLKEMPublicKeyParameters pub = new MLKEMPublicKeyParameters(mlkemParams, rawPub);
+ SecretWithEncapsulation enc = new MLKEMGenerator(new SecureRandom()).generateEncapsulated(pub);
+ byte[] sharedSecret = enc.getSecret();
+ byte[] ciphertext = enc.getEncapsulation();
+ if (ciphertext.length != ciphertextSize) {
+ throw new SDKException("ML-KEM ciphertext size " + ciphertext.length + " != expected " + ciphertextSize);
+ }
+
+ byte[] wrapKey = ECKeyPair.calculateHKDF(TDF.GLOBAL_KEY_SALT, sharedSecret);
+ byte[] encryptedDek = new AesGcm(wrapKey).encrypt(dek).asBytes();
+ byte[] out = new byte[ciphertextSize + encryptedDek.length];
+ System.arraycopy(ciphertext, 0, out, 0, ciphertextSize);
+ System.arraycopy(encryptedDek, 0, out, ciphertextSize, encryptedDek.length);
+ return out;
+ }
+
+ /**
+ * Inverse of {@link #wrapDEK(byte[], byte[])}. Used by unit tests and any future
+ * client-side decap path; the production decrypt flow defers unwrap to the KAS.
+ */
+ byte[] unwrapDEK(byte[] rawPriv, byte[] wrappedBlob) {
+ if (rawPriv.length != SEED_SIZE) {
+ throw new SDKException("invalid " + keyType + " private key seed size: got " + rawPriv.length + " want " + SEED_SIZE);
+ }
+ if (wrappedBlob.length <= ciphertextSize) {
+ throw new SDKException(keyType + " wrapped blob too short: got " + wrappedBlob.length
+ + ", need > " + ciphertextSize);
+ }
+ byte[] ciphertext = Arrays.copyOfRange(wrappedBlob, 0, ciphertextSize);
+ byte[] encryptedDek = Arrays.copyOfRange(wrappedBlob, ciphertextSize, wrappedBlob.length);
+
+ MLKEMPrivateKeyParameters priv = new MLKEMPrivateKeyParameters(mlkemParams, rawPriv);
+ byte[] sharedSecret = new MLKEMExtractor(priv).extractSecret(ciphertext);
+ byte[] wrapKey = ECKeyPair.calculateHKDF(TDF.GLOBAL_KEY_SALT, sharedSecret);
+ return new AesGcm(wrapKey).decrypt(new AesGcm.Encrypted(encryptedDek));
+ }
+
+ private static String rawToPem(String blockType, byte[] raw, int expectedSize) {
+ if (raw == null || raw.length != expectedSize) {
+ throw new SDKException("invalid " + blockType + " size: got " + (raw == null ? -1 : raw.length)
+ + " want " + expectedSize);
+ }
+ String b64 = Base64.getMimeEncoder(64, new byte[] { '\n' }).encodeToString(raw);
+ return "-----BEGIN " + blockType + "-----\n" + b64 + "\n-----END " + blockType + "-----\n";
+ }
+
+ private static byte[] decodeSizedPemBlock(String pem, String expectedType, int expectedSize) {
+ String header = "-----BEGIN " + expectedType + "-----";
+ String footer = "-----END " + expectedType + "-----";
+ int headerIdx = pem.indexOf(header);
+ int footerIdx = pem.indexOf(footer);
+ if (headerIdx < 0 || footerIdx < 0 || footerIdx <= headerIdx) {
+ throw new SDKException("failed to parse PEM formatted " + expectedType);
+ }
+ String body = pem.substring(headerIdx + header.length(), footerIdx).replaceAll("\\s", "");
+ byte[] raw;
+ try {
+ raw = Base64.getDecoder().decode(body);
+ } catch (IllegalArgumentException e) {
+ throw new SDKException("failed to base64-decode " + expectedType + " PEM body", e);
+ }
+ if (raw.length != expectedSize) {
+ throw new SDKException("invalid " + expectedType + " size: got " + raw.length + " want " + expectedSize);
+ }
+ return raw;
+ }
+}
diff --git a/sdk/src/main/java/io/opentdf/platform/sdk/TDF.java b/sdk/src/main/java/io/opentdf/platform/sdk/TDF.java
index b30460eb..151fa82c 100644
--- a/sdk/src/main/java/io/opentdf/platform/sdk/TDF.java
+++ b/sdk/src/main/java/io/opentdf/platform/sdk/TDF.java
@@ -226,7 +226,15 @@ private Manifest.KeyAccess createKeyAccess(Config.TDFConfig tdfConfig, Config.KA
: kasInfo.Algorithm;
var keyType = KeyType.fromString(algorithm);
- if (keyType.isEc()) {
+ if (keyType.isMLKEM()) {
+ // Pure ML-KEM: ct(1088/1568) || AES-GCM(nonce(12)||ct||tag(16)),
+ // base64'd into wrappedKey. KAO type stays "wrapped"; KAS
+ // disambiguates from RSA by the registered key's algorithm.
+ var mlkem = MLKEMKeyPair.forKeyType(keyType);
+ byte[] wrapped = mlkem.wrapDEK(mlkem.pubKeyFromPem(kasInfo.PublicKey), symKey);
+ keyAccess.wrappedKey = Base64.getEncoder().encodeToString(wrapped);
+ keyAccess.keyType = kWrapped;
+ } else if (keyType.isEc()) {
var ecKeyWrappedKeyInfo = createECWrappedKey(kasInfo, symKey, keyType);
keyAccess.wrappedKey = ecKeyWrappedKeyInfo.wrappedKey;
keyAccess.ephemeralPublicKey = ecKeyWrappedKeyInfo.publicKey;
diff --git a/sdk/src/test/java/io/opentdf/platform/sdk/MLKEMKeyPairTest.java b/sdk/src/test/java/io/opentdf/platform/sdk/MLKEMKeyPairTest.java
new file mode 100644
index 00000000..db3ab1fe
--- /dev/null
+++ b/sdk/src/test/java/io/opentdf/platform/sdk/MLKEMKeyPairTest.java
@@ -0,0 +1,110 @@
+package io.opentdf.platform.sdk;
+
+import org.junit.jupiter.api.Test;
+import org.junit.jupiter.params.ParameterizedTest;
+import org.junit.jupiter.params.provider.MethodSource;
+
+import java.util.stream.Stream;
+
+import static org.junit.jupiter.api.Assertions.assertArrayEquals;
+import static org.junit.jupiter.api.Assertions.assertEquals;
+import static org.junit.jupiter.api.Assertions.assertNotNull;
+import static org.junit.jupiter.api.Assertions.assertThrows;
+import static org.junit.jupiter.api.Assertions.assertTrue;
+
+/**
+ * Unit tests for pure ML-KEM key wrapping. Every test runs once per parameter
+ * set (ML-KEM-768 and ML-KEM-1024) so 1024 stays exercised even though no Go
+ * KAS reference exists for it yet.
+ *
+ * Exercises the full producer/consumer path locally: generate keypair → PEM
+ * round-trip → wrap a known DEK → unwrap with matching private key. If the
+ * wire format drifts (e.g. someone re-orders ciphertext and encrypted-DEK,
+ * or changes HKDF salt), the round-trip fails.
+ */
+class MLKEMKeyPairTest {
+
+ private static final byte[] DEK = "0123456789abcdef0123456789abcdef".getBytes();
+ private static final int AES_GCM_OVERHEAD = 12 + 16; // 12-byte nonce + 16-byte tag
+
+ static Stream algorithms() {
+ return Stream.of(MLKEMKeyPair.MLKEM_768, MLKEMKeyPair.MLKEM_1024);
+ }
+
+ @ParameterizedTest
+ @MethodSource("algorithms")
+ void roundTrip(MLKEMKeyPair alg) {
+ MLKEMKeyPair kp = alg.generate();
+
+ // PEM round-trip preserves both halves
+ String pubPem = kp.publicKeyInPemFormat();
+ String privPem = kp.privateKeyInPemFormat();
+ assertTrue(pubPem.startsWith("-----BEGIN " + headerFor(alg, "PUBLIC") + "-----"), "public PEM header");
+ assertTrue(privPem.contains(headerFor(alg, "PRIVATE")), "private PEM header");
+
+ byte[] rawPub = alg.pubKeyFromPem(pubPem);
+ byte[] rawPriv = alg.privateKeyFromPem(privPem);
+ assertEquals(alg.publicKeySize(), rawPub.length);
+ assertEquals(MLKEMKeyPair.SEED_SIZE, rawPriv.length);
+
+ // Wrap → expected blob length is ct + AES-GCM(12+|DEK|+16)
+ byte[] wrapped = alg.wrapDEK(rawPub, DEK);
+ assertNotNull(wrapped);
+ assertEquals(alg.ciphertextSize() + DEK.length + AES_GCM_OVERHEAD, wrapped.length,
+ "wrapped blob must be ct || nonce||ct||tag");
+
+ // Unwrap recovers the original DEK
+ byte[] unwrapped = alg.unwrapDEK(rawPriv, wrapped);
+ assertArrayEquals(DEK, unwrapped);
+ }
+
+ @ParameterizedTest
+ @MethodSource("algorithms")
+ void wrapRejectsWrongSizePublicKey(MLKEMKeyPair alg) {
+ byte[] tooShort = new byte[alg.publicKeySize() - 1];
+ SDKException ex = assertThrows(SDKException.class, () -> alg.wrapDEK(tooShort, DEK));
+ assertTrue(ex.getMessage().contains("public key size"), ex.getMessage());
+ }
+
+ @ParameterizedTest
+ @MethodSource("algorithms")
+ void unwrapRejectsWrongSizePrivateKey(MLKEMKeyPair alg) {
+ MLKEMKeyPair kp = alg.generate();
+ byte[] wrapped = alg.wrapDEK(kp.getPublicKey(), DEK);
+ byte[] badPriv = new byte[MLKEMKeyPair.SEED_SIZE - 1];
+ SDKException ex = assertThrows(SDKException.class, () -> alg.unwrapDEK(badPriv, wrapped));
+ assertTrue(ex.getMessage().contains("seed size"), ex.getMessage());
+ }
+
+ @ParameterizedTest
+ @MethodSource("algorithms")
+ void unwrapRejectsShortBlob(MLKEMKeyPair alg) {
+ MLKEMKeyPair kp = alg.generate();
+ byte[] tooShort = new byte[alg.ciphertextSize()]; // no room for the AES-GCM-wrapped DEK
+ SDKException ex = assertThrows(SDKException.class, () -> alg.unwrapDEK(kp.getPrivateKey(), tooShort));
+ assertTrue(ex.getMessage().contains("too short"), ex.getMessage());
+ }
+
+ @ParameterizedTest
+ @MethodSource("algorithms")
+ void tamperedCiphertextFailsAesGcmTag(MLKEMKeyPair alg) {
+ MLKEMKeyPair kp = alg.generate();
+ byte[] wrapped = alg.wrapDEK(kp.getPublicKey(), DEK);
+ // Flip a bit inside the AES-GCM-wrapped DEK section — must fail the tag check
+ wrapped[wrapped.length - 1] ^= 0x01;
+ assertThrows(Exception.class, () -> alg.unwrapDEK(kp.getPrivateKey(), wrapped));
+ }
+
+ @Test
+ void forKeyTypeDispatchesCorrectly() {
+ assertEquals(MLKEMKeyPair.MLKEM_768, MLKEMKeyPair.forKeyType(KeyType.MLKEM768Key));
+ assertEquals(MLKEMKeyPair.MLKEM_1024, MLKEMKeyPair.forKeyType(KeyType.MLKEM1024Key));
+ assertThrows(SDKException.class, () -> MLKEMKeyPair.forKeyType(KeyType.RSA2048Key));
+ }
+
+ private static String headerFor(MLKEMKeyPair alg, String half) {
+ return alg == MLKEMKeyPair.MLKEM_768
+ ? "ML-KEM-768 " + half + " KEY"
+ : "ML-KEM-1024 " + half + " KEY";
+ }
+}
diff --git a/sdk/src/test/java/io/opentdf/platform/sdk/TDFMLKEMTest.java b/sdk/src/test/java/io/opentdf/platform/sdk/TDFMLKEMTest.java
new file mode 100644
index 00000000..16365642
--- /dev/null
+++ b/sdk/src/test/java/io/opentdf/platform/sdk/TDFMLKEMTest.java
@@ -0,0 +1,118 @@
+package io.opentdf.platform.sdk;
+
+import com.connectrpc.ResponseMessage;
+import com.connectrpc.UnaryBlockingCall;
+import io.opentdf.platform.policy.KeyAccessServer;
+import io.opentdf.platform.policy.kasregistry.KeyAccessServerRegistryServiceClient;
+import io.opentdf.platform.policy.kasregistry.ListKeyAccessServersRequest;
+import io.opentdf.platform.policy.kasregistry.ListKeyAccessServersResponse;
+import org.junit.jupiter.api.BeforeAll;
+import org.junit.jupiter.params.ParameterizedTest;
+import org.junit.jupiter.params.provider.EnumSource;
+
+import java.io.ByteArrayInputStream;
+import java.io.ByteArrayOutputStream;
+import java.util.Base64;
+import java.util.Collections;
+
+import static org.assertj.core.api.Assertions.assertThat;
+import static org.mockito.ArgumentMatchers.any;
+import static org.mockito.Mockito.mock;
+import static org.mockito.Mockito.when;
+
+/**
+ * Producer-side integration test for pure ML-KEM key access. Drives
+ * {@link TDF#createTDF} with each ML-KEM {@link KeyType}; asserts the resulting
+ * manifest matches the wire format that the Go KAS expects (mirrors
+ * {@code lib/ocrypto/mlkem_key_pair.go} in opentdf/platform PR 3491):
+ *
+ * - {@code keyAccess.type == "wrapped"} (reused from the RSA slot)
+ * - {@code ephemeralPublicKey} is null/empty
+ * - {@code wrappedKey} decoded length == {@code ciphertextSize + 12 + 32 + 16}
+ *
+ * Round-trips the wrappedKey through the matching private key locally to
+ * confirm the producer/consumer wire format agrees.
+ */
+class TDFMLKEMTest {
+
+ private static KeyAccessServerRegistryServiceClient kasRegistryService;
+ private static final String KAS_URL = "https://kas.example.com";
+
+ @BeforeAll
+ static void setupRegistryMock() {
+ kasRegistryService = mock(KeyAccessServerRegistryServiceClient.class);
+ ListKeyAccessServersResponse response = ListKeyAccessServersResponse.newBuilder()
+ .addKeyAccessServers(KeyAccessServer.newBuilder().setUri(KAS_URL).build())
+ .build();
+ when(kasRegistryService.listKeyAccessServersBlocking(any(ListKeyAccessServersRequest.class), any()))
+ .thenReturn(new UnaryBlockingCall<>() {
+ @Override
+ public ResponseMessage execute() {
+ return new ResponseMessage.Success<>(response,
+ Collections.emptyMap(), Collections.emptyMap());
+ }
+ @Override public void cancel() {}
+ });
+ }
+
+ @ParameterizedTest
+ @EnumSource(value = KeyType.class, names = {"MLKEM768Key", "MLKEM1024Key"})
+ void createTDFProducesMLKEMWrappedKeyAccess(KeyType keyType) throws Exception {
+ MLKEMKeyPair alg = MLKEMKeyPair.forKeyType(keyType);
+ MLKEMKeyPair kp = alg.generate();
+ String pubPem = kp.publicKeyInPemFormat();
+ byte[] privSeed = alg.privateKeyFromPem(kp.privateKeyInPemFormat());
+
+ // Stub the SDK.KAS so getPublicKey returns the generated ML-KEM PEM.
+ // No unwrap mock is needed because this test never calls loadTDF.
+ SDK.KAS kasStub = new SDK.KAS() {
+ @Override public void close() {}
+ @Override public Config.KASInfo getPublicKey(Config.KASInfo kasInfo) {
+ Config.KASInfo out = new Config.KASInfo();
+ out.URL = kasInfo.URL;
+ out.KID = "mlkem-test-kid";
+ out.Algorithm = keyType.toString();
+ out.PublicKey = pubPem;
+ return out;
+ }
+ @Override public byte[] unwrap(Manifest.KeyAccess ka, String policy, KeyType skt) {
+ throw new UnsupportedOperationException("not used in this test");
+ }
+ @Override public KASKeyCache getKeyCache() { return new KASKeyCache(); }
+ };
+
+ Config.KASInfo kasInfo = new Config.KASInfo();
+ kasInfo.URL = KAS_URL;
+
+ Config.TDFConfig cfg = Config.newTDFConfig(
+ Config.withAutoconfigure(false),
+ Config.withKasInformation(kasInfo),
+ Config.WithWrappingKeyAlg(keyType));
+
+ TDF tdf = new TDF(new FakeServicesBuilder()
+ .setKas(kasStub)
+ .setKeyAccessServerRegistryService(kasRegistryService)
+ .build());
+
+ ByteArrayOutputStream tdfOut = new ByteArrayOutputStream();
+ var manifest = tdf.createTDF(
+ new ByteArrayInputStream("ml-kem round-trip payload".getBytes()),
+ tdfOut, cfg).getManifest();
+
+ assertThat(manifest.encryptionInformation.keyAccessObj).hasSize(1);
+ Manifest.KeyAccess ka = manifest.encryptionInformation.keyAccessObj.get(0);
+
+ // Wire-format invariants the KAS depends on.
+ assertThat(ka.keyType).isEqualTo("wrapped"); // serialized as JSON "type"
+ assertThat(ka.ephemeralPublicKey).isNullOrEmpty();
+ assertThat(ka.wrappedKey).isNotEmpty();
+
+ // wrappedKey layout: mlkem_ct || AES-GCM(nonce(12) || DEK(32) || tag(16))
+ byte[] wrappedBytes = Base64.getDecoder().decode(ka.wrappedKey);
+ assertThat(wrappedBytes.length).isEqualTo(alg.ciphertextSize() + 12 + 32 + 16);
+
+ // Round-trip the wrappedKey through the matching private key.
+ byte[] symKey = alg.unwrapDEK(privSeed, wrappedBytes);
+ assertThat(symKey).hasSize(32);
+ }
+}