Multilingual Query Expansion (#737)

This commit is contained in:
Yuhong Sun 2023-11-19 10:55:55 -08:00 committed by GitHub
parent b258ec1bed
commit 6fb07d20cc
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
9 changed files with 196 additions and 11 deletions

View File

@ -219,7 +219,9 @@ else:
EDIT_KEYWORD_QUERY = not os.environ.get("DOCUMENT_ENCODER_MODEL")
# Weighting factor between Vector and Keyword Search, 1 for completely vector search
HYBRID_ALPHA = max(0, min(1, float(os.environ.get("HYBRID_ALPHA") or 0.6)))
# A list of languages passed to the LLM to rephase the query
# For example "English,French,Spanish", be sure to use the "," separator
MULTILINGUAL_QUERY_EXPANSION = os.environ.get("MULTILINGUAL_QUERY_EXPANSION") or None
#####
# Model Server Configs

View File

@ -6,6 +6,7 @@ from collections.abc import Iterator
from langchain.schema.messages import BaseMessage
from langchain.schema.messages import HumanMessage
from danswer.configs.app_configs import MULTILINGUAL_QUERY_EXPANSION
from danswer.direct_qa.interfaces import AnswerQuestionReturn
from danswer.direct_qa.interfaces import AnswerQuestionStreamReturn
from danswer.direct_qa.interfaces import DanswerAnswer
@ -22,6 +23,7 @@ from danswer.llm.utils import tokenizer_trim_chunks
from danswer.prompts.constants import CODE_BLOCK_PAT
from danswer.prompts.direct_qa_prompts import COT_PROMPT
from danswer.prompts.direct_qa_prompts import JSON_PROMPT
from danswer.prompts.direct_qa_prompts import LANGUAGE_HINT
from danswer.prompts.direct_qa_prompts import WEAK_LLM_PROMPT
from danswer.utils.logger import setup_logger
from danswer.utils.text_processing import clean_up_code_blocks
@ -88,15 +90,20 @@ class SingleMessageQAHandler(QAHandler):
return True
def build_prompt(
self, query: str, context_chunks: list[InferenceChunk]
self,
query: str,
context_chunks: list[InferenceChunk],
use_language_hint: bool = bool(MULTILINGUAL_QUERY_EXPANSION),
) -> list[BaseMessage]:
context_docs_str = "\n".join(
f"\n{CODE_BLOCK_PAT.format(c.content)}\n" for c in context_chunks
)
single_message = JSON_PROMPT.format(
context_docs_str=context_docs_str, user_query=query
)
context_docs_str=context_docs_str,
user_query=query,
language_hint_or_none=LANGUAGE_HINT if use_language_hint else "",
).strip()
prompt: list[BaseMessage] = [HumanMessage(content=single_message)]
return prompt
@ -111,15 +118,20 @@ class SingleMessageScratchpadHandler(QAHandler):
return True
def build_prompt(
self, query: str, context_chunks: list[InferenceChunk]
self,
query: str,
context_chunks: list[InferenceChunk],
use_language_hint: bool = bool(MULTILINGUAL_QUERY_EXPANSION),
) -> list[BaseMessage]:
context_docs_str = "\n".join(
f"\n{CODE_BLOCK_PAT.format(c.content)}\n" for c in context_chunks
)
single_message = COT_PROMPT.format(
context_docs_str=context_docs_str, user_query=query
)
context_docs_str=context_docs_str,
user_query=query,
language_hint_or_none=LANGUAGE_HINT if use_language_hint else "",
).strip()
prompt: list[BaseMessage] = [HumanMessage(content=single_message)]
return prompt

View File

