Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
163 changes: 132 additions & 31 deletions tests/integ/sagemaker/jumpstart/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,14 @@
# language governing permissions and limitations under the License.
from __future__ import absolute_import

import json
import os
import pathlib
from datetime import datetime, timedelta, timezone

import boto3
import pytest
from filelock import FileLock
from botocore.config import Config
from sagemaker.jumpstart.constants import JUMPSTART_DEFAULT_REGION_NAME
from sagemaker.jumpstart.hub.hub import Hub
Expand All @@ -39,19 +44,23 @@
)


def _setup():
# Only delete leftover hubs from previous test runs that are older than this many
# hours. This guards against deleting a hub that another concurrent test run (or
# xdist worker) is actively using.
STALE_HUB_AGE_HOURS = 3


def _setup(test_suite_id=None, test_hub_name=None):
print("Setting up...")
test_suite_id = get_test_suite_id()
test_hub_name = f"{HUB_NAME_PREFIX}{test_suite_id}"
test_suite_id = test_suite_id or get_test_suite_id()
test_hub_name = test_hub_name or f"{HUB_NAME_PREFIX}{test_suite_id}"
test_hub_description = "PySDK Integ Test Private Hub"

os.environ.update({ENV_VAR_JUMPSTART_SDK_TEST_SUITE_ID: test_suite_id})
os.environ.update({ENV_VAR_JUMPSTART_SDK_TEST_HUB_NAME: test_hub_name})

# Create a private hub to use for the test session
hub = Hub(
hub_name=os.environ[ENV_VAR_JUMPSTART_SDK_TEST_HUB_NAME], sagemaker_session=get_sm_session()
)
hub = Hub(hub_name=test_hub_name, sagemaker_session=get_sm_session())

# Check if hub already exists before creating
try:
Expand All @@ -73,14 +82,14 @@ def _setup():
raise


def _teardown():
def _teardown(test_suite_id=None, test_hub_name=None):
print("Tearing down...")

test_cache_bucket = get_test_artifact_bucket()

test_suite_id = os.environ[ENV_VAR_JUMPSTART_SDK_TEST_SUITE_ID]
test_suite_id = test_suite_id or os.environ[ENV_VAR_JUMPSTART_SDK_TEST_SUITE_ID]

test_hub_name = os.environ[ENV_VAR_JUMPSTART_SDK_TEST_HUB_NAME]
test_hub_name = test_hub_name or os.environ[ENV_VAR_JUMPSTART_SDK_TEST_HUB_NAME]

boto3_session = boto3.Session(region_name=JUMPSTART_DEFAULT_REGION_NAME)

Expand Down Expand Up @@ -156,30 +165,41 @@ def _teardown():
_delete_hubs(sagemaker_session, test_hub_name)


def _cleanup_old_hubs(sagemaker_session):
"""Clean up old test hubs to free up resources."""
def _cleanup_old_hubs(sagemaker_session, active_hub_name=None):
"""Clean up stale test hubs from previous runs to free up resources.

Only deletes hubs that are clearly stale (older than ``STALE_HUB_AGE_HOURS``)
so that hubs actively in use by the current test run or by concurrent xdist
workers are never removed. The hub for the current run (``active_hub_name``)
is always preserved.
"""
try:
active_hub_name = active_hub_name or os.environ.get(ENV_VAR_JUMPSTART_SDK_TEST_HUB_NAME)
cutoff = datetime.now(timezone.utc) - timedelta(hours=STALE_HUB_AGE_HOURS)

response = sagemaker_session.list_hubs()
test_hubs = [
hub
for hub in response.get("HubSummaries", [])
if hub["HubName"].startswith(HUB_NAME_PREFIX)
]

# Sort by creation time and delete oldest hubs
test_hubs.sort(key=lambda x: x.get("CreationTime", ""))

