274 lines
9.7 KiB
Python
274 lines
9.7 KiB
Python
"""Tests for the LM Studio translation layer."""
|
|
|
|
from __future__ import annotations
|
|
|
|
import json
|
|
|
|
import httpx
|
|
import pytest
|
|
|
|
from src.core_utils import TranslationError
|
|
from src.translation import LMStudioTranslator, TranslationConfig
|
|
|
|
|
|
def _mock_client(handler):
|
|
return httpx.Client(transport=httpx.MockTransport(handler))
|
|
|
|
|
|
def _mock_batch_response(translations):
|
|
return httpx.Response(
|
|
200,
|
|
json={
|
|
"choices": [
|
|
{
|
|
"message": {
|
|
"content": json.dumps({"translations": translations}, ensure_ascii=False),
|
|
}
|
|
}
|
|
]
|
|
},
|
|
)
|
|
|
|
|
|
def _read_request_json(request: httpx.Request):
|
|
return json.loads(request.read().decode("utf-8"))
|
|
|
|
|
|
def test_translation_config_normalizes_base_url():
|
|
config = TranslationConfig.from_env(base_url="http://127.0.0.1:1234")
|
|
|
|
assert config.base_url == "http://127.0.0.1:1234/v1"
|
|
assert config.chat_completions_url == "http://127.0.0.1:1234/v1/chat/completions"
|
|
assert config.model == "gemma-3-4b-it"
|
|
|
|
|
|
def test_build_contextual_batch_payload_includes_neighboring_segments():
|
|
translator = LMStudioTranslator(TranslationConfig(), client=_mock_client(lambda request: None))
|
|
|
|
batches = translator.build_contextual_batches(
|
|
["one", "two", "three", "four", "five", "six", "seven"],
|
|
)
|
|
|
|
assert [len(batch.segments) for batch in batches] == [4, 3]
|
|
|
|
payload = translator.build_contextual_batch_payload(batches[0], "en", "es")
|
|
user_payload = json.loads(payload["messages"][1]["content"])
|
|
|
|
assert payload["model"] == "gemma-3-4b-it"
|
|
assert payload["messages"][0]["role"] == "system"
|
|
assert "expert audiovisual translator for dubbed video content" in payload["messages"][0]["content"]
|
|
assert payload["temperature"] == 0.0
|
|
assert payload["top_p"] == 1.0
|
|
assert user_payload == {
|
|
"source_language": "en",
|
|
"target_language": "es",
|
|
"previous_segments": [],
|
|
"segments": [
|
|
{"id": "0", "text": "one"},
|
|
{"id": "1", "text": "two"},
|
|
{"id": "2", "text": "three"},
|
|
{"id": "3", "text": "four"},
|
|
],
|
|
"next_segments": [
|
|
{"id": "4", "text": "five"},
|
|
{"id": "5", "text": "six"},
|
|
],
|
|
}
|
|
|
|
|
|
def test_translate_segments_batches_context_and_preserves_exact_mapping():
|
|
requests = []
|
|
|
|
def handler(request: httpx.Request) -> httpx.Response:
|
|
payload = _read_request_json(request)
|
|
batch_request = json.loads(payload["messages"][1]["content"])
|
|
requests.append(batch_request)
|
|
|
|
translations = []
|
|
for item in batch_request["segments"]:
|
|
translated_text = "" if not item["text"].strip() else f"es::{item['id']}::{item['text']}"
|
|
translations.append({"id": item["id"], "translated_text": translated_text})
|
|
|
|
return _mock_batch_response(translations)
|
|
|
|
translator = LMStudioTranslator(TranslationConfig(), client=_mock_client(handler))
|
|
|
|
translated = translator.translate_segments(
|
|
["first", "second", "", "fourth", "fifth", "sixth", "seventh"],
|
|
target_language="es",
|
|
source_language="en",
|
|
)
|
|
|
|
assert translated == [
|
|
"es::0::first",
|
|
"es::1::second",
|
|
"",
|
|
"es::3::fourth",
|
|
"es::4::fifth",
|
|
"es::5::sixth",
|
|
"es::6::seventh",
|
|
]
|
|
assert len(requests) == 2
|
|
assert [item["id"] for item in requests[0]["segments"]] == ["0", "1", "2", "3"]
|
|
assert requests[0]["previous_segments"] == []
|
|
assert [item["id"] for item in requests[0]["next_segments"]] == ["4", "5"]
|
|
assert [item["id"] for item in requests[1]["previous_segments"]] == ["2", "3"]
|
|
assert [item["id"] for item in requests[1]["segments"]] == ["4", "5", "6"]
|
|
|
|
|
|
def test_retry_on_transient_http_error_then_succeeds():
|
|
attempts = {"count": 0}
|
|
|
|
def handler(request: httpx.Request) -> httpx.Response:
|
|
attempts["count"] += 1
|
|
if attempts["count"] == 1:
|
|
return httpx.Response(503, json={"error": {"message": "busy"}})
|
|
return httpx.Response(200, json={"choices": [{"message": {"content": "hola"}}]})
|
|
|
|
translator = LMStudioTranslator(
|
|
TranslationConfig(max_retries=2),
|
|
client=_mock_client(handler),
|
|
sleeper=lambda _: None,
|
|
)
|
|
|
|
translated = translator.translate_text("hello", target_language="es", source_language="en")
|
|
|
|
assert translated == "hola"
|
|
assert attempts["count"] == 2
|
|
|
|
|
|
def test_parse_response_content_rejects_empty_content():
|
|
with pytest.raises(TranslationError, match="empty translation"):
|
|
LMStudioTranslator.parse_response_content({"choices": [{"message": {"content": " "}}]})
|
|
|
|
|
|
def test_parse_batch_translation_response_rejects_missing_ids():
|
|
translator = LMStudioTranslator(TranslationConfig(), client=_mock_client(lambda request: None))
|
|
batch = translator.build_contextual_batches(["one", "two", "three"], batch_size=3)[0]
|
|
|
|
with pytest.raises(TranslationError, match="ids did not match the request"):
|
|
LMStudioTranslator.parse_batch_translation_response(
|
|
json.dumps(
|
|
{
|
|
"translations": [
|
|
{"id": "0", "translated_text": "uno"},
|
|
{"id": "2", "translated_text": "dos"},
|
|
{"id": "2", "translated_text": "tres"},
|
|
]
|
|
}
|
|
),
|
|
batch,
|
|
)
|
|
|
|
|
|
def test_parse_batch_translation_response_rejects_out_of_order_ids():
|
|
translator = LMStudioTranslator(TranslationConfig(), client=_mock_client(lambda request: None))
|
|
batch = translator.build_contextual_batches(["one", "two", "three"], batch_size=3)[0]
|
|
|
|
with pytest.raises(TranslationError, match="out of order"):
|
|
LMStudioTranslator.parse_batch_translation_response(
|
|
json.dumps(
|
|
{
|
|
"translations": [
|
|
{"id": "1", "translated_text": "dos"},
|
|
{"id": "0", "translated_text": "uno"},
|
|
{"id": "2", "translated_text": "tres"},
|
|
]
|
|
}
|
|
),
|
|
batch,
|
|
)
|
|
|
|
|
|
def test_translate_segments_retries_on_malformed_json_batch_response():
|
|
attempts = {"count": 0}
|
|
|
|
def handler(request: httpx.Request) -> httpx.Response:
|
|
attempts["count"] += 1
|
|
payload = _read_request_json(request)
|
|
batch_request = json.loads(payload["messages"][1]["content"])
|
|
|
|
if attempts["count"] == 1:
|
|
return httpx.Response(200, json={"choices": [{"message": {"content": "not-json"}}]})
|
|
|
|
translations = [
|
|
{"id": item["id"], "translated_text": f"ok::{item['text']}"}
|
|
for item in batch_request["segments"]
|
|
]
|
|
return _mock_batch_response(translations)
|
|
|
|
translator = LMStudioTranslator(
|
|
TranslationConfig(max_retries=2),
|
|
client=_mock_client(handler),
|
|
sleeper=lambda _: None,
|
|
)
|
|
|
|
translated = translator.translate_segments(
|
|
["alpha", "beta", "gamma"],
|
|
target_language="es",
|
|
source_language="en",
|
|
)
|
|
|
|
assert translated == ["ok::alpha", "ok::beta", "ok::gamma"]
|
|
assert attempts["count"] == 2
|
|
|
|
|
|
def test_translate_text_raises_on_malformed_response():
|
|
def handler(request: httpx.Request) -> httpx.Response:
|
|
return httpx.Response(200, json={"choices": []})
|
|
|
|
translator = LMStudioTranslator(TranslationConfig(), client=_mock_client(handler))
|
|
|
|
with pytest.raises(TranslationError, match="did not contain a chat completion message"):
|
|
translator.translate_text("hello", target_language="es", source_language="en")
|
|
|
|
|
|
def test_translate_text_falls_back_to_user_only_prompt_for_template_error():
|
|
attempts = {"count": 0}
|
|
|
|
def handler(request: httpx.Request) -> httpx.Response:
|
|
attempts["count"] += 1
|
|
body = request.read().decode("utf-8")
|
|
if attempts["count"] == 1:
|
|
return httpx.Response(
|
|
400,
|
|
text='{"error":"Error rendering prompt with jinja template: \\"Conversations must start with a user prompt.\\""}',
|
|
)
|
|
assert '"role":"user"' in body
|
|
return httpx.Response(200, json={"choices": [{"message": {"content": "hola"}}]})
|
|
|
|
translator = LMStudioTranslator(TranslationConfig(), client=_mock_client(handler))
|
|
|
|
translated = translator.translate_text("hello", target_language="es", source_language="en")
|
|
|
|
assert translated == "hola"
|
|
assert attempts["count"] == 2
|
|
|
|
|
|
def test_translate_text_falls_back_to_structured_prompt_for_custom_template():
|
|
attempts = {"count": 0}
|
|
|
|
def handler(request: httpx.Request) -> httpx.Response:
|
|
attempts["count"] += 1
|
|
body = request.read().decode("utf-8")
|
|
if attempts["count"] == 1:
|
|
return httpx.Response(
|
|
400,
|
|
text='{"error":"Error rendering prompt with jinja template: \\"Conversations must start with a user prompt.\\""}',
|
|
)
|
|
if attempts["count"] == 2:
|
|
return httpx.Response(
|
|
400,
|
|
text='{"error":"Error rendering prompt with jinja template: \\"User role must provide `content` as an iterable with exactly one item. That item must be a mapping(type:\'text\' | \'image\', source_lang_code:string, target_lang_code:string, text:string | none, image:string | none).\\""}',
|
|
)
|
|
assert '"source_lang_code":"en"' in body
|
|
assert '"target_lang_code":"es"' in body
|
|
return httpx.Response(200, json={"choices": [{"message": {"content": "hola"}}]})
|
|
|
|
translator = LMStudioTranslator(TranslationConfig(), client=_mock_client(handler))
|
|
|
|
translated = translator.translate_text("hello", target_language="es", source_language="en")
|
|
|
|
assert translated == "hola"
|
|
assert attempts["count"] == 3
|