hopefully patch token counting?

This commit is contained in:
Game_Time
2025-08-22 02:01:15 +05:00
parent c8c6540d23
commit cadd959778
20 changed files with 112 additions and 6 deletions

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

View File

@@ -63,6 +63,8 @@ def chat_completions() -> Response:
content = sys_msg.get("content") if isinstance(sys_msg, dict) else "" content = sys_msg.get("content") if isinstance(sys_msg, dict) else ""
messages.insert(0, {"role": "user", "content": content}) messages.insert(0, {"role": "user", "content": content})
is_stream = bool(payload.get("stream")) is_stream = bool(payload.get("stream"))
stream_options = payload.get("stream_options") if isinstance(payload.get("stream_options"), dict) else {}
include_usage = bool(stream_options.get("include_usage", False))
tools_responses = convert_tools_chat_to_responses(payload.get("tools")) tools_responses = convert_tools_chat_to_responses(payload.get("tools"))
tool_choice = payload.get("tool_choice", "auto") tool_choice = payload.get("tool_choice", "auto")
@@ -85,6 +87,7 @@ def chat_completions() -> Response:
tool_choice=tool_choice, tool_choice=tool_choice,
parallel_tool_calls=parallel_tool_calls, parallel_tool_calls=parallel_tool_calls,
reasoning_param=reasoning_param, reasoning_param=reasoning_param,
include_usage=(not is_stream) or include_usage,
) )
if error_resp is not None: if error_resp is not None:
return error_resp return error_resp
@@ -112,6 +115,7 @@ def chat_completions() -> Response:
verbose=verbose, verbose=verbose,
vlog=print if verbose else None, vlog=print if verbose else None,
reasoning_compat=reasoning_compat, reasoning_compat=reasoning_compat,
include_usage=include_usage,
), ),
status=upstream.status_code, status=upstream.status_code,
mimetype="text/event-stream", mimetype="text/event-stream",
@@ -127,6 +131,19 @@ def chat_completions() -> Response:
response_id = "chatcmpl" response_id = "chatcmpl"
tool_calls: List[Dict[str, Any]] = [] tool_calls: List[Dict[str, Any]] = []
error_message: str | None = None error_message: str | None = None
usage_obj: Dict[str, int] | None = None
def _extract_usage(evt: Dict[str, Any]) -> Dict[str, int] | None:
try:
usage = (evt.get("response") or {}).get("usage")
if not isinstance(usage, dict):
return None
pt = int(usage.get("input_tokens") or 0)
ct = int(usage.get("output_tokens") or 0)
tt = int(usage.get("total_tokens") or (pt + ct))
return {"prompt_tokens": pt, "completion_tokens": ct, "total_tokens": tt}
except Exception:
return None
try: try:
for raw in upstream.iter_lines(decode_unicode=False): for raw in upstream.iter_lines(decode_unicode=False):
if not raw: if not raw:
@@ -144,6 +161,9 @@ def chat_completions() -> Response:
except Exception: except Exception:
continue continue
kind = evt.get("type") kind = evt.get("type")
mu = _extract_usage(evt)
if mu:
usage_obj = mu
if isinstance(evt.get("response"), dict) and isinstance(evt["response"].get("id"), str): if isinstance(evt.get("response"), dict) and isinstance(evt["response"].get("id"), str):
response_id = evt["response"].get("id") or response_id response_id = evt["response"].get("id") or response_id
if kind == "response.output_text.delta": if kind == "response.output_text.delta":
@@ -183,7 +203,6 @@ def chat_completions() -> Response:
if tool_calls: if tool_calls:
message["tool_calls"] = tool_calls message["tool_calls"] = tool_calls
message = apply_reasoning_to_message(message, reasoning_summary_text, reasoning_full_text, reasoning_compat) message = apply_reasoning_to_message(message, reasoning_summary_text, reasoning_full_text, reasoning_compat)
completion = { completion = {
"id": response_id or "chatcmpl", "id": response_id or "chatcmpl",
"object": "chat.completion", "object": "chat.completion",
@@ -196,6 +215,7 @@ def chat_completions() -> Response:
"finish_reason": "stop", "finish_reason": "stop",
} }
], ],
**({"usage": usage_obj} if usage_obj else {}),
} }
resp = make_response(jsonify(completion), upstream.status_code) resp = make_response(jsonify(completion), upstream.status_code)
for k, v in build_cors_headers().items(): for k, v in build_cors_headers().items():
@@ -223,6 +243,8 @@ def completions() -> Response:
if not isinstance(prompt, str): if not isinstance(prompt, str):
prompt = payload.get("suffix") or "" prompt = payload.get("suffix") or ""
stream_req = bool(payload.get("stream", False)) stream_req = bool(payload.get("stream", False))
stream_options = payload.get("stream_options") if isinstance(payload.get("stream_options"), dict) else {}
include_usage = bool(stream_options.get("include_usage", False))
messages = [{"role": "user", "content": prompt or ""}] messages = [{"role": "user", "content": prompt or ""}]
input_items = convert_chat_messages_to_responses_input(messages) input_items = convert_chat_messages_to_responses_input(messages)
@@ -234,6 +256,7 @@ def completions() -> Response:
input_items, input_items,
instructions=BASE_INSTRUCTIONS, instructions=BASE_INSTRUCTIONS,
reasoning_param=reasoning_param, reasoning_param=reasoning_param,
include_usage=(not stream_req) or include_usage,
) )
if error_resp is not None: if error_resp is not None:
return error_resp return error_resp
@@ -251,7 +274,14 @@ def completions() -> Response:
if stream_req: if stream_req:
resp = Response( resp = Response(
sse_translate_text(upstream, model, created, verbose=verbose, vlog=(print if verbose else None)), sse_translate_text(
upstream,
model,
created,
verbose=verbose,
vlog=(print if verbose else None),
include_usage=include_usage,
),
status=upstream.status_code, status=upstream.status_code,
mimetype="text/event-stream", mimetype="text/event-stream",
headers={"Cache-Control": "no-cache", "Connection": "keep-alive"}, headers={"Cache-Control": "no-cache", "Connection": "keep-alive"},
@@ -262,6 +292,18 @@ def completions() -> Response:
full_text = "" full_text = ""
response_id = "cmpl" response_id = "cmpl"
usage_obj: Dict[str, int] | None = None
def _extract_usage(evt: Dict[str, Any]) -> Dict[str, int] | None:
try:
usage = (evt.get("response") or {}).get("usage")
if not isinstance(usage, dict):
return None
pt = int(usage.get("input_tokens") or 0)
ct = int(usage.get("output_tokens") or 0)
tt = int(usage.get("total_tokens") or (pt + ct))
return {"prompt_tokens": pt, "completion_tokens": ct, "total_tokens": tt}
except Exception:
return None
try: try:
for raw_line in upstream.iter_lines(decode_unicode=False): for raw_line in upstream.iter_lines(decode_unicode=False):
if not raw_line: if not raw_line:
@@ -280,6 +322,9 @@ def completions() -> Response:
continue continue
if isinstance(evt.get("response"), dict) and isinstance(evt["response"].get("id"), str): if isinstance(evt.get("response"), dict) and isinstance(evt["response"].get("id"), str):
response_id = evt["response"].get("id") or response_id response_id = evt["response"].get("id") or response_id
mu = _extract_usage(evt)
if mu:
usage_obj = mu
kind = evt.get("type") kind = evt.get("type")
if kind == "response.output_text.delta": if kind == "response.output_text.delta":
full_text += evt.get("delta") or "" full_text += evt.get("delta") or ""
@@ -296,6 +341,7 @@ def completions() -> Response:
"choices": [ "choices": [
{"index": 0, "text": full_text, "finish_reason": "stop", "logprobs": None} {"index": 0, "text": full_text, "finish_reason": "stop", "logprobs": None}
], ],
**({"usage": usage_obj} if usage_obj else {}),
} }
resp = make_response(jsonify(completion), upstream.status_code) resp = make_response(jsonify(completion), upstream.status_code)
for k, v in build_cors_headers().items(): for k, v in build_cors_headers().items():
@@ -310,4 +356,3 @@ def list_models() -> Response:
for k, v in build_cors_headers().items(): for k, v in build_cors_headers().items():
resp.headers.setdefault(k, v) resp.headers.setdefault(k, v)
return resp return resp

