Add translation and TTS caching
This commit is contained in:
@@ -82,6 +82,8 @@ You can also upload a local `.mp4` instead of entering a YouTube URL. Uploaded v
|
||||
|
||||
The web UI automatically refreshes job status, progress, steps, and output choices every few seconds while it is open. The manual **Refresh** button is still available.
|
||||
|
||||
Translations and raw TTS clips are cached under `.cache/translations` and `.cache/tts`. This lets reruns skip work that already succeeded, which is especially useful after transient TTS failures. Set `TRANSLATION_CACHE_ENABLED=0` or `TTS_CACHE_ENABLED=0` to disable those caches.
|
||||
|
||||
### Docker
|
||||
|
||||
Build and run the Gradio UI in a container:
|
||||
|
||||
@@ -16,8 +16,10 @@ import torch
|
||||
import asyncio
|
||||
import edge_tts
|
||||
import gc
|
||||
import hashlib
|
||||
import json
|
||||
import os
|
||||
import shutil
|
||||
from abc import ABC
|
||||
import numpy as np
|
||||
from pathlib import Path
|
||||
@@ -42,6 +44,7 @@ BASE_DIR = Path(__file__).resolve().parent.parent
|
||||
CACHE_DIR = BASE_DIR / ".cache"
|
||||
OUTPUT_DIR = BASE_DIR / "output"
|
||||
TEMP_DIR = BASE_DIR / "temp"
|
||||
TTS_CACHE_DIR = CACHE_DIR / "tts"
|
||||
|
||||
# Configuration files
|
||||
LANG_MAP_FILE = BASE_DIR / "language_map.json"
|
||||
@@ -56,6 +59,23 @@ AUDIO_CHANNELS = 1
|
||||
DEFAULT_TTS_MAX_RETRIES = int(os.getenv("TTS_MAX_RETRIES", "4"))
|
||||
DEFAULT_TTS_RETRY_BACKOFF_SECONDS = float(os.getenv("TTS_RETRY_BACKOFF_SECONDS", "2.0"))
|
||||
|
||||
|
||||
def _cache_enabled(env_name: str) -> bool:
|
||||
return (os.getenv(env_name, "1") or "").strip().lower() not in {"0", "false", "no", "off"}
|
||||
|
||||
|
||||
def _tts_cache_key(text: str, target_lang: str, voice: str, rate: str) -> str:
|
||||
payload = {
|
||||
"version": "edge-tts-v1",
|
||||
"text": text,
|
||||
"target_lang": target_lang,
|
||||
"voice": voice,
|
||||
"rate": rate,
|
||||
"sample_rate": SAMPLE_RATE,
|
||||
}
|
||||
serialized = json.dumps(payload, ensure_ascii=False, sort_keys=True, separators=(",", ":"))
|
||||
return hashlib.sha256(serialized.encode("utf-8")).hexdigest()
|
||||
|
||||
def _select_optimal_whisper_model(device: str = "cpu") -> str:
|
||||
"""Select optimal Whisper model based on available VRAM and device.
|
||||
|
||||
@@ -495,6 +515,12 @@ class Engine(PipelineComponent):
|
||||
lang_cfg = self._getLangConfig(target_lang)
|
||||
voice_pool = self.config_manager.getVoicePool(target_lang, gender)
|
||||
voice = voice_pool[0] if voice_pool else DEFAULT_VOICE
|
||||
cache_path = TTS_CACHE_DIR / f"{_tts_cache_key(text, target_lang, voice, rate)}.mp3"
|
||||
|
||||
if _cache_enabled("TTS_CACHE_ENABLED") and cache_path.exists() and cache_path.stat().st_size >= 1024:
|
||||
print(f"[*] TTS cache hit: {cache_path.name}")
|
||||
shutil.copyfile(cache_path, out_path)
|
||||
return
|
||||
|
||||
try:
|
||||
communicate = edge_tts.Communicate(text, voice=voice, rate=rate)
|
||||
@@ -503,6 +529,10 @@ class Engine(PipelineComponent):
|
||||
if not out_path.exists() or out_path.stat().st_size < 1024:
|
||||
raise RuntimeError("TTS file invalid")
|
||||
|
||||
if _cache_enabled("TTS_CACHE_ENABLED"):
|
||||
cache_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
shutil.copyfile(out_path, cache_path)
|
||||
|
||||
return
|
||||
except Exception as exc:
|
||||
last_error = exc
|
||||
|
||||
@@ -2,10 +2,12 @@
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import hashlib
|
||||
import json
|
||||
import os
|
||||
import time
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, List, Optional
|
||||
from urllib.parse import urlparse
|
||||
|
||||
@@ -20,6 +22,9 @@ DEFAULT_TRANSLATION_BACKEND = "lmstudio"
|
||||
DEFAULT_CONTEXTUAL_BATCH_SIZE = 5
|
||||
MIN_CONTEXTUAL_BATCH_SIZE = 3
|
||||
DEFAULT_CONTEXT_SEGMENTS = 2
|
||||
PROMPT_VERSION = "gpt54-dub-v2"
|
||||
BASE_DIR = Path(__file__).resolve().parent.parent
|
||||
TRANSLATION_CACHE_DIR = BASE_DIR / ".cache" / "translations"
|
||||
|
||||
|
||||
def _normalize_base_url(base_url: str) -> str:
|
||||
@@ -125,26 +130,56 @@ def _build_system_prompt(source_language: str, target_language: str) -> str:
|
||||
def _build_contextual_system_prompt(source_language: str, target_language: str) -> str:
|
||||
source_descriptor = source_language or "auto"
|
||||
return (
|
||||
"You are an expert audiovisual translator for dubbed video content.\n\n"
|
||||
"You are an expert audiovisual translator and dubbing script adapter.\n\n"
|
||||
f"Translate subtitle segments from {source_descriptor} to {target_language} for natural spoken dubbing.\n\n"
|
||||
"Rules:\n"
|
||||
"- Preserve meaning, intent, tone, and subtext.\n"
|
||||
"- Use surrounding subtitle context to resolve ambiguity.\n"
|
||||
"- Do not summarize.\n"
|
||||
"- Do not simplify unless needed for natural speech.\n"
|
||||
"- Do not add explanations, notes, or commentary.\n"
|
||||
"- Preserve humor, sarcasm, emotional tone, and register.\n"
|
||||
"- Keep names, brands, URLs, emails, file paths, code, and product names unchanged unless transliteration is clearly needed.\n"
|
||||
"- Keep the translation natural for spoken dubbing.\n"
|
||||
"- Preserve segment boundaries exactly.\n"
|
||||
"Primary objective:\n"
|
||||
"- Produce faithful, idiomatic spoken lines that can be read aloud naturally in the target language.\n"
|
||||
"- Preserve intent, tone, emotion, register, humor, sarcasm, hesitation, and subtext.\n"
|
||||
"- Use the provided previous and next segments only as context; translate only the current segments.\n\n"
|
||||
"Dubbing adaptation rules:\n"
|
||||
"- Prefer natural speech over literal word-for-word phrasing when the literal version sounds stiff.\n"
|
||||
"- Keep each translated segment close to the source segment length when possible, because it will be timed to video.\n"
|
||||
"- Do not add new claims, soften meaning, moralize, censor, summarize, or omit content.\n"
|
||||
"- Preserve speaker deixis and continuity across adjacent segments.\n"
|
||||
"- Keep names, brands, URLs, emails, file paths, code, product names, and quoted UI text unchanged unless transliteration is clearly required.\n"
|
||||
"- Preserve numbers, units, dates, and technical terms accurately.\n"
|
||||
"- If a phrase is slang, idiom, or a joke, translate the effect rather than the literal wording.\n\n"
|
||||
"Output contract:\n"
|
||||
"- Return valid JSON only, with no markdown fences or commentary.\n"
|
||||
"- Return exactly one translated item per input segment.\n"
|
||||
"- Output order must match input order exactly.\n"
|
||||
"- If an input segment is empty or whitespace-only, return an empty translated_text for that id.\n"
|
||||
"- Output valid JSON only.\n"
|
||||
"- Preserve segment ids and output order exactly.\n"
|
||||
"- Preserve empty or whitespace-only segments as an empty translated_text.\n"
|
||||
"- Do not include previous_segments or next_segments in the output.\n"
|
||||
'- Return this exact schema: {"translations":[{"id":"...","translated_text":"..."}]}.'
|
||||
)
|
||||
|
||||
|
||||
def _cache_enabled(env_name: str) -> bool:
|
||||
return (os.getenv(env_name, "1") or "").strip().lower() not in {"0", "false", "no", "off"}
|
||||
|
||||
|
||||
def _json_cache_key(payload: Dict[str, Any]) -> str:
|
||||
serialized = json.dumps(payload, ensure_ascii=False, sort_keys=True, separators=(",", ":"))
|
||||
return hashlib.sha256(serialized.encode("utf-8")).hexdigest()
|
||||
|
||||
|
||||
def _read_json_cache(cache_path: Path) -> Optional[Dict[str, Any]]:
|
||||
if not cache_path.exists():
|
||||
return None
|
||||
try:
|
||||
payload = json.loads(cache_path.read_text(encoding="utf-8"))
|
||||
except (OSError, json.JSONDecodeError):
|
||||
return None
|
||||
return payload if isinstance(payload, dict) else None
|
||||
|
||||
|
||||
def _write_json_cache(cache_path: Path, payload: Dict[str, Any]) -> None:
|
||||
cache_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
tmp_path = cache_path.with_suffix(".tmp")
|
||||
tmp_path.write_text(json.dumps(payload, ensure_ascii=False, indent=2), encoding="utf-8")
|
||||
tmp_path.replace(cache_path)
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class TranslationSegment:
|
||||
"""A subtitle segment prepared for contextual batch translation."""
|
||||
@@ -183,6 +218,7 @@ class LMStudioTranslator:
|
||||
self._client = client or httpx.Client(timeout=httpx.Timeout(self.config.timeout_seconds))
|
||||
self._owns_client = client is None
|
||||
self._sleeper = sleeper
|
||||
self._cache_enabled = _cache_enabled("TRANSLATION_CACHE_ENABLED")
|
||||
|
||||
@staticmethod
|
||||
def _generation_settings() -> Dict[str, Any]:
|
||||
@@ -551,19 +587,43 @@ class LMStudioTranslator:
|
||||
) -> List[str]:
|
||||
"""Translate a single contextual subtitle batch with validation and retries."""
|
||||
payload = self.build_contextual_batch_payload(batch, source_language, target_language)
|
||||
cache_payload = {
|
||||
"prompt_version": PROMPT_VERSION,
|
||||
"backend": self.config.backend,
|
||||
"base_url": self.config.base_url,
|
||||
"model": self.config.model,
|
||||
"source_language": source_language or "auto",
|
||||
"target_language": target_language,
|
||||
"request": self.build_contextual_batch_request(batch, source_language, target_language),
|
||||
}
|
||||
cache_path = TRANSLATION_CACHE_DIR / f"{_json_cache_key(cache_payload)}.json"
|
||||
if self._cache_enabled:
|
||||
cached_payload = _read_json_cache(cache_path)
|
||||
cached_translations = cached_payload.get("translations") if cached_payload else None
|
||||
if isinstance(cached_translations, list) and all(isinstance(item, str) for item in cached_translations):
|
||||
if len(cached_translations) == len(batch.segments):
|
||||
print(f"[*] Translation cache hit: {cache_path.name}")
|
||||
return cached_translations
|
||||
|
||||
last_error: Optional[Exception] = None
|
||||
|
||||
for attempt in range(1, self.config.max_retries + 1):
|
||||
try:
|
||||
response_content = self._post_chat_completion(payload)
|
||||
return self.parse_batch_translation_response(response_content, batch)
|
||||
translations = self.parse_batch_translation_response(response_content, batch)
|
||||
if self._cache_enabled:
|
||||
_write_json_cache(cache_path, {"translations": translations})
|
||||
return translations
|
||||
except (httpx.HTTPError, ValueError, TranslationError) as exc:
|
||||
last_error = exc
|
||||
if self._should_retry_with_user_only_prompt(exc):
|
||||
try:
|
||||
fallback_payload = self.build_contextual_user_only_payload(batch, source_language, target_language)
|
||||
fallback_content = self._post_chat_completion(fallback_payload)
|
||||
return self.parse_batch_translation_response(fallback_content, batch)
|
||||
translations = self.parse_batch_translation_response(fallback_content, batch)
|
||||
if self._cache_enabled:
|
||||
_write_json_cache(cache_path, {"translations": translations})
|
||||
return translations
|
||||
except (httpx.HTTPError, ValueError, TranslationError) as fallback_exc:
|
||||
last_error = fallback_exc
|
||||
if self._should_retry_with_structured_translation_prompt(last_error):
|
||||
@@ -574,7 +634,10 @@ class LMStudioTranslator:
|
||||
target_language,
|
||||
)
|
||||
structured_content = self._post_chat_completion(structured_payload)
|
||||
return self.parse_batch_translation_response(structured_content, batch)
|
||||
translations = self.parse_batch_translation_response(structured_content, batch)
|
||||
if self._cache_enabled:
|
||||
_write_json_cache(cache_path, {"translations": translations})
|
||||
return translations
|
||||
except (httpx.HTTPError, ValueError, TranslationError) as structured_exc:
|
||||
last_error = structured_exc
|
||||
|
||||
|
||||
@@ -8,6 +8,7 @@ import httpx
|
||||
import pytest
|
||||
|
||||
from src.core_utils import TranslationError
|
||||
from src import translation
|
||||
from src.translation import LMStudioTranslator, TranslationConfig
|
||||
|
||||
|
||||
@@ -56,7 +57,8 @@ def test_build_contextual_batch_payload_includes_neighboring_segments():
|
||||
|
||||
assert payload["model"] == "gemma-3-4b-it"
|
||||
assert payload["messages"][0]["role"] == "system"
|
||||
assert "expert audiovisual translator for dubbed video content" in payload["messages"][0]["content"]
|
||||
assert "expert audiovisual translator and dubbing script adapter" in payload["messages"][0]["content"]
|
||||
assert "Preserve segment ids and output order exactly" in payload["messages"][0]["content"]
|
||||
assert payload["temperature"] == 0.0
|
||||
assert payload["top_p"] == 1.0
|
||||
assert user_payload == {
|
||||
@@ -116,6 +118,29 @@ def test_translate_segments_batches_context_and_preserves_exact_mapping():
|
||||
assert [item["id"] for item in requests[1]["segments"]] == ["4", "5", "6"]
|
||||
|
||||
|
||||
def test_translate_segments_uses_persistent_cache(tmp_path, monkeypatch):
|
||||
requests = {"count": 0}
|
||||
monkeypatch.setattr(translation, "TRANSLATION_CACHE_DIR", tmp_path)
|
||||
|
||||
def handler(request: httpx.Request) -> httpx.Response:
|
||||
requests["count"] += 1
|
||||
payload = _read_request_json(request)
|
||||
batch_request = json.loads(payload["messages"][1]["content"])
|
||||
translations = [
|
||||
{"id": item["id"], "translated_text": f"cached::{item['text']}"}
|
||||
for item in batch_request["segments"]
|
||||
]
|
||||
return _mock_batch_response(translations)
|
||||
|
||||
config = TranslationConfig(model="cache-model")
|
||||
first_translator = LMStudioTranslator(config, client=_mock_client(handler))
|
||||
second_translator = LMStudioTranslator(config, client=_mock_client(handler))
|
||||
|
||||
assert first_translator.translate_segments(["hello", "world"], "fr", "en") == ["cached::hello", "cached::world"]
|
||||
assert second_translator.translate_segments(["hello", "world"], "fr", "en") == ["cached::hello", "cached::world"]
|
||||
assert requests["count"] == 1
|
||||
|
||||
|
||||
def test_retry_on_transient_http_error_then_succeeds():
|
||||
attempts = {"count": 0}
|
||||
|
||||
|
||||
@@ -4,6 +4,7 @@ from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
|
||||
from src import engines
|
||||
from src.engines import Engine
|
||||
from src.translation import TranslationConfig
|
||||
|
||||
@@ -46,3 +47,39 @@ def test_synthesize_retries_transient_edge_tts_failure(tmp_path, monkeypatch):
|
||||
assert calls["count"] == 2
|
||||
assert out_path.exists()
|
||||
assert out_path.stat().st_size == 2048
|
||||
|
||||
|
||||
def test_synthesize_uses_tts_cache(tmp_path, monkeypatch):
|
||||
calls = {"count": 0}
|
||||
cache_dir = tmp_path / "tts-cache"
|
||||
monkeypatch.setattr(engines, "TTS_CACHE_DIR", cache_dir)
|
||||
|
||||
class FakeCommunicate:
|
||||
def __init__(self, text, voice, rate):
|
||||
self.text = text
|
||||
self.voice = voice
|
||||
self.rate = rate
|
||||
|
||||
async def save(self, out_path):
|
||||
calls["count"] += 1
|
||||
with open(out_path, "wb") as audio_file:
|
||||
audio_file.write(b"1" * 2048)
|
||||
|
||||
monkeypatch.setattr("src.engines.edge_tts.Communicate", FakeCommunicate)
|
||||
|
||||
engine = Engine(
|
||||
"cpu",
|
||||
translation_config=TranslationConfig(
|
||||
base_url="http://127.0.0.1:1234/v1",
|
||||
api_key="test-key",
|
||||
model="test-model",
|
||||
),
|
||||
)
|
||||
first_out = tmp_path / "first.mp3"
|
||||
second_out = tmp_path / "second.mp3"
|
||||
|
||||
asyncio.run(engine.synthesize("Bonjour", "fr", first_out))
|
||||
asyncio.run(engine.synthesize("Bonjour", "fr", second_out))
|
||||
|
||||
assert calls["count"] == 1
|
||||
assert first_out.read_bytes() == second_out.read_bytes()
|
||||
|
||||
Reference in New Issue
Block a user