243 lines
8.4 KiB
Python
243 lines
8.4 KiB
Python
from __future__ import annotations
|
|
|
|
import json
|
|
from dataclasses import dataclass
|
|
from typing import Any, Dict, Iterable, Iterator, List
|
|
|
|
from .config import BASE_INSTRUCTIONS, GPT5_CODEX_INSTRUCTIONS
|
|
from .fast_mode import ServiceTierResolution, resolve_service_tier
|
|
from .model_registry import (
|
|
allowed_efforts_for_model,
|
|
extract_reasoning_from_model_name,
|
|
normalize_model_name,
|
|
uses_codex_instructions,
|
|
)
|
|
from .reasoning import build_reasoning_param
|
|
from .session import ensure_session_id
|
|
|
|
|
|
@dataclass(frozen=True)
|
|
class ResponsesRequestError(Exception):
|
|
message: str
|
|
status_code: int = 400
|
|
code: str | None = None
|
|
|
|
def __str__(self) -> str:
|
|
return self.message
|
|
|
|
|
|
@dataclass(frozen=True)
|
|
class NormalizedResponsesRequest:
|
|
payload: Dict[str, Any]
|
|
requested_model: str | None
|
|
normalized_model: str
|
|
session_id: str
|
|
service_tier_resolution: ServiceTierResolution
|
|
|
|
|
|
def instructions_for_model(config: Dict[str, Any], model: str) -> str:
|
|
base = config.get("BASE_INSTRUCTIONS", BASE_INSTRUCTIONS)
|
|
if uses_codex_instructions(model):
|
|
codex = config.get("GPT5_CODEX_INSTRUCTIONS") or GPT5_CODEX_INSTRUCTIONS
|
|
if isinstance(codex, str) and codex.strip():
|
|
return codex
|
|
return base
|
|
|
|
|
|
def extract_client_session_id(headers: Any) -> str | None:
|
|
try:
|
|
return headers.get("X-Session-Id") or headers.get("session_id") or None
|
|
except Exception:
|
|
return None
|
|
|
|
|
|
def _input_items_for_session(raw_input: Any) -> List[Dict[str, Any]]:
|
|
if isinstance(raw_input, list):
|
|
return [item for item in raw_input if isinstance(item, dict)]
|
|
if isinstance(raw_input, dict):
|
|
return [raw_input]
|
|
if isinstance(raw_input, str) and raw_input.strip():
|
|
return [
|
|
{
|
|
"type": "message",
|
|
"role": "user",
|
|
"content": [{"type": "input_text", "text": raw_input}],
|
|
}
|
|
]
|
|
return []
|
|
|
|
|
|
def canonicalize_responses_input(raw_input: Any) -> Any:
|
|
if isinstance(raw_input, list):
|
|
return [item for item in raw_input if isinstance(item, dict)]
|
|
if isinstance(raw_input, dict):
|
|
return [raw_input]
|
|
if isinstance(raw_input, str):
|
|
return _input_items_for_session(raw_input)
|
|
return raw_input
|
|
|
|
|
|
def normalize_responses_payload(
|
|
payload: Dict[str, Any],
|
|
*,
|
|
config: Dict[str, Any],
|
|
client_session_id: str | None = None,
|
|
) -> NormalizedResponsesRequest:
|
|
requested_model = payload.get("model") if isinstance(payload.get("model"), str) else None
|
|
normalized_model = normalize_model_name(requested_model, config.get("DEBUG_MODEL"))
|
|
|
|
normalized = dict(payload)
|
|
normalized["model"] = normalized_model
|
|
|
|
if "input" in normalized:
|
|
normalized["input"] = canonicalize_responses_input(normalized.get("input"))
|
|
|
|
if "store" not in normalized:
|
|
normalized["store"] = False
|
|
|
|
instructions = normalized.get("instructions")
|
|
if not isinstance(instructions, str) or not instructions.strip():
|
|
instructions = instructions_for_model(config, normalized_model)
|
|
normalized["instructions"] = instructions
|
|
|
|
reasoning_effort = config.get("REASONING_EFFORT", "medium")
|
|
reasoning_summary = config.get("REASONING_SUMMARY", "auto")
|
|
reasoning_overrides = (
|
|
normalized.get("reasoning")
|
|
if isinstance(normalized.get("reasoning"), dict)
|
|
else extract_reasoning_from_model_name(requested_model)
|
|
)
|
|
normalized["reasoning"] = build_reasoning_param(
|
|
reasoning_effort,
|
|
reasoning_summary,
|
|
reasoning_overrides,
|
|
allowed_efforts=allowed_efforts_for_model(normalized_model),
|
|
)
|
|
|
|
include = normalized.get("include")
|
|
include_list = [item for item in include if isinstance(item, str)] if isinstance(include, list) else []
|
|
if "reasoning.encrypted_content" not in include_list:
|
|
include_list.append("reasoning.encrypted_content")
|
|
normalized["include"] = include_list
|
|
|
|
tools = normalized.get("tools")
|
|
if (not isinstance(tools, list) or not tools) and bool(config.get("DEFAULT_WEB_SEARCH")):
|
|
tool_choice = normalized.get("tool_choice")
|
|
if not (isinstance(tool_choice, str) and tool_choice.strip().lower() == "none"):
|
|
normalized["tools"] = [{"type": "web_search"}]
|
|
|
|
service_tier_resolution = resolve_service_tier(
|
|
normalized_model,
|
|
request_fast_mode=normalized.get("fast_mode"),
|
|
request_service_tier=normalized.get("service_tier"),
|
|
server_fast_mode=bool(config.get("FAST_MODE")),
|
|
)
|
|
if service_tier_resolution.error_message:
|
|
raise ResponsesRequestError(service_tier_resolution.error_message)
|
|
if service_tier_resolution.service_tier is None:
|
|
normalized.pop("service_tier", None)
|
|
else:
|
|
normalized["service_tier"] = service_tier_resolution.service_tier
|
|
normalized.pop("fast_mode", None)
|
|
|
|
input_items = _input_items_for_session(normalized.get("input"))
|
|
session_id = ensure_session_id(instructions, input_items, client_session_id)
|
|
prompt_cache_key = normalized.get("prompt_cache_key")
|
|
if not isinstance(prompt_cache_key, str) or not prompt_cache_key.strip():
|
|
normalized["prompt_cache_key"] = session_id
|
|
|
|
return NormalizedResponsesRequest(
|
|
payload=normalized,
|
|
requested_model=requested_model,
|
|
normalized_model=normalized_model,
|
|
session_id=session_id,
|
|
service_tier_resolution=service_tier_resolution,
|
|
)
|
|
|
|
|
|
def iter_sse_event_payloads(upstream: Any) -> Iterator[Dict[str, Any]]:
|
|
for raw in upstream.iter_lines(decode_unicode=False):
|
|
if not raw:
|
|
continue
|
|
line = raw.decode("utf-8", errors="ignore") if isinstance(raw, (bytes, bytearray)) else raw
|
|
if not line.startswith("data: "):
|
|
continue
|
|
data = line[len("data: ") :].strip()
|
|
if not data or data == "[DONE]":
|
|
if data == "[DONE]":
|
|
break
|
|
continue
|
|
try:
|
|
evt = json.loads(data)
|
|
except Exception:
|
|
continue
|
|
if isinstance(evt, dict):
|
|
yield evt
|
|
|
|
|
|
def aggregate_response_from_sse(
|
|
upstream: Any,
|
|
*,
|
|
on_event: Any | None = None,
|
|
) -> tuple[Dict[str, Any] | None, Dict[str, Any] | None]:
|
|
response_obj: Dict[str, Any] | None = None
|
|
error_obj: Dict[str, Any] | None = None
|
|
try:
|
|
for evt in iter_sse_event_payloads(upstream):
|
|
if callable(on_event):
|
|
try:
|
|
on_event(evt)
|
|
except Exception:
|
|
pass
|
|
response = evt.get("response")
|
|
if isinstance(response, dict):
|
|
response_obj = response
|
|
kind = evt.get("type")
|
|
if kind == "response.failed":
|
|
if isinstance(response, dict) and isinstance(response.get("error"), dict):
|
|
error_obj = {"error": response.get("error")}
|
|
else:
|
|
error_obj = {"error": {"message": "response.failed"}}
|
|
break
|
|
if kind == "response.completed":
|
|
break
|
|
finally:
|
|
upstream.close()
|
|
return response_obj, error_obj
|
|
|
|
|
|
def stream_upstream_bytes(
|
|
upstream: Any,
|
|
*,
|
|
on_event: Any | None = None,
|
|
) -> Iterable[bytes]:
|
|
buffer = b""
|
|
try:
|
|
for chunk in upstream.iter_content(chunk_size=None):
|
|
if chunk:
|
|
if callable(on_event):
|
|
if isinstance(chunk, bytes):
|
|
buffer += chunk
|
|
else:
|
|
buffer += str(chunk).encode("utf-8", errors="ignore")
|
|
while b"\n" in buffer:
|
|
line, buffer = buffer.split(b"\n", 1)
|
|
line = line.rstrip(b"\r")
|
|
if not line.startswith(b"data: "):
|
|
continue
|
|
data = line[len(b"data: ") :].strip()
|
|
if not data or data == b"[DONE]":
|
|
continue
|
|
try:
|
|
evt = json.loads(data.decode("utf-8", errors="ignore"))
|
|
except Exception:
|
|
evt = None
|
|
if isinstance(evt, dict):
|
|
try:
|
|
on_event(evt)
|
|
except Exception:
|
|
pass
|
|
yield chunk
|
|
finally:
|
|
upstream.close()
|