diff --git a/src/google/adk/memory/in_memory_memory_service.py b/src/google/adk/memory/in_memory_memory_service.py index 1d666a39b1d..af4b4c26d92 100644 --- a/src/google/adk/memory/in_memory_memory_service.py +++ b/src/google/adk/memory/in_memory_memory_service.py @@ -38,8 +38,8 @@ def _user_key(app_name: str, user_id: str) -> str: def _extract_words_lower(text: str) -> set[str]: - """Extracts words from a string and converts them to lowercase.""" - return set([word.lower() for word in re.findall(r'\w+', text, re.UNICODE)]) + """Extracts Unicode-aware tokens from a string in lowercase.""" + return set(word.lower() for word in re.findall(r'\w+', text)) class InMemoryMemoryService(BaseMemoryService): @@ -116,13 +116,19 @@ async def search_memory( for event in session_events: if not event.content or not event.content.parts: continue - words_in_event = _extract_words_lower( - ' '.join([part.text for part in event.content.parts if part.text]) + event_text = ' '.join( + [part.text for part in event.content.parts if part.text] ) + words_in_event = _extract_words_lower(event_text) if not words_in_event: continue - if any(query_word in words_in_event for query_word in words_in_query): + event_text_lower = event_text.lower() + if any( + query_word in words_in_event + or (not query_word.isascii() and query_word in event_text_lower) + for query_word in words_in_query + ): response.memories.append( MemoryEntry( content=event.content, diff --git a/tests/unittests/memory/test_in_memory_memory_service.py b/tests/unittests/memory/test_in_memory_memory_service.py index c80fd832b18..6c590cddec4 100644 --- a/tests/unittests/memory/test_in_memory_memory_service.py +++ b/tests/unittests/memory/test_in_memory_memory_service.py @@ -329,30 +329,51 @@ async def test_search_memory_is_scoped_by_user(): ) +# --- Non-Latin language tests --- + + @pytest.mark.asyncio -async def test_search_memory_matches_non_latin_text(): - """Tests that search matches non-Latin (e.g. Cyrillic) text.""" - memory_service = InMemoryMemoryService() +@pytest.mark.parametrize( + 'event_text,query,expected_count', + [ + # Japanese (no space delimiters — substring fallback) + ('私の名前は太郎です', '太郎', 1), + ('私の名前は太郎です', '天気', 0), + # Chinese (no space delimiters — substring fallback) + ('我喜欢机器学习', '机器学习', 1), + ('我喜欢机器学习', '天气预报', 0), + # Korean (space-delimited — token match) + ('제 이름은 민수입니다', '민수입니다', 1), + # Cyrillic (space-delimited — token match) + ('Меня зовут Алексей', 'Алексей', 1), + # Mixed: non-Latin substring + Latin token in same event + ('太郎 works at ABC Corp', '太郎', 1), + ('太郎 works at ABC Corp', 'ABC', 1), + # Latin partial-word must NOT match (regression guard) + ('I like to code in Python.', 'thon', 0), + ], +) +async def test_search_memory_non_latin(event_text, query, expected_count): + """Tests search_memory with non-Latin scripts and mixed content.""" session = Session( app_name=MOCK_APP_NAME, user_id=MOCK_USER_ID, - id='session-non-latin', - last_update_time=5000, + id='session-i18n', + last_update_time=7000, events=[ Event( - id='event-non-latin', - invocation_id='inv-non-latin', + id='event-i18n', + invocation_id='inv-i18n', author='user', - timestamp=70000, - content=types.Content(parts=[types.Part(text='Привет мир')]), + timestamp=90000, + content=types.Content(parts=[types.Part(text=event_text)]), ), ], ) + memory_service = InMemoryMemoryService() await memory_service.add_session_to_memory(session) result = await memory_service.search_memory( - app_name=MOCK_APP_NAME, user_id=MOCK_USER_ID, query='привет' + app_name=MOCK_APP_NAME, user_id=MOCK_USER_ID, query=query ) - - assert len(result.memories) == 1 - assert result.memories[0].content.parts[0].text == 'Привет мир' + assert len(result.memories) == expected_count