Compare commits
13 Commits
fix-vocal-
...
codex/guar
| Author | SHA1 | Date | |
|---|---|---|---|
| 9fbb7c1756 | |||
| 803f532ff3 | |||
| f1e72f27e2 | |||
| 75522ede50 | |||
| a25a60f217 | |||
| 665ea41c65 | |||
| 82d5c3c173 | |||
| f4f1236777 | |||
| 82718e5e84 | |||
| c6363dfa84 | |||
| 68d7ce928f | |||
| adbc687093 | |||
| d83c57cda3 |
13
.dockerignore
Normal file
13
.dockerignore
Normal file
@@ -0,0 +1,13 @@
|
||||
.cache/
|
||||
.git/
|
||||
.pytest_cache/
|
||||
.venv/
|
||||
__pycache__/
|
||||
logs/
|
||||
output/
|
||||
temp/
|
||||
*.log
|
||||
*.mp3
|
||||
*.mp4
|
||||
*.pyc
|
||||
*.wav
|
||||
3
.gitignore
vendored
3
.gitignore
vendored
@@ -1,5 +1,6 @@
|
||||
__pycache__/
|
||||
*.pyc
|
||||
.venv/
|
||||
.cache/
|
||||
temp/
|
||||
output/
|
||||
@@ -7,4 +8,4 @@ output/
|
||||
*.wav
|
||||
*.mp3
|
||||
logs/
|
||||
*.log
|
||||
*.log
|
||||
|
||||
24
Dockerfile
Normal file
24
Dockerfile
Normal file
@@ -0,0 +1,24 @@
|
||||
FROM python:3.10-slim
|
||||
|
||||
ENV PYTHONDONTWRITEBYTECODE=1 \
|
||||
PYTHONUNBUFFERED=1 \
|
||||
GRADIO_SERVER_NAME=0.0.0.0 \
|
||||
PORT=7860
|
||||
|
||||
WORKDIR /app
|
||||
|
||||
RUN apt-get update \
|
||||
&& apt-get install -y --no-install-recommends ffmpeg ca-certificates \
|
||||
&& rm -rf /var/lib/apt/lists/*
|
||||
|
||||
COPY requirements.txt .
|
||||
RUN pip install --no-cache-dir --upgrade pip \
|
||||
&& pip install --no-cache-dir -r requirements.txt
|
||||
|
||||
COPY . .
|
||||
|
||||
RUN mkdir -p .cache temp output logs/gradio
|
||||
|
||||
EXPOSE 7860
|
||||
|
||||
CMD ["python", "web_app.py"]
|
||||
56
README.md
56
README.md
@@ -4,9 +4,9 @@ YouTube Auto Dub is a Python pipeline that downloads a YouTube video, transcribe
|
||||
|
||||
## What Changed
|
||||
|
||||
- Translation now uses an LM Studio OpenAI-compatible `/v1/chat/completions` endpoint.
|
||||
- Translation now uses an OpenAI-compatible `/v1/chat/completions` endpoint.
|
||||
- Google Translate scraping has been removed from the active runtime path.
|
||||
- LM Studio is now the default and only supported translation backend.
|
||||
- OpenAI compatible backend is now the default with no option for Google Translate.
|
||||
- Translation settings can be configured with environment variables or CLI flags.
|
||||
|
||||
## Requirements
|
||||
@@ -14,7 +14,7 @@ YouTube Auto Dub is a Python pipeline that downloads a YouTube video, transcribe
|
||||
- Python 3.10+
|
||||
- [uv](https://docs.astral.sh/uv/)
|
||||
- FFmpeg and FFprobe available on `PATH`
|
||||
- LM Studio running locally with an OpenAI-compatible server enabled
|
||||
- An OpenAI-compatible server
|
||||
|
||||
## Setup
|
||||
|
||||
@@ -66,6 +66,49 @@ Basic example:
|
||||
.venv\Scripts\python.exe main.py "https://youtube.com/watch?v=VIDEO_ID" --lang es
|
||||
```
|
||||
|
||||
### Gradio Web UI
|
||||
|
||||
Gradio provides a local browser UI for starting dub jobs, watching progress, and downloading finished videos:
|
||||
|
||||
```powershell
|
||||
.venv\Scripts\python.exe web_app.py
|
||||
```
|
||||
|
||||
Open `http://127.0.0.1:7860` and submit a YouTube URL. Jobs run through the same `main.py` pipeline, so the CLI options and environment variables still apply.
|
||||
|
||||
The OpenAI-compatible translation endpoint, API key, and model can be changed in the UI under **OpenAI-Compatible Settings**. Click **Save Settings** to persist them to `.cache/web_settings.json` for future web jobs. Unsaved values in the fields are still used for the next job you start.
|
||||
|
||||
You can also upload a local `.mp4` instead of entering a YouTube URL. Uploaded videos are staged under `.cache/uploads` and processed with the same transcription, translation, dubbing, and render pipeline. Restricted YouTube videos can use the **Upload Cookies File** control instead of typing a local cookies path.
|
||||
|
||||
The web UI automatically refreshes job status, progress, steps, and output choices every few seconds while it is open. The manual **Refresh** button is still available.
|
||||
|
||||
Translations and raw TTS clips are cached under `.cache/translations` and `.cache/tts`. This lets reruns skip work that already succeeded, which is especially useful after transient TTS failures. Set `TRANSLATION_CACHE_ENABLED=0` or `TTS_CACHE_ENABLED=0` to disable those caches.
|
||||
|
||||
### Docker
|
||||
|
||||
Build and run the Gradio UI in a container:
|
||||
|
||||
```powershell
|
||||
docker build -t youtube-auto-dub:gradio .
|
||||
docker run --rm -p 7860:7860 `
|
||||
-e LM_STUDIO_BASE_URL=http://host.docker.internal:1234/v1 `
|
||||
-e LM_STUDIO_API_KEY=lm-studio `
|
||||
-e LM_STUDIO_MODEL=gemma-3-4b-it `
|
||||
-v ${PWD}\.cache:/app/.cache `
|
||||
-v ${PWD}\output:/app/output `
|
||||
-v ${PWD}\logs:/app/logs `
|
||||
-v ${PWD}\temp:/app/temp `
|
||||
youtube-auto-dub:gradio
|
||||
```
|
||||
|
||||
Or use Compose:
|
||||
|
||||
```powershell
|
||||
docker compose up --build
|
||||
```
|
||||
|
||||
When LM Studio runs on the host machine, use `http://host.docker.internal:1234/v1` from inside Docker instead of `http://127.0.0.1:1234/v1`.
|
||||
|
||||
Override the LM Studio endpoint or model from the CLI:
|
||||
|
||||
```powershell
|
||||
@@ -83,11 +126,18 @@ Authentication options for restricted videos still work as before:
|
||||
.venv\Scripts\python.exe main.py "https://youtube.com/watch?v=VIDEO_ID" --lang de --cookies cookies.txt
|
||||
```
|
||||
|
||||
Process a local MP4:
|
||||
|
||||
```powershell
|
||||
.venv\Scripts\python.exe main.py --input-file "C:\path\to\video.mp4" --lang es
|
||||
```
|
||||
|
||||
## CLI Options
|
||||
|
||||
| Option | Description |
|
||||
| --- | --- |
|
||||
| `url` | YouTube video URL to process |
|
||||
| `--input-file` | Local MP4 file to process instead of a YouTube URL |
|
||||
| `--lang`, `-l` | Target language code |
|
||||
| `--browser`, `-b` | Browser name for cookie extraction |
|
||||
| `--cookies`, `-c` | Path to exported cookies file |
|
||||
|
||||
15
docker-compose.yml
Normal file
15
docker-compose.yml
Normal file
@@ -0,0 +1,15 @@
|
||||
services:
|
||||
youtube-auto-dub:
|
||||
build: .
|
||||
image: youtube-auto-dub:gradio
|
||||
ports:
|
||||
- "7860:7860"
|
||||
environment:
|
||||
LM_STUDIO_BASE_URL: "http://host.docker.internal:1234/v1"
|
||||
LM_STUDIO_API_KEY: "lm-studio"
|
||||
LM_STUDIO_MODEL: "gemma-3-4b-it"
|
||||
volumes:
|
||||
- ./.cache:/app/.cache
|
||||
- ./output:/app/output
|
||||
- ./logs:/app/logs
|
||||
- ./temp:/app/temp
|
||||
82
main.py
82
main.py
@@ -7,6 +7,7 @@ import argparse
|
||||
import asyncio
|
||||
import shutil
|
||||
import time
|
||||
from pathlib import Path
|
||||
|
||||
from src.audio_separation import DEFAULT_MIX_MODE
|
||||
from src.core_utils import ConfigurationError
|
||||
@@ -28,7 +29,11 @@ Examples:
|
||||
""",
|
||||
)
|
||||
|
||||
parser.add_argument("url", help="YouTube video URL to subtitle")
|
||||
parser.add_argument("url", nargs="?", help="YouTube video URL to subtitle")
|
||||
parser.add_argument(
|
||||
"--input-file",
|
||||
help="Path to a local MP4 file to dub instead of downloading from YouTube.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--lang",
|
||||
"-l",
|
||||
@@ -148,6 +153,24 @@ def _build_translation_config(args: argparse.Namespace) -> TranslationConfig:
|
||||
)
|
||||
|
||||
|
||||
def _validate_source_args(args: argparse.Namespace) -> None:
|
||||
"""Ensure exactly one source input is configured."""
|
||||
if bool(args.url) == bool(args.input_file):
|
||||
raise SystemExit("Provide either a YouTube URL or --input-file, but not both.")
|
||||
|
||||
|
||||
def _prepare_local_video(input_file: str, media_module, cache_dir: Path) -> tuple[Path, Path]:
|
||||
"""Validate a local MP4 and extract its audio for the shared pipeline."""
|
||||
video_path = Path(input_file).expanduser().resolve()
|
||||
if not video_path.exists():
|
||||
raise FileNotFoundError(f"Input file not found: {video_path}")
|
||||
if video_path.suffix.lower() != ".mp4":
|
||||
raise ValueError("Only MP4 input files are supported.")
|
||||
|
||||
audio_path = cache_dir / f"{video_path.stem}_uploaded.wav"
|
||||
return video_path, media_module.extract_audio_from_video(video_path, audio_path)
|
||||
|
||||
|
||||
def _get_source_language_hint() -> str:
|
||||
"""Read an optional source language override from the environment."""
|
||||
import os
|
||||
@@ -190,6 +213,7 @@ def main() -> None:
|
||||
"""Run the full YouTube Auto Dub pipeline."""
|
||||
parser = build_parser()
|
||||
args = parser.parse_args()
|
||||
_validate_source_args(args)
|
||||
|
||||
import src.engines
|
||||
import src.media
|
||||
@@ -233,32 +257,42 @@ def main() -> None:
|
||||
)
|
||||
|
||||
print(f"\n{'=' * 60}")
|
||||
print("STEP 1: DOWNLOADING CONTENT")
|
||||
print("STEP 1: PREPARING CONTENT")
|
||||
print(f"{'=' * 60}")
|
||||
print(f"[*] Target URL: {args.url}")
|
||||
print(f"[*] Target Language: {args.lang.upper()}")
|
||||
|
||||
try:
|
||||
video_path = src.youtube.downloadVideo(
|
||||
args.url,
|
||||
browser=args.browser,
|
||||
cookies_file=args.cookies,
|
||||
)
|
||||
audio_path = src.youtube.downloadAudio(
|
||||
args.url,
|
||||
browser=args.browser,
|
||||
cookies_file=args.cookies,
|
||||
)
|
||||
print(f"[+] Video downloaded: {video_path}")
|
||||
print(f"[+] Audio extracted: {audio_path}")
|
||||
except Exception as exc:
|
||||
print(f"\n[!] DOWNLOAD FAILED: {exc}")
|
||||
print("\n[-] TROUBLESHOOTING TIPS:")
|
||||
print(" 1. Close all browser windows if using --browser")
|
||||
print(" 2. Export fresh cookies.txt and use --cookies")
|
||||
print(" 3. Check if video is private/region-restricted")
|
||||
print(" 4. Verify YouTube URL is correct")
|
||||
return
|
||||
if args.input_file:
|
||||
print(f"[*] Source MP4: {args.input_file}")
|
||||
try:
|
||||
video_path, audio_path = _prepare_local_video(args.input_file, src.media, src.engines.CACHE_DIR)
|
||||
print(f"[+] Local video ready: {video_path}")
|
||||
print(f"[+] Audio extracted: {audio_path}")
|
||||
except Exception as exc:
|
||||
print(f"\n[!] LOCAL INPUT FAILED: {exc}")
|
||||
return
|
||||
else:
|
||||
print(f"[*] Target URL: {args.url}")
|
||||
try:
|
||||
video_path = src.youtube.downloadVideo(
|
||||
args.url,
|
||||
browser=args.browser,
|
||||
cookies_file=args.cookies,
|
||||
)
|
||||
audio_path = src.youtube.downloadAudio(
|
||||
args.url,
|
||||
browser=args.browser,
|
||||
cookies_file=args.cookies,
|
||||
)
|
||||
print(f"[+] Video downloaded: {video_path}")
|
||||
print(f"[+] Audio extracted: {audio_path}")
|
||||
except Exception as exc:
|
||||
print(f"\n[!] DOWNLOAD FAILED: {exc}")
|
||||
print("\n[-] TROUBLESHOOTING TIPS:")
|
||||
print(" 1. Close all browser windows if using --browser")
|
||||
print(" 2. Export fresh cookies.txt and use --cookies")
|
||||
print(" 3. Check if video is private/region-restricted")
|
||||
print(" 4. Verify YouTube URL is correct")
|
||||
return
|
||||
|
||||
print(f"\n{'=' * 60}")
|
||||
print("STEP 2: SPEECH TRANSCRIPTION")
|
||||
|
||||
@@ -10,3 +10,4 @@ tqdm
|
||||
pathlib
|
||||
typing-extensions
|
||||
pytest
|
||||
gradio
|
||||
|
||||
@@ -16,8 +16,10 @@ import torch
|
||||
import asyncio
|
||||
import edge_tts
|
||||
import gc
|
||||
import hashlib
|
||||
import json
|
||||
import os
|
||||
import shutil
|
||||
from abc import ABC
|
||||
import numpy as np
|
||||
from pathlib import Path
|
||||
@@ -42,6 +44,7 @@ BASE_DIR = Path(__file__).resolve().parent.parent
|
||||
CACHE_DIR = BASE_DIR / ".cache"
|
||||
OUTPUT_DIR = BASE_DIR / "output"
|
||||
TEMP_DIR = BASE_DIR / "temp"
|
||||
TTS_CACHE_DIR = CACHE_DIR / "tts"
|
||||
|
||||
# Configuration files
|
||||
LANG_MAP_FILE = BASE_DIR / "language_map.json"
|
||||
@@ -53,6 +56,25 @@ for directory_path in [CACHE_DIR, OUTPUT_DIR, TEMP_DIR]:
|
||||
# Audio processing settings
|
||||
SAMPLE_RATE = 24000
|
||||
AUDIO_CHANNELS = 1
|
||||
DEFAULT_TTS_MAX_RETRIES = int(os.getenv("TTS_MAX_RETRIES", "4"))
|
||||
DEFAULT_TTS_RETRY_BACKOFF_SECONDS = float(os.getenv("TTS_RETRY_BACKOFF_SECONDS", "2.0"))
|
||||
|
||||
|
||||
def _cache_enabled(env_name: str) -> bool:
|
||||
return (os.getenv(env_name, "1") or "").strip().lower() not in {"0", "false", "no", "off"}
|
||||
|
||||
|
||||
def _tts_cache_key(text: str, target_lang: str, voice: str, rate: str) -> str:
|
||||
payload = {
|
||||
"version": "edge-tts-v1",
|
||||
"text": text,
|
||||
"target_lang": target_lang,
|
||||
"voice": voice,
|
||||
"rate": rate,
|
||||
"sample_rate": SAMPLE_RATE,
|
||||
}
|
||||
serialized = json.dumps(payload, ensure_ascii=False, sort_keys=True, separators=(",", ":"))
|
||||
return hashlib.sha256(serialized.encode("utf-8")).hexdigest()
|
||||
|
||||
def _select_optimal_whisper_model(device: str = "cpu") -> str:
|
||||
"""Select optimal Whisper model based on available VRAM and device.
|
||||
@@ -487,22 +509,47 @@ class Engine(PipelineComponent):
|
||||
) -> None:
|
||||
if not text.strip(): raise ValueError("Text empty")
|
||||
out_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
try:
|
||||
|
||||
last_error: Optional[Exception] = None
|
||||
for attempt in range(1, DEFAULT_TTS_MAX_RETRIES + 1):
|
||||
lang_cfg = self._getLangConfig(target_lang)
|
||||
voice_pool = self.config_manager.getVoicePool(target_lang, gender)
|
||||
voice = voice_pool[0] if voice_pool else DEFAULT_VOICE
|
||||
cache_path = TTS_CACHE_DIR / f"{_tts_cache_key(text, target_lang, voice, rate)}.mp3"
|
||||
|
||||
communicate = edge_tts.Communicate(text, voice=voice, rate=rate)
|
||||
await communicate.save(str(out_path))
|
||||
|
||||
if not out_path.exists() or out_path.stat().st_size < 1024:
|
||||
raise RuntimeError("TTS file invalid")
|
||||
|
||||
except Exception as e:
|
||||
if out_path.exists(): out_path.unlink(missing_ok=True)
|
||||
_handleError(e, "TTS synthesis")
|
||||
raise TTSError(f"TTS failed: {e}") from e
|
||||
if _cache_enabled("TTS_CACHE_ENABLED") and cache_path.exists() and cache_path.stat().st_size >= 1024:
|
||||
print(f"[*] TTS cache hit: {cache_path.name}")
|
||||
shutil.copyfile(cache_path, out_path)
|
||||
return
|
||||
|
||||
try:
|
||||
communicate = edge_tts.Communicate(text, voice=voice, rate=rate)
|
||||
await communicate.save(str(out_path))
|
||||
|
||||
if not out_path.exists() or out_path.stat().st_size < 1024:
|
||||
raise RuntimeError("TTS file invalid")
|
||||
|
||||
if _cache_enabled("TTS_CACHE_ENABLED"):
|
||||
cache_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
shutil.copyfile(out_path, cache_path)
|
||||
|
||||
return
|
||||
except Exception as exc:
|
||||
last_error = exc
|
||||
if out_path.exists():
|
||||
out_path.unlink(missing_ok=True)
|
||||
|
||||
if attempt < DEFAULT_TTS_MAX_RETRIES:
|
||||
wait_time = DEFAULT_TTS_RETRY_BACKOFF_SECONDS * attempt
|
||||
print(
|
||||
f"[!] TTS synthesis failed "
|
||||
f"(attempt {attempt}/{DEFAULT_TTS_MAX_RETRIES}): {exc}. "
|
||||
f"Retrying in {wait_time:.1f}s..."
|
||||
)
|
||||
await asyncio.sleep(wait_time)
|
||||
|
||||
_handleError(last_error or RuntimeError("unknown TTS failure"), "TTS synthesis")
|
||||
raise TTSError(f"TTS failed after {DEFAULT_TTS_MAX_RETRIES} attempts: {last_error}") from last_error
|
||||
|
||||
|
||||
def smartChunk(segments: List[Dict]) -> List[Dict]:
|
||||
|
||||
23
src/media.py
23
src/media.py
@@ -22,6 +22,29 @@ from src.engines import SAMPLE_RATE
|
||||
FINAL_MIX_CHANNELS = 2
|
||||
|
||||
|
||||
def extract_audio_from_video(video_path: Path, output_path: Path) -> Path:
|
||||
"""Extract mono WAV audio from a local video file for transcription."""
|
||||
if not video_path.exists():
|
||||
raise FileNotFoundError(f"Source video is missing: {video_path}")
|
||||
|
||||
output_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
cmd = [
|
||||
'ffmpeg', '-y', '-v', 'error',
|
||||
'-i', str(video_path),
|
||||
'-vn',
|
||||
'-acodec', 'pcm_s16le',
|
||||
'-ar', str(SAMPLE_RATE),
|
||||
'-ac', '1',
|
||||
str(output_path),
|
||||
]
|
||||
subprocess.run(cmd, check=True, timeout=None)
|
||||
|
||||
if not output_path.exists() or output_path.stat().st_size < 1024:
|
||||
raise RuntimeError(f"Audio extraction did not create a usable WAV file: {output_path}")
|
||||
|
||||
return output_path
|
||||
|
||||
|
||||
def _build_subtitle_filter(subtitle_path: Path) -> str:
|
||||
"""Build a Windows-safe FFmpeg subtitles filter expression."""
|
||||
escaped_path = str(subtitle_path.resolve()).replace("\\", "/").replace(":", "\\:")
|
||||
|
||||
@@ -2,9 +2,12 @@
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import hashlib
|
||||
import json
|
||||
import os
|
||||
import time
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, List, Optional
|
||||
from urllib.parse import urlparse
|
||||
|
||||
@@ -16,6 +19,12 @@ DEFAULT_LM_STUDIO_BASE_URL = "http://127.0.0.1:1234/v1"
|
||||
DEFAULT_LM_STUDIO_API_KEY = "lm-studio"
|
||||
DEFAULT_LM_STUDIO_MODEL = "gemma-3-4b-it"
|
||||
DEFAULT_TRANSLATION_BACKEND = "lmstudio"
|
||||
DEFAULT_CONTEXTUAL_BATCH_SIZE = 5
|
||||
MIN_CONTEXTUAL_BATCH_SIZE = 3
|
||||
DEFAULT_CONTEXT_SEGMENTS = 2
|
||||
PROMPT_VERSION = "gpt54-dub-v2"
|
||||
BASE_DIR = Path(__file__).resolve().parent.parent
|
||||
TRANSLATION_CACHE_DIR = BASE_DIR / ".cache" / "translations"
|
||||
|
||||
|
||||
def _normalize_base_url(base_url: str) -> str:
|
||||
@@ -118,6 +127,83 @@ def _build_system_prompt(source_language: str, target_language: str) -> str:
|
||||
)
|
||||
|
||||
|
||||
def _build_contextual_system_prompt(source_language: str, target_language: str) -> str:
|
||||
source_descriptor = source_language or "auto"
|
||||
return (
|
||||
"You are an expert audiovisual translator and dubbing script adapter.\n\n"
|
||||
f"Translate subtitle segments from {source_descriptor} to {target_language} for natural spoken dubbing.\n\n"
|
||||
"Primary objective:\n"
|
||||
"- Produce faithful, idiomatic spoken lines that can be read aloud naturally in the target language.\n"
|
||||
"- Preserve intent, tone, emotion, register, humor, sarcasm, hesitation, and subtext.\n"
|
||||
"- Use the provided previous and next segments only as context; translate only the current segments.\n\n"
|
||||
"Dubbing adaptation rules:\n"
|
||||
"- Prefer natural speech over literal word-for-word phrasing when the literal version sounds stiff.\n"
|
||||
"- Keep each translated segment close to the source segment length when possible, because it will be timed to video.\n"
|
||||
"- Do not add new claims, soften meaning, moralize, censor, summarize, or omit content.\n"
|
||||
"- Preserve speaker deixis and continuity across adjacent segments.\n"
|
||||
"- Keep names, brands, URLs, emails, file paths, code, product names, and quoted UI text unchanged unless transliteration is clearly required.\n"
|
||||
"- Preserve numbers, units, dates, and technical terms accurately.\n"
|
||||
"- If a phrase is slang, idiom, or a joke, translate the effect rather than the literal wording.\n\n"
|
||||
"Output contract:\n"
|
||||
"- Return valid JSON only, with no markdown fences or commentary.\n"
|
||||
"- Return exactly one translated item per input segment.\n"
|
||||
"- Preserve segment ids and output order exactly.\n"
|
||||
"- Preserve empty or whitespace-only segments as an empty translated_text.\n"
|
||||
"- Do not include previous_segments or next_segments in the output.\n"
|
||||
'- Return this exact schema: {"translations":[{"id":"...","translated_text":"..."}]}.'
|
||||
)
|
||||
|
||||
|
||||
def _cache_enabled(env_name: str) -> bool:
|
||||
return (os.getenv(env_name, "1") or "").strip().lower() not in {"0", "false", "no", "off"}
|
||||
|
||||
|
||||
def _json_cache_key(payload: Dict[str, Any]) -> str:
|
||||
serialized = json.dumps(payload, ensure_ascii=False, sort_keys=True, separators=(",", ":"))
|
||||
return hashlib.sha256(serialized.encode("utf-8")).hexdigest()
|
||||
|
||||
|
||||
def _read_json_cache(cache_path: Path) -> Optional[Dict[str, Any]]:
|
||||
if not cache_path.exists():
|
||||
return None
|
||||
try:
|
||||
payload = json.loads(cache_path.read_text(encoding="utf-8"))
|
||||
except (OSError, json.JSONDecodeError):
|
||||
return None
|
||||
return payload if isinstance(payload, dict) else None
|
||||
|
||||
|
||||
def _write_json_cache(cache_path: Path, payload: Dict[str, Any]) -> None:
|
||||
cache_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
tmp_path = cache_path.with_suffix(".tmp")
|
||||
tmp_path.write_text(json.dumps(payload, ensure_ascii=False, indent=2), encoding="utf-8")
|
||||
tmp_path.replace(cache_path)
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class TranslationSegment:
|
||||
"""A subtitle segment prepared for contextual batch translation."""
|
||||
|
||||
id: str
|
||||
text: str
|
||||
|
||||
def as_payload(self) -> Dict[str, str]:
|
||||
return {"id": self.id, "text": self.text}
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class TranslationBatch:
|
||||
"""A contextual subtitle translation batch."""
|
||||
|
||||
previous_segments: List[TranslationSegment]
|
||||
segments: List[TranslationSegment]
|
||||
next_segments: List[TranslationSegment]
|
||||
|
||||
@property
|
||||
def segment_ids(self) -> List[str]:
|
||||
return [segment.id for segment in self.segments]
|
||||
|
||||
|
||||
class LMStudioTranslator:
|
||||
"""OpenAI-style chat completions client for LM Studio."""
|
||||
|
||||
@@ -132,6 +218,15 @@ class LMStudioTranslator:
|
||||
self._client = client or httpx.Client(timeout=httpx.Timeout(self.config.timeout_seconds))
|
||||
self._owns_client = client is None
|
||||
self._sleeper = sleeper
|
||||
self._cache_enabled = _cache_enabled("TRANSLATION_CACHE_ENABLED")
|
||||
|
||||
@staticmethod
|
||||
def _generation_settings() -> Dict[str, Any]:
|
||||
return {
|
||||
"temperature": 0.0,
|
||||
"top_p": 1.0,
|
||||
"stream": False,
|
||||
}
|
||||
|
||||
def build_payload(self, text: str, source_language: str, target_language: str) -> Dict[str, Any]:
|
||||
"""Build the OpenAI-compatible chat completions payload."""
|
||||
@@ -141,9 +236,7 @@ class LMStudioTranslator:
|
||||
{"role": "system", "content": _build_system_prompt(source_language, target_language)},
|
||||
{"role": "user", "content": text},
|
||||
],
|
||||
"temperature": 0.1,
|
||||
"top_p": 1,
|
||||
"stream": False,
|
||||
**self._generation_settings(),
|
||||
}
|
||||
|
||||
def build_user_only_payload(
|
||||
@@ -160,9 +253,7 @@ class LMStudioTranslator:
|
||||
"messages": [
|
||||
{"role": "user", "content": merged_prompt},
|
||||
],
|
||||
"temperature": 0.1,
|
||||
"top_p": 1,
|
||||
"stream": False,
|
||||
**self._generation_settings(),
|
||||
}
|
||||
|
||||
def build_structured_translation_payload(
|
||||
@@ -188,9 +279,125 @@ class LMStudioTranslator:
|
||||
],
|
||||
}
|
||||
],
|
||||
"temperature": 0.1,
|
||||
"top_p": 1,
|
||||
"stream": False,
|
||||
**self._generation_settings(),
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def _build_translation_segments(texts: List[str]) -> List[TranslationSegment]:
|
||||
return [
|
||||
TranslationSegment(id=str(index), text=text)
|
||||
for index, text in enumerate(texts)
|
||||
]
|
||||
|
||||
def build_contextual_batches(
|
||||
self,
|
||||
texts: List[str],
|
||||
batch_size: int = DEFAULT_CONTEXTUAL_BATCH_SIZE,
|
||||
context_segments: int = DEFAULT_CONTEXT_SEGMENTS,
|
||||
) -> List[TranslationBatch]:
|
||||
"""Group subtitle segments into small batches with surrounding context."""
|
||||
if batch_size < 1:
|
||||
raise ValueError("batch_size must be at least 1")
|
||||
if context_segments < 0:
|
||||
raise ValueError("context_segments cannot be negative")
|
||||
|
||||
segments = self._build_translation_segments(texts)
|
||||
if not segments:
|
||||
return []
|
||||
|
||||
batches: List[TranslationBatch] = []
|
||||
start_index = 0
|
||||
total_segments = len(segments)
|
||||
|
||||
while start_index < total_segments:
|
||||
remaining = total_segments - start_index
|
||||
current_batch_size = min(batch_size, remaining)
|
||||
trailing_segments = remaining - current_batch_size
|
||||
|
||||
if 0 < trailing_segments < MIN_CONTEXTUAL_BATCH_SIZE and current_batch_size > MIN_CONTEXTUAL_BATCH_SIZE:
|
||||
current_batch_size -= MIN_CONTEXTUAL_BATCH_SIZE - trailing_segments
|
||||
|
||||
end_index = start_index + current_batch_size
|
||||
batches.append(
|
||||
TranslationBatch(
|
||||
previous_segments=segments[max(0, start_index - context_segments):start_index],
|
||||
segments=segments[start_index:end_index],
|
||||
next_segments=segments[end_index:min(total_segments, end_index + context_segments)],
|
||||
)
|
||||
)
|
||||
start_index = end_index
|
||||
|
||||
return batches
|
||||
|
||||
def build_contextual_batch_request(
|
||||
self,
|
||||
batch: TranslationBatch,
|
||||
source_language: str,
|
||||
target_language: str,
|
||||
) -> Dict[str, Any]:
|
||||
"""Build the contextual JSON payload sent to the model."""
|
||||
return {
|
||||
"source_language": source_language or "auto",
|
||||
"target_language": target_language,
|
||||
"previous_segments": [segment.as_payload() for segment in batch.previous_segments],
|
||||
"segments": [segment.as_payload() for segment in batch.segments],
|
||||
"next_segments": [segment.as_payload() for segment in batch.next_segments],
|
||||
}
|
||||
|
||||
def build_contextual_batch_payload(
|
||||
self,
|
||||
batch: TranslationBatch,
|
||||
source_language: str,
|
||||
target_language: str,
|
||||
) -> Dict[str, Any]:
|
||||
"""Build the LM Studio request for contextual subtitle batch translation."""
|
||||
return {
|
||||
"model": self.config.model,
|
||||
"messages": [
|
||||
{"role": "system", "content": _build_contextual_system_prompt(source_language, target_language)},
|
||||
{
|
||||
"role": "user",
|
||||
"content": json.dumps(
|
||||
self.build_contextual_batch_request(batch, source_language, target_language),
|
||||
ensure_ascii=False,
|
||||
indent=2,
|
||||
),
|
||||
},
|
||||
],
|
||||
**self._generation_settings(),
|
||||
}
|
||||
|
||||
def _build_contextual_user_prompt(
|
||||
self,
|
||||
batch: TranslationBatch,
|
||||
source_language: str,
|
||||
target_language: str,
|
||||
) -> str:
|
||||
request_payload = self.build_contextual_batch_request(batch, source_language, target_language)
|
||||
return (
|
||||
f"{_build_contextual_system_prompt(source_language, target_language)}\n\n"
|
||||
"USER PAYLOAD FORMAT:\n"
|
||||
f"{json.dumps(request_payload, ensure_ascii=False, indent=2)}\n\n"
|
||||
"EXPECTED OUTPUT:\n"
|
||||
'{"translations":[{"id":"...","translated_text":"..."}]}'
|
||||
)
|
||||
|
||||
def build_contextual_user_only_payload(
|
||||
self,
|
||||
batch: TranslationBatch,
|
||||
source_language: str,
|
||||
target_language: str,
|
||||
) -> Dict[str, Any]:
|
||||
"""Build a fallback contextual payload for models that require a user first turn."""
|
||||
return {
|
||||
"model": self.config.model,
|
||||
"messages": [
|
||||
{
|
||||
"role": "user",
|
||||
"content": self._build_contextual_user_prompt(batch, source_language, target_language),
|
||||
}
|
||||
],
|
||||
**self._generation_settings(),
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
@@ -219,6 +426,73 @@ class LMStudioTranslator:
|
||||
|
||||
return translated
|
||||
|
||||
@staticmethod
|
||||
def parse_batch_translation_response(content: str, batch: TranslationBatch) -> List[str]:
|
||||
"""Parse and validate the batch translation JSON response."""
|
||||
try:
|
||||
response_payload = json.loads(content)
|
||||
except json.JSONDecodeError as exc:
|
||||
raise TranslationError("LM Studio returned malformed JSON for batch translation.") from exc
|
||||
|
||||
if not isinstance(response_payload, dict):
|
||||
raise TranslationError("LM Studio batch translation response must be a JSON object.")
|
||||
|
||||
translations = response_payload.get("translations")
|
||||
if not isinstance(translations, list):
|
||||
raise TranslationError("LM Studio batch translation response must include a 'translations' list.")
|
||||
|
||||
expected_ids = batch.segment_ids
|
||||
actual_ids: List[str] = []
|
||||
translated_texts: List[str] = []
|
||||
|
||||
for item in translations:
|
||||
if not isinstance(item, dict):
|
||||
raise TranslationError("LM Studio batch translation response contained a non-object translation item.")
|
||||
|
||||
segment_id = item.get("id")
|
||||
translated_text = item.get("translated_text")
|
||||
|
||||
if not isinstance(segment_id, str) or not segment_id:
|
||||
raise TranslationError("LM Studio batch translation response contained an invalid segment id.")
|
||||
|
||||
if not isinstance(translated_text, str):
|
||||
raise TranslationError(
|
||||
f"LM Studio batch translation response for segment '{segment_id}' did not contain a string translated_text."
|
||||
)
|
||||
|
||||
actual_ids.append(segment_id)
|
||||
translated_texts.append(translated_text.strip())
|
||||
|
||||
if len(actual_ids) != len(expected_ids):
|
||||
raise TranslationError(
|
||||
f"LM Studio batch translation response returned {len(actual_ids)} items "
|
||||
f"for {len(expected_ids)} input segments."
|
||||
)
|
||||
|
||||
missing_ids = [segment_id for segment_id in expected_ids if segment_id not in actual_ids]
|
||||
unexpected_ids = [segment_id for segment_id in actual_ids if segment_id not in expected_ids]
|
||||
if missing_ids or unexpected_ids:
|
||||
raise TranslationError(
|
||||
"LM Studio batch translation response ids did not match the request. "
|
||||
f"Missing: {missing_ids or 'none'}. Unexpected: {unexpected_ids or 'none'}."
|
||||
)
|
||||
|
||||
if actual_ids != expected_ids:
|
||||
raise TranslationError("LM Studio batch translation response ids were out of order.")
|
||||
|
||||
validated_translations: List[str] = []
|
||||
for segment, translated_text in zip(batch.segments, translated_texts):
|
||||
if not segment.text.strip():
|
||||
validated_translations.append("")
|
||||
continue
|
||||
|
||||
if not translated_text:
|
||||
raise TranslationError(f"LM Studio returned an empty translation for segment '{segment.id}'.")
|
||||
|
||||
validated_translations.append(translated_text)
|
||||
|
||||
return validated_translations
|
||||
|
||||
def _headers(self) -> Dict[str, str]:
|
||||
return {
|
||||
"Authorization": f"Bearer {self.config.api_key}",
|
||||
@@ -305,23 +579,106 @@ class LMStudioTranslator:
|
||||
raise TranslationError("LM Studio returned a non-JSON response.") from last_error
|
||||
raise TranslationError(f"LM Studio request failed: {last_error}") from last_error
|
||||
|
||||
def _translate_contextual_batch(
|
||||
self,
|
||||
batch: TranslationBatch,
|
||||
target_language: str,
|
||||
source_language: str = "auto",
|
||||
) -> List[str]:
|
||||
"""Translate a single contextual subtitle batch with validation and retries."""
|
||||
payload = self.build_contextual_batch_payload(batch, source_language, target_language)
|
||||
cache_payload = {
|
||||
"prompt_version": PROMPT_VERSION,
|
||||
"backend": self.config.backend,
|
||||
"base_url": self.config.base_url,
|
||||
"model": self.config.model,
|
||||
"source_language": source_language or "auto",
|
||||
"target_language": target_language,
|
||||
"request": self.build_contextual_batch_request(batch, source_language, target_language),
|
||||
}
|
||||
cache_path = TRANSLATION_CACHE_DIR / f"{_json_cache_key(cache_payload)}.json"
|
||||
if self._cache_enabled:
|
||||
cached_payload = _read_json_cache(cache_path)
|
||||
cached_translations = cached_payload.get("translations") if cached_payload else None
|
||||
if isinstance(cached_translations, list) and all(isinstance(item, str) for item in cached_translations):
|
||||
if len(cached_translations) == len(batch.segments):
|
||||
print(f"[*] Translation cache hit: {cache_path.name}")
|
||||
return cached_translations
|
||||
|
||||
last_error: Optional[Exception] = None
|
||||
|
||||
for attempt in range(1, self.config.max_retries + 1):
|
||||
try:
|
||||
response_content = self._post_chat_completion(payload)
|
||||
translations = self.parse_batch_translation_response(response_content, batch)
|
||||
if self._cache_enabled:
|
||||
_write_json_cache(cache_path, {"translations": translations})
|
||||
return translations
|
||||
except (httpx.HTTPError, ValueError, TranslationError) as exc:
|
||||
last_error = exc
|
||||
if self._should_retry_with_user_only_prompt(exc):
|
||||
try:
|
||||
fallback_payload = self.build_contextual_user_only_payload(batch, source_language, target_language)
|
||||
fallback_content = self._post_chat_completion(fallback_payload)
|
||||
translations = self.parse_batch_translation_response(fallback_content, batch)
|
||||
if self._cache_enabled:
|
||||
_write_json_cache(cache_path, {"translations": translations})
|
||||
return translations
|
||||
except (httpx.HTTPError, ValueError, TranslationError) as fallback_exc:
|
||||
last_error = fallback_exc
|
||||
if self._should_retry_with_structured_translation_prompt(last_error):
|
||||
try:
|
||||
structured_payload = self.build_structured_translation_payload(
|
||||
self._build_contextual_user_prompt(batch, source_language, target_language),
|
||||
source_language,
|
||||
target_language,
|
||||
)
|
||||
structured_content = self._post_chat_completion(structured_payload)
|
||||
translations = self.parse_batch_translation_response(structured_content, batch)
|
||||
if self._cache_enabled:
|
||||
_write_json_cache(cache_path, {"translations": translations})
|
||||
return translations
|
||||
except (httpx.HTTPError, ValueError, TranslationError) as structured_exc:
|
||||
last_error = structured_exc
|
||||
|
||||
should_retry = self._should_retry(exc) or isinstance(last_error, TranslationError)
|
||||
if attempt >= self.config.max_retries or not should_retry:
|
||||
break
|
||||
|
||||
self._sleeper(self.config.retry_backoff_seconds * attempt)
|
||||
|
||||
if isinstance(last_error, TranslationError):
|
||||
raise last_error
|
||||
if isinstance(last_error, ValueError):
|
||||
raise TranslationError("LM Studio returned a non-JSON response.") from last_error
|
||||
raise TranslationError(f"LM Studio request failed: {last_error}") from last_error
|
||||
|
||||
def translate_segments(
|
||||
self,
|
||||
texts: List[str],
|
||||
target_language: str,
|
||||
source_language: str = "auto",
|
||||
) -> List[str]:
|
||||
"""Translate an ordered list of subtitle-like segments."""
|
||||
results: List[str] = []
|
||||
for text in texts:
|
||||
results.append(
|
||||
self.translate_text(
|
||||
text=text,
|
||||
"""Translate an ordered list of subtitle-like segments in contextual batches."""
|
||||
if not texts:
|
||||
return []
|
||||
|
||||
translated_segments: List[str] = []
|
||||
for batch in self.build_contextual_batches(texts):
|
||||
translated_segments.extend(
|
||||
self._translate_contextual_batch(
|
||||
batch=batch,
|
||||
target_language=target_language,
|
||||
source_language=source_language,
|
||||
)
|
||||
)
|
||||
return results
|
||||
|
||||
if len(translated_segments) != len(texts):
|
||||
raise TranslationError(
|
||||
f"LM Studio returned {len(translated_segments)} translated segments for {len(texts)} inputs."
|
||||
)
|
||||
|
||||
return translated_segments
|
||||
|
||||
def close(self) -> None:
|
||||
if self._owns_client:
|
||||
|
||||
@@ -4,7 +4,7 @@ from __future__ import annotations
|
||||
|
||||
from src.audio_separation import DEFAULT_MIX_MODE
|
||||
|
||||
from main import _build_translation_config, build_parser
|
||||
from main import _build_translation_config, _validate_source_args, build_parser
|
||||
|
||||
|
||||
def test_parser_accepts_lmstudio_flags():
|
||||
@@ -69,3 +69,37 @@ def test_parser_defaults_to_instrumental_only_mix_mode():
|
||||
args = parser.parse_args(["https://youtube.com/watch?v=demo"])
|
||||
|
||||
assert args.mix_mode == DEFAULT_MIX_MODE
|
||||
|
||||
|
||||
def test_parser_accepts_local_input_file_without_url():
|
||||
parser = build_parser()
|
||||
|
||||
args = parser.parse_args(["--input-file", "demo.mp4", "--lang", "fr"])
|
||||
|
||||
assert args.url is None
|
||||
assert args.input_file == "demo.mp4"
|
||||
assert args.lang == "fr"
|
||||
|
||||
|
||||
def test_validate_source_args_rejects_missing_source():
|
||||
parser = build_parser()
|
||||
args = parser.parse_args([])
|
||||
|
||||
try:
|
||||
_validate_source_args(args)
|
||||
except SystemExit as exc:
|
||||
assert "Provide either" in str(exc)
|
||||
else:
|
||||
raise AssertionError("Expected SystemExit for missing source")
|
||||
|
||||
|
||||
def test_validate_source_args_rejects_two_sources():
|
||||
parser = build_parser()
|
||||
args = parser.parse_args(["https://youtube.com/watch?v=demo", "--input-file", "demo.mp4"])
|
||||
|
||||
try:
|
||||
_validate_source_args(args)
|
||||
except SystemExit as exc:
|
||||
assert "not both" in str(exc)
|
||||
else:
|
||||
raise AssertionError("Expected SystemExit for two sources")
|
||||
|
||||
@@ -2,10 +2,13 @@
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
|
||||
import httpx
|
||||
import pytest
|
||||
|
||||
from src.core_utils import TranslationError
|
||||
from src import translation
|
||||
from src.translation import LMStudioTranslator, TranslationConfig
|
||||
|
||||
|
||||
@@ -13,6 +16,25 @@ def _mock_client(handler):
|
||||
return httpx.Client(transport=httpx.MockTransport(handler))
|
||||
|
||||
|
||||
def _mock_batch_response(translations):
|
||||
return httpx.Response(
|
||||
200,
|
||||
json={
|
||||
"choices": [
|
||||
{
|
||||
"message": {
|
||||
"content": json.dumps({"translations": translations}, ensure_ascii=False),
|
||||
}
|
||||
}
|
||||
]
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
def _read_request_json(request: httpx.Request):
|
||||
return json.loads(request.read().decode("utf-8"))
|
||||
|
||||
|
||||
def test_translation_config_normalizes_base_url():
|
||||
config = TranslationConfig.from_env(base_url="http://127.0.0.1:1234")
|
||||
|
||||
@@ -21,33 +43,102 @@ def test_translation_config_normalizes_base_url():
|
||||
assert config.model == "gemma-3-4b-it"
|
||||
|
||||
|
||||
def test_build_payload_includes_model_and_prompt():
|
||||
def test_build_contextual_batch_payload_includes_neighboring_segments():
|
||||
translator = LMStudioTranslator(TranslationConfig(), client=_mock_client(lambda request: None))
|
||||
|
||||
payload = translator.build_payload("Hello world", "en", "es")
|
||||
batches = translator.build_contextual_batches(
|
||||
["one", "two", "three", "four", "five", "six", "seven"],
|
||||
)
|
||||
|
||||
assert [len(batch.segments) for batch in batches] == [4, 3]
|
||||
|
||||
payload = translator.build_contextual_batch_payload(batches[0], "en", "es")
|
||||
user_payload = json.loads(payload["messages"][1]["content"])
|
||||
|
||||
assert payload["model"] == "gemma-3-4b-it"
|
||||
assert payload["messages"][0]["role"] == "system"
|
||||
assert "Translate the user-provided text from en to es." in payload["messages"][0]["content"]
|
||||
assert payload["messages"][1]["content"] == "Hello world"
|
||||
assert "expert audiovisual translator and dubbing script adapter" in payload["messages"][0]["content"]
|
||||
assert "Preserve segment ids and output order exactly" in payload["messages"][0]["content"]
|
||||
assert payload["temperature"] == 0.0
|
||||
assert payload["top_p"] == 1.0
|
||||
assert user_payload == {
|
||||
"source_language": "en",
|
||||
"target_language": "es",
|
||||
"previous_segments": [],
|
||||
"segments": [
|
||||
{"id": "0", "text": "one"},
|
||||
{"id": "1", "text": "two"},
|
||||
{"id": "2", "text": "three"},
|
||||
{"id": "3", "text": "four"},
|
||||
],
|
||||
"next_segments": [
|
||||
{"id": "4", "text": "five"},
|
||||
{"id": "5", "text": "six"},
|
||||
],
|
||||
}
|
||||
|
||||
|
||||
def test_translate_segments_preserves_order_and_blank_segments():
|
||||
def test_translate_segments_batches_context_and_preserves_exact_mapping():
|
||||
requests = []
|
||||
|
||||
def handler(request: httpx.Request) -> httpx.Response:
|
||||
text = request.read().decode("utf-8")
|
||||
if "first" in text:
|
||||
content = "primero"
|
||||
elif "third" in text:
|
||||
content = "tercero"
|
||||
else:
|
||||
content = "desconocido"
|
||||
return httpx.Response(200, json={"choices": [{"message": {"content": content}}]})
|
||||
payload = _read_request_json(request)
|
||||
batch_request = json.loads(payload["messages"][1]["content"])
|
||||
requests.append(batch_request)
|
||||
|
||||
translations = []
|
||||
for item in batch_request["segments"]:
|
||||
translated_text = "" if not item["text"].strip() else f"es::{item['id']}::{item['text']}"
|
||||
translations.append({"id": item["id"], "translated_text": translated_text})
|
||||
|
||||
return _mock_batch_response(translations)
|
||||
|
||||
translator = LMStudioTranslator(TranslationConfig(), client=_mock_client(handler))
|
||||
|
||||
translated = translator.translate_segments(["first", "", "third"], target_language="es", source_language="en")
|
||||
translated = translator.translate_segments(
|
||||
["first", "second", "", "fourth", "fifth", "sixth", "seventh"],
|
||||
target_language="es",
|
||||
source_language="en",
|
||||
)
|
||||
|
||||
assert translated == ["primero", "", "tercero"]
|
||||
assert translated == [
|
||||
"es::0::first",
|
||||
"es::1::second",
|
||||
"",
|
||||
"es::3::fourth",
|
||||
"es::4::fifth",
|
||||
"es::5::sixth",
|
||||
"es::6::seventh",
|
||||
]
|
||||
assert len(requests) == 2
|
||||
assert [item["id"] for item in requests[0]["segments"]] == ["0", "1", "2", "3"]
|
||||
assert requests[0]["previous_segments"] == []
|
||||
assert [item["id"] for item in requests[0]["next_segments"]] == ["4", "5"]
|
||||
assert [item["id"] for item in requests[1]["previous_segments"]] == ["2", "3"]
|
||||
assert [item["id"] for item in requests[1]["segments"]] == ["4", "5", "6"]
|
||||
|
||||
|
||||
def test_translate_segments_uses_persistent_cache(tmp_path, monkeypatch):
|
||||
requests = {"count": 0}
|
||||
monkeypatch.setattr(translation, "TRANSLATION_CACHE_DIR", tmp_path)
|
||||
|
||||
def handler(request: httpx.Request) -> httpx.Response:
|
||||
requests["count"] += 1
|
||||
payload = _read_request_json(request)
|
||||
batch_request = json.loads(payload["messages"][1]["content"])
|
||||
translations = [
|
||||
{"id": item["id"], "translated_text": f"cached::{item['text']}"}
|
||||
for item in batch_request["segments"]
|
||||
]
|
||||
return _mock_batch_response(translations)
|
||||
|
||||
config = TranslationConfig(model="cache-model")
|
||||
first_translator = LMStudioTranslator(config, client=_mock_client(handler))
|
||||
second_translator = LMStudioTranslator(config, client=_mock_client(handler))
|
||||
|
||||
assert first_translator.translate_segments(["hello", "world"], "fr", "en") == ["cached::hello", "cached::world"]
|
||||
assert second_translator.translate_segments(["hello", "world"], "fr", "en") == ["cached::hello", "cached::world"]
|
||||
assert requests["count"] == 1
|
||||
|
||||
|
||||
def test_retry_on_transient_http_error_then_succeeds():
|
||||
@@ -76,6 +167,77 @@ def test_parse_response_content_rejects_empty_content():
|
||||
LMStudioTranslator.parse_response_content({"choices": [{"message": {"content": " "}}]})
|
||||
|
||||
|
||||
def test_parse_batch_translation_response_rejects_missing_ids():
|
||||
translator = LMStudioTranslator(TranslationConfig(), client=_mock_client(lambda request: None))
|
||||
batch = translator.build_contextual_batches(["one", "two", "three"], batch_size=3)[0]
|
||||
|
||||
with pytest.raises(TranslationError, match="ids did not match the request"):
|
||||
LMStudioTranslator.parse_batch_translation_response(
|
||||
json.dumps(
|
||||
{
|
||||
"translations": [
|
||||
{"id": "0", "translated_text": "uno"},
|
||||
{"id": "2", "translated_text": "dos"},
|
||||
{"id": "2", "translated_text": "tres"},
|
||||
]
|
||||
}
|
||||
),
|
||||
batch,
|
||||
)
|
||||
|
||||
|
||||
def test_parse_batch_translation_response_rejects_out_of_order_ids():
|
||||
translator = LMStudioTranslator(TranslationConfig(), client=_mock_client(lambda request: None))
|
||||
batch = translator.build_contextual_batches(["one", "two", "three"], batch_size=3)[0]
|
||||
|
||||
with pytest.raises(TranslationError, match="out of order"):
|
||||
LMStudioTranslator.parse_batch_translation_response(
|
||||
json.dumps(
|
||||
{
|
||||
"translations": [
|
||||
{"id": "1", "translated_text": "dos"},
|
||||
{"id": "0", "translated_text": "uno"},
|
||||
{"id": "2", "translated_text": "tres"},
|
||||
]
|
||||
}
|
||||
),
|
||||
batch,
|
||||
)
|
||||
|
||||
|
||||
def test_translate_segments_retries_on_malformed_json_batch_response():
|
||||
attempts = {"count": 0}
|
||||
|
||||
def handler(request: httpx.Request) -> httpx.Response:
|
||||
attempts["count"] += 1
|
||||
payload = _read_request_json(request)
|
||||
batch_request = json.loads(payload["messages"][1]["content"])
|
||||
|
||||
if attempts["count"] == 1:
|
||||
return httpx.Response(200, json={"choices": [{"message": {"content": "not-json"}}]})
|
||||
|
||||
translations = [
|
||||
{"id": item["id"], "translated_text": f"ok::{item['text']}"}
|
||||
for item in batch_request["segments"]
|
||||
]
|
||||
return _mock_batch_response(translations)
|
||||
|
||||
translator = LMStudioTranslator(
|
||||
TranslationConfig(max_retries=2),
|
||||
client=_mock_client(handler),
|
||||
sleeper=lambda _: None,
|
||||
)
|
||||
|
||||
translated = translator.translate_segments(
|
||||
["alpha", "beta", "gamma"],
|
||||
target_language="es",
|
||||
source_language="en",
|
||||
)
|
||||
|
||||
assert translated == ["ok::alpha", "ok::beta", "ok::gamma"]
|
||||
assert attempts["count"] == 2
|
||||
|
||||
|
||||
def test_translate_text_raises_on_malformed_response():
|
||||
def handler(request: httpx.Request) -> httpx.Response:
|
||||
return httpx.Response(200, json={"choices": []})
|
||||
|
||||
85
tests/test_tts_retry.py
Normal file
85
tests/test_tts_retry.py
Normal file
@@ -0,0 +1,85 @@
|
||||
"""Tests for transient Edge TTS retry behavior."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
|
||||
from src import engines
|
||||
from src.engines import Engine
|
||||
from src.translation import TranslationConfig
|
||||
|
||||
|
||||
def test_synthesize_retries_transient_edge_tts_failure(tmp_path, monkeypatch):
|
||||
calls = {"count": 0}
|
||||
|
||||
class FakeCommunicate:
|
||||
def __init__(self, text, voice, rate):
|
||||
self.text = text
|
||||
self.voice = voice
|
||||
self.rate = rate
|
||||
|
||||
async def save(self, out_path):
|
||||
calls["count"] += 1
|
||||
if calls["count"] == 1:
|
||||
raise RuntimeError("transient 503")
|
||||
with open(out_path, "wb") as audio_file:
|
||||
audio_file.write(b"0" * 2048)
|
||||
|
||||
async def no_sleep(_seconds):
|
||||
return None
|
||||
|
||||
monkeypatch.setattr("src.engines.edge_tts.Communicate", FakeCommunicate)
|
||||
monkeypatch.setattr("src.engines.asyncio.sleep", no_sleep)
|
||||
monkeypatch.setattr("src.engines.DEFAULT_TTS_MAX_RETRIES", 2)
|
||||
|
||||
engine = Engine(
|
||||
"cpu",
|
||||
translation_config=TranslationConfig(
|
||||
base_url="http://127.0.0.1:1234/v1",
|
||||
api_key="test-key",
|
||||
model="test-model",
|
||||
),
|
||||
)
|
||||
out_path = tmp_path / "tts.mp3"
|
||||
|
||||
asyncio.run(engine.synthesize("Bonjour", "fr", out_path))
|
||||
|
||||
assert calls["count"] == 2
|
||||
assert out_path.exists()
|
||||
assert out_path.stat().st_size == 2048
|
||||
|
||||
|
||||
def test_synthesize_uses_tts_cache(tmp_path, monkeypatch):
|
||||
calls = {"count": 0}
|
||||
cache_dir = tmp_path / "tts-cache"
|
||||
monkeypatch.setattr(engines, "TTS_CACHE_DIR", cache_dir)
|
||||
|
||||
class FakeCommunicate:
|
||||
def __init__(self, text, voice, rate):
|
||||
self.text = text
|
||||
self.voice = voice
|
||||
self.rate = rate
|
||||
|
||||
async def save(self, out_path):
|
||||
calls["count"] += 1
|
||||
with open(out_path, "wb") as audio_file:
|
||||
audio_file.write(b"1" * 2048)
|
||||
|
||||
monkeypatch.setattr("src.engines.edge_tts.Communicate", FakeCommunicate)
|
||||
|
||||
engine = Engine(
|
||||
"cpu",
|
||||
translation_config=TranslationConfig(
|
||||
base_url="http://127.0.0.1:1234/v1",
|
||||
api_key="test-key",
|
||||
model="test-model",
|
||||
),
|
||||
)
|
||||
first_out = tmp_path / "first.mp3"
|
||||
second_out = tmp_path / "second.mp3"
|
||||
|
||||
asyncio.run(engine.synthesize("Bonjour", "fr", first_out))
|
||||
asyncio.run(engine.synthesize("Bonjour", "fr", second_out))
|
||||
|
||||
assert calls["count"] == 1
|
||||
assert first_out.read_bytes() == second_out.read_bytes()
|
||||
151
tests/test_web_app.py
Normal file
151
tests/test_web_app.py
Normal file
@@ -0,0 +1,151 @@
|
||||
"""Tests for the Gradio web UI command adapter."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import sys
|
||||
|
||||
import web_app
|
||||
from web_app import (
|
||||
DubJob,
|
||||
_job_progress,
|
||||
_stage_uploaded_cookies,
|
||||
build_pipeline_command,
|
||||
create_app,
|
||||
load_translation_settings,
|
||||
save_translation_settings,
|
||||
)
|
||||
|
||||
|
||||
def test_build_pipeline_command_uses_cli_parser_defaults():
|
||||
command = build_pipeline_command({"url": "https://youtube.com/watch?v=demo"})
|
||||
|
||||
assert command[:3] == [sys.executable, command[1], "https://youtube.com/watch?v=demo"]
|
||||
assert "--lang" in command
|
||||
assert command[command.index("--lang") + 1] == "es"
|
||||
assert "--mix-mode" in command
|
||||
assert command[command.index("--mix-mode") + 1] == "instrumental-only"
|
||||
|
||||
|
||||
def test_build_pipeline_command_accepts_optional_settings():
|
||||
command = build_pipeline_command(
|
||||
{
|
||||
"url": "https://youtube.com/watch?v=demo",
|
||||
"lang": "fr",
|
||||
"browser": "chrome",
|
||||
"whisper_model": "small",
|
||||
"lmstudio_base_url": "http://localhost:1234/v1",
|
||||
"lmstudio_model": "gemma-custom",
|
||||
"gpu": "on",
|
||||
}
|
||||
)
|
||||
|
||||
assert command[command.index("--lang") + 1] == "fr"
|
||||
assert command[command.index("--browser") + 1] == "chrome"
|
||||
assert command[command.index("--whisper_model") + 1] == "small"
|
||||
assert command[command.index("--lmstudio-base-url") + 1] == "http://localhost:1234/v1"
|
||||
assert command[command.index("--lmstudio-model") + 1] == "gemma-custom"
|
||||
assert "--gpu" in command
|
||||
|
||||
|
||||
def test_build_pipeline_command_accepts_uploaded_mp4():
|
||||
command = build_pipeline_command(
|
||||
{
|
||||
"input_file": "C:\\videos\\demo.mp4",
|
||||
"lang": "de",
|
||||
}
|
||||
)
|
||||
|
||||
assert "https://youtube.com/watch?v=demo" not in command
|
||||
assert "--input-file" in command
|
||||
assert command[command.index("--input-file") + 1] == "C:\\videos\\demo.mp4"
|
||||
assert command[command.index("--lang") + 1] == "de"
|
||||
|
||||
|
||||
def test_create_app_builds_gradio_blocks():
|
||||
app = create_app()
|
||||
|
||||
assert app.title == "Gradio YouTube Auto Dub"
|
||||
|
||||
|
||||
def test_save_and_load_translation_settings(tmp_path, monkeypatch):
|
||||
settings_file = tmp_path / "web_settings.json"
|
||||
monkeypatch.setattr(web_app, "SETTINGS_FILE", settings_file)
|
||||
|
||||
base_url, api_key, model, message = save_translation_settings(
|
||||
"http://openai-compatible.local:8080/v1",
|
||||
"secret-key",
|
||||
"custom-model",
|
||||
)
|
||||
|
||||
assert base_url == "http://openai-compatible.local:8080/v1"
|
||||
assert api_key == "secret-key"
|
||||
assert model == "custom-model"
|
||||
assert str(settings_file) in message
|
||||
assert load_translation_settings() == {
|
||||
"base_url": "http://openai-compatible.local:8080/v1",
|
||||
"api_key": "secret-key",
|
||||
"model": "custom-model",
|
||||
}
|
||||
|
||||
|
||||
def test_load_translation_settings_uses_env_defaults(tmp_path, monkeypatch):
|
||||
monkeypatch.setattr(web_app, "SETTINGS_FILE", tmp_path / "missing.json")
|
||||
monkeypatch.setenv("LM_STUDIO_BASE_URL", "http://env-host:1234/v1")
|
||||
monkeypatch.setenv("LM_STUDIO_API_KEY", "env-key")
|
||||
monkeypatch.setenv("LM_STUDIO_MODEL", "env-model")
|
||||
|
||||
assert load_translation_settings() == {
|
||||
"base_url": "http://env-host:1234/v1",
|
||||
"api_key": "env-key",
|
||||
"model": "env-model",
|
||||
}
|
||||
|
||||
|
||||
def test_stage_uploaded_cookies_copies_to_upload_dir(tmp_path, monkeypatch):
|
||||
upload_dir = tmp_path / "uploads"
|
||||
source_file = tmp_path / "cookies.txt"
|
||||
source_file.write_text("# Netscape HTTP Cookie File\n", encoding="utf-8")
|
||||
monkeypatch.setattr(web_app, "UPLOAD_DIR", upload_dir)
|
||||
|
||||
staged_path = _stage_uploaded_cookies(str(source_file))
|
||||
|
||||
assert staged_path.endswith(".txt")
|
||||
assert staged_path != str(source_file)
|
||||
assert upload_dir in web_app.Path(staged_path).parents
|
||||
assert web_app.Path(staged_path).read_text(encoding="utf-8") == "# Netscape HTTP Cookie File\n"
|
||||
|
||||
|
||||
def test_stage_uploaded_cookies_rejects_unsupported_extension(tmp_path):
|
||||
source_file = tmp_path / "cookies.json"
|
||||
source_file.write_text("{}", encoding="utf-8")
|
||||
|
||||
try:
|
||||
_stage_uploaded_cookies(str(source_file))
|
||||
except ValueError as exc:
|
||||
assert "Expected one of" in str(exc)
|
||||
else:
|
||||
raise AssertionError("Expected ValueError for unsupported cookie upload")
|
||||
|
||||
|
||||
def test_job_progress_tracks_pipeline_steps(tmp_path):
|
||||
log_path = tmp_path / "job.log"
|
||||
log_path.write_text("STEP 1: PREPARING CONTENT\nSTEP 2: SPEECH TRANSCRIPTION\n", encoding="utf-8")
|
||||
job = DubJob(id="demo", command=[], log_path=log_path, status="running")
|
||||
|
||||
progress, steps_html = _job_progress(job)
|
||||
|
||||
assert progress == 25
|
||||
assert "[done]" in steps_html
|
||||
assert "[active]" in steps_html
|
||||
assert "Speech transcription" in steps_html
|
||||
|
||||
|
||||
def test_job_progress_marks_succeeded_complete(tmp_path):
|
||||
log_path = tmp_path / "job.log"
|
||||
log_path.write_text("", encoding="utf-8")
|
||||
job = DubJob(id="demo", command=[], log_path=log_path, status="succeeded")
|
||||
|
||||
progress, steps_html = _job_progress(job)
|
||||
|
||||
assert progress == 100
|
||||
assert "[todo]" not in steps_html
|
||||
555
web_app.py
Normal file
555
web_app.py
Normal file
@@ -0,0 +1,555 @@
|
||||
#!/usr/bin/env python3
|
||||
"""Gradio web UI for launching YouTube Auto Dub jobs."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
from datetime import datetime, timezone
|
||||
import html
|
||||
import json
|
||||
from pathlib import Path
|
||||
import os
|
||||
import shutil
|
||||
import subprocess
|
||||
import sys
|
||||
import threading
|
||||
import uuid
|
||||
|
||||
import gradio as gr
|
||||
|
||||
from main import build_parser
|
||||
from src.audio_separation import DEFAULT_MIX_MODE
|
||||
from src.engines import OUTPUT_DIR
|
||||
from src.translation import (
|
||||
DEFAULT_LM_STUDIO_API_KEY,
|
||||
DEFAULT_LM_STUDIO_BASE_URL,
|
||||
DEFAULT_LM_STUDIO_MODEL,
|
||||
)
|
||||
|
||||
|
||||
BASE_DIR = Path(__file__).resolve().parent
|
||||
LOG_DIR = BASE_DIR / "logs" / "gradio"
|
||||
SETTINGS_FILE = BASE_DIR / ".cache" / "web_settings.json"
|
||||
UPLOAD_DIR = BASE_DIR / ".cache" / "uploads"
|
||||
PIPELINE_STEPS = [
|
||||
("STEP 1", "Preparing content"),
|
||||
("STEP 2", "Speech transcription"),
|
||||
("STEP 3", "Intelligent chunking"),
|
||||
("STEP 4", "Translation"),
|
||||
("STEP 5", "Dub audio synthesis"),
|
||||
("STEP 6", "Subtitle generation"),
|
||||
("STEP 7", "Audio bed preparation"),
|
||||
("STEP 8", "Final video rendering"),
|
||||
]
|
||||
|
||||
|
||||
@dataclass
|
||||
class DubJob:
|
||||
"""Runtime state for a web-launched dub job."""
|
||||
|
||||
id: str
|
||||
command: list[str]
|
||||
log_path: Path
|
||||
env_overrides: dict[str, str] = field(default_factory=dict)
|
||||
created_at: datetime = field(default_factory=lambda: datetime.now(timezone.utc))
|
||||
status: str = "queued"
|
||||
returncode: int | None = None
|
||||
completed_at: datetime | None = None
|
||||
|
||||
|
||||
JOBS: dict[str, DubJob] = {}
|
||||
JOBS_LOCK = threading.Lock()
|
||||
|
||||
|
||||
def _default_translation_settings() -> dict[str, str]:
|
||||
return {
|
||||
"base_url": os.getenv("LM_STUDIO_BASE_URL") or DEFAULT_LM_STUDIO_BASE_URL,
|
||||
"api_key": os.getenv("LM_STUDIO_API_KEY") or DEFAULT_LM_STUDIO_API_KEY,
|
||||
"model": os.getenv("LM_STUDIO_MODEL") or DEFAULT_LM_STUDIO_MODEL,
|
||||
}
|
||||
|
||||
|
||||
def load_translation_settings() -> dict[str, str]:
|
||||
"""Load saved OpenAI-compatible translation settings."""
|
||||
settings = _default_translation_settings()
|
||||
if not SETTINGS_FILE.exists():
|
||||
return settings
|
||||
|
||||
try:
|
||||
payload = json.loads(SETTINGS_FILE.read_text(encoding="utf-8"))
|
||||
except (OSError, json.JSONDecodeError):
|
||||
return settings
|
||||
|
||||
if not isinstance(payload, dict):
|
||||
return settings
|
||||
|
||||
for key in settings:
|
||||
value = payload.get(key)
|
||||
if isinstance(value, str) and value.strip():
|
||||
settings[key] = value.strip()
|
||||
return settings
|
||||
|
||||
|
||||
def save_translation_settings(base_url: str, api_key: str, model: str) -> tuple[str, str, str, str]:
|
||||
"""Persist OpenAI-compatible endpoint settings for future web jobs."""
|
||||
settings = {
|
||||
"base_url": (base_url or "").strip() or DEFAULT_LM_STUDIO_BASE_URL,
|
||||
"api_key": (api_key or "").strip() or DEFAULT_LM_STUDIO_API_KEY,
|
||||
"model": (model or "").strip() or DEFAULT_LM_STUDIO_MODEL,
|
||||
}
|
||||
SETTINGS_FILE.parent.mkdir(parents=True, exist_ok=True)
|
||||
SETTINGS_FILE.write_text(json.dumps(settings, indent=2), encoding="utf-8")
|
||||
return (
|
||||
settings["base_url"],
|
||||
settings["api_key"],
|
||||
settings["model"],
|
||||
f"Saved settings to {SETTINGS_FILE}",
|
||||
)
|
||||
|
||||
|
||||
def _utc_iso(value: datetime | None) -> str | None:
|
||||
if value is None:
|
||||
return None
|
||||
return value.astimezone(timezone.utc).isoformat()
|
||||
|
||||
|
||||
def build_pipeline_command(form: dict[str, str | bool]) -> list[str]:
|
||||
"""Build a validated command for the existing CLI pipeline."""
|
||||
parser = build_parser()
|
||||
args = parser.parse_args(_form_to_cli_args(form))
|
||||
command = [
|
||||
sys.executable,
|
||||
str(BASE_DIR / "main.py"),
|
||||
"--lang",
|
||||
args.lang,
|
||||
"--mix-mode",
|
||||
args.mix_mode,
|
||||
]
|
||||
if args.url:
|
||||
command.insert(2, args.url)
|
||||
if args.input_file:
|
||||
command.extend(["--input-file", args.input_file])
|
||||
if args.translation_backend:
|
||||
command.extend(["--translation-backend", args.translation_backend])
|
||||
|
||||
optional_flags = {
|
||||
"--browser": args.browser,
|
||||
"--cookies": args.cookies,
|
||||
"--whisper_model": args.whisper_model,
|
||||
"--lmstudio-base-url": args.lmstudio_base_url,
|
||||
"--lmstudio-model": args.lmstudio_model,
|
||||
}
|
||||
for flag, value in optional_flags.items():
|
||||
if value:
|
||||
command.extend([flag, value])
|
||||
|
||||
if args.gpu:
|
||||
command.append("--gpu")
|
||||
|
||||
return command
|
||||
|
||||
|
||||
def _form_to_cli_args(form: dict[str, str | bool]) -> list[str]:
|
||||
url = (form.get("url") or "").strip()
|
||||
input_file = (form.get("input_file") or "").strip()
|
||||
if not url and not input_file:
|
||||
raise ValueError("A YouTube URL or uploaded MP4 is required.")
|
||||
if url and input_file:
|
||||
raise ValueError("Use either a YouTube URL or uploaded MP4, not both.")
|
||||
|
||||
cli_args = [url] if url else []
|
||||
if input_file:
|
||||
cli_args.extend(["--input-file", input_file])
|
||||
field_flags = {
|
||||
"lang": "--lang",
|
||||
"browser": "--browser",
|
||||
"cookies": "--cookies",
|
||||
"whisper_model": "--whisper_model",
|
||||
"mix_mode": "--mix-mode",
|
||||
"translation_backend": "--translation-backend",
|
||||
"lmstudio_base_url": "--lmstudio-base-url",
|
||||
"lmstudio_model": "--lmstudio-model",
|
||||
}
|
||||
|
||||
defaults = {
|
||||
"lang": "es",
|
||||
"mix_mode": DEFAULT_MIX_MODE,
|
||||
"translation_backend": "lmstudio",
|
||||
}
|
||||
|
||||
for field_name, flag in field_flags.items():
|
||||
value = (form.get(field_name) or defaults.get(field_name) or "").strip()
|
||||
if value:
|
||||
cli_args.extend([flag, value])
|
||||
|
||||
gpu_value = form.get("gpu")
|
||||
if gpu_value is True or str(gpu_value).lower() in {"1", "true", "on", "yes"}:
|
||||
cli_args.append("--gpu")
|
||||
|
||||
return cli_args
|
||||
|
||||
|
||||
def _stage_uploaded_mp4(uploaded_file: str | None) -> str:
|
||||
return _stage_uploaded_file(uploaded_file, allowed_suffixes={".mp4"}, fallback_name="upload")
|
||||
|
||||
|
||||
def _stage_uploaded_cookies(uploaded_file: str | None) -> str:
|
||||
return _stage_uploaded_file(
|
||||
uploaded_file,
|
||||
allowed_suffixes={".txt", ".cookies", ".cookie"},
|
||||
fallback_name="cookies",
|
||||
)
|
||||
|
||||
|
||||
def _stage_uploaded_file(
|
||||
uploaded_file: str | None,
|
||||
allowed_suffixes: set[str],
|
||||
fallback_name: str,
|
||||
) -> str:
|
||||
if not uploaded_file:
|
||||
return ""
|
||||
|
||||
source_path = Path(uploaded_file)
|
||||
suffix = source_path.suffix.lower()
|
||||
if suffix not in allowed_suffixes:
|
||||
expected = ", ".join(sorted(allowed_suffixes))
|
||||
raise ValueError(f"Unsupported upload type. Expected one of: {expected}.")
|
||||
if not source_path.exists():
|
||||
raise FileNotFoundError(f"Uploaded file not found: {source_path}")
|
||||
|
||||
safe_stem = "".join(char if char.isalnum() or char in {"-", "_"} else "_" for char in source_path.stem)
|
||||
staged_name = f"{uuid.uuid4().hex[:12]}_{safe_stem or fallback_name}{suffix}"
|
||||
UPLOAD_DIR.mkdir(parents=True, exist_ok=True)
|
||||
staged_path = UPLOAD_DIR / staged_name
|
||||
shutil.copy2(source_path, staged_path)
|
||||
return str(staged_path)
|
||||
|
||||
|
||||
def _format_job_status(job: DubJob | None) -> str:
|
||||
if job is None:
|
||||
return "Ready"
|
||||
|
||||
lines = [
|
||||
f"Job: {job.id}",
|
||||
f"Status: {job.status}",
|
||||
f"Created: {_utc_iso(job.created_at)}",
|
||||
]
|
||||
if job.completed_at:
|
||||
lines.append(f"Completed: {_utc_iso(job.completed_at)}")
|
||||
if job.returncode is not None:
|
||||
lines.append(f"Return code: {job.returncode}")
|
||||
return "\n".join(lines)
|
||||
|
||||
|
||||
def _read_log_tail(log_path: Path, max_chars: int = 20000) -> str:
|
||||
if not log_path.exists():
|
||||
return ""
|
||||
text = log_path.read_text(encoding="utf-8", errors="replace")
|
||||
return text[-max_chars:]
|
||||
|
||||
|
||||
def _job_progress(job: DubJob | None) -> tuple[int, str]:
|
||||
"""Return a coarse progress percentage and HTML step summary."""
|
||||
if job is None:
|
||||
return 0, _render_steps_html(0, "queued")
|
||||
|
||||
log_text = _read_log_tail(job.log_path)
|
||||
current_step = 0
|
||||
for index, (marker, _) in enumerate(PIPELINE_STEPS, start=1):
|
||||
if marker in log_text:
|
||||
current_step = index
|
||||
|
||||
if job.status == "succeeded":
|
||||
return 100, _render_steps_html(len(PIPELINE_STEPS), job.status)
|
||||
|
||||
progress = int((current_step / len(PIPELINE_STEPS)) * 100)
|
||||
if job.status == "running" and progress == 0:
|
||||
progress = 3
|
||||
return progress, _render_steps_html(current_step, job.status)
|
||||
|
||||
|
||||
def _render_steps_html(current_step: int, status: str) -> str:
|
||||
rows = []
|
||||
failed = status == "failed"
|
||||
for index, (_, label) in enumerate(PIPELINE_STEPS, start=1):
|
||||
if failed and index == max(current_step, 1):
|
||||
state = "failed"
|
||||
elif index < current_step or status == "succeeded":
|
||||
state = "done"
|
||||
elif index == current_step and status in {"queued", "running"}:
|
||||
state = "active"
|
||||
else:
|
||||
state = "todo"
|
||||
|
||||
rows.append(
|
||||
"<li>"
|
||||
f"<strong>[{html.escape(state)}]</strong> "
|
||||
f"{index}. {html.escape(label)}"
|
||||
"</li>"
|
||||
)
|
||||
|
||||
return "<ul>" + "".join(rows) + "</ul>"
|
||||
|
||||
|
||||
def _render_progress_html(progress: int) -> str:
|
||||
bounded_progress = max(0, min(100, int(progress)))
|
||||
return (
|
||||
"<div>"
|
||||
"<label><strong>Progress</strong></label>"
|
||||
f"<progress value='{bounded_progress}' max='100' style='width: 100%; height: 24px;'></progress>"
|
||||
f"<div>{bounded_progress}%</div>"
|
||||
"</div>"
|
||||
)
|
||||
|
||||
|
||||
def _run_job(job: DubJob) -> None:
|
||||
with JOBS_LOCK:
|
||||
job.status = "running"
|
||||
|
||||
env = os.environ.copy()
|
||||
env["PYTHONUNBUFFERED"] = "1"
|
||||
env.update(job.env_overrides)
|
||||
|
||||
with job.log_path.open("w", encoding="utf-8", errors="replace") as log_file:
|
||||
log_file.write("Gradio started a YouTube Auto Dub job.\n")
|
||||
log_file.write(f"Command: {' '.join(job.command)}\n\n")
|
||||
log_file.flush()
|
||||
|
||||
process = subprocess.Popen(
|
||||
job.command,
|
||||
cwd=BASE_DIR,
|
||||
env=env,
|
||||
stdout=log_file,
|
||||
stderr=subprocess.STDOUT,
|
||||
text=True,
|
||||
)
|
||||
returncode = process.wait()
|
||||
|
||||
with JOBS_LOCK:
|
||||
job.returncode = returncode
|
||||
job.completed_at = datetime.now(timezone.utc)
|
||||
job.status = "succeeded" if returncode == 0 else "failed"
|
||||
|
||||
|
||||
def _list_outputs() -> list[Path]:
|
||||
OUTPUT_DIR.mkdir(parents=True, exist_ok=True)
|
||||
return sorted(
|
||||
(path for path in OUTPUT_DIR.glob("*") if path.is_file()),
|
||||
key=lambda path: path.stat().st_mtime,
|
||||
reverse=True,
|
||||
)
|
||||
|
||||
def _output_choices() -> list[str]:
|
||||
return [path.name for path in _list_outputs()[:20]]
|
||||
|
||||
|
||||
def _start_job(
|
||||
url: str,
|
||||
uploaded_mp4: str | None,
|
||||
lang: str,
|
||||
whisper_model: str,
|
||||
mix_mode: str,
|
||||
browser: str,
|
||||
cookies_upload: str | None,
|
||||
lmstudio_base_url: str,
|
||||
lmstudio_api_key: str,
|
||||
lmstudio_model: str,
|
||||
gpu: bool,
|
||||
) -> tuple[str, str, str, str, gr.Dropdown]:
|
||||
saved_settings = load_translation_settings()
|
||||
base_url = (lmstudio_base_url or "").strip() or saved_settings["base_url"]
|
||||
api_key = (lmstudio_api_key or "").strip() or saved_settings["api_key"]
|
||||
model = (lmstudio_model or "").strip() or saved_settings["model"]
|
||||
try:
|
||||
input_file = _stage_uploaded_mp4(uploaded_mp4)
|
||||
except (OSError, ValueError) as exc:
|
||||
message = str(exc) or "Invalid uploaded MP4."
|
||||
return "", message, _render_progress_html(0), _render_steps_html(0, "failed"), gr.update(choices=_output_choices())
|
||||
try:
|
||||
cookies = _stage_uploaded_cookies(cookies_upload)
|
||||
except (OSError, ValueError) as exc:
|
||||
message = str(exc) or "Invalid uploaded cookies file."
|
||||
return "", message, _render_progress_html(0), _render_steps_html(0, "failed"), gr.update(choices=_output_choices())
|
||||
|
||||
form = {
|
||||
"url": url,
|
||||
"input_file": input_file,
|
||||
"lang": lang,
|
||||
"whisper_model": whisper_model,
|
||||
"mix_mode": mix_mode,
|
||||
"browser": browser,
|
||||
"cookies": cookies,
|
||||
"translation_backend": "lmstudio",
|
||||
"lmstudio_base_url": base_url,
|
||||
"lmstudio_model": model,
|
||||
"gpu": gpu,
|
||||
}
|
||||
|
||||
try:
|
||||
command = build_pipeline_command(form)
|
||||
except (SystemExit, ValueError) as exc:
|
||||
message = str(exc) or "Invalid job options."
|
||||
return "", message, _render_progress_html(0), _render_steps_html(0, "failed"), gr.update(choices=_output_choices())
|
||||
|
||||
LOG_DIR.mkdir(parents=True, exist_ok=True)
|
||||
job_id = uuid.uuid4().hex[:12]
|
||||
job = DubJob(
|
||||
id=job_id,
|
||||
command=command,
|
||||
log_path=LOG_DIR / f"{job_id}.log",
|
||||
env_overrides={
|
||||
"LM_STUDIO_BASE_URL": base_url,
|
||||
"LM_STUDIO_API_KEY": api_key,
|
||||
"LM_STUDIO_MODEL": model,
|
||||
},
|
||||
)
|
||||
|
||||
with JOBS_LOCK:
|
||||
JOBS[job.id] = job
|
||||
|
||||
thread = threading.Thread(target=_run_job, args=(job,), daemon=True)
|
||||
thread.start()
|
||||
progress_value, steps_html = _job_progress(job)
|
||||
return job.id, _format_job_status(job), _render_progress_html(progress_value), steps_html, gr.update(choices=_output_choices())
|
||||
|
||||
|
||||
def _refresh_job(job_id: str) -> tuple[str, str, str, gr.Dropdown]:
|
||||
with JOBS_LOCK:
|
||||
job = JOBS.get(job_id)
|
||||
|
||||
if job is None:
|
||||
return "Ready", _render_progress_html(0), _render_steps_html(0, "queued"), gr.update(choices=_output_choices())
|
||||
|
||||
progress_value, steps_html = _job_progress(job)
|
||||
return _format_job_status(job), _render_progress_html(progress_value), steps_html, gr.update(choices=_output_choices())
|
||||
|
||||
|
||||
def _select_output(filename: str | None) -> str | None:
|
||||
if not filename:
|
||||
return None
|
||||
|
||||
output_path = OUTPUT_DIR / filename
|
||||
if not output_path.exists() or not output_path.is_file():
|
||||
return None
|
||||
return str(output_path)
|
||||
|
||||
|
||||
def create_app() -> gr.Blocks:
|
||||
"""Create the Gradio app."""
|
||||
saved_settings = load_translation_settings()
|
||||
with gr.Blocks(title="Gradio YouTube Auto Dub") as demo:
|
||||
gr.Markdown(
|
||||
"""
|
||||
# YouTube Auto Dub
|
||||
Start local dubbing jobs, watch progress, and collect finished videos.
|
||||
"""
|
||||
)
|
||||
job_id = gr.State("")
|
||||
log_timer = gr.Timer(value=2.0, active=True)
|
||||
|
||||
with gr.Row():
|
||||
with gr.Column(scale=5):
|
||||
url = gr.Textbox(label="YouTube URL", placeholder="https://www.youtube.com/watch?v=...")
|
||||
uploaded_mp4 = gr.File(
|
||||
label="Upload MP4",
|
||||
file_types=[".mp4"],
|
||||
type="filepath",
|
||||
)
|
||||
with gr.Row():
|
||||
lang = gr.Textbox(label="Target Language", value="es", max_lines=1)
|
||||
whisper_model = gr.Dropdown(
|
||||
label="Whisper Model",
|
||||
choices=["", "tiny", "base", "small", "medium", "large-v3"],
|
||||
value="",
|
||||
)
|
||||
with gr.Row():
|
||||
mix_mode = gr.Dropdown(
|
||||
label="Mix Mode",
|
||||
choices=[DEFAULT_MIX_MODE, "dub-only", "original-audio"],
|
||||
value=DEFAULT_MIX_MODE,
|
||||
)
|
||||
browser = gr.Dropdown(
|
||||
label="Browser Cookies",
|
||||
choices=["", "chrome", "edge", "firefox", "brave"],
|
||||
value="",
|
||||
)
|
||||
cookies_upload = gr.File(
|
||||
label="Upload Cookies File",
|
||||
file_types=[".txt", ".cookies", ".cookie"],
|
||||
type="filepath",
|
||||
)
|
||||
|
||||
with gr.Accordion("OpenAI-Compatible Settings", open=False):
|
||||
lmstudio_base_url = gr.Textbox(
|
||||
label="Endpoint",
|
||||
value=saved_settings["base_url"],
|
||||
placeholder=DEFAULT_LM_STUDIO_BASE_URL,
|
||||
)
|
||||
lmstudio_api_key = gr.Textbox(
|
||||
label="API Key",
|
||||
value=saved_settings["api_key"],
|
||||
type="password",
|
||||
)
|
||||
lmstudio_model = gr.Textbox(
|
||||
label="Model",
|
||||
value=saved_settings["model"],
|
||||
placeholder=DEFAULT_LM_STUDIO_MODEL,
|
||||
)
|
||||
with gr.Row():
|
||||
save_settings = gr.Button("Save Settings")
|
||||
settings_status = gr.Textbox(
|
||||
label="Settings Status",
|
||||
value=f"Loaded from {SETTINGS_FILE if SETTINGS_FILE.exists() else 'environment defaults'}",
|
||||
interactive=False,
|
||||
)
|
||||
gpu = gr.Checkbox(label="Prefer GPU", value=False)
|
||||
|
||||
start = gr.Button("Start Dub", variant="primary")
|
||||
|
||||
with gr.Column(scale=7):
|
||||
status = gr.Textbox(label="Job Status", value="Ready", lines=5, interactive=False)
|
||||
progress = gr.HTML(value=_render_progress_html(0))
|
||||
steps = gr.HTML(label="Steps", value=_render_steps_html(0, "queued"))
|
||||
refresh = gr.Button("Refresh")
|
||||
|
||||
with gr.Row():
|
||||
output_choice = gr.Dropdown(label="Finished Outputs", choices=_output_choices(), interactive=True)
|
||||
output_file = gr.File(label="Download Selected Output", interactive=False)
|
||||
|
||||
inputs = [
|
||||
url,
|
||||
uploaded_mp4,
|
||||
lang,
|
||||
whisper_model,
|
||||
mix_mode,
|
||||
browser,
|
||||
cookies_upload,
|
||||
lmstudio_base_url,
|
||||
lmstudio_api_key,
|
||||
lmstudio_model,
|
||||
gpu,
|
||||
]
|
||||
save_settings.click(
|
||||
save_translation_settings,
|
||||
inputs=[lmstudio_base_url, lmstudio_api_key, lmstudio_model],
|
||||
outputs=[lmstudio_base_url, lmstudio_api_key, lmstudio_model, settings_status],
|
||||
)
|
||||
start.click(
|
||||
_start_job,
|
||||
inputs=inputs,
|
||||
outputs=[job_id, status, progress, steps, output_choice],
|
||||
)
|
||||
refresh.click(_refresh_job, inputs=[job_id], outputs=[status, progress, steps, output_choice])
|
||||
log_timer.tick(_refresh_job, inputs=[job_id], outputs=[status, progress, steps, output_choice])
|
||||
output_choice.change(_select_output, inputs=[output_choice], outputs=[output_file])
|
||||
|
||||
return demo
|
||||
|
||||
|
||||
app = create_app()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
server_name = os.getenv("GRADIO_SERVER_NAME", "127.0.0.1")
|
||||
server_port = int(os.getenv("PORT", "7860"))
|
||||
app.launch(server_name=server_name, server_port=server_port)
|
||||
Reference in New Issue
Block a user