diff --git a/gql/transport/aiohttp.py b/gql/transport/aiohttp.py index e7eff55f..261b4e01 100644 --- a/gql/transport/aiohttp.py +++ b/gql/transport/aiohttp.py @@ -540,70 +540,73 @@ async def _parse_multipart_part( :param part: aiohttp BodyPartReader for the part :return: ExecutionResult or None if part is empty/heartbeat """ - # Verify the part has the correct content type + try: + body = await part.text() + body = body.strip() + except UnicodeDecodeError as e: + log.warning(f"Failed to decode part: {ascii(e)}") + return None + content_type = part.headers.get(aiohttp.hdrs.CONTENT_TYPE, "") + + if not content_type and not body: + log.debug("Skipping part with no content-type and no body") + return None + + # Verify the part has the correct content type if not content_type.startswith("application/json"): raise TransportProtocolError( f"Unexpected part content-type: {content_type}. " "Expected 'application/json'." ) - try: - # Read the part content as text - body = await part.text() - body = body.strip() - - if log.isEnabledFor(logging.DEBUG): - log.debug("<<< %s", ascii(body or "(empty body, skipping)")) + if log.isEnabledFor(logging.DEBUG): + log.debug("<<< %s", ascii(body or "(empty body, skipping)")) - if not body: - return None + if not body: + return None - # Parse JSON body using custom deserializer + try: data = self.json_deserialize(body) - - # Handle heartbeats - empty JSON objects - if not data: - log.debug("Received heartbeat, ignoring") - return None - - # The multipart subscription protocol wraps data in a "payload" property - if "payload" not in data: - log.warning("Invalid response: missing 'payload' field") - return None - - payload = data["payload"] - - # Check for transport-level errors (payload is null) - if payload is None: - # If there are errors, this is a transport-level error - errors = data.get("errors") - if errors: - error_messages = [ - error.get("message", "Unknown transport error") - for error in errors - ] - - for message in error_messages: - log.error(f"Transport error: {message}") - - raise TransportServerError("\n\n".join(error_messages)) - else: - # Null payload without errors - just skip this part - return None - - # Extract GraphQL data from payload - return ExecutionResult( - data=payload.get("data"), - errors=payload.get("errors"), - extensions=payload.get("extensions"), - ) except json.JSONDecodeError as e: log.warning( f"Failed to parse JSON: {ascii(e)}, " f"body: {ascii(body[:100]) if body else ''}" ) return None - except UnicodeDecodeError as e: - log.warning(f"Failed to decode part: {ascii(e)}") + + # Handle heartbeats - empty JSON objects + if not data: + log.debug("Received heartbeat, ignoring") + return None + + # The multipart subscription protocol wraps data in a "payload" property + if "payload" not in data: + log.warning("Invalid response: missing 'payload' field") return None + + payload = data["payload"] + + # Check for transport-level errors (payload is null) + if payload is None: + # If there are errors, this is a transport-level error + errors = data.get("errors") + if errors: + error_messages = [ + error.get("message", "Unknown transport error") for error in errors + ] + + for message in error_messages: + log.error(f"Transport error: {message}") + + raise TransportServerError("\n\n".join(error_messages)) + else: + # Null payload without errors - just skip this part + return None + + # Extract GraphQL data from payload + return ExecutionResult( + data=payload.get("data"), + errors=payload.get("errors"), + extensions=payload.get("extensions"), + ) diff --git a/tests/test_aiohttp_multipart.py b/tests/test_aiohttp_multipart.py index d71814a4..d43a2722 100644 --- a/tests/test_aiohttp_multipart.py +++ b/tests/test_aiohttp_multipart.py @@ -480,6 +480,38 @@ async def test_aiohttp_multipart_wrong_part_content_type(multipart_server): assert "text/html" in str(exc_info.value) +@pytest.mark.asyncio +async def test_aiohttp_multipart_empty_part_no_content_type_skipped(multipart_server): + """Test that empty parts with no content-type are skipped.""" + from gql.transport.aiohttp import AIOHTTPTransport + + book1_payload = json.dumps({"payload": {"data": {"book": book1}}}) + + parts = [ + ("--graphql\r\n" "\r\n" "\r\n"), + ( + "--graphql\r\n" + "Content-Type: application/json\r\n" + "\r\n" + f"{book1_payload}\r\n" + ), + "--graphql--\r\n", + ] + + server = await multipart_server(parts) + url = server.make_url("/") + transport = AIOHTTPTransport(url=url) + + async with Client(transport=transport) as session: + query = gql(subscription_str) + results = [] + async for result in session.subscribe(query): + results.append(result) + + assert len(results) == 1 + assert results[0]["book"]["title"] == "Book 1" + + @pytest.mark.asyncio async def test_aiohttp_multipart_response_headers(multipart_server): """Test that response headers are captured in the transport."""