Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .oagen-manifest.json
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
{
"version": 2,
"language": "python",
"generatedAt": "2026-05-06T23:30:58.090Z",
"generatedAt": "2026-05-11T15:56:51.952Z",
"files": [
"src/workos/_client.py",
"src/workos/admin_portal/__init__.py",
Expand Down
79 changes: 59 additions & 20 deletions src/workos/_base_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,8 @@
import random
from datetime import datetime, timezone
from email.utils import parsedate_to_datetime
from typing import Any, Dict, Optional, Type, cast, overload
from typing import Any, Dict, Optional, Sequence, Type, cast, overload
from urllib.parse import quote

import httpx

Expand Down Expand Up @@ -53,8 +54,15 @@ def __init__(
request_timeout: Optional[int] = None,
jwt_leeway: float = 0.0,
max_retries: int = MAX_RETRIES,
is_public: bool = False,
) -> None:
self._api_key = api_key or os.environ.get("WORKOS_API_KEY")
self._is_public = is_public
# Public clients (PKCE / browser / mobile / CLI) must never attach
# an API key, even if WORKOS_API_KEY is present in the environment.
if is_public:
self._api_key: Optional[str] = None
else:
self._api_key = api_key or os.environ.get("WORKOS_API_KEY")
self.client_id = client_id or os.environ.get("WORKOS_CLIENT_ID")
if not self._api_key and not self.client_id:
raise ValueError(
Expand All @@ -80,12 +88,14 @@ def base_url(self) -> str:
"""The base URL for API requests."""
return self._base_url

def build_url(self, path: str, params: Optional[Dict[str, Any]] = None) -> str:
def build_url(
self, path: Sequence[str], params: Optional[Dict[str, Any]] = None
) -> str:
"""Build a full URL with query parameters for redirect/authorization endpoints."""
from urllib.parse import urlencode

base = self._base_url.rstrip("/")
url = f"{base}/{path}"
url = f"{base}/{self._encode_path(path)}"
if params:
url = f"{url}?{urlencode(params)}"
return url
Expand Down Expand Up @@ -128,6 +138,27 @@ def _resolve_base_url(self, request_options: Optional[RequestOptions]) -> str:
return str(base_url).rstrip("/")
return self._base_url.rstrip("/")

@staticmethod
def _encode_path(path: Sequence[str]) -> str:
"""Percent-encode each path segment and join with ``/``.

Callers pass each path component as a separate element (e.g.
``("organizations", organization_id)``). Each element is URL-encoded
with ``safe=""`` so a caller-supplied id containing ``/``, ``?``,
``#``, ``%``, or ``..`` cannot escape its intended segment — this is
the structural protection against forged cross-resource API requests
under the application's API key.

A bare string would be silently iterable as a sequence of single
characters; we reject it explicitly so a forgotten tuple wrapper at a
call site fails loudly instead of producing a per-character URL.
"""
if isinstance(path, str):
raise TypeError(
"path must be a sequence of segments (e.g. a tuple), not a str"
)
return "/".join(quote(str(seg), safe="") for seg in path)

def _resolve_timeout(self, request_options: Optional[RequestOptions]) -> float:
timeout = self._request_timeout
if request_options:
Expand Down Expand Up @@ -332,6 +363,7 @@ def __init__(
request_timeout: Optional[int] = None,
jwt_leeway: float = 0.0,
max_retries: int = MAX_RETRIES,
is_public: bool = False,
) -> None:
"""Initialize the WorkOS client.

Expand All @@ -342,6 +374,10 @@ def __init__(
request_timeout: HTTP request timeout in seconds. Falls back to WORKOS_REQUEST_TIMEOUT or 60.
jwt_leeway: JWT clock skew leeway in seconds.
max_retries: Maximum number of retries for failed requests. Defaults to 3.
is_public: When True, mark this client as public (PKCE / browser
/ mobile / CLI). The API key is forced to None and the
``WORKOS_API_KEY`` environment variable is ignored. Use
``create_public_client`` instead of setting this directly.

Raises:
ValueError: If neither api_key nor client_id is provided, directly or via environment variables.
Expand All @@ -353,6 +389,7 @@ def __init__(
request_timeout=request_timeout,
jwt_leeway=jwt_leeway,
max_retries=max_retries,
is_public=is_public,
)
self._client = httpx.Client(
timeout=self._request_timeout, follow_redirects=True
Expand All @@ -372,7 +409,7 @@ def __exit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None:
def request(
self,
method: str,
path: str,
path: Sequence[str],
*,
model: Type[D],
params: Optional[Dict[str, Any]] = ...,
Expand All @@ -385,7 +422,7 @@ def request(
def request(
self,
method: str,
path: str,
path: Sequence[str],
*,
model: None = ...,
params: Optional[Dict[str, Any]] = ...,
Expand All @@ -397,7 +434,7 @@ def request(
def request(
self,
method: str,
path: str,
path: Sequence[str],
*,
params: Optional[Dict[str, Any]] = None,
body: Optional[Dict[str, Any]] = None,
Expand All @@ -406,7 +443,7 @@ def request(
request_options: Optional[RequestOptions] = None,
) -> Any:
"""Make an HTTP request with retry logic."""
url = f"{self._resolve_base_url(request_options)}/{path}"
url = f"{self._resolve_base_url(request_options)}/{self._encode_path(path)}"
headers = self._build_headers(method, idempotency_key, request_options)
timeout = self._resolve_timeout(request_options)
max_retries = self._resolve_max_retries(request_options)
Expand Down Expand Up @@ -453,7 +490,7 @@ def request(
def request_raw(
self,
method: str,
path: str,
path: Sequence[str],
*,
params: Optional[Dict[str, Any]] = None,
body: Optional[Dict[str, Any]] = None,
Expand All @@ -478,7 +515,7 @@ def request_raw(
def request_list(
self,
method: str,
path: str,
path: Sequence[str],
*,
params: Optional[Dict[str, Any]] = None,
body: Optional[Dict[str, Any]] = None,
Expand All @@ -500,14 +537,14 @@ def request_list(
)
if not isinstance(result, list):
raise WorkOSError(
f"Expected array response from {method.upper()} /{path}, got {type(result).__name__}"
f"Expected array response from {method.upper()} /{'/'.join(path)}, got {type(result).__name__}"
)
return result

def request_page(
self,
method: str,
path: str,
path: Sequence[str],
*,
model: Type[D],
params: Optional[Dict[str, Any]] = None,
Expand Down Expand Up @@ -557,6 +594,7 @@ def __init__(
request_timeout: Optional[int] = None,
jwt_leeway: float = 0.0,
max_retries: int = MAX_RETRIES,
is_public: bool = False,
) -> None:
"""Initialize the async WorkOS client.

Expand All @@ -578,6 +616,7 @@ def __init__(
request_timeout=request_timeout,
jwt_leeway=jwt_leeway,
max_retries=max_retries,
is_public=is_public,
)
self._client = httpx.AsyncClient(
timeout=self._request_timeout, follow_redirects=True
Expand All @@ -597,7 +636,7 @@ async def __aexit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None:
async def request(
self,
method: str,
path: str,
path: Sequence[str],
*,
model: Type[D],
params: Optional[Dict[str, Any]] = ...,
Expand All @@ -610,7 +649,7 @@ async def request(
async def request(
self,
method: str,
path: str,
path: Sequence[str],
*,
model: None = ...,
params: Optional[Dict[str, Any]] = ...,
Expand All @@ -622,7 +661,7 @@ async def request(
async def request(
self,
method: str,
path: str,
path: Sequence[str],
*,
params: Optional[Dict[str, Any]] = None,
body: Optional[Dict[str, Any]] = None,
Expand All @@ -631,7 +670,7 @@ async def request(
request_options: Optional[RequestOptions] = None,
) -> Any:
"""Make an async HTTP request with retry logic."""
url = f"{self._resolve_base_url(request_options)}/{path}"
url = f"{self._resolve_base_url(request_options)}/{self._encode_path(path)}"
headers = self._build_headers(method, idempotency_key, request_options)
timeout = self._resolve_timeout(request_options)
max_retries = self._resolve_max_retries(request_options)
Expand Down Expand Up @@ -678,7 +717,7 @@ async def request(
async def request_raw(
self,
method: str,
path: str,
path: Sequence[str],
*,
params: Optional[Dict[str, Any]] = None,
body: Optional[Dict[str, Any]] = None,
Expand All @@ -703,7 +742,7 @@ async def request_raw(
async def request_list(
self,
method: str,
path: str,
path: Sequence[str],
*,
params: Optional[Dict[str, Any]] = None,
body: Optional[Dict[str, Any]] = None,
Expand All @@ -725,14 +764,14 @@ async def request_list(
)
if not isinstance(result, list):
raise WorkOSError(
f"Expected array response from {method.upper()} /{path}, got {type(result).__name__}"
f"Expected array response from {method.upper()} /{'/'.join(path)}, got {type(result).__name__}"
)
return result

async def request_page(
self,
method: str,
path: str,
path: Sequence[str],
*,
model: Type[D],
params: Optional[Dict[str, Any]] = None,
Expand Down
2 changes: 1 addition & 1 deletion src/workos/actions.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ def _verify_signature(
timestamp_in_seconds = int(issued_timestamp) / 1000
seconds_since_issued = current_time - timestamp_in_seconds

if seconds_since_issued > tolerance:
if abs(seconds_since_issued) > tolerance:
raise ValueError("Timestamp outside the tolerance zone")

body_str = payload.decode("utf-8") if isinstance(payload, bytes) else payload
Expand Down
4 changes: 2 additions & 2 deletions src/workos/admin_portal/_resource.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ def generate_link(
}
return self._client.request(
method="post",
path="portal/generate_link",
path=("portal", "generate_link"),
body=body,
model=PortalLinkResponse,
request_options=request_options,
Expand Down Expand Up @@ -149,7 +149,7 @@ async def generate_link(
}
return await self._client.request(
method="post",
path="portal/generate_link",
path=("portal", "generate_link"),
body=body,
model=PortalLinkResponse,
request_options=request_options,
Expand Down
17 changes: 8 additions & 9 deletions src/workos/api_keys/_resource.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
from __future__ import annotations

from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union
from urllib.parse import quote

if TYPE_CHECKING:
from .._client import AsyncWorkOSClient, WorkOSClient
Expand Down Expand Up @@ -67,7 +66,7 @@ def list_organization_api_keys(
}
return self._client.request_page(
method="get",
path=f"organizations/{quote(str(organization_id), safe='')}/api_keys",
path=("organizations", str(organization_id), "api_keys"),
model=OrganizationApiKey,
params=params,
request_options=request_options,
Expand Down Expand Up @@ -111,7 +110,7 @@ def create_organization_api_key(
}
return self._client.request(
method="post",
path=f"organizations/{quote(str(organization_id), safe='')}/api_keys",
path=("organizations", str(organization_id), "api_keys"),
body=body,
model=OrganizationApiKeyWithValue,
request_options=request_options,
Expand Down Expand Up @@ -145,7 +144,7 @@ def create_validation(
}
return self._client.request(
method="post",
path="api_keys/validations",
path=("api_keys", "validations"),
body=body,
model=ApiKeyValidationResponse,
request_options=request_options,
Expand Down Expand Up @@ -173,7 +172,7 @@ def delete_api_key(
"""
self._client.request(
method="delete",
path=f"api_keys/{quote(str(id), safe='')}",
path=("api_keys", str(id)),
request_options=request_options,
)

Expand Down Expand Up @@ -227,7 +226,7 @@ async def list_organization_api_keys(
}
return await self._client.request_page(
method="get",
path=f"organizations/{quote(str(organization_id), safe='')}/api_keys",
path=("organizations", str(organization_id), "api_keys"),
model=OrganizationApiKey,
params=params,
request_options=request_options,
Expand Down Expand Up @@ -271,7 +270,7 @@ async def create_organization_api_key(
}
return await self._client.request(
method="post",
path=f"organizations/{quote(str(organization_id), safe='')}/api_keys",
path=("organizations", str(organization_id), "api_keys"),
body=body,
model=OrganizationApiKeyWithValue,
request_options=request_options,
Expand Down Expand Up @@ -305,7 +304,7 @@ async def create_validation(
}
return await self._client.request(
method="post",
path="api_keys/validations",
path=("api_keys", "validations"),
body=body,
model=ApiKeyValidationResponse,
request_options=request_options,
Expand Down Expand Up @@ -333,6 +332,6 @@ async def delete_api_key(
"""
await self._client.request(
method="delete",
path=f"api_keys/{quote(str(id), safe='')}",
path=("api_keys", str(id)),
request_options=request_options,
)
Loading
Loading