Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 9 additions & 3 deletions agentex/src/api/routes/events.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,12 @@
from src.api.schemas.authorization_types import (
AgentexResourceType,
AuthorizedOperationType,
TaskChildResourceType,
)
from src.api.schemas.events import Event
from src.domain.services.authorization_service import DAuthorizationService
from src.domain.use_cases.events_use_case import DEventUseCase
from src.utils.authorization_shortcuts import DAuthorizedId, DAuthorizedQuery
from src.utils.agent_authorization import check_agent_or_collapse_to_404
from src.utils.authorization_shortcuts import DAuthorizedQuery
from src.utils.logging import make_logger

logger = make_logger(__name__)
Expand All @@ -20,10 +21,15 @@
response_model=Event,
)
async def get_event(
event_id: DAuthorizedId(TaskChildResourceType.event, AuthorizedOperationType.read),
event_id: str,
event_use_case: DEventUseCase,
authorization: DAuthorizationService,
) -> Event:
# Events delegate authz to the parent agent.
event_entity = await event_use_case.get(event_id)
await check_agent_or_collapse_to_404(
authorization, event_entity.agent_id, AuthorizedOperationType.read
)
return Event.model_validate(event_entity)


Expand Down
1 change: 0 additions & 1 deletion agentex/src/api/schemas/authorization_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@ class AgentexResourceType(StrEnum):
class TaskChildResourceType(StrEnum):
"""Resources that inherit permissions from their parent task."""

event = "event"
state = "state"
message = "message"

Expand Down
26 changes: 26 additions & 0 deletions agentex/src/utils/agent_authorization.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
from src.adapters.authorization.exceptions import AuthorizationError
from src.adapters.crud_store.exceptions import ItemDoesNotExist
from src.api.schemas.authorization_types import (
AgentexResource,
AuthorizedOperationType,
)


async def check_agent_or_collapse_to_404(
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Shouldnt this alr be implemented somwhere? How do we make this check for tasks?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it currently lives in #249

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I can approve but rebase on that one when merged then so we dont have dupe code

authorization,
agent_id: str,
operation: AuthorizedOperationType,
) -> None:
"""Check an agent resource; collapse any denial to 404 to avoid leaking
cross-tenant existence. Mirrors ``check_task_or_collapse_to_404`` in
``task_authorization.py`` — see that docstring for the full rationale.

TODO(AGX1-290): restore 403/404 split once agents carry tenant scope.
"""
try:
await authorization.check(
resource=AgentexResource.agent(agent_id),
operation=operation,
)
except AuthorizationError:
raise ItemDoesNotExist(f"Item with id '{agent_id}' does not exist.") from None
Comment thread
greptile-apps[bot] marked this conversation as resolved.
9 changes: 1 addition & 8 deletions agentex/src/utils/authorization_shortcuts.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@
TaskChildResourceType,
)
from src.domain.repositories.agent_repository import DAgentRepository
from src.domain.repositories.event_repository import DEventRepository
from src.domain.repositories.task_message_repository import DTaskMessageRepository
from src.domain.repositories.task_repository import DTaskRepository
from src.domain.repositories.task_state_repository import DTaskStateRepository
Expand All @@ -22,14 +21,12 @@
async def _get_parent_task_id(
resource_type: TaskChildResourceType,
resource_id: str,
event_repository: DEventRepository,
state_repository: DTaskStateRepository,
message_repository: DTaskMessageRepository,
) -> str:
"""Get the parent task ID for a child resource."""
"""Get the parent task ID for a task-child resource."""
registry = {
TaskChildResourceType.state: state_repository,
TaskChildResourceType.event: event_repository,
TaskChildResourceType.message: message_repository,
}

