hopefully patch token counting?

This commit is contained in:
Game_Time
2025-08-22 02:01:15 +05:00
parent c8c6540d23
commit cadd959778
20 changed files with 112 additions and 6 deletions

View File

@@ -63,6 +63,8 @@ def chat_completions() -> Response:
content = sys_msg.get("content") if isinstance(sys_msg, dict) else ""
messages.insert(0, {"role": "user", "content": content})
is_stream = bool(payload.get("stream"))
stream_options = payload.get("stream_options") if isinstance(payload.get("stream_options"), dict) else {}
include_usage = bool(stream_options.get("include_usage", False))
tools_responses = convert_tools_chat_to_responses(payload.get("tools"))
tool_choice = payload.get("tool_choice", "auto")
@@ -85,6 +87,7 @@ def chat_completions() -> Response:
tool_choice=tool_choice,
parallel_tool_calls=parallel_tool_calls,
reasoning_param=reasoning_param,
include_usage=(not is_stream) or include_usage,
)
if error_resp is not None:
return error_resp
@@ -112,6 +115,7 @@ def chat_completions() -> Response:
verbose=verbose,
vlog=print if verbose else None,
reasoning_compat=reasoning_compat,
include_usage=include_usage,
),
status=upstream.status_code,
mimetype="text/event-stream",
@@ -127,6 +131,19 @@ def chat_completions() -> Response:
response_id = "chatcmpl"
tool_calls: List[Dict[str, Any]] = []
error_message: str | None = None
usage_obj: Dict[str, int] | None = None
def _extract_usage(evt: Dict[str, Any]) -> Dict[str, int] | None:
try:
usage = (evt.get("response") or {}).get("usage")
if not isinstance(usage, dict):
return None
pt = int(usage.get("input_tokens") or 0)
ct = int(usage.get("output_tokens") or 0)
tt = int(usage.get("total_tokens") or (pt + ct))
return {"prompt_tokens": pt, "completion_tokens": ct, "total_tokens": tt}
except Exception:
return None
try:
for raw in upstream.iter_lines(decode_unicode=False):
if not raw:
@@ -144,6 +161,9 @@ def chat_completions() -> Response:
except Exception:
continue
kind = evt.get("type")
mu = _extract_usage(evt)
if mu:
usage_obj = mu
if isinstance(evt.get("response"), dict) and isinstance(evt["response"].get("id"), str):
response_id = evt["response"].get("id") or response_id
if kind == "response.output_text.delta":
@@ -183,7 +203,6 @@ def chat_completions() -> Response:
if tool_calls:
message["tool_calls"] = tool_calls
message = apply_reasoning_to_message(message, reasoning_summary_text, reasoning_full_text, reasoning_compat)
completion = {
"id": response_id or "chatcmpl",
"object": "chat.completion",
@@ -196,6 +215,7 @@ def chat_completions() -> Response:
"finish_reason": "stop",
}
],
**({"usage": usage_obj} if usage_obj else {}),
}
resp = make_response(jsonify(completion), upstream.status_code)
for k, v in build_cors_headers().items():
@@ -223,6 +243,8 @@ def completions() -> Response:
if not isinstance(prompt, str):
prompt = payload.get("suffix") or ""
stream_req = bool(payload.get("stream", False))
stream_options = payload.get("stream_options") if isinstance(payload.get("stream_options"), dict) else {}
include_usage = bool(stream_options.get("include_usage", False))
messages = [{"role": "user", "content": prompt or ""}]
input_items = convert_chat_messages_to_responses_input(messages)
@@ -234,6 +256,7 @@ def completions() -> Response:
input_items,
instructions=BASE_INSTRUCTIONS,
reasoning_param=reasoning_param,
include_usage=(not stream_req) or include_usage,
)
if error_resp is not None:
return error_resp
@@ -251,7 +274,14 @@ def completions() -> Response:
if stream_req:
resp = Response(
sse_translate_text(upstream, model, created, verbose=verbose, vlog=(print if verbose else None)),
sse_translate_text(
upstream,
model,
created,
verbose=verbose,
vlog=(print if verbose else None),
include_usage=include_usage,
),
status=upstream.status_code,
mimetype="text/event-stream",
headers={"Cache-Control": "no-cache", "Connection": "keep-alive"},
@@ -262,6 +292,18 @@ def completions() -> Response:
full_text = ""
response_id = "cmpl"
usage_obj: Dict[str, int] | None = None
def _extract_usage(evt: Dict[str, Any]) -> Dict[str, int] | None:
try:
usage = (evt.get("response") or {}).get("usage")
if not isinstance(usage, dict):
return None
pt = int(usage.get("input_tokens") or 0)
ct = int(usage.get("output_tokens") or 0)
tt = int(usage.get("total_tokens") or (pt + ct))
return {"prompt_tokens": pt, "completion_tokens": ct, "total_tokens": tt}
except Exception:
return None
try:
for raw_line in upstream.iter_lines(decode_unicode=False):
if not raw_line:
@@ -280,6 +322,9 @@ def completions() -> Response:
continue
if isinstance(evt.get("response"), dict) and isinstance(evt["response"].get("id"), str):
response_id = evt["response"].get("id") or response_id
mu = _extract_usage(evt)
if mu:
usage_obj = mu
kind = evt.get("type")
if kind == "response.output_text.delta":
full_text += evt.get("delta") or ""
@@ -296,6 +341,7 @@ def completions() -> Response:
"choices": [
{"index": 0, "text": full_text, "finish_reason": "stop", "logprobs": None}
],
**({"usage": usage_obj} if usage_obj else {}),
}
resp = make_response(jsonify(completion), upstream.status_code)
for k, v in build_cors_headers().items():
@@ -310,4 +356,3 @@ def list_models() -> Response:
for k, v in build_cors_headers().items():
resp.headers.setdefault(k, v)
return resp