diff --git a/apps/sim/executor/execution/engine.test.ts b/apps/sim/executor/execution/engine.test.ts index b77b2f4f8f0..f1306a63326 100644 --- a/apps/sim/executor/execution/engine.test.ts +++ b/apps/sim/executor/execution/engine.test.ts @@ -4,9 +4,27 @@ import { sleep } from '@sim/utils/helpers' import { afterEach, beforeEach, describe, expect, it, type Mock, vi } from 'vitest' +const { mockCancellationSubscribers } = vi.hoisted(() => ({ + mockCancellationSubscribers: new Set<(event: { executionId: string }) => void>(), +})) + vi.mock('@/lib/execution/cancellation', () => ({ isExecutionCancelled: vi.fn(), isRedisCancellationEnabled: vi.fn(), + getCancellationChannel: () => ({ + publish: (event: { executionId: string }) => { + for (const handler of mockCancellationSubscribers) handler(event) + }, + subscribe: (handler: (event: { executionId: string }) => void) => { + mockCancellationSubscribers.add(handler) + return () => { + mockCancellationSubscribers.delete(handler) + } + }, + dispose: () => { + mockCancellationSubscribers.clear() + }, + }), })) import { isExecutionCancelled, isRedisCancellationEnabled } from '@/lib/execution/cancellation' @@ -115,6 +133,7 @@ function createMockNodeOrchestrator(executeDelay = 0): MockNodeOrchestrator { describe('ExecutionEngine', () => { beforeEach(() => { vi.clearAllMocks() + mockCancellationSubscribers.clear() ;(isExecutionCancelled as Mock).mockResolvedValue(false) ;(isRedisCancellationEnabled as Mock).mockReturnValue(false) }) @@ -346,7 +365,93 @@ describe('ExecutionEngine', () => { expect(result.status).toBe('cancelled') }) - it('should respect cancellation check interval', async () => { + it('wakes from a slow in-flight node when a pub/sub cancellation arrives', async () => { + ;(isRedisCancellationEnabled as Mock).mockReturnValue(true) + ;(isExecutionCancelled as Mock).mockResolvedValue(false) + + const startNode = createMockNode('start', 'starter') + const slowNode = createMockNode('slow', 'function') + startNode.outgoingEdges.set('edge1', { target: 'slow' }) + + const dag = createMockDAG([startNode, slowNode]) + const context = createMockContext({ executionId: 'pubsub-execution' }) + const edgeManager = createMockEdgeManager((node) => (node.id === 'start' ? ['slow'] : [])) + const nodeOrchestrator = createMockNodeOrchestrator(500) + + const engine = new ExecutionEngine(context, dag, edgeManager, nodeOrchestrator) + const executionPromise = engine.run('start') + + setTimeout(() => { + for (const handler of mockCancellationSubscribers) { + handler({ executionId: 'pubsub-execution' }) + } + }, 5) + + const startTime = Date.now() + const result = await executionPromise + const duration = Date.now() - startTime + + expect(result.status).toBe('cancelled') + expect(duration).toBeLessThan(100) + }) + + it('ignores pub/sub events targeting other executions', async () => { + ;(isRedisCancellationEnabled as Mock).mockReturnValue(true) + ;(isExecutionCancelled as Mock).mockResolvedValue(false) + + const startNode = createMockNode('start', 'starter') + const dag = createMockDAG([startNode]) + const context = createMockContext({ executionId: 'execution-a' }) + const edgeManager = createMockEdgeManager() + const nodeOrchestrator = createMockNodeOrchestrator() + + const engine = new ExecutionEngine(context, dag, edgeManager, nodeOrchestrator) + + for (const handler of mockCancellationSubscribers) { + handler({ executionId: 'execution-b' }) + } + + const result = await engine.run('start') + expect(result.status).toBeUndefined() + expect(result.success).toBe(true) + }) + + it('unsubscribes from the cancellation channel after run completes', async () => { + ;(isRedisCancellationEnabled as Mock).mockReturnValue(true) + ;(isExecutionCancelled as Mock).mockResolvedValue(false) + + const startNode = createMockNode('start', 'starter') + const dag = createMockDAG([startNode]) + const context = createMockContext({ executionId: 'cleanup-execution' }) + const edgeManager = createMockEdgeManager() + const nodeOrchestrator = createMockNodeOrchestrator() + + const engine = new ExecutionEngine(context, dag, edgeManager, nodeOrchestrator) + expect(mockCancellationSubscribers.size).toBe(1) + + await engine.run('start') + + expect(mockCancellationSubscribers.size).toBe(0) + }) + + it('honours the durable backstop when cancelled before subscribing', async () => { + ;(isRedisCancellationEnabled as Mock).mockReturnValue(true) + ;(isExecutionCancelled as Mock).mockResolvedValue(true) + + const startNode = createMockNode('start', 'starter') + const dag = createMockDAG([startNode]) + const context = createMockContext() + const edgeManager = createMockEdgeManager() + const nodeOrchestrator = createMockNodeOrchestrator() + + const engine = new ExecutionEngine(context, dag, edgeManager, nodeOrchestrator) + const result = await engine.run('start') + + expect(result.status).toBe('cancelled') + expect(nodeOrchestrator.executionCount).toBe(0) + }) + + it('calls isExecutionCancelled once as the startup backstop check', async () => { ;(isRedisCancellationEnabled as Mock).mockReturnValue(true) ;(isExecutionCancelled as Mock).mockResolvedValue(false) @@ -359,7 +464,7 @@ describe('ExecutionEngine', () => { const engine = new ExecutionEngine(context, dag, edgeManager, nodeOrchestrator) await engine.run('start') - expect((isExecutionCancelled as Mock).mock.calls.length).toBeGreaterThanOrEqual(1) + expect((isExecutionCancelled as Mock).mock.calls.length).toBe(1) }) }) diff --git a/apps/sim/executor/execution/engine.ts b/apps/sim/executor/execution/engine.ts index 45880bc44d2..b279aae2252 100644 --- a/apps/sim/executor/execution/engine.ts +++ b/apps/sim/executor/execution/engine.ts @@ -1,6 +1,10 @@ import { createLogger, type Logger } from '@sim/logger' import { toError } from '@sim/utils/errors' -import { isExecutionCancelled, isRedisCancellationEnabled } from '@/lib/execution/cancellation' +import { + getCancellationChannel, + isExecutionCancelled, + isRedisCancellationEnabled, +} from '@/lib/execution/cancellation' import { BlockType } from '@/executor/constants' import type { DAG } from '@/executor/dag/builder' import type { EdgeManager } from '@/executor/execution/edge-manager' @@ -31,11 +35,9 @@ export class ExecutionEngine { private errorFlag = false private stoppedEarlyFlag = false private executionError: Error | null = null - private lastCancellationCheck = 0 - private readonly useRedisCancellation: boolean - private readonly CANCELLATION_CHECK_INTERVAL_MS = 500 - private abortPromise: Promise | null = null - private abortResolve: (() => void) | null = null + private abortPromise!: Promise + private abortResolve!: () => void + private cancellationUnsubscribe: (() => void) | null = null private execLogger: Logger constructor( @@ -45,7 +47,6 @@ export class ExecutionEngine { private nodeOrchestrator: NodeExecutionOrchestrator ) { this.allowResumeTriggers = this.context.metadata.resumeFromSnapshot === true - this.useRedisCancellation = isRedisCancellationEnabled() && !!this.context.executionId this.execLogger = logger.withMetadata({ workflowId: this.context.workflowId, workspaceId: this.context.workspaceId, @@ -54,72 +55,64 @@ export class ExecutionEngine { requestId: this.context.metadata.requestId, }) this.initializeAbortHandler() + this.subscribeToCancellationChannel() + } + + private subscribeToCancellationChannel(): void { + if (!this.context.executionId) return + const executionId = this.context.executionId + this.cancellationUnsubscribe = getCancellationChannel().subscribe((event) => { + if (event.executionId !== executionId) return + this.execLogger.info('Execution cancelled via pub/sub', { executionId }) + this.signalCancelled() + }) } - /** - * Sets up a single abort promise that can be reused throughout execution. - * This avoids creating multiple event listeners and potential memory leaks. - */ private initializeAbortHandler(): void { + this.abortPromise = new Promise((resolve) => { + this.abortResolve = resolve + }) + if (!this.context.abortSignal) return if (this.context.abortSignal.aborted) { - this.cancelledFlag = true - this.abortPromise = Promise.resolve() + this.signalCancelled() return } - this.abortPromise = new Promise((resolve) => { - this.abortResolve = resolve - }) - - this.context.abortSignal.addEventListener( - 'abort', - () => { - this.cancelledFlag = true - this.abortResolve?.() - }, - { once: true } - ) + this.context.abortSignal.addEventListener('abort', () => this.signalCancelled(), { once: true }) } - private async checkCancellation(): Promise { - if (this.cancelledFlag) { - return true - } - - if (this.useRedisCancellation) { - const now = Date.now() - if (now - this.lastCancellationCheck < this.CANCELLATION_CHECK_INTERVAL_MS) { - return false - } - this.lastCancellationCheck = now + private signalCancelled(): void { + if (this.cancelledFlag) return + this.cancelledFlag = true + this.abortResolve() + } - const cancelled = await isExecutionCancelled(this.context.executionId!) - if (cancelled) { - this.cancelledFlag = true - this.execLogger.info('Execution cancelled via Redis', { - executionId: this.context.executionId, - }) - } - return cancelled - } + private checkCancellation(): boolean { + return this.cancelledFlag + } - if (this.context.abortSignal?.aborted) { - this.cancelledFlag = true - return true + /** Catches cancellations published before this engine subscribed (e.g. resume from snapshot). */ + private async checkCancellationBackstop(): Promise { + if (!this.context.executionId || !isRedisCancellationEnabled()) return + const cancelled = await isExecutionCancelled(this.context.executionId) + if (cancelled) { + this.execLogger.info('Execution already cancelled at engine start (Redis backstop)', { + executionId: this.context.executionId, + }) + this.signalCancelled() } - - return false } async run(triggerBlockId?: string): Promise { const startTime = performance.now() try { this.initializeQueue(triggerBlockId) + await this.checkCancellationBackstop() while (this.hasWork()) { - if ((await this.checkCancellation()) || this.errorFlag || this.stoppedEarlyFlag) { + if (this.checkCancellation() || this.errorFlag || this.stoppedEarlyFlag) { break } await this.processQueue() @@ -194,6 +187,15 @@ export class ExecutionEngine { attachExecutionResult(error, executionResult) } throw error + } finally { + this.cleanup() + } + } + + private cleanup(): void { + if (this.cancellationUnsubscribe) { + this.cancellationUnsubscribe() + this.cancellationUnsubscribe = null } } @@ -238,32 +240,17 @@ export class ExecutionEngine { private async waitForAnyExecution(): Promise { if (this.executing.size > 0) { - const abortPromise = this.getAbortPromise() - if (abortPromise) { - await Promise.race([...this.executing, abortPromise]) - } else { - await Promise.race(this.executing) - } + await Promise.race([...this.executing, this.abortPromise]) } } private async waitForAllExecutions(): Promise { - const abortPromise = this.getAbortPromise() - if (abortPromise) { - await Promise.race([Promise.all(this.executing), abortPromise]) - } else { - await Promise.all(this.executing) + await Promise.race([Promise.all(this.executing), this.abortPromise]) + if (this.executing.size > 0) { + await Promise.allSettled(this.executing) } } - /** - * Returns the cached abort promise. This is safe to call multiple times - * as it reuses the same promise instance created during initialization. - */ - private getAbortPromise(): Promise | null { - return this.abortPromise - } - private async withQueueLock(fn: () => Promise | T): Promise { const prevLock = this.queueLock let resolveLock: () => void @@ -363,7 +350,7 @@ export class ExecutionEngine { private async processQueue(): Promise { while (this.readyQueue.length > 0) { - if ((await this.checkCancellation()) || this.errorFlag) { + if (this.checkCancellation() || this.errorFlag) { break } const nodeId = this.dequeue() diff --git a/apps/sim/executor/handlers/api/api-handler.test.ts b/apps/sim/executor/handlers/api/api-handler.test.ts index f073f996ad5..e2eeca7f795 100644 --- a/apps/sim/executor/handlers/api/api-handler.test.ts +++ b/apps/sim/executor/handlers/api/api-handler.test.ts @@ -116,8 +116,7 @@ describe('ApiBlockHandler', () => { body: { key: 'value' }, // Expect parsed body _context: { workflowId: 'test-workflow-id' }, }, - false, // skipPostProcess - mockContext // execution context + { executionContext: mockContext } ) expect(result).toEqual(expectedOutput) }) @@ -177,8 +176,7 @@ describe('ApiBlockHandler', () => { expect(mockExecuteTool).toHaveBeenCalledWith( 'http_request', expect.objectContaining({ body: expectedParsedBody }), - false, // skipPostProcess - mockContext // execution context + { executionContext: mockContext } ) }) @@ -193,8 +191,7 @@ describe('ApiBlockHandler', () => { expect(mockExecuteTool).toHaveBeenCalledWith( 'http_request', expect.objectContaining({ body: 'This is plain text' }), - false, // skipPostProcess - mockContext // execution context + { executionContext: mockContext } ) }) @@ -209,8 +206,7 @@ describe('ApiBlockHandler', () => { expect(mockExecuteTool).toHaveBeenCalledWith( 'http_request', expect.objectContaining({ body: undefined }), - false, // skipPostProcess - mockContext // execution context + { executionContext: mockContext } ) }) diff --git a/apps/sim/executor/handlers/api/api-handler.ts b/apps/sim/executor/handlers/api/api-handler.ts index 04bdea4e23e..b6a28517d7d 100644 --- a/apps/sim/executor/handlers/api/api-handler.ts +++ b/apps/sim/executor/handlers/api/api-handler.ts @@ -78,8 +78,7 @@ export class ApiBlockHandler implements BlockHandler { callChain: ctx.callChain, }, }, - false, - ctx + { executionContext: ctx } ) if (!result.success) { diff --git a/apps/sim/executor/handlers/condition/condition-handler.test.ts b/apps/sim/executor/handlers/condition/condition-handler.test.ts index d96e182e1f1..ec446cec2be 100644 --- a/apps/sim/executor/handlers/condition/condition-handler.test.ts +++ b/apps/sim/executor/handlers/condition/condition-handler.test.ts @@ -180,8 +180,7 @@ describe('ConditionBlockHandler', () => { workspaceId: 'test-workspace-id', }, }), - false, - mockContext + { executionContext: mockContext } ) }) diff --git a/apps/sim/executor/handlers/condition/condition-handler.ts b/apps/sim/executor/handlers/condition/condition-handler.ts index c6b4032636e..20ce8660dd1 100644 --- a/apps/sim/executor/handlers/condition/condition-handler.ts +++ b/apps/sim/executor/handlers/condition/condition-handler.ts @@ -54,8 +54,7 @@ async function evaluateConditionExpression( enforceCredentialAccess: ctx.enforceCredentialAccess, }, }, - false, - ctx + { executionContext: ctx } ) if (!result.success) { diff --git a/apps/sim/executor/handlers/function/function-handler.test.ts b/apps/sim/executor/handlers/function/function-handler.test.ts index aafd49faea5..bc2485db9e8 100644 --- a/apps/sim/executor/handlers/function/function-handler.test.ts +++ b/apps/sim/executor/handlers/function/function-handler.test.ts @@ -91,12 +91,9 @@ describe('FunctionBlockHandler', () => { const result = await handler.execute(mockContext, mockBlock, inputs) - expect(mockExecuteTool).toHaveBeenCalledWith( - 'function_execute', - expectedToolParams, - false, - mockContext - ) + expect(mockExecuteTool).toHaveBeenCalledWith('function_execute', expectedToolParams, { + executionContext: mockContext, + }) expect(result).toEqual(expectedOutput) }) @@ -132,12 +129,9 @@ describe('FunctionBlockHandler', () => { const result = await handler.execute(mockContext, mockBlock, inputs) - expect(mockExecuteTool).toHaveBeenCalledWith( - 'function_execute', - expectedToolParams, - false, - mockContext - ) + expect(mockExecuteTool).toHaveBeenCalledWith('function_execute', expectedToolParams, { + executionContext: mockContext, + }) expect(result).toEqual(expectedOutput) }) @@ -165,12 +159,9 @@ describe('FunctionBlockHandler', () => { await handler.execute(mockContext, mockBlock, inputs) - expect(mockExecuteTool).toHaveBeenCalledWith( - 'function_execute', - expectedToolParams, - false, // skipPostProcess - mockContext // execution context - ) + expect(mockExecuteTool).toHaveBeenCalledWith('function_execute', expectedToolParams, { + executionContext: mockContext, + }) }) it('should handle execution errors from the tool', async () => { @@ -197,8 +188,7 @@ describe('FunctionBlockHandler', () => { expect.objectContaining({ contextVariables, }), - false, - mockContext + { executionContext: mockContext } ) }) @@ -217,8 +207,7 @@ describe('FunctionBlockHandler', () => { code: 'retur globalThis["__blockRef_0"]', sourceCode: 'retur "value"', }), - false, - mockContext + { executionContext: mockContext } ) }) @@ -239,8 +228,7 @@ describe('FunctionBlockHandler', () => { workflowVariables: { 'var-1': legacyVariable }, contextVariables: {}, }), - false, - mockContext + { executionContext: mockContext } ) }) diff --git a/apps/sim/executor/handlers/function/function-handler.ts b/apps/sim/executor/handlers/function/function-handler.ts index ec08996ba5b..c22205212e3 100644 --- a/apps/sim/executor/handlers/function/function-handler.ts +++ b/apps/sim/executor/handlers/function/function-handler.ts @@ -76,7 +76,7 @@ export class FunctionBlockHandler implements BlockHandler { }, } - const result = await executeTool('function_execute', toolParams, false, ctx) + const result = await executeTool('function_execute', toolParams, { executionContext: ctx }) if (!result.success) { throw new Error(result.error || 'Function execution failed') diff --git a/apps/sim/executor/handlers/generic/generic-handler.test.ts b/apps/sim/executor/handlers/generic/generic-handler.test.ts index cf18f8a254a..be3baec9aed 100644 --- a/apps/sim/executor/handlers/generic/generic-handler.test.ts +++ b/apps/sim/executor/handlers/generic/generic-handler.test.ts @@ -92,12 +92,9 @@ describe('GenericBlockHandler', () => { const result = await handler.execute(mockContext, mockBlock, inputs) expect(mockGetTool).toHaveBeenCalledWith('some_custom_tool') - expect(mockExecuteTool).toHaveBeenCalledWith( - 'some_custom_tool', - expectedToolParams, - false, // skipPostProcess - mockContext // execution context - ) + expect(mockExecuteTool).toHaveBeenCalledWith('some_custom_tool', expectedToolParams, { + executionContext: mockContext, + }) expect(result).toEqual(expectedOutput) }) diff --git a/apps/sim/executor/handlers/generic/generic-handler.ts b/apps/sim/executor/handlers/generic/generic-handler.ts index b1d700a1f44..c6ee89dd633 100644 --- a/apps/sim/executor/handlers/generic/generic-handler.ts +++ b/apps/sim/executor/handlers/generic/generic-handler.ts @@ -72,8 +72,7 @@ export class GenericBlockHandler implements BlockHandler { enforceCredentialAccess: ctx.enforceCredentialAccess, }, }, - false, - ctx + { executionContext: ctx } ) if (!result.success) { diff --git a/apps/sim/executor/handlers/human-in-the-loop/human-in-the-loop-handler.ts b/apps/sim/executor/handlers/human-in-the-loop/human-in-the-loop-handler.ts index 0634aa65dd4..032640d1b7e 100644 --- a/apps/sim/executor/handlers/human-in-the-loop/human-in-the-loop-handler.ts +++ b/apps/sim/executor/handlers/human-in-the-loop/human-in-the-loop-handler.ts @@ -480,7 +480,7 @@ export class HumanInTheLoopBlockHandler implements BlockHandler { blockNameMapping: blockNameMappingWithPause, } - const result = await executeTool(toolId, toolParams, false, ctx) + const result = await executeTool(toolId, toolParams, { executionContext: ctx }) const durationMs = Date.now() - startTime if (!result.success) { diff --git a/apps/sim/lib/copilot/tool-executor/executor.test.ts b/apps/sim/lib/copilot/tool-executor/executor.test.ts index a0cd2eec358..adeb6ce48da 100644 --- a/apps/sim/lib/copilot/tool-executor/executor.test.ts +++ b/apps/sim/lib/copilot/tool-executor/executor.test.ts @@ -54,8 +54,7 @@ describe('copilot tool executor fallback', () => { chatId: 'chat-1', enforceCredentialAccess: true, }), - }), - false + }) ) expect(result).toEqual({ success: true, output: { emails: [] } }) }) @@ -83,8 +82,7 @@ describe('copilot tool executor fallback', () => { _context: expect.objectContaining({ copilotToolExecution: true, }), - }), - false + }) ) }) @@ -108,8 +106,7 @@ describe('copilot tool executor fallback', () => { 'function_execute', expect.objectContaining({ timeout: 10_000, - }), - false + }) ) }) @@ -133,8 +130,7 @@ describe('copilot tool executor fallback', () => { 'function_execute', expect.objectContaining({ timeout: 10_000, - }), - false + }) ) }) @@ -158,8 +154,7 @@ describe('copilot tool executor fallback', () => { 'function_execute', expect.objectContaining({ timeout: DEFAULT_EXECUTION_TIMEOUT_MS, - }), - false + }) ) }) }) diff --git a/apps/sim/lib/copilot/tool-executor/executor.ts b/apps/sim/lib/copilot/tool-executor/executor.ts index 869228970bf..084d046c027 100644 --- a/apps/sim/lib/copilot/tool-executor/executor.ts +++ b/apps/sim/lib/copilot/tool-executor/executor.ts @@ -43,7 +43,7 @@ export async function executeTool( const canUseRegisteredHandler = isKnownTool(toolId) && isSimExecuted(toolId) if (!canUseRegisteredHandler) { const appParams = buildAppToolParams(toolId, params, context) - return executeAppTool(toolId, appParams, false) + return executeAppTool(toolId, appParams) } if (context.abortSignal?.aborted) { diff --git a/apps/sim/lib/copilot/tools/handlers/function-execute.ts b/apps/sim/lib/copilot/tools/handlers/function-execute.ts index 8d6139d4c56..29f53924610 100644 --- a/apps/sim/lib/copilot/tools/handlers/function-execute.ts +++ b/apps/sim/lib/copilot/tools/handlers/function-execute.ts @@ -143,5 +143,5 @@ export async function executeFunctionExecute( enforceCredentialAccess: true, } - return executeAppTool('function_execute', enrichedParams, false) + return executeAppTool('function_execute', enrichedParams) } diff --git a/apps/sim/lib/execution/cancellation.test.ts b/apps/sim/lib/execution/cancellation.test.ts index 0f587800f99..eb440b2985e 100644 --- a/apps/sim/lib/execution/cancellation.test.ts +++ b/apps/sim/lib/execution/cancellation.test.ts @@ -1,15 +1,24 @@ import { redisConfigMock, redisConfigMockFns } from '@sim/testing' import { beforeEach, describe, expect, it, vi } from 'vitest' -const { mockRedisSet } = vi.hoisted(() => ({ +const { mockRedisSet, mockPublish, mockSubscribe } = vi.hoisted(() => ({ mockRedisSet: vi.fn(), + mockPublish: vi.fn(), + mockSubscribe: vi.fn(), })) const mockGetRedisClient = redisConfigMockFns.mockGetRedisClient vi.mock('@/lib/core/config/redis', () => redisConfigMock) +vi.mock('@/lib/events/pubsub', () => ({ + createPubSubChannel: () => ({ + publish: mockPublish, + subscribe: mockSubscribe, + dispose: vi.fn(), + }), +})) -import { markExecutionCancelled } from './cancellation' +import { getCancellationChannel, markExecutionCancelled } from './cancellation' import { abortManualExecution, registerManualExecutionAborter, @@ -49,6 +58,41 @@ describe('markExecutionCancelled', () => { reason: 'redis_write_failed', }) }) + + it('publishes even when the Redis write fails so local subscribers wake up', async () => { + mockRedisSet.mockRejectedValue(new Error('set failed')) + mockGetRedisClient.mockReturnValue({ set: mockRedisSet }) + + await markExecutionCancelled('execution-write-failed') + + expect(mockPublish).toHaveBeenCalledWith({ executionId: 'execution-write-failed' }) + }) + + it('publishes a cancellation event after a successful Redis write', async () => { + mockRedisSet.mockResolvedValue('OK') + mockGetRedisClient.mockReturnValue({ set: mockRedisSet }) + + await markExecutionCancelled('execution-2') + + expect(mockPublish).toHaveBeenCalledWith({ executionId: 'execution-2' }) + expect(mockRedisSet.mock.invocationCallOrder[0]).toBeLessThan( + mockPublish.mock.invocationCallOrder[0] + ) + }) + + it('publishes even when Redis is unavailable so local subscribers wake up', async () => { + mockGetRedisClient.mockReturnValue(null) + + await markExecutionCancelled('execution-3') + + expect(mockPublish).toHaveBeenCalledWith({ executionId: 'execution-3' }) + }) +}) + +describe('getCancellationChannel', () => { + it('returns the same channel instance across calls', () => { + expect(getCancellationChannel()).toBe(getCancellationChannel()) + }) }) describe('manual execution cancellation registry', () => { diff --git a/apps/sim/lib/execution/cancellation.ts b/apps/sim/lib/execution/cancellation.ts index 26273f8521b..ffa8ad9d444 100644 --- a/apps/sim/lib/execution/cancellation.ts +++ b/apps/sim/lib/execution/cancellation.ts @@ -1,10 +1,16 @@ import { createLogger } from '@sim/logger' import { getRedisClient } from '@/lib/core/config/redis' +import { createPubSubChannel, type PubSubChannel } from '@/lib/events/pubsub' const logger = createLogger('ExecutionCancellation') const EXECUTION_CANCEL_PREFIX = 'execution:cancel:' const EXECUTION_CANCEL_EXPIRY = 60 * 60 +const EXECUTION_CANCEL_CHANNEL = 'execution:cancel' + +export interface ExecutionCancelEvent { + executionId: string +} export type ExecutionCancellationRecordResult = | { durablyRecorded: true; reason: 'recorded' } @@ -13,36 +19,44 @@ export type ExecutionCancellationRecordResult = reason: 'redis_unavailable' | 'redis_write_failed' } +let sharedChannel: PubSubChannel | null = null + +export function getCancellationChannel(): PubSubChannel { + if (!sharedChannel) { + sharedChannel = createPubSubChannel({ + channel: EXECUTION_CANCEL_CHANNEL, + label: 'execution-cancel', + }) + } + return sharedChannel +} + export function isRedisCancellationEnabled(): boolean { return getRedisClient() !== null } -/** - * Mark an execution as cancelled in Redis. - * Returns whether the cancellation was durably recorded. - */ +/** Writes the durable key first, then publishes — so a late subscriber still sees the flag on backstop check. */ export async function markExecutionCancelled( executionId: string ): Promise { const redis = getRedisClient() if (!redis) { + getCancellationChannel().publish({ executionId }) return { durablyRecorded: false, reason: 'redis_unavailable' } } try { await redis.set(`${EXECUTION_CANCEL_PREFIX}${executionId}`, '1', 'EX', EXECUTION_CANCEL_EXPIRY) logger.info('Marked execution as cancelled', { executionId }) + getCancellationChannel().publish({ executionId }) return { durablyRecorded: true, reason: 'recorded' } } catch (error) { logger.error('Failed to mark execution as cancelled', { executionId, error }) + getCancellationChannel().publish({ executionId }) return { durablyRecorded: false, reason: 'redis_write_failed' } } } -/** - * Check if an execution has been cancelled via Redis. - * Returns false if Redis is not available (fallback to local abort signal). - */ export async function isExecutionCancelled(executionId: string): Promise { const redis = getRedisClient() if (!redis) { @@ -58,9 +72,6 @@ export async function isExecutionCancelled(executionId: string): Promise { const redis = getRedisClient() if (!redis) { diff --git a/apps/sim/providers/anthropic/core.ts b/apps/sim/providers/anthropic/core.ts index 864c9fc6b65..80f741f4ac1 100644 --- a/apps/sim/providers/anthropic/core.ts +++ b/apps/sim/providers/anthropic/core.ts @@ -571,7 +571,9 @@ export async function executeAnthropicProviderRequest( if (!tool) return null const { toolParams, executionParams } = prepareToolExecution(tool, toolArgs, request) - const result = await executeTool(toolName, executionParams) + const result = await executeTool(toolName, executionParams, { + signal: request.abortSignal, + }) const toolCallEndTime = Date.now() return { @@ -1003,7 +1005,10 @@ export async function executeAnthropicProviderRequest( if (!tool) return null const { toolParams, executionParams } = prepareToolExecution(tool, toolArgs, request) - const result = await executeTool(toolName, executionParams, true) + const result = await executeTool(toolName, executionParams, { + skipPostProcess: true, + signal: request.abortSignal, + }) const toolCallEndTime = Date.now() return { diff --git a/apps/sim/providers/azure-openai/index.ts b/apps/sim/providers/azure-openai/index.ts index 3ddc32f0874..a7b879891d8 100644 --- a/apps/sim/providers/azure-openai/index.ts +++ b/apps/sim/providers/azure-openai/index.ts @@ -347,7 +347,9 @@ async function executeChatCompletionsRequest( if (!tool) return null const { toolParams, executionParams } = prepareToolExecution(tool, toolArgs, request) - const result = await executeTool(toolName, executionParams) + const result = await executeTool(toolName, executionParams, { + signal: request.abortSignal, + }) const toolCallEndTime = Date.now() return { diff --git a/apps/sim/providers/bedrock/index.ts b/apps/sim/providers/bedrock/index.ts index d7f8d53ce24..e6ab3d572a9 100644 --- a/apps/sim/providers/bedrock/index.ts +++ b/apps/sim/providers/bedrock/index.ts @@ -566,7 +566,9 @@ export const bedrockProvider: ProviderConfig = { if (!tool) return null const { toolParams, executionParams } = prepareToolExecution(tool, toolArgs, request) - const result = await executeTool(toolName, executionParams) + const result = await executeTool(toolName, executionParams, { + signal: request.abortSignal, + }) const toolCallEndTime = Date.now() return { diff --git a/apps/sim/providers/cerebras/index.ts b/apps/sim/providers/cerebras/index.ts index 70929f23bfc..f5991be9c3c 100644 --- a/apps/sim/providers/cerebras/index.ts +++ b/apps/sim/providers/cerebras/index.ts @@ -263,7 +263,9 @@ export const cerebrasProvider: ProviderConfig = { if (!tool) return null const { toolParams, executionParams } = prepareToolExecution(tool, toolArgs, request) - const result = await executeTool(toolName, executionParams) + const result = await executeTool(toolName, executionParams, { + signal: request.abortSignal, + }) const toolCallEndTime = Date.now() return { diff --git a/apps/sim/providers/deepseek/index.ts b/apps/sim/providers/deepseek/index.ts index 4f9c641291d..e42592ebc4f 100644 --- a/apps/sim/providers/deepseek/index.ts +++ b/apps/sim/providers/deepseek/index.ts @@ -276,7 +276,9 @@ export const deepseekProvider: ProviderConfig = { if (!tool) return null const { toolParams, executionParams } = prepareToolExecution(tool, toolArgs, request) - const result = await executeTool(toolName, executionParams) + const result = await executeTool(toolName, executionParams, { + signal: request.abortSignal, + }) const toolCallEndTime = Date.now() return { diff --git a/apps/sim/providers/fireworks/index.ts b/apps/sim/providers/fireworks/index.ts index 8eb4e70e890..794ee3f0805 100644 --- a/apps/sim/providers/fireworks/index.ts +++ b/apps/sim/providers/fireworks/index.ts @@ -307,7 +307,9 @@ export const fireworksProvider: ProviderConfig = { if (!tool) return null const { toolParams, executionParams } = prepareToolExecution(tool, toolArgs, request) - const result = await executeTool(toolName, executionParams) + const result = await executeTool(toolName, executionParams, { + signal: request.abortSignal, + }) const toolCallEndTime = Date.now() return { diff --git a/apps/sim/providers/gemini/core.ts b/apps/sim/providers/gemini/core.ts index 577b46454fd..be1161ef09f 100644 --- a/apps/sim/providers/gemini/core.ts +++ b/apps/sim/providers/gemini/core.ts @@ -129,7 +129,9 @@ async function executeToolCallsBatch( try { const { toolParams, executionParams } = prepareToolExecution(tool, args, request) - const result = await executeTool(toolName, executionParams) + const result = await executeTool(toolName, executionParams, { + signal: request.abortSignal, + }) const toolCallEndTime = Date.now() const duration = toolCallEndTime - toolCallStartTime diff --git a/apps/sim/providers/groq/index.ts b/apps/sim/providers/groq/index.ts index 149a5dfb6d5..750c55d04ce 100644 --- a/apps/sim/providers/groq/index.ts +++ b/apps/sim/providers/groq/index.ts @@ -254,7 +254,9 @@ export const groqProvider: ProviderConfig = { if (!tool) return null const { toolParams, executionParams } = prepareToolExecution(tool, toolArgs, request) - const result = await executeTool(toolName, executionParams) + const result = await executeTool(toolName, executionParams, { + signal: request.abortSignal, + }) const toolCallEndTime = Date.now() return { diff --git a/apps/sim/providers/mistral/index.ts b/apps/sim/providers/mistral/index.ts index 1caa08bcf8f..33fb3a5b524 100644 --- a/apps/sim/providers/mistral/index.ts +++ b/apps/sim/providers/mistral/index.ts @@ -318,7 +318,9 @@ export const mistralProvider: ProviderConfig = { if (!tool) return null const { toolParams, executionParams } = prepareToolExecution(tool, toolArgs, request) - const result = await executeTool(toolName, executionParams) + const result = await executeTool(toolName, executionParams, { + signal: request.abortSignal, + }) const toolCallEndTime = Date.now() return { diff --git a/apps/sim/providers/ollama/index.ts b/apps/sim/providers/ollama/index.ts index 085cf967c59..52332aecdb2 100644 --- a/apps/sim/providers/ollama/index.ts +++ b/apps/sim/providers/ollama/index.ts @@ -327,7 +327,9 @@ export const ollamaProvider: ProviderConfig = { if (!tool) return null const { toolParams, executionParams } = prepareToolExecution(tool, toolArgs, request) - const result = await executeTool(toolName, executionParams) + const result = await executeTool(toolName, executionParams, { + signal: request.abortSignal, + }) const toolCallEndTime = Date.now() return { diff --git a/apps/sim/providers/openai/core.ts b/apps/sim/providers/openai/core.ts index 11a465f2abc..6946f0c0fa3 100644 --- a/apps/sim/providers/openai/core.ts +++ b/apps/sim/providers/openai/core.ts @@ -476,7 +476,9 @@ export async function executeResponsesProviderRequest( } const { toolParams, executionParams } = prepareToolExecution(tool, toolArgs, request) - const result = await executeTool(toolName, executionParams) + const result = await executeTool(toolName, executionParams, { + signal: request.abortSignal, + }) const toolCallEndTime = Date.now() return { diff --git a/apps/sim/providers/openrouter/index.ts b/apps/sim/providers/openrouter/index.ts index 1f897ffbc7a..d3d2535b43d 100644 --- a/apps/sim/providers/openrouter/index.ts +++ b/apps/sim/providers/openrouter/index.ts @@ -308,7 +308,9 @@ export const openRouterProvider: ProviderConfig = { if (!tool) return null const { toolParams, executionParams } = prepareToolExecution(tool, toolArgs, request) - const result = await executeTool(toolName, executionParams) + const result = await executeTool(toolName, executionParams, { + signal: request.abortSignal, + }) const toolCallEndTime = Date.now() return { diff --git a/apps/sim/providers/vllm/index.ts b/apps/sim/providers/vllm/index.ts index 85756e47310..2de3c695116 100644 --- a/apps/sim/providers/vllm/index.ts +++ b/apps/sim/providers/vllm/index.ts @@ -379,7 +379,9 @@ export const vllmProvider: ProviderConfig = { if (!tool) return null const { toolParams, executionParams } = prepareToolExecution(tool, toolArgs, request) - const result = await executeTool(toolName, executionParams) + const result = await executeTool(toolName, executionParams, { + signal: request.abortSignal, + }) const toolCallEndTime = Date.now() return { diff --git a/apps/sim/providers/xai/index.ts b/apps/sim/providers/xai/index.ts index 6f33bad536c..0d2dcb7b27c 100644 --- a/apps/sim/providers/xai/index.ts +++ b/apps/sim/providers/xai/index.ts @@ -281,7 +281,9 @@ export const xAIProvider: ProviderConfig = { } const { toolParams, executionParams } = prepareToolExecution(tool, toolArgs, request) - const result = await executeTool(toolName, executionParams) + const result = await executeTool(toolName, executionParams, { + signal: request.abortSignal, + }) const toolCallEndTime = Date.now() return { diff --git a/apps/sim/tools/index.test.ts b/apps/sim/tools/index.test.ts index 349b167c23e..1a27cc552c8 100644 --- a/apps/sim/tools/index.test.ts +++ b/apps/sim/tools/index.test.ts @@ -15,6 +15,7 @@ import { inputValidationMockFns, type MockFetchResponse, } from '@sim/testing' +import { sleep } from '@sim/utils/helpers' import { afterEach, beforeEach, describe, expect, it, vi } from 'vitest' // Hoisted mock state - these are available to vi.mock factories @@ -531,7 +532,7 @@ describe('executeTool Function', () => { code: 'return 1', timeout: 5000, }, - true + { skipPostProcess: true } ) expect(result.success).toBe(true) @@ -560,7 +561,7 @@ describe('executeTool Function', () => { code: 'return { result: "hello world" }', language: 'javascript', }, - true + { skipPostProcess: true } ) // Skip proxy tools.function_execute = originalFunctionTool @@ -582,13 +583,85 @@ describe('executeTool Function', () => { vi.restoreAllMocks() }) + it('aborts the internal fetch when the caller signal is aborted', async () => { + const originalFunctionTool = { ...tools.function_execute } + tools.function_execute = { + ...tools.function_execute, + transformResponse: vi.fn().mockResolvedValue({ success: true, output: {} }), + } + + let observedSignal: AbortSignal | undefined + global.fetch = Object.assign( + vi.fn().mockImplementation(async (_url: string, init: RequestInit) => { + observedSignal = init.signal as AbortSignal + return new Promise((_resolve, reject) => { + observedSignal!.addEventListener('abort', () => { + const err = new Error('aborted') + err.name = 'AbortError' + reject(err) + }) + }) + }), + { preconnect: vi.fn() } + ) as typeof fetch + + const callerController = new AbortController() + const resultPromise = executeTool( + 'function_execute', + { code: 'return 1', timeout: 5000 }, + { skipPostProcess: true, signal: callerController.signal } + ) + + await sleep(1) + callerController.abort() + const result = await resultPromise + + expect(observedSignal?.aborted).toBe(true) + expect(result.success).toBe(false) + expect(result.error).not.toMatch(/timed out/i) + + tools.function_execute = originalFunctionTool + }) + + it('aborts immediately when the caller signal is already aborted at call time', async () => { + const originalFunctionTool = { ...tools.function_execute } + tools.function_execute = { + ...tools.function_execute, + transformResponse: vi.fn().mockResolvedValue({ success: true, output: {} }), + } + + let observedAborted = false + global.fetch = Object.assign( + vi.fn().mockImplementation(async (_url: string, init: RequestInit) => { + observedAborted = (init.signal as AbortSignal).aborted + const err = new Error('aborted') + err.name = 'AbortError' + throw err + }), + { preconnect: vi.fn() } + ) as typeof fetch + + const controller = new AbortController() + controller.abort() + const result = await executeTool( + 'function_execute', + { code: 'return 1', timeout: 5000 }, + { skipPostProcess: true, signal: controller.signal } + ) + + expect(observedAborted).toBe(true) + expect(result.success).toBe(false) + + tools.function_execute = originalFunctionTool + }) + it('should add timing information to results', async () => { const result = await executeTool( 'http_request', { url: 'https://api.example.com/data', }, - true + { skipPostProcess: true } ) expect(result.timing).toBeDefined() @@ -662,7 +735,7 @@ describe('Automatic Internal Route Detection', () => { { preconnect: vi.fn() } ) as typeof fetch - const result = await executeTool('test_internal_tool', {}, false) + const result = await executeTool('test_internal_tool', {}) expect(result.success).toBe(true) expect(result.output.result).toBe('Internal route success') @@ -924,8 +997,7 @@ describe('Copilot File Parameter Normalization', () => { const result = await executeTool( 'test_single_file_tool', { attachment: 'wf_123' }, - false, - context + { executionContext: context } ) expect(result.success).toBe(true) @@ -1014,8 +1086,7 @@ describe('Copilot File Parameter Normalization', () => { const result = await executeTool( 'test_file_array_tool', { attachments: ['wf_1', partialFileObject, existingFileObject, 'wf_2'] }, - false, - context + { executionContext: context } ) expect(result.success).toBe(true) @@ -1048,8 +1119,7 @@ describe('Copilot File Parameter Normalization', () => { const result = await executeTool( 'test_single_file_tool', { attachment: 'wf_123' }, - false, - context + { executionContext: context } ) expect(result.success).toBe(true) @@ -1079,7 +1149,7 @@ describe('Copilot OAuth Credential Enforcement', () => { copilotToolExecution: true, } as any) - const result = await executeTool('gmail_read', { maxResults: 5 }, false, context) + const result = await executeTool('gmail_read', { maxResults: 5 }, { executionContext: context }) expect(result.success).toBe(false) expect(result.error).toContain('credentialId') @@ -1123,7 +1193,7 @@ describe('Centralized Error Handling', () => { const result = await executeTool( 'function_execute', { code: 'return { result: "test" }' }, - true + { skipPostProcess: true } ) expect(result.success).toBe(false) @@ -1224,7 +1294,7 @@ describe('Centralized Error Handling', () => { const result = await executeTool( 'function_execute', { code: 'return { result: "test" }' }, - true + { skipPostProcess: true } ) expect(result.success).toBe(false) @@ -1254,7 +1324,7 @@ describe('Centralized Error Handling', () => { const result = await executeTool( 'function_execute', { code: 'return { result: "test" }' }, - true + { skipPostProcess: true } ) expect(result.success).toBe(false) @@ -1283,7 +1353,7 @@ describe('Centralized Error Handling', () => { const result = await executeTool( 'function_execute', { code: 'return { result: "test" }' }, - true + { skipPostProcess: true } ) expect(result.success).toBe(false) @@ -1361,7 +1431,11 @@ describe('MCP Tool Execution', () => { const mockContext = createToolExecutionContext() - const result = await executeTool('mcp-123-list_files', { path: '/test' }, false, mockContext) + const result = await executeTool( + 'mcp-123-list_files', + { path: '/test' }, + { executionContext: mockContext } + ) expect(result.success).toBe(true) expect(result.output).toBeDefined() @@ -1391,7 +1465,11 @@ describe('MCP Tool Execution', () => { const mockContext2 = createToolExecutionContext() - await executeTool('mcp-timestamp123-complex-tool-name', { param: 'value' }, false, mockContext2) + await executeTool( + 'mcp-timestamp123-complex-tool-name', + { param: 'value' }, + { executionContext: mockContext2 } + ) }) it('should handle MCP block arguments format', async () => { @@ -1422,8 +1500,7 @@ describe('MCP Tool Execution', () => { server: 'mcp-123', tool: 'read_file', }, - false, - mockContext3 + { executionContext: mockContext3 } ) }) @@ -1459,8 +1536,7 @@ describe('MCP Tool Execution', () => { workspaceId: 'workspace-456', requestId: 'req-123', }, - false, - mockContext4 + { executionContext: mockContext4 } ) }) @@ -1484,8 +1560,7 @@ describe('MCP Tool Execution', () => { const result = await executeTool( 'mcp-123-nonexistent_tool', { param: 'value' }, - false, - mockContext5 + { executionContext: mockContext5 } ) expect(result.success).toBe(false) @@ -1503,7 +1578,11 @@ describe('MCP Tool Execution', () => { it('should handle invalid MCP tool ID format', async () => { const mockContext6 = createToolExecutionContext() - const result = await executeTool('invalid-mcp-id', { param: 'value' }, false, mockContext6) + const result = await executeTool( + 'invalid-mcp-id', + { param: 'value' }, + { executionContext: mockContext6 } + ) expect(result.success).toBe(false) expect(result.error).toContain('Tool not found') @@ -1516,7 +1595,11 @@ describe('MCP Tool Execution', () => { const mockContext7 = createToolExecutionContext() - const result = await executeTool('mcp-123-test_tool', { param: 'value' }, false, mockContext7) + const result = await executeTool( + 'mcp-123-test_tool', + { param: 'value' }, + { executionContext: mockContext7 } + ) expect(result.success).toBe(false) expect(result.error).toContain('Network error') @@ -1827,7 +1910,7 @@ describe('Hosted Key Injection', () => { ) as typeof fetch const mockContext = createToolExecutionContext() - await executeTool('test_no_hosting', {}, false, mockContext) + await executeTool('test_no_hosting', {}, { executionContext: mockContext }) // BYOK should not be called since there's no hosting config expect(mockGetBYOKKey).not.toHaveBeenCalled() @@ -1890,7 +1973,7 @@ describe('Hosted Key Injection', () => { ) as typeof fetch const mockContext = createToolExecutionContext() - await executeTool('test_with_hosting', {}, false, mockContext) + await executeTool('test_with_hosting', {}, { executionContext: mockContext }) // With isHosted=false, BYOK won't be called - this is expected behavior // The test documents the current behavior @@ -2119,7 +2202,7 @@ describe('Rate Limiting and Retry Logic', () => { ) as typeof fetch const mockContext = createToolExecutionContext() - const resultPromise = executeTool('test_rate_limit', {}, false, mockContext) + const resultPromise = executeTool('test_rate_limit', {}, { executionContext: mockContext }) // Advance timers to skip retry delays (1s + 2s exponential backoff) await vi.advanceTimersByTimeAsync(10000) @@ -2180,7 +2263,11 @@ describe('Rate Limiting and Retry Logic', () => { ) as typeof fetch const mockContext = createToolExecutionContext() - const resultPromise = executeTool('test_persistent_rate_limit', {}, false, mockContext) + const resultPromise = executeTool( + 'test_persistent_rate_limit', + {}, + { executionContext: mockContext } + ) // Advance timers to skip retry delays (1s + 2s + 4s exponential backoff) await vi.advanceTimersByTimeAsync(15000) @@ -2243,7 +2330,7 @@ describe('Rate Limiting and Retry Logic', () => { ) as typeof fetch const mockContext = createToolExecutionContext() - const result = await executeTool('test_no_retry', {}, false, mockContext) + const result = await executeTool('test_no_retry', {}, { executionContext: mockContext }) // Should fail immediately without retries expect(result.success).toBe(false) @@ -2299,7 +2386,7 @@ describe('stripInternalFields Safety', () => { { preconnect: vi.fn() } ) as typeof fetch - const result = await executeTool('test_string_output', {}, true) + const result = await executeTool('test_string_output', {}, { skipPostProcess: true }) expect(result.success).toBe(true) expect(result.output).toBe(stringOutput) @@ -2341,7 +2428,7 @@ describe('stripInternalFields Safety', () => { { preconnect: vi.fn() } ) as typeof fetch - const result = await executeTool('test_array_output', {}, true) + const result = await executeTool('test_array_output', {}, { skipPostProcess: true }) expect(result.success).toBe(true) expect(Array.isArray(result.output)).toBe(true) @@ -2381,7 +2468,7 @@ describe('stripInternalFields Safety', () => { { preconnect: vi.fn() } ) as typeof fetch - const result = await executeTool('test_strip_internal', {}, true) + const result = await executeTool('test_strip_internal', {}, { skipPostProcess: true }) expect(result.success).toBe(true) expect(result.output.result).toBe('ok') @@ -2484,7 +2571,7 @@ describe('Cost Field Handling', () => { const mockContext = createToolExecutionContext({ userId: 'user-123', } as any) - const result = await executeTool('test_cost_per_request', {}, false, mockContext) + const result = await executeTool('test_cost_per_request', {}, { executionContext: mockContext }) expect(result.success).toBe(true) // Note: In test environment, hosted key injection may not work due to env mocking complexity. @@ -2549,8 +2636,7 @@ describe('Cost Field Handling', () => { const result = await executeTool( 'test_no_hosted_cost', { apiKey: 'user-api-key' }, - false, - mockContext + { executionContext: mockContext } ) expect(result.success).toBe(true) @@ -2617,8 +2703,7 @@ describe('Cost Field Handling', () => { const result = await executeTool( 'test_custom_pricing_cost', { mode: 'advanced' }, - false, - mockContext + { executionContext: mockContext } ) expect(result.success).toBe(true) diff --git a/apps/sim/tools/index.ts b/apps/sim/tools/index.ts index 090b30c628c..de5adb1dc7a 100644 --- a/apps/sim/tools/index.ts +++ b/apps/sim/tools/index.ts @@ -711,6 +711,12 @@ async function processFileOutputs( } } +export interface ExecuteToolOptions { + skipPostProcess?: boolean + executionContext?: ExecutionContext + signal?: AbortSignal +} + /** * Execute a tool by making the appropriate HTTP request * All requests go directly - internal routes use regular fetch, external use SSRF-protected fetch @@ -718,9 +724,9 @@ async function processFileOutputs( export async function executeTool( toolId: string, params: Record, - skipPostProcess = false, - executionContext?: ExecutionContext + options: ExecuteToolOptions = {} ): Promise { + const { skipPostProcess = false, executionContext, signal } = options // Capture start time for precise timing const startTime = new Date() const startTimeISO = startTime.toISOString() @@ -813,7 +819,8 @@ export async function executeTool( params, executionContext, requestId, - startTimeISO + startTimeISO, + signal ) } else { // For built-in tools, use the synchronous version @@ -1010,13 +1017,13 @@ export async function executeTool( // Execute the tool request directly (internal routes use regular fetch, external use SSRF-protected fetch) // Wrap with retry logic for hosted keys to handle rate limiting due to higher usage const result = hostedKeyInfo.isUsingHostedKey - ? await executeWithRetry(() => executeToolRequest(toolId, tool, contextParams), { + ? await executeWithRetry(() => executeToolRequest(toolId, tool, contextParams, signal), { requestId, toolId, envVarName: hostedKeyInfo.envVarName!, executionContext, }) - : await executeToolRequest(toolId, tool, contextParams) + : await executeToolRequest(toolId, tool, contextParams, signal) // Apply post-processing if available and not skipped let finalResult = result @@ -1300,7 +1307,8 @@ function parseRetryAfterHeader(header: string | null): number { async function executeToolRequest( toolId: string, tool: ToolConfig, - params: Record + params: Record, + signal?: AbortSignal ): Promise { const requestId = generateRequestId() @@ -1397,6 +1405,16 @@ async function executeToolRequest( timeout ) + let abortListener: (() => void) | null = null + if (signal) { + if (signal.aborted) { + controller.abort('caller_aborted') + } else { + abortListener = () => controller.abort('caller_aborted') + signal.addEventListener('abort', abortListener, { once: true }) + } + } + try { response = await fetch(fullUrl, { method: requestParams.method, @@ -1406,11 +1424,19 @@ async function executeToolRequest( }) } catch (error) { if (error instanceof Error && error.name === 'AbortError') { + // Distinguish caller cancellation from local timeout: rethrow the AbortError + // when the caller's signal triggered the abort so cancellation propagates as-is. + if (signal?.aborted) { + throw error + } throw new Error(`Request timed out after ${timeout}ms`) } throw error } finally { clearTimeout(timeoutId) + if (abortListener) { + signal?.removeEventListener('abort', abortListener) + } } } else { const urlValidation = await validateUrlWithDNS(fullUrl, 'toolUrl') @@ -1423,6 +1449,7 @@ async function executeToolRequest( headers: headersRecord, body: requestParams.body ?? undefined, timeout: requestParams.timeout, + signal, }) const responseHeaders = new Headers(secureResponse.headers.toRecord()) @@ -1702,7 +1729,8 @@ async function executeMcpTool( params: Record, executionContext?: ExecutionContext, requestId?: string, - startTimeISO?: string + startTimeISO?: string, + signal?: AbortSignal ): Promise { const actualRequestId = requestId || generateRequestId() const actualStartTime = startTimeISO || new Date().toISOString() @@ -1795,6 +1823,7 @@ async function executeMcpTool( method: 'POST', headers, body, + signal, }) const endTime = new Date()