feat: add responses api, websocket support, and fast mode

This commit is contained in:
Game_Time
2026-03-23 15:41:42 +05:00
parent e96db19538
commit 8754203ec6
22 changed files with 2148 additions and 119 deletions

View File

@@ -1,11 +1,13 @@
from __future__ import annotations
from flask import Flask, jsonify
from flask_sock import Sock
from .config import BASE_INSTRUCTIONS, GPT5_CODEX_INSTRUCTIONS
from .http import build_cors_headers
from .routes_openai import openai_bp
from .routes_ollama import ollama_bp
from .websocket_routes import register_websocket_routes
def create_app(
@@ -14,6 +16,7 @@ def create_app(
reasoning_effort: str = "medium",
reasoning_summary: str = "auto",
reasoning_compat: str = "think-tags",
fast_mode: bool = False,
debug_model: str | None = None,
expose_reasoning_models: bool = False,
default_web_search: bool = False,
@@ -26,6 +29,7 @@ def create_app(
REASONING_EFFORT=reasoning_effort,
REASONING_SUMMARY=reasoning_summary,
REASONING_COMPAT=reasoning_compat,
FAST_MODE=bool(fast_mode),
DEBUG_MODEL=debug_model,
BASE_INSTRUCTIONS=BASE_INSTRUCTIONS,
GPT5_CODEX_INSTRUCTIONS=GPT5_CODEX_INSTRUCTIONS,
@@ -46,5 +50,7 @@ def create_app(
app.register_blueprint(openai_bp)
app.register_blueprint(ollama_bp)
sock = Sock(app)
register_websocket_routes(sock)
return app

View File

@@ -267,6 +267,7 @@ def cmd_serve(
reasoning_effort: str,
reasoning_summary: str,
reasoning_compat: str,
fast_mode: bool,
debug_model: str | None,
expose_reasoning_models: bool,
default_web_search: bool,
@@ -277,6 +278,7 @@ def cmd_serve(
reasoning_effort=reasoning_effort,
reasoning_summary=reasoning_summary,
reasoning_compat=reasoning_compat,
fast_mode=fast_mode,
debug_model=debug_model,
expose_reasoning_models=expose_reasoning_models,
default_web_search=default_web_search,
@@ -309,6 +311,12 @@ def main() -> None:
default=os.getenv("CHATGPT_LOCAL_DEBUG_MODEL"),
help="Forcibly override requested 'model' with this value",
)
p_serve.add_argument(
"--fast-mode",
action=argparse.BooleanOptionalAction,
default=(os.getenv("CHATGPT_LOCAL_FAST_MODE") or "").strip().lower() in ("1", "true", "yes", "on"),
help="Enable GPT fast mode by default for supported models; request-level overrides still take precedence.",
)
p_serve.add_argument(
"--reasoning-effort",
choices=["none", "minimal", "low", "medium", "high", "xhigh"],
@@ -366,6 +374,7 @@ def main() -> None:
reasoning_effort=args.reasoning_effort,
reasoning_summary=args.reasoning_summary,
reasoning_compat=args.reasoning_compat,
fast_mode=args.fast_mode,
debug_model=args.debug_model,
expose_reasoning_models=args.expose_reasoning_models,
default_web_search=args.enable_web_search,

92
chatmock/fast_mode.py Normal file
View File

@@ -0,0 +1,92 @@
from __future__ import annotations
from dataclasses import dataclass
from typing import Any
from .model_registry import normalize_model_name
PRIORITY_SUPPORTED_MODELS = frozenset(
(
"gpt-5.4",
"gpt-5.2",
"gpt-5.1",
"gpt-5",
"gpt-5.1-codex",
"gpt-5-codex",
)
)
_TRUE_STRINGS = {"1", "true", "yes", "on"}
_FALSE_STRINGS = {"0", "false", "no", "off"}
def parse_optional_bool(value: Any) -> bool | None:
if isinstance(value, bool):
return value
if isinstance(value, str):
normalized = value.strip().lower()
if normalized in _TRUE_STRINGS:
return True
if normalized in _FALSE_STRINGS:
return False
return None
def supports_priority_service_tier(model: str | None) -> bool:
return normalize_model_name(model) in PRIORITY_SUPPORTED_MODELS
@dataclass(frozen=True)
class ServiceTierResolution:
service_tier: str | None
error_message: str | None = None
warning_message: str | None = None
used_server_default: bool = False
def resolve_service_tier(
model: str | None,
*,
request_fast_mode: Any = None,
request_service_tier: Any = None,
server_fast_mode: bool = False,
) -> ServiceTierResolution:
explicit_fast_mode = parse_optional_bool(request_fast_mode)
tier: str | None = None
explicit_request = False
used_server_default = False
if explicit_fast_mode is not None:
tier = "priority" if explicit_fast_mode else None
explicit_request = True
elif isinstance(request_service_tier, str) and request_service_tier.strip():
tier = request_service_tier.strip().lower()
explicit_request = True
elif server_fast_mode:
tier = "priority"
used_server_default = True
if tier == "priority" and not supports_priority_service_tier(model):
normalized = normalize_model_name(model)
message = (
f"Fast mode is not supported for model '{normalized}'. "
"Use a supported GPT-5 priority-processing model or disable fast mode for this request."
)
if explicit_request:
return ServiceTierResolution(
service_tier=None,
error_message=message,
used_server_default=used_server_default,
)
return ServiceTierResolution(
service_tier=None,
warning_message=message,
used_server_default=used_server_default,
)
return ServiceTierResolution(
service_tier=tier,
used_server_default=used_server_default,
)

View File

@@ -62,6 +62,14 @@ _MODEL_SPECS = (
variant_efforts=("xhigh", "high", "medium", "low"),
uses_codex_instructions=True,
),
ModelSpec(
public_id="gpt-5.3-codex-spark",
upstream_id="gpt-5.3-codex-spark",
aliases=("gpt5.3-codex-spark", "gpt-5.3-codex-spark-latest"),
allowed_efforts=frozenset(("low", "medium", "high", "xhigh")),
variant_efforts=("xhigh", "high", "medium", "low"),
uses_codex_instructions=True,
),
ModelSpec(
public_id="gpt-5-codex",
upstream_id="gpt-5-codex",

242
chatmock/responses_api.py Normal file
View File

@@ -0,0 +1,242 @@
from __future__ import annotations
import json
from dataclasses import dataclass
from typing import Any, Dict, Iterable, Iterator, List
from .config import BASE_INSTRUCTIONS, GPT5_CODEX_INSTRUCTIONS
from .fast_mode import ServiceTierResolution, resolve_service_tier
from .model_registry import (
allowed_efforts_for_model,
extract_reasoning_from_model_name,
normalize_model_name,
uses_codex_instructions,
)
from .reasoning import build_reasoning_param
from .session import ensure_session_id
@dataclass(frozen=True)
class ResponsesRequestError(Exception):
message: str
status_code: int = 400
code: str | None = None
def __str__(self) -> str:
return self.message
@dataclass(frozen=True)
class NormalizedResponsesRequest:
payload: Dict[str, Any]
requested_model: str | None
normalized_model: str
session_id: str
service_tier_resolution: ServiceTierResolution
def instructions_for_model(config: Dict[str, Any], model: str) -> str:
base = config.get("BASE_INSTRUCTIONS", BASE_INSTRUCTIONS)
if uses_codex_instructions(model):
codex = config.get("GPT5_CODEX_INSTRUCTIONS") or GPT5_CODEX_INSTRUCTIONS
if isinstance(codex, str) and codex.strip():
return codex
return base
def extract_client_session_id(headers: Any) -> str | None:
try:
return headers.get("X-Session-Id") or headers.get("session_id") or None
except Exception:
return None
def _input_items_for_session(raw_input: Any) -> List[Dict[str, Any]]:
if isinstance(raw_input, list):
return [item for item in raw_input if isinstance(item, dict)]
if isinstance(raw_input, dict):
return [raw_input]
if isinstance(raw_input, str) and raw_input.strip():
return [
{
"type": "message",
"role": "user",
"content": [{"type": "input_text", "text": raw_input}],
}
]
return []
def canonicalize_responses_input(raw_input: Any) -> Any:
if isinstance(raw_input, list):
return [item for item in raw_input if isinstance(item, dict)]
if isinstance(raw_input, dict):
return [raw_input]
if isinstance(raw_input, str):
return _input_items_for_session(raw_input)
return raw_input
def normalize_responses_payload(
payload: Dict[str, Any],
*,
config: Dict[str, Any],
client_session_id: str | None = None,
) -> NormalizedResponsesRequest:
requested_model = payload.get("model") if isinstance(payload.get("model"), str) else None
normalized_model = normalize_model_name(requested_model, config.get("DEBUG_MODEL"))
normalized = dict(payload)
normalized["model"] = normalized_model
if "input" in normalized:
normalized["input"] = canonicalize_responses_input(normalized.get("input"))
if "store" not in normalized:
normalized["store"] = False
instructions = normalized.get("instructions")
if not isinstance(instructions, str) or not instructions.strip():
instructions = instructions_for_model(config, normalized_model)
normalized["instructions"] = instructions
reasoning_effort = config.get("REASONING_EFFORT", "medium")
reasoning_summary = config.get("REASONING_SUMMARY", "auto")
reasoning_overrides = (
normalized.get("reasoning")
if isinstance(normalized.get("reasoning"), dict)
else extract_reasoning_from_model_name(requested_model)
)
normalized["reasoning"] = build_reasoning_param(
reasoning_effort,
reasoning_summary,
reasoning_overrides,
allowed_efforts=allowed_efforts_for_model(normalized_model),
)
include = normalized.get("include")
include_list = [item for item in include if isinstance(item, str)] if isinstance(include, list) else []
if "reasoning.encrypted_content" not in include_list:
include_list.append("reasoning.encrypted_content")
normalized["include"] = include_list
tools = normalized.get("tools")
if (not isinstance(tools, list) or not tools) and bool(config.get("DEFAULT_WEB_SEARCH")):
tool_choice = normalized.get("tool_choice")
if not (isinstance(tool_choice, str) and tool_choice.strip().lower() == "none"):
normalized["tools"] = [{"type": "web_search"}]
service_tier_resolution = resolve_service_tier(
normalized_model,
request_fast_mode=normalized.get("fast_mode"),
request_service_tier=normalized.get("service_tier"),
server_fast_mode=bool(config.get("FAST_MODE")),
)
if service_tier_resolution.error_message:
raise ResponsesRequestError(service_tier_resolution.error_message)
if service_tier_resolution.service_tier is None:
normalized.pop("service_tier", None)
else:
normalized["service_tier"] = service_tier_resolution.service_tier
normalized.pop("fast_mode", None)
input_items = _input_items_for_session(normalized.get("input"))
session_id = ensure_session_id(instructions, input_items, client_session_id)
prompt_cache_key = normalized.get("prompt_cache_key")
if not isinstance(prompt_cache_key, str) or not prompt_cache_key.strip():
normalized["prompt_cache_key"] = session_id
return NormalizedResponsesRequest(
payload=normalized,
requested_model=requested_model,
normalized_model=normalized_model,
session_id=session_id,
service_tier_resolution=service_tier_resolution,
)
def iter_sse_event_payloads(upstream: Any) -> Iterator[Dict[str, Any]]:
for raw in upstream.iter_lines(decode_unicode=False):
if not raw:
continue
line = raw.decode("utf-8", errors="ignore") if isinstance(raw, (bytes, bytearray)) else raw
if not line.startswith("data: "):
continue
data = line[len("data: ") :].strip()
if not data or data == "[DONE]":
if data == "[DONE]":
break
continue
try:
evt = json.loads(data)
except Exception:
continue
if isinstance(evt, dict):
yield evt
def aggregate_response_from_sse(
upstream: Any,
*,
on_event: Any | None = None,
) -> tuple[Dict[str, Any] | None, Dict[str, Any] | None]:
response_obj: Dict[str, Any] | None = None
error_obj: Dict[str, Any] | None = None
try:
for evt in iter_sse_event_payloads(upstream):
if callable(on_event):
try:
on_event(evt)
except Exception:
pass
response = evt.get("response")
if isinstance(response, dict):
response_obj = response
kind = evt.get("type")
if kind == "response.failed":
if isinstance(response, dict) and isinstance(response.get("error"), dict):
error_obj = {"error": response.get("error")}
else:
error_obj = {"error": {"message": "response.failed"}}
break
if kind == "response.completed":
break
finally:
upstream.close()
return response_obj, error_obj
def stream_upstream_bytes(
upstream: Any,
*,
on_event: Any | None = None,
) -> Iterable[bytes]:
buffer = b""
try:
for chunk in upstream.iter_content(chunk_size=None):
if chunk:
if callable(on_event):
if isinstance(chunk, bytes):
buffer += chunk
else:
buffer += str(chunk).encode("utf-8", errors="ignore")
while b"\n" in buffer:
line, buffer = buffer.split(b"\n", 1)
line = line.rstrip(b"\r")
if not line.startswith(b"data: "):
continue
data = line[len(b"data: ") :].strip()
if not data or data == b"[DONE]":
continue
try:
evt = json.loads(data.decode("utf-8", errors="ignore"))
except Exception:
evt = None
if isinstance(evt, dict):
try:
on_event(evt)
except Exception:
pass
yield chunk
finally:
upstream.close()

View File

@@ -8,9 +8,11 @@ from typing import Any, Dict, List
from flask import Blueprint, Response, current_app, jsonify, make_response, request, stream_with_context
from .config import BASE_INSTRUCTIONS, GPT5_CODEX_INSTRUCTIONS
from .fast_mode import resolve_service_tier
from .limits import record_rate_limits_from_response
from .http import build_cors_headers
from .model_registry import list_public_models, uses_codex_instructions
from .responses_api import instructions_for_model
from .reasoning import (
allowed_efforts_for_model,
build_reasoning_param,
@@ -71,12 +73,7 @@ def ollama_version() -> Response:
def _instructions_for_model(model: str) -> str:
base = current_app.config.get("BASE_INSTRUCTIONS", BASE_INSTRUCTIONS)
if uses_codex_instructions(model):
codex = current_app.config.get("GPT5_CODEX_INSTRUCTIONS") or GPT5_CODEX_INSTRUCTIONS
if isinstance(codex, str) and codex.strip():
return codex
return base
return instructions_for_model(current_app.config, model)
_OLLAMA_FAKE_EVAL = {
@@ -254,6 +251,19 @@ def ollama_chat() -> Response:
model_reasoning = extract_reasoning_from_model_name(model)
normalized_model = normalize_model_name(model)
service_tier_resolution = resolve_service_tier(
normalized_model,
request_fast_mode=payload.get("fast_mode"),
request_service_tier=payload.get("service_tier"),
server_fast_mode=bool(current_app.config.get("FAST_MODE")),
)
if service_tier_resolution.warning_message and verbose:
print(f"[FastMode] {service_tier_resolution.warning_message}")
if service_tier_resolution.error_message:
err = {"error": service_tier_resolution.error_message}
if verbose:
_log_json("OUT POST /api/chat", err)
return jsonify(err), 400
upstream, error_resp = start_upstream_request(
normalized_model,
input_items,
@@ -267,6 +277,7 @@ def ollama_chat() -> Response:
model_reasoning,
allowed_efforts=allowed_efforts_for_model(model),
),
service_tier=service_tier_resolution.service_tier,
)
if error_resp is not None:
if verbose:
@@ -307,6 +318,7 @@ def ollama_chat() -> Response:
model_reasoning,
allowed_efforts=allowed_efforts_for_model(model),
),
service_tier=service_tier_resolution.service_tier,
)
record_rate_limits_from_response(upstream2)
if err2 is None and upstream2 is not None and upstream2.status_code < 400:

View File

@@ -7,16 +7,31 @@ from typing import Any, Dict, List
from flask import Blueprint, Response, current_app, jsonify, make_response, request
from .config import BASE_INSTRUCTIONS, GPT5_CODEX_INSTRUCTIONS
from .fast_mode import resolve_service_tier
from .limits import record_rate_limits_from_response
from .http import build_cors_headers
from .model_registry import list_public_models, uses_codex_instructions
from .responses_api import (
ResponsesRequestError,
aggregate_response_from_sse,
extract_client_session_id,
instructions_for_model,
normalize_responses_payload,
stream_upstream_bytes,
)
from .reasoning import (
allowed_efforts_for_model,
apply_reasoning_to_message,
build_reasoning_param,
extract_reasoning_from_model_name,
)
from .upstream import normalize_model_name, start_upstream_request
from .session import (
clear_responses_reuse_state,
note_responses_final_response,
note_responses_stream_event,
prepare_responses_request_for_session,
)
from .upstream import normalize_model_name, start_upstream_raw_request, start_upstream_request
from .utils import (
convert_chat_messages_to_responses_input,
convert_tools_chat_to_responses,
@@ -59,12 +74,32 @@ def _wrap_stream_logging(label: str, iterator, enabled: bool):
def _instructions_for_model(model: str) -> str:
base = current_app.config.get("BASE_INSTRUCTIONS", BASE_INSTRUCTIONS)
if uses_codex_instructions(model):
codex = current_app.config.get("GPT5_CODEX_INSTRUCTIONS") or GPT5_CODEX_INSTRUCTIONS
if isinstance(codex, str) and codex.strip():
return codex
return base
return instructions_for_model(current_app.config, model)
def _service_tier_from_payload(
model: str,
payload: Dict[str, Any],
*,
verbose: bool = False,
) -> tuple[str | None, Response | None]:
resolution = resolve_service_tier(
model,
request_fast_mode=payload.get("fast_mode"),
request_service_tier=payload.get("service_tier"),
server_fast_mode=bool(current_app.config.get("FAST_MODE")),
)
if resolution.warning_message and verbose:
print(f"[FastMode] {resolution.warning_message}")
if resolution.error_message:
err = {"error": {"message": resolution.error_message}}
if verbose:
_log_json("OUT POST service_tier resolution", err)
resp = make_response(jsonify(err), 400)
for k, v in build_cors_headers().items():
resp.headers.setdefault(k, v)
return None, resp
return resolution.service_tier, None
@openai_bp.route("/v1/chat/completions", methods=["POST"])
@@ -178,6 +213,9 @@ def chat_completions() -> Response:
reasoning_overrides,
allowed_efforts=allowed_efforts_for_model(model),
)
service_tier, tier_error = _service_tier_from_payload(model, payload, verbose=verbose)
if tier_error is not None:
return tier_error
upstream, error_resp = start_upstream_request(
model,
@@ -187,6 +225,7 @@ def chat_completions() -> Response:
tool_choice=tool_choice,
parallel_tool_calls=parallel_tool_calls,
reasoning_param=reasoning_param,
service_tier=service_tier,
)
if error_resp is not None:
if verbose:
@@ -224,6 +263,7 @@ def chat_completions() -> Response:
tool_choice=safe_choice,
parallel_tool_calls=parallel_tool_calls,
reasoning_param=reasoning_param,
service_tier=service_tier,
)
record_rate_limits_from_response(upstream2)
if err2 is None and upstream2 is not None and upstream2.status_code < 400:
@@ -413,11 +453,15 @@ def completions() -> Response:
reasoning_overrides,
allowed_efforts=allowed_efforts_for_model(model),
)
service_tier, tier_error = _service_tier_from_payload(model, payload, verbose=verbose)
if tier_error is not None:
return tier_error
upstream, error_resp = start_upstream_request(
model,
input_items,
instructions=_instructions_for_model(model),
reasoning_param=reasoning_param,
service_tier=service_tier,
)
if error_resp is not None:
if verbose:
@@ -529,6 +573,161 @@ def completions() -> Response:
return resp
@openai_bp.route("/v1/responses", methods=["POST"])
def responses_create() -> Response:
verbose = bool(current_app.config.get("VERBOSE"))
raw = request.get_data(cache=True, as_text=True) or ""
if verbose:
try:
print("IN POST /v1/responses\n" + raw)
except Exception:
pass
try:
payload = json.loads(raw) if raw else {}
except Exception:
err = {"error": {"message": "Invalid JSON body"}}
if verbose:
_log_json("OUT POST /v1/responses", err)
return jsonify(err), 400
if not isinstance(payload, dict):
err = {"error": {"message": "Request body must be a JSON object"}}
if verbose:
_log_json("OUT POST /v1/responses", err)
return jsonify(err), 400
try:
normalized = normalize_responses_payload(
payload,
config=current_app.config,
client_session_id=extract_client_session_id(request.headers),
)
except ResponsesRequestError as exc:
err: Dict[str, Any] = {"error": {"message": str(exc)}}
if exc.code:
err["error"]["code"] = exc.code
if verbose:
_log_json("OUT POST /v1/responses", err)
return jsonify(err), exc.status_code
if normalized.service_tier_resolution.warning_message and verbose:
print(f"[FastMode] {normalized.service_tier_resolution.warning_message}")
prepared = prepare_responses_request_for_session(
normalized.session_id,
normalized.payload,
allow_previous_response_id=False,
)
stream_req = bool(prepared.payload.get("stream", False))
upstream_payload = dict(prepared.payload)
upstream_payload["stream"] = True
upstream, error_resp = start_upstream_raw_request(
upstream_payload,
session_id=normalized.session_id,
stream=True,
)
if error_resp is not None:
clear_responses_reuse_state(normalized.session_id)
if verbose:
try:
body = error_resp.get_data(as_text=True)
if body:
try:
parsed = json.loads(body)
except Exception:
parsed = body
_log_json("OUT POST /v1/responses", parsed)
except Exception:
pass
return error_resp
record_rate_limits_from_response(upstream)
if upstream.status_code >= 400:
try:
err_body = json.loads(upstream.content.decode("utf-8", errors="ignore")) if upstream.content else {"error": {"message": upstream.text}}
except Exception:
err_body = {"error": {"message": upstream.text or "Upstream error"}}
finally:
upstream.close()
clear_responses_reuse_state(normalized.session_id)
if verbose:
_log_json("OUT POST /v1/responses", err_body)
resp = make_response(jsonify(err_body), upstream.status_code)
for k, v in build_cors_headers().items():
resp.headers.setdefault(k, v)
return resp
if stream_req:
if verbose:
print("OUT POST /v1/responses (streaming response)")
stream_iter = _wrap_stream_logging(
"STREAM OUT /v1/responses",
stream_upstream_bytes(
upstream,
on_event=lambda evt: note_responses_stream_event(normalized.session_id, evt),
),
verbose,
)
resp = Response(
stream_iter,
status=upstream.status_code,
mimetype="text/event-stream",
headers={"Cache-Control": "no-cache", "Connection": "keep-alive"},
)
for k, v in build_cors_headers().items():
resp.headers.setdefault(k, v)
return resp
content_type = upstream.headers.get("Content-Type", "")
if "application/json" in content_type.lower():
try:
body = upstream.json()
except Exception:
body = None
finally:
upstream.close()
if isinstance(body, dict):
note_responses_final_response(normalized.session_id, body)
if verbose:
_log_json("OUT POST /v1/responses", body)
resp = make_response(jsonify(body), upstream.status_code)
for k, v in build_cors_headers().items():
resp.headers.setdefault(k, v)
return resp
response_obj, error_obj = aggregate_response_from_sse(
upstream,
on_event=lambda evt: note_responses_stream_event(normalized.session_id, evt),
)
if error_obj is not None:
clear_responses_reuse_state(normalized.session_id)
if verbose:
_log_json("OUT POST /v1/responses", error_obj)
resp = make_response(jsonify(error_obj), 502)
for k, v in build_cors_headers().items():
resp.headers.setdefault(k, v)
return resp
if response_obj is None:
clear_responses_reuse_state(normalized.session_id)
err = {"error": {"message": "Upstream response stream did not contain a completed response object"}}
if verbose:
_log_json("OUT POST /v1/responses", err)
resp = make_response(jsonify(err), 502)
for k, v in build_cors_headers().items():
resp.headers.setdefault(k, v)
return resp
if verbose:
_log_json("OUT POST /v1/responses", response_obj)
resp = make_response(jsonify(response_obj), upstream.status_code)
for k, v in build_cors_headers().items():
resp.headers.setdefault(k, v)
return resp
@openai_bp.route("/v1/models", methods=["GET"])
def list_models() -> Response:
expose_variants = bool(current_app.config.get("EXPOSE_REASONING_MODELS"))

View File

@@ -1,16 +1,37 @@
from __future__ import annotations
import copy
import hashlib
import json
import threading
import uuid
from typing import Any, Dict, List, Tuple
from dataclasses import dataclass, field
from typing import Any, Dict, List
_LOCK = threading.Lock()
_FINGERPRINT_TO_UUID: Dict[str, str] = {}
_ORDER: List[str] = []
_MAX_ENTRIES = 10000
_RESPONSES_SESSION_STATE: Dict[str, "_ResponsesSessionState"] = {}
_RESPONSES_ORDER: List[str] = []
@dataclass(frozen=True)
class PreparedResponsesRequest:
payload: Dict[str, Any]
session_id: str
@dataclass
class _ResponsesSessionState:
last_request_payload: Dict[str, Any] | None = None
last_response_id: str | None = None
last_response_items: List[Dict[str, Any]] = field(default_factory=list)
inflight_request_payload: Dict[str, Any] | None = None
inflight_track_result: bool = False
inflight_response_id: str | None = None
inflight_response_items: List[Dict[str, Any]] = field(default_factory=list)
def _canonicalize_first_user_message(input_items: List[Dict[str, Any]]) -> Dict[str, Any] | None:
@@ -70,6 +91,61 @@ def _remember(fp: str, sid: str) -> None:
_FINGERPRINT_TO_UUID.pop(oldest, None)
def _remember_responses_session(session_id: str) -> _ResponsesSessionState:
state = _RESPONSES_SESSION_STATE.get(session_id)
if state is None:
state = _ResponsesSessionState()
_RESPONSES_SESSION_STATE[session_id] = state
_RESPONSES_ORDER.append(session_id)
if len(_RESPONSES_ORDER) > _MAX_ENTRIES:
oldest = _RESPONSES_ORDER.pop(0)
_RESPONSES_SESSION_STATE.pop(oldest, None)
return state
def _request_without_input(payload: Dict[str, Any]) -> Dict[str, Any]:
clone = copy.deepcopy(payload)
clone["input"] = []
clone.pop("previous_response_id", None)
return clone
def _input_list(payload: Dict[str, Any]) -> List[Dict[str, Any]] | None:
raw = payload.get("input")
if not isinstance(raw, list):
return None
return [item for item in copy.deepcopy(raw) if isinstance(item, dict)]
def _conversation_output_items(items: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
reusable: List[Dict[str, Any]] = []
for item in items:
if not isinstance(item, dict):
continue
item_type = item.get("type")
if item_type == "reasoning":
continue
reusable.append(copy.deepcopy(item))
return reusable
def _clear_reuse_state(state: _ResponsesSessionState) -> None:
state.last_request_payload = None
state.last_response_id = None
state.last_response_items = []
state.inflight_request_payload = None
state.inflight_track_result = False
state.inflight_response_id = None
state.inflight_response_items = []
def _clear_inflight(state: _ResponsesSessionState) -> None:
state.inflight_request_payload = None
state.inflight_track_result = False
state.inflight_response_id = None
state.inflight_response_items = []
def ensure_session_id(
instructions: str | None,
input_items: List[Dict[str, Any]],
@@ -87,3 +163,150 @@ def ensure_session_id(
_remember(fp, sid)
return sid
def prepare_responses_request_for_session(
session_id: str,
payload: Dict[str, Any],
*,
allow_previous_response_id: bool = True,
) -> PreparedResponsesRequest:
full_payload = copy.deepcopy(payload)
outbound_payload = copy.deepcopy(payload)
explicit_previous_response_id = (
isinstance(full_payload.get("previous_response_id"), str)
and bool(full_payload.get("previous_response_id").strip())
)
with _LOCK:
state = _remember_responses_session(session_id)
if explicit_previous_response_id:
_clear_reuse_state(state)
return PreparedResponsesRequest(
payload=outbound_payload,
session_id=session_id,
)
request_input = _input_list(full_payload)
if (
allow_previous_response_id
and
state.last_request_payload is not None
and state.last_response_id
and request_input is not None
and _request_without_input(state.last_request_payload) == _request_without_input(full_payload)
):
baseline: List[Dict[str, Any]] = []
previous_input = _input_list(state.last_request_payload)
if previous_input is not None:
baseline.extend(previous_input)
baseline.extend(copy.deepcopy(state.last_response_items))
baseline_len = len(baseline)
if request_input[:baseline_len] == baseline and baseline_len <= len(request_input):
outbound_payload["input"] = copy.deepcopy(request_input[baseline_len:])
outbound_payload["previous_response_id"] = state.last_response_id
state.inflight_request_payload = full_payload
state.inflight_track_result = True
state.inflight_response_id = None
state.inflight_response_items = []
return PreparedResponsesRequest(
payload=outbound_payload,
session_id=session_id,
)
def note_responses_stream_event(session_id: str, event: Dict[str, Any]) -> None:
if not isinstance(session_id, str) or not session_id.strip():
return
if not isinstance(event, dict):
return
with _LOCK:
state = _RESPONSES_SESSION_STATE.get(session_id)
if state is None:
return
kind = event.get("type")
if kind == "response.created":
response = event.get("response")
if isinstance(response, dict) and isinstance(response.get("id"), str):
state.inflight_response_id = response.get("id")
return
if kind == "response.output_item.done":
item = event.get("item")
if isinstance(item, dict):
state.inflight_response_items.append(copy.deepcopy(item))
return
if kind == "response.completed":
response = event.get("response")
response_id = None
response_items: List[Dict[str, Any]] = copy.deepcopy(state.inflight_response_items)
if isinstance(response, dict):
if isinstance(response.get("id"), str):
response_id = response.get("id")
output = response.get("output")
if isinstance(output, list) and output:
response_items = [copy.deepcopy(item) for item in output if isinstance(item, dict)]
if not response_id:
response_id = state.inflight_response_id
if state.inflight_track_result and state.inflight_request_payload is not None and response_id:
state.last_request_payload = copy.deepcopy(state.inflight_request_payload)
state.last_response_id = response_id
state.last_response_items = _conversation_output_items(response_items)
else:
state.last_request_payload = None
state.last_response_id = None
state.last_response_items = []
_clear_inflight(state)
return
if kind in ("response.failed", "error"):
_clear_reuse_state(state)
def note_responses_final_response(session_id: str, response_obj: Dict[str, Any]) -> None:
if not isinstance(session_id, str) or not session_id.strip():
return
if not isinstance(response_obj, dict):
return
with _LOCK:
state = _RESPONSES_SESSION_STATE.get(session_id)
if state is None:
return
response_id = response_obj.get("id") if isinstance(response_obj.get("id"), str) else None
output = response_obj.get("output")
output_items = [copy.deepcopy(item) for item in output if isinstance(item, dict)] if isinstance(output, list) else []
if state.inflight_track_result and state.inflight_request_payload is not None and response_id:
state.last_request_payload = copy.deepcopy(state.inflight_request_payload)
state.last_response_id = response_id
state.last_response_items = _conversation_output_items(output_items)
else:
state.last_request_payload = None
state.last_response_id = None
state.last_response_items = []
_clear_inflight(state)
def clear_responses_reuse_state(session_id: str) -> None:
if not isinstance(session_id, str) or not session_id.strip():
return
with _LOCK:
state = _RESPONSES_SESSION_STATE.get(session_id)
if state is None:
return
_clear_reuse_state(state)
def reset_session_state() -> None:
with _LOCK:
_FINGERPRINT_TO_UUID.clear()
_ORDER.clear()
_RESPONSES_SESSION_STATE.clear()
_RESPONSES_ORDER.clear()

View File

@@ -3,6 +3,7 @@ from __future__ import annotations
import json
import time
from typing import Any, Dict, List, Tuple
from urllib.parse import urlparse, urlunparse
import requests
from flask import Response, current_app, jsonify, make_response
@@ -33,6 +34,7 @@ def start_upstream_request(
tool_choice: Any | None = None,
parallel_tool_calls: bool = False,
reasoning_param: Dict[str, Any] | None = None,
service_tier: str | None = None,
):
access_token, account_id = get_effective_chatgpt_auth()
if not access_token or not account_id:
@@ -81,6 +83,62 @@ def start_upstream_request(
if reasoning_param is not None:
responses_payload["reasoning"] = reasoning_param
if isinstance(service_tier, str) and service_tier.strip():
responses_payload["service_tier"] = service_tier.strip().lower()
return start_upstream_raw_request(
responses_payload,
session_id=session_id,
stream=True,
)
def build_upstream_headers(
access_token: str,
account_id: str,
session_id: str,
*,
accept: str = "text/event-stream",
) -> Dict[str, str]:
return {
"Authorization": f"Bearer {access_token}",
"Content-Type": "application/json",
"Accept": accept,
"chatgpt-account-id": account_id,
"OpenAI-Beta": "responses=experimental",
"session_id": session_id,
}
def start_upstream_raw_request(
responses_payload: Dict[str, Any],
*,
session_id: str | None = None,
stream: bool = True,
):
access_token, account_id = get_effective_chatgpt_auth()
if not access_token or not account_id:
resp = make_response(
jsonify(
{
"error": {
"message": "Missing ChatGPT credentials. Run 'python3 chatmock.py login' first.",
}
}
),
401,
)
for k, v in build_cors_headers().items():
resp.headers.setdefault(k, v)
return None, resp
effective_session_id = session_id
if not isinstance(effective_session_id, str) or not effective_session_id.strip():
payload_prompt_cache_key = responses_payload.get("prompt_cache_key")
if isinstance(payload_prompt_cache_key, str) and payload_prompt_cache_key.strip():
effective_session_id = payload_prompt_cache_key.strip()
if not isinstance(effective_session_id, str) or not effective_session_id.strip():
effective_session_id = str(int(time.time() * 1000))
verbose = False
try:
@@ -90,21 +148,19 @@ def start_upstream_request(
if verbose:
_log_json("OUTBOUND >> ChatGPT Responses API payload", responses_payload)
headers = {
"Authorization": f"Bearer {access_token}",
"Content-Type": "application/json",
"Accept": "text/event-stream",
"chatgpt-account-id": account_id,
"OpenAI-Beta": "responses=experimental",
"session_id": session_id,
}
headers = build_upstream_headers(
access_token,
account_id,
effective_session_id,
accept=("text/event-stream" if stream else "application/json"),
)
try:
upstream = requests.post(
CHATGPT_RESPONSES_URL,
headers=headers,
json=responses_payload,
stream=True,
stream=stream,
timeout=600,
)
except requests.RequestException as e:
@@ -113,3 +169,13 @@ def start_upstream_request(
resp.headers.setdefault(k, v)
return None, resp
return upstream, None
def build_upstream_websocket_url() -> str:
parsed = urlparse(CHATGPT_RESPONSES_URL)
scheme = parsed.scheme.lower()
if scheme == "https":
parsed = parsed._replace(scheme="wss")
elif scheme == "http":
parsed = parsed._replace(scheme="ws")
return urlunparse(parsed)

View File

@@ -1,4 +1,4 @@
from __future__ import annotations
__version__ = "1.36"
__version__ = "1.37"

View File

@@ -0,0 +1,225 @@
from __future__ import annotations
import json
import os
import ssl
from typing import Any, Dict
import certifi
from flask import current_app, request
from flask_sock import Sock
from websockets.sync.client import connect as websocket_connect
from websockets.exceptions import ConnectionClosed
from .responses_api import (
ResponsesRequestError,
extract_client_session_id,
normalize_responses_payload,
)
from .session import (
clear_responses_reuse_state,
note_responses_stream_event,
prepare_responses_request_for_session,
)
from .upstream import build_upstream_headers, build_upstream_websocket_url
from .utils import get_effective_chatgpt_auth
def _log_json(prefix: str, payload: Any) -> None:
try:
print(f"{prefix}\n{json.dumps(payload, indent=2, ensure_ascii=False)}")
except Exception:
try:
print(f"{prefix}\n{payload}")
except Exception:
pass
def _error_event(message: str, *, status_code: int = 400, code: str | None = None) -> Dict[str, Any]:
error: Dict[str, Any] = {"message": message}
if code:
error["code"] = code
return {"type": "error", "status_code": status_code, "error": error}
def _is_terminal_event(event: Any) -> bool:
if not isinstance(event, dict):
return False
kind = event.get("type")
return kind in ("response.completed", "response.failed", "error")
def _build_websocket_ssl_context() -> ssl.SSLContext:
cafile = (
os.getenv("CODEX_CA_CERTIFICATE")
or os.getenv("SSL_CERT_FILE")
or certifi.where()
)
return ssl.create_default_context(cafile=cafile)
def connect_upstream_websocket(url: str, headers: Dict[str, str]):
return websocket_connect(
url,
additional_headers=headers,
open_timeout=15,
ssl=_build_websocket_ssl_context(),
)
def register_websocket_routes(sock: Sock) -> None:
@sock.route("/v1/responses")
def responses_websocket(ws) -> None:
verbose = bool(current_app.config.get("VERBOSE"))
upstream_ws = None
upstream_session_id: str | None = None
active_session_id: str | None = None
def _send_error(message: str, *, status_code: int = 400, code: str | None = None) -> None:
evt = _error_event(message, status_code=status_code, code=code)
if verbose:
_log_json("STREAM OUT WS /v1/responses (error)", evt)
try:
ws.send(json.dumps(evt))
except Exception:
pass
try:
while True:
incoming = ws.receive()
if incoming is None:
break
if isinstance(incoming, bytes):
incoming_text = incoming.decode("utf-8", errors="ignore")
else:
incoming_text = str(incoming)
if verbose:
print("IN WS /v1/responses\n" + incoming_text)
try:
payload = json.loads(incoming_text)
except Exception:
_send_error("Websocket frames must be valid JSON objects.", status_code=400)
break
if not isinstance(payload, dict):
_send_error("Websocket frames must be JSON objects.", status_code=400)
break
client_session_id = extract_client_session_id(request.headers)
outbound_text = incoming_text
session_id = upstream_session_id
if payload.get("type") == "response.create":
try:
normalized = normalize_responses_payload(
payload,
config=current_app.config,
client_session_id=client_session_id,
)
except ResponsesRequestError as exc:
_send_error(str(exc), status_code=exc.status_code, code=exc.code)
continue
if normalized.service_tier_resolution.warning_message and verbose:
print(f"[FastMode] {normalized.service_tier_resolution.warning_message}")
prepared = prepare_responses_request_for_session(
normalized.session_id,
normalized.payload,
allow_previous_response_id=True,
)
outbound_text = json.dumps(prepared.payload)
session_id = normalized.session_id
active_session_id = normalized.session_id
if verbose:
_log_json("OUTBOUND >> ChatGPT Responses WS payload", prepared.payload)
elif upstream_ws is None:
_send_error(
"The first websocket message must be a response.create request.",
status_code=400,
)
break
if upstream_ws is None or (session_id and session_id != upstream_session_id):
access_token, account_id = get_effective_chatgpt_auth()
if not access_token or not account_id:
if session_id:
clear_responses_reuse_state(session_id)
_send_error(
"Missing ChatGPT credentials. Run 'python3 chatmock.py login' first.",
status_code=401,
)
break
if upstream_ws is not None:
try:
upstream_ws.close()
except Exception:
pass
effective_session_id = session_id or client_session_id or ""
try:
upstream_ws = connect_upstream_websocket(
build_upstream_websocket_url(),
build_upstream_headers(
access_token,
account_id,
effective_session_id,
accept="application/json",
),
)
except Exception as exc:
if session_id:
clear_responses_reuse_state(session_id)
_send_error(
f"Upstream websocket connection failed: {exc}",
status_code=502,
)
break
upstream_session_id = effective_session_id
upstream_ws.send(outbound_text)
while True:
try:
upstream_message = upstream_ws.recv()
except ConnectionClosed:
if active_session_id:
clear_responses_reuse_state(active_session_id)
_send_error("Upstream websocket closed unexpectedly.", status_code=502)
return
if upstream_message is None:
if active_session_id:
clear_responses_reuse_state(active_session_id)
_send_error("Upstream websocket closed unexpectedly.", status_code=502)
return
if verbose:
try:
print("STREAM OUT WS /v1/responses\n" + str(upstream_message))
except Exception:
pass
ws.send(upstream_message)
try:
parsed = json.loads(upstream_message)
except Exception:
parsed = None
if isinstance(parsed, dict) and active_session_id:
note_responses_stream_event(active_session_id, parsed)
if _is_terminal_event(parsed):
if isinstance(parsed, dict) and parsed.get("type") in ("response.failed", "error"):
if upstream_ws is not None:
try:
upstream_ws.close()
except Exception:
pass
upstream_ws = None
upstream_session_id = None
break
finally:
if upstream_ws is not None:
try:
upstream_ws.close()
except Exception:
pass