Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 8 additions & 6 deletions cecli/tools/_yield.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
import asyncio
import logging

from cecli.helpers.threading import ThreadSafeEvent
from cecli.tools.utils.base_tool import BaseTool
from cecli.tools.utils.helpers import ToolError
from cecli.tools.utils.output import color_markers, tool_footer, tool_header
Expand All @@ -17,10 +16,7 @@ class Tool(BaseTool):
"type": "function",
"function": {
"name": "Yield",
"description": (
"Yield control to subagents, to await their results or back to the user,"
" indicating all sub-goals are complete."
),
"description": "Yield control back to the user, indicating all sub-goals are complete.",
"parameters": {
"type": "object",
"properties": {
Expand Down Expand Up @@ -50,6 +46,12 @@ async def execute(cls, coder, **kwargs):
cls.clear_invocation_cache()

if coder:
reject = getattr(coder, "reject_yield", None)
if callable(reject):
blocked = reject(coder, **kwargs)
if blocked:
return blocked

# Check for active child sub-agents and await their tasks before finishing
try:
agent_service = AgentService.get_instance(coder)
Expand All @@ -69,7 +71,7 @@ async def execute(cls, coder, **kwargs):
# the interrupt event, avoiding nested asyncio.wait() calls.
interrupt_event = coder.interrupt_event
if interrupt_event is None:
interrupt_event = ThreadSafeEvent()
interrupt_event = asyncio.Event()

interrupt_task = asyncio.create_task(interrupt_event.wait())
pending = set(active_tasks) | {interrupt_task}
Expand Down
32 changes: 32 additions & 0 deletions tests/tools/test_yield_guard.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
"""Yield guard on implement turns (reject_yield hook)."""

from __future__ import annotations

import asyncio
import unittest

from cecli.tools._yield import Tool


class _CoderStub:
def __init__(self, *, reject_message: str | None = None):
self.reject_yield = (
(lambda _c, **_k: reject_message) if reject_message is not None else None
)
self.agent_finished = False


class TestYieldGuard(unittest.TestCase):
def test_yield_rejected_when_hook_blocks(self):
coder = _CoderStub(
reject_message="Yield rejected: no file edits saved this implement turn."
)

result = asyncio.run(Tool.execute(coder, summary="done"))

self.assertIn("Yield rejected", result)
self.assertFalse(coder.agent_finished)


if __name__ == "__main__":
unittest.main()
Loading