From 9fbb7c17560a24ef496b5161bdfc0752186faf3f Mon Sep 17 00:00:00 2001 From: oimwiodev Date: Sun, 24 May 2026 16:49:21 +0100 Subject: [PATCH] Add translation and TTS caching --- README.md | 2 + src/engines.py | 30 ++++++++++++ src/translation.py | 97 ++++++++++++++++++++++++++++++++------- tests/test_translation.py | 27 ++++++++++- tests/test_tts_retry.py | 37 +++++++++++++++ 5 files changed, 175 insertions(+), 18 deletions(-) diff --git a/README.md b/README.md index 70fc155..62144a6 100644 --- a/README.md +++ b/README.md @@ -82,6 +82,8 @@ You can also upload a local `.mp4` instead of entering a YouTube URL. Uploaded v The web UI automatically refreshes job status, progress, steps, and output choices every few seconds while it is open. The manual **Refresh** button is still available. +Translations and raw TTS clips are cached under `.cache/translations` and `.cache/tts`. This lets reruns skip work that already succeeded, which is especially useful after transient TTS failures. Set `TRANSLATION_CACHE_ENABLED=0` or `TTS_CACHE_ENABLED=0` to disable those caches. + ### Docker Build and run the Gradio UI in a container: diff --git a/src/engines.py b/src/engines.py index 4e5caef..f3cda2b 100644 --- a/src/engines.py +++ b/src/engines.py @@ -16,8 +16,10 @@ import torch import asyncio import edge_tts import gc +import hashlib import json import os +import shutil from abc import ABC import numpy as np from pathlib import Path @@ -42,6 +44,7 @@ BASE_DIR = Path(__file__).resolve().parent.parent CACHE_DIR = BASE_DIR / ".cache" OUTPUT_DIR = BASE_DIR / "output" TEMP_DIR = BASE_DIR / "temp" +TTS_CACHE_DIR = CACHE_DIR / "tts" # Configuration files LANG_MAP_FILE = BASE_DIR / "language_map.json" @@ -56,6 +59,23 @@ AUDIO_CHANNELS = 1 DEFAULT_TTS_MAX_RETRIES = int(os.getenv("TTS_MAX_RETRIES", "4")) DEFAULT_TTS_RETRY_BACKOFF_SECONDS = float(os.getenv("TTS_RETRY_BACKOFF_SECONDS", "2.0")) + +def _cache_enabled(env_name: str) -> bool: + return (os.getenv(env_name, "1") or "").strip().lower() not in {"0", "false", "no", "off"} + + +def _tts_cache_key(text: str, target_lang: str, voice: str, rate: str) -> str: + payload = { + "version": "edge-tts-v1", + "text": text, + "target_lang": target_lang, + "voice": voice, + "rate": rate, + "sample_rate": SAMPLE_RATE, + } + serialized = json.dumps(payload, ensure_ascii=False, sort_keys=True, separators=(",", ":")) + return hashlib.sha256(serialized.encode("utf-8")).hexdigest() + def _select_optimal_whisper_model(device: str = "cpu") -> str: """Select optimal Whisper model based on available VRAM and device. @@ -495,6 +515,12 @@ class Engine(PipelineComponent): lang_cfg = self._getLangConfig(target_lang) voice_pool = self.config_manager.getVoicePool(target_lang, gender) voice = voice_pool[0] if voice_pool else DEFAULT_VOICE + cache_path = TTS_CACHE_DIR / f"{_tts_cache_key(text, target_lang, voice, rate)}.mp3" + + if _cache_enabled("TTS_CACHE_ENABLED") and cache_path.exists() and cache_path.stat().st_size >= 1024: + print(f"[*] TTS cache hit: {cache_path.name}") + shutil.copyfile(cache_path, out_path) + return try: communicate = edge_tts.Communicate(text, voice=voice, rate=rate) @@ -503,6 +529,10 @@ class Engine(PipelineComponent): if not out_path.exists() or out_path.stat().st_size < 1024: raise RuntimeError("TTS file invalid") + if _cache_enabled("TTS_CACHE_ENABLED"): + cache_path.parent.mkdir(parents=True, exist_ok=True) + shutil.copyfile(out_path, cache_path) + return except Exception as exc: last_error = exc diff --git a/src/translation.py b/src/translation.py index 91998dd..47b185e 100644 --- a/src/translation.py +++ b/src/translation.py @@ -2,10 +2,12 @@ from __future__ import annotations +import hashlib import json import os import time from dataclasses import dataclass +from pathlib import Path from typing import Any, Dict, List, Optional from urllib.parse import urlparse @@ -20,6 +22,9 @@ DEFAULT_TRANSLATION_BACKEND = "lmstudio" DEFAULT_CONTEXTUAL_BATCH_SIZE = 5 MIN_CONTEXTUAL_BATCH_SIZE = 3 DEFAULT_CONTEXT_SEGMENTS = 2 +PROMPT_VERSION = "gpt54-dub-v2" +BASE_DIR = Path(__file__).resolve().parent.parent +TRANSLATION_CACHE_DIR = BASE_DIR / ".cache" / "translations" def _normalize_base_url(base_url: str) -> str: @@ -125,26 +130,56 @@ def _build_system_prompt(source_language: str, target_language: str) -> str: def _build_contextual_system_prompt(source_language: str, target_language: str) -> str: source_descriptor = source_language or "auto" return ( - "You are an expert audiovisual translator for dubbed video content.\n\n" + "You are an expert audiovisual translator and dubbing script adapter.\n\n" f"Translate subtitle segments from {source_descriptor} to {target_language} for natural spoken dubbing.\n\n" - "Rules:\n" - "- Preserve meaning, intent, tone, and subtext.\n" - "- Use surrounding subtitle context to resolve ambiguity.\n" - "- Do not summarize.\n" - "- Do not simplify unless needed for natural speech.\n" - "- Do not add explanations, notes, or commentary.\n" - "- Preserve humor, sarcasm, emotional tone, and register.\n" - "- Keep names, brands, URLs, emails, file paths, code, and product names unchanged unless transliteration is clearly needed.\n" - "- Keep the translation natural for spoken dubbing.\n" - "- Preserve segment boundaries exactly.\n" + "Primary objective:\n" + "- Produce faithful, idiomatic spoken lines that can be read aloud naturally in the target language.\n" + "- Preserve intent, tone, emotion, register, humor, sarcasm, hesitation, and subtext.\n" + "- Use the provided previous and next segments only as context; translate only the current segments.\n\n" + "Dubbing adaptation rules:\n" + "- Prefer natural speech over literal word-for-word phrasing when the literal version sounds stiff.\n" + "- Keep each translated segment close to the source segment length when possible, because it will be timed to video.\n" + "- Do not add new claims, soften meaning, moralize, censor, summarize, or omit content.\n" + "- Preserve speaker deixis and continuity across adjacent segments.\n" + "- Keep names, brands, URLs, emails, file paths, code, product names, and quoted UI text unchanged unless transliteration is clearly required.\n" + "- Preserve numbers, units, dates, and technical terms accurately.\n" + "- If a phrase is slang, idiom, or a joke, translate the effect rather than the literal wording.\n\n" + "Output contract:\n" + "- Return valid JSON only, with no markdown fences or commentary.\n" "- Return exactly one translated item per input segment.\n" - "- Output order must match input order exactly.\n" - "- If an input segment is empty or whitespace-only, return an empty translated_text for that id.\n" - "- Output valid JSON only.\n" + "- Preserve segment ids and output order exactly.\n" + "- Preserve empty or whitespace-only segments as an empty translated_text.\n" + "- Do not include previous_segments or next_segments in the output.\n" '- Return this exact schema: {"translations":[{"id":"...","translated_text":"..."}]}.' ) +def _cache_enabled(env_name: str) -> bool: + return (os.getenv(env_name, "1") or "").strip().lower() not in {"0", "false", "no", "off"} + + +def _json_cache_key(payload: Dict[str, Any]) -> str: + serialized = json.dumps(payload, ensure_ascii=False, sort_keys=True, separators=(",", ":")) + return hashlib.sha256(serialized.encode("utf-8")).hexdigest() + + +def _read_json_cache(cache_path: Path) -> Optional[Dict[str, Any]]: + if not cache_path.exists(): + return None + try: + payload = json.loads(cache_path.read_text(encoding="utf-8")) + except (OSError, json.JSONDecodeError): + return None + return payload if isinstance(payload, dict) else None + + +def _write_json_cache(cache_path: Path, payload: Dict[str, Any]) -> None: + cache_path.parent.mkdir(parents=True, exist_ok=True) + tmp_path = cache_path.with_suffix(".tmp") + tmp_path.write_text(json.dumps(payload, ensure_ascii=False, indent=2), encoding="utf-8") + tmp_path.replace(cache_path) + + @dataclass(frozen=True) class TranslationSegment: """A subtitle segment prepared for contextual batch translation.""" @@ -183,6 +218,7 @@ class LMStudioTranslator: self._client = client or httpx.Client(timeout=httpx.Timeout(self.config.timeout_seconds)) self._owns_client = client is None self._sleeper = sleeper + self._cache_enabled = _cache_enabled("TRANSLATION_CACHE_ENABLED") @staticmethod def _generation_settings() -> Dict[str, Any]: @@ -551,19 +587,43 @@ class LMStudioTranslator: ) -> List[str]: """Translate a single contextual subtitle batch with validation and retries.""" payload = self.build_contextual_batch_payload(batch, source_language, target_language) + cache_payload = { + "prompt_version": PROMPT_VERSION, + "backend": self.config.backend, + "base_url": self.config.base_url, + "model": self.config.model, + "source_language": source_language or "auto", + "target_language": target_language, + "request": self.build_contextual_batch_request(batch, source_language, target_language), + } + cache_path = TRANSLATION_CACHE_DIR / f"{_json_cache_key(cache_payload)}.json" + if self._cache_enabled: + cached_payload = _read_json_cache(cache_path) + cached_translations = cached_payload.get("translations") if cached_payload else None + if isinstance(cached_translations, list) and all(isinstance(item, str) for item in cached_translations): + if len(cached_translations) == len(batch.segments): + print(f"[*] Translation cache hit: {cache_path.name}") + return cached_translations + last_error: Optional[Exception] = None for attempt in range(1, self.config.max_retries + 1): try: response_content = self._post_chat_completion(payload) - return self.parse_batch_translation_response(response_content, batch) + translations = self.parse_batch_translation_response(response_content, batch) + if self._cache_enabled: + _write_json_cache(cache_path, {"translations": translations}) + return translations except (httpx.HTTPError, ValueError, TranslationError) as exc: last_error = exc if self._should_retry_with_user_only_prompt(exc): try: fallback_payload = self.build_contextual_user_only_payload(batch, source_language, target_language) fallback_content = self._post_chat_completion(fallback_payload) - return self.parse_batch_translation_response(fallback_content, batch) + translations = self.parse_batch_translation_response(fallback_content, batch) + if self._cache_enabled: + _write_json_cache(cache_path, {"translations": translations}) + return translations except (httpx.HTTPError, ValueError, TranslationError) as fallback_exc: last_error = fallback_exc if self._should_retry_with_structured_translation_prompt(last_error): @@ -574,7 +634,10 @@ class LMStudioTranslator: target_language, ) structured_content = self._post_chat_completion(structured_payload) - return self.parse_batch_translation_response(structured_content, batch) + translations = self.parse_batch_translation_response(structured_content, batch) + if self._cache_enabled: + _write_json_cache(cache_path, {"translations": translations}) + return translations except (httpx.HTTPError, ValueError, TranslationError) as structured_exc: last_error = structured_exc diff --git a/tests/test_translation.py b/tests/test_translation.py index 9c6bd7c..77cbba0 100644 --- a/tests/test_translation.py +++ b/tests/test_translation.py @@ -8,6 +8,7 @@ import httpx import pytest from src.core_utils import TranslationError +from src import translation from src.translation import LMStudioTranslator, TranslationConfig @@ -56,7 +57,8 @@ def test_build_contextual_batch_payload_includes_neighboring_segments(): assert payload["model"] == "gemma-3-4b-it" assert payload["messages"][0]["role"] == "system" - assert "expert audiovisual translator for dubbed video content" in payload["messages"][0]["content"] + assert "expert audiovisual translator and dubbing script adapter" in payload["messages"][0]["content"] + assert "Preserve segment ids and output order exactly" in payload["messages"][0]["content"] assert payload["temperature"] == 0.0 assert payload["top_p"] == 1.0 assert user_payload == { @@ -116,6 +118,29 @@ def test_translate_segments_batches_context_and_preserves_exact_mapping(): assert [item["id"] for item in requests[1]["segments"]] == ["4", "5", "6"] +def test_translate_segments_uses_persistent_cache(tmp_path, monkeypatch): + requests = {"count": 0} + monkeypatch.setattr(translation, "TRANSLATION_CACHE_DIR", tmp_path) + + def handler(request: httpx.Request) -> httpx.Response: + requests["count"] += 1 + payload = _read_request_json(request) + batch_request = json.loads(payload["messages"][1]["content"]) + translations = [ + {"id": item["id"], "translated_text": f"cached::{item['text']}"} + for item in batch_request["segments"] + ] + return _mock_batch_response(translations) + + config = TranslationConfig(model="cache-model") + first_translator = LMStudioTranslator(config, client=_mock_client(handler)) + second_translator = LMStudioTranslator(config, client=_mock_client(handler)) + + assert first_translator.translate_segments(["hello", "world"], "fr", "en") == ["cached::hello", "cached::world"] + assert second_translator.translate_segments(["hello", "world"], "fr", "en") == ["cached::hello", "cached::world"] + assert requests["count"] == 1 + + def test_retry_on_transient_http_error_then_succeeds(): attempts = {"count": 0} diff --git a/tests/test_tts_retry.py b/tests/test_tts_retry.py index 250e485..af3f1cd 100644 --- a/tests/test_tts_retry.py +++ b/tests/test_tts_retry.py @@ -4,6 +4,7 @@ from __future__ import annotations import asyncio +from src import engines from src.engines import Engine from src.translation import TranslationConfig @@ -46,3 +47,39 @@ def test_synthesize_retries_transient_edge_tts_failure(tmp_path, monkeypatch): assert calls["count"] == 2 assert out_path.exists() assert out_path.stat().st_size == 2048 + + +def test_synthesize_uses_tts_cache(tmp_path, monkeypatch): + calls = {"count": 0} + cache_dir = tmp_path / "tts-cache" + monkeypatch.setattr(engines, "TTS_CACHE_DIR", cache_dir) + + class FakeCommunicate: + def __init__(self, text, voice, rate): + self.text = text + self.voice = voice + self.rate = rate + + async def save(self, out_path): + calls["count"] += 1 + with open(out_path, "wb") as audio_file: + audio_file.write(b"1" * 2048) + + monkeypatch.setattr("src.engines.edge_tts.Communicate", FakeCommunicate) + + engine = Engine( + "cpu", + translation_config=TranslationConfig( + base_url="http://127.0.0.1:1234/v1", + api_key="test-key", + model="test-model", + ), + ) + first_out = tmp_path / "first.mp3" + second_out = tmp_path / "second.mp3" + + asyncio.run(engine.synthesize("Bonjour", "fr", first_out)) + asyncio.run(engine.synthesize("Bonjour", "fr", second_out)) + + assert calls["count"] == 1 + assert first_out.read_bytes() == second_out.read_bytes()