View File

@@ -40,6 +40,7 @@ def start_upstream_request(
tool_choice: Any | None = None, tool_choice: Any | None = None,
parallel_tool_calls: bool = False, parallel_tool_calls: bool = False,
reasoning_param: Dict[str, Any] | None = None, reasoning_param: Dict[str, Any] | None = None,
include_usage: bool | None = None,
): ):
access_token, account_id = get_effective_chatgpt_auth() access_token, account_id = get_effective_chatgpt_auth()
if not access_token or not account_id: if not access_token or not account_id:
@@ -81,9 +82,10 @@ def start_upstream_request(
"parallel_tool_calls": bool(parallel_tool_calls), "parallel_tool_calls": bool(parallel_tool_calls),
"store": False, "store": False,
"stream": True, "stream": True,
"include": include,
"prompt_cache_key": session_id, "prompt_cache_key": session_id,
} }
if include:
responses_payload["include"] = include
if reasoning_param is not None: if reasoning_param is not None:
responses_payload["reasoning"] = reasoning_param responses_payload["reasoning"] = reasoning_param

View File

@@ -239,6 +239,8 @@ def sse_translate_chat(
verbose: bool = False, verbose: bool = False,
vlog=None, vlog=None,
reasoning_compat: str = "think-tags", reasoning_compat: str = "think-tags",
*,
include_usage: bool = False,
): ):
response_id = "chatcmpl-stream" response_id = "chatcmpl-stream"
compat = (reasoning_compat or "think-tags").strip().lower() compat = (reasoning_compat or "think-tags").strip().lower()
@@ -247,6 +249,19 @@ def sse_translate_chat(
saw_output = False saw_output = False
saw_any_summary = False saw_any_summary = False
pending_summary_paragraph = False pending_summary_paragraph = False
upstream_usage = None
def _extract_usage(evt: Dict[str, Any]) -> Dict[str, int] | None:
try:
usage = (evt.get("response") or {}).get("usage")
if not isinstance(usage, dict):
return None
pt = int(usage.get("input_tokens") or 0)
ct = int(usage.get("output_tokens") or 0)
tt = int(usage.get("total_tokens") or (pt + ct))
return {"prompt_tokens": pt, "completion_tokens": ct, "total_tokens": tt}
except Exception:
return None
try: try:
for raw in upstream.iter_lines(decode_unicode=False): for raw in upstream.iter_lines(decode_unicode=False):
if not raw: if not raw:
@@ -442,6 +457,9 @@ def sse_translate_chat(
chunk = {"error": {"message": err}} chunk = {"error": {"message": err}}
yield f"data: {json.dumps(chunk)}\n\n".encode("utf-8") yield f"data: {json.dumps(chunk)}\n\n".encode("utf-8")
elif kind == "response.completed": elif kind == "response.completed":
m = _extract_usage(evt)
if m:
upstream_usage = m
if compat == "think-tags" and think_open and not think_closed: if compat == "think-tags" and think_open and not think_closed:
close_chunk = { close_chunk = {
"id": response_id, "id": response_id,
@@ -453,14 +471,40 @@ def sse_translate_chat(
yield f"data: {json.dumps(close_chunk)}\n\n".encode("utf-8") yield f"data: {json.dumps(close_chunk)}\n\n".encode("utf-8")
think_open = False think_open = False
think_closed = True think_closed = True
if include_usage and upstream_usage:
try:
usage_chunk = {
"id": response_id,
"object": "chat.completion.chunk",
"created": created,
"model": model,
"choices": [{"index": 0, "delta": {}, "finish_reason": None}],
"usage": upstream_usage,
}
yield f"data: {json.dumps(usage_chunk)}\n\n".encode("utf-8")
except Exception:
pass
yield b"data: [DONE]\n\n" yield b"data: [DONE]\n\n"
break break
finally: finally:
upstream.close() upstream.close()
def sse_translate_text(upstream, model: str, created: int, verbose: bool = False, vlog=None): def sse_translate_text(upstream, model: str, created: int, verbose: bool = False, vlog=None, *, include_usage: bool = False):
response_id = "cmpl-stream" response_id = "cmpl-stream"
upstream_usage = None
def _extract_usage(evt: Dict[str, Any]) -> Dict[str, int] | None:
try:
usage = (evt.get("response") or {}).get("usage")
if not isinstance(usage, dict):
return None
pt = int(usage.get("input_tokens") or 0)
ct = int(usage.get("output_tokens") or 0)
tt = int(usage.get("total_tokens") or (pt + ct))
return {"prompt_tokens": pt, "completion_tokens": ct, "total_tokens": tt}
except Exception:
return None
try: try:
for raw_line in upstream.iter_lines(decode_unicode=False): for raw_line in upstream.iter_lines(decode_unicode=False):
if not raw_line: if not raw_line:
@@ -509,8 +553,23 @@ def sse_translate_text(upstream, model: str, created: int, verbose: bool = False
} }
yield f"data: {json.dumps(chunk)}\n\n".encode("utf-8") yield f"data: {json.dumps(chunk)}\n\n".encode("utf-8")
elif kind == "response.completed": elif kind == "response.completed":
m = _extract_usage(evt)
if m:
upstream_usage = m
if include_usage and upstream_usage:
try:
usage_chunk = {
"id": response_id,
"object": "text_completion.chunk",
"created": created,
"model": model,
"choices": [{"index": 0, "text": "", "finish_reason": None}],
"usage": upstream_usage,
}
yield f"data: {json.dumps(usage_chunk)}\n\n".encode("utf-8")
except Exception:
pass
yield b"data: [DONE]\n\n" yield b"data: [DONE]\n\n"
break break
finally: finally:
upstream.close() upstream.close()