feat(translation): contextual batch translation with strict validation

This commit is contained in:
2026-03-30 19:05:50 +01:00
parent 3c9b3c8090
commit d83c57cda3
2 changed files with 462 additions and 31 deletions

View File

@@ -2,6 +2,7 @@
from __future__ import annotations
import json
import os
import time
from dataclasses import dataclass
@@ -16,6 +17,9 @@ DEFAULT_LM_STUDIO_BASE_URL = "http://127.0.0.1:1234/v1"
DEFAULT_LM_STUDIO_API_KEY = "lm-studio"
DEFAULT_LM_STUDIO_MODEL = "gemma-3-4b-it"
DEFAULT_TRANSLATION_BACKEND = "lmstudio"
DEFAULT_CONTEXTUAL_BATCH_SIZE = 5
MIN_CONTEXTUAL_BATCH_SIZE = 3
DEFAULT_CONTEXT_SEGMENTS = 2
def _normalize_base_url(base_url: str) -> str:
@@ -118,6 +122,53 @@ def _build_system_prompt(source_language: str, target_language: str) -> str:
)
def _build_contextual_system_prompt(source_language: str, target_language: str) -> str:
source_descriptor = source_language or "auto"
return (
"You are an expert audiovisual translator for dubbed video content.\n\n"
f"Translate subtitle segments from {source_descriptor} to {target_language} for natural spoken dubbing.\n\n"
"Rules:\n"
"- Preserve meaning, intent, tone, and subtext.\n"
"- Use surrounding subtitle context to resolve ambiguity.\n"
"- Do not summarize.\n"
"- Do not simplify unless needed for natural speech.\n"
"- Do not add explanations, notes, or commentary.\n"
"- Preserve humor, sarcasm, emotional tone, and register.\n"
"- Keep names, brands, URLs, emails, file paths, code, and product names unchanged unless transliteration is clearly needed.\n"
"- Keep the translation natural for spoken dubbing.\n"
"- Preserve segment boundaries exactly.\n"
"- Return exactly one translated item per input segment.\n"
"- Output order must match input order exactly.\n"
"- If an input segment is empty or whitespace-only, return an empty translated_text for that id.\n"
"- Output valid JSON only.\n"
'- Return this exact schema: {"translations":[{"id":"...","translated_text":"..."}]}.'
)
@dataclass(frozen=True)
class TranslationSegment:
"""A subtitle segment prepared for contextual batch translation."""
id: str
text: str
def as_payload(self) -> Dict[str, str]:
return {"id": self.id, "text": self.text}
@dataclass(frozen=True)
class TranslationBatch:
"""A contextual subtitle translation batch."""
previous_segments: List[TranslationSegment]
segments: List[TranslationSegment]
next_segments: List[TranslationSegment]
@property
def segment_ids(self) -> List[str]:
return [segment.id for segment in self.segments]
class LMStudioTranslator:
"""OpenAI-style chat completions client for LM Studio."""
@@ -133,6 +184,14 @@ class LMStudioTranslator:
self._owns_client = client is None
self._sleeper = sleeper
@staticmethod
def _generation_settings() -> Dict[str, Any]:
return {
"temperature": 0.0,
"top_p": 1.0,
"stream": False,
}
def build_payload(self, text: str, source_language: str, target_language: str) -> Dict[str, Any]:
"""Build the OpenAI-compatible chat completions payload."""
return {
@@ -141,9 +200,7 @@ class LMStudioTranslator:
{"role": "system", "content": _build_system_prompt(source_language, target_language)},
{"role": "user", "content": text},
],
"temperature": 0.1,
"top_p": 1,
"stream": False,
**self._generation_settings(),
}
def build_user_only_payload(
@@ -160,9 +217,7 @@ class LMStudioTranslator:
"messages": [
{"role": "user", "content": merged_prompt},
],
"temperature": 0.1,
"top_p": 1,
"stream": False,
**self._generation_settings(),
}
def build_structured_translation_payload(
@@ -188,9 +243,125 @@ class LMStudioTranslator:
],
}
],
"temperature": 0.1,
"top_p": 1,
"stream": False,
**self._generation_settings(),
}
@staticmethod
def _build_translation_segments(texts: List[str]) -> List[TranslationSegment]:
return [
TranslationSegment(id=str(index), text=text)
for index, text in enumerate(texts)
]
def build_contextual_batches(
self,
texts: List[str],
batch_size: int = DEFAULT_CONTEXTUAL_BATCH_SIZE,
context_segments: int = DEFAULT_CONTEXT_SEGMENTS,
) -> List[TranslationBatch]:
"""Group subtitle segments into small batches with surrounding context."""
if batch_size < 1:
raise ValueError("batch_size must be at least 1")
if context_segments < 0:
raise ValueError("context_segments cannot be negative")
segments = self._build_translation_segments(texts)
if not segments:
return []
batches: List[TranslationBatch] = []
start_index = 0
total_segments = len(segments)
while start_index < total_segments:
remaining = total_segments - start_index
current_batch_size = min(batch_size, remaining)
trailing_segments = remaining - current_batch_size
if 0 < trailing_segments < MIN_CONTEXTUAL_BATCH_SIZE and current_batch_size > MIN_CONTEXTUAL_BATCH_SIZE:
current_batch_size -= MIN_CONTEXTUAL_BATCH_SIZE - trailing_segments
end_index = start_index + current_batch_size
batches.append(
TranslationBatch(
previous_segments=segments[max(0, start_index - context_segments):start_index],
segments=segments[start_index:end_index],
next_segments=segments[end_index:min(total_segments, end_index + context_segments)],
)
)
start_index = end_index
return batches
def build_contextual_batch_request(
self,
batch: TranslationBatch,
source_language: str,
target_language: str,
) -> Dict[str, Any]:
"""Build the contextual JSON payload sent to the model."""
return {
"source_language": source_language or "auto",
"target_language": target_language,
"previous_segments": [segment.as_payload() for segment in batch.previous_segments],
"segments": [segment.as_payload() for segment in batch.segments],
"next_segments": [segment.as_payload() for segment in batch.next_segments],
}
def build_contextual_batch_payload(
self,
batch: TranslationBatch,
source_language: str,
target_language: str,
) -> Dict[str, Any]:
"""Build the LM Studio request for contextual subtitle batch translation."""
return {
"model": self.config.model,
"messages": [
{"role": "system", "content": _build_contextual_system_prompt(source_language, target_language)},
{
"role": "user",
"content": json.dumps(
self.build_contextual_batch_request(batch, source_language, target_language),
ensure_ascii=False,
indent=2,
),
},
],
**self._generation_settings(),
}
def _build_contextual_user_prompt(
self,
batch: TranslationBatch,
source_language: str,
target_language: str,
) -> str:
request_payload = self.build_contextual_batch_request(batch, source_language, target_language)
return (
f"{_build_contextual_system_prompt(source_language, target_language)}\n\n"
"USER PAYLOAD FORMAT:\n"
f"{json.dumps(request_payload, ensure_ascii=False, indent=2)}\n\n"
"EXPECTED OUTPUT:\n"
'{"translations":[{"id":"...","translated_text":"..."}]}'
)
def build_contextual_user_only_payload(
self,
batch: TranslationBatch,
source_language: str,
target_language: str,
) -> Dict[str, Any]:
"""Build a fallback contextual payload for models that require a user first turn."""
return {
"model": self.config.model,
"messages": [
{
"role": "user",
"content": self._build_contextual_user_prompt(batch, source_language, target_language),
}
],
**self._generation_settings(),
}
@staticmethod
@@ -219,6 +390,73 @@ class LMStudioTranslator:
return translated
@staticmethod
def parse_batch_translation_response(content: str, batch: TranslationBatch) -> List[str]:
"""Parse and validate the batch translation JSON response."""
try:
response_payload = json.loads(content)
except json.JSONDecodeError as exc:
raise TranslationError("LM Studio returned malformed JSON for batch translation.") from exc
if not isinstance(response_payload, dict):
raise TranslationError("LM Studio batch translation response must be a JSON object.")
translations = response_payload.get("translations")
if not isinstance(translations, list):
raise TranslationError("LM Studio batch translation response must include a 'translations' list.")
expected_ids = batch.segment_ids
actual_ids: List[str] = []
translated_texts: List[str] = []
for item in translations:
if not isinstance(item, dict):
raise TranslationError("LM Studio batch translation response contained a non-object translation item.")
segment_id = item.get("id")
translated_text = item.get("translated_text")
if not isinstance(segment_id, str) or not segment_id:
raise TranslationError("LM Studio batch translation response contained an invalid segment id.")
if not isinstance(translated_text, str):
raise TranslationError(
f"LM Studio batch translation response for segment '{segment_id}' did not contain a string translated_text."
)
actual_ids.append(segment_id)
translated_texts.append(translated_text.strip())
if len(actual_ids) != len(expected_ids):
raise TranslationError(
f"LM Studio batch translation response returned {len(actual_ids)} items "
f"for {len(expected_ids)} input segments."
)
missing_ids = [segment_id for segment_id in expected_ids if segment_id not in actual_ids]
unexpected_ids = [segment_id for segment_id in actual_ids if segment_id not in expected_ids]
if missing_ids or unexpected_ids:
raise TranslationError(
"LM Studio batch translation response ids did not match the request. "
f"Missing: {missing_ids or 'none'}. Unexpected: {unexpected_ids or 'none'}."
)
if actual_ids != expected_ids:
raise TranslationError("LM Studio batch translation response ids were out of order.")
validated_translations: List[str] = []
for segment, translated_text in zip(batch.segments, translated_texts):
if not segment.text.strip():
validated_translations.append("")
continue
if not translated_text:
raise TranslationError(f"LM Studio returned an empty translation for segment '{segment.id}'.")
validated_translations.append(translated_text)
return validated_translations
def _headers(self) -> Dict[str, str]:
return {
"Authorization": f"Bearer {self.config.api_key}",
@@ -305,23 +543,79 @@ class LMStudioTranslator:
raise TranslationError("LM Studio returned a non-JSON response.") from last_error
raise TranslationError(f"LM Studio request failed: {last_error}") from last_error
def _translate_contextual_batch(
self,
batch: TranslationBatch,
target_language: str,
source_language: str = "auto",
) -> List[str]:
"""Translate a single contextual subtitle batch with validation and retries."""
payload = self.build_contextual_batch_payload(batch, source_language, target_language)
last_error: Optional[Exception] = None
for attempt in range(1, self.config.max_retries + 1):
try:
response_content = self._post_chat_completion(payload)
return self.parse_batch_translation_response(response_content, batch)
except (httpx.HTTPError, ValueError, TranslationError) as exc:
last_error = exc
if self._should_retry_with_user_only_prompt(exc):
try:
fallback_payload = self.build_contextual_user_only_payload(batch, source_language, target_language)
fallback_content = self._post_chat_completion(fallback_payload)
return self.parse_batch_translation_response(fallback_content, batch)
except (httpx.HTTPError, ValueError, TranslationError) as fallback_exc:
last_error = fallback_exc
if self._should_retry_with_structured_translation_prompt(last_error):
try:
structured_payload = self.build_structured_translation_payload(
self._build_contextual_user_prompt(batch, source_language, target_language),
source_language,
target_language,
)
structured_content = self._post_chat_completion(structured_payload)
return self.parse_batch_translation_response(structured_content, batch)
except (httpx.HTTPError, ValueError, TranslationError) as structured_exc:
last_error = structured_exc
should_retry = self._should_retry(exc) or isinstance(last_error, TranslationError)
if attempt >= self.config.max_retries or not should_retry:
break
self._sleeper(self.config.retry_backoff_seconds * attempt)
if isinstance(last_error, TranslationError):
raise last_error
if isinstance(last_error, ValueError):
raise TranslationError("LM Studio returned a non-JSON response.") from last_error
raise TranslationError(f"LM Studio request failed: {last_error}") from last_error
def translate_segments(
self,
texts: List[str],
target_language: str,
source_language: str = "auto",
) -> List[str]:
"""Translate an ordered list of subtitle-like segments."""
results: List[str] = []
for text in texts:
results.append(
self.translate_text(
text=text,
"""Translate an ordered list of subtitle-like segments in contextual batches."""
if not texts:
return []
translated_segments: List[str] = []
for batch in self.build_contextual_batches(texts):
translated_segments.extend(
self._translate_contextual_batch(
batch=batch,
target_language=target_language,
source_language=source_language,
)
)
return results
if len(translated_segments) != len(texts):
raise TranslationError(
f"LM Studio returned {len(translated_segments)} translated segments for {len(texts)} inputs."
)
return translated_segments
def close(self) -> None:
if self._owns_client:

View File

@@ -2,6 +2,8 @@
from __future__ import annotations
import json
import httpx
import pytest
@@ -13,6 +15,25 @@ def _mock_client(handler):
return httpx.Client(transport=httpx.MockTransport(handler))
def _mock_batch_response(translations):
return httpx.Response(
200,
json={
"choices": [
{
"message": {
"content": json.dumps({"translations": translations}, ensure_ascii=False),
}
}
]
},
)
def _read_request_json(request: httpx.Request):
return json.loads(request.read().decode("utf-8"))
def test_translation_config_normalizes_base_url():
config = TranslationConfig.from_env(base_url="http://127.0.0.1:1234")
@@ -21,33 +42,78 @@ def test_translation_config_normalizes_base_url():
assert config.model == "gemma-3-4b-it"
def test_build_payload_includes_model_and_prompt():
def test_build_contextual_batch_payload_includes_neighboring_segments():
translator = LMStudioTranslator(TranslationConfig(), client=_mock_client(lambda request: None))
payload = translator.build_payload("Hello world", "en", "es")
batches = translator.build_contextual_batches(
["one", "two", "three", "four", "five", "six", "seven"],
)
assert [len(batch.segments) for batch in batches] == [4, 3]
payload = translator.build_contextual_batch_payload(batches[0], "en", "es")
user_payload = json.loads(payload["messages"][1]["content"])
assert payload["model"] == "gemma-3-4b-it"
assert payload["messages"][0]["role"] == "system"
assert "Translate the user-provided text from en to es." in payload["messages"][0]["content"]
assert payload["messages"][1]["content"] == "Hello world"
assert "expert audiovisual translator for dubbed video content" in payload["messages"][0]["content"]
assert payload["temperature"] == 0.0
assert payload["top_p"] == 1.0
assert user_payload == {
"source_language": "en",
"target_language": "es",
"previous_segments": [],
"segments": [
{"id": "0", "text": "one"},
{"id": "1", "text": "two"},
{"id": "2", "text": "three"},
{"id": "3", "text": "four"},
],
"next_segments": [
{"id": "4", "text": "five"},
{"id": "5", "text": "six"},
],
}
def test_translate_segments_preserves_order_and_blank_segments():
def test_translate_segments_batches_context_and_preserves_exact_mapping():
requests = []
def handler(request: httpx.Request) -> httpx.Response:
text = request.read().decode("utf-8")
if "first" in text:
content = "primero"
elif "third" in text:
content = "tercero"
else:
content = "desconocido"
return httpx.Response(200, json={"choices": [{"message": {"content": content}}]})
payload = _read_request_json(request)
batch_request = json.loads(payload["messages"][1]["content"])
requests.append(batch_request)
translations = []
for item in batch_request["segments"]:
translated_text = "" if not item["text"].strip() else f"es::{item['id']}::{item['text']}"
translations.append({"id": item["id"], "translated_text": translated_text})
return _mock_batch_response(translations)
translator = LMStudioTranslator(TranslationConfig(), client=_mock_client(handler))
translated = translator.translate_segments(["first", "", "third"], target_language="es", source_language="en")
translated = translator.translate_segments(
["first", "second", "", "fourth", "fifth", "sixth", "seventh"],
target_language="es",
source_language="en",
)
assert translated == ["primero", "", "tercero"]
assert translated == [
"es::0::first",
"es::1::second",
"",
"es::3::fourth",
"es::4::fifth",
"es::5::sixth",
"es::6::seventh",
]
assert len(requests) == 2
assert [item["id"] for item in requests[0]["segments"]] == ["0", "1", "2", "3"]
assert requests[0]["previous_segments"] == []
assert [item["id"] for item in requests[0]["next_segments"]] == ["4", "5"]
assert [item["id"] for item in requests[1]["previous_segments"]] == ["2", "3"]
assert [item["id"] for item in requests[1]["segments"]] == ["4", "5", "6"]
def test_retry_on_transient_http_error_then_succeeds():
@@ -76,6 +142,77 @@ def test_parse_response_content_rejects_empty_content():
LMStudioTranslator.parse_response_content({"choices": [{"message": {"content": " "}}]})
def test_parse_batch_translation_response_rejects_missing_ids():
translator = LMStudioTranslator(TranslationConfig(), client=_mock_client(lambda request: None))
batch = translator.build_contextual_batches(["one", "two", "three"], batch_size=3)[0]
with pytest.raises(TranslationError, match="ids did not match the request"):
LMStudioTranslator.parse_batch_translation_response(
json.dumps(
{
"translations": [
{"id": "0", "translated_text": "uno"},
{"id": "2", "translated_text": "dos"},
{"id": "2", "translated_text": "tres"},
]
}
),
batch,
)
def test_parse_batch_translation_response_rejects_out_of_order_ids():
translator = LMStudioTranslator(TranslationConfig(), client=_mock_client(lambda request: None))
batch = translator.build_contextual_batches(["one", "two", "three"], batch_size=3)[0]
with pytest.raises(TranslationError, match="out of order"):
LMStudioTranslator.parse_batch_translation_response(
json.dumps(
{
"translations": [
{"id": "1", "translated_text": "dos"},
{"id": "0", "translated_text": "uno"},
{"id": "2", "translated_text": "tres"},
]
}
),
batch,
)
def test_translate_segments_retries_on_malformed_json_batch_response():
attempts = {"count": 0}
def handler(request: httpx.Request) -> httpx.Response:
attempts["count"] += 1
payload = _read_request_json(request)
batch_request = json.loads(payload["messages"][1]["content"])
if attempts["count"] == 1:
return httpx.Response(200, json={"choices": [{"message": {"content": "not-json"}}]})
translations = [
{"id": item["id"], "translated_text": f"ok::{item['text']}"}
for item in batch_request["segments"]
]
return _mock_batch_response(translations)
translator = LMStudioTranslator(
TranslationConfig(max_retries=2),
client=_mock_client(handler),
sleeper=lambda _: None,
)
translated = translator.translate_segments(
["alpha", "beta", "gamma"],
target_language="es",
source_language="en",
)
assert translated == ["ok::alpha", "ok::beta", "ok::gamma"]
assert attempts["count"] == 2
def test_translate_text_raises_on_malformed_response():
def handler(request: httpx.Request) -> httpx.Response:
return httpx.Response(200, json={"choices": []})