diff --git a/src/translation.py b/src/translation.py index 2fb21c6..91998dd 100644 --- a/src/translation.py +++ b/src/translation.py @@ -2,6 +2,7 @@ from __future__ import annotations +import json import os import time from dataclasses import dataclass @@ -16,6 +17,9 @@ DEFAULT_LM_STUDIO_BASE_URL = "http://127.0.0.1:1234/v1" DEFAULT_LM_STUDIO_API_KEY = "lm-studio" DEFAULT_LM_STUDIO_MODEL = "gemma-3-4b-it" DEFAULT_TRANSLATION_BACKEND = "lmstudio" +DEFAULT_CONTEXTUAL_BATCH_SIZE = 5 +MIN_CONTEXTUAL_BATCH_SIZE = 3 +DEFAULT_CONTEXT_SEGMENTS = 2 def _normalize_base_url(base_url: str) -> str: @@ -118,6 +122,53 @@ def _build_system_prompt(source_language: str, target_language: str) -> str: ) +def _build_contextual_system_prompt(source_language: str, target_language: str) -> str: + source_descriptor = source_language or "auto" + return ( + "You are an expert audiovisual translator for dubbed video content.\n\n" + f"Translate subtitle segments from {source_descriptor} to {target_language} for natural spoken dubbing.\n\n" + "Rules:\n" + "- Preserve meaning, intent, tone, and subtext.\n" + "- Use surrounding subtitle context to resolve ambiguity.\n" + "- Do not summarize.\n" + "- Do not simplify unless needed for natural speech.\n" + "- Do not add explanations, notes, or commentary.\n" + "- Preserve humor, sarcasm, emotional tone, and register.\n" + "- Keep names, brands, URLs, emails, file paths, code, and product names unchanged unless transliteration is clearly needed.\n" + "- Keep the translation natural for spoken dubbing.\n" + "- Preserve segment boundaries exactly.\n" + "- Return exactly one translated item per input segment.\n" + "- Output order must match input order exactly.\n" + "- If an input segment is empty or whitespace-only, return an empty translated_text for that id.\n" + "- Output valid JSON only.\n" + '- Return this exact schema: {"translations":[{"id":"...","translated_text":"..."}]}.' + ) + + +@dataclass(frozen=True) +class TranslationSegment: + """A subtitle segment prepared for contextual batch translation.""" + + id: str + text: str + + def as_payload(self) -> Dict[str, str]: + return {"id": self.id, "text": self.text} + + +@dataclass(frozen=True) +class TranslationBatch: + """A contextual subtitle translation batch.""" + + previous_segments: List[TranslationSegment] + segments: List[TranslationSegment] + next_segments: List[TranslationSegment] + + @property + def segment_ids(self) -> List[str]: + return [segment.id for segment in self.segments] + + class LMStudioTranslator: """OpenAI-style chat completions client for LM Studio.""" @@ -133,6 +184,14 @@ class LMStudioTranslator: self._owns_client = client is None self._sleeper = sleeper + @staticmethod + def _generation_settings() -> Dict[str, Any]: + return { + "temperature": 0.0, + "top_p": 1.0, + "stream": False, + } + def build_payload(self, text: str, source_language: str, target_language: str) -> Dict[str, Any]: """Build the OpenAI-compatible chat completions payload.""" return { @@ -141,9 +200,7 @@ class LMStudioTranslator: {"role": "system", "content": _build_system_prompt(source_language, target_language)}, {"role": "user", "content": text}, ], - "temperature": 0.1, - "top_p": 1, - "stream": False, + **self._generation_settings(), } def build_user_only_payload( @@ -160,9 +217,7 @@ class LMStudioTranslator: "messages": [ {"role": "user", "content": merged_prompt}, ], - "temperature": 0.1, - "top_p": 1, - "stream": False, + **self._generation_settings(), } def build_structured_translation_payload( @@ -188,9 +243,125 @@ class LMStudioTranslator: ], } ], - "temperature": 0.1, - "top_p": 1, - "stream": False, + **self._generation_settings(), + } + + @staticmethod + def _build_translation_segments(texts: List[str]) -> List[TranslationSegment]: + return [ + TranslationSegment(id=str(index), text=text) + for index, text in enumerate(texts) + ] + + def build_contextual_batches( + self, + texts: List[str], + batch_size: int = DEFAULT_CONTEXTUAL_BATCH_SIZE, + context_segments: int = DEFAULT_CONTEXT_SEGMENTS, + ) -> List[TranslationBatch]: + """Group subtitle segments into small batches with surrounding context.""" + if batch_size < 1: + raise ValueError("batch_size must be at least 1") + if context_segments < 0: + raise ValueError("context_segments cannot be negative") + + segments = self._build_translation_segments(texts) + if not segments: + return [] + + batches: List[TranslationBatch] = [] + start_index = 0 + total_segments = len(segments) + + while start_index < total_segments: + remaining = total_segments - start_index + current_batch_size = min(batch_size, remaining) + trailing_segments = remaining - current_batch_size + + if 0 < trailing_segments < MIN_CONTEXTUAL_BATCH_SIZE and current_batch_size > MIN_CONTEXTUAL_BATCH_SIZE: + current_batch_size -= MIN_CONTEXTUAL_BATCH_SIZE - trailing_segments + + end_index = start_index + current_batch_size + batches.append( + TranslationBatch( + previous_segments=segments[max(0, start_index - context_segments):start_index], + segments=segments[start_index:end_index], + next_segments=segments[end_index:min(total_segments, end_index + context_segments)], + ) + ) + start_index = end_index + + return batches + + def build_contextual_batch_request( + self, + batch: TranslationBatch, + source_language: str, + target_language: str, + ) -> Dict[str, Any]: + """Build the contextual JSON payload sent to the model.""" + return { + "source_language": source_language or "auto", + "target_language": target_language, + "previous_segments": [segment.as_payload() for segment in batch.previous_segments], + "segments": [segment.as_payload() for segment in batch.segments], + "next_segments": [segment.as_payload() for segment in batch.next_segments], + } + + def build_contextual_batch_payload( + self, + batch: TranslationBatch, + source_language: str, + target_language: str, + ) -> Dict[str, Any]: + """Build the LM Studio request for contextual subtitle batch translation.""" + return { + "model": self.config.model, + "messages": [ + {"role": "system", "content": _build_contextual_system_prompt(source_language, target_language)}, + { + "role": "user", + "content": json.dumps( + self.build_contextual_batch_request(batch, source_language, target_language), + ensure_ascii=False, + indent=2, + ), + }, + ], + **self._generation_settings(), + } + + def _build_contextual_user_prompt( + self, + batch: TranslationBatch, + source_language: str, + target_language: str, + ) -> str: + request_payload = self.build_contextual_batch_request(batch, source_language, target_language) + return ( + f"{_build_contextual_system_prompt(source_language, target_language)}\n\n" + "USER PAYLOAD FORMAT:\n" + f"{json.dumps(request_payload, ensure_ascii=False, indent=2)}\n\n" + "EXPECTED OUTPUT:\n" + '{"translations":[{"id":"...","translated_text":"..."}]}' + ) + + def build_contextual_user_only_payload( + self, + batch: TranslationBatch, + source_language: str, + target_language: str, + ) -> Dict[str, Any]: + """Build a fallback contextual payload for models that require a user first turn.""" + return { + "model": self.config.model, + "messages": [ + { + "role": "user", + "content": self._build_contextual_user_prompt(batch, source_language, target_language), + } + ], + **self._generation_settings(), } @staticmethod @@ -219,6 +390,73 @@ class LMStudioTranslator: return translated + @staticmethod + def parse_batch_translation_response(content: str, batch: TranslationBatch) -> List[str]: + """Parse and validate the batch translation JSON response.""" + try: + response_payload = json.loads(content) + except json.JSONDecodeError as exc: + raise TranslationError("LM Studio returned malformed JSON for batch translation.") from exc + + if not isinstance(response_payload, dict): + raise TranslationError("LM Studio batch translation response must be a JSON object.") + + translations = response_payload.get("translations") + if not isinstance(translations, list): + raise TranslationError("LM Studio batch translation response must include a 'translations' list.") + + expected_ids = batch.segment_ids + actual_ids: List[str] = [] + translated_texts: List[str] = [] + + for item in translations: + if not isinstance(item, dict): + raise TranslationError("LM Studio batch translation response contained a non-object translation item.") + + segment_id = item.get("id") + translated_text = item.get("translated_text") + + if not isinstance(segment_id, str) or not segment_id: + raise TranslationError("LM Studio batch translation response contained an invalid segment id.") + + if not isinstance(translated_text, str): + raise TranslationError( + f"LM Studio batch translation response for segment '{segment_id}' did not contain a string translated_text." + ) + + actual_ids.append(segment_id) + translated_texts.append(translated_text.strip()) + + if len(actual_ids) != len(expected_ids): + raise TranslationError( + f"LM Studio batch translation response returned {len(actual_ids)} items " + f"for {len(expected_ids)} input segments." + ) + + missing_ids = [segment_id for segment_id in expected_ids if segment_id not in actual_ids] + unexpected_ids = [segment_id for segment_id in actual_ids if segment_id not in expected_ids] + if missing_ids or unexpected_ids: + raise TranslationError( + "LM Studio batch translation response ids did not match the request. " + f"Missing: {missing_ids or 'none'}. Unexpected: {unexpected_ids or 'none'}." + ) + + if actual_ids != expected_ids: + raise TranslationError("LM Studio batch translation response ids were out of order.") + + validated_translations: List[str] = [] + for segment, translated_text in zip(batch.segments, translated_texts): + if not segment.text.strip(): + validated_translations.append("") + continue + + if not translated_text: + raise TranslationError(f"LM Studio returned an empty translation for segment '{segment.id}'.") + + validated_translations.append(translated_text) + + return validated_translations + def _headers(self) -> Dict[str, str]: return { "Authorization": f"Bearer {self.config.api_key}", @@ -305,23 +543,79 @@ class LMStudioTranslator: raise TranslationError("LM Studio returned a non-JSON response.") from last_error raise TranslationError(f"LM Studio request failed: {last_error}") from last_error + def _translate_contextual_batch( + self, + batch: TranslationBatch, + target_language: str, + source_language: str = "auto", + ) -> List[str]: + """Translate a single contextual subtitle batch with validation and retries.""" + payload = self.build_contextual_batch_payload(batch, source_language, target_language) + last_error: Optional[Exception] = None + + for attempt in range(1, self.config.max_retries + 1): + try: + response_content = self._post_chat_completion(payload) + return self.parse_batch_translation_response(response_content, batch) + except (httpx.HTTPError, ValueError, TranslationError) as exc: + last_error = exc + if self._should_retry_with_user_only_prompt(exc): + try: + fallback_payload = self.build_contextual_user_only_payload(batch, source_language, target_language) + fallback_content = self._post_chat_completion(fallback_payload) + return self.parse_batch_translation_response(fallback_content, batch) + except (httpx.HTTPError, ValueError, TranslationError) as fallback_exc: + last_error = fallback_exc + if self._should_retry_with_structured_translation_prompt(last_error): + try: + structured_payload = self.build_structured_translation_payload( + self._build_contextual_user_prompt(batch, source_language, target_language), + source_language, + target_language, + ) + structured_content = self._post_chat_completion(structured_payload) + return self.parse_batch_translation_response(structured_content, batch) + except (httpx.HTTPError, ValueError, TranslationError) as structured_exc: + last_error = structured_exc + + should_retry = self._should_retry(exc) or isinstance(last_error, TranslationError) + if attempt >= self.config.max_retries or not should_retry: + break + + self._sleeper(self.config.retry_backoff_seconds * attempt) + + if isinstance(last_error, TranslationError): + raise last_error + if isinstance(last_error, ValueError): + raise TranslationError("LM Studio returned a non-JSON response.") from last_error + raise TranslationError(f"LM Studio request failed: {last_error}") from last_error + def translate_segments( self, texts: List[str], target_language: str, source_language: str = "auto", ) -> List[str]: - """Translate an ordered list of subtitle-like segments.""" - results: List[str] = [] - for text in texts: - results.append( - self.translate_text( - text=text, + """Translate an ordered list of subtitle-like segments in contextual batches.""" + if not texts: + return [] + + translated_segments: List[str] = [] + for batch in self.build_contextual_batches(texts): + translated_segments.extend( + self._translate_contextual_batch( + batch=batch, target_language=target_language, source_language=source_language, ) ) - return results + + if len(translated_segments) != len(texts): + raise TranslationError( + f"LM Studio returned {len(translated_segments)} translated segments for {len(texts)} inputs." + ) + + return translated_segments def close(self) -> None: if self._owns_client: diff --git a/tests/test_translation.py b/tests/test_translation.py index b067615..9c6bd7c 100644 --- a/tests/test_translation.py +++ b/tests/test_translation.py @@ -2,6 +2,8 @@ from __future__ import annotations +import json + import httpx import pytest @@ -13,6 +15,25 @@ def _mock_client(handler): return httpx.Client(transport=httpx.MockTransport(handler)) +def _mock_batch_response(translations): + return httpx.Response( + 200, + json={ + "choices": [ + { + "message": { + "content": json.dumps({"translations": translations}, ensure_ascii=False), + } + } + ] + }, + ) + + +def _read_request_json(request: httpx.Request): + return json.loads(request.read().decode("utf-8")) + + def test_translation_config_normalizes_base_url(): config = TranslationConfig.from_env(base_url="http://127.0.0.1:1234") @@ -21,33 +42,78 @@ def test_translation_config_normalizes_base_url(): assert config.model == "gemma-3-4b-it" -def test_build_payload_includes_model_and_prompt(): +def test_build_contextual_batch_payload_includes_neighboring_segments(): translator = LMStudioTranslator(TranslationConfig(), client=_mock_client(lambda request: None)) - payload = translator.build_payload("Hello world", "en", "es") + batches = translator.build_contextual_batches( + ["one", "two", "three", "four", "five", "six", "seven"], + ) + + assert [len(batch.segments) for batch in batches] == [4, 3] + + payload = translator.build_contextual_batch_payload(batches[0], "en", "es") + user_payload = json.loads(payload["messages"][1]["content"]) assert payload["model"] == "gemma-3-4b-it" assert payload["messages"][0]["role"] == "system" - assert "Translate the user-provided text from en to es." in payload["messages"][0]["content"] - assert payload["messages"][1]["content"] == "Hello world" + assert "expert audiovisual translator for dubbed video content" in payload["messages"][0]["content"] + assert payload["temperature"] == 0.0 + assert payload["top_p"] == 1.0 + assert user_payload == { + "source_language": "en", + "target_language": "es", + "previous_segments": [], + "segments": [ + {"id": "0", "text": "one"}, + {"id": "1", "text": "two"}, + {"id": "2", "text": "three"}, + {"id": "3", "text": "four"}, + ], + "next_segments": [ + {"id": "4", "text": "five"}, + {"id": "5", "text": "six"}, + ], + } -def test_translate_segments_preserves_order_and_blank_segments(): +def test_translate_segments_batches_context_and_preserves_exact_mapping(): + requests = [] + def handler(request: httpx.Request) -> httpx.Response: - text = request.read().decode("utf-8") - if "first" in text: - content = "primero" - elif "third" in text: - content = "tercero" - else: - content = "desconocido" - return httpx.Response(200, json={"choices": [{"message": {"content": content}}]}) + payload = _read_request_json(request) + batch_request = json.loads(payload["messages"][1]["content"]) + requests.append(batch_request) + + translations = [] + for item in batch_request["segments"]: + translated_text = "" if not item["text"].strip() else f"es::{item['id']}::{item['text']}" + translations.append({"id": item["id"], "translated_text": translated_text}) + + return _mock_batch_response(translations) translator = LMStudioTranslator(TranslationConfig(), client=_mock_client(handler)) - translated = translator.translate_segments(["first", "", "third"], target_language="es", source_language="en") + translated = translator.translate_segments( + ["first", "second", "", "fourth", "fifth", "sixth", "seventh"], + target_language="es", + source_language="en", + ) - assert translated == ["primero", "", "tercero"] + assert translated == [ + "es::0::first", + "es::1::second", + "", + "es::3::fourth", + "es::4::fifth", + "es::5::sixth", + "es::6::seventh", + ] + assert len(requests) == 2 + assert [item["id"] for item in requests[0]["segments"]] == ["0", "1", "2", "3"] + assert requests[0]["previous_segments"] == [] + assert [item["id"] for item in requests[0]["next_segments"]] == ["4", "5"] + assert [item["id"] for item in requests[1]["previous_segments"]] == ["2", "3"] + assert [item["id"] for item in requests[1]["segments"]] == ["4", "5", "6"] def test_retry_on_transient_http_error_then_succeeds(): @@ -76,6 +142,77 @@ def test_parse_response_content_rejects_empty_content(): LMStudioTranslator.parse_response_content({"choices": [{"message": {"content": " "}}]}) +def test_parse_batch_translation_response_rejects_missing_ids(): + translator = LMStudioTranslator(TranslationConfig(), client=_mock_client(lambda request: None)) + batch = translator.build_contextual_batches(["one", "two", "three"], batch_size=3)[0] + + with pytest.raises(TranslationError, match="ids did not match the request"): + LMStudioTranslator.parse_batch_translation_response( + json.dumps( + { + "translations": [ + {"id": "0", "translated_text": "uno"}, + {"id": "2", "translated_text": "dos"}, + {"id": "2", "translated_text": "tres"}, + ] + } + ), + batch, + ) + + +def test_parse_batch_translation_response_rejects_out_of_order_ids(): + translator = LMStudioTranslator(TranslationConfig(), client=_mock_client(lambda request: None)) + batch = translator.build_contextual_batches(["one", "two", "three"], batch_size=3)[0] + + with pytest.raises(TranslationError, match="out of order"): + LMStudioTranslator.parse_batch_translation_response( + json.dumps( + { + "translations": [ + {"id": "1", "translated_text": "dos"}, + {"id": "0", "translated_text": "uno"}, + {"id": "2", "translated_text": "tres"}, + ] + } + ), + batch, + ) + + +def test_translate_segments_retries_on_malformed_json_batch_response(): + attempts = {"count": 0} + + def handler(request: httpx.Request) -> httpx.Response: + attempts["count"] += 1 + payload = _read_request_json(request) + batch_request = json.loads(payload["messages"][1]["content"]) + + if attempts["count"] == 1: + return httpx.Response(200, json={"choices": [{"message": {"content": "not-json"}}]}) + + translations = [ + {"id": item["id"], "translated_text": f"ok::{item['text']}"} + for item in batch_request["segments"] + ] + return _mock_batch_response(translations) + + translator = LMStudioTranslator( + TranslationConfig(max_retries=2), + client=_mock_client(handler), + sleeper=lambda _: None, + ) + + translated = translator.translate_segments( + ["alpha", "beta", "gamma"], + target_language="es", + source_language="en", + ) + + assert translated == ["ok::alpha", "ok::beta", "ok::gamma"] + assert attempts["count"] == 2 + + def test_translate_text_raises_on_malformed_response(): def handler(request: httpx.Request) -> httpx.Response: return httpx.Response(200, json={"choices": []})