# Delete oldest hubs (keep only the most recent 10)
hubs_to_delete = (
test_hubs[:-10] if len(test_hubs) > 10 else test_hubs[: max(0, len(test_hubs) - 40)]
)
for hub in response.get("HubSummaries", []):
hub_name = hub["HubName"]
if not hub_name.startswith(HUB_NAME_PREFIX):
continue
if hub_name == active_hub_name:
continue

creation_time = hub.get("CreationTime")
# Only delete hubs we can confirm are older than the cutoff. If the
# creation time is unavailable, err on the side of keeping the hub.
if creation_time is None:
continue
if creation_time.tzinfo is None:
creation_time = creation_time.replace(tzinfo=timezone.utc)
if creation_time >= cutoff:
continue

for hub in hubs_to_delete:
try:
print(f"Deleting old hub: {hub['HubName']}")
_delete_hubs(sagemaker_session, hub["HubName"])
print(f"Deleting stale hub: {hub_name}")
_delete_hubs(sagemaker_session, hub_name)
except Exception as e:
print(f"Failed to delete hub {hub['HubName']}: {e}")
print(f"Failed to delete hub {hub_name}: {e}")
except Exception as e:
print(f"Failed to cleanup old hubs: {e}")

Expand Down Expand Up @@ -210,8 +230,89 @@ def _delete_hub_contents(sagemaker_session, hub_name, model):
)


def _hub_state_root(config):
"""Return the run-level tmp dir shared by the xdist controller and workers.

The controller's basetemp is the run root (e.g. ``.../pytest-N``) while each
worker's basetemp is a ``popen-gw*`` subdir of it. Normalizing to the run
root gives every process the same location for the shared state file.

Works across pytest versions: prefers the ``TempPathFactory`` attached as
``config._tmp_path_factory`` and falls back to the legacy ``_tmpdirhandler``.
"""
factory = getattr(config, "_tmp_path_factory", None)
if factory is not None:
basetemp = pathlib.Path(str(factory.getbasetemp()))
else:
basetemp = pathlib.Path(str(config._tmpdirhandler.getbasetemp()))

if basetemp.name.startswith("popen-gw"):
return basetemp.parent
return basetemp


@pytest.fixture(scope="session", autouse=True)
def setup(request):
_setup()

request.addfinalizer(_teardown)
"""Ensure a single shared private hub exists for the whole test run.

Under pytest-xdist every worker is a separate process, so a naive
``scope="session"`` fixture would create one hub per worker. With high
parallelism (e.g. ``-n 120``) that quickly exhausts the per-account private
hub limit (100). All workers therefore coordinate through a lock file and a
shared JSON state file: the first worker creates the hub, the rest reuse it.

The hub is intentionally NOT deleted from a worker finalizer. xdist
distributes tests dynamically, so a worker can finish its whole session (and
run its finalizers) before another worker even reaches its first hub test;
reference counting in that finalizer would delete the hub out from under
workers still using it ("Hub ... does not exist" failures). Teardown instead
runs exactly once, after all workers finish, in ``pytest_sessionfinish`` on
the controller process.
"""
root_tmp_dir = _hub_state_root(request.config)
state_file = root_tmp_dir / "jumpstart_hub_state.json"
lock_file = root_tmp_dir / "jumpstart_hub_state.json.lock"

with FileLock(str(lock_file)):
if state_file.is_file():
state = json.loads(state_file.read_text())
else:
test_suite_id = get_test_suite_id()
test_hub_name = f"{HUB_NAME_PREFIX}{test_suite_id}"
_setup(test_suite_id=test_suite_id, test_hub_name=test_hub_name)
state = {
"test_suite_id": test_suite_id,
"test_hub_name": test_hub_name,
}
state_file.write_text(json.dumps(state))

# Ensure this worker's environment points at the shared hub.
os.environ.update({ENV_VAR_JUMPSTART_SDK_TEST_SUITE_ID: state["test_suite_id"]})
os.environ.update({ENV_VAR_JUMPSTART_SDK_TEST_HUB_NAME: state["test_hub_name"]})


