feat: add responses api, websocket support, and fast mode
This commit is contained in:
@@ -1,11 +1,13 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from flask import Flask, jsonify
|
||||
from flask_sock import Sock
|
||||
|
||||
from .config import BASE_INSTRUCTIONS, GPT5_CODEX_INSTRUCTIONS
|
||||
from .http import build_cors_headers
|
||||
from .routes_openai import openai_bp
|
||||
from .routes_ollama import ollama_bp
|
||||
from .websocket_routes import register_websocket_routes
|
||||
|
||||
|
||||
def create_app(
|
||||
@@ -14,6 +16,7 @@ def create_app(
|
||||
reasoning_effort: str = "medium",
|
||||
reasoning_summary: str = "auto",
|
||||
reasoning_compat: str = "think-tags",
|
||||
fast_mode: bool = False,
|
||||
debug_model: str | None = None,
|
||||
expose_reasoning_models: bool = False,
|
||||
default_web_search: bool = False,
|
||||
@@ -26,6 +29,7 @@ def create_app(
|
||||
REASONING_EFFORT=reasoning_effort,
|
||||
REASONING_SUMMARY=reasoning_summary,
|
||||
REASONING_COMPAT=reasoning_compat,
|
||||
FAST_MODE=bool(fast_mode),
|
||||
DEBUG_MODEL=debug_model,
|
||||
BASE_INSTRUCTIONS=BASE_INSTRUCTIONS,
|
||||
GPT5_CODEX_INSTRUCTIONS=GPT5_CODEX_INSTRUCTIONS,
|
||||
@@ -46,5 +50,7 @@ def create_app(
|
||||
|
||||
app.register_blueprint(openai_bp)
|
||||
app.register_blueprint(ollama_bp)
|
||||
sock = Sock(app)
|
||||
register_websocket_routes(sock)
|
||||
|
||||
return app
|
||||
|
||||
@@ -267,6 +267,7 @@ def cmd_serve(
|
||||
reasoning_effort: str,
|
||||
reasoning_summary: str,
|
||||
reasoning_compat: str,
|
||||
fast_mode: bool,
|
||||
debug_model: str | None,
|
||||
expose_reasoning_models: bool,
|
||||
default_web_search: bool,
|
||||
@@ -277,6 +278,7 @@ def cmd_serve(
|
||||
reasoning_effort=reasoning_effort,
|
||||
reasoning_summary=reasoning_summary,
|
||||
reasoning_compat=reasoning_compat,
|
||||
fast_mode=fast_mode,
|
||||
debug_model=debug_model,
|
||||
expose_reasoning_models=expose_reasoning_models,
|
||||
default_web_search=default_web_search,
|
||||
@@ -309,6 +311,12 @@ def main() -> None:
|
||||
default=os.getenv("CHATGPT_LOCAL_DEBUG_MODEL"),
|
||||
help="Forcibly override requested 'model' with this value",
|
||||
)
|
||||
p_serve.add_argument(
|
||||
"--fast-mode",
|
||||
action=argparse.BooleanOptionalAction,
|
||||
default=(os.getenv("CHATGPT_LOCAL_FAST_MODE") or "").strip().lower() in ("1", "true", "yes", "on"),
|
||||
help="Enable GPT fast mode by default for supported models; request-level overrides still take precedence.",
|
||||
)
|
||||
p_serve.add_argument(
|
||||
"--reasoning-effort",
|
||||
choices=["none", "minimal", "low", "medium", "high", "xhigh"],
|
||||
@@ -366,6 +374,7 @@ def main() -> None:
|
||||
reasoning_effort=args.reasoning_effort,
|
||||
reasoning_summary=args.reasoning_summary,
|
||||
reasoning_compat=args.reasoning_compat,
|
||||
fast_mode=args.fast_mode,
|
||||
debug_model=args.debug_model,
|
||||
expose_reasoning_models=args.expose_reasoning_models,
|
||||
default_web_search=args.enable_web_search,
|
||||
|
||||
92
chatmock/fast_mode.py
Normal file
92
chatmock/fast_mode.py
Normal file
@@ -0,0 +1,92 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import Any
|
||||
|
||||
from .model_registry import normalize_model_name
|
||||
|
||||
|
||||
PRIORITY_SUPPORTED_MODELS = frozenset(
|
||||
(
|
||||
"gpt-5.4",
|
||||
"gpt-5.2",
|
||||
"gpt-5.1",
|
||||
"gpt-5",
|
||||
"gpt-5.1-codex",
|
||||
"gpt-5-codex",
|
||||
)
|
||||
)
|
||||
|
||||
_TRUE_STRINGS = {"1", "true", "yes", "on"}
|
||||
_FALSE_STRINGS = {"0", "false", "no", "off"}
|
||||
|
||||
|
||||
def parse_optional_bool(value: Any) -> bool | None:
|
||||
if isinstance(value, bool):
|
||||
return value
|
||||
if isinstance(value, str):
|
||||
normalized = value.strip().lower()
|
||||
if normalized in _TRUE_STRINGS:
|
||||
return True
|
||||
if normalized in _FALSE_STRINGS:
|
||||
return False
|
||||
return None
|
||||
|
||||
|
||||
def supports_priority_service_tier(model: str | None) -> bool:
|
||||
return normalize_model_name(model) in PRIORITY_SUPPORTED_MODELS
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class ServiceTierResolution:
|
||||
service_tier: str | None
|
||||
error_message: str | None = None
|
||||
warning_message: str | None = None
|
||||
used_server_default: bool = False
|
||||
|
||||
|
||||
def resolve_service_tier(
|
||||
model: str | None,
|
||||
*,
|
||||
request_fast_mode: Any = None,
|
||||
request_service_tier: Any = None,
|
||||
server_fast_mode: bool = False,
|
||||
) -> ServiceTierResolution:
|
||||
explicit_fast_mode = parse_optional_bool(request_fast_mode)
|
||||
|
||||
tier: str | None = None
|
||||
explicit_request = False
|
||||
used_server_default = False
|
||||
|
||||
if explicit_fast_mode is not None:
|
||||
tier = "priority" if explicit_fast_mode else None
|
||||
explicit_request = True
|
||||
elif isinstance(request_service_tier, str) and request_service_tier.strip():
|
||||
tier = request_service_tier.strip().lower()
|
||||
explicit_request = True
|
||||
elif server_fast_mode:
|
||||
tier = "priority"
|
||||
used_server_default = True
|
||||
|
||||
if tier == "priority" and not supports_priority_service_tier(model):
|
||||
normalized = normalize_model_name(model)
|
||||
message = (
|
||||
f"Fast mode is not supported for model '{normalized}'. "
|
||||
"Use a supported GPT-5 priority-processing model or disable fast mode for this request."
|
||||
)
|
||||
if explicit_request:
|
||||
return ServiceTierResolution(
|
||||
service_tier=None,
|
||||
error_message=message,
|
||||
used_server_default=used_server_default,
|
||||
)
|
||||
return ServiceTierResolution(
|
||||
service_tier=None,
|
||||
warning_message=message,
|
||||
used_server_default=used_server_default,
|
||||
)
|
||||
|
||||
return ServiceTierResolution(
|
||||
service_tier=tier,
|
||||
used_server_default=used_server_default,
|
||||
)
|
||||
@@ -62,6 +62,14 @@ _MODEL_SPECS = (
|
||||
variant_efforts=("xhigh", "high", "medium", "low"),
|
||||
uses_codex_instructions=True,
|
||||
),
|
||||
ModelSpec(
|
||||
public_id="gpt-5.3-codex-spark",
|
||||
upstream_id="gpt-5.3-codex-spark",
|
||||
aliases=("gpt5.3-codex-spark", "gpt-5.3-codex-spark-latest"),
|
||||
allowed_efforts=frozenset(("low", "medium", "high", "xhigh")),
|
||||
variant_efforts=("xhigh", "high", "medium", "low"),
|
||||
uses_codex_instructions=True,
|
||||
),
|
||||
ModelSpec(
|
||||
public_id="gpt-5-codex",
|
||||
upstream_id="gpt-5-codex",
|
||||
|
||||
242
chatmock/responses_api.py
Normal file
242
chatmock/responses_api.py
Normal file
@@ -0,0 +1,242 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Dict, Iterable, Iterator, List
|
||||
|
||||
from .config import BASE_INSTRUCTIONS, GPT5_CODEX_INSTRUCTIONS
|
||||
from .fast_mode import ServiceTierResolution, resolve_service_tier
|
||||
from .model_registry import (
|
||||
allowed_efforts_for_model,
|
||||
extract_reasoning_from_model_name,
|
||||
normalize_model_name,
|
||||
uses_codex_instructions,
|
||||
)
|
||||
from .reasoning import build_reasoning_param
|
||||
from .session import ensure_session_id
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class ResponsesRequestError(Exception):
|
||||
message: str
|
||||
status_code: int = 400
|
||||
code: str | None = None
|
||||
|
||||
def __str__(self) -> str:
|
||||
return self.message
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class NormalizedResponsesRequest:
|
||||
payload: Dict[str, Any]
|
||||
requested_model: str | None
|
||||
normalized_model: str
|
||||
session_id: str
|
||||
service_tier_resolution: ServiceTierResolution
|
||||
|
||||
|
||||
def instructions_for_model(config: Dict[str, Any], model: str) -> str:
|
||||
base = config.get("BASE_INSTRUCTIONS", BASE_INSTRUCTIONS)
|
||||
if uses_codex_instructions(model):
|
||||
codex = config.get("GPT5_CODEX_INSTRUCTIONS") or GPT5_CODEX_INSTRUCTIONS
|
||||
if isinstance(codex, str) and codex.strip():
|
||||
return codex
|
||||
return base
|
||||
|
||||
|
||||
def extract_client_session_id(headers: Any) -> str | None:
|
||||
try:
|
||||
return headers.get("X-Session-Id") or headers.get("session_id") or None
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
|
||||
def _input_items_for_session(raw_input: Any) -> List[Dict[str, Any]]:
|
||||
if isinstance(raw_input, list):
|
||||
return [item for item in raw_input if isinstance(item, dict)]
|
||||
if isinstance(raw_input, dict):
|
||||
return [raw_input]
|
||||
if isinstance(raw_input, str) and raw_input.strip():
|
||||
return [
|
||||
{
|
||||
"type": "message",
|
||||
"role": "user",
|
||||
"content": [{"type": "input_text", "text": raw_input}],
|
||||
}
|
||||
]
|
||||
return []
|
||||
|
||||
|
||||
def canonicalize_responses_input(raw_input: Any) -> Any:
|
||||
if isinstance(raw_input, list):
|
||||
return [item for item in raw_input if isinstance(item, dict)]
|
||||
if isinstance(raw_input, dict):
|
||||
return [raw_input]
|
||||
if isinstance(raw_input, str):
|
||||
return _input_items_for_session(raw_input)
|
||||
return raw_input
|
||||
|
||||
|
||||
def normalize_responses_payload(
|
||||
payload: Dict[str, Any],
|
||||
*,
|
||||
config: Dict[str, Any],
|
||||
client_session_id: str | None = None,
|
||||
) -> NormalizedResponsesRequest:
|
||||
requested_model = payload.get("model") if isinstance(payload.get("model"), str) else None
|
||||
normalized_model = normalize_model_name(requested_model, config.get("DEBUG_MODEL"))
|
||||
|
||||
normalized = dict(payload)
|
||||
normalized["model"] = normalized_model
|
||||
|
||||
if "input" in normalized:
|
||||
normalized["input"] = canonicalize_responses_input(normalized.get("input"))
|
||||
|
||||
if "store" not in normalized:
|
||||
normalized["store"] = False
|
||||
|
||||
instructions = normalized.get("instructions")
|
||||
if not isinstance(instructions, str) or not instructions.strip():
|
||||
instructions = instructions_for_model(config, normalized_model)
|
||||
normalized["instructions"] = instructions
|
||||
|
||||
reasoning_effort = config.get("REASONING_EFFORT", "medium")
|
||||
reasoning_summary = config.get("REASONING_SUMMARY", "auto")
|
||||
reasoning_overrides = (
|
||||
normalized.get("reasoning")
|
||||
if isinstance(normalized.get("reasoning"), dict)
|
||||
else extract_reasoning_from_model_name(requested_model)
|
||||
)
|
||||
normalized["reasoning"] = build_reasoning_param(
|
||||
reasoning_effort,
|
||||
reasoning_summary,
|
||||
reasoning_overrides,
|
||||
allowed_efforts=allowed_efforts_for_model(normalized_model),
|
||||
)
|
||||
|
||||
include = normalized.get("include")
|
||||
include_list = [item for item in include if isinstance(item, str)] if isinstance(include, list) else []
|
||||
if "reasoning.encrypted_content" not in include_list:
|
||||
include_list.append("reasoning.encrypted_content")
|
||||
normalized["include"] = include_list
|
||||
|
||||
tools = normalized.get("tools")
|
||||
if (not isinstance(tools, list) or not tools) and bool(config.get("DEFAULT_WEB_SEARCH")):
|
||||
tool_choice = normalized.get("tool_choice")
|
||||
if not (isinstance(tool_choice, str) and tool_choice.strip().lower() == "none"):
|
||||
normalized["tools"] = [{"type": "web_search"}]
|
||||
|
||||
service_tier_resolution = resolve_service_tier(
|
||||
normalized_model,
|
||||
request_fast_mode=normalized.get("fast_mode"),
|
||||
request_service_tier=normalized.get("service_tier"),
|
||||
server_fast_mode=bool(config.get("FAST_MODE")),
|
||||
)
|
||||
if service_tier_resolution.error_message:
|
||||
raise ResponsesRequestError(service_tier_resolution.error_message)
|
||||
if service_tier_resolution.service_tier is None:
|
||||
normalized.pop("service_tier", None)
|
||||
else:
|
||||
normalized["service_tier"] = service_tier_resolution.service_tier
|
||||
normalized.pop("fast_mode", None)
|
||||
|
||||
input_items = _input_items_for_session(normalized.get("input"))
|
||||
session_id = ensure_session_id(instructions, input_items, client_session_id)
|
||||
prompt_cache_key = normalized.get("prompt_cache_key")
|
||||
if not isinstance(prompt_cache_key, str) or not prompt_cache_key.strip():
|
||||
normalized["prompt_cache_key"] = session_id
|
||||
|
||||
return NormalizedResponsesRequest(
|
||||
payload=normalized,
|
||||
requested_model=requested_model,
|
||||
normalized_model=normalized_model,
|
||||
session_id=session_id,
|
||||
service_tier_resolution=service_tier_resolution,
|
||||
)
|
||||
|
||||
|
||||
def iter_sse_event_payloads(upstream: Any) -> Iterator[Dict[str, Any]]:
|
||||
for raw in upstream.iter_lines(decode_unicode=False):
|
||||
if not raw:
|
||||
continue
|
||||
line = raw.decode("utf-8", errors="ignore") if isinstance(raw, (bytes, bytearray)) else raw
|
||||
if not line.startswith("data: "):
|
||||
continue
|
||||
data = line[len("data: ") :].strip()
|
||||
if not data or data == "[DONE]":
|
||||
if data == "[DONE]":
|
||||
break
|
||||
continue
|
||||
try:
|
||||
evt = json.loads(data)
|
||||
except Exception:
|
||||
continue
|
||||
if isinstance(evt, dict):
|
||||
yield evt
|
||||
|
||||
|
||||
def aggregate_response_from_sse(
|
||||
upstream: Any,
|
||||
*,
|
||||
on_event: Any | None = None,
|
||||
) -> tuple[Dict[str, Any] | None, Dict[str, Any] | None]:
|
||||
response_obj: Dict[str, Any] | None = None
|
||||
error_obj: Dict[str, Any] | None = None
|
||||
try:
|
||||
for evt in iter_sse_event_payloads(upstream):
|
||||
if callable(on_event):
|
||||
try:
|
||||
on_event(evt)
|
||||
except Exception:
|
||||
pass
|
||||
response = evt.get("response")
|
||||
if isinstance(response, dict):
|
||||
response_obj = response
|
||||
kind = evt.get("type")
|
||||
if kind == "response.failed":
|
||||
if isinstance(response, dict) and isinstance(response.get("error"), dict):
|
||||
error_obj = {"error": response.get("error")}
|
||||
else:
|
||||
error_obj = {"error": {"message": "response.failed"}}
|
||||
break
|
||||
if kind == "response.completed":
|
||||
break
|
||||
finally:
|
||||
upstream.close()
|
||||
return response_obj, error_obj
|
||||
|
||||
|
||||
def stream_upstream_bytes(
|
||||
upstream: Any,
|
||||
*,
|
||||
on_event: Any | None = None,
|
||||
) -> Iterable[bytes]:
|
||||
buffer = b""
|
||||
try:
|
||||
for chunk in upstream.iter_content(chunk_size=None):
|
||||
if chunk:
|
||||
if callable(on_event):
|
||||
if isinstance(chunk, bytes):
|
||||
buffer += chunk
|
||||
else:
|
||||
buffer += str(chunk).encode("utf-8", errors="ignore")
|
||||
while b"\n" in buffer:
|
||||
line, buffer = buffer.split(b"\n", 1)
|
||||
line = line.rstrip(b"\r")
|
||||
if not line.startswith(b"data: "):
|
||||
continue
|
||||
data = line[len(b"data: ") :].strip()
|
||||
if not data or data == b"[DONE]":
|
||||
continue
|
||||
try:
|
||||
evt = json.loads(data.decode("utf-8", errors="ignore"))
|
||||
except Exception:
|
||||
evt = None
|
||||
if isinstance(evt, dict):
|
||||
try:
|
||||
on_event(evt)
|
||||
except Exception:
|
||||
pass
|
||||
yield chunk
|
||||
finally:
|
||||
upstream.close()
|
||||
@@ -8,9 +8,11 @@ from typing import Any, Dict, List
|
||||
from flask import Blueprint, Response, current_app, jsonify, make_response, request, stream_with_context
|
||||
|
||||
from .config import BASE_INSTRUCTIONS, GPT5_CODEX_INSTRUCTIONS
|
||||
from .fast_mode import resolve_service_tier
|
||||
from .limits import record_rate_limits_from_response
|
||||
from .http import build_cors_headers
|
||||
from .model_registry import list_public_models, uses_codex_instructions
|
||||
from .responses_api import instructions_for_model
|
||||
from .reasoning import (
|
||||
allowed_efforts_for_model,
|
||||
build_reasoning_param,
|
||||
@@ -71,12 +73,7 @@ def ollama_version() -> Response:
|
||||
|
||||
|
||||
def _instructions_for_model(model: str) -> str:
|
||||
base = current_app.config.get("BASE_INSTRUCTIONS", BASE_INSTRUCTIONS)
|
||||
if uses_codex_instructions(model):
|
||||
codex = current_app.config.get("GPT5_CODEX_INSTRUCTIONS") or GPT5_CODEX_INSTRUCTIONS
|
||||
if isinstance(codex, str) and codex.strip():
|
||||
return codex
|
||||
return base
|
||||
return instructions_for_model(current_app.config, model)
|
||||
|
||||
|
||||
_OLLAMA_FAKE_EVAL = {
|
||||
@@ -254,6 +251,19 @@ def ollama_chat() -> Response:
|
||||
|
||||
model_reasoning = extract_reasoning_from_model_name(model)
|
||||
normalized_model = normalize_model_name(model)
|
||||
service_tier_resolution = resolve_service_tier(
|
||||
normalized_model,
|
||||
request_fast_mode=payload.get("fast_mode"),
|
||||
request_service_tier=payload.get("service_tier"),
|
||||
server_fast_mode=bool(current_app.config.get("FAST_MODE")),
|
||||
)
|
||||
if service_tier_resolution.warning_message and verbose:
|
||||
print(f"[FastMode] {service_tier_resolution.warning_message}")
|
||||
if service_tier_resolution.error_message:
|
||||
err = {"error": service_tier_resolution.error_message}
|
||||
if verbose:
|
||||
_log_json("OUT POST /api/chat", err)
|
||||
return jsonify(err), 400
|
||||
upstream, error_resp = start_upstream_request(
|
||||
normalized_model,
|
||||
input_items,
|
||||
@@ -267,6 +277,7 @@ def ollama_chat() -> Response:
|
||||
model_reasoning,
|
||||
allowed_efforts=allowed_efforts_for_model(model),
|
||||
),
|
||||
service_tier=service_tier_resolution.service_tier,
|
||||
)
|
||||
if error_resp is not None:
|
||||
if verbose:
|
||||
@@ -307,6 +318,7 @@ def ollama_chat() -> Response:
|
||||
model_reasoning,
|
||||
allowed_efforts=allowed_efforts_for_model(model),
|
||||
),
|
||||
service_tier=service_tier_resolution.service_tier,
|
||||
)
|
||||
record_rate_limits_from_response(upstream2)
|
||||
if err2 is None and upstream2 is not None and upstream2.status_code < 400:
|
||||
|
||||
@@ -7,16 +7,31 @@ from typing import Any, Dict, List
|
||||
from flask import Blueprint, Response, current_app, jsonify, make_response, request
|
||||
|
||||
from .config import BASE_INSTRUCTIONS, GPT5_CODEX_INSTRUCTIONS
|
||||
from .fast_mode import resolve_service_tier
|
||||
from .limits import record_rate_limits_from_response
|
||||
from .http import build_cors_headers
|
||||
from .model_registry import list_public_models, uses_codex_instructions
|
||||
from .responses_api import (
|
||||
ResponsesRequestError,
|
||||
aggregate_response_from_sse,
|
||||
extract_client_session_id,
|
||||
instructions_for_model,
|
||||
normalize_responses_payload,
|
||||
stream_upstream_bytes,
|
||||
)
|
||||
from .reasoning import (
|
||||
allowed_efforts_for_model,
|
||||
apply_reasoning_to_message,
|
||||
build_reasoning_param,
|
||||
extract_reasoning_from_model_name,
|
||||
)
|
||||
from .upstream import normalize_model_name, start_upstream_request
|
||||
from .session import (
|
||||
clear_responses_reuse_state,
|
||||
note_responses_final_response,
|
||||
note_responses_stream_event,
|
||||
prepare_responses_request_for_session,
|
||||
)
|
||||
from .upstream import normalize_model_name, start_upstream_raw_request, start_upstream_request
|
||||
from .utils import (
|
||||
convert_chat_messages_to_responses_input,
|
||||
convert_tools_chat_to_responses,
|
||||
@@ -59,12 +74,32 @@ def _wrap_stream_logging(label: str, iterator, enabled: bool):
|
||||
|
||||
|
||||
def _instructions_for_model(model: str) -> str:
|
||||
base = current_app.config.get("BASE_INSTRUCTIONS", BASE_INSTRUCTIONS)
|
||||
if uses_codex_instructions(model):
|
||||
codex = current_app.config.get("GPT5_CODEX_INSTRUCTIONS") or GPT5_CODEX_INSTRUCTIONS
|
||||
if isinstance(codex, str) and codex.strip():
|
||||
return codex
|
||||
return base
|
||||
return instructions_for_model(current_app.config, model)
|
||||
|
||||
|
||||
def _service_tier_from_payload(
|
||||
model: str,
|
||||
payload: Dict[str, Any],
|
||||
*,
|
||||
verbose: bool = False,
|
||||
) -> tuple[str | None, Response | None]:
|
||||
resolution = resolve_service_tier(
|
||||
model,
|
||||
request_fast_mode=payload.get("fast_mode"),
|
||||
request_service_tier=payload.get("service_tier"),
|
||||
server_fast_mode=bool(current_app.config.get("FAST_MODE")),
|
||||
)
|
||||
if resolution.warning_message and verbose:
|
||||
print(f"[FastMode] {resolution.warning_message}")
|
||||
if resolution.error_message:
|
||||
err = {"error": {"message": resolution.error_message}}
|
||||
if verbose:
|
||||
_log_json("OUT POST service_tier resolution", err)
|
||||
resp = make_response(jsonify(err), 400)
|
||||
for k, v in build_cors_headers().items():
|
||||
resp.headers.setdefault(k, v)
|
||||
return None, resp
|
||||
return resolution.service_tier, None
|
||||
|
||||
|
||||
@openai_bp.route("/v1/chat/completions", methods=["POST"])
|
||||
@@ -178,6 +213,9 @@ def chat_completions() -> Response:
|
||||
reasoning_overrides,
|
||||
allowed_efforts=allowed_efforts_for_model(model),
|
||||
)
|
||||
service_tier, tier_error = _service_tier_from_payload(model, payload, verbose=verbose)
|
||||
if tier_error is not None:
|
||||
return tier_error
|
||||
|
||||
upstream, error_resp = start_upstream_request(
|
||||
model,
|
||||
@@ -187,6 +225,7 @@ def chat_completions() -> Response:
|
||||
tool_choice=tool_choice,
|
||||
parallel_tool_calls=parallel_tool_calls,
|
||||
reasoning_param=reasoning_param,
|
||||
service_tier=service_tier,
|
||||
)
|
||||
if error_resp is not None:
|
||||
if verbose:
|
||||
@@ -224,6 +263,7 @@ def chat_completions() -> Response:
|
||||
tool_choice=safe_choice,
|
||||
parallel_tool_calls=parallel_tool_calls,
|
||||
reasoning_param=reasoning_param,
|
||||
service_tier=service_tier,
|
||||
)
|
||||
record_rate_limits_from_response(upstream2)
|
||||
if err2 is None and upstream2 is not None and upstream2.status_code < 400:
|
||||
@@ -413,11 +453,15 @@ def completions() -> Response:
|
||||
reasoning_overrides,
|
||||
allowed_efforts=allowed_efforts_for_model(model),
|
||||
)
|
||||
service_tier, tier_error = _service_tier_from_payload(model, payload, verbose=verbose)
|
||||
if tier_error is not None:
|
||||
return tier_error
|
||||
upstream, error_resp = start_upstream_request(
|
||||
model,
|
||||
input_items,
|
||||
instructions=_instructions_for_model(model),
|
||||
reasoning_param=reasoning_param,
|
||||
service_tier=service_tier,
|
||||
)
|
||||
if error_resp is not None:
|
||||
if verbose:
|
||||
@@ -529,6 +573,161 @@ def completions() -> Response:
|
||||
return resp
|
||||
|
||||
|
||||
@openai_bp.route("/v1/responses", methods=["POST"])
|
||||
def responses_create() -> Response:
|
||||
verbose = bool(current_app.config.get("VERBOSE"))
|
||||
raw = request.get_data(cache=True, as_text=True) or ""
|
||||
if verbose:
|
||||
try:
|
||||
print("IN POST /v1/responses\n" + raw)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
try:
|
||||
payload = json.loads(raw) if raw else {}
|
||||
except Exception:
|
||||
err = {"error": {"message": "Invalid JSON body"}}
|
||||
if verbose:
|
||||
_log_json("OUT POST /v1/responses", err)
|
||||
return jsonify(err), 400
|
||||
|
||||
if not isinstance(payload, dict):
|
||||
err = {"error": {"message": "Request body must be a JSON object"}}
|
||||
if verbose:
|
||||
_log_json("OUT POST /v1/responses", err)
|
||||
return jsonify(err), 400
|
||||
|
||||
try:
|
||||
normalized = normalize_responses_payload(
|
||||
payload,
|
||||
config=current_app.config,
|
||||
client_session_id=extract_client_session_id(request.headers),
|
||||
)
|
||||
except ResponsesRequestError as exc:
|
||||
err: Dict[str, Any] = {"error": {"message": str(exc)}}
|
||||
if exc.code:
|
||||
err["error"]["code"] = exc.code
|
||||
if verbose:
|
||||
_log_json("OUT POST /v1/responses", err)
|
||||
return jsonify(err), exc.status_code
|
||||
|
||||
if normalized.service_tier_resolution.warning_message and verbose:
|
||||
print(f"[FastMode] {normalized.service_tier_resolution.warning_message}")
|
||||
|
||||
prepared = prepare_responses_request_for_session(
|
||||
normalized.session_id,
|
||||
normalized.payload,
|
||||
allow_previous_response_id=False,
|
||||
)
|
||||
stream_req = bool(prepared.payload.get("stream", False))
|
||||
upstream_payload = dict(prepared.payload)
|
||||
upstream_payload["stream"] = True
|
||||
upstream, error_resp = start_upstream_raw_request(
|
||||
upstream_payload,
|
||||
session_id=normalized.session_id,
|
||||
stream=True,
|
||||
)
|
||||
if error_resp is not None:
|
||||
clear_responses_reuse_state(normalized.session_id)
|
||||
if verbose:
|
||||
try:
|
||||
body = error_resp.get_data(as_text=True)
|
||||
if body:
|
||||
try:
|
||||
parsed = json.loads(body)
|
||||
except Exception:
|
||||
parsed = body
|
||||
_log_json("OUT POST /v1/responses", parsed)
|
||||
except Exception:
|
||||
pass
|
||||
return error_resp
|
||||
|
||||
record_rate_limits_from_response(upstream)
|
||||
|
||||
if upstream.status_code >= 400:
|
||||
try:
|
||||
err_body = json.loads(upstream.content.decode("utf-8", errors="ignore")) if upstream.content else {"error": {"message": upstream.text}}
|
||||
except Exception:
|
||||
err_body = {"error": {"message": upstream.text or "Upstream error"}}
|
||||
finally:
|
||||
upstream.close()
|
||||
clear_responses_reuse_state(normalized.session_id)
|
||||
if verbose:
|
||||
_log_json("OUT POST /v1/responses", err_body)
|
||||
resp = make_response(jsonify(err_body), upstream.status_code)
|
||||
for k, v in build_cors_headers().items():
|
||||
resp.headers.setdefault(k, v)
|
||||
return resp
|
||||
|
||||
if stream_req:
|
||||
if verbose:
|
||||
print("OUT POST /v1/responses (streaming response)")
|
||||
stream_iter = _wrap_stream_logging(
|
||||
"STREAM OUT /v1/responses",
|
||||
stream_upstream_bytes(
|
||||
upstream,
|
||||
on_event=lambda evt: note_responses_stream_event(normalized.session_id, evt),
|
||||
),
|
||||
verbose,
|
||||
)
|
||||
resp = Response(
|
||||
stream_iter,
|
||||
status=upstream.status_code,
|
||||
mimetype="text/event-stream",
|
||||
headers={"Cache-Control": "no-cache", "Connection": "keep-alive"},
|
||||
)
|
||||
for k, v in build_cors_headers().items():
|
||||
resp.headers.setdefault(k, v)
|
||||
return resp
|
||||
|
||||
content_type = upstream.headers.get("Content-Type", "")
|
||||
if "application/json" in content_type.lower():
|
||||
try:
|
||||
body = upstream.json()
|
||||
except Exception:
|
||||
body = None
|
||||
finally:
|
||||
upstream.close()
|
||||
if isinstance(body, dict):
|
||||
note_responses_final_response(normalized.session_id, body)
|
||||
if verbose:
|
||||
_log_json("OUT POST /v1/responses", body)
|
||||
resp = make_response(jsonify(body), upstream.status_code)
|
||||
for k, v in build_cors_headers().items():
|
||||
resp.headers.setdefault(k, v)
|
||||
return resp
|
||||
|
||||
response_obj, error_obj = aggregate_response_from_sse(
|
||||
upstream,
|
||||
on_event=lambda evt: note_responses_stream_event(normalized.session_id, evt),
|
||||
)
|
||||
if error_obj is not None:
|
||||
clear_responses_reuse_state(normalized.session_id)
|
||||
if verbose:
|
||||
_log_json("OUT POST /v1/responses", error_obj)
|
||||
resp = make_response(jsonify(error_obj), 502)
|
||||
for k, v in build_cors_headers().items():
|
||||
resp.headers.setdefault(k, v)
|
||||
return resp
|
||||
|
||||
if response_obj is None:
|
||||
clear_responses_reuse_state(normalized.session_id)
|
||||
err = {"error": {"message": "Upstream response stream did not contain a completed response object"}}
|
||||
if verbose:
|
||||
_log_json("OUT POST /v1/responses", err)
|
||||
resp = make_response(jsonify(err), 502)
|
||||
for k, v in build_cors_headers().items():
|
||||
resp.headers.setdefault(k, v)
|
||||
return resp
|
||||
|
||||
if verbose:
|
||||
_log_json("OUT POST /v1/responses", response_obj)
|
||||
resp = make_response(jsonify(response_obj), upstream.status_code)
|
||||
for k, v in build_cors_headers().items():
|
||||
resp.headers.setdefault(k, v)
|
||||
return resp
|
||||
|
||||
|
||||
@openai_bp.route("/v1/models", methods=["GET"])
|
||||
def list_models() -> Response:
|
||||
expose_variants = bool(current_app.config.get("EXPOSE_REASONING_MODELS"))
|
||||
|
||||
@@ -1,16 +1,37 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import copy
|
||||
import hashlib
|
||||
import json
|
||||
import threading
|
||||
import uuid
|
||||
from typing import Any, Dict, List, Tuple
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any, Dict, List
|
||||
|
||||
|
||||
_LOCK = threading.Lock()
|
||||
_FINGERPRINT_TO_UUID: Dict[str, str] = {}
|
||||
_ORDER: List[str] = []
|
||||
_MAX_ENTRIES = 10000
|
||||
_RESPONSES_SESSION_STATE: Dict[str, "_ResponsesSessionState"] = {}
|
||||
_RESPONSES_ORDER: List[str] = []
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class PreparedResponsesRequest:
|
||||
payload: Dict[str, Any]
|
||||
session_id: str
|
||||
|
||||
|
||||
@dataclass
|
||||
class _ResponsesSessionState:
|
||||
last_request_payload: Dict[str, Any] | None = None
|
||||
last_response_id: str | None = None
|
||||
last_response_items: List[Dict[str, Any]] = field(default_factory=list)
|
||||
inflight_request_payload: Dict[str, Any] | None = None
|
||||
inflight_track_result: bool = False
|
||||
inflight_response_id: str | None = None
|
||||
inflight_response_items: List[Dict[str, Any]] = field(default_factory=list)
|
||||
|
||||
|
||||
def _canonicalize_first_user_message(input_items: List[Dict[str, Any]]) -> Dict[str, Any] | None:
|
||||
@@ -70,6 +91,61 @@ def _remember(fp: str, sid: str) -> None:
|
||||
_FINGERPRINT_TO_UUID.pop(oldest, None)
|
||||
|
||||
|
||||
def _remember_responses_session(session_id: str) -> _ResponsesSessionState:
|
||||
state = _RESPONSES_SESSION_STATE.get(session_id)
|
||||
if state is None:
|
||||
state = _ResponsesSessionState()
|
||||
_RESPONSES_SESSION_STATE[session_id] = state
|
||||
_RESPONSES_ORDER.append(session_id)
|
||||
if len(_RESPONSES_ORDER) > _MAX_ENTRIES:
|
||||
oldest = _RESPONSES_ORDER.pop(0)
|
||||
_RESPONSES_SESSION_STATE.pop(oldest, None)
|
||||
return state
|
||||
|
||||
|
||||
def _request_without_input(payload: Dict[str, Any]) -> Dict[str, Any]:
|
||||
clone = copy.deepcopy(payload)
|
||||
clone["input"] = []
|
||||
clone.pop("previous_response_id", None)
|
||||
return clone
|
||||
|
||||
|
||||
def _input_list(payload: Dict[str, Any]) -> List[Dict[str, Any]] | None:
|
||||
raw = payload.get("input")
|
||||
if not isinstance(raw, list):
|
||||
return None
|
||||
return [item for item in copy.deepcopy(raw) if isinstance(item, dict)]
|
||||
|
||||
|
||||
def _conversation_output_items(items: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
|
||||
reusable: List[Dict[str, Any]] = []
|
||||
for item in items:
|
||||
if not isinstance(item, dict):
|
||||
continue
|
||||
item_type = item.get("type")
|
||||
if item_type == "reasoning":
|
||||
continue
|
||||
reusable.append(copy.deepcopy(item))
|
||||
return reusable
|
||||
|
||||
|
||||
def _clear_reuse_state(state: _ResponsesSessionState) -> None:
|
||||
state.last_request_payload = None
|
||||
state.last_response_id = None
|
||||
state.last_response_items = []
|
||||
state.inflight_request_payload = None
|
||||
state.inflight_track_result = False
|
||||
state.inflight_response_id = None
|
||||
state.inflight_response_items = []
|
||||
|
||||
|
||||
def _clear_inflight(state: _ResponsesSessionState) -> None:
|
||||
state.inflight_request_payload = None
|
||||
state.inflight_track_result = False
|
||||
state.inflight_response_id = None
|
||||
state.inflight_response_items = []
|
||||
|
||||
|
||||
def ensure_session_id(
|
||||
instructions: str | None,
|
||||
input_items: List[Dict[str, Any]],
|
||||
@@ -87,3 +163,150 @@ def ensure_session_id(
|
||||
_remember(fp, sid)
|
||||
return sid
|
||||
|
||||
|
||||
def prepare_responses_request_for_session(
|
||||
session_id: str,
|
||||
payload: Dict[str, Any],
|
||||
*,
|
||||
allow_previous_response_id: bool = True,
|
||||
) -> PreparedResponsesRequest:
|
||||
full_payload = copy.deepcopy(payload)
|
||||
outbound_payload = copy.deepcopy(payload)
|
||||
explicit_previous_response_id = (
|
||||
isinstance(full_payload.get("previous_response_id"), str)
|
||||
and bool(full_payload.get("previous_response_id").strip())
|
||||
)
|
||||
|
||||
with _LOCK:
|
||||
state = _remember_responses_session(session_id)
|
||||
|
||||
if explicit_previous_response_id:
|
||||
_clear_reuse_state(state)
|
||||
return PreparedResponsesRequest(
|
||||
payload=outbound_payload,
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
request_input = _input_list(full_payload)
|
||||
if (
|
||||
allow_previous_response_id
|
||||
and
|
||||
state.last_request_payload is not None
|
||||
and state.last_response_id
|
||||
and request_input is not None
|
||||
and _request_without_input(state.last_request_payload) == _request_without_input(full_payload)
|
||||
):
|
||||
baseline: List[Dict[str, Any]] = []
|
||||
previous_input = _input_list(state.last_request_payload)
|
||||
if previous_input is not None:
|
||||
baseline.extend(previous_input)
|
||||
baseline.extend(copy.deepcopy(state.last_response_items))
|
||||
baseline_len = len(baseline)
|
||||
if request_input[:baseline_len] == baseline and baseline_len <= len(request_input):
|
||||
outbound_payload["input"] = copy.deepcopy(request_input[baseline_len:])
|
||||
outbound_payload["previous_response_id"] = state.last_response_id
|
||||
|
||||
state.inflight_request_payload = full_payload
|
||||
state.inflight_track_result = True
|
||||
state.inflight_response_id = None
|
||||
state.inflight_response_items = []
|
||||
|
||||
return PreparedResponsesRequest(
|
||||
payload=outbound_payload,
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
|
||||
def note_responses_stream_event(session_id: str, event: Dict[str, Any]) -> None:
|
||||
if not isinstance(session_id, str) or not session_id.strip():
|
||||
return
|
||||
if not isinstance(event, dict):
|
||||
return
|
||||
|
||||
with _LOCK:
|
||||
state = _RESPONSES_SESSION_STATE.get(session_id)
|
||||
if state is None:
|
||||
return
|
||||
|
||||
kind = event.get("type")
|
||||
if kind == "response.created":
|
||||
response = event.get("response")
|
||||
if isinstance(response, dict) and isinstance(response.get("id"), str):
|
||||
state.inflight_response_id = response.get("id")
|
||||
return
|
||||
|
||||
if kind == "response.output_item.done":
|
||||
item = event.get("item")
|
||||
if isinstance(item, dict):
|
||||
state.inflight_response_items.append(copy.deepcopy(item))
|
||||
return
|
||||
|
||||
if kind == "response.completed":
|
||||
response = event.get("response")
|
||||
response_id = None
|
||||
response_items: List[Dict[str, Any]] = copy.deepcopy(state.inflight_response_items)
|
||||
if isinstance(response, dict):
|
||||
if isinstance(response.get("id"), str):
|
||||
response_id = response.get("id")
|
||||
output = response.get("output")
|
||||
if isinstance(output, list) and output:
|
||||
response_items = [copy.deepcopy(item) for item in output if isinstance(item, dict)]
|
||||
if not response_id:
|
||||
response_id = state.inflight_response_id
|
||||
|
||||
if state.inflight_track_result and state.inflight_request_payload is not None and response_id:
|
||||
state.last_request_payload = copy.deepcopy(state.inflight_request_payload)
|
||||
state.last_response_id = response_id
|
||||
state.last_response_items = _conversation_output_items(response_items)
|
||||
else:
|
||||
state.last_request_payload = None
|
||||
state.last_response_id = None
|
||||
state.last_response_items = []
|
||||
_clear_inflight(state)
|
||||
return
|
||||
|
||||
if kind in ("response.failed", "error"):
|
||||
_clear_reuse_state(state)
|
||||
|
||||
|
||||
def note_responses_final_response(session_id: str, response_obj: Dict[str, Any]) -> None:
|
||||
if not isinstance(session_id, str) or not session_id.strip():
|
||||
return
|
||||
if not isinstance(response_obj, dict):
|
||||
return
|
||||
|
||||
with _LOCK:
|
||||
state = _RESPONSES_SESSION_STATE.get(session_id)
|
||||
if state is None:
|
||||
return
|
||||
|
||||
response_id = response_obj.get("id") if isinstance(response_obj.get("id"), str) else None
|
||||
output = response_obj.get("output")
|
||||
output_items = [copy.deepcopy(item) for item in output if isinstance(item, dict)] if isinstance(output, list) else []
|
||||
if state.inflight_track_result and state.inflight_request_payload is not None and response_id:
|
||||
state.last_request_payload = copy.deepcopy(state.inflight_request_payload)
|
||||
state.last_response_id = response_id
|
||||
state.last_response_items = _conversation_output_items(output_items)
|
||||
else:
|
||||
state.last_request_payload = None
|
||||
state.last_response_id = None
|
||||
state.last_response_items = []
|
||||
_clear_inflight(state)
|
||||
|
||||
|
||||
def clear_responses_reuse_state(session_id: str) -> None:
|
||||
if not isinstance(session_id, str) or not session_id.strip():
|
||||
return
|
||||
with _LOCK:
|
||||
state = _RESPONSES_SESSION_STATE.get(session_id)
|
||||
if state is None:
|
||||
return
|
||||
_clear_reuse_state(state)
|
||||
|
||||
|
||||
def reset_session_state() -> None:
|
||||
with _LOCK:
|
||||
_FINGERPRINT_TO_UUID.clear()
|
||||
_ORDER.clear()
|
||||
_RESPONSES_SESSION_STATE.clear()
|
||||
_RESPONSES_ORDER.clear()
|
||||
|
||||
@@ -3,6 +3,7 @@ from __future__ import annotations
|
||||
import json
|
||||
import time
|
||||
from typing import Any, Dict, List, Tuple
|
||||
from urllib.parse import urlparse, urlunparse
|
||||
|
||||
import requests
|
||||
from flask import Response, current_app, jsonify, make_response
|
||||
@@ -33,6 +34,7 @@ def start_upstream_request(
|
||||
tool_choice: Any | None = None,
|
||||
parallel_tool_calls: bool = False,
|
||||
reasoning_param: Dict[str, Any] | None = None,
|
||||
service_tier: str | None = None,
|
||||
):
|
||||
access_token, account_id = get_effective_chatgpt_auth()
|
||||
if not access_token or not account_id:
|
||||
@@ -81,6 +83,62 @@ def start_upstream_request(
|
||||
|
||||
if reasoning_param is not None:
|
||||
responses_payload["reasoning"] = reasoning_param
|
||||
if isinstance(service_tier, str) and service_tier.strip():
|
||||
responses_payload["service_tier"] = service_tier.strip().lower()
|
||||
|
||||
return start_upstream_raw_request(
|
||||
responses_payload,
|
||||
session_id=session_id,
|
||||
stream=True,
|
||||
)
|
||||
|
||||
|
||||
def build_upstream_headers(
|
||||
access_token: str,
|
||||
account_id: str,
|
||||
session_id: str,
|
||||
*,
|
||||
accept: str = "text/event-stream",
|
||||
) -> Dict[str, str]:
|
||||
return {
|
||||
"Authorization": f"Bearer {access_token}",
|
||||
"Content-Type": "application/json",
|
||||
"Accept": accept,
|
||||
"chatgpt-account-id": account_id,
|
||||
"OpenAI-Beta": "responses=experimental",
|
||||
"session_id": session_id,
|
||||
}
|
||||
|
||||
|
||||
def start_upstream_raw_request(
|
||||
responses_payload: Dict[str, Any],
|
||||
*,
|
||||
session_id: str | None = None,
|
||||
stream: bool = True,
|
||||
):
|
||||
access_token, account_id = get_effective_chatgpt_auth()
|
||||
if not access_token or not account_id:
|
||||
resp = make_response(
|
||||
jsonify(
|
||||
{
|
||||
"error": {
|
||||
"message": "Missing ChatGPT credentials. Run 'python3 chatmock.py login' first.",
|
||||
}
|
||||
}
|
||||
),
|
||||
401,
|
||||
)
|
||||
for k, v in build_cors_headers().items():
|
||||
resp.headers.setdefault(k, v)
|
||||
return None, resp
|
||||
|
||||
effective_session_id = session_id
|
||||
if not isinstance(effective_session_id, str) or not effective_session_id.strip():
|
||||
payload_prompt_cache_key = responses_payload.get("prompt_cache_key")
|
||||
if isinstance(payload_prompt_cache_key, str) and payload_prompt_cache_key.strip():
|
||||
effective_session_id = payload_prompt_cache_key.strip()
|
||||
if not isinstance(effective_session_id, str) or not effective_session_id.strip():
|
||||
effective_session_id = str(int(time.time() * 1000))
|
||||
|
||||
verbose = False
|
||||
try:
|
||||
@@ -90,21 +148,19 @@ def start_upstream_request(
|
||||
if verbose:
|
||||
_log_json("OUTBOUND >> ChatGPT Responses API payload", responses_payload)
|
||||
|
||||
headers = {
|
||||
"Authorization": f"Bearer {access_token}",
|
||||
"Content-Type": "application/json",
|
||||
"Accept": "text/event-stream",
|
||||
"chatgpt-account-id": account_id,
|
||||
"OpenAI-Beta": "responses=experimental",
|
||||
"session_id": session_id,
|
||||
}
|
||||
headers = build_upstream_headers(
|
||||
access_token,
|
||||
account_id,
|
||||
effective_session_id,
|
||||
accept=("text/event-stream" if stream else "application/json"),
|
||||
)
|
||||
|
||||
try:
|
||||
upstream = requests.post(
|
||||
CHATGPT_RESPONSES_URL,
|
||||
headers=headers,
|
||||
json=responses_payload,
|
||||
stream=True,
|
||||
stream=stream,
|
||||
timeout=600,
|
||||
)
|
||||
except requests.RequestException as e:
|
||||
@@ -113,3 +169,13 @@ def start_upstream_request(
|
||||
resp.headers.setdefault(k, v)
|
||||
return None, resp
|
||||
return upstream, None
|
||||
|
||||
|
||||
def build_upstream_websocket_url() -> str:
|
||||
parsed = urlparse(CHATGPT_RESPONSES_URL)
|
||||
scheme = parsed.scheme.lower()
|
||||
if scheme == "https":
|
||||
parsed = parsed._replace(scheme="wss")
|
||||
elif scheme == "http":
|
||||
parsed = parsed._replace(scheme="ws")
|
||||
return urlunparse(parsed)
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
from __future__ import annotations
|
||||
|
||||
|
||||
__version__ = "1.36"
|
||||
__version__ = "1.37"
|
||||
|
||||
225
chatmock/websocket_routes.py
Normal file
225
chatmock/websocket_routes.py
Normal file
@@ -0,0 +1,225 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import os
|
||||
import ssl
|
||||
from typing import Any, Dict
|
||||
|
||||
import certifi
|
||||
from flask import current_app, request
|
||||
from flask_sock import Sock
|
||||
from websockets.sync.client import connect as websocket_connect
|
||||
from websockets.exceptions import ConnectionClosed
|
||||
|
||||
from .responses_api import (
|
||||
ResponsesRequestError,
|
||||
extract_client_session_id,
|
||||
normalize_responses_payload,
|
||||
)
|
||||
from .session import (
|
||||
clear_responses_reuse_state,
|
||||
note_responses_stream_event,
|
||||
prepare_responses_request_for_session,
|
||||
)
|
||||
from .upstream import build_upstream_headers, build_upstream_websocket_url
|
||||
from .utils import get_effective_chatgpt_auth
|
||||
|
||||
|
||||
def _log_json(prefix: str, payload: Any) -> None:
|
||||
try:
|
||||
print(f"{prefix}\n{json.dumps(payload, indent=2, ensure_ascii=False)}")
|
||||
except Exception:
|
||||
try:
|
||||
print(f"{prefix}\n{payload}")
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
|
||||
def _error_event(message: str, *, status_code: int = 400, code: str | None = None) -> Dict[str, Any]:
|
||||
error: Dict[str, Any] = {"message": message}
|
||||
if code:
|
||||
error["code"] = code
|
||||
return {"type": "error", "status_code": status_code, "error": error}
|
||||
|
||||
|
||||
def _is_terminal_event(event: Any) -> bool:
|
||||
if not isinstance(event, dict):
|
||||
return False
|
||||
kind = event.get("type")
|
||||
return kind in ("response.completed", "response.failed", "error")
|
||||
|
||||
|
||||
def _build_websocket_ssl_context() -> ssl.SSLContext:
|
||||
cafile = (
|
||||
os.getenv("CODEX_CA_CERTIFICATE")
|
||||
or os.getenv("SSL_CERT_FILE")
|
||||
or certifi.where()
|
||||
)
|
||||
return ssl.create_default_context(cafile=cafile)
|
||||
|
||||
|
||||
def connect_upstream_websocket(url: str, headers: Dict[str, str]):
|
||||
return websocket_connect(
|
||||
url,
|
||||
additional_headers=headers,
|
||||
open_timeout=15,
|
||||
ssl=_build_websocket_ssl_context(),
|
||||
)
|
||||
|
||||
|
||||
def register_websocket_routes(sock: Sock) -> None:
|
||||
@sock.route("/v1/responses")
|
||||
def responses_websocket(ws) -> None:
|
||||
verbose = bool(current_app.config.get("VERBOSE"))
|
||||
upstream_ws = None
|
||||
upstream_session_id: str | None = None
|
||||
active_session_id: str | None = None
|
||||
|
||||
def _send_error(message: str, *, status_code: int = 400, code: str | None = None) -> None:
|
||||
evt = _error_event(message, status_code=status_code, code=code)
|
||||
if verbose:
|
||||
_log_json("STREAM OUT WS /v1/responses (error)", evt)
|
||||
try:
|
||||
ws.send(json.dumps(evt))
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
try:
|
||||
while True:
|
||||
incoming = ws.receive()
|
||||
if incoming is None:
|
||||
break
|
||||
|
||||
if isinstance(incoming, bytes):
|
||||
incoming_text = incoming.decode("utf-8", errors="ignore")
|
||||
else:
|
||||
incoming_text = str(incoming)
|
||||
if verbose:
|
||||
print("IN WS /v1/responses\n" + incoming_text)
|
||||
|
||||
try:
|
||||
payload = json.loads(incoming_text)
|
||||
except Exception:
|
||||
_send_error("Websocket frames must be valid JSON objects.", status_code=400)
|
||||
break
|
||||
|
||||
if not isinstance(payload, dict):
|
||||
_send_error("Websocket frames must be JSON objects.", status_code=400)
|
||||
break
|
||||
|
||||
client_session_id = extract_client_session_id(request.headers)
|
||||
outbound_text = incoming_text
|
||||
session_id = upstream_session_id
|
||||
|
||||
if payload.get("type") == "response.create":
|
||||
try:
|
||||
normalized = normalize_responses_payload(
|
||||
payload,
|
||||
config=current_app.config,
|
||||
client_session_id=client_session_id,
|
||||
)
|
||||
except ResponsesRequestError as exc:
|
||||
_send_error(str(exc), status_code=exc.status_code, code=exc.code)
|
||||
continue
|
||||
|
||||
if normalized.service_tier_resolution.warning_message and verbose:
|
||||
print(f"[FastMode] {normalized.service_tier_resolution.warning_message}")
|
||||
prepared = prepare_responses_request_for_session(
|
||||
normalized.session_id,
|
||||
normalized.payload,
|
||||
allow_previous_response_id=True,
|
||||
)
|
||||
outbound_text = json.dumps(prepared.payload)
|
||||
session_id = normalized.session_id
|
||||
active_session_id = normalized.session_id
|
||||
if verbose:
|
||||
_log_json("OUTBOUND >> ChatGPT Responses WS payload", prepared.payload)
|
||||
elif upstream_ws is None:
|
||||
_send_error(
|
||||
"The first websocket message must be a response.create request.",
|
||||
status_code=400,
|
||||
)
|
||||
break
|
||||
|
||||
if upstream_ws is None or (session_id and session_id != upstream_session_id):
|
||||
access_token, account_id = get_effective_chatgpt_auth()
|
||||
if not access_token or not account_id:
|
||||
if session_id:
|
||||
clear_responses_reuse_state(session_id)
|
||||
_send_error(
|
||||
"Missing ChatGPT credentials. Run 'python3 chatmock.py login' first.",
|
||||
status_code=401,
|
||||
)
|
||||
break
|
||||
|
||||
if upstream_ws is not None:
|
||||
try:
|
||||
upstream_ws.close()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
effective_session_id = session_id or client_session_id or ""
|
||||
try:
|
||||
upstream_ws = connect_upstream_websocket(
|
||||
build_upstream_websocket_url(),
|
||||
build_upstream_headers(
|
||||
access_token,
|
||||
account_id,
|
||||
effective_session_id,
|
||||
accept="application/json",
|
||||
),
|
||||
)
|
||||
except Exception as exc:
|
||||
if session_id:
|
||||
clear_responses_reuse_state(session_id)
|
||||
_send_error(
|
||||
f"Upstream websocket connection failed: {exc}",
|
||||
status_code=502,
|
||||
)
|
||||
break
|
||||
upstream_session_id = effective_session_id
|
||||
|
||||
upstream_ws.send(outbound_text)
|
||||
|
||||
while True:
|
||||
try:
|
||||
upstream_message = upstream_ws.recv()
|
||||
except ConnectionClosed:
|
||||
if active_session_id:
|
||||
clear_responses_reuse_state(active_session_id)
|
||||
_send_error("Upstream websocket closed unexpectedly.", status_code=502)
|
||||
return
|
||||
if upstream_message is None:
|
||||
if active_session_id:
|
||||
clear_responses_reuse_state(active_session_id)
|
||||
_send_error("Upstream websocket closed unexpectedly.", status_code=502)
|
||||
return
|
||||
if verbose:
|
||||
try:
|
||||
print("STREAM OUT WS /v1/responses\n" + str(upstream_message))
|
||||
except Exception:
|
||||
pass
|
||||
ws.send(upstream_message)
|
||||
|
||||
try:
|
||||
parsed = json.loads(upstream_message)
|
||||
except Exception:
|
||||
parsed = None
|
||||
if isinstance(parsed, dict) and active_session_id:
|
||||
note_responses_stream_event(active_session_id, parsed)
|
||||
if _is_terminal_event(parsed):
|
||||
if isinstance(parsed, dict) and parsed.get("type") in ("response.failed", "error"):
|
||||
if upstream_ws is not None:
|
||||
try:
|
||||
upstream_ws.close()
|
||||
except Exception:
|
||||
pass
|
||||
upstream_ws = None
|
||||
upstream_session_id = None
|
||||
break
|
||||
finally:
|
||||
if upstream_ws is not None:
|
||||
try:
|
||||
upstream_ws.close()
|
||||
except Exception:
|
||||
pass
|
||||
Reference in New Issue
Block a user