Compare commits
1 Commits
fix-vocal-
...
main
| Author | SHA1 | Date | |
|---|---|---|---|
| d83c57cda3 |
@@ -2,6 +2,7 @@
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import os
|
||||
import time
|
||||
from dataclasses import dataclass
|
||||
@@ -16,6 +17,9 @@ DEFAULT_LM_STUDIO_BASE_URL = "http://127.0.0.1:1234/v1"
|
||||
DEFAULT_LM_STUDIO_API_KEY = "lm-studio"
|
||||
DEFAULT_LM_STUDIO_MODEL = "gemma-3-4b-it"
|
||||
DEFAULT_TRANSLATION_BACKEND = "lmstudio"
|
||||
DEFAULT_CONTEXTUAL_BATCH_SIZE = 5
|
||||
MIN_CONTEXTUAL_BATCH_SIZE = 3
|
||||
DEFAULT_CONTEXT_SEGMENTS = 2
|
||||
|
||||
|
||||
def _normalize_base_url(base_url: str) -> str:
|
||||
@@ -118,6 +122,53 @@ def _build_system_prompt(source_language: str, target_language: str) -> str:
|
||||
)
|
||||
|
||||
|
||||
def _build_contextual_system_prompt(source_language: str, target_language: str) -> str:
|
||||
source_descriptor = source_language or "auto"
|
||||
return (
|
||||
"You are an expert audiovisual translator for dubbed video content.\n\n"
|
||||
f"Translate subtitle segments from {source_descriptor} to {target_language} for natural spoken dubbing.\n\n"
|
||||
"Rules:\n"
|
||||
"- Preserve meaning, intent, tone, and subtext.\n"
|
||||
"- Use surrounding subtitle context to resolve ambiguity.\n"
|
||||
"- Do not summarize.\n"
|
||||
"- Do not simplify unless needed for natural speech.\n"
|
||||
"- Do not add explanations, notes, or commentary.\n"
|
||||
"- Preserve humor, sarcasm, emotional tone, and register.\n"
|
||||
"- Keep names, brands, URLs, emails, file paths, code, and product names unchanged unless transliteration is clearly needed.\n"
|
||||
"- Keep the translation natural for spoken dubbing.\n"
|
||||
"- Preserve segment boundaries exactly.\n"
|
||||
"- Return exactly one translated item per input segment.\n"
|
||||
"- Output order must match input order exactly.\n"
|
||||
"- If an input segment is empty or whitespace-only, return an empty translated_text for that id.\n"
|
||||
"- Output valid JSON only.\n"
|
||||
'- Return this exact schema: {"translations":[{"id":"...","translated_text":"..."}]}.'
|
||||
)
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class TranslationSegment:
|
||||
"""A subtitle segment prepared for contextual batch translation."""
|
||||
|
||||
id: str
|
||||
text: str
|
||||
|
||||
def as_payload(self) -> Dict[str, str]:
|
||||
return {"id": self.id, "text": self.text}
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class TranslationBatch:
|
||||
"""A contextual subtitle translation batch."""
|
||||
|
||||
previous_segments: List[TranslationSegment]
|
||||
segments: List[TranslationSegment]
|
||||
next_segments: List[TranslationSegment]
|
||||
|
||||
@property
|
||||
def segment_ids(self) -> List[str]:
|
||||
return [segment.id for segment in self.segments]
|
||||
|
||||
|
||||
class LMStudioTranslator:
|
||||
"""OpenAI-style chat completions client for LM Studio."""
|
||||
|
||||
@@ -133,6 +184,14 @@ class LMStudioTranslator:
|
||||
self._owns_client = client is None
|
||||
self._sleeper = sleeper
|
||||
|
||||
@staticmethod
|
||||
def _generation_settings() -> Dict[str, Any]:
|
||||
return {
|
||||
"temperature": 0.0,
|
||||
"top_p": 1.0,
|
||||
"stream": False,
|
||||
}
|
||||
|
||||
def build_payload(self, text: str, source_language: str, target_language: str) -> Dict[str, Any]:
|
||||
"""Build the OpenAI-compatible chat completions payload."""
|
||||
return {
|
||||
@@ -141,9 +200,7 @@ class LMStudioTranslator:
|
||||
{"role": "system", "content": _build_system_prompt(source_language, target_language)},
|
||||
{"role": "user", "content": text},
|
||||
],
|
||||
"temperature": 0.1,
|
||||
"top_p": 1,
|
||||
"stream": False,
|
||||
**self._generation_settings(),
|
||||
}
|
||||
|
||||
def build_user_only_payload(
|
||||
@@ -160,9 +217,7 @@ class LMStudioTranslator:
|
||||
"messages": [
|
||||
{"role": "user", "content": merged_prompt},
|
||||
],
|
||||
"temperature": 0.1,
|
||||
"top_p": 1,
|
||||
"stream": False,
|
||||
**self._generation_settings(),
|
||||
}
|
||||
|
||||
def build_structured_translation_payload(
|
||||
@@ -188,9 +243,125 @@ class LMStudioTranslator:
|
||||
],
|
||||
}
|
||||
],
|
||||
"temperature": 0.1,
|
||||
"top_p": 1,
|
||||
"stream": False,
|
||||
**self._generation_settings(),
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def _build_translation_segments(texts: List[str]) -> List[TranslationSegment]:
|
||||
return [
|
||||
TranslationSegment(id=str(index), text=text)
|
||||
for index, text in enumerate(texts)
|
||||
]
|
||||
|
||||
def build_contextual_batches(
|
||||
self,
|
||||
texts: List[str],
|
||||
batch_size: int = DEFAULT_CONTEXTUAL_BATCH_SIZE,
|
||||
context_segments: int = DEFAULT_CONTEXT_SEGMENTS,
|
||||
) -> List[TranslationBatch]:
|
||||
"""Group subtitle segments into small batches with surrounding context."""
|
||||
if batch_size < 1:
|
||||
raise ValueError("batch_size must be at least 1")
|
||||
if context_segments < 0:
|
||||
raise ValueError("context_segments cannot be negative")
|
||||
|
||||
segments = self._build_translation_segments(texts)
|
||||
if not segments:
|
||||
return []
|
||||
|
||||
batches: List[TranslationBatch] = []
|
||||
start_index = 0
|
||||
total_segments = len(segments)
|
||||
|
||||
while start_index < total_segments:
|
||||
remaining = total_segments - start_index
|
||||
current_batch_size = min(batch_size, remaining)
|
||||
trailing_segments = remaining - current_batch_size
|
||||
|
||||
if 0 < trailing_segments < MIN_CONTEXTUAL_BATCH_SIZE and current_batch_size > MIN_CONTEXTUAL_BATCH_SIZE:
|
||||
current_batch_size -= MIN_CONTEXTUAL_BATCH_SIZE - trailing_segments
|
||||
|
||||
end_index = start_index + current_batch_size
|
||||
batches.append(
|
||||
TranslationBatch(
|
||||
previous_segments=segments[max(0, start_index - context_segments):start_index],
|
||||
segments=segments[start_index:end_index],
|
||||
next_segments=segments[end_index:min(total_segments, end_index + context_segments)],
|
||||
)
|
||||
)
|
||||
start_index = end_index
|
||||
|
||||
return batches
|
||||
|
||||
def build_contextual_batch_request(
|
||||
self,
|
||||
batch: TranslationBatch,
|
||||
source_language: str,
|
||||
target_language: str,
|
||||
) -> Dict[str, Any]:
|
||||
"""Build the contextual JSON payload sent to the model."""
|
||||
return {
|
||||
"source_language": source_language or "auto",
|
||||
"target_language": target_language,
|
||||
"previous_segments": [segment.as_payload() for segment in batch.previous_segments],
|
||||
"segments": [segment.as_payload() for segment in batch.segments],
|
||||
"next_segments": [segment.as_payload() for segment in batch.next_segments],
|
||||
}
|
||||
|
||||
def build_contextual_batch_payload(
|
||||
self,
|
||||
batch: TranslationBatch,
|
||||
source_language: str,
|
||||
target_language: str,
|
||||
) -> Dict[str, Any]:
|
||||
"""Build the LM Studio request for contextual subtitle batch translation."""
|
||||
return {
|
||||
"model": self.config.model,
|
||||
"messages": [
|
||||
{"role": "system", "content": _build_contextual_system_prompt(source_language, target_language)},
|
||||
{
|
||||
"role": "user",
|
||||
"content": json.dumps(
|
||||
self.build_contextual_batch_request(batch, source_language, target_language),
|
||||
ensure_ascii=False,
|
||||
indent=2,
|
||||
),
|
||||
},
|
||||
],
|
||||
**self._generation_settings(),
|
||||
}
|
||||
|
||||
def _build_contextual_user_prompt(
|
||||
self,
|
||||
batch: TranslationBatch,
|
||||
source_language: str,
|
||||
target_language: str,
|
||||
) -> str:
|
||||
request_payload = self.build_contextual_batch_request(batch, source_language, target_language)
|
||||
return (
|
||||
f"{_build_contextual_system_prompt(source_language, target_language)}\n\n"
|
||||
"USER PAYLOAD FORMAT:\n"
|
||||
f"{json.dumps(request_payload, ensure_ascii=False, indent=2)}\n\n"
|
||||
"EXPECTED OUTPUT:\n"
|
||||
'{"translations":[{"id":"...","translated_text":"..."}]}'
|
||||
)
|
||||
|
||||
def build_contextual_user_only_payload(
|
||||
self,
|
||||
batch: TranslationBatch,
|
||||
source_language: str,
|
||||
target_language: str,
|
||||
) -> Dict[str, Any]:
|
||||
"""Build a fallback contextual payload for models that require a user first turn."""
|
||||
return {
|
||||
"model": self.config.model,
|
||||
"messages": [
|
||||
{
|
||||
"role": "user",
|
||||
"content": self._build_contextual_user_prompt(batch, source_language, target_language),
|
||||
}
|
||||
],
|
||||
**self._generation_settings(),
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
@@ -219,6 +390,73 @@ class LMStudioTranslator:
|
||||
|
||||
return translated
|
||||
|
||||
@staticmethod
|
||||
def parse_batch_translation_response(content: str, batch: TranslationBatch) -> List[str]:
|
||||
"""Parse and validate the batch translation JSON response."""
|
||||
try:
|
||||
response_payload = json.loads(content)
|
||||
except json.JSONDecodeError as exc:
|
||||
raise TranslationError("LM Studio returned malformed JSON for batch translation.") from exc
|
||||
|
||||
if not isinstance(response_payload, dict):
|
||||
raise TranslationError("LM Studio batch translation response must be a JSON object.")
|
||||
|
||||
translations = response_payload.get("translations")
|
||||
if not isinstance(translations, list):
|
||||
raise TranslationError("LM Studio batch translation response must include a 'translations' list.")
|
||||
|
||||
expected_ids = batch.segment_ids
|
||||
actual_ids: List[str] = []
|
||||
translated_texts: List[str] = []
|
||||
|
||||
for item in translations:
|
||||
if not isinstance(item, dict):
|
||||
raise TranslationError("LM Studio batch translation response contained a non-object translation item.")
|
||||
|
||||
segment_id = item.get("id")
|
||||
translated_text = item.get("translated_text")
|
||||
|
||||
if not isinstance(segment_id, str) or not segment_id:
|
||||
raise TranslationError("LM Studio batch translation response contained an invalid segment id.")
|
||||
|
||||
if not isinstance(translated_text, str):
|
||||
raise TranslationError(
|
||||
f"LM Studio batch translation response for segment '{segment_id}' did not contain a string translated_text."
|
||||
)
|
||||
|
||||
actual_ids.append(segment_id)
|
||||
translated_texts.append(translated_text.strip())
|
||||
|
||||
if len(actual_ids) != len(expected_ids):
|
||||
raise TranslationError(
|
||||
f"LM Studio batch translation response returned {len(actual_ids)} items "
|
||||
f"for {len(expected_ids)} input segments."
|
||||
)
|
||||
|
||||
missing_ids = [segment_id for segment_id in expected_ids if segment_id not in actual_ids]
|
||||
unexpected_ids = [segment_id for segment_id in actual_ids if segment_id not in expected_ids]
|
||||
if missing_ids or unexpected_ids:
|
||||
raise TranslationError(
|
||||
"LM Studio batch translation response ids did not match the request. "
|
||||
f"Missing: {missing_ids or 'none'}. Unexpected: {unexpected_ids or 'none'}."
|
||||
)
|
||||
|
||||
if actual_ids != expected_ids:
|
||||
raise TranslationError("LM Studio batch translation response ids were out of order.")
|
||||
|
||||
validated_translations: List[str] = []
|
||||
for segment, translated_text in zip(batch.segments, translated_texts):
|
||||
if not segment.text.strip():
|
||||
validated_translations.append("")
|
||||
continue
|
||||
|
||||
if not translated_text:
|
||||
raise TranslationError(f"LM Studio returned an empty translation for segment '{segment.id}'.")
|
||||
|
||||
validated_translations.append(translated_text)
|
||||
|
||||
return validated_translations
|
||||
|
||||
def _headers(self) -> Dict[str, str]:
|
||||
return {
|
||||
"Authorization": f"Bearer {self.config.api_key}",
|
||||
@@ -305,23 +543,79 @@ class LMStudioTranslator:
|
||||
raise TranslationError("LM Studio returned a non-JSON response.") from last_error
|
||||
raise TranslationError(f"LM Studio request failed: {last_error}") from last_error
|
||||
|
||||
def _translate_contextual_batch(
|
||||
self,
|
||||
batch: TranslationBatch,
|
||||
target_language: str,
|
||||
source_language: str = "auto",
|
||||
) -> List[str]:
|
||||
"""Translate a single contextual subtitle batch with validation and retries."""
|
||||
payload = self.build_contextual_batch_payload(batch, source_language, target_language)
|
||||
last_error: Optional[Exception] = None
|
||||
|
||||
for attempt in range(1, self.config.max_retries + 1):
|
||||
try:
|
||||
response_content = self._post_chat_completion(payload)
|
||||
return self.parse_batch_translation_response(response_content, batch)
|
||||
except (httpx.HTTPError, ValueError, TranslationError) as exc:
|
||||
last_error = exc
|
||||
if self._should_retry_with_user_only_prompt(exc):
|
||||
try:
|
||||
fallback_payload = self.build_contextual_user_only_payload(batch, source_language, target_language)
|
||||
fallback_content = self._post_chat_completion(fallback_payload)
|
||||
return self.parse_batch_translation_response(fallback_content, batch)
|
||||
except (httpx.HTTPError, ValueError, TranslationError) as fallback_exc:
|
||||
last_error = fallback_exc
|
||||
if self._should_retry_with_structured_translation_prompt(last_error):
|
||||
try:
|
||||
structured_payload = self.build_structured_translation_payload(
|
||||
self._build_contextual_user_prompt(batch, source_language, target_language),
|
||||
source_language,
|
||||
target_language,
|
||||
)
|
||||
structured_content = self._post_chat_completion(structured_payload)
|
||||
return self.parse_batch_translation_response(structured_content, batch)
|
||||
except (httpx.HTTPError, ValueError, TranslationError) as structured_exc:
|
||||
last_error = structured_exc
|
||||
|
||||
should_retry = self._should_retry(exc) or isinstance(last_error, TranslationError)
|
||||
if attempt >= self.config.max_retries or not should_retry:
|
||||
break
|
||||
|
||||
self._sleeper(self.config.retry_backoff_seconds * attempt)
|
||||
|
||||
if isinstance(last_error, TranslationError):
|
||||
raise last_error
|
||||
if isinstance(last_error, ValueError):
|
||||
raise TranslationError("LM Studio returned a non-JSON response.") from last_error
|
||||
raise TranslationError(f"LM Studio request failed: {last_error}") from last_error
|
||||
|
||||
def translate_segments(
|
||||
self,
|
||||
texts: List[str],
|
||||
target_language: str,
|
||||
source_language: str = "auto",
|
||||
) -> List[str]:
|
||||
"""Translate an ordered list of subtitle-like segments."""
|
||||
results: List[str] = []
|
||||
for text in texts:
|
||||
results.append(
|
||||
self.translate_text(
|
||||
text=text,
|
||||
"""Translate an ordered list of subtitle-like segments in contextual batches."""
|
||||
if not texts:
|
||||
return []
|
||||
|
||||
translated_segments: List[str] = []
|
||||
for batch in self.build_contextual_batches(texts):
|
||||
translated_segments.extend(
|
||||
self._translate_contextual_batch(
|
||||
batch=batch,
|
||||
target_language=target_language,
|
||||
source_language=source_language,
|
||||
)
|
||||
)
|
||||
return results
|
||||
|
||||
if len(translated_segments) != len(texts):
|
||||
raise TranslationError(
|
||||
f"LM Studio returned {len(translated_segments)} translated segments for {len(texts)} inputs."
|
||||
)
|
||||
|
||||
return translated_segments
|
||||
|
||||
def close(self) -> None:
|
||||
if self._owns_client:
|
||||
|
||||
@@ -2,6 +2,8 @@
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
|
||||
import httpx
|
||||
import pytest
|
||||
|
||||
@@ -13,6 +15,25 @@ def _mock_client(handler):
|
||||
return httpx.Client(transport=httpx.MockTransport(handler))
|
||||
|
||||
|
||||
def _mock_batch_response(translations):
|
||||
return httpx.Response(
|
||||
200,
|
||||
json={
|
||||
"choices": [
|
||||
{
|
||||
"message": {
|
||||
"content": json.dumps({"translations": translations}, ensure_ascii=False),
|
||||
}
|
||||
}
|
||||
]
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
def _read_request_json(request: httpx.Request):
|
||||
return json.loads(request.read().decode("utf-8"))
|
||||
|
||||
|
||||
def test_translation_config_normalizes_base_url():
|
||||
config = TranslationConfig.from_env(base_url="http://127.0.0.1:1234")
|
||||
|
||||
@@ -21,33 +42,78 @@ def test_translation_config_normalizes_base_url():
|
||||
assert config.model == "gemma-3-4b-it"
|
||||
|
||||
|
||||
def test_build_payload_includes_model_and_prompt():
|
||||
def test_build_contextual_batch_payload_includes_neighboring_segments():
|
||||
translator = LMStudioTranslator(TranslationConfig(), client=_mock_client(lambda request: None))
|
||||
|
||||
payload = translator.build_payload("Hello world", "en", "es")
|
||||
batches = translator.build_contextual_batches(
|
||||
["one", "two", "three", "four", "five", "six", "seven"],
|
||||
)
|
||||
|
||||
assert [len(batch.segments) for batch in batches] == [4, 3]
|
||||
|
||||
payload = translator.build_contextual_batch_payload(batches[0], "en", "es")
|
||||
user_payload = json.loads(payload["messages"][1]["content"])
|
||||
|
||||
assert payload["model"] == "gemma-3-4b-it"
|
||||
assert payload["messages"][0]["role"] == "system"
|
||||
assert "Translate the user-provided text from en to es." in payload["messages"][0]["content"]
|
||||
assert payload["messages"][1]["content"] == "Hello world"
|
||||
assert "expert audiovisual translator for dubbed video content" in payload["messages"][0]["content"]
|
||||
assert payload["temperature"] == 0.0
|
||||
assert payload["top_p"] == 1.0
|
||||
assert user_payload == {
|
||||
"source_language": "en",
|
||||
"target_language": "es",
|
||||
"previous_segments": [],
|
||||
"segments": [
|
||||
{"id": "0", "text": "one"},
|
||||
{"id": "1", "text": "two"},
|
||||
{"id": "2", "text": "three"},
|
||||
{"id": "3", "text": "four"},
|
||||
],
|
||||
"next_segments": [
|
||||
{"id": "4", "text": "five"},
|
||||
{"id": "5", "text": "six"},
|
||||
],
|
||||
}
|
||||
|
||||
|
||||
def test_translate_segments_preserves_order_and_blank_segments():
|
||||
def test_translate_segments_batches_context_and_preserves_exact_mapping():
|
||||
requests = []
|
||||
|
||||
def handler(request: httpx.Request) -> httpx.Response:
|
||||
text = request.read().decode("utf-8")
|
||||
if "first" in text:
|
||||
content = "primero"
|
||||
elif "third" in text:
|
||||
content = "tercero"
|
||||
else:
|
||||
content = "desconocido"
|
||||
return httpx.Response(200, json={"choices": [{"message": {"content": content}}]})
|
||||
payload = _read_request_json(request)
|
||||
batch_request = json.loads(payload["messages"][1]["content"])
|
||||
requests.append(batch_request)
|
||||
|
||||
translations = []
|
||||
for item in batch_request["segments"]:
|
||||
translated_text = "" if not item["text"].strip() else f"es::{item['id']}::{item['text']}"
|
||||
translations.append({"id": item["id"], "translated_text": translated_text})
|
||||
|
||||
return _mock_batch_response(translations)
|
||||
|
||||
translator = LMStudioTranslator(TranslationConfig(), client=_mock_client(handler))
|
||||
|
||||
translated = translator.translate_segments(["first", "", "third"], target_language="es", source_language="en")
|
||||
translated = translator.translate_segments(
|
||||
["first", "second", "", "fourth", "fifth", "sixth", "seventh"],
|
||||
target_language="es",
|
||||
source_language="en",
|
||||
)
|
||||
|
||||
assert translated == ["primero", "", "tercero"]
|
||||
assert translated == [
|
||||
"es::0::first",
|
||||
"es::1::second",
|
||||
"",
|
||||
"es::3::fourth",
|
||||
"es::4::fifth",
|
||||
"es::5::sixth",
|
||||
"es::6::seventh",
|
||||
]
|
||||
assert len(requests) == 2
|
||||
assert [item["id"] for item in requests[0]["segments"]] == ["0", "1", "2", "3"]
|
||||
assert requests[0]["previous_segments"] == []
|
||||
assert [item["id"] for item in requests[0]["next_segments"]] == ["4", "5"]
|
||||
assert [item["id"] for item in requests[1]["previous_segments"]] == ["2", "3"]
|
||||
assert [item["id"] for item in requests[1]["segments"]] == ["4", "5", "6"]
|
||||
|
||||
|
||||
def test_retry_on_transient_http_error_then_succeeds():
|
||||
@@ -76,6 +142,77 @@ def test_parse_response_content_rejects_empty_content():
|
||||
LMStudioTranslator.parse_response_content({"choices": [{"message": {"content": " "}}]})
|
||||
|
||||
|
||||
def test_parse_batch_translation_response_rejects_missing_ids():
|
||||
translator = LMStudioTranslator(TranslationConfig(), client=_mock_client(lambda request: None))
|
||||
batch = translator.build_contextual_batches(["one", "two", "three"], batch_size=3)[0]
|
||||
|
||||
with pytest.raises(TranslationError, match="ids did not match the request"):
|
||||
LMStudioTranslator.parse_batch_translation_response(
|
||||
json.dumps(
|
||||
{
|
||||
"translations": [
|
||||
{"id": "0", "translated_text": "uno"},
|
||||
{"id": "2", "translated_text": "dos"},
|
||||
{"id": "2", "translated_text": "tres"},
|
||||
]
|
||||
}
|
||||
),
|
||||
batch,
|
||||
)
|
||||
|
||||
|
||||
def test_parse_batch_translation_response_rejects_out_of_order_ids():
|
||||
translator = LMStudioTranslator(TranslationConfig(), client=_mock_client(lambda request: None))
|
||||
batch = translator.build_contextual_batches(["one", "two", "three"], batch_size=3)[0]
|
||||
|
||||
with pytest.raises(TranslationError, match="out of order"):
|
||||
LMStudioTranslator.parse_batch_translation_response(
|
||||
json.dumps(
|
||||
{
|
||||
"translations": [
|
||||
{"id": "1", "translated_text": "dos"},
|
||||
{"id": "0", "translated_text": "uno"},
|
||||
{"id": "2", "translated_text": "tres"},
|
||||
]
|
||||
}
|
||||
),
|
||||
batch,
|
||||
)
|
||||
|
||||
|
||||
def test_translate_segments_retries_on_malformed_json_batch_response():
|
||||
attempts = {"count": 0}
|
||||
|
||||
def handler(request: httpx.Request) -> httpx.Response:
|
||||
attempts["count"] += 1
|
||||
payload = _read_request_json(request)
|
||||
batch_request = json.loads(payload["messages"][1]["content"])
|
||||
|
||||
if attempts["count"] == 1:
|
||||
return httpx.Response(200, json={"choices": [{"message": {"content": "not-json"}}]})
|
||||
|
||||
translations = [
|
||||
{"id": item["id"], "translated_text": f"ok::{item['text']}"}
|
||||
for item in batch_request["segments"]
|
||||
]
|
||||
return _mock_batch_response(translations)
|
||||
|
||||
translator = LMStudioTranslator(
|
||||
TranslationConfig(max_retries=2),
|
||||
client=_mock_client(handler),
|
||||
sleeper=lambda _: None,
|
||||
)
|
||||
|
||||
translated = translator.translate_segments(
|
||||
["alpha", "beta", "gamma"],
|
||||
target_language="es",
|
||||
source_language="en",
|
||||
)
|
||||
|
||||
assert translated == ["ok::alpha", "ok::beta", "ok::gamma"]
|
||||
assert attempts["count"] == 2
|
||||
|
||||
|
||||
def test_translate_text_raises_on_malformed_response():
|
||||
def handler(request: httpx.Request) -> httpx.Response:
|
||||
return httpx.Response(200, json={"choices": []})
|
||||
|
||||
Reference in New Issue
Block a user