Add translation and TTS caching
This commit is contained in:
@@ -82,6 +82,8 @@ You can also upload a local `.mp4` instead of entering a YouTube URL. Uploaded v
|
|||||||
|
|
||||||
The web UI automatically refreshes job status, progress, steps, and output choices every few seconds while it is open. The manual **Refresh** button is still available.
|
The web UI automatically refreshes job status, progress, steps, and output choices every few seconds while it is open. The manual **Refresh** button is still available.
|
||||||
|
|
||||||
|
Translations and raw TTS clips are cached under `.cache/translations` and `.cache/tts`. This lets reruns skip work that already succeeded, which is especially useful after transient TTS failures. Set `TRANSLATION_CACHE_ENABLED=0` or `TTS_CACHE_ENABLED=0` to disable those caches.
|
||||||
|
|
||||||
### Docker
|
### Docker
|
||||||
|
|
||||||
Build and run the Gradio UI in a container:
|
Build and run the Gradio UI in a container:
|
||||||
|
|||||||
@@ -16,8 +16,10 @@ import torch
|
|||||||
import asyncio
|
import asyncio
|
||||||
import edge_tts
|
import edge_tts
|
||||||
import gc
|
import gc
|
||||||
|
import hashlib
|
||||||
import json
|
import json
|
||||||
import os
|
import os
|
||||||
|
import shutil
|
||||||
from abc import ABC
|
from abc import ABC
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
@@ -42,6 +44,7 @@ BASE_DIR = Path(__file__).resolve().parent.parent
|
|||||||
CACHE_DIR = BASE_DIR / ".cache"
|
CACHE_DIR = BASE_DIR / ".cache"
|
||||||
OUTPUT_DIR = BASE_DIR / "output"
|
OUTPUT_DIR = BASE_DIR / "output"
|
||||||
TEMP_DIR = BASE_DIR / "temp"
|
TEMP_DIR = BASE_DIR / "temp"
|
||||||
|
TTS_CACHE_DIR = CACHE_DIR / "tts"
|
||||||
|
|
||||||
# Configuration files
|
# Configuration files
|
||||||
LANG_MAP_FILE = BASE_DIR / "language_map.json"
|
LANG_MAP_FILE = BASE_DIR / "language_map.json"
|
||||||
@@ -56,6 +59,23 @@ AUDIO_CHANNELS = 1
|
|||||||
DEFAULT_TTS_MAX_RETRIES = int(os.getenv("TTS_MAX_RETRIES", "4"))
|
DEFAULT_TTS_MAX_RETRIES = int(os.getenv("TTS_MAX_RETRIES", "4"))
|
||||||
DEFAULT_TTS_RETRY_BACKOFF_SECONDS = float(os.getenv("TTS_RETRY_BACKOFF_SECONDS", "2.0"))
|
DEFAULT_TTS_RETRY_BACKOFF_SECONDS = float(os.getenv("TTS_RETRY_BACKOFF_SECONDS", "2.0"))
|
||||||
|
|
||||||
|
|
||||||
|
def _cache_enabled(env_name: str) -> bool:
|
||||||
|
return (os.getenv(env_name, "1") or "").strip().lower() not in {"0", "false", "no", "off"}
|
||||||
|
|
||||||
|
|
||||||
|
def _tts_cache_key(text: str, target_lang: str, voice: str, rate: str) -> str:
|
||||||
|
payload = {
|
||||||
|
"version": "edge-tts-v1",
|
||||||
|
"text": text,
|
||||||
|
"target_lang": target_lang,
|
||||||
|
"voice": voice,
|
||||||
|
"rate": rate,
|
||||||
|
"sample_rate": SAMPLE_RATE,
|
||||||
|
}
|
||||||
|
serialized = json.dumps(payload, ensure_ascii=False, sort_keys=True, separators=(",", ":"))
|
||||||
|
return hashlib.sha256(serialized.encode("utf-8")).hexdigest()
|
||||||
|
|
||||||
def _select_optimal_whisper_model(device: str = "cpu") -> str:
|
def _select_optimal_whisper_model(device: str = "cpu") -> str:
|
||||||
"""Select optimal Whisper model based on available VRAM and device.
|
"""Select optimal Whisper model based on available VRAM and device.
|
||||||
|
|
||||||
@@ -495,6 +515,12 @@ class Engine(PipelineComponent):
|
|||||||
lang_cfg = self._getLangConfig(target_lang)
|
lang_cfg = self._getLangConfig(target_lang)
|
||||||
voice_pool = self.config_manager.getVoicePool(target_lang, gender)
|
voice_pool = self.config_manager.getVoicePool(target_lang, gender)
|
||||||
voice = voice_pool[0] if voice_pool else DEFAULT_VOICE
|
voice = voice_pool[0] if voice_pool else DEFAULT_VOICE
|
||||||
|
cache_path = TTS_CACHE_DIR / f"{_tts_cache_key(text, target_lang, voice, rate)}.mp3"
|
||||||
|
|
||||||
|
if _cache_enabled("TTS_CACHE_ENABLED") and cache_path.exists() and cache_path.stat().st_size >= 1024:
|
||||||
|
print(f"[*] TTS cache hit: {cache_path.name}")
|
||||||
|
shutil.copyfile(cache_path, out_path)
|
||||||
|
return
|
||||||
|
|
||||||
try:
|
try:
|
||||||
communicate = edge_tts.Communicate(text, voice=voice, rate=rate)
|
communicate = edge_tts.Communicate(text, voice=voice, rate=rate)
|
||||||
@@ -503,6 +529,10 @@ class Engine(PipelineComponent):
|
|||||||
if not out_path.exists() or out_path.stat().st_size < 1024:
|
if not out_path.exists() or out_path.stat().st_size < 1024:
|
||||||
raise RuntimeError("TTS file invalid")
|
raise RuntimeError("TTS file invalid")
|
||||||
|
|
||||||
|
if _cache_enabled("TTS_CACHE_ENABLED"):
|
||||||
|
cache_path.parent.mkdir(parents=True, exist_ok=True)
|
||||||
|
shutil.copyfile(out_path, cache_path)
|
||||||
|
|
||||||
return
|
return
|
||||||
except Exception as exc:
|
except Exception as exc:
|
||||||
last_error = exc
|
last_error = exc
|
||||||
|
|||||||
@@ -2,10 +2,12 @@
|
|||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import hashlib
|
||||||
import json
|
import json
|
||||||
import os
|
import os
|
||||||
import time
|
import time
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
|
from pathlib import Path
|
||||||
from typing import Any, Dict, List, Optional
|
from typing import Any, Dict, List, Optional
|
||||||
from urllib.parse import urlparse
|
from urllib.parse import urlparse
|
||||||
|
|
||||||
@@ -20,6 +22,9 @@ DEFAULT_TRANSLATION_BACKEND = "lmstudio"
|
|||||||
DEFAULT_CONTEXTUAL_BATCH_SIZE = 5
|
DEFAULT_CONTEXTUAL_BATCH_SIZE = 5
|
||||||
MIN_CONTEXTUAL_BATCH_SIZE = 3
|
MIN_CONTEXTUAL_BATCH_SIZE = 3
|
||||||
DEFAULT_CONTEXT_SEGMENTS = 2
|
DEFAULT_CONTEXT_SEGMENTS = 2
|
||||||
|
PROMPT_VERSION = "gpt54-dub-v2"
|
||||||
|
BASE_DIR = Path(__file__).resolve().parent.parent
|
||||||
|
TRANSLATION_CACHE_DIR = BASE_DIR / ".cache" / "translations"
|
||||||
|
|
||||||
|
|
||||||
def _normalize_base_url(base_url: str) -> str:
|
def _normalize_base_url(base_url: str) -> str:
|
||||||
@@ -125,26 +130,56 @@ def _build_system_prompt(source_language: str, target_language: str) -> str:
|
|||||||
def _build_contextual_system_prompt(source_language: str, target_language: str) -> str:
|
def _build_contextual_system_prompt(source_language: str, target_language: str) -> str:
|
||||||
source_descriptor = source_language or "auto"
|
source_descriptor = source_language or "auto"
|
||||||
return (
|
return (
|
||||||
"You are an expert audiovisual translator for dubbed video content.\n\n"
|
"You are an expert audiovisual translator and dubbing script adapter.\n\n"
|
||||||
f"Translate subtitle segments from {source_descriptor} to {target_language} for natural spoken dubbing.\n\n"
|
f"Translate subtitle segments from {source_descriptor} to {target_language} for natural spoken dubbing.\n\n"
|
||||||
"Rules:\n"
|
"Primary objective:\n"
|
||||||
"- Preserve meaning, intent, tone, and subtext.\n"
|
"- Produce faithful, idiomatic spoken lines that can be read aloud naturally in the target language.\n"
|
||||||
"- Use surrounding subtitle context to resolve ambiguity.\n"
|
"- Preserve intent, tone, emotion, register, humor, sarcasm, hesitation, and subtext.\n"
|
||||||
"- Do not summarize.\n"
|
"- Use the provided previous and next segments only as context; translate only the current segments.\n\n"
|
||||||
"- Do not simplify unless needed for natural speech.\n"
|
"Dubbing adaptation rules:\n"
|
||||||
"- Do not add explanations, notes, or commentary.\n"
|
"- Prefer natural speech over literal word-for-word phrasing when the literal version sounds stiff.\n"
|
||||||
"- Preserve humor, sarcasm, emotional tone, and register.\n"
|
"- Keep each translated segment close to the source segment length when possible, because it will be timed to video.\n"
|
||||||
"- Keep names, brands, URLs, emails, file paths, code, and product names unchanged unless transliteration is clearly needed.\n"
|
"- Do not add new claims, soften meaning, moralize, censor, summarize, or omit content.\n"
|
||||||
"- Keep the translation natural for spoken dubbing.\n"
|
"- Preserve speaker deixis and continuity across adjacent segments.\n"
|
||||||
"- Preserve segment boundaries exactly.\n"
|
"- Keep names, brands, URLs, emails, file paths, code, product names, and quoted UI text unchanged unless transliteration is clearly required.\n"
|
||||||
|
"- Preserve numbers, units, dates, and technical terms accurately.\n"
|
||||||
|
"- If a phrase is slang, idiom, or a joke, translate the effect rather than the literal wording.\n\n"
|
||||||
|
"Output contract:\n"
|
||||||
|
"- Return valid JSON only, with no markdown fences or commentary.\n"
|
||||||
"- Return exactly one translated item per input segment.\n"
|
"- Return exactly one translated item per input segment.\n"
|
||||||
"- Output order must match input order exactly.\n"
|
"- Preserve segment ids and output order exactly.\n"
|
||||||
"- If an input segment is empty or whitespace-only, return an empty translated_text for that id.\n"
|
"- Preserve empty or whitespace-only segments as an empty translated_text.\n"
|
||||||
"- Output valid JSON only.\n"
|
"- Do not include previous_segments or next_segments in the output.\n"
|
||||||
'- Return this exact schema: {"translations":[{"id":"...","translated_text":"..."}]}.'
|
'- Return this exact schema: {"translations":[{"id":"...","translated_text":"..."}]}.'
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _cache_enabled(env_name: str) -> bool:
|
||||||
|
return (os.getenv(env_name, "1") or "").strip().lower() not in {"0", "false", "no", "off"}
|
||||||
|
|
||||||
|
|
||||||
|
def _json_cache_key(payload: Dict[str, Any]) -> str:
|
||||||
|
serialized = json.dumps(payload, ensure_ascii=False, sort_keys=True, separators=(",", ":"))
|
||||||
|
return hashlib.sha256(serialized.encode("utf-8")).hexdigest()
|
||||||
|
|
||||||
|
|
||||||
|
def _read_json_cache(cache_path: Path) -> Optional[Dict[str, Any]]:
|
||||||
|
if not cache_path.exists():
|
||||||
|
return None
|
||||||
|
try:
|
||||||
|
payload = json.loads(cache_path.read_text(encoding="utf-8"))
|
||||||
|
except (OSError, json.JSONDecodeError):
|
||||||
|
return None
|
||||||
|
return payload if isinstance(payload, dict) else None
|
||||||
|
|
||||||
|
|
||||||
|
def _write_json_cache(cache_path: Path, payload: Dict[str, Any]) -> None:
|
||||||
|
cache_path.parent.mkdir(parents=True, exist_ok=True)
|
||||||
|
tmp_path = cache_path.with_suffix(".tmp")
|
||||||
|
tmp_path.write_text(json.dumps(payload, ensure_ascii=False, indent=2), encoding="utf-8")
|
||||||
|
tmp_path.replace(cache_path)
|
||||||
|
|
||||||
|
|
||||||
@dataclass(frozen=True)
|
@dataclass(frozen=True)
|
||||||
class TranslationSegment:
|
class TranslationSegment:
|
||||||
"""A subtitle segment prepared for contextual batch translation."""
|
"""A subtitle segment prepared for contextual batch translation."""
|
||||||
@@ -183,6 +218,7 @@ class LMStudioTranslator:
|
|||||||
self._client = client or httpx.Client(timeout=httpx.Timeout(self.config.timeout_seconds))
|
self._client = client or httpx.Client(timeout=httpx.Timeout(self.config.timeout_seconds))
|
||||||
self._owns_client = client is None
|
self._owns_client = client is None
|
||||||
self._sleeper = sleeper
|
self._sleeper = sleeper
|
||||||
|
self._cache_enabled = _cache_enabled("TRANSLATION_CACHE_ENABLED")
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _generation_settings() -> Dict[str, Any]:
|
def _generation_settings() -> Dict[str, Any]:
|
||||||
@@ -551,19 +587,43 @@ class LMStudioTranslator:
|
|||||||
) -> List[str]:
|
) -> List[str]:
|
||||||
"""Translate a single contextual subtitle batch with validation and retries."""
|
"""Translate a single contextual subtitle batch with validation and retries."""
|
||||||
payload = self.build_contextual_batch_payload(batch, source_language, target_language)
|
payload = self.build_contextual_batch_payload(batch, source_language, target_language)
|
||||||
|
cache_payload = {
|
||||||
|
"prompt_version": PROMPT_VERSION,
|
||||||
|
"backend": self.config.backend,
|
||||||
|
"base_url": self.config.base_url,
|
||||||
|
"model": self.config.model,
|
||||||
|
"source_language": source_language or "auto",
|
||||||
|
"target_language": target_language,
|
||||||
|
"request": self.build_contextual_batch_request(batch, source_language, target_language),
|
||||||
|
}
|
||||||
|
cache_path = TRANSLATION_CACHE_DIR / f"{_json_cache_key(cache_payload)}.json"
|
||||||
|
if self._cache_enabled:
|
||||||
|
cached_payload = _read_json_cache(cache_path)
|
||||||
|
cached_translations = cached_payload.get("translations") if cached_payload else None
|
||||||
|
if isinstance(cached_translations, list) and all(isinstance(item, str) for item in cached_translations):
|
||||||
|
if len(cached_translations) == len(batch.segments):
|
||||||
|
print(f"[*] Translation cache hit: {cache_path.name}")
|
||||||
|
return cached_translations
|
||||||
|
|
||||||
last_error: Optional[Exception] = None
|
last_error: Optional[Exception] = None
|
||||||
|
|
||||||
for attempt in range(1, self.config.max_retries + 1):
|
for attempt in range(1, self.config.max_retries + 1):
|
||||||
try:
|
try:
|
||||||
response_content = self._post_chat_completion(payload)
|
response_content = self._post_chat_completion(payload)
|
||||||
return self.parse_batch_translation_response(response_content, batch)
|
translations = self.parse_batch_translation_response(response_content, batch)
|
||||||
|
if self._cache_enabled:
|
||||||
|
_write_json_cache(cache_path, {"translations": translations})
|
||||||
|
return translations
|
||||||
except (httpx.HTTPError, ValueError, TranslationError) as exc:
|
except (httpx.HTTPError, ValueError, TranslationError) as exc:
|
||||||
last_error = exc
|
last_error = exc
|
||||||
if self._should_retry_with_user_only_prompt(exc):
|
if self._should_retry_with_user_only_prompt(exc):
|
||||||
try:
|
try:
|
||||||
fallback_payload = self.build_contextual_user_only_payload(batch, source_language, target_language)
|
fallback_payload = self.build_contextual_user_only_payload(batch, source_language, target_language)
|
||||||
fallback_content = self._post_chat_completion(fallback_payload)
|
fallback_content = self._post_chat_completion(fallback_payload)
|
||||||
return self.parse_batch_translation_response(fallback_content, batch)
|
translations = self.parse_batch_translation_response(fallback_content, batch)
|
||||||
|
if self._cache_enabled:
|
||||||
|
_write_json_cache(cache_path, {"translations": translations})
|
||||||
|
return translations
|
||||||
except (httpx.HTTPError, ValueError, TranslationError) as fallback_exc:
|
except (httpx.HTTPError, ValueError, TranslationError) as fallback_exc:
|
||||||
last_error = fallback_exc
|
last_error = fallback_exc
|
||||||
if self._should_retry_with_structured_translation_prompt(last_error):
|
if self._should_retry_with_structured_translation_prompt(last_error):
|
||||||
@@ -574,7 +634,10 @@ class LMStudioTranslator:
|
|||||||
target_language,
|
target_language,
|
||||||
)
|
)
|
||||||
structured_content = self._post_chat_completion(structured_payload)
|
structured_content = self._post_chat_completion(structured_payload)
|
||||||
return self.parse_batch_translation_response(structured_content, batch)
|
translations = self.parse_batch_translation_response(structured_content, batch)
|
||||||
|
if self._cache_enabled:
|
||||||
|
_write_json_cache(cache_path, {"translations": translations})
|
||||||
|
return translations
|
||||||
except (httpx.HTTPError, ValueError, TranslationError) as structured_exc:
|
except (httpx.HTTPError, ValueError, TranslationError) as structured_exc:
|
||||||
last_error = structured_exc
|
last_error = structured_exc
|
||||||
|
|
||||||
|
|||||||
@@ -8,6 +8,7 @@ import httpx
|
|||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
from src.core_utils import TranslationError
|
from src.core_utils import TranslationError
|
||||||
|
from src import translation
|
||||||
from src.translation import LMStudioTranslator, TranslationConfig
|
from src.translation import LMStudioTranslator, TranslationConfig
|
||||||
|
|
||||||
|
|
||||||
@@ -56,7 +57,8 @@ def test_build_contextual_batch_payload_includes_neighboring_segments():
|
|||||||
|
|
||||||
assert payload["model"] == "gemma-3-4b-it"
|
assert payload["model"] == "gemma-3-4b-it"
|
||||||
assert payload["messages"][0]["role"] == "system"
|
assert payload["messages"][0]["role"] == "system"
|
||||||
assert "expert audiovisual translator for dubbed video content" in payload["messages"][0]["content"]
|
assert "expert audiovisual translator and dubbing script adapter" in payload["messages"][0]["content"]
|
||||||
|
assert "Preserve segment ids and output order exactly" in payload["messages"][0]["content"]
|
||||||
assert payload["temperature"] == 0.0
|
assert payload["temperature"] == 0.0
|
||||||
assert payload["top_p"] == 1.0
|
assert payload["top_p"] == 1.0
|
||||||
assert user_payload == {
|
assert user_payload == {
|
||||||
@@ -116,6 +118,29 @@ def test_translate_segments_batches_context_and_preserves_exact_mapping():
|
|||||||
assert [item["id"] for item in requests[1]["segments"]] == ["4", "5", "6"]
|
assert [item["id"] for item in requests[1]["segments"]] == ["4", "5", "6"]
|
||||||
|
|
||||||
|
|
||||||
|
def test_translate_segments_uses_persistent_cache(tmp_path, monkeypatch):
|
||||||
|
requests = {"count": 0}
|
||||||
|
monkeypatch.setattr(translation, "TRANSLATION_CACHE_DIR", tmp_path)
|
||||||
|
|
||||||
|
def handler(request: httpx.Request) -> httpx.Response:
|
||||||
|
requests["count"] += 1
|
||||||
|
payload = _read_request_json(request)
|
||||||
|
batch_request = json.loads(payload["messages"][1]["content"])
|
||||||
|
translations = [
|
||||||
|
{"id": item["id"], "translated_text": f"cached::{item['text']}"}
|
||||||
|
for item in batch_request["segments"]
|
||||||
|
]
|
||||||
|
return _mock_batch_response(translations)
|
||||||
|
|
||||||
|
config = TranslationConfig(model="cache-model")
|
||||||
|
first_translator = LMStudioTranslator(config, client=_mock_client(handler))
|
||||||
|
second_translator = LMStudioTranslator(config, client=_mock_client(handler))
|
||||||
|
|
||||||
|
assert first_translator.translate_segments(["hello", "world"], "fr", "en") == ["cached::hello", "cached::world"]
|
||||||
|
assert second_translator.translate_segments(["hello", "world"], "fr", "en") == ["cached::hello", "cached::world"]
|
||||||
|
assert requests["count"] == 1
|
||||||
|
|
||||||
|
|
||||||
def test_retry_on_transient_http_error_then_succeeds():
|
def test_retry_on_transient_http_error_then_succeeds():
|
||||||
attempts = {"count": 0}
|
attempts = {"count": 0}
|
||||||
|
|
||||||
|
|||||||
@@ -4,6 +4,7 @@ from __future__ import annotations
|
|||||||
|
|
||||||
import asyncio
|
import asyncio
|
||||||
|
|
||||||
|
from src import engines
|
||||||
from src.engines import Engine
|
from src.engines import Engine
|
||||||
from src.translation import TranslationConfig
|
from src.translation import TranslationConfig
|
||||||
|
|
||||||
@@ -46,3 +47,39 @@ def test_synthesize_retries_transient_edge_tts_failure(tmp_path, monkeypatch):
|
|||||||
assert calls["count"] == 2
|
assert calls["count"] == 2
|
||||||
assert out_path.exists()
|
assert out_path.exists()
|
||||||
assert out_path.stat().st_size == 2048
|
assert out_path.stat().st_size == 2048
|
||||||
|
|
||||||
|
|
||||||
|
def test_synthesize_uses_tts_cache(tmp_path, monkeypatch):
|
||||||
|
calls = {"count": 0}
|
||||||
|
cache_dir = tmp_path / "tts-cache"
|
||||||
|
monkeypatch.setattr(engines, "TTS_CACHE_DIR", cache_dir)
|
||||||
|
|
||||||
|
class FakeCommunicate:
|
||||||
|
def __init__(self, text, voice, rate):
|
||||||
|
self.text = text
|
||||||
|
self.voice = voice
|
||||||
|
self.rate = rate
|
||||||
|
|
||||||
|
async def save(self, out_path):
|
||||||
|
calls["count"] += 1
|
||||||
|
with open(out_path, "wb") as audio_file:
|
||||||
|
audio_file.write(b"1" * 2048)
|
||||||
|
|
||||||
|
monkeypatch.setattr("src.engines.edge_tts.Communicate", FakeCommunicate)
|
||||||
|
|
||||||
|
engine = Engine(
|
||||||
|
"cpu",
|
||||||
|
translation_config=TranslationConfig(
|
||||||
|
base_url="http://127.0.0.1:1234/v1",
|
||||||
|
api_key="test-key",
|
||||||
|
model="test-model",
|
||||||
|
),
|
||||||
|
)
|
||||||
|
first_out = tmp_path / "first.mp3"
|
||||||
|
second_out = tmp_path / "second.mp3"
|
||||||
|
|
||||||
|
asyncio.run(engine.synthesize("Bonjour", "fr", first_out))
|
||||||
|
asyncio.run(engine.synthesize("Bonjour", "fr", second_out))
|
||||||
|
|
||||||
|
assert calls["count"] == 1
|
||||||
|
assert first_out.read_bytes() == second_out.read_bytes()
|
||||||
|
|||||||
Reference in New Issue
Block a user