From 4945d55671aed3893c910a10c5a9f472f9631090 Mon Sep 17 00:00:00 2001 From: Theodore Li Date: Sat, 16 May 2026 22:51:17 -0700 Subject: [PATCH 1/3] fix(redis): apply TLS SNI override to pub/sub clients (#4638) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * fix(redis): apply TLS SNI override to pub/sub clients too Pub/sub clients in lib/events/pubsub.ts build their own ioredis instances directly via new Redis(redisUrl, ...) because pub/sub needs dedicated connections (can't multiplex on the shared client from getRedisClient). That path skipped the resolveTlsOptions helper added for trigger.dev's PrivateLink VPCE IP, so every pub/sub channel hit 'Hostname/IP does not match certificate's altnames' on connect. Export the helper as resolveRedisTlsOptions and use it from pubsub.ts. * refactor(redis): share connection defaults via one helper Extract keepAlive/connectTimeout/enableOfflineQueue + TLS SNI into a single getRedisConnectionDefaults helper. Main client and pub/sub clients both spread it; caller-specific retry/timeout policy stays per-caller (pub/sub still needs maxRetriesPerRequest: null and a different retry strategy for SUBSCRIBE). * fix(pubsub): surface TLS config errors instead of silently degrading resolveRedisTlsOptions (via getRedisConnectionDefaults) throws if REDIS_TLS_SERVERNAME is missing for an IP-based rediss:// URL. Calling it inside the constructor let createPubSubChannel's try/catch swallow the error and fall back to in-process EventEmitter — silent cross-replica pub/sub breakage in prod. Resolve defaults before the try so config errors propagate; only catch genuine runtime construction failures. --- apps/sim/lib/core/config/redis.ts | 28 +++++++++++++++++++++------- apps/sim/lib/events/pubsub.ts | 30 ++++++++++++++++-------------- 2 files changed, 37 insertions(+), 21 deletions(-) diff --git a/apps/sim/lib/core/config/redis.ts b/apps/sim/lib/core/config/redis.ts index aea707daa00..04f976b44ae 100644 --- a/apps/sim/lib/core/config/redis.ts +++ b/apps/sim/lib/core/config/redis.ts @@ -1,7 +1,7 @@ import { createLogger } from '@sim/logger' import { toError } from '@sim/utils/errors' import { randomFloat } from '@sim/utils/random' -import Redis from 'ioredis' +import Redis, { type RedisOptions } from 'ioredis' import { env } from '@/lib/core/config/env' const logger = createLogger('Redis') @@ -16,7 +16,7 @@ const redisUrl = env.REDIS_URL * * For DNS hosts: no override needed, default verification works. */ -function resolveTlsOptions(url: string | undefined): { servername: string } | undefined { +function resolveRedisTlsOptions(url: string | undefined): { servername: string } | undefined { if (!url) return undefined let parsed: URL try { @@ -37,6 +37,23 @@ function resolveTlsOptions(url: string | undefined): { servername: string } | un return { servername: env.REDIS_TLS_SERVERNAME } } +/** + * Shared connection defaults — keepAlive, connectTimeout, enableOfflineQueue, + * and TLS SNI when REDIS_URL targets an IP. Every Redis client we open should + * spread this; callers add their own retry / timeout policy on top. + */ +export function getRedisConnectionDefaults( + url: string | undefined +): Pick { + const tls = resolveRedisTlsOptions(url) + return { + keepAlive: 1000, + connectTimeout: 10000, + enableOfflineQueue: true, + ...(tls ? { tls } : {}), + } +} + let globalRedisClient: Redis | null = null let pingFailures = 0 let pingInterval: NodeJS.Timeout | null = null @@ -117,18 +134,15 @@ export function getRedisClient(): Redis | null { if (globalRedisClient) return globalRedisClient // Outside the try/catch so config errors aren't silently swallowed. - const tls = resolveTlsOptions(redisUrl) + const defaults = getRedisConnectionDefaults(redisUrl) try { logger.info('Initializing Redis client') globalRedisClient = new Redis(redisUrl, { - keepAlive: 1000, - connectTimeout: 10000, + ...defaults, commandTimeout: 5000, maxRetriesPerRequest: 5, - enableOfflineQueue: true, - ...(tls ? { tls } : {}), retryStrategy: (times) => { if (times > 10) { diff --git a/apps/sim/lib/events/pubsub.ts b/apps/sim/lib/events/pubsub.ts index b299eafc055..f866d8d1459 100644 --- a/apps/sim/lib/events/pubsub.ts +++ b/apps/sim/lib/events/pubsub.ts @@ -9,6 +9,7 @@ import { EventEmitter } from 'events' import { createLogger } from '@sim/logger' import Redis, { type RedisOptions } from 'ioredis' import { env } from '@/lib/core/config/env' +import { getRedisConnectionDefaults } from '@/lib/core/config/redis' const logger = createLogger('PubSub') @@ -31,13 +32,12 @@ class RedisPubSubChannel implements PubSubChannel { constructor( redisUrl: string, + connectionDefaults: ReturnType, private config: PubSubChannelConfig ) { const commonOpts = { - keepAlive: 1000, - connectTimeout: 10000, + ...connectionDefaults, maxRetriesPerRequest: null, - enableOfflineQueue: true, retryStrategy: (times: number) => { if (times > 10) return 30000 return Math.min(times * 500, 5000) @@ -139,16 +139,18 @@ class LocalPubSubChannel implements PubSubChannel { export function createPubSubChannel(config: PubSubChannelConfig): PubSubChannel { const redisUrl = env.REDIS_URL - - if (redisUrl) { - try { - logger.info(`${config.label}: Using Redis`) - return new RedisPubSubChannel(redisUrl, config) - } catch (err) { - logger.error(`Failed to create Redis ${config.label}, falling back to local:`, err) - return new LocalPubSubChannel(config) - } + if (!redisUrl) return new LocalPubSubChannel(config) + + // Resolve config-derived defaults outside the try so a missing + // REDIS_TLS_SERVERNAME (config error) surfaces instead of silently degrading + // to the in-process EventEmitter — that would break cross-replica pub/sub. + const connectionDefaults = getRedisConnectionDefaults(redisUrl) + + try { + logger.info(`${config.label}: Using Redis`) + return new RedisPubSubChannel(redisUrl, connectionDefaults, config) + } catch (err) { + logger.error(`Failed to create Redis ${config.label}, falling back to local:`, err) + return new LocalPubSubChannel(config) } - - return new LocalPubSubChannel(config) } From fd12137c28e9a01a3bf750c10b4b6dfb7b45d7ea Mon Sep 17 00:00:00 2001 From: Waleed Date: Sat, 16 May 2026 23:05:41 -0700 Subject: [PATCH 2/3] improvement(copilot): drop unused columns from mothership chat detail reads (#4640) --- .../mothership/chats/[chatId]/route.test.ts | 2 +- .../api/mothership/chats/[chatId]/route.ts | 4 +- apps/sim/lib/copilot/chat/lifecycle.ts | 65 +++++++++++++++++-- apps/sim/lib/copilot/chat/process-contents.ts | 4 +- 4 files changed, 65 insertions(+), 10 deletions(-) diff --git a/apps/sim/app/api/mothership/chats/[chatId]/route.test.ts b/apps/sim/app/api/mothership/chats/[chatId]/route.test.ts index 3d54f4fad00..7e96d476220 100644 --- a/apps/sim/app/api/mothership/chats/[chatId]/route.test.ts +++ b/apps/sim/app/api/mothership/chats/[chatId]/route.test.ts @@ -47,8 +47,8 @@ vi.mock('drizzle-orm', () => ({ vi.mock('@/lib/copilot/request/http', () => copilotHttpMock) vi.mock('@/lib/copilot/chat/lifecycle', () => ({ - getAccessibleCopilotChat: mockGetAccessibleCopilotChat, getAccessibleCopilotChatAuth: mockGetAccessibleCopilotChat, + getAccessibleCopilotChatWithMessages: mockGetAccessibleCopilotChat, })) vi.mock('@/lib/copilot/chat/stream-liveness', () => ({ diff --git a/apps/sim/app/api/mothership/chats/[chatId]/route.ts b/apps/sim/app/api/mothership/chats/[chatId]/route.ts index 121a01b5fde..2d7c20c1b1f 100644 --- a/apps/sim/app/api/mothership/chats/[chatId]/route.ts +++ b/apps/sim/app/api/mothership/chats/[chatId]/route.ts @@ -13,8 +13,8 @@ import { parseRequest } from '@/lib/api/server' import { getLatestRunForStream } from '@/lib/copilot/async-runs/repository' import { buildEffectiveChatTranscript } from '@/lib/copilot/chat/effective-transcript' import { - getAccessibleCopilotChat, getAccessibleCopilotChatAuth, + getAccessibleCopilotChatWithMessages, } from '@/lib/copilot/chat/lifecycle' import { normalizeMessage } from '@/lib/copilot/chat/persisted-message' import { reconcileChatStreamMarkers } from '@/lib/copilot/chat/stream-liveness' @@ -45,7 +45,7 @@ export const GET = withRouteHandler( if (!paramsResult.success) return paramsResult.response const { chatId } = paramsResult.data.params - const chat = await getAccessibleCopilotChat(chatId, userId) + const chat = await getAccessibleCopilotChatWithMessages(chatId, userId) if (!chat || chat.type !== 'mothership') { return NextResponse.json({ success: false, error: 'Chat not found' }, { status: 404 }) } diff --git a/apps/sim/lib/copilot/chat/lifecycle.ts b/apps/sim/lib/copilot/chat/lifecycle.ts index 5224ffdd611..01ec3f5b179 100644 --- a/apps/sim/lib/copilot/chat/lifecycle.ts +++ b/apps/sim/lib/copilot/chat/lifecycle.ts @@ -15,7 +15,7 @@ const logger = createLogger('CopilotChatLifecycle') export interface ChatLoadResult { chatId: string - chat: typeof copilotChats.$inferSelect | null + chat: CopilotChatDetailRow | null conversationHistory: unknown[] isNew: boolean } @@ -34,11 +34,43 @@ const copilotChatAuthColumns = { type: copilotChats.type, } as const +/** + * Column set for chat-detail callers that need the conversation transcript but + * not the copilot-only TOAST-able fields (`previewYaml`, `planArtifact`, + * `config`) or unused metadata (`model`, `pinned`, `lastSeenAt`). Selecting + * only these columns avoids the Postgres detoast cost on the dropped fields, + * which dominates latency for chats with large message histories. + */ +const copilotChatDetailColumns = { + ...copilotChatAuthColumns, + title: copilotChats.title, + messages: copilotChats.messages, + conversationId: copilotChats.conversationId, + resources: copilotChats.resources, + createdAt: copilotChats.createdAt, + updatedAt: copilotChats.updatedAt, +} as const + type CopilotChatAuthRow = Pick< typeof copilotChats.$inferSelect, 'id' | 'userId' | 'workflowId' | 'workspaceId' | 'type' > +export type CopilotChatDetailRow = Pick< + typeof copilotChats.$inferSelect, + | 'id' + | 'userId' + | 'workflowId' + | 'workspaceId' + | 'type' + | 'title' + | 'messages' + | 'conversationId' + | 'resources' + | 'createdAt' + | 'updatedAt' +> + async function authorizeCopilotChatRow( chat: T | undefined, chatId: string, @@ -99,8 +131,10 @@ export async function getAccessibleCopilotChatAuth( /** * Load the full copilot chat row after authorization. Use this only when the - * caller actually consumes the heavy columns (`messages`, `planArtifact`, - * `config`, etc.) — for example, chat resume or the GET-by-id endpoint. + * caller actually consumes copilot-only TOAST-able columns (`previewYaml`, + * `planArtifact`, `config`) or other extended metadata — for example the + * legacy copilot chat detail endpoint. Mothership chats and other consumers + * that only need the transcript should prefer `getAccessibleCopilotChatWithMessages`. */ export async function getAccessibleCopilotChat(chatId: string, userId: string) { const [chat] = await db @@ -112,6 +146,27 @@ export async function getAccessibleCopilotChat(chatId: string, userId: string) { return authorizeCopilotChatRow(chat, chatId, userId) } +/** + * Load a copilot chat with the conversation transcript and resources after + * authorization, omitting copilot-only TOAST-able fields (`previewYaml`, + * `planArtifact`, `config`) and unused metadata (`model`, `pinned`, + * `lastSeenAt`). Use this for the mothership chat detail endpoint and the + * shared `resolveOrCreateChat` path — every column read here is consumed + * downstream, and dropping the others avoids per-request detoast overhead. + */ +export async function getAccessibleCopilotChatWithMessages( + chatId: string, + userId: string +): Promise { + const [chat] = await db + .select(copilotChatDetailColumns) + .from(copilotChats) + .where(and(eq(copilotChats.id, chatId), eq(copilotChats.userId, userId))) + .limit(1) + + return authorizeCopilotChatRow(chat, chatId, userId) +} + /** * Resolve or create a copilot chat session. * If chatId is provided, loads the existing chat. Otherwise creates a new one. @@ -132,7 +187,7 @@ export async function resolveOrCreateChat(params: { } if (chatId) { - const chat = await getAccessibleCopilotChat(chatId, userId) + const chat = await getAccessibleCopilotChatWithMessages(chatId, userId) if (chat) { if (workflowId && chat.workflowId !== workflowId) { @@ -189,7 +244,7 @@ export async function resolveOrCreateChat(params: { messages: [], lastSeenAt: now, }) - .returning() + .returning(copilotChatDetailColumns) if (!newChat) { logger.warn('Failed to create new copilot chat row', { userId, workflowId, workspaceId }) diff --git a/apps/sim/lib/copilot/chat/process-contents.ts b/apps/sim/lib/copilot/chat/process-contents.ts index dd3cdccc3d5..21b6842cd6f 100644 --- a/apps/sim/lib/copilot/chat/process-contents.ts +++ b/apps/sim/lib/copilot/chat/process-contents.ts @@ -228,8 +228,8 @@ async function processPastChatFromDb( currentWorkspaceId?: string ): Promise { try { - const { getAccessibleCopilotChat } = await import('./lifecycle') - const chat = await getAccessibleCopilotChat(chatId, userId) + const { getAccessibleCopilotChatWithMessages } = await import('./lifecycle') + const chat = await getAccessibleCopilotChatWithMessages(chatId, userId) if (!chat) { return null } From 08eeecbebe5d1b632b8d7bc11b28f7594b09aeb9 Mon Sep 17 00:00:00 2001 From: Waleed Date: Sat, 16 May 2026 23:05:52 -0700 Subject: [PATCH 3/3] fix(security): KB fileUrl LFI, MCP/Agiloft SSRF pinning, form OTP, KB authz (#4639) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * fix(security): KB fileUrl LFI, MCP/Agiloft SSRF pinning, form OTP, KB authz * fix(otp): don't leak caught error.message; fail-closed on DB retry exhaust - Chat/form OTP routes: replace `error.message || fallback` with generic `Failed to process request` in 500 responses (logger still captures detail). - otp.ts incrementOTPAttempts DB path: on MAX_RETRIES exhaustion, delete the verification row and return `'locked'` instead of trusting a possibly- undercounted final read. Co-Authored-By: Claude Opus 4.7 * fix(mcp): use undici fetch directly in pinned-fetch for typed dispatcher Replace `globalThis.fetch` + double-cast with `undici.fetch` so the `dispatcher` option is part of the real type contract. This guarantees pinning won't silently break if a future runtime swaps the underlying fetch implementation. Co-Authored-By: Claude Opus 4.7 * fix(build): keep agiloft/grafana tool configs client-safe Tool config files are statically reachable from the client bundle (via tools/registry.ts → tools/{service}/index.ts). Importing `@/lib/core/security/input-validation.server` from these files pulled `node:dns/promises` into the Turbopack client bundle and broke the build. Split agiloft utils into client-safe (`utils.ts`, plain fetch + sync `validateExternalUrl`) and server-only (`utils.server.ts`, DNS-pinned variants). Routes that need TOCTOU protection import the pinned helpers; the executor-side tool path falls back to sync URL validation (matches the supabase precedent and pre-PR baseline). Grafana update tools likewise switch from `secureFetchWithValidation` (server-only) to inline sync `validateExternalUrl` + plain fetch. Co-Authored-By: Claude Opus 4.7 * fix(knowledge): case-insensitive scheme checks for fileUrl Boundary schema accepted uppercase schemes (e.g. HTTPS://, DATA:) via the case-insensitive http regex, but the processor's case-sensitive startsWith('data:') / startsWith('http') / startsWith('https://') checks rejected them with a confusing "Unsupported fileUrl scheme" error. Aligns processor checks to the schema using case-insensitive regex per RFC 3986 §3.1. Co-Authored-By: Claude Opus 4.7 * fix(mcp): annotate undici/DOM type-bridge double-casts in pinned-fetch Strict audit was failing on two new `as unknown as` casts in pinned-fetch.ts. They bridge DOM `RequestInit`/`Response` ↔ undici equivalents (structurally compatible at runtime since Node's global fetch is undici) and are required to satisfy the FetchLike contract. Annotate so they count as documented exemptions instead of new violations. Co-Authored-By: Claude Opus 4.7 --------- Co-authored-by: Claude Opus 4.7 --- .../app/api/chat/[identifier]/otp/route.ts | 237 +----- .../api/form/[identifier]/otp/route.test.ts | 695 ++++++++++++++++++ .../app/api/form/[identifier]/otp/route.ts | 261 +++++++ apps/sim/app/api/form/utils.test.ts | 10 +- apps/sim/app/api/form/utils.ts | 2 +- .../[id]/documents/upsert/route.test.ts | 111 +++ apps/sim/app/api/knowledge/[id]/route.test.ts | 53 +- apps/sim/app/api/knowledge/[id]/route.ts | 8 +- apps/sim/app/api/knowledge/route.test.ts | 17 + apps/sim/app/api/knowledge/route.ts | 5 + .../api/mcp/servers/test-connection/route.ts | 12 +- .../api/tools/agiloft/attach/route.test.ts | 144 ++++ .../sim/app/api/tools/agiloft/attach/route.ts | 26 +- .../api/tools/agiloft/retrieve/route.test.ts | 163 ++++ .../app/api/tools/agiloft/retrieve/route.ts | 26 +- .../[identifier]/components/email-auth.tsx | 284 +++++++ .../app/form/[identifier]/components/index.ts | 1 + apps/sim/app/form/[identifier]/form.tsx | 5 + apps/sim/hooks/queries/forms.ts | 32 + apps/sim/lib/api/contracts/forms.ts | 38 + .../lib/api/contracts/knowledge/documents.ts | 5 +- .../api/contracts/knowledge/shared.test.ts | 57 ++ .../sim/lib/api/contracts/knowledge/shared.ts | 15 + apps/sim/lib/core/security/otp.ts | 251 +++++++ .../knowledge/documents/document-processor.ts | 19 +- apps/sim/lib/knowledge/service.test.ts | 114 +++ apps/sim/lib/knowledge/service.ts | 106 ++- apps/sim/lib/mcp/client.ts | 37 +- apps/sim/lib/mcp/connection-manager.ts | 4 +- apps/sim/lib/mcp/domain-check.test.ts | 121 ++- apps/sim/lib/mcp/domain-check.ts | 64 +- apps/sim/lib/mcp/pinned-fetch.test.ts | 85 +++ apps/sim/lib/mcp/pinned-fetch.ts | 38 + apps/sim/lib/mcp/service.ts | 67 +- apps/sim/lib/mcp/types.ts | 8 + apps/sim/package.json | 1 + apps/sim/tools/agiloft/utils.server.ts | 79 ++ apps/sim/tools/agiloft/utils.test.ts | 136 ++++ apps/sim/tools/agiloft/utils.ts | 2 +- apps/sim/tools/grafana/update_alert_rule.ts | 12 +- apps/sim/tools/grafana/update_dashboard.ts | 19 +- bun.lock | 1 + scripts/check-api-validation-contracts.ts | 4 +- 43 files changed, 2997 insertions(+), 378 deletions(-) create mode 100644 apps/sim/app/api/form/[identifier]/otp/route.test.ts create mode 100644 apps/sim/app/api/form/[identifier]/otp/route.ts create mode 100644 apps/sim/app/api/knowledge/[id]/documents/upsert/route.test.ts create mode 100644 apps/sim/app/api/tools/agiloft/attach/route.test.ts create mode 100644 apps/sim/app/api/tools/agiloft/retrieve/route.test.ts create mode 100644 apps/sim/app/form/[identifier]/components/email-auth.tsx create mode 100644 apps/sim/lib/api/contracts/knowledge/shared.test.ts create mode 100644 apps/sim/lib/core/security/otp.ts create mode 100644 apps/sim/lib/knowledge/service.test.ts create mode 100644 apps/sim/lib/mcp/pinned-fetch.test.ts create mode 100644 apps/sim/lib/mcp/pinned-fetch.ts create mode 100644 apps/sim/tools/agiloft/utils.server.ts create mode 100644 apps/sim/tools/agiloft/utils.test.ts diff --git a/apps/sim/app/api/chat/[identifier]/otp/route.ts b/apps/sim/app/api/chat/[identifier]/otp/route.ts index b2e129b5fa8..fcccc003e86 100644 --- a/apps/sim/app/api/chat/[identifier]/otp/route.ts +++ b/apps/sim/app/api/chat/[identifier]/otp/route.ts @@ -1,18 +1,24 @@ -import { randomInt } from 'crypto' import { db } from '@sim/db' -import { chat, verification } from '@sim/db/schema' +import { chat } from '@sim/db/schema' import { createLogger } from '@sim/logger' -import { generateId } from '@sim/utils/id' -import { and, eq, gt, isNull } from 'drizzle-orm' +import { and, eq, isNull } from 'drizzle-orm' import type { NextRequest } from 'next/server' import { renderOTPEmail } from '@/components/emails' import { requestChatEmailOtpContract, verifyChatEmailOtpContract } from '@/lib/api/contracts/chats' import { getValidationErrorMessage, parseRequest } from '@/lib/api/server' -import { getRedisClient } from '@/lib/core/config/redis' -import type { TokenBucketConfig } from '@/lib/core/rate-limiter' import { RateLimiter } from '@/lib/core/rate-limiter' import { addCorsHeaders, isEmailAllowed } from '@/lib/core/security/deployment' -import { getStorageMethod } from '@/lib/core/storage' +import { + decodeOTPValue, + deleteOTP, + generateOTP, + getOTP, + incrementOTPAttempts, + MAX_OTP_ATTEMPTS, + OTP_EMAIL_RATE_LIMIT, + OTP_IP_RATE_LIMIT, + storeOTP, +} from '@/lib/core/security/otp' import { generateRequestId, getClientIp } from '@/lib/core/utils/request' import { withRouteHandler } from '@/lib/core/utils/with-route-handler' import { sendEmail } from '@/lib/messaging/email/mailer' @@ -23,199 +29,6 @@ const logger = createLogger('ChatOtpAPI') const rateLimiter = new RateLimiter() -const OTP_IP_RATE_LIMIT: TokenBucketConfig = { - maxTokens: 10, - refillRate: 10, - refillIntervalMs: 15 * 60_000, -} - -const OTP_EMAIL_RATE_LIMIT: TokenBucketConfig = { - maxTokens: 3, - refillRate: 3, - refillIntervalMs: 15 * 60_000, -} - -function generateOTP(): string { - return randomInt(100000, 1000000).toString() -} - -const OTP_EXPIRY = 15 * 60 // 15 minutes -const OTP_EXPIRY_MS = OTP_EXPIRY * 1000 -const MAX_OTP_ATTEMPTS = 5 - -/** - * OTP values are stored as "code:attempts" (e.g. "654321:0"). - * This keeps the attempt counter in the same key/row as the OTP itself. - */ -function encodeOTPValue(otp: string, attempts: number): string { - return `${otp}:${attempts}` -} - -function decodeOTPValue(value: string): { otp: string; attempts: number } { - const lastColon = value.lastIndexOf(':') - if (lastColon === -1) return { otp: value, attempts: 0 } - const attempts = Number.parseInt(value.slice(lastColon + 1), 10) - return { otp: value.slice(0, lastColon), attempts: Number.isNaN(attempts) ? 0 : attempts } -} - -/** - * Stores OTP in Redis or database depending on storage method. - * Uses the verification table for database storage. - */ -async function storeOTP(email: string, chatId: string, otp: string): Promise { - const identifier = `chat-otp:${chatId}:${email}` - const storageMethod = getStorageMethod() - const value = encodeOTPValue(otp, 0) - - if (storageMethod === 'redis') { - const redis = getRedisClient() - if (!redis) { - throw new Error('Redis configured but client unavailable') - } - await redis.set(`otp:${email}:${chatId}`, value, 'EX', OTP_EXPIRY) - } else { - const now = new Date() - const expiresAt = new Date(now.getTime() + OTP_EXPIRY_MS) - - await db.transaction(async (tx) => { - await tx.delete(verification).where(eq(verification.identifier, identifier)) - await tx.insert(verification).values({ - id: generateId(), - identifier, - value, - expiresAt, - createdAt: now, - updatedAt: now, - }) - }) - } -} - -async function getOTP(email: string, chatId: string): Promise { - const identifier = `chat-otp:${chatId}:${email}` - const storageMethod = getStorageMethod() - - if (storageMethod === 'redis') { - const redis = getRedisClient() - if (!redis) { - throw new Error('Redis configured but client unavailable') - } - return redis.get(`otp:${email}:${chatId}`) - } - - const now = new Date() - const [record] = await db - .select({ value: verification.value }) - .from(verification) - .where(and(eq(verification.identifier, identifier), gt(verification.expiresAt, now))) - .limit(1) - - return record?.value ?? null -} - -/** - * Lua script for atomic OTP attempt increment. - * Returns: "LOCKED" if max attempts reached (key deleted), new encoded value otherwise, nil if key missing. - */ -const ATOMIC_INCREMENT_SCRIPT = ` -local val = redis.call('GET', KEYS[1]) -if not val then return nil end -local colon = val:find(':([^:]*$)') -local otp, attempts -if colon then - otp = val:sub(1, colon - 1) - attempts = tonumber(val:sub(colon + 1)) or 0 -else - otp = val - attempts = 0 -end -attempts = attempts + 1 -if attempts >= tonumber(ARGV[1]) then - redis.call('DEL', KEYS[1]) - return 'LOCKED' -end -local newVal = otp .. ':' .. attempts -local ttl = redis.call('TTL', KEYS[1]) -if ttl > 0 then - redis.call('SET', KEYS[1], newVal, 'EX', ttl) -else - redis.call('SET', KEYS[1], newVal) -end -return newVal -` - -/** - * Atomically increments OTP attempts. Returns 'locked' if max reached, 'incremented' otherwise. - */ -async function incrementOTPAttempts( - email: string, - chatId: string, - currentValue: string -): Promise<'locked' | 'incremented'> { - const identifier = `chat-otp:${chatId}:${email}` - const storageMethod = getStorageMethod() - - if (storageMethod === 'redis') { - const redis = getRedisClient() - if (!redis) { - throw new Error('Redis configured but client unavailable') - } - const key = `otp:${email}:${chatId}` - const result = await redis.eval(ATOMIC_INCREMENT_SCRIPT, 1, key, MAX_OTP_ATTEMPTS) - if (result === null || result === 'LOCKED') return 'locked' - return 'incremented' - } - - // DB path: optimistic locking with retry on conflict - const MAX_RETRIES = 3 - let value = currentValue - - for (let attempt = 0; attempt < MAX_RETRIES; attempt++) { - const { otp, attempts } = decodeOTPValue(value) - const newAttempts = attempts + 1 - - if (newAttempts >= MAX_OTP_ATTEMPTS) { - await db.delete(verification).where(eq(verification.identifier, identifier)) - return 'locked' - } - - const newValue = encodeOTPValue(otp, newAttempts) - const updated = await db - .update(verification) - .set({ value: newValue, updatedAt: new Date() }) - .where(and(eq(verification.identifier, identifier), eq(verification.value, value))) - .returning({ id: verification.id }) - - if (updated.length > 0) return 'incremented' - - // Conflict: another request already incremented — re-read and retry - const fresh = await getOTP(email, chatId) - if (!fresh) return 'locked' - value = fresh - } - - // Exhausted retries — re-read final state to determine outcome - const final = await getOTP(email, chatId) - if (!final) return 'locked' - const { attempts: finalAttempts } = decodeOTPValue(final) - return finalAttempts >= MAX_OTP_ATTEMPTS ? 'locked' : 'incremented' -} - -async function deleteOTP(email: string, chatId: string): Promise { - const identifier = `chat-otp:${chatId}:${email}` - const storageMethod = getStorageMethod() - - if (storageMethod === 'redis') { - const redis = getRedisClient() - if (!redis) { - throw new Error('Redis configured but client unavailable') - } - await redis.del(`otp:${email}:${chatId}`) - } else { - await db.delete(verification).where(eq(verification.identifier, identifier)) - } -} - export const POST = withRouteHandler( async (request: NextRequest, context: { params: Promise<{ identifier: string }> }) => { const { identifier } = await context.params @@ -305,7 +118,7 @@ export const POST = withRouteHandler( } const otp = generateOTP() - await storeOTP(email, deployment.id, otp) + await storeOTP('chat', deployment.id, email, otp) const emailHtml = await renderOTPEmail( otp, @@ -330,12 +143,9 @@ export const POST = withRouteHandler( logger.info(`[${requestId}] OTP sent to ${email} for chat ${deployment.id}`) return addCorsHeaders(createSuccessResponse({ message: 'Verification code sent' }), request) - } catch (error: any) { + } catch (error) { logger.error(`[${requestId}] Error processing OTP request:`, error) - return addCorsHeaders( - createErrorResponse(error.message || 'Failed to process request', 500), - request - ) + return addCorsHeaders(createErrorResponse('Failed to process request', 500), request) } } ) @@ -379,7 +189,7 @@ export const PUT = withRouteHandler( const deployment = deploymentResult[0] - const storedValue = await getOTP(email, deployment.id) + const storedValue = await getOTP('chat', deployment.id, email) if (!storedValue) { return addCorsHeaders( createErrorResponse('No verification code found, request a new one', 400), @@ -390,7 +200,7 @@ export const PUT = withRouteHandler( const { otp: storedOTP, attempts } = decodeOTPValue(storedValue) if (attempts >= MAX_OTP_ATTEMPTS) { - await deleteOTP(email, deployment.id) + await deleteOTP('chat', deployment.id, email) logger.warn(`[${requestId}] OTP already locked out for ${email}`) return addCorsHeaders( createErrorResponse('Too many failed attempts. Please request a new code.', 429), @@ -399,7 +209,7 @@ export const PUT = withRouteHandler( } if (storedOTP !== otp) { - const result = await incrementOTPAttempts(email, deployment.id, storedValue) + const result = await incrementOTPAttempts('chat', deployment.id, email, storedValue) if (result === 'locked') { logger.warn(`[${requestId}] OTP invalidated after max failed attempts for ${email}`) return addCorsHeaders( @@ -410,7 +220,7 @@ export const PUT = withRouteHandler( return addCorsHeaders(createErrorResponse('Invalid verification code', 400), request) } - await deleteOTP(email, deployment.id) + await deleteOTP('chat', deployment.id, email) const response = addCorsHeaders( createSuccessResponse({ @@ -426,12 +236,9 @@ export const PUT = withRouteHandler( setChatAuthCookie(response, deployment.id, deployment.authType, deployment.password) return response - } catch (error: any) { + } catch (error) { logger.error(`[${requestId}] Error verifying OTP:`, error) - return addCorsHeaders( - createErrorResponse(error.message || 'Failed to process request', 500), - request - ) + return addCorsHeaders(createErrorResponse('Failed to process request', 500), request) } } ) diff --git a/apps/sim/app/api/form/[identifier]/otp/route.test.ts b/apps/sim/app/api/form/[identifier]/otp/route.test.ts new file mode 100644 index 00000000000..4b3b13441d0 --- /dev/null +++ b/apps/sim/app/api/form/[identifier]/otp/route.test.ts @@ -0,0 +1,695 @@ +/** + * Tests for form OTP API route + * + * @vitest-environment node + */ +import { + redisConfigMock, + redisConfigMockFns, + requestUtilsMockFns, + workflowsApiUtilsMock, + workflowsApiUtilsMockFns, +} from '@sim/testing' +import { NextRequest } from 'next/server' +import { afterEach, beforeEach, describe, expect, it, vi } from 'vitest' + +const { + mockRedisSet, + mockRedisGet, + mockRedisDel, + mockRedisTtl, + mockRedisEval, + mockRedisClient, + mockDbSelect, + mockDbInsert, + mockDbDelete, + mockDbUpdate, + mockSendEmail, + mockRenderOTPEmail, + mockAddCorsHeaders, + mockSetFormAuthCookie, + mockGetStorageMethod, + mockZodParse, + mockGetEnv, +} = vi.hoisted(() => { + const mockRedisSet = vi.fn() + const mockRedisGet = vi.fn() + const mockRedisDel = vi.fn() + const mockRedisTtl = vi.fn() + const mockRedisEval = vi.fn() + const mockRedisClient = { + set: mockRedisSet, + get: mockRedisGet, + del: mockRedisDel, + ttl: mockRedisTtl, + eval: mockRedisEval, + } + return { + mockRedisSet, + mockRedisGet, + mockRedisDel, + mockRedisTtl, + mockRedisEval, + mockRedisClient, + mockDbSelect: vi.fn(), + mockDbInsert: vi.fn(), + mockDbDelete: vi.fn(), + mockDbUpdate: vi.fn(), + mockSendEmail: vi.fn(), + mockRenderOTPEmail: vi.fn(), + mockAddCorsHeaders: vi.fn(), + mockSetFormAuthCookie: vi.fn(), + mockGetStorageMethod: vi.fn(), + mockZodParse: vi.fn(), + mockGetEnv: vi.fn(), + } +}) + +const mockGetRedisClient = redisConfigMockFns.mockGetRedisClient +const mockCreateSuccessResponse = workflowsApiUtilsMockFns.mockCreateSuccessResponse +const mockCreateErrorResponse = workflowsApiUtilsMockFns.mockCreateErrorResponse + +vi.mock('@/lib/core/config/redis', () => redisConfigMock) + +vi.mock('@sim/db', () => ({ + db: { + select: mockDbSelect, + insert: mockDbInsert, + delete: mockDbDelete, + update: mockDbUpdate, + transaction: vi.fn(async (callback: (tx: Record) => unknown) => { + return callback({ + select: mockDbSelect, + insert: mockDbInsert, + delete: mockDbDelete, + update: mockDbUpdate, + }) + }), + }, +})) + +vi.mock('drizzle-orm', () => ({ + eq: vi.fn((field: string, value: string) => ({ field, value, type: 'eq' })), + and: vi.fn((...conditions: unknown[]) => ({ conditions, type: 'and' })), + gt: vi.fn((field: string, value: string) => ({ field, value, type: 'gt' })), + lt: vi.fn((field: string, value: string) => ({ field, value, type: 'lt' })), + isNull: vi.fn((field: unknown) => ({ field, type: 'isNull' })), +})) + +vi.mock('@/lib/core/storage', () => ({ + getStorageMethod: mockGetStorageMethod, +})) + +const { mockCheckRateLimitDirect } = vi.hoisted(() => ({ + mockCheckRateLimitDirect: vi.fn(), +})) + +vi.mock('@/lib/core/rate-limiter', () => ({ + RateLimiter: class { + checkRateLimitDirect = mockCheckRateLimitDirect + }, +})) + +vi.mock('@/lib/messaging/email/mailer', () => ({ + sendEmail: mockSendEmail, +})) + +vi.mock('@/components/emails', () => ({ + renderOTPEmail: mockRenderOTPEmail, +})) + +vi.mock('@/lib/core/security/deployment', () => ({ + addCorsHeaders: mockAddCorsHeaders, + isEmailAllowed: (email: string, allowedEmails: string[]) => { + if (allowedEmails.includes(email)) return true + const atIndex = email.indexOf('@') + if (atIndex > 0) { + const domain = email.substring(atIndex + 1) + if (domain && allowedEmails.some((allowed: string) => allowed === `@${domain}`)) return true + } + return false + }, +})) + +vi.mock('@/app/api/form/utils', () => ({ + setFormAuthCookie: mockSetFormAuthCookie, +})) + +vi.mock('@/app/api/workflows/utils', () => workflowsApiUtilsMock) + +vi.mock('@/lib/core/config/env', () => ({ + env: { + NEXT_PUBLIC_APP_URL: 'http://localhost:3000', + NODE_ENV: 'test', + }, + getEnv: mockGetEnv, + isTruthy: vi.fn().mockReturnValue(false), + isFalsy: vi.fn().mockReturnValue(true), +})) + +vi.mock('zod', () => { + class ZodError extends Error { + errors: Array<{ message: string }> + constructor(issues: Array<{ message: string }>) { + super('ZodError') + this.errors = issues + } + } + const chainable: Record = {} + const proxy: Record = new Proxy(chainable, { + get(target, prop) { + if (prop === 'parse') return mockZodParse + if (prop === 'safeParse') { + return (data: unknown) => ({ success: true, data }) + } + if (prop === 'then') return undefined + if (typeof prop === 'symbol') return Reflect.get(target, prop) + if (!(prop in target)) { + target[prop as string] = vi.fn().mockReturnValue(proxy) + } + return target[prop as string] + }, + }) + const makeChain = vi.fn(() => proxy) + return { + z: new Proxy( + { ZodError }, + { + get(target, prop) { + if (prop === 'ZodError') return ZodError + if (typeof prop === 'symbol') return Reflect.get(target, prop) + return makeChain + }, + } + ), + } +}) + +import { POST, PUT } from './route' + +describe('Form OTP API Route', () => { + const mockEmail = 'user@example.com' + const mockFormId = 'form-123' + const mockIdentifier = 'test-form' + const mockOTP = '123456' + + const deploymentRow = { + id: mockFormId, + authType: 'email', + allowedEmails: [mockEmail], + title: 'Test Form', + isActive: true, + } + + const verifyDeploymentRow = { + id: mockFormId, + authType: 'email', + password: null, + allowedEmails: [mockEmail], + isActive: true, + } + + const selectOnce = (rows: unknown[]) => + mockDbSelect.mockImplementationOnce(() => ({ + from: vi.fn().mockReturnValue({ + where: vi.fn().mockReturnValue({ + limit: vi.fn().mockResolvedValue(rows), + }), + }), + })) + + beforeEach(() => { + vi.clearAllMocks() + + vi.spyOn(Math, 'random').mockReturnValue(0.123456) + vi.spyOn(Date, 'now').mockReturnValue(1640995200000) + + vi.stubGlobal('crypto', { + ...crypto, + randomUUID: vi.fn().mockReturnValue('test-uuid-1234'), + }) + + mockGetRedisClient.mockReturnValue(mockRedisClient) + mockRedisSet.mockResolvedValue('OK') + mockRedisGet.mockResolvedValue(null) + mockRedisDel.mockResolvedValue(1) + mockRedisTtl.mockResolvedValue(600) + + mockDbSelect.mockImplementation(() => ({ + from: vi.fn().mockReturnValue({ + where: vi.fn().mockReturnValue({ + limit: vi.fn().mockResolvedValue([]), + }), + }), + })) + mockDbInsert.mockImplementation(() => ({ values: vi.fn().mockResolvedValue(undefined) })) + mockDbDelete.mockImplementation(() => ({ where: vi.fn().mockResolvedValue(undefined) })) + mockDbUpdate.mockImplementation(() => ({ + set: vi.fn().mockReturnValue({ where: vi.fn().mockResolvedValue(undefined) }), + })) + + mockGetStorageMethod.mockReturnValue('redis') + + mockSendEmail.mockResolvedValue({ success: true }) + mockRenderOTPEmail.mockResolvedValue('OTP Email') + + mockAddCorsHeaders.mockImplementation((response: unknown) => response) + mockCreateSuccessResponse.mockImplementation((data: unknown) => ({ + json: () => Promise.resolve(data), + status: 200, + })) + mockCreateErrorResponse.mockImplementation((message: string, status: number) => ({ + json: () => Promise.resolve({ error: message }), + status, + })) + + requestUtilsMockFns.mockGenerateRequestId.mockReturnValue('req-123') + requestUtilsMockFns.mockGetClientIp.mockReturnValue('1.2.3.4') + + mockCheckRateLimitDirect.mockResolvedValue({ + allowed: true, + remaining: 10, + resetAt: new Date(Date.now() + 60_000), + }) + + mockZodParse.mockImplementation((data: unknown) => data) + mockGetEnv.mockReturnValue('http://localhost:3000') + }) + + afterEach(() => { + vi.restoreAllMocks() + }) + + describe('POST /otp - request code', () => { + it('stores OTP in Redis when storage is redis and sends email', async () => { + selectOnce([deploymentRow]) + + const request = new NextRequest('http://localhost:3000/api/form/test/otp', { + method: 'POST', + body: JSON.stringify({ email: mockEmail }), + }) + + await POST(request, { params: Promise.resolve({ identifier: mockIdentifier }) }) + + expect(mockRedisSet).toHaveBeenCalledWith( + `form-otp:${mockEmail}:${mockFormId}`, + expect.stringMatching(/^\d{6}:0$/), + 'EX', + 900 + ) + expect(mockSendEmail).toHaveBeenCalledWith( + expect.objectContaining({ to: mockEmail, subject: expect.stringContaining('Test Form') }) + ) + expect(mockDbInsert).not.toHaveBeenCalled() + }) + + it('stores OTP in database when storage is database', async () => { + mockGetStorageMethod.mockReturnValue('database') + mockGetRedisClient.mockReturnValue(null) + selectOnce([deploymentRow]) + const insertValues = vi.fn().mockResolvedValue(undefined) + mockDbInsert.mockImplementationOnce(() => ({ values: insertValues })) + + const request = new NextRequest('http://localhost:3000/api/form/test/otp', { + method: 'POST', + body: JSON.stringify({ email: mockEmail }), + }) + + await POST(request, { params: Promise.resolve({ identifier: mockIdentifier }) }) + + expect(insertValues).toHaveBeenCalledWith( + expect.objectContaining({ + identifier: `form-otp:${mockFormId}:${mockEmail}`, + value: expect.stringMatching(/^\d{6}:0$/), + }) + ) + expect(mockRedisSet).not.toHaveBeenCalled() + }) + + it('returns 404 when form is not found', async () => { + selectOnce([]) + + const request = new NextRequest('http://localhost:3000/api/form/test/otp', { + method: 'POST', + body: JSON.stringify({ email: mockEmail }), + }) + + await POST(request, { params: Promise.resolve({ identifier: mockIdentifier }) }) + + expect(mockCreateErrorResponse).toHaveBeenCalledWith('Form not found', 404) + expect(mockSendEmail).not.toHaveBeenCalled() + }) + + it('returns 403 when form is inactive', async () => { + selectOnce([{ ...deploymentRow, isActive: false }]) + + const request = new NextRequest('http://localhost:3000/api/form/test/otp', { + method: 'POST', + body: JSON.stringify({ email: mockEmail }), + }) + + await POST(request, { params: Promise.resolve({ identifier: mockIdentifier }) }) + + expect(mockCreateErrorResponse).toHaveBeenCalledWith( + 'This form is currently unavailable', + 403 + ) + expect(mockSendEmail).not.toHaveBeenCalled() + }) + + it('returns 400 when form authType is not email', async () => { + selectOnce([{ ...deploymentRow, authType: 'public' }]) + + const request = new NextRequest('http://localhost:3000/api/form/test/otp', { + method: 'POST', + body: JSON.stringify({ email: mockEmail }), + }) + + await POST(request, { params: Promise.resolve({ identifier: mockIdentifier }) }) + + expect(mockCreateErrorResponse).toHaveBeenCalledWith( + 'This form does not use email authentication', + 400 + ) + expect(mockSendEmail).not.toHaveBeenCalled() + }) + + it('returns 403 when email is not in allowedEmails', async () => { + selectOnce([{ ...deploymentRow, allowedEmails: ['other@example.com'] }]) + + const request = new NextRequest('http://localhost:3000/api/form/test/otp', { + method: 'POST', + body: JSON.stringify({ email: mockEmail }), + }) + + await POST(request, { params: Promise.resolve({ identifier: mockIdentifier }) }) + + expect(mockCreateErrorResponse).toHaveBeenCalledWith( + 'Email not authorized for this form', + 403 + ) + expect(mockSendEmail).not.toHaveBeenCalled() + }) + + it('authorizes by domain match in allowedEmails', async () => { + selectOnce([{ ...deploymentRow, allowedEmails: ['@example.com'] }]) + + const request = new NextRequest('http://localhost:3000/api/form/test/otp', { + method: 'POST', + body: JSON.stringify({ email: mockEmail }), + }) + + await POST(request, { params: Promise.resolve({ identifier: mockIdentifier }) }) + + expect(mockSendEmail).toHaveBeenCalled() + }) + + it('returns 429 with Retry-After when IP rate limit is exceeded', async () => { + mockCheckRateLimitDirect.mockResolvedValueOnce({ + allowed: false, + remaining: 0, + resetAt: new Date(Date.now() + 900_000), + retryAfterMs: 900_000, + }) + const headerSet = vi.fn() + mockCreateErrorResponse.mockImplementationOnce((message: string, status: number) => ({ + json: () => Promise.resolve({ error: message }), + status, + headers: { set: headerSet }, + })) + + const request = new NextRequest('http://localhost:3000/api/form/test/otp', { + method: 'POST', + body: JSON.stringify({ email: mockEmail }), + }) + + const response = await POST(request, { + params: Promise.resolve({ identifier: mockIdentifier }), + }) + + expect(response.status).toBe(429) + expect(headerSet).toHaveBeenCalledWith('Retry-After', '900') + expect(mockSendEmail).not.toHaveBeenCalled() + expect(mockDbSelect).not.toHaveBeenCalled() + }) + + it('returns 429 with Retry-After when email rate limit is exceeded', async () => { + mockCheckRateLimitDirect + .mockResolvedValueOnce({ + allowed: true, + remaining: 9, + resetAt: new Date(Date.now() + 60_000), + }) + .mockResolvedValueOnce({ + allowed: false, + remaining: 0, + resetAt: new Date(Date.now() + 900_000), + retryAfterMs: 900_000, + }) + const headerSet = vi.fn() + mockCreateErrorResponse.mockImplementationOnce((message: string, status: number) => ({ + json: () => Promise.resolve({ error: message }), + status, + headers: { set: headerSet }, + })) + selectOnce([deploymentRow]) + + const request = new NextRequest('http://localhost:3000/api/form/test/otp', { + method: 'POST', + body: JSON.stringify({ email: mockEmail }), + }) + + const response = await POST(request, { + params: Promise.resolve({ identifier: mockIdentifier }), + }) + + expect(response.status).toBe(429) + expect(headerSet).toHaveBeenCalledWith('Retry-After', '900') + expect(mockSendEmail).not.toHaveBeenCalled() + }) + + it('rate-limits the IP bucket before reading the deployment row', async () => { + mockCheckRateLimitDirect.mockResolvedValueOnce({ + allowed: false, + remaining: 0, + resetAt: new Date(Date.now() + 900_000), + retryAfterMs: 900_000, + }) + mockCreateErrorResponse.mockImplementationOnce((message: string, status: number) => ({ + json: () => Promise.resolve({ error: message }), + status, + headers: { set: vi.fn() }, + })) + + const request = new NextRequest('http://localhost:3000/api/form/test/otp', { + method: 'POST', + body: JSON.stringify({ email: mockEmail }), + }) + + await POST(request, { params: Promise.resolve({ identifier: mockIdentifier }) }) + + expect(mockDbSelect).not.toHaveBeenCalled() + }) + + it('returns 500 when email send fails', async () => { + selectOnce([deploymentRow]) + mockSendEmail.mockResolvedValueOnce({ success: false, message: 'smtp down' }) + + const request = new NextRequest('http://localhost:3000/api/form/test/otp', { + method: 'POST', + body: JSON.stringify({ email: mockEmail }), + }) + + await POST(request, { params: Promise.resolve({ identifier: mockIdentifier }) }) + + expect(mockCreateErrorResponse).toHaveBeenCalledWith('Failed to send verification email', 500) + }) + }) + + describe('PUT /otp - verify code', () => { + it('verifies OTP, deletes it, and sets the form auth cookie on success', async () => { + selectOnce([verifyDeploymentRow]) + mockRedisGet.mockResolvedValue(`${mockOTP}:0`) + + const request = new NextRequest('http://localhost:3000/api/form/test/otp', { + method: 'PUT', + body: JSON.stringify({ email: mockEmail, otp: mockOTP }), + }) + + await PUT(request, { params: Promise.resolve({ identifier: mockIdentifier }) }) + + expect(mockRedisGet).toHaveBeenCalledWith(`form-otp:${mockEmail}:${mockFormId}`) + expect(mockRedisDel).toHaveBeenCalledWith(`form-otp:${mockEmail}:${mockFormId}`) + expect(mockSetFormAuthCookie).toHaveBeenCalledWith( + expect.any(Object), + mockFormId, + 'email', + null + ) + expect(mockCreateSuccessResponse).toHaveBeenCalledWith({ authenticated: true }) + }) + + it('returns 404 when form is not found', async () => { + selectOnce([]) + + const request = new NextRequest('http://localhost:3000/api/form/test/otp', { + method: 'PUT', + body: JSON.stringify({ email: mockEmail, otp: mockOTP }), + }) + + await PUT(request, { params: Promise.resolve({ identifier: mockIdentifier }) }) + + expect(mockCreateErrorResponse).toHaveBeenCalledWith('Form not found', 404) + expect(mockSetFormAuthCookie).not.toHaveBeenCalled() + }) + + it('returns 403 when form is inactive at verify time', async () => { + selectOnce([{ ...verifyDeploymentRow, isActive: false }]) + + const request = new NextRequest('http://localhost:3000/api/form/test/otp', { + method: 'PUT', + body: JSON.stringify({ email: mockEmail, otp: mockOTP }), + }) + + await PUT(request, { params: Promise.resolve({ identifier: mockIdentifier }) }) + + expect(mockCreateErrorResponse).toHaveBeenCalledWith( + 'This form is currently unavailable', + 403 + ) + expect(mockSetFormAuthCookie).not.toHaveBeenCalled() + }) + + it('returns 403 when email is no longer in allowedEmails at verify time', async () => { + selectOnce([{ ...verifyDeploymentRow, allowedEmails: ['other@example.com'] }]) + mockRedisGet.mockResolvedValue(`${mockOTP}:0`) + + const request = new NextRequest('http://localhost:3000/api/form/test/otp', { + method: 'PUT', + body: JSON.stringify({ email: mockEmail, otp: mockOTP }), + }) + + await PUT(request, { params: Promise.resolve({ identifier: mockIdentifier }) }) + + expect(mockCreateErrorResponse).toHaveBeenCalledWith( + 'Email not authorized for this form', + 403 + ) + expect(mockSetFormAuthCookie).not.toHaveBeenCalled() + }) + + it('returns 400 when no OTP is stored', async () => { + selectOnce([verifyDeploymentRow]) + mockRedisGet.mockResolvedValue(null) + + const request = new NextRequest('http://localhost:3000/api/form/test/otp', { + method: 'PUT', + body: JSON.stringify({ email: mockEmail, otp: mockOTP }), + }) + + await PUT(request, { params: Promise.resolve({ identifier: mockIdentifier }) }) + + expect(mockCreateErrorResponse).toHaveBeenCalledWith( + 'No verification code found, request a new one', + 400 + ) + expect(mockSetFormAuthCookie).not.toHaveBeenCalled() + }) + + it('atomically increments attempts on wrong OTP and returns 400', async () => { + selectOnce([verifyDeploymentRow]) + mockRedisGet.mockResolvedValue('654321:0') + mockRedisEval.mockResolvedValue('654321:1') + + const request = new NextRequest('http://localhost:3000/api/form/test/otp', { + method: 'PUT', + body: JSON.stringify({ email: mockEmail, otp: 'wrong1' }), + }) + + await PUT(request, { params: Promise.resolve({ identifier: mockIdentifier }) }) + + expect(mockRedisEval).toHaveBeenCalledWith( + expect.any(String), + 1, + `form-otp:${mockEmail}:${mockFormId}`, + 5 + ) + expect(mockCreateErrorResponse).toHaveBeenCalledWith('Invalid verification code', 400) + expect(mockSetFormAuthCookie).not.toHaveBeenCalled() + }) + + it('invalidates OTP and returns 429 after max failed attempts', async () => { + selectOnce([verifyDeploymentRow]) + mockRedisGet.mockResolvedValue('654321:4') + mockRedisEval.mockResolvedValue('LOCKED') + + const request = new NextRequest('http://localhost:3000/api/form/test/otp', { + method: 'PUT', + body: JSON.stringify({ email: mockEmail, otp: 'wrong5' }), + }) + + await PUT(request, { params: Promise.resolve({ identifier: mockIdentifier }) }) + + expect(mockCreateErrorResponse).toHaveBeenCalledWith( + 'Too many failed attempts. Please request a new code.', + 429 + ) + expect(mockSetFormAuthCookie).not.toHaveBeenCalled() + }) + + it('rejects when stored OTP is already at max attempts', async () => { + selectOnce([verifyDeploymentRow]) + mockRedisGet.mockResolvedValue(`${mockOTP}:5`) + const deleteWhere = vi.fn().mockResolvedValue(undefined) + mockDbDelete.mockImplementation(() => ({ where: deleteWhere })) + + const request = new NextRequest('http://localhost:3000/api/form/test/otp', { + method: 'PUT', + body: JSON.stringify({ email: mockEmail, otp: mockOTP }), + }) + + await PUT(request, { params: Promise.resolve({ identifier: mockIdentifier }) }) + + expect(mockCreateErrorResponse).toHaveBeenCalledWith( + 'Too many failed attempts. Please request a new code.', + 429 + ) + expect(mockSetFormAuthCookie).not.toHaveBeenCalled() + }) + + it('uses database storage path when configured', async () => { + mockGetStorageMethod.mockReturnValue('database') + mockGetRedisClient.mockReturnValue(null) + let selectCallCount = 0 + mockDbSelect.mockImplementation(() => ({ + from: vi.fn().mockReturnValue({ + where: vi.fn().mockReturnValue({ + limit: vi.fn().mockImplementation(() => { + selectCallCount++ + if (selectCallCount === 1) return Promise.resolve([verifyDeploymentRow]) + return Promise.resolve([ + { + value: `${mockOTP}:0`, + expiresAt: new Date(Date.now() + 10 * 60 * 1000), + }, + ]) + }), + }), + }), + })) + const deleteWhere = vi.fn().mockResolvedValue(undefined) + mockDbDelete.mockImplementation(() => ({ where: deleteWhere })) + + const request = new NextRequest('http://localhost:3000/api/form/test/otp', { + method: 'PUT', + body: JSON.stringify({ email: mockEmail, otp: mockOTP }), + }) + + await PUT(request, { params: Promise.resolve({ identifier: mockIdentifier }) }) + + expect(mockDbDelete).toHaveBeenCalled() + expect(mockRedisDel).not.toHaveBeenCalled() + expect(mockSetFormAuthCookie).toHaveBeenCalled() + }) + }) +}) diff --git a/apps/sim/app/api/form/[identifier]/otp/route.ts b/apps/sim/app/api/form/[identifier]/otp/route.ts new file mode 100644 index 00000000000..0d9804efa55 --- /dev/null +++ b/apps/sim/app/api/form/[identifier]/otp/route.ts @@ -0,0 +1,261 @@ +import { db } from '@sim/db' +import { form } from '@sim/db/schema' +import { createLogger } from '@sim/logger' +import { and, eq, isNull } from 'drizzle-orm' +import type { NextRequest } from 'next/server' +import { renderOTPEmail } from '@/components/emails' +import { requestFormEmailOtpContract, verifyFormEmailOtpContract } from '@/lib/api/contracts/forms' +import { getValidationErrorMessage, parseRequest } from '@/lib/api/server' +import { RateLimiter } from '@/lib/core/rate-limiter' +import { addCorsHeaders, isEmailAllowed } from '@/lib/core/security/deployment' +import { + decodeOTPValue, + deleteOTP, + generateOTP, + getOTP, + incrementOTPAttempts, + MAX_OTP_ATTEMPTS, + OTP_EMAIL_RATE_LIMIT, + OTP_IP_RATE_LIMIT, + storeOTP, +} from '@/lib/core/security/otp' +import { generateRequestId, getClientIp } from '@/lib/core/utils/request' +import { withRouteHandler } from '@/lib/core/utils/with-route-handler' +import { sendEmail } from '@/lib/messaging/email/mailer' +import { setFormAuthCookie } from '@/app/api/form/utils' +import { createErrorResponse, createSuccessResponse } from '@/app/api/workflows/utils' + +const logger = createLogger('FormOtpAPI') + +const rateLimiter = new RateLimiter() + +export const POST = withRouteHandler( + async (request: NextRequest, context: { params: Promise<{ identifier: string }> }) => { + const { identifier } = await context.params + const requestId = generateRequestId() + + try { + const ip = getClientIp(request) + const ipRateLimit = await rateLimiter.checkRateLimitDirect( + `form-otp:ip:${identifier}:${ip}`, + OTP_IP_RATE_LIMIT + ) + if (!ipRateLimit.allowed) { + logger.warn(`[${requestId}] OTP IP rate limit exceeded for ${identifier} from ${ip}`) + const retryAfter = Math.ceil( + (ipRateLimit.retryAfterMs ?? OTP_IP_RATE_LIMIT.refillIntervalMs) / 1000 + ) + const response = createErrorResponse('Too many requests. Please try again later.', 429) + response.headers.set('Retry-After', String(retryAfter)) + return addCorsHeaders(response, request) + } + + const parsed = await parseRequest(requestFormEmailOtpContract, request, context, { + validationErrorResponse: (error) => + addCorsHeaders( + createErrorResponse(getValidationErrorMessage(error, 'Invalid request'), 400), + request + ), + }) + if (!parsed.success) return parsed.response + const { email } = parsed.data.body + + const deploymentResult = await db + .select({ + id: form.id, + authType: form.authType, + allowedEmails: form.allowedEmails, + title: form.title, + isActive: form.isActive, + }) + .from(form) + .where(and(eq(form.identifier, identifier), isNull(form.archivedAt))) + .limit(1) + + if (deploymentResult.length === 0) { + logger.warn(`[${requestId}] Form not found for identifier: ${identifier}`) + return addCorsHeaders(createErrorResponse('Form not found', 404), request) + } + + const deployment = deploymentResult[0] + + if (!deployment.isActive) { + return addCorsHeaders( + createErrorResponse('This form is currently unavailable', 403), + request + ) + } + + if (deployment.authType !== 'email') { + return addCorsHeaders( + createErrorResponse('This form does not use email authentication', 400), + request + ) + } + + const allowedEmails: string[] = Array.isArray(deployment.allowedEmails) + ? (deployment.allowedEmails as string[]) + : [] + + if (!isEmailAllowed(email, allowedEmails)) { + return addCorsHeaders( + createErrorResponse('Email not authorized for this form', 403), + request + ) + } + + const emailRateLimit = await rateLimiter.checkRateLimitDirect( + `form-otp:email:${deployment.id}:${email.toLowerCase()}`, + OTP_EMAIL_RATE_LIMIT + ) + if (!emailRateLimit.allowed) { + logger.warn( + `[${requestId}] OTP email rate limit exceeded for ${email} on form ${deployment.id}` + ) + const retryAfter = Math.ceil( + (emailRateLimit.retryAfterMs ?? OTP_EMAIL_RATE_LIMIT.refillIntervalMs) / 1000 + ) + const response = createErrorResponse( + 'Too many verification code requests. Please try again later.', + 429 + ) + response.headers.set('Retry-After', String(retryAfter)) + return addCorsHeaders(response, request) + } + + const otp = generateOTP() + await storeOTP('form', deployment.id, email, otp) + + const emailHtml = await renderOTPEmail( + otp, + email, + 'email-verification', + deployment.title || 'Form' + ) + + const emailResult = await sendEmail({ + to: email, + subject: `Verification code for ${deployment.title || 'Form'}`, + html: emailHtml, + }) + + if (!emailResult.success) { + logger.error(`[${requestId}] Failed to send OTP email:`, emailResult.message) + return addCorsHeaders( + createErrorResponse('Failed to send verification email', 500), + request + ) + } + + logger.info(`[${requestId}] OTP sent to ${email} for form ${deployment.id}`) + return addCorsHeaders(createSuccessResponse({ message: 'Verification code sent' }), request) + } catch (error) { + logger.error(`[${requestId}] Error processing OTP request:`, error) + return addCorsHeaders(createErrorResponse('Failed to process request', 500), request) + } + } +) + +export const PUT = withRouteHandler( + async (request: NextRequest, context: { params: Promise<{ identifier: string }> }) => { + const { identifier } = await context.params + const requestId = generateRequestId() + + try { + const parsed = await parseRequest(verifyFormEmailOtpContract, request, context, { + validationErrorResponse: (error) => + addCorsHeaders( + createErrorResponse(getValidationErrorMessage(error, 'Invalid request'), 400), + request + ), + }) + if (!parsed.success) return parsed.response + const { email, otp } = parsed.data.body + + const deploymentResult = await db + .select({ + id: form.id, + authType: form.authType, + password: form.password, + allowedEmails: form.allowedEmails, + isActive: form.isActive, + }) + .from(form) + .where(and(eq(form.identifier, identifier), isNull(form.archivedAt))) + .limit(1) + + if (deploymentResult.length === 0) { + logger.warn(`[${requestId}] Form not found for identifier: ${identifier}`) + return addCorsHeaders(createErrorResponse('Form not found', 404), request) + } + + const deployment = deploymentResult[0] + + if (!deployment.isActive) { + return addCorsHeaders( + createErrorResponse('This form is currently unavailable', 403), + request + ) + } + + if (deployment.authType !== 'email') { + return addCorsHeaders( + createErrorResponse('This form does not use email authentication', 400), + request + ) + } + + const allowedEmails: string[] = Array.isArray(deployment.allowedEmails) + ? (deployment.allowedEmails as string[]) + : [] + + if (!isEmailAllowed(email, allowedEmails)) { + return addCorsHeaders( + createErrorResponse('Email not authorized for this form', 403), + request + ) + } + + const storedValue = await getOTP('form', deployment.id, email) + if (!storedValue) { + return addCorsHeaders( + createErrorResponse('No verification code found, request a new one', 400), + request + ) + } + + const { otp: storedOTP, attempts } = decodeOTPValue(storedValue) + + if (attempts >= MAX_OTP_ATTEMPTS) { + await deleteOTP('form', deployment.id, email) + logger.warn(`[${requestId}] OTP already locked out for ${email}`) + return addCorsHeaders( + createErrorResponse('Too many failed attempts. Please request a new code.', 429), + request + ) + } + + if (storedOTP !== otp) { + const result = await incrementOTPAttempts('form', deployment.id, email, storedValue) + if (result === 'locked') { + logger.warn(`[${requestId}] OTP invalidated after max failed attempts for ${email}`) + return addCorsHeaders( + createErrorResponse('Too many failed attempts. Please request a new code.', 429), + request + ) + } + return addCorsHeaders(createErrorResponse('Invalid verification code', 400), request) + } + + await deleteOTP('form', deployment.id, email) + + const response = addCorsHeaders(createSuccessResponse({ authenticated: true }), request) + setFormAuthCookie(response, deployment.id, deployment.authType, deployment.password) + + return response + } catch (error) { + logger.error(`[${requestId}] Error verifying OTP:`, error) + return addCorsHeaders(createErrorResponse('Failed to process request', 500), request) + } + } +) diff --git a/apps/sim/app/api/form/utils.test.ts b/apps/sim/app/api/form/utils.test.ts index 9c36ccc6e92..1826d9386c1 100644 --- a/apps/sim/app/api/form/utils.test.ts +++ b/apps/sim/app/api/form/utils.test.ts @@ -239,18 +239,20 @@ describe('Form API Utils', () => { }, } as any - // Exact email match should authorize + // Exact email match should require OTP verification, not authorize directly mockIsEmailAllowed.mockReturnValue(true) const result1 = await validateFormAuth('request-id', deployment, mockRequest, { email: 'user@example.com', }) - expect(result1.authorized).toBe(true) + expect(result1.authorized).toBe(false) + expect(result1.error).toBe('otp_required') - // Domain match should authorize + // Domain match should also require OTP verification const result2 = await validateFormAuth('request-id', deployment, mockRequest, { email: 'other@company.com', }) - expect(result2.authorized).toBe(true) + expect(result2.authorized).toBe(false) + expect(result2.error).toBe('otp_required') // Unknown email should not authorize mockIsEmailAllowed.mockReturnValue(false) diff --git a/apps/sim/app/api/form/utils.ts b/apps/sim/app/api/form/utils.ts index 55bbe65e17f..7b1f1df54dc 100644 --- a/apps/sim/app/api/form/utils.ts +++ b/apps/sim/app/api/form/utils.ts @@ -159,7 +159,7 @@ export async function validateFormAuth( const allowedEmails: string[] = deployment.allowedEmails || [] if (isEmailAllowed(email, allowedEmails)) { - return { authorized: true } + return { authorized: false, error: 'otp_required' } } return { authorized: false, error: 'Email not authorized for this form' } diff --git a/apps/sim/app/api/knowledge/[id]/documents/upsert/route.test.ts b/apps/sim/app/api/knowledge/[id]/documents/upsert/route.test.ts new file mode 100644 index 00000000000..d5f64cf306e --- /dev/null +++ b/apps/sim/app/api/knowledge/[id]/documents/upsert/route.test.ts @@ -0,0 +1,111 @@ +/** + * Tests for knowledge base document upsert API route + * + * @vitest-environment node + */ +import { + auditMock, + createMockRequest, + hybridAuthMock, + hybridAuthMockFns, + knowledgeApiUtilsMock, +} from '@sim/testing' +import { beforeEach, describe, expect, it, vi } from 'vitest' + +const { mockDbChain } = vi.hoisted(() => { + const chain = { + select: vi.fn().mockReturnThis(), + from: vi.fn().mockReturnThis(), + where: vi.fn().mockReturnThis(), + limit: vi.fn().mockResolvedValue([]), + } + return { mockDbChain: chain } +}) + +vi.mock('@sim/db', () => ({ db: mockDbChain })) +vi.mock('@/lib/auth/hybrid', () => hybridAuthMock) +vi.mock('@/app/api/knowledge/utils', () => knowledgeApiUtilsMock) +vi.mock('@sim/audit', () => auditMock) + +vi.mock('@/lib/knowledge/documents/service', () => ({ + createDocumentRecords: vi.fn(), + deleteDocument: vi.fn(), + getProcessingConfig: vi.fn().mockReturnValue({ maxConcurrentDocuments: 1, batchSize: 1 }), + processDocumentsWithQueue: vi.fn(), +})) + +import { createDocumentRecords, processDocumentsWithQueue } from '@/lib/knowledge/documents/service' +import { POST } from '@/app/api/knowledge/[id]/documents/upsert/route' +import { checkKnowledgeBaseWriteAccess } from '@/app/api/knowledge/utils' + +describe('POST /api/knowledge/[id]/documents/upsert', () => { + const params = Promise.resolve({ id: 'kb-123' }) + + beforeEach(() => { + vi.clearAllMocks() + mockDbChain.select.mockReturnThis() + mockDbChain.from.mockReturnThis() + mockDbChain.where.mockReturnThis() + mockDbChain.limit.mockResolvedValue([]) + + hybridAuthMockFns.mockCheckSessionOrInternalAuth.mockResolvedValue({ + success: true, + userId: 'user-1', + authType: 'session', + userName: 'Test User', + userEmail: 'test@example.com', + }) + + vi.mocked(checkKnowledgeBaseWriteAccess).mockResolvedValue({ + hasAccess: true, + knowledgeBase: { id: 'kb-123', userId: 'user-1', workspaceId: 'ws-1', name: 'KB' }, + } as any) + + vi.mocked(createDocumentRecords).mockResolvedValue([ + { documentId: 'doc-new', filename: 'note.txt' }, + ] as any) + vi.mocked(processDocumentsWithQueue).mockResolvedValue(undefined as any) + }) + + const baseBody = { + filename: 'note.txt', + fileSize: 11, + mimeType: 'text/plain', + } + + it('accepts a data: URI', async () => { + const req = createMockRequest('POST', { + ...baseBody, + fileUrl: 'data:text/plain;base64,SGVsbG8gd29ybGQ=', + }) + const res = await POST(req, { params }) + expect(res.status).toBe(200) + expect(createDocumentRecords).toHaveBeenCalled() + }) + + it('accepts an https URL', async () => { + const req = createMockRequest('POST', { + ...baseBody, + fileUrl: 'https://example.com/note.txt', + }) + const res = await POST(req, { params }) + expect(res.status).toBe(200) + expect(createDocumentRecords).toHaveBeenCalled() + }) + + it.each([ + ['absolute local path', '/etc/passwd'], + ['app config path', '/app/.env'], + ['file:// URL', 'file:///etc/passwd'], + ['relative serve path', '/api/files/serve/kb/foo.pdf'], + ['ftp URL', 'ftp://example.com/file.pdf'], + ['parent traversal', '../../etc/passwd'], + ['windows path', 'C:\\Windows\\System32\\config\\SAM'], + ])('rejects %s with 400 and never invokes the pipeline', async (_label, fileUrl) => { + const req = createMockRequest('POST', { ...baseBody, fileUrl }) + const res = await POST(req, { params }) + expect(res.status).toBe(400) + expect(createDocumentRecords).not.toHaveBeenCalled() + expect(processDocumentsWithQueue).not.toHaveBeenCalled() + }) +}) diff --git a/apps/sim/app/api/knowledge/[id]/route.test.ts b/apps/sim/app/api/knowledge/[id]/route.test.ts index 111e42829bd..2d0dc1ce2e2 100644 --- a/apps/sim/app/api/knowledge/[id]/route.test.ts +++ b/apps/sim/app/api/knowledge/[id]/route.test.ts @@ -31,6 +31,7 @@ vi.mock('@/lib/knowledge/service', async (importOriginal) => { getKnowledgeBaseById: vi.fn(), updateKnowledgeBase: vi.fn(), deleteKnowledgeBase: vi.fn(), + KnowledgeBasePermissionError: actual.KnowledgeBasePermissionError, } }) @@ -39,6 +40,7 @@ vi.mock('@/app/api/knowledge/utils', () => knowledgeApiUtilsMock) import { deleteKnowledgeBase, getKnowledgeBaseById, + KnowledgeBasePermissionError, updateKnowledgeBase, } from '@/lib/knowledge/service' import { DELETE, GET, PUT } from '@/app/api/knowledge/[id]/route' @@ -229,10 +231,59 @@ describe('Knowledge Base By ID API Route', () => { workspaceId: undefined, chunkingConfig: undefined, }, - expect.any(String) + expect.any(String), + { actorUserId: 'user-123' } ) }) + it('returns 403 when service rejects a cross-workspace transfer', async () => { + authMockFns.mockGetSession.mockResolvedValue({ + user: { id: 'attacker', email: 'a@example.com' }, + }) + + resetMocks() + + vi.mocked(checkKnowledgeBaseWriteAccess).mockResolvedValueOnce({ + hasAccess: true, + knowledgeBase: { id: 'kb-123', userId: 'user-123', workspaceId: 'ws-current' }, + }) + + vi.mocked(updateKnowledgeBase).mockRejectedValueOnce( + new KnowledgeBasePermissionError('User does not have permission on the target workspace') + ) + + const req = createMockRequest('PUT', { workspaceId: 'ws-target' }) + const response = await PUT(req, { params: mockParams }) + const data = await response.json() + + expect(response.status).toBe(403) + expect(data.error).toBe('User does not have permission on the target workspace') + }) + + it('returns 403 when service rejects clearing workspaceId', async () => { + authMockFns.mockGetSession.mockResolvedValue({ + user: { id: 'user-123', email: 'test@example.com' }, + }) + + resetMocks() + + vi.mocked(checkKnowledgeBaseWriteAccess).mockResolvedValueOnce({ + hasAccess: true, + knowledgeBase: { id: 'kb-123', userId: 'user-123', workspaceId: 'ws-current' }, + }) + + vi.mocked(updateKnowledgeBase).mockRejectedValueOnce( + new KnowledgeBasePermissionError('Knowledge base workspace cannot be cleared') + ) + + const req = createMockRequest('PUT', { workspaceId: null }) + const response = await PUT(req, { params: mockParams }) + const data = await response.json() + + expect(response.status).toBe(403) + expect(data.error).toBe('Knowledge base workspace cannot be cleared') + }) + it('should return unauthorized for unauthenticated user', async () => { authMockFns.mockGetSession.mockResolvedValue(null) diff --git a/apps/sim/app/api/knowledge/[id]/route.ts b/apps/sim/app/api/knowledge/[id]/route.ts index 5013a8c29c7..34a85d1a684 100644 --- a/apps/sim/app/api/knowledge/[id]/route.ts +++ b/apps/sim/app/api/knowledge/[id]/route.ts @@ -11,6 +11,7 @@ import { deleteKnowledgeBase, getKnowledgeBaseById, KnowledgeBaseConflictError, + KnowledgeBasePermissionError, updateKnowledgeBase, } from '@/lib/knowledge/service' import { checkKnowledgeBaseAccess, checkKnowledgeBaseWriteAccess } from '@/app/api/knowledge/utils' @@ -101,7 +102,8 @@ export const PUT = withRouteHandler( workspaceId: validatedData.workspaceId, chunkingConfig: validatedData.chunkingConfig, }, - requestId + requestId, + { actorUserId: userId } ) logger.info(`[${requestId}] Knowledge base updated: ${id} for user ${userId}`) @@ -141,6 +143,10 @@ export const PUT = withRouteHandler( if (error instanceof KnowledgeBaseConflictError) { return NextResponse.json({ error: error.message }, { status: 409 }) } + if (error instanceof KnowledgeBasePermissionError) { + logger.warn(`[${requestId}] Forbidden knowledge base update on ${id}: ${error.message}`) + return NextResponse.json({ error: error.message }, { status: 403 }) + } logger.error(`[${requestId}] Error updating knowledge base`, error) return NextResponse.json({ error: 'Failed to update knowledge base' }, { status: 500 }) diff --git a/apps/sim/app/api/knowledge/route.test.ts b/apps/sim/app/api/knowledge/route.test.ts index bc0ab08d755..4ad2aad2acf 100644 --- a/apps/sim/app/api/knowledge/route.test.ts +++ b/apps/sim/app/api/knowledge/route.test.ts @@ -155,6 +155,23 @@ describe('Knowledge Base API Route', () => { expect(data.details).toBeDefined() }) + it('returns 403 when user lacks permission on target workspace', async () => { + authMockFns.mockGetSession.mockResolvedValue({ + user: { id: 'attacker', email: 'a@example.com' }, + }) + permissionsMockFns.mockGetUserEntityPermissions.mockResolvedValueOnce('read') + + const req = createMockRequest('POST', validKnowledgeBaseData) + const response = await POST(req) + const data = await response.json() + + expect(response.status).toBe(403) + expect(data.error).toBe( + 'User does not have permission to create knowledge bases in this workspace' + ) + expect(mockDbChain.insert).not.toHaveBeenCalled() + }) + it('should validate chunking config constraints', async () => { authMockFns.mockGetSession.mockResolvedValue({ user: { id: 'user-123', email: 'test@example.com' }, diff --git a/apps/sim/app/api/knowledge/route.ts b/apps/sim/app/api/knowledge/route.ts index 8cea52b8eb7..e14efd14656 100644 --- a/apps/sim/app/api/knowledge/route.ts +++ b/apps/sim/app/api/knowledge/route.ts @@ -15,6 +15,7 @@ import { createKnowledgeBase, getKnowledgeBases, KnowledgeBaseConflictError, + KnowledgeBasePermissionError, type KnowledgeBaseScope, } from '@/lib/knowledge/service' import { captureServerEvent } from '@/lib/posthog/server' @@ -159,6 +160,10 @@ export const POST = withRouteHandler(async (req: NextRequest) => { if (createError instanceof KnowledgeBaseConflictError) { return NextResponse.json({ error: createError.message }, { status: 409 }) } + if (createError instanceof KnowledgeBasePermissionError) { + logger.warn(`[${requestId}] Forbidden knowledge base creation: ${createError.message}`) + return NextResponse.json({ error: createError.message }, { status: 403 }) + } throw createError } } catch (error) { diff --git a/apps/sim/app/api/mcp/servers/test-connection/route.ts b/apps/sim/app/api/mcp/servers/test-connection/route.ts index 46ef05fc2bd..bd88be77aad 100644 --- a/apps/sim/app/api/mcp/servers/test-connection/route.ts +++ b/apps/sim/app/api/mcp/servers/test-connection/route.ts @@ -95,6 +95,9 @@ export const POST = withRouteHandler( } try { + // Initial pre-resolution check; the authoritative resolved IP is + // captured after env-var resolution below and used to pin the + // connection against DNS rebinding. await validateMcpServerSsrf(body.url) } catch (e) { if (e instanceof McpDnsResolutionError) { @@ -140,8 +143,9 @@ export const POST = withRouteHandler( throw e } + let resolvedIP: string | null try { - await validateMcpServerSsrf(testConfig.url) + resolvedIP = await validateMcpServerSsrf(testConfig.url) } catch (e) { if (e instanceof McpDnsResolutionError) { return createMcpErrorResponse(e, e.message, 502) @@ -162,7 +166,11 @@ export const POST = withRouteHandler( let client: McpClient | null = null try { - client = new McpClient(testConfig, testSecurityPolicy) + client = new McpClient({ + config: testConfig, + securityPolicy: testSecurityPolicy, + resolvedIP: resolvedIP ?? undefined, + }) await client.connect() result.negotiatedVersion = client.getNegotiatedVersion() diff --git a/apps/sim/app/api/tools/agiloft/attach/route.test.ts b/apps/sim/app/api/tools/agiloft/attach/route.test.ts new file mode 100644 index 00000000000..f1e4c8c4264 --- /dev/null +++ b/apps/sim/app/api/tools/agiloft/attach/route.test.ts @@ -0,0 +1,144 @@ +/** + * @vitest-environment node + */ +import { + createMockRequest, + hybridAuthMockFns, + inputValidationMock, + inputValidationMockFns, +} from '@sim/testing' +import { beforeEach, describe, expect, it, vi } from 'vitest' + +const { mockProcessFilesToUserFiles, mockDownloadFileFromStorage, mockAssertToolFileAccess } = + vi.hoisted(() => ({ + mockProcessFilesToUserFiles: vi.fn(), + mockDownloadFileFromStorage: vi.fn(), + mockAssertToolFileAccess: vi.fn(), + })) + +vi.mock('@/lib/core/security/input-validation.server', () => inputValidationMock) +vi.mock('@/lib/uploads/utils/file-utils', () => ({ + processFilesToUserFiles: mockProcessFilesToUserFiles, +})) +vi.mock('@/lib/uploads/utils/file-utils.server', () => ({ + downloadFileFromStorage: mockDownloadFileFromStorage, +})) +vi.mock('@/app/api/files/authorization', () => ({ + assertToolFileAccess: mockAssertToolFileAccess, +})) + +import { POST } from '@/app/api/tools/agiloft/attach/route' + +const PINNED_IP = '93.184.216.34' + +const baseBody = { + instanceUrl: 'https://example.agiloft.com', + knowledgeBase: 'demo', + login: 'admin', + password: 'secret', + table: 'contracts', + recordId: '42', + fieldName: 'attachments', + file: { key: 's3://bucket/file.txt', name: 'file.txt', size: 5, type: 'text/plain' }, + fileName: 'file.txt', +} + +function mockSecureFetchResponse(body: { + ok?: boolean + status?: number + json?: unknown + text?: string +}) { + return { + ok: body.ok ?? true, + status: body.status ?? 200, + statusText: '', + headers: new Headers(), + body: null, + text: async () => body.text ?? '', + json: async () => body.json ?? {}, + arrayBuffer: async () => new ArrayBuffer(0), + } +} + +beforeEach(() => { + vi.clearAllMocks() + hybridAuthMockFns.mockCheckInternalAuth.mockResolvedValue({ + success: true, + userId: 'user-1', + authType: 'internal_jwt', + }) + inputValidationMockFns.mockValidateUrlWithDNS.mockResolvedValue({ + isValid: true, + resolvedIP: PINNED_IP, + originalHostname: 'example.agiloft.com', + }) + mockProcessFilesToUserFiles.mockReturnValue([ + { key: 's3://bucket/file.txt', name: 'file.txt', size: 5, type: 'text/plain' }, + ]) + mockAssertToolFileAccess.mockResolvedValue(null) + mockDownloadFileFromStorage.mockResolvedValue(Buffer.from('hello')) +}) + +describe('POST /api/tools/agiloft/attach', () => { + it('rejects unauthenticated requests', async () => { + hybridAuthMockFns.mockCheckInternalAuth.mockResolvedValueOnce({ + success: false, + error: 'unauthorized', + }) + + const response = await POST(createMockRequest('POST', baseBody)) + expect(response.status).toBe(401) + expect(inputValidationMockFns.mockSecureFetchWithPinnedIP).not.toHaveBeenCalled() + }) + + it('blocks SSRF when the instance URL fails DNS validation', async () => { + inputValidationMockFns.mockValidateUrlWithDNS.mockResolvedValueOnce({ + isValid: false, + error: 'instanceUrl resolves to a blocked IP address', + }) + + const response = await POST( + createMockRequest('POST', { ...baseBody, instanceUrl: 'https://attacker.example.com' }) + ) + + expect(response.status).toBe(400) + expect(inputValidationMockFns.mockSecureFetchWithPinnedIP).not.toHaveBeenCalled() + }) + + it('pins the resolved IP for login, attach, and logout (TOCTOU fix)', async () => { + inputValidationMockFns.mockSecureFetchWithPinnedIP + .mockResolvedValueOnce(mockSecureFetchResponse({ json: { access_token: 'tok-att' } })) + .mockResolvedValueOnce(mockSecureFetchResponse({ text: '1' })) + .mockResolvedValueOnce(mockSecureFetchResponse({})) + + const response = await POST(createMockRequest('POST', baseBody)) + expect(response.status).toBe(200) + const data = (await response.json()) as { + success: true + output: { totalAttachments: number; fileName: string } + } + expect(data.output.totalAttachments).toBe(1) + expect(data.output.fileName).toBe('file.txt') + + const calls = inputValidationMockFns.mockSecureFetchWithPinnedIP.mock.calls + expect(calls).toHaveLength(3) + for (const call of calls) { + expect(call[1]).toBe(PINNED_IP) + } + + expect(calls[0][0]).toContain('https://example.agiloft.com/ewws/EWLogin') + expect(calls[1][0]).toContain('https://example.agiloft.com/ewws/EWAttach') + expect(calls[1][2]).toMatchObject({ + method: 'PUT', + headers: { + Authorization: 'Bearer tok-att', + 'Content-Type': 'application/octet-stream', + }, + }) + expect(calls[2][0]).toContain('https://example.agiloft.com/ewws/EWLogout') + + // DNS only resolved once. + expect(inputValidationMockFns.mockValidateUrlWithDNS).toHaveBeenCalledTimes(1) + }) +}) diff --git a/apps/sim/app/api/tools/agiloft/attach/route.ts b/apps/sim/app/api/tools/agiloft/attach/route.ts index 6257502ae4c..b0fcb351751 100644 --- a/apps/sim/app/api/tools/agiloft/attach/route.ts +++ b/apps/sim/app/api/tools/agiloft/attach/route.ts @@ -4,14 +4,19 @@ import { type NextRequest, NextResponse } from 'next/server' import { agiloftAttachContract } from '@/lib/api/contracts/tools/agiloft' import { getValidationErrorMessage, parseRequest } from '@/lib/api/server' import { checkInternalAuth } from '@/lib/auth/hybrid' -import { validateUrlWithDNS } from '@/lib/core/security/input-validation.server' +import { secureFetchWithPinnedIP } from '@/lib/core/security/input-validation.server' import { generateRequestId } from '@/lib/core/utils/request' import { withRouteHandler } from '@/lib/core/utils/with-route-handler' import type { RawFileInput } from '@/lib/uploads/utils/file-schemas' import { processFilesToUserFiles } from '@/lib/uploads/utils/file-utils' import { downloadFileFromStorage } from '@/lib/uploads/utils/file-utils.server' import { assertToolFileAccess } from '@/app/api/files/authorization' -import { agiloftLogin, agiloftLogout, buildAttachFileUrl } from '@/tools/agiloft/utils' +import { buildAttachFileUrl } from '@/tools/agiloft/utils' +import { + agiloftLoginPinned, + agiloftLogoutPinned, + resolveAgiloftInstance, +} from '@/tools/agiloft/utils.server' export const dynamic = 'force-dynamic' @@ -72,18 +77,17 @@ export const POST = withRouteHandler(async (request: NextRequest) => { const fileBuffer = await downloadFileFromStorage(userFile, requestId, logger) const resolvedFileName = data.fileName || userFile.name || 'attachment' - const urlValidation = await validateUrlWithDNS(data.instanceUrl, 'instanceUrl') - if (!urlValidation.isValid) { + let resolvedIP: string + try { + resolvedIP = await resolveAgiloftInstance(data.instanceUrl) + } catch (error) { logger.warn(`[${requestId}] SSRF attempt blocked for Agiloft instance URL`, { instanceUrl: data.instanceUrl, }) - return NextResponse.json( - { success: false, error: urlValidation.error || 'Invalid instance URL' }, - { status: 400 } - ) + return NextResponse.json({ success: false, error: toError(error).message }, { status: 400 }) } - const token = await agiloftLogin(data) + const token = await agiloftLoginPinned(data, resolvedIP) const base = data.instanceUrl.replace(/\/$/, '') try { @@ -91,7 +95,7 @@ export const POST = withRouteHandler(async (request: NextRequest) => { logger.info(`[${requestId}] Uploading file to Agiloft: ${resolvedFileName}`) - const agiloftResponse = await fetch(url, { + const agiloftResponse = await secureFetchWithPinnedIP(url, resolvedIP, { method: 'PUT', headers: { 'Content-Type': 'application/octet-stream', @@ -135,7 +139,7 @@ export const POST = withRouteHandler(async (request: NextRequest) => { }, }) } finally { - await agiloftLogout(data.instanceUrl, data.knowledgeBase, token) + await agiloftLogoutPinned(data.instanceUrl, data.knowledgeBase, token, resolvedIP) } } catch (error) { logger.error(`[${requestId}] Error attaching file to Agiloft:`, error) diff --git a/apps/sim/app/api/tools/agiloft/retrieve/route.test.ts b/apps/sim/app/api/tools/agiloft/retrieve/route.test.ts new file mode 100644 index 00000000000..efd435b5b04 --- /dev/null +++ b/apps/sim/app/api/tools/agiloft/retrieve/route.test.ts @@ -0,0 +1,163 @@ +/** + * @vitest-environment node + */ +import { + createMockRequest, + hybridAuthMockFns, + inputValidationMock, + inputValidationMockFns, +} from '@sim/testing' +import { beforeEach, describe, expect, it, vi } from 'vitest' + +vi.mock('@/lib/core/security/input-validation.server', () => inputValidationMock) + +import { POST } from '@/app/api/tools/agiloft/retrieve/route' + +const PINNED_IP = '93.184.216.34' + +const baseBody = { + instanceUrl: 'https://example.agiloft.com', + knowledgeBase: 'demo', + login: 'admin', + password: 'secret', + table: 'contracts', + recordId: '42', + fieldName: 'attachments', + position: '0', +} + +function mockSecureFetchResponse(body: { + ok?: boolean + status?: number + json?: unknown + text?: string + arrayBuffer?: ArrayBuffer + headers?: Headers +}) { + return { + ok: body.ok ?? true, + status: body.status ?? 200, + statusText: '', + headers: body.headers ?? new Headers(), + body: null, + text: async () => body.text ?? '', + json: async () => body.json ?? {}, + arrayBuffer: async () => body.arrayBuffer ?? new ArrayBuffer(0), + } +} + +beforeEach(() => { + vi.clearAllMocks() + hybridAuthMockFns.mockCheckInternalAuth.mockResolvedValue({ + success: true, + userId: 'user-1', + authType: 'internal_jwt', + }) + inputValidationMockFns.mockValidateUrlWithDNS.mockResolvedValue({ + isValid: true, + resolvedIP: PINNED_IP, + originalHostname: 'example.agiloft.com', + }) +}) + +describe('POST /api/tools/agiloft/retrieve', () => { + it('rejects unauthenticated requests', async () => { + hybridAuthMockFns.mockCheckInternalAuth.mockResolvedValueOnce({ + success: false, + error: 'unauthorized', + }) + + const response = await POST(createMockRequest('POST', baseBody)) + expect(response.status).toBe(401) + expect(inputValidationMockFns.mockSecureFetchWithPinnedIP).not.toHaveBeenCalled() + }) + + it('blocks SSRF when the instance URL fails DNS validation', async () => { + inputValidationMockFns.mockValidateUrlWithDNS.mockResolvedValueOnce({ + isValid: false, + error: 'instanceUrl resolves to a blocked IP address', + }) + + const response = await POST( + createMockRequest('POST', { ...baseBody, instanceUrl: 'https://attacker.example.com' }) + ) + + expect(response.status).toBe(400) + const data = (await response.json()) as { success: false; error: string } + expect(data.success).toBe(false) + expect(data.error).toContain('blocked IP') + expect(inputValidationMockFns.mockSecureFetchWithPinnedIP).not.toHaveBeenCalled() + }) + + it('pins the resolved IP for login, retrieve, and logout (TOCTOU fix)', async () => { + const fileBytes = Buffer.from('hello-attachment', 'utf-8') + + inputValidationMockFns.mockSecureFetchWithPinnedIP + .mockResolvedValueOnce(mockSecureFetchResponse({ json: { access_token: 'tok-xyz' } })) + .mockResolvedValueOnce( + mockSecureFetchResponse({ + arrayBuffer: fileBytes.buffer.slice( + fileBytes.byteOffset, + fileBytes.byteOffset + fileBytes.byteLength + ) as ArrayBuffer, + headers: new Headers({ + 'content-type': 'text/plain', + 'content-disposition': 'attachment; filename="report.txt"', + }), + }) + ) + .mockResolvedValueOnce(mockSecureFetchResponse({})) + + const response = await POST(createMockRequest('POST', baseBody)) + expect(response.status).toBe(200) + const data = (await response.json()) as { + success: true + output: { file: { name: string; mimeType: string; data: string; size: number } } + } + + expect(data.output.file.name).toBe('report.txt') + expect(data.output.file.mimeType).toBe('text/plain') + expect(data.output.file.size).toBe(fileBytes.length) + expect(Buffer.from(data.output.file.data, 'base64').toString('utf-8')).toBe('hello-attachment') + + const calls = inputValidationMockFns.mockSecureFetchWithPinnedIP.mock.calls + expect(calls).toHaveLength(3) + + // All three outbound calls must use the pre-resolved IP. + for (const call of calls) { + expect(call[1]).toBe(PINNED_IP) + } + + // Original hostname is preserved in the URL (so TLS SNI works). + expect(calls[0][0]).toContain('https://example.agiloft.com/ewws/EWLogin') + expect(calls[1][0]).toContain('https://example.agiloft.com/ewws/EWRetrieve') + expect(calls[1][2]).toMatchObject({ + method: 'GET', + headers: { Authorization: 'Bearer tok-xyz' }, + }) + expect(calls[2][0]).toContain('https://example.agiloft.com/ewws/EWLogout') + + // DNS only resolved once — no second lookup that could rebind. + expect(inputValidationMockFns.mockValidateUrlWithDNS).toHaveBeenCalledTimes(1) + }) + + it('propagates upstream errors and still calls logout', async () => { + inputValidationMockFns.mockSecureFetchWithPinnedIP + .mockResolvedValueOnce(mockSecureFetchResponse({ json: { access_token: 'tok-err' } })) + .mockResolvedValueOnce( + mockSecureFetchResponse({ ok: false, status: 404, text: 'Record not found' }) + ) + .mockResolvedValueOnce(mockSecureFetchResponse({})) + + const response = await POST(createMockRequest('POST', baseBody)) + expect(response.status).toBe(404) + const data = (await response.json()) as { success: false; error: string } + expect(data.error).toContain('Record not found') + + // Logout still runs. + expect(inputValidationMockFns.mockSecureFetchWithPinnedIP).toHaveBeenCalledTimes(3) + expect(inputValidationMockFns.mockSecureFetchWithPinnedIP.mock.calls[2][0]).toContain( + '/ewws/EWLogout' + ) + }) +}) diff --git a/apps/sim/app/api/tools/agiloft/retrieve/route.ts b/apps/sim/app/api/tools/agiloft/retrieve/route.ts index 64bd72daae8..539f0bf7c2e 100644 --- a/apps/sim/app/api/tools/agiloft/retrieve/route.ts +++ b/apps/sim/app/api/tools/agiloft/retrieve/route.ts @@ -4,10 +4,15 @@ import { type NextRequest, NextResponse } from 'next/server' import { agiloftRetrieveContract } from '@/lib/api/contracts/tools/agiloft' import { getValidationErrorMessage, parseRequest } from '@/lib/api/server' import { checkInternalAuth } from '@/lib/auth/hybrid' -import { validateUrlWithDNS } from '@/lib/core/security/input-validation.server' +import { secureFetchWithPinnedIP } from '@/lib/core/security/input-validation.server' import { generateRequestId } from '@/lib/core/utils/request' import { withRouteHandler } from '@/lib/core/utils/with-route-handler' -import { agiloftLogin, agiloftLogout, buildRetrieveAttachmentUrl } from '@/tools/agiloft/utils' +import { buildRetrieveAttachmentUrl } from '@/tools/agiloft/utils' +import { + agiloftLoginPinned, + agiloftLogoutPinned, + resolveAgiloftInstance, +} from '@/tools/agiloft/utils.server' export const dynamic = 'force-dynamic' @@ -48,18 +53,17 @@ export const POST = withRouteHandler(async (request: NextRequest) => { if (!parsed.success) return parsed.response const data = parsed.data.body - const urlValidation = await validateUrlWithDNS(data.instanceUrl, 'instanceUrl') - if (!urlValidation.isValid) { + let resolvedIP: string + try { + resolvedIP = await resolveAgiloftInstance(data.instanceUrl) + } catch (error) { logger.warn(`[${requestId}] SSRF attempt blocked for Agiloft instance URL`, { instanceUrl: data.instanceUrl, }) - return NextResponse.json( - { success: false, error: urlValidation.error || 'Invalid instance URL' }, - { status: 400 } - ) + return NextResponse.json({ success: false, error: toError(error).message }, { status: 400 }) } - const token = await agiloftLogin(data) + const token = await agiloftLoginPinned(data, resolvedIP) const base = data.instanceUrl.replace(/\/$/, '') try { @@ -71,7 +75,7 @@ export const POST = withRouteHandler(async (request: NextRequest) => { position: data.position, }) - const agiloftResponse = await fetch(url, { + const agiloftResponse = await secureFetchWithPinnedIP(url, resolvedIP, { method: 'GET', headers: { Authorization: `Bearer ${token}`, @@ -123,7 +127,7 @@ export const POST = withRouteHandler(async (request: NextRequest) => { }, }) } finally { - await agiloftLogout(data.instanceUrl, data.knowledgeBase, token) + await agiloftLogoutPinned(data.instanceUrl, data.knowledgeBase, token, resolvedIP) } } catch (error) { logger.error(`[${requestId}] Error retrieving Agiloft attachment:`, error) diff --git a/apps/sim/app/form/[identifier]/components/email-auth.tsx b/apps/sim/app/form/[identifier]/components/email-auth.tsx new file mode 100644 index 00000000000..b75cb159c3c --- /dev/null +++ b/apps/sim/app/form/[identifier]/components/email-auth.tsx @@ -0,0 +1,284 @@ +'use client' + +import { useEffect, useState } from 'react' +import { createLogger } from '@sim/logger' +import { toError } from '@sim/utils/errors' +import { Input, InputOTP, InputOTPGroup, InputOTPSlot, Label, Loader } from '@/components/emcn' +import { cn } from '@/lib/core/utils/cn' +import { quickValidateEmail } from '@/lib/messaging/email/validation' +import AuthBackground from '@/app/(auth)/components/auth-background' +import { AUTH_SUBMIT_BTN } from '@/app/(auth)/components/auth-button-classes' +import { SupportFooter } from '@/app/(auth)/components/support-footer' +import Navbar from '@/app/(landing)/components/navbar/navbar' +import { useFormEmailOtpRequest, useFormEmailOtpVerify } from '@/hooks/queries/forms' + +const logger = createLogger('FormEmailAuth') + +interface EmailAuthProps { + identifier: string + onAuthenticated: () => void +} + +function validateEmailField(emailValue: string): string[] { + const errors: string[] = [] + + if (!emailValue || !emailValue.trim()) { + errors.push('Email is required.') + return errors + } + + const validation = quickValidateEmail(emailValue.trim().toLowerCase()) + if (!validation.isValid) { + errors.push(validation.reason || 'Please enter a valid email address.') + } + + return errors +} + +export function EmailAuth({ identifier, onAuthenticated }: EmailAuthProps) { + const [email, setEmail] = useState('') + const [authError, setAuthError] = useState(null) + const [emailErrors, setEmailErrors] = useState([]) + const [showEmailValidationError, setShowEmailValidationError] = useState(false) + + const [showOtpVerification, setShowOtpVerification] = useState(false) + const [otpValue, setOtpValue] = useState('') + const [countdown, setCountdown] = useState(0) + + const requestOtp = useFormEmailOtpRequest(identifier) + const verifyOtp = useFormEmailOtpVerify(identifier) + + useEffect(() => { + if (countdown <= 0) return + const timer = setTimeout(() => setCountdown((c) => c - 1), 1000) + return () => clearTimeout(timer) + }, [countdown]) + + const handleEmailChange = (e: React.ChangeEvent) => { + const newEmail = e.target.value + setEmail(newEmail) + const errors = validateEmailField(newEmail) + setEmailErrors(errors) + setShowEmailValidationError(false) + } + + const handleSendOtp = async () => { + const emailValidationErrors = validateEmailField(email) + setEmailErrors(emailValidationErrors) + setShowEmailValidationError(emailValidationErrors.length > 0) + + if (emailValidationErrors.length > 0) return + + setAuthError(null) + + try { + await requestOtp.mutateAsync({ email }) + setShowOtpVerification(true) + } catch (error) { + logger.error('Error sending OTP:', error) + setEmailErrors([toError(error).message || 'Failed to send verification code']) + setShowEmailValidationError(true) + } + } + + const handleVerifyOtp = async (otp?: string) => { + const codeToVerify = otp || otpValue + if (!codeToVerify || codeToVerify.length !== 6) return + + setAuthError(null) + + try { + await verifyOtp.mutateAsync({ email, otp: codeToVerify }) + onAuthenticated() + } catch (error) { + logger.error('Error verifying OTP:', error) + setAuthError(toError(error).message || 'Invalid verification code') + } + } + + const handleResendOtp = async () => { + setAuthError(null) + setCountdown(30) + + try { + await requestOtp.mutateAsync({ email }) + setOtpValue('') + } catch (error) { + logger.error('Error resending OTP:', error) + setAuthError(toError(error).message || 'Failed to resend verification code') + setCountdown(0) + } + } + + return ( + +
+
+ +
+
+
+
+
+

+ {showOtpVerification ? 'Verify Your Email' : 'Email Verification'} +

+

+ {showOtpVerification + ? `A verification code has been sent to ${email}` + : 'This form requires email verification'} +

+
+ +
+ {!showOtpVerification ? ( +
{ + e.preventDefault() + handleSendOtp() + }} + className='space-y-6' + > +
+ + 0 && + 'border-red-500 focus:border-red-500' + )} + /> + {showEmailValidationError && emailErrors.length > 0 && ( +
+ {emailErrors.map((error) => ( +

{error}

+ ))} +
+ )} +
+ + +
+ ) : ( +
+

+ Enter the 6-digit code to verify your account. If you don't see it in your + inbox, check your spam folder. +

+ +
+ { + setOtpValue(value) + if (value.length === 6) { + handleVerifyOtp(value) + } + }} + disabled={verifyOtp.isPending} + className={cn('gap-2', authError && 'otp-error')} + > + + {[0, 1, 2, 3, 4, 5].map((index) => ( + + ))} + + +
+ + {authError && ( +
+

{authError}

+
+ )} + + + +
+

