Add translation and TTS caching

This commit is contained in:
2026-05-24 16:49:21 +01:00
parent 803f532ff3
commit 9fbb7c1756
5 changed files with 175 additions and 18 deletions

View File

@@ -82,6 +82,8 @@ You can also upload a local `.mp4` instead of entering a YouTube URL. Uploaded v
The web UI automatically refreshes job status, progress, steps, and output choices every few seconds while it is open. The manual **Refresh** button is still available. The web UI automatically refreshes job status, progress, steps, and output choices every few seconds while it is open. The manual **Refresh** button is still available.
Translations and raw TTS clips are cached under `.cache/translations` and `.cache/tts`. This lets reruns skip work that already succeeded, which is especially useful after transient TTS failures. Set `TRANSLATION_CACHE_ENABLED=0` or `TTS_CACHE_ENABLED=0` to disable those caches.
### Docker ### Docker
Build and run the Gradio UI in a container: Build and run the Gradio UI in a container:

View File

@@ -16,8 +16,10 @@ import torch
import asyncio import asyncio
import edge_tts import edge_tts
import gc import gc
import hashlib
import json import json
import os import os
import shutil
from abc import ABC from abc import ABC
import numpy as np import numpy as np
from pathlib import Path from pathlib import Path
@@ -42,6 +44,7 @@ BASE_DIR = Path(__file__).resolve().parent.parent
CACHE_DIR = BASE_DIR / ".cache" CACHE_DIR = BASE_DIR / ".cache"
OUTPUT_DIR = BASE_DIR / "output" OUTPUT_DIR = BASE_DIR / "output"
TEMP_DIR = BASE_DIR / "temp" TEMP_DIR = BASE_DIR / "temp"
TTS_CACHE_DIR = CACHE_DIR / "tts"
# Configuration files # Configuration files
LANG_MAP_FILE = BASE_DIR / "language_map.json" LANG_MAP_FILE = BASE_DIR / "language_map.json"
@@ -56,6 +59,23 @@ AUDIO_CHANNELS = 1
DEFAULT_TTS_MAX_RETRIES = int(os.getenv("TTS_MAX_RETRIES", "4")) DEFAULT_TTS_MAX_RETRIES = int(os.getenv("TTS_MAX_RETRIES", "4"))
DEFAULT_TTS_RETRY_BACKOFF_SECONDS = float(os.getenv("TTS_RETRY_BACKOFF_SECONDS", "2.0")) DEFAULT_TTS_RETRY_BACKOFF_SECONDS = float(os.getenv("TTS_RETRY_BACKOFF_SECONDS", "2.0"))
def _cache_enabled(env_name: str) -> bool:
return (os.getenv(env_name, "1") or "").strip().lower() not in {"0", "false", "no", "off"}
def _tts_cache_key(text: str, target_lang: str, voice: str, rate: str) -> str:
payload = {
"version": "edge-tts-v1",
"text": text,
"target_lang": target_lang,
"voice": voice,
"rate": rate,
"sample_rate": SAMPLE_RATE,
}
serialized = json.dumps(payload, ensure_ascii=False, sort_keys=True, separators=(",", ":"))
return hashlib.sha256(serialized.encode("utf-8")).hexdigest()
def _select_optimal_whisper_model(device: str = "cpu") -> str: def _select_optimal_whisper_model(device: str = "cpu") -> str:
"""Select optimal Whisper model based on available VRAM and device. """Select optimal Whisper model based on available VRAM and device.
@@ -495,6 +515,12 @@ class Engine(PipelineComponent):
lang_cfg = self._getLangConfig(target_lang) lang_cfg = self._getLangConfig(target_lang)
voice_pool = self.config_manager.getVoicePool(target_lang, gender) voice_pool = self.config_manager.getVoicePool(target_lang, gender)
voice = voice_pool[0] if voice_pool else DEFAULT_VOICE voice = voice_pool[0] if voice_pool else DEFAULT_VOICE
cache_path = TTS_CACHE_DIR / f"{_tts_cache_key(text, target_lang, voice, rate)}.mp3"
if _cache_enabled("TTS_CACHE_ENABLED") and cache_path.exists() and cache_path.stat().st_size >= 1024:
print(f"[*] TTS cache hit: {cache_path.name}")
shutil.copyfile(cache_path, out_path)
return
try: try:
communicate = edge_tts.Communicate(text, voice=voice, rate=rate) communicate = edge_tts.Communicate(text, voice=voice, rate=rate)
@@ -503,6 +529,10 @@ class Engine(PipelineComponent):
if not out_path.exists() or out_path.stat().st_size < 1024: if not out_path.exists() or out_path.stat().st_size < 1024:
raise RuntimeError("TTS file invalid") raise RuntimeError("TTS file invalid")
if _cache_enabled("TTS_CACHE_ENABLED"):
cache_path.parent.mkdir(parents=True, exist_ok=True)
shutil.copyfile(out_path, cache_path)
return return
except Exception as exc: except Exception as exc:
last_error = exc last_error = exc