def pytest_sessionfinish(session, exitstatus):
"""Tear down the shared hub once, after all xdist workers have finished.

xdist workers carry a ``workerinput`` attribute on their config; only the
controller (or a non-xdist run, which has no workerinput) performs teardown.
Running here guarantees no worker is still using the hub.
"""
if hasattr(session.config, "workerinput"):
return # xdist worker: the controller handles teardown.

root_tmp_dir = _hub_state_root(session.config)
state_file = root_tmp_dir / "jumpstart_hub_state.json"
lock_file = root_tmp_dir / "jumpstart_hub_state.json.lock"

with FileLock(str(lock_file)):
if not state_file.is_file():
return
state = json.loads(state_file.read_text())
try:
_teardown(
test_suite_id=state["test_suite_id"],
test_hub_name=state["test_hub_name"],
)
finally:
state_file.unlink()
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@

import pytest
from sagemaker.jumpstart.constants import JUMPSTART_DEFAULT_REGION_NAME
from sagemaker.jumpstart.hub.hub import Hub

from sagemaker.jumpstart.estimator import JumpStartEstimator
from sagemaker.jumpstart.utils import get_jumpstart_content_bucket
Expand All @@ -28,10 +27,9 @@
JUMPSTART_TAG,
)
from tests.integ.sagemaker.jumpstart.utils import (
get_public_hub_model_arn,
get_sm_session,
with_exponential_backoff,
get_training_dataset_for_model_and_version,
add_model_references_to_hub,
)

MAX_INIT_TIME_SECONDS = 5
Expand All @@ -43,23 +41,13 @@
}


@with_exponential_backoff()
def create_model_reference(hub_instance, model_arn):
try:
hub_instance.create_model_reference(model_arn=model_arn)
except Exception:
pass


@pytest.fixture(scope="session")
def add_model_references():
# Create Model References to test in Hub
hub_instance = Hub(
hub_name=os.environ[ENV_VAR_JUMPSTART_SDK_TEST_HUB_NAME], sagemaker_session=get_sm_session()
# Create Model References to test in Hub (idempotent + waits for readiness)
add_model_references_to_hub(
hub_name=os.environ[ENV_VAR_JUMPSTART_SDK_TEST_HUB_NAME],
model_ids=TEST_MODEL_IDS,
)
for model in TEST_MODEL_IDS:
model_arn = get_public_hub_model_arn(hub_instance, model)
create_model_reference(hub_instance, model_arn)


def test_jumpstart_hub_estimator(setup, add_model_references):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@

import pytest
from sagemaker.enums import EndpointType
from sagemaker.jumpstart.hub.hub import Hub
from sagemaker.jumpstart.hub.utils import generate_hub_arn_for_init_kwargs
from sagemaker.predictor import retrieve_default

Expand All @@ -30,9 +29,8 @@
JUMPSTART_TAG,
)
from tests.integ.sagemaker.jumpstart.utils import (
get_public_hub_model_arn,
get_sm_session,
with_exponential_backoff,
add_model_references_to_hub,
)

MAX_INIT_TIME_SECONDS = 5
Expand All @@ -46,23 +44,13 @@
}


@with_exponential_backoff()
def create_model_reference(hub_instance, model_arn):
try:
hub_instance.create_model_reference(model_arn=model_arn)
except Exception:
pass


@pytest.fixture(scope="session")
def add_model_references():
# Create Model References to test in Hub
hub_instance = Hub(
hub_name=os.environ[ENV_VAR_JUMPSTART_SDK_TEST_HUB_NAME], sagemaker_session=get_sm_session()
# Create Model References to test in Hub (idempotent + waits for readiness)
add_model_references_to_hub(
hub_name=os.environ[ENV_VAR_JUMPSTART_SDK_TEST_HUB_NAME],
model_ids=TEST_MODEL_IDS,
)
for model in TEST_MODEL_IDS:
model_arn = get_public_hub_model_arn(hub_instance, model)
create_model_reference(hub_instance, model_arn)


def test_jumpstart_hub_model(setup, add_model_references):
Expand Down
Loading
Loading