diff --git a/gateway/run.py b/gateway/run.py index bcd2457b9..0b8794924 100644 --- a/gateway/run.py +++ b/gateway/run.py @@ -1701,6 +1701,9 @@ class GatewayRunner: content = f"[Delivered from {mirror_src}] {content}" agent_history.append({"role": role, "content": content}) + # Track history length to only scan NEW messages for MEDIA tags + history_len = len(agent_history) + result = agent.run_conversation(message, conversation_history=agent_history) result_holder[0] = result @@ -1721,10 +1724,17 @@ class GatewayRunner: # doesn't include them. We collect unique tags from tool results and # append any that aren't already present in the final response, so the # adapter's extract_media() can find and deliver the files exactly once. + # + # IMPORTANT: Only scan messages from the CURRENT turn (after history_len), + # not the full history. This prevents TTS voice messages from earlier + # turns being re-attached to every subsequent reply. (Fixes #160) if "MEDIA:" not in final_response: media_tags = [] has_voice_directive = False - for msg in result.get("messages", []): + all_messages = result.get("messages", []) + # Only process new messages from this turn + new_messages = all_messages[history_len:] if len(all_messages) > history_len else [] + for msg in new_messages: if msg.get("role") == "tool" or msg.get("role") == "function": content = msg.get("content", "") if "MEDIA:" in content: diff --git a/tests/gateway/test_media_extraction.py b/tests/gateway/test_media_extraction.py new file mode 100644 index 000000000..20f7d73a8 --- /dev/null +++ b/tests/gateway/test_media_extraction.py @@ -0,0 +1,184 @@ +""" +Tests for MEDIA tag extraction from tool results. + +Verifies that MEDIA tags (e.g., from TTS tool) are only extracted from +messages in the CURRENT turn, not from the full conversation history. +This prevents voice messages from accumulating and being sent multiple +times per reply. (Regression test for #160) +""" + +import pytest +import re + + +def extract_media_tags_fixed(result_messages, history_len): + """ + Extract MEDIA tags from tool results, but ONLY from new messages + (those added after history_len). This is the fixed behavior. + + Args: + result_messages: Full list of messages including history + new + history_len: Length of history before this turn + + Returns: + Tuple of (media_tags list, has_voice_directive bool) + """ + media_tags = [] + has_voice_directive = False + + # Only process new messages from this turn + new_messages = result_messages[history_len:] if len(result_messages) > history_len else [] + + for msg in new_messages: + if msg.get("role") == "tool" or msg.get("role") == "function": + content = msg.get("content", "") + if "MEDIA:" in content: + for match in re.finditer(r'MEDIA:(\S+)', content): + path = match.group(1).strip().rstrip('",}') + if path: + media_tags.append(f"MEDIA:{path}") + if "[[audio_as_voice]]" in content: + has_voice_directive = True + + return media_tags, has_voice_directive + + +def extract_media_tags_broken(result_messages): + """ + The BROKEN behavior: extract MEDIA tags from ALL messages including history. + This causes TTS voice messages to accumulate and be re-sent on every reply. + """ + media_tags = [] + has_voice_directive = False + + for msg in result_messages: + if msg.get("role") == "tool" or msg.get("role") == "function": + content = msg.get("content", "") + if "MEDIA:" in content: + for match in re.finditer(r'MEDIA:(\S+)', content): + path = match.group(1).strip().rstrip('",}') + if path: + media_tags.append(f"MEDIA:{path}") + if "[[audio_as_voice]]" in content: + has_voice_directive = True + + return media_tags, has_voice_directive + + +class TestMediaExtraction: + """Tests for MEDIA tag extraction from tool results.""" + + def test_media_tags_not_extracted_from_history(self): + """MEDIA tags from previous turns should NOT be extracted again.""" + # Simulate conversation history with a TTS call from a previous turn + history = [ + {"role": "user", "content": "Say hello as audio"}, + {"role": "assistant", "content": None, "tool_calls": [{"id": "1", "function": {"name": "text_to_speech"}}]}, + {"role": "tool", "tool_call_id": "1", "content": '{"success": true, "media_tag": "[[audio_as_voice]]\\nMEDIA:/path/to/audio1.ogg"}'}, + {"role": "assistant", "content": "I've said hello for you!"}, + ] + + # New turn: user asks a simple question + new_messages = [ + {"role": "user", "content": "What time is it?"}, + {"role": "assistant", "content": "It's 3:30 AM."}, + ] + + all_messages = history + new_messages + history_len = len(history) + + # Fixed behavior: should extract NO media tags (none in new messages) + tags, voice_directive = extract_media_tags_fixed(all_messages, history_len) + assert tags == [], "Fixed extraction should not find tags in history" + assert voice_directive is False + + # Broken behavior: would incorrectly extract the old media tag + broken_tags, broken_voice = extract_media_tags_broken(all_messages) + assert len(broken_tags) == 1, "Broken extraction finds tags in history" + assert "audio1.ogg" in broken_tags[0] + + def test_media_tags_extracted_from_current_turn(self): + """MEDIA tags from the current turn SHOULD be extracted.""" + # History without TTS + history = [ + {"role": "user", "content": "Hello"}, + {"role": "assistant", "content": "Hi there!"}, + ] + + # New turn with TTS call + new_messages = [ + {"role": "user", "content": "Say goodbye as audio"}, + {"role": "assistant", "content": None, "tool_calls": [{"id": "2", "function": {"name": "text_to_speech"}}]}, + {"role": "tool", "tool_call_id": "2", "content": '{"success": true, "media_tag": "[[audio_as_voice]]\\nMEDIA:/path/to/audio2.ogg"}'}, + {"role": "assistant", "content": "I've said goodbye!"}, + ] + + all_messages = history + new_messages + history_len = len(history) + + # Fixed behavior: should extract the new media tag + tags, voice_directive = extract_media_tags_fixed(all_messages, history_len) + assert len(tags) == 1, "Should extract media tag from current turn" + assert "audio2.ogg" in tags[0] + assert voice_directive is True + + def test_multiple_tts_calls_in_history_not_accumulated(self): + """Multiple TTS calls in history should NOT accumulate in new responses.""" + # History with multiple TTS calls + history = [ + {"role": "user", "content": "Say hello"}, + {"role": "tool", "tool_call_id": "1", "content": 'MEDIA:/audio/hello.ogg'}, + {"role": "assistant", "content": "Done!"}, + {"role": "user", "content": "Say goodbye"}, + {"role": "tool", "tool_call_id": "2", "content": 'MEDIA:/audio/goodbye.ogg'}, + {"role": "assistant", "content": "Done!"}, + {"role": "user", "content": "Say thanks"}, + {"role": "tool", "tool_call_id": "3", "content": 'MEDIA:/audio/thanks.ogg'}, + {"role": "assistant", "content": "Done!"}, + ] + + # New turn: no TTS + new_messages = [ + {"role": "user", "content": "What time is it?"}, + {"role": "assistant", "content": "3 PM"}, + ] + + all_messages = history + new_messages + history_len = len(history) + + # Fixed: no tags + tags, _ = extract_media_tags_fixed(all_messages, history_len) + assert tags == [], "Should not accumulate tags from history" + + # Broken: would have 3 tags (all the old ones) + broken_tags, _ = extract_media_tags_broken(all_messages) + assert len(broken_tags) == 3, "Broken version accumulates all history tags" + + def test_deduplication_within_current_turn(self): + """Multiple MEDIA tags in current turn should be deduplicated.""" + history = [] + + # Current turn with multiple tool calls producing same media + new_messages = [ + {"role": "user", "content": "Multiple TTS"}, + {"role": "tool", "tool_call_id": "1", "content": 'MEDIA:/audio/same.ogg'}, + {"role": "tool", "tool_call_id": "2", "content": 'MEDIA:/audio/same.ogg'}, # duplicate + {"role": "tool", "tool_call_id": "3", "content": 'MEDIA:/audio/different.ogg'}, + {"role": "assistant", "content": "Done!"}, + ] + + all_messages = history + new_messages + + tags, _ = extract_media_tags_fixed(all_messages, 0) + # Even though same.ogg appears twice, deduplication happens after extraction + # The extraction itself should get both, then caller deduplicates + assert len(tags) == 3 # Raw extraction gets all + + # Deduplication as done in the actual code: + seen = set() + unique = [t for t in tags if t not in seen and not seen.add(t)] + assert len(unique) == 2 # After dedup: same.ogg and different.ogg + + +if __name__ == "__main__": + pytest.main([__file__, "-v"])