+ Didn't receive a code?{' '} + {countdown > 0 ? ( + + Resend in{' '} + + {countdown}s + + + ) : ( + + )} +

+
+ +
+ +
+
+ )} +
+
+
+
+ +
+
+ ) +} diff --git a/apps/sim/app/form/[identifier]/components/index.ts b/apps/sim/app/form/[identifier]/components/index.ts index 31cb46d6843..e888196967c 100644 --- a/apps/sim/app/form/[identifier]/components/index.ts +++ b/apps/sim/app/form/[identifier]/components/index.ts @@ -1,3 +1,4 @@ +export { EmailAuth } from './email-auth' export { FormErrorState } from './error-state' export { FormField } from './form-field' export { FormLoadingState } from './loading-state' diff --git a/apps/sim/app/form/[identifier]/form.tsx b/apps/sim/app/form/[identifier]/form.tsx index f6264ddf0fb..4e809096b55 100644 --- a/apps/sim/app/form/[identifier]/form.tsx +++ b/apps/sim/app/form/[identifier]/form.tsx @@ -10,6 +10,7 @@ import { AUTH_SUBMIT_BTN } from '@/app/(auth)/components/auth-button-classes' import { SupportFooter } from '@/app/(auth)/components/support-footer' import Navbar from '@/app/(landing)/components/navbar/navbar' import { + EmailAuth, FormErrorState, FormField, FormLoadingState, @@ -241,6 +242,10 @@ export default function Form({ identifier }: { identifier: string }) { return } + if (authRequired === 'email') { + return fetchFormConfig()} /> + } + if (isSubmitted && thankYouData) { return ( diff --git a/apps/sim/hooks/queries/forms.ts b/apps/sim/hooks/queries/forms.ts index e43e8fe7ecc..e20df733b5e 100644 --- a/apps/sim/hooks/queries/forms.ts +++ b/apps/sim/hooks/queries/forms.ts @@ -14,8 +14,10 @@ import { type FormStatusResponse, getFormDetailContract, getFormStatusContract, + requestFormEmailOtpContract, type UpdateFormInput, updateFormContract, + verifyFormEmailOtpContract, } from '@/lib/api/contracts/forms' import { deploymentKeys } from './deployments' @@ -35,6 +37,36 @@ export const formKeys = { */ export type { FormAuthType } +/** + * Requests a one-time passcode for an email-gated deployed form. + * Used for both the initial send and resend flows. + */ +export function useFormEmailOtpRequest(identifier: string) { + return useMutation({ + mutationFn: async ({ email }: { email: string }) => { + await requestJson(requestFormEmailOtpContract, { + params: { identifier }, + body: { email }, + }) + }, + }) +} + +/** + * Verifies a one-time passcode for an email-gated deployed form. + * On success the server sets the auth cookie; the caller should re-fetch the form config. + */ +export function useFormEmailOtpVerify(identifier: string) { + return useMutation({ + mutationFn: async ({ email, otp }: { email: string; otp: string }) => { + await requestJson(verifyFormEmailOtpContract, { + params: { identifier }, + body: { email, otp }, + }) + }, + }) +} + /** * Field configuration for form fields */ diff --git a/apps/sim/lib/api/contracts/forms.ts b/apps/sim/lib/api/contracts/forms.ts index af252043b65..0283121edca 100644 --- a/apps/sim/lib/api/contracts/forms.ts +++ b/apps/sim/lib/api/contracts/forms.ts @@ -148,6 +148,22 @@ export const formMutationResponseSchema = z.object({ message: z.string(), }) +export const formEmailOtpRequestBodySchema = z.object({ + email: z.string().email('Invalid email address'), +}) + +export const formEmailOtpVerifyBodySchema = formEmailOtpRequestBodySchema.extend({ + otp: z.string().length(6, 'OTP must be 6 digits'), +}) + +export const formEmailOtpRequestResponseSchema = z.object({ + message: z.string(), +}) + +export const formEmailOtpVerifyResponseSchema = z.object({ + authenticated: z.literal(true), +}) + export const getFormStatusContract = defineRouteContract({ method: 'GET', path: '/api/workflows/[id]/form/status', @@ -199,6 +215,28 @@ export const deleteFormContract = defineRouteContract({ }, }) +export const requestFormEmailOtpContract = defineRouteContract({ + method: 'POST', + path: '/api/form/[identifier]/otp', + params: formIdentifierParamsSchema, + body: formEmailOtpRequestBodySchema, + response: { + mode: 'json', + schema: formEmailOtpRequestResponseSchema, + }, +}) + +export const verifyFormEmailOtpContract = defineRouteContract({ + method: 'PUT', + path: '/api/form/[identifier]/otp', + params: formIdentifierParamsSchema, + body: formEmailOtpVerifyBodySchema, + response: { + mode: 'json', + schema: formEmailOtpVerifyResponseSchema, + }, +}) + export const validateFormIdentifierContract = defineRouteContract({ method: 'GET', path: '/api/form/validate', diff --git a/apps/sim/lib/api/contracts/knowledge/documents.ts b/apps/sim/lib/api/contracts/knowledge/documents.ts index 17cb344f37f..6a85005f700 100644 --- a/apps/sim/lib/api/contracts/knowledge/documents.ts +++ b/apps/sim/lib/api/contracts/knowledge/documents.ts @@ -5,6 +5,7 @@ import { documentNumberFieldSchema, documentTagFieldSchema, knowledgeBaseParamsSchema, + knowledgeDocumentFileUrlSchema, knowledgeDocumentParamsSchema, nullableWireDateSchema, paginationSchema, @@ -55,7 +56,7 @@ export const listKnowledgeDocumentsQuerySchema = z.object({ export const createDocumentBodySchema = z.object({ filename: z.string().min(1, 'Filename is required'), - fileUrl: z.string().url('File URL must be valid'), + fileUrl: knowledgeDocumentFileUrlSchema, fileSize: z.number().min(1, 'File size must be greater than 0'), mimeType: z.string().min(1, 'MIME type is required'), tag1: z.string().optional(), @@ -101,7 +102,7 @@ export type SingleCreateDocumentBody = z.input { + it('accepts data: URIs', () => { + const result = knowledgeDocumentFileUrlSchema.safeParse( + 'data:text/plain;base64,SGVsbG8gd29ybGQ=' + ) + expect(result.success).toBe(true) + }) + + it('accepts https URLs', () => { + const result = knowledgeDocumentFileUrlSchema.safeParse('https://example.com/file.pdf') + expect(result.success).toBe(true) + }) + + it('accepts http URLs', () => { + const result = knowledgeDocumentFileUrlSchema.safeParse( + 'http://localhost:3000/api/files/serve/kb/foo.pdf?context=knowledge-base' + ) + expect(result.success).toBe(true) + }) + + it('is case-insensitive on the scheme', () => { + expect(knowledgeDocumentFileUrlSchema.safeParse('HTTPS://example.com/x').success).toBe(true) + expect(knowledgeDocumentFileUrlSchema.safeParse('Http://example.com/x').success).toBe(true) + }) + + it.each([ + ['absolute local path', '/etc/passwd'], + ['app path', '/app/.env'], + ['relative path', './secrets.txt'], + ['parent traversal', '../../etc/shadow'], + ['file:// scheme', 'file:///etc/passwd'], + ['ftp scheme', 'ftp://example.com/x'], + ['javascript scheme', 'javascript:alert(1)'], + ['gopher scheme', 'gopher://example.com'], + ['relative serve path', '/api/files/serve/kb/foo.pdf'], + ['windows path', 'C:\\Windows\\System32\\config\\SAM'], + ['empty string', ''], + ['whitespace prefix', ' https://example.com/x'], + ])('rejects %s', (_label, value) => { + const result = knowledgeDocumentFileUrlSchema.safeParse(value) + expect(result.success).toBe(false) + }) + + it('returns a useful error message for unsupported schemes', () => { + const result = knowledgeDocumentFileUrlSchema.safeParse('/etc/passwd') + if (result.success) throw new Error('expected failure') + expect(result.error.issues[0].message).toMatch(/data: URI or an http\(s\):\/\/ URL/) + }) +}) diff --git a/apps/sim/lib/api/contracts/knowledge/shared.ts b/apps/sim/lib/api/contracts/knowledge/shared.ts index 070cd4606de..e75b3bde368 100644 --- a/apps/sim/lib/api/contracts/knowledge/shared.ts +++ b/apps/sim/lib/api/contracts/knowledge/shared.ts @@ -23,6 +23,21 @@ export const knowledgeConnectorParamsSchema = knowledgeBaseParamsSchema.extend({ connectorId: z.string().min(1), }) +/** + * A `fileUrl` accepted by knowledge document ingestion endpoints. + * + * Must be a `data:` URI or an `http(s)://` URL. Local paths, `file://`, + * and other schemes are rejected at the boundary to prevent the background + * parser from reading arbitrary files off the Sim server's filesystem. + */ +export const knowledgeDocumentFileUrlSchema = z + .string() + .min(1, 'File URL is required') + .refine( + (value) => /^data:/i.test(value) || /^https?:\/\//i.test(value), + 'File URL must be a data: URI or an http(s):// URL' + ) + export const documentTagFieldSchema = z.string().nullable().optional() export const documentNumberFieldSchema = z.number().nullable().optional() export const documentBooleanFieldSchema = z.boolean().nullable().optional() diff --git a/apps/sim/lib/core/security/otp.ts b/apps/sim/lib/core/security/otp.ts new file mode 100644 index 00000000000..5163487d10b --- /dev/null +++ b/apps/sim/lib/core/security/otp.ts @@ -0,0 +1,251 @@ +import { randomInt } from 'crypto' +import { db } from '@sim/db' +import { verification } from '@sim/db/schema' +import { generateId } from '@sim/utils/id' +import { and, eq, gt } from 'drizzle-orm' +import { getRedisClient } from '@/lib/core/config/redis' +import type { TokenBucketConfig } from '@/lib/core/rate-limiter' +import { getStorageMethod } from '@/lib/core/storage' + +export type DeploymentKind = 'chat' | 'form' + +/** + * Shared OTP configuration for deployment (chat/form) email-auth gates. + */ +export const OTP_EXPIRY_SECONDS = 15 * 60 +export const OTP_EXPIRY_MS = OTP_EXPIRY_SECONDS * 1000 +export const MAX_OTP_ATTEMPTS = 5 + +export const OTP_IP_RATE_LIMIT: TokenBucketConfig = { + maxTokens: 10, + refillRate: 10, + refillIntervalMs: 15 * 60_000, +} + +export const OTP_EMAIL_RATE_LIMIT: TokenBucketConfig = { + maxTokens: 3, + refillRate: 3, + refillIntervalMs: 15 * 60_000, +} + +/** + * Key formats are kept per-kind to preserve any in-flight OTPs already issued + * against existing chat deployments. The chat Redis key uses the legacy `otp:` + * prefix; the chat DB identifier uses `chat-otp:`. Forms use `form-otp:` for + * both. + */ +const OTP_KEYS = { + chat: { + redisKey: (email: string, deploymentId: string) => `otp:${email}:${deploymentId}`, + dbIdentifier: (email: string, deploymentId: string) => `chat-otp:${deploymentId}:${email}`, + }, + form: { + redisKey: (email: string, deploymentId: string) => `form-otp:${email}:${deploymentId}`, + dbIdentifier: (email: string, deploymentId: string) => `form-otp:${deploymentId}:${email}`, + }, +} as const satisfies Record< + DeploymentKind, + { + redisKey: (email: string, deploymentId: string) => string + dbIdentifier: (email: string, deploymentId: string) => string + } +> + +/** Returns a cryptographically random 6-digit OTP code. */ +export function generateOTP(): string { + return randomInt(100000, 1000000).toString() +} + +/** + * OTP values are stored as `"code:attempts"` (e.g. `"654321:0"`). + * This keeps the attempt counter in the same key/row as the OTP itself. + */ +function encodeOTPValue(otp: string, attempts: number): string { + return `${otp}:${attempts}` +} + +export function decodeOTPValue(value: string): { otp: string; attempts: number } { + const lastColon = value.lastIndexOf(':') + if (lastColon === -1) return { otp: value, attempts: 0 } + const attempts = Number.parseInt(value.slice(lastColon + 1), 10) + return { otp: value.slice(0, lastColon), attempts: Number.isNaN(attempts) ? 0 : attempts } +} + +/** + * Stores an OTP for a deployment+email pair, choosing Redis or the + * `verification` table based on the configured storage method. + */ +export async function storeOTP( + kind: DeploymentKind, + deploymentId: string, + email: string, + otp: string +): Promise { + const keys = OTP_KEYS[kind] + const value = encodeOTPValue(otp, 0) + const storageMethod = getStorageMethod() + + if (storageMethod === 'redis') { + const redis = getRedisClient() + if (!redis) throw new Error('Redis configured but client unavailable') + await redis.set(keys.redisKey(email, deploymentId), value, 'EX', OTP_EXPIRY_SECONDS) + return + } + + const now = new Date() + const expiresAt = new Date(now.getTime() + OTP_EXPIRY_MS) + const identifier = keys.dbIdentifier(email, deploymentId) + + await db.transaction(async (tx) => { + await tx.delete(verification).where(eq(verification.identifier, identifier)) + await tx.insert(verification).values({ + id: generateId(), + identifier, + value, + expiresAt, + createdAt: now, + updatedAt: now, + }) + }) +} + +export async function getOTP( + kind: DeploymentKind, + deploymentId: string, + email: string +): Promise { + const keys = OTP_KEYS[kind] + const storageMethod = getStorageMethod() + + if (storageMethod === 'redis') { + const redis = getRedisClient() + if (!redis) throw new Error('Redis configured but client unavailable') + return redis.get(keys.redisKey(email, deploymentId)) + } + + const now = new Date() + const [record] = await db + .select({ value: verification.value }) + .from(verification) + .where( + and( + eq(verification.identifier, keys.dbIdentifier(email, deploymentId)), + gt(verification.expiresAt, now) + ) + ) + .limit(1) + + return record?.value ?? null +} + +/** + * Lua script for atomic OTP attempt increment in Redis. + * Returns `'LOCKED'` if max attempts reached (key deleted), new encoded value + * otherwise, nil if key missing. + */ +const ATOMIC_INCREMENT_SCRIPT = ` +local val = redis.call('GET', KEYS[1]) +if not val then return nil end +local colon = val:find(':([^:]*$)') +local otp, attempts +if colon then + otp = val:sub(1, colon - 1) + attempts = tonumber(val:sub(colon + 1)) or 0 +else + otp = val + attempts = 0 +end +attempts = attempts + 1 +if attempts >= tonumber(ARGV[1]) then + redis.call('DEL', KEYS[1]) + return 'LOCKED' +end +local newVal = otp .. ':' .. attempts +local ttl = redis.call('TTL', KEYS[1]) +if ttl > 0 then + redis.call('SET', KEYS[1], newVal, 'EX', ttl) +else + redis.call('SET', KEYS[1], newVal) +end +return newVal +` + +/** + * Atomically increments an OTP's failed-attempt counter. Returns `'locked'` + * if the max-attempts threshold was reached (and the OTP was deleted), or + * `'incremented'` otherwise. The DB path uses optimistic locking with retry. + */ +export async function incrementOTPAttempts( + kind: DeploymentKind, + deploymentId: string, + email: string, + currentValue: string +): Promise<'locked' | 'incremented'> { + const keys = OTP_KEYS[kind] + const storageMethod = getStorageMethod() + + if (storageMethod === 'redis') { + const redis = getRedisClient() + if (!redis) throw new Error('Redis configured but client unavailable') + const key = keys.redisKey(email, deploymentId) + const result = await redis.eval(ATOMIC_INCREMENT_SCRIPT, 1, key, MAX_OTP_ATTEMPTS) + if (result === null || result === 'LOCKED') return 'locked' + return 'incremented' + } + + const identifier = keys.dbIdentifier(email, deploymentId) + const MAX_RETRIES = 3 + let value = currentValue + + for (let attempt = 0; attempt < MAX_RETRIES; attempt++) { + const { otp, attempts } = decodeOTPValue(value) + const newAttempts = attempts + 1 + + if (newAttempts >= MAX_OTP_ATTEMPTS) { + await db.delete(verification).where(eq(verification.identifier, identifier)) + return 'locked' + } + + const newValue = encodeOTPValue(otp, newAttempts) + const updated = await db + .update(verification) + .set({ value: newValue, updatedAt: new Date() }) + .where(and(eq(verification.identifier, identifier), eq(verification.value, value))) + .returning({ id: verification.id }) + + if (updated.length > 0) return 'incremented' + + const fresh = await getOTP(kind, deploymentId, email) + if (!fresh) return 'locked' + value = fresh + } + + /** + * Retry exhaustion under heavy DB-path contention: this request did not + * succeed in writing its own +1, so the stored count may not reflect it. + * Fail closed — invalidate the OTP rather than return `'incremented'` with + * a possibly-undercounted attempt total. + */ + await db.delete(verification).where(eq(verification.identifier, identifier)) + return 'locked' +} + +export async function deleteOTP( + kind: DeploymentKind, + deploymentId: string, + email: string +): Promise { + const keys = OTP_KEYS[kind] + const storageMethod = getStorageMethod() + + if (storageMethod === 'redis') { + const redis = getRedisClient() + if (!redis) throw new Error('Redis configured but client unavailable') + await redis.del(keys.redisKey(email, deploymentId)) + return + } + + await db + .delete(verification) + .where(eq(verification.identifier, keys.dbIdentifier(email, deploymentId))) +} diff --git a/apps/sim/lib/knowledge/documents/document-processor.ts b/apps/sim/lib/knowledge/documents/document-processor.ts index 67f246d06b5..5e550f5af60 100644 --- a/apps/sim/lib/knowledge/documents/document-processor.ts +++ b/apps/sim/lib/knowledge/documents/document-processor.ts @@ -15,7 +15,7 @@ import { } from '@/lib/chunkers' import type { ChunkingStrategy, StrategyOptions } from '@/lib/chunkers/types' import { env, envNumber } from '@/lib/core/config/env' -import { parseBuffer, parseFile } from '@/lib/file-parsers' +import { parseBuffer } from '@/lib/file-parsers' import type { FileParseMetadata } from '@/lib/file-parsers/types' import { resolveParserExtension } from '@/lib/knowledge/documents/parser-extension' import { retryWithExponentialBackoff } from '@/lib/knowledge/documents/utils' @@ -315,7 +315,7 @@ async function handleFileForOCR( userId?: string, workspaceId?: string | null ) { - const isExternalHttps = fileUrl.startsWith('https://') && !isInternalFileUrl(fileUrl) + const isExternalHttps = /^https:\/\//i.test(fileUrl) && !isInternalFileUrl(fileUrl) if (isExternalHttps) { if (mimeType === 'application/pdf') { @@ -385,18 +385,17 @@ async function downloadFileWithTimeout(fileUrl: string): Promise { } async function downloadFileForBase64(fileUrl: string): Promise { - if (fileUrl.startsWith('data:')) { + if (/^data:/i.test(fileUrl)) { const [, base64Data] = fileUrl.split(',') if (!base64Data) { throw new Error('Invalid data URI format') } return Buffer.from(base64Data, 'base64') } - if (fileUrl.startsWith('http')) { + if (/^https?:\/\//i.test(fileUrl)) { return downloadFileWithTimeout(fileUrl) } - const fs = await import('fs/promises') - return fs.readFile(fileUrl) + throw new Error('Unsupported fileUrl scheme: only data: URIs and http(s):// URLs are allowed') } function processOCRContent(result: OCRResult, filename: string): string { @@ -783,16 +782,14 @@ async function parseWithFileParser(fileUrl: string, filename: string, mimeType: let content: string let metadata: FileParseMetadata = {} - if (fileUrl.startsWith('data:')) { + if (/^data:/i.test(fileUrl)) { content = await parseDataURI(fileUrl, filename, mimeType) - } else if (fileUrl.startsWith('http')) { + } else if (/^https?:\/\//i.test(fileUrl)) { const result = await parseHttpFile(fileUrl, filename, mimeType) content = result.content metadata = result.metadata || {} } else { - const result = await parseFile(fileUrl) - content = result.content - metadata = result.metadata || {} + throw new Error('Unsupported fileUrl scheme: only data: URIs and http(s):// URLs are allowed') } if (!content.trim()) { diff --git a/apps/sim/lib/knowledge/service.test.ts b/apps/sim/lib/knowledge/service.test.ts new file mode 100644 index 00000000000..08e81b2746c --- /dev/null +++ b/apps/sim/lib/knowledge/service.test.ts @@ -0,0 +1,114 @@ +/** + * @vitest-environment node + */ +import { + dbChainMock, + dbChainMockFns, + permissionsMock, + permissionsMockFns, + resetDbChainMock, +} from '@sim/testing' +import { beforeEach, describe, expect, it, vi } from 'vitest' + +vi.mock('@sim/db', () => dbChainMock) +vi.mock('@/lib/workspaces/permissions/utils', () => permissionsMock) + +import { KnowledgeBasePermissionError, updateKnowledgeBase } from '@/lib/knowledge/service' + +/** + * These tests guard the workspace mass-assignment fix: + * a user with write/admin on the *source* workspace must not be able to move a + * knowledge base into a workspace where they have no permission, and must not + * be able to clear `workspaceId` (which would orphan the KB to its original + * `userId`, who may not be the caller). + */ +describe('updateKnowledgeBase — workspace transfer authorization', () => { + beforeEach(() => { + vi.clearAllMocks() + dbChainMockFns.limit.mockReset() + resetDbChainMock() + }) + + it('rejects workspaceId change without actorUserId', async () => { + await expect( + updateKnowledgeBase('kb-1', { workspaceId: 'ws-target' }, 'req-1') + ).rejects.toBeInstanceOf(KnowledgeBasePermissionError) + expect(permissionsMockFns.mockGetUserEntityPermissions).not.toHaveBeenCalled() + }) + + it('rejects clearing workspaceId to null when actor is not the KB owner', async () => { + dbChainMockFns.limit.mockResolvedValueOnce([{ workspaceId: 'ws-current', userId: 'owner' }]) + + await expect( + updateKnowledgeBase('kb-1', { workspaceId: null }, 'req-1', { actorUserId: 'attacker' }) + ).rejects.toMatchObject({ + code: 'KNOWLEDGE_BASE_FORBIDDEN', + message: 'Only the knowledge base owner can remove it from a workspace', + }) + expect(permissionsMockFns.mockGetUserEntityPermissions).not.toHaveBeenCalled() + }) + + it('allows the KB owner to clear workspaceId to null (gate passes; target permission not checked)', async () => { + dbChainMockFns.limit.mockResolvedValueOnce([{ workspaceId: 'ws-current', userId: 'owner' }]) + + await expect( + updateKnowledgeBase('kb-1', { workspaceId: null }, 'req-1', { actorUserId: 'owner' }) + ).rejects.not.toBeInstanceOf(KnowledgeBasePermissionError) + expect(permissionsMockFns.mockGetUserEntityPermissions).not.toHaveBeenCalled() + }) + + it('rejects transfer when actor has no permission on target workspace', async () => { + dbChainMockFns.limit.mockResolvedValueOnce([{ workspaceId: 'ws-current', userId: 'u-1' }]) + permissionsMockFns.mockGetUserEntityPermissions.mockResolvedValueOnce(null) + + await expect( + updateKnowledgeBase('kb-1', { workspaceId: 'ws-target' }, 'req-1', { + actorUserId: 'attacker', + }) + ).rejects.toMatchObject({ + code: 'KNOWLEDGE_BASE_FORBIDDEN', + message: 'User does not have permission on the target workspace', + }) + expect(permissionsMockFns.mockGetUserEntityPermissions).toHaveBeenCalledWith( + 'attacker', + 'workspace', + 'ws-target' + ) + }) + + it('rejects transfer when actor only has read permission on target workspace', async () => { + dbChainMockFns.limit.mockResolvedValueOnce([{ workspaceId: 'ws-current', userId: 'u-1' }]) + permissionsMockFns.mockGetUserEntityPermissions.mockResolvedValueOnce('read') + + await expect( + updateKnowledgeBase('kb-1', { workspaceId: 'ws-target' }, 'req-1', { + actorUserId: 'reader', + }) + ).rejects.toBeInstanceOf(KnowledgeBasePermissionError) + }) + + it('throws when knowledge base does not exist during transfer', async () => { + dbChainMockFns.limit.mockResolvedValueOnce([]) + + await expect( + updateKnowledgeBase('kb-missing', { workspaceId: 'ws-target' }, 'req-1', { + actorUserId: 'u-1', + }) + ).rejects.toThrow('Knowledge base kb-missing not found') + expect(permissionsMockFns.mockGetUserEntityPermissions).not.toHaveBeenCalled() + }) + + it('locks the knowledge base row (SELECT … FOR UPDATE) before the permission check', async () => { + dbChainMockFns.limit.mockResolvedValueOnce([{ workspaceId: 'ws-current', userId: 'u-1' }]) + permissionsMockFns.mockGetUserEntityPermissions.mockResolvedValueOnce(null) + + await expect( + updateKnowledgeBase('kb-1', { workspaceId: 'ws-target' }, 'req-1', { + actorUserId: 'attacker', + }) + ).rejects.toBeInstanceOf(KnowledgeBasePermissionError) + + expect(dbChainMockFns.transaction).toHaveBeenCalledTimes(1) + expect(dbChainMockFns.for).toHaveBeenCalledWith('update') + }) +}) diff --git a/apps/sim/lib/knowledge/service.ts b/apps/sim/lib/knowledge/service.ts index 0ba5de8162a..8bca342cb77 100644 --- a/apps/sim/lib/knowledge/service.ts +++ b/apps/sim/lib/knowledge/service.ts @@ -21,6 +21,10 @@ export class KnowledgeBaseConflictError extends Error { } } +export class KnowledgeBasePermissionError extends Error { + readonly code = 'KNOWLEDGE_BASE_FORBIDDEN' as const +} + export type KnowledgeBaseScope = 'active' | 'archived' | 'all' /** @@ -148,7 +152,9 @@ export async function createKnowledgeBase( const hasPermission = await getUserEntityPermissions(data.userId, 'workspace', data.workspaceId) if (hasPermission !== 'admin' && hasPermission !== 'write') { - throw new Error('User does not have permission to create knowledge bases in this workspace') + throw new KnowledgeBasePermissionError( + 'User does not have permission to create knowledge bases in this workspace' + ) } const newKnowledgeBase = { @@ -226,7 +232,8 @@ export async function updateKnowledgeBase( overlap: number } }, - requestId: string + requestId: string, + options?: { actorUserId?: string } ): Promise { const now = new Date() const updateData: { @@ -252,38 +259,81 @@ export async function updateKnowledgeBase( updateData.chunkingConfig = updates.chunkingConfig } - if (updates.name !== undefined) { - const existing = await db - .select({ id: knowledgeBase.id, workspaceId: knowledgeBase.workspaceId }) - .from(knowledgeBase) - .where(and(eq(knowledgeBase.id, knowledgeBaseId), isNull(knowledgeBase.deletedAt))) - .limit(1) + if (updates.workspaceId !== undefined && !options?.actorUserId) { + throw new KnowledgeBasePermissionError( + 'actorUserId is required to change a knowledge base workspace' + ) + } - if (existing.length > 0 && existing[0].workspaceId) { - const duplicate = await db - .select({ id: knowledgeBase.id }) + try { + await db.transaction(async (tx) => { + const [currentKb] = await tx + .select({ workspaceId: knowledgeBase.workspaceId, userId: knowledgeBase.userId }) .from(knowledgeBase) - .where( - and( - eq(knowledgeBase.workspaceId, existing[0].workspaceId), - eq(knowledgeBase.name, updates.name), - isNull(knowledgeBase.deletedAt), - ne(knowledgeBase.id, knowledgeBaseId) - ) - ) + .where(and(eq(knowledgeBase.id, knowledgeBaseId), isNull(knowledgeBase.deletedAt))) + .for('update') .limit(1) - if (duplicate.length > 0) { - throw new KnowledgeBaseConflictError(updates.name) + if (!currentKb) { + throw new Error(`Knowledge base ${knowledgeBaseId} not found`) } - } - } - try { - await db - .update(knowledgeBase) - .set(updateData) - .where(and(eq(knowledgeBase.id, knowledgeBaseId), isNull(knowledgeBase.deletedAt))) + if (updates.workspaceId !== undefined) { + const actorUserId = options?.actorUserId as string + const currentWorkspaceId = currentKb.workspaceId ?? null + const targetWorkspaceId = updates.workspaceId ?? null + + if (targetWorkspaceId !== currentWorkspaceId) { + if (!targetWorkspaceId) { + if (actorUserId !== currentKb.userId) { + throw new KnowledgeBasePermissionError( + 'Only the knowledge base owner can remove it from a workspace' + ) + } + } else { + const targetPermission = await getUserEntityPermissions( + actorUserId, + 'workspace', + targetWorkspaceId + ) + if (targetPermission !== 'write' && targetPermission !== 'admin') { + throw new KnowledgeBasePermissionError( + 'User does not have permission on the target workspace' + ) + } + } + } + } + + if (updates.name !== undefined) { + const effectiveWorkspaceId = + updates.workspaceId !== undefined ? updates.workspaceId : currentKb.workspaceId + + if (effectiveWorkspaceId) { + const duplicate = await tx + .select({ id: knowledgeBase.id }) + .from(knowledgeBase) + .where( + and( + eq(knowledgeBase.workspaceId, effectiveWorkspaceId), + eq(knowledgeBase.name, updates.name), + isNull(knowledgeBase.deletedAt), + ne(knowledgeBase.id, knowledgeBaseId) + ) + ) + .limit(1) + + if (duplicate.length > 0) { + throw new KnowledgeBaseConflictError(updates.name) + } + } + } + + await tx + .update(knowledgeBase) + .set(updateData) + .where(and(eq(knowledgeBase.id, knowledgeBaseId), isNull(knowledgeBase.deletedAt))) + }) } catch (error: unknown) { if (getPostgresErrorCode(error) === '23505' && updates.name !== undefined) { throw new KnowledgeBaseConflictError(updates.name) diff --git a/apps/sim/lib/mcp/client.ts b/apps/sim/lib/mcp/client.ts index 93588aecdd3..bbc5cb19e00 100644 --- a/apps/sim/lib/mcp/client.ts +++ b/apps/sim/lib/mcp/client.ts @@ -18,6 +18,7 @@ import { import { createLogger } from '@sim/logger' import { getErrorMessage } from '@sim/utils/errors' import { getMaxExecutionTimeout } from '@/lib/core/execution-limits' +import { createMcpPinnedFetch } from '@/lib/mcp/pinned-fetch' import { type McpClientOptions, McpConnectionError, @@ -51,34 +52,15 @@ export class McpClient { '2024-11-05', // Initial stable release ] - /** - * Creates a new MCP client. - * - * Accepts either the legacy (config, securityPolicy?) signature - * or a single McpClientOptions object with an optional onToolsChanged callback. - */ - constructor(config: McpServerConfig, securityPolicy?: McpSecurityPolicy) - constructor(options: McpClientOptions) - constructor( - configOrOptions: McpServerConfig | McpClientOptions, - securityPolicy?: McpSecurityPolicy - ) { - if ('config' in configOrOptions) { - this.config = configOrOptions.config - this.securityPolicy = configOrOptions.securityPolicy ?? { - requireConsent: true, - auditLevel: 'basic', - maxToolExecutionsPerHour: 1000, - } - this.onToolsChanged = configOrOptions.onToolsChanged - } else { - this.config = configOrOptions - this.securityPolicy = securityPolicy ?? { - requireConsent: true, - auditLevel: 'basic', - maxToolExecutionsPerHour: 1000, - } + constructor(options: McpClientOptions) { + this.config = options.config + this.securityPolicy = options.securityPolicy ?? { + requireConsent: true, + auditLevel: 'basic', + maxToolExecutionsPerHour: 1000, } + this.onToolsChanged = options.onToolsChanged + const resolvedIP = options.resolvedIP this.connectionStatus = { connected: false } @@ -90,6 +72,7 @@ export class McpClient { requestInit: { headers: this.config.headers, }, + ...(resolvedIP ? { fetch: createMcpPinnedFetch(resolvedIP) } : {}), }) this.client = new Client( diff --git a/apps/sim/lib/mcp/connection-manager.ts b/apps/sim/lib/mcp/connection-manager.ts index 3d6627be57b..a150b194a87 100644 --- a/apps/sim/lib/mcp/connection-manager.ts +++ b/apps/sim/lib/mcp/connection-manager.ts @@ -71,7 +71,8 @@ export class McpConnectionManager { async connect( config: McpServerConfig, userId: string, - workspaceId: string + workspaceId: string, + resolvedIP?: string | null ): Promise<{ supportsListChanged: boolean }> { if (this.disposed) { logger.warn('Connection manager is disposed, ignoring connect request') @@ -106,6 +107,7 @@ export class McpConnectionManager { maxToolExecutionsPerHour: 1000, }, onToolsChanged, + resolvedIP: resolvedIP ?? undefined, }) try { diff --git a/apps/sim/lib/mcp/domain-check.test.ts b/apps/sim/lib/mcp/domain-check.test.ts index 6cc76716ca0..ff559caa8cf 100644 --- a/apps/sim/lib/mcp/domain-check.test.ts +++ b/apps/sim/lib/mcp/domain-check.test.ts @@ -4,13 +4,17 @@ import { inputValidationMock, inputValidationMockFns } from '@sim/testing' import { beforeEach, describe, expect, it, vi } from 'vitest' -const { mockGetAllowedMcpDomainsFromEnv, mockDnsLookup } = vi.hoisted(() => ({ +const { mockGetAllowedMcpDomainsFromEnv, mockDnsLookup, hostedFlag } = vi.hoisted(() => ({ mockGetAllowedMcpDomainsFromEnv: vi.fn<() => string[] | null>(), mockDnsLookup: vi.fn(), + hostedFlag: { value: false }, })) vi.mock('@/lib/core/config/feature-flags', () => ({ getAllowedMcpDomainsFromEnv: mockGetAllowedMcpDomainsFromEnv, + get isHosted() { + return hostedFlag.value + }, })) vi.mock('@/lib/core/security/input-validation.server', () => inputValidationMock) @@ -331,41 +335,44 @@ describe('validateMcpServerSsrf', () => { beforeEach(() => { vi.clearAllMocks() mockGetAllowedMcpDomainsFromEnv.mockReturnValue(null) + hostedFlag.value = false }) - it('does nothing for undefined URL', async () => { - await expect(validateMcpServerSsrf(undefined)).resolves.toBeUndefined() + it('returns null for undefined URL', async () => { + await expect(validateMcpServerSsrf(undefined)).resolves.toBeNull() expect(mockDnsLookup).not.toHaveBeenCalled() }) - it('skips validation for env var URLs', async () => { - await expect(validateMcpServerSsrf('{{MCP_SERVER_URL}}')).resolves.toBeUndefined() + it('returns null and skips validation for env var URLs', async () => { + await expect(validateMcpServerSsrf('{{MCP_SERVER_URL}}')).resolves.toBeNull() expect(mockDnsLookup).not.toHaveBeenCalled() }) - it('skips validation for URLs with env var in hostname', async () => { - await expect(validateMcpServerSsrf('https://{{MCP_HOST}}/mcp')).resolves.toBeUndefined() + it('returns null and skips validation for URLs with env var in hostname', async () => { + await expect(validateMcpServerSsrf('https://{{MCP_HOST}}/mcp')).resolves.toBeNull() expect(mockDnsLookup).not.toHaveBeenCalled() }) - it('allows localhost URLs without DNS lookup', async () => { - await expect(validateMcpServerSsrf('http://localhost:3000/mcp')).resolves.toBeUndefined() + it('returns null for localhost URLs without DNS lookup', async () => { + await expect(validateMcpServerSsrf('http://localhost:3000/mcp')).resolves.toBeNull() expect(mockDnsLookup).not.toHaveBeenCalled() }) - it('allows 127.0.0.1 URLs without DNS lookup', async () => { - await expect(validateMcpServerSsrf('http://127.0.0.1:8080/mcp')).resolves.toBeUndefined() + it('returns null for 127.0.0.1 literal without DNS lookup', async () => { + await expect(validateMcpServerSsrf('http://127.0.0.1:8080/mcp')).resolves.toBeNull() expect(mockDnsLookup).not.toHaveBeenCalled() }) - it('allows URLs that resolve to public IPs', async () => { + it('returns resolved IP for URLs that resolve to public IPs', async () => { mockDnsLookup.mockResolvedValue({ address: '93.184.216.34' }) - await expect(validateMcpServerSsrf('https://example.com/mcp')).resolves.toBeUndefined() + await expect(validateMcpServerSsrf('https://example.com/mcp')).resolves.toBe('93.184.216.34') }) - it('allows HTTP URLs on non-localhost hosts', async () => { + it('returns resolved IP for HTTP URLs on non-localhost hosts', async () => { mockDnsLookup.mockResolvedValue({ address: '93.184.216.34' }) - await expect(validateMcpServerSsrf('http://example.com:3000/mcp')).resolves.toBeUndefined() + await expect(validateMcpServerSsrf('http://example.com:3000/mcp')).resolves.toBe( + '93.184.216.34' + ) }) it('throws McpSsrfError for cloud metadata IP literal', async () => { @@ -402,21 +409,97 @@ describe('validateMcpServerSsrf', () => { ) }) - it('allows URLs resolving to loopback (localhost alias)', async () => { + it('returns resolved IP for URLs resolving to loopback on self-hosted (localhost alias)', async () => { mockDnsLookup.mockResolvedValue({ address: '127.0.0.1' }) - await expect(validateMcpServerSsrf('http://my-local-alias:3000/mcp')).resolves.toBeUndefined() + await expect(validateMcpServerSsrf('http://my-local-alias:3000/mcp')).resolves.toBe('127.0.0.1') }) it('throws for malformed URLs', async () => { await expect(validateMcpServerSsrf('not-a-url')).rejects.toThrow(McpSsrfError) }) + describe('hosted environment', () => { + beforeEach(() => { + hostedFlag.value = true + }) + + it('rejects localhost URLs on hosted', async () => { + await expect(validateMcpServerSsrf('http://localhost:3000/mcp')).rejects.toThrow(McpSsrfError) + }) + + it('rejects 127.0.0.1 URLs on hosted', async () => { + await expect(validateMcpServerSsrf('http://127.0.0.1:8080/mcp')).rejects.toThrow(McpSsrfError) + }) + + it('rejects [::1] URLs on hosted', async () => { + await expect(validateMcpServerSsrf('http://[::1]:8080/mcp')).rejects.toThrow(McpSsrfError) + }) + + it('rejects URLs resolving to loopback on hosted', async () => { + mockDnsLookup.mockResolvedValue({ address: '127.0.0.1' }) + await expect(validateMcpServerSsrf('http://my-local-alias:3000/mcp')).rejects.toThrow( + McpSsrfError + ) + }) + + it('returns resolved IP for public IP resolutions on hosted', async () => { + mockDnsLookup.mockResolvedValue({ address: '93.184.216.34' }) + await expect(validateMcpServerSsrf('https://example.com/mcp')).resolves.toBe('93.184.216.34') + }) + + it('skips loopback check on hosted when allowlist is configured', async () => { + mockGetAllowedMcpDomainsFromEnv.mockReturnValue(['localhost']) + await expect(validateMcpServerSsrf('http://localhost:3000/mcp')).resolves.toBeNull() + }) + + it('still blocks RFC-1918 IP literals on hosted (regression)', async () => { + await expect(validateMcpServerSsrf('http://10.0.0.1/mcp')).rejects.toThrow(McpSsrfError) + await expect(validateMcpServerSsrf('http://192.168.1.1/mcp')).rejects.toThrow(McpSsrfError) + }) + + it('still blocks cloud metadata IP on hosted (regression)', async () => { + await expect( + validateMcpServerSsrf('http://169.254.169.254/latest/meta-data/') + ).rejects.toThrow(McpSsrfError) + }) + + it('still blocks DNS resolutions to private IPs on hosted (regression)', async () => { + mockDnsLookup.mockResolvedValue({ address: '10.0.0.5' }) + await expect(validateMcpServerSsrf('https://internal.corp/mcp')).rejects.toThrow(McpSsrfError) + }) + + it('still skips env var hostnames on hosted', async () => { + await expect(validateMcpServerSsrf('{{MCP_SERVER_URL}}')).resolves.toBeNull() + await expect(validateMcpServerSsrf('https://{{MCP_HOST}}/mcp')).resolves.toBeNull() + expect(mockDnsLookup).not.toHaveBeenCalled() + }) + }) + + describe('self-hosted environment (regression)', () => { + beforeEach(() => { + hostedFlag.value = false + }) + + it('still allows localhost URLs (returns null, no pinning needed)', async () => { + await expect(validateMcpServerSsrf('http://localhost:3000/mcp')).resolves.toBeNull() + }) + + it('still allows 127.0.0.1 URLs (returns null, no pinning needed)', async () => { + await expect(validateMcpServerSsrf('http://127.0.0.1:8080/mcp')).resolves.toBeNull() + }) + + it('returns resolved loopback IP for DNS aliases (caller pins)', async () => { + mockDnsLookup.mockResolvedValue({ address: '127.0.0.1' }) + await expect(validateMcpServerSsrf('http://my-local-alias/mcp')).resolves.toBe('127.0.0.1') + }) + }) + it('skips all checks when ALLOWED_MCP_DOMAINS is configured', async () => { mockGetAllowedMcpDomainsFromEnv.mockReturnValue(['internal.corp']) - await expect(validateMcpServerSsrf('http://10.0.0.1/mcp')).resolves.toBeUndefined() + await expect(validateMcpServerSsrf('http://10.0.0.1/mcp')).resolves.toBeNull() await expect( validateMcpServerSsrf('http://169.254.169.254/latest/meta-data/') - ).resolves.toBeUndefined() + ).resolves.toBeNull() expect(mockDnsLookup).not.toHaveBeenCalled() }) }) diff --git a/apps/sim/lib/mcp/domain-check.ts b/apps/sim/lib/mcp/domain-check.ts index 83ec36c69f5..9e57b23c7f4 100644 --- a/apps/sim/lib/mcp/domain-check.ts +++ b/apps/sim/lib/mcp/domain-check.ts @@ -2,7 +2,7 @@ import dns from 'dns/promises' import { createLogger } from '@sim/logger' import { toError } from '@sim/utils/errors' import * as ipaddr from 'ipaddr.js' -import { getAllowedMcpDomainsFromEnv } from '@/lib/core/config/feature-flags' +import { getAllowedMcpDomainsFromEnv, isHosted } from '@/lib/core/config/feature-flags' import { isPrivateOrReservedIP } from '@/lib/core/security/input-validation.server' import { createEnvVarPattern } from '@/executor/utils/reference-validation' @@ -133,16 +133,25 @@ function isLocalhostHostname(hostname: string): boolean { * Does NOT enforce protocol (HTTP is allowed) or block service ports — MCP * servers legitimately run on HTTP and on arbitrary ports. * - * Localhost/loopback is always allowed for local dev MCP servers. + * Localhost/loopback is allowed for local dev MCP servers in self-hosted + * deployments, but blocked on the hosted environment (sim.ai) where users + * must not be able to reach the server's own loopback interface. * URLs with env var references in the hostname are skipped — they will be * validated after resolution at execution time. * + * Returns the resolved IP address when DNS resolution was performed (so the + * caller can pin subsequent connections to that IP and prevent DNS-rebinding + * TOCTOU attacks). Returns null in cases where pinning is unnecessary or + * impossible: no URL, allowlist-only mode, env-var hostnames (validated later), + * IP literals (no DNS to rebind), and localhost on self-hosted (no rebinding + * risk against a fixed loopback). + * * @throws McpSsrfError if the URL resolves to a blocked IP address */ -export async function validateMcpServerSsrf(url: string | undefined): Promise { - if (!url) return - if (getAllowedMcpDomainsFromEnv() !== null) return - if (hasEnvVarInHostname(url)) return +export async function validateMcpServerSsrf(url: string | undefined): Promise { + if (!url) return null + if (getAllowedMcpDomainsFromEnv() !== null) return null + if (hasEnvVarInHostname(url)) return null let hostname: string try { @@ -154,28 +163,47 @@ export async function validateMcpServerSsrf(url: string | undefined): Promise { + const capturedAgentOptions: unknown[] = [] + class MockAgent { + constructor(options: unknown) { + capturedAgentOptions.push(options) + } + } + return { + mockAgent: MockAgent, + mockCreatePinnedLookup: vi.fn(), + mockUndiciFetch: vi.fn(), + capturedAgentOptions, + } + } +) + +vi.mock('undici', () => ({ Agent: mockAgent, fetch: mockUndiciFetch })) +vi.mock('@/lib/core/security/input-validation.server', () => ({ + createPinnedLookup: mockCreatePinnedLookup, +})) + +import { createMcpPinnedFetch } from '@/lib/mcp/pinned-fetch' + +describe('createMcpPinnedFetch', () => { + beforeEach(() => { + vi.clearAllMocks() + capturedAgentOptions.length = 0 + mockCreatePinnedLookup.mockReturnValue('pinned-lookup-fn') + mockUndiciFetch.mockResolvedValue(new Response('ok')) + }) + + it('builds an undici Agent with the pinned lookup for the resolved IP', () => { + createMcpPinnedFetch('203.0.113.10') + expect(mockCreatePinnedLookup).toHaveBeenCalledWith('203.0.113.10') + expect(capturedAgentOptions).toHaveLength(1) + expect(capturedAgentOptions[0]).toEqual({ connect: { lookup: 'pinned-lookup-fn' } }) + }) + + it('forwards the dispatcher on every fetch call', async () => { + const fetchLike = createMcpPinnedFetch('203.0.113.10') + await fetchLike('https://example.com/mcp', { method: 'POST' }) + expect(mockUndiciFetch).toHaveBeenCalledTimes(1) + const [url, init] = mockUndiciFetch.mock.calls[0] + expect(url).toBe('https://example.com/mcp') + expect((init as { dispatcher?: unknown }).dispatcher).toBeInstanceOf(mockAgent) + expect((init as { method?: string }).method).toBe('POST') + }) + + it('preserves caller-provided init options (headers, signal)', async () => { + const fetchLike = createMcpPinnedFetch('203.0.113.10') + const controller = new AbortController() + await fetchLike('https://example.com/mcp', { + method: 'GET', + headers: { 'x-test': '1' }, + signal: controller.signal, + }) + const init = mockUndiciFetch.mock.calls[0][1] as RequestInit & { dispatcher?: unknown } + expect(init.headers).toEqual({ 'x-test': '1' }) + expect(init.signal).toBe(controller.signal) + expect(init.dispatcher).toBeInstanceOf(mockAgent) + }) + + it('handles undefined init gracefully', async () => { + const fetchLike = createMcpPinnedFetch('203.0.113.10') + await fetchLike('https://example.com/mcp') + const init = mockUndiciFetch.mock.calls[0][1] as { dispatcher?: unknown } + expect(init.dispatcher).toBeInstanceOf(mockAgent) + }) + + it('reuses the same dispatcher across calls (one Agent per fetch instance)', async () => { + const fetchLike = createMcpPinnedFetch('203.0.113.10') + await fetchLike('https://example.com/a') + await fetchLike('https://example.com/b') + expect(capturedAgentOptions).toHaveLength(1) + const d1 = (mockUndiciFetch.mock.calls[0][1] as { dispatcher: unknown }).dispatcher + const d2 = (mockUndiciFetch.mock.calls[1][1] as { dispatcher: unknown }).dispatcher + expect(d1).toBe(d2) + }) +}) diff --git a/apps/sim/lib/mcp/pinned-fetch.ts b/apps/sim/lib/mcp/pinned-fetch.ts new file mode 100644 index 00000000000..798de5710e6 --- /dev/null +++ b/apps/sim/lib/mcp/pinned-fetch.ts @@ -0,0 +1,38 @@ +import type { FetchLike } from '@modelcontextprotocol/sdk/shared/transport.js' +import { Agent, type RequestInit as UndiciRequestInit, fetch as undiciFetch } from 'undici' +import { createPinnedLookup } from '@/lib/core/security/input-validation.server' + +/** + * Creates a FetchLike that pins all outbound HTTP connections to a pre-resolved + * IP address. Used by the MCP transport to prevent DNS-rebinding (TOCTOU) + * attacks: validation performs DNS once and confirms the IP is allowed; this + * fetch then forces every subsequent request (initial POST, SSE GET, redirects) + * to use that same IP, regardless of what the hostname now resolves to. + * + * Uses undici's `fetch` directly so the `dispatcher` option is part of the + * real type contract — not a cast that would silently break if a future + * runtime swapped out the implementation. + * + * The original hostname is preserved on the request so TLS SNI and the Host + * header continue to match the certificate. + */ +export function createMcpPinnedFetch(resolvedIP: string): FetchLike { + const dispatcher = new Agent({ + connect: { lookup: createPinnedLookup(resolvedIP) }, + }) + + return (async (url, init) => { + // DOM `RequestInit` and undici's `RequestInit` are structurally compatible + // at runtime (Node's global fetch IS undici) but differ in TS types. + // Cast the init through unknown to bridge the typing without losing the + // critical `dispatcher` typing on the call itself. + const undiciInit: UndiciRequestInit = { + // double-cast-allowed: DOM RequestInit and undici RequestInit are structurally compatible at runtime (Node's global fetch IS undici) but the TS types differ + ...(init as unknown as UndiciRequestInit), + dispatcher, + } + const response = await undiciFetch(url as string | URL, undiciInit) + // double-cast-allowed: undici Response and DOM Response are structurally compatible at runtime; bridging the types is required to satisfy the FetchLike contract + return response as unknown as Response + }) satisfies FetchLike +} diff --git a/apps/sim/lib/mcp/service.ts b/apps/sim/lib/mcp/service.ts index 4ec764c3d41..7838f682822 100644 --- a/apps/sim/lib/mcp/service.ts +++ b/apps/sim/lib/mcp/service.ts @@ -69,13 +69,13 @@ class McpService { config: McpServerConfig, userId: string, workspaceId?: string - ): Promise { + ): Promise<{ config: McpServerConfig; resolvedIP: string | null }> { const { config: resolvedConfig } = await resolveMcpConfigEnvVars(config, userId, workspaceId, { strict: true, }) validateMcpDomain(resolvedConfig.url) - await validateMcpServerSsrf(resolvedConfig.url) - return resolvedConfig + const resolvedIP = await validateMcpServerSsrf(resolvedConfig.url) + return { config: resolvedConfig, resolvedIP } } /** @@ -156,7 +156,10 @@ class McpService { /** * Create and connect to an MCP client */ - private async createClient(config: McpServerConfig): Promise { + private async createClient( + config: McpServerConfig, + resolvedIP: string | null + ): Promise { const securityPolicy = { requireConsent: true, auditLevel: 'basic' as const, @@ -164,7 +167,11 @@ class McpService { allowedOrigins: config.url ? [new URL(config.url).origin] : undefined, } - const client = new McpClient(config, securityPolicy) + const client = new McpClient({ + config, + securityPolicy, + resolvedIP: resolvedIP ?? undefined, + }) await client.connect() return client } @@ -194,11 +201,15 @@ class McpService { throw new Error(`Server ${serverId} not found or not accessible`) } - const resolvedConfig = await this.resolveConfigEnvVars(config, userId, workspaceId) + const { config: resolvedConfig, resolvedIP } = await this.resolveConfigEnvVars( + config, + userId, + workspaceId + ) if (extraHeaders && Object.keys(extraHeaders).length > 0) { resolvedConfig.headers = { ...resolvedConfig.headers, ...extraHeaders } } - const client = await this.createClient(resolvedConfig) + const client = await this.createClient(resolvedConfig, resolvedIP) try { const result = await client.callTool(toolCall) @@ -348,14 +359,18 @@ class McpService { const allTools: McpTool[] = [] const results = await Promise.allSettled( servers.map(async (config) => { - const resolvedConfig = await this.resolveConfigEnvVars(config, userId, workspaceId) - const client = await this.createClient(resolvedConfig) + const { config: resolvedConfig, resolvedIP } = await this.resolveConfigEnvVars( + config, + userId, + workspaceId + ) + const client = await this.createClient(resolvedConfig, resolvedIP) try { const tools = await client.listTools() logger.debug( `[${requestId}] Discovered ${tools.length} tools from server ${config.name}` ) - return { serverId: config.id, tools, resolvedConfig } + return { serverId: config.id, tools, resolvedConfig, resolvedIP } } finally { await client.disconnect() } @@ -394,13 +409,15 @@ class McpService { if (mcpConnectionManager) { for (const [index, result] of results.entries()) { if (result.status === 'fulfilled') { - const { resolvedConfig } = result.value - mcpConnectionManager.connect(resolvedConfig, userId, workspaceId).catch((err) => { - logger.warn( - `[${requestId}] Persistent connection failed for ${servers[index].name}:`, - err - ) - }) + const { resolvedConfig, resolvedIP } = result.value + mcpConnectionManager + .connect(resolvedConfig, userId, workspaceId, resolvedIP) + .catch((err) => { + logger.warn( + `[${requestId}] Persistent connection failed for ${servers[index].name}:`, + err + ) + }) } } } @@ -450,8 +467,12 @@ class McpService { throw new Error(`Server ${serverId} not found or not accessible`) } - const resolvedConfig = await this.resolveConfigEnvVars(config, userId, workspaceId) - const client = await this.createClient(resolvedConfig) + const { config: resolvedConfig, resolvedIP } = await this.resolveConfigEnvVars( + config, + userId, + workspaceId + ) + const client = await this.createClient(resolvedConfig, resolvedIP) try { const tools = await client.listTools() @@ -490,8 +511,12 @@ class McpService { for (const config of servers) { try { - const resolvedConfig = await this.resolveConfigEnvVars(config, userId, workspaceId) - const client = await this.createClient(resolvedConfig) + const { config: resolvedConfig, resolvedIP } = await this.resolveConfigEnvVars( + config, + userId, + workspaceId + ) + const client = await this.createClient(resolvedConfig, resolvedIP) const tools = await client.listTools() await client.disconnect() diff --git a/apps/sim/lib/mcp/types.ts b/apps/sim/lib/mcp/types.ts index db9ac11fd0a..f4f5c939efd 100644 --- a/apps/sim/lib/mcp/types.ts +++ b/apps/sim/lib/mcp/types.ts @@ -161,6 +161,14 @@ export interface McpClientOptions { config: McpServerConfig securityPolicy?: McpSecurityPolicy onToolsChanged?: McpToolsChangedCallback + /** + * Pre-resolved IP address to pin all transport HTTP connections to. When + * set, the SDK transport uses a custom fetch backed by an undici Agent with + * a fixed DNS lookup, preventing DNS-rebinding (TOCTOU) attacks between + * URL validation and connection. Should be supplied by callers that have + * just validated the URL via `validateMcpServerSsrf`. + */ + resolvedIP?: string } /** diff --git a/apps/sim/package.json b/apps/sim/package.json index 71521b166bb..f175df18d51 100644 --- a/apps/sim/package.json +++ b/apps/sim/package.json @@ -196,6 +196,7 @@ "three": "0.177.0", "tldts": "7.0.30", "twilio": "5.9.0", + "undici": "7.25.0", "unified": "11.0.5", "unpdf": "1.4.0", "xlsx": "https://cdn.sheetjs.com/xlsx-0.20.3/xlsx-0.20.3.tgz", diff --git a/apps/sim/tools/agiloft/utils.server.ts b/apps/sim/tools/agiloft/utils.server.ts new file mode 100644 index 00000000000..3aaa0c62b71 --- /dev/null +++ b/apps/sim/tools/agiloft/utils.server.ts @@ -0,0 +1,79 @@ +import { createLogger } from '@sim/logger' +import { + type SecureFetchResponse, + secureFetchWithPinnedIP, + validateUrlWithDNS, +} from '@/lib/core/security/input-validation.server' +import type { AgiloftBaseParams } from '@/tools/agiloft/types' + +const logger = createLogger('AgiloftAuthServer') + +/** + * Validates the Agiloft instance URL and resolves its DNS once, returning the + * resolved IP so subsequent requests can pin to it. This prevents DNS-rebinding + * (TOCTOU) SSRF where the hostname could resolve to a private IP on a later + * lookup. Server-only — uses node:dns/promises. + */ +export async function resolveAgiloftInstance(instanceUrl: string): Promise { + const validation = await validateUrlWithDNS(instanceUrl, 'instanceUrl') + if (!validation.isValid || !validation.resolvedIP) { + throw new Error(validation.error || 'Invalid Agiloft instance URL') + } + return validation.resolvedIP +} + +/** + * DNS-pinned variant of agiloftLogin. Requires a pre-resolved IP so the + * connection cannot be steered to a different host between validation and + * the actual TCP connection. + */ +export async function agiloftLoginPinned( + params: AgiloftBaseParams, + resolvedIP: string +): Promise { + const base = params.instanceUrl.replace(/\/$/, '') + const kb = encodeURIComponent(params.knowledgeBase) + const login = encodeURIComponent(params.login) + const password = encodeURIComponent(params.password) + + const url = `${base}/ewws/EWLogin?$KB=${kb}&$login=${login}&$password=${password}` + const response = await secureFetchWithPinnedIP(url, resolvedIP, { method: 'POST' }) + + if (!response.ok) { + const errorText = await response.text() + throw new Error(`Agiloft login failed: ${response.status} - ${errorText}`) + } + + const data = (await response.json()) as { access_token?: string } + const token = data.access_token + + if (!token) { + throw new Error('Agiloft login did not return an access token') + } + + return token +} + +/** + * DNS-pinned variant of agiloftLogout. Best-effort — failures are logged but + * not thrown. + */ +export async function agiloftLogoutPinned( + instanceUrl: string, + knowledgeBase: string, + token: string, + resolvedIP: string +): Promise { + try { + const base = instanceUrl.replace(/\/$/, '') + const kb = encodeURIComponent(knowledgeBase) + await secureFetchWithPinnedIP(`${base}/ewws/EWLogout?$KB=${kb}`, resolvedIP, { + method: 'POST', + headers: { Authorization: `Bearer ${token}` }, + }) + } catch (error) { + logger.warn('Agiloft logout failed (best-effort)', { error }) + } +} + +export type { SecureFetchResponse } diff --git a/apps/sim/tools/agiloft/utils.test.ts b/apps/sim/tools/agiloft/utils.test.ts new file mode 100644 index 00000000000..b80eb2a33ba --- /dev/null +++ b/apps/sim/tools/agiloft/utils.test.ts @@ -0,0 +1,136 @@ +/** + * @vitest-environment node + */ +import { afterEach, beforeEach, describe, expect, it, vi } from 'vitest' +import { executeAgiloftRequest } from '@/tools/agiloft/utils' + +const baseParams = { + instanceUrl: 'https://example.agiloft.com', + knowledgeBase: 'demo', + login: 'admin', + password: 'secret', + table: 'contracts', +} + +function mockFetchResponse(body: { ok?: boolean; status?: number; json?: unknown; text?: string }) { + return { + ok: body.ok ?? true, + status: body.status ?? 200, + statusText: '', + headers: new Headers(), + text: async () => body.text ?? '', + json: async () => body.json ?? {}, + } as unknown as Response +} + +const fetchSpy = vi.fn() + +beforeEach(() => { + fetchSpy.mockReset() + vi.stubGlobal('fetch', fetchSpy) +}) + +afterEach(() => { + vi.unstubAllGlobals() +}) + +describe('executeAgiloftRequest', () => { + it('logs in, runs the operation with the bearer token, then logs out', async () => { + fetchSpy + .mockResolvedValueOnce(mockFetchResponse({ json: { access_token: 'tok-1' } })) + .mockResolvedValueOnce(mockFetchResponse({ json: { id: 42, fields: { name: 'foo' } } })) + .mockResolvedValueOnce(mockFetchResponse({})) + + const result = await executeAgiloftRequest( + baseParams, + (base) => ({ + url: `${base}/ewws/REST/demo/contracts/42`, + method: 'GET', + headers: { Accept: 'application/json' }, + }), + async (response) => { + const data = (await response.json()) as { id: number; fields: Record } + return { + success: response.ok, + output: { id: String(data.id), fields: data.fields }, + } + } + ) + + expect(result).toEqual({ success: true, output: { id: '42', fields: { name: 'foo' } } }) + + const calls = fetchSpy.mock.calls + expect(calls).toHaveLength(3) + expect(calls[0][0]).toBe( + 'https://example.agiloft.com/ewws/EWLogin?$KB=demo&$login=admin&$password=secret' + ) + expect(calls[1][0]).toBe('https://example.agiloft.com/ewws/REST/demo/contracts/42') + expect(calls[1][1]).toMatchObject({ + method: 'GET', + headers: { Accept: 'application/json', Authorization: 'Bearer tok-1' }, + }) + expect(calls[2][0]).toBe('https://example.agiloft.com/ewws/EWLogout?$KB=demo') + }) + + it('still calls logout when the operation throws', async () => { + fetchSpy + .mockResolvedValueOnce(mockFetchResponse({ json: { access_token: 'tok-2' } })) + .mockResolvedValueOnce(mockFetchResponse({ ok: false, status: 500 })) + .mockResolvedValueOnce(mockFetchResponse({})) + + await expect( + executeAgiloftRequest( + baseParams, + (base) => ({ url: `${base}/ewws/REST/demo/contracts/42`, method: 'GET' }), + async (response) => { + if (!response.ok) throw new Error('operation failed') + return { success: true, output: {} } + } + ) + ).rejects.toThrow('operation failed') + + expect(fetchSpy).toHaveBeenCalledTimes(3) + expect(fetchSpy.mock.calls[2][0]).toContain('/ewws/EWLogout') + }) + + it('swallows logout failures (best-effort)', async () => { + fetchSpy + .mockResolvedValueOnce(mockFetchResponse({ json: { access_token: 'tok-3' } })) + .mockResolvedValueOnce(mockFetchResponse({ json: { ok: true } })) + .mockRejectedValueOnce(new Error('logout network error')) + + const result = await executeAgiloftRequest( + baseParams, + (base) => ({ url: `${base}/ewws/REST/demo/contracts/42`, method: 'GET' }), + async () => ({ success: true, output: {} }) + ) + + expect(result.success).toBe(true) + }) + + it('throws when login does not return an access token', async () => { + fetchSpy.mockResolvedValueOnce(mockFetchResponse({ json: {} })) + + await expect( + executeAgiloftRequest( + baseParams, + (base) => ({ url: `${base}/ewws/REST/demo/contracts/42`, method: 'GET' }), + async () => ({ success: true, output: {} }) + ) + ).rejects.toThrow('Agiloft login did not return an access token') + + expect(fetchSpy).toHaveBeenCalledTimes(1) + }) + + it('rejects an instance URL that fails synchronous URL validation', async () => { + await expect( + executeAgiloftRequest( + { ...baseParams, instanceUrl: 'not-a-valid-url' }, + (base) => ({ url: `${base}/ewws/REST/demo/contracts/42`, method: 'GET' }), + async () => ({ success: true, output: {} }) + ) + ).rejects.toThrow(/Invalid Agiloft instance URL/) + + expect(fetchSpy).not.toHaveBeenCalled() + }) +}) diff --git a/apps/sim/tools/agiloft/utils.ts b/apps/sim/tools/agiloft/utils.ts index 47184deb5fb..811187ab833 100644 --- a/apps/sim/tools/agiloft/utils.ts +++ b/apps/sim/tools/agiloft/utils.ts @@ -47,7 +47,7 @@ async function agiloftLogin(params: AgiloftBaseParams): Promise { throw new Error(`Agiloft login failed: ${response.status} - ${errorText}`) } - const data = await response.json() + const data = (await response.json()) as { access_token?: string } const token = data.access_token if (!token) { diff --git a/apps/sim/tools/grafana/update_alert_rule.ts b/apps/sim/tools/grafana/update_alert_rule.ts index 9ca23bff773..19f2bf8164d 100644 --- a/apps/sim/tools/grafana/update_alert_rule.ts +++ b/apps/sim/tools/grafana/update_alert_rule.ts @@ -1,3 +1,4 @@ +import { validateExternalUrl } from '@/lib/core/security/input-validation' import { ALERT_RULE_OUTPUT_FIELDS, type GrafanaUpdateAlertRuleParams } from '@/tools/grafana/types' import { mapAlertRule } from '@/tools/grafana/utils' import type { ToolConfig, ToolResponse } from '@/tools/types' @@ -269,6 +270,15 @@ export const updateAlertRuleTool: ToolConfig return { success: true, output: mapAlertRule(data) } }, diff --git a/apps/sim/tools/grafana/update_dashboard.ts b/apps/sim/tools/grafana/update_dashboard.ts index 23449f36830..99a5f4352d3 100644 --- a/apps/sim/tools/grafana/update_dashboard.ts +++ b/apps/sim/tools/grafana/update_dashboard.ts @@ -1,3 +1,4 @@ +import { validateExternalUrl } from '@/lib/core/security/input-validation' import type { GrafanaUpdateDashboardParams } from '@/tools/grafana/types' import type { ToolConfig, ToolResponse } from '@/tools/types' @@ -183,6 +184,15 @@ export const updateDashboardTool: ToolConfig