Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
103 changes: 53 additions & 50 deletions gql/transport/aiohttp.py
Original file line number Diff line number Diff line change
Expand Up @@ -540,70 +540,73 @@ async def _parse_multipart_part(
:param part: aiohttp BodyPartReader for the part
:return: ExecutionResult or None if part is empty/heartbeat
"""
# Verify the part has the correct content type
try:
body = await part.text()
body = body.strip()
except UnicodeDecodeError as e:
log.warning(f"Failed to decode part: {ascii(e)}")
return None

content_type = part.headers.get(aiohttp.hdrs.CONTENT_TYPE, "")

if not content_type and not body:
log.debug("Skipping part with no content-type and no body")
return None

# Verify the part has the correct content type
if not content_type.startswith("application/json"):
raise TransportProtocolError(
f"Unexpected part content-type: {content_type}. "
"Expected 'application/json'."
)

try:
# Read the part content as text
body = await part.text()
body = body.strip()

if log.isEnabledFor(logging.DEBUG):
log.debug("<<< %s", ascii(body or "(empty body, skipping)"))
if log.isEnabledFor(logging.DEBUG):
log.debug("<<< %s", ascii(body or "(empty body, skipping)"))

if not body:
return None
if not body:
return None

# Parse JSON body using custom deserializer
try:
data = self.json_deserialize(body)

# Handle heartbeats - empty JSON objects
if not data:
log.debug("Received heartbeat, ignoring")
return None

# The multipart subscription protocol wraps data in a "payload" property
if "payload" not in data:
log.warning("Invalid response: missing 'payload' field")
return None

payload = data["payload"]

# Check for transport-level errors (payload is null)
if payload is None:
# If there are errors, this is a transport-level error
errors = data.get("errors")
if errors:
error_messages = [
error.get("message", "Unknown transport error")
for error in errors
]

for message in error_messages:
log.error(f"Transport error: {message}")

raise TransportServerError("\n\n".join(error_messages))
else:
# Null payload without errors - just skip this part
return None

# Extract GraphQL data from payload
return ExecutionResult(
data=payload.get("data"),
errors=payload.get("errors"),
extensions=payload.get("extensions"),
)
except json.JSONDecodeError as e:
log.warning(
f"Failed to parse JSON: {ascii(e)}, "
f"body: {ascii(body[:100]) if body else ''}"
)
return None
except UnicodeDecodeError as e:
log.warning(f"Failed to decode part: {ascii(e)}")

# Handle heartbeats - empty JSON objects
if not data:
log.debug("Received heartbeat, ignoring")
return None

# The multipart subscription protocol wraps data in a "payload" property
if "payload" not in data:
log.warning("Invalid response: missing 'payload' field")
return None

payload = data["payload"]

# Check for transport-level errors (payload is null)
if payload is None:
# If there are errors, this is a transport-level error
errors = data.get("errors")
if errors:
error_messages = [
error.get("message", "Unknown transport error") for error in errors
]

for message in error_messages:
log.error(f"Transport error: {message}")

raise TransportServerError("\n\n".join(error_messages))
else:
# Null payload without errors - just skip this part
return None

# Extract GraphQL data from payload
return ExecutionResult(
data=payload.get("data"),
errors=payload.get("errors"),
extensions=payload.get("extensions"),
)
32 changes: 32 additions & 0 deletions tests/test_aiohttp_multipart.py
Original file line number Diff line number Diff line change
Expand Up @@ -480,6 +480,38 @@ async def test_aiohttp_multipart_wrong_part_content_type(multipart_server):
assert "text/html" in str(exc_info.value)


@pytest.mark.asyncio
async def test_aiohttp_multipart_empty_part_no_content_type_skipped(multipart_server):
"""Test that empty parts with no content-type are skipped."""
from gql.transport.aiohttp import AIOHTTPTransport

book1_payload = json.dumps({"payload": {"data": {"book": book1}}})

parts = [
("--graphql\r\n" "\r\n" "\r\n"),
(
"--graphql\r\n"
"Content-Type: application/json\r\n"
"\r\n"
f"{book1_payload}\r\n"
),
"--graphql--\r\n",
]

server = await multipart_server(parts)
url = server.make_url("/")
transport = AIOHTTPTransport(url=url)

async with Client(transport=transport) as session:
query = gql(subscription_str)
results = []
async for result in session.subscribe(query):
results.append(result)

assert len(results) == 1
assert results[0]["book"]["title"] == "Book 1"


@pytest.mark.asyncio
async def test_aiohttp_multipart_response_headers(multipart_server):
"""Test that response headers are captured in the transport."""
Expand Down
Loading