From 6fb07d20ccb6e8385d5c60df9d8f42c9cbcd1553 Mon Sep 17 00:00:00 2001 From: Yuhong Sun Date: Sun, 19 Nov 2023 10:55:55 -0800 Subject: [PATCH] Multilingual Query Expansion (#737) --- backend/danswer/configs/app_configs.py | 4 +- backend/danswer/direct_qa/qa_block.py | 24 ++++++-- backend/danswer/main.py | 6 ++ backend/danswer/prompts/direct_qa_prompts.py | 7 +++ .../danswer/prompts/secondary_llm_flows.py | 12 ++++ backend/danswer/search/search_runner.py | 56 +++++++++++++++++-- .../secondary_llm_flows/query_expansion.py | 48 ++++++++++++++++ .../docker_compose/docker-compose.dev.yml | 9 +++ .../docker_compose/env.multilingual.template | 41 ++++++++++++++ 9 files changed, 196 insertions(+), 11 deletions(-) create mode 100644 backend/danswer/secondary_llm_flows/query_expansion.py create mode 100644 deployment/docker_compose/env.multilingual.template diff --git a/backend/danswer/configs/app_configs.py b/backend/danswer/configs/app_configs.py index 16b200d1d..fcba5decc 100644 --- a/backend/danswer/configs/app_configs.py +++ b/backend/danswer/configs/app_configs.py @@ -219,7 +219,9 @@ else: EDIT_KEYWORD_QUERY = not os.environ.get("DOCUMENT_ENCODER_MODEL") # Weighting factor between Vector and Keyword Search, 1 for completely vector search HYBRID_ALPHA = max(0, min(1, float(os.environ.get("HYBRID_ALPHA") or 0.6))) - +# A list of languages passed to the LLM to rephase the query +# For example "English,French,Spanish", be sure to use the "," separator +MULTILINGUAL_QUERY_EXPANSION = os.environ.get("MULTILINGUAL_QUERY_EXPANSION") or None ##### # Model Server Configs diff --git a/backend/danswer/direct_qa/qa_block.py b/backend/danswer/direct_qa/qa_block.py index 0ea404c7d..c0df1666d 100644 --- a/backend/danswer/direct_qa/qa_block.py +++ b/backend/danswer/direct_qa/qa_block.py @@ -6,6 +6,7 @@ from collections.abc import Iterator from langchain.schema.messages import BaseMessage from langchain.schema.messages import HumanMessage +from danswer.configs.app_configs import MULTILINGUAL_QUERY_EXPANSION from danswer.direct_qa.interfaces import AnswerQuestionReturn from danswer.direct_qa.interfaces import AnswerQuestionStreamReturn from danswer.direct_qa.interfaces import DanswerAnswer @@ -22,6 +23,7 @@ from danswer.llm.utils import tokenizer_trim_chunks from danswer.prompts.constants import CODE_BLOCK_PAT from danswer.prompts.direct_qa_prompts import COT_PROMPT from danswer.prompts.direct_qa_prompts import JSON_PROMPT +from danswer.prompts.direct_qa_prompts import LANGUAGE_HINT from danswer.prompts.direct_qa_prompts import WEAK_LLM_PROMPT from danswer.utils.logger import setup_logger from danswer.utils.text_processing import clean_up_code_blocks @@ -88,15 +90,20 @@ class SingleMessageQAHandler(QAHandler): return True def build_prompt( - self, query: str, context_chunks: list[InferenceChunk] + self, + query: str, + context_chunks: list[InferenceChunk], + use_language_hint: bool = bool(MULTILINGUAL_QUERY_EXPANSION), ) -> list[BaseMessage]: context_docs_str = "\n".join( f"\n{CODE_BLOCK_PAT.format(c.content)}\n" for c in context_chunks ) single_message = JSON_PROMPT.format( - context_docs_str=context_docs_str, user_query=query - ) + context_docs_str=context_docs_str, + user_query=query, + language_hint_or_none=LANGUAGE_HINT if use_language_hint else "", + ).strip() prompt: list[BaseMessage] = [HumanMessage(content=single_message)] return prompt @@ -111,15 +118,20 @@ class SingleMessageScratchpadHandler(QAHandler): return True def build_prompt( - self, query: str, context_chunks: list[InferenceChunk] + self, + query: str, + context_chunks: list[InferenceChunk], + use_language_hint: bool = bool(MULTILINGUAL_QUERY_EXPANSION), ) -> list[BaseMessage]: context_docs_str = "\n".join( f"\n{CODE_BLOCK_PAT.format(c.content)}\n" for c in context_chunks ) single_message = COT_PROMPT.format( - context_docs_str=context_docs_str, user_query=query - ) + context_docs_str=context_docs_str, + user_query=query, + language_hint_or_none=LANGUAGE_HINT if use_language_hint else "", + ).strip() prompt: list[BaseMessage] = [HumanMessage(content=single_message)] return prompt diff --git a/backend/danswer/main.py b/backend/danswer/main.py index 4990fc0f6..222a6eed2 100644 --- a/backend/danswer/main.py +++ b/backend/danswer/main.py @@ -21,6 +21,7 @@ from danswer.configs.app_configs import AUTH_TYPE from danswer.configs.app_configs import DISABLE_GENERATIVE_AI from danswer.configs.app_configs import MODEL_SERVER_HOST from danswer.configs.app_configs import MODEL_SERVER_PORT +from danswer.configs.app_configs import MULTILINGUAL_QUERY_EXPANSION from danswer.configs.app_configs import OAUTH_CLIENT_ID from danswer.configs.app_configs import OAUTH_CLIENT_SECRET from danswer.configs.app_configs import SECRET @@ -175,6 +176,11 @@ def get_application() -> FastAPI: if GEN_AI_API_ENDPOINT: logger.info(f"Using LLM Endpoint: {GEN_AI_API_ENDPOINT}") + if MULTILINGUAL_QUERY_EXPANSION: + logger.info( + f"Using multilingual flow with languages: {MULTILINGUAL_QUERY_EXPANSION}" + ) + if SKIP_RERANKING: logger.info("Reranking step of search flow is disabled") diff --git a/backend/danswer/prompts/direct_qa_prompts.py b/backend/danswer/prompts/direct_qa_prompts.py index f4ffc0798..bf47ff659 100644 --- a/backend/danswer/prompts/direct_qa_prompts.py +++ b/backend/danswer/prompts/direct_qa_prompts.py @@ -27,6 +27,11 @@ Quotes MUST be EXACT substrings from provided documents! """.strip() +LANGUAGE_HINT = """ +IMPORTANT: Respond in the same language as my query! +""".strip() + + # This has to be doubly escaped due to json containing { } which are also used for format strings EMPTY_SAMPLE_JSON = { "answer": "Place your final answer here. It should be as DETAILED and INFORMATIVE as possible.", @@ -54,6 +59,7 @@ SAMPLE_RESPONSE: ``` {QUESTION_PAT} {{user_query}} {JSON_HELPFUL_HINT} +{{language_hint_or_none}} """.strip() @@ -75,6 +81,7 @@ You MUST respond in the following format: {QUESTION_PAT} {{user_query}} {JSON_HELPFUL_HINT} +{{language_hint_or_none}} """.strip() diff --git a/backend/danswer/prompts/secondary_llm_flows.py b/backend/danswer/prompts/secondary_llm_flows.py index 484c468a6..fecd98597 100644 --- a/backend/danswer/prompts/secondary_llm_flows.py +++ b/backend/danswer/prompts/secondary_llm_flows.py @@ -155,6 +155,18 @@ Respond with EXACTLY AND ONLY: "{USEFUL_PAT}" or "{NONUSEFUL_PAT}" """.strip() +LANGUAGE_REPHRASE_PROMPT = """ +Rephrase the query in {target_language}. +If the query is already in the correct language, \ +simply repeat the ORIGINAL query back to me, EXACTLY as is with no rephrasing. +NEVER change proper nouns, technical terms, acronyms, or terms you are not familiar with. +IMPORTANT, if the query is already in the target language, DO NOT REPHRASE OR EDIT the query! + +Query: +{query} +""".strip() + + # User the following for easy viewing of prompts if __name__ == "__main__": print(ANSWERABLE_PROMPT) diff --git a/backend/danswer/search/search_runner.py b/backend/danswer/search/search_runner.py index e5c78cce7..807eaadb6 100644 --- a/backend/danswer/search/search_runner.py +++ b/backend/danswer/search/search_runner.py @@ -1,4 +1,6 @@ from collections.abc import Callable +from copy import deepcopy +from typing import Any from typing import cast import numpy @@ -10,6 +12,7 @@ from sqlalchemy.orm import Session from danswer.configs.app_configs import DISABLE_LLM_CHUNK_FILTER from danswer.configs.app_configs import HYBRID_ALPHA +from danswer.configs.app_configs import MULTILINGUAL_QUERY_EXPANSION from danswer.configs.app_configs import NUM_RERANKED_RESULTS from danswer.configs.model_configs import ASYM_QUERY_PREFIX from danswer.configs.model_configs import CROSS_ENCODER_RANGE_MAX @@ -36,11 +39,13 @@ from danswer.search.models import SearchType from danswer.search.search_nlp_models import CrossEncoderEnsembleModel from danswer.search.search_nlp_models import EmbeddingModel from danswer.secondary_llm_flows.chunk_usefulness import llm_batch_eval_chunks +from danswer.secondary_llm_flows.query_expansion import rephrase_query from danswer.server.models import QuestionRequest from danswer.server.models import SearchDoc from danswer.utils.logger import setup_logger from danswer.utils.threadpool_concurrency import FunctionCall from danswer.utils.threadpool_concurrency import run_functions_in_parallel +from danswer.utils.threadpool_concurrency import run_functions_tuples_in_parallel from danswer.utils.timing import log_function_time @@ -108,6 +113,30 @@ def chunks_to_search_docs(chunks: list[InferenceChunk] | None) -> list[SearchDoc return search_docs +def combine_retrieval_results( + chunk_sets: list[list[InferenceChunk]], +) -> list[InferenceChunk]: + all_chunks = [chunk for chunk_set in chunk_sets for chunk in chunk_set] + + unique_chunks: dict[tuple[str, int], InferenceChunk] = {} + for chunk in all_chunks: + key = (chunk.document_id, chunk.chunk_id) + if key not in unique_chunks: + unique_chunks[key] = chunk + continue + + stored_chunk_score = unique_chunks[key].score or 0 + this_chunk_score = chunk.score or 0 + if stored_chunk_score < this_chunk_score: + unique_chunks[key] = chunk + + sorted_chunks = sorted( + unique_chunks.values(), key=lambda x: x.score or 0, reverse=True + ) + + return sorted_chunks + + @log_function_time() def doc_index_retrieval( query: SearchQuery, @@ -313,6 +342,7 @@ def search_chunks( query: SearchQuery, document_index: DocumentIndex, hybrid_alpha: float = HYBRID_ALPHA, # Only applicable to hybrid search + multilingual_query_expansion: str | None = MULTILINGUAL_QUERY_EXPANSION, retrieval_metrics_callback: Callable[[RetrievalMetricsContainer], None] | None = None, rerank_metrics_callback: Callable[[RerankMetricsContainer], None] | None = None, @@ -331,9 +361,25 @@ def search_chunks( ] logger.info(f"Top links from {search_flow} search: {', '.join(top_links)}") - top_chunks = doc_index_retrieval( - query=query, document_index=document_index, hybrid_alpha=hybrid_alpha - ) + # Don't do query expansion on complex queries, rephrasings likely would not work well + if not multilingual_query_expansion or "\n" in query.query or "\r" in query.query: + top_chunks = doc_index_retrieval( + query=query, document_index=document_index, hybrid_alpha=hybrid_alpha + ) + else: + run_queries: list[tuple[Callable, tuple]] = [] + # Currently only uses query expansion on multilingual use cases + query_rephrases = rephrase_query(query.query, multilingual_query_expansion) + # Just to be extra sure, add the original query. + query_rephrases.append(query.query) + for rephrase in set(query_rephrases): + q_copy = deepcopy(query) + q_copy.query = rephrase + run_queries.append( + (doc_index_retrieval, (q_copy, document_index, hybrid_alpha)) + ) + parallel_search_results = run_functions_tuples_in_parallel(run_queries) + top_chunks = combine_retrieval_results(parallel_search_results) if not top_chunks: logger.info( @@ -384,7 +430,9 @@ def search_chunks( functions_to_run.append(run_llm_filter) run_llm_filter_id = run_llm_filter.result_id - parallel_results = run_functions_in_parallel(functions_to_run) + parallel_results: dict[str, Any] = {} + if functions_to_run: + parallel_results = run_functions_in_parallel(functions_to_run) ranked_results = parallel_results.get(str(run_rerank_id)) if ranked_results is None: diff --git a/backend/danswer/secondary_llm_flows/query_expansion.py b/backend/danswer/secondary_llm_flows/query_expansion.py new file mode 100644 index 000000000..874ca2131 --- /dev/null +++ b/backend/danswer/secondary_llm_flows/query_expansion.py @@ -0,0 +1,48 @@ +from collections.abc import Callable + +from danswer.llm.factory import get_default_llm +from danswer.llm.utils import dict_based_prompt_to_langchain_prompt +from danswer.prompts.secondary_llm_flows import LANGUAGE_REPHRASE_PROMPT +from danswer.utils.logger import setup_logger +from danswer.utils.threadpool_concurrency import run_functions_tuples_in_parallel + +logger = setup_logger() + + +def llm_rephrase_query(query: str, language: str) -> str: + def _get_rephrase_messages() -> list[dict[str, str]]: + messages = [ + { + "role": "user", + "content": LANGUAGE_REPHRASE_PROMPT.format( + query=query, target_language=language + ), + }, + ] + + return messages + + messages = _get_rephrase_messages() + filled_llm_prompt = dict_based_prompt_to_langchain_prompt(messages) + model_output = get_default_llm().invoke(filled_llm_prompt) + logger.debug(model_output) + + return model_output + + +def rephrase_query( + query: str, + multilingual_query_expansion: str, + use_threads: bool = True, +) -> list[str]: + languages = multilingual_query_expansion.split(",") + languages = [language.strip() for language in languages] + if use_threads: + functions_with_args: list[tuple[Callable, tuple]] = [ + (llm_rephrase_query, (query, language)) for language in languages + ] + + return run_functions_tuples_in_parallel(functions_with_args) + + else: + return [llm_rephrase_query(query, language) for language in languages] diff --git a/deployment/docker_compose/docker-compose.dev.yml b/deployment/docker_compose/docker-compose.dev.yml index 7dd763563..ac809ce84 100644 --- a/deployment/docker_compose/docker-compose.dev.yml +++ b/deployment/docker_compose/docker-compose.dev.yml @@ -47,6 +47,7 @@ services: - SKIP_RERANKING=${SKIP_RERANKING:-} - QA_PROMPT_OVERRIDE=${QA_PROMPT_OVERRIDE:-} - EDIT_KEYWORD_QUERY=${EDIT_KEYWORD_QUERY:-} + - MULTILINGUAL_QUERY_EXPANSION=${MULTILINGUAL_QUERY_EXPANSION:-} - MODEL_SERVER_HOST=${MODEL_SERVER_HOST:-} - MODEL_SERVER_PORT=${MODEL_SERVER_PORT:-} # Leave this on pretty please? Nothing sensitive is collected! @@ -54,6 +55,8 @@ services: - DISABLE_TELEMETRY=${DISABLE_TELEMETRY:-} # Set to debug to get more fine-grained logs - LOG_LEVEL=${LOG_LEVEL:-info} + # Log all of the prompts to the LLM + - LOG_ALL_MODEL_INTERACTIONS=${LOG_ALL_MODEL_INTERACTIONS:-info} volumes: - local_dynamic_storage:/home/storage - file_connector_tmp_storage:/home/file_connector_storage @@ -106,11 +109,17 @@ services: - SKIP_RERANKING=${SKIP_RERANKING:-} - QA_PROMPT_OVERRIDE=${QA_PROMPT_OVERRIDE:-} - EDIT_KEYWORD_QUERY=${EDIT_KEYWORD_QUERY:-} + - MULTILINGUAL_QUERY_EXPANSION=${MULTILINGUAL_QUERY_EXPANSION:-} - MIN_THREADS_ML_MODELS=${MIN_THREADS_ML_MODELS:-} - MODEL_SERVER_HOST=${MODEL_SERVER_HOST:-} - MODEL_SERVER_PORT=${MODEL_SERVER_PORT:-} + # Leave this on pretty please? Nothing sensitive is collected! + # https://docs.danswer.dev/more/telemetry + - DISABLE_TELEMETRY=${DISABLE_TELEMETRY:-} # Set to debug to get more fine-grained logs - LOG_LEVEL=${LOG_LEVEL:-info} + # Log all of the prompts to the LLM + - LOG_ALL_MODEL_INTERACTIONS=${LOG_ALL_MODEL_INTERACTIONS:-info} volumes: - local_dynamic_storage:/home/storage - file_connector_tmp_storage:/home/file_connector_storage diff --git a/deployment/docker_compose/env.multilingual.template b/deployment/docker_compose/env.multilingual.template new file mode 100644 index 000000000..2083bddd2 --- /dev/null +++ b/deployment/docker_compose/env.multilingual.template @@ -0,0 +1,41 @@ +# This env template shows how to configure Danswer for multilingual use +# In this case, it is configured for French and English +# To use it, copy it to .env in the docker_compose directory. +# Feel free to combine it with the other templates to suit your needs + + +# A recent MIT license multilingual model: https://huggingface.co/intfloat/multilingual-e5-small +DOCUMENT_ENCODER_MODEL="intfloat/multilingual-e5-small" + +# The model above is trained with the following prefix for queries and passages to improve retrieval +# by letting the model know which of the two type is currently being embedded +ASYM_QUERY_PREFIX="query: " +ASYM_PASSAGE_PREFIX="passage: " + +# Depends model by model, this one is tuned with this as True +NORMALIZE_EMBEDDINGS="True" + +# Due to the loss function used in training, this model outputs similarity scores from range ~0.6 to 1 +SIM_SCORE_RANGE_LOW="0.6" +SIM_SCORE_RANGE_LOW="0.8" + +# No recent multilingual reranking models small enough to run on CPU, so turning it off +SKIP_RERANKING="True" + +# Use LLM to determine if chunks are relevant to the query +# may not work well for languages that do not have much training data in the LLM training set +DISABLE_LLM_CHUNK_FILTER="True" + +# Rephrase the user query in specified languages using LLM, use comma separated values +MULTILINGUAL_QUERY_EXPANSION="English, French" + +# Enables fine-grained embeddings for better retrieval +# At the cost of indexing speed (~5x slower), query time is same speed +ENABLE_MINI_CHUNK="True" + +# Stronger model will help with multilingual tasks +GEN_AI_MODEL_VERSION="gpt-4" +GEN_AI_API_KEY= + +# More verbose logging if desired +LOG_LEVEL="debug"