View File

@@ -2,10 +2,12 @@
from __future__ import annotations from __future__ import annotations
import hashlib
import json import json
import os import os
import time import time
from dataclasses import dataclass from dataclasses import dataclass
from pathlib import Path
from typing import Any, Dict, List, Optional from typing import Any, Dict, List, Optional
from urllib.parse import urlparse from urllib.parse import urlparse
@@ -20,6 +22,9 @@ DEFAULT_TRANSLATION_BACKEND = "lmstudio"
DEFAULT_CONTEXTUAL_BATCH_SIZE = 5 DEFAULT_CONTEXTUAL_BATCH_SIZE = 5
MIN_CONTEXTUAL_BATCH_SIZE = 3 MIN_CONTEXTUAL_BATCH_SIZE = 3
DEFAULT_CONTEXT_SEGMENTS = 2 DEFAULT_CONTEXT_SEGMENTS = 2
PROMPT_VERSION = "gpt54-dub-v2"
BASE_DIR = Path(__file__).resolve().parent.parent
TRANSLATION_CACHE_DIR = BASE_DIR / ".cache" / "translations"
def _normalize_base_url(base_url: str) -> str: def _normalize_base_url(base_url: str) -> str:
@@ -125,26 +130,56 @@ def _build_system_prompt(source_language: str, target_language: str) -> str:
def _build_contextual_system_prompt(source_language: str, target_language: str) -> str: def _build_contextual_system_prompt(source_language: str, target_language: str) -> str:
source_descriptor = source_language or "auto" source_descriptor = source_language or "auto"
return ( return (
"You are an expert audiovisual translator for dubbed video content.\n\n" "You are an expert audiovisual translator and dubbing script adapter.\n\n"
f"Translate subtitle segments from {source_descriptor} to {target_language} for natural spoken dubbing.\n\n" f"Translate subtitle segments from {source_descriptor} to {target_language} for natural spoken dubbing.\n\n"
"Rules:\n" "Primary objective:\n"
"- Preserve meaning, intent, tone, and subtext.\n" "- Produce faithful, idiomatic spoken lines that can be read aloud naturally in the target language.\n"
"- Use surrounding subtitle context to resolve ambiguity.\n" "- Preserve intent, tone, emotion, register, humor, sarcasm, hesitation, and subtext.\n"
"- Do not summarize.\n" "- Use the provided previous and next segments only as context; translate only the current segments.\n\n"
"- Do not simplify unless needed for natural speech.\n" "Dubbing adaptation rules:\n"
"- Do not add explanations, notes, or commentary.\n" "- Prefer natural speech over literal word-for-word phrasing when the literal version sounds stiff.\n"
"- Preserve humor, sarcasm, emotional tone, and register.\n" "- Keep each translated segment close to the source segment length when possible, because it will be timed to video.\n"
"- Keep names, brands, URLs, emails, file paths, code, and product names unchanged unless transliteration is clearly needed.\n" "- Do not add new claims, soften meaning, moralize, censor, summarize, or omit content.\n"
"- Keep the translation natural for spoken dubbing.\n" "- Preserve speaker deixis and continuity across adjacent segments.\n"
"- Preserve segment boundaries exactly.\n" "- Keep names, brands, URLs, emails, file paths, code, product names, and quoted UI text unchanged unless transliteration is clearly required.\n"
"- Preserve numbers, units, dates, and technical terms accurately.\n"
"- If a phrase is slang, idiom, or a joke, translate the effect rather than the literal wording.\n\n"
"Output contract:\n"
"- Return valid JSON only, with no markdown fences or commentary.\n"
"- Return exactly one translated item per input segment.\n" "- Return exactly one translated item per input segment.\n"
"- Output order must match input order exactly.\n" "- Preserve segment ids and output order exactly.\n"
"- If an input segment is empty or whitespace-only, return an empty translated_text for that id.\n" "- Preserve empty or whitespace-only segments as an empty translated_text.\n"
"- Output valid JSON only.\n" "- Do not include previous_segments or next_segments in the output.\n"
'- Return this exact schema: {"translations":[{"id":"...","translated_text":"..."}]}.' '- Return this exact schema: {"translations":[{"id":"...","translated_text":"..."}]}.'
) )
def _cache_enabled(env_name: str) -> bool:
return (os.getenv(env_name, "1") or "").strip().lower() not in {"0", "false", "no", "off"}
def _json_cache_key(payload: Dict[str, Any]) -> str:
serialized = json.dumps(payload, ensure_ascii=False, sort_keys=True, separators=(",", ":"))
return hashlib.sha256(serialized.encode("utf-8")).hexdigest()
def _read_json_cache(cache_path: Path) -> Optional[Dict[str, Any]]:
if not cache_path.exists():
return None
try:
payload = json.loads(cache_path.read_text(encoding="utf-8"))
except (OSError, json.JSONDecodeError):
return None
return payload if isinstance(payload, dict) else None
def _write_json_cache(cache_path: Path, payload: Dict[str, Any]) -> None:
cache_path.parent.mkdir(parents=True, exist_ok=True)
tmp_path = cache_path.with_suffix(".tmp")
tmp_path.write_text(json.dumps(payload, ensure_ascii=False, indent=2), encoding="utf-8")
tmp_path.replace(cache_path)
@dataclass(frozen=True) @dataclass(frozen=True)
class TranslationSegment: class TranslationSegment:
"""A subtitle segment prepared for contextual batch translation.""" """A subtitle segment prepared for contextual batch translation."""
@@ -183,6 +218,7 @@ class LMStudioTranslator:
self._client = client or httpx.Client(timeout=httpx.Timeout(self.config.timeout_seconds)) self._client = client or httpx.Client(timeout=httpx.Timeout(self.config.timeout_seconds))
self._owns_client = client is None self._owns_client = client is None
self._sleeper = sleeper self._sleeper = sleeper
self._cache_enabled = _cache_enabled("TRANSLATION_CACHE_ENABLED")
@staticmethod @staticmethod
def _generation_settings() -> Dict[str, Any]: def _generation_settings() -> Dict[str, Any]:
@@ -551,19 +587,43 @@ class LMStudioTranslator:
) -> List[str]: ) -> List[str]:
"""Translate a single contextual subtitle batch with validation and retries.""" """Translate a single contextual subtitle batch with validation and retries."""
payload = self.build_contextual_batch_payload(batch, source_language, target_language) payload = self.build_contextual_batch_payload(batch, source_language, target_language)
cache_payload = {
"prompt_version": PROMPT_VERSION,
"backend": self.config.backend,
"base_url": self.config.base_url,
"model": self.config.model,
"source_language": source_language or "auto",
"target_language": target_language,
"request": self.build_contextual_batch_request(batch, source_language, target_language),
}
cache_path = TRANSLATION_CACHE_DIR / f"{_json_cache_key(cache_payload)}.json"
if self._cache_enabled:
cached_payload = _read_json_cache(cache_path)
cached_translations = cached_payload.get("translations") if cached_payload else None
if isinstance(cached_translations, list) and all(isinstance(item, str) for item in cached_translations):
if len(cached_translations) == len(batch.segments):
print(f"[*] Translation cache hit: {cache_path.name}")
return cached_translations
last_error: Optional[Exception] = None last_error: Optional[Exception] = None
for attempt in range(1, self.config.max_retries + 1): for attempt in range(1, self.config.max_retries + 1):
try: try:
response_content = self._post_chat_completion(payload) response_content = self._post_chat_completion(payload)
return self.parse_batch_translation_response(response_content, batch) translations = self.parse_batch_translation_response(response_content, batch)
if self._cache_enabled:
_write_json_cache(cache_path, {"translations": translations})
return translations
except (httpx.HTTPError, ValueError, TranslationError) as exc: except (httpx.HTTPError, ValueError, TranslationError) as exc:
last_error = exc last_error = exc
if self._should_retry_with_user_only_prompt(exc): if self._should_retry_with_user_only_prompt(exc):
try: try:
fallback_payload = self.build_contextual_user_only_payload(batch, source_language, target_language) fallback_payload = self.build_contextual_user_only_payload(batch, source_language, target_language)
fallback_content = self._post_chat_completion(fallback_payload) fallback_content = self._post_chat_completion(fallback_payload)
return self.parse_batch_translation_response(fallback_content, batch) translations = self.parse_batch_translation_response(fallback_content, batch)
if self._cache_enabled:
_write_json_cache(cache_path, {"translations": translations})
return translations
except (httpx.HTTPError, ValueError, TranslationError) as fallback_exc: except (httpx.HTTPError, ValueError, TranslationError) as fallback_exc:
last_error = fallback_exc last_error = fallback_exc
if self._should_retry_with_structured_translation_prompt(last_error): if self._should_retry_with_structured_translation_prompt(last_error):
@@ -574,7 +634,10 @@ class LMStudioTranslator:
target_language, target_language,
) )
structured_content = self._post_chat_completion(structured_payload) structured_content = self._post_chat_completion(structured_payload)
return self.parse_batch_translation_response(structured_content, batch) translations = self.parse_batch_translation_response(structured_content, batch)
if self._cache_enabled:
_write_json_cache(cache_path, {"translations": translations})
return translations
except (httpx.HTTPError, ValueError, TranslationError) as structured_exc: except (httpx.HTTPError, ValueError, TranslationError) as structured_exc:
last_error = structured_exc last_error = structured_exc

