diff --git a/optillm/server.py b/optillm/server.py index bc2c88f..0d9c6bb 100644 --- a/optillm/server.py +++ b/optillm/server.py @@ -203,24 +203,38 @@ def count_reasoning_tokens(text: str, tokenizer=None) -> int: plugin_approaches = {} +def _is_multimodal_content(content: list) -> bool: + for item in content: + if not isinstance(item, dict): + continue + item_type = item.get('type') + if item_type and item_type != 'text': + return True + if 'image_url' in item or 'image' in item: + return True + return False + def normalize_message_content(messages): """ - Ensure all message content fields are strings, not lists. - Some models don't handle list-format content correctly. + Flatten text-only list content to strings for models that expect plain text. + + Multimodal content (image_url, etc.) is preserved as a list so vision-capable + upstream models receive images intact. """ normalized_messages = [] for message in messages: normalized_message = message.copy() content = message.get('content', '') - # Convert list content to string if needed if isinstance(content, list): - # Extract text content from the list - text_content = ' '.join( - item.get('text', '') for item in content - if isinstance(item, dict) and item.get('type') == 'text' - ) - normalized_message['content'] = text_content + if _is_multimodal_content(content): + normalized_message['content'] = content + else: + text_content = ' '.join( + item.get('text', '') for item in content + if isinstance(item, dict) and item.get('type') == 'text' + ) + normalized_message['content'] = text_content normalized_messages.append(normalized_message) @@ -548,6 +562,67 @@ def execute_n_times(n: int, approaches, operation: str, system_prompt: str, init return responses[0], total_tokens return responses, total_tokens +def promote_tool_calls_to_first_choice(response_dict): + """Merge tool_calls from a non-zero choice into choices[0]. + + Some OpenAI-compatible providers return assistant text in choices[0] and + tool_calls in a later choice. Clients that only read choices[0] miss tools. + """ + if not isinstance(response_dict, dict): + return response_dict + + choices = response_dict.get('choices') + if not choices or len(choices) < 2: + return response_dict + + tool_idx = None + for idx, choice in enumerate(choices): + message = choice.get('message') or {} + if message.get('tool_calls'): + tool_idx = idx + break + + if tool_idx is None or tool_idx == 0: + return response_dict + + promoted = dict(response_dict) + text_choice = dict(choices[0]) + tool_choice = choices[tool_idx] + merged_message = dict(text_choice.get('message') or {}) + tool_message = tool_choice.get('message') or {} + + if tool_message.get('tool_calls'): + merged_message['tool_calls'] = tool_message['tool_calls'] + + text_choice['message'] = merged_message + text_choice['finish_reason'] = tool_choice.get('finish_reason') or 'tool_calls' + promoted['choices'] = [text_choice] + return promoted + +def generate_stream_passthrough(messages, client, model, kwargs, request_id=None): + """Stream upstream SSE chunks verbatim for transparent none-approach proxying.""" + if model.startswith('none-'): + model = model[5:] + + passthrough_kwargs = kwargs.copy() if kwargs else {} + passthrough_kwargs['stream'] = True + normalized_messages = normalize_message_content(messages) + + stream = client.chat.completions.create( + model=model, + messages=normalized_messages, + **passthrough_kwargs + ) + + for chunk in stream: + chunk_dict = chunk.model_dump() if hasattr(chunk, 'model_dump') else chunk + yield "data: " + json.dumps(chunk_dict) + "\n\n" + + yield "data: [DONE]\n\n" + + if request_id: + logger.info(f'Request {request_id}: Completed (streaming passthrough)') + def generate_streaming_response(final_response, model): # Generate a unique response ID response_id = f"chatcmpl-{int(time.time()*1000)}" @@ -821,24 +896,35 @@ def proxy(): contains_none = any(approach == 'none' for approach in approaches) if operation == 'SINGLE' and approaches[0] == 'none': - # Pass through the request including the n parameter - result, completion_tokens = execute_single_approach(approaches[0], system_prompt, initial_query, client, model, request_config, request_id) + passthrough_kwargs = request_config.copy() if request_config else {} + + if stream: + if request_id: + logger.info(f'Request {request_id}: Starting streaming passthrough') + return Response( + generate_stream_passthrough(messages, client, model, passthrough_kwargs, request_id), + content_type='text/event-stream' + ) + + passthrough_kwargs.pop('stream', None) + result = none_approach( + original_messages=messages, + client=client, + model=model, + request_id=request_id, + **passthrough_kwargs + ) + result = promote_tool_calls_to_first_choice(result) logger.debug(f'Direct proxy response: {result}') - # Log the final response and finalize conversation logging if conversation_logger and request_id: conversation_logger.log_final_response(request_id, result) conversation_logger.finalize_conversation(request_id) - if stream: - if request_id: - logger.info(f'Request {request_id}: Completed (streaming response)') - return Response(generate_streaming_response(extract_contents(result), model), content_type='text/event-stream') - else : - if request_id: - logger.info(f'Request {request_id}: Completed') - return jsonify(result), 200 + if request_id: + logger.info(f'Request {request_id}: Completed') + return jsonify(result), 200 elif operation == 'AND' or operation == 'OR': if contains_none: