"""Tests for transient Edge TTS retry behavior.""" from __future__ import annotations import asyncio from src import engines from src.engines import Engine from src.translation import TranslationConfig def test_synthesize_retries_transient_edge_tts_failure(tmp_path, monkeypatch): calls = {"count": 0} class FakeCommunicate: def __init__(self, text, voice, rate): self.text = text self.voice = voice self.rate = rate async def save(self, out_path): calls["count"] += 1 if calls["count"] == 1: raise RuntimeError("transient 503") with open(out_path, "wb") as audio_file: audio_file.write(b"0" * 2048) async def no_sleep(_seconds): return None monkeypatch.setattr("src.engines.edge_tts.Communicate", FakeCommunicate) monkeypatch.setattr("src.engines.asyncio.sleep", no_sleep) monkeypatch.setattr("src.engines.DEFAULT_TTS_MAX_RETRIES", 2) engine = Engine( "cpu", translation_config=TranslationConfig( base_url="http://127.0.0.1:1234/v1", api_key="test-key", model="test-model", ), ) out_path = tmp_path / "tts.mp3" asyncio.run(engine.synthesize("Bonjour", "fr", out_path)) assert calls["count"] == 2 assert out_path.exists() assert out_path.stat().st_size == 2048 def test_synthesize_uses_tts_cache(tmp_path, monkeypatch): calls = {"count": 0} cache_dir = tmp_path / "tts-cache" monkeypatch.setattr(engines, "TTS_CACHE_DIR", cache_dir) class FakeCommunicate: def __init__(self, text, voice, rate): self.text = text self.voice = voice self.rate = rate async def save(self, out_path): calls["count"] += 1 with open(out_path, "wb") as audio_file: audio_file.write(b"1" * 2048) monkeypatch.setattr("src.engines.edge_tts.Communicate", FakeCommunicate) engine = Engine( "cpu", translation_config=TranslationConfig( base_url="http://127.0.0.1:1234/v1", api_key="test-key", model="test-model", ), ) first_out = tmp_path / "first.mp3" second_out = tmp_path / "second.mp3" asyncio.run(engine.synthesize("Bonjour", "fr", first_out)) asyncio.run(engine.synthesize("Bonjour", "fr", second_out)) assert calls["count"] == 1 assert first_out.read_bytes() == second_out.read_bytes()