@ -21,6 +21,7 @@ from danswer.configs.app_configs import AUTH_TYPE
from danswer.configs.app_configs import DISABLE_GENERATIVE_AI
from danswer.configs.app_configs import MODEL_SERVER_HOST
from danswer.configs.app_configs import MODEL_SERVER_PORT
from danswer.configs.app_configs import MULTILINGUAL_QUERY_EXPANSION
from danswer.configs.app_configs import OAUTH_CLIENT_ID
from danswer.configs.app_configs import OAUTH_CLIENT_SECRET
from danswer.configs.app_configs import SECRET
@ -175,6 +176,11 @@ def get_application() -> FastAPI:
if GEN_AI_API_ENDPOINT:
logger.info(f"Using LLM Endpoint: {GEN_AI_API_ENDPOINT}")
if MULTILINGUAL_QUERY_EXPANSION:
logger.info(
f"Using multilingual flow with languages: {MULTILINGUAL_QUERY_EXPANSION}"
)
if SKIP_RERANKING:
logger.info("Reranking step of search flow is disabled")

View File

@ -27,6 +27,11 @@ Quotes MUST be EXACT substrings from provided documents!
""".strip()
LANGUAGE_HINT = """
IMPORTANT: Respond in the same language as my query!
""".strip()
# This has to be doubly escaped due to json containing { } which are also used for format strings
EMPTY_SAMPLE_JSON = {
"answer": "Place your final answer here. It should be as DETAILED and INFORMATIVE as possible.",
@ -54,6 +59,7 @@ SAMPLE_RESPONSE:
```
{QUESTION_PAT} {{user_query}}
{JSON_HELPFUL_HINT}
{{language_hint_or_none}}
""".strip()
@ -75,6 +81,7 @@ You MUST respond in the following format:
{QUESTION_PAT} {{user_query}}
{JSON_HELPFUL_HINT}
{{language_hint_or_none}}
""".strip()

View File

@ -155,6 +155,18 @@ Respond with EXACTLY AND ONLY: "{USEFUL_PAT}" or "{NONUSEFUL_PAT}"
""".strip()
LANGUAGE_REPHRASE_PROMPT = """
Rephrase the query in {target_language}.
If the query is already in the correct language, \
simply repeat the ORIGINAL query back to me, EXACTLY as is with no rephrasing.
NEVER change proper nouns, technical terms, acronyms, or terms you are not familiar with.
IMPORTANT, if the query is already in the target language, DO NOT REPHRASE OR EDIT the query!
Query:
{query}
""".strip()
# User the following for easy viewing of prompts
if __name__ == "__main__":
print(ANSWERABLE_PROMPT)

View File

