Add translation and TTS caching

This commit is contained in:
2026-05-24 16:49:21 +01:00
parent 803f532ff3
commit 9fbb7c1756
5 changed files with 175 additions and 18 deletions

View File

@@ -8,6 +8,7 @@ import httpx
import pytest
from src.core_utils import TranslationError
from src import translation
from src.translation import LMStudioTranslator, TranslationConfig
@@ -56,7 +57,8 @@ def test_build_contextual_batch_payload_includes_neighboring_segments():
assert payload["model"] == "gemma-3-4b-it"
assert payload["messages"][0]["role"] == "system"
assert "expert audiovisual translator for dubbed video content" in payload["messages"][0]["content"]
assert "expert audiovisual translator and dubbing script adapter" in payload["messages"][0]["content"]
assert "Preserve segment ids and output order exactly" in payload["messages"][0]["content"]
assert payload["temperature"] == 0.0
assert payload["top_p"] == 1.0
assert user_payload == {
@@ -116,6 +118,29 @@ def test_translate_segments_batches_context_and_preserves_exact_mapping():
assert [item["id"] for item in requests[1]["segments"]] == ["4", "5", "6"]
def test_translate_segments_uses_persistent_cache(tmp_path, monkeypatch):
requests = {"count": 0}
monkeypatch.setattr(translation, "TRANSLATION_CACHE_DIR", tmp_path)
def handler(request: httpx.Request) -> httpx.Response:
requests["count"] += 1
payload = _read_request_json(request)
batch_request = json.loads(payload["messages"][1]["content"])
translations = [
{"id": item["id"], "translated_text": f"cached::{item['text']}"}
for item in batch_request["segments"]
]
return _mock_batch_response(translations)
config = TranslationConfig(model="cache-model")
first_translator = LMStudioTranslator(config, client=_mock_client(handler))
second_translator = LMStudioTranslator(config, client=_mock_client(handler))
assert first_translator.translate_segments(["hello", "world"], "fr", "en") == ["cached::hello", "cached::world"]
assert second_translator.translate_segments(["hello", "world"], "fr", "en") == ["cached::hello", "cached::world"]
assert requests["count"] == 1
def test_retry_on_transient_http_error_then_succeeds():
attempts = {"count": 0}

View File

@@ -4,6 +4,7 @@ from __future__ import annotations
import asyncio
from src import engines
from src.engines import Engine
from src.translation import TranslationConfig
@@ -46,3 +47,39 @@ def test_synthesize_retries_transient_edge_tts_failure(tmp_path, monkeypatch):
assert calls["count"] == 2
assert out_path.exists()
assert out_path.stat().st_size == 2048
def test_synthesize_uses_tts_cache(tmp_path, monkeypatch):
calls = {"count": 0}
cache_dir = tmp_path / "tts-cache"
monkeypatch.setattr(engines, "TTS_CACHE_DIR", cache_dir)
class FakeCommunicate:
def __init__(self, text, voice, rate):
self.text = text
self.voice = voice
self.rate = rate
async def save(self, out_path):
calls["count"] += 1
with open(out_path, "wb") as audio_file:
audio_file.write(b"1" * 2048)
monkeypatch.setattr("src.engines.edge_tts.Communicate", FakeCommunicate)
engine = Engine(
"cpu",
translation_config=TranslationConfig(
base_url="http://127.0.0.1:1234/v1",
api_key="test-key",
model="test-model",
),
)
first_out = tmp_path / "first.mp3"
second_out = tmp_path / "second.mp3"
asyncio.run(engine.synthesize("Bonjour", "fr", first_out))
asyncio.run(engine.synthesize("Bonjour", "fr", second_out))
assert calls["count"] == 1
assert first_out.read_bytes() == second_out.read_bytes()