Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
123 changes: 108 additions & 15 deletions apps/sim/app/api/copilot/chat/stop/route.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -10,24 +10,49 @@ const {
mockFrom,
mockWhereSelect,
mockLimit,
mockForUpdate,
mockUpdate,
mockSet,
mockWhereUpdate,
mockReturning,
mockPublishStatusChanged,
mockSql,
} = vi.hoisted(() => ({
mockSelect: vi.fn(),
mockFrom: vi.fn(),
mockWhereSelect: vi.fn(),
mockLimit: vi.fn(),
mockUpdate: vi.fn(),
mockSet: vi.fn(),
mockWhereUpdate: vi.fn(),
mockReturning: vi.fn(),
mockPublishStatusChanged: vi.fn(),
mockSql: vi.fn((strings: TemplateStringsArray, ...values: unknown[]) => ({ strings, values })),
}))
mockTransaction,
} = vi.hoisted(() => {
const mockSelect = vi.fn()
const mockFrom = vi.fn()
const mockWhereSelect = vi.fn()
const mockLimit = vi.fn()
const mockForUpdate = vi.fn()
const mockUpdate = vi.fn()
const mockSet = vi.fn()
const mockWhereUpdate = vi.fn()
const mockReturning = vi.fn()
const mockPublishStatusChanged = vi.fn()
const mockSql = vi.fn((strings: TemplateStringsArray, ...values: unknown[]) => ({
strings,
values,
}))
const mockTransaction = vi.fn(
(callback: (tx: { select: typeof mockSelect; update: typeof mockUpdate }) => unknown) =>
callback({ select: mockSelect, update: mockUpdate })
)

return {
mockSelect,
mockFrom,
mockWhereSelect,
mockLimit,
mockForUpdate,
mockUpdate,
mockSet,
mockWhereUpdate,
mockReturning,
mockPublishStatusChanged,
mockSql,
mockTransaction,
}
})