View File

@@ -8,6 +8,7 @@ import httpx
import pytest import pytest
from src.core_utils import TranslationError from src.core_utils import TranslationError
from src import translation
from src.translation import LMStudioTranslator, TranslationConfig from src.translation import LMStudioTranslator, TranslationConfig
@@ -56,7 +57,8 @@ def test_build_contextual_batch_payload_includes_neighboring_segments():
assert payload["model"] == "gemma-3-4b-it" assert payload["model"] == "gemma-3-4b-it"
assert payload["messages"][0]["role"] == "system" assert payload["messages"][0]["role"] == "system"
assert "expert audiovisual translator for dubbed video content" in payload["messages"][0]["content"] assert "expert audiovisual translator and dubbing script adapter" in payload["messages"][0]["content"]
assert "Preserve segment ids and output order exactly" in payload["messages"][0]["content"]
assert payload["temperature"] == 0.0 assert payload["temperature"] == 0.0
assert payload["top_p"] == 1.0 assert payload["top_p"] == 1.0
assert user_payload == { assert user_payload == {
@@ -116,6 +118,29 @@ def test_translate_segments_batches_context_and_preserves_exact_mapping():
assert [item["id"] for item in requests[1]["segments"]] == ["4", "5", "6"] assert [item["id"] for item in requests[1]["segments"]] == ["4", "5", "6"]
def test_translate_segments_uses_persistent_cache(tmp_path, monkeypatch):
requests = {"count": 0}
monkeypatch.setattr(translation, "TRANSLATION_CACHE_DIR", tmp_path)
def handler(request: httpx.Request) -> httpx.Response:
requests["count"] += 1
payload = _read_request_json(request)
batch_request = json.loads(payload["messages"][1]["content"])
translations = [
{"id": item["id"], "translated_text": f"cached::{item['text']}"}
for item in batch_request["segments"]
]
return _mock_batch_response(translations)
config = TranslationConfig(model="cache-model")
first_translator = LMStudioTranslator(config, client=_mock_client(handler))
second_translator = LMStudioTranslator(config, client=_mock_client(handler))
assert first_translator.translate_segments(["hello", "world"], "fr", "en") == ["cached::hello", "cached::world"]
assert second_translator.translate_segments(["hello", "world"], "fr", "en") == ["cached::hello", "cached::world"]
assert requests["count"] == 1
def test_retry_on_transient_http_error_then_succeeds(): def test_retry_on_transient_http_error_then_succeeds():
attempts = {"count": 0} attempts = {"count": 0}

View File

@@ -4,6 +4,7 @@ from __future__ import annotations
import asyncio import asyncio
from src import engines
from src.engines import Engine from src.engines import Engine
from src.translation import TranslationConfig from src.translation import TranslationConfig
@@ -46,3 +47,39 @@ def test_synthesize_retries_transient_edge_tts_failure(tmp_path, monkeypatch):
assert calls["count"] == 2 assert calls["count"] == 2
assert out_path.exists() assert out_path.exists()
assert out_path.stat().st_size == 2048 assert out_path.stat().st_size == 2048
def test_synthesize_uses_tts_cache(tmp_path, monkeypatch):
calls = {"count": 0}
cache_dir = tmp_path / "tts-cache"
monkeypatch.setattr(engines, "TTS_CACHE_DIR", cache_dir)
class FakeCommunicate:
def __init__(self, text, voice, rate):
self.text = text
self.voice = voice
self.rate = rate
async def save(self, out_path):
calls["count"] += 1
with open(out_path, "wb") as audio_file:
audio_file.write(b"1" * 2048)
monkeypatch.setattr("src.engines.edge_tts.Communicate", FakeCommunicate)
engine = Engine(
"cpu",
translation_config=TranslationConfig(
base_url="http://127.0.0.1:1234/v1",
api_key="test-key",
model="test-model",
),
)
first_out = tmp_path / "first.mp3"
second_out = tmp_path / "second.mp3"
asyncio.run(engine.synthesize("Bonjour", "fr", first_out))
asyncio.run(engine.synthesize("Bonjour", "fr", second_out))
assert calls["count"] == 1
assert first_out.read_bytes() == second_out.read_bytes()