diff --git a/src/dstack/_internal/server/background/scheduled_tasks/probes.py b/src/dstack/_internal/server/background/scheduled_tasks/probes.py index d2c550b5d..2d1821128 100644 --- a/src/dstack/_internal/server/background/scheduled_tasks/probes.py +++ b/src/dstack/_internal/server/background/scheduled_tasks/probes.py @@ -12,9 +12,9 @@ from dstack._internal.server.models import InstanceModel, JobModel, ProbeModel from dstack._internal.server.services.jobs import get_job_spec from dstack._internal.server.services.jobs.job_replica_http_client import ( - SSH_CONNECT_TIMEOUT, get_service_replica_client, ) +from dstack._internal.server.services.jobs.job_replica_tunnel import SSH_CONNECT_TIMEOUT from dstack._internal.server.services.locking import get_locker from dstack._internal.server.services.logging import fmt from dstack._internal.utils.common import get_current_datetime diff --git a/src/dstack/_internal/server/services/jobs/job_replica_grpc_client.py b/src/dstack/_internal/server/services/jobs/job_replica_grpc_client.py index bc6f6cffe..1827274a1 100644 --- a/src/dstack/_internal/server/services/jobs/job_replica_grpc_client.py +++ b/src/dstack/_internal/server/services/jobs/job_replica_grpc_client.py @@ -2,25 +2,14 @@ from collections.abc import AsyncGenerator from contextlib import asynccontextmanager -from datetime import timedelta from pathlib import Path -from tempfile import TemporaryDirectory from typing import Any import grpc -from dstack._internal.core.services.ssh.tunnel import ( - SSH_DEFAULT_OPTIONS, - IPSocket, - SocketPair, - UnixSocket, -) from dstack._internal.server.models import JobModel -from dstack._internal.server.services.jobs import get_job_spec -from dstack._internal.server.services.ssh import container_ssh_tunnel -from dstack._internal.utils.common import get_or_error +from dstack._internal.server.services.jobs.job_replica_tunnel import get_service_replica_tunnel -SSH_CONNECT_TIMEOUT = timedelta(seconds=10) # Match router_worker_sync HTTP server_info cap (_MAX_SERVER_INFO_RESPONSE_BYTES). _MAX_GRPC_MESSAGE_BYTES = 256 * 1024 _GRPC_CHANNEL_OPTIONS = ( @@ -29,29 +18,20 @@ ) +@asynccontextmanager +async def get_service_replica_grpc_channel_over_uds( + uds_path: Path, +) -> AsyncGenerator[Any, None]: + target = f"unix://{uds_path}" + channel = grpc.aio.insecure_channel(target, options=_GRPC_CHANNEL_OPTIONS) + try: + yield channel + finally: + await channel.close() + + @asynccontextmanager async def get_service_replica_grpc_client(job: JobModel) -> AsyncGenerator[Any, None]: - options = { - **SSH_DEFAULT_OPTIONS, - "ConnectTimeout": str(int(SSH_CONNECT_TIMEOUT.total_seconds())), - } - job_spec = get_job_spec(job) - with TemporaryDirectory() as temp_dir: - # Keep the same socket file name as the HTTP helper for consistency. - app_socket_path = (Path(temp_dir) / "replica.sock").absolute() - async with container_ssh_tunnel( - job=job, - forwarded_sockets=[ - SocketPair( - remote=IPSocket("localhost", get_or_error(job_spec.service_port)), - local=UnixSocket(app_socket_path), - ), - ], - options=options, - ): - target = f"unix://{app_socket_path}" - channel = grpc.aio.insecure_channel(target, options=_GRPC_CHANNEL_OPTIONS) - try: - yield channel - finally: - await channel.close() + async with get_service_replica_tunnel(job) as uds_path: + async with get_service_replica_grpc_channel_over_uds(uds_path) as channel: + yield channel diff --git a/src/dstack/_internal/server/services/jobs/job_replica_http_client.py b/src/dstack/_internal/server/services/jobs/job_replica_http_client.py index 1497fe5e0..ec63ae868 100644 --- a/src/dstack/_internal/server/services/jobs/job_replica_http_client.py +++ b/src/dstack/_internal/server/services/jobs/job_replica_http_client.py @@ -2,48 +2,26 @@ from collections.abc import AsyncGenerator from contextlib import asynccontextmanager -from datetime import timedelta from pathlib import Path -from tempfile import TemporaryDirectory from httpx import AsyncClient, AsyncHTTPTransport -from dstack._internal.core.services.ssh.tunnel import ( - SSH_DEFAULT_OPTIONS, - IPSocket, - SocketPair, - UnixSocket, -) from dstack._internal.server.models import JobModel -from dstack._internal.server.services.jobs import get_job_spec -from dstack._internal.server.services.ssh import container_ssh_tunnel -from dstack._internal.utils.common import get_or_error +from dstack._internal.server.services.jobs.job_replica_tunnel import get_service_replica_tunnel -SSH_CONNECT_TIMEOUT = timedelta(seconds=10) + +@asynccontextmanager +async def get_service_replica_http_client_over_uds( + uds_path: Path, +) -> AsyncGenerator[AsyncClient, None]: + async with AsyncClient(transport=AsyncHTTPTransport(uds=str(uds_path))) as client: + yield client @asynccontextmanager async def get_service_replica_client( job: JobModel, ) -> AsyncGenerator[AsyncClient, None]: - options = { - **SSH_DEFAULT_OPTIONS, - "ConnectTimeout": str(int(SSH_CONNECT_TIMEOUT.total_seconds())), - } - job_spec = get_job_spec(job) - with TemporaryDirectory() as temp_dir: - app_socket_path = (Path(temp_dir) / "replica.sock").absolute() - async with container_ssh_tunnel( - job=job, - forwarded_sockets=[ - SocketPair( - remote=IPSocket("localhost", get_or_error(job_spec.service_port)), - local=UnixSocket(app_socket_path), - ), - ], - options=options, - ): - async with AsyncClient( - transport=AsyncHTTPTransport(uds=str(app_socket_path)) - ) as client: - yield client + async with get_service_replica_tunnel(job) as uds_path: + async with get_service_replica_http_client_over_uds(uds_path) as client: + yield client diff --git a/src/dstack/_internal/server/services/jobs/job_replica_tunnel.py b/src/dstack/_internal/server/services/jobs/job_replica_tunnel.py new file mode 100644 index 000000000..b6466255f --- /dev/null +++ b/src/dstack/_internal/server/services/jobs/job_replica_tunnel.py @@ -0,0 +1,43 @@ +"""SSH tunnel to a job replica's service port, exposed as a local Unix domain socket.""" + +from collections.abc import AsyncGenerator +from contextlib import asynccontextmanager +from datetime import timedelta +from pathlib import Path +from tempfile import TemporaryDirectory + +from dstack._internal.core.services.ssh.tunnel import ( + SSH_DEFAULT_OPTIONS, + IPSocket, + SocketPair, + UnixSocket, +) +from dstack._internal.server.models import JobModel +from dstack._internal.server.services.jobs import get_job_spec +from dstack._internal.server.services.ssh import container_ssh_tunnel +from dstack._internal.utils.common import get_or_error + +SSH_CONNECT_TIMEOUT = timedelta(seconds=10) +_REPLICA_SOCKET_NAME = "replica.sock" + + +@asynccontextmanager +async def get_service_replica_tunnel(job: JobModel) -> AsyncGenerator[Path, None]: + options = { + **SSH_DEFAULT_OPTIONS, + "ConnectTimeout": str(int(SSH_CONNECT_TIMEOUT.total_seconds())), + } + job_spec = get_job_spec(job) + with TemporaryDirectory() as temp_dir: + app_socket_path = (Path(temp_dir) / _REPLICA_SOCKET_NAME).absolute() + async with container_ssh_tunnel( + job=job, + forwarded_sockets=[ + SocketPair( + remote=IPSocket("localhost", get_or_error(job_spec.service_port)), + local=UnixSocket(app_socket_path), + ), + ], + options=options, + ): + yield app_socket_path diff --git a/src/dstack/_internal/server/services/runs/router_worker_sync.py b/src/dstack/_internal/server/services/runs/router_worker_sync.py index 910dc8d57..c9960f54d 100644 --- a/src/dstack/_internal/server/services/runs/router_worker_sync.py +++ b/src/dstack/_internal/server/services/runs/router_worker_sync.py @@ -28,11 +28,14 @@ from dstack._internal.server.models import JobModel, RunModel from dstack._internal.server.services.jobs import get_job_provisioning_data, get_job_spec from dstack._internal.server.services.jobs.job_replica_grpc_client import ( + get_service_replica_grpc_channel_over_uds, get_service_replica_grpc_client, ) from dstack._internal.server.services.jobs.job_replica_http_client import ( get_service_replica_client, + get_service_replica_http_client_over_uds, ) +from dstack._internal.server.services.jobs.job_replica_tunnel import get_service_replica_tunnel from dstack._internal.server.services.logging import fmt from dstack._internal.utils.logging import get_logger @@ -366,50 +369,49 @@ def _is_expected_grpc_discovery_error(error: Exception) -> bool: return False -async def _get_http_worker(job_model: JobModel, *, worker_url: str) -> _WorkerPayloadResult: +async def _probe_http_worker(client: AsyncClient, *, worker_url: str) -> _WorkerPayloadResult: try: - async with get_service_replica_client(job_model) as client: - data = await _request_json_limited( - client, - "GET", - f"{_ROUTER_HTTP}/server_info", - max_response_bytes=_MAX_SERVER_INFO_RESPONSE_BYTES, - ok_statuses={200}, - ) - if isinstance(data, dict): - if data.get("status") != "ready": - return {"status": "not_ready", "worker": None} - mode = data.get("disaggregation_mode", "") - if mode == "prefill": - bootstrap_port = data.get("disaggregation_bootstrap_port") - worker: _TargetWorker = { - "url": worker_url, - "worker_type": "prefill", - "connection_mode": "http", - "runtime_type": "sglang", - } - if bootstrap_port is not None: - worker["bootstrap_port"] = bootstrap_port - return {"status": "ready", "worker": worker} - if mode == "decode": - return { - "status": "ready", - "worker": { - "url": worker_url, - "worker_type": "decode", - "connection_mode": "http", - "runtime_type": "sglang", - }, - } + data = await _request_json_limited( + client, + "GET", + f"{_ROUTER_HTTP}/server_info", + max_response_bytes=_MAX_SERVER_INFO_RESPONSE_BYTES, + ok_statuses={200}, + ) + if isinstance(data, dict): + if data.get("status") != "ready": + return {"status": "not_ready", "worker": None} + mode = data.get("disaggregation_mode", "") + if mode == "prefill": + bootstrap_port = data.get("disaggregation_bootstrap_port") + worker: _TargetWorker = { + "url": worker_url, + "worker_type": "prefill", + "connection_mode": "http", + "runtime_type": "sglang", + } + if bootstrap_port is not None: + worker["bootstrap_port"] = bootstrap_port + return {"status": "ready", "worker": worker} + if mode == "decode": return { "status": "ready", "worker": { "url": worker_url, - "worker_type": "regular", + "worker_type": "decode", "connection_mode": "http", "runtime_type": "sglang", }, } + return { + "status": "ready", + "worker": { + "url": worker_url, + "worker_type": "regular", + "connection_mode": "http", + "runtime_type": "sglang", + }, + } except _ResponseTooLargeError: logger.warning("server_info response too large for worker %s", worker_url) except RemoteProtocolError as e: @@ -419,6 +421,11 @@ async def _get_http_worker(job_model: JobModel, *, worker_url: str) -> _WorkerPa return {"status": "not_ready", "worker": None} +async def _get_http_worker(job_model: JobModel, *, worker_url: str) -> _WorkerPayloadResult: + async with get_service_replica_client(job_model) as client: + return await _probe_http_worker(client, worker_url=worker_url) + + async def _get_grpc_server_info( channel: grpc.aio.Channel, runtime_type: _RuntimeType, @@ -487,6 +494,30 @@ def _grpc_server_info_to_worker( return worker +async def _probe_grpc_worker( + channel: grpc.aio.Channel, + *, + worker_url: str, + runtime_type: Optional[_RuntimeType] = None, +) -> _WorkerPayloadResult: + if runtime_type is not None: + try: + response = await _get_grpc_server_info(channel, runtime_type) + except Exception as e: + if _is_expected_grpc_discovery_error(e): + logger.debug("gRPC worker %s not ready (GetServerInfo)", worker_url) + return {"status": "not_ready", "worker": None} + raise + else: + runtime_type, response = await _discover_grpc_server_info(channel) + if runtime_type is None or response is None: + logger.debug("gRPC worker %s not ready (GetServerInfo)", worker_url) + return {"status": "not_ready", "worker": None} + + worker = _grpc_server_info_to_worker(worker_url, runtime_type, response) + return {"status": "ready", "worker": worker} + + async def _get_grpc_worker( job_model: JobModel, *, @@ -495,19 +526,9 @@ async def _get_grpc_worker( ) -> _WorkerPayloadResult: try: async with get_service_replica_grpc_client(job_model) as channel: - if runtime_type is not None: - try: - response = await _get_grpc_server_info(channel, runtime_type) - except Exception as e: - if _is_expected_grpc_discovery_error(e): - logger.debug("gRPC worker %s not ready (GetServerInfo)", worker_url) - return {"status": "not_ready", "worker": None} - raise - else: - runtime_type, response = await _discover_grpc_server_info(channel) - if runtime_type is None or response is None: - logger.debug("gRPC worker %s not ready (GetServerInfo)", worker_url) - return {"status": "not_ready", "worker": None} + return await _probe_grpc_worker( + channel, worker_url=worker_url, runtime_type=runtime_type + ) except Exception as e: logger.exception( "Could not fetch gRPC GetServerInfo for worker %s: %r", @@ -516,9 +537,6 @@ async def _get_grpc_worker( ) return {"status": "not_ready", "worker": None} - worker = _grpc_server_info_to_worker(worker_url, runtime_type, response) - return {"status": "ready", "worker": worker} - async def _get_worker( job_model: JobModel, @@ -535,18 +553,25 @@ async def _get_worker( if connection_mode == "http": return await _get_http_worker(job_model, worker_url=http_worker_url) # Router workers list is empty and no connection_mode discovered. - try: - result = await _get_http_worker(job_model, worker_url=http_worker_url) - except RemoteProtocolError as e: - logger.debug( - "HTTP server_info probe failed for %s (trying gRPC): %r", - http_worker_url, - e, - ) - result: _WorkerPayloadResult = {"status": "not_ready", "worker": None} - if result["status"] == "ready": - return result - return await _get_grpc_worker(job_model, worker_url=grpc_worker_url, runtime_type=runtime_type) + async with get_service_replica_tunnel(job_model) as uds_path: + async with get_service_replica_http_client_over_uds(uds_path) as client: + result = await _probe_http_worker(client, worker_url=http_worker_url) + if result["status"] == "ready": + return result + async with get_service_replica_grpc_channel_over_uds(uds_path) as channel: + try: + return await _probe_grpc_worker( + channel, + worker_url=grpc_worker_url, + runtime_type=runtime_type, + ) + except Exception as e: + logger.exception( + "Could not fetch gRPC GetServerInfo for worker %s: %r", + grpc_worker_url, + e, + ) + return {"status": "not_ready", "worker": None} async def _build_target_workers( diff --git a/src/tests/_internal/server/background/scheduled_tasks/test_probes.py b/src/tests/_internal/server/background/scheduled_tasks/test_probes.py index bfd569ab1..22a9ba338 100644 --- a/src/tests/_internal/server/background/scheduled_tasks/test_probes.py +++ b/src/tests/_internal/server/background/scheduled_tasks/test_probes.py @@ -10,9 +10,9 @@ from dstack._internal.core.models.runs import JobStatus from dstack._internal.server.background.scheduled_tasks.probes import ( PROCESSING_OVERHEAD_TIMEOUT, - SSH_CONNECT_TIMEOUT, process_probes, ) +from dstack._internal.server.services.jobs.job_replica_tunnel import SSH_CONNECT_TIMEOUT from dstack._internal.server.testing.common import ( create_instance, create_job, diff --git a/src/tests/_internal/server/services/runs/test_router_worker_sync.py b/src/tests/_internal/server/services/runs/test_router_worker_sync.py index 2cf027563..093293994 100644 --- a/src/tests/_internal/server/services/runs/test_router_worker_sync.py +++ b/src/tests/_internal/server/services/runs/test_router_worker_sync.py @@ -1,4 +1,5 @@ from contextlib import asynccontextmanager, contextmanager +from pathlib import Path from unittest.mock import AsyncMock, MagicMock, patch import pytest @@ -229,3 +230,79 @@ async def test_get_worker_grpc_preference_skips_http(): assert result == grpc_not_ready grpc_mock.assert_awaited_once() http_mock.assert_not_awaited() + + +@pytest.mark.asyncio +async def test_get_worker_bootstrap_uses_single_tunnel(): + job = MagicMock() + uds_path = Path("/tmp/replica.sock") + grpc_ready: dict = { + "status": "ready", + "worker": { + "url": "grpc://10.0.0.1:8000", + "worker_type": "prefill", + "connection_mode": "grpc", + "runtime_type": "vllm", + }, + } + + tunnel_cm = AsyncMock() + tunnel_cm.__aenter__.return_value = uds_path + tunnel_cm.__aexit__.return_value = None + + http_cm = AsyncMock() + http_cm.__aenter__.return_value = MagicMock() + http_cm.__aexit__.return_value = None + + grpc_cm = AsyncMock() + grpc_cm.__aenter__.return_value = MagicMock() + grpc_cm.__aexit__.return_value = None + + with ( + patch( + "dstack._internal.server.services.runs.router_worker_sync.get_service_replica_tunnel", + return_value=tunnel_cm, + ) as tunnel_mock, + patch( + "dstack._internal.server.services.runs.router_worker_sync" + ".get_service_replica_http_client_over_uds", + return_value=http_cm, + ) as http_over_uds_mock, + patch( + "dstack._internal.server.services.runs.router_worker_sync" + ".get_service_replica_grpc_channel_over_uds", + return_value=grpc_cm, + ) as grpc_over_uds_mock, + patch( + "dstack._internal.server.services.runs.router_worker_sync._probe_http_worker", + new_callable=AsyncMock, + return_value={"status": "not_ready", "worker": None}, + ) as http_probe_mock, + patch( + "dstack._internal.server.services.runs.router_worker_sync._probe_grpc_worker", + new_callable=AsyncMock, + return_value=grpc_ready, + ) as grpc_probe_mock, + patch( + "dstack._internal.server.services.runs.router_worker_sync._get_http_worker", + new_callable=AsyncMock, + ) as get_http_mock, + patch( + "dstack._internal.server.services.runs.router_worker_sync._get_grpc_worker", + new_callable=AsyncMock, + ) as get_grpc_mock, + ): + result = await _get_worker( + job, + http_worker_url="http://10.0.0.1:8000", + grpc_worker_url="grpc://10.0.0.1:8000", + ) + + assert result == grpc_ready + tunnel_mock.assert_called_once_with(job) + http_over_uds_mock.assert_called_once_with(uds_path) + grpc_over_uds_mock.assert_called_once_with(uds_path) + http_probe_mock.assert_awaited_once() + grpc_probe_mock.assert_awaited_once() + get_http_mock.assert_not_awaited() + get_grpc_mock.assert_not_awaited()