|
77 | 77 | ) |
78 | 78 | from nemoguardrails.kb.kb import KnowledgeBase |
79 | 79 | from nemoguardrails.llm.cache import CacheInterface, LFUCache |
| 80 | +from nemoguardrails.llm.clients._errors import _redact_secrets |
80 | 81 | from nemoguardrails.llm.models.initializer import ( |
81 | 82 | ModelInitializationError, |
82 | 83 | init_llm_model, |
@@ -899,14 +900,9 @@ async def generate_async( |
899 | 900 | log.error("Error in generate_async: %s", e, exc_info=True) |
900 | 901 | streaming_handler = streaming_handler_var.get() |
901 | 902 | if streaming_handler: |
902 | | - # Push an error chunk instead of None. |
903 | | - error_message = str(e) |
904 | | - error_dict = extract_error_json(error_message) |
905 | | - error_payload: str = json.dumps(error_dict) |
| 903 | + error_payload: str = _build_streaming_error_payload(e) |
906 | 904 | await streaming_handler.push_chunk(error_payload) |
907 | | - # push a termination signal |
908 | 905 | await streaming_handler.push_chunk(END_OF_STREAM) # type: ignore |
909 | | - # Re-raise the exact exception |
910 | 906 | raise |
911 | 907 | else: |
912 | 908 | # In generation mode, by default the bot response is an instant action. |
@@ -1265,12 +1261,8 @@ async def _generation_task(): |
1265 | 1261 | state=state, |
1266 | 1262 | ) |
1267 | 1263 | except Exception as e: |
1268 | | - # If an exception occurs during generation, push it to the streaming handler as a json string |
1269 | | - # This ensures the streaming pipeline is properly terminated |
1270 | 1264 | log.error(f"Error in generation task: {e}", exc_info=True) |
1271 | | - error_message = str(e) |
1272 | | - error_dict = extract_error_json(error_message) |
1273 | | - error_payload = json.dumps(error_dict) |
| 1265 | + error_payload = _build_streaming_error_payload(e) |
1274 | 1266 | await streaming_handler.push_chunk(error_payload) |
1275 | 1267 | await streaming_handler.push_chunk(END_OF_STREAM) # type: ignore |
1276 | 1268 |
|
@@ -1931,3 +1923,42 @@ def _get_last_response_content(response: "GenerationResponse") -> str: |
1931 | 1923 | if isinstance(response.response, str): |
1932 | 1924 | return response.response |
1933 | 1925 | return "" |
| 1926 | + |
| 1927 | + |
| 1928 | +def _build_streaming_error_payload(e: Exception) -> str: |
| 1929 | + """Build a JSON error payload for SSE streaming from an exception. |
| 1930 | +
|
| 1931 | + Normalizes all error shapes from extract_error_json into the |
| 1932 | + {"error": {"message", "type", "code"}} format that iorails.py |
| 1933 | + expects for error chunk detection. |
| 1934 | + """ |
| 1935 | + error_dict = extract_error_json(str(e)) |
| 1936 | + if not isinstance(error_dict, dict): |
| 1937 | + error_dict = {} |
| 1938 | + error_val = error_dict.get("error") |
| 1939 | + status = getattr(e, "status", None) |
| 1940 | + error_type = "downstream_error" if status is not None else "generation_error" |
| 1941 | + error_code = status if status is not None else "generation_failed" |
| 1942 | + |
| 1943 | + if isinstance(error_val, dict): |
| 1944 | + error_val["message"] = _redact_secrets(error_val.get("message", "")) |
| 1945 | + if status is not None: |
| 1946 | + error_val["code"] = status |
| 1947 | + error_val["type"] = "downstream_error" |
| 1948 | + else: |
| 1949 | + error_val.setdefault("type", error_type) |
| 1950 | + error_val.setdefault("code", error_code) |
| 1951 | + elif isinstance(error_val, str): |
| 1952 | + error_dict["error"] = { |
| 1953 | + "message": _redact_secrets(error_val), |
| 1954 | + "type": error_type, |
| 1955 | + "code": error_code, |
| 1956 | + } |
| 1957 | + else: |
| 1958 | + error_dict["error"] = { |
| 1959 | + "message": _redact_secrets(str(e)), |
| 1960 | + "type": error_type, |
| 1961 | + "code": error_code, |
| 1962 | + } |
| 1963 | + |
| 1964 | + return json.dumps(error_dict) |
0 commit comments