@ -1,4 +1,6 @@
from collections.abc import Callable
from copy import deepcopy
from typing import Any
from typing import cast
import numpy
@ -10,6 +12,7 @@ from sqlalchemy.orm import Session
from danswer.configs.app_configs import DISABLE_LLM_CHUNK_FILTER
from danswer.configs.app_configs import HYBRID_ALPHA
from danswer.configs.app_configs import MULTILINGUAL_QUERY_EXPANSION
from danswer.configs.app_configs import NUM_RERANKED_RESULTS
from danswer.configs.model_configs import ASYM_QUERY_PREFIX
from danswer.configs.model_configs import CROSS_ENCODER_RANGE_MAX
@ -36,11 +39,13 @@ from danswer.search.models import SearchType
from danswer.search.search_nlp_models import CrossEncoderEnsembleModel
from danswer.search.search_nlp_models import EmbeddingModel
from danswer.secondary_llm_flows.chunk_usefulness import llm_batch_eval_chunks
from danswer.secondary_llm_flows.query_expansion import rephrase_query
from danswer.server.models import QuestionRequest
from danswer.server.models import SearchDoc
from danswer.utils.logger import setup_logger
from danswer.utils.threadpool_concurrency import FunctionCall
from danswer.utils.threadpool_concurrency import run_functions_in_parallel
from danswer.utils.threadpool_concurrency import run_functions_tuples_in_parallel
from danswer.utils.timing import log_function_time
@ -108,6 +113,30 @@ def chunks_to_search_docs(chunks: list[InferenceChunk] | None) -> list[SearchDoc
return search_docs
def combine_retrieval_results(
chunk_sets: list[list[InferenceChunk]],
) -> list[InferenceChunk]:
all_chunks = [chunk for chunk_set in chunk_sets for chunk in chunk_set]
unique_chunks: dict[tuple[str, int], InferenceChunk] = {}
for chunk in all_chunks:
key = (chunk.document_id, chunk.chunk_id)
if key not in unique_chunks:
unique_chunks[key] = chunk
continue
stored_chunk_score = unique_chunks[key].score or 0
this_chunk_score = chunk.score or 0
if stored_chunk_score < this_chunk_score:
unique_chunks[key] = chunk
sorted_chunks = sorted(
unique_chunks.values(), key=lambda x: x.score or 0, reverse=True
)
return sorted_chunks
@log_function_time()
def doc_index_retrieval(
query: SearchQuery,
@ -313,6 +342,7 @@ def search_chunks(
query: SearchQuery,
document_index: DocumentIndex,
hybrid_alpha: float = HYBRID_ALPHA, # Only applicable to hybrid search
multilingual_query_expansion: str | None = MULTILINGUAL_QUERY_EXPANSION,
retrieval_metrics_callback: Callable[[RetrievalMetricsContainer], None]
| None = None,
rerank_metrics_callback: Callable[[RerankMetricsContainer], None] | None = None,
@ -331,9 +361,25 @@ def search_chunks(
]
logger.info(f"Top links from {search_flow} search: {', '.join(top_links)}")
top_chunks = doc_index_retrieval(
query=query, document_index=document_index, hybrid_alpha=hybrid_alpha
)
# Don't do query expansion on complex queries, rephrasings likely would not work well
if not multilingual_query_expansion or "\n" in query.query or "\r" in query.query:
top_chunks = doc_index_retrieval(
query=query, document_index=document_index, hybrid_alpha=hybrid_alpha
)
else:
run_queries: list[tuple[Callable, tuple]] = []
# Currently only uses query expansion on multilingual use cases
query_rephrases = rephrase_query(query.query, multilingual_query_expansion)
# Just to be extra sure, add the original query.
query_rephrases.append(query.query)
for rephrase in set(query_rephrases):
q_copy = deepcopy(query)
q_copy.query = rephrase
run_queries.append(
(doc_index_retrieval, (q_copy, document_index, hybrid_alpha))
)
parallel_search_results = run_functions_tuples_in_parallel(run_queries)
top_chunks = combine_retrieval_results(parallel_search_results)
if not top_chunks:
logger.info(
@ -384,7 +430,9 @@ def search_chunks(
functions_to_run.append(run_llm_filter)
run_llm_filter_id = run_llm_filter.result_id
parallel_results = run_functions_in_parallel(functions_to_run)
parallel_results: dict[str, Any] = {}
if functions_to_run:
parallel_results = run_functions_in_parallel(functions_to_run)
ranked_results = parallel_results.get(str(run_rerank_id))
if ranked_results is None:

View File

@ -0,0 +1,48 @@
from collections.abc import Callable
from danswer.llm.factory import get_default_llm
from danswer.llm.utils import dict_based_prompt_to_langchain_prompt
from danswer.prompts.secondary_llm_flows import LANGUAGE_REPHRASE_PROMPT
from danswer.utils.logger import setup_logger
from danswer.utils.threadpool_concurrency import run_functions_tuples_in_parallel
logger = setup_logger()
def llm_rephrase_query(query: str, language: str) -> str:
def _get_rephrase_messages() -> list[dict[str, str]]:
messages = [
{
"role": "user",
"content": LANGUAGE_REPHRASE_PROMPT.format(
query=query, target_language=language
),
},
]
return messages
messages = _get_rephrase_messages()
filled_llm_prompt = dict_based_prompt_to_langchain_prompt(messages)
model_output = get_default_llm().invoke(filled_llm_prompt)
logger.debug(model_output)
return model_output
def rephrase_query(
query: str,
multilingual_query_expansion: str,
use_threads: bool = True,
) -> list[str]:
languages = multilingual_query_expansion.split(",")
languages = [language.strip() for language in languages]
if use_threads:
functions_with_args: list[tuple[Callable, tuple]] = [
(llm_rephrase_query, (query, language)) for language in languages
]
return run_functions_tuples_in_parallel(functions_with_args)
else:
return [llm_rephrase_query(query, language) for language in languages]

View File

@ -47,6 +47,7 @@ services:
- SKIP_RERANKING=${SKIP_RERANKING:-}
- QA_PROMPT_OVERRIDE=${QA_PROMPT_OVERRIDE:-}
- EDIT_KEYWORD_QUERY=${EDIT_KEYWORD_QUERY:-}
- MULTILINGUAL_QUERY_EXPANSION=${MULTILINGUAL_QUERY_EXPANSION:-}
- MODEL_SERVER_HOST=${MODEL_SERVER_HOST:-}
- MODEL_SERVER_PORT=${MODEL_SERVER_PORT:-}
# Leave this on pretty please? Nothing sensitive is collected!
@ -54,6 +55,8 @@ services:
- DISABLE_TELEMETRY=${DISABLE_TELEMETRY:-}
# Set to debug to get more fine-grained logs
- LOG_LEVEL=${LOG_LEVEL:-info}
# Log all of the prompts to the LLM
- LOG_ALL_MODEL_INTERACTIONS=${LOG_ALL_MODEL_INTERACTIONS:-info}
volumes:
- local_dynamic_storage:/home/storage
- file_connector_tmp_storage:/home/file_connector_storage
@ -106,11 +109,17 @@ services:
- SKIP_RERANKING=${SKIP_RERANKING:-}
- QA_PROMPT_OVERRIDE=${QA_PROMPT_OVERRIDE:-}
- EDIT_KEYWORD_QUERY=${EDIT_KEYWORD_QUERY:-}
- MULTILINGUAL_QUERY_EXPANSION=${MULTILINGUAL_QUERY_EXPANSION:-}
- MIN_THREADS_ML_MODELS=${MIN_THREADS_ML_MODELS:-}
- MODEL_SERVER_HOST=${MODEL_SERVER_HOST:-}
- MODEL_SERVER_PORT=${MODEL_SERVER_PORT:-}
# Leave this on pretty please? Nothing sensitive is collected!
# https://docs.danswer.dev/more/telemetry
- DISABLE_TELEMETRY=${DISABLE_TELEMETRY:-}
# Set to debug to get more fine-grained logs
- LOG_LEVEL=${LOG_LEVEL:-info}
# Log all of the prompts to the LLM
- LOG_ALL_MODEL_INTERACTIONS=${LOG_ALL_MODEL_INTERACTIONS:-info}
volumes:
- local_dynamic_storage:/home/storage
- file_connector_tmp_storage:/home/file_connector_storage

View File

@ -0,0 +1,41 @@
# This env template shows how to configure Danswer for multilingual use
# In this case, it is configured for French and English
# To use it, copy it to .env in the docker_compose directory.
# Feel free to combine it with the other templates to suit your needs
# A recent MIT license multilingual model: https://huggingface.co/intfloat/multilingual-e5-small
DOCUMENT_ENCODER_MODEL="intfloat/multilingual-e5-small"
# The model above is trained with the following prefix for queries and passages to improve retrieval
# by letting the model know which of the two type is currently being embedded
ASYM_QUERY_PREFIX="query: "
ASYM_PASSAGE_PREFIX="passage: "
# Depends model by model, this one is tuned with this as True
NORMALIZE_EMBEDDINGS="True"
# Due to the loss function used in training, this model outputs similarity scores from range ~0.6 to 1
SIM_SCORE_RANGE_LOW="0.6"
SIM_SCORE_RANGE_LOW="0.8"
# No recent multilingual reranking models small enough to run on CPU, so turning it off
SKIP_RERANKING="True"
# Use LLM to determine if chunks are relevant to the query
# may not work well for languages that do not have much training data in the LLM training set
DISABLE_LLM_CHUNK_FILTER="True"
# Rephrase the user query in specified languages using LLM, use comma separated values
MULTILINGUAL_QUERY_EXPANSION="English, French"
# Enables fine-grained embeddings for better retrieval
# At the cost of indexing speed (~5x slower), query time is same speed
ENABLE_MINI_CHUNK="True"
# Stronger model will help with multilingual tasks
GEN_AI_MODEL_VERSION="gpt-4"
GEN_AI_API_KEY=<provide your api key>
# More verbose logging if desired
LOG_LEVEL="debug"