diff --git a/.github/workflows/dependency-submission.yml b/.github/workflows/dependency-submission.yml index c1491141a..5fa23850a 100644 --- a/.github/workflows/dependency-submission.yml +++ b/.github/workflows/dependency-submission.yml @@ -18,5 +18,9 @@ jobs: with: distribution: 'temurin' java-version: 17 + - name: Install Rust toolchain + uses: actions-rust-lang/setup-rust-toolchain@v1 + with: + target: wasm32-unknown-unknown - name: Generate and submit dependency graph uses: gradle/actions/dependency-submission@v4 \ No newline at end of file diff --git a/.github/workflows/docker.yaml b/.github/workflows/docker.yaml index f69e02065..cfaf65bbd 100644 --- a/.github/workflows/docker.yaml +++ b/.github/workflows/docker.yaml @@ -32,6 +32,11 @@ jobs: - name: Setup Gradle uses: gradle/actions/setup-gradle@v4 + - name: Install Rust toolchain + uses: actions-rust-lang/setup-rust-toolchain@v1 + with: + target: wasm32-unknown-unknown + - name: Log into GitHub container registry uses: docker/login-action@v2 with: diff --git a/.github/workflows/integration.yaml b/.github/workflows/integration.yaml index df36b392d..8c1d08bf5 100644 --- a/.github/workflows/integration.yaml +++ b/.github/workflows/integration.yaml @@ -123,6 +123,12 @@ jobs: if: ${{ inputs.serviceImage == '' }} uses: gradle/actions/setup-gradle@v4 + - name: Install Rust toolchain + if: ${{ inputs.serviceImage == '' }} + uses: actions-rust-lang/setup-rust-toolchain@v1 + with: + target: wasm32-unknown-unknown + - name: Build restatedev/test-services-java image if: ${{ inputs.serviceImage == '' }} run: ./gradlew -Djib.console=plain :test-services:jibDockerBuild diff --git a/.github/workflows/release-docs.yml b/.github/workflows/release-docs.yml index b6ebdc5e9..23e043687 100644 --- a/.github/workflows/release-docs.yml +++ b/.github/workflows/release-docs.yml @@ -31,6 +31,10 @@ jobs: java-version: '21' - name: Setup Gradle uses: gradle/actions/setup-gradle@v4 + - name: Install Rust toolchain + uses: actions-rust-lang/setup-rust-toolchain@v1 + with: + target: wasm32-unknown-unknown - name: Build Javadocs run: gradle :sdk-aggregated-javadocs:javadoc diff --git a/.github/workflows/release.yml b/.github/workflows/release.yml index 2a8f09e8d..e0cb542b6 100644 --- a/.github/workflows/release.yml +++ b/.github/workflows/release.yml @@ -20,6 +20,11 @@ jobs: - name: Setup Gradle uses: gradle/actions/setup-gradle@v4 + - name: Install Rust toolchain + uses: actions-rust-lang/setup-rust-toolchain@v1 + with: + target: wasm32-unknown-unknown + # Retrieve the version of the SDK - name: Install dasel run: curl -sSLf "$(curl -sSLf https://api.github.com/repos/tomwright/dasel/releases/latest | grep browser_download_url | grep linux_amd64 | grep -v .gz | cut -d\" -f 4)" -L -o dasel && chmod +x dasel && mv ./dasel /usr/local/bin/dasel diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index 5fdaeb1a1..f1c138ede 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -27,6 +27,11 @@ jobs: - name: Setup Gradle uses: gradle/actions/setup-gradle@v4 + - name: Install Rust toolchain + uses: actions-rust-lang/setup-rust-toolchain@v1 + with: + target: wasm32-unknown-unknown + - name: Pull Restate docker image run: docker pull ghcr.io/restatedev/restate:main @@ -54,6 +59,10 @@ jobs: java-version: '21' - name: Setup Gradle uses: gradle/actions/setup-gradle@v4 + - name: Install Rust toolchain + uses: actions-rust-lang/setup-rust-toolchain@v1 + with: + target: wasm32-unknown-unknown - name: Build Javadocs run: gradle :sdk-aggregated-javadocs:javadoc diff --git a/.gitignore b/.gitignore index 8eda7e77d..4a7221e7e 100644 --- a/.gitignore +++ b/.gitignore @@ -37,4 +37,5 @@ build kls_database.db .kotlin -.restate \ No newline at end of file +.restate +/sdk-core/src/main/rust/target/ diff --git a/gradle/libs.versions.toml b/gradle/libs.versions.toml index 0085c5e2d..21cf25dd4 100644 --- a/gradle/libs.versions.toml +++ b/gradle/libs.versions.toml @@ -117,18 +117,6 @@ [libraries.opentelemetry-kotlin.version] ref = 'opentelemetry' - [libraries.protobuf-java] - module = 'com.google.protobuf:protobuf-java' - - [libraries.protobuf-java.version] - ref = 'protobuf' - - [libraries.protobuf-kotlin] - module = 'com.google.protobuf:protobuf-kotlin' - - [libraries.protobuf-kotlin.version] - ref = 'protobuf' - [libraries.schema-kenerator-core] module = 'io.github.smiley4:schema-kenerator-core' @@ -213,7 +201,32 @@ [libraries.victools-jsonschema-module-jackson.version] ref = 'victools-json-schema' + [libraries.chicory-runtime] + module = 'com.dylibso.chicory:runtime' + + [libraries.chicory-runtime.version] + ref = 'chicory' + + [libraries.chicory-annotations] + module = 'com.dylibso.chicory:annotations' + + [libraries.chicory-annotations.version] + ref = 'chicory' + + [libraries.chicory-annotations-processor] + module = 'com.dylibso.chicory:annotations-processor' + + [libraries.chicory-annotations-processor.version] + ref = 'chicory' + + [libraries.jackson-cbor] + module = 'com.fasterxml.jackson.dataformat:jackson-dataformat-cbor' + + [libraries.jackson-cbor.version] + ref = 'jackson' + [plugins] + wasm2class = 'at.released.wasm2class.plugin:0.5.0' aggregate-javadoc = 'io.freefair.aggregate-javadoc:8.14' dependency-license-report = 'com.github.jk1.dependency-license-report:2.9' dokka = 'org.jetbrains.dokka:1.9.20' @@ -221,7 +234,6 @@ jsonschema2pojo = 'org.jsonschema2pojo:1.2.2' nexus-publish = 'io.github.gradle-nexus.publish-plugin:1.3.0' openapi-generator = 'org.openapi.generator:7.17.0' - protobuf = 'com.google.protobuf:0.9.4' shadow = 'com.gradleup.shadow:9.0.0-beta8' spotless = 'com.diffplug.spotless:7.2.1' spring-dependency-management = 'io.spring.dependency-management:1.1.6' @@ -233,6 +245,7 @@ ref = 'ksp' [versions] + chicory = '1.7.5' jackson = '2.19.4' junit = '5.14.1' kotlinx-coroutines = '1.10.2' @@ -240,7 +253,6 @@ ksp = '2.2.10-2.0.2' log4j = '2.24.3' opentelemetry = '1.58.0' - protobuf = '4.29.3' restate = '2.8.0-SNAPSHOT' schema-kenerator = '2.1.2' spring-boot = '3.4.13' diff --git a/sdk-core/build.gradle.kts b/sdk-core/build.gradle.kts index 87ce62e49..f1b67839e 100644 --- a/sdk-core/build.gradle.kts +++ b/sdk-core/build.gradle.kts @@ -8,9 +8,9 @@ plugins { `kotlin-conventions` `library-publishing-conventions` alias(libs.plugins.jsonschema2pojo) - alias(libs.plugins.protobuf) - alias(libs.plugins.shadow) alias(libs.plugins.ksp) + // Chicory AOT: compile .wasm to JVM bytecode at build time + id("at.released.wasm2class.plugin") version "0.5.0" // https://github.com/gradle/gradle/issues/20084#issuecomment-1060822638 id(libs.plugins.spotless.get().pluginId) apply false @@ -18,6 +18,97 @@ plugins { description = "Restate SDK Core" +// --------------------------------------------------------------------------- +// Rust WASM build pipeline (merged from sdk-shared-core) +// --------------------------------------------------------------------------- + +val rustSrcDir = file("src/main/rust") +val wasmReleaseDir = file("$rustSrcDir/target/wasm32-unknown-unknown/release") +val wasmFile = file("$wasmReleaseDir/restate_sdk_shared_core_wasm.wasm") +val wasmResourceDir = layout.buildDirectory.dir("wasm-resource") +val wasmResourceFile = wasmResourceDir.map { it.file("restate_sdk_shared_core_wasm.wasm") } + +val compileRustToWasm by + tasks.registering(Exec::class) { + group = "build" + description = "Compile the Rust WASM wrapper crate" + workingDir = rustSrcDir + commandLine("cargo", "build", "--target", "wasm32-unknown-unknown", "--release") + inputs.dir("$rustSrcDir/src") + inputs.file("$rustSrcDir/Cargo.toml") + outputs.file(wasmFile) + } + +val copyWasm by + tasks.registering(Copy::class) { + group = "build" + dependsOn(compileRustToWasm) + from(wasmFile) + into(wasmResourceDir) + } + +// Chicory AOT: compile .wasm → JVM bytecode +wasm2class { + modules { + targetPackage = "dev.restate.sdk.core.sharedcore.generated" + create("SharedCoreWasm") { wasm = wasmResourceFile } + } +} + +tasks.named("precompileWasm2Class") { dependsOn(copyWasm) } + +tasks.named("processResources") { dependsOn(copyWasm) } + +tasks.withType().configureEach { + mustRunAfter(generateWasmMarker, "precompileWasm2Class") +} + +// Generate the @WasmModuleInterface marker class with the absolute wasm file URI so the +// Chicory annotation processor can find it at compile time (it uses StandardLocation.CLASS_OUTPUT +// which is not reliable cross-tool, so a file: URI works around that). +val generatedWasmMarkerDir = layout.buildDirectory.dir("generated-wasm-marker") + +val generateWasmMarker by + tasks.registering { + group = "build" + dependsOn(copyWasm) + inputs.file(wasmResourceFile) + outputs.dir(generatedWasmMarkerDir) + doLast { + val wasmUri = wasmResourceFile.get().asFile.toURI().toString() + val content = + """ +// AUTO-GENERATED — do not edit. Regenerated by generateWasmMarker Gradle task. +// Copyright (c) 2023 - Restate Software, Inc., Restate GmbH +// +// This file is part of the Restate Java SDK, +// which is released under the MIT license. +package dev.restate.sdk.core.sharedcore; + +import com.dylibso.chicory.annotations.WasmModuleInterface; + +/** + * Marker class processed by Chicory's annotation processor. Generates + * {@code SharedCoreWasm_ModuleExports} (typed wrapper for all WASM exports). + */ +@WasmModuleInterface("$wasmUri") +public final class SharedCoreWasm {} + """ + .trimIndent() + "\n" + val out = + generatedWasmMarkerDir + .get() + .file("dev/restate/sdk/core/sharedcore/SharedCoreWasm.java") + .asFile + out.parentFile.mkdirs() + out.writeText(content) + } + } + +// --------------------------------------------------------------------------- +// Dependency configurations +// --------------------------------------------------------------------------- + val shade by configurations.creating val implementation by configurations.getting @@ -30,17 +121,21 @@ api.extendsFrom(shade) dependencies { compileOnly(libs.jspecify) - shadow(project(":sdk-common")) + // Chicory annotation processor (@WasmModuleInterface, @HostModule) + compileOnly(libs.chicory.annotations) + annotationProcessor(libs.chicory.annotations.processor) - shadow(libs.log4j.api) - shadow(libs.opentelemetry.api) + implementation(project(":sdk-common")) + implementation(libs.log4j.api) + implementation(libs.opentelemetry.api) // We need this for the manifest - shadow(libs.jackson.annotations) - shadow(libs.jackson.databind) + implementation(libs.jackson.annotations) + implementation(libs.jackson.databind) - // We shade protobuf java - shade(libs.protobuf.java) + // Chicory runtime + Jackson CBOR for the WASM bridge — shaded into the jar + implementation(libs.chicory.runtime) + implementation(libs.jackson.cbor) // We don't want a hard-dependency on it compileOnly(libs.log4j.core) @@ -58,8 +153,9 @@ dependencies { testImplementation(project(":sdk-lambda")) testImplementation(libs.jackson.annotations) testImplementation(libs.jackson.databind) + testImplementation(libs.jackson.cbor) testImplementation(libs.opentelemetry.api) - testImplementation(libs.protobuf.java) + testImplementation(libs.chicory.runtime) testImplementation(libs.mutiny) testImplementation(libs.junit.jupiter) testImplementation(libs.assertj) @@ -71,53 +167,47 @@ dependencies { testRuntimeOnly(libs.junit.platform.launcher) } -// Configure source sets for protobuf plugin and jsonschema2pojo +// --------------------------------------------------------------------------- +// Source sets +// --------------------------------------------------------------------------- + val generatedJ2SPDir = layout.buildDirectory.dir("generated/j2sp") sourceSets { main { java.srcDir(generatedJ2SPDir) - proto { srcDirs("src/main/service-protocol") } + java.srcDir(generatedWasmMarkerDir) + resources.srcDir(wasmResourceDir) } } -// Configure jsonSchema2Pojo +// --------------------------------------------------------------------------- +// jsonSchema2Pojo +// --------------------------------------------------------------------------- + jsonSchema2Pojo { setSource(files("$projectDir/src/main/service-protocol/endpoint_manifest_schema.json")) targetPackage = "dev.restate.sdk.core.generated.manifest" targetDirectory = generatedJ2SPDir.get().asFile - useLongIntegers = true includeSetters = true includeGetters = true generateBuilders = true } -// Configure protobuf - -val protobufVersion = libs.versions.protobuf.get() - -protobuf { protoc { artifact = "com.google.protobuf:protoc:$protobufVersion" } } - -// Make sure task dependencies are correct +// --------------------------------------------------------------------------- +// Task wiring +// --------------------------------------------------------------------------- tasks { withType { - dependsOn(generateJsonSchema2Pojo, generateProto) + dependsOn(generateJsonSchema2Pojo, generateWasmMarker, "precompileWasm2Class") val disabledClassesCodegen = listOf( - "dev.restate.sdk.core.javaapi.reflections.CheckedException", - "dev.restate.sdk.core.javaapi.reflections.CustomSerde", - "dev.restate.sdk.core.javaapi.reflections.Empty", - "dev.restate.sdk.core.javaapi.reflections.GreeterInterface", "dev.restate.sdk.core.javaapi.reflections.MyWorkflow", - "dev.restate.sdk.core.javaapi.reflections.ObjectGreeter", - "dev.restate.sdk.core.javaapi.reflections.ObjectGreeterImplementedFromInterface", - "dev.restate.sdk.core.javaapi.reflections.PrimitiveTypes", "dev.restate.sdk.core.javaapi.reflections.RawInputOutput", "dev.restate.sdk.core.javaapi.reflections.RawService", - "dev.restate.sdk.core.javaapi.reflections.ServiceGreeter", ) options.compilerArgs.addAll( @@ -127,62 +217,24 @@ tasks { ) ) } - withType().configureEach { dependsOn(generateJsonSchema2Pojo, generateProto) } + withType().configureEach { dependsOn(generateJsonSchema2Pojo) } withType().configureEach { - dependsOn(generateJsonSchema2Pojo, generateProto) - } - withType().configureEach { dependsOn(generateJsonSchema2Pojo, generateProto) } - - getByName("jar") { - enabled = false - dependsOn(shadowJar) + dependsOn(generateJsonSchema2Pojo, generateWasmMarker) } - - shadowJar { - configurations = listOf(shade) - enableRelocation = true - archiveClassifier = null - relocate("com.google.protobuf", "dev.restate.shaded.com.google.protobuf") - dependencies { - project.configurations["shadow"].allDependencies.forEach { exclude(dependency(it)) } - exclude("**/google/protobuf/*.proto") - } + withType().configureEach { + dependsOn(generateJsonSchema2Pojo, generateWasmMarker) } } ksp { val disabledClassesCodegen = listOf( - "dev.restate.sdk.core.kotlinapi.reflections.CheckedException", - "dev.restate.sdk.core.kotlinapi.reflections.CustomSerdeService", "dev.restate.sdk.core.kotlinapi.reflections.Empty", - "dev.restate.sdk.core.kotlinapi.reflections.GreeterInterface", - "dev.restate.sdk.core.kotlinapi.reflections.NestedDataClass", - "dev.restate.sdk.core.kotlinapi.reflections.CornerCases", - "dev.restate.sdk.core.kotlinapi.reflections.GreeterWithExplicitName", - "dev.restate.sdk.core.kotlinapi.reflections.MyWorkflow", - "dev.restate.sdk.core.kotlinapi.reflections.ObjectGreeter", - "dev.restate.sdk.core.kotlinapi.reflections.ObjectGreeterImplementedFromInterface", "dev.restate.sdk.core.kotlinapi.reflections.PrimitiveTypes", + "dev.restate.sdk.core.kotlinapi.reflections.CornerCases", "dev.restate.sdk.core.kotlinapi.reflections.RawInputOutput", - "dev.restate.sdk.core.kotlinapi.reflections.ServiceGreeter", + "dev.restate.sdk.core.kotlinapi.reflections.MyWorkflow", + "dev.restate.sdk.core.kotlinapi.reflections.GreeterWithExplicitName", ) arg("dev.restate.codegen.disabledClasses", disabledClassesCodegen.joinToString(",")) } - -// spotless configuration for protobuf - -configure { - format("proto") { - target("**/*.proto") - - // Exclude proto and service-protocol directories because those get the license header from - // their repos. - targetExclude( - fileTree("$rootDir/sdk-common/src/main/proto") { include("**/*.*") }, - fileTree("$rootDir/sdk-core/src/main/service-protocol") { include("**/*.*") }, - ) - - licenseHeaderFile("$rootDir/config/license-header", "syntax") - } -} diff --git a/sdk-core/src/main/java/dev/restate/sdk/core/AsyncResults.java b/sdk-core/src/main/java/dev/restate/sdk/core/AsyncResults.java index 4f128cca6..bda83e806 100644 --- a/sdk-core/src/main/java/dev/restate/sdk/core/AsyncResults.java +++ b/sdk-core/src/main/java/dev/restate/sdk/core/AsyncResults.java @@ -11,8 +11,6 @@ import dev.restate.common.function.ThrowingFunction; import dev.restate.sdk.common.AbortedExecutionException; import dev.restate.sdk.common.TerminalException; -import dev.restate.sdk.core.statemachine.NotificationValue; -import dev.restate.sdk.core.statemachine.StateMachine; import dev.restate.sdk.endpoint.definition.AsyncResult; import java.util.*; import java.util.concurrent.CompletableFuture; @@ -26,6 +24,11 @@ interface Completer { void complete(NotificationValue value, CompletableFuture future); } + @FunctionalInterface + interface NotificationReader { + java.util.Optional take(int handle); + } + private AsyncResults() {} static AsyncResultInternal single( @@ -48,7 +51,7 @@ interface AsyncResultInternal extends AsyncResult { void tryCancel(); - void tryComplete(StateMachine stateMachine); + void tryComplete(NotificationReader reader); CompletableFuture publicFuture(); @@ -111,9 +114,9 @@ public void tryCancel() { } @Override - public void tryComplete(StateMachine stateMachine) { - stateMachine - .takeNotification(handle) + public void tryComplete(NotificationReader reader) { + reader + .take(handle) .ifPresent( value -> { try { @@ -161,8 +164,8 @@ public void tryCancel() { } @Override - public void tryComplete(StateMachine stateMachine) { - asyncResult.tryComplete(stateMachine); + public void tryComplete(NotificationReader reader) { + asyncResult.tryComplete(reader); } @Override @@ -275,8 +278,8 @@ public void tryCancel() { } @Override - public void tryComplete(StateMachine stateMachine) { - asyncResults.forEach(ar -> ar.tryComplete(stateMachine)); + public void tryComplete(NotificationReader reader) { + asyncResults.forEach(ar -> ar.tryComplete(reader)); for (int i = 0; i < asyncResults.size(); i++) { if (asyncResults.get(i).isDone()) { publicFuture.complete(i); @@ -322,8 +325,8 @@ public void tryCancel() { } @Override - public void tryComplete(StateMachine stateMachine) { - asyncResults.forEach(ar -> ar.tryComplete(stateMachine)); + public void tryComplete(NotificationReader reader) { + asyncResults.forEach(ar -> ar.tryComplete(reader)); asyncResults.stream() .filter(ar -> ar.publicFuture().isCompletedExceptionally()) .findFirst() diff --git a/sdk-core/src/main/java/dev/restate/sdk/core/DiscoveryProtocol.java b/sdk-core/src/main/java/dev/restate/sdk/core/DiscoveryProtocol.java index 1d1a79c04..67172ecde 100644 --- a/sdk-core/src/main/java/dev/restate/sdk/core/DiscoveryProtocol.java +++ b/sdk-core/src/main/java/dev/restate/sdk/core/DiscoveryProtocol.java @@ -13,7 +13,6 @@ import com.fasterxml.jackson.databind.ObjectMapper; import com.fasterxml.jackson.databind.ser.impl.SimpleBeanPropertyFilter; import com.fasterxml.jackson.databind.ser.impl.SimpleFilterProvider; -import dev.restate.sdk.core.generated.discovery.Discovery; import dev.restate.sdk.core.generated.manifest.EndpointManifestSchema; import dev.restate.sdk.core.generated.manifest.Handler; import dev.restate.sdk.core.generated.manifest.Service; @@ -23,18 +22,41 @@ import java.util.stream.Collectors; import java.util.stream.Stream; -class DiscoveryProtocol { - static final Discovery.ServiceDiscoveryProtocolVersion MIN_SERVICE_DISCOVERY_PROTOCOL_VERSION = - Discovery.ServiceDiscoveryProtocolVersion.V1; - static final Discovery.ServiceDiscoveryProtocolVersion MAX_SERVICE_DISCOVERY_PROTOCOL_VERSION = - Discovery.ServiceDiscoveryProtocolVersion.V4; - - static boolean isSupported( - Discovery.ServiceDiscoveryProtocolVersion serviceDiscoveryProtocolVersion) { - return MIN_SERVICE_DISCOVERY_PROTOCOL_VERSION.getNumber() - <= serviceDiscoveryProtocolVersion.getNumber() - && serviceDiscoveryProtocolVersion.getNumber() - <= MAX_SERVICE_DISCOVERY_PROTOCOL_VERSION.getNumber(); +public class DiscoveryProtocol { + public enum Version { + V1("application/vnd.restate.endpointmanifest.v1+json"), + V2("application/vnd.restate.endpointmanifest.v2+json"), + V3("application/vnd.restate.endpointmanifest.v3+json"), + V4("application/vnd.restate.endpointmanifest.v4+json"); + + private final String header; + + Version(String header) { + this.header = header; + } + + public String getHeader() { + return header; + } + + public int getNumber() { + return ordinal() + 1; + } + + public boolean isSupported() { + // We support all versions so far + return true; + } + + public static final Version MIN = Version.V1; + public static final Version MAX = Version.V4; + + public static Optional fromHeader(String headerValue) { + String trimmed = headerValue.trim(); + return Stream.of(values()) + .filter(version -> version.header.equalsIgnoreCase(trimmed)) + .findFirst(); + } } /** @@ -44,69 +66,36 @@ static boolean isSupported( * @return The highest supported service protocol version, otherwise * Protocol.ServiceProtocolVersion.SERVICE_PROTOCOL_VERSION_UNSPECIFIED */ - static Discovery.ServiceDiscoveryProtocolVersion selectSupportedServiceDiscoveryProtocolVersion( - String acceptedVersionsString) { + static Version selectSupportedServiceDiscoveryProtocolVersion(String acceptedVersionsString) { // assume V1 in case nothing was set if (acceptedVersionsString == null || acceptedVersionsString.isEmpty()) { - return Discovery.ServiceDiscoveryProtocolVersion.V1; + return Version.V1; } final String[] supportedVersions = acceptedVersionsString.split(","); - Discovery.ServiceDiscoveryProtocolVersion maxVersion = - Discovery.ServiceDiscoveryProtocolVersion.SERVICE_DISCOVERY_PROTOCOL_VERSION_UNSPECIFIED; + Version maxVersion = null; for (String versionString : supportedVersions) { - final Optional optionalVersion = - parseServiceDiscoveryProtocolVersion(versionString); + final Optional optionalVersion = Version.fromHeader(versionString); if (optionalVersion.isPresent()) { - final Discovery.ServiceDiscoveryProtocolVersion version = optionalVersion.get(); - if (isSupported(version) && version.getNumber() > maxVersion.getNumber()) { + final Version version = optionalVersion.get(); + if (version.isSupported() + && (maxVersion == null || version.getNumber() > maxVersion.getNumber())) { maxVersion = version; } } } - return maxVersion; - } - - static Optional parseServiceDiscoveryProtocolVersion( - String versionString) { - versionString = versionString.trim(); - - if (versionString.equals("application/vnd.restate.endpointmanifest.v1+json")) { - return Optional.of(Discovery.ServiceDiscoveryProtocolVersion.V1); - } - if (versionString.equals("application/vnd.restate.endpointmanifest.v2+json")) { - return Optional.of(Discovery.ServiceDiscoveryProtocolVersion.V2); - } - if (versionString.equals("application/vnd.restate.endpointmanifest.v3+json")) { - return Optional.of(Discovery.ServiceDiscoveryProtocolVersion.V3); - } - if (versionString.equals("application/vnd.restate.endpointmanifest.v4+json")) { - return Optional.of(Discovery.ServiceDiscoveryProtocolVersion.V4); + if (Objects.isNull(maxVersion)) { + throw new ProtocolException( + String.format( + "Unsupported Discovery version in the Accept header '%s'", acceptedVersionsString), + ProtocolException.UNSUPPORTED_MEDIA_TYPE_CODE); } - return Optional.empty(); - } - static String serviceDiscoveryProtocolVersionToHeaderValue( - Discovery.ServiceDiscoveryProtocolVersion version) { - if (Objects.requireNonNull(version) == Discovery.ServiceDiscoveryProtocolVersion.V1) { - return "application/vnd.restate.endpointmanifest.v1+json"; - } - if (Objects.requireNonNull(version) == Discovery.ServiceDiscoveryProtocolVersion.V2) { - return "application/vnd.restate.endpointmanifest.v2+json"; - } - if (Objects.requireNonNull(version) == Discovery.ServiceDiscoveryProtocolVersion.V3) { - return "application/vnd.restate.endpointmanifest.v3+json"; - } - if (Objects.requireNonNull(version) == Discovery.ServiceDiscoveryProtocolVersion.V4) { - return "application/vnd.restate.endpointmanifest.v4+json"; - } - throw new IllegalArgumentException( - String.format( - "Service discovery protocol version '%s' has no header value", version.getNumber())); + return maxVersion; } static final ObjectMapper MANIFEST_OBJECT_MAPPER = new ObjectMapper(); @@ -139,12 +128,11 @@ interface FieldsMixin {} } static byte[] serializeManifest( - Discovery.ServiceDiscoveryProtocolVersion serviceDiscoveryProtocolVersion, - EndpointManifestSchema response) + Version serviceDiscoveryProtocolVersion, EndpointManifestSchema response) throws ProtocolException { try { SimpleBeanPropertyFilter filter; - if (serviceDiscoveryProtocolVersion == Discovery.ServiceDiscoveryProtocolVersion.V1) { + if (serviceDiscoveryProtocolVersion == Version.V1) { filter = SimpleBeanPropertyFilter.serializeAllExcept( Stream.concat( @@ -153,14 +141,14 @@ static byte[] serializeManifest( DISCOVERY_FIELDS_ADDED_IN_V3.stream()), DISCOVERY_FIELDS_ADDED_IN_V4.stream()) .collect(Collectors.toSet())); - } else if (serviceDiscoveryProtocolVersion == Discovery.ServiceDiscoveryProtocolVersion.V2) { + } else if (serviceDiscoveryProtocolVersion == Version.V2) { filter = SimpleBeanPropertyFilter.serializeAllExcept( Stream.concat( DISCOVERY_FIELDS_ADDED_IN_V3.stream(), DISCOVERY_FIELDS_ADDED_IN_V4.stream()) .collect(Collectors.toSet())); - } else if (serviceDiscoveryProtocolVersion == Discovery.ServiceDiscoveryProtocolVersion.V3) { + } else if (serviceDiscoveryProtocolVersion == Version.V3) { filter = SimpleBeanPropertyFilter.serializeAllExcept(DISCOVERY_FIELDS_ADDED_IN_V4); } else { filter = SimpleBeanPropertyFilter.serializeAll(); diff --git a/sdk-core/src/main/java/dev/restate/sdk/core/EndpointManifest.java b/sdk-core/src/main/java/dev/restate/sdk/core/EndpointManifest.java index 1b6b513f9..7a9b72a61 100644 --- a/sdk-core/src/main/java/dev/restate/sdk/core/EndpointManifest.java +++ b/sdk-core/src/main/java/dev/restate/sdk/core/EndpointManifest.java @@ -9,11 +9,8 @@ package dev.restate.sdk.core; import static dev.restate.sdk.core.DiscoveryProtocol.MANIFEST_OBJECT_MAPPER; -import static dev.restate.sdk.core.statemachine.ServiceProtocol.MAX_SERVICE_PROTOCOL_VERSION; -import static dev.restate.sdk.core.statemachine.ServiceProtocol.MIN_SERVICE_PROTOCOL_VERSION; import com.fasterxml.jackson.core.JsonProcessingException; -import dev.restate.sdk.core.generated.discovery.Discovery; import dev.restate.sdk.core.generated.manifest.*; import dev.restate.sdk.endpoint.definition.*; import dev.restate.serde.Serde; @@ -36,23 +33,22 @@ final class EndpointManifest { } EndpointManifestSchema manifest( - Discovery.ServiceDiscoveryProtocolVersion version, - EndpointManifestSchema.ProtocolMode protocolMode) { + DiscoveryProtocol.Version version, EndpointManifestSchema.ProtocolMode protocolMode) { EndpointManifestSchema manifest = new EndpointManifestSchema() .withProtocolMode(protocolMode) - .withMinProtocolVersion((long) MIN_SERVICE_PROTOCOL_VERSION.getNumber()) - .withMaxProtocolVersion((long) MAX_SERVICE_PROTOCOL_VERSION.getNumber()) + .withMinProtocolVersion(5L) + .withMaxProtocolVersion(7L) .withServices(this.services); // Verify that the user didn't set fields that we don't support in the discovery version we set for (var service : manifest.getServices()) { - if (version.getNumber() < Discovery.ServiceDiscoveryProtocolVersion.V2.getNumber()) { + if (version.getNumber() < DiscoveryProtocol.Version.V2.getNumber()) { verifyFieldNotSet( "metadata", service, s -> s.getMetadata() != null && !s.getMetadata().getAdditionalProperties().isEmpty()); } - if (version.getNumber() < Discovery.ServiceDiscoveryProtocolVersion.V3.getNumber()) { + if (version.getNumber() < DiscoveryProtocol.Version.V3.getNumber()) { verifyFieldNull("idempotency retention", service.getIdempotencyRetention()); verifyFieldNull("journal retention", service.getJournalRetention()); verifyFieldNull("inactivity timeout", service.getInactivityTimeout()); @@ -60,7 +56,7 @@ EndpointManifestSchema manifest( verifyFieldNull("enable lazy state", service.getEnableLazyState()); verifyFieldNull("ingress private", service.getIngressPrivate()); } - if (version.getNumber() < Discovery.ServiceDiscoveryProtocolVersion.V4.getNumber()) { + if (version.getNumber() < DiscoveryProtocol.Version.V4.getNumber()) { verifyFieldNull("retry policy initial interval", service.getRetryPolicyInitialInterval()); verifyFieldNull("retry policy max interval", service.getRetryPolicyMaxInterval()); verifyFieldNull("retry policy max attempts", service.getRetryPolicyMaxAttempts()); @@ -69,13 +65,13 @@ EndpointManifestSchema manifest( "retry policy exponentiation factor", service.getRetryPolicyExponentiationFactor()); } for (var handler : service.getHandlers()) { - if (version.getNumber() < Discovery.ServiceDiscoveryProtocolVersion.V2.getNumber()) { + if (version.getNumber() < DiscoveryProtocol.Version.V2.getNumber()) { verifyFieldNotSet( "metadata", handler, h -> h.getMetadata() != null && !h.getMetadata().getAdditionalProperties().isEmpty()); } - if (version.getNumber() < Discovery.ServiceDiscoveryProtocolVersion.V3.getNumber()) { + if (version.getNumber() < DiscoveryProtocol.Version.V3.getNumber()) { verifyFieldNull("idempotency retention", handler.getIdempotencyRetention()); verifyFieldNull("journal retention", handler.getJournalRetention()); verifyFieldNull("inactivity timeout", handler.getInactivityTimeout()); @@ -83,7 +79,7 @@ EndpointManifestSchema manifest( verifyFieldNull("enable lazy state", handler.getEnableLazyState()); verifyFieldNull("ingress private", handler.getIngressPrivate()); } - if (version.getNumber() < Discovery.ServiceDiscoveryProtocolVersion.V4.getNumber()) { + if (version.getNumber() < DiscoveryProtocol.Version.V4.getNumber()) { verifyFieldNull("retry policy initial interval", handler.getRetryPolicyInitialInterval()); verifyFieldNull("retry policy max interval", handler.getRetryPolicyMaxInterval()); verifyFieldNull("retry policy max attempts", handler.getRetryPolicyMaxAttempts()); diff --git a/sdk-core/src/main/java/dev/restate/sdk/core/EndpointRequestHandler.java b/sdk-core/src/main/java/dev/restate/sdk/core/EndpointRequestHandler.java index 261b14a42..dcf21ae43 100644 --- a/sdk-core/src/main/java/dev/restate/sdk/core/EndpointRequestHandler.java +++ b/sdk-core/src/main/java/dev/restate/sdk/core/EndpointRequestHandler.java @@ -9,10 +9,9 @@ package dev.restate.sdk.core; import dev.restate.common.Slice; -import dev.restate.sdk.core.generated.discovery.Discovery; import dev.restate.sdk.core.generated.manifest.EndpointManifestSchema; import dev.restate.sdk.core.generated.manifest.Service; -import dev.restate.sdk.core.statemachine.StateMachine; +import dev.restate.sdk.core.sharedcore.SharedCoreVM; import dev.restate.sdk.endpoint.Endpoint; import dev.restate.sdk.endpoint.HeadersAccessor; import dev.restate.sdk.endpoint.definition.HandlerDefinition; @@ -179,9 +178,6 @@ public RequestProcessor processorForRequest( loggingContextSetter.set(LoggingContextSetter.INVOCATION_ID_KEY, invocationIdHeader); } - // Instantiate state machine - StateMachine stateMachine = StateMachine.init(headersAccessor, loggingContextSetter); - // Resolve the service method definition ServiceDefinition svc = this.endpoint.resolveService(serviceName); if (svc == null) { @@ -214,8 +210,8 @@ public RequestProcessor processorForRequest( LoggingContextSetter.INVOCATION_TARGET_KEY, fullyQualifiedServiceMethod); return new RequestProcessorImpl( + SharedCoreVM.create(headersAccessor), fullyQualifiedServiceMethod, - stateMachine, svc.getServiceType(), handler, otelContext, @@ -228,14 +224,8 @@ StaticResponseRequestProcessor handleDiscoveryRequest( throws ProtocolException { String acceptContentType = headersAccessor.get(ACCEPT); - Discovery.ServiceDiscoveryProtocolVersion version = + DiscoveryProtocol.Version version = DiscoveryProtocol.selectSupportedServiceDiscoveryProtocolVersion(acceptContentType); - if (!DiscoveryProtocol.isSupported(version)) { - throw new ProtocolException( - String.format( - "Unsupported Discovery version in the Accept header '%s'", acceptContentType), - ProtocolException.UNSUPPORTED_MEDIA_TYPE_CODE); - } EndpointManifestSchema response = this.deploymentManifest.manifest( @@ -249,7 +239,7 @@ StaticResponseRequestProcessor handleDiscoveryRequest( return new StaticResponseRequestProcessor( 200, - DiscoveryProtocol.serviceDiscoveryProtocolVersionToHeaderValue(version), + version.getHeader(), Slice.wrap(DiscoveryProtocol.serializeManifest(version, response))); } } diff --git a/sdk-core/src/main/java/dev/restate/sdk/core/ExecutorSwitchingHandlerContextImpl.java b/sdk-core/src/main/java/dev/restate/sdk/core/ExecutorSwitchingHandlerContextImpl.java index 257e951a6..d0659907c 100644 --- a/sdk-core/src/main/java/dev/restate/sdk/core/ExecutorSwitchingHandlerContextImpl.java +++ b/sdk-core/src/main/java/dev/restate/sdk/core/ExecutorSwitchingHandlerContextImpl.java @@ -12,7 +12,7 @@ import dev.restate.common.Slice; import dev.restate.common.Target; import dev.restate.sdk.common.*; -import dev.restate.sdk.core.statemachine.StateMachine; +import dev.restate.sdk.core.sharedcore.SharedCoreVM; import dev.restate.sdk.endpoint.definition.AsyncResult; import dev.restate.sdk.endpoint.definition.HandlerType; import dev.restate.sdk.endpoint.definition.ServiceType; @@ -32,14 +32,30 @@ final class ExecutorSwitchingHandlerContextImpl extends HandlerContextImpl { private final Executor coreExecutor; ExecutorSwitchingHandlerContextImpl( + SharedCoreVM vm, + Consumer outputSink, + Runnable closeCallback, String fullyQualifiedHandlerName, ServiceType serviceType, @Nullable HandlerType handlerType, - StateMachine stateMachine, Context otelContext, - StateMachine.Input input, + InvocationId invocationId, + Slice body, + Map headers, + @Nullable String key, Executor coreExecutor) { - super(fullyQualifiedHandlerName, serviceType, handlerType, stateMachine, otelContext, input); + super( + vm, + outputSink, + closeCallback, + fullyQualifiedHandlerName, + serviceType, + handlerType, + otelContext, + invocationId, + body, + headers, + key); this.coreExecutor = coreExecutor; } diff --git a/sdk-core/src/main/java/dev/restate/sdk/core/HandlerContextImpl.java b/sdk-core/src/main/java/dev/restate/sdk/core/HandlerContextImpl.java index 57906ef86..669ec69cf 100644 --- a/sdk-core/src/main/java/dev/restate/sdk/core/HandlerContextImpl.java +++ b/sdk-core/src/main/java/dev/restate/sdk/core/HandlerContextImpl.java @@ -15,91 +15,136 @@ import dev.restate.common.function.ThrowingSupplier; import dev.restate.sdk.common.*; import dev.restate.sdk.core.AsyncResults.AsyncResultInternal; -import dev.restate.sdk.core.statemachine.InvocationState; -import dev.restate.sdk.core.statemachine.NotificationValue; -import dev.restate.sdk.core.statemachine.StateMachine; +import dev.restate.sdk.core.sharedcore.SharedCoreVM; import dev.restate.sdk.endpoint.definition.AsyncResult; import dev.restate.sdk.endpoint.definition.HandlerType; import dev.restate.sdk.endpoint.definition.ServiceType; import io.opentelemetry.context.Context; +import java.io.PrintWriter; +import java.io.StringWriter; import java.time.Duration; import java.time.Instant; import java.util.*; import java.util.concurrent.CompletableFuture; import java.util.concurrent.Executor; import java.util.function.Consumer; -import java.util.stream.Stream; import org.apache.logging.log4j.LogManager; import org.apache.logging.log4j.Logger; +import org.jspecify.annotations.NonNull; import org.jspecify.annotations.Nullable; +/** + * Mirrors TS {@code ContextImpl}: depends only on {@link SharedCoreVM} plus two functional + * callbacks supplied by {@link RequestProcessorImpl}. + */ class HandlerContextImpl implements HandlerContextInternal { private static final Logger LOG = LogManager.getLogger(HandlerContextImpl.class); - private static final int CANCEL_HANDLE = 1; + // --- VM and I/O (supplied by RequestProcessorImpl, no class-level dependency on it) + + final SharedCoreVM vm; + private final Consumer outputSink; + private final Runnable closeCallback; + + // --- Poll-loop coordination (owned entirely by HandlerContextImpl) + + private @NonNull Runnable nextEventListener = () -> {}; + + // --- Handler metadata private final HandlerRequest handlerRequest; - private final StateMachine stateMachine; private final @Nullable String objectKey; private final String fullyQualifiedHandlerName; private final ServiceType serviceType; private final @Nullable HandlerType handlerType; + boolean closed = false; - private final List> invocationIdsToCancel; private final HashMap> scheduledRuns; HandlerContextImpl( + SharedCoreVM vm, + Consumer outputSink, + Runnable closeCallback, String fullyQualifiedHandlerName, ServiceType serviceType, @Nullable HandlerType handlerType, - StateMachine stateMachine, Context otelContext, - StateMachine.Input input) { - this.handlerRequest = - new HandlerRequest(input.invocationId(), otelContext, input.body(), input.headers()); - this.objectKey = input.key(); - this.stateMachine = stateMachine; + InvocationId invocationId, + Slice body, + Map headers, + @Nullable String key) { + this.vm = vm; + this.outputSink = outputSink; + this.closeCallback = closeCallback; + this.handlerRequest = new HandlerRequest(invocationId, otelContext, body, headers); + this.objectKey = key; this.fullyQualifiedHandlerName = fullyQualifiedHandlerName; this.serviceType = serviceType; this.handlerType = handlerType; - this.invocationIdsToCancel = new ArrayList<>(); this.scheduledRuns = new HashMap<>(); } - private static void parseSuccessOrFailure(NotificationValue s, CompletableFuture cf) { - if (s instanceof NotificationValue.Success success) { - cf.complete(success.slice()); - } else if (s instanceof NotificationValue.Failure failure) { - cf.completeExceptionally(failure.exception()); - } else { - throw ProtocolException.unexpectedNotificationVariant(s.getClass()); - } + // --------------------------------------------------------------------------- + // External-progress trigger (called by RequestProcessorImpl on new input) + // --------------------------------------------------------------------------- + + void triggerExternalProgress() { + Runnable listener = this.nextEventListener; + this.nextEventListener = () -> {}; + listener.run(); } - private static void parseEmptyOrSuccessOrFailure( - NotificationValue s, CompletableFuture> cf) { - if (s instanceof NotificationValue.Empty) { - cf.complete(Output.notReady()); - } else if (s instanceof NotificationValue.Success success) { - cf.complete(Output.ready(success.slice())); - } else if (s instanceof NotificationValue.Failure failure) { - cf.completeExceptionally(failure.exception()); - } else { - throw ProtocolException.unexpectedNotificationVariant(s.getClass()); + // --------------------------------------------------------------------------- + // Output pump (mirrors TS OutputPump) + // --------------------------------------------------------------------------- + + private void pumpOutput() { + byte[] chunk = vm.takeOutput(); + if (chunk.length > 0) outputSink.accept(Slice.wrap(chunk)); + } + + private void drainAllOutput() { + while (true) { + byte[] chunk = vm.takeOutput(); + if (chunk.length == 0) return; + outputSink.accept(Slice.wrap(chunk)); } } - private static void parseEmptyOrFailure(NotificationValue s, CompletableFuture cf) { - if (s instanceof NotificationValue.Empty) { - cf.complete(null); - } else if (s instanceof NotificationValue.Failure failure) { - cf.completeExceptionally(failure.exception()); - } else { - throw ProtocolException.unexpectedNotificationVariant(s.getClass()); + // --------------------------------------------------------------------------- + // Notification reader (used by AsyncResults.tryComplete) + // --------------------------------------------------------------------------- + + Optional takeNotification(int handle) { + SharedCoreVM.NotificationValue raw = vm.takeNotification(handle); + if (raw == null) return Optional.empty(); + return Optional.of(mapNotificationValue(raw)); + } + + private static NotificationValue mapNotificationValue(SharedCoreVM.NotificationValue raw) { + if (raw instanceof SharedCoreVM.NotificationValue.Void) { + return NotificationValue.Empty.INSTANCE; + } else if (raw instanceof SharedCoreVM.NotificationValue.Success s) { + return new NotificationValue.Success(Slice.wrap(s.value())); + } else if (raw instanceof SharedCoreVM.NotificationValue.Failure f) { + Map meta = new LinkedHashMap<>(); + if (f.metadata() != null) { + for (String[] pair : f.metadata()) meta.put(pair[0], pair[1]); + } + return new NotificationValue.Failure(new TerminalException(f.code(), f.message(), meta)); + } else if (raw instanceof SharedCoreVM.NotificationValue.StateKeys sk) { + return new NotificationValue.StateKeys(sk.keys()); + } else if (raw instanceof SharedCoreVM.NotificationValue.InvocationId id) { + return new NotificationValue.InvocationId(id.id()); } + throw new IllegalStateException("Unknown NotificationValue: " + raw); } + // --------------------------------------------------------------------------- + // HandlerContextInternal — metadata + // --------------------------------------------------------------------------- + @Override public String objectKey() { return this.objectKey; @@ -137,7 +182,7 @@ public String getFullyQualifiedMethodName() { @Override public InvocationState getInvocationState() { - return this.stateMachine.state(); + return closed ? InvocationState.CLOSED : InvocationState.PROCESSING; } @Override @@ -145,13 +190,75 @@ public Executor stateMachineExecutor() { return Runnable::run; } + // --------------------------------------------------------------------------- + // HandlerContextInternal — poll loop (mirrors TS VMProgressCoordinator) + // --------------------------------------------------------------------------- + + @Override + public void pollAsyncResult(AsyncResultInternal asyncResult) { + // Drain one output chunk BEFORE the first do_progress — mirrors TS outer doProgress drain. + pumpOutput(); + pollAsyncResultInner(asyncResult); + } + + private void pollAsyncResultInner(AsyncResultInternal asyncResult) { + while (true) { + if (closed) { + asyncResult.publicFuture().completeExceptionally(AbortedExecutionException.INSTANCE); + return; + } + if (asyncResult.isDone()) { + return; + } + + asyncResult.tryComplete(this::takeNotification); + + List uncompletedLeaves = asyncResult.uncompletedLeaves().toList(); + if (uncompletedLeaves.isEmpty()) { + return; + } + + int[] handles = uncompletedLeaves.stream().mapToInt(Integer::intValue).toArray(); + SharedCoreVM.DoProgressResult result; + try { + result = vm.doProgress(handles); + } catch (Throwable e) { + failWithoutContextSwitch(e); + asyncResult.publicFuture().completeExceptionally(AbortedExecutionException.INSTANCE); + return; + } + + if (result instanceof SharedCoreVM.DoProgressResult.AnyCompleted) { + // loop + } else if (result instanceof SharedCoreVM.DoProgressResult.WaitExternalProgress) { + // Drain one chunk after WaitExternalProgress — mirrors TS inner doProgress drain. + pumpOutput(); + this.nextEventListener = () -> pollAsyncResultInner(asyncResult); + return; + } else if (result instanceof SharedCoreVM.DoProgressResult.CancelSignalReceived) { + asyncResult.tryCancel(); + return; + } else if (result instanceof SharedCoreVM.DoProgressResult.ExecuteRun r) { + triggerScheduledRun(r.handle()); + // loop + } else if (result instanceof SharedCoreVM.DoProgressResult.Suspended) { + pumpOutput(); + ExceptionUtils.sneakyThrow(AbortedExecutionException.INSTANCE); + } + } + } + + // --------------------------------------------------------------------------- + // HandlerContextInternal — state ops + // --------------------------------------------------------------------------- + @Override public CompletableFuture>> get(String name) { return catchExceptions( () -> AsyncResults.single( this, - this.stateMachine.stateGet(name), + vm.sysStateGet(name), (s, cf) -> { if (s instanceof NotificationValue.Empty) { cf.complete(Optional.empty()); @@ -169,7 +276,7 @@ public CompletableFuture>> getKeys() { () -> AsyncResults.single( this, - this.stateMachine.stateGetKeys(), + vm.sysStateGetKeys(), (s, cf) -> { if (s instanceof NotificationValue.StateKeys stateKeys) { cf.complete(stateKeys.stateKeys()); @@ -181,26 +288,30 @@ public CompletableFuture>> getKeys() { @Override public CompletableFuture clear(String name) { - return this.catchExceptions(() -> this.stateMachine.stateClear(name)); + return catchExceptions(() -> vm.sysStateClear(name)); } @Override public CompletableFuture clearAll() { - return this.catchExceptions(this.stateMachine::stateClearAll); + return catchExceptions(vm::sysStateClearAll); } @Override public CompletableFuture set(String name, Slice value) { - return this.catchExceptions(() -> this.stateMachine.stateSet(name, value)); + return catchExceptions(() -> vm.sysStateSet(name, value.toByteArray())); } + // --------------------------------------------------------------------------- + // HandlerContextInternal — timer + // --------------------------------------------------------------------------- + @Override public CompletableFuture> timer(Duration duration, String name) { return catchExceptions( () -> AsyncResults.single( this, - this.stateMachine.sleep(duration, name), + vm.sysSleep(duration.toMillis(), name), (s, cf) -> { if (s instanceof NotificationValue.Empty) { cf.complete(null); @@ -210,6 +321,10 @@ public CompletableFuture> timer(Duration duration, String name })); } + // --------------------------------------------------------------------------- + // HandlerContextInternal — call / send + // --------------------------------------------------------------------------- + @Override public CompletableFuture call( Target target, @@ -218,16 +333,21 @@ public CompletableFuture call( @Nullable Collection> headers) { return catchExceptions( () -> { - StateMachine.CallHandle callHandle = - this.stateMachine.call(target, parameter, idempotencyKey, headers); + SharedCoreVM.CallHandleResult r = + vm.sysCall( + target.getService(), + target.getHandler(), + target.getKey(), + parameter.toByteArray(), + idempotencyKey, + headers != null ? new ArrayList<>(headers) : null); AsyncResultInternal invocationIdAsyncResult = - AsyncResults.single(this, callHandle.invocationIdHandle(), invocationIdCompleter()); - this.invocationIdsToCancel.add(invocationIdAsyncResult); + AsyncResults.single(this, r.invocationIdHandle(), invocationIdCompleter()); AsyncResult callAsyncResult = AsyncResults.single( - this, callHandle.resultHandle(), HandlerContextImpl::parseSuccessOrFailure); + this, r.resultHandle(), HandlerContextImpl::parseSuccessOrFailure); return new CallResult(invocationIdAsyncResult, callAsyncResult); }); @@ -243,7 +363,14 @@ public CompletableFuture> send( return catchExceptions( () -> { int sendHandle = - this.stateMachine.send(target, parameter, idempotencyKey, headers, delay); + vm.sysSend( + target.getService(), + target.getHandler(), + target.getKey(), + parameter.toByteArray(), + idempotencyKey, + headers != null ? new ArrayList<>(headers) : null, + delay != null ? delay.toMillis() : null); return AsyncResults.single(this, sendHandle, invocationIdCompleter()); }); @@ -259,47 +386,116 @@ private static AsyncResults.Completer invocationIdCompleter() { }; } + // --------------------------------------------------------------------------- + // HandlerContextInternal — run + // --------------------------------------------------------------------------- + @Override public CompletableFuture> submitRun( @Nullable String name, Consumer closure) { return catchExceptions( () -> { - int runHandle = this.stateMachine.run(name); + int runHandle = vm.sysRun(name); this.scheduledRuns.put(runHandle, closure); return AsyncResults.single(this, runHandle, HandlerContextImpl::parseSuccessOrFailure); }); } + @Override + public void proposeRunSuccess(int runHandle, Slice toWrite) { + try { + vm.proposeRunCompletionSuccess(runHandle, toWrite.toByteArray()); + pumpOutput(); + } catch (Exception e) { + failWithoutContextSwitch(e); + } + } + + @Override + public void proposeRunFailure( + int runHandle, + Throwable toWrite, + Duration attemptDuration, + @Nullable RetryPolicy retryPolicy) { + try { + SharedCoreVM.WasmRetryPolicy rp = toWasmRetryPolicy(retryPolicy); + if (toWrite instanceof TerminalException te) { + vm.proposeRunCompletionTerminalFailure( + runHandle, te.getCode(), te.getMessage(), toMetaList(te)); + } else { + vm.proposeRunCompletionRetryableFailure( + runHandle, + 500, + toWrite.getMessage() != null ? toWrite.getMessage() : toWrite.getClass().getName(), + stacktraceToString(toWrite), + attemptDuration.toMillis(), + rp); + } + pumpOutput(); + } catch (Exception e) { + failWithoutContextSwitch(e); + } + } + + private void triggerScheduledRun(int handle) { + var consumer = + Objects.requireNonNull( + this.scheduledRuns.get(handle), "The given handle doesn't exist, this is an SDK bug"); + var startTime = Instant.now(); + consumer.accept( + new RunCompleter() { + @Override + public void proposeSuccess(Slice toWrite) { + proposeRunSuccess(handle, toWrite); + } + + @Override + public void proposeFailure(Throwable toWrite, @Nullable RetryPolicy retryPolicy) { + proposeRunFailure( + handle, toWrite, Duration.between(startTime, Instant.now()), retryPolicy); + } + }); + } + + // --------------------------------------------------------------------------- + // HandlerContextInternal — awakeable + // --------------------------------------------------------------------------- + @Override public CompletableFuture awakeable() { return catchExceptions( () -> { - StateMachine.Awakeable awakeable = this.stateMachine.awakeable(); + SharedCoreVM.AwakeableResult r = vm.sysAwakeable(); return new Awakeable( - awakeable.awakeableId(), + r.awakeableId(), AsyncResults.single( - this, awakeable.handle(), HandlerContextImpl::parseSuccessOrFailure)); + this, r.signalHandle(), HandlerContextImpl::parseSuccessOrFailure)); }); } @Override public CompletableFuture resolveAwakeable(String id, Slice payload) { - return this.catchExceptions(() -> this.stateMachine.completeAwakeable(id, payload)); + return catchExceptions(() -> vm.sysCompleteAwakeableSuccess(id, payload.toByteArray())); } @Override public CompletableFuture rejectAwakeable(String id, TerminalException reason) { - return this.catchExceptions(() -> this.stateMachine.completeAwakeable(id, reason)); + return catchExceptions( + () -> + vm.sysCompleteAwakeableFailure( + id, reason.getCode(), reason.getMessage(), toMetaList(reason))); } + // --------------------------------------------------------------------------- + // HandlerContextInternal — promises + // --------------------------------------------------------------------------- + @Override public CompletableFuture> promise(String key) { return catchExceptions( () -> AsyncResults.single( - this, - this.stateMachine.promiseGet(key), - HandlerContextImpl::parseSuccessOrFailure)); + this, vm.sysPromiseGet(key), HandlerContextImpl::parseSuccessOrFailure)); } @Override @@ -307,9 +503,7 @@ public CompletableFuture>> peekPromise(String key) { return catchExceptions( () -> AsyncResults.single( - this, - this.stateMachine.promisePeek(key), - HandlerContextImpl::parseEmptyOrSuccessOrFailure)); + this, vm.sysPromisePeek(key), HandlerContextImpl::parseEmptyOrSuccessOrFailure)); } @Override @@ -318,7 +512,7 @@ public CompletableFuture> resolvePromise(String key, Slice pay () -> AsyncResults.single( this, - this.stateMachine.promiseComplete(key, payload), + vm.sysPromiseCompleteSuccess(key, payload.toByteArray()), HandlerContextImpl::parseEmptyOrFailure)); } @@ -328,13 +522,18 @@ public CompletableFuture> rejectPromise(String key, TerminalEx () -> AsyncResults.single( this, - this.stateMachine.promiseComplete(key, reason), + vm.sysPromiseCompleteFailure( + key, reason.getCode(), reason.getMessage(), toMetaList(reason)), HandlerContextImpl::parseEmptyOrFailure)); } + // --------------------------------------------------------------------------- + // HandlerContextInternal — invocation control + // --------------------------------------------------------------------------- + @Override public CompletableFuture cancelInvocation(String invocationId) { - return this.catchExceptions(() -> this.stateMachine.cancelInvocation(invocationId)); + return catchExceptions(() -> vm.sysCancelInvocation(invocationId)); } @Override @@ -343,7 +542,7 @@ public CompletableFuture> attachInvocation(String invocationI () -> AsyncResults.single( this, - this.stateMachine.attachInvocation(invocationId), + vm.sysAttachInvocation(invocationId), HandlerContextImpl::parseSuccessOrFailure)); } @@ -353,172 +552,128 @@ public CompletableFuture>> getInvocationOutput(String () -> AsyncResults.single( this, - this.stateMachine.getInvocationOutput(invocationId), + vm.sysGetInvocationOutput(invocationId), HandlerContextImpl::parseEmptyOrSuccessOrFailure)); } + // --------------------------------------------------------------------------- + // HandlerContextInternal — output / lifecycle + // --------------------------------------------------------------------------- + @Override public CompletableFuture writeOutput(Slice value) { - return this.catchExceptions(() -> this.stateMachine.writeOutput(value)); + return catchExceptions(() -> vm.sysWriteOutputSuccess(value.toByteArray())); } @Override public CompletableFuture writeOutput(TerminalException throwable) { - return this.catchExceptions(() -> this.stateMachine.writeOutput(throwable)); + return catchExceptions( + () -> + vm.sysWriteOutputFailure( + throwable.getCode(), throwable.getMessage(), toMetaList(throwable))); } @Override - public void pollAsyncResult(AsyncResultInternal asyncResult) { - // We use the separate function for the recursion, - // as there's no need to jump back and forth between threads again. - this.pollAsyncResultInner(asyncResult); + public void close() { + vm.sysEnd(); + drainAllOutput(); + closeCallback.run(); + this.closed = true; } - private void pollAsyncResultInner(AsyncResultInternal asyncResult) { - while (true) { - if (this.stateMachine.state() == InvocationState.CLOSED) { - asyncResult.publicFuture().completeExceptionally(AbortedExecutionException.INSTANCE); - return; - } - if (asyncResult.isDone()) { - return; - } - - // Let's look for the cancellation notification - var cancellationNotification = this.stateMachine.takeNotification(CANCEL_HANDLE); - if (cancellationNotification.isPresent()) { - LOG.info("Detected cancellation signal! Will start cancelling child invocations"); - - // Let's wait to cancel all - @SuppressWarnings({"rawtypes", "unchecked"}) - AsyncResultInternal allInvocationIds = - AsyncResults.all(this, (List) this.invocationIdsToCancel); - allInvocationIds - .publicFuture() - .whenComplete( - (ignored, throwable) -> { - if (throwable != null) { - // Already handled - return; - } - LOG.info("All child invocation ids retrieved"); - try { - for (var invocationIdAr : this.invocationIdsToCancel) { - this.stateMachine.cancelInvocation( - Objects.requireNonNull(invocationIdAr.publicFuture().getNow(null))); - } - asyncResult.tryCancel(); - } catch (Throwable e) { - // Not good! - this.failWithoutContextSwitch(e); - } - }); - // Let's resolve all the invocation IDs - pollAsyncResultInner(allInvocationIds); - return; - } - - // Let's start by trying to complete it - asyncResult.tryComplete(this.stateMachine); - - // Now let's take the unprocessed leaves - List uncompletedLeaves = - Stream.concat(asyncResult.uncompletedLeaves(), Stream.of(CANCEL_HANDLE)).toList(); - if (uncompletedLeaves.size() == 1) { - // Nothing else to do! - return; - } - - // Not ready yet, let's try to do some progress - StateMachine.DoProgressResponse response; - try { - response = this.stateMachine.doProgress(uncompletedLeaves); - } catch (Throwable e) { - this.failWithoutContextSwitch(e); - asyncResult.publicFuture().completeExceptionally(AbortedExecutionException.INSTANCE); - return; - } - - if (response instanceof StateMachine.DoProgressResponse.AnyCompleted) { - // Let it loop now - } else if (response instanceof StateMachine.DoProgressResponse.ReadFromInput - || response instanceof StateMachine.DoProgressResponse.WaitingPendingRun) { - this.stateMachine.onNextEvent( - () -> this.pollAsyncResultInner(asyncResult), - response instanceof StateMachine.DoProgressResponse.ReadFromInput); - return; - } else if (response instanceof StateMachine.DoProgressResponse.ExecuteRun) { - triggerScheduledRun(((StateMachine.DoProgressResponse.ExecuteRun) response).handle()); - // Let it loop now - } - } + @Override + public void fail(Throwable cause) { + failWithoutContextSwitch(cause); } @Override - public void proposeRunSuccess(int runHandle, Slice toWrite) { + public void failWithoutContextSwitch(Throwable cause) { try { - this.stateMachine.proposeRunCompletion(runHandle, toWrite); - } catch (Exception e) { - this.failWithoutContextSwitch(e); + String message = cause.getMessage() != null ? cause.getMessage() : cause.getClass().getName(); + vm.notifyError(message, stacktraceToString(cause), null); + pumpOutput(); + } catch (Throwable ignored) { + // already in error handling } } - @Override - public void proposeRunFailure( - int runHandle, - Throwable toWrite, - Duration attemptDuration, - @Nullable RetryPolicy retryPolicy) { - try { - this.stateMachine.proposeRunCompletion(runHandle, toWrite, attemptDuration, retryPolicy); - } catch (Exception e) { - this.failWithoutContextSwitch(e); + // --------------------------------------------------------------------------- + // Notification value parsing helpers + // --------------------------------------------------------------------------- + + private static void parseSuccessOrFailure(NotificationValue s, CompletableFuture cf) { + if (s instanceof NotificationValue.Success success) { + cf.complete(success.slice()); + } else if (s instanceof NotificationValue.Failure failure) { + cf.completeExceptionally(failure.exception()); + } else { + throw ProtocolException.unexpectedNotificationVariant(s.getClass()); } } - private void triggerScheduledRun(int handle) { - var consumer = - Objects.requireNonNull( - this.scheduledRuns.get(handle), "The given handle doesn't exist, this is an SDK bug"); - var startTime = Instant.now(); - consumer.accept( - new RunCompleter() { - @Override - public void proposeSuccess(Slice toWrite) { - proposeRunSuccess(handle, toWrite); - } + private static void parseEmptyOrSuccessOrFailure( + NotificationValue s, CompletableFuture> cf) { + if (s instanceof NotificationValue.Empty) { + cf.complete(Output.notReady()); + } else if (s instanceof NotificationValue.Success success) { + cf.complete(Output.ready(success.slice())); + } else if (s instanceof NotificationValue.Failure failure) { + cf.completeExceptionally(failure.exception()); + } else { + throw ProtocolException.unexpectedNotificationVariant(s.getClass()); + } + } - @Override - public void proposeFailure(Throwable toWrite, @Nullable RetryPolicy retryPolicy) { - proposeRunFailure( - handle, toWrite, Duration.between(startTime, Instant.now()), retryPolicy); - } - }); + private static void parseEmptyOrFailure(NotificationValue s, CompletableFuture cf) { + if (s instanceof NotificationValue.Empty) { + cf.complete(null); + } else if (s instanceof NotificationValue.Failure failure) { + cf.completeExceptionally(failure.exception()); + } else { + throw ProtocolException.unexpectedNotificationVariant(s.getClass()); + } } - @Override - public void close() { - this.stateMachine.end(); + // --------------------------------------------------------------------------- + // Static helpers + // --------------------------------------------------------------------------- + + private static @Nullable List toMetaList(TerminalException exception) { + Map meta = exception.getMetadata(); + if (meta == null || meta.isEmpty()) return null; + List r = new ArrayList<>(meta.size()); + for (Map.Entry e : meta.entrySet()) + r.add(new String[] {e.getKey(), e.getValue()}); + return r; } - @Override - public void fail(Throwable cause) { - this.failWithoutContextSwitch(cause); + private static SharedCoreVM.@Nullable WasmRetryPolicy toWasmRetryPolicy( + @Nullable RetryPolicy retryPolicy) { + if (retryPolicy == null) return null; + return new SharedCoreVM.WasmRetryPolicy( + retryPolicy.getInitialDelay().toMillis(), + retryPolicy.getExponentiationFactor(), + retryPolicy.getMaxDelay() != null ? retryPolicy.getMaxDelay().toMillis() : null, + retryPolicy.getMaxAttempts(), + retryPolicy.getMaxDuration() != null ? retryPolicy.getMaxDuration().toMillis() : null); } - @Override - public void failWithoutContextSwitch(Throwable cause) { - this.stateMachine.onError(cause); + private static String stacktraceToString(Throwable t) { + StringWriter sw = new StringWriter(); + t.printStackTrace(new PrintWriter(sw)); + return sw.toString(); } - // -- Wrapper for failure propagation + // --------------------------------------------------------------------------- + // catchExceptions helpers + // --------------------------------------------------------------------------- private CompletableFuture catchExceptions(ThrowingRunnable r) { try { r.run(); return CompletableFuture.completedFuture(null); } catch (Throwable e) { - this.failWithoutContextSwitch(e); + failWithoutContextSwitch(e); return CompletableFuture.failedFuture(AbortedExecutionException.INSTANCE); } } @@ -527,7 +682,7 @@ private CompletableFuture catchExceptions(ThrowingSupplier r) { try { return CompletableFuture.completedFuture(r.get()); } catch (Throwable e) { - this.failWithoutContextSwitch(e); + failWithoutContextSwitch(e); return CompletableFuture.failedFuture(AbortedExecutionException.INSTANCE); } } diff --git a/sdk-core/src/main/java/dev/restate/sdk/core/HandlerContextInternal.java b/sdk-core/src/main/java/dev/restate/sdk/core/HandlerContextInternal.java index b89b18c9c..5bcd859b2 100644 --- a/sdk-core/src/main/java/dev/restate/sdk/core/HandlerContextInternal.java +++ b/sdk-core/src/main/java/dev/restate/sdk/core/HandlerContextInternal.java @@ -11,7 +11,6 @@ import dev.restate.common.Slice; import dev.restate.sdk.common.RetryPolicy; import dev.restate.sdk.core.AsyncResults.AsyncResultInternal; -import dev.restate.sdk.core.statemachine.InvocationState; import dev.restate.sdk.endpoint.definition.AsyncResult; import dev.restate.sdk.endpoint.definition.HandlerContext; import java.time.Duration; @@ -50,6 +49,8 @@ void proposeRunFailure( void close(); + void fail(Throwable throwable); + // -- State machine introspection (used by logging propagator) /** diff --git a/sdk-core/src/main/java/dev/restate/sdk/core/statemachine/InvocationIdImpl.java b/sdk-core/src/main/java/dev/restate/sdk/core/InvocationIdImpl.java similarity index 97% rename from sdk-core/src/main/java/dev/restate/sdk/core/statemachine/InvocationIdImpl.java rename to sdk-core/src/main/java/dev/restate/sdk/core/InvocationIdImpl.java index 547326c48..85f64981c 100644 --- a/sdk-core/src/main/java/dev/restate/sdk/core/statemachine/InvocationIdImpl.java +++ b/sdk-core/src/main/java/dev/restate/sdk/core/InvocationIdImpl.java @@ -6,7 +6,7 @@ // You can find a copy of the license in file LICENSE in the root // directory of this repository or package, or at // https://github.com/restatedev/sdk-java/blob/main/LICENSE -package dev.restate.sdk.core.statemachine; +package dev.restate.sdk.core; import dev.restate.sdk.common.InvocationId; import java.nio.charset.StandardCharsets; diff --git a/sdk-core/src/main/java/dev/restate/sdk/core/statemachine/InvocationState.java b/sdk-core/src/main/java/dev/restate/sdk/core/InvocationState.java similarity index 90% rename from sdk-core/src/main/java/dev/restate/sdk/core/statemachine/InvocationState.java rename to sdk-core/src/main/java/dev/restate/sdk/core/InvocationState.java index 3820f41da..2944d6c3e 100644 --- a/sdk-core/src/main/java/dev/restate/sdk/core/statemachine/InvocationState.java +++ b/sdk-core/src/main/java/dev/restate/sdk/core/InvocationState.java @@ -6,7 +6,7 @@ // You can find a copy of the license in file LICENSE in the root // directory of this repository or package, or at // https://github.com/restatedev/sdk-java/blob/main/LICENSE -package dev.restate.sdk.core.statemachine; +package dev.restate.sdk.core; public enum InvocationState { WAITING_START, diff --git a/sdk-core/src/main/java/dev/restate/sdk/core/statemachine/NotificationValue.java b/sdk-core/src/main/java/dev/restate/sdk/core/NotificationValue.java similarity index 95% rename from sdk-core/src/main/java/dev/restate/sdk/core/statemachine/NotificationValue.java rename to sdk-core/src/main/java/dev/restate/sdk/core/NotificationValue.java index 8834dd1d1..202157201 100644 --- a/sdk-core/src/main/java/dev/restate/sdk/core/statemachine/NotificationValue.java +++ b/sdk-core/src/main/java/dev/restate/sdk/core/NotificationValue.java @@ -6,7 +6,7 @@ // You can find a copy of the license in file LICENSE in the root // directory of this repository or package, or at // https://github.com/restatedev/sdk-java/blob/main/LICENSE -package dev.restate.sdk.core.statemachine; +package dev.restate.sdk.core; import dev.restate.common.Slice; import dev.restate.sdk.common.TerminalException; diff --git a/sdk-core/src/main/java/dev/restate/sdk/core/ProtocolException.java b/sdk-core/src/main/java/dev/restate/sdk/core/ProtocolException.java index 572b187a2..34af00658 100644 --- a/sdk-core/src/main/java/dev/restate/sdk/core/ProtocolException.java +++ b/sdk-core/src/main/java/dev/restate/sdk/core/ProtocolException.java @@ -8,12 +8,7 @@ // https://github.com/restatedev/sdk-java/blob/main/LICENSE package dev.restate.sdk.core; -import com.google.protobuf.MessageLite; import dev.restate.sdk.common.TerminalException; -import dev.restate.sdk.core.generated.protocol.Protocol; -import dev.restate.sdk.core.statemachine.NotificationId; -import java.util.List; -import java.util.Map; public class ProtocolException extends RuntimeException { @@ -21,9 +16,7 @@ public class ProtocolException extends RuntimeException { static final int NOT_FOUND_CODE = 404; public static final int UNSUPPORTED_MEDIA_TYPE_CODE = 415; public static final int INTERNAL_CODE = 500; - public static final int JOURNAL_MISMATCH_CODE = 570; - static final int PROTOCOL_VIOLATION_CODE = 571; - static final int UNSUPPORTED_FEATURE = 573; + @Deprecated public static final int UNSUPPORTED_FEATURE = 573; private final int code; @@ -40,39 +33,9 @@ public int getCode() { return code; } - public static ProtocolException unexpectedMessage( - Class expected, MessageLite actual) { - return new ProtocolException( - "Unexpected message type received from the runtime. Expected: '" - + expected.getCanonicalName() - + "', Actual: '" - + actual.getClass().getCanonicalName() - + "'", - PROTOCOL_VIOLATION_CODE); - } - - public static ProtocolException unexpectedMessage(String expected, MessageLite actual) { - return new ProtocolException( - "Unexpected message type received from the runtime. Expected: '" - + expected - + "', Actual: '" - + actual.getClass().getCanonicalName() - + "'", - PROTOCOL_VIOLATION_CODE); - } - static ProtocolException unexpectedNotificationVariant(Class clazz) { return new ProtocolException( - "Unexpected notification variant " + clazz.getName(), PROTOCOL_VIOLATION_CODE); - } - - public static ProtocolException commandsToProcessIsEmpty() { - return new ProtocolException("Expecting command queue to be non empty", JOURNAL_MISMATCH_CODE); - } - - public static ProtocolException unknownMessageType(short type) { - return new ProtocolException( - "MessageType " + Integer.toHexString(type) + " unknown", PROTOCOL_VIOLATION_CODE); + "Unexpected notification variant " + clazz.getName(), INTERNAL_CODE); } public static ProtocolException methodNotFound(String serviceName, String handlerName) { @@ -80,50 +43,7 @@ public static ProtocolException methodNotFound(String serviceName, String handle "Cannot find handler '" + serviceName + "/" + handlerName + "'", NOT_FOUND_CODE); } - public static ProtocolException badState(Object thisState) { - return new ProtocolException( - "Cannot process operation because the handler is in unexpected state: " + thisState, - INTERNAL_CODE); - } - - public static ProtocolException badNotificationMessage(String missingField) { - return new ProtocolException( - "Bad notification message, missing field " + missingField, PROTOCOL_VIOLATION_CODE); - } - - public static ProtocolException badRunNotificationId(NotificationId notificationId) { - return new ProtocolException( - "Bad run handle, should be mapped to a completion notification id, but was " - + notificationId, - PROTOCOL_VIOLATION_CODE); - } - - public static ProtocolException commandMissingField(Class clazz, String missingField) { - return new ProtocolException( - "Bad command " + clazz.getName() + ", missing field " + missingField, - PROTOCOL_VIOLATION_CODE); - } - - public static ProtocolException inputClosedWhileWaitingEntries() { - return new ProtocolException( - "The input was closed while still waiting to receive all the `known_entries`", - PROTOCOL_VIOLATION_CODE); - } - - public static ProtocolException closedWhileWaitingEntries() { - return new ProtocolException( - "The state machine was closed while still waiting to receive all the `known_entries`", - PROTOCOL_VIOLATION_CODE); - } - @Deprecated - static ProtocolException invalidSideEffectCall() { - return new ProtocolException( - "A syscall was invoked from within a side effect closure.", - TerminalException.INTERNAL_SERVER_ERROR_CODE, - null); - } - public static ProtocolException idempotencyKeyIsEmpty() { return new ProtocolException( "The provided idempotency key is empty.", @@ -134,44 +54,4 @@ public static ProtocolException idempotencyKeyIsEmpty() { public static ProtocolException unauthorized(Throwable e) { return new ProtocolException("Unauthorized", UNAUTHORIZED_CODE, e); } - - public static ProtocolException uncompletedDoProgressDuringReplay( - List sortedNotificationIds, - Map notificationDescriptions) { - var sb = new StringBuilder(); - sb.append( - "Found a mismatch between the code paths taken during the previous execution and the paths taken during this execution.\n"); - sb.append( - "'Awaiting a future' could not be replayed. This usually means the code was mutated adding an 'await' without registering a new service revision.\n"); - sb.append("Notifications awaited on this await point:"); - for (var notificationId : sortedNotificationIds) { - sb.append("\n - "); - String description = notificationDescriptions.get(notificationId); - if (description != null) { - sb.append(description); - } else if (notificationId instanceof NotificationId.CompletionId completionId) { - sb.append("completion id ").append(completionId.id()); - } else if (notificationId instanceof NotificationId.SignalId signalId) { - sb.append("signal [").append(signalId.id()).append("]"); - } else if (notificationId instanceof NotificationId.SignalName signalName) { - sb.append("signal '").append(signalName.name()).append("'"); - } - } - return new ProtocolException(sb.toString(), JOURNAL_MISMATCH_CODE); - } - - public static ProtocolException unsupportedFeature( - String featureName, - Protocol.ServiceProtocolVersion requiredVersion, - Protocol.ServiceProtocolVersion negotiatedVersion) { - return new ProtocolException( - "Current service protocol version does not support " - + featureName - + ". " - + "Negotiated version: " - + negotiatedVersion.getNumber() - + ", minimum required: " - + requiredVersion.getNumber(), - UNSUPPORTED_FEATURE); - } } diff --git a/sdk-core/src/main/java/dev/restate/sdk/core/RequestProcessorImpl.java b/sdk-core/src/main/java/dev/restate/sdk/core/RequestProcessorImpl.java index 7f7137e54..74864c302 100644 --- a/sdk-core/src/main/java/dev/restate/sdk/core/RequestProcessorImpl.java +++ b/sdk-core/src/main/java/dev/restate/sdk/core/RequestProcessorImpl.java @@ -9,26 +9,59 @@ package dev.restate.sdk.core; import dev.restate.common.Slice; +import dev.restate.sdk.common.InvocationId; import dev.restate.sdk.common.TerminalException; -import dev.restate.sdk.core.statemachine.InvocationState; -import dev.restate.sdk.core.statemachine.StateMachine; +import dev.restate.sdk.core.sharedcore.SharedCoreVM; import dev.restate.sdk.endpoint.definition.HandlerDefinition; import dev.restate.sdk.endpoint.definition.ServiceType; import io.opentelemetry.context.Context; +import java.io.PrintWriter; +import java.io.StringWriter; +import java.util.*; import java.util.concurrent.CompletableFuture; import java.util.concurrent.Executor; import java.util.concurrent.Flow; import java.util.concurrent.atomic.AtomicReference; import org.apache.logging.log4j.LogManager; import org.apache.logging.log4j.Logger; +import org.jspecify.annotations.NonNull; import org.jspecify.annotations.Nullable; +/** + * Handles I/O (Flow.Processor), pre-flight replay, and user code orchestration. Mirrors TS {@code + * RestateInvokeResponse.process()}. + * + *

{@link HandlerContextImpl} depends only on {@link SharedCoreVM} and two lambdas supplied here + * — no reference back to this class. + */ final class RequestProcessorImpl implements RequestProcessor { private static final Logger LOG = LogManager.getLogger(RequestProcessorImpl.class); + // --------------------------------------------------------------------------- + // VM (owned here — lifecycle: created in constructor, closed after handler finishes) + // --------------------------------------------------------------------------- + + private final SharedCoreVM vm; + + // --------------------------------------------------------------------------- + // Streaming state (mirrors WasmStateMachineImpl / TS RestateInvokeResponse) + // --------------------------------------------------------------------------- + + private final CompletableFuture waitForReadyFuture = new CompletableFuture<>(); + + /** Wired to {@code HandlerContextImpl::triggerExternalProgress} once the handler is ready. */ + private @NonNull Runnable externalProgressTrigger = () -> {}; + + private Flow.@Nullable Subscriber outputSubscriber; + private Flow.@Nullable Subscription inputSubscription; + private boolean inputClosed = false; + + // --------------------------------------------------------------------------- + // Handler orchestration + // --------------------------------------------------------------------------- + private final String fullyQualifiedHandlerName; - private final StateMachine stateMachine; private final ServiceType serviceType; private final HandlerDefinition handlerDefinition; private final Context otelContext; @@ -38,15 +71,15 @@ final class RequestProcessorImpl implements RequestProcessor { @SuppressWarnings("unchecked") RequestProcessorImpl( + SharedCoreVM vm, String fullyQualifiedHandlerName, - StateMachine stateMachine, ServiceType serviceType, HandlerDefinition handlerDefinition, Context otelContext, EndpointRequestHandler.LoggingContextSetter loggingContextSetter, Executor syscallExecutor) { + this.vm = vm; this.fullyQualifiedHandlerName = fullyQualifiedHandlerName; - this.stateMachine = stateMachine; this.serviceType = serviceType; this.otelContext = otelContext; this.loggingContextSetter = loggingContextSetter; @@ -55,151 +88,228 @@ final class RequestProcessorImpl implements RequestProcessor { this.onHandlerTaskCancellation = new AtomicReference<>(); } - // Flow methods implementation + // =========================================================================== + // RequestProcessor — public streaming interface + // =========================================================================== @Override - public void subscribe(Flow.Subscriber subscriber) { - LOG.trace("Start processing invocation"); - this.stateMachine.subscribe( - new Flow.Subscriber<>() { - @Override - public void onSubscribe(Flow.Subscription subscription) { - subscriber.onSubscribe(subscription); - } + public int statusCode() { + return 200; + } - @Override - public void onNext(Slice slice) { - subscriber.onNext(slice); - } + @Override + public String responseContentType() { + return vm.getResponseContentType(); + } + @Override + public void subscribe(Flow.Subscriber subscriber) { + LOG.trace("Start processing invocation"); + this.outputSubscriber = subscriber; + subscriber.onSubscribe( + new Flow.Subscription() { @Override - public void onError(Throwable throwable) { - Runnable cancelTask = onHandlerTaskCancellation.get(); - if (cancelTask != null) { - cancelTask.run(); - } - subscriber.onError(throwable); - } + public void request(long n) {} @Override - public void onComplete() { - Runnable cancelTask = onHandlerTaskCancellation.get(); - if (cancelTask != null) { - cancelTask.run(); - } - subscriber.onComplete(); + public void cancel() { + // Transport cancelled — just clean up without calling subscriber methods. + inputClosed = true; + cancelInputSubscription(); + vm.close(); } }); - stateMachine - .waitForReady() - .thenCompose(v -> this.onReady()) + waitForReadyFuture + .thenCompose(v -> onReady()) .whenComplete( (v, t) -> { - if (t != null) { - this.onError(t); - } + if (t != null) notifyErrorToVm(t); }); } @Override public void onSubscribe(Flow.Subscription subscription) { - this.stateMachine.onSubscribe(subscription); + this.inputSubscription = subscription; + subscription.request(Long.MAX_VALUE); } @Override - public void onNext(Slice item) { - this.stateMachine.onNext(item); + public void onNext(Slice slice) { + try { + vm.notifyInput(slice.toByteArray()); + checkReadyToExecute(); + externalProgressTrigger.run(); + } catch (Throwable e) { + onError(e); + } } @Override public void onError(Throwable throwable) { - this.stateMachine.onError(throwable); + notifyErrorToVm(throwable); + if (!waitForReadyFuture.isDone()) { + waitForReadyFuture.completeExceptionally(throwable); + } + Runnable cancelTask = onHandlerTaskCancellation.get(); + if (cancelTask != null) cancelTask.run(); + if (outputSubscriber != null) outputSubscriber.onError(throwable); + externalProgressTrigger.run(); + cancelInputSubscription(); } @Override public void onComplete() { - this.stateMachine.onComplete(); + try { + vm.notifyInputClosed(); + checkReadyToExecute(); + } catch (Throwable e) { + onError(e); + return; + } + Runnable cancelTask = onHandlerTaskCancellation.get(); + if (cancelTask != null) cancelTask.run(); + if (outputSubscriber != null) outputSubscriber.onComplete(); + externalProgressTrigger.run(); + cancelInputSubscription(); } - @Override - public int statusCode() { - return 200; + // --------------------------------------------------------------------------- + // Streaming helpers + // --------------------------------------------------------------------------- + + private void checkReadyToExecute() { + if (!waitForReadyFuture.isDone() && vm.isReadyToExecute()) { + waitForReadyFuture.complete(null); + } } - @Override - public String responseContentType() { - return this.stateMachine.getResponseContentType(); + private void cancelInputSubscription() { + this.inputClosed = true; + if (this.inputSubscription != null) { + this.inputSubscription.cancel(); + this.inputSubscription = null; + } } - private CompletableFuture onReady() { - StateMachine.Input input = stateMachine.input(); + private void notifyErrorToVm(Throwable t) { + try { + String msg = t.getMessage() != null ? t.getMessage() : t.getClass().getName(); + vm.notifyError(msg, stacktraceToString(t), null); + // Drain one chunk in case the VM produced a response. + pumpOutputOnce(); + } catch (Throwable ignored) { + } + } + + private void pumpOutputOnce() { + if (outputSubscriber == null) return; + byte[] chunk = vm.takeOutput(); + if (chunk.length > 0) outputSubscriber.onNext(Slice.wrap(chunk)); + } + + // --------------------------------------------------------------------------- + // Handler orchestration (mirrors TS startUserHandler + flushAndClose) + // --------------------------------------------------------------------------- - if (input == null) { + private CompletableFuture onReady() { + SharedCoreVM.Input raw = vm.sysInput(); + if (raw == null) { return CompletableFuture.failedFuture( new IllegalStateException("State machine input is empty")); } + Map headers = new LinkedHashMap<>(); + if (raw.headers() != null) { + for (var e : raw.headers()) headers.put(e.getKey(), e.getValue()); + } + InvocationId invocationId = new InvocationIdImpl(raw.invocationId(), raw.randomSeed()); + String key = raw.key() != null && !raw.key().isEmpty() ? raw.key() : null; + this.loggingContextSetter.set( - EndpointRequestHandler.LoggingContextSetter.INVOCATION_ID_KEY, - input.invocationId().toString()); + EndpointRequestHandler.LoggingContextSetter.INVOCATION_ID_KEY, invocationId.toString()); + + // Build the output sink and close callback — no reference to `this` class leaks into ctx. + var outputSink = + (java.util.function.Consumer) + slice -> { + if (outputSubscriber != null) outputSubscriber.onNext(slice); + }; - // Prepare HandlerContext object - HandlerContextInternal contextInternal = + var closeCallback = + (Runnable) + () -> { + if (outputSubscriber != null) outputSubscriber.onComplete(); + cancelInputSubscription(); + }; + + HandlerContextImpl ctx = this.syscallsExecutor != null ? new ExecutorSwitchingHandlerContextImpl( + vm, + outputSink, + closeCallback, fullyQualifiedHandlerName, serviceType, handlerDefinition.getHandlerType(), - stateMachine, otelContext, - input, - this.syscallsExecutor) + invocationId, + Slice.wrap(raw.body()), + Collections.unmodifiableMap(headers), + key, + syscallsExecutor) : new HandlerContextImpl( + vm, + outputSink, + closeCallback, fullyQualifiedHandlerName, serviceType, handlerDefinition.getHandlerType(), - stateMachine, otelContext, - input); + invocationId, + Slice.wrap(raw.body()), + Collections.unmodifiableMap(headers), + key); + + // Wire external-progress trigger to HandlerContextImpl's internal listener. + this.externalProgressTrigger = ctx::triggerExternalProgress; CompletableFuture userCodeFuture = this.handlerDefinition .getRunner() .run( - contextInternal, + ctx, handlerDefinition.getRequestSerde(), handlerDefinition.getResponseSerde(), onHandlerTaskCancellation); - return userCodeFuture.handle( - (slice, t) -> { - if (t != null) { - this.end(contextInternal, t); - } else { - this.writeOutputAndEnd(contextInternal, slice); - } - return null; - }); + return userCodeFuture + .handle( + (slice, t) -> { + if (t != null) { + endInvocation(ctx, t); + } else { + writeOutputAndEnd(ctx, slice); + } + return null; + }) + .thenApply(v -> null) + .whenComplete((v, t) -> vm.close()); // VM lifecycle owned here } - private CompletableFuture writeOutputAndEnd( - HandlerContextInternal contextInternal, Slice output) { - return contextInternal.writeOutput(output).thenAccept(v -> this.end(contextInternal, null)); + private CompletableFuture writeOutputAndEnd(HandlerContextImpl ctx, Slice output) { + return ctx.writeOutput(output).thenAccept(v -> endInvocation(ctx, null)); } - private CompletableFuture end( - HandlerContextInternal contextInternal, @Nullable Throwable exception) { + private CompletableFuture endInvocation( + HandlerContextImpl ctx, @Nullable Throwable exception) { if (exception == null || ExceptionUtils.containsSuspendedException(exception)) { - contextInternal.close(); - } else if (contextInternal.getInvocationState() != InvocationState.CLOSED) { + ctx.close(); + } else if (!ctx.closed) { if (ExceptionUtils.isTerminalException(exception)) { LOG.info("Invocation completed with terminal error", exception); - return contextInternal - .writeOutput((TerminalException) exception) - .thenAccept(v -> contextInternal.close()); + return ctx.writeOutput((TerminalException) exception).thenAccept(v -> ctx.close()); } else { - // No need to log here, fail inside will log - contextInternal.fail(exception); + ctx.fail(exception); } } else if (!"kotlinx.coroutines.JobCancellationException" .equals(exception.getClass().getCanonicalName())) { @@ -207,4 +317,14 @@ private CompletableFuture end( } return CompletableFuture.completedFuture(null); } + + // --------------------------------------------------------------------------- + // Static helpers + // --------------------------------------------------------------------------- + + private static String stacktraceToString(Throwable t) { + StringWriter sw = new StringWriter(); + t.printStackTrace(new PrintWriter(sw)); + return sw.toString(); + } } diff --git a/sdk-core/src/main/java/dev/restate/sdk/core/sharedcore/SharedCoreInstance.java b/sdk-core/src/main/java/dev/restate/sdk/core/sharedcore/SharedCoreInstance.java new file mode 100644 index 000000000..f0d002ca7 --- /dev/null +++ b/sdk-core/src/main/java/dev/restate/sdk/core/sharedcore/SharedCoreInstance.java @@ -0,0 +1,186 @@ +// Copyright (c) 2023 - Restate Software, Inc., Restate GmbH +// +// This file is part of the Restate Java SDK, +// which is released under the MIT license. +// +// You can find a copy of the license in file LICENSE in the root +// directory of this repository or package, or at +// https://github.com/restatedev/sdk-java/blob/main/LICENSE +package dev.restate.sdk.core.sharedcore; + +import com.dylibso.chicory.annotations.HostModule; +import com.dylibso.chicory.annotations.WasmExport; +import com.dylibso.chicory.runtime.HostFunction; +import com.dylibso.chicory.runtime.ImportValues; +import com.dylibso.chicory.runtime.Instance; +import com.dylibso.chicory.runtime.Memory; +import com.dylibso.chicory.wasm.WasmModule; +import com.fasterxml.jackson.core.JsonProcessingException; +import com.fasterxml.jackson.dataformat.cbor.databind.CBORMapper; +import dev.restate.sdk.core.ProtocolException; +import dev.restate.sdk.core.sharedcore.generated.SharedCoreWasmMachine; +import java.io.IOException; +import java.util.function.Function; +import org.apache.logging.log4j.Level; +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; + +class SharedCoreInstance { + + private static final Logger LOG = LogManager.getLogger(SharedCoreInstance.class); + private static final CBORMapper CBOR = CBORMapper.builder().build(); + private static final WasmModule WASM_MODULE = + dev.restate.sdk.core.sharedcore.generated.SharedCoreWasm.load(); + private static final ThreadLocal THREAD_LOCAL = + ThreadLocal.withInitial(SharedCoreInstance::create); + + private final Memory memory; + private final SharedCoreWasm_ModuleExports exports; + + private SharedCoreInstance(Memory memory, SharedCoreWasm_ModuleExports exports) { + this.memory = memory; + this.exports = exports; + } + + static SharedCoreInstance get() { + return THREAD_LOCAL.get(); + } + + private static SharedCoreInstance create() { + ImportValues importValues = + ImportValues.builder().addFunction(SharedCoreImports.INSTANCE.toHostFunctions()).build(); + + Instance instance = + Instance.builder(WASM_MODULE) + .withMachineFactory(SharedCoreWasmMachine::new) + .withImportValues(importValues) + .build(); + + Memory mem = instance.memory(); + SharedCoreWasm_ModuleExports exp = new SharedCoreWasm_ModuleExports(instance); + + exp.init(toWasmLevel(LOG.getLevel())); + return new SharedCoreInstance(mem, exp); + } + + public SharedCoreWasm_ModuleExports getExports() { + return exports; + } + + public byte[] readAndFree(long packed) { + int ptr = (int) (packed >>> 32); + int len = (int) (packed & 0xFFFFFFFFL); + byte[] bytes = memory.readBytes(ptr, len); + exports.deallocate(ptr, len); + return bytes; + } + + public T readCborAndFree(long packed, Class outputClazz) { + byte[] retBytes = readAndFree(packed); + try { + return CBOR.readValue(retBytes, outputClazz); + } catch (IOException e) { + throw new ProtocolException("Failed to decode CBOR", ProtocolException.INTERNAL_CODE, e); + } + } + + public int write(byte[] bytes) { + int hPtr = exports.allocate(bytes.length); + memory.write(hPtr, bytes); + return hPtr; + } + + public record BufferPointer(int ptr, int len) {} + + public BufferPointer writeCbor(Object input) { + byte[] cbor; + try { + cbor = CBOR.writeValueAsBytes(input); + } catch (JsonProcessingException e) { + throw new ProtocolException("Failed to encode CBOR", ProtocolException.INTERNAL_CODE, e); + } + int ptr = write(cbor); + return new BufferPointer(ptr, cbor.length); + } + + public T callCborVmFunction( + TriFunction func, + Object input, + Class outputClazz) { + var inputBufferPtr = writeCbor(input); + long packed = func.apply(exports, inputBufferPtr.ptr, inputBufferPtr.len); + return readCborAndFree(packed, outputClazz); + } + + public void callCborVmFunction( + TriConsumer func, Object input) { + var inputBufferPtr = writeCbor(input); + func.accept(exports, inputBufferPtr.ptr, inputBufferPtr.len); + } + + public T callCborVmFunction( + Function func, Class outputClazz) { + long packed = func.apply(exports); + return readCborAndFree(packed, outputClazz); + } + + @FunctionalInterface + public interface TriFunction { + R apply(X x, Y y, Z z); + } + + @FunctionalInterface + public interface QuadFunction { + R apply(W w, X x, Y y, Z z); + } + + @FunctionalInterface + public interface TriConsumer { + void accept(X x, Y y, Z z); + } + + static Level toLog4jLevel(int level) { + return switch (level) { + case 0 -> Level.TRACE; + case 1 -> Level.DEBUG; + case 2 -> Level.INFO; + case 3 -> Level.WARN; + default -> Level.ERROR; + }; + } + + static int toWasmLevel(Level level) { + if (level == Level.TRACE) { + return 0; + } else if (level == Level.DEBUG) { + return 1; + } else if (level == Level.INFO) { + return 2; + } else if (level == Level.WARN) { + return 3; + } else { + return 4; + } + } + + @HostModule("env") + static final class SharedCoreImports { + + private SharedCoreImports() {} + + public static final SharedCoreImports INSTANCE = new SharedCoreImports(); + + @WasmExport + public void log(Memory memory, int level, int ptr, int len) { + if (len <= 0) { + return; + } + String message = memory.readString(ptr, len); + LOG.atLevel(toLog4jLevel(level)).log(message); + } + + public HostFunction[] toHostFunctions() { + return SharedCoreImports_ModuleFactory.toHostFunctions(this); + } + } +} diff --git a/sdk-core/src/main/java/dev/restate/sdk/core/sharedcore/SharedCoreVM.java b/sdk-core/src/main/java/dev/restate/sdk/core/sharedcore/SharedCoreVM.java new file mode 100644 index 000000000..fb662b4a3 --- /dev/null +++ b/sdk-core/src/main/java/dev/restate/sdk/core/sharedcore/SharedCoreVM.java @@ -0,0 +1,867 @@ +// Copyright (c) 2023 - Restate Software, Inc., Restate GmbH +// +// This file is part of the Restate Java SDK, +// which is released under the MIT license. +// +// You can find a copy of the license in file LICENSE in the root +// directory of this repository or package, or at +// https://github.com/restatedev/sdk-java/blob/main/LICENSE +package dev.restate.sdk.core.sharedcore; + +import com.fasterxml.jackson.annotation.JsonSubTypes; +import com.fasterxml.jackson.annotation.JsonTypeInfo; +import com.fasterxml.jackson.annotation.JsonTypeInfo.As; +import com.fasterxml.jackson.annotation.JsonTypeInfo.Id; +import dev.restate.sdk.core.ProtocolException; +import dev.restate.sdk.endpoint.HeadersAccessor; +import java.util.ArrayList; +import java.util.Collections; +import java.util.List; +import java.util.Map; +import java.util.function.BiFunction; +import java.util.stream.Collectors; +import java.util.stream.StreamSupport; +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; +import org.jspecify.annotations.Nullable; + +/** + * Java wrapper around the Rust {@code restate-sdk-shared-core} VM, embedded as a Chicory WASM + * module. + * + *

Mirrors sdk-go/internal/statemachine/wasm.go. Every WASM function returns {@code u64 = (ptr << + * 32) | len} pointing to CBOR. Response types match the Rust output DTOs 1:1. + */ +public final class SharedCoreVM implements AutoCloseable { + + private static final Logger LOG = LogManager.getLogger(SharedCoreVM.class); + + private final SharedCoreInstance instance; + private final int vmPtr; + private boolean closed; + + private SharedCoreVM(SharedCoreInstance instance, int vmPtr) { + this.instance = instance; + this.vmPtr = vmPtr; + this.closed = false; + } + + // ------------------------------------------------------------------------- + // Factory + // ------------------------------------------------------------------------- + + public static SharedCoreVM create(HeadersAccessor headersAccessor) { + LOG.trace("create()"); + SharedCoreInstance instance = SharedCoreInstance.get(); + + var newVmReturn = + instance.callCborVmFunction( + SharedCoreWasm_ModuleExports::vmNew, + new VmNewParameters( + StreamSupport.stream(headersAccessor.keys().spliterator(), false) + .map(key -> new String[] {key, headersAccessor.get(key)}) + .filter(arr -> arr[1] != null) + .collect(Collectors.toList())), + VmNewReturn.class); + if (newVmReturn instanceof VmNewReturn.Failure f) { + throw new ProtocolException("Failed to create state machine: " + f.message(), f.code); + } + int vmPtr = ((VmNewReturn.Ok) newVmReturn).pointer(); + return new SharedCoreVM(instance, vmPtr); + } + + @Override + public void close() { + if (!closed) { + LOG.trace("[vm=0x{}] close()", Integer.toHexString(vmPtr)); + instance.getExports().vmFree(vmPtr); + closed = true; + } + } + + private void verifyNotClosed() { + if (closed) { + throw new IllegalStateException("Attempting to use the VM when is closed"); + } + } + + public void notifyInput(byte[] bytes) { + if (closed) { + return; + } + LOG.trace("[vm=0x{}] notifyInput()", Integer.toHexString(vmPtr)); + var bufferPtr = instance.write(bytes); + instance.getExports().vmNotifyInput(vmPtr, bufferPtr, bytes.length); + } + + public void notifyInputClosed() { + if (closed) { + return; + } + LOG.trace("[vm=0x{}] notifyInputClosed()", Integer.toHexString(vmPtr)); + instance.getExports().vmNotifyInputClosed(vmPtr); + } + + public void notifyError( + String message, @Nullable String stacktrace, @Nullable Long delayOverrideMillis) { + if (closed) { + return; + } + LOG.trace("[vm=0x{}] notifyError()", Integer.toHexString(vmPtr)); + instance.callCborVmFunction( + (exports, ptr, len) -> exports.vmNotifyError(vmPtr, ptr, len), + new VmNotifyError(message, stacktrace, delayOverrideMillis)); + } + + public byte[] takeOutput() { + if (closed) { + return new byte[0]; + } + LOG.trace("[vm=0x{}] takeOutput()", Integer.toHexString(vmPtr)); + + var ptr = instance.getExports().vmTakeOutput(vmPtr); + return instance.readAndFree(ptr); + } + + public String getResponseContentType() { + LOG.trace("[vm=0x{}] getResponseContentType()", Integer.toHexString(vmPtr)); + verifyNotClosed(); + + // TODO change this to just return headers back bro + var ret = + instance.callCborVmFunction( + exports -> exports.vmGetResponseHead(vmPtr), ResponseHeadReturn.class); + if (ret.headers() == null) return ""; + for (String[] pair : ret.headers()) { + if ("content-type".equalsIgnoreCase(pair[0])) return pair[1]; + } + return ""; + } + + public boolean isReadyToExecute() { + LOG.trace("[vm=0x{}] isReadyToExecute()", Integer.toHexString(vmPtr)); + verifyNotClosed(); + + var ret = + instance.callCborVmFunction( + exports -> exports.vmIsReadyToExecute(vmPtr), IsReadyReturn.class); + if (ret instanceof IsReadyReturn.Failure f) throw vmError(f.code, f.message); + return ((IsReadyReturn.Ok) ret).ready(); + } + + public boolean isCompleted(int handle) { + LOG.trace("[vm=0x{}] isCompleted()", Integer.toHexString(vmPtr)); + verifyNotClosed(); + + return instance.getExports().vmIsCompleted(vmPtr, handle) != 0L; + } + + public DoProgressResult doProgress(int[] handles) { + LOG.trace("[vm=0x{}] doProgress()", Integer.toHexString(vmPtr)); + verifyNotClosed(); + + var ret = + instance.callCborVmFunction( + (exports, ptr, len) -> exports.vmDoProgress(vmPtr, ptr, len), + new VmDoProgressParameters(handles), + DoProgressReturn.class); + if (ret instanceof DoProgressReturn.AnyCompleted) return DoProgressResult.ANY_COMPLETED; + if (ret instanceof DoProgressReturn.WaitingExternalProgress) + return DoProgressResult.WAIT_EXTERNAL_PROGRESS; + if (ret instanceof DoProgressReturn.CancelSignalReceived) + return DoProgressResult.CANCEL_SIGNAL_RECEIVED; + if (ret instanceof DoProgressReturn.ExecuteRun r) + return new DoProgressResult.ExecuteRun(r.handle()); + if (ret instanceof DoProgressReturn.Suspended) return DoProgressResult.SUSPENDED; + DoProgressReturn.Failure f = (DoProgressReturn.Failure) ret; + throw vmError(f.code, f.message); + } + + public @Nullable NotificationValue takeNotification(int handle) { + LOG.trace("[vm=0x{}] takeNotification()", Integer.toHexString(vmPtr)); + verifyNotClosed(); + + var ret = + instance.callCborVmFunction( + exports -> exports.vmTakeNotification(vmPtr, handle), TakeNotificationReturn.class); + if (ret instanceof TakeNotificationReturn.NotReady) return null; + if (ret instanceof TakeNotificationReturn.Suspended) return null; + if (ret instanceof TakeNotificationReturn.Value v) return v.value(); + TakeNotificationReturn.Failure f = (TakeNotificationReturn.Failure) ret; + throw vmError(f.code, f.message); + } + + public Input sysInput() { + LOG.trace("[vm=0x{}] sysInput()", Integer.toHexString(vmPtr)); + verifyNotClosed(); + + var ret = + instance.callCborVmFunction(exports -> exports.vmSysInput(vmPtr), SysInputReturn.class); + if (ret instanceof SysInputReturn.Failure f) throw vmError(f.code(), f.message()); + WasmInput w = ((SysInputReturn.Ok) ret).input(); + List> hdrs = new ArrayList<>(); + if (w.headers() != null) { + for (String[] pair : w.headers()) hdrs.add(Map.entry(pair[0], pair[1])); + } + return new Input( + w.invocationId(), w.randomSeed(), w.key(), Collections.unmodifiableList(hdrs), w.input()); + } + + // ------------------------------------------------------------------------- + // State + // ------------------------------------------------------------------------- + + public int sysStateGet(String key) { + LOG.trace("[vm=0x{}] sysStateGet()", Integer.toHexString(vmPtr)); + verifyNotClosed(); + + return callWithHandleReturn( + SharedCoreWasm_ModuleExports::vmSysStateGet, new VmSysStateGetParameters(key)); + } + + public int sysStateGetKeys() { + LOG.trace("[vm=0x{}] sysStateGetKeys()", Integer.toHexString(vmPtr)); + verifyNotClosed(); + + return callWithHandleReturn(SharedCoreWasm_ModuleExports::vmSysStateGetKeys); + } + + public void sysStateSet(String key, byte[] value) { + LOG.trace("[vm=0x{}] sysStateSet()", Integer.toHexString(vmPtr)); + verifyNotClosed(); + + callWithEmptyReturn( + SharedCoreWasm_ModuleExports::vmSysStateSet, new VmSysStateSetParameters(key, value)); + } + + public void sysStateClear(String key) { + LOG.trace("[vm=0x{}] sysStateClear()", Integer.toHexString(vmPtr)); + verifyNotClosed(); + + callWithEmptyReturn( + SharedCoreWasm_ModuleExports::vmSysStateClear, new VmSysStateClearParameters(key)); + } + + public void sysStateClearAll() { + LOG.trace("[vm=0x{}] sysStateClearAll()", Integer.toHexString(vmPtr)); + verifyNotClosed(); + + callWithEmptyReturn(SharedCoreWasm_ModuleExports::vmSysStateClearAll); + } + + // ------------------------------------------------------------------------- + // Sleep + // ------------------------------------------------------------------------- + + public int sysSleep(long durationMillis, @Nullable String name) { + LOG.trace("[vm=0x{}] sysSleep()", Integer.toHexString(vmPtr)); + verifyNotClosed(); + + long now = System.currentTimeMillis(); + return callWithHandleReturn( + SharedCoreWasm_ModuleExports::vmSysSleep, + new VmSysSleepParameters(name != null ? name : "", now + durationMillis, now)); + } + + // ------------------------------------------------------------------------- + // Call / Send + // ------------------------------------------------------------------------- + + public CallHandleResult sysCall( + String service, + String handler, + @Nullable String key, + byte[] payload, + @Nullable String idempotencyKey, + @Nullable List> extraHeaders) { + LOG.trace("[vm=0x{}] sysCall()", Integer.toHexString(vmPtr)); + verifyNotClosed(); + + var ret = + instance.callCborVmFunction( + (exports, ptr, len) -> exports.vmSysCall(vmPtr, ptr, len), + new VmSysCallParameters( + service, handler, key, idempotencyKey, toHeaderList(extraHeaders), payload), + SysCallReturn.class); + if (ret instanceof SysCallReturn.Failure f) throw vmError(f.code(), f.message()); + SysCallReturn.Ok ok = (SysCallReturn.Ok) ret; + return new CallHandleResult(ok.invocationIdHandle(), ok.resultHandle()); + } + + public int sysSend( + String service, + String handler, + @Nullable String key, + byte[] payload, + @Nullable String idempotencyKey, + @Nullable List> extraHeaders, + @Nullable Long delayMillis) { + LOG.trace("[vm=0x{}] sysSend()", Integer.toHexString(vmPtr)); + verifyNotClosed(); + + Long executionTime = delayMillis != null ? System.currentTimeMillis() + delayMillis : null; + return callWithHandleReturn( + SharedCoreWasm_ModuleExports::vmSysSend, + new VmSysSendParameters( + service, + handler, + key, + idempotencyKey, + toHeaderList(extraHeaders), + payload, + executionTime)); + } + + // ------------------------------------------------------------------------- + // Awakeables + // ------------------------------------------------------------------------- + + public AwakeableResult sysAwakeable() { + LOG.trace("[vm=0x{}] sysAwakeable()", Integer.toHexString(vmPtr)); + verifyNotClosed(); + + AwakeableReturn ret = + instance.callCborVmFunction( + (exports) -> exports.vmSysAwakeable(vmPtr), AwakeableReturn.class); + if (ret instanceof AwakeableReturn.Failure f) throw vmError(f.code(), f.message()); + AwakeableReturn.Ok ok = (AwakeableReturn.Ok) ret; + return new AwakeableResult(ok.handle(), ok.id()); + } + + public void sysCompleteAwakeable(String id, NonEmptyValueParam result) { + LOG.trace("[vm=0x{}] sysCompleteAwakeable()", Integer.toHexString(vmPtr)); + verifyNotClosed(); + + callWithEmptyReturn( + SharedCoreWasm_ModuleExports::vmSysCompleteAwakeable, + new VmSysCompleteAwakeableParameters(id, result)); + } + + public void sysCompleteAwakeableSuccess(String id, byte[] value) { + sysCompleteAwakeable(id, new NonEmptyValueParam.Success(value)); + } + + public void sysCompleteAwakeableFailure(String id, int code, String message) { + sysCompleteAwakeable(id, new NonEmptyValueParam.Failure(code, message, null)); + } + + public void sysCompleteAwakeableFailure( + String id, int code, String message, @Nullable List metadata) { + sysCompleteAwakeable(id, new NonEmptyValueParam.Failure(code, message, metadata)); + } + + public int sysCreateSignalHandle(String signalName) { + LOG.trace("[vm=0x{}] sysCreateSignalHandle()", Integer.toHexString(vmPtr)); + verifyNotClosed(); + + return callWithHandleReturn( + SharedCoreWasm_ModuleExports::vmSysCreateSignalHandle, + new VmSysCreateSignalHandleParameters(signalName)); + } + + public void sysCompleteSignal(String target, String signalName, NonEmptyValueParam result) { + LOG.trace("[vm=0x{}] sysCompleteSignal()", Integer.toHexString(vmPtr)); + verifyNotClosed(); + + callWithEmptyReturn( + SharedCoreWasm_ModuleExports::vmSysCompleteSignal, + new VmSysCompleteSignalParameters(target, signalName, result)); + } + + public void sysCompleteSignalSuccess(String target, String signalName, byte[] value) { + sysCompleteSignal(target, signalName, new NonEmptyValueParam.Success(value)); + } + + public void sysCompleteSignalFailure(String target, String signalName, int code, String message) { + sysCompleteSignal(target, signalName, new NonEmptyValueParam.Failure(code, message, null)); + } + + public void sysCompleteSignalFailure( + String target, + String signalName, + int code, + String message, + @Nullable List metadata) { + sysCompleteSignal(target, signalName, new NonEmptyValueParam.Failure(code, message, metadata)); + } + + public int sysPromiseGet(String key) { + LOG.trace("[vm=0x{}] sysPromiseGet()", Integer.toHexString(vmPtr)); + verifyNotClosed(); + + return callWithHandleReturn( + SharedCoreWasm_ModuleExports::vmSysPromiseGet, new VmSysPromiseGetParameters(key)); + } + + public int sysPromisePeek(String key) { + LOG.trace("[vm=0x{}] sysPromisePeek()", Integer.toHexString(vmPtr)); + verifyNotClosed(); + + return callWithHandleReturn( + SharedCoreWasm_ModuleExports::vmSysPromisePeek, new VmSysPromisePeekParameters(key)); + } + + public int sysPromiseComplete(String key, NonEmptyValueParam result) { + LOG.trace("[vm=0x{}] sysPromiseComplete()", Integer.toHexString(vmPtr)); + return callWithHandleReturn( + SharedCoreWasm_ModuleExports::vmSysPromiseComplete, + new VmSysPromiseCompleteParameters(key, result)); + } + + public int sysPromiseCompleteSuccess(String key, byte[] value) { + return sysPromiseComplete(key, new NonEmptyValueParam.Success(value)); + } + + public int sysPromiseCompleteFailure(String key, int code, String message) { + return sysPromiseComplete(key, new NonEmptyValueParam.Failure(code, message, null)); + } + + public int sysPromiseCompleteFailure( + String key, int code, String message, @Nullable List metadata) { + return sysPromiseComplete(key, new NonEmptyValueParam.Failure(code, message, metadata)); + } + + public int sysRun(String name) { + LOG.trace("[vm=0x{}] sysRun()", Integer.toHexString(vmPtr)); + verifyNotClosed(); + + return callWithHandleReturn( + SharedCoreWasm_ModuleExports::vmSysRun, new VmSysRunParameters(name)); + } + + public void proposeRunCompletionSuccess(int handle, byte[] value) { + LOG.trace("[vm=0x{}] proposeRunCompletionSuccess()", Integer.toHexString(vmPtr)); + verifyNotClosed(); + + callWithEmptyReturn( + SharedCoreWasm_ModuleExports::vmProposeRunCompletion, + new VmProposeRunCompletionParameters(handle, new RunResult.Success(value), 0L, null)); + } + + public void proposeRunCompletionTerminalFailure( + int handle, int code, String message, @Nullable List metadata) { + LOG.trace("[vm=0x{}] proposeRunCompletionTerminalFailure()", Integer.toHexString(vmPtr)); + verifyNotClosed(); + + callWithEmptyReturn( + SharedCoreWasm_ModuleExports::vmProposeRunCompletion, + new VmProposeRunCompletionParameters( + handle, new RunResult.TerminalFailure(code, message, metadata), 0L, null)); + } + + public void proposeRunCompletionRetryableFailure( + int handle, + int code, + String message, + @Nullable String stacktrace, + long attemptDurationMillis, + @Nullable WasmRetryPolicy retryPolicy) { + LOG.trace("[vm=0x{}] proposeRunCompletionRetryableFailure()", Integer.toHexString(vmPtr)); + verifyNotClosed(); + + callWithEmptyReturn( + SharedCoreWasm_ModuleExports::vmProposeRunCompletion, + new VmProposeRunCompletionParameters( + handle, + new RunResult.RetryableFailure(code, message, stacktrace), + attemptDurationMillis, + retryPolicy)); + } + + public void sysCancelInvocation(String invocationId) { + LOG.trace("[vm=0x{}] sysCancelInvocation()", Integer.toHexString(vmPtr)); + verifyNotClosed(); + + callWithEmptyReturn( + SharedCoreWasm_ModuleExports::vmSysCancelInvocation, + new VmSysCancelInvocation(invocationId)); + } + + public int sysAttachInvocation(String invocationId) { + LOG.trace("[vm=0x{}] sysAttachInvocation()", Integer.toHexString(vmPtr)); + verifyNotClosed(); + + return callWithHandleReturn( + SharedCoreWasm_ModuleExports::vmSysAttachInvocation, + new VmSysAttachInvocation(invocationId)); + } + + public int sysGetInvocationOutput(String invocationId) { + LOG.trace("[vm=0x{}] sysGetInvocationOutput()", Integer.toHexString(vmPtr)); + verifyNotClosed(); + + return callWithHandleReturn( + SharedCoreWasm_ModuleExports::vmSysGetInvocationOutput, + new VmSysGetInvocationOutput(invocationId)); + } + + public void sysWriteOutput(NonEmptyValueParam result) { + LOG.trace("[vm=0x{}] sysWriteOutput()", Integer.toHexString(vmPtr)); + verifyNotClosed(); + + callWithEmptyReturn( + SharedCoreWasm_ModuleExports::vmSysWriteOutput, new VmSysWriteOutputParameters(result)); + } + + public void sysWriteOutputSuccess(byte[] value) { + sysWriteOutput(new NonEmptyValueParam.Success(value)); + } + + public void sysWriteOutputFailure(int code, String message) { + sysWriteOutput(new NonEmptyValueParam.Failure(code, message, null)); + } + + public void sysWriteOutputFailure(int code, String message, @Nullable List metadata) { + sysWriteOutput(new NonEmptyValueParam.Failure(code, message, metadata)); + } + + public void sysEnd() { + LOG.trace("[vm=0x{}] sysEnd()", Integer.toHexString(vmPtr)); + verifyNotClosed(); + + callWithEmptyReturn(SharedCoreWasm_ModuleExports::vmSysEnd); + } + + // ========================================================================= + // Result types (returned to callers of this class) + // ========================================================================= + + public sealed interface DoProgressResult { + DoProgressResult ANY_COMPLETED = new AnyCompleted(); + DoProgressResult WAIT_EXTERNAL_PROGRESS = new WaitExternalProgress(); + DoProgressResult CANCEL_SIGNAL_RECEIVED = new CancelSignalReceived(); + DoProgressResult SUSPENDED = new Suspended(); + + record AnyCompleted() implements DoProgressResult {} + + record WaitExternalProgress() implements DoProgressResult {} + + record ExecuteRun(int handle) implements DoProgressResult {} + + record CancelSignalReceived() implements DoProgressResult {} + + record Suspended() implements DoProgressResult {} + } + + // ========================================================================= + // CBOR response types (decoded from WASM responses) + // ========================================================================= + + @JsonTypeInfo(use = Id.NAME, include = As.PROPERTY, property = "type") + @JsonSubTypes({ + @JsonSubTypes.Type(value = VmNewReturn.Ok.class, name = "ok"), + @JsonSubTypes.Type(value = VmNewReturn.Failure.class, name = "failure"), + }) + public sealed interface VmNewReturn { + record Ok(int pointer) implements VmNewReturn {} + + record Failure(int code, String message) implements VmNewReturn {} + } + + @JsonTypeInfo(use = Id.NAME, include = As.PROPERTY, property = "type") + @JsonSubTypes({ + @JsonSubTypes.Type(value = EmptyReturn.Ok.class, name = "ok"), + @JsonSubTypes.Type(value = EmptyReturn.Failure.class, name = "failure"), + }) + public sealed interface EmptyReturn { + record Ok() implements EmptyReturn {} + + record Failure(int code, String message) implements EmptyReturn {} + } + + @JsonTypeInfo(use = Id.NAME, include = As.PROPERTY, property = "type") + @JsonSubTypes({ + @JsonSubTypes.Type(value = HandleReturn.Ok.class, name = "ok"), + @JsonSubTypes.Type(value = HandleReturn.Failure.class, name = "failure"), + }) + public sealed interface HandleReturn { + record Ok(int handle) implements HandleReturn {} + + record Failure(int code, String message) implements HandleReturn {} + } + + @JsonTypeInfo(use = Id.NAME, include = As.PROPERTY, property = "type") + @JsonSubTypes({ + @JsonSubTypes.Type(value = IsReadyReturn.Ok.class, name = "ok"), + @JsonSubTypes.Type(value = IsReadyReturn.Failure.class, name = "failure"), + }) + public sealed interface IsReadyReturn { + record Ok(boolean ready) implements IsReadyReturn {} + + record Failure(int code, String message) implements IsReadyReturn {} + } + + /** Plain struct — mirrors Go's VmGetResponseHeadReturn (no Ok/Failure wrapper). */ + public record ResponseHeadReturn(int statusCode, List headers) {} + + @JsonTypeInfo(use = Id.NAME, include = As.PROPERTY, property = "type") + @JsonSubTypes({ + @JsonSubTypes.Type(value = DoProgressReturn.AnyCompleted.class, name = "anyCompleted"), + @JsonSubTypes.Type( + value = DoProgressReturn.WaitingExternalProgress.class, + name = "waitingExternalProgress"), + @JsonSubTypes.Type(value = DoProgressReturn.ExecuteRun.class, name = "executeRun"), + @JsonSubTypes.Type( + value = DoProgressReturn.CancelSignalReceived.class, + name = "cancelSignalReceived"), + @JsonSubTypes.Type(value = DoProgressReturn.Suspended.class, name = "suspended"), + @JsonSubTypes.Type(value = DoProgressReturn.Failure.class, name = "failure"), + }) + public sealed interface DoProgressReturn { + record AnyCompleted() implements DoProgressReturn {} + + record WaitingExternalProgress() implements DoProgressReturn {} + + record ExecuteRun(int handle) implements DoProgressReturn {} + + record CancelSignalReceived() implements DoProgressReturn {} + + record Suspended() implements DoProgressReturn {} + + record Failure(int code, String message) implements DoProgressReturn {} + } + + /** Notification value payload — matches Rust {@code NotificationValue}. */ + @JsonTypeInfo(use = Id.NAME, include = As.PROPERTY, property = "type") + @JsonSubTypes({ + @JsonSubTypes.Type(value = NotificationValue.Void.class, name = "void"), + @JsonSubTypes.Type(value = NotificationValue.Success.class, name = "success"), + @JsonSubTypes.Type(value = NotificationValue.Failure.class, name = "failure"), + @JsonSubTypes.Type(value = NotificationValue.StateKeys.class, name = "stateKeys"), + @JsonSubTypes.Type(value = NotificationValue.InvocationId.class, name = "invocationId"), + }) + public sealed interface NotificationValue { + NotificationValue VOID = new Void(); + + record Void() implements NotificationValue {} + + record Success(byte[] value) implements NotificationValue {} + + record Failure(int code, String message, @Nullable List metadata) + implements NotificationValue {} + + record StateKeys(List keys) implements NotificationValue {} + + record InvocationId(String id) implements NotificationValue {} + } + + @JsonTypeInfo(use = Id.NAME, include = As.PROPERTY, property = "type") + @JsonSubTypes({ + @JsonSubTypes.Type(value = TakeNotificationReturn.NotReady.class, name = "notReady"), + @JsonSubTypes.Type(value = TakeNotificationReturn.Value.class, name = "value"), + @JsonSubTypes.Type(value = TakeNotificationReturn.Suspended.class, name = "suspended"), + @JsonSubTypes.Type(value = TakeNotificationReturn.Failure.class, name = "failure"), + }) + public sealed interface TakeNotificationReturn { + record NotReady() implements TakeNotificationReturn {} + + record Value(NotificationValue value) implements TakeNotificationReturn {} + + record Suspended() implements TakeNotificationReturn {} + + record Failure(int code, String message) implements TakeNotificationReturn {} + } + + public record WasmInput( + String invocationId, + String key, + List headers, + byte[] input, + long randomSeed, + boolean shouldUseRandomSeed) {} + + @JsonTypeInfo(use = Id.NAME, include = As.PROPERTY, property = "type") + @JsonSubTypes({ + @JsonSubTypes.Type(value = SysInputReturn.Ok.class, name = "ok"), + @JsonSubTypes.Type(value = SysInputReturn.Failure.class, name = "failure"), + }) + public sealed interface SysInputReturn { + record Ok(WasmInput input) implements SysInputReturn {} + + record Failure(int code, String message) implements SysInputReturn {} + } + + @JsonTypeInfo(use = Id.NAME, include = As.PROPERTY, property = "type") + @JsonSubTypes({ + @JsonSubTypes.Type(value = AwakeableReturn.Ok.class, name = "ok"), + @JsonSubTypes.Type(value = AwakeableReturn.Failure.class, name = "failure"), + }) + public sealed interface AwakeableReturn { + record Ok(String id, int handle) implements AwakeableReturn {} + + record Failure(int code, String message) implements AwakeableReturn {} + } + + @JsonTypeInfo(use = Id.NAME, include = As.PROPERTY, property = "type") + @JsonSubTypes({ + @JsonSubTypes.Type(value = SysCallReturn.Ok.class, name = "ok"), + @JsonSubTypes.Type(value = SysCallReturn.Failure.class, name = "failure"), + }) + public sealed interface SysCallReturn { + record Ok(int invocationIdHandle, int resultHandle) implements SysCallReturn {} + + record Failure(int code, String message) implements SysCallReturn {} + } + + // ========================================================================= + // Input DTOs (Java → Rust, CBOR maps with camelCase keys) + // Field names match Rust struct field names after camelCase renaming. + // ========================================================================= + + public record VmNotifyError( + String message, @Nullable String stacktrace, @Nullable Long delayOverrideMillis) {} + + public record VmNewParameters(List headers) {} + + public record VmDoProgressParameters(int[] handles) {} + + public record VmSysStateGetParameters(String key) {} + + public record VmSysStateSetParameters(String key, byte[] value) {} + + public record VmSysStateClearParameters(String key) {} + + public record VmSysSleepParameters( + String name, long wakeUpTimeSinceUnixEpochMillis, long nowSinceUnixEpochMillis) {} + + /** Combined success/failure union — matches Rust {@code NonEmptyValueParam}. */ + @JsonTypeInfo(use = Id.NAME, include = As.PROPERTY, property = "type") + @JsonSubTypes({ + @JsonSubTypes.Type(value = NonEmptyValueParam.Success.class, name = "success"), + @JsonSubTypes.Type(value = NonEmptyValueParam.Failure.class, name = "failure"), + }) + public sealed interface NonEmptyValueParam { + record Success(byte[] value) implements NonEmptyValueParam {} + + record Failure(int code, String message, @Nullable List metadata) + implements NonEmptyValueParam {} + } + + public record VmSysCompleteAwakeableParameters(String id, NonEmptyValueParam result) {} + + public record VmSysCallParameters( + String service, + String handler, + @Nullable String key, + @Nullable String idempotencyKey, + List headers, + byte[] input) {} + + public record VmSysSendParameters( + String service, + String handler, + @Nullable String key, + @Nullable String idempotencyKey, + List headers, + byte[] input, + @Nullable Long executionTimeSinceUnixEpochMillis) {} + + public record VmSysCancelInvocation(String invocationId) {} + + public record VmSysAttachInvocation(String invocationId) {} + + public record VmSysGetInvocationOutput(String invocationId) {} + + public record VmSysPromiseGetParameters(String key) {} + + public record VmSysPromisePeekParameters(String key) {} + + public record VmSysPromiseCompleteParameters(String id, NonEmptyValueParam result) {} + + public record VmSysRunParameters(String name) {} + + @JsonTypeInfo(use = Id.NAME, include = As.PROPERTY, property = "type") + @JsonSubTypes({ + @JsonSubTypes.Type(value = RunResult.Success.class, name = "success"), + @JsonSubTypes.Type(value = RunResult.TerminalFailure.class, name = "terminalFailure"), + @JsonSubTypes.Type(value = RunResult.RetryableFailure.class, name = "retryableFailure"), + }) + public sealed interface RunResult { + record Success(byte[] value) implements RunResult {} + + record TerminalFailure(int code, String message, @Nullable List metadata) + implements RunResult {} + + record RetryableFailure(int code, String message, @Nullable String stacktrace) + implements RunResult {} + } + + public record VmProposeRunCompletionParameters( + int handle, + RunResult result, + long attemptDurationMillis, + @Nullable WasmRetryPolicy retryPolicy) {} + + public record WasmRetryPolicy( + long initialIntervalMillis, + float factor, + @Nullable Long maxIntervalMillis, + @Nullable Integer maxAttempts, + @Nullable Long maxDurationMillis) {} + + public record VmSysCreateSignalHandleParameters(String name) {} + + public record VmSysCompleteSignalParameters( + String target, String name, NonEmptyValueParam result) {} + + public record VmSysWriteOutputParameters(NonEmptyValueParam result) {} + + // ========================================================================= + // Result value types returned to WasmStateMachineImpl + // ========================================================================= + + public record Input( + String invocationId, + long randomSeed, + String key, + List> headers, + byte[] body) {} + + public record CallHandleResult(int invocationIdHandle, int resultHandle) {} + + public record AwakeableResult(int signalHandle, String awakeableId) {} + + private int callWithHandleReturn( + SharedCoreInstance.QuadFunction + func, + Object input) { + var ret = + instance.callCborVmFunction( + (exports, ptr, len) -> func.apply(exports, vmPtr, ptr, len), input, HandleReturn.class); + if (ret instanceof HandleReturn.Failure f) throw vmError(f.code, f.message); + return ((HandleReturn.Ok) ret).handle(); + } + + private int callWithHandleReturn(BiFunction func) { + var ret = + instance.callCborVmFunction(exports -> func.apply(exports, vmPtr), HandleReturn.class); + if (ret instanceof HandleReturn.Failure f) throw vmError(f.code, f.message); + return ((HandleReturn.Ok) ret).handle(); + } + + private void callWithEmptyReturn( + SharedCoreInstance.QuadFunction + func, + Object input) { + var ret = + instance.callCborVmFunction( + (exports, ptr, len) -> func.apply(exports, vmPtr, ptr, len), input, EmptyReturn.class); + if (ret instanceof EmptyReturn.Failure f) throw vmError(f.code, f.message); + } + + private void callWithEmptyReturn(BiFunction func) { + var ret = instance.callCborVmFunction(exports -> func.apply(exports, vmPtr), EmptyReturn.class); + if (ret instanceof EmptyReturn.Failure f) throw vmError(f.code, f.message); + } + + private static ProtocolException vmError(int code, String message) { + return new ProtocolException(message, code); + } + + private static List toHeaderList(@Nullable List> headers) { + if (headers == null || headers.isEmpty()) return Collections.emptyList(); + List r = new ArrayList<>(headers.size()); + for (Map.Entry e : headers) r.add(new String[] {e.getKey(), e.getValue()}); + return r; + } +} diff --git a/sdk-core/src/main/java/dev/restate/sdk/core/statemachine/AsyncResultsState.java b/sdk-core/src/main/java/dev/restate/sdk/core/statemachine/AsyncResultsState.java deleted file mode 100644 index 0594624c4..000000000 --- a/sdk-core/src/main/java/dev/restate/sdk/core/statemachine/AsyncResultsState.java +++ /dev/null @@ -1,131 +0,0 @@ -// Copyright (c) 2023 - Restate Software, Inc., Restate GmbH -// -// This file is part of the Restate Java SDK, -// which is released under the MIT license. -// -// You can find a copy of the license in file LICENSE in the root -// directory of this repository or package, or at -// https://github.com/restatedev/sdk-java/blob/main/LICENSE -package dev.restate.sdk.core.statemachine; - -import com.google.protobuf.ByteString; -import dev.restate.sdk.core.ProtocolException; -import dev.restate.sdk.core.generated.protocol.Protocol; -import java.util.*; - -final class AsyncResultsState { - public static final int CANCEL_NOTIFICATION_HANDLE = 1; - - private final Deque> toProcess; - private final Map ready; - private final Map handleMapping; - - private int nextNotificationHandle; - - public AsyncResultsState() { - this.toProcess = new ArrayDeque<>(); - this.ready = new HashMap<>(); - - this.handleMapping = new HashMap<>(); - // Prepare built in signal handles here - this.handleMapping.put(CANCEL_NOTIFICATION_HANDLE, new NotificationId.SignalId(1)); - - // First 15 are reserved for built-in signals! - nextNotificationHandle = 17; - } - - public void enqueue(Protocol.NotificationTemplate notification) { - var notificationId = - switch (notification.getIdCase()) { - case COMPLETION_ID -> new NotificationId.CompletionId(notification.getCompletionId()); - case SIGNAL_ID -> new NotificationId.SignalId(notification.getSignalId()); - case SIGNAL_NAME -> new NotificationId.SignalName(notification.getSignalName()); - case ID_NOT_SET -> throw ProtocolException.badNotificationMessage("id"); - }; - - var notificationValue = - switch (notification.getResultCase()) { - case VOID -> NotificationValue.Empty.INSTANCE; - case VALUE -> - new NotificationValue.Success( - Util.byteStringToSlice(notification.getValue().getContent())); - case FAILURE -> - new NotificationValue.Failure(Util.toRestateException(notification.getFailure())); - case INVOCATION_ID -> new NotificationValue.InvocationId(notification.getInvocationId()); - case STATE_KEYS -> - new NotificationValue.StateKeys( - notification.getStateKeys().getKeysList().stream() - .map(ByteString::toStringUtf8) - .toList()); - case RESULT_NOT_SET -> throw ProtocolException.badNotificationMessage("result"); - }; - - toProcess.addLast(Map.entry(notificationId, notificationValue)); - } - - public void insertReady(NotificationId id, NotificationValue value) { - ready.put(id, value); - } - - public int createHandleMapping(NotificationId notificationId) { - int assignedHandle = nextNotificationHandle; - nextNotificationHandle++; - handleMapping.put(assignedHandle, notificationId); - return assignedHandle; - } - - public boolean processNextUntilAnyFound(Set ids) { - while (!toProcess.isEmpty()) { - Map.Entry notif = toProcess.removeFirst(); - boolean anyFound = ids.contains(notif.getKey()); - ready.put(notif.getKey(), notif.getValue()); - if (anyFound) { - return true; - } - } - return false; - } - - public boolean isHandleCompleted(int handle) { - NotificationId id = handleMapping.get(handle); - return id != null && ready.containsKey(id); - } - - public boolean nonDeterministicFindId(NotificationId id) { - if (ready.containsKey(id)) { - return true; - } - return toProcess.stream().anyMatch(notif -> notif.getKey().equals(id)); - } - - public Set resolveNotificationHandles(List handles) { - Set result = new LinkedHashSet<>(); - for (int handle : handles) { - NotificationId id = handleMapping.get(handle); - if (id != null) { - result.add(id); - } - } - return result; - } - - public NotificationId mustResolveNotificationHandle(int handle) { - NotificationId id = handleMapping.get(handle); - if (id == null) { - throw new IllegalStateException("If there is a handle, there must be a corresponding id"); - } - return id; - } - - public Optional takeHandle(int handle) { - NotificationId id = handleMapping.get(handle); - if (id != null) { - NotificationValue result = ready.remove(id); - if (result != null) { - handleMapping.remove(handle); - return Optional.of(result); - } - } - return Optional.empty(); - } -} diff --git a/sdk-core/src/main/java/dev/restate/sdk/core/statemachine/ClosedState.java b/sdk-core/src/main/java/dev/restate/sdk/core/statemachine/ClosedState.java deleted file mode 100644 index 68ac44b92..000000000 --- a/sdk-core/src/main/java/dev/restate/sdk/core/statemachine/ClosedState.java +++ /dev/null @@ -1,34 +0,0 @@ -// Copyright (c) 2023 - Restate Software, Inc., Restate GmbH -// -// This file is part of the Restate Java SDK, -// which is released under the MIT license. -// -// You can find a copy of the license in file LICENSE in the root -// directory of this repository or package, or at -// https://github.com/restatedev/sdk-java/blob/main/LICENSE -package dev.restate.sdk.core.statemachine; - -import java.time.Duration; -import org.jspecify.annotations.Nullable; - -final class ClosedState implements State { - - @Override - public void hitError( - Throwable throwable, - @Nullable CommandRelationship relatedCommand, - @Nullable Duration nextRetryDelay, - StateContext stateContext) { - // Ignore, as we closed already - } - - @Override - public void end(StateContext stateContext) { - // Ignore, as we closed already - } - - @Override - public InvocationState getInvocationState() { - return InvocationState.CLOSED; - } -} diff --git a/sdk-core/src/main/java/dev/restate/sdk/core/statemachine/CommandAccessor.java b/sdk-core/src/main/java/dev/restate/sdk/core/statemachine/CommandAccessor.java deleted file mode 100644 index 89e496ccb..000000000 --- a/sdk-core/src/main/java/dev/restate/sdk/core/statemachine/CommandAccessor.java +++ /dev/null @@ -1,434 +0,0 @@ -// Copyright (c) 2023 - Restate Software, Inc., Restate GmbH -// -// This file is part of the Restate Java SDK, -// which is released under the MIT license. -// -// You can find a copy of the license in file LICENSE in the root -// directory of this repository or package, or at -// https://github.com/restatedev/sdk-java/blob/main/LICENSE -package dev.restate.sdk.core.statemachine; - -import com.google.protobuf.MessageLite; -import dev.restate.sdk.core.ProtocolException; -import dev.restate.sdk.core.generated.protocol.Protocol; - -interface CommandAccessor { - - String getName(E expected); - - void checkEntryHeader(int commandIndex, E expected, MessageLite actual) throws ProtocolException; - - CommandAccessor INPUT = - new CommandAccessor<>() { - @Override - public String getName(Protocol.InputCommandMessage expected) { - return ""; - } - - @Override - public void checkEntryHeader( - int commandIndex, Protocol.InputCommandMessage expected, MessageLite actual) - throws ProtocolException { - // Nothing to check - } - }; - CommandAccessor OUTPUT = - new CommandAccessor<>() { - @Override - public String getName(Protocol.OutputCommandMessage expected) { - return expected.getName(); - } - - @Override - public void checkEntryHeader( - int commandIndex, Protocol.OutputCommandMessage expected, MessageLite actual) - throws ProtocolException { - EntryHeaderChecker.check( - commandIndex, Protocol.OutputCommandMessage.class, expected, actual) - .checkField("name", Protocol.OutputCommandMessage::getName) - .checkField("result", Protocol.OutputCommandMessage::getResultCase) - .checkField("value", Protocol.OutputCommandMessage::getValue) - .checkField("failure", Protocol.OutputCommandMessage::getFailure) - .verify(); - } - }; - CommandAccessor GET_EAGER_STATE = - new CommandAccessor<>() { - @Override - public void checkEntryHeader( - int commandIndex, Protocol.GetEagerStateCommandMessage expected, MessageLite actual) - throws ProtocolException { - EntryHeaderChecker.check( - commandIndex, Protocol.GetEagerStateCommandMessage.class, expected, actual) - .checkField("name", Protocol.GetEagerStateCommandMessage::getName) - .checkField("key", Protocol.GetEagerStateCommandMessage::getKey) - .verify(); - } - - @Override - public String getName(Protocol.GetEagerStateCommandMessage expected) { - return expected.getName(); - } - }; - CommandAccessor GET_LAZY_STATE = - new CommandAccessor<>() { - @Override - public String getName(Protocol.GetLazyStateCommandMessage expected) { - return expected.getName(); - } - - @Override - public void checkEntryHeader( - int commandIndex, Protocol.GetLazyStateCommandMessage expected, MessageLite actual) - throws ProtocolException { - EntryHeaderChecker.check( - commandIndex, Protocol.GetLazyStateCommandMessage.class, expected, actual) - .checkField("name", Protocol.GetLazyStateCommandMessage::getName) - .checkField("key", Protocol.GetLazyStateCommandMessage::getKey) - .checkField( - "result_completion_id", - Protocol.GetLazyStateCommandMessage::getResultCompletionId) - .verify(); - } - }; - CommandAccessor GET_EAGER_STATE_KEYS = - new CommandAccessor<>() { - @Override - public String getName(Protocol.GetEagerStateKeysCommandMessage expected) { - return expected.getName(); - } - - @Override - public void checkEntryHeader( - int commandIndex, Protocol.GetEagerStateKeysCommandMessage expected, MessageLite actual) - throws ProtocolException { - EntryHeaderChecker.check( - commandIndex, Protocol.GetEagerStateKeysCommandMessage.class, expected, actual) - .checkField("name", Protocol.GetEagerStateKeysCommandMessage::getName) - .verify(); - } - }; - CommandAccessor GET_LAZY_STATE_KEYS = - new CommandAccessor<>() { - @Override - public String getName(Protocol.GetLazyStateKeysCommandMessage expected) { - return expected.getName(); - } - - @Override - public void checkEntryHeader( - int commandIndex, Protocol.GetLazyStateKeysCommandMessage expected, MessageLite actual) - throws ProtocolException { - EntryHeaderChecker.check( - commandIndex, Protocol.GetLazyStateKeysCommandMessage.class, expected, actual) - .checkField("name", Protocol.GetLazyStateKeysCommandMessage::getName) - .verify(); - } - }; - CommandAccessor CLEAR_STATE = - new CommandAccessor<>() { - @Override - public String getName(Protocol.ClearStateCommandMessage expected) { - return expected.getName(); - } - - @Override - public void checkEntryHeader( - int commandIndex, Protocol.ClearStateCommandMessage expected, MessageLite actual) - throws ProtocolException { - EntryHeaderChecker.check( - commandIndex, Protocol.ClearStateCommandMessage.class, expected, actual) - .checkField("name", Protocol.ClearStateCommandMessage::getName) - .checkField("key", Protocol.ClearStateCommandMessage::getKey) - .verify(); - } - }; - CommandAccessor CLEAR_ALL_STATE = - new CommandAccessor<>() { - @Override - public String getName(Protocol.ClearAllStateCommandMessage expected) { - return expected.getName(); - } - - @Override - public void checkEntryHeader( - int commandIndex, Protocol.ClearAllStateCommandMessage expected, MessageLite actual) - throws ProtocolException { - EntryHeaderChecker.check( - commandIndex, Protocol.ClearAllStateCommandMessage.class, expected, actual) - .checkField("name", Protocol.ClearAllStateCommandMessage::getName) - .verify(); - } - }; - CommandAccessor SET_STATE = - new CommandAccessor<>() { - @Override - public String getName(Protocol.SetStateCommandMessage expected) { - return expected.getName(); - } - - @Override - public void checkEntryHeader( - int commandIndex, Protocol.SetStateCommandMessage expected, MessageLite actual) - throws ProtocolException { - EntryHeaderChecker.check( - commandIndex, Protocol.SetStateCommandMessage.class, expected, actual) - .checkField("name", Protocol.SetStateCommandMessage::getName) - .checkField("key", Protocol.SetStateCommandMessage::getKey) - .checkField("value", Protocol.SetStateCommandMessage::getValue) - .verify(); - } - }; - - CommandAccessor SLEEP = - new CommandAccessor<>() { - @Override - public String getName(Protocol.SleepCommandMessage expected) { - return expected.getName(); - } - - @Override - public void checkEntryHeader( - int commandIndex, Protocol.SleepCommandMessage expected, MessageLite actual) - throws ProtocolException { - EntryHeaderChecker.check( - commandIndex, Protocol.SleepCommandMessage.class, expected, actual) - .checkField("name", Protocol.SleepCommandMessage::getName) - .checkField( - "result_completion_id", Protocol.SleepCommandMessage::getResultCompletionId) - .verify(); - } - }; - - CommandAccessor CALL = - new CommandAccessor<>() { - @Override - public String getName(Protocol.CallCommandMessage expected) { - return expected.getName(); - } - - @Override - public void checkEntryHeader( - int commandIndex, Protocol.CallCommandMessage expected, MessageLite actual) - throws ProtocolException { - EntryHeaderChecker.check( - commandIndex, Protocol.CallCommandMessage.class, expected, actual) - .checkField("name", Protocol.CallCommandMessage::getName) - .checkField("service_name", Protocol.CallCommandMessage::getServiceName) - .checkField("handler_name", Protocol.CallCommandMessage::getHandlerName) - .checkField("parameter", Protocol.CallCommandMessage::getParameter) - .checkField("key", Protocol.CallCommandMessage::getKey) - .checkField("idempotency_key", Protocol.CallCommandMessage::getIdempotencyKey) - .checkField("headers", Protocol.CallCommandMessage::getHeadersList) - .checkField( - "invocation_id_notification_idx", - Protocol.CallCommandMessage::getInvocationIdNotificationIdx) - .checkField( - "result_completion_id", Protocol.CallCommandMessage::getResultCompletionId) - .verify(); - } - }; - CommandAccessor ONE_WAY_CALL = - new CommandAccessor<>() { - @Override - public String getName(Protocol.OneWayCallCommandMessage expected) { - return ""; - } - - @Override - public void checkEntryHeader( - int commandIndex, Protocol.OneWayCallCommandMessage expected, MessageLite actual) - throws ProtocolException { - EntryHeaderChecker.check( - commandIndex, Protocol.OneWayCallCommandMessage.class, expected, actual) - .checkField("name", Protocol.OneWayCallCommandMessage::getName) - .checkField("service_name", Protocol.OneWayCallCommandMessage::getServiceName) - .checkField("handler_name", Protocol.OneWayCallCommandMessage::getHandlerName) - .checkField("parameter", Protocol.OneWayCallCommandMessage::getParameter) - .checkField("key", Protocol.OneWayCallCommandMessage::getKey) - .checkField("headers", Protocol.OneWayCallCommandMessage::getHeadersList) - .checkField("idempotency_key", Protocol.OneWayCallCommandMessage::getIdempotencyKey) - .checkField( - "invocation_id_notification_idx", - Protocol.OneWayCallCommandMessage::getInvocationIdNotificationIdx) - .verify(); - } - }; - - CommandAccessor COMPLETE_AWAKEABLE = - new CommandAccessor<>() { - @Override - public String getName(Protocol.CompleteAwakeableCommandMessage expected) { - return expected.getName(); - } - - @Override - public void checkEntryHeader( - int commandIndex, Protocol.CompleteAwakeableCommandMessage expected, MessageLite actual) - throws ProtocolException { - EntryHeaderChecker.check( - commandIndex, Protocol.CompleteAwakeableCommandMessage.class, expected, actual) - .checkField("name", Protocol.CompleteAwakeableCommandMessage::getName) - .checkField("awakeable_id", Protocol.CompleteAwakeableCommandMessage::getAwakeableId) - .checkField("result", Protocol.CompleteAwakeableCommandMessage::getResultCase) - .checkField("value", Protocol.CompleteAwakeableCommandMessage::getValue) - .checkField("failure", Protocol.CompleteAwakeableCommandMessage::getFailure) - .verify(); - } - }; - CommandAccessor RUN = - new CommandAccessor<>() { - @Override - public String getName(Protocol.RunCommandMessage expected) { - return expected.getName(); - } - - @Override - public void checkEntryHeader( - int commandIndex, Protocol.RunCommandMessage expected, MessageLite actual) - throws ProtocolException { - EntryHeaderChecker.check(commandIndex, Protocol.RunCommandMessage.class, expected, actual) - .checkField("name", Protocol.RunCommandMessage::getName) - .verify(); - } - }; - - CommandAccessor GET_PROMISE = - new CommandAccessor<>() { - @Override - public String getName(Protocol.GetPromiseCommandMessage expected) { - return expected.getName(); - } - - @Override - public void checkEntryHeader( - int commandIndex, Protocol.GetPromiseCommandMessage expected, MessageLite actual) - throws ProtocolException { - EntryHeaderChecker.check( - commandIndex, Protocol.GetPromiseCommandMessage.class, expected, actual) - .checkField("name", Protocol.GetPromiseCommandMessage::getName) - .checkField("key", Protocol.GetPromiseCommandMessage::getKey) - .checkField( - "result_completion_id", Protocol.GetPromiseCommandMessage::getResultCompletionId) - .verify(); - } - }; - CommandAccessor PEEK_PROMISE = - new CommandAccessor<>() { - @Override - public String getName(Protocol.PeekPromiseCommandMessage expected) { - return expected.getName(); - } - - @Override - public void checkEntryHeader( - int commandIndex, Protocol.PeekPromiseCommandMessage expected, MessageLite actual) - throws ProtocolException { - EntryHeaderChecker.check( - commandIndex, Protocol.PeekPromiseCommandMessage.class, expected, actual) - .checkField("name", Protocol.PeekPromiseCommandMessage::getName) - .checkField("key", Protocol.PeekPromiseCommandMessage::getKey) - .checkField( - "result_completion_id", Protocol.PeekPromiseCommandMessage::getResultCompletionId) - .verify(); - } - }; - CommandAccessor COMPLETE_PROMISE = - new CommandAccessor<>() { - @Override - public String getName(Protocol.CompletePromiseCommandMessage expected) { - return expected.getName(); - } - - @Override - public void checkEntryHeader( - int commandIndex, Protocol.CompletePromiseCommandMessage expected, MessageLite actual) - throws ProtocolException { - EntryHeaderChecker.check( - commandIndex, Protocol.CompletePromiseCommandMessage.class, expected, actual) - .checkField("name", Protocol.CompletePromiseCommandMessage::getName) - .checkField("key", Protocol.CompletePromiseCommandMessage::getKey) - .checkField( - "result_completion_id", - Protocol.CompletePromiseCommandMessage::getResultCompletionId) - .checkField("completion", Protocol.CompletePromiseCommandMessage::getCompletionCase) - .checkField( - "completionValue", Protocol.CompletePromiseCommandMessage::getCompletionValue) - .checkField( - "completionFailure", Protocol.CompletePromiseCommandMessage::getCompletionFailure) - .verify(); - } - }; - - CommandAccessor SEND_SIGNAL = - new CommandAccessor<>() { - @Override - public String getName(Protocol.SendSignalCommandMessage expected) { - return expected.getEntryName(); - } - - @Override - public void checkEntryHeader( - int commandIndex, Protocol.SendSignalCommandMessage expected, MessageLite actual) - throws ProtocolException { - EntryHeaderChecker.check( - commandIndex, Protocol.SendSignalCommandMessage.class, expected, actual) - .checkField("entry_name", Protocol.SendSignalCommandMessage::getEntryName) - .checkField( - "target_invocation_id", Protocol.SendSignalCommandMessage::getTargetInvocationId) - .checkField("signal_id", Protocol.SendSignalCommandMessage::getSignalIdCase) - .checkField("result", Protocol.SendSignalCommandMessage::getResultCase) - .checkField("void", Protocol.SendSignalCommandMessage::getVoid) - .checkField("value", Protocol.SendSignalCommandMessage::getValue) - .checkField("failure", Protocol.SendSignalCommandMessage::getFailure) - .verify(); - } - }; - - CommandAccessor ATTACH_INVOCATION = - new CommandAccessor<>() { - @Override - public String getName(Protocol.AttachInvocationCommandMessage expected) { - return expected.getName(); - } - - @Override - public void checkEntryHeader( - int commandIndex, Protocol.AttachInvocationCommandMessage expected, MessageLite actual) - throws ProtocolException { - EntryHeaderChecker.check( - commandIndex, Protocol.AttachInvocationCommandMessage.class, expected, actual) - .checkField("name", Protocol.AttachInvocationCommandMessage::getName) - .checkField("invocation_id", Protocol.AttachInvocationCommandMessage::getInvocationId) - .checkField( - "result_completion_id", - Protocol.AttachInvocationCommandMessage::getResultCompletionId) - .verify(); - } - }; - - CommandAccessor GET_INVOCATION_OUTPUT = - new CommandAccessor<>() { - @Override - public String getName(Protocol.GetInvocationOutputCommandMessage expected) { - return expected.getName(); - } - - @Override - public void checkEntryHeader( - int commandIndex, - Protocol.GetInvocationOutputCommandMessage expected, - MessageLite actual) - throws ProtocolException { - EntryHeaderChecker.check( - commandIndex, Protocol.GetInvocationOutputCommandMessage.class, expected, actual) - .checkField("name", Protocol.GetInvocationOutputCommandMessage::getName) - .checkField( - "invocation_id", Protocol.GetInvocationOutputCommandMessage::getInvocationId) - .checkField( - "result_completion_id", - Protocol.GetInvocationOutputCommandMessage::getResultCompletionId) - .verify(); - } - }; -} diff --git a/sdk-core/src/main/java/dev/restate/sdk/core/statemachine/CommandMetadata.java b/sdk-core/src/main/java/dev/restate/sdk/core/statemachine/CommandMetadata.java deleted file mode 100644 index d9280fbb9..000000000 --- a/sdk-core/src/main/java/dev/restate/sdk/core/statemachine/CommandMetadata.java +++ /dev/null @@ -1,14 +0,0 @@ -// Copyright (c) 2023 - Restate Software, Inc., Restate GmbH -// -// This file is part of the Restate Java SDK, -// which is released under the MIT license. -// -// You can find a copy of the license in file LICENSE in the root -// directory of this repository or package, or at -// https://github.com/restatedev/sdk-java/blob/main/LICENSE -package dev.restate.sdk.core.statemachine; - -import org.jspecify.annotations.Nullable; - -/** Metadata about a command. */ -record CommandMetadata(int index, MessageType type, @Nullable String name) {} diff --git a/sdk-core/src/main/java/dev/restate/sdk/core/statemachine/CommandRelationship.java b/sdk-core/src/main/java/dev/restate/sdk/core/statemachine/CommandRelationship.java deleted file mode 100644 index ea4e853a8..000000000 --- a/sdk-core/src/main/java/dev/restate/sdk/core/statemachine/CommandRelationship.java +++ /dev/null @@ -1,26 +0,0 @@ -// Copyright (c) 2023 - Restate Software, Inc., Restate GmbH -// -// This file is part of the Restate Java SDK, -// which is released under the MIT license. -// -// You can find a copy of the license in file LICENSE in the root -// directory of this repository or package, or at -// https://github.com/restatedev/sdk-java/blob/main/LICENSE -package dev.restate.sdk.core.statemachine; - -import org.jspecify.annotations.Nullable; - -/** Used in `hitError` to specify which command this error relates to. */ -sealed interface CommandRelationship { - /** The error is related to the last command. */ - record Last() implements CommandRelationship { - public static final Last INSTANCE = new Last(); - } - - /** The error is related to the next command of the specified type. */ - record Next(CommandType type, @Nullable String name) implements CommandRelationship {} - - /** The error is related to a specific command. */ - record Specific(int commandIndex, CommandType type, @Nullable String name) - implements CommandRelationship {} -} diff --git a/sdk-core/src/main/java/dev/restate/sdk/core/statemachine/CommandType.java b/sdk-core/src/main/java/dev/restate/sdk/core/statemachine/CommandType.java deleted file mode 100644 index c983bed60..000000000 --- a/sdk-core/src/main/java/dev/restate/sdk/core/statemachine/CommandType.java +++ /dev/null @@ -1,56 +0,0 @@ -// Copyright (c) 2023 - Restate Software, Inc., Restate GmbH -// -// This file is part of the Restate Java SDK, -// which is released under the MIT license. -// -// You can find a copy of the license in file LICENSE in the root -// directory of this repository or package, or at -// https://github.com/restatedev/sdk-java/blob/main/LICENSE -package dev.restate.sdk.core.statemachine; - -/** Enum representing the type of command. This is used for error reporting. */ -enum CommandType { - INPUT, - OUTPUT, - GET_STATE, - GET_STATE_KEYS, - SET_STATE, - CLEAR_STATE, - CLEAR_ALL_STATE, - GET_PROMISE, - PEEK_PROMISE, - COMPLETE_PROMISE, - SLEEP, - CALL, - ONE_WAY_CALL, - SEND_SIGNAL, - RUN, - ATTACH_INVOCATION, - GET_INVOCATION_OUTPUT, - COMPLETE_AWAKEABLE, - CANCEL_INVOCATION; - - /** Convert a CommandType to a MessageType. */ - public MessageType toMessageType() { - return switch (this) { - case INPUT -> MessageType.InputCommandMessage; - case OUTPUT -> MessageType.OutputCommandMessage; - case GET_STATE -> MessageType.GetLazyStateCommandMessage; - case GET_STATE_KEYS -> MessageType.GetLazyStateKeysCommandMessage; - case SET_STATE -> MessageType.SetStateCommandMessage; - case CLEAR_STATE -> MessageType.ClearStateCommandMessage; - case CLEAR_ALL_STATE -> MessageType.ClearAllStateCommandMessage; - case GET_PROMISE -> MessageType.GetPromiseCommandMessage; - case PEEK_PROMISE -> MessageType.PeekPromiseCommandMessage; - case COMPLETE_PROMISE -> MessageType.CompletePromiseCommandMessage; - case SLEEP -> MessageType.SleepCommandMessage; - case CALL -> MessageType.CallCommandMessage; - case ONE_WAY_CALL -> MessageType.OneWayCallCommandMessage; - case SEND_SIGNAL, CANCEL_INVOCATION -> MessageType.SendSignalCommandMessage; - case RUN -> MessageType.RunCommandMessage; - case ATTACH_INVOCATION -> MessageType.AttachInvocationCommandMessage; - case GET_INVOCATION_OUTPUT -> MessageType.GetInvocationOutputCommandMessage; - case COMPLETE_AWAKEABLE -> MessageType.CompleteAwakeableCommandMessage; - }; - } -} diff --git a/sdk-core/src/main/java/dev/restate/sdk/core/statemachine/EagerState.java b/sdk-core/src/main/java/dev/restate/sdk/core/statemachine/EagerState.java deleted file mode 100644 index 419bad194..000000000 --- a/sdk-core/src/main/java/dev/restate/sdk/core/statemachine/EagerState.java +++ /dev/null @@ -1,61 +0,0 @@ -// Copyright (c) 2023 - Restate Software, Inc., Restate GmbH -// -// This file is part of the Restate Java SDK, -// which is released under the MIT license. -// -// You can find a copy of the license in file LICENSE in the root -// directory of this repository or package, or at -// https://github.com/restatedev/sdk-java/blob/main/LICENSE -package dev.restate.sdk.core.statemachine; - -import com.google.protobuf.ByteString; -import dev.restate.common.Slice; -import dev.restate.sdk.core.generated.protocol.Protocol; -import java.util.HashMap; -import java.util.Set; -import org.jspecify.annotations.Nullable; - -final class EagerState { - - private boolean isPartial; - private final HashMap map; - - EagerState(Protocol.StartMessage startMessage) { - this.isPartial = startMessage.getPartialState(); - this.map = new HashMap<>(startMessage.getStateMapCount()); - for (int i = 0; i < startMessage.getStateMapCount(); i++) { - Protocol.StartMessage.StateEntry entry = startMessage.getStateMap(i); - this.map.put( - entry.getKey(), - new NotificationValue.Success(Slice.wrap(entry.getValue().asReadOnlyByteBuffer()))); - } - } - - public @Nullable NotificationValue get(ByteString key) { - return this.map.getOrDefault(key, isComplete() ? NotificationValue.Empty.INSTANCE : null); - } - - public void set(ByteString key, Slice value) { - this.map.put(key, new NotificationValue.Success(value)); - } - - public void clear(ByteString key) { - this.map.put(key, NotificationValue.Empty.INSTANCE); - } - - public void clearAll() { - this.map.clear(); - this.isPartial = false; - } - - public boolean isComplete() { - return !isPartial; - } - - public @Nullable Set keys() { - if (isComplete()) { - return this.map.keySet(); - } - return null; - } -} diff --git a/sdk-core/src/main/java/dev/restate/sdk/core/statemachine/EntryHeaderChecker.java b/sdk-core/src/main/java/dev/restate/sdk/core/statemachine/EntryHeaderChecker.java deleted file mode 100644 index 245fcd7df..000000000 --- a/sdk-core/src/main/java/dev/restate/sdk/core/statemachine/EntryHeaderChecker.java +++ /dev/null @@ -1,124 +0,0 @@ -// Copyright (c) 2023 - Restate Software, Inc., Restate GmbH -// -// This file is part of the Restate Java SDK, -// which is released under the MIT license. -// -// You can find a copy of the license in file LICENSE in the root -// directory of this repository or package, or at -// https://github.com/restatedev/sdk-java/blob/main/LICENSE -package dev.restate.sdk.core.statemachine; - -import static dev.restate.sdk.core.ProtocolException.JOURNAL_MISMATCH_CODE; - -import com.google.protobuf.MessageLite; -import dev.restate.sdk.core.ProtocolException; -import java.util.ArrayList; -import java.util.List; -import java.util.Objects; -import java.util.function.Function; - -/** - * A builder-style utility for checking entry headers and generating detailed error messages. This - * class encapsulates the logic for validating message fields and throwing appropriate protocol - * exceptions when mismatches are found. - */ -final class EntryHeaderChecker { - private final int commandIndex; - private final E expected; - private final E actual; - private List mismatches; - - private EntryHeaderChecker(int commandIndex, E expected, E actual) { - this.commandIndex = commandIndex; - this.expected = expected; - this.actual = actual; - } - - /** - * Creates a new EntryHeaderChecker for the given expected and actual messages. - * - * @param The type of the expected message - * @param commandIndex The index of this command - * @param expected The expected message - * @param actual The actual message - * @return A new EntryHeaderChecker - */ - @SuppressWarnings("unchecked") - public static EntryHeaderChecker check( - int commandIndex, Class expectedClass, E expected, MessageLite actual) { - if (!expectedClass.isInstance(actual)) { - throw new ProtocolException( - "Found a mismatch between the code paths taken during the previous execution and the paths taken during this execution.\n" - + "This typically happens when some parts of the code are non-deterministic.\n" - + "- Expecting command '" - + Util.commandMessageToString(expected) - + "' (index " - + commandIndex - + ") but was '" - + Util.commandMessageToString(actual) - + "'", - JOURNAL_MISMATCH_CODE); - } - return new EntryHeaderChecker<>(commandIndex, expected, (E) actual); - } - - /** - * Checks that a field in the expected and actual messages match. - * - * @param fieldName The name of the field being checked - * @param getter Function to extract the field value from the message - * @param The type of the field - * @return This EntryHeaderChecker for method chaining - */ - public EntryHeaderChecker checkField(String fieldName, Function getter) { - T expectedValue = getter.apply(expected); - T actualValue = getter.apply(actual); - - if (!Objects.equals(expectedValue, actualValue)) { - if (mismatches == null) { - mismatches = new ArrayList<>(); - } - mismatches.add(new FieldMismatch(fieldName, expectedValue, actualValue)); - } - - return this; - } - - /** - * Verifies all checks and throws a ProtocolException if any mismatches were found. - * - * @throws ProtocolException if any mismatches were found - */ - public void verify() throws ProtocolException { - if (mismatches != null && !mismatches.isEmpty()) { - throw createMismatchException(); - } - } - - private ProtocolException createMismatchException() { - StringBuilder customMessage = - new StringBuilder( - "Found a mismatch between the code paths taken during the previous execution and the paths taken during this execution.\n" - + "This typically happens when some parts of the code are non-deterministic.\n" - + "- The mismatch happened while executing '" - + Util.commandMessageToString(expected) - + " (index " - + commandIndex - + ")'\n" - + "- Difference:"); - for (FieldMismatch mismatch : mismatches) { - customMessage - .append("\n ") - .append(mismatch.fieldName) - .append(": '") - .append(mismatch.expectedValue) - .append("' != '") - .append(mismatch.actualValue) - .append("'"); - } - - return new ProtocolException(customMessage.toString(), JOURNAL_MISMATCH_CODE); - } - - private record FieldMismatch(String fieldName, Object expectedValue, Object actualValue) {} -} diff --git a/sdk-core/src/main/java/dev/restate/sdk/core/statemachine/InvocationInput.java b/sdk-core/src/main/java/dev/restate/sdk/core/statemachine/InvocationInput.java deleted file mode 100644 index fafab731a..000000000 --- a/sdk-core/src/main/java/dev/restate/sdk/core/statemachine/InvocationInput.java +++ /dev/null @@ -1,36 +0,0 @@ -// Copyright (c) 2023 - Restate Software, Inc., Restate GmbH -// -// This file is part of the Restate Java SDK, -// which is released under the MIT license. -// -// You can find a copy of the license in file LICENSE in the root -// directory of this repository or package, or at -// https://github.com/restatedev/sdk-java/blob/main/LICENSE -package dev.restate.sdk.core.statemachine; - -import com.google.protobuf.MessageLite; - -public interface InvocationInput { - MessageHeader header(); - - MessageLite message(); - - static InvocationInput of(MessageHeader header, MessageLite message) { - return new InvocationInput() { - @Override - public MessageHeader header() { - return header; - } - - @Override - public MessageLite message() { - return message; - } - - @Override - public String toString() { - return header.toString() + " " + message.toString(); - } - }; - } -} diff --git a/sdk-core/src/main/java/dev/restate/sdk/core/statemachine/Journal.java b/sdk-core/src/main/java/dev/restate/sdk/core/statemachine/Journal.java deleted file mode 100644 index a9b64a7d2..000000000 --- a/sdk-core/src/main/java/dev/restate/sdk/core/statemachine/Journal.java +++ /dev/null @@ -1,93 +0,0 @@ -// Copyright (c) 2023 - Restate Software, Inc., Restate GmbH -// -// This file is part of the Restate Java SDK, -// which is released under the MIT license. -// -// You can find a copy of the license in file LICENSE in the root -// directory of this repository or package, or at -// https://github.com/restatedev/sdk-java/blob/main/LICENSE -package dev.restate.sdk.core.statemachine; - -import com.google.protobuf.MessageLite; - -class Journal { - private int commandIndex; - private int notificationIndex; - private int completionIndex; - private int signalIndex; - private MessageType currentEntryTy; - private String currentEntryName; - - Journal() { - this.commandIndex = -1; - this.notificationIndex = -1; - // Clever trick for protobuf here - this.completionIndex = 1; - // 1 to 16 are reserved! - this.signalIndex = 17; - this.currentEntryTy = MessageType.StartMessage; - this.currentEntryName = ""; - } - - public void commandTransition(String entryName, MessageLite expected) { - this.commandIndex++; - this.currentEntryName = entryName; - this.currentEntryTy = MessageType.fromMessage(expected); - } - - public void notificationTransition(MessageLite expected) { - this.notificationIndex++; - this.currentEntryName = ""; - this.currentEntryTy = null; - } - - public int getCommandIndex() { - return this.commandIndex; - } - - public MessageType getCurrentEntryTy() { - return currentEntryTy; - } - - public String getCurrentEntryName() { - return currentEntryName; - } - - public int getNotificationIndex() { - return this.notificationIndex; - } - - public int nextCompletionNotificationId() { - int next = this.completionIndex; - this.completionIndex++; - return next; - } - - public int nextSignalNotificationId() { - int next = this.signalIndex; - this.signalIndex++; - return next; - } - - /** Resolve a command relationship to a command metadata. */ - public CommandMetadata resolveRelatedCommand(CommandRelationship relationship) { - if (relationship instanceof CommandRelationship.Last) { - return lastCommandMetadata(); - } else if (relationship instanceof CommandRelationship.Next next) { - return new CommandMetadata(this.commandIndex + 1, next.type().toMessageType(), next.name()); - } else if (relationship instanceof CommandRelationship.Specific specific) { - return new CommandMetadata( - specific.commandIndex(), specific.type().toMessageType(), specific.name()); - } else { - throw new IllegalArgumentException("Unknown command relationship type: " + relationship); - } - } - - /** Get the metadata for the last command. */ - public CommandMetadata lastCommandMetadata() { - return new CommandMetadata( - this.commandIndex, - this.currentEntryTy, - this.currentEntryName.isEmpty() ? null : this.currentEntryName); - } -} diff --git a/sdk-core/src/main/java/dev/restate/sdk/core/statemachine/MessageDecoder.java b/sdk-core/src/main/java/dev/restate/sdk/core/statemachine/MessageDecoder.java deleted file mode 100644 index abc518fcb..000000000 --- a/sdk-core/src/main/java/dev/restate/sdk/core/statemachine/MessageDecoder.java +++ /dev/null @@ -1,130 +0,0 @@ -// Copyright (c) 2023 - Restate Software, Inc., Restate GmbH -// -// This file is part of the Restate Java SDK, -// which is released under the MIT license. -// -// You can find a copy of the license in file LICENSE in the root -// directory of this repository or package, or at -// https://github.com/restatedev/sdk-java/blob/main/LICENSE -package dev.restate.sdk.core.statemachine; - -import com.google.protobuf.ByteString; -import com.google.protobuf.InvalidProtocolBufferException; -import com.google.protobuf.UnsafeByteOperations; -import dev.restate.common.Slice; -import java.util.ArrayDeque; -import java.util.Queue; -import org.jspecify.annotations.Nullable; - -public final class MessageDecoder { - - private enum State { - WAITING_HEADER, - WAITING_PAYLOAD, - FAILED - } - - private final Queue parsedMessages; - private ByteString internalBuffer; - - private State state; - private MessageHeader lastParsedMessageHeader; - private RuntimeException lastParsingFailure; - - public MessageDecoder() { - this.parsedMessages = new ArrayDeque<>(); - this.internalBuffer = ByteString.EMPTY; - - this.state = State.WAITING_HEADER; - this.lastParsedMessageHeader = null; - this.lastParsingFailure = null; - } - - // -- Subscriber methods - - public void offer(Slice item) { - this.offer(UnsafeByteOperations.unsafeWrap(item.asReadOnlyByteBuffer())); - } - - public @Nullable InvocationInput next() { - if (this.state == State.FAILED) { - throw lastParsingFailure; - } - return this.parsedMessages.poll(); - } - - public boolean isNextAvailable() { - return !this.parsedMessages.isEmpty(); - } - - // -- Internal methods to handle decoding - - private void offer(ByteString buffer) { - if (this.state != State.FAILED) { - this.internalBuffer = this.internalBuffer.concat(buffer); - this.tryConsumeInternalBuffer(); - } - } - - private void tryConsumeInternalBuffer() { - while (this.state != State.FAILED && this.internalBuffer.size() >= wantBytes()) { - if (state == State.WAITING_HEADER) { - try { - this.lastParsedMessageHeader = MessageHeader.parse(readLongAtBeginning()); - this.state = State.WAITING_PAYLOAD; - this.sliceInternalBuffer(8); - } catch (RuntimeException e) { - this.lastParsingFailure = e; - this.state = State.FAILED; - } - } else { - try { - this.parsedMessages.offer( - InvocationInput.of( - this.lastParsedMessageHeader, - this.lastParsedMessageHeader - .getType() - .messageParser() - .parseFrom( - this.internalBuffer.substring( - 0, this.lastParsedMessageHeader.getLength())))); - this.state = State.WAITING_HEADER; - this.sliceInternalBuffer(this.lastParsedMessageHeader.getLength()); - } catch (InvalidProtocolBufferException e) { - this.lastParsingFailure = new RuntimeException("Cannot parse the protobuf message", e); - this.state = State.FAILED; - } catch (RuntimeException e) { - this.lastParsingFailure = e; - this.state = State.FAILED; - } - } - } - } - - private int wantBytes() { - if (state == State.WAITING_HEADER) { - return 8; - } else { - return lastParsedMessageHeader.getLength(); - } - } - - private void sliceInternalBuffer(int substring) { - if (this.internalBuffer.size() == substring) { - this.internalBuffer = ByteString.EMPTY; - } else { - this.internalBuffer = this.internalBuffer.substring(substring); - } - } - - private long readLongAtBeginning() { - return ((this.internalBuffer.byteAt(7) & 0xffL) - | ((this.internalBuffer.byteAt(6) & 0xffL) << 8) - | ((this.internalBuffer.byteAt(5) & 0xffL) << 16) - | ((this.internalBuffer.byteAt(4) & 0xffL) << 24) - | ((this.internalBuffer.byteAt(3) & 0xffL) << 32) - | ((this.internalBuffer.byteAt(2) & 0xffL) << 40) - | ((this.internalBuffer.byteAt(1) & 0xffL) << 48) - | ((this.internalBuffer.byteAt(0) & 0xffL) << 56)); - } -} diff --git a/sdk-core/src/main/java/dev/restate/sdk/core/statemachine/MessageEncoder.java b/sdk-core/src/main/java/dev/restate/sdk/core/statemachine/MessageEncoder.java deleted file mode 100644 index f98b79db9..000000000 --- a/sdk-core/src/main/java/dev/restate/sdk/core/statemachine/MessageEncoder.java +++ /dev/null @@ -1,61 +0,0 @@ -// Copyright (c) 2023 - Restate Software, Inc., Restate GmbH -// -// This file is part of the Restate Java SDK, -// which is released under the MIT license. -// -// You can find a copy of the license in file LICENSE in the root -// directory of this repository or package, or at -// https://github.com/restatedev/sdk-java/blob/main/LICENSE -package dev.restate.sdk.core.statemachine; - -import com.google.protobuf.MessageLite; -import dev.restate.common.Slice; -import java.nio.ByteBuffer; -import java.util.concurrent.Flow; - -final class MessageEncoder implements Flow.Subscriber { - - private final Flow.Subscriber inner; - - MessageEncoder(Flow.Subscriber inner) { - this.inner = inner; - } - - @Override - public void onSubscribe(Flow.Subscription subscription) { - inner.onSubscribe(subscription); - } - - @Override - public void onNext(MessageLite item) { - // We could pool those buffers somehow? - ByteBuffer buffer = ByteBuffer.allocate(MessageEncoder.encodeLength(item)); - MessageEncoder.encode(buffer, item); - inner.onNext(Slice.wrap(buffer)); - } - - @Override - public void onError(Throwable throwable) { - inner.onError(throwable); - } - - @Override - public void onComplete() { - inner.onComplete(); - } - - static int encodeLength(MessageLite msg) { - return 8 + msg.getSerializedSize(); - } - - static ByteBuffer encode(ByteBuffer buffer, MessageLite msg) { - MessageHeader header = MessageHeader.fromMessage(msg); - - buffer.putLong(header.encode()); - buffer.put(msg.toByteString().asReadOnlyByteBuffer()); - - buffer.flip(); - - return buffer; - } -} diff --git a/sdk-core/src/main/java/dev/restate/sdk/core/statemachine/MessageHeader.java b/sdk-core/src/main/java/dev/restate/sdk/core/statemachine/MessageHeader.java deleted file mode 100644 index c5a9e6123..000000000 --- a/sdk-core/src/main/java/dev/restate/sdk/core/statemachine/MessageHeader.java +++ /dev/null @@ -1,53 +0,0 @@ -// Copyright (c) 2023 - Restate Software, Inc., Restate GmbH -// -// This file is part of the Restate Java SDK, -// which is released under the MIT license. -// -// You can find a copy of the license in file LICENSE in the root -// directory of this repository or package, or at -// https://github.com/restatedev/sdk-java/blob/main/LICENSE -package dev.restate.sdk.core.statemachine; - -import com.google.protobuf.MessageLite; -import dev.restate.sdk.core.ProtocolException; - -public class MessageHeader { - - private final MessageType type; - private final int flags; - private final int length; - - public MessageHeader(MessageType type, int flags, int length) { - this.type = type; - this.flags = flags; - this.length = length; - } - - public MessageType getType() { - return type; - } - - public int getLength() { - return length; - } - - public long encode() { - long res = 0L; - res |= ((long) type.encode() << 48); - res |= ((long) flags << 32); - res |= length; - return res; - } - - public static MessageHeader parse(long encoded) throws ProtocolException { - var ty_code = (short) (encoded >> 48); - var flags = (short) (encoded >> 32); - var len = (int) encoded; - - return new MessageHeader(MessageType.decode(ty_code), flags, len); - } - - public static MessageHeader fromMessage(MessageLite msg) { - return new MessageHeader(MessageType.fromMessage(msg), 0, msg.getSerializedSize()); - } -} diff --git a/sdk-core/src/main/java/dev/restate/sdk/core/statemachine/MessageType.java b/sdk-core/src/main/java/dev/restate/sdk/core/statemachine/MessageType.java deleted file mode 100644 index b4dd6e10f..000000000 --- a/sdk-core/src/main/java/dev/restate/sdk/core/statemachine/MessageType.java +++ /dev/null @@ -1,361 +0,0 @@ -// Copyright (c) 2023 - Restate Software, Inc., Restate GmbH -// -// This file is part of the Restate Java SDK, -// which is released under the MIT license. -// -// You can find a copy of the license in file LICENSE in the root -// directory of this repository or package, or at -// https://github.com/restatedev/sdk-java/blob/main/LICENSE -package dev.restate.sdk.core.statemachine; - -import com.google.protobuf.MessageLite; -import com.google.protobuf.Parser; -import dev.restate.sdk.core.ProtocolException; -import dev.restate.sdk.core.generated.protocol.Protocol; - -public enum MessageType { - StartMessage, - SuspensionMessage, - ErrorMessage, - EndMessage, - ProposeRunCompletionMessage, - - InputCommandMessage, - OutputCommandMessage, - GetLazyStateCommandMessage, - GetLazyStateCompletionNotificationMessage, - SetStateCommandMessage, - ClearStateCommandMessage, - ClearAllStateCommandMessage, - GetLazyStateKeysCommandMessage, - GetLazyStateKeysCompletionNotificationMessage, - GetEagerStateCommandMessage, - GetEagerStateKeysCommandMessage, - GetPromiseCommandMessage, - GetPromiseCompletionNotificationMessage, - PeekPromiseCommandMessage, - PeekPromiseCompletionNotificationMessage, - CompletePromiseCommandMessage, - CompletePromiseCompletionNotificationMessage, - SleepCommandMessage, - SleepCompletionNotificationMessage, - CallCommandMessage, - CallInvocationIdCompletionNotificationMessage, - CallCompletionNotificationMessage, - OneWayCallCommandMessage, - SendSignalCommandMessage, - RunCommandMessage, - RunCompletionNotificationMessage, - AttachInvocationCommandMessage, - AttachInvocationCompletionNotificationMessage, - GetInvocationOutputCommandMessage, - GetInvocationOutputCompletionNotificationMessage, - CompleteAwakeableCommandMessage, - SignalNotificationMessage; - - public static final short StartMessage_TYPE = (short) 0x0000; - public static final short SuspensionMessage_TYPE = (short) 0x0001; - public static final short ErrorMessage_TYPE = (short) 0x0002; - public static final short EndMessage_TYPE = (short) 0x0003; - public static final short ProposeRunCompletionMessage_TYPE = (short) 0x0005; - public static final short InputCommandMessage_TYPE = (short) 0x0400; - public static final short OutputCommandMessage_TYPE = (short) 0x0401; - public static final short GetLazyStateCommandMessage_TYPE = (short) 0x0402; - public static final short GetLazyStateCompletionNotificationMessage_TYPE = (short) 0x8002; - public static final short SetStateCommandMessage_TYPE = (short) 0x0403; - public static final short ClearStateCommandMessage_TYPE = (short) 0x0404; - public static final short ClearAllStateCommandMessage_TYPE = (short) 0x0405; - public static final short GetLazyStateKeysCommandMessage_TYPE = (short) 0x0406; - public static final short GetLazyStateKeysCompletionNotificationMessage_TYPE = (short) 0x8006; - public static final short GetEagerStateCommandMessage_TYPE = (short) 0x0407; - public static final short GetEagerStateKeysCommandMessage_TYPE = (short) 0x0408; - public static final short GetPromiseCommandMessage_TYPE = (short) 0x0409; - public static final short GetPromiseCompletionNotificationMessage_TYPE = (short) 0x8009; - public static final short PeekPromiseCommandMessage_TYPE = (short) 0x040A; - public static final short PeekPromiseCompletionNotificationMessage_TYPE = (short) 0x800A; - public static final short CompletePromiseCommandMessage_TYPE = (short) 0x040B; - public static final short CompletePromiseCompletionNotificationMessage_TYPE = (short) 0x800B; - public static final short SleepCommandMessage_TYPE = (short) 0x040C; - public static final short SleepCompletionNotificationMessage_TYPE = (short) 0x800C; - public static final short CallCommandMessage_TYPE = (short) 0x040D; - public static final short CallInvocationIdCompletionNotificationMessage_TYPE = (short) 0x800E; - public static final short CallCompletionNotificationMessage_TYPE = (short) 0x800D; - public static final short OneWayCallCommandMessage_TYPE = (short) 0x040E; - public static final short SendSignalCommandMessage_TYPE = (short) 0x0410; - public static final short RunCommandMessage_TYPE = (short) 0x0411; - public static final short RunCompletionNotificationMessage_TYPE = (short) 0x8011; - public static final short AttachInvocationCommandMessage_TYPE = (short) 0x0412; - public static final short AttachInvocationCompletionNotificationMessage_TYPE = (short) 0x8012; - public static final short GetInvocationOutputCommandMessage_TYPE = (short) 0x0413; - public static final short GetInvocationOutputCompletionNotificationMessage_TYPE = (short) 0x8013; - public static final short CompleteAwakeableCommandMessage_TYPE = (short) 0x0414; - public static final short SignalNotificationMessage_TYPE = (short) 0xFBFF; - - public Parser messageParser() { - return switch (this) { - case StartMessage -> Protocol.StartMessage.parser(); - case SuspensionMessage -> Protocol.SuspensionMessage.parser(); - case ErrorMessage -> Protocol.ErrorMessage.parser(); - case EndMessage -> Protocol.EndMessage.parser(); - case ProposeRunCompletionMessage -> Protocol.ProposeRunCompletionMessage.parser(); - case InputCommandMessage -> Protocol.InputCommandMessage.parser(); - case OutputCommandMessage -> Protocol.OutputCommandMessage.parser(); - case GetLazyStateCommandMessage -> Protocol.GetLazyStateCommandMessage.parser(); - case SetStateCommandMessage -> Protocol.SetStateCommandMessage.parser(); - case ClearStateCommandMessage -> Protocol.ClearStateCommandMessage.parser(); - case ClearAllStateCommandMessage -> Protocol.ClearAllStateCommandMessage.parser(); - case GetLazyStateKeysCommandMessage -> Protocol.GetLazyStateKeysCommandMessage.parser(); - case GetEagerStateCommandMessage -> Protocol.GetEagerStateCommandMessage.parser(); - case GetEagerStateKeysCommandMessage -> Protocol.GetEagerStateKeysCommandMessage.parser(); - case GetPromiseCommandMessage -> Protocol.GetPromiseCommandMessage.parser(); - case PeekPromiseCommandMessage -> Protocol.PeekPromiseCommandMessage.parser(); - case CompletePromiseCommandMessage -> Protocol.CompletePromiseCommandMessage.parser(); - case SleepCommandMessage -> Protocol.SleepCommandMessage.parser(); - case CallCommandMessage -> Protocol.CallCommandMessage.parser(); - case OneWayCallCommandMessage -> Protocol.OneWayCallCommandMessage.parser(); - case SendSignalCommandMessage -> Protocol.SendSignalCommandMessage.parser(); - case RunCommandMessage -> Protocol.RunCommandMessage.parser(); - case AttachInvocationCommandMessage -> Protocol.AttachInvocationCommandMessage.parser(); - case GetInvocationOutputCommandMessage -> Protocol.GetInvocationOutputCommandMessage.parser(); - case CompleteAwakeableCommandMessage -> Protocol.CompleteAwakeableCommandMessage.parser(); - case GetLazyStateCompletionNotificationMessage, - SignalNotificationMessage, - GetLazyStateKeysCompletionNotificationMessage, - GetPromiseCompletionNotificationMessage, - PeekPromiseCompletionNotificationMessage, - CompletePromiseCompletionNotificationMessage, - SleepCompletionNotificationMessage, - CallInvocationIdCompletionNotificationMessage, - CallCompletionNotificationMessage, - RunCompletionNotificationMessage, - AttachInvocationCompletionNotificationMessage, - GetInvocationOutputCompletionNotificationMessage -> - Protocol.NotificationTemplate.parser(); - }; - } - - public short encode() { - return switch (this) { - case StartMessage -> StartMessage_TYPE; - case SuspensionMessage -> SuspensionMessage_TYPE; - case ErrorMessage -> ErrorMessage_TYPE; - case EndMessage -> EndMessage_TYPE; - case ProposeRunCompletionMessage -> ProposeRunCompletionMessage_TYPE; - case InputCommandMessage -> InputCommandMessage_TYPE; - case OutputCommandMessage -> OutputCommandMessage_TYPE; - case GetLazyStateCommandMessage -> GetLazyStateCommandMessage_TYPE; - case GetLazyStateCompletionNotificationMessage -> - GetLazyStateCompletionNotificationMessage_TYPE; - case SetStateCommandMessage -> SetStateCommandMessage_TYPE; - case ClearStateCommandMessage -> ClearStateCommandMessage_TYPE; - case ClearAllStateCommandMessage -> ClearAllStateCommandMessage_TYPE; - case GetLazyStateKeysCommandMessage -> GetLazyStateKeysCommandMessage_TYPE; - case GetLazyStateKeysCompletionNotificationMessage -> - GetLazyStateKeysCompletionNotificationMessage_TYPE; - case GetEagerStateCommandMessage -> GetEagerStateCommandMessage_TYPE; - case GetEagerStateKeysCommandMessage -> GetEagerStateKeysCommandMessage_TYPE; - case GetPromiseCommandMessage -> GetPromiseCommandMessage_TYPE; - case GetPromiseCompletionNotificationMessage -> GetPromiseCompletionNotificationMessage_TYPE; - case PeekPromiseCommandMessage -> PeekPromiseCommandMessage_TYPE; - case PeekPromiseCompletionNotificationMessage -> - PeekPromiseCompletionNotificationMessage_TYPE; - case CompletePromiseCommandMessage -> CompletePromiseCommandMessage_TYPE; - case CompletePromiseCompletionNotificationMessage -> - CompletePromiseCompletionNotificationMessage_TYPE; - case SleepCommandMessage -> SleepCommandMessage_TYPE; - case SleepCompletionNotificationMessage -> SleepCompletionNotificationMessage_TYPE; - case CallCommandMessage -> CallCommandMessage_TYPE; - case CallInvocationIdCompletionNotificationMessage -> - CallInvocationIdCompletionNotificationMessage_TYPE; - case CallCompletionNotificationMessage -> CallCompletionNotificationMessage_TYPE; - case OneWayCallCommandMessage -> OneWayCallCommandMessage_TYPE; - case SendSignalCommandMessage -> SendSignalCommandMessage_TYPE; - case RunCommandMessage -> RunCommandMessage_TYPE; - case RunCompletionNotificationMessage -> RunCompletionNotificationMessage_TYPE; - case AttachInvocationCommandMessage -> AttachInvocationCommandMessage_TYPE; - case AttachInvocationCompletionNotificationMessage -> - AttachInvocationCompletionNotificationMessage_TYPE; - case GetInvocationOutputCommandMessage -> GetInvocationOutputCommandMessage_TYPE; - case GetInvocationOutputCompletionNotificationMessage -> - GetInvocationOutputCompletionNotificationMessage_TYPE; - case CompleteAwakeableCommandMessage -> CompleteAwakeableCommandMessage_TYPE; - case SignalNotificationMessage -> SignalNotificationMessage_TYPE; - }; - } - - public boolean isCommand() { - return switch (this) { - case InputCommandMessage, - GetLazyStateCommandMessage, - OutputCommandMessage, - SetStateCommandMessage, - ClearStateCommandMessage, - ClearAllStateCommandMessage, - GetLazyStateKeysCommandMessage, - GetEagerStateCommandMessage, - GetEagerStateKeysCommandMessage, - GetPromiseCommandMessage, - PeekPromiseCommandMessage, - CompletePromiseCommandMessage, - SleepCommandMessage, - CallCommandMessage, - OneWayCallCommandMessage, - SendSignalCommandMessage, - RunCommandMessage, - AttachInvocationCommandMessage, - GetInvocationOutputCommandMessage, - CompleteAwakeableCommandMessage -> - true; - default -> false; - }; - } - - public boolean isNotification() { - return switch (this) { - case GetLazyStateCompletionNotificationMessage, - SignalNotificationMessage, - GetLazyStateKeysCompletionNotificationMessage, - GetPromiseCompletionNotificationMessage, - PeekPromiseCompletionNotificationMessage, - CompletePromiseCompletionNotificationMessage, - SleepCompletionNotificationMessage, - CallInvocationIdCompletionNotificationMessage, - CallCompletionNotificationMessage, - RunCompletionNotificationMessage, - AttachInvocationCompletionNotificationMessage, - GetInvocationOutputCompletionNotificationMessage -> - true; - default -> false; - }; - } - - public static MessageType decode(short value) throws ProtocolException { - return switch (value) { - case StartMessage_TYPE -> StartMessage; - case SuspensionMessage_TYPE -> SuspensionMessage; - case ErrorMessage_TYPE -> ErrorMessage; - case EndMessage_TYPE -> EndMessage; - case ProposeRunCompletionMessage_TYPE -> ProposeRunCompletionMessage; - case InputCommandMessage_TYPE -> InputCommandMessage; - case OutputCommandMessage_TYPE -> OutputCommandMessage; - case GetLazyStateCommandMessage_TYPE -> GetLazyStateCommandMessage; - case GetLazyStateCompletionNotificationMessage_TYPE -> - GetLazyStateCompletionNotificationMessage; - case SetStateCommandMessage_TYPE -> SetStateCommandMessage; - case ClearStateCommandMessage_TYPE -> ClearStateCommandMessage; - case ClearAllStateCommandMessage_TYPE -> ClearAllStateCommandMessage; - case GetLazyStateKeysCommandMessage_TYPE -> GetLazyStateKeysCommandMessage; - case GetLazyStateKeysCompletionNotificationMessage_TYPE -> - GetLazyStateKeysCompletionNotificationMessage; - case GetEagerStateCommandMessage_TYPE -> GetEagerStateCommandMessage; - case GetEagerStateKeysCommandMessage_TYPE -> GetEagerStateKeysCommandMessage; - case GetPromiseCommandMessage_TYPE -> GetPromiseCommandMessage; - case GetPromiseCompletionNotificationMessage_TYPE -> GetPromiseCompletionNotificationMessage; - case PeekPromiseCommandMessage_TYPE -> PeekPromiseCommandMessage; - case PeekPromiseCompletionNotificationMessage_TYPE -> - PeekPromiseCompletionNotificationMessage; - case CompletePromiseCommandMessage_TYPE -> CompletePromiseCommandMessage; - case CompletePromiseCompletionNotificationMessage_TYPE -> - CompletePromiseCompletionNotificationMessage; - case SleepCommandMessage_TYPE -> SleepCommandMessage; - case SleepCompletionNotificationMessage_TYPE -> SleepCompletionNotificationMessage; - case CallCommandMessage_TYPE -> CallCommandMessage; - case CallInvocationIdCompletionNotificationMessage_TYPE -> - CallInvocationIdCompletionNotificationMessage; - case CallCompletionNotificationMessage_TYPE -> CallCompletionNotificationMessage; - case OneWayCallCommandMessage_TYPE -> OneWayCallCommandMessage; - case SendSignalCommandMessage_TYPE -> SendSignalCommandMessage; - case RunCommandMessage_TYPE -> RunCommandMessage; - case RunCompletionNotificationMessage_TYPE -> RunCompletionNotificationMessage; - case AttachInvocationCommandMessage_TYPE -> AttachInvocationCommandMessage; - case AttachInvocationCompletionNotificationMessage_TYPE -> - AttachInvocationCompletionNotificationMessage; - case GetInvocationOutputCommandMessage_TYPE -> GetInvocationOutputCommandMessage; - case GetInvocationOutputCompletionNotificationMessage_TYPE -> - GetInvocationOutputCompletionNotificationMessage; - case CompleteAwakeableCommandMessage_TYPE -> CompleteAwakeableCommandMessage; - case SignalNotificationMessage_TYPE -> SignalNotificationMessage; - default -> throw ProtocolException.unknownMessageType(value); - }; - } - - public static MessageType fromMessage(MessageLite msg) { - if (msg instanceof Protocol.StartMessage) { - return MessageType.StartMessage; - } else if (msg instanceof Protocol.SuspensionMessage) { - return MessageType.SuspensionMessage; - } else if (msg instanceof Protocol.ErrorMessage) { - return MessageType.ErrorMessage; - } else if (msg instanceof Protocol.EndMessage) { - return MessageType.EndMessage; - } else if (msg instanceof Protocol.ProposeRunCompletionMessage) { - return MessageType.ProposeRunCompletionMessage; - } else if (msg instanceof Protocol.InputCommandMessage) { - return MessageType.InputCommandMessage; - } else if (msg instanceof Protocol.OutputCommandMessage) { - return MessageType.OutputCommandMessage; - } else if (msg instanceof Protocol.GetLazyStateCommandMessage) { - return MessageType.GetLazyStateCommandMessage; - } else if (msg instanceof Protocol.GetLazyStateCompletionNotificationMessage) { - return MessageType.GetLazyStateCompletionNotificationMessage; - } else if (msg instanceof Protocol.SetStateCommandMessage) { - return MessageType.SetStateCommandMessage; - } else if (msg instanceof Protocol.ClearStateCommandMessage) { - return MessageType.ClearStateCommandMessage; - } else if (msg instanceof Protocol.ClearAllStateCommandMessage) { - return MessageType.ClearAllStateCommandMessage; - } else if (msg instanceof Protocol.GetLazyStateKeysCommandMessage) { - return MessageType.GetLazyStateKeysCommandMessage; - } else if (msg instanceof Protocol.GetLazyStateKeysCompletionNotificationMessage) { - return MessageType.GetLazyStateKeysCompletionNotificationMessage; - } else if (msg instanceof Protocol.GetEagerStateCommandMessage) { - return MessageType.GetEagerStateCommandMessage; - } else if (msg instanceof Protocol.GetEagerStateKeysCommandMessage) { - return MessageType.GetEagerStateKeysCommandMessage; - } else if (msg instanceof Protocol.GetPromiseCommandMessage) { - return MessageType.GetPromiseCommandMessage; - } else if (msg instanceof Protocol.GetPromiseCompletionNotificationMessage) { - return MessageType.GetPromiseCompletionNotificationMessage; - } else if (msg instanceof Protocol.PeekPromiseCommandMessage) { - return MessageType.PeekPromiseCommandMessage; - } else if (msg instanceof Protocol.PeekPromiseCompletionNotificationMessage) { - return MessageType.PeekPromiseCompletionNotificationMessage; - } else if (msg instanceof Protocol.CompletePromiseCommandMessage) { - return MessageType.CompletePromiseCommandMessage; - } else if (msg instanceof Protocol.CompletePromiseCompletionNotificationMessage) { - return MessageType.CompletePromiseCompletionNotificationMessage; - } else if (msg instanceof Protocol.SleepCommandMessage) { - return MessageType.SleepCommandMessage; - } else if (msg instanceof Protocol.SleepCompletionNotificationMessage) { - return MessageType.SleepCompletionNotificationMessage; - } else if (msg instanceof Protocol.CallCommandMessage) { - return MessageType.CallCommandMessage; - } else if (msg instanceof Protocol.CallInvocationIdCompletionNotificationMessage) { - return MessageType.CallInvocationIdCompletionNotificationMessage; - } else if (msg instanceof Protocol.CallCompletionNotificationMessage) { - return MessageType.CallCompletionNotificationMessage; - } else if (msg instanceof Protocol.OneWayCallCommandMessage) { - return MessageType.OneWayCallCommandMessage; - } else if (msg instanceof Protocol.SendSignalCommandMessage) { - return MessageType.SendSignalCommandMessage; - } else if (msg instanceof Protocol.RunCommandMessage) { - return MessageType.RunCommandMessage; - } else if (msg instanceof Protocol.RunCompletionNotificationMessage) { - return MessageType.RunCompletionNotificationMessage; - } else if (msg instanceof Protocol.AttachInvocationCommandMessage) { - return MessageType.AttachInvocationCommandMessage; - } else if (msg instanceof Protocol.AttachInvocationCompletionNotificationMessage) { - return MessageType.AttachInvocationCompletionNotificationMessage; - } else if (msg instanceof Protocol.GetInvocationOutputCommandMessage) { - return MessageType.GetInvocationOutputCommandMessage; - } else if (msg instanceof Protocol.GetInvocationOutputCompletionNotificationMessage) { - return MessageType.GetInvocationOutputCompletionNotificationMessage; - } else if (msg instanceof Protocol.CompleteAwakeableCommandMessage) { - return MessageType.CompleteAwakeableCommandMessage; - } else if (msg instanceof Protocol.SignalNotificationMessage) { - return MessageType.SignalNotificationMessage; - } - - throw new IllegalStateException("Unexpected protobuf message"); - } -} diff --git a/sdk-core/src/main/java/dev/restate/sdk/core/statemachine/NotificationId.java b/sdk-core/src/main/java/dev/restate/sdk/core/statemachine/NotificationId.java deleted file mode 100644 index 5b3a0a1d4..000000000 --- a/sdk-core/src/main/java/dev/restate/sdk/core/statemachine/NotificationId.java +++ /dev/null @@ -1,18 +0,0 @@ -// Copyright (c) 2023 - Restate Software, Inc., Restate GmbH -// -// This file is part of the Restate Java SDK, -// which is released under the MIT license. -// -// You can find a copy of the license in file LICENSE in the root -// directory of this repository or package, or at -// https://github.com/restatedev/sdk-java/blob/main/LICENSE -package dev.restate.sdk.core.statemachine; - -public sealed interface NotificationId { - - record CompletionId(int id) implements NotificationId {} - - record SignalId(int id) implements NotificationId {} - - record SignalName(String name) implements NotificationId {} -} diff --git a/sdk-core/src/main/java/dev/restate/sdk/core/statemachine/ProcessingState.java b/sdk-core/src/main/java/dev/restate/sdk/core/statemachine/ProcessingState.java deleted file mode 100644 index 5300444d0..000000000 --- a/sdk-core/src/main/java/dev/restate/sdk/core/statemachine/ProcessingState.java +++ /dev/null @@ -1,378 +0,0 @@ -// Copyright (c) 2023 - Restate Software, Inc., Restate GmbH -// -// This file is part of the Restate Java SDK, -// which is released under the MIT license. -// -// You can find a copy of the license in file LICENSE in the root -// directory of this repository or package, or at -// https://github.com/restatedev/sdk-java/blob/main/LICENSE -package dev.restate.sdk.core.statemachine; - -import static dev.restate.sdk.core.statemachine.Util.durationMin; -import static dev.restate.sdk.core.statemachine.Util.sliceToByteString; - -import com.google.protobuf.ByteString; -import com.google.protobuf.MessageLite; -import dev.restate.common.Slice; -import dev.restate.sdk.common.AbortedExecutionException; -import dev.restate.sdk.common.RetryPolicy; -import dev.restate.sdk.common.TerminalException; -import dev.restate.sdk.core.ExceptionUtils; -import dev.restate.sdk.core.ProtocolException; -import dev.restate.sdk.core.generated.protocol.Protocol; -import dev.restate.sdk.core.statemachine.StateMachine.DoProgressResponse; -import java.time.Duration; -import java.util.List; -import java.util.Optional; -import java.util.concurrent.CompletableFuture; -import org.apache.logging.log4j.LogManager; -import org.apache.logging.log4j.Logger; -import org.jspecify.annotations.Nullable; - -final class ProcessingState implements State { - - private static final Logger LOG = LogManager.getLogger(ProcessingState.class); - - private final AsyncResultsState asyncResultsState; - private final RunState runState; - private boolean processingFirstEntry; - - ProcessingState(AsyncResultsState asyncResultsState, RunState runState) { - this.asyncResultsState = asyncResultsState; - this.runState = runState; - this.processingFirstEntry = true; - } - - @Override - public void onNewMessage( - InvocationInput invocationInput, - StateContext stateContext, - CompletableFuture waitForReadyFuture) { - if (invocationInput.header().getType().isNotification()) { - if (!(invocationInput.message() - instanceof Protocol.NotificationTemplate notificationTemplate)) { - throw ProtocolException.unexpectedMessage( - Protocol.NotificationTemplate.class, invocationInput.message()); - } - this.asyncResultsState.enqueue(notificationTemplate); - } else { - throw ProtocolException.unexpectedMessage("notification", invocationInput.message()); - } - } - - @Override - public DoProgressResponse doProgress(List awaitingOn, StateContext stateContext) { - if (awaitingOn.stream().anyMatch(this.asyncResultsState::isHandleCompleted)) { - return DoProgressResponse.AnyCompleted.INSTANCE; - } - - var notificationIds = asyncResultsState.resolveNotificationHandles(awaitingOn); - if (notificationIds.isEmpty()) { - return DoProgressResponse.AnyCompleted.INSTANCE; - } - - if (asyncResultsState.processNextUntilAnyFound(notificationIds)) { - return DoProgressResponse.AnyCompleted.INSTANCE; - } - - Integer maybeRunHandle = runState.tryExecuteRun(awaitingOn); - if (maybeRunHandle != null) { - return new DoProgressResponse.ExecuteRun(maybeRunHandle); - } - - if (stateContext.isInputClosed()) { - if (runState.anyExecuting(awaitingOn)) { - return DoProgressResponse.WaitingPendingRun.INSTANCE; - } - - this.hitSuspended(notificationIds, stateContext); - ExceptionUtils.sneakyThrow(AbortedExecutionException.INSTANCE); - } - - return DoProgressResponse.ReadFromInput.INSTANCE; - } - - @Override - public boolean isCompleted(int handle) { - return this.asyncResultsState.isHandleCompleted(handle); - } - - @Override - public Optional takeNotification(int handle, StateContext stateContext) { - return this.asyncResultsState.takeHandle(handle); - } - - @Override - public int processRunCommand(String name, StateContext stateContext) { - var completionId = stateContext.getJournal().nextCompletionNotificationId(); - var notificationId = new NotificationId.CompletionId(completionId); - - var runCmdBuilder = Protocol.RunCommandMessage.newBuilder().setResultCompletionId(completionId); - if (name != null) { - runCmdBuilder.setName(name); - } - - var runCmd = runCmdBuilder.build(); - var notificationHandle = - this.processCompletableCommand( - runCmd, CommandAccessor.RUN, new int[] {completionId}, stateContext)[0]; - - LOG.trace("Enqueued run notification for {} with id {}.", notificationHandle, notificationId); - runState.insertRunToExecute( - notificationHandle, - stateContext.getJournal().lastCommandMetadata().index(), - name != null ? name : ""); - - return notificationHandle; - } - - @Override - public int processStateGetCommand(String key, StateContext stateContext) { - this.flipFirstProcessingEntry(); - var completionId = stateContext.getJournal().nextCompletionNotificationId(); - var handle = - asyncResultsState.createHandleMapping(new NotificationId.CompletionId(completionId)); - - ByteString keyBytes = ByteString.copyFromUtf8(key); - var eagerStateQuery = stateContext.getEagerState().get(keyBytes); - if (eagerStateQuery == null) { - // Lazy state case - var commandMessage = - Protocol.GetLazyStateCommandMessage.newBuilder() - .setKey(keyBytes) - .setResultCompletionId(completionId) - .build(); - stateContext - .getJournal() - .commandTransition( - CommandAccessor.GET_LAZY_STATE.getName(commandMessage), commandMessage); - stateContext.writeMessageOut(commandMessage); - - return handle; - } - - // Eager state case - var commandMessageBuilder = Protocol.GetEagerStateCommandMessage.newBuilder().setKey(keyBytes); - if (eagerStateQuery instanceof NotificationValue.Success) { - commandMessageBuilder.setValue( - Protocol.Value.newBuilder() - .setContent(sliceToByteString(((NotificationValue.Success) eagerStateQuery).slice())) - .build()); - } else { - commandMessageBuilder.setVoid(Protocol.Void.getDefaultInstance()); - } - var commandMessage = commandMessageBuilder.build(); - stateContext - .getJournal() - .commandTransition(CommandAccessor.GET_EAGER_STATE.getName(commandMessage), commandMessage); - - asyncResultsState.insertReady(new NotificationId.CompletionId(completionId), eagerStateQuery); - stateContext.writeMessageOut(commandMessage); - - return handle; - } - - @Override - public int processStateGetKeysCommand(StateContext stateContext) { - this.flipFirstProcessingEntry(); - var completionId = stateContext.getJournal().nextCompletionNotificationId(); - var handle = - asyncResultsState.createHandleMapping(new NotificationId.CompletionId(completionId)); - - var eagerStateQuery = stateContext.getEagerState().keys(); - if (eagerStateQuery == null) { - // Lazy state case - var commandMessage = - Protocol.GetLazyStateKeysCommandMessage.newBuilder() - .setResultCompletionId(completionId) - .build(); - stateContext - .getJournal() - .commandTransition( - CommandAccessor.GET_LAZY_STATE_KEYS.getName(commandMessage), commandMessage); - stateContext.writeMessageOut(commandMessage); - - return handle; - } - - // Eager state case - var commandMessage = - Protocol.GetEagerStateKeysCommandMessage.newBuilder() - .setValue(Protocol.StateKeys.newBuilder().addAllKeys(eagerStateQuery).build()) - .build(); - stateContext - .getJournal() - .commandTransition( - CommandAccessor.GET_EAGER_STATE_KEYS.getName(commandMessage), commandMessage); - - asyncResultsState.insertReady( - new NotificationId.CompletionId(completionId), - new NotificationValue.StateKeys( - eagerStateQuery.stream().map(ByteString::toStringUtf8).toList())); - stateContext.writeMessageOut(commandMessage); - - return handle; - } - - @Override - public void processNonCompletableCommand( - E commandMessage, CommandAccessor commandAccessor, StateContext stateContext) { - stateContext - .getJournal() - .commandTransition(commandAccessor.getName(commandMessage), commandMessage); - this.flipFirstProcessingEntry(); - - stateContext.writeMessageOut(commandMessage); - } - - @Override - public int[] processCompletableCommand( - E commandMessage, - CommandAccessor commandAccessor, - int[] completionIds, - StateContext stateContext) { - stateContext - .getJournal() - .commandTransition(commandAccessor.getName(commandMessage), commandMessage); - this.flipFirstProcessingEntry(); - - stateContext.writeMessageOut(commandMessage); - - int[] handles = new int[completionIds.length]; - for (int i = 0; i < handles.length; i++) { - handles[i] = - asyncResultsState.createHandleMapping(new NotificationId.CompletionId(completionIds[i])); - } - - return handles; - } - - @Override - public int createSignalHandle(NotificationId notificationId, StateContext stateContext) { - return asyncResultsState.createHandleMapping(notificationId); - } - - @Override - public void proposeRunCompletion(int handle, Slice value, StateContext stateContext) { - var notificationId = asyncResultsState.mustResolveNotificationHandle(handle); - if (!(notificationId instanceof NotificationId.CompletionId)) { - throw ProtocolException.badRunNotificationId(notificationId); - } - - runState.notifyExecutionCompleted(handle); - - proposeRunCompletion( - handle, - Protocol.ProposeRunCompletionMessage.newBuilder() - .setResultCompletionId(((NotificationId.CompletionId) notificationId).id()) - .setValue(sliceToByteString(value)), - stateContext); - } - - @Override - public void proposeRunCompletion( - int handle, - Throwable runException, - Duration attemptDuration, - @Nullable RetryPolicy retryPolicy, - StateContext stateContext) { - var notificationId = asyncResultsState.mustResolveNotificationHandle(handle); - if (!(notificationId instanceof NotificationId.CompletionId)) { - throw ProtocolException.badRunNotificationId(notificationId); - } - - RunState.CommandInfo commandInfo = runState.notifyExecutionCompleted(handle); - - Duration retryLoopDuration = - this.getDurationSinceLastStoredEntry(stateContext).plus(attemptDuration); - int retryCount = this.getRetryCountSinceLastStoredEntry(stateContext) + 1; - - TerminalException terminalExceptionToWrite = null; - if (runException instanceof TerminalException) { - LOG.trace("The run completed with a terminal exception"); - terminalExceptionToWrite = (TerminalException) runException; - } else if (retryPolicy != null - && ((retryPolicy.getMaxAttempts() != null && retryPolicy.getMaxAttempts() <= retryCount) - || (retryPolicy.getMaxDuration() != null - && retryPolicy.getMaxDuration().compareTo(retryLoopDuration) <= 0))) { - LOG.trace("The run completed with a retryable exception and attempts were exhausted"); - // We need to convert it to TerminalException - terminalExceptionToWrite = new TerminalException(runException.toString()); - } else { - // In the other cases, it's a retryable error! - } - - if (terminalExceptionToWrite != null) { - // Terminal exception case - this.proposeRunCompletion( - handle, - Protocol.ProposeRunCompletionMessage.newBuilder() - .setResultCompletionId(((NotificationId.CompletionId) notificationId).id()) - .setFailure(Util.toProtocolFailure(terminalExceptionToWrite)), - stateContext); - } else { - // Compute retry delay - Duration nextRetryDelay = null; - if (retryPolicy != null) { - Duration nextComputedDelay = - retryPolicy - .getInitialDelay() - .multipliedBy((long) Math.pow(retryPolicy.getExponentiationFactor(), retryCount)); - nextRetryDelay = - retryPolicy.getMaxDelay() != null - ? durationMin(retryPolicy.getMaxDelay(), nextComputedDelay) - : nextComputedDelay; - } - - this.hitError( - runException, - new CommandRelationship.Specific( - commandInfo.commandIndex(), CommandType.RUN, commandInfo.commandName()), - nextRetryDelay, - stateContext); - } - } - - private void proposeRunCompletion( - int handle, - Protocol.ProposeRunCompletionMessage.Builder messageBuilder, - StateContext stateContext) { - if (!stateContext.maybeWriteMessageOut(messageBuilder.build())) { - LOG.warn( - "Cannot write proposed completion for run with handle {} because the output stream was already closed.", - handle); - } - } - - private Duration getDurationSinceLastStoredEntry(StateContext stateContext) { - // We need to check if this is the first entry we try to commit after replay, and only in this - // case we need to return the info we got from the start message - // - // Moreover, when the retry count is == 0, the durationSinceLastStoredEntry might not be zero. - // In fact, in that case the duration is the interval between the previously stored entry and - // the time to start/resume the invocation. - // For the sake of entry retries though, we're not interested in that time elapsed, so we 0 it - // here for simplicity of the downstream consumer (the retry policy). - return this.processingFirstEntry - && stateContext.getStartInfo().retryCountSinceLastStoredEntry() > 0 - ? stateContext.getStartInfo().durationSinceLastStoredEntry() - : Duration.ZERO; - } - - private int getRetryCountSinceLastStoredEntry(StateContext stateContext) { - // We need to check if this is the first entry we try to commit after replay, and only in this - // case we need to return the info we got from the start message - return this.processingFirstEntry - ? stateContext.getStartInfo().retryCountSinceLastStoredEntry() - : 0; - } - - private void flipFirstProcessingEntry() { - this.processingFirstEntry = false; - } - - @Override - public InvocationState getInvocationState() { - return InvocationState.PROCESSING; - } -} diff --git a/sdk-core/src/main/java/dev/restate/sdk/core/statemachine/ReplayingState.java b/sdk-core/src/main/java/dev/restate/sdk/core/statemachine/ReplayingState.java deleted file mode 100644 index c88974797..000000000 --- a/sdk-core/src/main/java/dev/restate/sdk/core/statemachine/ReplayingState.java +++ /dev/null @@ -1,390 +0,0 @@ -// Copyright (c) 2023 - Restate Software, Inc., Restate GmbH -// -// This file is part of the Restate Java SDK, -// which is released under the MIT license. -// -// You can find a copy of the license in file LICENSE in the root -// directory of this repository or package, or at -// https://github.com/restatedev/sdk-java/blob/main/LICENSE -package dev.restate.sdk.core.statemachine; - -import static dev.restate.sdk.core.statemachine.StateMachineImpl.CANCEL_SIGNAL_ID; -import static dev.restate.sdk.core.statemachine.Util.byteStringToSlice; - -import com.google.protobuf.ByteString; -import com.google.protobuf.MessageLite; -import dev.restate.sdk.common.AbortedExecutionException; -import dev.restate.sdk.core.ExceptionUtils; -import dev.restate.sdk.core.ProtocolException; -import dev.restate.sdk.core.generated.protocol.Protocol; -import dev.restate.sdk.core.statemachine.StateMachine.DoProgressResponse; -import java.util.*; -import java.util.concurrent.CompletableFuture; -import java.util.stream.Collectors; -import org.apache.logging.log4j.LogManager; -import org.apache.logging.log4j.Logger; - -final class ReplayingState implements State { - - private static final Logger LOG = LogManager.getLogger(ReplayingState.class); - - /** - * Comparator for notification IDs in error messages. Orders: completions first (by id), then - * named signals (by name), then signal IDs (by id, with cancel signal last). - */ - private static final Comparator NOTIFICATION_ID_COMPARATOR_FOR_JOURNAL_MISMATCH = - Comparator.comparingInt( - id -> { - if (id instanceof NotificationId.CompletionId) return 0; - if (id instanceof NotificationId.SignalName) return 1; - return 2; - }) - .thenComparing( - (a, b) -> { - if (a instanceof NotificationId.CompletionId ac - && b instanceof NotificationId.CompletionId bc) { - return Integer.compare(ac.id(), bc.id()); - } - if (a instanceof NotificationId.SignalName an - && b instanceof NotificationId.SignalName bn) { - return an.name().compareTo(bn.name()); - } - if (a instanceof NotificationId.SignalId as_ - && b instanceof NotificationId.SignalId bs) { - boolean aIsCancel = as_.id() == CANCEL_SIGNAL_ID; - boolean bIsCancel = bs.id() == CANCEL_SIGNAL_ID; - if (aIsCancel != bIsCancel) return aIsCancel ? 1 : -1; - return Integer.compare(as_.id(), bs.id()); - } - return 0; - }); - - private final Deque commandsToProcess; - private final AsyncResultsState asyncResultsState; - private final RunState runState; - - ReplayingState(Deque commandsToProcess, AsyncResultsState asyncResultsState) { - this.commandsToProcess = commandsToProcess; - this.asyncResultsState = asyncResultsState; - this.runState = new RunState(); - } - - @Override - public void onNewMessage( - InvocationInput invocationInput, - StateContext stateContext, - CompletableFuture waitForReadyFuture) { - if (invocationInput.header().getType().isNotification()) { - if (!(invocationInput.message() - instanceof Protocol.NotificationTemplate notificationTemplate)) { - throw ProtocolException.unexpectedMessage( - Protocol.NotificationTemplate.class, invocationInput.message()); - } - this.asyncResultsState.enqueue(notificationTemplate); - } else { - throw ProtocolException.unexpectedMessage("notification", invocationInput.message()); - } - } - - @Override - public DoProgressResponse doProgress(List awaitingOn, StateContext stateContext) { - if (awaitingOn.stream().anyMatch(this.asyncResultsState::isHandleCompleted)) { - return DoProgressResponse.AnyCompleted.INSTANCE; - } - - var notificationIds = asyncResultsState.resolveNotificationHandles(awaitingOn); - if (notificationIds.isEmpty()) { - return DoProgressResponse.AnyCompleted.INSTANCE; - } - - if (asyncResultsState.processNextUntilAnyFound(notificationIds)) { - return DoProgressResponse.AnyCompleted.INSTANCE; - } - - // This assertion proves the user mutated the code, adding an await point. - // - // During replay, we transition to processing AFTER replaying all COMMANDS. - // If we reach this point, none of the previous checks succeeded, meaning we don't have - // enough notifications to complete this await point. But if this await cannot be completed - // during replay, then no progress should have been made afterward, meaning there should be - // no more commands to replay. However, we ARE still replaying, which means there ARE commands - // to replay after this await point. - // - // This contradiction proves the code was mutated: an await must have been added after - // the journal was originally created. - - // Prepare error metadata to make it easier to debug - Map knownNotificationMetadata = new HashMap<>(); - CommandRelationship relatedCommand = null; - - // Collect run info - for (int handle : awaitingOn) { - RunState.Run runInfo = runState.getRunInfo(handle); - if (runInfo != null) { - var notifId = asyncResultsState.mustResolveNotificationHandle(handle); - knownNotificationMetadata.put( - notifId, - MessageType.RunCommandMessage.name() - + " '" - + runInfo.commandName() - + "' (command index " - + runInfo.commandIndex() - + ")"); - relatedCommand = - new CommandRelationship.Specific( - runInfo.commandIndex(), CommandType.RUN, runInfo.commandName()); - } - } - - // For awakeables and cancellation, add descriptions - for (var notifId : notificationIds) { - if (notifId instanceof NotificationId.SignalId signalId) { - if (signalId.id() == CANCEL_SIGNAL_ID) { - knownNotificationMetadata.put(notifId, "Cancellation"); - } else if (signalId.id() > 16) { - knownNotificationMetadata.put( - notifId, - "Awakeable " + Util.awakeableIdStr(stateContext.getStartInfo().id(), signalId.id())); - } - } - } - - this.hitError( - ProtocolException.uncompletedDoProgressDuringReplay( - notificationIds.stream() - .sorted(NOTIFICATION_ID_COMPARATOR_FOR_JOURNAL_MISMATCH) - .collect(Collectors.toList()), - knownNotificationMetadata), - relatedCommand, - null, - stateContext); - ExceptionUtils.sneakyThrow(AbortedExecutionException.INSTANCE); - return null; // unreachable - } - - @Override - public boolean isCompleted(int handle) { - return this.asyncResultsState.isHandleCompleted(handle); - } - - @Override - public Optional takeNotification(int handle, StateContext stateContext) { - return this.asyncResultsState.takeHandle(handle); - } - - @Override - public StateMachine.Input processInputCommand(StateContext stateContext) { - stateContext - .getJournal() - .commandTransition("", Protocol.InputCommandMessage.getDefaultInstance()); - - MessageLite actual = takeNextCommandToProcess(); - if (!(actual instanceof Protocol.InputCommandMessage inputCommandMessage)) { - throw ProtocolException.unexpectedMessage(Protocol.InputCommandMessage.class, actual); - } - - afterProcessingCommand(stateContext); - - //noinspection unchecked - return new StateMachine.Input( - new InvocationIdImpl( - stateContext.getStartInfo().debugId(), stateContext.getStartInfo().randomSeed()), - byteStringToSlice(inputCommandMessage.getValue().getContent()), - Map.ofEntries( - inputCommandMessage.getHeadersList().stream() - .map(h -> Map.entry(h.getKey(), h.getValue())) - .toArray(Map.Entry[]::new)), - stateContext.getStartInfo().objectKey()); - } - - @Override - public int processRunCommand(String name, StateContext stateContext) { - var completionId = stateContext.getJournal().nextCompletionNotificationId(); - var notificationId = new NotificationId.CompletionId(completionId); - - var runCmdBuilder = Protocol.RunCommandMessage.newBuilder().setResultCompletionId(completionId); - if (name != null) { - runCmdBuilder.setName(name); - } - - var notificationHandle = - this.processCompletableCommand( - runCmdBuilder.build(), CommandAccessor.RUN, new int[] {completionId}, stateContext)[0]; - - if (asyncResultsState.nonDeterministicFindId(notificationId)) { - LOG.trace( - "Found notification for {} with id {} while replaying, the run closure won't be executed.", - notificationHandle, - notificationId); - } else { - LOG.trace( - "Run notification for {} with id {} not found while replaying, so we enqueue the run to be executed later.", - notificationHandle, - notificationId); - runState.insertRunToExecute( - notificationHandle, - stateContext.getJournal().lastCommandMetadata().index(), - name != null ? name : ""); - } - - return notificationHandle; - } - - @Override - public void processNonCompletableCommand( - E commandMessage, CommandAccessor commandAccessor, StateContext stateContext) { - stateContext - .getJournal() - .commandTransition(commandAccessor.getName(commandMessage), commandMessage); - - MessageLite actual = takeNextCommandToProcess(); - try { - commandAccessor.checkEntryHeader( - stateContext.getJournal().getCommandIndex(), commandMessage, actual); - } catch (ProtocolException e) { - this.hitError(e, CommandRelationship.Last.INSTANCE, null, stateContext); - AbortedExecutionException.sneakyThrow(); - } - - afterProcessingCommand(stateContext); - } - - @Override - public int[] processCompletableCommand( - E commandMessage, - CommandAccessor commandAccessor, - int[] completionIds, - StateContext stateContext) { - stateContext - .getJournal() - .commandTransition(commandAccessor.getName(commandMessage), commandMessage); - MessageLite actual = takeNextCommandToProcess(); - try { - commandAccessor.checkEntryHeader( - stateContext.getJournal().getCommandIndex(), commandMessage, actual); - } catch (ProtocolException e) { - this.hitError(e, CommandRelationship.Last.INSTANCE, null, stateContext); - AbortedExecutionException.sneakyThrow(); - } - - int[] handles = new int[completionIds.length]; - for (int i = 0; i < handles.length; i++) { - handles[i] = - asyncResultsState.createHandleMapping(new NotificationId.CompletionId(completionIds[i])); - } - - afterProcessingCommand(stateContext); - - return handles; - } - - @Override - public int processStateGetCommand(String key, StateContext stateContext) { - var completionId = stateContext.getJournal().nextCompletionNotificationId(); - var handle = - asyncResultsState.createHandleMapping(new NotificationId.CompletionId(completionId)); - - stateContext - .getJournal() - .commandTransition("", Protocol.GetEagerStateCommandMessage.getDefaultInstance()); - MessageLite actual = takeNextCommandToProcess(); - - if (actual instanceof Protocol.GetEagerStateCommandMessage eagerStateCommandMessage) { - CommandAccessor.GET_EAGER_STATE.checkEntryHeader( - stateContext.getJournal().getCommandIndex(), - Protocol.GetEagerStateCommandMessage.newBuilder() - .setKey(ByteString.copyFromUtf8(key)) - .build(), - actual); - - asyncResultsState.insertReady( - new NotificationId.CompletionId(completionId), - switch (eagerStateCommandMessage.getResultCase()) { - case VOID -> NotificationValue.Empty.INSTANCE; - case VALUE -> - new NotificationValue.Success( - byteStringToSlice(eagerStateCommandMessage.getValue().getContent())); - case RESULT_NOT_SET -> - throw ProtocolException.commandMissingField( - Protocol.GetEagerStateCommandMessage.class, "result"); - }); - - } else if (actual instanceof Protocol.GetLazyStateCommandMessage) { - CommandAccessor.GET_LAZY_STATE.checkEntryHeader( - stateContext.getJournal().getCommandIndex(), - Protocol.GetLazyStateCommandMessage.newBuilder() - .setKey(ByteString.copyFromUtf8(key)) - .setResultCompletionId(completionId) - .build(), - actual); - } else { - throw ProtocolException.unexpectedMessage("get state", actual); - } - - afterProcessingCommand(stateContext); - - return handle; - } - - @Override - public int processStateGetKeysCommand(StateContext stateContext) { - var completionId = stateContext.getJournal().nextCompletionNotificationId(); - var handle = - asyncResultsState.createHandleMapping(new NotificationId.CompletionId(completionId)); - - stateContext - .getJournal() - .commandTransition("", Protocol.GetEagerStateKeysCommandMessage.getDefaultInstance()); - MessageLite actual = takeNextCommandToProcess(); - - if (actual instanceof Protocol.GetEagerStateKeysCommandMessage eagerStateCommandMessage) { - CommandAccessor.GET_EAGER_STATE_KEYS.checkEntryHeader( - stateContext.getJournal().getCommandIndex(), - Protocol.GetEagerStateKeysCommandMessage.getDefaultInstance(), - actual); - - asyncResultsState.insertReady( - new NotificationId.CompletionId(completionId), - new NotificationValue.StateKeys( - eagerStateCommandMessage.getValue().getKeysList().stream() - .map(ByteString::toStringUtf8) - .toList())); - } else if (actual instanceof Protocol.GetLazyStateKeysCommandMessage) { - CommandAccessor.GET_LAZY_STATE_KEYS.checkEntryHeader( - stateContext.getJournal().getCommandIndex(), - Protocol.GetLazyStateKeysCommandMessage.newBuilder() - .setResultCompletionId(completionId) - .build(), - actual); - } else { - throw ProtocolException.unexpectedMessage("get state keys", actual); - } - - afterProcessingCommand(stateContext); - - return handle; - } - - @Override - public int createSignalHandle(NotificationId notificationId, StateContext stateContext) { - return asyncResultsState.createHandleMapping(notificationId); - } - - private void afterProcessingCommand(StateContext stateContext) { - if (commandsToProcess.isEmpty()) { - stateContext.getStateHolder().transition(new ProcessingState(asyncResultsState, runState)); - } - } - - private MessageLite takeNextCommandToProcess() { - if (commandsToProcess.isEmpty()) { - throw ProtocolException.commandsToProcessIsEmpty(); - } - return commandsToProcess.removeFirst(); - } - - @Override - public InvocationState getInvocationState() { - return InvocationState.REPLAYING; - } -} diff --git a/sdk-core/src/main/java/dev/restate/sdk/core/statemachine/RunState.java b/sdk-core/src/main/java/dev/restate/sdk/core/statemachine/RunState.java deleted file mode 100644 index dedae9b66..000000000 --- a/sdk-core/src/main/java/dev/restate/sdk/core/statemachine/RunState.java +++ /dev/null @@ -1,70 +0,0 @@ -// Copyright (c) 2023 - Restate Software, Inc., Restate GmbH -// -// This file is part of the Restate Java SDK, -// which is released under the MIT license. -// -// You can find a copy of the license in file LICENSE in the root -// directory of this repository or package, or at -// https://github.com/restatedev/sdk-java/blob/main/LICENSE -package dev.restate.sdk.core.statemachine; - -import java.util.Collection; -import java.util.HashMap; -import java.util.Map; -import org.jspecify.annotations.Nullable; - -final class RunState { - private final Map runs = new HashMap<>(); - - public void insertRunToExecute(int handle, int commandIndex, String commandName) { - runs.put(handle, new Run(commandIndex, commandName, RunStateInner.ToExecute)); - } - - public @Nullable Integer tryExecuteRun(Collection anyHandle) { - for (Map.Entry entry : runs.entrySet()) { - Integer handle = entry.getKey(); - Run run = entry.getValue(); - if (run.state == RunStateInner.ToExecute && anyHandle.contains(handle)) { - entry.setValue(new Run(run.commandIndex, run.commandName, RunStateInner.Executing)); - return handle; - } - } - return null; - } - - public @Nullable Run getRunInfo(int handle) { - return runs.get(handle); - } - - public boolean anyExecuting(Collection anyHandle) { - return anyHandle.stream() - .anyMatch(h -> runs.containsKey(h) && runs.get(h).state == RunStateInner.Executing); - } - - /** - * Notifies that execution has completed for the given handle. - * - * @param executed the handle of the completed execution - * @return a tuple of (commandName, commandIndex) - */ - public CommandInfo notifyExecutionCompleted(int executed) { - Run run = runs.remove(executed); - if (run == null) { - throw new IllegalStateException("There must be a corresponding run for the given handle"); - } - return new CommandInfo(run.commandName, run.commandIndex); - } - - enum RunStateInner { - ToExecute, - Executing - } - - record Run(int commandIndex, String commandName, RunStateInner state) {} - - /** - * Represents the command information in the order expected by the Rust code (commandName, - * commandIndex). - */ - record CommandInfo(String commandName, int commandIndex) {} -} diff --git a/sdk-core/src/main/java/dev/restate/sdk/core/statemachine/ServiceProtocol.java b/sdk-core/src/main/java/dev/restate/sdk/core/statemachine/ServiceProtocol.java deleted file mode 100644 index 214fb0e2b..000000000 --- a/sdk-core/src/main/java/dev/restate/sdk/core/statemachine/ServiceProtocol.java +++ /dev/null @@ -1,76 +0,0 @@ -// Copyright (c) 2023 - Restate Software, Inc., Restate GmbH -// -// This file is part of the Restate Java SDK, -// which is released under the MIT license. -// -// You can find a copy of the license in file LICENSE in the root -// directory of this repository or package, or at -// https://github.com/restatedev/sdk-java/blob/main/LICENSE -package dev.restate.sdk.core.statemachine; - -import dev.restate.sdk.core.generated.protocol.Protocol; -import java.util.Objects; - -public class ServiceProtocol { - public static final Protocol.ServiceProtocolVersion MIN_SERVICE_PROTOCOL_VERSION = - Protocol.ServiceProtocolVersion.V5; - public static final Protocol.ServiceProtocolVersion MAX_SERVICE_PROTOCOL_VERSION = - Protocol.ServiceProtocolVersion.V6; - - static final String CONTENT_TYPE = "content-type"; - - static Protocol.ServiceProtocolVersion parseServiceProtocolVersion(String version) { - if (version == null) { - return Protocol.ServiceProtocolVersion.SERVICE_PROTOCOL_VERSION_UNSPECIFIED; - } - version = version.trim(); - - if (version.equals("application/vnd.restate.invocation.v1")) { - return Protocol.ServiceProtocolVersion.V1; - } - if (version.equals("application/vnd.restate.invocation.v2")) { - return Protocol.ServiceProtocolVersion.V2; - } - if (version.equals("application/vnd.restate.invocation.v3")) { - return Protocol.ServiceProtocolVersion.V3; - } - if (version.equals("application/vnd.restate.invocation.v4")) { - return Protocol.ServiceProtocolVersion.V4; - } - if (version.equals("application/vnd.restate.invocation.v5")) { - return Protocol.ServiceProtocolVersion.V5; - } - if (version.equals("application/vnd.restate.invocation.v6")) { - return Protocol.ServiceProtocolVersion.V6; - } - return Protocol.ServiceProtocolVersion.SERVICE_PROTOCOL_VERSION_UNSPECIFIED; - } - - static String serviceProtocolVersionToHeaderValue(Protocol.ServiceProtocolVersion version) { - if (Objects.requireNonNull(version) == Protocol.ServiceProtocolVersion.V1) { - return "application/vnd.restate.invocation.v1"; - } - if (Objects.requireNonNull(version) == Protocol.ServiceProtocolVersion.V2) { - return "application/vnd.restate.invocation.v2"; - } - if (Objects.requireNonNull(version) == Protocol.ServiceProtocolVersion.V3) { - return "application/vnd.restate.invocation.v3"; - } - if (Objects.requireNonNull(version) == Protocol.ServiceProtocolVersion.V4) { - return "application/vnd.restate.invocation.v4"; - } - if (Objects.requireNonNull(version) == Protocol.ServiceProtocolVersion.V5) { - return "application/vnd.restate.invocation.v5"; - } - if (Objects.requireNonNull(version) == Protocol.ServiceProtocolVersion.V6) { - return "application/vnd.restate.invocation.v6"; - } - throw new IllegalArgumentException( - String.format("Service protocol version '%s' has no header value", version.getNumber())); - } - - static boolean isSupported(Protocol.ServiceProtocolVersion serviceProtocolVersion) { - return MIN_SERVICE_PROTOCOL_VERSION.getNumber() <= serviceProtocolVersion.getNumber() - && serviceProtocolVersion.getNumber() <= MAX_SERVICE_PROTOCOL_VERSION.getNumber(); - } -} diff --git a/sdk-core/src/main/java/dev/restate/sdk/core/statemachine/StartInfo.java b/sdk-core/src/main/java/dev/restate/sdk/core/statemachine/StartInfo.java deleted file mode 100644 index 03bf90c37..000000000 --- a/sdk-core/src/main/java/dev/restate/sdk/core/statemachine/StartInfo.java +++ /dev/null @@ -1,21 +0,0 @@ -// Copyright (c) 2023 - Restate Software, Inc., Restate GmbH -// -// This file is part of the Restate Java SDK, -// which is released under the MIT license. -// -// You can find a copy of the license in file LICENSE in the root -// directory of this repository or package, or at -// https://github.com/restatedev/sdk-java/blob/main/LICENSE -package dev.restate.sdk.core.statemachine; - -import com.google.protobuf.ByteString; -import java.time.Duration; - -record StartInfo( - ByteString id, - String debugId, - String objectKey, - int entriesToReplay, - int retryCountSinceLastStoredEntry, - Duration durationSinceLastStoredEntry, - Long randomSeed) {} diff --git a/sdk-core/src/main/java/dev/restate/sdk/core/statemachine/State.java b/sdk-core/src/main/java/dev/restate/sdk/core/statemachine/State.java deleted file mode 100644 index 5c8cb6758..000000000 --- a/sdk-core/src/main/java/dev/restate/sdk/core/statemachine/State.java +++ /dev/null @@ -1,201 +0,0 @@ -// Copyright (c) 2023 - Restate Software, Inc., Restate GmbH -// -// This file is part of the Restate Java SDK, -// which is released under the MIT license. -// -// You can find a copy of the license in file LICENSE in the root -// directory of this repository or package, or at -// https://github.com/restatedev/sdk-java/blob/main/LICENSE -package dev.restate.sdk.core.statemachine; - -import com.google.protobuf.MessageLite; -import dev.restate.common.Slice; -import dev.restate.sdk.common.RetryPolicy; -import dev.restate.sdk.common.TerminalException; -import dev.restate.sdk.core.ProtocolException; -import dev.restate.sdk.core.generated.protocol.Protocol; -import java.io.PrintWriter; -import java.io.StringWriter; -import java.time.Duration; -import java.util.Collection; -import java.util.List; -import java.util.Optional; -import java.util.concurrent.CompletableFuture; -import org.apache.logging.log4j.LogManager; -import org.apache.logging.log4j.Logger; -import org.jspecify.annotations.Nullable; - -sealed interface State - permits ClosedState, - ProcessingState, - ReplayingState, - WaitingReplayEntriesState, - WaitingStartState { - - Logger LOG = LogManager.getLogger(State.class); - - default void onNewMessage( - InvocationInput invocationInput, - StateContext stateContext, - CompletableFuture waitForReadyFuture) { - throw ProtocolException.badState(this); - } - - default StateMachine.DoProgressResponse doProgress( - List anyHandle, StateContext stateContext) { - throw ProtocolException.badState(this); - } - - default boolean isCompleted(int handle) { - throw ProtocolException.badState(this); - } - - default Optional takeNotification(int handle, StateContext stateContext) { - throw ProtocolException.badState(this); - } - - default StateMachine.@Nullable Input processInputCommand(StateContext stateContext) { - throw ProtocolException.badState(this); - } - - default int processStateGetCommand(String key, StateContext stateContext) { - throw ProtocolException.badState(this); - } - - default int processStateGetKeysCommand(StateContext stateContext) { - throw ProtocolException.badState(this); - } - - default void processNonCompletableCommand( - E commandMessage, CommandAccessor commandAccessor, StateContext stateContext) { - throw ProtocolException.badState(this); - } - - default int[] processCompletableCommand( - E commandMessage, - CommandAccessor commandAccessor, - int[] completionIds, - StateContext stateContext) { - throw ProtocolException.badState(this); - } - - default int createSignalHandle(NotificationId notificationId, StateContext stateContext) { - throw ProtocolException.badState(this); - } - - default int processRunCommand(String name, StateContext stateContext) { - throw ProtocolException.badState(this); - } - - default void proposeRunCompletion(int handle, Slice value, StateContext stateContext) { - LOG.warn( - "Going to ignore proposed run completion with handle {} because the state machine is not in processing state.", - handle); - } - - default void proposeRunCompletion( - int handle, - Throwable exception, - Duration attemptDuration, - @Nullable RetryPolicy retryPolicy, - StateContext stateContext) { - LOG.warn( - "Going to ignore proposed run completion with handle {} because the state machine is not in processing state.", - handle); - } - - default void hitError( - Throwable throwable, - @Nullable CommandRelationship relatedCommand, - @Nullable Duration nextRetryDelay, - StateContext stateContext) { - LOG.warn("Invocation failed", throwable); - - var errorMessageBuilder = Protocol.ErrorMessage.newBuilder(); - - // Figure out message - if (throwable.getMessage() == null) { - // This happens only with few common exceptions, but anyway - errorMessageBuilder.setMessage(throwable.toString()); - } else { - errorMessageBuilder.setMessage(throwable.getMessage()); - } - - // Figure out code - if (throwable instanceof ProtocolException) { - errorMessageBuilder.setCode(((ProtocolException) throwable).getCode()); - } else { - errorMessageBuilder.setCode(TerminalException.INTERNAL_SERVER_ERROR_CODE); - } - - // Convert stacktrace to string - StringWriter sw = new StringWriter(); - PrintWriter pw = new PrintWriter(sw); - throwable.printStackTrace(pw); - errorMessageBuilder.setStacktrace(sw.toString()); - - // Add command metadata, if any - CommandMetadata commandMetadata = - (relatedCommand != null) - ? stateContext.getJournal().resolveRelatedCommand(relatedCommand) - : null; - if (commandMetadata != null) { - if (commandMetadata.index() >= 0) { - errorMessageBuilder.setRelatedCommandIndex(commandMetadata.index()); - } - if (commandMetadata.name() != null) { - errorMessageBuilder.setRelatedCommandName(commandMetadata.name()); - } - if (commandMetadata.type() != null) { - errorMessageBuilder.setRelatedCommandType(commandMetadata.type().encode()); - } - } - - // Add next retry delay, if any - if (nextRetryDelay != null) { - errorMessageBuilder.setNextRetryDelay(nextRetryDelay.toMillis()); - } - - stateContext.maybeWriteMessageOut(errorMessageBuilder.build()); - stateContext.getStateHolder().transition(new ClosedState()); - - stateContext.closeOutputSubscriber(); - } - - default void hitSuspended(Collection awaitingOn, StateContext stateContext) { - LOG.info("Invocation suspended"); - LOG.debug("Awaiting on {}", awaitingOn); - - var suspensionMessageBuilder = Protocol.SuspensionMessage.newBuilder(); - for (var notificationId : awaitingOn) { - if (notificationId instanceof NotificationId.CompletionId completionId) { - suspensionMessageBuilder.addWaitingCompletions(completionId.id()); - } else if (notificationId instanceof NotificationId.SignalId signalId) { - suspensionMessageBuilder.addWaitingSignals(signalId.id()); - } else if (notificationId instanceof NotificationId.SignalName signalName) { - suspensionMessageBuilder.addWaitingNamedSignals(signalName.name()); - } - } - - stateContext.maybeWriteMessageOut(suspensionMessageBuilder.build()); - stateContext.getStateHolder().transition(new ClosedState()); - - stateContext.closeOutputSubscriber(); - } - - default void end(StateContext stateContext) { - LOG.info("Invocation ended"); - - stateContext.writeMessageOut(Protocol.EndMessage.getDefaultInstance()); - stateContext.getStateHolder().transition(new ClosedState()); - - stateContext.closeOutputSubscriber(); - } - - default void onInputClosed(StateContext stateContext) { - LOG.trace("Marking input closed"); - stateContext.markInputClosed(); - } - - InvocationState getInvocationState(); -} diff --git a/sdk-core/src/main/java/dev/restate/sdk/core/statemachine/StateContext.java b/sdk-core/src/main/java/dev/restate/sdk/core/statemachine/StateContext.java deleted file mode 100644 index a680374b0..000000000 --- a/sdk-core/src/main/java/dev/restate/sdk/core/statemachine/StateContext.java +++ /dev/null @@ -1,103 +0,0 @@ -// Copyright (c) 2023 - Restate Software, Inc., Restate GmbH -// -// This file is part of the Restate Java SDK, -// which is released under the MIT license. -// -// You can find a copy of the license in file LICENSE in the root -// directory of this repository or package, or at -// https://github.com/restatedev/sdk-java/blob/main/LICENSE -package dev.restate.sdk.core.statemachine; - -import com.google.protobuf.MessageLite; -import dev.restate.sdk.core.EndpointRequestHandler; -import dev.restate.sdk.core.generated.protocol.Protocol; -import java.util.Objects; -import java.util.concurrent.Flow; - -final class StateContext { - - private final Protocol.ServiceProtocolVersion negotiatedProtocolVersion; - private final StateHolder stateHolder; - private final Journal journal; - private EagerState eagerState; - private transient StartInfo startInfo; - private boolean inputClosed; - private Flow.Subscriber outputSubscriber; - - StateContext( - EndpointRequestHandler.LoggingContextSetter loggingContextSetter, - Protocol.ServiceProtocolVersion negotiatedProtocolVersion) { - this.stateHolder = new StateHolder(loggingContextSetter); - this.negotiatedProtocolVersion = negotiatedProtocolVersion; - this.journal = new Journal(); - this.inputClosed = false; - } - - public Protocol.ServiceProtocolVersion getNegotiatedProtocolVersion() { - return negotiatedProtocolVersion; - } - - public State getCurrentState() { - return stateHolder.getState(); - } - - public StateHolder getStateHolder() { - return stateHolder; - } - - public Journal getJournal() { - return journal; - } - - public StateContext setEagerState(EagerState eagerState) { - this.eagerState = eagerState; - return this; - } - - public StateContext setStartInfo(StartInfo startInfo) { - this.startInfo = startInfo; - return this; - } - - EagerState getEagerState() { - return Objects.requireNonNull(eagerState, "The state machine should be initialized"); - } - - StartInfo getStartInfo() { - return Objects.requireNonNull(startInfo, "The state machine should be initialized"); - } - - public void markInputClosed() { - this.inputClosed = true; - } - - public boolean isInputClosed() { - return this.inputClosed; - } - - public void writeMessageOut(MessageLite msg) { - Objects.requireNonNull( - this.outputSubscriber, - "Output subscriber should be configured before running the state machine") - .onNext(msg); - } - - public boolean maybeWriteMessageOut(MessageLite msg) { - if (this.outputSubscriber != null) { - this.outputSubscriber.onNext(msg); - return true; - } - return false; - } - - public void closeOutputSubscriber() { - if (this.outputSubscriber != null) { - this.outputSubscriber.onComplete(); - this.outputSubscriber = null; - } - } - - public void registerOutputSubscriber(Flow.Subscriber outputSubscriber) { - this.outputSubscriber = outputSubscriber; - } -} diff --git a/sdk-core/src/main/java/dev/restate/sdk/core/statemachine/StateHolder.java b/sdk-core/src/main/java/dev/restate/sdk/core/statemachine/StateHolder.java deleted file mode 100644 index 3d58f5a8e..000000000 --- a/sdk-core/src/main/java/dev/restate/sdk/core/statemachine/StateHolder.java +++ /dev/null @@ -1,38 +0,0 @@ -// Copyright (c) 2023 - Restate Software, Inc., Restate GmbH -// -// This file is part of the Restate Java SDK, -// which is released under the MIT license. -// -// You can find a copy of the license in file LICENSE in the root -// directory of this repository or package, or at -// https://github.com/restatedev/sdk-java/blob/main/LICENSE -package dev.restate.sdk.core.statemachine; - -import dev.restate.sdk.core.EndpointRequestHandler; -import org.apache.logging.log4j.LogManager; -import org.apache.logging.log4j.Logger; - -final class StateHolder { - - Logger LOG = LogManager.getLogger(StateHolder.class); - - private State state; - private final EndpointRequestHandler.LoggingContextSetter loggingContextSetter; - - StateHolder(EndpointRequestHandler.LoggingContextSetter loggingContextSetter) { - this.loggingContextSetter = loggingContextSetter; - this.state = new WaitingStartState(); - } - - State getState() { - return state; - } - - void transition(State state) { - this.state = state; - LOG.debug("Transitioning state machine to {}", state.getInvocationState()); - this.loggingContextSetter.set( - EndpointRequestHandler.LoggingContextSetter.INVOCATION_STATUS_KEY, - state.getInvocationState().toString()); - } -} diff --git a/sdk-core/src/main/java/dev/restate/sdk/core/statemachine/StateMachine.java b/sdk-core/src/main/java/dev/restate/sdk/core/statemachine/StateMachine.java deleted file mode 100644 index 14d810c11..000000000 --- a/sdk-core/src/main/java/dev/restate/sdk/core/statemachine/StateMachine.java +++ /dev/null @@ -1,151 +0,0 @@ -// Copyright (c) 2023 - Restate Software, Inc., Restate GmbH -// -// This file is part of the Restate Java SDK, -// which is released under the MIT license. -// -// You can find a copy of the license in file LICENSE in the root -// directory of this repository or package, or at -// https://github.com/restatedev/sdk-java/blob/main/LICENSE -package dev.restate.sdk.core.statemachine; - -import dev.restate.common.Slice; -import dev.restate.common.Target; -import dev.restate.sdk.common.*; -import dev.restate.sdk.core.EndpointRequestHandler; -import dev.restate.sdk.endpoint.HeadersAccessor; -import java.time.Duration; -import java.util.Collection; -import java.util.List; -import java.util.Map; -import java.util.Optional; -import java.util.concurrent.CompletableFuture; -import java.util.concurrent.Flow; -import org.jspecify.annotations.Nullable; - -/** - * More or less same as the VM trait - */ -public interface StateMachine extends Flow.Processor { - - static StateMachine init( - HeadersAccessor headersAccessor, - EndpointRequestHandler.LoggingContextSetter loggingContextSetter) { - return new StateMachineImpl(headersAccessor, loggingContextSetter); - } - - // --- Response metadata - - String getResponseContentType(); - - // --- Execution starting point - - CompletableFuture waitForReady(); - - // --- Await next event - - void onNextEvent(Runnable runnable, boolean triggerNowIfInputClosed); - - // --- Async results - - sealed interface DoProgressResponse { - record AnyCompleted() implements DoProgressResponse { - static AnyCompleted INSTANCE = new AnyCompleted(); - } - - record ReadFromInput() implements DoProgressResponse { - static ReadFromInput INSTANCE = new ReadFromInput(); - } - - record ExecuteRun(int handle) implements DoProgressResponse {} - - record WaitingPendingRun() implements DoProgressResponse { - static WaitingPendingRun INSTANCE = new WaitingPendingRun(); - } - } - - DoProgressResponse doProgress(List anyHandle); - - boolean isCompleted(int handle); - - Optional takeNotification(int handle); - - // --- Commands. The int return value is the handle of the operation. - - record Input( - InvocationId invocationId, Slice body, Map headers, @Nullable String key) {} - - @Nullable Input input(); - - int stateGet(String key); - - int stateGetKeys(); - - void stateSet(String key, Slice bytes); - - void stateClear(String key); - - void stateClearAll(); - - int sleep(Duration duration, String name); - - record CallHandle(int invocationIdHandle, int resultHandle) {} - - CallHandle call( - Target target, - Slice payload, - @Nullable String idempotencyKey, - @Nullable Collection> headers); - - int send( - Target target, - Slice payload, - @Nullable String idempotencyKey, - @Nullable Collection> headers, - @Nullable Duration delay); - - record Awakeable(String awakeableId, int handle) {} - - Awakeable awakeable(); - - void completeAwakeable(String awakeableId, Slice value); - - void completeAwakeable(String awakeableId, TerminalException exception); - - int createSignalHandle(String signalName); - - void completeSignal(String targetInvocationId, String signalName, Slice value); - - void completeSignal(String targetInvocationId, String signalName, TerminalException exception); - - int promiseGet(String key); - - int promisePeek(String key); - - int promiseComplete(String key, Slice value); - - int promiseComplete(String key, TerminalException exception); - - int run(String name); - - void proposeRunCompletion(int handle, Slice value); - - void proposeRunCompletion( - int handle, Throwable exception, Duration attemptDuration, RetryPolicy retryPolicy); - - void cancelInvocation(String targetInvocationId); - - int attachInvocation(String invocationId); - - int getInvocationOutput(String invocationId); - - void writeOutput(Slice value); - - void writeOutput(TerminalException exception); - - void end(); - - // -- Introspection - - InvocationState state(); -} diff --git a/sdk-core/src/main/java/dev/restate/sdk/core/statemachine/StateMachineImpl.java b/sdk-core/src/main/java/dev/restate/sdk/core/statemachine/StateMachineImpl.java deleted file mode 100644 index 5d9a5ddfd..000000000 --- a/sdk-core/src/main/java/dev/restate/sdk/core/statemachine/StateMachineImpl.java +++ /dev/null @@ -1,677 +0,0 @@ -// Copyright (c) 2023 - Restate Software, Inc., Restate GmbH -// -// This file is part of the Restate Java SDK, -// which is released under the MIT license. -// -// You can find a copy of the license in file LICENSE in the root -// directory of this repository or package, or at -// https://github.com/restatedev/sdk-java/blob/main/LICENSE -package dev.restate.sdk.core.statemachine; - -import static dev.restate.sdk.core.statemachine.Util.sliceToByteString; -import static dev.restate.sdk.core.statemachine.Util.toProtocolFailure; - -import com.google.protobuf.ByteString; -import dev.restate.common.Slice; -import dev.restate.common.Target; -import dev.restate.sdk.common.*; -import dev.restate.sdk.core.EndpointRequestHandler; -import dev.restate.sdk.core.ProtocolException; -import dev.restate.sdk.core.generated.protocol.Protocol; -import dev.restate.sdk.endpoint.HeadersAccessor; -import java.time.Duration; -import java.time.Instant; -import java.util.*; -import java.util.concurrent.CompletableFuture; -import java.util.concurrent.Flow; -import java.util.function.Consumer; -import org.apache.logging.log4j.LogManager; -import org.apache.logging.log4j.Logger; -import org.jspecify.annotations.NonNull; -import org.jspecify.annotations.Nullable; - -class StateMachineImpl implements StateMachine { - - private static final Logger LOG = LogManager.getLogger(StateMachineImpl.class); - static final int CANCEL_SIGNAL_ID = 1; - - // Callbacks - private final CompletableFuture waitForReadyFuture = new CompletableFuture<>(); - private @NonNull Runnable nextEventListener = () -> {}; - - // Java Flow and message handling - private final MessageDecoder messageDecoder = new MessageDecoder(); - private Flow.@Nullable Subscription inputSubscription; - - // State machine context - private final StateContext stateContext; - - StateMachineImpl( - HeadersAccessor headersAccessor, - EndpointRequestHandler.LoggingContextSetter loggingContextSetter) { - String contentTypeHeader = headersAccessor.get(ServiceProtocol.CONTENT_TYPE); - - var serviceProtocolVersion = ServiceProtocol.parseServiceProtocolVersion(contentTypeHeader); - if (!ServiceProtocol.isSupported(serviceProtocolVersion)) { - throw new ProtocolException( - String.format( - "Service endpoint does not support the service protocol version '%s'.", - contentTypeHeader), - ProtocolException.UNSUPPORTED_MEDIA_TYPE_CODE); - } - - this.stateContext = new StateContext(loggingContextSetter, serviceProtocolVersion); - } - - // -- Few callbacks - - @Override - public CompletableFuture waitForReady() { - return waitForReadyFuture; - } - - @Override - public void onNextEvent(Runnable runnable, boolean triggerNowIfInputClosed) { - this.nextEventListener = - () -> { - this.nextEventListener.run(); - runnable.run(); - }; - // Trigger this now - if (triggerNowIfInputClosed && this.stateContext.isInputClosed()) { - this.triggerNextEventSignal(); - } - } - - private void triggerNextEventSignal() { - Runnable listener = this.nextEventListener; - this.nextEventListener = () -> {}; - listener.run(); - } - - // -- IO - - @Override - public void subscribe(Flow.Subscriber subscriber) { - var outputSubscriber = new MessageEncoder(subscriber); - this.stateContext.registerOutputSubscriber(outputSubscriber); - outputSubscriber.onSubscribe( - new Flow.Subscription() { - @Override - public void request(long l) {} - - @Override - public void cancel() { - end(); - } - }); - } - - // --- Input Subscriber impl - - @Override - public void onSubscribe(Flow.Subscription subscription) { - try { - this.inputSubscription = subscription; - this.inputSubscription.request(Long.MAX_VALUE); - } catch (Throwable e) { - this.onError(e); - } - } - - @Override - public void onNext(Slice slice) { - try { - LOG.trace("Received input slice"); - this.messageDecoder.offer(slice); - - boolean shouldTriggerInputListener = this.messageDecoder.isNextAvailable(); - InvocationInput invocationInput = this.messageDecoder.next(); - while (invocationInput != null) { - LOG.trace( - "Received input message {} {}", - invocationInput.message().getClass(), - invocationInput.message()); - - this.stateContext - .getCurrentState() - .onNewMessage(invocationInput, this.stateContext, this.waitForReadyFuture); - - invocationInput = this.messageDecoder.next(); - } - - if (shouldTriggerInputListener) { - this.triggerNextEventSignal(); - } - - } catch (Throwable e) { - this.onError(e); - } - } - - @Override - public void onError(Throwable throwable) { - this.stateContext.getCurrentState().hitError(throwable, null, null, this.stateContext); - this.triggerNextEventSignal(); - cancelInputSubscription(); - } - - @Override - public void onComplete() { - LOG.trace("Input publisher closed"); - try { - this.stateContext.getCurrentState().onInputClosed(this.stateContext); - } catch (Throwable e) { - this.onError(e); - return; - } - this.triggerNextEventSignal(); - this.cancelInputSubscription(); - } - - // -- State machine - - @Override - public String getResponseContentType() { - return ServiceProtocol.serviceProtocolVersionToHeaderValue( - stateContext.getNegotiatedProtocolVersion()); - } - - @Override - public DoProgressResponse doProgress(List anyHandle) { - return this.stateContext.getCurrentState().doProgress(anyHandle, this.stateContext); - } - - @Override - public boolean isCompleted(int handle) { - return this.stateContext.getCurrentState().isCompleted(handle); - } - - @Override - public Optional takeNotification(int handle) { - return this.stateContext.getCurrentState().takeNotification(handle, this.stateContext); - } - - @Override - public @Nullable Input input() { - return this.stateContext.getCurrentState().processInputCommand(this.stateContext); - } - - @Override - public int stateGet(String key) { - LOG.debug("Executing 'Get state {}'", key); - return this.stateContext.getCurrentState().processStateGetCommand(key, this.stateContext); - } - - @Override - public int stateGetKeys() { - LOG.debug("Executing 'Get state keys'"); - return this.stateContext.getCurrentState().processStateGetKeysCommand(this.stateContext); - } - - @Override - public void stateSet(String key, Slice bytes) { - LOG.debug("Executing 'Set state {}'", key); - ByteString keyBuffer = ByteString.copyFromUtf8(key); - this.stateContext.getEagerState().set(keyBuffer, bytes); - this.stateContext - .getCurrentState() - .processNonCompletableCommand( - Protocol.SetStateCommandMessage.newBuilder() - .setKey(keyBuffer) - .setValue(Protocol.Value.newBuilder().setContent(sliceToByteString(bytes)).build()) - .build(), - CommandAccessor.SET_STATE, - this.stateContext); - } - - @Override - public void stateClear(String key) { - LOG.debug("Executing 'Clear state {}'", key); - ByteString keyBuffer = ByteString.copyFromUtf8(key); - this.stateContext.getEagerState().clear(keyBuffer); - this.stateContext - .getCurrentState() - .processNonCompletableCommand( - Protocol.ClearStateCommandMessage.newBuilder().setKey(keyBuffer).build(), - CommandAccessor.CLEAR_STATE, - this.stateContext); - } - - @Override - public void stateClearAll() { - LOG.debug("Executing 'Clear all state'"); - this.stateContext.getEagerState().clearAll(); - this.stateContext - .getCurrentState() - .processNonCompletableCommand( - Protocol.ClearAllStateCommandMessage.getDefaultInstance(), - CommandAccessor.CLEAR_ALL_STATE, - this.stateContext); - } - - @Override - public int sleep(Duration duration, @Nullable String name) { - LOG.debug("Executing 'Sleeping for {}'", duration); - var completionId = this.stateContext.getJournal().nextCompletionNotificationId(); - - var sleepCommandBuilder = - Protocol.SleepCommandMessage.newBuilder() - .setWakeUpTime(Instant.now().toEpochMilli() + duration.toMillis()) - .setResultCompletionId(completionId); - if (name != null) { - sleepCommandBuilder.setName(name); - } - - return this.stateContext.getCurrentState() - .processCompletableCommand( - sleepCommandBuilder.build(), - CommandAccessor.SLEEP, - new int[] {completionId}, - this.stateContext)[0]; - } - - @Override - public CallHandle call( - Target target, - Slice payload, - @Nullable String idempotencyKey, - @Nullable Collection> headers) { - LOG.debug("Executing 'Call {}'", target); - if (idempotencyKey != null && idempotencyKey.isBlank()) { - throw ProtocolException.idempotencyKeyIsEmpty(); - } - - var invocationIdCompletionId = this.stateContext.getJournal().nextCompletionNotificationId(); - var callCompletionId = this.stateContext.getJournal().nextCompletionNotificationId(); - - var callCommandBuilder = - Protocol.CallCommandMessage.newBuilder() - .setServiceName(target.getService()) - .setHandlerName(target.getHandler()) - .setParameter(sliceToByteString(payload)) - .setInvocationIdNotificationIdx(invocationIdCompletionId) - .setResultCompletionId(callCompletionId); - if (target.getKey() != null) { - callCommandBuilder.setKey(target.getKey()); - } - if (idempotencyKey != null) { - callCommandBuilder.setIdempotencyKey(idempotencyKey); - } - if (headers != null) { - for (var header : headers) { - callCommandBuilder.addHeaders( - Protocol.Header.newBuilder() - .setKey(header.getKey()) - .setValue(header.getValue()) - .build()); - } - } - - var notificationHandles = - this.stateContext - .getCurrentState() - .processCompletableCommand( - callCommandBuilder.build(), - CommandAccessor.CALL, - new int[] {invocationIdCompletionId, callCompletionId}, - this.stateContext); - - return new CallHandle(notificationHandles[0], notificationHandles[1]); - } - - @Override - public int send( - Target target, - Slice payload, - @Nullable String idempotencyKey, - @Nullable Collection> headers, - @Nullable Duration delay) { - if (delay != null && !delay.isZero()) { - LOG.debug("Executing 'Delayed send {} with delay {}'", target, delay); - } else { - LOG.debug("Executing 'Send {}'", target); - } - if (idempotencyKey != null && idempotencyKey.isBlank()) { - throw ProtocolException.idempotencyKeyIsEmpty(); - } - - var invocationIdCompletionId = this.stateContext.getJournal().nextCompletionNotificationId(); - - var sendCommandBuilder = - Protocol.OneWayCallCommandMessage.newBuilder() - .setServiceName(target.getService()) - .setHandlerName(target.getHandler()) - .setParameter(sliceToByteString(payload)) - .setInvocationIdNotificationIdx(invocationIdCompletionId); - if (target.getKey() != null) { - sendCommandBuilder.setKey(target.getKey()); - } - if (idempotencyKey != null) { - sendCommandBuilder.setIdempotencyKey(idempotencyKey); - } - if (headers != null) { - for (var header : headers) { - sendCommandBuilder.addHeaders( - Protocol.Header.newBuilder() - .setKey(header.getKey()) - .setValue(header.getValue()) - .build()); - } - } - if (delay != null && !delay.isZero()) { - sendCommandBuilder.setInvokeTime(Instant.now().toEpochMilli() + delay.toMillis()); - } - - return this.stateContext.getCurrentState() - .processCompletableCommand( - sendCommandBuilder.build(), - CommandAccessor.ONE_WAY_CALL, - new int[] {invocationIdCompletionId}, - this.stateContext)[0]; - } - - @Override - public Awakeable awakeable() { - LOG.debug("Executing 'Create awakeable'"); - - var signalId = this.stateContext.getJournal().nextSignalNotificationId(); - - var signalHandle = - this.stateContext - .getCurrentState() - .createSignalHandle(new NotificationId.SignalId(signalId), this.stateContext); - - // Encode awakeable id - String awakeableId = Util.awakeableIdStr(this.stateContext.getStartInfo().id(), signalId); - - return new Awakeable(awakeableId, signalHandle); - } - - @Override - public void completeAwakeable(String awakeableId, Slice value) { - LOG.debug("Executing 'Complete awakeable {} with success'", awakeableId); - completeAwakeable( - awakeableId, - builder -> - builder.setValue( - Protocol.Value.newBuilder().setContent(sliceToByteString(value)).build())); - } - - @Override - public void completeAwakeable(String awakeableId, TerminalException exception) { - LOG.debug("Executing 'Complete awakeable {} with failure'", awakeableId); - verifyErrorMetadataFeatureSupport(exception); - completeAwakeable(awakeableId, builder -> builder.setFailure(toProtocolFailure(exception))); - } - - private void completeAwakeable( - String awakeableId, Consumer filler) { - var builder = Protocol.CompleteAwakeableCommandMessage.newBuilder().setAwakeableId(awakeableId); - filler.accept(builder); - - this.stateContext - .getCurrentState() - .processNonCompletableCommand( - builder.build(), CommandAccessor.COMPLETE_AWAKEABLE, this.stateContext); - } - - @Override - public int createSignalHandle(String signalName) { - LOG.debug("Executing 'Create signal handle {}'", signalName); - - return this.stateContext - .getCurrentState() - .createSignalHandle(new NotificationId.SignalName(signalName), this.stateContext); - } - - @Override - public void completeSignal(String targetInvocationId, String signalName, Slice value) { - LOG.debug( - "Executing 'Complete signal {} to invocation {} with success'", - signalName, - targetInvocationId); - this.completeSignal( - targetInvocationId, - signalName, - builder -> - builder.setValue( - Protocol.Value.newBuilder().setContent(sliceToByteString(value)).build())); - } - - @Override - public void completeSignal( - String targetInvocationId, String signalName, TerminalException exception) { - LOG.debug( - "Executing 'Complete signal {} to invocation {} with failure'", - signalName, - targetInvocationId); - verifyErrorMetadataFeatureSupport(exception); - this.completeSignal( - targetInvocationId, - signalName, - builder -> builder.setFailure(toProtocolFailure(exception))); - } - - private void completeSignal( - String targetInvocationId, - String signalName, - Consumer filler) { - var builder = - Protocol.SendSignalCommandMessage.newBuilder() - .setTargetInvocationId(targetInvocationId) - .setName(signalName); - filler.accept(builder); - - this.stateContext - .getCurrentState() - .processNonCompletableCommand( - builder.build(), CommandAccessor.SEND_SIGNAL, this.stateContext); - } - - @Override - public int promiseGet(String key) { - LOG.debug("Executing 'Await promise {}'", key); - var completionId = this.stateContext.getJournal().nextCompletionNotificationId(); - return this.stateContext.getCurrentState() - .processCompletableCommand( - Protocol.GetPromiseCommandMessage.newBuilder() - .setKey(key) - .setResultCompletionId(completionId) - .build(), - CommandAccessor.GET_PROMISE, - new int[] {completionId}, - this.stateContext)[0]; - } - - @Override - public int promisePeek(String key) { - LOG.debug("Executing 'Peek promise {}'", key); - var completionId = this.stateContext.getJournal().nextCompletionNotificationId(); - return this.stateContext.getCurrentState() - .processCompletableCommand( - Protocol.PeekPromiseCommandMessage.newBuilder() - .setKey(key) - .setResultCompletionId(completionId) - .build(), - CommandAccessor.PEEK_PROMISE, - new int[] {completionId}, - this.stateContext)[0]; - } - - @Override - public int promiseComplete(String key, Slice value) { - LOG.debug("Executing 'Complete promise {} with success'", key); - return this.promiseComplete( - key, - builder -> - builder.setCompletionValue( - Protocol.Value.newBuilder().setContent(sliceToByteString(value)).build())); - } - - @Override - public int promiseComplete(String key, TerminalException exception) { - LOG.debug("Executing 'Complete promise {} with failure'", key); - verifyErrorMetadataFeatureSupport(exception); - return this.promiseComplete( - key, builder -> builder.setCompletionFailure(toProtocolFailure(exception))); - } - - private int promiseComplete( - String key, Consumer filler) { - var completionId = this.stateContext.getJournal().nextCompletionNotificationId(); - - var builder = - Protocol.CompletePromiseCommandMessage.newBuilder() - .setResultCompletionId(completionId) - .setKey(key); - filler.accept(builder); - - return this.stateContext.getCurrentState() - .processCompletableCommand( - builder.build(), - CommandAccessor.COMPLETE_PROMISE, - new int[] {completionId}, - this.stateContext)[0]; - } - - @Override - public int run(String name) { - LOG.debug("Executing 'Created run {}'", name); - return this.stateContext.getCurrentState().processRunCommand(name, this.stateContext); - } - - @Override - public void proposeRunCompletion(int handle, Slice value) { - LOG.debug("Executing 'Run completed with success'"); - try { - this.stateContext.getCurrentState().proposeRunCompletion(handle, value, this.stateContext); - } catch (Throwable e) { - this.onError(e); - return; - } - this.triggerNextEventSignal(); - } - - @Override - public void proposeRunCompletion( - int handle, - Throwable exception, - Duration attemptDuration, - @Nullable RetryPolicy retryPolicy) { - LOG.debug("Executing 'Run completed with failure'"); - if (exception instanceof TerminalException) { - verifyErrorMetadataFeatureSupport((TerminalException) exception); - } - try { - this.stateContext - .getCurrentState() - .proposeRunCompletion(handle, exception, attemptDuration, retryPolicy, this.stateContext); - } catch (Throwable e) { - this.onError(e); - return; - } - this.triggerNextEventSignal(); - } - - @Override - public void cancelInvocation(String targetInvocationId) { - LOG.debug("Executing 'Cancel invocation {}'", targetInvocationId); - this.stateContext - .getCurrentState() - .processNonCompletableCommand( - Protocol.SendSignalCommandMessage.newBuilder() - .setTargetInvocationId(targetInvocationId) - .setIdx(CANCEL_SIGNAL_ID) - .setVoid(Protocol.Void.getDefaultInstance()) - .build(), - CommandAccessor.SEND_SIGNAL, - this.stateContext); - } - - @Override - public int attachInvocation(String invocationId) { - LOG.debug("Executing 'Attach invocation {}'", invocationId); - var completionId = this.stateContext.getJournal().nextCompletionNotificationId(); - return this.stateContext.getCurrentState() - .processCompletableCommand( - Protocol.AttachInvocationCommandMessage.newBuilder() - .setInvocationId(invocationId) - .setResultCompletionId(completionId) - .build(), - CommandAccessor.ATTACH_INVOCATION, - new int[] {completionId}, - this.stateContext)[0]; - } - - @Override - public int getInvocationOutput(String invocationId) { - LOG.debug("Executing 'Get invocation output {}'", invocationId); - var completionId = this.stateContext.getJournal().nextCompletionNotificationId(); - return this.stateContext.getCurrentState() - .processCompletableCommand( - Protocol.GetInvocationOutputCommandMessage.newBuilder() - .setInvocationId(invocationId) - .setResultCompletionId(completionId) - .build(), - CommandAccessor.GET_INVOCATION_OUTPUT, - new int[] {completionId}, - this.stateContext)[0]; - } - - @Override - public void writeOutput(Slice value) { - LOG.debug("Executing 'Write invocation output with success'"); - this.stateContext - .getCurrentState() - .processNonCompletableCommand( - Protocol.OutputCommandMessage.newBuilder() - .setValue(Protocol.Value.newBuilder().setContent(sliceToByteString(value)).build()) - .build(), - CommandAccessor.OUTPUT, - this.stateContext); - } - - @Override - public void writeOutput(TerminalException exception) { - LOG.debug("Executing 'Write invocation output with failure'"); - verifyErrorMetadataFeatureSupport(exception); - this.stateContext - .getCurrentState() - .processNonCompletableCommand( - Protocol.OutputCommandMessage.newBuilder() - .setFailure(toProtocolFailure(exception)) - .build(), - CommandAccessor.OUTPUT, - this.stateContext); - } - - @Override - public void end() { - this.stateContext.getCurrentState().end(this.stateContext); - cancelInputSubscription(); - } - - @Override - public InvocationState state() { - return this.stateContext.getCurrentState().getInvocationState(); - } - - private void cancelInputSubscription() { - if (this.inputSubscription != null) { - this.inputSubscription.cancel(); - this.inputSubscription = null; - } - } - - private void verifyErrorMetadataFeatureSupport(TerminalException exception) { - if (!exception.getMetadata().isEmpty() - && stateContext.getNegotiatedProtocolVersion().getNumber() - < Protocol.ServiceProtocolVersion.V6.getNumber()) { - throw ProtocolException.unsupportedFeature( - "terminal error metadata", - Protocol.ServiceProtocolVersion.V6, - stateContext.getNegotiatedProtocolVersion()); - } - } -} diff --git a/sdk-core/src/main/java/dev/restate/sdk/core/statemachine/Util.java b/sdk-core/src/main/java/dev/restate/sdk/core/statemachine/Util.java deleted file mode 100644 index 9c14f3258..000000000 --- a/sdk-core/src/main/java/dev/restate/sdk/core/statemachine/Util.java +++ /dev/null @@ -1,183 +0,0 @@ -// Copyright (c) 2023 - Restate Software, Inc., Restate GmbH -// -// This file is part of the Restate Java SDK, -// which is released under the MIT license. -// -// You can find a copy of the license in file LICENSE in the root -// directory of this repository or package, or at -// https://github.com/restatedev/sdk-java/blob/main/LICENSE -package dev.restate.sdk.core.statemachine; - -import com.google.protobuf.ByteString; -import com.google.protobuf.MessageLite; -import com.google.protobuf.UnsafeByteOperations; -import dev.restate.common.Slice; -import dev.restate.sdk.common.TerminalException; -import dev.restate.sdk.core.generated.protocol.Protocol; -import java.nio.ByteBuffer; -import java.time.Duration; -import java.util.Base64; -import java.util.Map; -import java.util.Objects; -import java.util.stream.Collectors; - -public class Util { - - static Protocol.Failure toProtocolFailure( - int code, String message, Map metadata) { - Protocol.Failure.Builder builder = Protocol.Failure.newBuilder().setCode(code); - if (message != null) { - builder.setMessage(message); - } - if (metadata != null) { - for (Map.Entry entry : metadata.entrySet()) { - builder.addMetadata( - Protocol.FailureMetadata.newBuilder() - .setKey(entry.getKey()) - .setValue(entry.getValue())); - } - } - return builder.build(); - } - - static Protocol.Failure toProtocolFailure(Throwable throwable) { - if (throwable instanceof TerminalException) { - return toProtocolFailure( - ((TerminalException) throwable).getCode(), - throwable.getMessage(), - ((TerminalException) throwable).getMetadata()); - } - return toProtocolFailure( - TerminalException.INTERNAL_SERVER_ERROR_CODE, throwable.toString(), Map.of()); - } - - static TerminalException toRestateException(Protocol.Failure failure) { - return new TerminalException( - failure.getCode(), - failure.getMessage(), - failure.getMetadataList().stream() - .collect( - Collectors.toMap( - Protocol.FailureMetadata::getKey, Protocol.FailureMetadata::getValue))); - } - - /** NOTE! This method rewinds the buffer!!! */ - static ByteString nioBufferToProtobufBuffer(ByteBuffer nioBuffer) { - return UnsafeByteOperations.unsafeWrap(nioBuffer); - } - - /** NOTE! This method rewinds the buffer!!! */ - static ByteString sliceToByteString(Slice slice) { - return nioBufferToProtobufBuffer(slice.asReadOnlyByteBuffer()); - } - - static Slice byteStringToSlice(ByteString byteString) { - return new ByteStringSlice(byteString); - } - - static Duration durationMin(Duration a, Duration b) { - return (a.compareTo(b) <= 0) ? a : b; - } - - private static final String AWAKEABLE_IDENTIFIER_PREFIX = "sign_1"; - - static String awakeableIdStr(ByteString invocationId, int signalId) { - return AWAKEABLE_IDENTIFIER_PREFIX - + Base64.getUrlEncoder() - .encodeToString( - invocationId - .concat(ByteString.copyFrom(ByteBuffer.allocate(4).putInt(signalId).flip())) - .toByteArray()); - } - - /** - * Returns a string representation of a command message. - * - * @param message The command message - * @return A string representation of the command message - */ - static String commandMessageToString(MessageLite message) { - if (message instanceof Protocol.InputCommandMessage) { - return "handler input"; - } else if (message instanceof Protocol.OutputCommandMessage) { - return "handler return"; - } else if (message instanceof Protocol.GetLazyStateCommandMessage) { - return "get state"; - } else if (message instanceof Protocol.GetLazyStateKeysCommandMessage) { - return "get state keys"; - } else if (message instanceof Protocol.SetStateCommandMessage) { - return "set state"; - } else if (message instanceof Protocol.ClearStateCommandMessage) { - return "clear state"; - } else if (message instanceof Protocol.ClearAllStateCommandMessage) { - return "clear all state"; - } else if (message instanceof Protocol.GetPromiseCommandMessage) { - return "get promise"; - } else if (message instanceof Protocol.PeekPromiseCommandMessage) { - return "peek promise"; - } else if (message instanceof Protocol.CompletePromiseCommandMessage) { - return "complete promise"; - } else if (message instanceof Protocol.SleepCommandMessage) { - return "sleep"; - } else if (message instanceof Protocol.CallCommandMessage) { - return "call"; - } else if (message instanceof Protocol.OneWayCallCommandMessage) { - return "one way call/send"; - } else if (message instanceof Protocol.SendSignalCommandMessage) { - return "send signal"; - } else if (message instanceof Protocol.RunCommandMessage) { - return "run"; - } else if (message instanceof Protocol.AttachInvocationCommandMessage) { - return "attach invocation"; - } else if (message instanceof Protocol.GetInvocationOutputCommandMessage) { - return "get invocation output"; - } else if (message instanceof Protocol.CompleteAwakeableCommandMessage) { - return "complete awakeable"; - } - - return message.getClass().getSimpleName(); - } - - private static final class ByteStringSlice implements Slice { - private final ByteString byteString; - - public ByteStringSlice(ByteString bytes) { - this.byteString = Objects.requireNonNull(bytes); - } - - @Override - public ByteBuffer asReadOnlyByteBuffer() { - return byteString.asReadOnlyByteBuffer(); - } - - @Override - public int readableBytes() { - return byteString.size(); - } - - @Override - public void copyTo(byte[] target) { - copyTo(target, 0); - } - - @Override - public void copyTo(byte[] target, int targetOffset) { - byteString.copyTo(target, targetOffset); - } - - @Override - public byte byteAt(int position) { - return byteString.byteAt(position); - } - - @Override - public void copyTo(ByteBuffer buffer) { - byteString.copyTo(buffer); - } - - @Override - public byte[] toByteArray() { - return byteString.toByteArray(); - } - } -} diff --git a/sdk-core/src/main/java/dev/restate/sdk/core/statemachine/WaitingReplayEntriesState.java b/sdk-core/src/main/java/dev/restate/sdk/core/statemachine/WaitingReplayEntriesState.java deleted file mode 100644 index 8be9c36f6..000000000 --- a/sdk-core/src/main/java/dev/restate/sdk/core/statemachine/WaitingReplayEntriesState.java +++ /dev/null @@ -1,68 +0,0 @@ -// Copyright (c) 2023 - Restate Software, Inc., Restate GmbH -// -// This file is part of the Restate Java SDK, -// which is released under the MIT license. -// -// You can find a copy of the license in file LICENSE in the root -// directory of this repository or package, or at -// https://github.com/restatedev/sdk-java/blob/main/LICENSE -package dev.restate.sdk.core.statemachine; - -import com.google.protobuf.MessageLite; -import dev.restate.sdk.core.ProtocolException; -import dev.restate.sdk.core.generated.protocol.Protocol; -import java.util.ArrayDeque; -import java.util.Deque; -import java.util.concurrent.CompletableFuture; - -final class WaitingReplayEntriesState implements State { - - private int receivedEntries = 0; - private final Deque commandsToProcess = new ArrayDeque<>(); - private final AsyncResultsState asyncResultsState = new AsyncResultsState(); - - @Override - public void onNewMessage( - InvocationInput invocationInput, - StateContext stateContext, - CompletableFuture waitForReadyFuture) { - if (invocationInput.header().getType().isNotification()) { - if (!(invocationInput.message() - instanceof Protocol.NotificationTemplate notificationTemplate)) { - throw ProtocolException.unexpectedMessage( - Protocol.NotificationTemplate.class, invocationInput.message()); - } - - this.asyncResultsState.enqueue(notificationTemplate); - } else if (invocationInput.header().getType().isCommand()) { - this.commandsToProcess.add(invocationInput.message()); - } else { - throw ProtocolException.unexpectedMessage( - "command or notification", invocationInput.message()); - } - - this.receivedEntries++; - - if (stateContext.getStartInfo().entriesToReplay() == this.receivedEntries) { - stateContext - .getStateHolder() - .transition(new ReplayingState(commandsToProcess, asyncResultsState)); - waitForReadyFuture.complete(null); - } - } - - @Override - public void onInputClosed(StateContext stateContext) { - throw ProtocolException.inputClosedWhileWaitingEntries(); - } - - @Override - public void end(StateContext stateContext) { - throw ProtocolException.closedWhileWaitingEntries(); - } - - @Override - public InvocationState getInvocationState() { - return InvocationState.WAITING_START; - } -} diff --git a/sdk-core/src/main/java/dev/restate/sdk/core/statemachine/WaitingStartState.java b/sdk-core/src/main/java/dev/restate/sdk/core/statemachine/WaitingStartState.java deleted file mode 100644 index 6b7ded0f5..000000000 --- a/sdk-core/src/main/java/dev/restate/sdk/core/statemachine/WaitingStartState.java +++ /dev/null @@ -1,73 +0,0 @@ -// Copyright (c) 2023 - Restate Software, Inc., Restate GmbH -// -// This file is part of the Restate Java SDK, -// which is released under the MIT license. -// -// You can find a copy of the license in file LICENSE in the root -// directory of this repository or package, or at -// https://github.com/restatedev/sdk-java/blob/main/LICENSE -package dev.restate.sdk.core.statemachine; - -import dev.restate.sdk.common.TerminalException; -import dev.restate.sdk.core.ProtocolException; -import dev.restate.sdk.core.generated.protocol.Protocol; -import java.time.Duration; -import java.util.concurrent.CompletableFuture; - -final class WaitingStartState implements State { - - @Override - public void onNewMessage( - InvocationInput invocationInput, - StateContext stateContext, - CompletableFuture waitForReadyFuture) { - if (!(invocationInput.message() instanceof Protocol.StartMessage startMessage)) { - throw ProtocolException.unexpectedMessage( - Protocol.StartMessage.class, invocationInput.message()); - } - - // Sanity checks - if (startMessage.getKnownEntries() == 0) { - throw new ProtocolException( - "Expected at least one entry with Input, got 0 entries", - TerminalException.INTERNAL_SERVER_ERROR_CODE); - } - - // Register start info and eager state - stateContext.setStartInfo( - new StartInfo( - startMessage.getId(), - startMessage.getDebugId(), - startMessage.getKey(), - startMessage.getKnownEntries(), - startMessage.getRetryCountSinceLastStoredEntry(), - Duration.ofMillis(startMessage.getDurationSinceLastStoredEntry()), - // Random seed from start message will be set only if protocol >= 6 - stateContext.getNegotiatedProtocolVersion().getNumber() - >= Protocol.ServiceProtocolVersion.V6_VALUE - ? startMessage.getRandomSeed() - : null)); - stateContext.setEagerState(new EagerState(startMessage)); - - // Tracing and logging setup - LOG.info("Start invocation"); - - // Execute state transition - stateContext.getStateHolder().transition(new WaitingReplayEntriesState()); - } - - @Override - public void onInputClosed(StateContext stateContext) { - throw ProtocolException.inputClosedWhileWaitingEntries(); - } - - @Override - public void end(StateContext stateContext) { - throw ProtocolException.closedWhileWaitingEntries(); - } - - @Override - public InvocationState getInvocationState() { - return InvocationState.WAITING_START; - } -} diff --git a/sdk-core/src/main/rust/Cargo.lock b/sdk-core/src/main/rust/Cargo.lock new file mode 100644 index 000000000..93d4eba65 --- /dev/null +++ b/sdk-core/src/main/rust/Cargo.lock @@ -0,0 +1,477 @@ +# This file is automatically @generated by Cargo. +# It is not intended for manual editing. +version = 4 + +[[package]] +name = "aho-corasick" +version = "1.1.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ddd31a130427c27518df266943a5308ed92d4b226cc639f5a8f1002816174301" +dependencies = [ + "memchr", +] + +[[package]] +name = "anyhow" +version = "1.0.102" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7f202df86484c868dbad7eaa557ef785d5c66295e41b460ef922eca0723b842c" + +[[package]] +name = "base64" +version = "0.22.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "72b3254f16251a8381aa12e40e3c4d2f0199f8c6508fbecb9d91f575e0fbb8c6" + +[[package]] +name = "bytes" +version = "1.11.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1e748733b7cbc798e1434b6ac524f0c1ff2ab456fe201501e6497c8417a4fc33" + +[[package]] +name = "bytes-utils" +version = "0.1.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7dafe3a8757b027e2be6e4e5601ed563c55989fcf1546e933c66c8eb3a058d35" +dependencies = [ + "bytes", + "either", +] + +[[package]] +name = "cfg-if" +version = "1.0.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9330f8b2ff13f34540b44e946ef35111825727b38d33286ef986142615121801" + +[[package]] +name = "ciborium" +version = "0.2.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "42e69ffd6f0917f5c029256a24d0161db17cea3997d185db0d35926308770f0e" +dependencies = [ + "ciborium-io", + "ciborium-ll", + "serde", +] + +[[package]] +name = "ciborium-io" +version = "0.2.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "05afea1e0a06c9be33d539b876f1ce3692f4afea2cb41f740e7743225ed1c757" + +[[package]] +name = "ciborium-ll" +version = "0.2.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "57663b653d948a338bfb3eeba9bb2fd5fcfaecb9e199e87e1eda4d9e8b240fd9" +dependencies = [ + "ciborium-io", + "half", +] + +[[package]] +name = "crunchy" +version = "0.2.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "460fbee9c2c2f33933d720630a6a0bac33ba7053db5344fac858d4b8952d77d5" + +[[package]] +name = "either" +version = "1.15.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "48c757948c5ede0e46177b7add2e67155f70e33c07fea8284df6576da70b3719" + +[[package]] +name = "half" +version = "2.7.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6ea2d84b969582b4b1864a92dc5d27cd2b77b622a8d79306834f1be5ba20d84b" +dependencies = [ + "cfg-if", + "crunchy", + "zerocopy", +] + +[[package]] +name = "heck" +version = "0.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2304e00983f87ffb38b55b444b5e3b60a884b5d30c0fca7d82fe33449bbe55ea" + +[[package]] +name = "itertools" +version = "0.14.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2b192c782037fadd9cfa75548310488aabdbf3d2da73885b31bd0abd03351285" +dependencies = [ + "either", +] + +[[package]] +name = "lazy_static" +version = "1.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "bbd2bcb4c963f2ddae06a2efc7e9f3591312473c50c6685e1f298068316e66fe" + +[[package]] +name = "log" +version = "0.4.29" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5e5032e24019045c762d3c0f28f5b6b8bbf38563a65908389bf7978758920897" + +[[package]] +name = "matchers" +version = "0.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d1525a2a28c7f4fa0fc98bb91ae755d1e2d1505079e05539e35bc876b5d65ae9" +dependencies = [ + "regex-automata", +] + +[[package]] +name = "memchr" +version = "2.8.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f8ca58f447f06ed17d5fc4043ce1b10dd205e060fb3ce5b979b8ed8e59ff3f79" + +[[package]] +name = "nu-ansi-term" +version = "0.50.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7957b9740744892f114936ab4a57b3f487491bbeafaf8083688b16841a4240e5" +dependencies = [ + "windows-sys", +] + +[[package]] +name = "once_cell" +version = "1.21.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9f7c3e4beb33f85d45ae3e3a1792185706c8e16d043238c593331cc7cd313b50" + +[[package]] +name = "pastey" +version = "0.2.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c5a797f0e07bdf071d15742978fc3128ec6c22891c31a3a931513263904c982a" + +[[package]] +name = "pin-project-lite" +version = "0.2.17" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a89322df9ebe1c1578d689c92318e070967d1042b512afbe49518723f4e6d5cd" + +[[package]] +name = "proc-macro2" +version = "1.0.106" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8fd00f0bb2e90d81d1044c2b32617f68fcb9fa3bb7640c23e9c748e53fb30934" +dependencies = [ + "unicode-ident", +] + +[[package]] +name = "prost" +version = "0.14.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d2ea70524a2f82d518bce41317d0fae74151505651af45faf1ffbd6fd33f0568" +dependencies = [ + "bytes", + "prost-derive", +] + +[[package]] +name = "prost-derive" +version = "0.14.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "27c6023962132f4b30eb4c172c91ce92d933da334c59c23cddee82358ddafb0b" +dependencies = [ + "anyhow", + "itertools", + "proc-macro2", + "quote", + "syn", +] + +[[package]] +name = "quote" +version = "1.0.45" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "41f2619966050689382d2b44f664f4bc593e129785a36d6ee376ddf37259b924" +dependencies = [ + "proc-macro2", +] + +[[package]] +name = "regex-automata" +version = "0.4.14" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6e1dd4122fc1595e8162618945476892eefca7b88c52820e74af6262213cae8f" +dependencies = [ + "aho-corasick", + "memchr", + "regex-syntax", +] + +[[package]] +name = "regex-syntax" +version = "0.8.10" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "dc897dd8d9e8bd1ed8cdad82b5966c3e0ecae09fb1907d58efaa013543185d0a" + +[[package]] +name = "restate-sdk-shared-core" +version = "0.10.0" +source = "git+https://github.com/restatedev/sdk-shared-core?branch=main#d8a42ecceab6e7874138b6316e128a09f2de76d1" +dependencies = [ + "base64", + "bytes", + "bytes-utils", + "pastey", + "prost", + "serde", + "strum", + "thiserror", + "tracing", + "tracing-subscriber", +] + +[[package]] +name = "restate-sdk-shared-core-wasm" +version = "0.1.0" +dependencies = [ + "bytes", + "ciborium", + "restate-sdk-shared-core", + "serde", + "serde_bytes", + "tracing", + "tracing-subscriber", +] + +[[package]] +name = "serde" +version = "1.0.228" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9a8e94ea7f378bd32cbbd37198a4a91436180c5bb472411e48b5ec2e2124ae9e" +dependencies = [ + "serde_core", + "serde_derive", +] + +[[package]] +name = "serde_bytes" +version = "0.11.19" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a5d440709e79d88e51ac01c4b72fc6cb7314017bb7da9eeff678aa94c10e3ea8" +dependencies = [ + "serde", + "serde_core", +] + +[[package]] +name = "serde_core" +version = "1.0.228" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "41d385c7d4ca58e59fc732af25c3983b67ac852c1a25000afe1175de458b67ad" +dependencies = [ + "serde_derive", +] + +[[package]] +name = "serde_derive" +version = "1.0.228" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d540f220d3187173da220f885ab66608367b6574e925011a9353e4badda91d79" +dependencies = [ + "proc-macro2", + "quote", + "syn", +] + +[[package]] +name = "sharded-slab" +version = "0.1.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f40ca3c46823713e0d4209592e8d6e826aa57e928f09752619fc696c499637f6" +dependencies = [ + "lazy_static", +] + +[[package]] +name = "smallvec" +version = "1.15.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "67b1b7a3b5fe4f1376887184045fcf45c69e92af734b7aaddc05fb777b6fbd03" + +[[package]] +name = "strum" +version = "0.27.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "af23d6f6c1a224baef9d3f61e287d2761385a5b88fdab4eb4c6f11aeb54c4bcf" +dependencies = [ + "strum_macros", +] + +[[package]] +name = "strum_macros" +version = "0.27.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7695ce3845ea4b33927c055a39dc438a45b059f7c1b3d91d38d10355fb8cbca7" +dependencies = [ + "heck", + "proc-macro2", + "quote", + "syn", +] + +[[package]] +name = "syn" +version = "2.0.117" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e665b8803e7b1d2a727f4023456bbbbe74da67099c585258af0ad9c5013b9b99" +dependencies = [ + "proc-macro2", + "quote", + "unicode-ident", +] + +[[package]] +name = "thiserror" +version = "2.0.18" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4288b5bcbc7920c07a1149a35cf9590a2aa808e0bc1eafaade0b80947865fbc4" +dependencies = [ + "thiserror-impl", +] + +[[package]] +name = "thiserror-impl" +version = "2.0.18" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ebc4ee7f67670e9b64d05fa4253e753e016c6c95ff35b89b7941d6b856dec1d5" +dependencies = [ + "proc-macro2", + "quote", + "syn", +] + +[[package]] +name = "thread_local" +version = "1.1.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f60246a4944f24f6e018aa17cdeffb7818b76356965d03b07d6a9886e8962185" +dependencies = [ + "cfg-if", +] + +[[package]] +name = "tracing" +version = "0.1.44" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "63e71662fa4b2a2c3a26f570f037eb95bb1f85397f3cd8076caed2f026a6d100" +dependencies = [ + "pin-project-lite", + "tracing-attributes", + "tracing-core", +] + +[[package]] +name = "tracing-attributes" +version = "0.1.31" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7490cfa5ec963746568740651ac6781f701c9c5ea257c58e057f3ba8cf69e8da" +dependencies = [ + "proc-macro2", + "quote", + "syn", +] + +[[package]] +name = "tracing-core" +version = "0.1.36" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "db97caf9d906fbde555dd62fa95ddba9eecfd14cb388e4f491a66d74cd5fb79a" +dependencies = [ + "once_cell", + "valuable", +] + +[[package]] +name = "tracing-log" +version = "0.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ee855f1f400bd0e5c02d150ae5de3840039a3f54b025156404e34c23c03f47c3" +dependencies = [ + "log", + "once_cell", + "tracing-core", +] + +[[package]] +name = "tracing-subscriber" +version = "0.3.23" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "cb7f578e5945fb242538965c2d0b04418d38ec25c79d160cd279bf0731c8d319" +dependencies = [ + "matchers", + "nu-ansi-term", + "once_cell", + "regex-automata", + "sharded-slab", + "smallvec", + "thread_local", + "tracing", + "tracing-core", + "tracing-log", +] + +[[package]] +name = "unicode-ident" +version = "1.0.24" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e6e4313cd5fcd3dad5cafa179702e2b244f760991f45397d14d4ebf38247da75" + +[[package]] +name = "valuable" +version = "0.1.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ba73ea9cf16a25df0c8caa16c51acb937d5712a8429db78a3ee29d5dcacd3a65" + +[[package]] +name = "windows-link" +version = "0.2.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f0805222e57f7521d6a62e36fa9163bc891acd422f971defe97d64e70d0a4fe5" + +[[package]] +name = "windows-sys" +version = "0.61.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ae137229bcbd6cdf0f7b80a31df61766145077ddf49416a728b02cb3921ff3fc" +dependencies = [ + "windows-link", +] + +[[package]] +name = "zerocopy" +version = "0.8.48" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "eed437bf9d6692032087e337407a86f04cd8d6a16a37199ed57949d415bd68e9" +dependencies = [ + "zerocopy-derive", +] + +[[package]] +name = "zerocopy-derive" +version = "0.8.48" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "70e3cd084b1788766f53af483dd21f93881ff30d7320490ec3ef7526d203bad4" +dependencies = [ + "proc-macro2", + "quote", + "syn", +] diff --git a/sdk-core/src/main/rust/Cargo.toml b/sdk-core/src/main/rust/Cargo.toml new file mode 100644 index 000000000..dc02be3f8 --- /dev/null +++ b/sdk-core/src/main/rust/Cargo.toml @@ -0,0 +1,20 @@ +[package] +name = "restate-sdk-shared-core-wasm" +version = "0.1.0" +edition = "2021" + +[lib] +crate-type = ["cdylib"] + +[dependencies] +restate-sdk-shared-core = { git = "https://github.com/restatedev/sdk-shared-core", branch = "main", features = ["tracing_pretty"] } +bytes = "1" +serde = { version = "1", features = ["derive"] } +serde_bytes = "0.11" +ciborium = "0.2" +tracing = "0.1" +tracing-subscriber = { version = "0.3", features = ["fmt", "env-filter"] } + +[profile.release] +opt-level = 3 +lto = true \ No newline at end of file diff --git a/sdk-core/src/main/rust/src/lib.rs b/sdk-core/src/main/rust/src/lib.rs new file mode 100644 index 000000000..d31353857 --- /dev/null +++ b/sdk-core/src/main/rust/src/lib.rs @@ -0,0 +1,1425 @@ +//! Rust WASM wrapper around `restate-sdk-shared-core` for the Java SDK. +//! +//! Mirrors sdk-go/shared-core/src/lib.rs in structure exactly. +//! The only difference is CBOR (ciborium + serde) instead of protobuf (prost). +//! +//! Structure: +//! - Each exported function is a thin `pub unsafe extern "C"` wrapper (prefixed `_`) +//! that calls `ptr_to_input` / `output_to_ptr` and delegates to a safe inner fn. +//! - Inner functions take `&Rc>` and return a typed CBOR response. +//! - `From` impls at the bottom convert core results to CBOR response types, +//! enabling `.into()` in inner functions (same pattern as Go). + +#![allow(clippy::missing_safety_doc)] + +use bytes::Bytes; +use restate_sdk_shared_core::{ + AttachInvocationTarget, AwaitResponse, CoreVM, Error, Header, HeaderMap, NonEmptyValue, + NotificationHandle, PayloadOptions, ResponseHead, RetryPolicy, RunExitResult, TakeOutputResult, + Target, TerminalFailure, UnresolvedFuture, VMOptions, Value, Version, VM, +}; +use serde::{Deserialize, Serialize}; +use std::borrow::Cow; +use std::cell::RefCell; +use std::convert::Infallible; +use std::io::Write; +use std::mem::MaybeUninit; +use std::rc::Rc; +use std::time::Duration; +use tracing::level_filters::LevelFilter; +use tracing::{Level, Subscriber}; +use tracing_subscriber::fmt::format::FmtSpan; +use tracing_subscriber::fmt::MakeWriter; +use tracing_subscriber::layer::SubscriberExt; +use tracing_subscriber::{Layer, Registry}; + +// --------- Init and logging + +#[export_name = "init"] +pub unsafe extern "C" fn init(level: u32) { + std::panic::set_hook(Box::new(|panic| { + let panic_str = format!("Core panicked: {panic}"); + log(AbiLogLevel::Error, &panic_str) + })); + let _ = tracing::subscriber::set_global_default(log_subscriber(level.into())); +} + +#[repr(u32)] +enum AbiLogLevel { + Trace = 0, + Debug = 1, + Info = 2, + Warn = 3, + Error = 4, +} + +impl From for AbiLogLevel { + fn from(value: u32) -> Self { + match value { + 0 => AbiLogLevel::Trace, + 1 => AbiLogLevel::Debug, + 2 => AbiLogLevel::Info, + 3 => AbiLogLevel::Warn, + 4 => AbiLogLevel::Error, + _ => AbiLogLevel::Error, + } + } +} + +impl From for AbiLogLevel { + fn from(value: Level) -> Self { + match value { + Level::TRACE => AbiLogLevel::Trace, + Level::DEBUG => AbiLogLevel::Debug, + Level::INFO => AbiLogLevel::Info, + Level::WARN => AbiLogLevel::Warn, + Level::ERROR => AbiLogLevel::Error, + } + } +} + +impl From for Level { + fn from(value: AbiLogLevel) -> Self { + match value { + AbiLogLevel::Trace => Level::TRACE, + AbiLogLevel::Debug => Level::DEBUG, + AbiLogLevel::Info => Level::INFO, + AbiLogLevel::Warn => Level::WARN, + AbiLogLevel::Error => Level::ERROR, + } + } +} + +pub struct MakeAbiLogWriter; + +impl<'a> MakeWriter<'a> for MakeAbiLogWriter { + type Writer = ConsoleWriter; + + fn make_writer(&'a self) -> Self::Writer { + ConsoleWriter { + buffer: vec![], + level: Level::TRACE, + } + } + + fn make_writer_for(&'a self, meta: &tracing::Metadata<'_>) -> Self::Writer { + let level = *meta.level(); + ConsoleWriter { + buffer: vec![], + level, + } + } +} + +pub struct ConsoleWriter { + buffer: Vec, + level: Level, +} + +impl Write for ConsoleWriter { + fn write(&mut self, buf: &[u8]) -> std::io::Result { + self.buffer.write(buf) + } + + fn flush(&mut self) -> std::io::Result<()> { + Ok(()) + } +} + +impl Drop for ConsoleWriter { + fn drop(&mut self) { + let mut len = self.buffer.len(); + if len > 0 && self.buffer[len - 1] == b'\n' { + len -= 1; + } + unsafe { + _log( + AbiLogLevel::from(self.level) as u32, + self.buffer.as_ptr() as u32, + len as u32, + ) + } + } +} + +fn log_subscriber(level: AbiLogLevel) -> impl Subscriber + Send + Sync + 'static { + let level = level.into(); + let fmt_layer = tracing_subscriber::fmt::layer() + .with_ansi(false) + .without_time() + .with_thread_names(false) + .with_thread_ids(false) + .with_file(false) + .with_line_number(false) + .with_target(level == Level::TRACE) + .with_level(false) + .with_span_events(if level == Level::TRACE { + FmtSpan::ENTER + } else { + FmtSpan::NONE + }) + .with_writer(MakeAbiLogWriter) + .with_filter(LevelFilter::from_level(level)); + Registry::default().with(fmt_layer) +} + +// --------- VM + +pub struct WasmVM { + vm: CoreVM, +} + +pub struct WasmHeaders(Vec<(String, String)>); + +impl HeaderMap for WasmHeaders { + type Error = Infallible; + + fn extract(&self, name: &str) -> Result, Self::Error> { + for (key, value) in &self.0 { + if key.eq_ignore_ascii_case(name) { + return Ok(Some(value)); + } + } + Ok(None) + } +} + +#[export_name = "vm_new"] +pub unsafe extern "C" fn _vm_new(ptr: *mut u8, len: usize) -> u64 { + let input = ptr_to_input(ptr, len); + let response = vm_new(input); + output_to_ptr(response) +} + +fn vm_new(input: VmNewParameters) -> VmNewReturn { + match CoreVM::new(WasmHeaders(input.headers), VMOptions::default()) { + Ok(vm) => { + let wasm_vm = WasmVM { vm }; + VmNewReturn::Ok { + pointer: Rc::into_raw(Rc::new(RefCell::new(wasm_vm))) as u32, + } + } + Err(e) => VmNewReturn::from_err(e), + } +} + +#[export_name = "vm_get_response_head"] +pub unsafe extern "C" fn _vm_get_response_head(vm_pointer: *const RefCell) -> u64 { + let rc_vm = vm_ptr_to_rc(vm_pointer); + let response: ResponseHeadReturn = VM::get_response_head(&rc_vm.borrow().vm).into(); + output_to_ptr(response) +} + +#[export_name = "vm_notify_input"] +pub unsafe extern "C" fn _vm_notify_input( + vm_pointer: *const RefCell, + ptr: *mut u8, + len: usize, +) { + let rc_vm = vm_ptr_to_rc(vm_pointer); + let input = ptr_to_vec(ptr, len); + VM::notify_input(&mut rc_vm.borrow_mut().vm, input.into()); +} + +#[export_name = "vm_notify_input_closed"] +pub unsafe extern "C" fn _vm_notify_input_closed(vm_pointer: *const RefCell) { + let rc_vm = vm_ptr_to_rc(vm_pointer); + VM::notify_input_closed(&mut rc_vm.borrow_mut().vm); +} + +#[export_name = "vm_notify_error"] +pub unsafe extern "C" fn _vm_notify_error( + vm_pointer: *const RefCell, + ptr: *mut u8, + len: usize, +) { + let rc_vm = vm_ptr_to_rc(vm_pointer); + let input = ptr_to_input(ptr, len); + vm_notify_error(&rc_vm, input); +} + +fn vm_notify_error(rc_vm: &Rc>, input: VmNotifyError) { + let mut error = Error::new(500u16, Cow::Owned(input.message)); + if let Some(st) = input.stacktrace { + error = error.with_stacktrace(st); + } + if let Some(delay) = input.delay_override_millis { + error = error.with_next_retry_delay_override(Duration::from_millis(delay)); + } + VM::notify_error(&mut rc_vm.borrow_mut().vm, error, None) +} + +#[export_name = "vm_take_output"] +pub unsafe extern "C" fn _vm_take_output(vm_pointer: *const RefCell) -> u64 { + let rc_vm = vm_ptr_to_rc(vm_pointer); + let res: Vec = match VM::take_output(&mut rc_vm.borrow_mut().vm) { + TakeOutputResult::Buffer(b) => b.to_vec(), + TakeOutputResult::EOF => Vec::default(), + }; + vec_to_ptr(res) +} + +#[export_name = "vm_is_ready_to_execute"] +pub unsafe extern "C" fn _vm_is_ready_to_execute(vm_pointer: *const RefCell) -> u64 { + let rc_vm = vm_ptr_to_rc(vm_pointer); + let res = vm_is_ready_to_execute(&rc_vm); + output_to_ptr(res) +} + +fn vm_is_ready_to_execute(rc_vm: &Rc>) -> IsReadyReturn { + match VM::is_ready_to_execute(&rc_vm.borrow().vm) { + Ok(ready) => IsReadyReturn::Ok { ready }, + Err(e) => IsReadyReturn::from_err(e), + } +} + +#[export_name = "vm_is_completed"] +pub unsafe extern "C" fn _vm_is_completed(vm_pointer: *const RefCell, handle: u32) -> u64 { + let rc_vm = vm_ptr_to_rc(vm_pointer); + let result = VM::is_completed(&rc_vm.borrow().vm, NotificationHandle::from(handle)); + result as u64 +} + +#[export_name = "vm_is_processing"] +pub unsafe extern "C" fn _vm_is_processing(vm_pointer: *const RefCell) -> u64 { + let rc_vm = vm_ptr_to_rc(vm_pointer); + let result = VM::is_processing(&rc_vm.borrow().vm); + result as u64 +} + +#[export_name = "vm_do_progress"] +pub unsafe extern "C" fn _vm_do_progress( + vm_pointer: *const RefCell, + ptr: *mut u8, + len: usize, +) -> u64 { + let rc_vm = vm_ptr_to_rc(vm_pointer); + let input = ptr_to_input(ptr, len); + let res = vm_do_progress(&rc_vm, input); + output_to_ptr(res) +} + +fn vm_do_progress(rc_vm: &Rc>, input: VmDoProgressParameters) -> DoProgressReturn { + match VM::do_await( + &mut rc_vm.borrow_mut().vm, + UnresolvedFuture::FirstCompleted( + input + .handles + .into_iter() + .map(NotificationHandle::from) + .map(UnresolvedFuture::Single) + .collect(), + ), + ) { + Ok(AwaitResponse::AnyCompleted) => DoProgressReturn::AnyCompleted, + Ok(AwaitResponse::WaitingExternalProgress { .. }) => { + DoProgressReturn::WaitingExternalProgress + } + Ok(AwaitResponse::CancelSignalReceived) => DoProgressReturn::CancelSignalReceived, + Ok(AwaitResponse::ExecuteRun(handle)) => DoProgressReturn::ExecuteRun { + handle: handle.into(), + }, + Err(e) if e.is_suspended_error() => DoProgressReturn::Suspended, + Err(e) => DoProgressReturn::from_err(e), + } +} + +#[export_name = "vm_take_notification"] +pub unsafe extern "C" fn _vm_take_notification( + vm_pointer: *const RefCell, + handle: u32, +) -> u64 { + let rc_vm = vm_ptr_to_rc(vm_pointer); + let res = vm_take_notification(&rc_vm, NotificationHandle::from(handle)); + output_to_ptr(res) +} + +fn vm_take_notification( + rc_vm: &Rc>, + handle: NotificationHandle, +) -> TakeNotificationReturn { + match VM::take_notification(&mut rc_vm.borrow_mut().vm, handle) { + Ok(None) => TakeNotificationReturn::NotReady, + Ok(Some(v)) => TakeNotificationReturn::Value { + value: NotificationValue::from(v), + }, + Err(e) if e.is_suspended_error() => TakeNotificationReturn::Suspended, + Err(e) => TakeNotificationReturn::from_err(e), + } +} + +#[export_name = "vm_sys_input"] +pub unsafe extern "C" fn _vm_sys_input(vm_pointer: *const RefCell) -> u64 { + let rc_vm = vm_ptr_to_rc(vm_pointer); + let res = vm_sys_input(&rc_vm); + output_to_ptr(res) +} + +fn vm_sys_input(rc_vm: &Rc>) -> SysInputReturn { + let mut vm = rc_vm.borrow_mut(); + let protocol_version = vm.vm.get_response_head().version; + match VM::sys_input(&mut vm.vm) { + Ok(input) => SysInputReturn::Ok { + input: WasmInput { + invocation_id: input.invocation_id, + key: input.key, + headers: input + .headers + .into_iter() + .map(|h| (h.key.into_owned(), h.value.into_owned())) + .collect(), + input: input.input.to_vec(), + random_seed: input.random_seed as i64, + should_use_random_seed: protocol_version >= Version::V6, + }, + }, + Err(e) => SysInputReturn::from_err(e), + } +} + +#[export_name = "vm_sys_state_get"] +pub unsafe extern "C" fn _vm_sys_state_get( + vm_pointer: *const RefCell, + ptr: *mut u8, + len: usize, +) -> u64 { + let rc_vm = vm_ptr_to_rc(vm_pointer); + let input = ptr_to_input(ptr, len); + let res = vm_sys_state_get(&rc_vm, input); + output_to_ptr(res) +} + +fn vm_sys_state_get(rc_vm: &Rc>, input: VmSysStateGetParameters) -> HandleReturn { + VM::sys_state_get( + &mut rc_vm.borrow_mut().vm, + input.key, + PayloadOptions::default(), + ) + .into() +} + +#[export_name = "vm_sys_state_get_keys"] +pub unsafe extern "C" fn _vm_sys_state_get_keys(vm_pointer: *const RefCell) -> u64 { + let rc_vm = vm_ptr_to_rc(vm_pointer); + let res = vm_sys_state_get_keys(&rc_vm); + output_to_ptr(res) +} + +fn vm_sys_state_get_keys(rc_vm: &Rc>) -> HandleReturn { + VM::sys_state_get_keys(&mut rc_vm.borrow_mut().vm).into() +} + +#[export_name = "vm_sys_state_set"] +pub unsafe extern "C" fn _vm_sys_state_set( + vm_pointer: *const RefCell, + ptr: *mut u8, + len: usize, +) -> u64 { + let rc_vm = vm_ptr_to_rc(vm_pointer); + let input = ptr_to_input(ptr, len); + let res = vm_sys_state_set(&rc_vm, input); + output_to_ptr(res) +} + +fn vm_sys_state_set(rc_vm: &Rc>, input: VmSysStateSetParameters) -> EmptyReturn { + VM::sys_state_set( + &mut rc_vm.borrow_mut().vm, + input.key, + Bytes::from(input.value), + PayloadOptions::default(), + ) + .into() +} + +#[export_name = "vm_sys_state_clear"] +pub unsafe extern "C" fn _vm_sys_state_clear( + vm_pointer: *const RefCell, + ptr: *mut u8, + len: usize, +) -> u64 { + let rc_vm = vm_ptr_to_rc(vm_pointer); + let input = ptr_to_input(ptr, len); + let res = vm_sys_state_clear(&rc_vm, input); + output_to_ptr(res) +} + +fn vm_sys_state_clear( + rc_vm: &Rc>, + input: VmSysStateClearParameters, +) -> EmptyReturn { + VM::sys_state_clear(&mut rc_vm.borrow_mut().vm, input.key).into() +} + +#[export_name = "vm_sys_state_clear_all"] +pub unsafe extern "C" fn _vm_sys_state_clear_all(vm_pointer: *const RefCell) -> u64 { + let rc_vm = vm_ptr_to_rc(vm_pointer); + let res = vm_sys_state_clear_all(&rc_vm); + output_to_ptr(res) +} + +fn vm_sys_state_clear_all(rc_vm: &Rc>) -> EmptyReturn { + VM::sys_state_clear_all(&mut rc_vm.borrow_mut().vm).into() +} + +#[export_name = "vm_sys_sleep"] +pub unsafe extern "C" fn _vm_sys_sleep( + vm_pointer: *const RefCell, + ptr: *mut u8, + len: usize, +) -> u64 { + let rc_vm = vm_ptr_to_rc(vm_pointer); + let input = ptr_to_input(ptr, len); + let res = vm_sys_sleep(&rc_vm, input); + output_to_ptr(res) +} + +fn vm_sys_sleep(rc_vm: &Rc>, input: VmSysSleepParameters) -> HandleReturn { + VM::sys_sleep( + &mut rc_vm.borrow_mut().vm, + input.name, + Duration::from_millis(input.wake_up_time_since_unix_epoch_millis), + Some(Duration::from_millis(input.now_since_unix_epoch_millis)), + ) + .into() +} + +#[export_name = "vm_sys_awakeable"] +pub unsafe extern "C" fn _vm_sys_awakeable(vm_pointer: *const RefCell) -> u64 { + let rc_vm = vm_ptr_to_rc(vm_pointer); + let res = vm_sys_awakeable(&rc_vm); + output_to_ptr(res) +} + +fn vm_sys_awakeable(rc_vm: &Rc>) -> AwakeableReturn { + match VM::sys_awakeable(&mut rc_vm.borrow_mut().vm) { + Ok((awakeable_id, handle)) => AwakeableReturn::Ok { + id: awakeable_id, + handle: handle.into(), + }, + Err(e) => AwakeableReturn::from_err(e), + } +} + +#[export_name = "vm_sys_complete_awakeable"] +pub unsafe extern "C" fn _vm_sys_complete_awakeable( + vm_pointer: *const RefCell, + ptr: *mut u8, + len: usize, +) -> u64 { + let rc_vm = vm_ptr_to_rc(vm_pointer); + let input = ptr_to_input(ptr, len); + let res = vm_sys_complete_awakeable(&rc_vm, input); + output_to_ptr(res) +} + +fn vm_sys_complete_awakeable( + rc_vm: &Rc>, + input: VmSysCompleteAwakeableParameters, +) -> EmptyReturn { + VM::sys_complete_awakeable( + &mut rc_vm.borrow_mut().vm, + input.id, + input.result.into(), + PayloadOptions::default(), + ) + .into() +} + +#[export_name = "vm_sys_call"] +pub unsafe extern "C" fn _vm_sys_call( + vm_pointer: *const RefCell, + ptr: *mut u8, + len: usize, +) -> u64 { + let rc_vm = vm_ptr_to_rc(vm_pointer); + let input = ptr_to_input(ptr, len); + let res = vm_sys_call(&rc_vm, input); + output_to_ptr(res) +} + +fn vm_sys_call(rc_vm: &Rc>, input: VmSysCallParameters) -> SysCallReturn { + match VM::sys_call( + &mut rc_vm.borrow_mut().vm, + Target { + service: input.service, + handler: input.handler, + key: input.key, + idempotency_key: input.idempotency_key, + headers: input + .headers + .into_iter() + .map(|(k, v)| Header { + key: k.into(), + value: v.into(), + }) + .collect(), + }, + Bytes::from(input.input), + None, + PayloadOptions::default(), + ) { + Ok(call_handle) => SysCallReturn::Ok { + invocation_id_handle: call_handle.invocation_id_notification_handle.into(), + result_handle: call_handle.call_notification_handle.into(), + }, + Err(e) => SysCallReturn::from_err(e), + } +} + +#[export_name = "vm_sys_send"] +pub unsafe extern "C" fn _vm_sys_send( + vm_pointer: *const RefCell, + ptr: *mut u8, + len: usize, +) -> u64 { + let rc_vm = vm_ptr_to_rc(vm_pointer); + let input = ptr_to_input(ptr, len); + let res = vm_sys_send(&rc_vm, input); + output_to_ptr(res) +} + +fn vm_sys_send(rc_vm: &Rc>, input: VmSysSendParameters) -> HandleReturn { + VM::sys_send( + &mut rc_vm.borrow_mut().vm, + Target { + service: input.service, + handler: input.handler, + key: input.key, + idempotency_key: input.idempotency_key, + headers: input + .headers + .into_iter() + .map(|(k, v)| Header { + key: k.into(), + value: v.into(), + }) + .collect(), + }, + Bytes::from(input.input), + input + .execution_time_since_unix_epoch_millis + .map(Duration::from_millis), + None, + PayloadOptions::default(), + ) + .map(|s| s.invocation_id_notification_handle) + .into() +} + +#[export_name = "vm_sys_cancel_invocation"] +pub unsafe extern "C" fn _vm_sys_cancel_invocation( + vm_pointer: *const RefCell, + ptr: *mut u8, + len: usize, +) -> u64 { + let rc_vm = vm_ptr_to_rc(vm_pointer); + let input = ptr_to_input(ptr, len); + let res = vm_sys_cancel_invocation(&rc_vm, input); + output_to_ptr(res) +} + +fn vm_sys_cancel_invocation( + rc_vm: &Rc>, + input: VmSysCancelInvocation, +) -> EmptyReturn { + VM::sys_cancel_invocation(&mut rc_vm.borrow_mut().vm, input.invocation_id).into() +} + +#[export_name = "vm_sys_attach_invocation"] +pub unsafe extern "C" fn _vm_sys_attach_invocation( + vm_pointer: *const RefCell, + ptr: *mut u8, + len: usize, +) -> u64 { + let rc_vm = vm_ptr_to_rc(vm_pointer); + let input = ptr_to_input(ptr, len); + let res = vm_sys_attach_invocation(&rc_vm, input); + output_to_ptr(res) +} + +fn vm_sys_attach_invocation( + rc_vm: &Rc>, + input: VmSysAttachInvocation, +) -> HandleReturn { + VM::sys_attach_invocation( + &mut rc_vm.borrow_mut().vm, + AttachInvocationTarget::InvocationId(input.invocation_id), + ) + .into() +} + +#[export_name = "vm_sys_get_invocation_output"] +pub unsafe extern "C" fn _vm_sys_get_invocation_output( + vm_pointer: *const RefCell, + ptr: *mut u8, + len: usize, +) -> u64 { + let rc_vm = vm_ptr_to_rc(vm_pointer); + let input = ptr_to_input(ptr, len); + let res = vm_sys_get_invocation_output(&rc_vm, input); + output_to_ptr(res) +} + +fn vm_sys_get_invocation_output( + rc_vm: &Rc>, + input: VmSysGetInvocationOutput, +) -> HandleReturn { + VM::sys_get_invocation_output( + &mut rc_vm.borrow_mut().vm, + AttachInvocationTarget::InvocationId(input.invocation_id), + ) + .into() +} + +#[export_name = "vm_sys_promise_get"] +pub unsafe extern "C" fn _vm_sys_promise_get( + vm_pointer: *const RefCell, + ptr: *mut u8, + len: usize, +) -> u64 { + let rc_vm = vm_ptr_to_rc(vm_pointer); + let input = ptr_to_input(ptr, len); + let res = vm_sys_promise_get(&rc_vm, input); + output_to_ptr(res) +} + +fn vm_sys_promise_get( + rc_vm: &Rc>, + input: VmSysPromiseGetParameters, +) -> HandleReturn { + VM::sys_get_promise(&mut rc_vm.borrow_mut().vm, input.key).into() +} + +#[export_name = "vm_sys_promise_peek"] +pub unsafe extern "C" fn _vm_sys_promise_peek( + vm_pointer: *const RefCell, + ptr: *mut u8, + len: usize, +) -> u64 { + let rc_vm = vm_ptr_to_rc(vm_pointer); + let input = ptr_to_input(ptr, len); + let res = vm_sys_promise_peek(&rc_vm, input); + output_to_ptr(res) +} + +fn vm_sys_promise_peek( + rc_vm: &Rc>, + input: VmSysPromisePeekParameters, +) -> HandleReturn { + VM::sys_peek_promise(&mut rc_vm.borrow_mut().vm, input.key).into() +} + +#[export_name = "vm_sys_promise_complete"] +pub unsafe extern "C" fn _vm_sys_promise_complete( + vm_pointer: *const RefCell, + ptr: *mut u8, + len: usize, +) -> u64 { + let rc_vm = vm_ptr_to_rc(vm_pointer); + let input = ptr_to_input(ptr, len); + let res = vm_sys_promise_complete(&rc_vm, input); + output_to_ptr(res) +} + +fn vm_sys_promise_complete( + rc_vm: &Rc>, + input: VmSysPromiseCompleteParameters, +) -> HandleReturn { + VM::sys_complete_promise( + &mut rc_vm.borrow_mut().vm, + input.id, + input.result.into(), + PayloadOptions::default(), + ) + .into() +} + +#[export_name = "vm_sys_run"] +pub unsafe extern "C" fn _vm_sys_run( + vm_pointer: *const RefCell, + ptr: *mut u8, + len: usize, +) -> u64 { + let rc_vm = vm_ptr_to_rc(vm_pointer); + let input = ptr_to_input(ptr, len); + let res = vm_sys_run(&rc_vm, input); + output_to_ptr(res) +} + +fn vm_sys_run(rc_vm: &Rc>, input: VmSysRunParameters) -> HandleReturn { + VM::sys_run(&mut rc_vm.borrow_mut().vm, input.name).into() +} + +#[export_name = "vm_propose_run_completion"] +pub unsafe extern "C" fn _vm_propose_run_completion( + vm_pointer: *const RefCell, + ptr: *mut u8, + len: usize, +) -> u64 { + let rc_vm = vm_ptr_to_rc(vm_pointer); + let input = ptr_to_input(ptr, len); + let res = vm_propose_run_completion(&rc_vm, input); + output_to_ptr(res) +} + +fn vm_propose_run_completion( + rc_vm: &Rc>, + input: VmProposeRunCompletionParameters, +) -> EmptyReturn { + let run_exit_result = match input.result { + RunResult::Success { value } => RunExitResult::Success(Bytes::from(value)), + RunResult::TerminalFailure { code, message, metadata } => { + RunExitResult::TerminalFailure(TerminalFailure { + code: code as u16, + message, + metadata: metadata.unwrap_or_default(), + }) + } + RunResult::RetryableFailure { code, message, stacktrace } => { + let mut error = Error::new(code as u16, message); + if let Some(st) = stacktrace { + error = error.with_stacktrace(st); + } + RunExitResult::RetryableFailure { + attempt_duration: Duration::from_millis(input.attempt_duration_millis), + error, + } + } + }; + + let retry_policy = match input.retry_policy { + None => RetryPolicy::default(), + Some(rp) => RetryPolicy::Exponential { + initial_interval: Duration::from_millis(rp.initial_interval_millis), + factor: rp.factor, + max_interval: rp.max_interval_millis.map(Duration::from_millis), + max_attempts: rp.max_attempts, + max_duration: rp.max_duration_millis.map(Duration::from_millis), + }, + }; + + VM::propose_run_completion( + &mut rc_vm.borrow_mut().vm, + input.handle.into(), + run_exit_result, + retry_policy, + ) + .into() +} + +// Java SDK-specific: signals (not in Go SDK yet) + +#[export_name = "vm_sys_create_signal_handle"] +pub unsafe extern "C" fn _vm_sys_create_signal_handle( + vm_pointer: *const RefCell, + ptr: *mut u8, + len: usize, +) -> u64 { + let rc_vm = vm_ptr_to_rc(vm_pointer); + let input = ptr_to_input(ptr, len); + let res = vm_sys_create_signal_handle(&rc_vm, input); + output_to_ptr(res) +} + +fn vm_sys_create_signal_handle( + rc_vm: &Rc>, + input: VmSysCreateSignalHandleParameters, +) -> HandleReturn { + VM::create_signal_handle(&mut rc_vm.borrow_mut().vm, input.name).into() +} + +#[export_name = "vm_sys_complete_signal"] +pub unsafe extern "C" fn _vm_sys_complete_signal( + vm_pointer: *const RefCell, + ptr: *mut u8, + len: usize, +) -> u64 { + let rc_vm = vm_ptr_to_rc(vm_pointer); + let input = ptr_to_input(ptr, len); + let res = vm_sys_complete_signal(&rc_vm, input); + output_to_ptr(res) +} + +fn vm_sys_complete_signal( + rc_vm: &Rc>, + input: VmSysCompleteSignalParameters, +) -> EmptyReturn { + VM::sys_complete_signal( + &mut rc_vm.borrow_mut().vm, + input.target, + input.name, + input.result.into(), + ) + .into() +} + +#[export_name = "vm_sys_write_output"] +pub unsafe extern "C" fn _vm_sys_write_output( + vm_pointer: *const RefCell, + ptr: *mut u8, + len: usize, +) -> u64 { + let rc_vm = vm_ptr_to_rc(vm_pointer); + let input = ptr_to_input(ptr, len); + let res = vm_sys_write_output(&rc_vm, input); + output_to_ptr(res) +} + +fn vm_sys_write_output( + rc_vm: &Rc>, + input: VmSysWriteOutputParameters, +) -> EmptyReturn { + VM::sys_write_output( + &mut rc_vm.borrow_mut().vm, + input.result.into(), + PayloadOptions::default(), + ) + .into() +} + +#[export_name = "vm_sys_end"] +pub unsafe extern "C" fn _vm_sys_end(vm_pointer: *const RefCell) -> u64 { + let rc_vm = vm_ptr_to_rc(vm_pointer); + let res = vm_sys_end(&rc_vm); + output_to_ptr(res) +} + +fn vm_sys_end(rc_vm: &Rc>) -> EmptyReturn { + VM::sys_end(&mut rc_vm.borrow_mut().vm).into() +} + +#[export_name = "vm_free"] +pub unsafe extern "C" fn _vm_free(vm: *const RefCell) { + assert_not_null(vm); + // We don't need to increment the counter, we're materializing the initial leak! + let rc = Rc::from_raw(vm); + match Rc::try_unwrap(rc) { + Ok(cell) => drop(cell.into_inner()), + Err(_) => panic!("attempted to free vm while still borrowed"), + } +} + +// --------- Logging infra + +fn log(level: AbiLogLevel, message: &str) { + unsafe { + let (ptr, len) = string_to_ptr(message); + _log(level as u32, ptr, len); + } +} + +#[link(wasm_import_module = "env")] +extern "C" { + #[link_name = "log"] + fn _log(level: u32, ptr: u32, size: u32); +} + +// --------- Unsafe memory helpers + +#[inline] +pub fn assert_not_null(s: *const T) { + if s.is_null() { + panic!("Null pointer exception on input") + } +} + +#[inline] +unsafe fn ptr_to_vec(ptr: *mut u8, len: usize) -> Vec { + assert_not_null(ptr); + Vec::from_raw_parts(ptr, len, len) +} + +/// Deserializes CBOR from caller-allocated memory (ownership transferred to us — memory freed). +#[inline] +unsafe fn ptr_to_input(ptr: *mut u8, len: usize) -> T { + let vec = ptr_to_vec(ptr, len); + ciborium::from_reader(vec.as_slice()).expect("CBOR deserialization of input should not fail") +} + +#[inline] +unsafe fn vec_to_ptr(v: Vec) -> u64 { + let len = v.len(); + let ptr = Box::into_raw(v.into_boxed_slice()) as *mut u8; + ((ptr as u64) << 32) | len as u64 +} + +/// Serializes `t` to CBOR and returns packed `(ptr << 32) | len`. +#[inline] +fn output_to_ptr(t: T) -> u64 { + let mut buf = Vec::new(); + ciborium::into_writer(&t, &mut buf).expect("CBOR serialization of output should not fail"); + unsafe { vec_to_ptr(buf) } +} + +unsafe fn vm_ptr_to_rc(vm_pointer: *const RefCell) -> Rc> { + assert_not_null(vm_pointer); + Rc::increment_strong_count(vm_pointer); + Rc::from_raw(vm_pointer) +} + +unsafe fn string_to_ptr(s: &str) -> (u32, u32) { + (s.as_ptr() as u32, s.len() as u32) +} + +#[export_name = "allocate"] +pub unsafe extern "C" fn _allocate(size: usize) -> *mut u8 { + allocate(size) +} + +fn allocate(size: usize) -> *mut u8 { + let vec: Vec> = vec![MaybeUninit::uninit(); size]; + Box::into_raw(vec.into_boxed_slice()) as *mut u8 +} + +#[export_name = "deallocate"] +pub unsafe extern "C" fn _deallocate(ptr: *mut u8, size: usize) { + deallocate(ptr, size); +} + +unsafe fn deallocate(ptr: *mut u8, size: usize) { + let _: Vec = Vec::from_raw_parts(ptr, 0, size); +} + +// --------- Input DTOs (Java → Rust, CBOR maps with camelCase keys) + +#[derive(Deserialize)] +#[serde(rename_all = "camelCase")] +struct VmNewParameters { + headers: Vec<(String, String)>, +} + +#[derive(Deserialize)] +#[serde(rename_all = "camelCase")] +struct VmNotifyError { + message: String, + #[serde(default)] + stacktrace: Option, + #[serde(default)] + delay_override_millis: Option, +} + +/// Flat list of handles — Rust converts to FirstCompleted(handles.map(Single)) internally. +#[derive(Deserialize)] +struct VmDoProgressParameters { + handles: Vec, +} + +#[derive(Deserialize)] +#[serde(rename_all = "camelCase")] +struct VmSysStateGetParameters { + key: String, +} + +#[derive(Deserialize)] +#[serde(rename_all = "camelCase")] +struct VmSysStateSetParameters { + key: String, + #[serde(with = "serde_bytes")] + value: Vec, +} + +#[derive(Deserialize)] +struct VmSysStateClearParameters { + key: String, +} + +#[derive(Deserialize)] +#[serde(rename_all = "camelCase")] +struct VmSysSleepParameters { + name: String, + wake_up_time_since_unix_epoch_millis: u64, + now_since_unix_epoch_millis: u64, +} + +/// Combined success/failure for awakeable completion (mirrors Go's single endpoint). +#[derive(Deserialize)] +#[serde(rename_all = "camelCase")] +struct VmSysCompleteAwakeableParameters { + id: String, + result: NonEmptyValueParam, +} + +/// Combined success/failure union used by awakeable, promise, write_output, signal. +#[derive(Deserialize)] +#[serde(tag = "type", rename_all = "camelCase", rename_all_fields = "camelCase")] +enum NonEmptyValueParam { + Success { + #[serde(with = "serde_bytes")] + value: Vec, + }, + Failure { + code: u32, + message: String, + #[serde(default)] + metadata: Option>, + }, +} + +impl From for NonEmptyValue { + fn from(p: NonEmptyValueParam) -> Self { + match p { + NonEmptyValueParam::Success { value } => NonEmptyValue::Success(Bytes::from(value)), + NonEmptyValueParam::Failure { code, message, metadata } => { + NonEmptyValue::Failure(TerminalFailure { + code: code as u16, + message, + metadata: metadata.unwrap_or_default(), + }) + } + } + } +} + +#[derive(Deserialize)] +#[serde(rename_all = "camelCase")] +struct VmSysCallParameters { + service: String, + handler: String, + key: Option, + idempotency_key: Option, + headers: Vec<(String, String)>, + #[serde(with = "serde_bytes")] + input: Vec, +} + +#[derive(Deserialize)] +#[serde(rename_all = "camelCase")] +struct VmSysSendParameters { + service: String, + handler: String, + key: Option, + idempotency_key: Option, + headers: Vec<(String, String)>, + #[serde(with = "serde_bytes")] + input: Vec, + execution_time_since_unix_epoch_millis: Option, +} + +#[derive(Deserialize)] +#[serde(rename_all = "camelCase")] +struct VmSysCancelInvocation { + invocation_id: String, +} + +#[derive(Deserialize)] +#[serde(rename_all = "camelCase")] +struct VmSysAttachInvocation { + invocation_id: String, +} + +#[derive(Deserialize)] +#[serde(rename_all = "camelCase")] +struct VmSysGetInvocationOutput { + invocation_id: String, +} + +#[derive(Deserialize)] +struct VmSysPromiseGetParameters { + key: String, +} + +#[derive(Deserialize)] +struct VmSysPromisePeekParameters { + key: String, +} + +#[derive(Deserialize)] +#[serde(rename_all = "camelCase")] +struct VmSysPromiseCompleteParameters { + id: String, + result: NonEmptyValueParam, +} + +#[derive(Deserialize)] +struct VmSysRunParameters { + name: String, +} + +#[derive(Deserialize)] +#[serde(rename_all = "camelCase")] +struct VmProposeRunCompletionParameters { + handle: u32, + result: RunResult, + attempt_duration_millis: u64, + retry_policy: Option, +} + +#[derive(Deserialize)] +#[serde(tag = "type", rename_all = "camelCase", rename_all_fields = "camelCase")] +enum RunResult { + Success { + #[serde(with = "serde_bytes")] + value: Vec, + }, + TerminalFailure { + code: u32, + message: String, + #[serde(default)] + metadata: Option>, + }, + RetryableFailure { + code: u32, + message: String, + #[serde(default)] + stacktrace: Option, + }, +} + +#[derive(Deserialize)] +#[serde(rename_all = "camelCase")] +struct WasmRetryPolicy { + initial_interval_millis: u64, + factor: f32, + max_interval_millis: Option, + max_attempts: Option, + max_duration_millis: Option, +} + +#[derive(Deserialize)] +struct VmSysCreateSignalHandleParameters { + name: String, +} + +#[derive(Deserialize)] +struct VmSysCompleteSignalParameters { + target: String, + name: String, + result: NonEmptyValueParam, +} + +#[derive(Deserialize)] +#[serde(rename_all = "camelCase")] +struct VmSysWriteOutputParameters { + result: NonEmptyValueParam, +} + +// --------- Output DTOs (Rust → Java) +// Each return type has a `from_err` helper matching Go's `.into()` from Failure conversions. + +#[derive(Serialize)] +#[serde(tag = "type", rename_all = "camelCase", rename_all_fields = "camelCase")] +enum VmNewReturn { + Ok { pointer: u32 }, + Failure { code: u32, message: String }, +} +impl VmNewReturn { + fn from_err(e: Error) -> Self { + Self::Failure { + code: e.code() as u32, + message: e.to_string(), + } + } +} + +/// Equivalent to Go's GenericEmptyReturn. +#[derive(Serialize)] +#[serde(tag = "type", rename_all = "camelCase", rename_all_fields = "camelCase")] +enum EmptyReturn { + Ok, + Failure { code: u32, message: String }, +} + +/// Equivalent to Go's SimpleSysAsyncResultReturn. +#[derive(Serialize)] +#[serde(tag = "type", rename_all = "camelCase", rename_all_fields = "camelCase")] +enum HandleReturn { + Ok { handle: u32 }, + Failure { code: u32, message: String }, +} + +#[derive(Serialize)] +#[serde(tag = "type", rename_all = "camelCase", rename_all_fields = "camelCase")] +enum IsReadyReturn { + Ok { ready: bool }, + Failure { code: u32, message: String }, +} +impl IsReadyReturn { + fn from_err(e: Error) -> Self { + Self::Failure { + code: e.code() as u32, + message: e.to_string(), + } + } +} + +/// Plain struct (no Ok/Failure wrapper) — mirrors Go's VmGetResponseHeadReturn. +#[derive(Serialize)] +#[serde(rename_all = "camelCase")] +struct ResponseHeadReturn { + status_code: u32, + headers: Vec<(String, String)>, +} + +#[derive(Serialize)] +#[serde(tag = "type", rename_all = "camelCase", rename_all_fields = "camelCase")] +enum DoProgressReturn { + AnyCompleted, + WaitingExternalProgress, + ExecuteRun { handle: u32 }, + CancelSignalReceived, + Suspended, + Failure { code: u32, message: String }, +} +impl DoProgressReturn { + fn from_err(e: Error) -> Self { + Self::Failure { + code: e.code() as u32, + message: e.to_string(), + } + } +} + +/// Notification value (the payload of a completed handle). +#[derive(Serialize)] +#[serde(tag = "type", rename_all = "camelCase", rename_all_fields = "camelCase")] +enum NotificationValue { + Void, + Success { + #[serde(with = "serde_bytes")] + value: Vec, + }, + Failure { + code: u16, + message: String, + metadata: Vec<(String, String)>, + }, + StateKeys { + keys: Vec, + }, + InvocationId { + id: String, + }, +} + +impl From for NotificationValue { + fn from(v: Value) -> Self { + match v { + Value::Void => NotificationValue::Void, + Value::Success(b) => NotificationValue::Success { value: b.to_vec() }, + Value::Failure(TerminalFailure { code, message, metadata }) => { + NotificationValue::Failure { code, message, metadata } + } + Value::StateKeys(keys) => NotificationValue::StateKeys { keys }, + Value::InvocationId(id) => NotificationValue::InvocationId { id }, + } + } +} + +#[derive(Serialize)] +#[serde(tag = "type", rename_all = "camelCase", rename_all_fields = "camelCase")] +enum TakeNotificationReturn { + NotReady, + Value { value: NotificationValue }, + Suspended, + Failure { code: u32, message: String }, +} +impl TakeNotificationReturn { + fn from_err(e: Error) -> Self { + Self::Failure { + code: e.code() as u32, + message: e.to_string(), + } + } +} + +#[derive(Serialize)] +#[serde(rename_all = "camelCase")] +struct WasmInput { + invocation_id: String, + key: String, + headers: Vec<(String, String)>, + #[serde(with = "serde_bytes")] + input: Vec, + random_seed: i64, + should_use_random_seed: bool, +} + +#[derive(Serialize)] +#[serde(tag = "type", rename_all = "camelCase", rename_all_fields = "camelCase")] +enum SysInputReturn { + Ok { input: WasmInput }, + Failure { code: u32, message: String }, +} +impl SysInputReturn { + fn from_err(e: Error) -> Self { + Self::Failure { + code: e.code() as u32, + message: e.to_string(), + } + } +} + +#[derive(Serialize)] +#[serde(tag = "type", rename_all = "camelCase", rename_all_fields = "camelCase")] +enum AwakeableReturn { + Ok { id: String, handle: u32 }, + Failure { code: u32, message: String }, +} +impl AwakeableReturn { + fn from_err(e: Error) -> Self { + Self::Failure { + code: e.code() as u32, + message: e.to_string(), + } + } +} + +#[derive(Serialize)] +#[serde(tag = "type", rename_all = "camelCase", rename_all_fields = "camelCase")] +enum SysCallReturn { + Ok { + invocation_id_handle: u32, + result_handle: u32, + }, + Failure { + code: u32, + message: String, + }, +} +impl SysCallReturn { + fn from_err(e: Error) -> Self { + Self::Failure { + code: e.code() as u32, + message: e.to_string(), + } + } +} + +// --------- From impls (enable `.into()` in inner functions, like Go's `into()` on pb types) + +impl From> for EmptyReturn { + fn from(value: Result<(), Error>) -> Self { + match value { + Ok(()) => EmptyReturn::Ok, + Err(e) => EmptyReturn::Failure { + code: e.code() as u32, + message: e.to_string(), + }, + } + } +} + +impl From> for HandleReturn { + fn from(value: Result) -> Self { + match value { + Ok(h) => HandleReturn::Ok { handle: h.into() }, + Err(e) => HandleReturn::Failure { + code: e.code() as u32, + message: e.to_string(), + }, + } + } +} + +impl From for ResponseHeadReturn { + fn from(head: ResponseHead) -> Self { + ResponseHeadReturn { + status_code: head.status_code as u32, + headers: head + .headers + .into_iter() + .map(|h| (h.key.into_owned(), h.value.into_owned())) + .collect(), + } + } +} diff --git a/sdk-core/src/main/service-protocol/.github/workflows/lint.yaml b/sdk-core/src/main/service-protocol/.github/workflows/lint.yaml deleted file mode 100644 index 8b2d8e5fc..000000000 --- a/sdk-core/src/main/service-protocol/.github/workflows/lint.yaml +++ /dev/null @@ -1,20 +0,0 @@ -name: Lint Code Base - -on: - push: - pull_request: - branches: [main] - -jobs: - build: - name: Lint - runs-on: ubuntu-latest - - steps: - - name: Checkout Code - uses: actions/checkout@v3 - - - name: Run protolint - uses: plexsystems/protolint-action@v0.7.0 - with: - configDirectory: . \ No newline at end of file diff --git a/sdk-core/src/main/service-protocol/.gitignore b/sdk-core/src/main/service-protocol/.gitignore deleted file mode 100644 index 29b636a48..000000000 --- a/sdk-core/src/main/service-protocol/.gitignore +++ /dev/null @@ -1,2 +0,0 @@ -.idea -*.iml \ No newline at end of file diff --git a/sdk-core/src/main/service-protocol/.prettierrc.toml b/sdk-core/src/main/service-protocol/.prettierrc.toml deleted file mode 100644 index 1191103fe..000000000 --- a/sdk-core/src/main/service-protocol/.prettierrc.toml +++ /dev/null @@ -1,3 +0,0 @@ -embeddedLanguageFormatting = "off" -proseWrap = "always" -printWidth = 120 \ No newline at end of file diff --git a/sdk-core/src/main/service-protocol/.protolint.yaml b/sdk-core/src/main/service-protocol/.protolint.yaml deleted file mode 100644 index bfe300a60..000000000 --- a/sdk-core/src/main/service-protocol/.protolint.yaml +++ /dev/null @@ -1,12 +0,0 @@ -lint: - rules: - remove: - - ENUM_FIELD_NAMES_PREFIX - - rules_option: - max_line_length: - max_chars: 180 - tab_chars: 2 - - indent: - style: 2 \ No newline at end of file diff --git a/sdk-core/src/main/service-protocol/LICENSE b/sdk-core/src/main/service-protocol/LICENSE deleted file mode 100644 index b81eecf56..000000000 --- a/sdk-core/src/main/service-protocol/LICENSE +++ /dev/null @@ -1,21 +0,0 @@ -MIT License - -Copyright (c) 2023 - Restate Software, Inc., Restate GmbH - -Permission is hereby granted, free of charge, to any person obtaining a copy -of this software and associated documentation files (the "Software"), to deal -in the Software without restriction, including without limitation the rights -to use, copy, modify, merge, publish, distribute, sublicense, and/or sell -copies of the Software, and to permit persons to whom the Software is -furnished to do so, subject to the following conditions: - -The above copyright notice and this permission notice shall be included in all -copies or substantial portions of the Software. - -THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR -IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, -FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE -AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER -LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, -OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE -SOFTWARE \ No newline at end of file diff --git a/sdk-core/src/main/service-protocol/README.md b/sdk-core/src/main/service-protocol/README.md deleted file mode 100644 index 4a0ca91ff..000000000 --- a/sdk-core/src/main/service-protocol/README.md +++ /dev/null @@ -1,13 +0,0 @@ -# Restate Service Protocol - -This repo contains specification documents and Protobuf schemas of the Restate Service Protocol. - -* [Service invocation protocol specification](./service-invocation-protocol.md) - -## Development - -To format the spec document: - -```shell -npx prettier -w service-invocation-protocol.md -``` \ No newline at end of file diff --git a/sdk-core/src/main/service-protocol/buf.lock b/sdk-core/src/main/service-protocol/buf.lock deleted file mode 100644 index 4f98143f5..000000000 --- a/sdk-core/src/main/service-protocol/buf.lock +++ /dev/null @@ -1,2 +0,0 @@ -# Generated by buf. DO NOT EDIT. -version: v2 diff --git a/sdk-core/src/main/service-protocol/buf.yaml b/sdk-core/src/main/service-protocol/buf.yaml deleted file mode 100644 index ab3bd5be4..000000000 --- a/sdk-core/src/main/service-protocol/buf.yaml +++ /dev/null @@ -1,8 +0,0 @@ -version: v2 -name: buf.build/restatedev/service-protocol -lint: - use: - - DEFAULT -breaking: - use: - - FILE diff --git a/sdk-core/src/main/service-protocol/dev/restate/service/discovery.proto b/sdk-core/src/main/service-protocol/dev/restate/service/discovery.proto deleted file mode 100644 index 519dc4e54..000000000 --- a/sdk-core/src/main/service-protocol/dev/restate/service/discovery.proto +++ /dev/null @@ -1,27 +0,0 @@ -// Copyright (c) 2024 - Restate Software, Inc., Restate GmbH -// -// This file is part of the Restate service protocol, which is -// released under the MIT license. -// -// You can find a copy of the license in file LICENSE in the root -// directory of this repository or package, or at -// https://github.com/restatedev/service-protocol/blob/main/LICENSE - -syntax = "proto3"; - -package dev.restate.service.discovery; - -option java_package = "dev.restate.sdk.core.generated.discovery"; - -// Service discovery protocol version. -enum ServiceDiscoveryProtocolVersion { - SERVICE_DISCOVERY_PROTOCOL_VERSION_UNSPECIFIED = 0; - // initial service discovery protocol version using endpoint_manifest_schema.json - V1 = 1; - // add custom metadata and documentation for services/handlers - V2 = 2; - // add options for private service, journal retention, idempotency retention, workflow completion retention, inactivity timeout, abort timeout, enable lazy state - V3 = 3; - // add lambda compression - V4 = 4; -} diff --git a/sdk-core/src/main/service-protocol/dev/restate/service/protocol.proto b/sdk-core/src/main/service-protocol/dev/restate/service/protocol.proto deleted file mode 100644 index 0a6696533..000000000 --- a/sdk-core/src/main/service-protocol/dev/restate/service/protocol.proto +++ /dev/null @@ -1,671 +0,0 @@ -// Copyright (c) 2023 - Restate Software, Inc., Restate GmbH -// -// This file is part of the Restate service protocol, which is -// released under the MIT license. -// -// You can find a copy of the license in file LICENSE in the root -// directory of this repository or package, or at -// https://github.com/restatedev/service-protocol/blob/main/LICENSE - -syntax = "proto3"; - -package dev.restate.service.protocol; - -option java_package = "dev.restate.sdk.core.generated.protocol"; - -// Service protocol version. -enum ServiceProtocolVersion { - SERVICE_PROTOCOL_VERSION_UNSPECIFIED = 0; - // initial service protocol version - V1 = 1; - // Added - // * Entry retry mechanism: ErrorMessage.next_retry_delay, StartMessage.retry_count_since_last_stored_entry and StartMessage.duration_since_last_stored_entry - V2 = 2; - // **Yanked** - V3 = 3; - // **Yanked** - V4 = 4; - // Immutable journal. Added: - // * New command to cancel invocations - // * Both Call and Send commands now return an additional notification to return the invocation id - // * New field to set idempotency key for Call/Send commands - // * New command to attach to existing invocation - // * New command to get output of existing invocation - V5 = 5; - // Added: - // * StartMessage.random_seed - // * Failure.metadata - V6 = 6; -} - -// --- Core frames --- - -// Type: 0x0000 + 0 -message StartMessage { - message StateEntry { - bytes key = 1; - // If value is an empty byte array, - // then it means the value is empty and not "missing" (e.g. empty string). - bytes value = 2; - } - - // Unique id of the invocation. This id is unique across invocations and won't change when replaying the journal. - bytes id = 1; - - // Invocation id that can be used for logging. - // The user can use this id to address this invocation in admin and status introspection apis. - string debug_id = 2; - - // This is the sum of known commands + notifications - uint32 known_entries = 3; - - // protolint:disable:next REPEATED_FIELD_NAMES_PLURALIZED - repeated StateEntry state_map = 4; - bool partial_state = 5; - - // If this invocation has a key associated (e.g. for objects and workflows), then this key is filled in. Empty otherwise. - string key = 6; - - // Retry count since the last stored entry. - // - // Please note that this count might not be accurate, as it's not durably stored, - // thus it might get reset in case Restate crashes/changes leader. - uint32 retry_count_since_last_stored_entry = 7; - - // Duration since the last stored entry, in milliseconds. - // - // Please note this duration might not be accurate, - // and might change depending on which Restate replica executes the request. - uint64 duration_since_last_stored_entry = 8; - - // Random seed to use to seed the deterministic RNG exposed in the context API. - // This will be stable across restarts. - uint64 random_seed = 9; -} - -// Type: 0x0000 + 1 -// Implementations MUST send this message when suspending an invocation. -// -// These lists represent any of the notification_idx and/or notification_name the invocation is waiting on to progress. -// The runtime will resume the invocation as soon as either one of the given notification_idx or notification_name is completed. -// Between the two lists there MUST be at least one element. -message SuspensionMessage { - repeated uint32 waiting_completions = 1; - repeated uint32 waiting_signals = 2; - repeated string waiting_named_signals = 3; -} - -// Type: 0x0000 + 2 -message ErrorMessage { - // The code can be any HTTP status code, as described https://www.iana.org/assignments/http-status-codes/http-status-codes.xhtml. - // In addition, we define the following error codes that MAY be used by the SDK for better error reporting: - // * JOURNAL_MISMATCH = 570, that is when the SDK cannot replay a journal due to the mismatch between the journal and the actual code. - // * PROTOCOL_VIOLATION = 571, that is when the SDK receives an unexpected message or an expected message variant, given its state. - uint32 code = 1; - // Contains a concise error message, e.g. Throwable#getMessage() in Java. - string message = 2; - // The exception stacktrace, if available. - string stacktrace = 3; - - // Command that caused the failure. This may be outside the current stored journal size. - // If no specific entry caused the failure, the current replayed/processed entry can be used. - optional uint32 related_command_index = 4; - // Name of the entry that caused the failure. - optional string related_command_name = 5; - // Command type. - optional uint32 related_command_type = 6; - - // Delay before executing the next retry, specified as duration in milliseconds. - // If provided, it will override the default retry policy used by Restate's invoker ONLY for the next retry attempt. - optional uint64 next_retry_delay = 8; -} - -// Type: 0x0000 + 3 -// Implementations MUST send this message when the invocation lifecycle ends. -message EndMessage { -} - -// Type: 0x0000 + 4 -message CommandAckMessage { - uint32 command_index = 1; -} - -// This is a special control message to propose ctx.run completions to the runtime. -// This won't be written to the journal immediately, but will appear later as a new notification (meaning the result was stored). -// -// Type: 0x0000 + 5 -message ProposeRunCompletionMessage { - uint32 result_completion_id = 1; - oneof result { - bytes value = 14; - Failure failure = 15; - }; -} - -// --- Commands and Notifications --- - -// The Journal is modelled as commands and notifications. -// Commands define the operations executed, while notifications can be: -// * Completions to commands -// * Unnamed signals -// * Named signals -// -// An individual command can produce 0 or more completions, where the respective completion id(s) are defined in the command message. - -// A notification message follows the following duck-type: -// -message NotificationTemplate { - reserved 12; - - oneof id { - uint32 completion_id = 1; - uint32 signal_id = 2; - string signal_name = 3; - } - - oneof result { - Void void = 4; - Value value = 5; - Failure failure = 6; - - // Used by specific commands - string invocation_id = 16; - StateKeys state_keys = 17; - }; -} - -// ------ Input and output ------ - -// Completable: No -// Fallible: No -// Type: 0x0400 + 0 -message InputCommandMessage { - repeated Header headers = 1; - - Value value = 14; - - // Entry name - string name = 12; -} - -// Completable: No -// Fallible: No -// Type: 0x0400 + 1 -message OutputCommandMessage { - oneof result { - Value value = 14; - Failure failure = 15; - }; - - // Entry name - string name = 12; -} - -// ------ State access ------ - -// Completable: Yes -// Fallible: No -// Type: 0x0400 + 2 -message GetLazyStateCommandMessage { - bytes key = 1; - - uint32 result_completion_id = 11; - string name = 12; -} - -// Notification for GetLazyStateCommandMessage -// Type: 0x8000 + 2 -message GetLazyStateCompletionNotificationMessage { - // See NotificationMessage above - reserved 2, 3, 6, 7, 8, 12; - - uint32 completion_id = 1; - - oneof result { - Void void = 4; - Value value = 5; - }; -} - -// Completable: No -// Fallible: No -// Type: 0x0400 + 3 -message SetStateCommandMessage { - bytes key = 1; - Value value = 3; - - // Entry name - string name = 12; -} - -// Completable: No -// Fallible: No -// Type: 0x0400 + 4 -message ClearStateCommandMessage { - bytes key = 1; - - // Entry name - string name = 12; -} - -// Completable: No -// Fallible: No -// Type: 0x0400 + 5 -message ClearAllStateCommandMessage { - // Entry name - string name = 12; -} - -// Completable: Yes -// Fallible: No -// Type: 0x0400 + 6 -message GetLazyStateKeysCommandMessage { - uint32 result_completion_id = 11; - string name = 12; -} - -// Notification for GetLazyStateKeysCommandMessage -// Type: 0x8000 + 6 -message GetLazyStateKeysCompletionNotificationMessage { - // See NotificationMessage above - reserved 2 to 8, 12, 16; - - uint32 completion_id = 1; - StateKeys state_keys = 17; -} - -// Completable: No -// Fallible: No -// Type: 0x0400 + 7 -message GetEagerStateCommandMessage { - bytes key = 1; - - oneof result { - Void void = 13; - Value value = 14; - }; - - // Entry name - string name = 12; -} - -// Completable: No -// Fallible: No -// Type: 0x0400 + 8 -message GetEagerStateKeysCommandMessage { - StateKeys value = 14; - - // Entry name - string name = 12; -} - -// Completable: Yes -// Fallible: No -// Type: 0x0400 + 9 -message GetPromiseCommandMessage { - string key = 1; - - uint32 result_completion_id = 11; - string name = 12; -} - -// Notification for GetPromiseCommandMessage -// Type: 0x8000 + 9 -message GetPromiseCompletionNotificationMessage { - // See NotificationMessage above - reserved 2, 3, 4, 12, 16, 17; - - uint32 completion_id = 1; - - oneof result { - Value value = 5; - Failure failure = 6; - }; -} - -// Completable: Yes -// Fallible: No -// Type: 0x0400 + A -message PeekPromiseCommandMessage { - string key = 1; - - uint32 result_completion_id = 11; - string name = 12; -} - -// Notification for PeekPromiseCommandMessage -// Type: 0x8000 + A -message PeekPromiseCompletionNotificationMessage { - // See NotificationMessage above - reserved 2, 3, 12, 16, 17; - - uint32 completion_id = 1; - - oneof result { - Void void = 4; - Value value = 5; - Failure failure = 6; - }; -} - -// Completable: Yes -// Fallible: No -// Type: 0x0400 + B -message CompletePromiseCommandMessage { - string key = 1; - - // The value to use to complete the promise - oneof completion { - Value completion_value = 2; - Failure completion_failure = 3; - }; - - uint32 result_completion_id = 11; - string name = 12; -} - -// Notification for CompletePromiseCommandMessage -// Type: 0x8000 + B -message CompletePromiseCompletionNotificationMessage { - // See NotificationMessage above - reserved 2, 3, 5, 7, 8, 12, 16, 17; - - uint32 completion_id = 1; - - oneof result { - Void void = 4; - Failure failure = 6; - }; -} - -// ------ Syscalls ------ - -// Completable: Yes -// Fallible: No -// Type: 0x0400 + C -message SleepCommandMessage { - // Wake up time. - // The time is set as duration since UNIX Epoch. - uint64 wake_up_time = 1; - - uint32 result_completion_id = 11; - string name = 12; -} - -// Notification for SleepCommandMessage -// Type: 0x8000 + C -message SleepCompletionNotificationMessage { - // See NotificationMessage above - reserved 2, 3, 5, 6, 7, 8, 12, 16, 17; - - uint32 completion_id = 1; - Void void = 4; -} - -// Completable: Yes (two notifications: one with invocation id, then one with the actual result) -// Fallible: Yes -// Type: 0x0400 + D -message CallCommandMessage { - string service_name = 1; - string handler_name = 2; - - bytes parameter = 3; - - repeated Header headers = 4; - - // If this invocation has a key associated (e.g. for objects and workflows), then this key is filled in. Empty otherwise. - string key = 5; - - // If present, it must be non empty. - optional string idempotency_key = 6; - - uint32 invocation_id_notification_idx = 10; - uint32 result_completion_id = 11; - string name = 12; -} - -// Notification for CallCommandMessage and OneWayCallCommandMessage -// Type: 0x8000 + E -message CallInvocationIdCompletionNotificationMessage { - // See NotificationMessage above - reserved 2, 3, 4, 5, 6, 7, 8, 12, 17; - - uint32 completion_id = 1; - string invocation_id = 16; -} - -// Notification for CallCommandMessage -// Type: 0x8000 + D -message CallCompletionNotificationMessage { - // See NotificationMessage above - reserved 2, 3, 4, 12, 16, 17; - - uint32 completion_id = 1; - - oneof result { - Value value = 5; - Failure failure = 6; - }; -} - -// Completable: Yes (only one notification with invocation id) -// Fallible: Yes -// Type: 0x0400 + E -message OneWayCallCommandMessage { - string service_name = 1; - string handler_name = 2; - - bytes parameter = 3; - - // Time when this BackgroundInvoke should be executed. - // The time is set as duration since UNIX Epoch. - // If this value is not set, equal to 0, or past in time, - // the runtime will execute this BackgroundInvoke as soon as possible. - uint64 invoke_time = 4; - - repeated Header headers = 5; - - // If this invocation has a key associated (e.g. for objects and workflows), then this key is filled in. Empty otherwise. - string key = 6; - - // If present, it must be non empty. - optional string idempotency_key = 7; - - uint32 invocation_id_notification_idx = 10; - string name = 12; -} - -// Completable: No -// Fallible: Yes -// Type: 0x04000 + 10 -message SendSignalCommandMessage { - string target_invocation_id = 1; - - oneof signal_id { - uint32 idx = 2; - string name = 3; - } - - oneof result { - Void void = 4; - Value value = 5; - Failure failure = 6; - }; - - // Cannot use the field 'name' here because used above - string entry_name = 12; -} - -// Proposals for Run completions are sent through ProposeRunCompletionMessage -// -// Completable: Yes -// Fallible: No -// Type: 0x0400 + 11 -message RunCommandMessage { - uint32 result_completion_id = 11; - string name = 12; -} - -// Notification for RunCommandMessage -// Type: 0x8000 + 11 -message RunCompletionNotificationMessage { - // See NotificationMessage above - reserved 2, 3, 4, 12, 16, 17; - - uint32 completion_id = 1; - - oneof result { - Value value = 5; - Failure failure = 6; - }; -} - -// Completable: Yes -// Fallible: Yes -// Type: 0x0400 + 12 -message AttachInvocationCommandMessage { - oneof target { - // Target invocation id - string invocation_id = 1; - // Target idempotent request - IdempotentRequestTarget idempotent_request_target = 3; - // Target workflow target - WorkflowTarget workflow_target = 4; - } - - uint32 result_completion_id = 11; - string name = 12; -} - -// Notification for AttachInvocationCommandMessage -// Type: 0x8000 + 12 -message AttachInvocationCompletionNotificationMessage { - // See NotificationMessage above - reserved 2, 3, 4, 12, 16, 17; - - uint32 completion_id = 1; - - oneof result { - Value value = 5; - Failure failure = 6; - }; -} - -// Completable: Yes -// Fallible: Yes -// Type: 0x0400 + 13 -message GetInvocationOutputCommandMessage { - oneof target { - // Target invocation id - string invocation_id = 1; - // Target idempotent request - IdempotentRequestTarget idempotent_request_target = 3; - // Target workflow target - WorkflowTarget workflow_target = 4; - } - - uint32 result_completion_id = 11; - string name = 12; -} - -// Notification for GetInvocationOutputCommandMessage -// Type: 0x8000 + 13 -message GetInvocationOutputCompletionNotificationMessage { - // See NotificationMessage above - reserved 2, 3, 12, 16, 17; - - uint32 completion_id = 1; - - oneof result { - Void void = 4; - Value value = 5; - Failure failure = 6; - }; -} - -// We have this for backward compatibility, because we need to parse both old and new awakeable id. -// Completable: No -// Fallible: Yes -// Type: 0x0400 + 14 -message CompleteAwakeableCommandMessage { - string awakeable_id = 1; - - oneof result { - Value value = 2; - Failure failure = 3; - }; - - // Cannot use the field 'name' here because used above - string name = 12; -} - -// Notification message for signals -// Type: 0xFBFF -message SignalNotificationMessage { - // See NotificationMessage above - reserved 1, 12, 16, 17; - - oneof signal_id { - uint32 idx = 2; - string name = 3; - } - - oneof result { - Void void = 4; - Value value = 5; - Failure failure = 6; - }; -} - -// --- Nested messages - -message StateKeys { - repeated bytes keys = 1; -} - -message Value { - bytes content = 1; -} - -// This failure object carries user visible errors, -// e.g. invocation failure return value or failure result of an InvokeCommandMessage. -message Failure { - // The code can be any HTTP status code, as described https://www.iana.org/assignments/http-status-codes/http-status-codes.xhtml. - uint32 code = 1; - // Contains a concise error message, e.g. Throwable#getMessage() in Java. - string message = 2; - - // Error metadata - repeated FailureMetadata metadata = 3; -} - -message FailureMetadata { - string key = 1; - string value = 2; -} - -message Header { - string key = 1; - string value = 2; -} - -message WorkflowTarget { - string workflow_name = 1; - string workflow_key = 2; -} - -message IdempotentRequestTarget { - string service_name = 1; - optional string service_key = 2; - string handler_name = 3; - string idempotency_key = 4; -} - -message Void { -} - -enum BuiltInSignal { - UNKNOWN = 0; - CANCEL = 1; - reserved 2 to 15; -} \ No newline at end of file diff --git a/sdk-core/src/main/service-protocol/service-invocation-protocol.md b/sdk-core/src/main/service-protocol/service-invocation-protocol.md deleted file mode 100644 index 45fcc1179..000000000 --- a/sdk-core/src/main/service-protocol/service-invocation-protocol.md +++ /dev/null @@ -1,491 +0,0 @@ -# Restate Service Invocation Protocol - -The following specification describes the protocol used by Restate to invoke remote Restate services. - -## Architecture - -The system is composed of two actors: - -- Restate Runtime -- Service deployment, which is split into: - - SDK, which contains the implementation of the Restate Protocol - - User business logic, which interacts with the SDK to access Restate system calls (or handlerContext) - -Each invocation is modeled by the protocol as a state machine, where state transitions can be caused either by user code -or by _Runtime events_. - -Every state transition is logged in the _Invocation journal_, used to implement Restate's durable execution model. The -journal is also used to suspend an invocation and resume it at a later point in time. The _Invocation journal_ is -tracked both by Restate's runtime and the service deployment. - -Runtime and service deployment exchange _Messages_ containing the invocation journal and runtime events through an HTTP -message stream. - -## State machine and journal - -Every invocation state machine begins when the stream is opened and ends when the stream is closed. In the middle, -arbitrary interaction can be performed from the Service deployment to the Runtime and vice versa via well-defined -messages. - -The state machine is summarized in the following diagram: - -```mermaid -sequenceDiagram - Note over Runtime,SDK: Start - Runtime->>SDK: HTTP Request to /invoke/{service}/{handler} - Runtime->>SDK: StartMessage - Note over Runtime,SDK: Replaying - Runtime->>SDK: [...]EntryMessage(s) - Note over Runtime,SDK: Processing - SDK->>Runtime: HTTP Response headersAccessor - loop - SDK->>Runtime: [...]EntryMessage - Runtime->>SDK: CompletionMessage and/or EntryAckMessage - end - Note over SDK: Reached close condition - alt - SDK->>Runtime: SuspensionMessage - else - SDK->>Runtime: ErrorMessage - else - SDK->>Runtime: EndMessage - end - SDK->>Runtime: Close HTTP Response - Note over Runtime,SDK: Closed -``` - -### Replaying and Processing - -Both runtime and SDKs transition the message stream through 2 states: - -- _Replaying_, that is when there are journal entries to replay before continuing the execution. Described in - [Suspension behavior](#suspension-behavior). -- _Processing_, that is after the _replaying_ state is over. - -There are a couple of properties that we enforce through the design of the protocol: - -- Runtime and service deployment both have their view of the journal -- The source of truth of the journal and its ordering is: - - The runtime, when the invocation is not in _processing_ state - - The service deployment, when the invocation is in _processing_ state -- When in _replaying_ state, the service deployment cannot create new journal entries. -- When in _processing_ state, only the service deployment can create new journal entries, picking their order. - Consequently, it might have newer entries that the runtime is not aware of. It’s also the responsibility of the - service deployment to make sure the runtime has the same ordered view of the journal it has. -- Only in processing state the runtime can send - [`CompletionMessage`](#completable-journal-entries-and-completionmessage) - -### Syscalls - -Most Restate features, such as interaction with other services, accessing service instance state, and so on, are defined -as _Restate syscalls_ and exposed through the service protocol. The user interacts with these handlerContext using the SDK -APIs, which generate _Journal Entry_ messages that will be handled by the invocation state machine. - -Depending on the specific syscall, the Restate runtime generates as response either: - -- A completion, that is the response to the syscall -- An ack, that is a confirmation the syscall has been persisted and **will** be executed -- Nothing - -Each syscall defines a priori whether it replies with an ack or a completion, or doesn't reply at all. - -## Messages - -The protocol is composed by messages that are sent back and forth between runtime and the service deployment. The -protocol mandates the following messages: - -- `StartMessage` -- `[..]EntryMessage` -- `CompletionMessage` -- `SuspensionMessage` -- `EntryAckMessage` -- `EndMessage` - -### Message stream - -In order to execute an invocation, service deployment and restate Runtime open a single stream between the runtime and -the service deployment. Given 10 concurrent invocations to a service deployment, there are 10 concurrent streams, each -of them mapping to a specific invocation. - -Every unit of the stream contains a Message serialized using the -[Protobuf encoding](https://protobuf.dev/programming-guides/encoding/), using the definitions in -[`protocol.proto`](dev/restate/service/protocol.proto), prepended by a [message header](#message-header). - -This stream is implemented using HTTP, and depending on the deployment environment and the HTTP version it can operate -in two modes: - -- Full duplex (bidirectional) stream: Messages are sent back and forth on the same stream at the same time. This option - is supported only when using HTTP/2. -- Request/Response stream: Messages are sent from runtime to service deployment, and later from service deployment to - runtime. Once the service deployment starts sending messages to the runtime, the runtime cannot send messages anymore - back to the service deployment. - -A message stream MUST start with `StartMessage` and MUST end with either: - -- One [`SuspensionMessage`](#suspension) -- One [`ErrorMessage`](#failures) -- One `EndMessage` - -If the message stream does not end with any of these two messages, it will be considered equivalent to sending an -`ErrorMessage` with an [unknown failure](#failures). - -The `EndMessage` marks the end of the invocation lifecycle, that is the end of the journal. - -### Initiating the stream - -As described above, the runtime opens an HTTP request to the SDK to initiate the message stream. - -#### Method - -The request method used is always `POST`. - -#### Path - -The request path has the following format: - -``` -/invoke/{serviceName}/{handlerName} -``` - -For example: - -``` -/invoke/counter.Counter/Add -``` - -An arbitrary path MAY prepend the aforementioned path format. - -In case the path format is not respected, or `serviceName` or `handlerName` is unknown, the SDK MUST close the stream -replying back with a `404` status code. - -#### Content type and protocol version - -The request contains the content-type `application/vnd.restate.invocation.vX` where `X` is the service protocol version -chosen by the runtime, e.g.: - -```http request -content-type: application/vnd.restate.invocation.v1 -``` - -The service protocol version is defined by `ServiceProtocolVersion` in -[`protocol.proto`](dev/restate/service/protocol.proto). - -The SDK MUST return back the same content-type in the successful response case. If the SDK doesn't support the -content-type, It SHOULD close the stream replying back with a `415` status code. - -#### Stream ready - -To notify that the stream is ready to be used, the SDK MUST reply with `200` status code. - -#### SDK version - -The SDK MAY send back the response header `x-restate-server`: - -```http request -x-restate-server: / -``` - -E.g.: - -```http request -x-restate-server: restate-sdk-java/0.8.0 -``` - -This header is used for observability purposes by the Restate observability tools. - -### Message header - -Each message is sent together with a message header prepending the serialized message bytes. - - 0 1 2 3 - 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 - +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ - | Type | Reserved | - +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ - | Length | - +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ - -The message header is a fixed 64-bit number containing: - -- (MSB) Message type: 16 bit. The type of the message. Used to deserialize the message. The first 6 bits are used as the - message namespace, to categorize the different message types. -- Message reserved bits: 16 bit. These bits can be used to send flags and other information, and are defined per message - type/namespace. -- Message length: 32 bit. Length of serialized message bytes, excluding header length. - -### StartMessage - -The `StartMessage` carries the metadata required to bootstrap the invocation state machine, including: - -- `known_entries`: The known journal length -- `state_map`: The eager state map (see [Eager state](#eager-state)) - -**Header** - - 0 1 2 3 - 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 - +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ - | 0x0000 | Reserved | - +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ - | Length | - +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ - -Flags: - -- 16 bits: Reserved - -### Entries and Completions - -For each journal entry the runtime commits the entry message and executes the corresponding action atomically. The -runtime won't commit the entry, nor perform the action, if the entry is invalid. If an entry is not committed, all the -subsequent entries are not committed as well. - -Entries can be: - -- Completable or not: These represent actions the runtime will perform, and for which consequently provide a completion - value. All these entries have a `result` field defined in the message descriptor, defining the different variants of - the completion value, and have a `COMPLETED` flag in the header. -- Fallible or not: These can be rejected by the runtime when trying to commit them. The failure is not recorded in the - journal, thus the runtime will abort the stream after receiving an invalid entry from the SDK. - -The type of the journal entry is intrinsic in the definition of the journal action itself. - -The header format for journal entries applies both when the runtime is sending entries to the SDK during a replay, and -when the SDK sends entries to the runtime during processing. - -**Headers** - - 0 1 2 3 - 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 - +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ - | Type |A| Reserved |C| - +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ - | Length | - +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ - -Flags: - -- 1 bit (MSB) `A`: [`REQUIRES_ACK` flag](#acknowledgment-of-stored-entries). Mask: `0x0000_8000_0000_0000` -- 14 bits: Reserved -- 1 bit `C`: `COMPLETED` flag (only Completable journal entries). Mask: `0x0000_0001_0000_0000` - -#### Completable journal entries and `CompletionMessage` - -A completable journal entry at any point in time is either completed or not. After a completable journal entry is -completed, it cannot change its state back to not completed. - -There are three situations where a completable journal entry can be completed: - -- At creation time: when the SDK creates a completable journal entry, it can fill its `result` field and set the - `COMPLETED` flag before sending the entry to the runtime. When replaying, the same `result` will be used. -- At suspension time: when the invocation is suspended, meaning there is no in-flight message stream, the runtime might - internally complete a journal entry filling its `result` field. -- During the invocation processing: when the message stream is active and in [Full duplex mode](#message-stream), the - runtime can notify a completion by sending a `CompletionMessage`. - -A `CompletionMessage` holds the `result` of the JournalEntry and its `entry_index`. A `CompletionMessage` can hold all -the possible variants of a `result` field, and the SDK MUST be able to correlate the `result` field of the entry with -the `result` field of `CompletionMessage` through the `entry_index`. After the completion is notified, the SDK MUST NOT -send any additional messages related to this specific entry. On subsequent replays, the runtime automatically fills the -`result` field of this entry, without sending a subsequent `CompletionMessage`. - -The runtime can send `CompletionMessage` in a different order than the one used to store journal entries. The SDK might -also not be interested in the `result` of completable journal entries, or it might be interested in the `results` in a -different order used to create the related journal entries. Usually it's the service business logic that dictates in -which `result`s the SDK is interested, and in which order. - -**`CompletionMessage` Header** - - 0 1 2 3 - 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 - +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ - | 0x0001 | Reserved | - +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ - | Length | - +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ - -#### Acknowledgment of stored entries - -If the SDK needs an acknowledgment that a journal entry, of any type, has been persisted, it can set the `REQUIRES_ACK` -flag in the header. When set, as soon as the entry is persisted, the runtime will send back a `EntryAckMessage` with the -index of the corresponding entry. - -**`EntryAckMessage` Header** - - 0 1 2 3 - 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 - +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ - | 0x0004 | Reserved | - +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ - | Length | - +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ - -#### Entry names - -Every Journal entry has a field `string name = 12`, which can be set by the SDK when recording the entry. This field is -used for observability purposes by Restate observability tools. - -### Journal entries reference - -The following tables describe the currently available journal entries. For more details, check the protobuf message -descriptions in [`protocol.proto`](dev/restate/service/protocol.proto). - -| Message | Type | Completable | Fallible | Description | -|-----------------------------------|----------|-------------|----------|------------------------------------------------------------------------------------------------------------------------------------------------------------------| -| `InputEntryMessage` | `0x0400` | No | No | Carries the invocation input message(s) of the invocation. | -| `GetStateEntryMessage` | `0x0800` | Yes | No | Get the value of a service instance state key. | -| `GetStateKeysEntryMessage` | `0x0804` | Yes | No | Get all the known state keys for this service instance. Note: the completion value for this message is a protobuf of type `GetStateKeysEntryMessage.StateKeys`. | -| `SleepEntryMessage` | `0x0C00` | Yes | No | Initiate a timer that completes after the given time. | -| `CallEntryMessage` | `0x0C01` | Yes | Yes | Invoke another Restate service. | -| `AwakeableEntryMessage` | `0x0C03` | Yes | No | Arbitrary result container which can be completed from another service, given a specific id. See [Awakeable identifier](#awakeable-identifier) for more details. | -| `OneWayCallEntryMessage` | `0x0C02` | No | Yes | Invoke another Restate service at the given time, without waiting for the response. | -| `CompleteAwakeableEntryMessage` | `0x0C04` | No | Yes | Complete an `Awakeable`, given its id. See [Awakeable identifier](#awakeable-identifier) for more details. | -| `OutputEntryMessage` | `0x0401` | No | No | Carries the invocation output message(s) or terminal failure of the invocation. | -| `SetStateEntryMessage` | `0x0800` | No | No | Set the value of a service instance state key. | -| `ClearStateEntryMessage` | `0x0801` | No | No | Clear the value of a service instance state key. | -| `ClearAllStateEntryMessage` | `0x0802` | No | No | Clear all the values of the service instance state. | -| `RunEntryMessage` | `0x0C05` | No | No | Run non-deterministic user provided code and persist the result. | -| `GetPromiseEntryMessage` | `0x0808` | Yes | No | Get or wait the value of the given promise. If the value is not present yet, this entry will block waiting for the value. | -| `PeekPromiseEntryMessage` | `0x0809` | Yes | No | Get the value of the given promise. If the value is not present, this entry completes immediately with empty completion. | -| `CompletePromiseEntryMessage` | `0x080A` | Yes | No | Complete the given promise. If the promise was completed already, this entry completes with a failure. | -| `CancelInvocationEntryMessage` | `0x0C06` | No | Yes | Cancel the target invocation id or the target journal entry. | -| `GetCallInvocationIdEntryMessage` | `0x0C07` | Yes | Yes | Get the invocation id of a previously created call/one way call. | - -#### Awakeable identifier - -When creating an `AwakeableEntryMessage`, the SDK MUST expose to the user code an id, required to later complete the -entry, using either `CompleteAwakeableEntryMessage` or some other mechanism provided by the runtime. - -The id format is a string starts with `prom_1` concatenated with a -[Base64 URL Safe string](https://datatracker.ietf.org/doc/html/rfc4648#section-5) encoding of a byte array that -concatenates: - -- `StartMessage.id` -- The index of the Awakeable entry, encoded as unsigned 32 bit integer big endian. - -An example of a valid identifier would look like `prom_1NMyOAvDK2CcBjUH4Rmb7eGBp0DNNDnmsAAAAAQ` - -## Suspension - -As mentioned in [Replaying and processing](#replaying-and-processing), an invocation can be suspended while waiting for -some journal entries to complete. When suspended, no message stream is in-flight for the given invocation. - -To suspend an invocation, the SDK MUST send a `SuspensionMessage` containing entry indexes of the journal entry results -required to continue the computation. This set MUST contain only indexes of completable journal entries that are not -completed and that have been sent to the runtime. After sending the `SuspensionMessage`, the stream MUST be closed. - -The runtime will resume the invocation as soon as at least one of the given indexes is completed. - -## Failures - -There are a number of failures that can incur during a service invocation, including: - -- Transient network failures that interrupt the message stream -- SDK bugs -- Protocol violations -- Business logic bugs -- User thrown retryable errors - -To notify a failure, the SDK can either: - -- Close the stream with `ErrorMessage` as last message. This message is used by the runtime for accurate reporting to - the user. -- Close the stream without `EndMessage` or `SuspensionMessage` or `ErrorMessage`. This is equivalent to sending an - `ErrorMessage` with unknown reason. - -The runtime takes care of retrying to execute the invocation after such failures occur, following a defined set of -policies. When retrying, the previous stored journal will be reused. Moreover, the SDK MUST NOT assume that every -journal entry previously sent on the same message stream has been correctly stored. - -The SDK can allow users to end/terminate invocations with an exceptional return value. This is done in a similar fashion -to the successful return value case, by generating a `OutputStreamEntry` with the `failure` variant set, sending it and -closing the stream afterward. - -**`ErrorMessage` Header** - - 0 1 2 3 - 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 - +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ - | 0x0003 | Reserved | - +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ - | Length | - +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ - -## Endpoint discovery - -Restate expects SDKs to provide reflective information about the exposed services and the supported protocol versions at -`/discovery`. These reflective information are propagated through an _endpoint manifest_. This document MUST follow the -schema defined in [endpoint_manifest_schema.json](./endpoint_manifest_schema.json) and is identified by the content-type -string `application/vnd.restate.endpointmanifest.vX+json`, where `X` is the manifest version. - -When sending the discovery request, the Restate runtime might specify a set of supported endpoint manifest schemas in -the [`Accept`](https://httpwg.org/specs/rfc9110.html#field.accept) header, for example: - -```http -accept: application/vnd.restate.endpointmanifest.v2+json, application/vnd.restate.endpointmanifest.v1+json -``` - -When replying, the content-type MUST contain the chosen endpoint manifest type/version: - -```http -content-type: application/vnd.restate.endpointmanifest.v1+json -``` - -The service discovery protocol version is defined by `ServiceDiscoveryProtocolVersion` in -[`discovery.proto`](dev/restate/service/discovery.proto). - -## Optional features - -The following section describes optional features SDK developers MAY implement to improve the experience and provide -additional features to the users. - -### Custom entry messages - -The protocol allows the SDK to register an arbitrary entry type within the journal. The type MUST be `>= 0xFC00`. The -runtime will treat this entry as any other entry, persisting it and sending it during replay in the correct order. - -Custom entries MAY have the entry name field `12`, as described in [entry names](#entry-names). - -The field numbers 13, 14 and 15 MUST not be used, as they're reserved for completable journal entries, as described in -[completable journal entries](#completable-journal-entries-and-completionmessage). - -**Header** - - 0 1 2 3 - 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 - +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ - | Type |A| Reserved | - +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ - | Length | - +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ - -- Type MUST be `>= 0xFC00` - -Flags: - -- 1 bit (MSB) `A`: [`REQUIRES_ACK` flag](#acknowledgment-of-stored-entries). Mask: `0x0000_8000_0000_0000` -- 15 bits: Reserved - -### Eager state - -As described in [Journal entries reference](#journal-entries-reference), to get a service instance state entry, the SDK -creates a `GetStateEntryMessage` without a result, and waits for a `Completion` with the result, or alternatively -suspends and expects the `GetStateEntryMessage.result` is filled when replaying. - -SDKs MAY optimize the state access operations by reading the `partial_state` and `state_map` fields within the -[`StartMessage`](#startmessage). The `state_map` field contains key-value pairs of the current state of the service -instance. When `partial_state` is set, the `state_map` is partial/incomplete, meaning there might be entries stored in -the Runtime that are not part of `state_map`. When `partial_state` is unset, the `state_map` is complete, thus if an -entry is not within the map, the SDK can assume it's not stored in the runtime either. - -A possible implementation could be the following. Given a user requests a state entry with key `my-key`: - -- If `my-key` is available in `state_map`, generate a `GetStateEntryMessage` with filled `result`, and return the value - to the user -- If `my-key` is not available in `state_map` - - If `partial_state` is unset, generate a `GetStateEntryMessage` with empty `result`, and return empty to the user - - If `partial_state` is set, generate a `GetStateEntryMessage` without a `result`, and wait for the runtime to send a - `Completion` back (same logic as without eager state) - -In order for the aforementioned algorithm to work, set, clear and clear all state operations must be reflected on the -local `state_map` as well. diff --git a/sdk-core/src/test/java/dev/restate/sdk/core/AssertUtils.java b/sdk-core/src/test/java/dev/restate/sdk/core/AssertUtils.java index d0dcc2349..3e4dd1c2c 100644 --- a/sdk-core/src/test/java/dev/restate/sdk/core/AssertUtils.java +++ b/sdk-core/src/test/java/dev/restate/sdk/core/AssertUtils.java @@ -10,73 +10,18 @@ import static org.assertj.core.api.Assertions.assertThat; import static org.assertj.core.api.Assertions.fail; -import static org.assertj.core.api.InstanceOfAssertFactories.STRING; -import static org.assertj.core.api.InstanceOfAssertFactories.type; -import com.google.protobuf.MessageLite; -import dev.restate.common.Slice; -import dev.restate.sdk.common.TerminalException; import dev.restate.sdk.core.generated.manifest.EndpointManifestSchema; import dev.restate.sdk.core.generated.manifest.Handler; import dev.restate.sdk.core.generated.manifest.Service; -import dev.restate.sdk.core.generated.protocol.Protocol; -import dev.restate.sdk.core.statemachine.InvocationInput; -import dev.restate.sdk.core.statemachine.MessageDecoder; import dev.restate.sdk.endpoint.Endpoint; -import java.util.ArrayList; -import java.util.List; import java.util.Optional; -import java.util.function.Consumer; import java.util.stream.Collectors; -import java.util.stream.Stream; import org.assertj.core.api.AbstractObjectAssert; -import org.assertj.core.api.ListAssert; import org.assertj.core.api.ObjectAssert; public class AssertUtils { - public static Consumer> containsOnly(Consumer consumer) { - return msgs -> assertThat(msgs).satisfiesExactly(consumer); - } - - public static Consumer> containsOnlyExactErrorMessage(Throwable e) { - return containsOnly(exactErrorMessage(e)); - } - - public static Consumer errorMessage( - Consumer consumer) { - return msg -> - assertThat(msg).asInstanceOf(type(Protocol.ErrorMessage.class)).satisfies(consumer); - } - - public static Consumer exactErrorMessage(Throwable e) { - return errorMessage( - msg -> - assertThat(msg) - .returns(e.getMessage(), Protocol.ErrorMessage::getMessage) - .returns( - TerminalException.INTERNAL_SERVER_ERROR_CODE, Protocol.ErrorMessage::getCode) - .extracting(Protocol.ErrorMessage::getStacktrace, STRING) - .startsWith(e.getClass().getName())); - } - - public static Consumer errorDescriptionStartingWith(String str) { - return errorMessage( - msg -> - assertThat(msg) - .extracting(Protocol.ErrorMessage::getStacktrace, STRING) - .startsWith(str)); - } - - public static Consumer protocolExceptionErrorMessage(int code) { - return errorMessage( - msg -> - assertThat(msg) - .returns(code, Protocol.ErrorMessage::getCode) - .extracting(Protocol.ErrorMessage::getStacktrace, STRING) - .startsWith(ProtocolException.class.getCanonicalName())); - } - public static EndpointManifestSchemaAssert assertThatDiscovery(Object... services) { Endpoint.Builder builder = Endpoint.builder(); for (var svc : services) { @@ -94,22 +39,10 @@ public static EndpointManifestSchemaAssert assertThatDiscovery(Endpoint endpoint return new EndpointManifestSchemaAssert( new EndpointManifest(endpoint.getServiceDefinitions(), true) .manifest( - DiscoveryProtocol.MAX_SERVICE_DISCOVERY_PROTOCOL_VERSION, - EndpointManifestSchema.ProtocolMode.BIDI_STREAM), + DiscoveryProtocol.Version.MAX, EndpointManifestSchema.ProtocolMode.BIDI_STREAM), EndpointManifestSchemaAssert.class); } - public static ListAssert assertThatDecodingMessages(Slice... slices) { - var messageDecoder = new MessageDecoder(); - Stream.of(slices).forEach(messageDecoder::offer); - - var outputList = new ArrayList(); - while (messageDecoder.isNextAvailable()) { - outputList.add(messageDecoder.next()); - } - return assertThat(outputList); - } - public static class EndpointManifestSchemaAssert extends AbstractObjectAssert { public EndpointManifestSchemaAssert( diff --git a/sdk-core/src/test/java/dev/restate/sdk/core/AsyncResultTestSuite.java b/sdk-core/src/test/java/dev/restate/sdk/core/AsyncResultTestSuite.java deleted file mode 100644 index cb170a9e5..000000000 --- a/sdk-core/src/test/java/dev/restate/sdk/core/AsyncResultTestSuite.java +++ /dev/null @@ -1,323 +0,0 @@ -// Copyright (c) 2023 - Restate Software, Inc., Restate GmbH -// -// This file is part of the Restate Java SDK, -// which is released under the MIT license. -// -// You can find a copy of the license in file LICENSE in the root -// directory of this repository or package, or at -// https://github.com/restatedev/sdk-java/blob/main/LICENSE -package dev.restate.sdk.core; - -import static dev.restate.sdk.core.TestDefinitions.*; -import static dev.restate.sdk.core.statemachine.ProtoUtils.*; -import static org.assertj.core.api.Assertions.assertThat; - -import dev.restate.sdk.common.TerminalException; -import dev.restate.sdk.core.generated.protocol.Protocol; -import java.util.function.Supplier; -import java.util.stream.Stream; - -public abstract class AsyncResultTestSuite implements TestSuite { - - protected abstract TestInvocationBuilder reverseAwaitOrder(); - - protected abstract TestInvocationBuilder awaitTwiceTheSameAwaitable(); - - protected abstract TestInvocationBuilder awaitAll(); - - protected abstract TestInvocationBuilder awaitAny(); - - protected abstract TestInvocationBuilder combineAnyWithAll(); - - protected abstract TestInvocationBuilder awaitAnyIndex(); - - protected abstract TestInvocationBuilder awaitOnAlreadyResolvedAwaitables(); - - protected abstract TestInvocationBuilder awaitWithTimeout(); - - protected Stream anyTestDefinitions( - Supplier testInvocation) { - return Stream.of( - testInvocation - .get() - .withInput(startMessage(1), inputCmd()) - .expectingOutput( - callCmd(1, 2, GREETER_SERVICE_TARGET, "Francesco"), - callCmd(3, 4, GREETER_SERVICE_TARGET, "Till"), - suspensionMessage(2, 4)) - .named("No completions will suspend"), - testInvocation - .get() - .withInput( - startMessage(4), - inputCmd(), - callCmd(1, 2, GREETER_SERVICE_TARGET, "Francesco"), - callCmd(3, 4, GREETER_SERVICE_TARGET, "Till"), - callCompletion(4, "TILL")) - .expectingOutput(outputCmd("TILL"), END_MESSAGE) - .named("Only one completion completes any combinator"), - testInvocation - .get() - .withInput( - startMessage(4), - inputCmd(), - callCmd(1, 2, GREETER_SERVICE_TARGET, "Francesco"), - callCmd(3, 4, GREETER_SERVICE_TARGET, "Till"), - callCompletion(4, new TerminalException("My error"))) - .expectingOutput(outputCmd(new TerminalException("My error")), END_MESSAGE) - .named("Only one failure completes any combinator"), - testInvocation - .get() - .withInput( - startMessage(5), - inputCmd(), - callCmd(1, 2, GREETER_SERVICE_TARGET, "Francesco"), - callCompletion(2, "FRANCESCO"), - callCmd(3, 4, GREETER_SERVICE_TARGET, "Till"), - callCompletion(4, "TILL")) - .assertingOutput( - msgs -> { - assertThat(msgs).hasSize(2); - - assertThat(msgs).element(0).isIn(outputCmd("FRANCESCO"), outputCmd("TILL")); - assertThat(msgs).element(1).isEqualTo(END_MESSAGE); - }) - .named("Everything completed completes the any combinator"), - testInvocation - .get() - .withInput(startMessage(1), inputCmd(), callCompletion(2, "FRANCESCO")) - .onlyBidiStream() - .expectingOutput( - callCmd(1, 2, GREETER_SERVICE_TARGET, "Francesco"), - callCmd(3, 4, GREETER_SERVICE_TARGET, "Till"), - outputCmd("FRANCESCO"), - END_MESSAGE) - .named("Complete any asynchronously")); - } - - @Override - public Stream definitions() { - return Stream.concat( - // --- Any combinator - anyTestDefinitions(this::awaitAny), - Stream.of( - // --- Reverse await order - this.reverseAwaitOrder() - .withInput(startMessage(1), inputCmd()) - .expectingOutput( - callCmd(1, 2, GREETER_SERVICE_TARGET, "Francesco"), - callCmd(3, 4, GREETER_SERVICE_TARGET, "Till"), - suspensionMessage(4)) - .named("None completed"), - this.reverseAwaitOrder() - .withInput( - startMessage(1), - inputCmd(), - callCompletion(2, "FRANCESCO"), - callCompletion(4, "TILL")) - .onlyBidiStream() - .expectingOutput( - callCmd(1, 2, GREETER_SERVICE_TARGET, "Francesco"), - callCmd(3, 4, GREETER_SERVICE_TARGET, "Till"), - setStateCmd("A2", "TILL"), - outputCmd("FRANCESCO-TILL"), - END_MESSAGE) - .named("A1 and A2 completed later"), - this.reverseAwaitOrder() - .withInput( - startMessage(1), - inputCmd(), - callCompletion(4, "TILL"), - callCompletion(2, "FRANCESCO")) - .onlyBidiStream() - .expectingOutput( - callCmd(1, 2, GREETER_SERVICE_TARGET, "Francesco"), - callCmd(3, 4, GREETER_SERVICE_TARGET, "Till"), - setStateCmd("A2", "TILL"), - outputCmd("FRANCESCO-TILL"), - END_MESSAGE) - .named("A2 and A1 completed later in reverse order"), - this.reverseAwaitOrder() - .withInput(startMessage(1), inputCmd(), callCompletion(4, "TILL")) - .onlyBidiStream() - .expectingOutput( - callCmd(1, 2, GREETER_SERVICE_TARGET, "Francesco"), - callCmd(3, 4, GREETER_SERVICE_TARGET, "Till"), - setStateCmd("A2", "TILL"), - suspensionMessage(2)) - .named("Only A2 completed"), - this.reverseAwaitOrder() - .withInput(startMessage(1), inputCmd(), callCompletion(2, "FRANCESCO")) - .onlyBidiStream() - .expectingOutput( - callCmd(1, 2, GREETER_SERVICE_TARGET, "Francesco"), - callCmd(3, 4, GREETER_SERVICE_TARGET, "Till"), - suspensionMessage(4)) - .named("Only A1 completed"), - - // --- Await twice the same executable - this.awaitTwiceTheSameAwaitable() - .withInput(startMessage(1), inputCmd(), callCompletion(2, "FRANCESCO")) - .onlyBidiStream() - .expectingOutput( - callCmd(1, 2, GREETER_SERVICE_TARGET, "Francesco"), - outputCmd("FRANCESCO-FRANCESCO"), - END_MESSAGE), - - // --- All combinator - this.awaitAll() - .withInput(startMessage(1), inputCmd()) - .expectingOutput( - callCmd(1, 2, GREETER_SERVICE_TARGET, "Francesco"), - callCmd(3, 4, GREETER_SERVICE_TARGET, "Till"), - suspensionMessage(2, 4)) - .named("No completions will suspend"), - this.awaitAll() - .withInput( - startMessage(4), - inputCmd(), - callCmd(1, 2, GREETER_SERVICE_TARGET, "Francesco"), - callCmd(3, 4, GREETER_SERVICE_TARGET, "Till"), - callCompletion(4, "TILL")) - .expectingOutput(suspensionMessage(2)) - .named("Only one completion will suspend"), - this.awaitAll() - .withInput( - startMessage(3), - inputCmd(), - callCmd(1, 2, GREETER_SERVICE_TARGET, "Francesco"), - callCmd(3, 4, GREETER_SERVICE_TARGET, "Till"), - callCompletion(2, "FRANCESCO"), - callCompletion(4, "TILL")) - .expectingOutput(outputCmd("FRANCESCO-TILL"), END_MESSAGE) - .named("Everything completed completes the all combinator"), - this.awaitAll() - .withInput( - startMessage(1), - inputCmd(), - callCompletion(2, "FRANCESCO"), - callCompletion(4, "TILL")) - .onlyBidiStream() - .expectingOutput( - callCmd(1, 2, GREETER_SERVICE_TARGET, "Francesco"), - callCmd(3, 4, GREETER_SERVICE_TARGET, "Till"), - outputCmd("FRANCESCO-TILL"), - END_MESSAGE) - .named("Complete all asynchronously"), - this.awaitAll() - .withInput( - startMessage(1), - inputCmd(), - callCompletion(2, new IllegalStateException("My error"))) - .onlyBidiStream() - .expectingOutput( - callCmd(1, 2, GREETER_SERVICE_TARGET, "Francesco"), - callCmd(3, 4, GREETER_SERVICE_TARGET, "Till"), - outputCmd(new IllegalStateException("My error")), - END_MESSAGE) - .named("All fails on first failure"), - this.awaitAll() - .withInput( - startMessage(1), - inputCmd(), - callCompletion(2, "FRANCESCO"), - callCompletion(4, new IllegalStateException("My error"))) - .onlyBidiStream() - .expectingOutput( - callCmd(1, 2, GREETER_SERVICE_TARGET, "Francesco"), - callCmd(3, 4, GREETER_SERVICE_TARGET, "Till"), - outputCmd(new IllegalStateException("My error")), - END_MESSAGE) - .named("All fails on second failure"), - - // --- Compose any with all - this.combineAnyWithAll() - .withInput( - startMessage(5), - inputCmd(), - signalNotification(17, "1"), - signalNotification(18, "2"), - signalNotification(19, "3"), - signalNotification(20, "4")) - .expectingOutput(outputCmd("123"), END_MESSAGE), - this.combineAnyWithAll() - .withInput( - startMessage(5), - inputCmd(), - signalNotification(18, "2"), - signalNotification(17, "1"), - signalNotification(20, "4"), - signalNotification(19, "3")) - .expectingOutput(outputCmd("224"), END_MESSAGE) - .named("Inverted order"), - - // --- Await Any with index - this.awaitAnyIndex() - .withInput( - startMessage(5), - inputCmd(), - signalNotification(17, "1"), - signalNotification(18, "2"), - signalNotification(19, "3"), - signalNotification(20, "4")) - .expectingOutput(outputCmd("0"), END_MESSAGE), - this.awaitAnyIndex() - .withInput( - startMessage(5), - inputCmd(), - signalNotification(19, "3"), - signalNotification(18, "2"), - signalNotification(17, "1"), - signalNotification(20, "4")) - .expectingOutput(outputCmd("1"), END_MESSAGE) - .named("Complete all"), - - // --- Compose nested and resolved all should work - this.awaitOnAlreadyResolvedAwaitables() - .withInput( - startMessage(3), - inputCmd(), - signalNotification(17, "1"), - signalNotification(18, "2")) - .expectingOutput(outputCmd("12"), END_MESSAGE), - - // --- Await with timeout - this.awaitWithTimeout() - .withInput(startMessage(1), inputCmd(), callCompletion(2, "FRANCESCO")) - .onlyBidiStream() - .assertingOutput( - messages -> { - assertThat(messages).hasSize(4); - assertThat(messages) - .element(0) - .isEqualTo(callCmd(1, 2, GREETER_SERVICE_TARGET, "Francesco").build()); - assertThat(messages) - .element(1) - .isInstanceOf(Protocol.SleepCommandMessage.class); - assertThat(messages).element(2).isEqualTo(outputCmd("FRANCESCO")); - assertThat(messages).element(3).isEqualTo(END_MESSAGE); - }), - this.awaitWithTimeout() - .withInput( - startMessage(1), - inputCmd(), - Protocol.SleepCompletionNotificationMessage.newBuilder() - .setCompletionId(3) - .setVoid(Protocol.Void.getDefaultInstance()) - .build()) - .onlyBidiStream() - .assertingOutput( - messages -> { - assertThat(messages).hasSize(4); - assertThat(messages) - .element(0) - .isEqualTo(callCmd(1, 2, GREETER_SERVICE_TARGET, "Francesco").build()); - assertThat(messages) - .element(1) - .isInstanceOf(Protocol.SleepCommandMessage.class); - assertThat(messages).element(2).isEqualTo(outputCmd("timeout")); - assertThat(messages).element(3).isEqualTo(END_MESSAGE); - }) - .named("Fires timeout"))); - } -} diff --git a/sdk-core/src/test/java/dev/restate/sdk/core/AwakeableIdTestSuite.java b/sdk-core/src/test/java/dev/restate/sdk/core/AwakeableIdTestSuite.java deleted file mode 100644 index 6f9408423..000000000 --- a/sdk-core/src/test/java/dev/restate/sdk/core/AwakeableIdTestSuite.java +++ /dev/null @@ -1,53 +0,0 @@ -// Copyright (c) 2023 - Restate Software, Inc., Restate GmbH -// -// This file is part of the Restate Java SDK, -// which is released under the MIT license. -// -// You can find a copy of the license in file LICENSE in the root -// directory of this repository or package, or at -// https://github.com/restatedev/sdk-java/blob/main/LICENSE -package dev.restate.sdk.core; - -import static dev.restate.sdk.core.statemachine.ProtoUtils.*; - -import com.google.protobuf.ByteString; -import dev.restate.sdk.core.TestDefinitions.TestDefinition; -import dev.restate.sdk.core.TestDefinitions.TestSuite; -import java.nio.ByteBuffer; -import java.util.Base64; -import java.util.UUID; -import java.util.stream.Stream; - -public abstract class AwakeableIdTestSuite implements TestSuite { - - protected abstract TestDefinitions.TestInvocationBuilder returnAwakeableId(); - - @Override - public Stream definitions() { - UUID id = UUID.randomUUID(); - String debugId = id.toString(); - byte[] serializedId = serializeUUID(id); - - ByteBuffer expectedAwakeableId = ByteBuffer.allocate(serializedId.length + 4); - expectedAwakeableId.put(serializedId); - expectedAwakeableId.putInt(17); - expectedAwakeableId.flip(); - String base64ExpectedAwakeableId = - "sign_1" + Base64.getUrlEncoder().encodeToString(expectedAwakeableId.array()); - - return Stream.of( - returnAwakeableId() - .withInput( - startMessage(1).setDebugId(debugId).setId(ByteString.copyFrom(serializedId)), - inputCmd()) - .expectingOutput(outputCmd(base64ExpectedAwakeableId), END_MESSAGE)); - } - - private byte[] serializeUUID(UUID uuid) { - ByteBuffer serializedId = ByteBuffer.allocate(16); - serializedId.putLong(uuid.getMostSignificantBits()); - serializedId.putLong(uuid.getLeastSignificantBits()); - serializedId.flip(); - return serializedId.array(); - } -} diff --git a/sdk-core/src/test/java/dev/restate/sdk/core/CallTestSuite.java b/sdk-core/src/test/java/dev/restate/sdk/core/CallTestSuite.java deleted file mode 100644 index 511edeb86..000000000 --- a/sdk-core/src/test/java/dev/restate/sdk/core/CallTestSuite.java +++ /dev/null @@ -1,80 +0,0 @@ -// Copyright (c) 2023 - Restate Software, Inc., Restate GmbH -// -// This file is part of the Restate Java SDK, -// which is released under the MIT license. -// -// You can find a copy of the license in file LICENSE in the root -// directory of this repository or package, or at -// https://github.com/restatedev/sdk-java/blob/main/LICENSE -package dev.restate.sdk.core; - -import static dev.restate.sdk.core.statemachine.ProtoUtils.*; - -import dev.restate.common.Slice; -import dev.restate.common.Target; -import dev.restate.sdk.common.TerminalException; -import dev.restate.sdk.core.TestDefinitions.TestDefinition; -import dev.restate.sdk.core.TestDefinitions.TestSuite; -import dev.restate.sdk.core.generated.protocol.Protocol; -import java.util.Map; -import java.util.stream.Stream; - -public abstract class CallTestSuite implements TestSuite { - - protected abstract TestDefinitions.TestInvocationBuilder oneWayCall( - Target target, String idempotencyKey, Map headers, Slice body); - - protected abstract TestDefinitions.TestInvocationBuilder implicitCancellation( - Target target, Slice body); - - private static String IDEMPOTENCY_KEY = "my-idempotency-key"; - private static Map HEADERS = Map.of("abc", "123", "fge", "456"); - private static Slice BODY = Slice.wrap("bla"); - - @Override - public Stream definitions() { - return Stream.of( - oneWayCall(GREETER_SERVICE_TARGET, IDEMPOTENCY_KEY, HEADERS, BODY) - .withInput(startMessage(1), inputCmd()) - .expectingOutput( - oneWayCallCmd(1, GREETER_SERVICE_TARGET, IDEMPOTENCY_KEY, HEADERS, BODY), - outputCmd(), - END_MESSAGE), - oneWayCall(GREETER_SERVICE_TARGET, IDEMPOTENCY_KEY, HEADERS, BODY) - .withInput( - startMessage(3), - inputCmd(), - oneWayCallCmd(1, GREETER_SERVICE_TARGET, IDEMPOTENCY_KEY, HEADERS, BODY), - callInvocationIdCompletion(1, "abc")) - .expectingOutput(outputCmd(), END_MESSAGE) - .named("With invocation ID completion"), - oneWayCall(GREETER_VIRTUAL_OBJECT_TARGET, IDEMPOTENCY_KEY, HEADERS, BODY) - .withInput(startMessage(1), inputCmd()) - .expectingOutput( - oneWayCallCmd(1, GREETER_VIRTUAL_OBJECT_TARGET, IDEMPOTENCY_KEY, HEADERS, BODY), - outputCmd(), - END_MESSAGE), - implicitCancellation(GREETER_SERVICE_TARGET, BODY) - .withInput( - startMessage(3), - inputCmd(), - callCmd(1, 2, GREETER_SERVICE_TARGET, BODY.toByteArray()), - CANCELLATION_SIGNAL) - .onlyBidiStream() - .expectingOutput(Protocol.SuspensionMessage.newBuilder().addWaitingCompletions(1)) - .named("Suspends on waiting the invocation id"), - implicitCancellation(GREETER_SERVICE_TARGET, BODY) - .withInput( - startMessage(4), - inputCmd(), - callCmd(1, 2, GREETER_SERVICE_TARGET, BODY.toByteArray()), - CANCELLATION_SIGNAL, - callInvocationIdCompletion(1, "my-id")) - .onlyBidiStream() - .expectingOutput( - sendCancelSignal("my-id"), - outputCmd(new TerminalException(TerminalException.CANCELLED_CODE)), - END_MESSAGE) - .named("Surfaces cancellation")); - } -} diff --git a/sdk-core/src/test/java/dev/restate/sdk/core/ComponentDiscoveryHandlerTest.java b/sdk-core/src/test/java/dev/restate/sdk/core/ComponentDiscoveryHandlerTest.java index 59b3f59e7..2c4a72f24 100644 --- a/sdk-core/src/test/java/dev/restate/sdk/core/ComponentDiscoveryHandlerTest.java +++ b/sdk-core/src/test/java/dev/restate/sdk/core/ComponentDiscoveryHandlerTest.java @@ -38,8 +38,7 @@ void handleWithMultipleServices() { EndpointManifestSchema manifest = deploymentManifest.manifest( - DiscoveryProtocol.MAX_SERVICE_DISCOVERY_PROTOCOL_VERSION, - EndpointManifestSchema.ProtocolMode.REQUEST_RESPONSE); + DiscoveryProtocol.Version.MAX, EndpointManifestSchema.ProtocolMode.REQUEST_RESPONSE); assertThat(manifest.getServices()).extracting(Service::getName).containsOnly("MyGreeter"); assertThat(manifest.getProtocolMode()) diff --git a/sdk-core/src/test/java/dev/restate/sdk/core/EagerStateTestSuite.java b/sdk-core/src/test/java/dev/restate/sdk/core/EagerStateTestSuite.java deleted file mode 100644 index 0f3c34ebc..000000000 --- a/sdk-core/src/test/java/dev/restate/sdk/core/EagerStateTestSuite.java +++ /dev/null @@ -1,188 +0,0 @@ -// Copyright (c) 2023 - Restate Software, Inc., Restate GmbH -// -// This file is part of the Restate Java SDK, -// which is released under the MIT license. -// -// You can find a copy of the license in file LICENSE in the root -// directory of this repository or package, or at -// https://github.com/restatedev/sdk-java/blob/main/LICENSE -package dev.restate.sdk.core; - -import static dev.restate.sdk.core.TestDefinitions.*; -import static dev.restate.sdk.core.generated.protocol.Protocol.*; -import static dev.restate.sdk.core.statemachine.ProtoUtils.*; -import static org.assertj.core.api.AssertionsForClassTypes.entry; - -import com.google.protobuf.MessageLite; -import java.util.Map; -import java.util.stream.Stream; - -public abstract class EagerStateTestSuite implements TestSuite { - - protected abstract TestInvocationBuilder getEmpty(); - - protected abstract TestInvocationBuilder get(); - - protected abstract TestInvocationBuilder getAppendAndGet(); - - protected abstract TestInvocationBuilder getClearAndGet(); - - protected abstract TestInvocationBuilder getClearAllAndGet(); - - protected abstract TestInvocationBuilder listKeys(); - - protected abstract TestInvocationBuilder consecutiveGetWithEmpty(); - - private static final Map.Entry STATE_FRANCESCO = entry("STATE", "Francesco"); - private static final Map.Entry ANOTHER_STATE_FRANCESCO = - entry("ANOTHER_STATE", "Francesco"); - private static final MessageLite INPUT_TILL = inputCmd("Till"); - private static final MessageLite GET_STATE_FRANCESCO = getEagerStateCmd("STATE", "Francesco"); - private static final MessageLite GET_STATE_FRANCESCO_TILL = - getEagerStateCmd("STATE", "FrancescoTill"); - private static final MessageLite SET_STATE_FRANCESCO_TILL = setStateCmd("STATE", "FrancescoTill"); - private static final MessageLite OUTPUT_FRANCESCO = outputCmd("Francesco"); - private static final MessageLite OUTPUT_FRANCESCO_TILL = outputCmd("FrancescoTill"); - - @Override - public Stream definitions() { - return Stream.of( - this.getEmpty() - .withInput(startMessage(1).setPartialState(false), INPUT_TILL) - .expectingOutput(getEagerStateEmptyCmd("STATE"), outputCmd("true"), END_MESSAGE) - .named("With complete state"), - this.getEmpty() - .withInput(startMessage(1).setPartialState(true), INPUT_TILL) - .expectingOutput(getLazyStateCmd(1, "STATE"), suspensionMessage(1)) - .named("With partial state"), - this.getEmpty() - .withInput( - startMessage(2).setPartialState(true), INPUT_TILL, getEagerStateEmptyCmd("STATE")) - .expectingOutput(outputCmd("true"), END_MESSAGE) - .named("Resume with partial state"), - this.get() - .withInput( - startMessage(1, "my-greeter", STATE_FRANCESCO).setPartialState(false), INPUT_TILL) - .expectingOutput(GET_STATE_FRANCESCO, OUTPUT_FRANCESCO, END_MESSAGE) - .named("With complete state"), - this.get() - .withInput( - startMessage(1, "my-greeter", STATE_FRANCESCO).setPartialState(true), INPUT_TILL) - .expectingOutput(GET_STATE_FRANCESCO, OUTPUT_FRANCESCO, END_MESSAGE) - .named("With partial state"), - this.get() - .withInput(startMessage(1).setPartialState(true), INPUT_TILL) - .expectingOutput(getLazyStateCmd(1, "STATE"), suspensionMessage(1)) - .named("With partial state without the state entry"), - this.getAppendAndGet() - .withInput(startMessage(1, "my-greeter", STATE_FRANCESCO), INPUT_TILL) - .expectingOutput( - GET_STATE_FRANCESCO, - SET_STATE_FRANCESCO_TILL, - GET_STATE_FRANCESCO_TILL, - OUTPUT_FRANCESCO_TILL, - END_MESSAGE) - .named("With state in the state_map"), - this.getAppendAndGet() - .withInput( - startMessage(1).setPartialState(true), - INPUT_TILL, - getLazyStateCompletion(1, "Francesco")) - .onlyBidiStream() - .expectingOutput( - getLazyStateCmd(1, "STATE"), - SET_STATE_FRANCESCO_TILL, - GET_STATE_FRANCESCO_TILL, - OUTPUT_FRANCESCO_TILL, - END_MESSAGE) - .named("With partial state on the first get"), - this.getClearAndGet() - .withInput(startMessage(1, "my-greeter", STATE_FRANCESCO), INPUT_TILL) - .expectingOutput( - GET_STATE_FRANCESCO, - clearStateCmd("STATE"), - getEagerStateEmptyCmd("STATE"), - OUTPUT_FRANCESCO, - END_MESSAGE) - .named("With state in the state_map"), - this.getClearAndGet() - .withInput( - startMessage(1).setPartialState(true), - INPUT_TILL, - getLazyStateCompletion(1, "Francesco")) - .onlyBidiStream() - .expectingOutput( - getLazyStateCmd(1, "STATE"), - clearStateCmd("STATE"), - getEagerStateEmptyCmd("STATE"), - OUTPUT_FRANCESCO, - END_MESSAGE) - .named("With partial state on the first get"), - this.getClearAllAndGet() - .withInput( - startMessage(1, "my-greeter", STATE_FRANCESCO, ANOTHER_STATE_FRANCESCO), INPUT_TILL) - .expectingOutput( - GET_STATE_FRANCESCO, - ClearAllStateCommandMessage.getDefaultInstance(), - getEagerStateEmptyCmd("STATE"), - getEagerStateEmptyCmd("ANOTHER_STATE"), - OUTPUT_FRANCESCO, - END_MESSAGE) - .named("With state in the state_map"), - this.getClearAllAndGet() - .withInput( - startMessage(1).setPartialState(true), - INPUT_TILL, - getLazyStateCompletion(1, STATE_FRANCESCO.getValue())) - .onlyBidiStream() - .expectingOutput( - getLazyStateCmd(1, "STATE"), - ClearAllStateCommandMessage.getDefaultInstance(), - getEagerStateEmptyCmd("STATE"), - getEagerStateEmptyCmd("ANOTHER_STATE"), - OUTPUT_FRANCESCO, - END_MESSAGE) - .named("With partial state on the first get"), - this.listKeys() - .withInput( - startMessage(1, "my-greeter", STATE_FRANCESCO).setPartialState(true), - INPUT_TILL, - GetLazyStateKeysCompletionNotificationMessage.newBuilder() - .setCompletionId(1) - .setStateKeys(stateKeys("a", "b"))) - .onlyBidiStream() - .expectingOutput( - GetLazyStateKeysCommandMessage.newBuilder().setResultCompletionId(1), - outputCmd("a,b"), - END_MESSAGE) - .named("With partial state"), - this.listKeys() - .withInput( - startMessage(1, "my-greeter", STATE_FRANCESCO).setPartialState(false), INPUT_TILL) - .expectingOutput( - GetEagerStateKeysCommandMessage.newBuilder() - .setValue(stateKeys(STATE_FRANCESCO.getKey())), - outputCmd(STATE_FRANCESCO.getKey()), - END_MESSAGE) - .named("With complete state"), - this.listKeys() - .withInput( - startMessage(2).setPartialState(true), - INPUT_TILL, - GetEagerStateKeysCommandMessage.newBuilder().setValue(stateKeys("3", "2", "1"))) - .expectingOutput(outputCmd("3,2,1"), END_MESSAGE) - .named("With replayed list"), - this.consecutiveGetWithEmpty() - .withInput(startMessage(1).setPartialState(false), inputCmd()) - .expectingOutput( - getEagerStateEmptyCmd("key-0"), - getEagerStateEmptyCmd("key-0"), - outputCmd(), - END_MESSAGE), - this.consecutiveGetWithEmpty() - .withInput( - startMessage(2).setPartialState(false), inputCmd(), getEagerStateEmptyCmd("key-0")) - .expectingOutput(getEagerStateEmptyCmd("key-0"), outputCmd(), END_MESSAGE) - .named("With replay of the first get")); - } -} diff --git a/sdk-core/src/test/java/dev/restate/sdk/core/InvocationIdTestSuite.java b/sdk-core/src/test/java/dev/restate/sdk/core/InvocationIdTestSuite.java deleted file mode 100644 index 00bf88d9b..000000000 --- a/sdk-core/src/test/java/dev/restate/sdk/core/InvocationIdTestSuite.java +++ /dev/null @@ -1,34 +0,0 @@ -// Copyright (c) 2023 - Restate Software, Inc., Restate GmbH -// -// This file is part of the Restate Java SDK, -// which is released under the MIT license. -// -// You can find a copy of the license in file LICENSE in the root -// directory of this repository or package, or at -// https://github.com/restatedev/sdk-java/blob/main/LICENSE -package dev.restate.sdk.core; - -import static dev.restate.sdk.core.statemachine.ProtoUtils.*; - -import com.google.protobuf.ByteString; -import dev.restate.sdk.core.TestDefinitions.TestDefinition; -import dev.restate.sdk.core.TestDefinitions.TestInvocationBuilder; -import dev.restate.sdk.core.TestDefinitions.TestSuite; -import java.util.stream.Stream; - -public abstract class InvocationIdTestSuite implements TestSuite { - - protected abstract TestInvocationBuilder returnInvocationId(); - - @Override - public Stream definitions() { - String debugId = "my-debug-id"; - ByteString id = ByteString.copyFromUtf8(debugId); - - return Stream.of( - returnInvocationId() - .withInput(startMessage(1).setDebugId(debugId).setId(id), inputCmd()) - .onlyBidiStream() - .expectingOutput(outputCmd(debugId), END_MESSAGE)); - } -} diff --git a/sdk-core/src/test/java/dev/restate/sdk/core/MockBidiStream.java b/sdk-core/src/test/java/dev/restate/sdk/core/MockBidiStream.java deleted file mode 100644 index 3b144dee6..000000000 --- a/sdk-core/src/test/java/dev/restate/sdk/core/MockBidiStream.java +++ /dev/null @@ -1,114 +0,0 @@ -// Copyright (c) 2023 - Restate Software, Inc., Restate GmbH -// -// This file is part of the Restate Java SDK, -// which is released under the MIT license. -// -// You can find a copy of the license in file LICENSE in the root -// directory of this repository or package, or at -// https://github.com/restatedev/sdk-java/blob/main/LICENSE -package dev.restate.sdk.core; - -import static dev.restate.sdk.core.AssertUtils.assertThatDecodingMessages; - -import com.google.protobuf.MessageLite; -import dev.restate.common.Slice; -import dev.restate.sdk.core.generated.protocol.Protocol; -import dev.restate.sdk.core.statemachine.InvocationInput; -import dev.restate.sdk.core.statemachine.ProtoUtils; -import dev.restate.sdk.endpoint.Endpoint; -import dev.restate.sdk.endpoint.HeadersAccessor; -import dev.restate.sdk.endpoint.definition.ServiceDefinition; -import io.smallrye.mutiny.Multi; -import io.smallrye.mutiny.helpers.test.AssertSubscriber; -import io.smallrye.mutiny.subscription.DemandPacer; -import io.smallrye.mutiny.subscription.FixedDemandPacer; -import java.time.Duration; -import java.util.List; -import java.util.Map; -import java.util.concurrent.Executor; -import java.util.concurrent.Executors; -import org.apache.logging.log4j.ThreadContext; - -public final class MockBidiStream implements TestDefinitions.TestExecutor { - - public static final MockBidiStream INSTANCE = new MockBidiStream(); - - private MockBidiStream() {} - - @Override - public boolean buffered() { - return false; - } - - @Override - public void executeTest(TestDefinitions.TestDefinition definition) { - Executor coreExecutor = Executors.newSingleThreadExecutor(); - - // This test infra supports only services returning one service definition - ServiceDefinition serviceDefinition = definition.getServiceDefinition(); - - // Prepare server - Endpoint.Builder builder = - Endpoint.builder().bind(serviceDefinition, definition.getServiceOptions()); - if (definition.isEnablePreviewContext()) { - builder.enablePreviewContext(); - } - EndpointRequestHandler server = EndpointRequestHandler.create(builder.build()); - - // Start invocation - RequestProcessor handler = - server.processorForRequest( - "/" + serviceDefinition.getServiceName() + "/" + definition.getMethod(), - HeadersAccessor.wrap( - Map.of( - "content-type", - ProtoUtils.serviceProtocolContentTypeHeader( - definition.isEnablePreviewContext()))), - EndpointRequestHandler.LoggingContextSetter.THREAD_LOCAL_INSTANCE, - coreExecutor, - true); - - // Wire invocation - AssertSubscriber assertSubscriber = AssertSubscriber.create(Long.MAX_VALUE); - - // Wire invocation and start it - Multi.createFrom() - .iterable(definition.getInput()) - .runSubscriptionOn(coreExecutor) - .map(ProtoUtils::invocationInputToByteString) - .map(Slice::wrap) - .paceDemand() - .using(inputPacer(definition.getInput())) - .emitOn(coreExecutor) - .subscribe(handler); - Multi.createFrom() - .publisher(handler) - .runSubscriptionOn(coreExecutor) - .subscribe(assertSubscriber); - - // Check completed - assertSubscriber.awaitCompletion(Duration.ofSeconds(10)); - - // Unwrap messages and decode them - //noinspection unchecked - assertThatDecodingMessages(assertSubscriber.getItems().toArray(Slice[]::new)) - .map(InvocationInput::message) - .satisfies(l -> definition.getOutputAssert().accept((List) l)); - - // Clean logging - ThreadContext.clearAll(); - } - - private DemandPacer inputPacer(List input) { - if (input.get(0).message() instanceof Protocol.StartMessage startMessage) { - int knownEntries = startMessage.getKnownEntries(); - if (knownEntries != input.size() - 1) { - // We're sending a journal to replay plus more stuff, let's pace after the replay ends - return new FixedDemandPacer(knownEntries + 1, Duration.ofMillis(200)); - } - } - // We're only sending a journal to replay, or we're not sending start message, let's just pace - // right in the middle - return new FixedDemandPacer(Math.min(1, input.size() / 2), Duration.ofMillis(100)); - } -} diff --git a/sdk-core/src/test/java/dev/restate/sdk/core/MockRequestResponse.java b/sdk-core/src/test/java/dev/restate/sdk/core/MockRequestResponse.java deleted file mode 100644 index a5488af1f..000000000 --- a/sdk-core/src/test/java/dev/restate/sdk/core/MockRequestResponse.java +++ /dev/null @@ -1,93 +0,0 @@ -// Copyright (c) 2023 - Restate Software, Inc., Restate GmbH -// -// This file is part of the Restate Java SDK, -// which is released under the MIT license. -// -// You can find a copy of the license in file LICENSE in the root -// directory of this repository or package, or at -// https://github.com/restatedev/sdk-java/blob/main/LICENSE -package dev.restate.sdk.core; - -import static dev.restate.sdk.core.AssertUtils.assertThatDecodingMessages; - -import com.google.protobuf.MessageLite; -import dev.restate.common.Slice; -import dev.restate.sdk.core.TestDefinitions.TestDefinition; -import dev.restate.sdk.core.TestDefinitions.TestExecutor; -import dev.restate.sdk.core.statemachine.InvocationInput; -import dev.restate.sdk.core.statemachine.ProtoUtils; -import dev.restate.sdk.endpoint.Endpoint; -import dev.restate.sdk.endpoint.HeadersAccessor; -import dev.restate.sdk.endpoint.definition.ServiceDefinition; -import io.smallrye.mutiny.Multi; -import io.smallrye.mutiny.helpers.test.AssertSubscriber; -import java.time.Duration; -import java.util.List; -import java.util.Map; -import java.util.concurrent.Executor; -import java.util.concurrent.Executors; -import org.apache.logging.log4j.ThreadContext; - -public final class MockRequestResponse implements TestExecutor { - - public static final MockRequestResponse INSTANCE = new MockRequestResponse(); - - private MockRequestResponse() {} - - @Override - public boolean buffered() { - return true; - } - - @Override - public void executeTest(TestDefinition definition) { - Executor syscallsExecutor = Executors.newSingleThreadExecutor(); - - ServiceDefinition serviceDefinition = definition.getServiceDefinition(); - - // Prepare server - Endpoint.Builder builder = - Endpoint.builder().bind(serviceDefinition, definition.getServiceOptions()); - if (definition.isEnablePreviewContext()) { - builder.enablePreviewContext(); - } - EndpointRequestHandler server = EndpointRequestHandler.create(builder.build()); - - // Start invocation - RequestProcessor handler = - server.processorForRequest( - "/" + serviceDefinition.getServiceName() + "/" + definition.getMethod(), - HeadersAccessor.wrap( - Map.of( - "content-type", - ProtoUtils.serviceProtocolContentTypeHeader( - definition.isEnablePreviewContext()))), - EndpointRequestHandler.LoggingContextSetter.THREAD_LOCAL_INSTANCE, - syscallsExecutor, - false); - - // Wire invocation - AssertSubscriber assertSubscriber = AssertSubscriber.create(Long.MAX_VALUE); - Multi.createFrom() - .iterable(definition.getInput()) - .runSubscriptionOn(syscallsExecutor) - .map(ProtoUtils::invocationInputToByteString) - .map(Slice::wrap) - .subscribe(handler); - Multi.createFrom() - .publisher(handler) - .runSubscriptionOn(syscallsExecutor) - .subscribe(assertSubscriber); - - // Check completed - assertSubscriber.awaitCompletion(Duration.ofSeconds(10000)); - // Unwrap messages and decode them - //noinspection unchecked - assertThatDecodingMessages(assertSubscriber.getItems().toArray(Slice[]::new)) - .map(InvocationInput::message) - .satisfies(l -> definition.getOutputAssert().accept((List) l)); - - // Clean logging - ThreadContext.clearAll(); - } -} diff --git a/sdk-core/src/test/java/dev/restate/sdk/core/OnlyInputAndOutputTestSuite.java b/sdk-core/src/test/java/dev/restate/sdk/core/OnlyInputAndOutputTestSuite.java deleted file mode 100644 index dfcb89072..000000000 --- a/sdk-core/src/test/java/dev/restate/sdk/core/OnlyInputAndOutputTestSuite.java +++ /dev/null @@ -1,29 +0,0 @@ -// Copyright (c) 2023 - Restate Software, Inc., Restate GmbH -// -// This file is part of the Restate Java SDK, -// which is released under the MIT license. -// -// You can find a copy of the license in file LICENSE in the root -// directory of this repository or package, or at -// https://github.com/restatedev/sdk-java/blob/main/LICENSE -package dev.restate.sdk.core; - -import static dev.restate.sdk.core.TestDefinitions.TestDefinition; -import static dev.restate.sdk.core.statemachine.ProtoUtils.*; - -import dev.restate.sdk.core.TestDefinitions.TestInvocationBuilder; -import dev.restate.sdk.core.TestDefinitions.TestSuite; -import java.util.stream.Stream; - -public abstract class OnlyInputAndOutputTestSuite implements TestSuite { - - protected abstract TestInvocationBuilder noSyscallsGreeter(); - - @Override - public Stream definitions() { - return Stream.of( - this.noSyscallsGreeter() - .withInput(startMessage(1), inputCmd("Francesco")) - .expectingOutput(outputCmd("Hello Francesco"), END_MESSAGE)); - } -} diff --git a/sdk-core/src/test/java/dev/restate/sdk/core/PromiseTestSuite.java b/sdk-core/src/test/java/dev/restate/sdk/core/PromiseTestSuite.java deleted file mode 100644 index 151f94e22..000000000 --- a/sdk-core/src/test/java/dev/restate/sdk/core/PromiseTestSuite.java +++ /dev/null @@ -1,171 +0,0 @@ -// Copyright (c) 2023 - Restate Software, Inc., Restate GmbH -// -// This file is part of the Restate Java SDK, -// which is released under the MIT license. -// -// You can find a copy of the license in file LICENSE in the root -// directory of this repository or package, or at -// https://github.com/restatedev/sdk-java/blob/main/LICENSE -package dev.restate.sdk.core; - -import static dev.restate.sdk.core.TestDefinitions.*; -import static dev.restate.sdk.core.generated.protocol.Protocol.*; -import static dev.restate.sdk.core.statemachine.ProtoUtils.*; - -import dev.restate.sdk.common.TerminalException; -import dev.restate.sdk.core.generated.protocol.Protocol; -import dev.restate.sdk.core.generated.protocol.Protocol.GetPromiseCompletionNotificationMessage; -import dev.restate.sdk.core.statemachine.ProtoUtils; -import java.util.stream.Stream; - -public abstract class PromiseTestSuite implements TestSuite { - - private static final String PROMISE_KEY = "my-prom"; - - protected abstract TestInvocationBuilder awaitPromise(String promiseKey); - - protected abstract TestInvocationBuilder awaitPeekPromise( - String promiseKey, String emptyCaseReturnValue); - - protected abstract TestInvocationBuilder awaitIsPromiseCompleted(String promiseKey); - - protected abstract TestInvocationBuilder awaitResolvePromise( - String promiseKey, String completionValue); - - protected abstract TestInvocationBuilder awaitRejectPromise( - String promiseKey, String rejectReason); - - @Override - public Stream definitions() { - return Stream.of( - // --- Await promise - this.awaitPromise(PROMISE_KEY) - .withInput( - startMessage(1), - inputCmd(), - GetPromiseCompletionNotificationMessage.newBuilder() - .setCompletionId(1) - .setValue(value("my value"))) - .expectingOutput(getPromiseCmd(1, PROMISE_KEY), outputCmd("my value"), END_MESSAGE) - .named("Completed with success"), - this.awaitPromise(PROMISE_KEY) - .withInput( - startMessage(1), - inputCmd(), - GetPromiseCompletionNotificationMessage.newBuilder() - .setCompletionId(1) - .setFailure(ProtoUtils.failure(new TerminalException("myerror")))) - .expectingOutput( - getPromiseCmd(1, PROMISE_KEY), - outputCmd(new TerminalException("myerror")), - END_MESSAGE) - .named("Completed with failure"), - // --- Peek promise - this.awaitPeekPromise(PROMISE_KEY, "null") - .withInput( - startMessage(1), - inputCmd(), - PeekPromiseCompletionNotificationMessage.newBuilder() - .setCompletionId(1) - .setValue(value("my value"))) - .expectingOutput(peekPromiseCmd(1, PROMISE_KEY), outputCmd("my value"), END_MESSAGE) - .named("Completed with success"), - this.awaitPeekPromise(PROMISE_KEY, "null") - .withInput( - startMessage(1), - inputCmd(), - PeekPromiseCompletionNotificationMessage.newBuilder() - .setCompletionId(1) - .setFailure(ProtoUtils.failure(new TerminalException("myerror")))) - .expectingOutput( - peekPromiseCmd(1, PROMISE_KEY), - outputCmd(new TerminalException("myerror")), - END_MESSAGE) - .named("Completed with failure"), - this.awaitPeekPromise(PROMISE_KEY, "null") - .withInput( - startMessage(1), - inputCmd(), - PeekPromiseCompletionNotificationMessage.newBuilder() - .setCompletionId(1) - .setVoid(Protocol.Void.getDefaultInstance())) - .expectingOutput(peekPromiseCmd(1, PROMISE_KEY), outputCmd("null"), END_MESSAGE) - .named("Completed with null"), - // --- Promise is completed - this.awaitIsPromiseCompleted(PROMISE_KEY) - .withInput( - startMessage(1), - inputCmd(), - PeekPromiseCompletionNotificationMessage.newBuilder() - .setCompletionId(1) - .setValue(value("my value"))) - .onlyBidiStream() - .expectingOutput( - peekPromiseCmd(1, PROMISE_KEY), outputCmd(TestSerdes.BOOLEAN, true), END_MESSAGE) - .named("Completed with success"), - this.awaitIsPromiseCompleted(PROMISE_KEY) - .withInput( - startMessage(1), - inputCmd(), - PeekPromiseCompletionNotificationMessage.newBuilder() - .setCompletionId(1) - .setVoid(Protocol.Void.getDefaultInstance())) - .expectingOutput( - peekPromiseCmd(1, PROMISE_KEY), outputCmd(TestSerdes.BOOLEAN, false), END_MESSAGE) - .named("Not completed"), - // --- Promise resolve - this.awaitResolvePromise(PROMISE_KEY, "my val") - .withInput( - startMessage(1), - inputCmd(), - CompletePromiseCompletionNotificationMessage.newBuilder() - .setCompletionId(1) - .setVoid(Protocol.Void.getDefaultInstance()) - .build()) - .expectingOutput( - completePromiseCmd(1, PROMISE_KEY, "my val"), - outputCmd(TestSerdes.BOOLEAN, true), - END_MESSAGE) - .named("resolve succeeds"), - this.awaitResolvePromise(PROMISE_KEY, "my val") - .withInput( - startMessage(1), - inputCmd(), - CompletePromiseCompletionNotificationMessage.newBuilder() - .setCompletionId(1) - .setFailure(failure(new TerminalException("cannot write promise"))) - .build()) - .expectingOutput( - completePromiseCmd(1, PROMISE_KEY, "my val"), - outputCmd(TestSerdes.BOOLEAN, false), - END_MESSAGE) - .named("resolve fails"), - // --- Promise reject - this.awaitRejectPromise(PROMISE_KEY, "my failure") - .withInput( - startMessage(1), - inputCmd(), - CompletePromiseCompletionNotificationMessage.newBuilder() - .setCompletionId(1) - .setVoid(Protocol.Void.getDefaultInstance()) - .build()) - .expectingOutput( - completePromiseCmd(1, PROMISE_KEY, new TerminalException("my failure")), - outputCmd(TestSerdes.BOOLEAN, true), - END_MESSAGE) - .named("resolve succeeds"), - this.awaitRejectPromise(PROMISE_KEY, "my failure") - .withInput( - startMessage(1), - inputCmd(), - CompletePromiseCompletionNotificationMessage.newBuilder() - .setCompletionId(1) - .setFailure(failure(new TerminalException("cannot write promise"))) - .build()) - .expectingOutput( - completePromiseCmd(1, PROMISE_KEY, new TerminalException("my failure")), - outputCmd(TestSerdes.BOOLEAN, false), - END_MESSAGE) - .named("resolve fails")); - } -} diff --git a/sdk-core/src/test/java/dev/restate/sdk/core/RandomTestSuite.java b/sdk-core/src/test/java/dev/restate/sdk/core/RandomTestSuite.java deleted file mode 100644 index c615bab51..000000000 --- a/sdk-core/src/test/java/dev/restate/sdk/core/RandomTestSuite.java +++ /dev/null @@ -1,46 +0,0 @@ -// Copyright (c) 2023 - Restate Software, Inc., Restate GmbH -// -// This file is part of the Restate Java SDK, -// which is released under the MIT license. -// -// You can find a copy of the license in file LICENSE in the root -// directory of this repository or package, or at -// https://github.com/restatedev/sdk-java/blob/main/LICENSE -package dev.restate.sdk.core; - -import static dev.restate.sdk.core.statemachine.ProtoUtils.*; -import static dev.restate.sdk.core.statemachine.ProtoUtils.invocationIdToRandomSeed; - -import dev.restate.sdk.core.TestDefinitions.TestDefinition; -import dev.restate.sdk.core.TestDefinitions.TestInvocationBuilder; -import dev.restate.sdk.core.TestDefinitions.TestSuite; -import dev.restate.sdk.core.statemachine.ProtoUtils; -import java.util.stream.Stream; - -public abstract class RandomTestSuite implements TestSuite { - - protected abstract TestInvocationBuilder randomShouldBeDeterministic(); - - protected abstract int getExpectedInt(long seed); - - @Override - public Stream definitions() { - String debugId = "my-id"; - long startMessageSeed = System.currentTimeMillis(); - - return Stream.of( - this.randomShouldBeDeterministic() - .withInput( - startMessage(1).setDebugId(debugId).setRandomSeed(startMessageSeed), - ProtoUtils.inputCmd()) - // This enables protocol v6 - .enablePreviewContext() - .expectingOutput(outputCmd(getExpectedInt(startMessageSeed)), END_MESSAGE) - .named("Using StartMessage.random_seed"), - this.randomShouldBeDeterministic() - .withInput(startMessage(1).setDebugId(debugId), ProtoUtils.inputCmd()) - .expectingOutput( - outputCmd(getExpectedInt(invocationIdToRandomSeed(debugId))), END_MESSAGE) - .named("Using invocation id")); - } -} diff --git a/sdk-core/src/test/java/dev/restate/sdk/core/SideEffectTestSuite.java b/sdk-core/src/test/java/dev/restate/sdk/core/SideEffectTestSuite.java deleted file mode 100644 index 581d0bea7..000000000 --- a/sdk-core/src/test/java/dev/restate/sdk/core/SideEffectTestSuite.java +++ /dev/null @@ -1,379 +0,0 @@ -// Copyright (c) 2023 - Restate Software, Inc., Restate GmbH -// -// This file is part of the Restate Java SDK, -// which is released under the MIT license. -// -// You can find a copy of the license in file LICENSE in the root -// directory of this repository or package, or at -// https://github.com/restatedev/sdk-java/blob/main/LICENSE -package dev.restate.sdk.core; - -import static dev.restate.sdk.core.AssertUtils.*; -import static dev.restate.sdk.core.TestDefinitions.TestInvocationBuilder; -import static dev.restate.sdk.core.statemachine.ProtoUtils.*; -import static org.assertj.core.api.Assertions.assertThat; -import static org.assertj.core.api.InstanceOfAssertFactories.STRING; -import static org.assertj.core.api.InstanceOfAssertFactories.type; - -import com.google.protobuf.ByteString; -import dev.restate.sdk.common.RetryPolicy; -import dev.restate.sdk.common.TerminalException; -import dev.restate.sdk.core.generated.protocol.Protocol; -import dev.restate.sdk.core.statemachine.MessageType; -import java.time.Duration; -import java.util.stream.Stream; -import org.assertj.core.data.Index; - -public abstract class SideEffectTestSuite implements TestDefinitions.TestSuite { - - protected abstract TestInvocationBuilder sideEffect(String sideEffectOutput); - - protected abstract TestInvocationBuilder namedSideEffect(String name, String sideEffectOutput); - - protected abstract TestInvocationBuilder consecutiveSideEffect(String sideEffectOutput); - - protected abstract TestInvocationBuilder checkContextSwitching(); - - protected abstract TestInvocationBuilder failingSideEffect(String name, String reason); - - protected abstract TestInvocationBuilder awaitAllSideEffectWithFirstFailing( - String firstSideEffect, String secondSideEffect, String successValue, String failureReason); - - protected abstract TestInvocationBuilder awaitAllSideEffectWithSecondFailing( - String firstSideEffect, String secondSideEffect, String successValue, String failureReason); - - protected abstract TestInvocationBuilder failingSideEffectWithRetryPolicy( - String reason, RetryPolicy retryPolicy); - - protected abstract TestInvocationBuilder sideEffectGuard(); - - protected abstract TestInvocationBuilder sideEffectGuardAwait(); - - protected abstract TestInvocationBuilder instantNow(); - - protected abstract void assertIsInstant(ByteString bytes); - - @Override - public Stream definitions() { - return Stream.of( - this.sideEffect("Francesco") - .withInput(startMessage(1), inputCmd("Till")) - .expectingOutput(runCmd(1), proposeRunCompletion(1, "Francesco"), suspensionMessage(1)) - .named("Run and propose completion"), - this.sideEffect("Francesco") - .withInput(startMessage(3), inputCmd("Till"), runCmd(1), runCompletion(1, "Francesco")) - .onlyBidiStream() - .expectingOutput(outputCmd("Hello Francesco"), END_MESSAGE) - .named("Replay from completion"), - this.namedSideEffect("get-my-name", "Francesco") - .withInput(startMessage(1), inputCmd("Till")) - .expectingOutput( - runCmd(1, "get-my-name"), - proposeRunCompletion(1, "Francesco"), - suspensionMessage(1)), - this.consecutiveSideEffect("Francesco") - .withInput(startMessage(3), inputCmd("Till"), runCmd(1), runCompletion(1, "Francesco")) - .expectingOutput(runCmd(2), proposeRunCompletion(2, "FRANCESCO"), suspensionMessage(2)) - .named("Suspends on second run"), - this.consecutiveSideEffect("Francesco") - .withInput( - startMessage(5), - inputCmd("Till"), - runCmd(1), - runCmd(2), - runCompletion(1, "Francesco"), - runCompletion(2, "FRANCESCO")) - .expectingOutput(outputCmd("Hello FRANCESCO"), END_MESSAGE) - .named("With optimization and ack on first and second side effect will resume"), - this.failingSideEffect("my-side-effect", "some failure") - .withInput(startMessage(1), inputCmd()) - .assertingOutput( - msgs -> - assertThat(msgs) - .satisfiesExactly( - msg -> assertThat(msg).isEqualTo(runCmd(1, "my-side-effect")), - errorMessage( - errorMessage -> - assertThat(errorMessage) - .returns( - TerminalException.INTERNAL_SERVER_ERROR_CODE, - Protocol.ErrorMessage::getCode) - .returns(1, Protocol.ErrorMessage::getRelatedCommandIndex) - .returns( - (int) MessageType.RunCommandMessage.encode(), - Protocol.ErrorMessage::getRelatedCommandType) - .returns( - "my-side-effect", - Protocol.ErrorMessage::getRelatedCommandName) - .extracting(Protocol.ErrorMessage::getMessage, STRING) - .contains("some failure")))) - .named("Fail on first attempt"), - this.failingSideEffect("my-side-effect", "some failure") - .withInput(startMessage(2), inputCmd(), runCmd(1, "my-side-effect")) - .assertingOutput( - containsOnly( - errorMessage( - errorMessage -> - assertThat(errorMessage) - .returns( - TerminalException.INTERNAL_SERVER_ERROR_CODE, - Protocol.ErrorMessage::getCode) - .returns(1, Protocol.ErrorMessage::getRelatedCommandIndex) - .returns( - (int) MessageType.RunCommandMessage.encode(), - Protocol.ErrorMessage::getRelatedCommandType) - .returns( - "my-side-effect", Protocol.ErrorMessage::getRelatedCommandName) - .extracting(Protocol.ErrorMessage::getMessage, STRING) - .contains("some failure")))) - .named("Fail on second attempt"), - this.awaitAllSideEffectWithFirstFailing( - "first-side-effect", "second-side-effect", "Francesco", "some failure") - .withInput(startMessage(1), inputCmd()) - .assertingOutput( - msgs -> { - // The thing here is, it depends on timing. Sometimes we might get the proposal - // for the succeeded one, sometimes not. - // - // So we need to take that in account in the assertions. - assertThat(msgs).size().isBetween(3, 4); - assertThat(msgs) - .satisfies( - msg -> assertThat(msg).isEqualTo(runCmd(1, "first-side-effect")), - Index.atIndex(0)); - assertThat(msgs) - .satisfies( - msg -> assertThat(msg).isEqualTo(runCmd(2, "second-side-effect")), - Index.atIndex(1)); - - if (msgs.size() == 4) { - // If there's four messages, the third one must be the run completion proposal - assertThat(msgs) - .satisfies( - msg -> - assertThat(msg) - .isEqualTo(proposeRunCompletion(2, "Francesco").build()), - Index.atIndex(2)); - } - - // Last message must be the error - assertThat(msgs) - .satisfies( - errorMessage( - errorMessage -> - assertThat(errorMessage) - .returns( - TerminalException.INTERNAL_SERVER_ERROR_CODE, - Protocol.ErrorMessage::getCode) - .returns(1, Protocol.ErrorMessage::getRelatedCommandIndex) - .returns( - (int) MessageType.RunCommandMessage.encode(), - Protocol.ErrorMessage::getRelatedCommandType) - .returns( - "first-side-effect", - Protocol.ErrorMessage::getRelatedCommandName) - .extracting(Protocol.ErrorMessage::getMessage, STRING) - .contains("some failure")), - Index.atIndex(msgs.size() - 1)); - }) - .named("Fail the first side effect"), - this.awaitAllSideEffectWithFirstFailing( - "first-side-effect", "second-side-effect", "Francesco", "some failure") - .withInput( - startMessage(3), - inputCmd(), - runCmd(1, "first-side-effect"), - runCmd(2, "second-side-effect")) - .assertingOutput( - msgs -> { - // The thing here is, it depends on timing. Sometimes we might get the proposal - // for the succeeded one, sometimes not. - // - // So we need to take that in account in the assertions. - assertThat(msgs).size().isBetween(1, 2); - - if (msgs.size() == 2) { - // If there's four messages, the third one must be the run completion proposal - assertThat(msgs) - .satisfies( - msg -> - assertThat(msg) - .isEqualTo(proposeRunCompletion(2, "Francesco").build()), - Index.atIndex(0)); - } - - // Last message must be the error - assertThat(msgs) - .satisfies( - errorMessage( - errorMessage -> - assertThat(errorMessage) - .returns( - TerminalException.INTERNAL_SERVER_ERROR_CODE, - Protocol.ErrorMessage::getCode) - .returns(1, Protocol.ErrorMessage::getRelatedCommandIndex) - .returns( - (int) MessageType.RunCommandMessage.encode(), - Protocol.ErrorMessage::getRelatedCommandType) - .returns( - "first-side-effect", - Protocol.ErrorMessage::getRelatedCommandName) - .extracting(Protocol.ErrorMessage::getMessage, STRING) - .contains("some failure")), - Index.atIndex(msgs.size() - 1)); - }) - .named("Fail the first side effect during replay"), - this.awaitAllSideEffectWithSecondFailing( - "first-side-effect", "second-side-effect", "Francesco", "some failure") - .withInput(startMessage(1), inputCmd()) - .assertingOutput( - msgs -> { - // The thing here is, it depends on timing. Sometimes we might get the proposal - // for the succeeded one, sometimes not. - // - // So we need to take that in account in the assertions. - assertThat(msgs).size().isBetween(3, 4); - assertThat(msgs) - .satisfies( - msg -> assertThat(msg).isEqualTo(runCmd(1, "first-side-effect")), - Index.atIndex(0)); - assertThat(msgs) - .satisfies( - msg -> assertThat(msg).isEqualTo(runCmd(2, "second-side-effect")), - Index.atIndex(1)); - - if (msgs.size() == 4) { - // If there's four messages, the third one must be the run completion proposal - assertThat(msgs) - .satisfies( - msg -> - assertThat(msg) - .isEqualTo(proposeRunCompletion(1, "Francesco").build()), - Index.atIndex(2)); - } - - // Last message must be the error - assertThat(msgs) - .satisfies( - errorMessage( - errorMessage -> - assertThat(errorMessage) - .returns( - TerminalException.INTERNAL_SERVER_ERROR_CODE, - Protocol.ErrorMessage::getCode) - .returns(2, Protocol.ErrorMessage::getRelatedCommandIndex) - .returns( - (int) MessageType.RunCommandMessage.encode(), - Protocol.ErrorMessage::getRelatedCommandType) - .returns( - "second-side-effect", - Protocol.ErrorMessage::getRelatedCommandName) - .extracting(Protocol.ErrorMessage::getMessage, STRING) - .contains("some failure")), - Index.atIndex(msgs.size() - 1)); - }) - .named("Fail the second side effect"), - this.failingSideEffectWithRetryPolicy( - "some failure", - RetryPolicy.exponential(Duration.ofMillis(100), 1.0f).setMaxAttempts(2)) - .withInput(startMessage(1).setRetryCountSinceLastStoredEntry(0), inputCmd()) - .onlyBidiStream() - .assertingOutput( - msgs -> - assertThat(msgs) - .satisfiesExactly( - msg -> assertThat(msg).isEqualTo(runCmd(1)), - errorMessage( - errorMessage -> - assertThat(errorMessage) - .returns( - TerminalException.INTERNAL_SERVER_ERROR_CODE, - Protocol.ErrorMessage::getCode) - .returns(1, Protocol.ErrorMessage::getRelatedCommandIndex) - .returns( - (int) MessageType.RunCommandMessage.encode(), - Protocol.ErrorMessage::getRelatedCommandType) - .returns(100L, Protocol.ErrorMessage::getNextRetryDelay) - .extracting(Protocol.ErrorMessage::getMessage, STRING) - .contains("some failure")))) - .named("Should fail as retryable error with the attached next retry delay"), - this.failingSideEffectWithRetryPolicy( - "some failure", - RetryPolicy.exponential(Duration.ofMillis(100), 1.0f).setMaxAttempts(2)) - .withInput(startMessage(2).setRetryCountSinceLastStoredEntry(1), inputCmd(), runCmd(1)) - .expectingOutput( - proposeRunCompletion(1, 500, "java.lang.IllegalStateException: some failure"), - suspensionMessage(1)) - .named("Should convert retryable error to terminal"), - // --- Other tests - this.checkContextSwitching() - .withInput(startMessage(1), inputCmd()) - .onlyBidiStream() - .assertingOutput( - actualOutputMessages -> - assertThat(actualOutputMessages).element(2).isEqualTo(suspensionMessage(1))), - this.instantNow() - .withInput(startMessage(1), inputCmd()) - .onlyBidiStream() - .assertingOutput( - msgs -> - assertThat(msgs) - .satisfiesExactly( - msg -> - assertThat(msg) - .asInstanceOf(type(Protocol.RunCommandMessage.class)) - .returns(1, Protocol.RunCommandMessage::getResultCompletionId), - msg -> - assertThat(msg) - .asInstanceOf(type(Protocol.ProposeRunCompletionMessage.class)) - .returns( - 1, - Protocol.ProposeRunCompletionMessage::getResultCompletionId) - .extracting(Protocol.ProposeRunCompletionMessage::getValue) - .satisfies(this::assertIsInstant), - msg -> assertThat(msg).isEqualTo(suspensionMessage(1)))), - this.sideEffectGuard() - .withInput(startMessage(1), inputCmd()) - .assertingOutput( - msgs -> - assertThat(msgs) - .satisfiesExactly( - msg -> assertThat(msg).isEqualTo(runCmd(1)), - errorMessage( - errorMessage -> - assertThat(errorMessage) - .returns( - TerminalException.INTERNAL_SERVER_ERROR_CODE, - Protocol.ErrorMessage::getCode) - .returns(1, Protocol.ErrorMessage::getRelatedCommandIndex) - .returns( - (int) MessageType.RunCommandMessage.encode(), - Protocol.ErrorMessage::getRelatedCommandType) - .extracting(Protocol.ErrorMessage::getMessage, STRING) - .contains( - "Cannot invoke context method inside ctx.run()")))) - .named("Side effect guard prevents context usage inside run"), - this.sideEffectGuardAwait() - .withInput(startMessage(1), inputCmd()) - .assertingOutput( - msgs -> - assertThat(msgs) - .satisfiesExactly( - msg -> assertThat(msg).isInstanceOf(Protocol.SleepCommandMessage.class), - msg -> assertThat(msg).isEqualTo(runCmd(2)), - errorMessage( - errorMessage -> - assertThat(errorMessage) - .returns( - TerminalException.INTERNAL_SERVER_ERROR_CODE, - Protocol.ErrorMessage::getCode) - .returns(2, Protocol.ErrorMessage::getRelatedCommandIndex) - .returns( - (int) MessageType.RunCommandMessage.encode(), - Protocol.ErrorMessage::getRelatedCommandType) - .extracting(Protocol.ErrorMessage::getMessage, STRING) - .contains( - "Cannot invoke context method inside ctx.run()")))) - .named("Side effect guard prevents awaiting durable future inside run")); - } -} diff --git a/sdk-core/src/test/java/dev/restate/sdk/core/SleepTestSuite.java b/sdk-core/src/test/java/dev/restate/sdk/core/SleepTestSuite.java deleted file mode 100644 index 05ecdbf48..000000000 --- a/sdk-core/src/test/java/dev/restate/sdk/core/SleepTestSuite.java +++ /dev/null @@ -1,113 +0,0 @@ -// Copyright (c) 2023 - Restate Software, Inc., Restate GmbH -// -// This file is part of the Restate Java SDK, -// which is released under the MIT license. -// -// You can find a copy of the license in file LICENSE in the root -// directory of this repository or package, or at -// https://github.com/restatedev/sdk-java/blob/main/LICENSE -package dev.restate.sdk.core; - -import static dev.restate.sdk.core.TestDefinitions.TestInvocationBuilder; -import static dev.restate.sdk.core.statemachine.ProtoUtils.*; -import static org.assertj.core.api.Assertions.assertThat; -import static org.assertj.core.api.InstanceOfAssertFactories.LONG; -import static org.assertj.core.api.InstanceOfAssertFactories.type; - -import com.google.protobuf.MessageLiteOrBuilder; -import dev.restate.sdk.core.generated.protocol.Protocol; -import java.time.Instant; -import java.util.function.Function; -import java.util.stream.IntStream; -import java.util.stream.Stream; - -public abstract class SleepTestSuite implements TestDefinitions.TestSuite { - - final Long startTime = System.currentTimeMillis(); - - protected abstract TestInvocationBuilder sleepGreeter(); - - protected abstract TestInvocationBuilder manySleeps(); - - @Override - public Stream definitions() { - return Stream.of( - this.sleepGreeter() - .withInput(startMessage(1), inputCmd("Till")) - .assertingOutput( - messageLites -> { - assertThat(messageLites) - .element(0) - .asInstanceOf(type(Protocol.SleepCommandMessage.class)) - .extracting(Protocol.SleepCommandMessage::getWakeUpTime, LONG) - .isGreaterThanOrEqualTo(startTime + 1000) - .isLessThanOrEqualTo(Instant.now().toEpochMilli() + 1000); - - assertThat(messageLites) - .element(1) - .isInstanceOf(Protocol.SuspensionMessage.class); - }) - .named("Sleep 1000 ms not completed"), - this.sleepGreeter() - .withInput( - startMessage(2), - inputCmd("Till"), - Protocol.SleepCommandMessage.newBuilder() - .setWakeUpTime(Instant.now().toEpochMilli()) - .setResultCompletionId(1), - Protocol.SleepCompletionNotificationMessage.newBuilder() - .setCompletionId(1) - .setVoid(Protocol.Void.getDefaultInstance())) - .expectingOutput(outputCmd("Hello"), END_MESSAGE) - .named("Sleep 1000 ms sleep completed"), - this.sleepGreeter() - .withInput( - startMessage(2), - inputCmd("Till"), - Protocol.SleepCommandMessage.newBuilder() - .setResultCompletionId(1) - .setWakeUpTime(Instant.now().toEpochMilli()) - .build()) - .expectingOutput(suspensionMessage(1)) - .named("Sleep 1000 ms still sleeping"), - this.manySleeps() - .withInput( - Stream.concat( - Stream.of(startMessage(14), inputCmd("Till")), - IntStream.rangeClosed(1, 10) - .mapToObj( - i -> - (i % 3 == 0) - ? Stream.of( - Protocol.SleepCommandMessage.newBuilder() - .setWakeUpTime(Instant.now().toEpochMilli()) - .setResultCompletionId(i), - Protocol.SleepCompletionNotificationMessage.newBuilder() - .setCompletionId(i) - .setVoid(Protocol.Void.getDefaultInstance())) - : Stream.of( - Protocol.SleepCommandMessage.newBuilder() - .setWakeUpTime(Instant.now().toEpochMilli()) - .setResultCompletionId(i))) - .flatMap(Function.identity()))) - .expectingOutput(suspensionMessage(1, 2, 4, 5, 7, 8, 10)) - .named("Sleep 1000 ms sleep completed"), - this.sleepGreeter() - .withInput( - startMessage(1), - inputCmd("Till"), - Protocol.SleepCompletionNotificationMessage.newBuilder() - .setCompletionId(1) - .setVoid(Protocol.Void.getDefaultInstance())) - .onlyBidiStream() - .assertingOutput( - messageLites -> { - assertThat(messageLites) - .element(0) - .isInstanceOf(Protocol.SleepCommandMessage.class); - assertThat(messageLites).element(1).isEqualTo(outputCmd("Hello")); - assertThat(messageLites).element(2).isEqualTo(END_MESSAGE); - }) - .named("Failing sleep")); - } -} diff --git a/sdk-core/src/test/java/dev/restate/sdk/core/StateMachineFailuresTestSuite.java b/sdk-core/src/test/java/dev/restate/sdk/core/StateMachineFailuresTestSuite.java deleted file mode 100644 index 1bd24a839..000000000 --- a/sdk-core/src/test/java/dev/restate/sdk/core/StateMachineFailuresTestSuite.java +++ /dev/null @@ -1,169 +0,0 @@ -// Copyright (c) 2023 - Restate Software, Inc., Restate GmbH -// -// This file is part of the Restate Java SDK, -// which is released under the MIT license. -// -// You can find a copy of the license in file LICENSE in the root -// directory of this repository or package, or at -// https://github.com/restatedev/sdk-java/blob/main/LICENSE -package dev.restate.sdk.core; - -import static dev.restate.sdk.core.AssertUtils.*; -import static dev.restate.sdk.core.TestDefinitions.TestInvocationBuilder; -import static dev.restate.sdk.core.statemachine.ProtoUtils.*; -import static org.assertj.core.api.Assertions.assertThat; -import static org.assertj.core.api.InstanceOfAssertFactories.STRING; - -import dev.restate.sdk.core.generated.protocol.Protocol; -import dev.restate.serde.Serde; -import java.nio.charset.StandardCharsets; -import java.util.concurrent.atomic.AtomicInteger; -import java.util.stream.Stream; - -public abstract class StateMachineFailuresTestSuite implements TestDefinitions.TestSuite { - - protected abstract TestInvocationBuilder getState(AtomicInteger nonTerminalExceptionsSeen); - - protected abstract TestInvocationBuilder sideEffectFailure(Serde serde); - - protected abstract TestInvocationBuilder awaitRunAfterProgressWasMade(); - - protected abstract TestInvocationBuilder awaitSleepAfterProgressWasMade(); - - protected abstract TestInvocationBuilder awaitAwakeableAfterProgressWasMade(); - - private static final Serde FAILING_SERIALIZATION_INTEGER_TYPE_TAG = - Serde.using( - i -> { - throw new IllegalStateException("Cannot serialize integer"); - }, - b -> Integer.parseInt(new String(b, StandardCharsets.UTF_8))); - - private static final Serde FAILING_DESERIALIZATION_INTEGER_TYPE_TAG = - Serde.using( - i -> Integer.toString(i).getBytes(StandardCharsets.UTF_8), - b -> { - throw new IllegalStateException("Cannot deserialize integer"); - }); - - @Override - public Stream definitions() { - AtomicInteger nonTerminalExceptionsSeenTest1 = new AtomicInteger(); - AtomicInteger nonTerminalExceptionsSeenTest2 = new AtomicInteger(); - - return Stream.of( - this.getState(nonTerminalExceptionsSeenTest1) - .withInput(startMessage(2), inputCmd("Till"), getLazyStateCmd(1, "Something")) - .assertingOutput( - msgs -> { - assertThat(msgs) - .satisfiesExactly( - protocolExceptionErrorMessage(ProtocolException.JOURNAL_MISMATCH_CODE)); - assertThat(nonTerminalExceptionsSeenTest1).hasValue(0); - }) - .named("Protocol Exception"), - this.getState(nonTerminalExceptionsSeenTest2) - .withInput( - startMessage(2), - inputCmd("Till"), - getLazyStateCmd(1, "STATE"), - getLazyStateCompletion(1, "This is not an integer")) - .assertingOutput( - msgs -> { - assertThat(msgs) - .satisfiesExactly( - errorDescriptionStartingWith( - NumberFormatException.class.getCanonicalName())); - assertThat(nonTerminalExceptionsSeenTest2).hasValue(0); - }) - .named("Serde error"), - this.sideEffectFailure(FAILING_SERIALIZATION_INTEGER_TYPE_TAG) - .withInput(startMessage(1), inputCmd("Till")) - .assertingOutput( - msgs -> - assertThat(msgs.get(1)) - .satisfies( - errorDescriptionStartingWith( - IllegalStateException.class.getCanonicalName()))) - .named("Serde serialization error"), - this.sideEffectFailure(FAILING_DESERIALIZATION_INTEGER_TYPE_TAG) - .withInput( - startMessage(3), - inputCmd("Till"), - runCmd(1), - Protocol.RunCompletionNotificationMessage.newBuilder() - .setCompletionId(1) - .setValue(Protocol.Value.getDefaultInstance()) - .build()) - .assertingOutput( - containsOnly( - errorDescriptionStartingWith(IllegalStateException.class.getCanonicalName()))) - .named("Serde deserialization error"), - // --- Uncompleted doProgress during replay (bad await) tests - this.awaitRunAfterProgressWasMade() - .withInput( - startMessage(4), - inputCmd(), - runCmd(1, "my-side-effect"), - Protocol.SleepCommandMessage.newBuilder().setResultCompletionId(2).build(), - Protocol.SleepCompletionNotificationMessage.newBuilder() - .setCompletionId(2) - .setVoid(Protocol.Void.getDefaultInstance()) - .build()) - .assertingOutput( - containsOnly( - errorMessage( - errorMessage -> - assertThat(errorMessage) - .returns( - ProtocolException.JOURNAL_MISMATCH_CODE, - Protocol.ErrorMessage::getCode) - .extracting(Protocol.ErrorMessage::getMessage, STRING) - .contains("could not be replayed") - .contains("await")))) - .named("Add await on run after progress was made"), - this.awaitSleepAfterProgressWasMade() - .withInput( - startMessage(4), - inputCmd(), - Protocol.SleepCommandMessage.newBuilder().setResultCompletionId(1).build(), - Protocol.SleepCommandMessage.newBuilder().setResultCompletionId(2).build(), - Protocol.SleepCompletionNotificationMessage.newBuilder() - .setCompletionId(2) - .setVoid(Protocol.Void.getDefaultInstance()) - .build()) - .assertingOutput( - containsOnly( - errorMessage( - errorMessage -> - assertThat(errorMessage) - .returns( - ProtocolException.JOURNAL_MISMATCH_CODE, - Protocol.ErrorMessage::getCode) - .extracting(Protocol.ErrorMessage::getMessage, STRING) - .contains("could not be replayed") - .contains("await")))) - .named("Add await on sleep after progress was made"), - this.awaitAwakeableAfterProgressWasMade() - .withInput( - startMessage(3), - inputCmd(), - Protocol.SleepCommandMessage.newBuilder().setResultCompletionId(2).build(), - Protocol.SleepCompletionNotificationMessage.newBuilder() - .setCompletionId(2) - .setVoid(Protocol.Void.getDefaultInstance()) - .build()) - .assertingOutput( - containsOnly( - errorMessage( - errorMessage -> - assertThat(errorMessage) - .returns( - ProtocolException.JOURNAL_MISMATCH_CODE, - Protocol.ErrorMessage::getCode) - .extracting(Protocol.ErrorMessage::getMessage, STRING) - .contains("could not be replayed") - .contains("await")))) - .named("Add await on awakeable after progress was made")); - } -} diff --git a/sdk-core/src/test/java/dev/restate/sdk/core/StateTestSuite.java b/sdk-core/src/test/java/dev/restate/sdk/core/StateTestSuite.java deleted file mode 100644 index 3c73b74ea..000000000 --- a/sdk-core/src/test/java/dev/restate/sdk/core/StateTestSuite.java +++ /dev/null @@ -1,107 +0,0 @@ -// Copyright (c) 2023 - Restate Software, Inc., Restate GmbH -// -// This file is part of the Restate Java SDK, -// which is released under the MIT license. -// -// You can find a copy of the license in file LICENSE in the root -// directory of this repository or package, or at -// https://github.com/restatedev/sdk-java/blob/main/LICENSE -package dev.restate.sdk.core; - -import static dev.restate.sdk.core.AssertUtils.*; -import static dev.restate.sdk.core.statemachine.ProtoUtils.*; -import static org.assertj.core.api.AssertionsForClassTypes.assertThat; -import static org.assertj.core.api.InstanceOfAssertFactories.STRING; - -import dev.restate.sdk.core.TestDefinitions.TestInvocationBuilder; -import dev.restate.sdk.core.generated.protocol.Protocol; -import java.util.stream.Stream; - -public abstract class StateTestSuite implements TestDefinitions.TestSuite { - - protected abstract TestInvocationBuilder getState(); - - protected abstract TestInvocationBuilder getAndSetState(); - - protected abstract TestInvocationBuilder setNullState(); - - @Override - public Stream definitions() { - return Stream.of( - this.getState() - .withInput( - startMessage(3), - inputCmd("Till"), - getLazyStateCmd(1, "STATE"), - getLazyStateCompletion(1, "Francesco")) - .expectingOutput(outputCmd("Hello Francesco"), END_MESSAGE) - .named("With GetStateEntry already completed"), - this.getState() - .withInput( - startMessage(3), - inputCmd("Till"), - getLazyStateCmd(1, "STATE"), - getLazyStateCompletionEmpty(1)) - .expectingOutput(outputCmd("Hello Unknown"), END_MESSAGE) - .named("With GetStateEntry already completed empty"), - this.getState() - .withInput(startMessage(1), inputCmd("Till")) - .expectingOutput(getLazyStateCmd(1, "STATE"), suspensionMessage(1)) - .named("Without GetStateEntry"), - this.getState() - .withInput(startMessage(2), inputCmd("Till"), getLazyStateCmd(1, "STATE")) - .expectingOutput(suspensionMessage(1)) - .named("With GetStateEntry not completed"), - this.getState() - .withInput( - startMessage(2), - inputCmd("Till"), - getLazyStateCmd(1, "STATE"), - getLazyStateCompletion(1, "Francesco")) - .onlyBidiStream() - .expectingOutput(outputCmd("Hello Francesco"), END_MESSAGE) - .named("With GetStateEntry and completed with later CompletionFrame"), - this.getState() - .withInput(startMessage(1), inputCmd("Till"), getLazyStateCompletion(1, "Francesco")) - .onlyBidiStream() - .expectingOutput(getLazyStateCmd(1, "STATE"), outputCmd("Hello Francesco"), END_MESSAGE) - .named("Without GetStateEntry and completed with later CompletionFrame"), - this.getAndSetState() - .withInput( - startMessage(4), - inputCmd("Till"), - getLazyStateCmd(1, "STATE"), - getLazyStateCompletion(1, "Francesco"), - setStateCmd("STATE", "Till")) - .expectingOutput(outputCmd("Hello Francesco"), END_MESSAGE) - .named("With GetState and SetState"), - this.getAndSetState() - .withInput( - startMessage(3), - inputCmd("Till"), - getLazyStateCmd(1, "STATE"), - getLazyStateCompletion(1, "Francesco")) - .expectingOutput( - setStateCmd("STATE", "Till"), outputCmd("Hello Francesco"), END_MESSAGE) - .named("With GetState already completed"), - this.getAndSetState() - .withInput(startMessage(1), inputCmd("Till"), getLazyStateCompletion(1, "Francesco")) - .onlyBidiStream() - .expectingOutput( - getLazyStateCmd(1, "STATE"), - setStateCmd("STATE", "Till"), - outputCmd("Hello Francesco"), - END_MESSAGE) - .named("With GetState completed later"), - this.setNullState() - .withInput(startMessage(1), inputCmd("Till")) - .assertingOutput( - containsOnly( - errorMessage( - errorMessage -> - assertThat(errorMessage) - .extracting(Protocol.ErrorMessage::getStacktrace, STRING) - .startsWith(NullPointerException.class.getName())))) - .named("Set null state")); - } -} diff --git a/sdk-core/src/test/java/dev/restate/sdk/core/TestDefinitions.java b/sdk-core/src/test/java/dev/restate/sdk/core/TestDefinitions.java deleted file mode 100644 index b10fdc265..000000000 --- a/sdk-core/src/test/java/dev/restate/sdk/core/TestDefinitions.java +++ /dev/null @@ -1,357 +0,0 @@ -// Copyright (c) 2023 - Restate Software, Inc., Restate GmbH -// -// This file is part of the Restate Java SDK, -// which is released under the MIT license. -// -// You can find a copy of the license in file LICENSE in the root -// directory of this repository or package, or at -// https://github.com/restatedev/sdk-java/blob/main/LICENSE -package dev.restate.sdk.core; - -import static org.assertj.core.api.Assertions.assertThat; - -import com.google.protobuf.MessageLite; -import com.google.protobuf.MessageLiteOrBuilder; -import dev.restate.sdk.core.generated.protocol.Protocol; -import dev.restate.sdk.core.statemachine.InvocationInput; -import dev.restate.sdk.core.statemachine.MessageHeader; -import dev.restate.sdk.core.statemachine.ProtoUtils; -import dev.restate.sdk.endpoint.definition.HandlerRunner; -import dev.restate.sdk.endpoint.definition.ServiceDefinition; -import dev.restate.sdk.endpoint.definition.ServiceDefinitionFactories; -import java.util.*; -import java.util.function.Consumer; -import java.util.function.Supplier; -import java.util.stream.Collectors; -import java.util.stream.Stream; -import org.assertj.core.api.InstanceOfAssertFactories; -import org.jspecify.annotations.Nullable; - -public final class TestDefinitions { - - private TestDefinitions() {} - - public interface TestDefinition { - ServiceDefinition getServiceDefinition(); - - HandlerRunner.Options getServiceOptions(); - - String getMethod(); - - boolean isOnlyUnbuffered(); - - boolean isEnablePreviewContext(); - - List getInput(); - - Consumer> getOutputAssert(); - - String getTestCaseName(); - - default boolean isValid() { - return this.getInvalidReason() == null; - } - - @Nullable String getInvalidReason(); - } - - public interface TestSuite { - Stream definitions(); - } - - public interface TestExecutor { - boolean buffered(); - - void executeTest(TestDefinition definition); - } - - public static TestInvocationBuilder testInvocation(Supplier svcSupplier, String handler) { - Object service; - try { - service = svcSupplier.get(); - } catch (UnsupportedOperationException e) { - return new TestInvocationBuilder(Objects.requireNonNull(e.getMessage())); - } - return testInvocation(service, handler); - } - - public static TestInvocationBuilder testInvocation(Object service, String handler) { - if (service instanceof ServiceDefinition) { - return new TestInvocationBuilder((ServiceDefinition) service, null, handler); - } - - // In case it's code generated, discover the adapter - ServiceDefinition serviceDefinition = - ServiceDefinitionFactories.discover(service).create(service, null); - return new TestInvocationBuilder(serviceDefinition, null, handler); - } - - public static TestInvocationBuilder testInvocation( - ServiceDefinition service, HandlerRunner.Options options, String handler) { - return new TestInvocationBuilder(service, options, handler); - } - - public static TestInvocationBuilder unsupported(String reason) { - return new TestInvocationBuilder(Objects.requireNonNull(reason)); - } - - public static class TestInvocationBuilder { - protected final @Nullable ServiceDefinition service; - protected final HandlerRunner.@Nullable Options options; - protected final @Nullable String handler; - protected final @Nullable String invalidReason; - - TestInvocationBuilder( - ServiceDefinition service, HandlerRunner.@Nullable Options options, String handler) { - this.service = service; - this.options = options; - this.handler = handler; - - this.invalidReason = null; - } - - TestInvocationBuilder(String invalidReason) { - this.service = null; - this.options = null; - this.handler = null; - - this.invalidReason = invalidReason; - } - - public WithInputBuilder withInput(Stream messages) { - if (invalidReason != null) { - return new WithInputBuilder(invalidReason); - } - - return new WithInputBuilder( - service, - options, - handler, - messages - .map( - msgOrBuilder -> { - MessageLite msg = ProtoUtils.build(msgOrBuilder); - return InvocationInput.of(MessageHeader.fromMessage(msg), msg); - }) - .collect(Collectors.toList())); - } - - public WithInputBuilder withInput(MessageLiteOrBuilder... messages) { - return withInput(Arrays.stream(messages)); - } - } - - public static class WithInputBuilder extends TestInvocationBuilder { - private final List input; - private boolean onlyUnbuffered = false; - private boolean enablePreviewContext = false; - - WithInputBuilder(@Nullable String invalidReason) { - super(invalidReason); - this.input = Collections.emptyList(); - } - - WithInputBuilder( - ServiceDefinition service, - HandlerRunner.@Nullable Options options, - String method, - List input) { - super(service, options, method); - this.input = new ArrayList<>(input); - } - - @Override - public WithInputBuilder withInput(MessageLiteOrBuilder... messages) { - if (this.invalidReason == null) { - this.input.addAll( - Arrays.stream(messages) - .map( - msgOrBuilder -> { - MessageLite msg = ProtoUtils.build(msgOrBuilder); - return InvocationInput.of(MessageHeader.fromMessage(msg), msg); - }) - .toList()); - } - return this; - } - - public WithInputBuilder onlyBidiStream() { - this.onlyUnbuffered = true; - return this; - } - - public WithInputBuilder enablePreviewContext() { - this.enablePreviewContext = true; - return this; - } - - public ExpectingOutputMessages expectingOutput(MessageLiteOrBuilder... messages) { - List builtMessages = - Arrays.stream(messages).map(ProtoUtils::build).collect(Collectors.toList()); - return assertingOutput( - actual -> - assertThat(actual) - .asInstanceOf(InstanceOfAssertFactories.LIST) - .containsExactlyElementsOf(builtMessages)); - } - - public ExpectingOutputMessages assertingOutput(Consumer> messages) { - return new ExpectingOutputMessages( - service, - options, - invalidReason, - handler, - input, - onlyUnbuffered, - enablePreviewContext, - messages); - } - } - - public abstract static class BaseTestDefinition implements TestDefinition { - protected final @Nullable ServiceDefinition service; - protected final HandlerRunner.@Nullable Options options; - protected final @Nullable String invalidReason; - protected final String method; - protected final List input; - protected final boolean onlyUnbuffered; - protected final boolean enablePreviewContext; - protected final String named; - - private BaseTestDefinition( - @Nullable ServiceDefinition service, - HandlerRunner.@Nullable Options options, - @Nullable String invalidReason, - String method, - List input, - boolean onlyUnbuffered, - boolean enablePreviewContext, - String named) { - this.service = service; - this.options = options; - this.invalidReason = invalidReason; - this.method = method; - this.input = input; - this.onlyUnbuffered = onlyUnbuffered; - this.enablePreviewContext = enablePreviewContext; - this.named = named; - } - - @Override - public ServiceDefinition getServiceDefinition() { - return Objects.requireNonNull(service); - } - - @Override - public HandlerRunner.Options getServiceOptions() { - return options; - } - - @Override - public String getMethod() { - return method; - } - - @Override - public List getInput() { - return input; - } - - @Override - public boolean isOnlyUnbuffered() { - return onlyUnbuffered; - } - - @Override - public boolean isEnablePreviewContext() { - return enablePreviewContext; - } - - @Override - public String getTestCaseName() { - return this.named; - } - - @Override - @Nullable - public String getInvalidReason() { - return invalidReason; - } - } - - public static class ExpectingOutputMessages extends BaseTestDefinition { - private final Consumer> messagesAssert; - - private ExpectingOutputMessages( - @Nullable ServiceDefinition service, - HandlerRunner.@Nullable Options options, - @Nullable String invalidReason, - String method, - List input, - boolean onlyUnbuffered, - boolean enablePreviewContext, - Consumer> messagesAssert) { - super( - service, - options, - invalidReason, - method, - input, - onlyUnbuffered, - enablePreviewContext, - service != null ? service.getServiceName() + "#" + method : "Unknown"); - this.messagesAssert = messagesAssert; - } - - ExpectingOutputMessages( - @Nullable ServiceDefinition service, - HandlerRunner.@Nullable Options options, - @Nullable String invalidReason, - String method, - List input, - boolean onlyUnbuffered, - boolean enablePreviewContext, - Consumer> messagesAssert, - String named) { - super( - service, - options, - invalidReason, - method, - input, - onlyUnbuffered, - enablePreviewContext, - named); - this.messagesAssert = messagesAssert; - } - - public ExpectingOutputMessages named(String name) { - return new ExpectingOutputMessages( - service, - options, - invalidReason, - method, - input, - onlyUnbuffered, - enablePreviewContext, - messagesAssert, - this.named + ": " + name); - } - - @Override - public Consumer> getOutputAssert() { - return outputMessages -> { - messagesAssert.accept(outputMessages); - - // Assert the last message is either an OutputStreamEntry or a SuspensionMessage - assertThat(outputMessages) - .last() - .isNotNull() - .isInstanceOfAny( - Protocol.ErrorMessage.class, - Protocol.SuspensionMessage.class, - Protocol.EndMessage.class); - }; - } - } -} diff --git a/sdk-core/src/test/java/dev/restate/sdk/core/TestRunner.java b/sdk-core/src/test/java/dev/restate/sdk/core/TestRunner.java deleted file mode 100644 index 902002a1f..000000000 --- a/sdk-core/src/test/java/dev/restate/sdk/core/TestRunner.java +++ /dev/null @@ -1,127 +0,0 @@ -// Copyright (c) 2023 - Restate Software, Inc., Restate GmbH -// -// This file is part of the Restate Java SDK, -// which is released under the MIT license. -// -// You can find a copy of the license in file LICENSE in the root -// directory of this repository or package, or at -// https://github.com/restatedev/sdk-java/blob/main/LICENSE -package dev.restate.sdk.core; - -import static java.lang.String.format; -import static org.assertj.core.api.Assertions.entry; -import static org.junit.jupiter.params.provider.Arguments.arguments; - -import com.google.protobuf.InvalidProtocolBufferException; -import dev.restate.sdk.core.TestDefinitions.TestDefinition; -import dev.restate.sdk.core.TestDefinitions.TestExecutor; -import dev.restate.sdk.core.TestDefinitions.TestSuite; -import dev.restate.sdk.core.statemachine.MessageType; -import java.lang.reflect.Method; -import java.util.Arrays; -import java.util.List; -import java.util.Objects; -import java.util.stream.Stream; -import org.assertj.core.api.Assertions; -import org.junit.jupiter.api.TestInstance; -import org.junit.jupiter.api.extension.*; -import org.junit.jupiter.api.parallel.Execution; -import org.junit.jupiter.api.parallel.ExecutionMode; -import org.junit.jupiter.params.ParameterizedTest; -import org.junit.jupiter.params.provider.Arguments; -import org.junit.jupiter.params.provider.MethodSource; -import org.opentest4j.TestAbortedException; - -@TestInstance(TestInstance.Lifecycle.PER_CLASS) -public abstract class TestRunner { - - protected abstract Stream executors(); - - protected abstract Stream definitions(); - - final Stream source() { - List executors = executors().toList(); - - return definitions() - .flatMap(ts -> ts.definitions().map(def -> entry(ts.getClass().getName(), def))) - .flatMap( - entry -> - executors.stream() - .filter( - executor -> !entry.getValue().isOnlyUnbuffered() || !executor.buffered()) - .map( - executor -> - arguments( - "[" - + executor.getClass().getSimpleName() - + "][" - + entry.getKey() - + "] " - + entry.getValue().getTestCaseName(), - executor, - entry.getValue()))); - } - - private static class DisableInvalidTestDefinition implements InvocationInterceptor { - - @Override - public void interceptTestTemplateMethod( - Invocation invocation, - ReflectiveInvocationContext invocationContext, - ExtensionContext extensionContext) - throws Throwable { - Method testMethod = extensionContext.getRequiredTestMethod(); - List arguments = invocationContext.getArguments(); - if (arguments.isEmpty()) { - throw new ExtensionConfigurationException( - format( - "Can't disable based on arguments, because method %s had no parameters.", - testMethod.getName())); - } - - Object maybeTestDefinition = arguments.get(2); - if (!(maybeTestDefinition instanceof TestDefinition)) { - throw new ExtensionConfigurationException( - format( - "Expected second argument to be a TestDefinition, but is %s.", - maybeTestDefinition)); - } - - if (!((TestDefinition) maybeTestDefinition).isValid()) { - throw new TestAbortedException( - "Disabled test definition: " - + ((TestDefinition) maybeTestDefinition).getInvalidReason()); - } - invocation.proceed(); - } - } - - static { - registerMessageFormatters(); - } - - private static void registerMessageFormatters() { - Arrays.stream(MessageType.values()) - .map( - mt -> { - try { - return mt.messageParser().parseFrom(new byte[] {}).getClass(); - } catch (InvalidProtocolBufferException e) { - return null; - } - }) - .filter(Objects::nonNull) - .forEach( - messageClazz -> - Assertions.registerFormatterForType( - messageClazz, ml -> ml.getClass().getSimpleName() + " { " + ml + "}")); - } - - @ExtendWith(DisableInvalidTestDefinition.class) - @ParameterizedTest(name = "{index}: {0}") - @MethodSource("source") - @Execution(ExecutionMode.CONCURRENT) - void executeTest(String testName, TestExecutor executor, TestDefinition definition) { - executor.executeTest(definition); - } -} diff --git a/sdk-core/src/test/java/dev/restate/sdk/core/UserFailuresTestSuite.java b/sdk-core/src/test/java/dev/restate/sdk/core/UserFailuresTestSuite.java deleted file mode 100644 index 6dab29f14..000000000 --- a/sdk-core/src/test/java/dev/restate/sdk/core/UserFailuresTestSuite.java +++ /dev/null @@ -1,104 +0,0 @@ -// Copyright (c) 2023 - Restate Software, Inc., Restate GmbH -// -// This file is part of the Restate Java SDK, -// which is released under the MIT license. -// -// You can find a copy of the license in file LICENSE in the root -// directory of this repository or package, or at -// https://github.com/restatedev/sdk-java/blob/main/LICENSE -package dev.restate.sdk.core; - -import static dev.restate.sdk.core.AssertUtils.containsOnlyExactErrorMessage; -import static dev.restate.sdk.core.AssertUtils.exactErrorMessage; -import static dev.restate.sdk.core.TestDefinitions.*; -import static dev.restate.sdk.core.statemachine.ProtoUtils.*; -import static org.assertj.core.api.Assertions.assertThat; - -import dev.restate.sdk.common.TerminalException; -import dev.restate.sdk.core.generated.protocol.Protocol; -import dev.restate.sdk.core.statemachine.ProtoUtils; -import java.util.concurrent.atomic.AtomicInteger; -import java.util.stream.Stream; - -public abstract class UserFailuresTestSuite implements TestSuite { - - public static final String MY_ERROR = "my error"; - - public static final String WHATEVER = "Whatever"; - - protected abstract TestInvocationBuilder throwIllegalStateException(); - - protected abstract TestInvocationBuilder sideEffectThrowIllegalStateException( - AtomicInteger nonTerminalExceptionsSeen); - - protected abstract TestInvocationBuilder throwTerminalException(int code, String message); - - protected abstract TestInvocationBuilder sideEffectThrowTerminalException( - int code, String message); - - @Override - public Stream definitions() { - AtomicInteger nonTerminalExceptionsSeen = new AtomicInteger(); - - return Stream.of( - // Cases returning ErrorMessage - this.throwIllegalStateException() - .withInput(startMessage(1), inputCmd()) - .assertingOutput(containsOnlyExactErrorMessage(new IllegalStateException("Whatever"))), - this.sideEffectThrowIllegalStateException(nonTerminalExceptionsSeen) - .withInput(startMessage(1), inputCmd()) - .assertingOutput( - msgs -> { - assertThat(msgs.get(1)) - .satisfies(exactErrorMessage(new IllegalStateException("Whatever"))); - - // Check the counter has not been incremented - assertThat(nonTerminalExceptionsSeen).hasValue(0); - }), - - // Cases completing the invocation with OutputStreamEntry.failure - this.throwTerminalException(TerminalException.INTERNAL_SERVER_ERROR_CODE, MY_ERROR) - .withInput(startMessage(1), inputCmd()) - .expectingOutput( - outputCmd(TerminalException.INTERNAL_SERVER_ERROR_CODE, MY_ERROR), END_MESSAGE) - .named("With internal error"), - this.throwTerminalException(501, WHATEVER) - .withInput(startMessage(1), inputCmd()) - .expectingOutput(outputCmd(501, WHATEVER), END_MESSAGE) - .named("With unknown error"), - this.sideEffectThrowTerminalException( - TerminalException.INTERNAL_SERVER_ERROR_CODE, MY_ERROR) - .withInput(startMessage(1), inputCmd()) - .expectingOutput( - Protocol.RunCommandMessage.newBuilder().setResultCompletionId(1), - Protocol.ProposeRunCompletionMessage.newBuilder() - .setResultCompletionId(1) - .setFailure( - ProtoUtils.failure(TerminalException.INTERNAL_SERVER_ERROR_CODE, MY_ERROR)), - suspensionMessage(1)) - .named("With internal error"), - this.sideEffectThrowTerminalException(501, WHATEVER) - .withInput(startMessage(1), inputCmd()) - .expectingOutput( - Protocol.RunCommandMessage.newBuilder().setResultCompletionId(1), - Protocol.ProposeRunCompletionMessage.newBuilder() - .setResultCompletionId(1) - .setFailure(ProtoUtils.failure(501, WHATEVER)), - suspensionMessage(1)) - .named("With unknown error"), - this.sideEffectThrowTerminalException( - TerminalException.INTERNAL_SERVER_ERROR_CODE, MY_ERROR) - .withInput( - startMessage(3), - inputCmd(), - runCmd(1), - runCompletion(1, TerminalException.INTERNAL_SERVER_ERROR_CODE, MY_ERROR)) - .expectingOutput( - outputCmd(TerminalException.INTERNAL_SERVER_ERROR_CODE, MY_ERROR), END_MESSAGE) - .named("With internal error during replay"), - this.sideEffectThrowTerminalException(501, WHATEVER) - .withInput(startMessage(3), inputCmd(), runCmd(1), runCompletion(1, 501, WHATEVER)) - .expectingOutput(outputCmd(501, WHATEVER), END_MESSAGE) - .named("With unknown error during replay")); - } -} diff --git a/sdk-core/src/test/java/dev/restate/sdk/core/javaapi/AsyncResultTest.java b/sdk-core/src/test/java/dev/restate/sdk/core/javaapi/AsyncResultTest.java deleted file mode 100644 index 4699870ea..000000000 --- a/sdk-core/src/test/java/dev/restate/sdk/core/javaapi/AsyncResultTest.java +++ /dev/null @@ -1,206 +0,0 @@ -// Copyright (c) 2023 - Restate Software, Inc., Restate GmbH -// -// This file is part of the Restate Java SDK, -// which is released under the MIT license. -// -// You can find a copy of the license in file LICENSE in the root -// directory of this repository or package, or at -// https://github.com/restatedev/sdk-java/blob/main/LICENSE -package dev.restate.sdk.core.javaapi; - -import static dev.restate.sdk.core.javaapi.JavaAPITests.*; -import static dev.restate.sdk.core.statemachine.ProtoUtils.*; -import static org.assertj.core.api.Assertions.assertThat; - -import dev.restate.sdk.DurableFuture; -import dev.restate.sdk.Select; -import dev.restate.sdk.common.StateKey; -import dev.restate.sdk.common.TimeoutException; -import dev.restate.sdk.core.AsyncResultTestSuite; -import dev.restate.sdk.core.TestDefinitions; -import dev.restate.sdk.core.TestDefinitions.TestInvocationBuilder; -import dev.restate.sdk.core.TestSerdes; -import dev.restate.serde.Serde; -import java.time.Duration; -import java.util.stream.Stream; - -public class AsyncResultTest extends AsyncResultTestSuite { - - @Override - protected TestInvocationBuilder reverseAwaitOrder() { - return testDefinitionForVirtualObject( - "ReverseAwaitOrder", - Serde.VOID, - TestSerdes.STRING, - (context, unused) -> { - DurableFuture a1 = callGreeterGreetService(context, "Francesco"); - DurableFuture a2 = callGreeterGreetService(context, "Till"); - - String a2Res = a2.await(); - context.set(StateKey.of("A2", TestSerdes.STRING), a2Res); - - String a1Res = a1.await(); - - return a1Res + "-" + a2Res; - }); - } - - @Override - protected TestInvocationBuilder awaitTwiceTheSameAwaitable() { - return testDefinitionForService( - "AwaitTwiceTheSameAwaitable", - Serde.VOID, - TestSerdes.STRING, - (context, unused) -> { - DurableFuture a = callGreeterGreetService(context, "Francesco"); - - return a.await() + "-" + a.await(); - }); - } - - @Override - protected TestInvocationBuilder awaitAll() { - return testDefinitionForService( - "AwaitAll", - Serde.VOID, - TestSerdes.STRING, - (context, unused) -> { - DurableFuture a1 = callGreeterGreetService(context, "Francesco"); - DurableFuture a2 = callGreeterGreetService(context, "Till"); - - DurableFuture.all(a1, a2).await(); - - return a1.await() + "-" + a2.await(); - }); - } - - @Override - protected TestInvocationBuilder awaitAny() { - return testDefinitionForService( - "AwaitAny", - Serde.VOID, - TestSerdes.STRING, - (context, unused) -> { - DurableFuture a1 = callGreeterGreetService(context, "Francesco"); - DurableFuture a2 = callGreeterGreetService(context, "Till"); - - return Select.select().or(a1).or(a2).await(); - }); - } - - @Override - protected TestInvocationBuilder combineAnyWithAll() { - return testDefinitionForService( - "CombineAnyWithAll", - Serde.VOID, - TestSerdes.STRING, - (ctx, unused) -> { - DurableFuture a1 = ctx.awakeable(String.class); - DurableFuture a2 = ctx.awakeable(String.class); - DurableFuture a3 = ctx.awakeable(String.class); - DurableFuture a4 = ctx.awakeable(String.class); - - DurableFuture a12 = Select.select().or(a1).or(a2); - DurableFuture a23 = Select.select().or(a2).or(a3); - DurableFuture a34 = Select.select().or(a3).or(a4); - DurableFuture result = - DurableFuture.all(a12, a23, a34).map(v -> a12.await() + a23.await() + a34.await()); - - return result.await(); - }); - } - - @Override - protected TestInvocationBuilder awaitAnyIndex() { - return testDefinitionForService( - "AwaitAnyIndex", - Serde.VOID, - TestSerdes.STRING, - (ctx, unused) -> { - DurableFuture a1 = ctx.awakeable(String.class); - DurableFuture a2 = ctx.awakeable(String.class); - DurableFuture a3 = ctx.awakeable(String.class); - DurableFuture a4 = ctx.awakeable(String.class); - - return String.valueOf(DurableFuture.any(a1, DurableFuture.all(a2, a3), a4).await()); - }); - } - - @Override - protected TestInvocationBuilder awaitOnAlreadyResolvedAwaitables() { - return testDefinitionForService( - "AwaitOnAlreadyResolvedAwaitables", - Serde.VOID, - TestSerdes.STRING, - (ctx, unused) -> { - DurableFuture a1 = ctx.awakeable(String.class); - DurableFuture a2 = ctx.awakeable(String.class); - - DurableFuture a12 = DurableFuture.all(a1, a2); - DurableFuture a12and1 = DurableFuture.all(a12, a1); - DurableFuture a121and12 = DurableFuture.all(a12and1, a12); - - a12and1.await(); - a121and12.await(); - - return a1.await() + a2.await(); - }); - } - - @Override - protected TestInvocationBuilder awaitWithTimeout() { - return testDefinitionForService( - "AwaitWithTimeout", - Serde.VOID, - TestSerdes.STRING, - (ctx, unused) -> { - DurableFuture call = callGreeterGreetService(ctx, "Francesco"); - - String result; - try { - result = call.await(Duration.ofDays(1)); - } catch (TimeoutException e) { - result = "timeout"; - } - - return result; - }); - } - - private TestInvocationBuilder checkAwaitableMapThread() { - return testDefinitionForService( - "CheckAwaitableThread", - Serde.VOID, - Serde.VOID, - (ctx, unused) -> { - var currentThreadName = Thread.currentThread().getName().split("-"); - var currentThreadPool = currentThreadName[0] + "-" + currentThreadName[1]; - - callGreeterGreetService(ctx, "Francesco") - .map( - u -> { - assertThat(Thread.currentThread().getName()).startsWith(currentThreadPool); - return null; - }) - .await(); - - return null; - }); - } - - @Override - public Stream definitions() { - return Stream.concat( - super.definitions(), - Stream.of( - this.checkAwaitableMapThread() - .withInput( - startMessage(3), - inputCmd(), - callCmd(1, 2, GREETER_SERVICE_TARGET, "Francesco"), - callCompletion(2, "FRANCESCO")) - .onlyBidiStream() - .expectingOutput(outputCmd(), END_MESSAGE) - .named("Check map constraints"))); - } -} diff --git a/sdk-core/src/test/java/dev/restate/sdk/core/javaapi/AwakeableIdTest.java b/sdk-core/src/test/java/dev/restate/sdk/core/javaapi/AwakeableIdTest.java deleted file mode 100644 index 3deafeaf4..000000000 --- a/sdk-core/src/test/java/dev/restate/sdk/core/javaapi/AwakeableIdTest.java +++ /dev/null @@ -1,28 +0,0 @@ -// Copyright (c) 2023 - Restate Software, Inc., Restate GmbH -// -// This file is part of the Restate Java SDK, -// which is released under the MIT license. -// -// You can find a copy of the license in file LICENSE in the root -// directory of this repository or package, or at -// https://github.com/restatedev/sdk-java/blob/main/LICENSE -package dev.restate.sdk.core.javaapi; - -import static dev.restate.sdk.core.javaapi.JavaAPITests.testDefinitionForService; - -import dev.restate.sdk.core.AwakeableIdTestSuite; -import dev.restate.sdk.core.TestDefinitions.TestInvocationBuilder; -import dev.restate.sdk.core.TestSerdes; -import dev.restate.serde.Serde; - -public class AwakeableIdTest extends AwakeableIdTestSuite { - - @Override - protected TestInvocationBuilder returnAwakeableId() { - return testDefinitionForService( - "ReturnAwakeableId", - Serde.VOID, - TestSerdes.STRING, - (context, unused) -> context.awakeable(TestSerdes.STRING).id()); - } -} diff --git a/sdk-core/src/test/java/dev/restate/sdk/core/javaapi/CallTest.java b/sdk-core/src/test/java/dev/restate/sdk/core/javaapi/CallTest.java deleted file mode 100644 index e7dc29b68..000000000 --- a/sdk-core/src/test/java/dev/restate/sdk/core/javaapi/CallTest.java +++ /dev/null @@ -1,47 +0,0 @@ -// Copyright (c) 2023 - Restate Software, Inc., Restate GmbH -// -// This file is part of the Restate Java SDK, -// which is released under the MIT license. -// -// You can find a copy of the license in file LICENSE in the root -// directory of this repository or package, or at -// https://github.com/restatedev/sdk-java/blob/main/LICENSE -package dev.restate.sdk.core.javaapi; - -import static dev.restate.sdk.core.javaapi.JavaAPITests.testDefinitionForService; - -import dev.restate.common.Request; -import dev.restate.common.Slice; -import dev.restate.common.Target; -import dev.restate.sdk.core.CallTestSuite; -import dev.restate.sdk.core.TestDefinitions.TestInvocationBuilder; -import dev.restate.serde.Serde; -import java.util.Map; - -public class CallTest extends CallTestSuite { - - @Override - protected TestInvocationBuilder oneWayCall( - Target target, String idempotencyKey, Map headers, Slice body) { - return testDefinitionForService( - "OneWayCall", - Serde.VOID, - Serde.VOID, - (context, unused) -> { - context.send( - Request.of(target, body.toByteArray()) - .headers(headers) - .idempotencyKey(idempotencyKey)); - return null; - }); - } - - @Override - protected TestInvocationBuilder implicitCancellation(Target target, Slice body) { - return testDefinitionForService( - "ImplicitCancellation", - Serde.VOID, - Serde.RAW, - (context, unused) -> context.call(Request.of(target, body.toByteArray())).await()); - } -} diff --git a/sdk-core/src/test/java/dev/restate/sdk/core/javaapi/CodegenDiscoveryTest.java b/sdk-core/src/test/java/dev/restate/sdk/core/javaapi/CodegenDiscoveryTest.java deleted file mode 100644 index e9dc2e2f0..000000000 --- a/sdk-core/src/test/java/dev/restate/sdk/core/javaapi/CodegenDiscoveryTest.java +++ /dev/null @@ -1,86 +0,0 @@ -// Copyright (c) 2023 - Restate Software, Inc., Restate GmbH -// -// This file is part of the Restate Java SDK, -// which is released under the MIT license. -// -// You can find a copy of the license in file LICENSE in the root -// directory of this repository or package, or at -// https://github.com/restatedev/sdk-java/blob/main/LICENSE -package dev.restate.sdk.core.javaapi; - -import static dev.restate.sdk.core.AssertUtils.assertThatDiscovery; -import static org.assertj.core.api.Assertions.assertThat; -import static org.assertj.core.api.InstanceOfAssertFactories.type; - -import dev.restate.sdk.core.generated.manifest.Handler; -import dev.restate.sdk.core.generated.manifest.Input; -import dev.restate.sdk.core.generated.manifest.Output; -import dev.restate.sdk.core.generated.manifest.Service; -import dev.restate.sdk.endpoint.Endpoint; -import org.junit.jupiter.api.Test; - -public class CodegenDiscoveryTest { - - @Test - void checkCustomInputContentType() { - assertThatDiscovery(new CodegenTest.RawInputOutput()) - .extractingService("RawInputOutput") - .extractingHandler("rawInputWithCustomCt") - .extracting(Handler::getInput, type(Input.class)) - .extracting(Input::getContentType) - .isEqualTo("application/vnd.my.custom"); - } - - @Test - void checkCustomInputAcceptContentType() { - assertThatDiscovery(new CodegenTest.RawInputOutput()) - .extractingService("RawInputOutput") - .extractingHandler("rawInputWithCustomAccept") - .extracting(Handler::getInput, type(Input.class)) - .extracting(Input::getContentType) - .isEqualTo("application/*"); - } - - @Test - void checkCustomOutputContentType() { - assertThatDiscovery(new CodegenTest.RawInputOutput()) - .extractingService("RawInputOutput") - .extractingHandler("rawOutputWithCustomCT") - .extracting(Handler::getOutput, type(Output.class)) - .extracting(Output::getContentType) - .isEqualTo("application/vnd.my.custom"); - } - - @Test - void explicitNames() { - assertThatDiscovery((GreeterWithExplicitName) (context, request) -> "") - .extractingService("MyExplicitName") - .extractingHandler("my_greeter"); - assertThat(GreeterWithExplicitNameHandlers.Metadata.SERVICE_NAME).isEqualTo("MyExplicitName"); - } - - @Test - void workflowType() { - assertThatDiscovery(new CodegenTest.MyWorkflow()) - .extractingService("MyWorkflow") - .returns(Service.Ty.WORKFLOW, Service::getTy) - .extractingHandler("run") - .returns(Handler.Ty.WORKFLOW, Handler::getTy); - } - - @Test - void usingTransformer() { - assertThatDiscovery( - Endpoint.bind( - new CodegenTest.RawInputOutput(), - sd -> - sd.documentation("My service documentation") - .configureHandler( - "rawInputWithCustomCt", - hd -> hd.documentation("My handler documentation")))) - .extractingService("RawInputOutput") - .returns("My service documentation", Service::getDocumentation) - .extractingHandler("rawInputWithCustomCt") - .returns("My handler documentation", Handler::getDocumentation); - } -} diff --git a/sdk-core/src/test/java/dev/restate/sdk/core/javaapi/CodegenTest.java b/sdk-core/src/test/java/dev/restate/sdk/core/javaapi/CodegenTest.java deleted file mode 100644 index 759296085..000000000 --- a/sdk-core/src/test/java/dev/restate/sdk/core/javaapi/CodegenTest.java +++ /dev/null @@ -1,327 +0,0 @@ -// Copyright (c) 2023 - Restate Software, Inc., Restate GmbH -// -// This file is part of the Restate Java SDK, -// which is released under the MIT license. -// -// You can find a copy of the license in file LICENSE in the root -// directory of this repository or package, or at -// https://github.com/restatedev/sdk-java/blob/main/LICENSE -package dev.restate.sdk.core.javaapi; - -import static dev.restate.sdk.core.TestDefinitions.testInvocation; -import static dev.restate.sdk.core.statemachine.ProtoUtils.*; -import static org.assertj.core.api.Assertions.assertThat; - -import dev.restate.common.Target; -import dev.restate.sdk.*; -import dev.restate.sdk.annotation.*; -import dev.restate.sdk.core.TestDefinitions; -import dev.restate.sdk.core.TestDefinitions.TestSuite; -import dev.restate.sdk.core.TestSerdes; -import dev.restate.serde.Serde; -import java.io.IOException; -import java.nio.charset.StandardCharsets; -import java.util.stream.Stream; - -public class CodegenTest implements TestSuite { - - @Service - static class ServiceGreeter { - @Handler - String greet(Context context, String request) { - return request; - } - } - - @VirtualObject - static class ObjectGreeter { - @Exclusive - String greet(ObjectContext context, String request) { - return request; - } - - @Handler - @Shared - String sharedGreet(SharedObjectContext context, String request) { - return request; - } - } - - @VirtualObject - public interface GreeterInterface { - @Exclusive - String greet(ObjectContext context, String request); - } - - private static class ObjectGreeterImplementedFromInterface implements GreeterInterface { - - @Override - public String greet(ObjectContext context, String request) { - return request; - } - } - - @Service - @Name("Empty") - static class Empty { - - @Handler - public String emptyInput(Context context) { - var client = CodegenTestEmptyClient.fromContext(context); - return client.emptyInput().await(); - } - - @Handler - public void emptyOutput(Context context, String request) { - var client = CodegenTestEmptyClient.fromContext(context); - client.emptyOutput(request).await(); - } - - @Handler - public void emptyInputOutput(Context context) { - var client = CodegenTestEmptyClient.fromContext(context); - client.emptyInputOutput().await(); - } - } - - @Service - @Name("PrimitiveTypes") - static class PrimitiveTypes { - - @Handler - public int primitiveOutput(Context context) { - var client = CodegenTestPrimitiveTypesClient.fromContext(context); - return client.primitiveOutput().await(); - } - - @Handler - public void primitiveInput(Context context, int input) { - var client = CodegenTestPrimitiveTypesClient.fromContext(context); - client.primitiveInput(input).await(); - } - } - - @VirtualObject - static class CornerCases { - @Exclusive - public String send(ObjectContext context, String request) { - // Just needs to compile - return CodegenTestCornerCasesClient.fromContext(context, request)._send("my_send").await(); - } - } - - @Workflow - static class WorkflowCornerCases { - @Workflow - public String run(WorkflowContext context, String request) { - return null; - } - - @Shared - public String submit(SharedWorkflowContext context, String request) { - // Just needs to compile - String ignored = - CodegenTestWorkflowCornerCasesClient.connect("invalid", request)._submit("my_send"); - CodegenTestWorkflowCornerCasesClient.connect("invalid", request).submit("my_send"); - return CodegenTestWorkflowCornerCasesClient.connect("invalid", request) - .workflowHandle() - .getOutput() - .response() - .getValue(); - } - } - - @Service - @Name("RawInputOutput") - static class RawInputOutput { - - @Handler - @Raw - public byte[] rawOutput(Context context) { - var client = CodegenTestRawInputOutputClient.fromContext(context); - return client.rawOutput().await(); - } - - @Handler - @Raw(contentType = "application/vnd.my.custom") - public byte[] rawOutputWithCustomCT(Context context) { - var client = CodegenTestRawInputOutputClient.fromContext(context); - return client.rawOutputWithCustomCT().await(); - } - - @Handler - public void rawInput(Context context, @Raw byte[] input) { - var client = CodegenTestRawInputOutputClient.fromContext(context); - client.rawInput(input).await(); - } - - @Handler - public void rawInputWithCustomCt( - Context context, @Raw(contentType = "application/vnd.my.custom") byte[] input) { - var client = CodegenTestRawInputOutputClient.fromContext(context); - client.rawInputWithCustomCt(input).await(); - } - - @Handler - public void rawInputWithCustomAccept( - Context context, - @Accept("application/*") @Raw(contentType = "application/vnd.my.custom") byte[] input) { - var client = CodegenTestRawInputOutputClient.fromContext(context); - client.rawInputWithCustomCt(input).await(); - } - } - - @Workflow - @Name("MyWorkflow") - static class MyWorkflow { - - @Workflow - public void run(WorkflowContext context, String myInput) { - var client = CodegenTestMyWorkflowClient.fromContext(context, context.key()); - client.send().sharedHandler(myInput); - } - - @Handler - public String sharedHandler(SharedWorkflowContext context, String myInput) { - var client = CodegenTestMyWorkflowClient.fromContext(context, context.key()); - return client.sharedHandler(myInput).await(); - } - } - - @Service - static class CheckedException { - @Handler - String greet(Context context, String request) throws IOException { - return request; - } - } - - @Service - @CustomSerdeFactory(MySerdeFactory.class) - static class CustomSerde { - @Handler - String greet(Context context, String request) { - assertThat(request).isEqualTo("INPUT"); - return "output"; - } - } - - @Override - public Stream definitions() { - return Stream.of( - testInvocation(ServiceGreeter::new, "greet") - .withInput(startMessage(1), inputCmd("Francesco")) - .onlyBidiStream() - .expectingOutput(outputCmd("Francesco"), END_MESSAGE), - testInvocation(ObjectGreeter::new, "greet") - .withInput(startMessage(1, "slinkydeveloper"), inputCmd("Francesco")) - .onlyBidiStream() - .expectingOutput(outputCmd("Francesco"), END_MESSAGE), - testInvocation(ObjectGreeter::new, "sharedGreet") - .withInput(startMessage(1, "slinkydeveloper"), inputCmd("Francesco")) - .onlyBidiStream() - .expectingOutput(outputCmd("Francesco"), END_MESSAGE), - testInvocation(ObjectGreeterImplementedFromInterface::new, "greet") - .withInput(startMessage(1, "slinkydeveloper"), inputCmd("Francesco")) - .onlyBidiStream() - .expectingOutput(outputCmd("Francesco"), END_MESSAGE), - testInvocation(Empty::new, "emptyInput") - .withInput(startMessage(1), inputCmd(), callCompletion(2, "Till")) - .onlyBidiStream() - .expectingOutput( - callCmd(1, 2, Target.service("Empty", "emptyInput")), - outputCmd("Till"), - END_MESSAGE) - .named("empty output"), - testInvocation(Empty::new, "emptyOutput") - .withInput(startMessage(1), inputCmd("Francesco"), callCompletion(2, Serde.VOID, null)) - .onlyBidiStream() - .expectingOutput( - callCmd(1, 2, Target.service("Empty", "emptyOutput"), "Francesco"), - outputCmd(), - END_MESSAGE) - .named("empty output"), - testInvocation(Empty::new, "emptyInputOutput") - .withInput(startMessage(1), inputCmd("Francesco"), callCompletion(2, Serde.VOID, null)) - .onlyBidiStream() - .expectingOutput( - callCmd(1, 2, Target.service("Empty", "emptyInputOutput")), - outputCmd(), - END_MESSAGE) - .named("empty input and empty output"), - testInvocation(PrimitiveTypes::new, "primitiveOutput") - .withInput(startMessage(1), inputCmd(), callCompletion(2, TestSerdes.INT, 10)) - .onlyBidiStream() - .expectingOutput( - callCmd( - 1, 2, Target.service("PrimitiveTypes", "primitiveOutput"), Serde.VOID, null), - outputCmd(TestSerdes.INT, 10), - END_MESSAGE) - .named("primitive output"), - testInvocation(PrimitiveTypes::new, "primitiveInput") - .withInput(startMessage(1), inputCmd(10), callCompletion(2, Serde.VOID, null)) - .onlyBidiStream() - .expectingOutput( - callCmd( - 1, 2, Target.service("PrimitiveTypes", "primitiveInput"), TestSerdes.INT, 10), - outputCmd(), - END_MESSAGE) - .named("primitive input"), - testInvocation(RawInputOutput::new, "rawInput") - .withInput( - startMessage(1), - inputCmd("{{".getBytes(StandardCharsets.UTF_8)), - callCompletion(2, Serde.VOID, null)) - .onlyBidiStream() - .expectingOutput( - callCmd( - 1, - 2, - Target.service("RawInputOutput", "rawInput"), - "{{".getBytes(StandardCharsets.UTF_8)), - outputCmd(), - END_MESSAGE), - testInvocation(RawInputOutput::new, "rawInputWithCustomCt") - .withInput( - startMessage(1), - inputCmd("{{".getBytes(StandardCharsets.UTF_8)), - callCompletion(2, Serde.VOID, null)) - .onlyBidiStream() - .expectingOutput( - callCmd( - 1, - 2, - Target.service("RawInputOutput", "rawInputWithCustomCt"), - "{{".getBytes(StandardCharsets.UTF_8)), - outputCmd(), - END_MESSAGE), - testInvocation(RawInputOutput::new, "rawOutput") - .withInput( - startMessage(1), - inputCmd(), - callCompletion(2, Serde.RAW, "{{".getBytes(StandardCharsets.UTF_8))) - .onlyBidiStream() - .expectingOutput( - callCmd(1, 2, Target.service("RawInputOutput", "rawOutput"), Serde.VOID, null), - outputCmd("{{".getBytes(StandardCharsets.UTF_8)), - END_MESSAGE), - testInvocation(RawInputOutput::new, "rawOutputWithCustomCT") - .withInput( - startMessage(1), - inputCmd(), - callCompletion(2, Serde.RAW, "{{".getBytes(StandardCharsets.UTF_8))) - .onlyBidiStream() - .expectingOutput( - callCmd( - 1, - 2, - Target.service("RawInputOutput", "rawOutputWithCustomCT"), - Serde.VOID, - null), - outputCmd("{{".getBytes(StandardCharsets.UTF_8)), - END_MESSAGE), - testInvocation(CustomSerde::new, "greet") - .withInput(startMessage(1), inputCmd(MySerdeFactory.SERDE, "input")) - .expectingOutput(outputCmd(MySerdeFactory.SERDE, "OUTPUT"), END_MESSAGE)); - } -} diff --git a/sdk-core/src/test/java/dev/restate/sdk/core/javaapi/EagerStateTest.java b/sdk-core/src/test/java/dev/restate/sdk/core/javaapi/EagerStateTest.java deleted file mode 100644 index 96aa3dff5..000000000 --- a/sdk-core/src/test/java/dev/restate/sdk/core/javaapi/EagerStateTest.java +++ /dev/null @@ -1,108 +0,0 @@ -// Copyright (c) 2023 - Restate Software, Inc., Restate GmbH -// -// This file is part of the Restate Java SDK, -// which is released under the MIT license. -// -// You can find a copy of the license in file LICENSE in the root -// directory of this repository or package, or at -// https://github.com/restatedev/sdk-java/blob/main/LICENSE -package dev.restate.sdk.core.javaapi; - -import static dev.restate.sdk.core.javaapi.JavaAPITests.testDefinitionForVirtualObject; -import static org.assertj.core.api.Assertions.assertThat; - -import dev.restate.sdk.common.StateKey; -import dev.restate.sdk.core.EagerStateTestSuite; -import dev.restate.sdk.core.TestDefinitions.TestInvocationBuilder; -import dev.restate.sdk.core.TestSerdes; -import dev.restate.serde.Serde; - -public class EagerStateTest extends EagerStateTestSuite { - - @Override - protected TestInvocationBuilder getEmpty() { - return testDefinitionForVirtualObject( - "GetEmpty", - Serde.VOID, - TestSerdes.STRING, - (ctx, unused) -> - String.valueOf(ctx.get(StateKey.of("STATE", TestSerdes.STRING)).isEmpty())); - } - - @Override - protected TestInvocationBuilder get() { - return testDefinitionForVirtualObject( - "GetEmpty", - Serde.VOID, - TestSerdes.STRING, - (ctx, unused) -> ctx.get(StateKey.of("STATE", TestSerdes.STRING)).get()); - } - - @Override - protected TestInvocationBuilder getAppendAndGet() { - return testDefinitionForVirtualObject( - "GetAppendAndGet", - TestSerdes.STRING, - TestSerdes.STRING, - (ctx, input) -> { - String oldState = ctx.get(StateKey.of("STATE", TestSerdes.STRING)).get(); - ctx.set(StateKey.of("STATE", TestSerdes.STRING), oldState + input); - - return ctx.get(StateKey.of("STATE", TestSerdes.STRING)).get(); - }); - } - - @Override - protected TestInvocationBuilder getClearAndGet() { - return testDefinitionForVirtualObject( - "GetClearAndGet", - Serde.VOID, - TestSerdes.STRING, - (ctx, input) -> { - String oldState = ctx.get(StateKey.of("STATE", TestSerdes.STRING)).get(); - - ctx.clear(StateKey.of("STATE", TestSerdes.STRING)); - assertThat(ctx.get(StateKey.of("STATE", TestSerdes.STRING))).isEmpty(); - return oldState; - }); - } - - @Override - protected TestInvocationBuilder getClearAllAndGet() { - return testDefinitionForVirtualObject( - "GetClearAllAndGet", - Serde.VOID, - TestSerdes.STRING, - (ctx, input) -> { - String oldState = ctx.get(StateKey.of("STATE", TestSerdes.STRING)).get(); - - ctx.clearAll(); - assertThat(ctx.get(StateKey.of("STATE", TestSerdes.STRING))).isEmpty(); - assertThat(ctx.get(StateKey.of("ANOTHER_STATE", TestSerdes.STRING))).isEmpty(); - - return oldState; - }); - } - - @Override - protected TestInvocationBuilder listKeys() { - return testDefinitionForVirtualObject( - "ListKeys", - Serde.VOID, - TestSerdes.STRING, - (ctx, input) -> String.join(",", ctx.stateKeys())); - } - - @Override - protected TestInvocationBuilder consecutiveGetWithEmpty() { - return testDefinitionForVirtualObject( - "ConsecutiveGetWithEmpty", - Serde.VOID, - Serde.VOID, - (ctx, input) -> { - assertThat(ctx.get(StateKey.of("key-0", TestSerdes.STRING))).isEmpty(); - assertThat(ctx.get(StateKey.of("key-0", TestSerdes.STRING))).isEmpty(); - return null; - }); - } -} diff --git a/sdk-core/src/test/java/dev/restate/sdk/core/javaapi/GreeterWithoutExplicitName.java b/sdk-core/src/test/java/dev/restate/sdk/core/javaapi/GreeterWithoutExplicitName.java deleted file mode 100644 index 77571d77d..000000000 --- a/sdk-core/src/test/java/dev/restate/sdk/core/javaapi/GreeterWithoutExplicitName.java +++ /dev/null @@ -1,19 +0,0 @@ -// Copyright (c) 2023 - Restate Software, Inc., Restate GmbH -// -// This file is part of the Restate Java SDK, -// which is released under the MIT license. -// -// You can find a copy of the license in file LICENSE in the root -// directory of this repository or package, or at -// https://github.com/restatedev/sdk-java/blob/main/LICENSE -package dev.restate.sdk.core.javaapi; - -import dev.restate.sdk.Context; -import dev.restate.sdk.annotation.Handler; -import dev.restate.sdk.annotation.Service; - -@Service -public interface GreeterWithoutExplicitName { - @Handler - String greet(Context context, String request); -} diff --git a/sdk-core/src/test/java/dev/restate/sdk/core/javaapi/InvocationIdTest.java b/sdk-core/src/test/java/dev/restate/sdk/core/javaapi/InvocationIdTest.java deleted file mode 100644 index 97e10d934..000000000 --- a/sdk-core/src/test/java/dev/restate/sdk/core/javaapi/InvocationIdTest.java +++ /dev/null @@ -1,28 +0,0 @@ -// Copyright (c) 2023 - Restate Software, Inc., Restate GmbH -// -// This file is part of the Restate Java SDK, -// which is released under the MIT license. -// -// You can find a copy of the license in file LICENSE in the root -// directory of this repository or package, or at -// https://github.com/restatedev/sdk-java/blob/main/LICENSE -package dev.restate.sdk.core.javaapi; - -import static dev.restate.sdk.core.javaapi.JavaAPITests.testDefinitionForService; - -import dev.restate.sdk.core.InvocationIdTestSuite; -import dev.restate.sdk.core.TestDefinitions.TestInvocationBuilder; -import dev.restate.sdk.core.TestSerdes; -import dev.restate.serde.Serde; - -public class InvocationIdTest extends InvocationIdTestSuite { - - @Override - protected TestInvocationBuilder returnInvocationId() { - return testDefinitionForService( - "ReturnInvocationId", - Serde.VOID, - TestSerdes.STRING, - (ctx, unused) -> ctx.request().invocationId().toString()); - } -} diff --git a/sdk-core/src/test/java/dev/restate/sdk/core/javaapi/JavaAPITests.java b/sdk-core/src/test/java/dev/restate/sdk/core/javaapi/JavaAPITests.java deleted file mode 100644 index 1d6ef41c0..000000000 --- a/sdk-core/src/test/java/dev/restate/sdk/core/javaapi/JavaAPITests.java +++ /dev/null @@ -1,115 +0,0 @@ -// Copyright (c) 2023 - Restate Software, Inc., Restate GmbH -// -// This file is part of the Restate Java SDK, -// which is released under the MIT license. -// -// You can find a copy of the license in file LICENSE in the root -// directory of this repository or package, or at -// https://github.com/restatedev/sdk-java/blob/main/LICENSE -package dev.restate.sdk.core.javaapi; - -import static dev.restate.sdk.core.statemachine.ProtoUtils.GREETER_SERVICE_TARGET; - -import dev.restate.common.Request; -import dev.restate.common.function.ThrowingBiFunction; -import dev.restate.sdk.*; -import dev.restate.sdk.core.*; -import dev.restate.sdk.core.TestDefinitions.TestExecutor; -import dev.restate.sdk.core.TestDefinitions.TestInvocationBuilder; -import dev.restate.sdk.core.TestDefinitions.TestSuite; -import dev.restate.sdk.core.javaapi.reflections.ReflectionTest; -import dev.restate.sdk.endpoint.definition.HandlerDefinition; -import dev.restate.sdk.endpoint.definition.HandlerType; -import dev.restate.sdk.endpoint.definition.ServiceDefinition; -import dev.restate.sdk.endpoint.definition.ServiceType; -import dev.restate.serde.Serde; -import dev.restate.serde.jackson.JacksonSerdeFactory; -import java.util.List; -import java.util.stream.Stream; - -public class JavaAPITests extends TestRunner { - - @Override - protected Stream executors() { - return Stream.of(MockRequestResponse.INSTANCE, MockBidiStream.INSTANCE); - } - - @Override - public Stream definitions() { - return Stream.of( - new AwakeableIdTest(), - new AsyncResultTest(), - new CallTest(), - new EagerStateTest(), - new StateTest(), - new InvocationIdTest(), - new OnlyInputAndOutputTest(), - new PromiseTest(), - new SideEffectTest(), - new SleepTest(), - new StateMachineFailuresTest(), - new UserFailuresTest(), - new RandomTest(), - new CodegenTest(), - new ReflectionTest()); - } - - public static TestInvocationBuilder testDefinitionForService( - String name, Serde reqSerde, Serde resSerde, ThrowingBiFunction runner) { - return TestDefinitions.testInvocation( - ServiceDefinition.of( - name, - ServiceType.SERVICE, - List.of( - HandlerDefinition.of( - "run", - HandlerType.SHARED, - reqSerde, - resSerde, - HandlerRunner.of(runner, new JacksonSerdeFactory(), null)))), - "run"); - } - - public static TestInvocationBuilder testDefinitionForVirtualObject( - String name, - Serde reqSerde, - Serde resSerde, - ThrowingBiFunction runner) { - return TestDefinitions.testInvocation( - ServiceDefinition.of( - name, - ServiceType.VIRTUAL_OBJECT, - List.of( - HandlerDefinition.of( - "run", - HandlerType.EXCLUSIVE, - reqSerde, - resSerde, - HandlerRunner.of(runner, new JacksonSerdeFactory(), null)))), - "run"); - } - - public static TestInvocationBuilder testDefinitionForWorkflow( - String name, - Serde reqSerde, - Serde resSerde, - ThrowingBiFunction runner) { - return TestDefinitions.testInvocation( - ServiceDefinition.of( - name, - ServiceType.WORKFLOW, - List.of( - HandlerDefinition.of( - "run", - HandlerType.WORKFLOW, - reqSerde, - resSerde, - HandlerRunner.of(runner, new JacksonSerdeFactory(), null)))), - "run"); - } - - public static DurableFuture callGreeterGreetService(Context ctx, String parameter) { - return ctx.call( - Request.of(GREETER_SERVICE_TARGET, TestSerdes.STRING, TestSerdes.STRING, parameter)); - } -} diff --git a/sdk-core/src/test/java/dev/restate/sdk/core/javaapi/MySerdeFactory.java b/sdk-core/src/test/java/dev/restate/sdk/core/javaapi/MySerdeFactory.java deleted file mode 100644 index 8d7dbd53f..000000000 --- a/sdk-core/src/test/java/dev/restate/sdk/core/javaapi/MySerdeFactory.java +++ /dev/null @@ -1,38 +0,0 @@ -// Copyright (c) 2023 - Restate Software, Inc., Restate GmbH -// -// This file is part of the Restate Java SDK, -// which is released under the MIT license. -// -// You can find a copy of the license in file LICENSE in the root -// directory of this repository or package, or at -// https://github.com/restatedev/sdk-java/blob/main/LICENSE -package dev.restate.sdk.core.javaapi; - -import static org.assertj.core.api.Assertions.assertThat; - -import dev.restate.serde.Serde; -import dev.restate.serde.SerdeFactory; -import dev.restate.serde.TypeRef; -import java.nio.charset.StandardCharsets; - -@SuppressWarnings("unchecked") -public class MySerdeFactory implements SerdeFactory { - - public static Serde SERDE = - Serde.using( - "mycontent/type", - s -> s.toUpperCase().getBytes(), - b -> new String(b, StandardCharsets.UTF_8).toUpperCase()); - - @Override - public Serde create(TypeRef typeRef) { - assertThat(typeRef.getType()).isEqualTo(String.class); - return (Serde) SERDE; - } - - @Override - public Serde create(Class clazz) { - assertThat(clazz).isEqualTo(String.class); - return (Serde) SERDE; - } -} diff --git a/sdk-core/src/test/java/dev/restate/sdk/core/javaapi/NameInferenceTest.java b/sdk-core/src/test/java/dev/restate/sdk/core/javaapi/NameInferenceTest.java deleted file mode 100644 index 86f495c1c..000000000 --- a/sdk-core/src/test/java/dev/restate/sdk/core/javaapi/NameInferenceTest.java +++ /dev/null @@ -1,25 +0,0 @@ -// Copyright (c) 2023 - Restate Software, Inc., Restate GmbH -// -// This file is part of the Restate Java SDK, -// which is released under the MIT license. -// -// You can find a copy of the license in file LICENSE in the root -// directory of this repository or package, or at -// https://github.com/restatedev/sdk-java/blob/main/LICENSE -package dev.restate.sdk.core.javaapi; - -import static org.assertj.core.api.Assertions.assertThat; - -import org.junit.jupiter.api.Test; - -public class NameInferenceTest { - - @Test - void expectedName() { - assertThat(CodegenTestServiceGreeterHandlers.Metadata.SERVICE_NAME) - .isEqualTo("CodegenTestServiceGreeter"); - assertThat(GreeterWithoutExplicitNameHandlers.Metadata.SERVICE_NAME) - .isEqualTo("GreeterWithoutExplicitName"); - assertThat(GreeterWithExplicitNameHandlers.Metadata.SERVICE_NAME).isEqualTo("MyExplicitName"); - } -} diff --git a/sdk-core/src/test/java/dev/restate/sdk/core/javaapi/OnlyInputAndOutputTest.java b/sdk-core/src/test/java/dev/restate/sdk/core/javaapi/OnlyInputAndOutputTest.java deleted file mode 100644 index 506b31114..000000000 --- a/sdk-core/src/test/java/dev/restate/sdk/core/javaapi/OnlyInputAndOutputTest.java +++ /dev/null @@ -1,27 +0,0 @@ -// Copyright (c) 2023 - Restate Software, Inc., Restate GmbH -// -// This file is part of the Restate Java SDK, -// which is released under the MIT license. -// -// You can find a copy of the license in file LICENSE in the root -// directory of this repository or package, or at -// https://github.com/restatedev/sdk-java/blob/main/LICENSE -package dev.restate.sdk.core.javaapi; - -import static dev.restate.sdk.core.javaapi.JavaAPITests.testDefinitionForService; - -import dev.restate.sdk.core.OnlyInputAndOutputTestSuite; -import dev.restate.sdk.core.TestDefinitions; -import dev.restate.sdk.core.TestSerdes; - -public class OnlyInputAndOutputTest extends OnlyInputAndOutputTestSuite { - - @Override - protected TestDefinitions.TestInvocationBuilder noSyscallsGreeter() { - return testDefinitionForService( - "NoSyscallsGreeter", - TestSerdes.STRING, - TestSerdes.STRING, - (ctx, input) -> "Hello " + input); - } -} diff --git a/sdk-core/src/test/java/dev/restate/sdk/core/javaapi/PromiseTest.java b/sdk-core/src/test/java/dev/restate/sdk/core/javaapi/PromiseTest.java deleted file mode 100644 index b26221038..000000000 --- a/sdk-core/src/test/java/dev/restate/sdk/core/javaapi/PromiseTest.java +++ /dev/null @@ -1,92 +0,0 @@ -// Copyright (c) 2023 - Restate Software, Inc., Restate GmbH -// -// This file is part of the Restate Java SDK, -// which is released under the MIT license. -// -// You can find a copy of the license in file LICENSE in the root -// directory of this repository or package, or at -// https://github.com/restatedev/sdk-java/blob/main/LICENSE -package dev.restate.sdk.core.javaapi; - -import static dev.restate.sdk.core.javaapi.JavaAPITests.*; - -import dev.restate.sdk.common.DurablePromiseKey; -import dev.restate.sdk.common.TerminalException; -import dev.restate.sdk.core.PromiseTestSuite; -import dev.restate.sdk.core.TestDefinitions; -import dev.restate.sdk.core.TestSerdes; -import dev.restate.serde.Serde; - -public class PromiseTest extends PromiseTestSuite { - @Override - protected TestDefinitions.TestInvocationBuilder awaitPromise(String promiseKey) { - return testDefinitionForWorkflow( - "AwaitPromise", - Serde.VOID, - TestSerdes.STRING, - (context, unused) -> - context.promise(DurablePromiseKey.of(promiseKey, String.class)).future().await()); - } - - @Override - protected TestDefinitions.TestInvocationBuilder awaitPeekPromise( - String promiseKey, String emptyCaseReturnValue) { - return testDefinitionForWorkflow( - "PeekPromise", - Serde.VOID, - TestSerdes.STRING, - (context, unused) -> - context - .promise(DurablePromiseKey.of(promiseKey, String.class)) - .peek() - .orElse(emptyCaseReturnValue)); - } - - @Override - protected TestDefinitions.TestInvocationBuilder awaitIsPromiseCompleted(String promiseKey) { - return testDefinitionForWorkflow( - "IsCompletedPromise", - Serde.VOID, - TestSerdes.BOOLEAN, - (context, unused) -> - context.promise(DurablePromiseKey.of(promiseKey, String.class)).peek().isReady()); - } - - @Override - protected TestDefinitions.TestInvocationBuilder awaitResolvePromise( - String promiseKey, String completionValue) { - return testDefinitionForWorkflow( - "ResolvePromise", - Serde.VOID, - TestSerdes.BOOLEAN, - (context, unused) -> { - try { - context - .promiseHandle(DurablePromiseKey.of(promiseKey, String.class)) - .resolve(completionValue); - return true; - } catch (TerminalException e) { - return false; - } - }); - } - - @Override - protected TestDefinitions.TestInvocationBuilder awaitRejectPromise( - String promiseKey, String rejectReason) { - return testDefinitionForWorkflow( - "RejectPromise", - Serde.VOID, - TestSerdes.BOOLEAN, - (context, unused) -> { - try { - context - .promiseHandle(DurablePromiseKey.of(promiseKey, String.class)) - .reject(rejectReason); - return true; - } catch (TerminalException e) { - return false; - } - }); - } -} diff --git a/sdk-core/src/test/java/dev/restate/sdk/core/javaapi/RandomTest.java b/sdk-core/src/test/java/dev/restate/sdk/core/javaapi/RandomTest.java deleted file mode 100644 index 7c9815a08..000000000 --- a/sdk-core/src/test/java/dev/restate/sdk/core/javaapi/RandomTest.java +++ /dev/null @@ -1,34 +0,0 @@ -// Copyright (c) 2023 - Restate Software, Inc., Restate GmbH -// -// This file is part of the Restate Java SDK, -// which is released under the MIT license. -// -// You can find a copy of the license in file LICENSE in the root -// directory of this repository or package, or at -// https://github.com/restatedev/sdk-java/blob/main/LICENSE -package dev.restate.sdk.core.javaapi; - -import static dev.restate.sdk.core.javaapi.JavaAPITests.testDefinitionForService; - -import dev.restate.sdk.core.RandomTestSuite; -import dev.restate.sdk.core.TestDefinitions.TestInvocationBuilder; -import dev.restate.sdk.core.TestSerdes; -import dev.restate.serde.Serde; -import java.util.Random; - -public class RandomTest extends RandomTestSuite { - - @Override - protected TestInvocationBuilder randomShouldBeDeterministic() { - return testDefinitionForService( - "RandomShouldBeDeterministic", - Serde.VOID, - TestSerdes.INT, - (ctx, unused) -> ctx.random().nextInt()); - } - - @Override - protected int getExpectedInt(long seed) { - return new Random(seed).nextInt(); - } -} diff --git a/sdk-core/src/test/java/dev/restate/sdk/core/javaapi/SideEffectTest.java b/sdk-core/src/test/java/dev/restate/sdk/core/javaapi/SideEffectTest.java deleted file mode 100644 index c06a5cbd5..000000000 --- a/sdk-core/src/test/java/dev/restate/sdk/core/javaapi/SideEffectTest.java +++ /dev/null @@ -1,214 +0,0 @@ -// Copyright (c) 2023 - Restate Software, Inc., Restate GmbH -// -// This file is part of the Restate Java SDK, -// which is released under the MIT license. -// -// You can find a copy of the license in file LICENSE in the root -// directory of this repository or package, or at -// https://github.com/restatedev/sdk-java/blob/main/LICENSE -package dev.restate.sdk.core.javaapi; - -import static dev.restate.sdk.core.javaapi.JavaAPITests.testDefinitionForService; -import static org.assertj.core.api.Assertions.assertThat; - -import com.google.protobuf.ByteString; -import dev.restate.common.Slice; -import dev.restate.sdk.DurableFuture; -import dev.restate.sdk.Restate; -import dev.restate.sdk.common.RetryPolicy; -import dev.restate.sdk.core.SideEffectTestSuite; -import dev.restate.sdk.core.TestDefinitions.TestInvocationBuilder; -import dev.restate.sdk.core.TestSerdes; -import dev.restate.serde.Serde; -import dev.restate.serde.jackson.JacksonSerdeFactory; -import java.time.Instant; -import java.util.List; -import java.util.Objects; - -public class SideEffectTest extends SideEffectTestSuite { - - @Override - protected TestInvocationBuilder sideEffect(String sideEffectOutput) { - return testDefinitionForService( - "SideEffect", - Serde.VOID, - TestSerdes.STRING, - (ctx, unused) -> { - String result = ctx.run(String.class, () -> sideEffectOutput); - return "Hello " + result; - }); - } - - @Override - protected TestInvocationBuilder namedSideEffect(String name, String sideEffectOutput) { - return testDefinitionForService( - "NamedSideEffect", - Serde.VOID, - TestSerdes.STRING, - (ctx, unused) -> { - String result = ctx.run(name, String.class, () -> sideEffectOutput); - return "Hello " + result; - }); - } - - @Override - protected TestInvocationBuilder consecutiveSideEffect(String sideEffectOutput) { - return testDefinitionForService( - "ConsecutiveSideEffect", - Serde.VOID, - TestSerdes.STRING, - (ctx, unused) -> { - String firstResult = ctx.run(String.class, () -> sideEffectOutput); - String secondResult = ctx.run(String.class, firstResult::toUpperCase); - - return "Hello " + secondResult; - }); - } - - @Override - protected TestInvocationBuilder checkContextSwitching() { - return testDefinitionForService( - "CheckContextSwitching", - Serde.VOID, - TestSerdes.STRING, - (ctx, unused) -> { - String currentThread = Thread.currentThread().getName(); - - String sideEffectThread = ctx.run(String.class, () -> Thread.currentThread().getName()); - - if (!Objects.equals(currentThread, sideEffectThread)) { - throw new IllegalStateException( - "Current thread and side effect thread do not match: " - + currentThread - + " != " - + sideEffectThread); - } - - return "Hello"; - }); - } - - @Override - protected TestInvocationBuilder failingSideEffect(String name, String reason) { - return testDefinitionForService( - "FailingSideEffect", - Serde.VOID, - TestSerdes.STRING, - (ctx, unused) -> { - ctx.run( - name, - () -> { - throw new IllegalStateException(reason); - }); - return null; - }); - } - - @Override - protected TestInvocationBuilder awaitAllSideEffectWithFirstFailing( - String firstSideEffect, String secondSideEffect, String successValue, String failureReason) { - return testDefinitionForService( - "AwaitAllSideEffectWithFirstFailing", - Serde.VOID, - TestSerdes.STRING, - (ctx, unused) -> { - DurableFuture fut1 = - ctx.runAsync( - firstSideEffect, - String.class, - () -> { - throw new IllegalStateException(failureReason); - }); - DurableFuture fut2 = - ctx.runAsync(secondSideEffect, String.class, () -> successValue); - DurableFuture.all(List.of(fut1, fut2)).await(); - return null; - }); - } - - @Override - protected TestInvocationBuilder awaitAllSideEffectWithSecondFailing( - String firstSideEffect, String secondSideEffect, String successValue, String failureReason) { - return testDefinitionForService( - "AwaitAllSideEffectWithFirstFailing", - Serde.VOID, - TestSerdes.STRING, - (ctx, unused) -> { - DurableFuture fut1 = - ctx.runAsync(firstSideEffect, String.class, () -> successValue); - DurableFuture fut2 = - ctx.runAsync( - secondSideEffect, - String.class, - () -> { - throw new IllegalStateException(failureReason); - }); - DurableFuture.all(List.of(fut1, fut2)).await(); - return null; - }); - } - - @Override - protected TestInvocationBuilder failingSideEffectWithRetryPolicy( - String reason, RetryPolicy retryPolicy) { - return testDefinitionForService( - "FailingSideEffectWithRetryPolicy", - Serde.VOID, - TestSerdes.STRING, - (ctx, unused) -> { - ctx.run( - null, - retryPolicy, - () -> { - throw new IllegalStateException(reason); - }); - return null; - }); - } - - @Override - protected TestInvocationBuilder sideEffectGuard() { - return testDefinitionForService( - "SideEffectGuard", - Serde.VOID, - TestSerdes.STRING, - (ctx, unused) -> { - ctx.run(() -> ctx.sleep(java.time.Duration.ofMillis(100))); - return null; - }); - } - - @Override - protected TestInvocationBuilder sideEffectGuardAwait() { - return testDefinitionForService( - "SideEffectGuardAwait", - Serde.VOID, - TestSerdes.STRING, - (ctx, unused) -> { - DurableFuture timer = ctx.timer("my-sleep", java.time.Duration.ofMillis(100)); - ctx.run(() -> timer.await()); - return null; - }); - } - - @Override - protected TestInvocationBuilder instantNow() { - return testDefinitionForService( - "InstantNow", - Serde.VOID, - TestSerdes.STRING, - (ctx, unused) -> { - var instant = Restate.instantNow(); - return null; - }); - } - - @Override - protected void assertIsInstant(ByteString bytes) { - Instant instant = - JacksonSerdeFactory.DEFAULT - .create(Instant.class) - .deserialize(Slice.wrap(bytes.asReadOnlyByteBuffer())); - assertThat(instant).isNotNull().isBefore(Instant.now()); - } -} diff --git a/sdk-core/src/test/java/dev/restate/sdk/core/javaapi/SleepTest.java b/sdk-core/src/test/java/dev/restate/sdk/core/javaapi/SleepTest.java deleted file mode 100644 index e27783ba0..000000000 --- a/sdk-core/src/test/java/dev/restate/sdk/core/javaapi/SleepTest.java +++ /dev/null @@ -1,60 +0,0 @@ -// Copyright (c) 2023 - Restate Software, Inc., Restate GmbH -// -// This file is part of the Restate Java SDK, -// which is released under the MIT license. -// -// You can find a copy of the license in file LICENSE in the root -// directory of this repository or package, or at -// https://github.com/restatedev/sdk-java/blob/main/LICENSE -package dev.restate.sdk.core.javaapi; - -import static dev.restate.sdk.core.javaapi.JavaAPITests.testDefinitionForService; - -import dev.restate.sdk.DurableFuture; -import dev.restate.sdk.core.SleepTestSuite; -import dev.restate.sdk.core.TestDefinitions.TestInvocationBuilder; -import dev.restate.sdk.core.TestSerdes; -import dev.restate.serde.Serde; -import java.time.Duration; -import java.util.ArrayList; -import java.util.List; - -public class SleepTest extends SleepTestSuite { - - @Override - protected TestInvocationBuilder sleepGreeter() { - return testDefinitionForService( - "SleepGreeter", - Serde.VOID, - TestSerdes.STRING, - (ctx, unused) -> { - ctx.sleep(Duration.ofSeconds(1)); - return "Hello"; - }); - } - - @Override - protected TestInvocationBuilder manySleeps() { - return testDefinitionForService( - "ManySleeps", - Serde.VOID, - Serde.VOID, - (ctx, unused) -> { - List> collectedDurableFutures = new ArrayList<>(); - - for (int i = 0; i < 10; i++) { - collectedDurableFutures.add(ctx.timer(Duration.ofSeconds(1))); - } - - DurableFuture.all( - collectedDurableFutures.get(0), - collectedDurableFutures.get(1), - collectedDurableFutures - .subList(2, collectedDurableFutures.size()) - .toArray(DurableFuture[]::new)) - .await(); - - return null; - }); - } -} diff --git a/sdk-core/src/test/java/dev/restate/sdk/core/javaapi/StateMachineFailuresTest.java b/sdk-core/src/test/java/dev/restate/sdk/core/javaapi/StateMachineFailuresTest.java deleted file mode 100644 index c346b060e..000000000 --- a/sdk-core/src/test/java/dev/restate/sdk/core/javaapi/StateMachineFailuresTest.java +++ /dev/null @@ -1,109 +0,0 @@ -// Copyright (c) 2023 - Restate Software, Inc., Restate GmbH -// -// This file is part of the Restate Java SDK, -// which is released under the MIT license. -// -// You can find a copy of the license in file LICENSE in the root -// directory of this repository or package, or at -// https://github.com/restatedev/sdk-java/blob/main/LICENSE -package dev.restate.sdk.core.javaapi; - -import static dev.restate.sdk.core.javaapi.JavaAPITests.testDefinitionForService; -import static dev.restate.sdk.core.javaapi.JavaAPITests.testDefinitionForVirtualObject; - -import dev.restate.sdk.common.AbortedExecutionException; -import dev.restate.sdk.common.StateKey; -import dev.restate.sdk.common.TerminalException; -import dev.restate.sdk.core.StateMachineFailuresTestSuite; -import dev.restate.sdk.core.TestDefinitions.TestInvocationBuilder; -import dev.restate.sdk.core.TestSerdes; -import dev.restate.serde.Serde; -import java.nio.charset.StandardCharsets; -import java.time.Duration; -import java.util.concurrent.atomic.AtomicInteger; - -public class StateMachineFailuresTest extends StateMachineFailuresTestSuite { - - private static final StateKey STATE = - StateKey.of( - "STATE", - Serde.using( - i -> Integer.toString(i).getBytes(StandardCharsets.UTF_8), - b -> Integer.parseInt(new String(b, StandardCharsets.UTF_8)))); - - @Override - protected TestInvocationBuilder getState(AtomicInteger nonTerminalExceptionsSeen) { - return testDefinitionForVirtualObject( - "GetState", - Serde.VOID, - TestSerdes.STRING, - (ctx, unused) -> { - try { - ctx.get(STATE); - } catch (Throwable e) { - // A user should never catch Throwable!!! - if (AbortedExecutionException.INSTANCE.equals(e)) { - AbortedExecutionException.sneakyThrow(); - } - if (!(e instanceof TerminalException)) { - nonTerminalExceptionsSeen.addAndGet(1); - } else { - throw e; - } - } - - return "Francesco"; - }); - } - - @Override - protected TestInvocationBuilder sideEffectFailure(Serde serde) { - return testDefinitionForVirtualObject( - "SideEffectFailure", - Serde.VOID, - TestSerdes.STRING, - (ctx, unused) -> { - ctx.run(serde, () -> 0); - return "Francesco"; - }); - } - - @Override - protected TestInvocationBuilder awaitRunAfterProgressWasMade() { - return testDefinitionForService( - "AwaitRunAfterProgressWasMade", - Serde.VOID, - TestSerdes.STRING, - (ctx, unused) -> { - var runFuture = ctx.runAsync("my-side-effect", String.class, () -> "result"); - runFuture.await(); - return null; - }); - } - - @Override - protected TestInvocationBuilder awaitSleepAfterProgressWasMade() { - return testDefinitionForService( - "AwaitSleepAfterProgressWasMade", - Serde.VOID, - TestSerdes.STRING, - (ctx, unused) -> { - var sleepFuture = ctx.timer(Duration.ZERO); - sleepFuture.await(); - return null; - }); - } - - @Override - protected TestInvocationBuilder awaitAwakeableAfterProgressWasMade() { - return testDefinitionForService( - "AwaitAwakeableAfterProgressWasMade", - Serde.VOID, - TestSerdes.STRING, - (ctx, unused) -> { - var awakeable = ctx.awakeable(String.class); - awakeable.await(); - return null; - }); - } -} diff --git a/sdk-core/src/test/java/dev/restate/sdk/core/javaapi/StateTest.java b/sdk-core/src/test/java/dev/restate/sdk/core/javaapi/StateTest.java deleted file mode 100644 index 54a91a204..000000000 --- a/sdk-core/src/test/java/dev/restate/sdk/core/javaapi/StateTest.java +++ /dev/null @@ -1,71 +0,0 @@ -// Copyright (c) 2023 - Restate Software, Inc., Restate GmbH -// -// This file is part of the Restate Java SDK, -// which is released under the MIT license. -// -// You can find a copy of the license in file LICENSE in the root -// directory of this repository or package, or at -// https://github.com/restatedev/sdk-java/blob/main/LICENSE -package dev.restate.sdk.core.javaapi; - -import static dev.restate.sdk.core.javaapi.JavaAPITests.testDefinitionForVirtualObject; - -import dev.restate.sdk.common.StateKey; -import dev.restate.sdk.core.StateTestSuite; -import dev.restate.sdk.core.TestDefinitions.TestInvocationBuilder; -import dev.restate.sdk.core.TestSerdes; -import dev.restate.serde.Serde; - -public class StateTest extends StateTestSuite { - - @Override - protected TestInvocationBuilder getState() { - return testDefinitionForVirtualObject( - "GetState", - Serde.VOID, - TestSerdes.STRING, - (ctx, unused) -> { - String state = ctx.get(StateKey.of("STATE", String.class)).orElse("Unknown"); - - return "Hello " + state; - }); - } - - @Override - protected TestInvocationBuilder getAndSetState() { - return testDefinitionForVirtualObject( - "GetState", - TestSerdes.STRING, - TestSerdes.STRING, - (ctx, input) -> { - String state = ctx.get(StateKey.of("STATE", String.class)).get(); - - ctx.set(StateKey.of("STATE", String.class), input); - - return "Hello " + state; - }); - } - - @Override - protected TestInvocationBuilder setNullState() { - return testDefinitionForVirtualObject( - "GetState", - Serde.VOID, - TestSerdes.STRING, - (ctx, unused) -> { - ctx.set( - StateKey.of( - "STATE", - Serde.using( - l -> { - throw new IllegalStateException("Unexpected call to serde fn"); - }, - l -> { - throw new IllegalStateException("Unexpected call to serde fn"); - })), - null); - - throw new IllegalStateException("set did not fail"); - }); - } -} diff --git a/sdk-core/src/test/java/dev/restate/sdk/core/javaapi/UserFailuresTest.java b/sdk-core/src/test/java/dev/restate/sdk/core/javaapi/UserFailuresTest.java deleted file mode 100644 index b54ab3227..000000000 --- a/sdk-core/src/test/java/dev/restate/sdk/core/javaapi/UserFailuresTest.java +++ /dev/null @@ -1,86 +0,0 @@ -// Copyright (c) 2023 - Restate Software, Inc., Restate GmbH -// -// This file is part of the Restate Java SDK, -// which is released under the MIT license. -// -// You can find a copy of the license in file LICENSE in the root -// directory of this repository or package, or at -// https://github.com/restatedev/sdk-java/blob/main/LICENSE -package dev.restate.sdk.core.javaapi; - -import static dev.restate.sdk.core.javaapi.JavaAPITests.testDefinitionForService; - -import dev.restate.sdk.common.AbortedExecutionException; -import dev.restate.sdk.common.TerminalException; -import dev.restate.sdk.core.TestDefinitions.TestInvocationBuilder; -import dev.restate.sdk.core.UserFailuresTestSuite; -import dev.restate.serde.Serde; -import java.util.concurrent.atomic.AtomicInteger; - -public class UserFailuresTest extends UserFailuresTestSuite { - - @Override - protected TestInvocationBuilder throwIllegalStateException() { - return testDefinitionForService( - "ThrowIllegalStateException", - Serde.VOID, - Serde.VOID, - (ctx, unused) -> { - throw new IllegalStateException("Whatever"); - }); - } - - @Override - protected TestInvocationBuilder sideEffectThrowIllegalStateException( - AtomicInteger nonTerminalExceptionsSeen) { - return testDefinitionForService( - "SideEffectThrowIllegalStateException", - Serde.VOID, - Serde.VOID, - (ctx, unused) -> { - try { - ctx.run( - () -> { - throw new IllegalStateException("Whatever"); - }); - } catch (Throwable e) { - // A user should never catch Throwable!!! - if (AbortedExecutionException.INSTANCE.equals(e)) { - AbortedExecutionException.sneakyThrow(); - } - if (!(e instanceof TerminalException)) { - nonTerminalExceptionsSeen.addAndGet(1); - } - throw e; - } - - throw new IllegalStateException("Unexpected end"); - }); - } - - @Override - protected TestInvocationBuilder throwTerminalException(int code, String message) { - return testDefinitionForService( - "ThrowTerminalException", - Serde.VOID, - Serde.VOID, - (ctx, unused) -> { - throw new TerminalException(code, message); - }); - } - - @Override - protected TestInvocationBuilder sideEffectThrowTerminalException(int code, String message) { - return testDefinitionForService( - "SideEffectThrowTerminalException", - Serde.VOID, - Serde.VOID, - (ctx, unused) -> { - ctx.run( - () -> { - throw new TerminalException(code, message); - }); - throw new IllegalStateException("This should not be reached"); - }); - } -} diff --git a/sdk-core/src/test/java/dev/restate/sdk/core/javaapi/reflections/CheckedException.java b/sdk-core/src/test/java/dev/restate/sdk/core/javaapi/reflections/CheckedException.java deleted file mode 100644 index ecc486c90..000000000 --- a/sdk-core/src/test/java/dev/restate/sdk/core/javaapi/reflections/CheckedException.java +++ /dev/null @@ -1,21 +0,0 @@ -// Copyright (c) 2023 - Restate Software, Inc., Restate GmbH -// -// This file is part of the Restate Java SDK, -// which is released under the MIT license. -// -// You can find a copy of the license in file LICENSE in the root -// directory of this repository or package, or at -// https://github.com/restatedev/sdk-java/blob/main/LICENSE -package dev.restate.sdk.core.javaapi.reflections; - -import dev.restate.sdk.annotation.Handler; -import dev.restate.sdk.annotation.Service; -import java.io.IOException; - -@Service -public class CheckedException { - @Handler - public String greet(String request) throws IOException { - return request; - } -} diff --git a/sdk-core/src/test/java/dev/restate/sdk/core/javaapi/reflections/CustomSerde.java b/sdk-core/src/test/java/dev/restate/sdk/core/javaapi/reflections/CustomSerde.java deleted file mode 100644 index d0b41081e..000000000 --- a/sdk-core/src/test/java/dev/restate/sdk/core/javaapi/reflections/CustomSerde.java +++ /dev/null @@ -1,26 +0,0 @@ -// Copyright (c) 2023 - Restate Software, Inc., Restate GmbH -// -// This file is part of the Restate Java SDK, -// which is released under the MIT license. -// -// You can find a copy of the license in file LICENSE in the root -// directory of this repository or package, or at -// https://github.com/restatedev/sdk-java/blob/main/LICENSE -package dev.restate.sdk.core.javaapi.reflections; - -import static org.assertj.core.api.Assertions.assertThat; - -import dev.restate.sdk.annotation.CustomSerdeFactory; -import dev.restate.sdk.annotation.Handler; -import dev.restate.sdk.annotation.Service; -import dev.restate.sdk.core.javaapi.MySerdeFactory; - -@Service -@CustomSerdeFactory(MySerdeFactory.class) -public class CustomSerde { - @Handler - public String greet(String request) { - assertThat(request).isEqualTo("INPUT"); - return "output"; - } -} diff --git a/sdk-core/src/test/java/dev/restate/sdk/core/javaapi/reflections/Empty.java b/sdk-core/src/test/java/dev/restate/sdk/core/javaapi/reflections/Empty.java deleted file mode 100644 index 8585b4048..000000000 --- a/sdk-core/src/test/java/dev/restate/sdk/core/javaapi/reflections/Empty.java +++ /dev/null @@ -1,34 +0,0 @@ -// Copyright (c) 2023 - Restate Software, Inc., Restate GmbH -// -// This file is part of the Restate Java SDK, -// which is released under the MIT license. -// -// You can find a copy of the license in file LICENSE in the root -// directory of this repository or package, or at -// https://github.com/restatedev/sdk-java/blob/main/LICENSE -package dev.restate.sdk.core.javaapi.reflections; - -import dev.restate.sdk.Restate; -import dev.restate.sdk.annotation.Handler; -import dev.restate.sdk.annotation.Name; -import dev.restate.sdk.annotation.Service; - -@Service -@Name("Empty") -public class Empty { - - @Handler - public String emptyInput() { - return Restate.service(Empty.class).emptyInput(); - } - - @Handler - public void emptyOutput(String request) { - Restate.service(Empty.class).emptyOutput(request); - } - - @Handler - public void emptyInputOutput() { - Restate.service(Empty.class).emptyInputOutput(); - } -} diff --git a/sdk-core/src/test/java/dev/restate/sdk/core/javaapi/reflections/GreeterInterface.java b/sdk-core/src/test/java/dev/restate/sdk/core/javaapi/reflections/GreeterInterface.java deleted file mode 100644 index 8f9d267b5..000000000 --- a/sdk-core/src/test/java/dev/restate/sdk/core/javaapi/reflections/GreeterInterface.java +++ /dev/null @@ -1,20 +0,0 @@ -// Copyright (c) 2023 - Restate Software, Inc., Restate GmbH -// -// This file is part of the Restate Java SDK, -// which is released under the MIT license. -// -// You can find a copy of the license in file LICENSE in the root -// directory of this repository or package, or at -// https://github.com/restatedev/sdk-java/blob/main/LICENSE -package dev.restate.sdk.core.javaapi.reflections; - -import dev.restate.sdk.annotation.Handler; -import dev.restate.sdk.annotation.Name; -import dev.restate.sdk.annotation.VirtualObject; - -@VirtualObject -@Name("GreeterInterface") -public interface GreeterInterface { - @Handler - String greet(String request); -} diff --git a/sdk-core/src/test/java/dev/restate/sdk/core/javaapi/GreeterWithExplicitName.java b/sdk-core/src/test/java/dev/restate/sdk/core/javaapi/reflections/GreeterWithExplicitName.java similarity index 92% rename from sdk-core/src/test/java/dev/restate/sdk/core/javaapi/GreeterWithExplicitName.java rename to sdk-core/src/test/java/dev/restate/sdk/core/javaapi/reflections/GreeterWithExplicitName.java index 593a95940..eb10efc87 100644 --- a/sdk-core/src/test/java/dev/restate/sdk/core/javaapi/GreeterWithExplicitName.java +++ b/sdk-core/src/test/java/dev/restate/sdk/core/javaapi/reflections/GreeterWithExplicitName.java @@ -6,7 +6,7 @@ // You can find a copy of the license in file LICENSE in the root // directory of this repository or package, or at // https://github.com/restatedev/sdk-java/blob/main/LICENSE -package dev.restate.sdk.core.javaapi; +package dev.restate.sdk.core.javaapi.reflections; import dev.restate.sdk.Context; import dev.restate.sdk.annotation.Handler; diff --git a/sdk-core/src/test/java/dev/restate/sdk/core/javaapi/reflections/ObjectGreeter.java b/sdk-core/src/test/java/dev/restate/sdk/core/javaapi/reflections/ObjectGreeter.java deleted file mode 100644 index 5317ad328..000000000 --- a/sdk-core/src/test/java/dev/restate/sdk/core/javaapi/reflections/ObjectGreeter.java +++ /dev/null @@ -1,26 +0,0 @@ -// Copyright (c) 2023 - Restate Software, Inc., Restate GmbH -// -// This file is part of the Restate Java SDK, -// which is released under the MIT license. -// -// You can find a copy of the license in file LICENSE in the root -// directory of this repository or package, or at -// https://github.com/restatedev/sdk-java/blob/main/LICENSE -package dev.restate.sdk.core.javaapi.reflections; - -import dev.restate.sdk.annotation.Exclusive; -import dev.restate.sdk.annotation.Shared; -import dev.restate.sdk.annotation.VirtualObject; - -@VirtualObject -public class ObjectGreeter { - @Exclusive - public String greet(String request) { - return request; - } - - @Shared - public String sharedGreet(String request) { - return request; - } -} diff --git a/sdk-core/src/test/java/dev/restate/sdk/core/javaapi/reflections/ObjectGreeterImplementedFromInterface.java b/sdk-core/src/test/java/dev/restate/sdk/core/javaapi/reflections/ObjectGreeterImplementedFromInterface.java deleted file mode 100644 index ca49a003d..000000000 --- a/sdk-core/src/test/java/dev/restate/sdk/core/javaapi/reflections/ObjectGreeterImplementedFromInterface.java +++ /dev/null @@ -1,16 +0,0 @@ -// Copyright (c) 2023 - Restate Software, Inc., Restate GmbH -// -// This file is part of the Restate Java SDK, -// which is released under the MIT license. -// -// You can find a copy of the license in file LICENSE in the root -// directory of this repository or package, or at -// https://github.com/restatedev/sdk-java/blob/main/LICENSE -package dev.restate.sdk.core.javaapi.reflections; - -public class ObjectGreeterImplementedFromInterface implements GreeterInterface { - @Override - public String greet(String request) { - return request; - } -} diff --git a/sdk-core/src/test/java/dev/restate/sdk/core/javaapi/reflections/PrimitiveTypes.java b/sdk-core/src/test/java/dev/restate/sdk/core/javaapi/reflections/PrimitiveTypes.java deleted file mode 100644 index f747434ea..000000000 --- a/sdk-core/src/test/java/dev/restate/sdk/core/javaapi/reflections/PrimitiveTypes.java +++ /dev/null @@ -1,29 +0,0 @@ -// Copyright (c) 2023 - Restate Software, Inc., Restate GmbH -// -// This file is part of the Restate Java SDK, -// which is released under the MIT license. -// -// You can find a copy of the license in file LICENSE in the root -// directory of this repository or package, or at -// https://github.com/restatedev/sdk-java/blob/main/LICENSE -package dev.restate.sdk.core.javaapi.reflections; - -import dev.restate.sdk.Restate; -import dev.restate.sdk.annotation.Handler; -import dev.restate.sdk.annotation.Name; -import dev.restate.sdk.annotation.Service; - -@Service -@Name("PrimitiveTypes") -public class PrimitiveTypes { - - @Handler - public int primitiveOutput() { - return Restate.service(PrimitiveTypes.class).primitiveOutput(); - } - - @Handler - public void primitiveInput(int input) { - Restate.service(PrimitiveTypes.class).primitiveInput(input); - } -} diff --git a/sdk-core/src/test/java/dev/restate/sdk/core/javaapi/reflections/ReflectionDiscoveryTest.java b/sdk-core/src/test/java/dev/restate/sdk/core/javaapi/reflections/ReflectionDiscoveryTest.java index 9917863a7..327745f8b 100644 --- a/sdk-core/src/test/java/dev/restate/sdk/core/javaapi/reflections/ReflectionDiscoveryTest.java +++ b/sdk-core/src/test/java/dev/restate/sdk/core/javaapi/reflections/ReflectionDiscoveryTest.java @@ -9,15 +9,12 @@ package dev.restate.sdk.core.javaapi.reflections; import static dev.restate.sdk.core.AssertUtils.assertThatDiscovery; -import static org.assertj.core.api.Assertions.assertThat; import static org.assertj.core.api.InstanceOfAssertFactories.type; import dev.restate.sdk.core.generated.manifest.Handler; import dev.restate.sdk.core.generated.manifest.Input; import dev.restate.sdk.core.generated.manifest.Output; import dev.restate.sdk.core.generated.manifest.Service; -import dev.restate.sdk.core.javaapi.GreeterWithExplicitName; -import dev.restate.sdk.core.javaapi.GreeterWithExplicitNameHandlers; import dev.restate.sdk.endpoint.Endpoint; import dev.restate.serde.Serde; import org.junit.jupiter.api.Test; @@ -97,7 +94,6 @@ void explicitNames() { assertThatDiscovery((GreeterWithExplicitName) (context, request) -> "") .extractingService("MyExplicitName") .extractingHandler("my_greeter"); - assertThat(GreeterWithExplicitNameHandlers.Metadata.SERVICE_NAME).isEqualTo("MyExplicitName"); } @Test diff --git a/sdk-core/src/test/java/dev/restate/sdk/core/javaapi/reflections/ReflectionTest.java b/sdk-core/src/test/java/dev/restate/sdk/core/javaapi/reflections/ReflectionTest.java deleted file mode 100644 index 0ca8e2767..000000000 --- a/sdk-core/src/test/java/dev/restate/sdk/core/javaapi/reflections/ReflectionTest.java +++ /dev/null @@ -1,143 +0,0 @@ -// Copyright (c) 2023 - Restate Software, Inc., Restate GmbH -// -// This file is part of the Restate Java SDK, -// which is released under the MIT license. -// -// You can find a copy of the license in file LICENSE in the root -// directory of this repository or package, or at -// https://github.com/restatedev/sdk-java/blob/main/LICENSE -package dev.restate.sdk.core.javaapi.reflections; - -import static dev.restate.sdk.core.TestDefinitions.testInvocation; -import static dev.restate.sdk.core.statemachine.ProtoUtils.*; - -import dev.restate.common.Target; -import dev.restate.sdk.core.TestDefinitions; -import dev.restate.sdk.core.TestDefinitions.TestSuite; -import dev.restate.sdk.core.TestSerdes; -import dev.restate.sdk.core.javaapi.MySerdeFactory; -import dev.restate.serde.Serde; -import java.nio.charset.StandardCharsets; -import java.util.stream.Stream; - -public class ReflectionTest implements TestSuite { - - @Override - public Stream definitions() { - return Stream.of( - testInvocation(ServiceGreeter::new, "greet") - .withInput(startMessage(1), inputCmd("Francesco")) - .onlyBidiStream() - .expectingOutput(outputCmd("Francesco"), END_MESSAGE), - testInvocation(ObjectGreeter::new, "greet") - .withInput(startMessage(1, "slinkydeveloper"), inputCmd("Francesco")) - .onlyBidiStream() - .expectingOutput(outputCmd("Francesco"), END_MESSAGE), - testInvocation(ObjectGreeter::new, "sharedGreet") - .withInput(startMessage(1, "slinkydeveloper"), inputCmd("Francesco")) - .onlyBidiStream() - .expectingOutput(outputCmd("Francesco"), END_MESSAGE), - testInvocation(ObjectGreeterImplementedFromInterface::new, "greet") - .withInput(startMessage(1, "slinkydeveloper"), inputCmd("Francesco")) - .onlyBidiStream() - .expectingOutput(outputCmd("Francesco"), END_MESSAGE), - testInvocation(Empty::new, "emptyInput") - .withInput(startMessage(1), inputCmd(), callCompletion(2, "Till")) - .onlyBidiStream() - .expectingOutput( - callCmd(1, 2, Target.service("Empty", "emptyInput")), - outputCmd("Till"), - END_MESSAGE) - .named("empty output"), - testInvocation(Empty::new, "emptyOutput") - .withInput(startMessage(1), inputCmd("Francesco"), callCompletion(2, Serde.VOID, null)) - .onlyBidiStream() - .expectingOutput( - callCmd(1, 2, Target.service("Empty", "emptyOutput"), "Francesco"), - outputCmd(), - END_MESSAGE) - .named("empty output"), - testInvocation(Empty::new, "emptyInputOutput") - .withInput(startMessage(1), inputCmd("Francesco"), callCompletion(2, Serde.VOID, null)) - .onlyBidiStream() - .expectingOutput( - callCmd(1, 2, Target.service("Empty", "emptyInputOutput")), - outputCmd(), - END_MESSAGE) - .named("empty input and empty output"), - testInvocation(PrimitiveTypes::new, "primitiveOutput") - .withInput(startMessage(1), inputCmd(), callCompletion(2, TestSerdes.INT, 10)) - .onlyBidiStream() - .expectingOutput( - callCmd( - 1, 2, Target.service("PrimitiveTypes", "primitiveOutput"), Serde.VOID, null), - outputCmd(TestSerdes.INT, 10), - END_MESSAGE) - .named("primitive output"), - testInvocation(PrimitiveTypes::new, "primitiveInput") - .withInput(startMessage(1), inputCmd(10), callCompletion(2, Serde.VOID, null)) - .onlyBidiStream() - .expectingOutput( - callCmd( - 1, 2, Target.service("PrimitiveTypes", "primitiveInput"), TestSerdes.INT, 10), - outputCmd(), - END_MESSAGE) - .named("primitive input"), - testInvocation(RawInputOutput::new, "rawInput") - .withInput( - startMessage(1), - inputCmd("{{".getBytes(StandardCharsets.UTF_8)), - callCompletion(2, Serde.VOID, null)) - .onlyBidiStream() - .expectingOutput( - callCmd( - 1, - 2, - Target.service("RawInputOutput", "rawInput"), - "{{".getBytes(StandardCharsets.UTF_8)), - outputCmd(), - END_MESSAGE), - testInvocation(RawInputOutput::new, "rawInputWithCustomCt") - .withInput( - startMessage(1), - inputCmd("{{".getBytes(StandardCharsets.UTF_8)), - callCompletion(2, Serde.VOID, null)) - .onlyBidiStream() - .expectingOutput( - callCmd( - 1, - 2, - Target.service("RawInputOutput", "rawInputWithCustomCt"), - "{{".getBytes(StandardCharsets.UTF_8)), - outputCmd(), - END_MESSAGE), - testInvocation(RawInputOutput::new, "rawOutput") - .withInput( - startMessage(1), - inputCmd(), - callCompletion(2, Serde.RAW, "{{".getBytes(StandardCharsets.UTF_8))) - .onlyBidiStream() - .expectingOutput( - callCmd(1, 2, Target.service("RawInputOutput", "rawOutput"), Serde.VOID, null), - outputCmd("{{".getBytes(StandardCharsets.UTF_8)), - END_MESSAGE), - testInvocation(RawInputOutput::new, "rawOutputWithCustomCT") - .withInput( - startMessage(1), - inputCmd(), - callCompletion(2, Serde.RAW, "{{".getBytes(StandardCharsets.UTF_8))) - .onlyBidiStream() - .expectingOutput( - callCmd( - 1, - 2, - Target.service("RawInputOutput", "rawOutputWithCustomCT"), - Serde.VOID, - null), - outputCmd("{{".getBytes(StandardCharsets.UTF_8)), - END_MESSAGE), - testInvocation(CustomSerde::new, "greet") - .withInput(startMessage(1), inputCmd(MySerdeFactory.SERDE, "input")) - .expectingOutput(outputCmd(MySerdeFactory.SERDE, "OUTPUT"), END_MESSAGE)); - } -} diff --git a/sdk-core/src/test/java/dev/restate/sdk/core/javaapi/reflections/ServiceGreeter.java b/sdk-core/src/test/java/dev/restate/sdk/core/javaapi/reflections/ServiceGreeter.java deleted file mode 100644 index 5357a10bc..000000000 --- a/sdk-core/src/test/java/dev/restate/sdk/core/javaapi/reflections/ServiceGreeter.java +++ /dev/null @@ -1,20 +0,0 @@ -// Copyright (c) 2023 - Restate Software, Inc., Restate GmbH -// -// This file is part of the Restate Java SDK, -// which is released under the MIT license. -// -// You can find a copy of the license in file LICENSE in the root -// directory of this repository or package, or at -// https://github.com/restatedev/sdk-java/blob/main/LICENSE -package dev.restate.sdk.core.javaapi.reflections; - -import dev.restate.sdk.annotation.Handler; -import dev.restate.sdk.annotation.Service; - -@Service -public class ServiceGreeter { - @Handler - public String greet(String request) { - return request; - } -} diff --git a/sdk-core/src/test/java/dev/restate/sdk/core/lambda/LambdaHandlerTest.java b/sdk-core/src/test/java/dev/restate/sdk/core/lambda/LambdaHandlerTest.java index 61e267104..323813d8f 100644 --- a/sdk-core/src/test/java/dev/restate/sdk/core/lambda/LambdaHandlerTest.java +++ b/sdk-core/src/test/java/dev/restate/sdk/core/lambda/LambdaHandlerTest.java @@ -8,7 +8,6 @@ // https://github.com/restatedev/sdk-java/blob/main/LICENSE package dev.restate.sdk.core.lambda; -import static dev.restate.sdk.core.statemachine.ProtoUtils.*; import static org.assertj.core.api.Assertions.assertThat; import com.amazonaws.services.lambda.runtime.ClientContext; @@ -18,61 +17,19 @@ import com.amazonaws.services.lambda.runtime.events.APIGatewayProxyRequestEvent; import com.amazonaws.services.lambda.runtime.events.APIGatewayProxyResponseEvent; import com.fasterxml.jackson.databind.ObjectMapper; -import com.google.protobuf.ByteString; -import com.google.protobuf.MessageLite; +import dev.restate.sdk.core.DiscoveryProtocol; import dev.restate.sdk.core.generated.manifest.EndpointManifestSchema; import dev.restate.sdk.core.generated.manifest.Service; -import dev.restate.sdk.core.generated.protocol.Protocol; import dev.restate.sdk.core.lambda.testservices.JavaCounterServiceHandlers; import dev.restate.sdk.core.lambda.testservices.MyServicesHandler; -import dev.restate.sdk.core.statemachine.MessageHeader; -import dev.restate.sdk.core.statemachine.ProtoUtils; import dev.restate.sdk.lambda.BaseRestateLambdaHandler; -import java.io.ByteArrayOutputStream; import java.io.IOException; -import java.nio.ByteBuffer; import java.util.Base64; import java.util.Map; import org.junit.jupiter.api.Test; class LambdaHandlerTest { - @Test - public void testInvoke() throws IOException { - MyServicesHandler handler = new MyServicesHandler(); - - // Mock request - APIGatewayProxyRequestEvent request = new APIGatewayProxyRequestEvent(); - request.setHeaders(Map.of("content-type", ProtoUtils.serviceProtocolContentTypeHeader(false))); - request.setPath( - "/a/path/prefix/invoke/" + JavaCounterServiceHandlers.Metadata.SERVICE_NAME + "/get"); - request.setHttpMethod("POST"); - request.setIsBase64Encoded(true); - request.setBody( - Base64.getEncoder() - .encodeToString( - serializeEntries( - Protocol.StartMessage.newBuilder() - .setDebugId("123") - .setId(ByteString.copyFromUtf8("123")) - .setKnownEntries(1) - .setPartialState(true) - .build(), - inputCmd()))); - - // Send request - APIGatewayProxyResponseEvent response = handler.handleRequest(request, mockContext()); - - // Assert response - assertThat(response.getStatusCode()).isEqualTo(200); - assertThat(response.getHeaders()) - .containsEntry("content-type", ProtoUtils.serviceProtocolContentTypeHeader(false)); - assertThat(response.getIsBase64Encoded()).isTrue(); - assertThat(response.getBody()) - .asBase64Decoded() - .isEqualTo(serializeEntries(getLazyStateCmd(1, "counter").build(), suspensionMessage(1))); - } - @Test public void testDiscovery() throws IOException { BaseRestateLambdaHandler handler = new MyServicesHandler(); @@ -80,7 +37,7 @@ public void testDiscovery() throws IOException { // Mock request APIGatewayProxyRequestEvent request = new APIGatewayProxyRequestEvent(); request.setPath("/a/path/prefix/discover"); - request.setHeaders(Map.of("accept", ProtoUtils.serviceProtocolDiscoveryContentTypeHeader())); + request.setHeaders(Map.of("accept", DiscoveryProtocol.Version.MAX.getHeader())); // Send request APIGatewayProxyResponseEvent response = handler.handleRequest(request, mockContext()); @@ -88,7 +45,7 @@ public void testDiscovery() throws IOException { // Assert response assertThat(response.getStatusCode()).isEqualTo(200); assertThat(response.getHeaders()) - .containsEntry("content-type", ProtoUtils.serviceProtocolDiscoveryContentTypeHeader()); + .containsEntry("content-type", DiscoveryProtocol.Version.MAX.getHeader()); assertThat(response.getIsBase64Encoded()).isTrue(); byte[] decodedStringResponse = Base64.getDecoder().decode(response.getBody()); // Compute response and write it back @@ -100,17 +57,6 @@ public void testDiscovery() throws IOException { .containsOnly(JavaCounterServiceHandlers.Metadata.SERVICE_NAME); } - private static byte[] serializeEntries(MessageLite... msgs) throws IOException { - ByteArrayOutputStream outputStream = new ByteArrayOutputStream(); - for (MessageLite msg : msgs) { - ByteBuffer headerBuf = ByteBuffer.allocate(8); - headerBuf.putLong(MessageHeader.fromMessage(msg).encode()); - outputStream.write(headerBuf.array()); - msg.writeTo(outputStream); - } - return outputStream.toByteArray(); - } - private Context mockContext() { return new Context() { @Override diff --git a/sdk-core/src/test/java/dev/restate/sdk/core/statemachine/MessageDecoderTest.java b/sdk-core/src/test/java/dev/restate/sdk/core/statemachine/MessageDecoderTest.java deleted file mode 100644 index 04038dab0..000000000 --- a/sdk-core/src/test/java/dev/restate/sdk/core/statemachine/MessageDecoderTest.java +++ /dev/null @@ -1,59 +0,0 @@ -// Copyright (c) 2023 - Restate Software, Inc., Restate GmbH -// -// This file is part of the Restate Java SDK, -// which is released under the MIT license. -// -// You can find a copy of the license in file LICENSE in the root -// directory of this repository or package, or at -// https://github.com/restatedev/sdk-java/blob/main/LICENSE -package dev.restate.sdk.core.statemachine; - -import static dev.restate.sdk.core.AssertUtils.assertThatDecodingMessages; -import static dev.restate.sdk.core.statemachine.ProtoUtils.startMessage; -import static org.assertj.core.api.Assertions.entry; - -import com.google.protobuf.MessageLite; -import dev.restate.common.Slice; -import java.nio.ByteBuffer; -import java.util.List; -import org.junit.jupiter.api.Test; - -public class MessageDecoderTest { - - @Test - void oneMessage() { - assertThatDecodingMessages( - ProtoUtils.encodeMessageToSlice(startMessage(1, "my-key", entry("key", "value")))) - .map(InvocationInput::message) - .containsExactly(startMessage(1, "my-key", entry("key", "value")).build()); - } - - @Test - void multiMessage() { - assertThatDecodingMessages( - ProtoUtils.encodeMessageToSlice(startMessage(1, "my-key", entry("key", "value"))), - ProtoUtils.encodeMessageToSlice(ProtoUtils.inputCmd("my-value"))) - .map(InvocationInput::message) - .containsExactly( - startMessage(1, "my-key", entry("key", "value")).build(), - ProtoUtils.inputCmd("my-value")); - } - - @Test - void multiMessageInSingleBuffer() { - List messages = - List.of( - startMessage(1, "my-key", entry("key", "value")).build(), - ProtoUtils.inputCmd("my-value")); - ByteBuffer byteBuffer = - ByteBuffer.allocate(messages.stream().mapToInt(MessageEncoder::encodeLength).sum()); - messages.stream().map(ProtoUtils::encodeMessageToByteBuffer).forEach(byteBuffer::put); - byteBuffer.flip(); - - assertThatDecodingMessages(Slice.wrap(byteBuffer)) - .map(InvocationInput::message) - .containsExactly( - startMessage(1, "my-key", entry("key", "value")).build(), - ProtoUtils.inputCmd("my-value")); - } -} diff --git a/sdk-core/src/test/java/dev/restate/sdk/core/statemachine/ProtoUtils.java b/sdk-core/src/test/java/dev/restate/sdk/core/statemachine/ProtoUtils.java deleted file mode 100644 index 6a68958fe..000000000 --- a/sdk-core/src/test/java/dev/restate/sdk/core/statemachine/ProtoUtils.java +++ /dev/null @@ -1,499 +0,0 @@ -// Copyright (c) 2023 - Restate Software, Inc., Restate GmbH -// -// This file is part of the Restate Java SDK, -// which is released under the MIT license. -// -// You can find a copy of the license in file LICENSE in the root -// directory of this repository or package, or at -// https://github.com/restatedev/sdk-java/blob/main/LICENSE -package dev.restate.sdk.core.statemachine; - -import com.google.protobuf.ByteString; -import com.google.protobuf.MessageLite; -import com.google.protobuf.MessageLiteOrBuilder; -import com.google.protobuf.UnsafeByteOperations; -import dev.restate.common.Slice; -import dev.restate.common.Target; -import dev.restate.sdk.core.TestSerdes; -import dev.restate.sdk.core.generated.protocol.Protocol; -import dev.restate.sdk.core.generated.protocol.Protocol.StartMessage.StateEntry; -import dev.restate.serde.Serde; -import java.nio.ByteBuffer; -import java.util.ArrayList; -import java.util.Arrays; -import java.util.List; -import java.util.Map; -import java.util.stream.Collectors; -import org.jspecify.annotations.Nullable; - -public class ProtoUtils { - - public static long invocationIdToRandomSeed(String invocationId) { - return new InvocationIdImpl(invocationId, null).toRandomSeed(); - } - - public static String serviceProtocolContentTypeHeader(boolean enableContextPreview) { - return ServiceProtocol.serviceProtocolVersionToHeaderValue( - enableContextPreview - ? ServiceProtocol.MAX_SERVICE_PROTOCOL_VERSION - : ServiceProtocol.MIN_SERVICE_PROTOCOL_VERSION); - } - - public static String serviceProtocolDiscoveryContentTypeHeader() { - return "application/vnd.restate.endpointmanifest.v2+json"; - } - - public static ByteBuffer invocationInputToByteString(InvocationInput invocationInput) { - ByteBuffer buffer = ByteBuffer.allocate(MessageEncoder.encodeLength(invocationInput.message())); - - buffer.putLong(invocationInput.header().encode()); - buffer.put(invocationInput.message().toByteString().asReadOnlyByteBuffer()); - - buffer.flip(); - return buffer; - } - - public static ByteBuffer encodeMessageToByteBuffer(MessageLiteOrBuilder msgOrBuilder) { - var msg = build(msgOrBuilder); - return invocationInputToByteString(InvocationInput.of(MessageHeader.fromMessage(msg), msg)); - } - - public static Slice encodeMessageToSlice(MessageLiteOrBuilder msgOrBuilder) { - return Slice.wrap(encodeMessageToByteBuffer(msgOrBuilder)); - } - - public static List bufferToMessages(List byteBuffers) { - var messageDecoder = new MessageDecoder(); - byteBuffers.stream().map(Slice::wrap).forEach(messageDecoder::offer); - - var outputList = new ArrayList(); - while (messageDecoder.isNextAvailable()) { - outputList.add(messageDecoder.next()); - } - return outputList.stream().map(InvocationInput::message).collect(Collectors.toList()); - } - - public static Protocol.StartMessage.Builder startMessage(int entries) { - return Protocol.StartMessage.newBuilder() - .setId(ByteString.copyFromUtf8("abc")) - .setDebugId("abc") - .setKnownEntries(entries) - .setPartialState(true); - } - - public static Protocol.StartMessage.Builder startMessage(int entries, String key) { - return Protocol.StartMessage.newBuilder() - .setId(ByteString.copyFromUtf8("abc")) - .setDebugId("abc") - .setKnownEntries(entries) - .setKey(key) - .setPartialState(true); - } - - @SafeVarargs - public static Protocol.StartMessage.Builder startMessage( - int entries, String key, Map.Entry... stateEntries) { - return startMessage(entries, key) - .addAllStateMap( - Arrays.stream(stateEntries) - .map( - e -> - StateEntry.newBuilder() - .setKey(ByteString.copyFromUtf8(e.getKey())) - .setValue( - ByteString.copyFrom( - TestSerdes.STRING.serialize(e.getValue()).toByteArray())) - .build()) - .collect(Collectors.toList())); - } - - public static Protocol.SuspensionMessage suspensionMessage(Integer... completionIds) { - return Protocol.SuspensionMessage.newBuilder() - .addAllWaitingCompletions(List.of(completionIds)) - .addWaitingSignals(1) - .build(); - } - - public static Protocol.InputCommandMessage inputCmd() { - return Protocol.InputCommandMessage.newBuilder() - .setValue(Protocol.Value.newBuilder().setContent(ByteString.EMPTY)) - .build(); - } - - public static Protocol.InputCommandMessage inputCmd(byte[] value) { - return Protocol.InputCommandMessage.newBuilder() - .setValue(Protocol.Value.newBuilder().setContent(ByteString.copyFrom(value))) - .build(); - } - - public static Protocol.InputCommandMessage inputCmd(Serde serde, T value) { - return Protocol.InputCommandMessage.newBuilder().setValue(value(serde, value)).build(); - } - - public static Protocol.InputCommandMessage inputCmd(String value) { - return inputCmd(TestSerdes.STRING, value); - } - - public static Protocol.InputCommandMessage inputCmd(int value) { - return inputCmd(TestSerdes.INT, value); - } - - public static Protocol.OutputCommandMessage outputCmd(Serde serde, T value) { - return Protocol.OutputCommandMessage.newBuilder().setValue(value(serde, value)).build(); - } - - public static Protocol.OutputCommandMessage outputCmd(String value) { - return outputCmd(TestSerdes.STRING, value); - } - - public static Protocol.OutputCommandMessage outputCmd(int value) { - return outputCmd(TestSerdes.INT, value); - } - - public static Protocol.OutputCommandMessage outputCmd(byte[] b) { - return outputCmd(Serde.RAW, b); - } - - public static Protocol.OutputCommandMessage outputCmd() { - return Protocol.OutputCommandMessage.newBuilder() - .setValue(Protocol.Value.newBuilder().setContent(ByteString.empty()).build()) - .build(); - } - - public static Protocol.OutputCommandMessage outputCmd(int code, String message) { - return Protocol.OutputCommandMessage.newBuilder().setFailure(failure(code, message)).build(); - } - - public static Protocol.OutputCommandMessage outputCmd(Throwable e) { - return Protocol.OutputCommandMessage.newBuilder().setFailure(failure(e)).build(); - } - - public static Protocol.GetLazyStateCommandMessage.Builder getLazyStateCmd( - int completionId, String key) { - return Protocol.GetLazyStateCommandMessage.newBuilder() - .setKey(ByteString.copyFromUtf8(key)) - .setResultCompletionId(completionId); - } - - public static Protocol.GetEagerStateCommandMessage getEagerStateEmptyCmd(String key) { - return Protocol.GetEagerStateCommandMessage.newBuilder() - .setKey(ByteString.copyFromUtf8(key)) - .setVoid(Protocol.Void.getDefaultInstance()) - .build(); - } - - public static Protocol.GetEagerStateCommandMessage getEagerStateCmd( - String key, Serde serde, T value) { - return Protocol.GetEagerStateCommandMessage.newBuilder() - .setKey(ByteString.copyFromUtf8(key)) - .setValue(value(serde, value)) - .build(); - } - - public static Protocol.GetEagerStateCommandMessage getEagerStateCmd(String key, String value) { - return getEagerStateCmd(key, TestSerdes.STRING, value); - } - - public static Protocol.GetLazyStateCompletionNotificationMessage getLazyStateCompletion( - int completionId, Serde serde, T value) { - return Protocol.GetLazyStateCompletionNotificationMessage.newBuilder() - .setCompletionId(completionId) - .setValue( - Protocol.Value.newBuilder() - .setContent( - UnsafeByteOperations.unsafeWrap(serde.serialize(value).asReadOnlyByteBuffer()))) - .build(); - } - - public static Protocol.GetLazyStateCompletionNotificationMessage getLazyStateCompletion( - int completionId, String value) { - return getLazyStateCompletion(completionId, TestSerdes.STRING, value); - } - - public static Protocol.GetLazyStateCompletionNotificationMessage getLazyStateCompletionEmpty( - int completionId) { - return Protocol.GetLazyStateCompletionNotificationMessage.newBuilder() - .setCompletionId(completionId) - .setVoid(Protocol.Void.getDefaultInstance()) - .build(); - } - - public static Protocol.SetStateCommandMessage setStateCmd( - String key, Serde serde, T value) { - return Protocol.SetStateCommandMessage.newBuilder() - .setKey(ByteString.copyFromUtf8(key)) - .setValue( - Protocol.Value.newBuilder() - .setContent(ByteString.copyFrom(serde.serialize(value).toByteArray()))) - .build(); - } - - public static Protocol.SetStateCommandMessage setStateCmd(String key, String value) { - return setStateCmd(key, TestSerdes.STRING, value); - } - - public static Protocol.ClearStateCommandMessage clearStateCmd(String key) { - return Protocol.ClearStateCommandMessage.newBuilder() - .setKey(ByteString.copyFromUtf8(key)) - .build(); - } - - public static Protocol.CallCommandMessage.Builder callCmd( - int invocationIdCompletionId, int resultCompletionId, Target target) { - Protocol.CallCommandMessage.Builder builder = - Protocol.CallCommandMessage.newBuilder() - .setServiceName(target.getService()) - .setHandlerName(target.getHandler()); - if (target.getKey() != null) { - builder.setKey(target.getKey()); - } - builder - .setInvocationIdNotificationIdx(invocationIdCompletionId) - .setResultCompletionId(resultCompletionId); - - return builder; - } - - public static Protocol.CallCommandMessage.Builder callCmd( - int invocationIdCompletionId, int resultCompletionId, Target target, byte[] parameter) { - return callCmd(invocationIdCompletionId, resultCompletionId, target, Serde.RAW, parameter); - } - - public static Protocol.CallCommandMessage.Builder callCmd( - int invocationIdCompletionId, - int resultCompletionId, - Target target, - Serde reqSerde, - T parameter) { - return callCmd(invocationIdCompletionId, resultCompletionId, target) - .setParameter(ByteString.copyFrom(reqSerde.serialize(parameter).toByteArray())); - } - - public static Protocol.CallCommandMessage.Builder callCmd( - int invocationIdCompletionId, int resultCompletionId, Target target, String parameter) { - return callCmd( - invocationIdCompletionId, resultCompletionId, target, TestSerdes.STRING, parameter); - } - - public static Protocol.OneWayCallCommandMessage.Builder oneWayCallCmd( - int invocationIdCompletionId, - Target target, - @Nullable String idempotencyKey, - @Nullable Map headers, - Slice input) { - Protocol.OneWayCallCommandMessage.Builder builder = - Protocol.OneWayCallCommandMessage.newBuilder() - .setServiceName(target.getService()) - .setHandlerName(target.getHandler()); - if (target.getKey() != null) { - builder.setKey(target.getKey()); - } - if (idempotencyKey != null) { - builder.setIdempotencyKey(idempotencyKey); - } - if (headers != null) { - builder.addAllHeaders( - headers.entrySet().stream() - .map( - e -> - Protocol.Header.newBuilder() - .setKey(e.getKey()) - .setValue(e.getValue()) - .build()) - .toList()); - } - - builder - .setParameter(UnsafeByteOperations.unsafeWrap(input.asReadOnlyByteBuffer())) - .setInvocationIdNotificationIdx(invocationIdCompletionId); - - return builder; - } - - public static Protocol.CallCompletionNotificationMessage.Builder callCompletion( - int completionId, Serde reqSerde, T parameter) { - return Protocol.CallCompletionNotificationMessage.newBuilder() - .setCompletionId(completionId) - .setValue(value(reqSerde, parameter)); - } - - public static Protocol.CallCompletionNotificationMessage.Builder callCompletion( - int completionId, String result) { - return callCompletion(completionId, TestSerdes.STRING, result); - } - - public static Protocol.CallCompletionNotificationMessage.Builder callCompletion( - int completionId, Throwable failure) { - return Protocol.CallCompletionNotificationMessage.newBuilder() - .setCompletionId(completionId) - .setFailure(failure(failure)); - } - - public static - Protocol.CallInvocationIdCompletionNotificationMessage.Builder callInvocationIdCompletion( - int completionId, String invocationId) { - return Protocol.CallInvocationIdCompletionNotificationMessage.newBuilder() - .setCompletionId(completionId) - .setInvocationId(invocationId); - } - - public static Protocol.GetPromiseCommandMessage.Builder getPromiseCmd( - int completionId, String key) { - return Protocol.GetPromiseCommandMessage.newBuilder() - .setResultCompletionId(completionId) - .setKey(key); - } - - public static Protocol.PeekPromiseCommandMessage.Builder peekPromiseCmd( - int completionId, String key) { - return Protocol.PeekPromiseCommandMessage.newBuilder() - .setResultCompletionId(completionId) - .setKey(key); - } - - public static Protocol.CompletePromiseCommandMessage.Builder completePromiseCmd( - int completionId, String key, String value) { - return Protocol.CompletePromiseCommandMessage.newBuilder() - .setKey(key) - .setResultCompletionId(completionId) - .setCompletionValue(value(value)); - } - - public static Protocol.CompletePromiseCommandMessage.Builder completePromiseCmd( - int completionId, String key, Throwable e) { - return Protocol.CompletePromiseCommandMessage.newBuilder() - .setKey(key) - .setResultCompletionId(completionId) - .setCompletionFailure(failure(e)); - } - - public static Protocol.SignalNotificationMessage signalNotification( - int signalId, Serde serde, T value) { - return Protocol.SignalNotificationMessage.newBuilder() - .setIdx(signalId) - .setValue( - Protocol.Value.newBuilder() - .setContent( - UnsafeByteOperations.unsafeWrap(serde.serialize(value).asReadOnlyByteBuffer()))) - .build(); - } - - public static Protocol.SignalNotificationMessage signalNotification(int signalId, String value) { - return signalNotification(signalId, TestSerdes.STRING, value); - } - - public static Protocol.SignalNotificationMessage signalNotification( - String signalName, Serde serde, T value) { - return Protocol.SignalNotificationMessage.newBuilder() - .setName(signalName) - .setValue( - Protocol.Value.newBuilder() - .setContent( - UnsafeByteOperations.unsafeWrap(serde.serialize(value).asReadOnlyByteBuffer()))) - .build(); - } - - public static Protocol.SignalNotificationMessage signalNotification( - String signalName, String value) { - return signalNotification(signalName, TestSerdes.STRING, value); - } - - public static Protocol.RunCommandMessage runCmd(int completion) { - return Protocol.RunCommandMessage.newBuilder().setResultCompletionId(completion).build(); - } - - public static Protocol.RunCommandMessage runCmd(int completion, String name) { - return Protocol.RunCommandMessage.newBuilder() - .setResultCompletionId(completion) - .setName(name) - .build(); - } - - public static Protocol.RunCompletionNotificationMessage.Builder runCompletion( - int completionId, Serde reqSerde, T parameter) { - return Protocol.RunCompletionNotificationMessage.newBuilder() - .setCompletionId(completionId) - .setValue(value(reqSerde, parameter)); - } - - public static Protocol.RunCompletionNotificationMessage.Builder runCompletion( - int completionId, String result) { - return runCompletion(completionId, TestSerdes.STRING, result); - } - - public static Protocol.RunCompletionNotificationMessage.Builder runCompletion( - int completionId, int code, String message) { - return Protocol.RunCompletionNotificationMessage.newBuilder() - .setCompletionId(completionId) - .setFailure(failure(code, message)); - } - - public static Protocol.ProposeRunCompletionMessage.Builder proposeRunCompletion( - int completionId, Serde reqSerde, T parameter) { - return Protocol.ProposeRunCompletionMessage.newBuilder() - .setResultCompletionId(completionId) - .setValue(value(reqSerde, parameter).getContent()); - } - - public static Protocol.ProposeRunCompletionMessage.Builder proposeRunCompletion( - int completionId, String result) { - return proposeRunCompletion(completionId, TestSerdes.STRING, result); - } - - public static Protocol.ProposeRunCompletionMessage.Builder proposeRunCompletion( - int completionId, int code, String message) { - return Protocol.ProposeRunCompletionMessage.newBuilder() - .setResultCompletionId(completionId) - .setFailure(failure(code, message)); - } - - public static Protocol.SendSignalCommandMessage sendCancelSignal(String targetInvocationId) { - return Protocol.SendSignalCommandMessage.newBuilder() - .setTargetInvocationId(targetInvocationId) - .setIdx(1) - .setVoid(Protocol.Void.getDefaultInstance()) - .build(); - } - - public static Protocol.Failure failure(int code, String message) { - return Util.toProtocolFailure(code, message, Map.of()); - } - - public static Protocol.Failure failure(Throwable throwable) { - return Util.toProtocolFailure(throwable); - } - - public static Protocol.Value value(String jsonStringContent) { - return value(TestSerdes.STRING, jsonStringContent); - } - - public static Protocol.Value value(Serde serde, T value) { - return Protocol.Value.newBuilder() - .setContent(UnsafeByteOperations.unsafeWrap(serde.serialize(value).asReadOnlyByteBuffer())) - .build(); - } - - public static final Protocol.EndMessage END_MESSAGE = Protocol.EndMessage.getDefaultInstance(); - public static final Protocol.SignalNotificationMessage CANCELLATION_SIGNAL = - Protocol.SignalNotificationMessage.newBuilder() - .setVoid(Protocol.Void.getDefaultInstance()) - .setIdx(1) - .build(); - - public static final Target GREETER_SERVICE_TARGET = Target.service("Greeter", "greeter"); - public static Target GREETER_VIRTUAL_OBJECT_TARGET = - Target.virtualObject("Greeter", "Francesco", "greeter"); - - public static Protocol.StateKeys.Builder stateKeys(String... keys) { - return Protocol.StateKeys.newBuilder() - .addAllKeys(Arrays.stream(keys).map(ByteString::copyFromUtf8).collect(Collectors.toList())); - } - - public static MessageLite build(MessageLiteOrBuilder value) { - if (value instanceof MessageLite) { - return (MessageLite) value; - } else { - return ((MessageLite.Builder) value).build(); - } - } -} diff --git a/sdk-core/src/test/kotlin/dev/restate/sdk/core/kotlinapi/AsyncResultTest.kt b/sdk-core/src/test/kotlin/dev/restate/sdk/core/kotlinapi/AsyncResultTest.kt deleted file mode 100644 index bef5b070f..000000000 --- a/sdk-core/src/test/kotlin/dev/restate/sdk/core/kotlinapi/AsyncResultTest.kt +++ /dev/null @@ -1,124 +0,0 @@ -// Copyright (c) 2023 - Restate Software, Inc., Restate GmbH -// -// This file is part of the Restate Java SDK, -// which is released under the MIT license. -// -// You can find a copy of the license in file LICENSE in the root -// directory of this repository or package, or at -// https://github.com/restatedev/sdk-java/blob/main/LICENSE -package dev.restate.sdk.core.kotlinapi - -import dev.restate.sdk.common.StateKey -import dev.restate.sdk.common.TimeoutException -import dev.restate.sdk.core.AsyncResultTestSuite -import dev.restate.sdk.core.TestDefinitions.* -import dev.restate.sdk.core.TestSerdes -import dev.restate.sdk.core.kotlinapi.KotlinAPITests.Companion.callGreeterGreetService -import dev.restate.sdk.core.kotlinapi.KotlinAPITests.Companion.testDefinitionForVirtualObject -import dev.restate.sdk.kotlin.* -import java.util.stream.Stream -import kotlin.time.Duration.Companion.days - -class AsyncResultTest : AsyncResultTestSuite() { - override fun reverseAwaitOrder(): TestInvocationBuilder = - testDefinitionForVirtualObject("ReverseAwaitOrder") { ctx, _: Unit -> - val a1: DurableFuture = callGreeterGreetService(ctx, "Francesco") - val a2: DurableFuture = callGreeterGreetService(ctx, "Till") - - val a2Res: String = a2.await() - ctx.set(StateKey.of("A2", TestSerdes.STRING), a2Res) - - val a1Res: String = a1.await() - return@testDefinitionForVirtualObject "$a1Res-$a2Res" - } - - override fun awaitTwiceTheSameAwaitable(): TestInvocationBuilder = - testDefinitionForVirtualObject("AwaitTwiceTheSameAwaitable") { ctx, _: Unit -> - val a = callGreeterGreetService(ctx, "Francesco") - return@testDefinitionForVirtualObject "${a.await()}-${a.await()}" - } - - override fun awaitAll(): TestInvocationBuilder = - testDefinitionForVirtualObject("AwaitAll") { ctx, _: Unit -> - val a1 = callGreeterGreetService(ctx, "Francesco") - val a2 = callGreeterGreetService(ctx, "Till") - - return@testDefinitionForVirtualObject listOf(a1, a2) - .awaitAll() - .joinToString(separator = "-") - } - - override fun awaitAny(): TestInvocationBuilder = - testDefinitionForVirtualObject("AwaitAny") { ctx, _: Unit -> - val a1 = callGreeterGreetService(ctx, "Francesco") - val a2 = callGreeterGreetService(ctx, "Till") - - return@testDefinitionForVirtualObject DurableFuture.any(a1, a2) - .map { it -> if (it == 0) a1.await() else a2.await() } - .await() - } - - private fun awaitSelect(): TestInvocationBuilder = - testDefinitionForVirtualObject("AwaitSelect") { ctx, _: Unit -> - val a1 = callGreeterGreetService(ctx, "Francesco") - val a2 = callGreeterGreetService(ctx, "Till") - return@testDefinitionForVirtualObject select { - a1.onAwait { it } - a2.onAwait { it } - } - .await() - } - - override fun combineAnyWithAll(): TestInvocationBuilder = - testDefinitionForVirtualObject("CombineAnyWithAll") { ctx, _: Unit -> - val a1 = ctx.awakeable(TestSerdes.STRING) - val a2 = ctx.awakeable(TestSerdes.STRING) - val a3 = ctx.awakeable(TestSerdes.STRING) - val a4 = ctx.awakeable(TestSerdes.STRING) - - val a12 = DurableFuture.any(a1, a2).map { if (it == 0) a1.await() else a2.await() } - val a23 = DurableFuture.any(a2, a3).map { if (it == 0) a2.await() else a3.await() } - val a34 = DurableFuture.any(a3, a4).map { if (it == 0) a3.await() else a4.await() } - DurableFuture.all(a12, a23, a34).await() - - return@testDefinitionForVirtualObject a12.await() + a23.await() + a34.await() - } - - override fun awaitAnyIndex(): TestInvocationBuilder = - testDefinitionForVirtualObject("AwaitAnyIndex") { ctx, _: Unit -> - val a1 = ctx.awakeable(TestSerdes.STRING) - val a2 = ctx.awakeable(TestSerdes.STRING) - val a3 = ctx.awakeable(TestSerdes.STRING) - val a4 = ctx.awakeable(TestSerdes.STRING) - - return@testDefinitionForVirtualObject DurableFuture.any(a1, DurableFuture.all(a2, a3), a4) - .await() - .toString() - } - - override fun awaitOnAlreadyResolvedAwaitables(): TestInvocationBuilder = - testDefinitionForVirtualObject("AwaitOnAlreadyResolvedAwaitables") { ctx, _: Unit -> - val a1 = ctx.awakeable(TestSerdes.STRING) - val a2 = ctx.awakeable(TestSerdes.STRING) - val a12 = DurableFuture.all(a1, a2) - val a12and1 = DurableFuture.all(a12, a1) - val a121and12 = DurableFuture.all(a12and1, a12) - a12and1.await() - a121and12.await() - - return@testDefinitionForVirtualObject a1.await() + a2.await() - } - - override fun awaitWithTimeout(): TestInvocationBuilder = - testDefinitionForVirtualObject("AwaitWithTimeout") { ctx, _: Unit -> - val a1 = callGreeterGreetService(ctx, "Francesco") - return@testDefinitionForVirtualObject try { - a1.await(1.days) - } catch (_: TimeoutException) { - "timeout" - } - } - - override fun definitions(): Stream = - Stream.concat(super.definitions(), super.anyTestDefinitions { awaitSelect() }) -} diff --git a/sdk-core/src/test/kotlin/dev/restate/sdk/core/kotlinapi/AwakeableIdTest.kt b/sdk-core/src/test/kotlin/dev/restate/sdk/core/kotlinapi/AwakeableIdTest.kt deleted file mode 100644 index 2ab5398ec..000000000 --- a/sdk-core/src/test/kotlin/dev/restate/sdk/core/kotlinapi/AwakeableIdTest.kt +++ /dev/null @@ -1,23 +0,0 @@ -// Copyright (c) 2023 - Restate Software, Inc., Restate GmbH -// -// This file is part of the Restate Java SDK, -// which is released under the MIT license. -// -// You can find a copy of the license in file LICENSE in the root -// directory of this repository or package, or at -// https://github.com/restatedev/sdk-java/blob/main/LICENSE -package dev.restate.sdk.core.kotlinapi - -import dev.restate.sdk.core.AwakeableIdTestSuite -import dev.restate.sdk.core.TestDefinitions -import dev.restate.sdk.core.kotlinapi.KotlinAPITests.Companion.testDefinitionForService -import dev.restate.sdk.kotlin.* - -class AwakeableIdTest : AwakeableIdTestSuite() { - - override fun returnAwakeableId(): TestDefinitions.TestInvocationBuilder = - testDefinitionForService("ReturnAwakeableId") { ctx, _: Unit -> - val awakeable: Awakeable = ctx.awakeable() - awakeable.id - } -} diff --git a/sdk-core/src/test/kotlin/dev/restate/sdk/core/kotlinapi/CallTest.kt b/sdk-core/src/test/kotlin/dev/restate/sdk/core/kotlinapi/CallTest.kt deleted file mode 100644 index 9b21532c5..000000000 --- a/sdk-core/src/test/kotlin/dev/restate/sdk/core/kotlinapi/CallTest.kt +++ /dev/null @@ -1,40 +0,0 @@ -// Copyright (c) 2023 - Restate Software, Inc., Restate GmbH -// -// This file is part of the Restate Java SDK, -// which is released under the MIT license. -// -// You can find a copy of the license in file LICENSE in the root -// directory of this repository or package, or at -// https://github.com/restatedev/sdk-java/blob/main/LICENSE -package dev.restate.sdk.core.kotlinapi - -import dev.restate.common.Request -import dev.restate.common.Slice -import dev.restate.common.Target -import dev.restate.sdk.core.CallTestSuite -import dev.restate.sdk.core.kotlinapi.KotlinAPITests.Companion.testDefinitionForService -import dev.restate.serde.Serde - -class CallTest : CallTestSuite() { - - override fun oneWayCall( - target: Target, - idempotencyKey: String, - headers: Map, - body: Slice, - ) = - testDefinitionForService("OneWayCall") { ctx, _: Unit -> - val ignored = - ctx.send( - Request.of(target, Serde.SLICE, Serde.RAW, body) - .headers(headers) - .idempotencyKey(idempotencyKey) - ) - } - - override fun implicitCancellation(target: Target, body: Slice) = - testDefinitionForService("ImplicitCancellation") { ctx, _: Unit -> - val ignored = - ctx.call(Request.of(target, Serde.SLICE, Serde.RAW, body)).await() - } -} diff --git a/sdk-core/src/test/kotlin/dev/restate/sdk/core/kotlinapi/CodegenDiscoveryTest.kt b/sdk-core/src/test/kotlin/dev/restate/sdk/core/kotlinapi/CodegenDiscoveryTest.kt deleted file mode 100644 index b2783f592..000000000 --- a/sdk-core/src/test/kotlin/dev/restate/sdk/core/kotlinapi/CodegenDiscoveryTest.kt +++ /dev/null @@ -1,94 +0,0 @@ -// Copyright (c) 2023 - Restate Software, Inc., Restate GmbH -// -// This file is part of the Restate Java SDK, -// which is released under the MIT license. -// -// You can find a copy of the license in file LICENSE in the root -// directory of this repository or package, or at -// https://github.com/restatedev/sdk-java/blob/main/LICENSE -package dev.restate.sdk.core.kotlinapi - -import dev.restate.sdk.core.AssertUtils.assertThatDiscovery -import dev.restate.sdk.core.generated.manifest.Handler -import dev.restate.sdk.core.generated.manifest.Input -import dev.restate.sdk.core.generated.manifest.Output -import dev.restate.sdk.core.generated.manifest.Service -import dev.restate.sdk.kotlin.endpoint.* -import org.assertj.core.api.Assertions -import org.assertj.core.api.InstanceOfAssertFactories.type -import org.junit.jupiter.api.Test - -class CodegenDiscoveryTest { - - @Test - fun checkCustomInputContentType() { - assertThatDiscovery(CodegenTest.RawInputOutput()) - .extractingService("RawInputOutput") - .extractingHandler("rawInputWithCustomCt") - .extracting({ it.input }, type(Input::class.java)) - .extracting { it.contentType } - .isEqualTo("application/vnd.my.custom") - } - - @Test - fun checkCustomInputAcceptContentType() { - assertThatDiscovery(CodegenTest.RawInputOutput()) - .extractingService("RawInputOutput") - .extractingHandler("rawInputWithCustomAccept") - .extracting({ it.input }, type(Input::class.java)) - .extracting { it.contentType } - .isEqualTo("application/*") - } - - @Test - fun checkCustomOutputContentType() { - assertThatDiscovery(CodegenTest.RawInputOutput()) - .extractingService("RawInputOutput") - .extractingHandler("rawOutputWithCustomCT") - .extracting({ it.output }, type(Output::class.java)) - .extracting { it.contentType } - .isEqualTo("application/vnd.my.custom") - } - - @Test - fun explicitNames() { - assertThatDiscovery( - object : GreeterWithExplicitName { - override fun greet(context: dev.restate.sdk.kotlin.Context, request: String): String { - TODO("Not yet implemented") - } - } - ) - .extractingService("MyExplicitName") - .extractingHandler("my_greeter") - Assertions.assertThat(GreeterWithExplicitNameHandlers.Metadata.SERVICE_NAME) - .isEqualTo("MyExplicitName") - } - - @Test - fun workflowType() { - assertThatDiscovery(CodegenTest.MyWorkflow()) - .extractingService("MyWorkflow") - .returns(Service.Ty.WORKFLOW) { obj -> obj.ty } - .extractingHandler("run") - .returns(Handler.Ty.WORKFLOW) { obj -> obj.ty } - } - - @Test - fun usingTransformer() { - assertThatDiscovery( - endpoint { - bind(CodegenTest.RawInputOutput()) { - it.documentation = "My service documentation" - it.configureHandler("rawInputWithCustomCt") { - it.documentation = "My handler documentation" - } - } - } - ) - .extractingService("RawInputOutput") - .returns("My service documentation", Service::getDocumentation) - .extractingHandler("rawInputWithCustomCt") - .returns("My handler documentation", Handler::getDocumentation) - } -} diff --git a/sdk-core/src/test/kotlin/dev/restate/sdk/core/kotlinapi/CodegenTest.kt b/sdk-core/src/test/kotlin/dev/restate/sdk/core/kotlinapi/CodegenTest.kt deleted file mode 100644 index 9ee50f076..000000000 --- a/sdk-core/src/test/kotlin/dev/restate/sdk/core/kotlinapi/CodegenTest.kt +++ /dev/null @@ -1,447 +0,0 @@ -// Copyright (c) 2023 - Restate Software, Inc., Restate GmbH -// -// This file is part of the Restate Java SDK, -// which is released under the MIT license. -// -// You can find a copy of the license in file LICENSE in the root -// directory of this repository or package, or at -// https://github.com/restatedev/sdk-java/blob/main/LICENSE -package dev.restate.sdk.core.kotlinapi - -import dev.restate.common.Slice -import dev.restate.common.Target -import dev.restate.sdk.annotation.* -import dev.restate.sdk.core.TestDefinitions -import dev.restate.sdk.core.TestDefinitions.TestDefinition -import dev.restate.sdk.core.TestDefinitions.testInvocation -import dev.restate.sdk.core.TestSerdes -import dev.restate.sdk.core.statemachine.ProtoUtils.* -import dev.restate.sdk.kotlin.* -import dev.restate.serde.Serde -import dev.restate.serde.SerdeFactory -import dev.restate.serde.TypeRef -import dev.restate.serde.TypeTag -import dev.restate.serde.kotlinx.* -import java.util.stream.Stream -import kotlinx.serialization.Serializable - -class CodegenTest : TestDefinitions.TestSuite { - @Service - class ServiceGreeter { - @Handler - suspend fun greet(context: Context, request: String): String { - return request - } - } - - @VirtualObject - class ObjectGreeter { - @Exclusive - suspend fun greet(context: ObjectContext, request: String): String { - return request - } - - @Handler - @Shared - suspend fun sharedGreet(context: SharedObjectContext, request: String): String { - return request - } - } - - @VirtualObject - class NestedDataClass { - @Serializable data class Input(val a: String) - - @Serializable data class Output(val a: String) - - @Exclusive - suspend fun greet(context: ObjectContext, request: Input): Output { - return Output(request.a) - } - - @Exclusive - suspend fun complexType( - context: ObjectContext, - request: Map>, - ): Map> { - return mapOf() - } - } - - @VirtualObject - interface GreeterInterface { - @Exclusive suspend fun greet(context: ObjectContext, request: String): String - } - - private class ObjectGreeterImplementedFromInterface : GreeterInterface { - override suspend fun greet(context: ObjectContext, request: String): String { - return request - } - } - - @Service - @Name("Empty") - class Empty { - @Handler - suspend fun emptyInput(context: Context): String { - val client = CodegenTestEmptyClient.fromContext(context) - return client.emptyInput().await() - } - - @Handler - suspend fun emptyOutput(context: Context, request: String) { - val client = CodegenTestEmptyClient.fromContext(context) - client.emptyOutput(request).await() - } - - @Handler - suspend fun emptyInputOutput(context: Context) { - val client = CodegenTestEmptyClient.fromContext(context) - client.emptyInputOutput().await() - } - } - - @Service - @Name("PrimitiveTypes") - class PrimitiveTypes { - @Handler - suspend fun primitiveOutput(context: Context): Int { - val client = CodegenTestPrimitiveTypesClient.fromContext(context) - return client.primitiveOutput().await() - } - - @Handler - suspend fun primitiveInput(context: Context, input: Int) { - val client = CodegenTestPrimitiveTypesClient.fromContext(context) - client.primitiveInput(input).await() - } - } - - @VirtualObject - class CornerCases { - @Exclusive - suspend fun send(context: ObjectContext, request: String): String { - // Just needs to compile - return CodegenTestCornerCasesClient.fromContext(context, request)._send("my_send").await() - } - - @Exclusive - suspend fun returnNull(context: ObjectContext, request: String?): String? { - return CodegenTestCornerCasesClient.fromContext(context, context.key()) - .returnNull(request) {} - .await() - } - - @Exclusive - suspend fun badReturnTypeInferred(context: ObjectContext): Unit { - CodegenTestCornerCasesClient.fromContext(context, context.key()) - .send() - .badReturnTypeInferred() - } - } - - @Workflow - class WorkflowCornerCases { - @Workflow - fun process(context: WorkflowContext, request: String): String { - return "" - } - - @Shared - suspend fun submit(context: SharedWorkflowContext, request: String): String { - // Just needs to compile - val ignored: String = - CodegenTestWorkflowCornerCasesClient.connect("invalid", request)._submit("my_send") - CodegenTestWorkflowCornerCasesClient.connect("invalid", request).submit("my_send") - return CodegenTestWorkflowCornerCasesClient.connect("invalid", request) - .workflowHandle() - .output - .response() - .value - } - } - - @Service - @Name("RawInputOutput") - class RawInputOutput { - @Handler - @Raw - suspend fun rawOutput(context: Context): ByteArray { - val client: CodegenTestRawInputOutputClient.ContextClient = - CodegenTestRawInputOutputClient.fromContext(context) - return client.rawOutput().await() - } - - @Handler - @Raw(contentType = "application/vnd.my.custom") - suspend fun rawOutputWithCustomCT(context: Context): ByteArray { - val client: CodegenTestRawInputOutputClient.ContextClient = - CodegenTestRawInputOutputClient.fromContext(context) - return client.rawOutputWithCustomCT().await() - } - - @Handler - suspend fun rawInput(context: Context, @Raw input: ByteArray) { - val client: CodegenTestRawInputOutputClient.ContextClient = - CodegenTestRawInputOutputClient.fromContext(context) - client.rawInput(input).await() - } - - @Handler - suspend fun rawInputWithCustomCt( - context: Context, - @Raw(contentType = "application/vnd.my.custom") input: ByteArray, - ) { - val client: CodegenTestRawInputOutputClient.ContextClient = - CodegenTestRawInputOutputClient.fromContext(context) - client.rawInputWithCustomCt(input).await() - } - - @Handler - suspend fun rawInputWithCustomAccept( - context: Context, - @Accept("application/*") @Raw(contentType = "application/vnd.my.custom") input: ByteArray, - ) { - val client: CodegenTestRawInputOutputClient.ContextClient = - CodegenTestRawInputOutputClient.fromContext(context) - client.rawInputWithCustomCt(input).await() - } - } - - @Workflow - @Name("MyWorkflow") - class MyWorkflow { - @Workflow - suspend fun run(context: WorkflowContext, myInput: String) { - val client = CodegenTestMyWorkflowClient.fromContext(context, context.key()) - client.send().sharedHandler(myInput) - } - - @Handler - suspend fun sharedHandler(context: SharedWorkflowContext, myInput: String): String { - val client = CodegenTestMyWorkflowClient.fromContext(context, context.key()) - return client.sharedHandler(myInput).await() - } - } - - class MyCustomSerdeFactory : SerdeFactory { - override fun create(typeTag: TypeTag): Serde { - check(typeTag is KotlinSerializationSerdeFactory.KtTypeTag) - check(typeTag.type == Byte::class) - return Serde.using({ b -> byteArrayOf(b) }, { it[0] }) as Serde - } - - override fun create(typeRef: TypeRef): Serde { - check(typeRef.type == Byte::class) - return Serde.using({ b -> byteArrayOf(b) }, { it[0] }) as Serde - } - - override fun create(clazz: Class?): Serde { - check(clazz == Byte::class.java) - return Serde.using({ b -> byteArrayOf(b) }, { it[0] }) as Serde - } - } - - @CustomSerdeFactory(MyCustomSerdeFactory::class) - @Service - @Name("CustomSerdeService") - class CustomSerdeService { - @Handler - suspend fun echo(context: Context, input: Byte): Byte { - return input - } - } - - override fun definitions(): Stream { - return Stream.of( - testInvocation({ ServiceGreeter() }, "greet") - .withInput(startMessage(1), inputCmd("Francesco")) - .onlyBidiStream() - .expectingOutput(outputCmd("Francesco"), END_MESSAGE), - testInvocation({ ObjectGreeter() }, "greet") - .withInput(startMessage(1, "slinkydeveloper"), inputCmd("Francesco")) - .onlyBidiStream() - .expectingOutput(outputCmd("Francesco"), END_MESSAGE), - testInvocation({ ObjectGreeter() }, "sharedGreet") - .withInput(startMessage(1, "slinkydeveloper"), inputCmd("Francesco")) - .onlyBidiStream() - .expectingOutput(outputCmd("Francesco"), END_MESSAGE), - testInvocation({ NestedDataClass() }, "greet") - .withInput( - startMessage(1, "slinkydeveloper"), - inputCmd(jsonSerde(), NestedDataClass.Input("123")), - ) - .onlyBidiStream() - .expectingOutput( - outputCmd(jsonSerde(), NestedDataClass.Output("123")), - END_MESSAGE, - ), - testInvocation({ ObjectGreeterImplementedFromInterface() }, "greet") - .withInput(startMessage(1, "slinkydeveloper"), inputCmd("Francesco")) - .onlyBidiStream() - .expectingOutput(outputCmd("Francesco"), END_MESSAGE), - testInvocation({ Empty() }, "emptyInput") - .withInput(startMessage(1), inputCmd(), callCompletion(2, "Till")) - .onlyBidiStream() - .expectingOutput( - callCmd(1, 2, Target.service("Empty", "emptyInput")), - outputCmd("Till"), - END_MESSAGE, - ) - .named("empty output"), - testInvocation({ Empty() }, "emptyOutput") - .withInput(startMessage(1), inputCmd("Francesco"), callCompletion(2, Serde.VOID, null)) - .onlyBidiStream() - .expectingOutput( - callCmd(1, 2, Target.service("Empty", "emptyOutput"), "Francesco"), - outputCmd(), - END_MESSAGE, - ) - .named("empty output"), - testInvocation({ Empty() }, "emptyInputOutput") - .withInput(startMessage(1), inputCmd("Francesco"), callCompletion(2, Serde.VOID, null)) - .onlyBidiStream() - .expectingOutput( - callCmd(1, 2, Target.service("Empty", "emptyInputOutput")), - outputCmd(), - END_MESSAGE, - ) - .named("empty input and empty output"), - testInvocation({ PrimitiveTypes() }, "primitiveOutput") - .withInput(startMessage(1), inputCmd(), callCompletion(2, TestSerdes.INT, 10)) - .onlyBidiStream() - .expectingOutput( - callCmd( - 1, - 2, - Target.service("PrimitiveTypes", "primitiveOutput"), - Serde.VOID, - null, - ), - outputCmd(TestSerdes.INT, 10), - END_MESSAGE, - ) - .named("primitive output"), - testInvocation({ PrimitiveTypes() }, "primitiveInput") - .withInput(startMessage(1), inputCmd(10), callCompletion(2, Serde.VOID, null)) - .onlyBidiStream() - .expectingOutput( - callCmd( - 1, - 2, - Target.service("PrimitiveTypes", "primitiveInput"), - TestSerdes.INT, - 10, - ), - outputCmd(), - END_MESSAGE, - ) - .named("primitive input"), - testInvocation({ RawInputOutput() }, "rawInput") - .withInput( - startMessage(1), - inputCmd("{{".toByteArray()), - callCompletion(2, KotlinSerializationSerdeFactory.UNIT, Unit), - ) - .onlyBidiStream() - .expectingOutput( - callCmd(1, 2, Target.service("RawInputOutput", "rawInput"), "{{".toByteArray()), - outputCmd(), - END_MESSAGE, - ), - testInvocation({ RawInputOutput() }, "rawInputWithCustomCt") - .withInput( - startMessage(1), - inputCmd("{{".toByteArray()), - callCompletion(2, KotlinSerializationSerdeFactory.UNIT, Unit), - ) - .onlyBidiStream() - .expectingOutput( - callCmd( - 1, - 2, - Target.service("RawInputOutput", "rawInputWithCustomCt"), - "{{".toByteArray(), - ), - outputCmd(), - END_MESSAGE, - ), - testInvocation({ RawInputOutput() }, "rawOutput") - .withInput( - startMessage(1), - inputCmd(), - callCompletion(2, Serde.RAW, "{{".toByteArray()), - ) - .onlyBidiStream() - .expectingOutput( - callCmd( - 1, - 2, - Target.service("RawInputOutput", "rawOutput"), - KotlinSerializationSerdeFactory.UNIT, - Unit, - ), - outputCmd("{{".toByteArray()), - END_MESSAGE, - ), - testInvocation({ RawInputOutput() }, "rawOutputWithCustomCT") - .withInput( - startMessage(1), - inputCmd(), - callCompletion(2, Serde.RAW, "{{".toByteArray()), - ) - .onlyBidiStream() - .expectingOutput( - callCmd( - 1, - 2, - Target.service("RawInputOutput", "rawOutputWithCustomCT"), - KotlinSerializationSerdeFactory.UNIT, - Unit, - ), - outputCmd("{{".toByteArray()), - END_MESSAGE, - ), - testInvocation({ CornerCases() }, "returnNull") - .withInput( - startMessage(1, "mykey"), - inputCmd(jsonSerde(), null), - callCompletion(2, jsonSerde(), null), - ) - .onlyBidiStream() - .expectingOutput( - callCmd( - 1, - 2, - Target.virtualObject("CodegenTestCornerCases", "mykey", "returnNull"), - jsonSerde(), - null, - ), - outputCmd(jsonSerde(), null), - END_MESSAGE, - ), - testInvocation({ CornerCases() }, "badReturnTypeInferred") - .withInput(startMessage(1, "mykey"), inputCmd()) - .onlyBidiStream() - .expectingOutput( - oneWayCallCmd( - 1, - Target.virtualObject( - "CodegenTestCornerCases", - "mykey", - "badReturnTypeInferred", - ), - null, - null, - Slice.EMPTY, - ), - outputCmd(), - END_MESSAGE, - ), - testInvocation({ CustomSerdeService() }, "echo") - .withInput(startMessage(1), inputCmd(byteArrayOf(1))) - .onlyBidiStream() - .expectingOutput(outputCmd(byteArrayOf(1)), END_MESSAGE), - ) - } -} diff --git a/sdk-core/src/test/kotlin/dev/restate/sdk/core/kotlinapi/EagerStateTest.kt b/sdk-core/src/test/kotlin/dev/restate/sdk/core/kotlinapi/EagerStateTest.kt deleted file mode 100644 index b163050a1..000000000 --- a/sdk-core/src/test/kotlin/dev/restate/sdk/core/kotlinapi/EagerStateTest.kt +++ /dev/null @@ -1,66 +0,0 @@ -// Copyright (c) 2023 - Restate Software, Inc., Restate GmbH -// -// This file is part of the Restate Java SDK, -// which is released under the MIT license. -// -// You can find a copy of the license in file LICENSE in the root -// directory of this repository or package, or at -// https://github.com/restatedev/sdk-java/blob/main/LICENSE -package dev.restate.sdk.core.kotlinapi - -import dev.restate.sdk.common.StateKey -import dev.restate.sdk.core.EagerStateTestSuite -import dev.restate.sdk.core.TestDefinitions.TestInvocationBuilder -import dev.restate.sdk.core.TestSerdes -import dev.restate.sdk.core.kotlinapi.KotlinAPITests.Companion.testDefinitionForVirtualObject -import org.assertj.core.api.AssertionsForClassTypes.assertThat - -class EagerStateTest : EagerStateTestSuite() { - override fun getEmpty(): TestInvocationBuilder = - testDefinitionForVirtualObject("GetEmpty") { ctx, _: Unit -> - val stateIsEmpty = ctx.get(StateKey.of("STATE", TestSerdes.STRING)) == null - stateIsEmpty.toString() - } - - override fun get(): TestInvocationBuilder = - testDefinitionForVirtualObject("GetEmpty") { ctx, _: Unit -> - ctx.get(StateKey.of("STATE", TestSerdes.STRING))!! - } - - override fun getAppendAndGet(): TestInvocationBuilder = - testDefinitionForVirtualObject("GetAppendAndGet") { ctx, name: String -> - val oldState = ctx.get(StateKey.of("STATE", TestSerdes.STRING))!! - ctx.set(StateKey.of("STATE", TestSerdes.STRING), oldState + name) - ctx.get(StateKey.of("STATE", TestSerdes.STRING))!! - } - - override fun getClearAndGet(): TestInvocationBuilder = - testDefinitionForVirtualObject("GetClearAndGet") { ctx, _: Unit -> - val oldState = ctx.get(StateKey.of("STATE", TestSerdes.STRING))!! - ctx.clear(StateKey.of("STATE", TestSerdes.STRING)) - assertThat(ctx.get(StateKey.of("STATE", TestSerdes.STRING))).isNull() - oldState - } - - override fun getClearAllAndGet(): TestInvocationBuilder = - testDefinitionForVirtualObject("GetClearAllAndGet") { ctx, _: Unit -> - val oldState = ctx.get(StateKey.of("STATE", TestSerdes.STRING))!! - - ctx.clearAll() - - assertThat(ctx.get(StateKey.of("STATE", TestSerdes.STRING))).isNull() - assertThat(ctx.get(StateKey.of("ANOTHER_STATE", TestSerdes.STRING))).isNull() - oldState - } - - override fun listKeys(): TestInvocationBuilder = - testDefinitionForVirtualObject("ListKeys") { ctx, _: Unit -> - ctx.stateKeys().joinToString(separator = ",") - } - - override fun consecutiveGetWithEmpty(): TestInvocationBuilder = - testDefinitionForVirtualObject("ConsecutiveGetWithEmpty") { ctx, _: Unit -> - assertThat(ctx.get(StateKey.of("key-0", TestSerdes.STRING))).isNull() - assertThat(ctx.get(StateKey.of("key-0", TestSerdes.STRING))).isNull() - } -} diff --git a/sdk-core/src/test/kotlin/dev/restate/sdk/core/kotlinapi/GreeterWithExplicitName.kt b/sdk-core/src/test/kotlin/dev/restate/sdk/core/kotlinapi/GreeterWithExplicitName.kt deleted file mode 100644 index a01eaedd3..000000000 --- a/sdk-core/src/test/kotlin/dev/restate/sdk/core/kotlinapi/GreeterWithExplicitName.kt +++ /dev/null @@ -1,18 +0,0 @@ -// Copyright (c) 2023 - Restate Software, Inc., Restate GmbH -// -// This file is part of the Restate Java SDK, -// which is released under the MIT license. -// -// You can find a copy of the license in file LICENSE in the root -// directory of this repository or package, or at -// https://github.com/restatedev/sdk-java/blob/main/LICENSE -package dev.restate.sdk.core.kotlinapi - -import dev.restate.sdk.annotation.* -import dev.restate.sdk.kotlin.* - -@Service -@Name("MyExplicitName") -interface GreeterWithExplicitName { - @Handler @Name("my_greeter") fun greet(context: Context, request: String): String -} diff --git a/sdk-core/src/test/kotlin/dev/restate/sdk/core/kotlinapi/InvocationIdTest.kt b/sdk-core/src/test/kotlin/dev/restate/sdk/core/kotlinapi/InvocationIdTest.kt deleted file mode 100644 index a4e504f52..000000000 --- a/sdk-core/src/test/kotlin/dev/restate/sdk/core/kotlinapi/InvocationIdTest.kt +++ /dev/null @@ -1,21 +0,0 @@ -// Copyright (c) 2023 - Restate Software, Inc., Restate GmbH -// -// This file is part of the Restate Java SDK, -// which is released under the MIT license. -// -// You can find a copy of the license in file LICENSE in the root -// directory of this repository or package, or at -// https://github.com/restatedev/sdk-java/blob/main/LICENSE -package dev.restate.sdk.core.kotlinapi - -import dev.restate.sdk.core.InvocationIdTestSuite -import dev.restate.sdk.core.TestDefinitions.TestInvocationBuilder -import dev.restate.sdk.core.kotlinapi.KotlinAPITests.Companion.testDefinitionForService - -class InvocationIdTest : InvocationIdTestSuite() { - - override fun returnInvocationId(): TestInvocationBuilder = - testDefinitionForService("ReturnInvocationId") { ctx, _: Unit -> - ctx.request().invocationId().toString() - } -} diff --git a/sdk-core/src/test/kotlin/dev/restate/sdk/core/kotlinapi/KotlinAPITests.kt b/sdk-core/src/test/kotlin/dev/restate/sdk/core/kotlinapi/KotlinAPITests.kt deleted file mode 100644 index e50305c56..000000000 --- a/sdk-core/src/test/kotlin/dev/restate/sdk/core/kotlinapi/KotlinAPITests.kt +++ /dev/null @@ -1,141 +0,0 @@ -// Copyright (c) 2023 - Restate Software, Inc., Restate GmbH -// -// This file is part of the Restate Java SDK, -// which is released under the MIT license. -// -// You can find a copy of the license in file LICENSE in the root -// directory of this repository or package, or at -// https://github.com/restatedev/sdk-java/blob/main/LICENSE -package dev.restate.sdk.core.kotlinapi - -import dev.restate.common.Request -import dev.restate.sdk.core.* -import dev.restate.sdk.core.TestDefinitions.TestExecutor -import dev.restate.sdk.core.TestDefinitions.TestInvocationBuilder -import dev.restate.sdk.core.kotlinapi.reflections.ReflectionTest -import dev.restate.sdk.core.statemachine.ProtoUtils -import dev.restate.sdk.endpoint.definition.HandlerDefinition -import dev.restate.sdk.endpoint.definition.HandlerType -import dev.restate.sdk.endpoint.definition.ServiceDefinition -import dev.restate.sdk.endpoint.definition.ServiceType -import dev.restate.sdk.kotlin.* -import dev.restate.serde.kotlinx.* -import java.util.stream.Stream -import kotlinx.coroutines.Dispatchers - -class KotlinAPITests : TestRunner() { - override fun executors(): Stream { - return Stream.of(MockRequestResponse.INSTANCE, MockBidiStream.INSTANCE) - } - - public override fun definitions(): Stream { - return Stream.of( - AwakeableIdTest(), - AsyncResultTest(), - CallTest(), - EagerStateTest(), - StateTest(), - InvocationIdTest(), - OnlyInputAndOutputTest(), - PromiseTest(), - SideEffectTest(), - SleepTest(), - StateMachineFailuresTest(), - UserFailuresTest(), - RandomTest(), - CodegenTest(), - ReflectionTest(), - ) - } - - companion object { - inline fun testDefinitionForService( - name: String, - noinline runner: suspend (Context, REQ) -> RES, - ): TestInvocationBuilder { - return TestDefinitions.testInvocation( - ServiceDefinition.of( - name, - ServiceType.SERVICE, - listOf( - HandlerDefinition.of( - "run", - HandlerType.SHARED, - jsonSerde(), - jsonSerde(), - HandlerRunner.of( - KotlinSerializationSerdeFactory(), - HandlerRunner.Options(Dispatchers.Unconfined), - runner, - ), - ) - ), - ), - "run", - ) - } - - inline fun testDefinitionForVirtualObject( - name: String, - noinline runner: suspend (ObjectContext, REQ) -> RES, - ): TestInvocationBuilder { - return TestDefinitions.testInvocation( - ServiceDefinition.of( - name, - ServiceType.VIRTUAL_OBJECT, - listOf( - HandlerDefinition.of( - "run", - HandlerType.EXCLUSIVE, - jsonSerde(), - jsonSerde(), - HandlerRunner.of( - KotlinSerializationSerdeFactory(), - HandlerRunner.Options(Dispatchers.Unconfined), - runner, - ), - ) - ), - ), - "run", - ) - } - - inline fun testDefinitionForWorkflow( - name: String, - noinline runner: suspend (WorkflowContext, REQ) -> RES, - ): TestInvocationBuilder { - return TestDefinitions.testInvocation( - ServiceDefinition.of( - name, - ServiceType.WORKFLOW, - listOf( - HandlerDefinition.of( - "run", - HandlerType.WORKFLOW, - jsonSerde(), - jsonSerde(), - HandlerRunner.of( - KotlinSerializationSerdeFactory(), - HandlerRunner.Options(Dispatchers.Unconfined), - runner, - ), - ) - ), - ), - "run", - ) - } - - suspend fun callGreeterGreetService(ctx: Context, parameter: String): DurableFuture { - return ctx.call( - Request.of( - ProtoUtils.GREETER_SERVICE_TARGET, - TestSerdes.STRING, - TestSerdes.STRING, - parameter, - ) - ) - } - } -} diff --git a/sdk-core/src/test/kotlin/dev/restate/sdk/core/kotlinapi/MyMetaServiceAnnotation.kt b/sdk-core/src/test/kotlin/dev/restate/sdk/core/kotlinapi/MyMetaServiceAnnotation.kt deleted file mode 100644 index e8c6606a2..000000000 --- a/sdk-core/src/test/kotlin/dev/restate/sdk/core/kotlinapi/MyMetaServiceAnnotation.kt +++ /dev/null @@ -1,13 +0,0 @@ -// Copyright (c) 2023 - Restate Software, Inc., Restate GmbH -// -// This file is part of the Restate Java SDK, -// which is released under the MIT license. -// -// You can find a copy of the license in file LICENSE in the root -// directory of this repository or package, or at -// https://github.com/restatedev/sdk-java/blob/main/LICENSE -package dev.restate.sdk.core.kotlinapi - -import dev.restate.sdk.annotation.Service - -@Service annotation class MyMetaServiceAnnotation(val name: String = "") diff --git a/sdk-core/src/test/kotlin/dev/restate/sdk/core/kotlinapi/OnlyInputAndOutputTest.kt b/sdk-core/src/test/kotlin/dev/restate/sdk/core/kotlinapi/OnlyInputAndOutputTest.kt deleted file mode 100644 index d8bf351ab..000000000 --- a/sdk-core/src/test/kotlin/dev/restate/sdk/core/kotlinapi/OnlyInputAndOutputTest.kt +++ /dev/null @@ -1,19 +0,0 @@ -// Copyright (c) 2023 - Restate Software, Inc., Restate GmbH -// -// This file is part of the Restate Java SDK, -// which is released under the MIT license. -// -// You can find a copy of the license in file LICENSE in the root -// directory of this repository or package, or at -// https://github.com/restatedev/sdk-java/blob/main/LICENSE -package dev.restate.sdk.core.kotlinapi - -import dev.restate.sdk.core.OnlyInputAndOutputTestSuite -import dev.restate.sdk.core.TestDefinitions.TestInvocationBuilder -import dev.restate.sdk.core.kotlinapi.KotlinAPITests.Companion.testDefinitionForService - -class OnlyInputAndOutputTest : OnlyInputAndOutputTestSuite() { - - override fun noSyscallsGreeter(): TestInvocationBuilder = - testDefinitionForService("NoSyscallsGreeter") { _, name: String -> "Hello $name" } -} diff --git a/sdk-core/src/test/kotlin/dev/restate/sdk/core/kotlinapi/PromiseTest.kt b/sdk-core/src/test/kotlin/dev/restate/sdk/core/kotlinapi/PromiseTest.kt deleted file mode 100644 index 4b9989663..000000000 --- a/sdk-core/src/test/kotlin/dev/restate/sdk/core/kotlinapi/PromiseTest.kt +++ /dev/null @@ -1,61 +0,0 @@ -// Copyright (c) 2023 - Restate Software, Inc., Restate GmbH -// -// This file is part of the Restate Java SDK, -// which is released under the MIT license. -// -// You can find a copy of the license in file LICENSE in the root -// directory of this repository or package, or at -// https://github.com/restatedev/sdk-java/blob/main/LICENSE -package dev.restate.sdk.core.kotlinapi - -import dev.restate.sdk.common.TerminalException -import dev.restate.sdk.core.PromiseTestSuite -import dev.restate.sdk.core.TestDefinitions -import dev.restate.sdk.core.kotlinapi.KotlinAPITests.Companion.testDefinitionForWorkflow -import dev.restate.sdk.kotlin.* - -class PromiseTest : PromiseTestSuite() { - override fun awaitPromise(promiseKey: String): TestDefinitions.TestInvocationBuilder = - testDefinitionForWorkflow("AwaitPromise") { ctx, _: Unit -> - ctx.promise(durablePromiseKey(promiseKey)).future().await() - } - - override fun awaitPeekPromise( - promiseKey: String, - emptyCaseReturnValue: String, - ): TestDefinitions.TestInvocationBuilder = - testDefinitionForWorkflow("AwaitPeekPromise") { ctx, _: Unit -> - ctx.promise(durablePromiseKey(promiseKey)).peek().orElse(emptyCaseReturnValue) - } - - override fun awaitIsPromiseCompleted(promiseKey: String): TestDefinitions.TestInvocationBuilder = - testDefinitionForWorkflow("IsCompletedPromise") { ctx, _: Unit -> - ctx.promise(durablePromiseKey(promiseKey)).peek().isReady - } - - override fun awaitResolvePromise( - promiseKey: String, - completionValue: String, - ): TestDefinitions.TestInvocationBuilder = - testDefinitionForWorkflow("ResolvePromise") { ctx, _: Unit -> - try { - ctx.promiseHandle(durablePromiseKey(promiseKey)).resolve(completionValue) - return@testDefinitionForWorkflow true - } catch (e: TerminalException) { - return@testDefinitionForWorkflow false - } - } - - override fun awaitRejectPromise( - promiseKey: String, - rejectReason: String, - ): TestDefinitions.TestInvocationBuilder = - testDefinitionForWorkflow("RejectPromise") { ctx, _: Unit -> - try { - ctx.promiseHandle(durablePromiseKey(promiseKey)).reject(rejectReason) - return@testDefinitionForWorkflow true - } catch (e: TerminalException) { - return@testDefinitionForWorkflow false - } - } -} diff --git a/sdk-core/src/test/kotlin/dev/restate/sdk/core/kotlinapi/RandomTest.kt b/sdk-core/src/test/kotlin/dev/restate/sdk/core/kotlinapi/RandomTest.kt deleted file mode 100644 index 99b7de8d8..000000000 --- a/sdk-core/src/test/kotlin/dev/restate/sdk/core/kotlinapi/RandomTest.kt +++ /dev/null @@ -1,25 +0,0 @@ -// Copyright (c) 2023 - Restate Software, Inc., Restate GmbH -// -// This file is part of the Restate Java SDK, -// which is released under the MIT license. -// -// You can find a copy of the license in file LICENSE in the root -// directory of this repository or package, or at -// https://github.com/restatedev/sdk-java/blob/main/LICENSE -package dev.restate.sdk.core.kotlinapi - -import dev.restate.sdk.core.RandomTestSuite -import dev.restate.sdk.core.TestDefinitions.TestInvocationBuilder -import dev.restate.sdk.core.kotlinapi.KotlinAPITests.Companion.testDefinitionForService -import kotlin.random.Random - -class RandomTest : RandomTestSuite() { - override fun randomShouldBeDeterministic(): TestInvocationBuilder = - testDefinitionForService("RandomShouldBeDeterministic") { ctx, _: Unit -> - ctx.random().nextInt() - } - - override fun getExpectedInt(seed: Long): Int { - return Random(seed).nextInt() - } -} diff --git a/sdk-core/src/test/kotlin/dev/restate/sdk/core/kotlinapi/SideEffectTest.kt b/sdk-core/src/test/kotlin/dev/restate/sdk/core/kotlinapi/SideEffectTest.kt deleted file mode 100644 index e0ead9769..000000000 --- a/sdk-core/src/test/kotlin/dev/restate/sdk/core/kotlinapi/SideEffectTest.kt +++ /dev/null @@ -1,165 +0,0 @@ -// Copyright (c) 2023 - Restate Software, Inc., Restate GmbH -// -// This file is part of the Restate Java SDK, -// which is released under the MIT license. -// -// You can find a copy of the license in file LICENSE in the root -// directory of this repository or package, or at -// https://github.com/restatedev/sdk-java/blob/main/LICENSE -package dev.restate.sdk.core.kotlinapi - -import com.google.protobuf.ByteString -import dev.restate.common.Slice -import dev.restate.sdk.Restate -import dev.restate.sdk.common.RetryPolicy -import dev.restate.sdk.core.SideEffectTestSuite -import dev.restate.sdk.core.TestDefinitions -import dev.restate.sdk.core.TestDefinitions.TestInvocationBuilder -import dev.restate.sdk.core.kotlinapi.KotlinAPITests.Companion.testDefinitionForService -import dev.restate.sdk.endpoint.definition.HandlerDefinition -import dev.restate.sdk.endpoint.definition.HandlerType -import dev.restate.sdk.endpoint.definition.ServiceDefinition -import dev.restate.sdk.endpoint.definition.ServiceType -import dev.restate.sdk.kotlin.* -import dev.restate.serde.kotlinx.KotlinSerializationSerdeFactory -import dev.restate.serde.kotlinx.jsonSerde -import dev.restate.serde.kotlinx.typeTag -import java.util.* -import kotlin.coroutines.coroutineContext -import kotlin.time.Clock -import kotlin.time.Duration.Companion.milliseconds -import kotlin.time.ExperimentalTime -import kotlin.time.Instant -import kotlin.time.toJavaInstant -import kotlin.time.toKotlinDuration -import kotlinx.coroutines.CoroutineName -import kotlinx.coroutines.Dispatchers -import org.assertj.core.api.Assertions - -class SideEffectTest : SideEffectTestSuite() { - - override fun sideEffect(sideEffectOutput: String): TestInvocationBuilder = - testDefinitionForService("SideEffect") { ctx, _: Unit -> - val result = ctx.runBlock { sideEffectOutput } - "Hello $result" - } - - override fun namedSideEffect(name: String, sideEffectOutput: String): TestInvocationBuilder = - testDefinitionForService("SideEffect") { ctx, _: Unit -> - val result = ctx.runBlock(name) { sideEffectOutput } - "Hello $result" - } - - override fun consecutiveSideEffect(sideEffectOutput: String): TestInvocationBuilder = - testDefinitionForService("ConsecutiveSideEffect") { ctx, _: Unit -> - val firstResult = ctx.runBlock { sideEffectOutput } - val secondResult = ctx.runBlock { firstResult.uppercase(Locale.getDefault()) } - "Hello $secondResult" - } - - override fun checkContextSwitching(): TestInvocationBuilder = - TestDefinitions.testInvocation( - ServiceDefinition.of( - "CheckContextSwitching", - ServiceType.SERVICE, - listOf( - HandlerDefinition.of( - "run", - HandlerType.SHARED, - jsonSerde(), - jsonSerde(), - HandlerRunner.of( - KotlinSerializationSerdeFactory(), - HandlerRunner.Options( - Dispatchers.Unconfined + - CoroutineName("CheckContextSwitchingTestCoroutine") - ), - ) { ctx: Context, _: Unit -> - val sideEffectCoroutine = - ctx.runBlock { coroutineContext[CoroutineName]!!.name } - check(sideEffectCoroutine == "CheckContextSwitchingTestCoroutine") { - "Side effect thread is not running within the same coroutine context of the handler method: $sideEffectCoroutine" - } - "Hello" - }, - ) - ), - ), - "run", - ) - - override fun failingSideEffect(name: String, reason: String) = - testDefinitionForService("FailingSideEffect") { ctx, _: Unit -> - ctx.runBlock(name) { throw IllegalStateException(reason) } - } - - override fun awaitAllSideEffectWithFirstFailing( - firstSideEffect: String, - secondSideEffect: String, - successValue: String, - failureReason: String, - ) = - testDefinitionForService("AwaitAllSideEffectWithFirstFailing") { ctx, _: Unit -> - val fut1 = - ctx.runAsync(firstSideEffect) { throw IllegalStateException(failureReason) } - val fut2 = ctx.runAsync(secondSideEffect) { successValue } - listOf(fut1, fut2).awaitAll() - } - - override fun awaitAllSideEffectWithSecondFailing( - firstSideEffect: String, - secondSideEffect: String, - successValue: String, - failureReason: String, - ) = - testDefinitionForService("AwaitAllSideEffectWithSecondFailing") { ctx, _: Unit -> - val fut1 = ctx.runAsync(firstSideEffect) { successValue } - val fut2 = - ctx.runAsync(secondSideEffect) { throw IllegalStateException(failureReason) } - listOf(fut1, fut2).awaitAll() - } - - override fun failingSideEffectWithRetryPolicy(reason: String, retryPolicy: RetryPolicy?) = - testDefinitionForService("FailingSideEffectWithRetryPolicy") { ctx, _: Unit -> - ctx.runBlock( - retryPolicy = - retryPolicy?.let { - RetryPolicy( - initialDelay = it.initialDelay.toKotlinDuration(), - exponentiationFactor = it.exponentiationFactor, - maxDelay = it.maxDelay?.toKotlinDuration(), - maxDuration = it.maxDuration?.toKotlinDuration(), - maxAttempts = it.maxAttempts, - ) - } - ) { - throw IllegalStateException(reason) - } - } - - override fun sideEffectGuard() = - testDefinitionForService("SideEffectGuard") { ctx, _: Unit -> - ctx.runBlock { ctx.sleep(100.milliseconds) } - "" - } - - override fun sideEffectGuardAwait() = - testDefinitionForService("SideEffectGuardAwait") { ctx, _: Unit -> - val timer = ctx.timer(100.milliseconds) - ctx.runBlock { timer.await() } - "" - } - - @OptIn(ExperimentalTime::class) - override fun instantNow() = - testDefinitionForService("InstantNow") { ctx, _: Unit -> Clock.Restate.now() } - - @OptIn(ExperimentalTime::class) - override fun assertIsInstant(bytes: ByteString) { - val instant = - KotlinSerializationSerdeFactory() - .create(typeTag()) - .deserialize(Slice.wrap(bytes.asReadOnlyByteBuffer())) - Assertions.assertThat(instant.toJavaInstant()).isNotNull().isBefore(java.time.Instant.now()) - } -} diff --git a/sdk-core/src/test/kotlin/dev/restate/sdk/core/kotlinapi/SleepTest.kt b/sdk-core/src/test/kotlin/dev/restate/sdk/core/kotlinapi/SleepTest.kt deleted file mode 100644 index 58b069f79..000000000 --- a/sdk-core/src/test/kotlin/dev/restate/sdk/core/kotlinapi/SleepTest.kt +++ /dev/null @@ -1,33 +0,0 @@ -// Copyright (c) 2023 - Restate Software, Inc., Restate GmbH -// -// This file is part of the Restate Java SDK, -// which is released under the MIT license. -// -// You can find a copy of the license in file LICENSE in the root -// directory of this repository or package, or at -// https://github.com/restatedev/sdk-java/blob/main/LICENSE -package dev.restate.sdk.core.kotlinapi - -import dev.restate.sdk.core.SleepTestSuite -import dev.restate.sdk.core.TestDefinitions -import dev.restate.sdk.core.kotlinapi.KotlinAPITests.Companion.testDefinitionForService -import dev.restate.sdk.kotlin.* -import kotlin.time.Duration.Companion.seconds - -class SleepTest : SleepTestSuite() { - - override fun sleepGreeter(): TestDefinitions.TestInvocationBuilder = - testDefinitionForService("SleepGreeter") { ctx, _: Unit -> - ctx.sleep(1.seconds) - "Hello" - } - - override fun manySleeps(): TestDefinitions.TestInvocationBuilder = - testDefinitionForService("ManySleeps") { ctx, _: Unit -> - val durableFutures = mutableListOf>() - for (i in 0..9) { - durableFutures.add(ctx.timer(1.seconds)) - } - durableFutures.awaitAll() - } -} diff --git a/sdk-core/src/test/kotlin/dev/restate/sdk/core/kotlinapi/StateMachineFailuresTest.kt b/sdk-core/src/test/kotlin/dev/restate/sdk/core/kotlinapi/StateMachineFailuresTest.kt deleted file mode 100644 index 3172c5c17..000000000 --- a/sdk-core/src/test/kotlin/dev/restate/sdk/core/kotlinapi/StateMachineFailuresTest.kt +++ /dev/null @@ -1,84 +0,0 @@ -// Copyright (c) 2023 - Restate Software, Inc., Restate GmbH -// -// This file is part of the Restate Java SDK, -// which is released under the MIT license. -// -// You can find a copy of the license in file LICENSE in the root -// directory of this repository or package, or at -// https://github.com/restatedev/sdk-java/blob/main/LICENSE -package dev.restate.sdk.core.kotlinapi - -import dev.restate.sdk.common.AbortedExecutionException -import dev.restate.sdk.common.StateKey -import dev.restate.sdk.common.TerminalException -import dev.restate.sdk.core.StateMachineFailuresTestSuite -import dev.restate.sdk.core.TestDefinitions.TestInvocationBuilder -import dev.restate.sdk.core.TestSerdes -import dev.restate.sdk.core.kotlinapi.KotlinAPITests.Companion.testDefinitionForService -import dev.restate.sdk.core.kotlinapi.KotlinAPITests.Companion.testDefinitionForVirtualObject -import dev.restate.sdk.kotlin.* -import dev.restate.serde.Serde -import java.nio.charset.StandardCharsets -import java.util.concurrent.atomic.AtomicInteger -import kotlin.time.Duration.Companion.milliseconds -import kotlinx.coroutines.CancellationException - -class StateMachineFailuresTest : StateMachineFailuresTestSuite() { - companion object { - private val STATE = - StateKey.of( - "STATE", - Serde.using({ i: Int -> i.toString().toByteArray(StandardCharsets.UTF_8) }) { - b: ByteArray? -> - String(b!!, StandardCharsets.UTF_8).toInt() - }, - ) - } - - override fun getState(nonTerminalExceptionsSeen: AtomicInteger): TestInvocationBuilder = - testDefinitionForVirtualObject("GetState") { ctx, _: Unit -> - try { - ctx.get(STATE) - } catch (e: Throwable) { - // A user should never catch Throwable!!! - if (AbortedExecutionException.INSTANCE == e) { - throw e - } - // A user should never catch Throwable!!! - if (e !is CancellationException && e !is TerminalException) { - nonTerminalExceptionsSeen.addAndGet(1) - } else { - throw e - } - } - "Francesco" - } - - override fun sideEffectFailure(serde: Serde): TestInvocationBuilder = - testDefinitionForService("SideEffectFailure") { ctx, _: Unit -> - ctx.runBlock(serde) { 0 } - "Francesco" - } - - override fun awaitRunAfterProgressWasMade(): TestInvocationBuilder = - testDefinitionForService("AwaitRunAfterProgressWasMade") { ctx, _: Unit -> - val runFuture = ctx.runAsync("my-side-effect") { "result" } - runFuture.await() - null - } - - override fun awaitSleepAfterProgressWasMade(): TestInvocationBuilder = - testDefinitionForService("AwaitSleepAfterProgressWasMade") { ctx, _: Unit -> - val sleepFuture = ctx.timer(0.milliseconds) - sleepFuture.await() - null - } - - override fun awaitAwakeableAfterProgressWasMade(): TestInvocationBuilder = - testDefinitionForService("AwaitAwakeableAfterProgressWasMade") { ctx, _: Unit - -> - val awakeable = ctx.awakeable(TestSerdes.STRING) - awakeable.await() - null - } -} diff --git a/sdk-core/src/test/kotlin/dev/restate/sdk/core/kotlinapi/StateTest.kt b/sdk-core/src/test/kotlin/dev/restate/sdk/core/kotlinapi/StateTest.kt deleted file mode 100644 index 4f63035f0..000000000 --- a/sdk-core/src/test/kotlin/dev/restate/sdk/core/kotlinapi/StateTest.kt +++ /dev/null @@ -1,86 +0,0 @@ -// Copyright (c) 2023 - Restate Software, Inc., Restate GmbH -// -// This file is part of the Restate Java SDK, -// which is released under the MIT license. -// -// You can find a copy of the license in file LICENSE in the root -// directory of this repository or package, or at -// https://github.com/restatedev/sdk-java/blob/main/LICENSE -package dev.restate.sdk.core.kotlinapi - -import dev.restate.sdk.common.StateKey -import dev.restate.sdk.core.StateTestSuite -import dev.restate.sdk.core.TestDefinitions.* -import dev.restate.sdk.core.TestSerdes -import dev.restate.sdk.core.kotlinapi.KotlinAPITests.Companion.testDefinitionForVirtualObject -import dev.restate.sdk.core.statemachine.ProtoUtils.* -import dev.restate.sdk.kotlin.* -import dev.restate.serde.kotlinx.* -import java.util.stream.Stream -import kotlinx.serialization.Serializable - -class StateTest : StateTestSuite() { - - override fun getState(): TestInvocationBuilder = - testDefinitionForVirtualObject("GetState") { ctx, _: Unit -> - val state = ctx.get(StateKey.of("STATE", TestSerdes.STRING)) ?: "Unknown" - "Hello $state" - } - - override fun getAndSetState(): TestInvocationBuilder = - testDefinitionForVirtualObject("GetAndSetState") { ctx, name: String -> - val state = ctx.get(StateKey.of("STATE", TestSerdes.STRING))!! - ctx.set(StateKey.of("STATE", TestSerdes.STRING), name) - "Hello $state" - } - - override fun setNullState(): TestInvocationBuilder { - return unsupported("The kotlin type system enforces non null state values") - } - - // --- Test using KTSerdes - - @Serializable data class Data(var a: Int, val b: String) - - private companion object { - val DATA = stateKey("STATE") - } - - private fun getAndSetStateUsingKtSerdes(): TestInvocationBuilder = - testDefinitionForVirtualObject("GetAndSetStateUsingKtSerdes") { ctx, _: Unit -> - val state = ctx.get(DATA)!! - state.a += 1 - ctx.set(DATA, state) - - "Hello $state" - } - - override fun definitions(): Stream { - return Stream.concat( - super.definitions(), - Stream.of( - getAndSetStateUsingKtSerdes() - .withInput( - startMessage(3), - inputCmd(), - getEagerStateCmd("STATE", jsonSerde(), Data(1, "Till")), - setStateCmd("STATE", jsonSerde(), Data(2, "Till")), - ) - .expectingOutput(outputCmd("Hello " + Data(2, "Till")), END_MESSAGE) - .named("With GetState and SetState"), - getAndSetStateUsingKtSerdes() - .withInput( - startMessage(2), - inputCmd(), - getEagerStateCmd("STATE", jsonSerde(), Data(1, "Till")), - ) - .expectingOutput( - setStateCmd("STATE", jsonSerde(), Data(2, "Till")), - outputCmd("Hello " + Data(2, "Till")), - END_MESSAGE, - ) - .named("With GetState already completed"), - ), - ) - } -} diff --git a/sdk-core/src/test/kotlin/dev/restate/sdk/core/kotlinapi/UserFailuresTest.kt b/sdk-core/src/test/kotlin/dev/restate/sdk/core/kotlinapi/UserFailuresTest.kt deleted file mode 100644 index 9d8e2af58..000000000 --- a/sdk-core/src/test/kotlin/dev/restate/sdk/core/kotlinapi/UserFailuresTest.kt +++ /dev/null @@ -1,51 +0,0 @@ -// Copyright (c) 2023 - Restate Software, Inc., Restate GmbH -// -// This file is part of the Restate Java SDK, -// which is released under the MIT license. -// -// You can find a copy of the license in file LICENSE in the root -// directory of this repository or package, or at -// https://github.com/restatedev/sdk-java/blob/main/LICENSE -package dev.restate.sdk.core.kotlinapi - -import dev.restate.sdk.common.TerminalException -import dev.restate.sdk.core.TestDefinitions.TestInvocationBuilder -import dev.restate.sdk.core.UserFailuresTestSuite -import dev.restate.sdk.core.kotlinapi.KotlinAPITests.Companion.testDefinitionForService -import dev.restate.sdk.kotlin.* -import java.util.concurrent.atomic.AtomicInteger -import kotlin.coroutines.cancellation.CancellationException - -class UserFailuresTest : UserFailuresTestSuite() { - override fun throwIllegalStateException(): TestInvocationBuilder = - testDefinitionForService("ThrowIllegalStateException") { _, _: Unit -> - throw IllegalStateException("Whatever") - } - - override fun sideEffectThrowIllegalStateException( - nonTerminalExceptionsSeen: AtomicInteger - ): TestInvocationBuilder = - testDefinitionForService("SideEffectThrowIllegalStateException") { ctx, _: Unit -> - try { - ctx.runBlock { throw IllegalStateException("Whatever") } - } catch (e: Throwable) { - if (e !is CancellationException && e !is TerminalException) { - nonTerminalExceptionsSeen.addAndGet(1) - } else { - throw e - } - } - throw IllegalStateException("Not expected to reach this point") - } - - override fun throwTerminalException(code: Int, message: String): TestInvocationBuilder = - testDefinitionForService("ThrowTerminalException") { _, _: Unit -> - throw TerminalException(code, message) - } - - override fun sideEffectThrowTerminalException(code: Int, message: String): TestInvocationBuilder = - testDefinitionForService("SideEffectThrowTerminalException") { ctx, _: Unit -> - ctx.runBlock { throw TerminalException(code, message) } - throw IllegalStateException("Not expected to reach this point") - } -} diff --git a/sdk-core/src/test/kotlin/dev/restate/sdk/core/kotlinapi/reflections/ReflectionTest.kt b/sdk-core/src/test/kotlin/dev/restate/sdk/core/kotlinapi/reflections/ReflectionTest.kt deleted file mode 100644 index e85b49087..000000000 --- a/sdk-core/src/test/kotlin/dev/restate/sdk/core/kotlinapi/reflections/ReflectionTest.kt +++ /dev/null @@ -1,247 +0,0 @@ -// Copyright (c) 2023 - Restate Software, Inc., Restate GmbH -// -// This file is part of the Restate Java SDK, -// which is released under the MIT license. -// -// You can find a copy of the license in file LICENSE in the root -// directory of this repository or package, or at -// https://github.com/restatedev/sdk-java/blob/main/LICENSE -package dev.restate.sdk.core.kotlinapi.reflections - -import dev.restate.common.Slice -import dev.restate.common.Target -import dev.restate.sdk.core.TestDefinitions -import dev.restate.sdk.core.TestDefinitions.TestDefinition -import dev.restate.sdk.core.TestDefinitions.testInvocation -import dev.restate.sdk.core.TestSerdes -import dev.restate.sdk.core.statemachine.ProtoUtils.* -import dev.restate.serde.Serde -import dev.restate.serde.kotlinx.* -import java.util.stream.Stream - -class ReflectionTest : TestDefinitions.TestSuite { - - override fun definitions(): Stream { - return Stream.of( - testInvocation({ ServiceGreeter() }, "greet") - .withInput(startMessage(1), inputCmd("Francesco")) - .onlyBidiStream() - .expectingOutput(outputCmd("Francesco"), END_MESSAGE), - testInvocation({ ObjectGreeter() }, "greet") - .withInput(startMessage(1, "slinkydeveloper"), inputCmd("Francesco")) - .onlyBidiStream() - .expectingOutput(outputCmd("Francesco"), END_MESSAGE), - testInvocation({ ObjectGreeter() }, "sharedGreet") - .withInput(startMessage(1, "slinkydeveloper"), inputCmd("Francesco")) - .onlyBidiStream() - .expectingOutput(outputCmd("Francesco"), END_MESSAGE), - testInvocation({ NestedDataClass() }, "greet") - .withInput( - startMessage(1, "slinkydeveloper"), - inputCmd(jsonSerde(), NestedDataClass.Input("123")), - ) - .onlyBidiStream() - .expectingOutput( - outputCmd(jsonSerde(), NestedDataClass.Output("123")), - END_MESSAGE, - ), - testInvocation({ ObjectGreeterImplementedFromInterface() }, "greet") - .withInput( - startMessage(1, "slinkydeveloper"), - inputCmd("Francesco"), - callCompletion(2, "Francesco"), - ) - .onlyBidiStream() - .expectingOutput( - callCmd( - 1, - 2, - Target.virtualObject("GreeterInterface", "slinkydeveloper", "greet"), - "Francesco", - ), - outputCmd("Francesco"), - END_MESSAGE, - ), - testInvocation({ Empty() }, "emptyInput") - .withInput(startMessage(1), inputCmd(), callCompletion(2, "Till")) - .onlyBidiStream() - .expectingOutput( - callCmd(1, 2, Target.service("Empty", "emptyInput")), - outputCmd("Till"), - END_MESSAGE, - ) - .named("empty output"), - testInvocation({ Empty() }, "emptyOutput") - .withInput(startMessage(1), inputCmd("Francesco"), callCompletion(2, Serde.VOID, null)) - .onlyBidiStream() - .expectingOutput( - callCmd(1, 2, Target.service("Empty", "emptyOutput"), "Francesco"), - outputCmd(), - END_MESSAGE, - ) - .named("empty output"), - testInvocation({ Empty() }, "emptyInputOutput") - .withInput(startMessage(1), inputCmd("Francesco"), callCompletion(2, Serde.VOID, null)) - .onlyBidiStream() - .expectingOutput( - callCmd(1, 2, Target.service("Empty", "emptyInputOutput")), - outputCmd(), - END_MESSAGE, - ) - .named("empty input and empty output"), - testInvocation({ PrimitiveTypes() }, "primitiveOutput") - .withInput(startMessage(1), inputCmd(), callCompletion(2, TestSerdes.INT, 10)) - .onlyBidiStream() - .expectingOutput( - callCmd( - 1, - 2, - Target.service("PrimitiveTypes", "primitiveOutput"), - Serde.VOID, - null, - ), - outputCmd(TestSerdes.INT, 10), - END_MESSAGE, - ) - .named("primitive output"), - testInvocation({ PrimitiveTypes() }, "primitiveInput") - .withInput(startMessage(1), inputCmd(10), callCompletion(2, Serde.VOID, null)) - .onlyBidiStream() - .expectingOutput( - callCmd( - 1, - 2, - Target.service("PrimitiveTypes", "primitiveInput"), - TestSerdes.INT, - 10, - ), - outputCmd(), - END_MESSAGE, - ) - .named("primitive input"), - testInvocation({ RawInputOutput() }, "rawInput") - .withInput( - startMessage(1), - inputCmd("{{".toByteArray()), - callCompletion(2, KotlinSerializationSerdeFactory.UNIT, Unit), - ) - .onlyBidiStream() - .expectingOutput( - callCmd(1, 2, Target.service("RawInputOutput", "rawInput"), "{{".toByteArray()), - outputCmd(), - END_MESSAGE, - ), - testInvocation({ RawInputOutput() }, "rawInputWithCustomCt") - .withInput( - startMessage(1), - inputCmd("{{".toByteArray()), - callCompletion(2, KotlinSerializationSerdeFactory.UNIT, Unit), - ) - .onlyBidiStream() - .expectingOutput( - callCmd( - 1, - 2, - Target.service("RawInputOutput", "rawInputWithCustomCt"), - "{{".toByteArray(), - ), - outputCmd(), - END_MESSAGE, - ), - testInvocation({ RawInputOutput() }, "rawOutput") - .withInput( - startMessage(1), - inputCmd(), - callCompletion(2, Serde.RAW, "{{".toByteArray()), - ) - .onlyBidiStream() - .expectingOutput( - callCmd( - 1, - 2, - Target.service("RawInputOutput", "rawOutput"), - KotlinSerializationSerdeFactory.UNIT, - Unit, - ), - outputCmd("{{".toByteArray()), - END_MESSAGE, - ), - testInvocation({ RawInputOutput() }, "rawOutputWithCustomCT") - .withInput( - startMessage(1), - inputCmd(), - callCompletion(2, Serde.RAW, "{{".toByteArray()), - ) - .onlyBidiStream() - .expectingOutput( - callCmd( - 1, - 2, - Target.service("RawInputOutput", "rawOutputWithCustomCT"), - KotlinSerializationSerdeFactory.UNIT, - Unit, - ), - outputCmd("{{".toByteArray()), - END_MESSAGE, - ), - testInvocation({ CornerCases() }, "returnNull") - .withInput( - startMessage(1, "mykey"), - inputCmd(jsonSerde(), null), - callCompletion(2, jsonSerde(), null), - ) - .onlyBidiStream() - .expectingOutput( - callCmd( - 1, - 2, - Target.virtualObject("CornerCases", "mykey", "returnNull"), - jsonSerde(), - null, - ), - outputCmd(jsonSerde(), null), - END_MESSAGE, - ), - testInvocation({ CornerCases() }, "badReturnTypeInferred") - .withInput(startMessage(1, "mykey"), inputCmd()) - .onlyBidiStream() - .expectingOutput( - oneWayCallCmd( - 1, - Target.virtualObject( - "CornerCases", - "mykey", - "badReturnTypeInferred", - ), - null, - null, - Slice.EMPTY, - ), - outputCmd(), - END_MESSAGE, - ), - testInvocation({ CornerCases() }, "callSuspendWithinProxy") - .withInput(startMessage(1, "mykey"), inputCmd()) - .onlyBidiStream() - .expectingOutput( - oneWayCallCmd( - 1, - Target.virtualObject( - "CornerCases", - "mykey", - "callSuspendWithinProxy", - ), - null, - null, - Slice.EMPTY, - ), - outputCmd(), - END_MESSAGE, - ), - testInvocation({ CustomSerdeService() }, "echo") - .withInput(startMessage(1), inputCmd(byteArrayOf(1))) - .onlyBidiStream() - .expectingOutput(outputCmd(byteArrayOf(1)), END_MESSAGE), - ) - } -} diff --git a/sdk-core/src/test/kotlin/dev/restate/sdk/core/kotlinapi/reflections/testClasses.kt b/sdk-core/src/test/kotlin/dev/restate/sdk/core/kotlinapi/reflections/testClasses.kt index 4f08348c6..57151c876 100644 --- a/sdk-core/src/test/kotlin/dev/restate/sdk/core/kotlinapi/reflections/testClasses.kt +++ b/sdk-core/src/test/kotlin/dev/restate/sdk/core/kotlinapi/reflections/testClasses.kt @@ -10,63 +10,7 @@ package dev.restate.sdk.core.kotlinapi.reflections import dev.restate.sdk.annotation.* import dev.restate.sdk.kotlin.* -import dev.restate.serde.Serde -import dev.restate.serde.SerdeFactory -import dev.restate.serde.TypeRef -import dev.restate.serde.TypeTag -import dev.restate.serde.kotlinx.KotlinSerializationSerdeFactory import kotlinx.coroutines.delay -import kotlinx.serialization.Serializable - -@Service -class ServiceGreeter { - @Handler - suspend fun greet(request: String): String { - return request - } -} - -@VirtualObject -class ObjectGreeter { - @Exclusive - suspend fun greet(request: String): String { - return request - } - - @Handler - @Shared - suspend fun sharedGreet(request: String): String { - return request - } -} - -@VirtualObject -class NestedDataClass { - @Serializable data class Input(val a: String) - - @Serializable data class Output(val a: String) - - @Exclusive - suspend fun greet(request: Input): Output { - return Output(request.a) - } - - @Exclusive - suspend fun complexType(request: Map>): Map> { - return mapOf() - } -} - -@VirtualObject -interface GreeterInterface { - @Exclusive suspend fun greet(request: String): String -} - -class ObjectGreeterImplementedFromInterface : GreeterInterface { - override suspend fun greet(request: String): String { - return virtualObject(objectKey()).greet(request) - } -} @Service @Name("Empty") @@ -169,35 +113,6 @@ open class MyWorkflow { workflow(workflowKey()).sharedHandler(myInput) } -@Suppress("UNCHECKED_CAST") -class MyCustomSerdeFactory : SerdeFactory { - override fun create(typeTag: TypeTag): Serde { - check(typeTag is KotlinSerializationSerdeFactory.KtTypeTag) - check(typeTag.type == Byte::class) - return Serde.using({ b -> byteArrayOf(b) }, { it[0] }) as Serde - } - - override fun create(typeRef: TypeRef): Serde { - check(typeRef.type == Byte::class) - return Serde.using({ b -> byteArrayOf(b) }, { it[0] }) as Serde - } - - override fun create(clazz: Class?): Serde { - check(clazz == Byte::class.java) - return Serde.using({ b -> byteArrayOf(b) }, { it[0] }) as Serde - } -} - -@CustomSerdeFactory(MyCustomSerdeFactory::class) -@Service -@Name("CustomSerdeService") -class CustomSerdeService { - @Handler - suspend fun echo(input: Byte): Byte { - return input - } -} - @Service @Name("MyExplicitName") interface GreeterWithExplicitName { diff --git a/sdk-core/src/test/kotlin/dev/restate/sdk/core/vertx/RestateHttpServerTest.kt b/sdk-core/src/test/kotlin/dev/restate/sdk/core/vertx/RestateHttpServerTest.kt index 497f5e958..b85391c86 100644 --- a/sdk-core/src/test/kotlin/dev/restate/sdk/core/vertx/RestateHttpServerTest.kt +++ b/sdk-core/src/test/kotlin/dev/restate/sdk/core/vertx/RestateHttpServerTest.kt @@ -9,9 +9,8 @@ package dev.restate.sdk.core.vertx import com.fasterxml.jackson.databind.ObjectMapper -import com.google.protobuf.MessageLite +import dev.restate.sdk.core.DiscoveryProtocol import dev.restate.sdk.core.generated.manifest.EndpointManifestSchema -import dev.restate.sdk.core.statemachine.ProtoUtils.* import dev.restate.sdk.endpoint.definition.HandlerDefinition import dev.restate.sdk.endpoint.definition.HandlerType import dev.restate.sdk.endpoint.definition.ServiceDefinition @@ -22,10 +21,8 @@ import dev.restate.sdk.kotlin.ObjectContext import dev.restate.sdk.kotlin.endpoint.endpoint import dev.restate.sdk.kotlin.stateKey import dev.restate.serde.kotlinx.* -import io.netty.buffer.Unpooled import io.netty.handler.codec.http.HttpResponseStatus import io.vertx.core.Vertx -import io.vertx.core.buffer.Buffer import io.vertx.core.http.* import io.vertx.junit5.VertxExtension import io.vertx.kotlin.coroutines.coAwait @@ -81,46 +78,6 @@ internal class RestateHttpServerTest { ) } - @Test - fun return404(vertx: Vertx): Unit = - runBlocking(vertx.dispatcher()) { - val endpointPort: Int = - RestateHttpServer.fromEndpoint( - vertx, - endpoint { bind(greeter()) }, - HttpServerOptions().setPort(0), - ) - .listen() - .coAwait() - .actualPort() - - val client = vertx.createHttpClient(HTTP_CLIENT_OPTIONS) - - val request = - client - .request( - HttpMethod.POST, - endpointPort, - "localhost", - "/invoke/$GREETER_NAME/unknownMethod", - ) - .coAwait() - - // Prepare request header - request - .setChunked(true) - .putHeader(HttpHeaders.CONTENT_TYPE, serviceProtocolContentTypeHeader(false)) - .putHeader(HttpHeaders.ACCEPT, serviceProtocolContentTypeHeader(false)) - request.write(encode(startMessage(0).build())) - - val response = request.response().coAwait() - - // Response status should be 404 - assertThat(response.statusCode()).isEqualTo(HttpResponseStatus.NOT_FOUND.code()) - - response.end().coAwait() - } - @Test fun serviceDiscovery(vertx: Vertx): Unit = runBlocking(vertx.dispatcher()) { @@ -139,7 +96,7 @@ internal class RestateHttpServerTest { // Send request val request = client.request(HttpMethod.GET, endpointPort, "localhost", "/discover").coAwait() - request.putHeader(HttpHeaders.ACCEPT, serviceProtocolDiscoveryContentTypeHeader()) + request.putHeader(HttpHeaders.ACCEPT, DiscoveryProtocol.Version.MAX.header) request.end().coAwait() // Assert response @@ -148,7 +105,7 @@ internal class RestateHttpServerTest { // Response status and content type header assertThat(response.statusCode()).isEqualTo(HttpResponseStatus.OK.code()) assertThat(response.getHeader(HttpHeaders.CONTENT_TYPE)) - .isEqualTo(serviceProtocolDiscoveryContentTypeHeader()) + .isEqualTo(DiscoveryProtocol.Version.MAX.header) // Parse response val responseBody = response.body().coAwait() @@ -158,8 +115,4 @@ internal class RestateHttpServerTest { assertThat(discoveryResponse.services).map { it.name }.containsOnly(GREETER_NAME) } - - private fun encode(msg: MessageLite): Buffer { - return Buffer.buffer(Unpooled.wrappedBuffer(encodeMessageToByteBuffer(msg))) - } } diff --git a/sdk-core/src/test/kotlin/dev/restate/sdk/core/vertx/RestateHttpServerTestExecutor.kt b/sdk-core/src/test/kotlin/dev/restate/sdk/core/vertx/RestateHttpServerTestExecutor.kt deleted file mode 100644 index 5f2945bf0..000000000 --- a/sdk-core/src/test/kotlin/dev/restate/sdk/core/vertx/RestateHttpServerTestExecutor.kt +++ /dev/null @@ -1,114 +0,0 @@ -// Copyright (c) 2023 - Restate Software, Inc., Restate GmbH -// -// This file is part of the Restate Java SDK, -// which is released under the MIT license. -// -// You can find a copy of the license in file LICENSE in the root -// directory of this repository or package, or at -// https://github.com/restatedev/sdk-java/blob/main/LICENSE -package dev.restate.sdk.core.vertx - -import dev.restate.sdk.core.TestDefinitions.TestDefinition -import dev.restate.sdk.core.TestDefinitions.TestExecutor -import dev.restate.sdk.core.statemachine.ProtoUtils -import dev.restate.sdk.endpoint.Endpoint -import dev.restate.sdk.endpoint.definition.ServiceDefinition -import dev.restate.sdk.http.vertx.RestateHttpServer -import io.netty.buffer.Unpooled -import io.vertx.core.Vertx -import io.vertx.core.buffer.Buffer -import io.vertx.core.http.HttpHeaders -import io.vertx.core.http.HttpMethod -import io.vertx.core.http.HttpServerOptions -import io.vertx.kotlin.coroutines.coAwait -import io.vertx.kotlin.coroutines.dispatcher -import java.nio.ByteBuffer -import kotlinx.coroutines.channels.Channel -import kotlinx.coroutines.flow.receiveAsFlow -import kotlinx.coroutines.flow.toList -import kotlinx.coroutines.launch -import kotlinx.coroutines.runBlocking -import kotlinx.coroutines.yield - -class RestateHttpServerTestExecutor(private val vertx: Vertx) : TestExecutor { - override fun buffered(): Boolean { - return false - } - - override fun executeTest(definition: TestDefinition) { - runBlocking(vertx.dispatcher()) { - // Build server - val endpointBuilder = - Endpoint.builder() - .bind(definition.serviceDefinition as ServiceDefinition, definition.serviceOptions) - if (definition.isEnablePreviewContext()) { - endpointBuilder.enablePreviewContext() - } - - // Start server - val server = - RestateHttpServer.fromEndpoint( - vertx, - endpointBuilder.build(), - HttpServerOptions().setPort(0), - ) - server.listen().coAwait() - - val client = vertx.createHttpClient(RestateHttpServerTest.Companion.HTTP_CLIENT_OPTIONS) - - val request = - client - .request( - HttpMethod.POST, - server.actualPort(), - "localhost", - "/invoke/${definition.serviceDefinition.serviceName}/${definition.method}", - ) - .coAwait() - - // Prepare request header and send them - request - .setChunked(true) - .putHeader( - HttpHeaders.CONTENT_TYPE, - ProtoUtils.serviceProtocolContentTypeHeader(definition.isEnablePreviewContext), - ) - .putHeader( - HttpHeaders.ACCEPT, - ProtoUtils.serviceProtocolContentTypeHeader(definition.isEnablePreviewContext), - ) - request.sendHead().coAwait() - - launch { - for (msg in definition.input) { - request - .write( - Buffer.buffer(Unpooled.wrappedBuffer(ProtoUtils.invocationInputToByteString(msg))) - ) - .coAwait() - yield() - } - - request.end().coAwait() - } - - val response = request.response().coAwait() - - // Start the response receiver - val inputChannel = Channel() - response.handler { launch(vertx.dispatcher()) { inputChannel.send(it) } } - response.endHandler { inputChannel.close() } - response.resume() - - // Collect all the output messages - val buffers = inputChannel.receiveAsFlow().toList() - - definition.outputAssert.accept( - ProtoUtils.bufferToMessages(buffers.map { ByteBuffer.wrap(it.bytes) }) - ) - - // Close the server - server.close().coAwait() - } - } -} diff --git a/sdk-core/src/test/kotlin/dev/restate/sdk/core/vertx/RestateHttpServerTests.kt b/sdk-core/src/test/kotlin/dev/restate/sdk/core/vertx/RestateHttpServerTests.kt deleted file mode 100644 index 97cda32a6..000000000 --- a/sdk-core/src/test/kotlin/dev/restate/sdk/core/vertx/RestateHttpServerTests.kt +++ /dev/null @@ -1,45 +0,0 @@ -// Copyright (c) 2023 - Restate Software, Inc., Restate GmbH -// -// This file is part of the Restate Java SDK, -// which is released under the MIT license. -// -// You can find a copy of the license in file LICENSE in the root -// directory of this repository or package, or at -// https://github.com/restatedev/sdk-java/blob/main/LICENSE -package dev.restate.sdk.core.vertx - -import dev.restate.sdk.core.TestDefinitions.TestExecutor -import dev.restate.sdk.core.TestDefinitions.TestSuite -import dev.restate.sdk.core.TestRunner -import dev.restate.sdk.core.javaapi.JavaAPITests -import dev.restate.sdk.core.kotlinapi.KotlinAPITests -import io.vertx.core.Vertx -import java.util.stream.Stream -import org.junit.jupiter.api.AfterAll -import org.junit.jupiter.api.BeforeAll - -class RestateHttpServerTests : TestRunner() { - - lateinit var vertx: Vertx - - @BeforeAll - fun beforeAll() { - vertx = Vertx.vertx() - } - - @AfterAll - fun afterAll() { - vertx.close().toCompletionStage().toCompletableFuture().get() - } - - override fun executors(): Stream { - return Stream.of(RestateHttpServerTestExecutor(vertx)) - } - - override fun definitions(): Stream { - return Stream.concat( - Stream.concat(JavaAPITests().definitions(), KotlinAPITests().definitions()), - Stream.of(ThreadTrampoliningTestSuite()), - ) - } -} diff --git a/sdk-core/src/test/kotlin/dev/restate/sdk/core/vertx/ThreadTrampoliningTestSuite.kt b/sdk-core/src/test/kotlin/dev/restate/sdk/core/vertx/ThreadTrampoliningTestSuite.kt deleted file mode 100644 index 5a4f6eafb..000000000 --- a/sdk-core/src/test/kotlin/dev/restate/sdk/core/vertx/ThreadTrampoliningTestSuite.kt +++ /dev/null @@ -1,125 +0,0 @@ -// Copyright (c) 2023 - Restate Software, Inc., Restate GmbH -// -// This file is part of the Restate Java SDK, -// which is released under the MIT license. -// -// You can find a copy of the license in file LICENSE in the root -// directory of this repository or package, or at -// https://github.com/restatedev/sdk-java/blob/main/LICENSE -package dev.restate.sdk.core.vertx - -import dev.restate.sdk.core.TestDefinitions -import dev.restate.sdk.core.TestDefinitions.testInvocation -import dev.restate.sdk.core.statemachine.ProtoUtils.* -import dev.restate.sdk.endpoint.definition.HandlerDefinition -import dev.restate.sdk.endpoint.definition.HandlerType -import dev.restate.sdk.endpoint.definition.ServiceDefinition -import dev.restate.sdk.endpoint.definition.ServiceType -import dev.restate.sdk.kotlin.Context -import dev.restate.sdk.kotlin.HandlerRunner -import dev.restate.sdk.kotlin.runBlock -import dev.restate.serde.Serde -import dev.restate.serde.jackson.JacksonSerdeFactory -import dev.restate.serde.kotlinx.* -import io.vertx.core.Vertx -import java.util.stream.Stream -import kotlin.coroutines.coroutineContext -import kotlinx.coroutines.CoroutineName -import kotlinx.coroutines.Dispatchers -import org.apache.logging.log4j.LogManager - -class ThreadTrampoliningTestSuite : TestDefinitions.TestSuite { - - private val nonBlockingCoroutineName = CoroutineName("CheckContextSwitchingTestCoroutine") - - companion object { - private val LOG = LogManager.getLogger() - } - - private suspend fun checkNonBlockingComponentTrampolineExecutor(ctx: Context) { - LOG.info("I am on the thread I am before executing side effect") - check(Vertx.currentContext() == null) - check(coroutineContext[CoroutineName] == nonBlockingCoroutineName) - ctx.runBlock { - LOG.info("I am on the thread I am when executing side effect") - check(Vertx.currentContext() == null) - } - LOG.info("I am on the thread I am after executing side effect") - check(coroutineContext[CoroutineName] == nonBlockingCoroutineName) - check(Vertx.currentContext() == null) - } - - private fun checkBlockingComponentTrampolineExecutor( - ctx: dev.restate.sdk.Context, - _unused: Any?, - ): Void? { - val id = Thread.currentThread().id - check(Vertx.currentContext() == null) - ctx.run { check(Vertx.currentContext() == null) } - check(Thread.currentThread().id == id) - check(Vertx.currentContext() == null) - return null - } - - override fun definitions(): Stream { - return Stream.of( - testInvocation( - ServiceDefinition.of( - "CheckNonBlockingComponentTrampolineExecutor", - ServiceType.SERVICE, - listOf( - HandlerDefinition.of( - "do", - HandlerType.SHARED, - KotlinSerializationSerdeFactory.UNIT, - KotlinSerializationSerdeFactory.UNIT, - HandlerRunner.of( - KotlinSerializationSerdeFactory(), - HandlerRunner.Options( - Dispatchers.Default + nonBlockingCoroutineName - ), - ) { ctx: Context, _: Unit -> - checkNonBlockingComponentTrampolineExecutor(ctx) - }, - ) - ), - ), - "do", - ) - .withInput(startMessage(1), inputCmd()) - .onlyBidiStream() - .expectingOutput( - runCmd(1), - proposeRunCompletion(1, Serde.VOID, null), - suspensionMessage(1), - ), - testInvocation( - ServiceDefinition.of( - "CheckBlockingComponentTrampolineExecutor", - ServiceType.SERVICE, - listOf( - HandlerDefinition.of( - "do", - HandlerType.SHARED, - Serde.VOID, - Serde.VOID, - dev.restate.sdk.HandlerRunner.of( - this::checkBlockingComponentTrampolineExecutor, - JacksonSerdeFactory(), - null, - ), - ) - ), - ), - "do", - ) - .withInput(startMessage(1), inputCmd()) - .onlyBidiStream() - .expectingOutput( - runCmd(1), - proposeRunCompletion(1, Serde.VOID, null), - suspensionMessage(1), - ), - ) - } -}