diff --git a/tests/integ/sagemaker/jumpstart/conftest.py b/tests/integ/sagemaker/jumpstart/conftest.py index 50e062e384..938984a9c7 100644 --- a/tests/integ/sagemaker/jumpstart/conftest.py +++ b/tests/integ/sagemaker/jumpstart/conftest.py @@ -12,9 +12,14 @@ # language governing permissions and limitations under the License. from __future__ import absolute_import +import json import os +import pathlib +from datetime import datetime, timedelta, timezone + import boto3 import pytest +from filelock import FileLock from botocore.config import Config from sagemaker.jumpstart.constants import JUMPSTART_DEFAULT_REGION_NAME from sagemaker.jumpstart.hub.hub import Hub @@ -39,19 +44,23 @@ ) -def _setup(): +# Only delete leftover hubs from previous test runs that are older than this many +# hours. This guards against deleting a hub that another concurrent test run (or +# xdist worker) is actively using. +STALE_HUB_AGE_HOURS = 3 + + +def _setup(test_suite_id=None, test_hub_name=None): print("Setting up...") - test_suite_id = get_test_suite_id() - test_hub_name = f"{HUB_NAME_PREFIX}{test_suite_id}" + test_suite_id = test_suite_id or get_test_suite_id() + test_hub_name = test_hub_name or f"{HUB_NAME_PREFIX}{test_suite_id}" test_hub_description = "PySDK Integ Test Private Hub" os.environ.update({ENV_VAR_JUMPSTART_SDK_TEST_SUITE_ID: test_suite_id}) os.environ.update({ENV_VAR_JUMPSTART_SDK_TEST_HUB_NAME: test_hub_name}) # Create a private hub to use for the test session - hub = Hub( - hub_name=os.environ[ENV_VAR_JUMPSTART_SDK_TEST_HUB_NAME], sagemaker_session=get_sm_session() - ) + hub = Hub(hub_name=test_hub_name, sagemaker_session=get_sm_session()) # Check if hub already exists before creating try: @@ -73,14 +82,14 @@ def _setup(): raise -def _teardown(): +def _teardown(test_suite_id=None, test_hub_name=None): print("Tearing down...") test_cache_bucket = get_test_artifact_bucket() - test_suite_id = os.environ[ENV_VAR_JUMPSTART_SDK_TEST_SUITE_ID] + test_suite_id = test_suite_id or os.environ[ENV_VAR_JUMPSTART_SDK_TEST_SUITE_ID] - test_hub_name = os.environ[ENV_VAR_JUMPSTART_SDK_TEST_HUB_NAME] + test_hub_name = test_hub_name or os.environ[ENV_VAR_JUMPSTART_SDK_TEST_HUB_NAME] boto3_session = boto3.Session(region_name=JUMPSTART_DEFAULT_REGION_NAME) @@ -156,30 +165,41 @@ def _teardown(): _delete_hubs(sagemaker_session, test_hub_name) -def _cleanup_old_hubs(sagemaker_session): - """Clean up old test hubs to free up resources.""" +def _cleanup_old_hubs(sagemaker_session, active_hub_name=None): + """Clean up stale test hubs from previous runs to free up resources. + + Only deletes hubs that are clearly stale (older than ``STALE_HUB_AGE_HOURS``) + so that hubs actively in use by the current test run or by concurrent xdist + workers are never removed. The hub for the current run (``active_hub_name``) + is always preserved. + """ try: + active_hub_name = active_hub_name or os.environ.get(ENV_VAR_JUMPSTART_SDK_TEST_HUB_NAME) + cutoff = datetime.now(timezone.utc) - timedelta(hours=STALE_HUB_AGE_HOURS) + response = sagemaker_session.list_hubs() - test_hubs = [ - hub - for hub in response.get("HubSummaries", []) - if hub["HubName"].startswith(HUB_NAME_PREFIX) - ] - - # Sort by creation time and delete oldest hubs - test_hubs.sort(key=lambda x: x.get("CreationTime", "")) - - # Delete oldest hubs (keep only the most recent 10) - hubs_to_delete = ( - test_hubs[:-10] if len(test_hubs) > 10 else test_hubs[: max(0, len(test_hubs) - 40)] - ) + for hub in response.get("HubSummaries", []): + hub_name = hub["HubName"] + if not hub_name.startswith(HUB_NAME_PREFIX): + continue + if hub_name == active_hub_name: + continue + + creation_time = hub.get("CreationTime") + # Only delete hubs we can confirm are older than the cutoff. If the + # creation time is unavailable, err on the side of keeping the hub. + if creation_time is None: + continue + if creation_time.tzinfo is None: + creation_time = creation_time.replace(tzinfo=timezone.utc) + if creation_time >= cutoff: + continue - for hub in hubs_to_delete: try: - print(f"Deleting old hub: {hub['HubName']}") - _delete_hubs(sagemaker_session, hub["HubName"]) + print(f"Deleting stale hub: {hub_name}") + _delete_hubs(sagemaker_session, hub_name) except Exception as e: - print(f"Failed to delete hub {hub['HubName']}: {e}") + print(f"Failed to delete hub {hub_name}: {e}") except Exception as e: print(f"Failed to cleanup old hubs: {e}") @@ -210,8 +230,89 @@ def _delete_hub_contents(sagemaker_session, hub_name, model): ) +def _hub_state_root(config): + """Return the run-level tmp dir shared by the xdist controller and workers. + + The controller's basetemp is the run root (e.g. ``.../pytest-N``) while each + worker's basetemp is a ``popen-gw*`` subdir of it. Normalizing to the run + root gives every process the same location for the shared state file. + + Works across pytest versions: prefers the ``TempPathFactory`` attached as + ``config._tmp_path_factory`` and falls back to the legacy ``_tmpdirhandler``. + """ + factory = getattr(config, "_tmp_path_factory", None) + if factory is not None: + basetemp = pathlib.Path(str(factory.getbasetemp())) + else: + basetemp = pathlib.Path(str(config._tmpdirhandler.getbasetemp())) + + if basetemp.name.startswith("popen-gw"): + return basetemp.parent + return basetemp + + @pytest.fixture(scope="session", autouse=True) def setup(request): - _setup() - - request.addfinalizer(_teardown) + """Ensure a single shared private hub exists for the whole test run. + + Under pytest-xdist every worker is a separate process, so a naive + ``scope="session"`` fixture would create one hub per worker. With high + parallelism (e.g. ``-n 120``) that quickly exhausts the per-account private + hub limit (100). All workers therefore coordinate through a lock file and a + shared JSON state file: the first worker creates the hub, the rest reuse it. + + The hub is intentionally NOT deleted from a worker finalizer. xdist + distributes tests dynamically, so a worker can finish its whole session (and + run its finalizers) before another worker even reaches its first hub test; + reference counting in that finalizer would delete the hub out from under + workers still using it ("Hub ... does not exist" failures). Teardown instead + runs exactly once, after all workers finish, in ``pytest_sessionfinish`` on + the controller process. + """ + root_tmp_dir = _hub_state_root(request.config) + state_file = root_tmp_dir / "jumpstart_hub_state.json" + lock_file = root_tmp_dir / "jumpstart_hub_state.json.lock" + + with FileLock(str(lock_file)): + if state_file.is_file(): + state = json.loads(state_file.read_text()) + else: + test_suite_id = get_test_suite_id() + test_hub_name = f"{HUB_NAME_PREFIX}{test_suite_id}" + _setup(test_suite_id=test_suite_id, test_hub_name=test_hub_name) + state = { + "test_suite_id": test_suite_id, + "test_hub_name": test_hub_name, + } + state_file.write_text(json.dumps(state)) + + # Ensure this worker's environment points at the shared hub. + os.environ.update({ENV_VAR_JUMPSTART_SDK_TEST_SUITE_ID: state["test_suite_id"]}) + os.environ.update({ENV_VAR_JUMPSTART_SDK_TEST_HUB_NAME: state["test_hub_name"]}) + + +def pytest_sessionfinish(session, exitstatus): + """Tear down the shared hub once, after all xdist workers have finished. + + xdist workers carry a ``workerinput`` attribute on their config; only the + controller (or a non-xdist run, which has no workerinput) performs teardown. + Running here guarantees no worker is still using the hub. + """ + if hasattr(session.config, "workerinput"): + return # xdist worker: the controller handles teardown. + + root_tmp_dir = _hub_state_root(session.config) + state_file = root_tmp_dir / "jumpstart_hub_state.json" + lock_file = root_tmp_dir / "jumpstart_hub_state.json.lock" + + with FileLock(str(lock_file)): + if not state_file.is_file(): + return + state = json.loads(state_file.read_text()) + try: + _teardown( + test_suite_id=state["test_suite_id"], + test_hub_name=state["test_hub_name"], + ) + finally: + state_file.unlink() diff --git a/tests/integ/sagemaker/jumpstart/private_hub/estimator/test_jumpstart_private_hub_estimator.py b/tests/integ/sagemaker/jumpstart/private_hub/estimator/test_jumpstart_private_hub_estimator.py index d512915343..4c455c3b32 100644 --- a/tests/integ/sagemaker/jumpstart/private_hub/estimator/test_jumpstart_private_hub_estimator.py +++ b/tests/integ/sagemaker/jumpstart/private_hub/estimator/test_jumpstart_private_hub_estimator.py @@ -17,7 +17,6 @@ import pytest from sagemaker.jumpstart.constants import JUMPSTART_DEFAULT_REGION_NAME -from sagemaker.jumpstart.hub.hub import Hub from sagemaker.jumpstart.estimator import JumpStartEstimator from sagemaker.jumpstart.utils import get_jumpstart_content_bucket @@ -28,10 +27,9 @@ JUMPSTART_TAG, ) from tests.integ.sagemaker.jumpstart.utils import ( - get_public_hub_model_arn, get_sm_session, - with_exponential_backoff, get_training_dataset_for_model_and_version, + add_model_references_to_hub, ) MAX_INIT_TIME_SECONDS = 5 @@ -43,23 +41,13 @@ } -@with_exponential_backoff() -def create_model_reference(hub_instance, model_arn): - try: - hub_instance.create_model_reference(model_arn=model_arn) - except Exception: - pass - - @pytest.fixture(scope="session") def add_model_references(): - # Create Model References to test in Hub - hub_instance = Hub( - hub_name=os.environ[ENV_VAR_JUMPSTART_SDK_TEST_HUB_NAME], sagemaker_session=get_sm_session() + # Create Model References to test in Hub (idempotent + waits for readiness) + add_model_references_to_hub( + hub_name=os.environ[ENV_VAR_JUMPSTART_SDK_TEST_HUB_NAME], + model_ids=TEST_MODEL_IDS, ) - for model in TEST_MODEL_IDS: - model_arn = get_public_hub_model_arn(hub_instance, model) - create_model_reference(hub_instance, model_arn) def test_jumpstart_hub_estimator(setup, add_model_references): diff --git a/tests/integ/sagemaker/jumpstart/private_hub/model/test_jumpstart_private_hub_model.py b/tests/integ/sagemaker/jumpstart/private_hub/model/test_jumpstart_private_hub_model.py index 3956c2240d..3737391102 100644 --- a/tests/integ/sagemaker/jumpstart/private_hub/model/test_jumpstart_private_hub_model.py +++ b/tests/integ/sagemaker/jumpstart/private_hub/model/test_jumpstart_private_hub_model.py @@ -17,7 +17,6 @@ import pytest from sagemaker.enums import EndpointType -from sagemaker.jumpstart.hub.hub import Hub from sagemaker.jumpstart.hub.utils import generate_hub_arn_for_init_kwargs from sagemaker.predictor import retrieve_default @@ -30,9 +29,8 @@ JUMPSTART_TAG, ) from tests.integ.sagemaker.jumpstart.utils import ( - get_public_hub_model_arn, get_sm_session, - with_exponential_backoff, + add_model_references_to_hub, ) MAX_INIT_TIME_SECONDS = 5 @@ -46,23 +44,13 @@ } -@with_exponential_backoff() -def create_model_reference(hub_instance, model_arn): - try: - hub_instance.create_model_reference(model_arn=model_arn) - except Exception: - pass - - @pytest.fixture(scope="session") def add_model_references(): - # Create Model References to test in Hub - hub_instance = Hub( - hub_name=os.environ[ENV_VAR_JUMPSTART_SDK_TEST_HUB_NAME], sagemaker_session=get_sm_session() + # Create Model References to test in Hub (idempotent + waits for readiness) + add_model_references_to_hub( + hub_name=os.environ[ENV_VAR_JUMPSTART_SDK_TEST_HUB_NAME], + model_ids=TEST_MODEL_IDS, ) - for model in TEST_MODEL_IDS: - model_arn = get_public_hub_model_arn(hub_instance, model) - create_model_reference(hub_instance, model_arn) def test_jumpstart_hub_model(setup, add_model_references): diff --git a/tests/integ/sagemaker/jumpstart/utils.py b/tests/integ/sagemaker/jumpstart/utils.py index d439ef7e95..c326b135e0 100644 --- a/tests/integ/sagemaker/jumpstart/utils.py +++ b/tests/integ/sagemaker/jumpstart/utils.py @@ -12,9 +12,11 @@ # language governing permissions and limitations under the License. from __future__ import absolute_import import functools +import hashlib import json import random +import tempfile import time import uuid from typing import Any, Dict, List, Tuple @@ -24,6 +26,7 @@ from botocore.config import Config from botocore.exceptions import ClientError +from filelock import FileLock import pytest @@ -149,6 +152,80 @@ def wrapper(*args, **kwargs): return decorator +@with_exponential_backoff() +def _create_model_reference(hub_instance, model_arn): + """Create a model reference in the hub, tolerating an already-existing one.""" + try: + hub_instance.create_model_reference(model_arn=model_arn) + except ClientError as e: + # A reference that already exists is fine (idempotent across xdist + # workers sharing a hub). Anything else should surface. + if e.response["Error"]["Code"] in ("ResourceInUse", "ResourceLimitExceeded"): + return + raise + + +def _wait_for_model_reference(sagemaker_session, hub_name, model_name, timeout=300, poll=10): + """Block until a model reference is resolvable in the hub. + + ``create_hub_content_reference`` is asynchronous, so a test that uses the + reference immediately after creation can race against propagation and see + ``ResourceNotFound``. Poll until the reference is listable (or time out). + """ + deadline = time.time() + timeout + last_error = None + while time.time() < deadline: + try: + response = sagemaker_session.list_hub_content_versions( + hub_name=hub_name, + hub_content_type="ModelReference", + hub_content_name=model_name, + ) + if response.get("HubContentSummaries"): + return + except ClientError as e: + if e.response["Error"]["Code"] != "ResourceNotFound": + raise + last_error = e + time.sleep(poll) + raise TimeoutError( + f"Model reference '{model_name}' was not available in hub '{hub_name}' " + f"within {timeout}s. Last error: {last_error}" + ) + + +def add_model_references_to_hub(hub_name, model_ids): + """Idempotently add model references to a hub and wait until they resolve. + + Safe to call concurrently from multiple xdist workers sharing a hub: a lock + file serializes the creation work and a marker file ensures it only runs + once per hub per test run. The marker is keyed by both the hub name and the + specific set of model ids, so different callers adding different model sets + to the same shared hub each run exactly once. + """ + sagemaker_session = get_sm_session() + hub_instance = Hub(hub_name=hub_name, sagemaker_session=sagemaker_session) + + model_ids = sorted(model_ids) + models_digest = hashlib.md5( + ",".join(model_ids).encode("utf-8"), usedforsecurity=False + ).hexdigest() + marker = os.path.join( + tempfile.gettempdir(), f"jumpstart_model_refs_{hub_name}_{models_digest}.done" + ) + lock_path = f"{marker}.lock" + + with FileLock(lock_path): + if not os.path.exists(marker): + for model in model_ids: + model_arn = get_public_hub_model_arn(hub_instance, model) + _create_model_reference(hub_instance, model_arn) + for model in model_ids: + _wait_for_model_reference(sagemaker_session, hub_name, model) + with open(marker, "w") as f: + f.write("done") + + class EndpointInvoker: def __init__( self, diff --git a/tests/integ/sagemaker/serve/conftest.py b/tests/integ/sagemaker/serve/conftest.py index 5eb3a2ea11..5119dfb3a0 100644 --- a/tests/integ/sagemaker/serve/conftest.py +++ b/tests/integ/sagemaker/serve/conftest.py @@ -18,7 +18,55 @@ import sagemaker import sagemaker_core.helper.session_helper as core_session +from botocore.config import Config +from sagemaker import Session + DEFAULT_REGION = "us-west-2" +CUSTOM_S3_OBJECT_KEY_PREFIX = "session-default-prefix" + + +@pytest.fixture(scope="session") +def sagemaker_session( + sagemaker_client_config, sagemaker_runtime_config, boto_session, sagemaker_metrics_config +): + """Isolated Session for the serve (ModelBuilder) integ tests. + + Overrides the repo-wide ``sagemaker_session`` fixture (defined in + ``tests/conftest.py``) for everything under ``tests/integ/sagemaker/serve``. + + ModelBuilder mutates the global ``session.settings._local_download_dir`` to a + temporary ``/tmp/sagemaker/model-builder/`` path. When the shared + session-scoped fixture is reused by other test modules, that temp dir gets + cleaned up while the polluted setting lingers, breaking unrelated tests such + as ``tests/integ/sagemaker/workflow/test_tuning_steps.py::test_tuning_multi_algos`` + (``ValueError: Inputted directory ... does not exist``). Scoping a dedicated + session to the serve package keeps that mutation contained here. + """ + sagemaker_client_config.setdefault("config", Config(retries=dict(max_attempts=10))) + sagemaker_client = ( + boto_session.client("sagemaker", **sagemaker_client_config) + if sagemaker_client_config + else None + ) + runtime_client = ( + boto_session.client("sagemaker-runtime", **sagemaker_runtime_config) + if sagemaker_runtime_config + else None + ) + metrics_client = ( + boto_session.client("sagemaker-metrics", **sagemaker_metrics_config) + if sagemaker_metrics_config + else None + ) + + return Session( + boto_session=boto_session, + sagemaker_client=sagemaker_client, + sagemaker_runtime_client=runtime_client, + sagemaker_metrics_client=metrics_client, + sagemaker_config={}, + default_bucket_prefix=CUSTOM_S3_OBJECT_KEY_PREFIX, + ) @pytest.fixture(scope="module") diff --git a/tests/integ/sagemaker/serve/constants.py b/tests/integ/sagemaker/serve/constants.py index 3f25f6a575..b2fcb4154f 100644 --- a/tests/integ/sagemaker/serve/constants.py +++ b/tests/integ/sagemaker/serve/constants.py @@ -21,6 +21,10 @@ SERVE_MODEL_PACKAGE_TIMEOUT = 10 SERVE_LOCAL_CONTAINER_TIMEOUT = 10 SERVE_SAGEMAKER_ENDPOINT_TIMEOUT = 15 +# Inference-component deployments of large (7B) JumpStart models pull a big image +# and load the model before the endpoint reaches InService, which routinely takes +# longer than the standard endpoint timeout. Give that flow more headroom. +SERVE_SAGEMAKER_IC_ENDPOINT_TIMEOUT = 30 SERVE_SAVE_TIMEOUT = 2 PYTHON_VERSION_IS_NOT_38 = platform.python_version_tuple()[1] != "8" diff --git a/tests/integ/sagemaker/serve/test_serve_model_builder_inference_component_happy.py b/tests/integ/sagemaker/serve/test_serve_model_builder_inference_component_happy.py index 06312a45b1..bb2c1a34c8 100644 --- a/tests/integ/sagemaker/serve/test_serve_model_builder_inference_component_happy.py +++ b/tests/integ/sagemaker/serve/test_serve_model_builder_inference_component_happy.py @@ -24,7 +24,7 @@ from sagemaker.utils import unique_name_from_base from tests.integ.sagemaker.serve.constants import ( - SERVE_SAGEMAKER_ENDPOINT_TIMEOUT, + SERVE_SAGEMAKER_IC_ENDPOINT_TIMEOUT, ) from tests.integ.timeout import timeout import logging @@ -88,7 +88,7 @@ def test_model_builder_ic_sagemaker_endpoint( chain.build() - with timeout(minutes=SERVE_SAGEMAKER_ENDPOINT_TIMEOUT): + with timeout(minutes=SERVE_SAGEMAKER_IC_ENDPOINT_TIMEOUT): try: logger.info("Deploying and predicting in SAGEMAKER_ENDPOINT mode...") endpoint_name = f"llama-ic-endpoint-name-{uuid.uuid1().hex}" diff --git a/tests/integ/test_spark_processing.py b/tests/integ/test_spark_processing.py index ac956be94e..b6443a80bb 100644 --- a/tests/integ/test_spark_processing.py +++ b/tests/integ/test_spark_processing.py @@ -38,6 +38,8 @@ @pytest.fixture(scope="module", autouse=True) def build_jar(): jar_file_path = os.path.join(SPARK_PATH, "code", "java", "hello-java-spark") + jar_file = os.path.join(jar_file_path, "hello-spark-java.jar") + # compile java file java_version = subprocess.check_output(["java", "-version"], stderr=subprocess.STDOUT).decode( "utf-8" @@ -45,30 +47,39 @@ def build_jar(): java_version = re.search(JAVA_VERSION_PATTERN, java_version).groups()[0] if float(java_version) > 1.8: - subprocess.run( - [ - "javac", - "--release", - "8", - os.path.join(jar_file_path, JAVA_FILE_PATH, "HelloJavaSparkApp.java"), - ] - ) + javac_cmd = [ + "javac", + "--release", + "8", + os.path.join(jar_file_path, JAVA_FILE_PATH, "HelloJavaSparkApp.java"), + ] else: - subprocess.run( - ["javac", os.path.join(jar_file_path, JAVA_FILE_PATH, "HelloJavaSparkApp.java")] - ) + javac_cmd = ["javac", os.path.join(jar_file_path, JAVA_FILE_PATH, "HelloJavaSparkApp.java")] + + jar_cmd = [ + "jar", + "cfm", + jar_file, + os.path.join(jar_file_path, "manifest.txt"), + "-C", + jar_file_path, + ".", + ] - subprocess.run( - [ - "jar", - "cfm", - os.path.join(jar_file_path, "hello-spark-java.jar"), - os.path.join(jar_file_path, "manifest.txt"), - "-C", - jar_file_path, - ".", - ] - ) + # Build with check=True so a failing javac/jar surfaces immediately instead + # of being swallowed. The jar (re)build truncates the committed + # hello-spark-java.jar, so a silent failure here would leave the test with a + # missing/corrupt jar and a confusing "code ... wasn't found" error at run + # time (especially under xdist, where this runs per worker). + for cmd in (javac_cmd, jar_cmd): + result = subprocess.run(cmd, capture_output=True, text=True) + if result.returncode != 0: + raise RuntimeError( + f"Failed to build Spark test jar (command: {' '.join(cmd)}).\n" + f"stdout:\n{result.stdout}\nstderr:\n{result.stderr}" + ) + + assert os.path.isfile(jar_file), f"Spark test jar was not produced at {jar_file}" @pytest.fixture(scope="module")