Expand All @@ -48,7 +45,6 @@ def DAuthorizedId(

async def _ensure_authorized_id(
authorization: DAuthorizationService,
event_repository: DEventRepository,
state_repository: DTaskStateRepository,
message_repository: DTaskMessageRepository,
resource_id: str = Path(..., alias=param_name),
Expand All @@ -60,7 +56,6 @@ async def _ensure_authorized_id(
task_id = await _get_parent_task_id(
resource_type,
resource_id,
event_repository,
state_repository,
message_repository,
)
Expand Down Expand Up @@ -103,7 +98,6 @@ def DAuthorizedQuery(

async def _ensure_authorized_query(
authorization: DAuthorizationService,
event_repository: DEventRepository,
state_repository: DTaskStateRepository,
message_repository: DTaskMessageRepository,
resource_id: str = Query(..., alias=param_name, description=description),
Expand All @@ -115,7 +109,6 @@ async def _ensure_authorized_query(
task_id = await _get_parent_task_id(
resource_type,
resource_id,
event_repository,
state_repository,
message_repository,
)
Expand Down
268 changes: 268 additions & 0 deletions agentex/tests/integration/api/events/test_events_authz_api.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,268 @@
"""AGX1-244: event routes delegate authz to the parent agent."""

from typing import Any
from unittest.mock import patch

import pytest
import pytest_asyncio
from src.adapters.authorization.exceptions import AuthorizationError
from src.api.schemas.authorization_types import AgentexResourceType
from src.domain.entities.agents import ACPType, AgentEntity
from src.domain.entities.task_messages import TextContentEntity
from src.domain.entities.tasks import TaskEntity, TaskStatus
from src.utils.ids import orm_id

MOCK_PRINCIPAL_CONTEXT = {
"account_id": "test-account-id",
"user_id": "test-user-id",
}


def _mock_post_factory(
*,
deny_agent_ids: set[str] | None = None,
deny_task_ids: set[str] | None = None,
):
"""Return a side_effect that allows authn + authz, except for agents/tasks
listed in ``deny_*_ids`` for which ``/v1/authz/check`` raises
AuthorizationError.
"""
deny_agent_ids = deny_agent_ids or set()
deny_task_ids = deny_task_ids or set()

async def _side_effect(
base_url: str = "http://test.com",
path: str = "/test",
*,
json: dict | None = None,
headers: dict[str, str] | None = None,
) -> dict[str, Any]:
if path == "/v1/authn":
return MOCK_PRINCIPAL_CONTEXT
if path == "/v1/authz/check":
assert json is not None
resource = json.get("resource", {})
if (
resource.get("type") == AgentexResourceType.agent.value
and resource.get("selector") in deny_agent_ids
):
raise AuthorizationError("Denied by mock")
if (
resource.get("type") == AgentexResourceType.task.value
and resource.get("selector") in deny_task_ids
):
raise AuthorizationError("Denied by mock")
return {"allowed": True}
if path == "/v1/authz/search":
return {"items": []}
raise Exception(f"Unknown path: {path}")

return _side_effect


@pytest.mark.integration
class TestEventsAuthzAPIIntegration:
"""End-to-end integration tests for event-route authorization."""

@pytest_asyncio.fixture
async def test_agent(self, isolated_repositories):
agent_repo = isolated_repositories["agent_repository"]
agent = AgentEntity(
id=orm_id(),
name="test-authz-agent",
description="Agent for event-authz tests",
acp_url="http://test-acp:8000",
acp_type=ACPType.SYNC,
)
return await agent_repo.create(agent)

@pytest_asyncio.fixture
async def test_task(self, isolated_repositories, test_agent):
task_repo = isolated_repositories["task_repository"]
task = TaskEntity(
id=orm_id(),
name="test-authz-task",
status=TaskStatus.RUNNING,
status_reason="Task for event-authz tests",
)
return await task_repo.create(agent_id=test_agent.id, task=task)

@pytest_asyncio.fixture
async def test_event(self, isolated_repositories, test_agent, test_task):
event_repo = isolated_repositories["event_repository"]
content = TextContentEntity(type="text", author="user", content="hello")
return await event_repo.create(
id=orm_id(),
task_id=test_task.id,
agent_id=test_agent.id,
content=content,
)

@pytest.mark.asyncio
@patch(
"src.api.authentication_middleware.AgentexAuthMiddleware.is_enabled",
return_value=True,
)
@patch(
"src.domain.services.authorization_service.AuthorizationService.is_enabled",
return_value=True,
)
@patch(
"src.utils.http_request_handler.HttpRequestHandler.post_with_error_handling",
side_effect=_mock_post_factory(),
)
async def test_get_event_authorized_returns_200(
self,
post_with_error_handling_mock,
is_enabled_authorization_mock,
is_enabled_mock,
isolated_client,
test_event,
test_agent,
):
response = await isolated_client.get(f"/events/{test_event.id}")
assert response.status_code == 200
assert response.json()["id"] == test_event.id

# One check, on the parent agent (not the task).
check_calls = [
call
for call in post_with_error_handling_mock.call_args_list
if call[0][1] == "/v1/authz/check"
]
assert len(check_calls) == 1
authz_data = check_calls[0][1]["json"]
assert authz_data["resource"]["type"] == AgentexResourceType.agent.value
assert authz_data["resource"]["selector"] == test_agent.id
assert authz_data["operation"] == "read"
assert authz_data["principal"] == MOCK_PRINCIPAL_CONTEXT

@pytest.mark.asyncio
@patch(
"src.api.authentication_middleware.AgentexAuthMiddleware.is_enabled",
return_value=True,
)
@patch(
"src.domain.services.authorization_service.AuthorizationService.is_enabled",
return_value=True,
)
async def test_get_event_unauthorized_returns_404(
self,
is_enabled_authorization_mock,
is_enabled_mock,
isolated_client,
test_event,
test_agent,
):
with patch(
"src.utils.http_request_handler.HttpRequestHandler.post_with_error_handling",
side_effect=_mock_post_factory(deny_agent_ids={test_agent.id}),
):
response = await isolated_client.get(f"/events/{test_event.id}")
# Parent-agent denial collapses to 404.
assert response.status_code == 404

@pytest.mark.asyncio
@patch(
"src.api.authentication_middleware.AgentexAuthMiddleware.is_enabled",
return_value=True,
)
@patch(
"src.domain.services.authorization_service.AuthorizationService.is_enabled",
return_value=True,
)
@patch(
"src.utils.http_request_handler.HttpRequestHandler.post_with_error_handling",
side_effect=_mock_post_factory(),
)
async def test_get_event_nonexistent_returns_404(
self,
post_with_error_handling_mock,
is_enabled_authorization_mock,
is_enabled_mock,
isolated_client,
):
response = await isolated_client.get(f"/events/{orm_id()}")
assert response.status_code == 404
# Event lookup 404s before any authz call fires.
assert not any(
call[0][1] == "/v1/authz/check"
for call in post_with_error_handling_mock.call_args_list
)

@pytest.mark.asyncio
@patch(
"src.api.authentication_middleware.AgentexAuthMiddleware.is_enabled",
return_value=True,
)
@patch(
"src.domain.services.authorization_service.AuthorizationService.is_enabled",
return_value=True,
)
@patch(
"src.utils.http_request_handler.HttpRequestHandler.post_with_error_handling",
side_effect=_mock_post_factory(),
)
async def test_list_events_authorized_returns_200(
self,
post_with_error_handling_mock,
is_enabled_authorization_mock,
is_enabled_mock,
isolated_client,
test_event,
test_agent,
test_task,
):
response = await isolated_client.get(
f"/events?task_id={test_task.id}&agent_id={test_agent.id}"
)
assert response.status_code == 200
body = response.json()
assert any(e["id"] == test_event.id for e in body)

# Two checks: one on the task, one on the agent.
check_calls = [
call
for call in post_with_error_handling_mock.call_args_list
if call[0][1] == "/v1/authz/check"
]
assert len(check_calls) == 2
checked = {
(
call[1]["json"]["resource"]["type"],
call[1]["json"]["resource"]["selector"],
)
for call in check_calls
}
assert (AgentexResourceType.task.value, test_task.id) in checked
assert (AgentexResourceType.agent.value, test_agent.id) in checked
for call in check_calls:
assert call[1]["json"]["operation"] == "read"

@pytest.mark.asyncio
@patch(
"src.api.authentication_middleware.AgentexAuthMiddleware.is_enabled",
return_value=True,
)
@patch(
"src.domain.services.authorization_service.AuthorizationService.is_enabled",
return_value=True,
)
async def test_list_events_unauthorized_agent_returns_403(
self,
is_enabled_authorization_mock,
is_enabled_mock,
isolated_client,
test_event,
test_agent,
test_task,
):
"""Direct-resource denials surface as 403 (convention from #249/#255)."""
with patch(
"src.utils.http_request_handler.HttpRequestHandler.post_with_error_handling",
side_effect=_mock_post_factory(deny_agent_ids={test_agent.id}),
):
response = await isolated_client.get(
f"/events?task_id={test_task.id}&agent_id={test_agent.id}"
)
assert response.status_code == 403
Loading