From 0c58649d550367026bb1cb9689be931b866860cd Mon Sep 17 00:00:00 2001 From: NguyenCong2k <123174148+NguyenCong2k@users.noreply.github.com> Date: Wed, 13 May 2026 10:01:00 +0700 Subject: [PATCH] Validate OIDC allowed hosts before cache reuse --- pymongo/asynchronous/auth_oidc.py | 6 +++--- pymongo/synchronous/auth_oidc.py | 6 +++--- test/asynchronous/test_auth_oidc.py | 20 ++++++++++++++++++++ test/test_auth_oidc.py | 20 ++++++++++++++++++++ 4 files changed, 46 insertions(+), 6 deletions(-) diff --git a/pymongo/asynchronous/auth_oidc.py b/pymongo/asynchronous/auth_oidc.py index f8f046bd94..6001bea48e 100644 --- a/pymongo/asynchronous/auth_oidc.py +++ b/pymongo/asynchronous/auth_oidc.py @@ -49,9 +49,6 @@ def _get_authenticator( credentials: MongoCredential, address: tuple[str, int] ) -> _OIDCAuthenticator: - if credentials.cache.data: - return credentials.cache.data - # Extract values. principal_name = credentials.username properties = credentials.mechanism_properties @@ -70,6 +67,9 @@ def _get_authenticator( f"Refusing to connect to {address[0]}, which is not in authOIDCAllowedHosts: {allowed_hosts}" ) + if credentials.cache.data: + return credentials.cache.data + # Get or create the cache data. credentials.cache.data = _OIDCAuthenticator(username=principal_name, properties=properties) return credentials.cache.data diff --git a/pymongo/synchronous/auth_oidc.py b/pymongo/synchronous/auth_oidc.py index 583ee39f67..0db812bb6d 100644 --- a/pymongo/synchronous/auth_oidc.py +++ b/pymongo/synchronous/auth_oidc.py @@ -49,9 +49,6 @@ def _get_authenticator( credentials: MongoCredential, address: tuple[str, int] ) -> _OIDCAuthenticator: - if credentials.cache.data: - return credentials.cache.data - # Extract values. principal_name = credentials.username properties = credentials.mechanism_properties @@ -70,6 +67,9 @@ def _get_authenticator( f"Refusing to connect to {address[0]}, which is not in authOIDCAllowedHosts: {allowed_hosts}" ) + if credentials.cache.data: + return credentials.cache.data + # Get or create the cache data. credentials.cache.data = _OIDCAuthenticator(username=principal_name, properties=properties) return credentials.cache.data diff --git a/test/asynchronous/test_auth_oidc.py b/test/asynchronous/test_auth_oidc.py index 3567d7706b..6cfe3b2286 100644 --- a/test/asynchronous/test_auth_oidc.py +++ b/test/asynchronous/test_auth_oidc.py @@ -117,6 +117,26 @@ async def fail_point(self, command_args): await client.close() +class TestOIDCAllowedHostsCache(unittest.TestCase): + class HumanCallback(OIDCCallback): + def fetch(self, context): + return OIDCCallbackResult(access_token="token") + + def test_allowed_hosts_checked_before_cached_authenticator_reuse(self): + props = { + "OIDC_HUMAN_CALLBACK": self.HumanCallback(), + "ALLOWED_HOSTS": ["good.example.com"], + } + extra = {"authmechanismproperties": props} + credentials = _build_credentials_tuple("MONGODB-OIDC", None, "user", None, extra, "test") + + authenticator = _get_authenticator(credentials, ("good.example.com", 27017)) + self.assertIs(authenticator, credentials.cache.data) + + with self.assertRaisesRegex(ConfigurationError, "evil.example.com"): + _get_authenticator(credentials, ("evil.example.com", 27017)) + + class TestAuthOIDCHuman(OIDCTestBase): uri: str diff --git a/test/test_auth_oidc.py b/test/test_auth_oidc.py index e88e067b2c..28b9c0a262 100644 --- a/test/test_auth_oidc.py +++ b/test/test_auth_oidc.py @@ -117,6 +117,26 @@ def fail_point(self, command_args): client.close() +class TestOIDCAllowedHostsCache(unittest.TestCase): + class HumanCallback(OIDCCallback): + def fetch(self, context): + return OIDCCallbackResult(access_token="token") + + def test_allowed_hosts_checked_before_cached_authenticator_reuse(self): + props = { + "OIDC_HUMAN_CALLBACK": self.HumanCallback(), + "ALLOWED_HOSTS": ["good.example.com"], + } + extra = {"authmechanismproperties": props} + credentials = _build_credentials_tuple("MONGODB-OIDC", None, "user", None, extra, "test") + + authenticator = _get_authenticator(credentials, ("good.example.com", 27017)) + self.assertIs(authenticator, credentials.cache.data) + + with self.assertRaisesRegex(ConfigurationError, "evil.example.com"): + _get_authenticator(credentials, ("evil.example.com", 27017)) + + class TestAuthOIDCHuman(OIDCTestBase): uri: str