vi.mock('@sim/db/schema', () => ({
copilotChats: {
Expand All @@ -41,8 +66,7 @@ vi.mock('@sim/db/schema', () => ({

vi.mock('@sim/db', () => ({
db: {
select: mockSelect,
update: mockUpdate,
transaction: mockTransaction,
},
}))

Expand Down Expand Up @@ -78,9 +102,11 @@ describe('copilot chat stop route', () => {
{
workspaceId: 'ws-1',
messages: [{ id: 'stream-1', role: 'user', content: 'hello' }],
conversationId: 'stream-1',
},
])
mockWhereSelect.mockReturnValue({ limit: mockLimit })
mockForUpdate.mockReturnValue({ limit: mockLimit })
mockWhereSelect.mockReturnValue({ for: mockForUpdate })
mockFrom.mockReturnValue({ where: mockWhereSelect })
mockSelect.mockReturnValue({ from: mockFrom })

Expand Down Expand Up @@ -153,4 +179,71 @@ describe('copilot chat stop route', () => {
streamId: 'stream-1',
})
})

it('appends a stopped assistant message if the stream marker was already cleared', async () => {
mockLimit.mockResolvedValueOnce([
{
workspaceId: 'ws-1',
messages: [{ id: 'stream-1', role: 'user', content: 'hello' }],
conversationId: null,
},
])

const response = await POST(
createRequest({
chatId: 'chat-1',
streamId: 'stream-1',
content: 'partial',
})
)

expect(response.status).toBe(200)
expect(await response.json()).toEqual({ success: true })

const setArg = mockSet.mock.calls[0]?.[0]
expect(setArg.messages).toBeTruthy()
const appendedPayload = JSON.parse(setArg.messages.values[1] as string)
expect(appendedPayload[0]).toMatchObject({
role: 'assistant',
content: 'partial',
})

expect(mockPublishStatusChanged).toHaveBeenCalledWith({
workspaceId: 'ws-1',
chatId: 'chat-1',
type: 'completed',
streamId: 'stream-1',
})
})

it('republishes completed status when the assistant was already persisted', async () => {
mockLimit.mockResolvedValueOnce([
{
workspaceId: 'ws-1',
messages: [
{ id: 'stream-1', role: 'user', content: 'hello' },
{ id: 'assistant-1', role: 'assistant', content: 'partial' },
],
conversationId: null,
},
])

const response = await POST(
createRequest({
chatId: 'chat-1',
streamId: 'stream-1',
content: 'partial',
})
)

expect(response.status).toBe(200)
expect(await response.json()).toEqual({ success: true })
expect(mockUpdate).not.toHaveBeenCalled()
expect(mockPublishStatusChanged).toHaveBeenCalledWith({
workspaceId: 'ws-1',
chatId: 'chat-1',
type: 'completed',
streamId: 'stream-1',
})
})
Comment thread
icecrasher321 marked this conversation as resolved.
})
97 changes: 35 additions & 62 deletions apps/sim/app/api/copilot/chat/stop/route.ts
Original file line number Diff line number Diff line change
@@ -1,14 +1,19 @@
import { db } from '@sim/db'
import { copilotChats } from '@sim/db/schema'
import { createLogger } from '@sim/logger'
import { generateId } from '@sim/utils/id'
import { and, eq, sql } from 'drizzle-orm'
import { type NextRequest, NextResponse } from 'next/server'
import { copilotChatStopContract } from '@/lib/api/contracts/copilot'
import { parseRequest } from '@/lib/api/server'
import { getSession } from '@/lib/auth'
import { normalizeMessage, type PersistedMessage } from '@/lib/copilot/chat/persisted-message'
import { CopilotStopOutcome } from '@/lib/copilot/generated/trace-attribute-values-v1'
import {
normalizeMessage,
type PersistedMessage,
withStoppedContentBlock,
} from '@/lib/copilot/chat/persisted-message'
import { finalizeAssistantTurn } from '@/lib/copilot/chat/terminal-state'
import {
CopilotChatFinalizeOutcome,
CopilotStopOutcome,
} from '@/lib/copilot/generated/trace-attribute-values-v1'
import { TraceAttr } from '@/lib/copilot/generated/trace-attributes-v1'
import { TraceSpan } from '@/lib/copilot/generated/trace-spans-v1'
import { withIncomingGoSpan } from '@/lib/copilot/request/otel'
Expand Down Expand Up @@ -44,81 +49,49 @@ export const POST = withRouteHandler((req: NextRequest) =>
...(requestId ? { [TraceAttr.RequestId]: requestId } : {}),
})

const [row] = await db
.select({
workspaceId: copilotChats.workspaceId,
messages: copilotChats.messages,
})
.from(copilotChats)
.where(and(eq(copilotChats.id, chatId), eq(copilotChats.userId, session.user.id)))
.limit(1)

if (!row) {
span.setAttribute(TraceAttr.CopilotStopOutcome, CopilotStopOutcome.ChatNotFound)
return NextResponse.json({ success: true })
}

const messages: Record<string, unknown>[] = Array.isArray(row.messages) ? row.messages : []
const userIdx = messages.findIndex((message) => message.id === streamId)
const alreadyHasResponse =
userIdx >= 0 &&
userIdx + 1 < messages.length &&
(messages[userIdx + 1] as Record<string, unknown>)?.role === 'assistant'
const canAppendAssistant =
userIdx >= 0 && userIdx === messages.length - 1 && !alreadyHasResponse

const updateWhere = and(
eq(copilotChats.id, chatId),
eq(copilotChats.userId, session.user.id),
eq(copilotChats.conversationId, streamId)
)

const setClause: Record<string, unknown> = {
conversationId: null,
updatedAt: new Date(),
}

const hasContent = content.trim().length > 0
const hasBlocks = Array.isArray(contentBlocks) && contentBlocks.length > 0
const synthesizedStoppedBlocks = hasBlocks
const assistantBlocks = hasBlocks
? contentBlocks
: hasContent
? [{ type: 'text', channel: 'assistant', content }, { type: 'stopped' }]
: [{ type: 'stopped' }]
if (canAppendAssistant) {
const normalized = normalizeMessage({
? [{ type: 'text', channel: 'assistant', content }]
: []
const assistantMessage: PersistedMessage = withStoppedContentBlock(
normalizeMessage({
id: generateId(),
role: 'assistant',
content,
timestamp: new Date().toISOString(),
contentBlocks: synthesizedStoppedBlocks,
// Persist so the UI copy-request-id button survives refetch.
contentBlocks: assistantBlocks,
...(requestId ? { requestId } : {}),
})
const assistantMessage: PersistedMessage = normalized
setClause.messages = sql`${copilotChats.messages} || ${JSON.stringify([assistantMessage])}::jsonb`
}
span.setAttribute(TraceAttr.CopilotStopAppendedAssistant, canAppendAssistant)

const [updated] = await db
.update(copilotChats)
.set(setClause)
.where(updateWhere)
.returning({ workspaceId: copilotChats.workspaceId })
)
const result = await finalizeAssistantTurn({
chatId,
userId: session.user.id,
userMessageId: streamId,
assistantMessage,
streamMarkerPolicy: 'active-or-cleared',
})
span.setAttribute(TraceAttr.CopilotStopAppendedAssistant, result.appendedAssistant)
const stopOutcome = !result.found
? CopilotStopOutcome.ChatNotFound
: result.updated || result.outcome === CopilotChatFinalizeOutcome.AssistantAlreadyPersisted
? CopilotStopOutcome.Persisted
: CopilotStopOutcome.NoMatchingRow
const shouldPublishCompleted =
result.updated || result.outcome === CopilotChatFinalizeOutcome.AssistantAlreadyPersisted

if (updated?.workspaceId) {
if (shouldPublishCompleted && result.workspaceId) {
taskPubSub?.publishStatusChanged({
workspaceId: updated.workspaceId,
workspaceId: result.workspaceId,
chatId,
type: 'completed',
streamId,
})
}

span.setAttribute(
TraceAttr.CopilotStopOutcome,
updated ? CopilotStopOutcome.Persisted : CopilotStopOutcome.NoMatchingRow
)
span.setAttribute(TraceAttr.CopilotStopOutcome, stopOutcome)
return NextResponse.json({ success: true })
} catch (error) {
logger.error('Error stopping chat stream:', error)
Expand Down
39 changes: 39 additions & 0 deletions apps/sim/lib/copilot/chat/persisted-message.ts
Original file line number Diff line number Diff line change
Expand Up @@ -224,6 +224,45 @@ export function buildPersistedAssistantMessage(
return message
}

export function withStoppedContentBlock(message: PersistedMessage): PersistedMessage {
const contentBlocks = message.contentBlocks ?? []
const hasAssistantText = contentBlocks.some(
(block) =>
block.type === MothershipStreamV1EventType.text &&
block.channel !== MothershipStreamV1TextChannel.thinking &&
block.content?.trim()
)
if (
contentBlocks.some(
(block) =>
block.type === MothershipStreamV1EventType.complete &&
block.status === MothershipStreamV1CompletionStatus.cancelled
)
) {
return message
}

return normalizeMessage({
...message,
contentBlocks: [
...(hasAssistantText || !message.content.trim()
? []
: [
{
type: MothershipStreamV1EventType.text,
channel: MothershipStreamV1TextChannel.assistant,
content: message.content,
},
]),
...contentBlocks,
{
type: MothershipStreamV1EventType.complete,
status: MothershipStreamV1CompletionStatus.cancelled,
},
],
})
}

export interface UserMessageParams {
id: string
content: string
Expand Down
Loading
Loading