feat(translation): contextual batch translation with strict validation
This commit is contained in:
@@ -2,6 +2,8 @@
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
|
||||
import httpx
|
||||
import pytest
|
||||
|
||||
@@ -13,6 +15,25 @@ def _mock_client(handler):
|
||||
return httpx.Client(transport=httpx.MockTransport(handler))
|
||||
|
||||
|
||||
def _mock_batch_response(translations):
|
||||
return httpx.Response(
|
||||
200,
|
||||
json={
|
||||
"choices": [
|
||||
{
|
||||
"message": {
|
||||
"content": json.dumps({"translations": translations}, ensure_ascii=False),
|
||||
}
|
||||
}
|
||||
]
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
def _read_request_json(request: httpx.Request):
|
||||
return json.loads(request.read().decode("utf-8"))
|
||||
|
||||
|
||||
def test_translation_config_normalizes_base_url():
|
||||
config = TranslationConfig.from_env(base_url="http://127.0.0.1:1234")
|
||||
|
||||
@@ -21,33 +42,78 @@ def test_translation_config_normalizes_base_url():
|
||||
assert config.model == "gemma-3-4b-it"
|
||||
|
||||
|
||||
def test_build_payload_includes_model_and_prompt():
|
||||
def test_build_contextual_batch_payload_includes_neighboring_segments():
|
||||
translator = LMStudioTranslator(TranslationConfig(), client=_mock_client(lambda request: None))
|
||||
|
||||
payload = translator.build_payload("Hello world", "en", "es")
|
||||
batches = translator.build_contextual_batches(
|
||||
["one", "two", "three", "four", "five", "six", "seven"],
|
||||
)
|
||||
|
||||
assert [len(batch.segments) for batch in batches] == [4, 3]
|
||||
|
||||
payload = translator.build_contextual_batch_payload(batches[0], "en", "es")
|
||||
user_payload = json.loads(payload["messages"][1]["content"])
|
||||
|
||||
assert payload["model"] == "gemma-3-4b-it"
|
||||
assert payload["messages"][0]["role"] == "system"
|
||||
assert "Translate the user-provided text from en to es." in payload["messages"][0]["content"]
|
||||
assert payload["messages"][1]["content"] == "Hello world"
|
||||
assert "expert audiovisual translator for dubbed video content" in payload["messages"][0]["content"]
|
||||
assert payload["temperature"] == 0.0
|
||||
assert payload["top_p"] == 1.0
|
||||
assert user_payload == {
|
||||
"source_language": "en",
|
||||
"target_language": "es",
|
||||
"previous_segments": [],
|
||||
"segments": [
|
||||
{"id": "0", "text": "one"},
|
||||
{"id": "1", "text": "two"},
|
||||
{"id": "2", "text": "three"},
|
||||
{"id": "3", "text": "four"},
|
||||
],
|
||||
"next_segments": [
|
||||
{"id": "4", "text": "five"},
|
||||
{"id": "5", "text": "six"},
|
||||
],
|
||||
}
|
||||
|
||||
|
||||
def test_translate_segments_preserves_order_and_blank_segments():
|
||||
def test_translate_segments_batches_context_and_preserves_exact_mapping():
|
||||
requests = []
|
||||
|
||||
def handler(request: httpx.Request) -> httpx.Response:
|
||||
text = request.read().decode("utf-8")
|
||||
if "first" in text:
|
||||
content = "primero"
|
||||
elif "third" in text:
|
||||
content = "tercero"
|
||||
else:
|
||||
content = "desconocido"
|
||||
return httpx.Response(200, json={"choices": [{"message": {"content": content}}]})
|
||||
payload = _read_request_json(request)
|
||||
batch_request = json.loads(payload["messages"][1]["content"])
|
||||
requests.append(batch_request)
|
||||
|
||||
translations = []
|
||||
for item in batch_request["segments"]:
|
||||
translated_text = "" if not item["text"].strip() else f"es::{item['id']}::{item['text']}"
|
||||
translations.append({"id": item["id"], "translated_text": translated_text})
|
||||
|
||||
return _mock_batch_response(translations)
|
||||
|
||||
translator = LMStudioTranslator(TranslationConfig(), client=_mock_client(handler))
|
||||
|
||||
translated = translator.translate_segments(["first", "", "third"], target_language="es", source_language="en")
|
||||
translated = translator.translate_segments(
|
||||
["first", "second", "", "fourth", "fifth", "sixth", "seventh"],
|
||||
target_language="es",
|
||||
source_language="en",
|
||||
)
|
||||
|
||||
assert translated == ["primero", "", "tercero"]
|
||||
assert translated == [
|
||||
"es::0::first",
|
||||
"es::1::second",
|
||||
"",
|
||||
"es::3::fourth",
|
||||
"es::4::fifth",
|
||||
"es::5::sixth",
|
||||
"es::6::seventh",
|
||||
]
|
||||
assert len(requests) == 2
|
||||
assert [item["id"] for item in requests[0]["segments"]] == ["0", "1", "2", "3"]
|
||||
assert requests[0]["previous_segments"] == []
|
||||
assert [item["id"] for item in requests[0]["next_segments"]] == ["4", "5"]
|
||||
assert [item["id"] for item in requests[1]["previous_segments"]] == ["2", "3"]
|
||||
assert [item["id"] for item in requests[1]["segments"]] == ["4", "5", "6"]
|
||||
|
||||
|
||||
def test_retry_on_transient_http_error_then_succeeds():
|
||||
@@ -76,6 +142,77 @@ def test_parse_response_content_rejects_empty_content():
|
||||
LMStudioTranslator.parse_response_content({"choices": [{"message": {"content": " "}}]})
|
||||
|
||||
|
||||
def test_parse_batch_translation_response_rejects_missing_ids():
|
||||
translator = LMStudioTranslator(TranslationConfig(), client=_mock_client(lambda request: None))
|
||||
batch = translator.build_contextual_batches(["one", "two", "three"], batch_size=3)[0]
|
||||
|
||||
with pytest.raises(TranslationError, match="ids did not match the request"):
|
||||
LMStudioTranslator.parse_batch_translation_response(
|
||||
json.dumps(
|
||||
{
|
||||
"translations": [
|
||||
{"id": "0", "translated_text": "uno"},
|
||||
{"id": "2", "translated_text": "dos"},
|
||||
{"id": "2", "translated_text": "tres"},
|
||||
]
|
||||
}
|
||||
),
|
||||
batch,
|
||||
)
|
||||
|
||||
|
||||
def test_parse_batch_translation_response_rejects_out_of_order_ids():
|
||||
translator = LMStudioTranslator(TranslationConfig(), client=_mock_client(lambda request: None))
|
||||
batch = translator.build_contextual_batches(["one", "two", "three"], batch_size=3)[0]
|
||||
|
||||
with pytest.raises(TranslationError, match="out of order"):
|
||||
LMStudioTranslator.parse_batch_translation_response(
|
||||
json.dumps(
|
||||
{
|
||||
"translations": [
|
||||
{"id": "1", "translated_text": "dos"},
|
||||
{"id": "0", "translated_text": "uno"},
|
||||
{"id": "2", "translated_text": "tres"},
|
||||
]
|
||||
}
|
||||
),
|
||||
batch,
|
||||
)
|
||||
|
||||
|
||||
def test_translate_segments_retries_on_malformed_json_batch_response():
|
||||
attempts = {"count": 0}
|
||||
|
||||
def handler(request: httpx.Request) -> httpx.Response:
|
||||
attempts["count"] += 1
|
||||
payload = _read_request_json(request)
|
||||
batch_request = json.loads(payload["messages"][1]["content"])
|
||||
|
||||
if attempts["count"] == 1:
|
||||
return httpx.Response(200, json={"choices": [{"message": {"content": "not-json"}}]})
|
||||
|
||||
translations = [
|
||||
{"id": item["id"], "translated_text": f"ok::{item['text']}"}
|
||||
for item in batch_request["segments"]
|
||||
]
|
||||
return _mock_batch_response(translations)
|
||||
|
||||
translator = LMStudioTranslator(
|
||||
TranslationConfig(max_retries=2),
|
||||
client=_mock_client(handler),
|
||||
sleeper=lambda _: None,
|
||||
)
|
||||
|
||||
translated = translator.translate_segments(
|
||||
["alpha", "beta", "gamma"],
|
||||
target_language="es",
|
||||
source_language="en",
|
||||
)
|
||||
|
||||
assert translated == ["ok::alpha", "ok::beta", "ok::gamma"]
|
||||
assert attempts["count"] == 2
|
||||
|
||||
|
||||
def test_translate_text_raises_on_malformed_response():
|
||||
def handler(request: httpx.Request) -> httpx.Response:
|
||||
return httpx.Response(200, json={"choices": []})
|
||||
|
||||
Reference in New Issue
Block a user