226 lines
8.7 KiB
Python
226 lines
8.7 KiB
Python
from __future__ import annotations
|
|
|
|
import json
|
|
import os
|
|
import ssl
|
|
from typing import Any, Dict
|
|
|
|
import certifi
|
|
from flask import current_app, request
|
|
from flask_sock import Sock
|
|
from websockets.sync.client import connect as websocket_connect
|
|
from websockets.exceptions import ConnectionClosed
|
|
|
|
from .responses_api import (
|
|
ResponsesRequestError,
|
|
extract_client_session_id,
|
|
normalize_responses_payload,
|
|
)
|
|
from .session import (
|
|
clear_responses_reuse_state,
|
|
note_responses_stream_event,
|
|
prepare_responses_request_for_session,
|
|
)
|
|
from .upstream import build_upstream_headers, build_upstream_websocket_url
|
|
from .utils import get_effective_chatgpt_auth
|
|
|
|
|
|
def _log_json(prefix: str, payload: Any) -> None:
|
|
try:
|
|
print(f"{prefix}\n{json.dumps(payload, indent=2, ensure_ascii=False)}")
|
|
except Exception:
|
|
try:
|
|
print(f"{prefix}\n{payload}")
|
|
except Exception:
|
|
pass
|
|
|
|
|
|
def _error_event(message: str, *, status_code: int = 400, code: str | None = None) -> Dict[str, Any]:
|
|
error: Dict[str, Any] = {"message": message}
|
|
if code:
|
|
error["code"] = code
|
|
return {"type": "error", "status_code": status_code, "error": error}
|
|
|
|
|
|
def _is_terminal_event(event: Any) -> bool:
|
|
if not isinstance(event, dict):
|
|
return False
|
|
kind = event.get("type")
|
|
return kind in ("response.completed", "response.failed", "error")
|
|
|
|
|
|
def _build_websocket_ssl_context() -> ssl.SSLContext:
|
|
cafile = (
|
|
os.getenv("CODEX_CA_CERTIFICATE")
|
|
or os.getenv("SSL_CERT_FILE")
|
|
or certifi.where()
|
|
)
|
|
return ssl.create_default_context(cafile=cafile)
|
|
|
|
|
|
def connect_upstream_websocket(url: str, headers: Dict[str, str]):
|
|
return websocket_connect(
|
|
url,
|
|
additional_headers=headers,
|
|
open_timeout=15,
|
|
ssl=_build_websocket_ssl_context(),
|
|
)
|
|
|
|
|
|
def register_websocket_routes(sock: Sock) -> None:
|
|
@sock.route("/v1/responses")
|
|
def responses_websocket(ws) -> None:
|
|
verbose = bool(current_app.config.get("VERBOSE"))
|
|
upstream_ws = None
|
|
upstream_session_id: str | None = None
|
|
active_session_id: str | None = None
|
|
|
|
def _send_error(message: str, *, status_code: int = 400, code: str | None = None) -> None:
|
|
evt = _error_event(message, status_code=status_code, code=code)
|
|
if verbose:
|
|
_log_json("STREAM OUT WS /v1/responses (error)", evt)
|
|
try:
|
|
ws.send(json.dumps(evt))
|
|
except Exception:
|
|
pass
|
|
|
|
try:
|
|
while True:
|
|
incoming = ws.receive()
|
|
if incoming is None:
|
|
break
|
|
|
|
if isinstance(incoming, bytes):
|
|
incoming_text = incoming.decode("utf-8", errors="ignore")
|
|
else:
|
|
incoming_text = str(incoming)
|
|
if verbose:
|
|
print("IN WS /v1/responses\n" + incoming_text)
|
|
|
|
try:
|
|
payload = json.loads(incoming_text)
|
|
except Exception:
|
|
_send_error("Websocket frames must be valid JSON objects.", status_code=400)
|
|
break
|
|
|
|
if not isinstance(payload, dict):
|
|
_send_error("Websocket frames must be JSON objects.", status_code=400)
|
|
break
|
|
|
|
client_session_id = extract_client_session_id(request.headers)
|
|
outbound_text = incoming_text
|
|
session_id = upstream_session_id
|
|
|
|
if payload.get("type") == "response.create":
|
|
try:
|
|
normalized = normalize_responses_payload(
|
|
payload,
|
|
config=current_app.config,
|
|
client_session_id=client_session_id,
|
|
)
|
|
except ResponsesRequestError as exc:
|
|
_send_error(str(exc), status_code=exc.status_code, code=exc.code)
|
|
continue
|
|
|
|
if normalized.service_tier_resolution.warning_message and verbose:
|
|
print(f"[FastMode] {normalized.service_tier_resolution.warning_message}")
|
|
prepared = prepare_responses_request_for_session(
|
|
normalized.session_id,
|
|
normalized.payload,
|
|
allow_previous_response_id=True,
|
|
)
|
|
outbound_text = json.dumps(prepared.payload)
|
|
session_id = normalized.session_id
|
|
active_session_id = normalized.session_id
|
|
if verbose:
|
|
_log_json("OUTBOUND >> ChatGPT Responses WS payload", prepared.payload)
|
|
elif upstream_ws is None:
|
|
_send_error(
|
|
"The first websocket message must be a response.create request.",
|
|
status_code=400,
|
|
)
|
|
break
|
|
|
|
if upstream_ws is None or (session_id and session_id != upstream_session_id):
|
|
access_token, account_id = get_effective_chatgpt_auth()
|
|
if not access_token or not account_id:
|
|
if session_id:
|
|
clear_responses_reuse_state(session_id)
|
|
_send_error(
|
|
"Missing ChatGPT credentials. Run 'python3 chatmock.py login' first.",
|
|
status_code=401,
|
|
)
|
|
break
|
|
|
|
if upstream_ws is not None:
|
|
try:
|
|
upstream_ws.close()
|
|
except Exception:
|
|
pass
|
|
|
|
effective_session_id = session_id or client_session_id or ""
|
|
try:
|
|
upstream_ws = connect_upstream_websocket(
|
|
build_upstream_websocket_url(),
|
|
build_upstream_headers(
|
|
access_token,
|
|
account_id,
|
|
effective_session_id,
|
|
accept="application/json",
|
|
),
|
|
)
|
|
except Exception as exc:
|
|
if session_id:
|
|
clear_responses_reuse_state(session_id)
|
|
_send_error(
|
|
f"Upstream websocket connection failed: {exc}",
|
|
status_code=502,
|
|
)
|
|
break
|
|
upstream_session_id = effective_session_id
|
|
|
|
upstream_ws.send(outbound_text)
|
|
|
|
while True:
|
|
try:
|
|
upstream_message = upstream_ws.recv()
|
|
except ConnectionClosed:
|
|
if active_session_id:
|
|
clear_responses_reuse_state(active_session_id)
|
|
_send_error("Upstream websocket closed unexpectedly.", status_code=502)
|
|
return
|
|
if upstream_message is None:
|
|
if active_session_id:
|
|
clear_responses_reuse_state(active_session_id)
|
|
_send_error("Upstream websocket closed unexpectedly.", status_code=502)
|
|
return
|
|
if verbose:
|
|
try:
|
|
print("STREAM OUT WS /v1/responses\n" + str(upstream_message))
|
|
except Exception:
|
|
pass
|
|
ws.send(upstream_message)
|
|
|
|
try:
|
|
parsed = json.loads(upstream_message)
|
|
except Exception:
|
|
parsed = None
|
|
if isinstance(parsed, dict) and active_session_id:
|
|
note_responses_stream_event(active_session_id, parsed)
|
|
if _is_terminal_event(parsed):
|
|
if isinstance(parsed, dict) and parsed.get("type") in ("response.failed", "error"):
|
|
if upstream_ws is not None:
|
|
try:
|
|
upstream_ws.close()
|
|
except Exception:
|
|
pass
|
|
upstream_ws = None
|
|
upstream_session_id = None
|
|
break
|
|
finally:
|
|
if upstream_ws is not None:
|
|
try:
|
|
upstream_ws.close()
|
|
except Exception:
|
|
pass
|