From 232a8055887f61cc0812f584dbc07df846cf407c Mon Sep 17 00:00:00 2001 From: sujan kota Date: Fri, 29 May 2026 12:54:55 -0400 Subject: [PATCH] feat(sdk): DSPX-3383 add pure ML-KEM-768 and ML-KEM-1024 key wrapping --- scripts/README.md | 98 +++++++ scripts/test-mlkem.sh | 252 ++++++++++++++++++ sdk/pom.xml | 6 + .../java/io/opentdf/platform/sdk/KeyType.java | 14 +- .../io/opentdf/platform/sdk/MLKEMKeyPair.java | 214 +++++++++++++++ .../java/io/opentdf/platform/sdk/TDF.java | 10 +- .../platform/sdk/MLKEMKeyPairTest.java | 110 ++++++++ .../io/opentdf/platform/sdk/TDFMLKEMTest.java | 118 ++++++++ 8 files changed, 820 insertions(+), 2 deletions(-) create mode 100644 scripts/README.md create mode 100755 scripts/test-mlkem.sh create mode 100644 sdk/src/main/java/io/opentdf/platform/sdk/MLKEMKeyPair.java create mode 100644 sdk/src/test/java/io/opentdf/platform/sdk/MLKEMKeyPairTest.java create mode 100644 sdk/src/test/java/io/opentdf/platform/sdk/TDFMLKEMTest.java diff --git a/scripts/README.md b/scripts/README.md new file mode 100644 index 00000000..9faaf3ec --- /dev/null +++ b/scripts/README.md @@ -0,0 +1,98 @@ +# scripts/ + +Developer scripts for the OpenTDF Java SDK. Not bundled with the published +artifacts. + +## `test-mlkem.sh` + +End-to-end test of the Java SDK's pure ML-KEM (FIPS 203) post-quantum key +wrapping (`mlkem:768`, `mlkem:1024`) against a locally running OpenTDF +platform. Per algorithm it: + +1. Confirms the KAS publishes an ML-KEM PEM for that algorithm (`grpcurl` + pre-flight, optional). +2. Encrypts a small payload via the `cmdline` jar using + `--encap-key-type=MLKEM768Key` (or `MLKEM1024Key`). +3. Asserts the resulting TDF manifest has: + - `keyAccess[0].type == "wrapped"` (reuses the existing RSA-wrapped slot; + KAS disambiguates by registered key algorithm) + - `keyAccess[0].ephemeralPublicKey` empty + - `keyAccess[0].wrappedKey` decoded length equals + `ciphertextSize + 12 (nonce) + 32 (DEK) + 16 (GCM tag)` +4. Decrypts the TDF — exercises the KAS rewrap path with server-side ML-KEM + decapsulation. +5. Diffs the decrypted payload against the original. + +On success the script prints the plaintext, the full `keyAccess[0]` (KAO), +and the decrypted output for each algorithm. + +### Prerequisites + +| Requirement | Notes | +|---|---| +| **JDK 17** | The project's Kotlin compiler can't parse newer JDK version strings. Use Corretto/Temurin/etc. 17. On macOS: `export JAVA_HOME=$(/usr/libexec/java_home -v 17)`. | +| **Maven 3.9+** | Project uses standard `mvn clean install`. | +| **Buf token** | Proto generation requires auth. Either `buf registry login` once, or export `BUF_INPUT_HTTPS_USERNAME` / `BUF_INPUT_HTTPS_PASSWORD`. | +| **`non-fips` Maven profile (default)** | Pure ML-KEM needs `bcprov-jdk18on` at compile/runtime scope (no JDK 11 stdlib equivalent; the JCA KEM API is JDK 21+). The default `non-fips` profile pulls it in. The `fips` profile does not yet support ML-KEM — follow-up. | +| **Local platform with ML-KEM support** | `opentdf/platform` checked out on a branch with PR 3491 applied; `preview.mlkem_enabled: true` in `opentdf-dev.yaml`; an `mlkem:768` KAS key registered. ML-KEM-1024 has no Go reference at the time of writing — Java unit tests cover it but end-to-end is 768-only for now. | +| **CLI tools** | `java`, `mvn`, `unzip`, `jq` on `PATH`. `grpcurl` optional but recommended (drives the pre-flight check). | + +### Run it + +From the repo root: + +```bash +# Default — mlkem:768 only (Go KAS scope) +PLATFORM_ENDPOINT=http://localhost:8080 scripts/test-mlkem.sh + +# Skip rebuild on iterative runs +scripts/test-mlkem.sh --skip-build + +# Include 1024 (only works once Go KAS supports it) +scripts/test-mlkem.sh --algorithms MLKEM768Key,MLKEM1024Key + +# Skip the grpcurl pre-flight (use when grpcurl isn't installed) +scripts/test-mlkem.sh --skip-kas-check +``` + +### Configuration + +| Flag / Env | Default | Description | +|---|---|---| +| `--platform-endpoint` / `PLATFORM_ENDPOINT` | `http://localhost:8080` | Platform base URL | +| `--kas-url` / `KAS_URL` | same as platform endpoint | KAS URL passed to cmdline `encrypt` | +| `--client-id` / `CLIENT_ID` | `opentdf-sdk` | OIDC client id | +| `--client-secret` / `CLIENT_SECRET` | `secret` | OIDC client secret | +| `--attr` / `DATA_ATTR` | `https://example.com/attr/attr1/value/value1` | Attribute FQN attached to encrypt | +| `--algorithms` | `MLKEM768Key` | Comma-separated subset of `KeyType` enum names | +| `--skip-build` | (off) | Reuse `cmdline/target/cmdline.jar` | +| `--skip-kas-check` | (off) | Skip the `grpcurl` pre-flight | + +### Troubleshooting + +| Symptom | Likely cause / fix | +|---|---| +| `Maven build failed ... Buf API token` | Run `buf registry login`, or export `BUF_INPUT_HTTPS_USERNAME` and `BUF_INPUT_HTTPS_PASSWORD`. | +| `Maven build failed ... Kotlin ... isAtLeastJava9` (stack trace) | JDK too new. `export JAVA_HOME=$(/usr/libexec/java_home -v 17)` and rerun. | +| `KAS returned no publicKey` | Platform isn't running, isn't reachable at `$PLATFORM_ENDPOINT`, or `preview.mlkem_enabled` is off. | +| `KAS returned a non-ML-KEM PEM` | Platform is up but no ML-KEM KAS key is registered for that algorithm. Register one (e.g. via `otdfctl`) and rerun. | +| `type='null'` (manifest assertion) | You're on an old branch where `TDF.java` doesn't yet route ML-KEM algorithms. Pull the latest branch HEAD. | +| `decrypt failed` after manifest passes | KAS-side rewrap doesn't yet support ML-KEM, or the Go SDK's HKDF discrepancy (see Known SDK gap below) hasn't been resolved. | +| `wrappedKey decoded length N != expected M` | Wire format drift — likely your local Java SDK has a different layout than what was tested. | + +### Known SDK gap + +`KeyType.fromAlgorithm` and `KeyType.fromPublicKeyAlgorithm` +(`sdk/src/main/java/io/opentdf/platform/sdk/KeyType.java`) don't yet map the +ML-KEM algorithm protobuf enums. Auto-discovery via the KAS registry +(`Config.KASInfo.fromKeyAccessServer`) will throw `IllegalArgumentException` +once the platform's proto definitions include `KAS_PUBLIC_KEY_ALG_ENUM_MLKEM_*` +values. This script bypasses that path by using `--encap-key-type` explicitly; +extending the script to also exercise registry-discovery should wait until +the mapping is added. + +### Reference + +- Go server: opentdf/platform PR 3491 (`lib/ocrypto/mlkem_key_pair.go`) +- Go SDK: opentdf/platform PR 3486 (`wrapKeyWithMLKEM`) +- Jira: DSPX-2399 diff --git a/scripts/test-mlkem.sh b/scripts/test-mlkem.sh new file mode 100755 index 00000000..7164fc46 --- /dev/null +++ b/scripts/test-mlkem.sh @@ -0,0 +1,252 @@ +#!/usr/bin/env bash +# +# test-mlkem.sh — round-trip the Java SDK's pure ML-KEM key wrapping (FIPS 203) +# against a locally running OpenTDF platform. +# +# Per algorithm: encrypt → assert manifest → KAS rewrap → decrypt → diff. +# Wire format: ct(1088 or 1568) || AES-GCM(nonce(12) || DEK(32) || tag(16)), +# base64'd into keyAccess.wrappedKey. keyAccess.type stays "wrapped" (NOT +# "hybrid-wrapped" — pure ML-KEM reuses the existing wrapped slot; the KAS +# disambiguates by the registered key's algorithm). +# +# Prereqs: +# * Local platform up at $PLATFORM_ENDPOINT on a branch with the ML-KEM PRs +# applied (opentdf/platform PR 3491 server + key registration) +# * preview.mlkem_enabled: true in opentdf-dev.yaml +# * An mlkem:768 KAS key registered (and mlkem:1024 if you test 1024 — Go +# KAS doesn't support 1024 yet at the time of writing) +# * java, mvn (JDK 17), unzip, jq on PATH +# * grpcurl optional (pre-flight key-publication check) +# +# Usage: +# scripts/test-mlkem.sh # 768 only (Go KAS scope) +# scripts/test-mlkem.sh --algorithms MLKEM768Key,MLKEM1024Key # both +# scripts/test-mlkem.sh --skip-build # reuse existing jar +# scripts/test-mlkem.sh --skip-kas-check # skip grpcurl pre-flight +# PLATFORM_ENDPOINT=http://localhost:8080 scripts/test-mlkem.sh + +set -euo pipefail + +SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" +REPO_ROOT="$(cd "$SCRIPT_DIR/.." && pwd)" +JAR="$REPO_ROOT/cmdline/target/cmdline.jar" + +PLATFORM_ENDPOINT="${PLATFORM_ENDPOINT:-http://localhost:8080}" +KAS_URL="${KAS_URL:-$PLATFORM_ENDPOINT}" +CLIENT_ID="${CLIENT_ID:-opentdf-sdk}" +CLIENT_SECRET="${CLIENT_SECRET:-secret}" +DATA_ATTR="${DATA_ATTR:-https://example.com/attr/attr1/value/value1}" +# Default to 768 only because Go KAS hasn't shipped 1024 support yet. +ALGORITHMS=(MLKEM768Key) +SKIP_BUILD=0 +SKIP_KAS_CHECK=0 + +while [[ $# -gt 0 ]]; do + case "$1" in + --skip-build) SKIP_BUILD=1; shift ;; + --skip-kas-check) SKIP_KAS_CHECK=1; shift ;; + --algorithms) IFS=, read -r -a ALGORITHMS <<< "$2"; shift 2 ;; + --platform-endpoint) PLATFORM_ENDPOINT="$2"; shift 2 ;; + --kas-url) KAS_URL="$2"; shift 2 ;; + --attr) DATA_ATTR="$2"; shift 2 ;; + --client-id) CLIENT_ID="$2"; shift 2 ;; + --client-secret) CLIENT_SECRET="$2"; shift 2 ;; + -h|--help) sed -n '2,/^$/p' "$0" | sed 's/^# \{0,1\}//'; exit 0 ;; + *) echo "unknown option: $1" >&2; exit 2 ;; + esac +done + +# Map KeyType enum name → (algorithm string, ciphertext size, PEM marker). +# Function form (case statement) for bash 3.2 (macOS system bash) compat. +alg_to_string() { + case "$1" in + MLKEM768Key) echo "mlkem:768" ;; + MLKEM1024Key) echo "mlkem:1024" ;; + *) return 1 ;; + esac +} +ciphertext_size() { + case "$1" in + MLKEM768Key) echo 1088 ;; + MLKEM1024Key) echo 1568 ;; + *) return 1 ;; + esac +} + +WORK_DIR="$(mktemp -d -t mlkem-pqc-XXXXXX)" +trap 'rm -rf "$WORK_DIR"' EXIT + +if [[ -t 1 ]]; then + GREEN=$'\033[0;32m'; RED=$'\033[0;31m'; YELLOW=$'\033[0;33m'; RESET=$'\033[0m' +else + GREEN=''; RED=''; YELLOW=''; RESET='' +fi +pass() { echo "${GREEN}[OK]${RESET} $*"; } +fail() { echo "${RED}[FAIL]${RESET} $*"; } +info() { echo "${YELLOW}[..]${RESET} $*"; } + +require() { command -v "$1" >/dev/null 2>&1 || { fail "missing required tool: $1"; exit 2; }; } +require java; require unzip; require jq +[[ $SKIP_BUILD -eq 1 ]] || require mvn + +run_cmdline() { + java -jar "$JAR" \ + --client-id="$CLIENT_ID" \ + --client-secret="$CLIENT_SECRET" \ + --platform-endpoint="$PLATFORM_ENDPOINT" \ + -h "$@" +} + +##### 1. Build +if [[ $SKIP_BUILD -eq 0 ]]; then + info "Building cmdline (mvn clean install -DskipTests)" + build_log="$WORK_DIR/build.log" + if ! (cd "$REPO_ROOT" && mvn --batch-mode clean install -DskipTests) > "$build_log" 2>&1; then + fail "Maven build failed. Tail of build log:" + tail -40 "$build_log" | sed 's/^/ /' + if grep -q "Buf API token" "$build_log" 2>/dev/null; then + fail "Hint: run 'buf registry login' or export BUF_INPUT_HTTPS_USERNAME / BUF_INPUT_HTTPS_PASSWORD before retrying." + fi + exit 1 + fi + pass "Build complete" +else + info "Skipping build (--skip-build)" +fi +[[ -f "$JAR" ]] || { fail "jar not found at $JAR — run without --skip-build"; exit 1; } + +##### 2. Pre-flight: confirm KAS publishes ML-KEM keys +if [[ $SKIP_KAS_CHECK -eq 0 ]] && command -v grpcurl >/dev/null 2>&1; then + info "Pre-flight: querying KAS for ML-KEM public keys" + host="${PLATFORM_ENDPOINT#http://}"; host="${host#https://}" + for alg_name in "${ALGORITHMS[@]}"; do + if ! alg=$(alg_to_string "$alg_name"); then + fail "unknown algorithm: $alg_name"; exit 2 + fi + resp=$(grpcurl -plaintext -d "{\"algorithm\":\"$alg\"}" \ + "$host" kas.AccessService/PublicKey 2>&1 || true) + pem=$(jq -r '.publicKey // empty' <<<"$resp" 2>/dev/null || true) + if [[ -z "$pem" ]]; then + fail "$alg: KAS returned no publicKey. Response was:" + echo "$resp" | head -5 | sed 's/^/ /' + fail "Is the platform running with preview.mlkem_enabled=true and the key registered?" + exit 1 + fi + first_line=$(echo "$pem" | head -1) + if [[ "$first_line" != *"ML-KEM"* ]]; then + fail "$alg: KAS returned a non-ML-KEM PEM (first line: $first_line)" + exit 1 + fi + pass "$alg: KAS returns ML-KEM PEM ($first_line)" + done +else + info "Skipping KAS pre-flight check" +fi + +##### 3. Round-trip each algorithm +PAYLOAD="$WORK_DIR/payload" +printf 'pure ml-kem round-trip payload @ %s\n' "$(date)" > "$PAYLOAD" +PAYLOAD_BYTES=$(wc -c < "$PAYLOAD" | tr -d ' ') +info "Test payload: $PAYLOAD_BYTES bytes" +echo " --- plaintext ---" +sed 's/^/ /' < "$PAYLOAD" +echo " --- end plaintext ---" + +failures=() +for alg_name in "${ALGORITHMS[@]}"; do + ct_size=$(ciphertext_size "$alg_name") + # Expected wrappedKey decoded length: ct || nonce(12) || DEK(32) || tag(16) + expected_wk_len=$((ct_size + 12 + 32 + 16)) + + tdf="$WORK_DIR/test-${alg_name}.tdf" + out="$WORK_DIR/out-${alg_name}" + enc_log="$WORK_DIR/encrypt-${alg_name}.log" + dec_log="$WORK_DIR/decrypt-${alg_name}.log" + + info "[$alg_name] encrypt" + if ! run_cmdline encrypt \ + --kas-url="$KAS_URL" \ + --mime-type=text/plain \ + --attr="$DATA_ATTR" \ + --autoconfigure=false \ + --encap-key-type="$alg_name" \ + -f "$PAYLOAD" > "$tdf" 2> "$enc_log"; then + fail "$alg_name: encrypt failed" + sed 's/^/ /' < "$enc_log" + failures+=("$alg_name (encrypt)") + continue + fi + + info "[$alg_name] verify manifest" + manifest_entry=$(unzip -l "$tdf" 2>/dev/null | awk '/manifest\.json$/ {print $NF; exit}') + if [[ -z "$manifest_entry" ]]; then + fail "$alg_name: no manifest.json entry inside $tdf" + failures+=("$alg_name (manifest entry missing)") + continue + fi + manifest=$(unzip -p "$tdf" "$manifest_entry") + # JSON key is "type" (Manifest.keyType is @SerializedName("type")). + type=$(jq -r '.encryptionInformation.keyAccess[0].type' <<<"$manifest") + ephem=$(jq -r '.encryptionInformation.keyAccess[0].ephemeralPublicKey // ""' <<<"$manifest") + wrapped=$(jq -r '.encryptionInformation.keyAccess[0].wrappedKey // ""' <<<"$manifest") + + if [[ "$type" != "wrapped" ]]; then + fail "$alg_name: type='$type' (expected 'wrapped' — ML-KEM reuses the RSA-wrapped slot)" + echo " keyAccess[0]:" + jq '.encryptionInformation.keyAccess[0]' <<<"$manifest" 2>/dev/null | sed 's/^/ /' + failures+=("$alg_name (bad type: $type)") + continue + fi + if [[ -n "$ephem" ]]; then + fail "$alg_name: ephemeralPublicKey unexpectedly set ('$ephem')" + failures+=("$alg_name (stray ephemeralPublicKey)") + continue + fi + if [[ -z "$wrapped" ]]; then + fail "$alg_name: wrappedKey is empty" + failures+=("$alg_name (empty wrappedKey)") + continue + fi + actual_wk_len=$(base64 -d <<<"$wrapped" 2>/dev/null | wc -c | tr -d ' ') + if [[ "$actual_wk_len" != "$expected_wk_len" ]]; then + fail "$alg_name: wrappedKey decoded length $actual_wk_len != expected $expected_wk_len (ct=$ct_size + nonce(12) + DEK(32) + tag(16))" + failures+=("$alg_name (bad wrappedKey length: $actual_wk_len)") + continue + fi + pass "$alg_name: manifest OK (type=wrapped, length=$actual_wk_len, no ephemeralPublicKey)" + echo " --- keyAccess[0] (KAO) ---" + jq '.encryptionInformation.keyAccess[0]' <<<"$manifest" | sed 's/^/ /' + echo " --- end keyAccess[0] ---" + + info "[$alg_name] decrypt (rewrap via KAS)" + if ! run_cmdline decrypt -f "$tdf" > "$out" 2> "$dec_log"; then + fail "$alg_name: decrypt failed" + sed 's/^/ /' < "$dec_log" + failures+=("$alg_name (decrypt)") + continue + fi + if ! diff -q "$PAYLOAD" "$out" >/dev/null; then + fail "$alg_name: decrypted payload differs from original" + echo " --- expected ---" + head -c 200 "$PAYLOAD" | sed 's/^/ /'; echo + echo " --- got ---" + head -c 200 "$out" | sed 's/^/ /'; echo + failures+=("$alg_name (payload mismatch)") + continue + fi + pass "$alg_name: round-trip OK" + out_bytes=$(wc -c < "$out" | tr -d ' ') + echo " --- decrypted ($out_bytes bytes) ---" + sed 's/^/ /' < "$out" + echo " --- end decrypted ---" +done + +echo +if [[ ${#failures[@]} -eq 0 ]]; then + echo "${GREEN}All ${#ALGORITHMS[@]} ML-KEM algorithm(s) passed round-trip.${RESET}" + exit 0 +else + echo "${RED}FAILURES (${#failures[@]}):${RESET}" + printf ' - %s\n' "${failures[@]}" + exit 1 +fi diff --git a/sdk/pom.xml b/sdk/pom.xml index de6b7618..8945001a 100644 --- a/sdk/pom.xml +++ b/sdk/pom.xml @@ -483,6 +483,12 @@ true + + + org.bouncycastle + bcprov-jdk18on + org.bouncycastle bcpkix-jdk18on diff --git a/sdk/src/main/java/io/opentdf/platform/sdk/KeyType.java b/sdk/src/main/java/io/opentdf/platform/sdk/KeyType.java index 06be4cf6..80291584 100644 --- a/sdk/src/main/java/io/opentdf/platform/sdk/KeyType.java +++ b/sdk/src/main/java/io/opentdf/platform/sdk/KeyType.java @@ -14,7 +14,9 @@ public enum KeyType { RSA4096Key("rsa:4096"), EC256Key("ec:secp256r1", SECP256R1), EC384Key("ec:secp384r1", SECP384R1), - EC521Key("ec:secp521r1", SECP521R1); + EC521Key("ec:secp521r1", SECP521R1), + MLKEM768Key("mlkem:768"), + MLKEM1024Key("mlkem:1024"); private final String keyType; private final ECCurve curve; @@ -93,4 +95,14 @@ public static KeyType fromPublicKeyAlgorithm(KasPublicKeyAlgEnum algorithm) { public boolean isEc() { return this.curve != null; } + + public boolean isMLKEM() { + switch (this) { + case MLKEM768Key: + case MLKEM1024Key: + return true; + default: + return false; + } + } } diff --git a/sdk/src/main/java/io/opentdf/platform/sdk/MLKEMKeyPair.java b/sdk/src/main/java/io/opentdf/platform/sdk/MLKEMKeyPair.java new file mode 100644 index 00000000..04d95548 --- /dev/null +++ b/sdk/src/main/java/io/opentdf/platform/sdk/MLKEMKeyPair.java @@ -0,0 +1,214 @@ +package io.opentdf.platform.sdk; + +import org.bouncycastle.crypto.AsymmetricCipherKeyPair; +import org.bouncycastle.crypto.SecretWithEncapsulation; +import org.bouncycastle.pqc.crypto.mlkem.MLKEMExtractor; +import org.bouncycastle.pqc.crypto.mlkem.MLKEMGenerator; +import org.bouncycastle.pqc.crypto.mlkem.MLKEMKeyGenerationParameters; +import org.bouncycastle.pqc.crypto.mlkem.MLKEMKeyPairGenerator; +import org.bouncycastle.pqc.crypto.mlkem.MLKEMParameters; +import org.bouncycastle.pqc.crypto.mlkem.MLKEMPrivateKeyParameters; +import org.bouncycastle.pqc.crypto.mlkem.MLKEMPublicKeyParameters; + +import java.security.SecureRandom; +import java.util.Arrays; +import java.util.Base64; + +/** + * Pure ML-KEM (FIPS 203) key encapsulation for DEK wrapping. + * + * Wire format of the {@code wrappedKey} field (after base64 decode): + *
+ *   mlkem_ciphertext (1088 bytes for 768; 1568 for 1024)
+ *   || aes_gcm_nonce (12 bytes)
+ *   || aes_gcm_ciphertext_and_tag
+ * 
+ * + * Wrap key derivation: {@code HKDF-SHA256(ikm = mlkem_shared_secret, + * salt = SHA-256("TDF"), info = empty, L = 32)} via + * {@link ECKeyPair#calculateHKDF(byte[], byte[])}. + * + * The KAO uses {@code type == "wrapped"} (the default; no override needed) + * and leaves {@code ephemeralPublicKey} empty — the KAS disambiguates from + * RSA-wrapped by looking up the registered key's algorithm. + * + * BouncyCastle is required for the ML-KEM primitives because JDK 11 stdlib + * has no KEM API (added in JDK 21). + */ +final class MLKEMKeyPair { + + static final MLKEMKeyPair MLKEM_768 = new MLKEMKeyPair( + MLKEMParameters.ml_kem_768, + "ML-KEM-768 PUBLIC KEY", + "ML-KEM-768 PRIVATE KEY", + /* publicKeySize */ 1184, + /* ciphertextSize */ 1088, + KeyType.MLKEM768Key); + + static final MLKEMKeyPair MLKEM_1024 = new MLKEMKeyPair( + MLKEMParameters.ml_kem_1024, + "ML-KEM-1024 PUBLIC KEY", + "ML-KEM-1024 PRIVATE KEY", + /* publicKeySize */ 1568, + /* ciphertextSize */ 1568, + KeyType.MLKEM1024Key); + + /** FIPS 203 seed (d || z) — same 64 bytes for both 768 and 1024. */ + static final int SEED_SIZE = 64; + static final int SHARED_SECRET_SIZE = 32; + + private final MLKEMParameters mlkemParams; + private final String pubPemBlock; + private final String privPemBlock; + private final int publicKeySize; + private final int ciphertextSize; + private final KeyType keyType; + + private final byte[] publicKey; + private final byte[] privateKey; + + private MLKEMKeyPair(MLKEMParameters mlkemParams, String pubPemBlock, String privPemBlock, + int publicKeySize, int ciphertextSize, KeyType keyType) { + this.mlkemParams = mlkemParams; + this.pubPemBlock = pubPemBlock; + this.privPemBlock = privPemBlock; + this.publicKeySize = publicKeySize; + this.ciphertextSize = ciphertextSize; + this.keyType = keyType; + this.publicKey = null; + this.privateKey = null; + } + + private MLKEMKeyPair(MLKEMKeyPair params, byte[] publicKey, byte[] privateKey) { + this.mlkemParams = params.mlkemParams; + this.pubPemBlock = params.pubPemBlock; + this.privPemBlock = params.privPemBlock; + this.publicKeySize = params.publicKeySize; + this.ciphertextSize = params.ciphertextSize; + this.keyType = params.keyType; + this.publicKey = publicKey; + this.privateKey = privateKey; + } + + static MLKEMKeyPair forKeyType(KeyType kt) { + switch (kt) { + case MLKEM768Key: return MLKEM_768; + case MLKEM1024Key: return MLKEM_1024; + default: throw new SDKException("not an ML-KEM key type: " + kt); + } + } + + int publicKeySize() { return publicKeySize; } + int ciphertextSize() { return ciphertextSize; } + KeyType keyType() { return keyType; } + + MLKEMKeyPair generate() { + SecureRandom random = new SecureRandom(); + MLKEMKeyPairGenerator gen = new MLKEMKeyPairGenerator(); + gen.init(new MLKEMKeyGenerationParameters(random, mlkemParams)); + AsymmetricCipherKeyPair kp = gen.generateKeyPair(); + byte[] pub = ((MLKEMPublicKeyParameters) kp.getPublic()).getEncoded(); + byte[] seed = ((MLKEMPrivateKeyParameters) kp.getPrivate()).getSeed(); + if (pub.length != publicKeySize) { + throw new SDKException("ML-KEM public key size " + pub.length + " != expected " + publicKeySize); + } + if (seed.length != SEED_SIZE) { + throw new SDKException("ML-KEM seed size " + seed.length + " != expected " + SEED_SIZE); + } + return new MLKEMKeyPair(this, pub, seed); + } + + String publicKeyInPemFormat() { + return rawToPem(pubPemBlock, publicKey, publicKeySize); + } + + String privateKeyInPemFormat() { + return rawToPem(privPemBlock, privateKey, SEED_SIZE); + } + + byte[] getPublicKey() { return publicKey == null ? null : publicKey.clone(); } + byte[] getPrivateKey() { return privateKey == null ? null : privateKey.clone(); } + + byte[] pubKeyFromPem(String pem) { + return decodeSizedPemBlock(pem, pubPemBlock, publicKeySize); + } + + byte[] privateKeyFromPem(String pem) { + return decodeSizedPemBlock(pem, privPemBlock, SEED_SIZE); + } + + /** + * Encapsulate against {@code rawPub} (an ML-KEM encapsulation key) and AES-256-GCM + * wrap the {@code dek}. Returns the raw, un-base64'd blob: ciphertext || AES-GCM(nonce||ct||tag). + */ + byte[] wrapDEK(byte[] rawPub, byte[] dek) { + if (rawPub.length != publicKeySize) { + throw new SDKException("invalid " + keyType + " public key size: got " + rawPub.length + " want " + publicKeySize); + } + MLKEMPublicKeyParameters pub = new MLKEMPublicKeyParameters(mlkemParams, rawPub); + SecretWithEncapsulation enc = new MLKEMGenerator(new SecureRandom()).generateEncapsulated(pub); + byte[] sharedSecret = enc.getSecret(); + byte[] ciphertext = enc.getEncapsulation(); + if (ciphertext.length != ciphertextSize) { + throw new SDKException("ML-KEM ciphertext size " + ciphertext.length + " != expected " + ciphertextSize); + } + + byte[] wrapKey = ECKeyPair.calculateHKDF(TDF.GLOBAL_KEY_SALT, sharedSecret); + byte[] encryptedDek = new AesGcm(wrapKey).encrypt(dek).asBytes(); + byte[] out = new byte[ciphertextSize + encryptedDek.length]; + System.arraycopy(ciphertext, 0, out, 0, ciphertextSize); + System.arraycopy(encryptedDek, 0, out, ciphertextSize, encryptedDek.length); + return out; + } + + /** + * Inverse of {@link #wrapDEK(byte[], byte[])}. Used by unit tests and any future + * client-side decap path; the production decrypt flow defers unwrap to the KAS. + */ + byte[] unwrapDEK(byte[] rawPriv, byte[] wrappedBlob) { + if (rawPriv.length != SEED_SIZE) { + throw new SDKException("invalid " + keyType + " private key seed size: got " + rawPriv.length + " want " + SEED_SIZE); + } + if (wrappedBlob.length <= ciphertextSize) { + throw new SDKException(keyType + " wrapped blob too short: got " + wrappedBlob.length + + ", need > " + ciphertextSize); + } + byte[] ciphertext = Arrays.copyOfRange(wrappedBlob, 0, ciphertextSize); + byte[] encryptedDek = Arrays.copyOfRange(wrappedBlob, ciphertextSize, wrappedBlob.length); + + MLKEMPrivateKeyParameters priv = new MLKEMPrivateKeyParameters(mlkemParams, rawPriv); + byte[] sharedSecret = new MLKEMExtractor(priv).extractSecret(ciphertext); + byte[] wrapKey = ECKeyPair.calculateHKDF(TDF.GLOBAL_KEY_SALT, sharedSecret); + return new AesGcm(wrapKey).decrypt(new AesGcm.Encrypted(encryptedDek)); + } + + private static String rawToPem(String blockType, byte[] raw, int expectedSize) { + if (raw == null || raw.length != expectedSize) { + throw new SDKException("invalid " + blockType + " size: got " + (raw == null ? -1 : raw.length) + + " want " + expectedSize); + } + String b64 = Base64.getMimeEncoder(64, new byte[] { '\n' }).encodeToString(raw); + return "-----BEGIN " + blockType + "-----\n" + b64 + "\n-----END " + blockType + "-----\n"; + } + + private static byte[] decodeSizedPemBlock(String pem, String expectedType, int expectedSize) { + String header = "-----BEGIN " + expectedType + "-----"; + String footer = "-----END " + expectedType + "-----"; + int headerIdx = pem.indexOf(header); + int footerIdx = pem.indexOf(footer); + if (headerIdx < 0 || footerIdx < 0 || footerIdx <= headerIdx) { + throw new SDKException("failed to parse PEM formatted " + expectedType); + } + String body = pem.substring(headerIdx + header.length(), footerIdx).replaceAll("\\s", ""); + byte[] raw; + try { + raw = Base64.getDecoder().decode(body); + } catch (IllegalArgumentException e) { + throw new SDKException("failed to base64-decode " + expectedType + " PEM body", e); + } + if (raw.length != expectedSize) { + throw new SDKException("invalid " + expectedType + " size: got " + raw.length + " want " + expectedSize); + } + return raw; + } +} diff --git a/sdk/src/main/java/io/opentdf/platform/sdk/TDF.java b/sdk/src/main/java/io/opentdf/platform/sdk/TDF.java index b30460eb..151fa82c 100644 --- a/sdk/src/main/java/io/opentdf/platform/sdk/TDF.java +++ b/sdk/src/main/java/io/opentdf/platform/sdk/TDF.java @@ -226,7 +226,15 @@ private Manifest.KeyAccess createKeyAccess(Config.TDFConfig tdfConfig, Config.KA : kasInfo.Algorithm; var keyType = KeyType.fromString(algorithm); - if (keyType.isEc()) { + if (keyType.isMLKEM()) { + // Pure ML-KEM: ct(1088/1568) || AES-GCM(nonce(12)||ct||tag(16)), + // base64'd into wrappedKey. KAO type stays "wrapped"; KAS + // disambiguates from RSA by the registered key's algorithm. + var mlkem = MLKEMKeyPair.forKeyType(keyType); + byte[] wrapped = mlkem.wrapDEK(mlkem.pubKeyFromPem(kasInfo.PublicKey), symKey); + keyAccess.wrappedKey = Base64.getEncoder().encodeToString(wrapped); + keyAccess.keyType = kWrapped; + } else if (keyType.isEc()) { var ecKeyWrappedKeyInfo = createECWrappedKey(kasInfo, symKey, keyType); keyAccess.wrappedKey = ecKeyWrappedKeyInfo.wrappedKey; keyAccess.ephemeralPublicKey = ecKeyWrappedKeyInfo.publicKey; diff --git a/sdk/src/test/java/io/opentdf/platform/sdk/MLKEMKeyPairTest.java b/sdk/src/test/java/io/opentdf/platform/sdk/MLKEMKeyPairTest.java new file mode 100644 index 00000000..db3ab1fe --- /dev/null +++ b/sdk/src/test/java/io/opentdf/platform/sdk/MLKEMKeyPairTest.java @@ -0,0 +1,110 @@ +package io.opentdf.platform.sdk; + +import org.junit.jupiter.api.Test; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.MethodSource; + +import java.util.stream.Stream; + +import static org.junit.jupiter.api.Assertions.assertArrayEquals; +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertNotNull; +import static org.junit.jupiter.api.Assertions.assertThrows; +import static org.junit.jupiter.api.Assertions.assertTrue; + +/** + * Unit tests for pure ML-KEM key wrapping. Every test runs once per parameter + * set (ML-KEM-768 and ML-KEM-1024) so 1024 stays exercised even though no Go + * KAS reference exists for it yet. + * + * Exercises the full producer/consumer path locally: generate keypair → PEM + * round-trip → wrap a known DEK → unwrap with matching private key. If the + * wire format drifts (e.g. someone re-orders ciphertext and encrypted-DEK, + * or changes HKDF salt), the round-trip fails. + */ +class MLKEMKeyPairTest { + + private static final byte[] DEK = "0123456789abcdef0123456789abcdef".getBytes(); + private static final int AES_GCM_OVERHEAD = 12 + 16; // 12-byte nonce + 16-byte tag + + static Stream algorithms() { + return Stream.of(MLKEMKeyPair.MLKEM_768, MLKEMKeyPair.MLKEM_1024); + } + + @ParameterizedTest + @MethodSource("algorithms") + void roundTrip(MLKEMKeyPair alg) { + MLKEMKeyPair kp = alg.generate(); + + // PEM round-trip preserves both halves + String pubPem = kp.publicKeyInPemFormat(); + String privPem = kp.privateKeyInPemFormat(); + assertTrue(pubPem.startsWith("-----BEGIN " + headerFor(alg, "PUBLIC") + "-----"), "public PEM header"); + assertTrue(privPem.contains(headerFor(alg, "PRIVATE")), "private PEM header"); + + byte[] rawPub = alg.pubKeyFromPem(pubPem); + byte[] rawPriv = alg.privateKeyFromPem(privPem); + assertEquals(alg.publicKeySize(), rawPub.length); + assertEquals(MLKEMKeyPair.SEED_SIZE, rawPriv.length); + + // Wrap → expected blob length is ct + AES-GCM(12+|DEK|+16) + byte[] wrapped = alg.wrapDEK(rawPub, DEK); + assertNotNull(wrapped); + assertEquals(alg.ciphertextSize() + DEK.length + AES_GCM_OVERHEAD, wrapped.length, + "wrapped blob must be ct || nonce||ct||tag"); + + // Unwrap recovers the original DEK + byte[] unwrapped = alg.unwrapDEK(rawPriv, wrapped); + assertArrayEquals(DEK, unwrapped); + } + + @ParameterizedTest + @MethodSource("algorithms") + void wrapRejectsWrongSizePublicKey(MLKEMKeyPair alg) { + byte[] tooShort = new byte[alg.publicKeySize() - 1]; + SDKException ex = assertThrows(SDKException.class, () -> alg.wrapDEK(tooShort, DEK)); + assertTrue(ex.getMessage().contains("public key size"), ex.getMessage()); + } + + @ParameterizedTest + @MethodSource("algorithms") + void unwrapRejectsWrongSizePrivateKey(MLKEMKeyPair alg) { + MLKEMKeyPair kp = alg.generate(); + byte[] wrapped = alg.wrapDEK(kp.getPublicKey(), DEK); + byte[] badPriv = new byte[MLKEMKeyPair.SEED_SIZE - 1]; + SDKException ex = assertThrows(SDKException.class, () -> alg.unwrapDEK(badPriv, wrapped)); + assertTrue(ex.getMessage().contains("seed size"), ex.getMessage()); + } + + @ParameterizedTest + @MethodSource("algorithms") + void unwrapRejectsShortBlob(MLKEMKeyPair alg) { + MLKEMKeyPair kp = alg.generate(); + byte[] tooShort = new byte[alg.ciphertextSize()]; // no room for the AES-GCM-wrapped DEK + SDKException ex = assertThrows(SDKException.class, () -> alg.unwrapDEK(kp.getPrivateKey(), tooShort)); + assertTrue(ex.getMessage().contains("too short"), ex.getMessage()); + } + + @ParameterizedTest + @MethodSource("algorithms") + void tamperedCiphertextFailsAesGcmTag(MLKEMKeyPair alg) { + MLKEMKeyPair kp = alg.generate(); + byte[] wrapped = alg.wrapDEK(kp.getPublicKey(), DEK); + // Flip a bit inside the AES-GCM-wrapped DEK section — must fail the tag check + wrapped[wrapped.length - 1] ^= 0x01; + assertThrows(Exception.class, () -> alg.unwrapDEK(kp.getPrivateKey(), wrapped)); + } + + @Test + void forKeyTypeDispatchesCorrectly() { + assertEquals(MLKEMKeyPair.MLKEM_768, MLKEMKeyPair.forKeyType(KeyType.MLKEM768Key)); + assertEquals(MLKEMKeyPair.MLKEM_1024, MLKEMKeyPair.forKeyType(KeyType.MLKEM1024Key)); + assertThrows(SDKException.class, () -> MLKEMKeyPair.forKeyType(KeyType.RSA2048Key)); + } + + private static String headerFor(MLKEMKeyPair alg, String half) { + return alg == MLKEMKeyPair.MLKEM_768 + ? "ML-KEM-768 " + half + " KEY" + : "ML-KEM-1024 " + half + " KEY"; + } +} diff --git a/sdk/src/test/java/io/opentdf/platform/sdk/TDFMLKEMTest.java b/sdk/src/test/java/io/opentdf/platform/sdk/TDFMLKEMTest.java new file mode 100644 index 00000000..16365642 --- /dev/null +++ b/sdk/src/test/java/io/opentdf/platform/sdk/TDFMLKEMTest.java @@ -0,0 +1,118 @@ +package io.opentdf.platform.sdk; + +import com.connectrpc.ResponseMessage; +import com.connectrpc.UnaryBlockingCall; +import io.opentdf.platform.policy.KeyAccessServer; +import io.opentdf.platform.policy.kasregistry.KeyAccessServerRegistryServiceClient; +import io.opentdf.platform.policy.kasregistry.ListKeyAccessServersRequest; +import io.opentdf.platform.policy.kasregistry.ListKeyAccessServersResponse; +import org.junit.jupiter.api.BeforeAll; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.EnumSource; + +import java.io.ByteArrayInputStream; +import java.io.ByteArrayOutputStream; +import java.util.Base64; +import java.util.Collections; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.when; + +/** + * Producer-side integration test for pure ML-KEM key access. Drives + * {@link TDF#createTDF} with each ML-KEM {@link KeyType}; asserts the resulting + * manifest matches the wire format that the Go KAS expects (mirrors + * {@code lib/ocrypto/mlkem_key_pair.go} in opentdf/platform PR 3491): + *
    + *
  • {@code keyAccess.type == "wrapped"} (reused from the RSA slot)
  • + *
  • {@code ephemeralPublicKey} is null/empty
  • + *
  • {@code wrappedKey} decoded length == {@code ciphertextSize + 12 + 32 + 16}
  • + *
+ * Round-trips the wrappedKey through the matching private key locally to + * confirm the producer/consumer wire format agrees. + */ +class TDFMLKEMTest { + + private static KeyAccessServerRegistryServiceClient kasRegistryService; + private static final String KAS_URL = "https://kas.example.com"; + + @BeforeAll + static void setupRegistryMock() { + kasRegistryService = mock(KeyAccessServerRegistryServiceClient.class); + ListKeyAccessServersResponse response = ListKeyAccessServersResponse.newBuilder() + .addKeyAccessServers(KeyAccessServer.newBuilder().setUri(KAS_URL).build()) + .build(); + when(kasRegistryService.listKeyAccessServersBlocking(any(ListKeyAccessServersRequest.class), any())) + .thenReturn(new UnaryBlockingCall<>() { + @Override + public ResponseMessage execute() { + return new ResponseMessage.Success<>(response, + Collections.emptyMap(), Collections.emptyMap()); + } + @Override public void cancel() {} + }); + } + + @ParameterizedTest + @EnumSource(value = KeyType.class, names = {"MLKEM768Key", "MLKEM1024Key"}) + void createTDFProducesMLKEMWrappedKeyAccess(KeyType keyType) throws Exception { + MLKEMKeyPair alg = MLKEMKeyPair.forKeyType(keyType); + MLKEMKeyPair kp = alg.generate(); + String pubPem = kp.publicKeyInPemFormat(); + byte[] privSeed = alg.privateKeyFromPem(kp.privateKeyInPemFormat()); + + // Stub the SDK.KAS so getPublicKey returns the generated ML-KEM PEM. + // No unwrap mock is needed because this test never calls loadTDF. + SDK.KAS kasStub = new SDK.KAS() { + @Override public void close() {} + @Override public Config.KASInfo getPublicKey(Config.KASInfo kasInfo) { + Config.KASInfo out = new Config.KASInfo(); + out.URL = kasInfo.URL; + out.KID = "mlkem-test-kid"; + out.Algorithm = keyType.toString(); + out.PublicKey = pubPem; + return out; + } + @Override public byte[] unwrap(Manifest.KeyAccess ka, String policy, KeyType skt) { + throw new UnsupportedOperationException("not used in this test"); + } + @Override public KASKeyCache getKeyCache() { return new KASKeyCache(); } + }; + + Config.KASInfo kasInfo = new Config.KASInfo(); + kasInfo.URL = KAS_URL; + + Config.TDFConfig cfg = Config.newTDFConfig( + Config.withAutoconfigure(false), + Config.withKasInformation(kasInfo), + Config.WithWrappingKeyAlg(keyType)); + + TDF tdf = new TDF(new FakeServicesBuilder() + .setKas(kasStub) + .setKeyAccessServerRegistryService(kasRegistryService) + .build()); + + ByteArrayOutputStream tdfOut = new ByteArrayOutputStream(); + var manifest = tdf.createTDF( + new ByteArrayInputStream("ml-kem round-trip payload".getBytes()), + tdfOut, cfg).getManifest(); + + assertThat(manifest.encryptionInformation.keyAccessObj).hasSize(1); + Manifest.KeyAccess ka = manifest.encryptionInformation.keyAccessObj.get(0); + + // Wire-format invariants the KAS depends on. + assertThat(ka.keyType).isEqualTo("wrapped"); // serialized as JSON "type" + assertThat(ka.ephemeralPublicKey).isNullOrEmpty(); + assertThat(ka.wrappedKey).isNotEmpty(); + + // wrappedKey layout: mlkem_ct || AES-GCM(nonce(12) || DEK(32) || tag(16)) + byte[] wrappedBytes = Base64.getDecoder().decode(ka.wrappedKey); + assertThat(wrappedBytes.length).isEqualTo(alg.ciphertextSize() + 12 + 32 + 16); + + // Round-trip the wrappedKey through the matching private key. + byte[] symKey = alg.unwrapDEK(privSeed, wrappedBytes); + assertThat(symKey).hasSize(32); + } +}