feat: add responses api, websocket support, and fast mode
This commit is contained in:
225
chatmock/websocket_routes.py
Normal file
225
chatmock/websocket_routes.py
Normal file
@@ -0,0 +1,225 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import os
|
||||
import ssl
|
||||
from typing import Any, Dict
|
||||
|
||||
import certifi
|
||||
from flask import current_app, request
|
||||
from flask_sock import Sock
|
||||
from websockets.sync.client import connect as websocket_connect
|
||||
from websockets.exceptions import ConnectionClosed
|
||||
|
||||
from .responses_api import (
|
||||
ResponsesRequestError,
|
||||
extract_client_session_id,
|
||||
normalize_responses_payload,
|
||||
)
|
||||
from .session import (
|
||||
clear_responses_reuse_state,
|
||||
note_responses_stream_event,
|
||||
prepare_responses_request_for_session,
|
||||
)
|
||||
from .upstream import build_upstream_headers, build_upstream_websocket_url
|
||||
from .utils import get_effective_chatgpt_auth
|
||||
|
||||
|
||||
def _log_json(prefix: str, payload: Any) -> None:
|
||||
try:
|
||||
print(f"{prefix}\n{json.dumps(payload, indent=2, ensure_ascii=False)}")
|
||||
except Exception:
|
||||
try:
|
||||
print(f"{prefix}\n{payload}")
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
|
||||
def _error_event(message: str, *, status_code: int = 400, code: str | None = None) -> Dict[str, Any]:
|
||||
error: Dict[str, Any] = {"message": message}
|
||||
if code:
|
||||
error["code"] = code
|
||||
return {"type": "error", "status_code": status_code, "error": error}
|
||||
|
||||
|
||||
def _is_terminal_event(event: Any) -> bool:
|
||||
if not isinstance(event, dict):
|
||||
return False
|
||||
kind = event.get("type")
|
||||
return kind in ("response.completed", "response.failed", "error")
|
||||
|
||||
|
||||
def _build_websocket_ssl_context() -> ssl.SSLContext:
|
||||
cafile = (
|
||||
os.getenv("CODEX_CA_CERTIFICATE")
|
||||
or os.getenv("SSL_CERT_FILE")
|
||||
or certifi.where()
|
||||
)
|
||||
return ssl.create_default_context(cafile=cafile)
|
||||
|
||||
|
||||
def connect_upstream_websocket(url: str, headers: Dict[str, str]):
|
||||
return websocket_connect(
|
||||
url,
|
||||
additional_headers=headers,
|
||||
open_timeout=15,
|
||||
ssl=_build_websocket_ssl_context(),
|
||||
)
|
||||
|
||||
|
||||
def register_websocket_routes(sock: Sock) -> None:
|
||||
@sock.route("/v1/responses")
|
||||
def responses_websocket(ws) -> None:
|
||||
verbose = bool(current_app.config.get("VERBOSE"))
|
||||
upstream_ws = None
|
||||
upstream_session_id: str | None = None
|
||||
active_session_id: str | None = None
|
||||
|
||||
def _send_error(message: str, *, status_code: int = 400, code: str | None = None) -> None:
|
||||
evt = _error_event(message, status_code=status_code, code=code)
|
||||
if verbose:
|
||||
_log_json("STREAM OUT WS /v1/responses (error)", evt)
|
||||
try:
|
||||
ws.send(json.dumps(evt))
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
try:
|
||||
while True:
|
||||
incoming = ws.receive()
|
||||
if incoming is None:
|
||||
break
|
||||
|
||||
if isinstance(incoming, bytes):
|
||||
incoming_text = incoming.decode("utf-8", errors="ignore")
|
||||
else:
|
||||
incoming_text = str(incoming)
|
||||
if verbose:
|
||||
print("IN WS /v1/responses\n" + incoming_text)
|
||||
|
||||
try:
|
||||
payload = json.loads(incoming_text)
|
||||
except Exception:
|
||||
_send_error("Websocket frames must be valid JSON objects.", status_code=400)
|
||||
break
|
||||
|
||||
if not isinstance(payload, dict):
|
||||
_send_error("Websocket frames must be JSON objects.", status_code=400)
|
||||
break
|
||||
|
||||
client_session_id = extract_client_session_id(request.headers)
|
||||
outbound_text = incoming_text
|
||||
session_id = upstream_session_id
|
||||
|
||||
if payload.get("type") == "response.create":
|
||||
try:
|
||||
normalized = normalize_responses_payload(
|
||||
payload,
|
||||
config=current_app.config,
|
||||
client_session_id=client_session_id,
|
||||
)
|
||||
except ResponsesRequestError as exc:
|
||||
_send_error(str(exc), status_code=exc.status_code, code=exc.code)
|
||||
continue
|
||||
|
||||
if normalized.service_tier_resolution.warning_message and verbose:
|
||||
print(f"[FastMode] {normalized.service_tier_resolution.warning_message}")
|
||||
prepared = prepare_responses_request_for_session(
|
||||
normalized.session_id,
|
||||
normalized.payload,
|
||||
allow_previous_response_id=True,
|
||||
)
|
||||
outbound_text = json.dumps(prepared.payload)
|
||||
session_id = normalized.session_id
|
||||
active_session_id = normalized.session_id
|
||||
if verbose:
|
||||
_log_json("OUTBOUND >> ChatGPT Responses WS payload", prepared.payload)
|
||||
elif upstream_ws is None:
|
||||
_send_error(
|
||||
"The first websocket message must be a response.create request.",
|
||||
status_code=400,
|
||||
)
|
||||
break
|
||||
|
||||
if upstream_ws is None or (session_id and session_id != upstream_session_id):
|
||||
access_token, account_id = get_effective_chatgpt_auth()
|
||||
if not access_token or not account_id:
|
||||
if session_id:
|
||||
clear_responses_reuse_state(session_id)
|
||||
_send_error(
|
||||
"Missing ChatGPT credentials. Run 'python3 chatmock.py login' first.",
|
||||
status_code=401,
|
||||
)
|
||||
break
|
||||
|
||||
if upstream_ws is not None:
|
||||
try:
|
||||
upstream_ws.close()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
effective_session_id = session_id or client_session_id or ""
|
||||
try:
|
||||
upstream_ws = connect_upstream_websocket(
|
||||
build_upstream_websocket_url(),
|
||||
build_upstream_headers(
|
||||
access_token,
|
||||
account_id,
|
||||
effective_session_id,
|
||||
accept="application/json",
|
||||
),
|
||||
)
|
||||
except Exception as exc:
|
||||
if session_id:
|
||||
clear_responses_reuse_state(session_id)
|
||||
_send_error(
|
||||
f"Upstream websocket connection failed: {exc}",
|
||||
status_code=502,
|
||||
)
|
||||
break
|
||||
upstream_session_id = effective_session_id
|
||||
|
||||
upstream_ws.send(outbound_text)
|
||||
|
||||
while True:
|
||||
try:
|
||||
upstream_message = upstream_ws.recv()
|
||||
except ConnectionClosed:
|
||||
if active_session_id:
|
||||
clear_responses_reuse_state(active_session_id)
|
||||
_send_error("Upstream websocket closed unexpectedly.", status_code=502)
|
||||
return
|
||||
if upstream_message is None:
|
||||
if active_session_id:
|
||||
clear_responses_reuse_state(active_session_id)
|
||||
_send_error("Upstream websocket closed unexpectedly.", status_code=502)
|
||||
return
|
||||
if verbose:
|
||||
try:
|
||||
print("STREAM OUT WS /v1/responses\n" + str(upstream_message))
|
||||
except Exception:
|
||||
pass
|
||||
ws.send(upstream_message)
|
||||
|
||||
try:
|
||||
parsed = json.loads(upstream_message)
|
||||
except Exception:
|
||||
parsed = None
|
||||
if isinstance(parsed, dict) and active_session_id:
|
||||
note_responses_stream_event(active_session_id, parsed)
|
||||
if _is_terminal_event(parsed):
|
||||
if isinstance(parsed, dict) and parsed.get("type") in ("response.failed", "error"):
|
||||
if upstream_ws is not None:
|
||||
try:
|
||||
upstream_ws.close()
|
||||
except Exception:
|
||||
pass
|
||||
upstream_ws = None
|
||||
upstream_session_id = None
|
||||
break
|
||||
finally:
|
||||
if upstream_ws is not None:
|
||||
try:
|
||||
upstream_ws.close()
|
||||
except Exception:
|
||||
pass
|
||||
Reference in New Issue
Block a user