Add translation and TTS caching

This commit is contained in:
2026-05-24 16:49:21 +01:00
parent 803f532ff3
commit 9fbb7c1756
5 changed files with 175 additions and 18 deletions

View File

@@ -82,6 +82,8 @@ You can also upload a local `.mp4` instead of entering a YouTube URL. Uploaded v
The web UI automatically refreshes job status, progress, steps, and output choices every few seconds while it is open. The manual **Refresh** button is still available.
Translations and raw TTS clips are cached under `.cache/translations` and `.cache/tts`. This lets reruns skip work that already succeeded, which is especially useful after transient TTS failures. Set `TRANSLATION_CACHE_ENABLED=0` or `TTS_CACHE_ENABLED=0` to disable those caches.
### Docker
Build and run the Gradio UI in a container:

View File

@@ -16,8 +16,10 @@ import torch
import asyncio
import edge_tts
import gc
import hashlib
import json
import os
import shutil
from abc import ABC
import numpy as np
from pathlib import Path
@@ -42,6 +44,7 @@ BASE_DIR = Path(__file__).resolve().parent.parent
CACHE_DIR = BASE_DIR / ".cache"
OUTPUT_DIR = BASE_DIR / "output"
TEMP_DIR = BASE_DIR / "temp"
TTS_CACHE_DIR = CACHE_DIR / "tts"
# Configuration files
LANG_MAP_FILE = BASE_DIR / "language_map.json"
@@ -56,6 +59,23 @@ AUDIO_CHANNELS = 1
DEFAULT_TTS_MAX_RETRIES = int(os.getenv("TTS_MAX_RETRIES", "4"))
DEFAULT_TTS_RETRY_BACKOFF_SECONDS = float(os.getenv("TTS_RETRY_BACKOFF_SECONDS", "2.0"))
def _cache_enabled(env_name: str) -> bool:
return (os.getenv(env_name, "1") or "").strip().lower() not in {"0", "false", "no", "off"}
def _tts_cache_key(text: str, target_lang: str, voice: str, rate: str) -> str:
payload = {
"version": "edge-tts-v1",
"text": text,
"target_lang": target_lang,
"voice": voice,
"rate": rate,
"sample_rate": SAMPLE_RATE,
}
serialized = json.dumps(payload, ensure_ascii=False, sort_keys=True, separators=(",", ":"))
return hashlib.sha256(serialized.encode("utf-8")).hexdigest()
def _select_optimal_whisper_model(device: str = "cpu") -> str:
"""Select optimal Whisper model based on available VRAM and device.
@@ -495,6 +515,12 @@ class Engine(PipelineComponent):
lang_cfg = self._getLangConfig(target_lang)
voice_pool = self.config_manager.getVoicePool(target_lang, gender)
voice = voice_pool[0] if voice_pool else DEFAULT_VOICE
cache_path = TTS_CACHE_DIR / f"{_tts_cache_key(text, target_lang, voice, rate)}.mp3"
if _cache_enabled("TTS_CACHE_ENABLED") and cache_path.exists() and cache_path.stat().st_size >= 1024:
print(f"[*] TTS cache hit: {cache_path.name}")
shutil.copyfile(cache_path, out_path)
return
try:
communicate = edge_tts.Communicate(text, voice=voice, rate=rate)
@@ -503,6 +529,10 @@ class Engine(PipelineComponent):
if not out_path.exists() or out_path.stat().st_size < 1024:
raise RuntimeError("TTS file invalid")
if _cache_enabled("TTS_CACHE_ENABLED"):
cache_path.parent.mkdir(parents=True, exist_ok=True)
shutil.copyfile(out_path, cache_path)
return
except Exception as exc:
last_error = exc

View File

@@ -2,10 +2,12 @@
from __future__ import annotations
import hashlib
import json
import os
import time
from dataclasses import dataclass
from pathlib import Path
from typing import Any, Dict, List, Optional
from urllib.parse import urlparse
@@ -20,6 +22,9 @@ DEFAULT_TRANSLATION_BACKEND = "lmstudio"
DEFAULT_CONTEXTUAL_BATCH_SIZE = 5
MIN_CONTEXTUAL_BATCH_SIZE = 3
DEFAULT_CONTEXT_SEGMENTS = 2
PROMPT_VERSION = "gpt54-dub-v2"
BASE_DIR = Path(__file__).resolve().parent.parent
TRANSLATION_CACHE_DIR = BASE_DIR / ".cache" / "translations"
def _normalize_base_url(base_url: str) -> str:
@@ -125,26 +130,56 @@ def _build_system_prompt(source_language: str, target_language: str) -> str:
def _build_contextual_system_prompt(source_language: str, target_language: str) -> str:
source_descriptor = source_language or "auto"
return (
"You are an expert audiovisual translator for dubbed video content.\n\n"
"You are an expert audiovisual translator and dubbing script adapter.\n\n"
f"Translate subtitle segments from {source_descriptor} to {target_language} for natural spoken dubbing.\n\n"
"Rules:\n"
"- Preserve meaning, intent, tone, and subtext.\n"
"- Use surrounding subtitle context to resolve ambiguity.\n"
"- Do not summarize.\n"
"- Do not simplify unless needed for natural speech.\n"
"- Do not add explanations, notes, or commentary.\n"
"- Preserve humor, sarcasm, emotional tone, and register.\n"
"- Keep names, brands, URLs, emails, file paths, code, and product names unchanged unless transliteration is clearly needed.\n"
"- Keep the translation natural for spoken dubbing.\n"
"- Preserve segment boundaries exactly.\n"
"Primary objective:\n"
"- Produce faithful, idiomatic spoken lines that can be read aloud naturally in the target language.\n"
"- Preserve intent, tone, emotion, register, humor, sarcasm, hesitation, and subtext.\n"
"- Use the provided previous and next segments only as context; translate only the current segments.\n\n"
"Dubbing adaptation rules:\n"
"- Prefer natural speech over literal word-for-word phrasing when the literal version sounds stiff.\n"
"- Keep each translated segment close to the source segment length when possible, because it will be timed to video.\n"
"- Do not add new claims, soften meaning, moralize, censor, summarize, or omit content.\n"
"- Preserve speaker deixis and continuity across adjacent segments.\n"
"- Keep names, brands, URLs, emails, file paths, code, product names, and quoted UI text unchanged unless transliteration is clearly required.\n"
"- Preserve numbers, units, dates, and technical terms accurately.\n"
"- If a phrase is slang, idiom, or a joke, translate the effect rather than the literal wording.\n\n"
"Output contract:\n"
"- Return valid JSON only, with no markdown fences or commentary.\n"
"- Return exactly one translated item per input segment.\n"
"- Output order must match input order exactly.\n"
"- If an input segment is empty or whitespace-only, return an empty translated_text for that id.\n"
"- Output valid JSON only.\n"
"- Preserve segment ids and output order exactly.\n"
"- Preserve empty or whitespace-only segments as an empty translated_text.\n"
"- Do not include previous_segments or next_segments in the output.\n"
'- Return this exact schema: {"translations":[{"id":"...","translated_text":"..."}]}.'
)
def _cache_enabled(env_name: str) -> bool:
return (os.getenv(env_name, "1") or "").strip().lower() not in {"0", "false", "no", "off"}
def _json_cache_key(payload: Dict[str, Any]) -> str:
serialized = json.dumps(payload, ensure_ascii=False, sort_keys=True, separators=(",", ":"))
return hashlib.sha256(serialized.encode("utf-8")).hexdigest()
def _read_json_cache(cache_path: Path) -> Optional[Dict[str, Any]]:
if not cache_path.exists():
return None
try:
payload = json.loads(cache_path.read_text(encoding="utf-8"))
except (OSError, json.JSONDecodeError):
return None
return payload if isinstance(payload, dict) else None
def _write_json_cache(cache_path: Path, payload: Dict[str, Any]) -> None:
cache_path.parent.mkdir(parents=True, exist_ok=True)
tmp_path = cache_path.with_suffix(".tmp")
tmp_path.write_text(json.dumps(payload, ensure_ascii=False, indent=2), encoding="utf-8")
tmp_path.replace(cache_path)
@dataclass(frozen=True)
class TranslationSegment:
"""A subtitle segment prepared for contextual batch translation."""
@@ -183,6 +218,7 @@ class LMStudioTranslator:
self._client = client or httpx.Client(timeout=httpx.Timeout(self.config.timeout_seconds))
self._owns_client = client is None
self._sleeper = sleeper
self._cache_enabled = _cache_enabled("TRANSLATION_CACHE_ENABLED")
@staticmethod
def _generation_settings() -> Dict[str, Any]:
@@ -551,19 +587,43 @@ class LMStudioTranslator:
) -> List[str]:
"""Translate a single contextual subtitle batch with validation and retries."""
payload = self.build_contextual_batch_payload(batch, source_language, target_language)
cache_payload = {
"prompt_version": PROMPT_VERSION,
"backend": self.config.backend,
"base_url": self.config.base_url,
"model": self.config.model,
"source_language": source_language or "auto",
"target_language": target_language,
"request": self.build_contextual_batch_request(batch, source_language, target_language),
}
cache_path = TRANSLATION_CACHE_DIR / f"{_json_cache_key(cache_payload)}.json"
if self._cache_enabled:
cached_payload = _read_json_cache(cache_path)
cached_translations = cached_payload.get("translations") if cached_payload else None
if isinstance(cached_translations, list) and all(isinstance(item, str) for item in cached_translations):
if len(cached_translations) == len(batch.segments):
print(f"[*] Translation cache hit: {cache_path.name}")
return cached_translations
last_error: Optional[Exception] = None
for attempt in range(1, self.config.max_retries + 1):
try:
response_content = self._post_chat_completion(payload)
return self.parse_batch_translation_response(response_content, batch)
translations = self.parse_batch_translation_response(response_content, batch)
if self._cache_enabled:
_write_json_cache(cache_path, {"translations": translations})
return translations
except (httpx.HTTPError, ValueError, TranslationError) as exc:
last_error = exc
if self._should_retry_with_user_only_prompt(exc):
try:
fallback_payload = self.build_contextual_user_only_payload(batch, source_language, target_language)
fallback_content = self._post_chat_completion(fallback_payload)
return self.parse_batch_translation_response(fallback_content, batch)
translations = self.parse_batch_translation_response(fallback_content, batch)
if self._cache_enabled:
_write_json_cache(cache_path, {"translations": translations})
return translations
except (httpx.HTTPError, ValueError, TranslationError) as fallback_exc:
last_error = fallback_exc
if self._should_retry_with_structured_translation_prompt(last_error):
@@ -574,7 +634,10 @@ class LMStudioTranslator:
target_language,
)
structured_content = self._post_chat_completion(structured_payload)
return self.parse_batch_translation_response(structured_content, batch)
translations = self.parse_batch_translation_response(structured_content, batch)
if self._cache_enabled:
_write_json_cache(cache_path, {"translations": translations})
return translations
except (httpx.HTTPError, ValueError, TranslationError) as structured_exc:
last_error = structured_exc

View File

@@ -8,6 +8,7 @@ import httpx
import pytest
from src.core_utils import TranslationError
from src import translation
from src.translation import LMStudioTranslator, TranslationConfig
@@ -56,7 +57,8 @@ def test_build_contextual_batch_payload_includes_neighboring_segments():
assert payload["model"] == "gemma-3-4b-it"
assert payload["messages"][0]["role"] == "system"
assert "expert audiovisual translator for dubbed video content" in payload["messages"][0]["content"]
assert "expert audiovisual translator and dubbing script adapter" in payload["messages"][0]["content"]
assert "Preserve segment ids and output order exactly" in payload["messages"][0]["content"]
assert payload["temperature"] == 0.0
assert payload["top_p"] == 1.0
assert user_payload == {
@@ -116,6 +118,29 @@ def test_translate_segments_batches_context_and_preserves_exact_mapping():
assert [item["id"] for item in requests[1]["segments"]] == ["4", "5", "6"]
def test_translate_segments_uses_persistent_cache(tmp_path, monkeypatch):
requests = {"count": 0}
monkeypatch.setattr(translation, "TRANSLATION_CACHE_DIR", tmp_path)
def handler(request: httpx.Request) -> httpx.Response:
requests["count"] += 1
payload = _read_request_json(request)
batch_request = json.loads(payload["messages"][1]["content"])
translations = [
{"id": item["id"], "translated_text": f"cached::{item['text']}"}
for item in batch_request["segments"]
]
return _mock_batch_response(translations)
config = TranslationConfig(model="cache-model")
first_translator = LMStudioTranslator(config, client=_mock_client(handler))
second_translator = LMStudioTranslator(config, client=_mock_client(handler))
assert first_translator.translate_segments(["hello", "world"], "fr", "en") == ["cached::hello", "cached::world"]
assert second_translator.translate_segments(["hello", "world"], "fr", "en") == ["cached::hello", "cached::world"]
assert requests["count"] == 1
def test_retry_on_transient_http_error_then_succeeds():
attempts = {"count": 0}

View File

@@ -4,6 +4,7 @@ from __future__ import annotations
import asyncio
from src import engines
from src.engines import Engine
from src.translation import TranslationConfig
@@ -46,3 +47,39 @@ def test_synthesize_retries_transient_edge_tts_failure(tmp_path, monkeypatch):
assert calls["count"] == 2
assert out_path.exists()
assert out_path.stat().st_size == 2048
def test_synthesize_uses_tts_cache(tmp_path, monkeypatch):
calls = {"count": 0}
cache_dir = tmp_path / "tts-cache"
monkeypatch.setattr(engines, "TTS_CACHE_DIR", cache_dir)
class FakeCommunicate:
def __init__(self, text, voice, rate):
self.text = text
self.voice = voice
self.rate = rate
async def save(self, out_path):
calls["count"] += 1
with open(out_path, "wb") as audio_file:
audio_file.write(b"1" * 2048)
monkeypatch.setattr("src.engines.edge_tts.Communicate", FakeCommunicate)
engine = Engine(
"cpu",
translation_config=TranslationConfig(
base_url="http://127.0.0.1:1234/v1",
api_key="test-key",
model="test-model",
),
)
first_out = tmp_path / "first.mp3"
second_out = tmp_path / "second.mp3"
asyncio.run(engine.synthesize("Bonjour", "fr", first_out))
asyncio.run(engine.synthesize("Bonjour", "fr", second_out))
assert calls["count"] == 1
assert first_out.read_bytes() == second_out.read_bytes()