"""Tests for the LM Studio translation layer.""" from __future__ import annotations import json import httpx import pytest from src.core_utils import TranslationError from src.translation import LMStudioTranslator, TranslationConfig def _mock_client(handler): return httpx.Client(transport=httpx.MockTransport(handler)) def _mock_batch_response(translations): return httpx.Response( 200, json={ "choices": [ { "message": { "content": json.dumps({"translations": translations}, ensure_ascii=False), } } ] }, ) def _read_request_json(request: httpx.Request): return json.loads(request.read().decode("utf-8")) def test_translation_config_normalizes_base_url(): config = TranslationConfig.from_env(base_url="http://127.0.0.1:1234") assert config.base_url == "http://127.0.0.1:1234/v1" assert config.chat_completions_url == "http://127.0.0.1:1234/v1/chat/completions" assert config.model == "gemma-3-4b-it" def test_build_contextual_batch_payload_includes_neighboring_segments(): translator = LMStudioTranslator(TranslationConfig(), client=_mock_client(lambda request: None)) batches = translator.build_contextual_batches( ["one", "two", "three", "four", "five", "six", "seven"], ) assert [len(batch.segments) for batch in batches] == [4, 3] payload = translator.build_contextual_batch_payload(batches[0], "en", "es") user_payload = json.loads(payload["messages"][1]["content"]) assert payload["model"] == "gemma-3-4b-it" assert payload["messages"][0]["role"] == "system" assert "expert audiovisual translator for dubbed video content" in payload["messages"][0]["content"] assert payload["temperature"] == 0.0 assert payload["top_p"] == 1.0 assert user_payload == { "source_language": "en", "target_language": "es", "previous_segments": [], "segments": [ {"id": "0", "text": "one"}, {"id": "1", "text": "two"}, {"id": "2", "text": "three"}, {"id": "3", "text": "four"}, ], "next_segments": [ {"id": "4", "text": "five"}, {"id": "5", "text": "six"}, ], } def test_translate_segments_batches_context_and_preserves_exact_mapping(): requests = [] def handler(request: httpx.Request) -> httpx.Response: payload = _read_request_json(request) batch_request = json.loads(payload["messages"][1]["content"]) requests.append(batch_request) translations = [] for item in batch_request["segments"]: translated_text = "" if not item["text"].strip() else f"es::{item['id']}::{item['text']}" translations.append({"id": item["id"], "translated_text": translated_text}) return _mock_batch_response(translations) translator = LMStudioTranslator(TranslationConfig(), client=_mock_client(handler)) translated = translator.translate_segments( ["first", "second", "", "fourth", "fifth", "sixth", "seventh"], target_language="es", source_language="en", ) assert translated == [ "es::0::first", "es::1::second", "", "es::3::fourth", "es::4::fifth", "es::5::sixth", "es::6::seventh", ] assert len(requests) == 2 assert [item["id"] for item in requests[0]["segments"]] == ["0", "1", "2", "3"] assert requests[0]["previous_segments"] == [] assert [item["id"] for item in requests[0]["next_segments"]] == ["4", "5"] assert [item["id"] for item in requests[1]["previous_segments"]] == ["2", "3"] assert [item["id"] for item in requests[1]["segments"]] == ["4", "5", "6"] def test_retry_on_transient_http_error_then_succeeds(): attempts = {"count": 0} def handler(request: httpx.Request) -> httpx.Response: attempts["count"] += 1 if attempts["count"] == 1: return httpx.Response(503, json={"error": {"message": "busy"}}) return httpx.Response(200, json={"choices": [{"message": {"content": "hola"}}]}) translator = LMStudioTranslator( TranslationConfig(max_retries=2), client=_mock_client(handler), sleeper=lambda _: None, ) translated = translator.translate_text("hello", target_language="es", source_language="en") assert translated == "hola" assert attempts["count"] == 2 def test_parse_response_content_rejects_empty_content(): with pytest.raises(TranslationError, match="empty translation"): LMStudioTranslator.parse_response_content({"choices": [{"message": {"content": " "}}]}) def test_parse_batch_translation_response_rejects_missing_ids(): translator = LMStudioTranslator(TranslationConfig(), client=_mock_client(lambda request: None)) batch = translator.build_contextual_batches(["one", "two", "three"], batch_size=3)[0] with pytest.raises(TranslationError, match="ids did not match the request"): LMStudioTranslator.parse_batch_translation_response( json.dumps( { "translations": [ {"id": "0", "translated_text": "uno"}, {"id": "2", "translated_text": "dos"}, {"id": "2", "translated_text": "tres"}, ] } ), batch, ) def test_parse_batch_translation_response_rejects_out_of_order_ids(): translator = LMStudioTranslator(TranslationConfig(), client=_mock_client(lambda request: None)) batch = translator.build_contextual_batches(["one", "two", "three"], batch_size=3)[0] with pytest.raises(TranslationError, match="out of order"): LMStudioTranslator.parse_batch_translation_response( json.dumps( { "translations": [ {"id": "1", "translated_text": "dos"}, {"id": "0", "translated_text": "uno"}, {"id": "2", "translated_text": "tres"}, ] } ), batch, ) def test_translate_segments_retries_on_malformed_json_batch_response(): attempts = {"count": 0} def handler(request: httpx.Request) -> httpx.Response: attempts["count"] += 1 payload = _read_request_json(request) batch_request = json.loads(payload["messages"][1]["content"]) if attempts["count"] == 1: return httpx.Response(200, json={"choices": [{"message": {"content": "not-json"}}]}) translations = [ {"id": item["id"], "translated_text": f"ok::{item['text']}"} for item in batch_request["segments"] ] return _mock_batch_response(translations) translator = LMStudioTranslator( TranslationConfig(max_retries=2), client=_mock_client(handler), sleeper=lambda _: None, ) translated = translator.translate_segments( ["alpha", "beta", "gamma"], target_language="es", source_language="en", ) assert translated == ["ok::alpha", "ok::beta", "ok::gamma"] assert attempts["count"] == 2 def test_translate_text_raises_on_malformed_response(): def handler(request: httpx.Request) -> httpx.Response: return httpx.Response(200, json={"choices": []}) translator = LMStudioTranslator(TranslationConfig(), client=_mock_client(handler)) with pytest.raises(TranslationError, match="did not contain a chat completion message"): translator.translate_text("hello", target_language="es", source_language="en") def test_translate_text_falls_back_to_user_only_prompt_for_template_error(): attempts = {"count": 0} def handler(request: httpx.Request) -> httpx.Response: attempts["count"] += 1 body = request.read().decode("utf-8") if attempts["count"] == 1: return httpx.Response( 400, text='{"error":"Error rendering prompt with jinja template: \\"Conversations must start with a user prompt.\\""}', ) assert '"role":"user"' in body return httpx.Response(200, json={"choices": [{"message": {"content": "hola"}}]}) translator = LMStudioTranslator(TranslationConfig(), client=_mock_client(handler)) translated = translator.translate_text("hello", target_language="es", source_language="en") assert translated == "hola" assert attempts["count"] == 2 def test_translate_text_falls_back_to_structured_prompt_for_custom_template(): attempts = {"count": 0} def handler(request: httpx.Request) -> httpx.Response: attempts["count"] += 1 body = request.read().decode("utf-8") if attempts["count"] == 1: return httpx.Response( 400, text='{"error":"Error rendering prompt with jinja template: \\"Conversations must start with a user prompt.\\""}', ) if attempts["count"] == 2: return httpx.Response( 400, text='{"error":"Error rendering prompt with jinja template: \\"User role must provide `content` as an iterable with exactly one item. That item must be a mapping(type:\'text\' | \'image\', source_lang_code:string, target_lang_code:string, text:string | none, image:string | none).\\""}', ) assert '"source_lang_code":"en"' in body assert '"target_lang_code":"es"' in body return httpx.Response(200, json={"choices": [{"message": {"content": "hola"}}]}) translator = LMStudioTranslator(TranslationConfig(), client=_mock_client(handler)) translated = translator.translate_text("hello", target_language="es", source_language="en") assert translated == "hola" assert attempts["count"] == 3