mirror of
https://github.com/danswer-ai/danswer.git
synced 2025-04-03 09:28:25 +02:00
Multilingual Query Expansion (#737)
This commit is contained in:
parent
b258ec1bed
commit
6fb07d20cc
@ -219,7 +219,9 @@ else:
|
||||
EDIT_KEYWORD_QUERY = not os.environ.get("DOCUMENT_ENCODER_MODEL")
|
||||
# Weighting factor between Vector and Keyword Search, 1 for completely vector search
|
||||
HYBRID_ALPHA = max(0, min(1, float(os.environ.get("HYBRID_ALPHA") or 0.6)))
|
||||
|
||||
# A list of languages passed to the LLM to rephase the query
|
||||
# For example "English,French,Spanish", be sure to use the "," separator
|
||||
MULTILINGUAL_QUERY_EXPANSION = os.environ.get("MULTILINGUAL_QUERY_EXPANSION") or None
|
||||
|
||||
#####
|
||||
# Model Server Configs
|
||||
|
@ -6,6 +6,7 @@ from collections.abc import Iterator
|
||||
from langchain.schema.messages import BaseMessage
|
||||
from langchain.schema.messages import HumanMessage
|
||||
|
||||
from danswer.configs.app_configs import MULTILINGUAL_QUERY_EXPANSION
|
||||
from danswer.direct_qa.interfaces import AnswerQuestionReturn
|
||||
from danswer.direct_qa.interfaces import AnswerQuestionStreamReturn
|
||||
from danswer.direct_qa.interfaces import DanswerAnswer
|
||||
@ -22,6 +23,7 @@ from danswer.llm.utils import tokenizer_trim_chunks
|
||||
from danswer.prompts.constants import CODE_BLOCK_PAT
|
||||
from danswer.prompts.direct_qa_prompts import COT_PROMPT
|
||||
from danswer.prompts.direct_qa_prompts import JSON_PROMPT
|
||||
from danswer.prompts.direct_qa_prompts import LANGUAGE_HINT
|
||||
from danswer.prompts.direct_qa_prompts import WEAK_LLM_PROMPT
|
||||
from danswer.utils.logger import setup_logger
|
||||
from danswer.utils.text_processing import clean_up_code_blocks
|
||||
@ -88,15 +90,20 @@ class SingleMessageQAHandler(QAHandler):
|
||||
return True
|
||||
|
||||
def build_prompt(
|
||||
self, query: str, context_chunks: list[InferenceChunk]
|
||||
self,
|
||||
query: str,
|
||||
context_chunks: list[InferenceChunk],
|
||||
use_language_hint: bool = bool(MULTILINGUAL_QUERY_EXPANSION),
|
||||
) -> list[BaseMessage]:
|
||||
context_docs_str = "\n".join(
|
||||
f"\n{CODE_BLOCK_PAT.format(c.content)}\n" for c in context_chunks
|
||||
)
|
||||
|
||||
single_message = JSON_PROMPT.format(
|
||||
context_docs_str=context_docs_str, user_query=query
|
||||
)
|
||||
context_docs_str=context_docs_str,
|
||||
user_query=query,
|
||||
language_hint_or_none=LANGUAGE_HINT if use_language_hint else "",
|
||||
).strip()
|
||||
|
||||
prompt: list[BaseMessage] = [HumanMessage(content=single_message)]
|
||||
return prompt
|
||||
@ -111,15 +118,20 @@ class SingleMessageScratchpadHandler(QAHandler):
|
||||
return True
|
||||
|
||||
def build_prompt(
|
||||
self, query: str, context_chunks: list[InferenceChunk]
|
||||
self,
|
||||
query: str,
|
||||
context_chunks: list[InferenceChunk],
|
||||
use_language_hint: bool = bool(MULTILINGUAL_QUERY_EXPANSION),
|
||||
) -> list[BaseMessage]:
|
||||
context_docs_str = "\n".join(
|
||||
f"\n{CODE_BLOCK_PAT.format(c.content)}\n" for c in context_chunks
|
||||
)
|
||||
|
||||
single_message = COT_PROMPT.format(
|
||||
context_docs_str=context_docs_str, user_query=query
|
||||
)
|
||||
context_docs_str=context_docs_str,
|
||||
user_query=query,
|
||||
language_hint_or_none=LANGUAGE_HINT if use_language_hint else "",
|
||||
).strip()
|
||||
|
||||
prompt: list[BaseMessage] = [HumanMessage(content=single_message)]
|
||||
return prompt
|
||||
|
@ -21,6 +21,7 @@ from danswer.configs.app_configs import AUTH_TYPE
|
||||
from danswer.configs.app_configs import DISABLE_GENERATIVE_AI
|
||||
from danswer.configs.app_configs import MODEL_SERVER_HOST
|
||||
from danswer.configs.app_configs import MODEL_SERVER_PORT
|
||||
from danswer.configs.app_configs import MULTILINGUAL_QUERY_EXPANSION
|
||||
from danswer.configs.app_configs import OAUTH_CLIENT_ID
|
||||
from danswer.configs.app_configs import OAUTH_CLIENT_SECRET
|
||||
from danswer.configs.app_configs import SECRET
|
||||
@ -175,6 +176,11 @@ def get_application() -> FastAPI:
|
||||
if GEN_AI_API_ENDPOINT:
|
||||
logger.info(f"Using LLM Endpoint: {GEN_AI_API_ENDPOINT}")
|
||||
|
||||
if MULTILINGUAL_QUERY_EXPANSION:
|
||||
logger.info(
|
||||
f"Using multilingual flow with languages: {MULTILINGUAL_QUERY_EXPANSION}"
|
||||
)
|
||||
|
||||
if SKIP_RERANKING:
|
||||
logger.info("Reranking step of search flow is disabled")
|
||||
|
||||
|
@ -27,6 +27,11 @@ Quotes MUST be EXACT substrings from provided documents!
|
||||
""".strip()
|
||||
|
||||
|
||||
LANGUAGE_HINT = """
|
||||
IMPORTANT: Respond in the same language as my query!
|
||||
""".strip()
|
||||
|
||||
|
||||
# This has to be doubly escaped due to json containing { } which are also used for format strings
|
||||
EMPTY_SAMPLE_JSON = {
|
||||
"answer": "Place your final answer here. It should be as DETAILED and INFORMATIVE as possible.",
|
||||
@ -54,6 +59,7 @@ SAMPLE_RESPONSE:
|
||||
```
|
||||
{QUESTION_PAT} {{user_query}}
|
||||
{JSON_HELPFUL_HINT}
|
||||
{{language_hint_or_none}}
|
||||
""".strip()
|
||||
|
||||
|
||||
@ -75,6 +81,7 @@ You MUST respond in the following format:
|
||||
|
||||
{QUESTION_PAT} {{user_query}}
|
||||
{JSON_HELPFUL_HINT}
|
||||
{{language_hint_or_none}}
|
||||
""".strip()
|
||||
|
||||
|
||||
|
@ -155,6 +155,18 @@ Respond with EXACTLY AND ONLY: "{USEFUL_PAT}" or "{NONUSEFUL_PAT}"
|
||||
""".strip()
|
||||
|
||||
|
||||
LANGUAGE_REPHRASE_PROMPT = """
|
||||
Rephrase the query in {target_language}.
|
||||
If the query is already in the correct language, \
|
||||
simply repeat the ORIGINAL query back to me, EXACTLY as is with no rephrasing.
|
||||
NEVER change proper nouns, technical terms, acronyms, or terms you are not familiar with.
|
||||
IMPORTANT, if the query is already in the target language, DO NOT REPHRASE OR EDIT the query!
|
||||
|
||||
Query:
|
||||
{query}
|
||||
""".strip()
|
||||
|
||||
|
||||
# User the following for easy viewing of prompts
|
||||
if __name__ == "__main__":
|
||||
print(ANSWERABLE_PROMPT)
|
||||
|
@ -1,4 +1,6 @@
|
||||
from collections.abc import Callable
|
||||
from copy import deepcopy
|
||||
from typing import Any
|
||||
from typing import cast
|
||||
|
||||
import numpy
|
||||
@ -10,6 +12,7 @@ from sqlalchemy.orm import Session
|
||||
|
||||
from danswer.configs.app_configs import DISABLE_LLM_CHUNK_FILTER
|
||||
from danswer.configs.app_configs import HYBRID_ALPHA
|
||||
from danswer.configs.app_configs import MULTILINGUAL_QUERY_EXPANSION
|
||||
from danswer.configs.app_configs import NUM_RERANKED_RESULTS
|
||||
from danswer.configs.model_configs import ASYM_QUERY_PREFIX
|
||||
from danswer.configs.model_configs import CROSS_ENCODER_RANGE_MAX
|
||||
@ -36,11 +39,13 @@ from danswer.search.models import SearchType
|
||||
from danswer.search.search_nlp_models import CrossEncoderEnsembleModel
|
||||
from danswer.search.search_nlp_models import EmbeddingModel
|
||||
from danswer.secondary_llm_flows.chunk_usefulness import llm_batch_eval_chunks
|
||||
from danswer.secondary_llm_flows.query_expansion import rephrase_query
|
||||
from danswer.server.models import QuestionRequest
|
||||
from danswer.server.models import SearchDoc
|
||||
from danswer.utils.logger import setup_logger
|
||||
from danswer.utils.threadpool_concurrency import FunctionCall
|
||||
from danswer.utils.threadpool_concurrency import run_functions_in_parallel
|
||||
from danswer.utils.threadpool_concurrency import run_functions_tuples_in_parallel
|
||||
from danswer.utils.timing import log_function_time
|
||||
|
||||
|
||||
@ -108,6 +113,30 @@ def chunks_to_search_docs(chunks: list[InferenceChunk] | None) -> list[SearchDoc
|
||||
return search_docs
|
||||
|
||||
|
||||
def combine_retrieval_results(
|
||||
chunk_sets: list[list[InferenceChunk]],
|
||||
) -> list[InferenceChunk]:
|
||||
all_chunks = [chunk for chunk_set in chunk_sets for chunk in chunk_set]
|
||||
|
||||
unique_chunks: dict[tuple[str, int], InferenceChunk] = {}
|
||||
for chunk in all_chunks:
|
||||
key = (chunk.document_id, chunk.chunk_id)
|
||||
if key not in unique_chunks:
|
||||
unique_chunks[key] = chunk
|
||||
continue
|
||||
|
||||
stored_chunk_score = unique_chunks[key].score or 0
|
||||
this_chunk_score = chunk.score or 0
|
||||
if stored_chunk_score < this_chunk_score:
|
||||
unique_chunks[key] = chunk
|
||||
|
||||
sorted_chunks = sorted(
|
||||
unique_chunks.values(), key=lambda x: x.score or 0, reverse=True
|
||||
)
|
||||
|
||||
return sorted_chunks
|
||||
|
||||
|
||||
@log_function_time()
|
||||
def doc_index_retrieval(
|
||||
query: SearchQuery,
|
||||
@ -313,6 +342,7 @@ def search_chunks(
|
||||
query: SearchQuery,
|
||||
document_index: DocumentIndex,
|
||||
hybrid_alpha: float = HYBRID_ALPHA, # Only applicable to hybrid search
|
||||
multilingual_query_expansion: str | None = MULTILINGUAL_QUERY_EXPANSION,
|
||||
retrieval_metrics_callback: Callable[[RetrievalMetricsContainer], None]
|
||||
| None = None,
|
||||
rerank_metrics_callback: Callable[[RerankMetricsContainer], None] | None = None,
|
||||
@ -331,9 +361,25 @@ def search_chunks(
|
||||
]
|
||||
logger.info(f"Top links from {search_flow} search: {', '.join(top_links)}")
|
||||
|
||||
top_chunks = doc_index_retrieval(
|
||||
query=query, document_index=document_index, hybrid_alpha=hybrid_alpha
|
||||
)
|
||||
# Don't do query expansion on complex queries, rephrasings likely would not work well
|
||||
if not multilingual_query_expansion or "\n" in query.query or "\r" in query.query:
|
||||
top_chunks = doc_index_retrieval(
|
||||
query=query, document_index=document_index, hybrid_alpha=hybrid_alpha
|
||||
)
|
||||
else:
|
||||
run_queries: list[tuple[Callable, tuple]] = []
|
||||
# Currently only uses query expansion on multilingual use cases
|
||||
query_rephrases = rephrase_query(query.query, multilingual_query_expansion)
|
||||
# Just to be extra sure, add the original query.
|
||||
query_rephrases.append(query.query)
|
||||
for rephrase in set(query_rephrases):
|
||||
q_copy = deepcopy(query)
|
||||
q_copy.query = rephrase
|
||||
run_queries.append(
|
||||
(doc_index_retrieval, (q_copy, document_index, hybrid_alpha))
|
||||
)
|
||||
parallel_search_results = run_functions_tuples_in_parallel(run_queries)
|
||||
top_chunks = combine_retrieval_results(parallel_search_results)
|
||||
|
||||
if not top_chunks:
|
||||
logger.info(
|
||||
@ -384,7 +430,9 @@ def search_chunks(
|
||||
functions_to_run.append(run_llm_filter)
|
||||
run_llm_filter_id = run_llm_filter.result_id
|
||||
|
||||
parallel_results = run_functions_in_parallel(functions_to_run)
|
||||
parallel_results: dict[str, Any] = {}
|
||||
if functions_to_run:
|
||||
parallel_results = run_functions_in_parallel(functions_to_run)
|
||||
|
||||
ranked_results = parallel_results.get(str(run_rerank_id))
|
||||
if ranked_results is None:
|
||||
|
48
backend/danswer/secondary_llm_flows/query_expansion.py
Normal file
48
backend/danswer/secondary_llm_flows/query_expansion.py
Normal file
@ -0,0 +1,48 @@
|
||||
from collections.abc import Callable
|
||||
|
||||
from danswer.llm.factory import get_default_llm
|
||||
from danswer.llm.utils import dict_based_prompt_to_langchain_prompt
|
||||
from danswer.prompts.secondary_llm_flows import LANGUAGE_REPHRASE_PROMPT
|
||||
from danswer.utils.logger import setup_logger
|
||||
from danswer.utils.threadpool_concurrency import run_functions_tuples_in_parallel
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
def llm_rephrase_query(query: str, language: str) -> str:
|
||||
def _get_rephrase_messages() -> list[dict[str, str]]:
|
||||
messages = [
|
||||
{
|
||||
"role": "user",
|
||||
"content": LANGUAGE_REPHRASE_PROMPT.format(
|
||||
query=query, target_language=language
|
||||
),
|
||||
},
|
||||
]
|
||||
|
||||
return messages
|
||||
|
||||
messages = _get_rephrase_messages()
|
||||
filled_llm_prompt = dict_based_prompt_to_langchain_prompt(messages)
|
||||
model_output = get_default_llm().invoke(filled_llm_prompt)
|
||||
logger.debug(model_output)
|
||||
|
||||
return model_output
|
||||
|
||||
|
||||
def rephrase_query(
|
||||
query: str,
|
||||
multilingual_query_expansion: str,
|
||||
use_threads: bool = True,
|
||||
) -> list[str]:
|
||||
languages = multilingual_query_expansion.split(",")
|
||||
languages = [language.strip() for language in languages]
|
||||
if use_threads:
|
||||
functions_with_args: list[tuple[Callable, tuple]] = [
|
||||
(llm_rephrase_query, (query, language)) for language in languages
|
||||
]
|
||||
|
||||
return run_functions_tuples_in_parallel(functions_with_args)
|
||||
|
||||
else:
|
||||
return [llm_rephrase_query(query, language) for language in languages]
|
@ -47,6 +47,7 @@ services:
|
||||
- SKIP_RERANKING=${SKIP_RERANKING:-}
|
||||
- QA_PROMPT_OVERRIDE=${QA_PROMPT_OVERRIDE:-}
|
||||
- EDIT_KEYWORD_QUERY=${EDIT_KEYWORD_QUERY:-}
|
||||
- MULTILINGUAL_QUERY_EXPANSION=${MULTILINGUAL_QUERY_EXPANSION:-}
|
||||
- MODEL_SERVER_HOST=${MODEL_SERVER_HOST:-}
|
||||
- MODEL_SERVER_PORT=${MODEL_SERVER_PORT:-}
|
||||
# Leave this on pretty please? Nothing sensitive is collected!
|
||||
@ -54,6 +55,8 @@ services:
|
||||
- DISABLE_TELEMETRY=${DISABLE_TELEMETRY:-}
|
||||
# Set to debug to get more fine-grained logs
|
||||
- LOG_LEVEL=${LOG_LEVEL:-info}
|
||||
# Log all of the prompts to the LLM
|
||||
- LOG_ALL_MODEL_INTERACTIONS=${LOG_ALL_MODEL_INTERACTIONS:-info}
|
||||
volumes:
|
||||
- local_dynamic_storage:/home/storage
|
||||
- file_connector_tmp_storage:/home/file_connector_storage
|
||||
@ -106,11 +109,17 @@ services:
|
||||
- SKIP_RERANKING=${SKIP_RERANKING:-}
|
||||
- QA_PROMPT_OVERRIDE=${QA_PROMPT_OVERRIDE:-}
|
||||
- EDIT_KEYWORD_QUERY=${EDIT_KEYWORD_QUERY:-}
|
||||
- MULTILINGUAL_QUERY_EXPANSION=${MULTILINGUAL_QUERY_EXPANSION:-}
|
||||
- MIN_THREADS_ML_MODELS=${MIN_THREADS_ML_MODELS:-}
|
||||
- MODEL_SERVER_HOST=${MODEL_SERVER_HOST:-}
|
||||
- MODEL_SERVER_PORT=${MODEL_SERVER_PORT:-}
|
||||
# Leave this on pretty please? Nothing sensitive is collected!
|
||||
# https://docs.danswer.dev/more/telemetry
|
||||
- DISABLE_TELEMETRY=${DISABLE_TELEMETRY:-}
|
||||
# Set to debug to get more fine-grained logs
|
||||
- LOG_LEVEL=${LOG_LEVEL:-info}
|
||||
# Log all of the prompts to the LLM
|
||||
- LOG_ALL_MODEL_INTERACTIONS=${LOG_ALL_MODEL_INTERACTIONS:-info}
|
||||
volumes:
|
||||
- local_dynamic_storage:/home/storage
|
||||
- file_connector_tmp_storage:/home/file_connector_storage
|
||||
|
41
deployment/docker_compose/env.multilingual.template
Normal file
41
deployment/docker_compose/env.multilingual.template
Normal file
@ -0,0 +1,41 @@
|
||||
# This env template shows how to configure Danswer for multilingual use
|
||||
# In this case, it is configured for French and English
|
||||
# To use it, copy it to .env in the docker_compose directory.
|
||||
# Feel free to combine it with the other templates to suit your needs
|
||||
|
||||
|
||||
# A recent MIT license multilingual model: https://huggingface.co/intfloat/multilingual-e5-small
|
||||
DOCUMENT_ENCODER_MODEL="intfloat/multilingual-e5-small"
|
||||
|
||||
# The model above is trained with the following prefix for queries and passages to improve retrieval
|
||||
# by letting the model know which of the two type is currently being embedded
|
||||
ASYM_QUERY_PREFIX="query: "
|
||||
ASYM_PASSAGE_PREFIX="passage: "
|
||||
|
||||
# Depends model by model, this one is tuned with this as True
|
||||
NORMALIZE_EMBEDDINGS="True"
|
||||
|
||||
# Due to the loss function used in training, this model outputs similarity scores from range ~0.6 to 1
|
||||
SIM_SCORE_RANGE_LOW="0.6"
|
||||
SIM_SCORE_RANGE_LOW="0.8"
|
||||
|
||||
# No recent multilingual reranking models small enough to run on CPU, so turning it off
|
||||
SKIP_RERANKING="True"
|
||||
|
||||
# Use LLM to determine if chunks are relevant to the query
|
||||
# may not work well for languages that do not have much training data in the LLM training set
|
||||
DISABLE_LLM_CHUNK_FILTER="True"
|
||||
|
||||
# Rephrase the user query in specified languages using LLM, use comma separated values
|
||||
MULTILINGUAL_QUERY_EXPANSION="English, French"
|
||||
|
||||
# Enables fine-grained embeddings for better retrieval
|
||||
# At the cost of indexing speed (~5x slower), query time is same speed
|
||||
ENABLE_MINI_CHUNK="True"
|
||||
|
||||
# Stronger model will help with multilingual tasks
|
||||
GEN_AI_MODEL_VERSION="gpt-4"
|
||||
GEN_AI_API_KEY=<provide your api key>
|
||||
|
||||
# More verbose logging if desired
|
||||
LOG_LEVEL="debug"
|
Loading…
x
Reference in New Issue
Block a user