feat(translation): contextual batch translation with strict validation

This commit is contained in:
2026-03-30 19:05:50 +01:00
parent 3c9b3c8090
commit d83c57cda3
2 changed files with 462 additions and 31 deletions

View File

@@ -2,6 +2,8 @@
from __future__ import annotations
import json
import httpx
import pytest
@@ -13,6 +15,25 @@ def _mock_client(handler):
return httpx.Client(transport=httpx.MockTransport(handler))
def _mock_batch_response(translations):
return httpx.Response(
200,
json={
"choices": [
{
"message": {
"content": json.dumps({"translations": translations}, ensure_ascii=False),
}
}
]
},
)
def _read_request_json(request: httpx.Request):
return json.loads(request.read().decode("utf-8"))
def test_translation_config_normalizes_base_url():
config = TranslationConfig.from_env(base_url="http://127.0.0.1:1234")
@@ -21,33 +42,78 @@ def test_translation_config_normalizes_base_url():
assert config.model == "gemma-3-4b-it"
def test_build_payload_includes_model_and_prompt():
def test_build_contextual_batch_payload_includes_neighboring_segments():
translator = LMStudioTranslator(TranslationConfig(), client=_mock_client(lambda request: None))
payload = translator.build_payload("Hello world", "en", "es")
batches = translator.build_contextual_batches(
["one", "two", "three", "four", "five", "six", "seven"],
)
assert [len(batch.segments) for batch in batches] == [4, 3]
payload = translator.build_contextual_batch_payload(batches[0], "en", "es")
user_payload = json.loads(payload["messages"][1]["content"])
assert payload["model"] == "gemma-3-4b-it"
assert payload["messages"][0]["role"] == "system"
assert "Translate the user-provided text from en to es." in payload["messages"][0]["content"]
assert payload["messages"][1]["content"] == "Hello world"
assert "expert audiovisual translator for dubbed video content" in payload["messages"][0]["content"]
assert payload["temperature"] == 0.0
assert payload["top_p"] == 1.0
assert user_payload == {
"source_language": "en",
"target_language": "es",
"previous_segments": [],
"segments": [
{"id": "0", "text": "one"},
{"id": "1", "text": "two"},
{"id": "2", "text": "three"},
{"id": "3", "text": "four"},
],
"next_segments": [
{"id": "4", "text": "five"},
{"id": "5", "text": "six"},
],
}
def test_translate_segments_preserves_order_and_blank_segments():
def test_translate_segments_batches_context_and_preserves_exact_mapping():
requests = []
def handler(request: httpx.Request) -> httpx.Response:
text = request.read().decode("utf-8")
if "first" in text:
content = "primero"
elif "third" in text:
content = "tercero"
else:
content = "desconocido"
return httpx.Response(200, json={"choices": [{"message": {"content": content}}]})
payload = _read_request_json(request)
batch_request = json.loads(payload["messages"][1]["content"])
requests.append(batch_request)
translations = []
for item in batch_request["segments"]:
translated_text = "" if not item["text"].strip() else f"es::{item['id']}::{item['text']}"
translations.append({"id": item["id"], "translated_text": translated_text})
return _mock_batch_response(translations)
translator = LMStudioTranslator(TranslationConfig(), client=_mock_client(handler))
translated = translator.translate_segments(["first", "", "third"], target_language="es", source_language="en")
translated = translator.translate_segments(
["first", "second", "", "fourth", "fifth", "sixth", "seventh"],
target_language="es",
source_language="en",
)
assert translated == ["primero", "", "tercero"]
assert translated == [
"es::0::first",
"es::1::second",
"",
"es::3::fourth",
"es::4::fifth",
"es::5::sixth",
"es::6::seventh",
]
assert len(requests) == 2
assert [item["id"] for item in requests[0]["segments"]] == ["0", "1", "2", "3"]
assert requests[0]["previous_segments"] == []
assert [item["id"] for item in requests[0]["next_segments"]] == ["4", "5"]
assert [item["id"] for item in requests[1]["previous_segments"]] == ["2", "3"]
assert [item["id"] for item in requests[1]["segments"]] == ["4", "5", "6"]
def test_retry_on_transient_http_error_then_succeeds():
@@ -76,6 +142,77 @@ def test_parse_response_content_rejects_empty_content():
LMStudioTranslator.parse_response_content({"choices": [{"message": {"content": " "}}]})
def test_parse_batch_translation_response_rejects_missing_ids():
translator = LMStudioTranslator(TranslationConfig(), client=_mock_client(lambda request: None))
batch = translator.build_contextual_batches(["one", "two", "three"], batch_size=3)[0]
with pytest.raises(TranslationError, match="ids did not match the request"):
LMStudioTranslator.parse_batch_translation_response(
json.dumps(
{
"translations": [
{"id": "0", "translated_text": "uno"},
{"id": "2", "translated_text": "dos"},
{"id": "2", "translated_text": "tres"},
]
}
),
batch,
)
def test_parse_batch_translation_response_rejects_out_of_order_ids():
translator = LMStudioTranslator(TranslationConfig(), client=_mock_client(lambda request: None))
batch = translator.build_contextual_batches(["one", "two", "three"], batch_size=3)[0]
with pytest.raises(TranslationError, match="out of order"):
LMStudioTranslator.parse_batch_translation_response(
json.dumps(
{
"translations": [
{"id": "1", "translated_text": "dos"},
{"id": "0", "translated_text": "uno"},
{"id": "2", "translated_text": "tres"},
]
}
),
batch,
)
def test_translate_segments_retries_on_malformed_json_batch_response():
attempts = {"count": 0}
def handler(request: httpx.Request) -> httpx.Response:
attempts["count"] += 1
payload = _read_request_json(request)
batch_request = json.loads(payload["messages"][1]["content"])
if attempts["count"] == 1:
return httpx.Response(200, json={"choices": [{"message": {"content": "not-json"}}]})
translations = [
{"id": item["id"], "translated_text": f"ok::{item['text']}"}
for item in batch_request["segments"]
]
return _mock_batch_response(translations)
translator = LMStudioTranslator(
TranslationConfig(max_retries=2),
client=_mock_client(handler),
sleeper=lambda _: None,
)
translated = translator.translate_segments(
["alpha", "beta", "gamma"],
target_language="es",
source_language="en",
)
assert translated == ["ok::alpha", "ok::beta", "ok::gamma"]
assert attempts["count"] == 2
def test_translate_text_raises_on_malformed_response():
def handler(request: httpx.Request) -> httpx.Response:
return httpx.Response(200, json={"choices": []})