mirror of
https://github.com/danswer-ai/danswer.git
synced 2025-07-29 14:24:06 +02:00
283 lines
10 KiB
Python
283 lines
10 KiB
Python
import string
|
|
from collections.abc import Callable
|
|
|
|
import nltk # type:ignore
|
|
from nltk.corpus import stopwords # type:ignore
|
|
from nltk.stem import WordNetLemmatizer # type:ignore
|
|
from nltk.tokenize import word_tokenize # type:ignore
|
|
from sqlalchemy.orm import Session
|
|
|
|
from danswer.configs.chat_configs import HYBRID_ALPHA
|
|
from danswer.configs.chat_configs import MULTILINGUAL_QUERY_EXPANSION
|
|
from danswer.db.embedding_model import get_current_db_embedding_model
|
|
from danswer.document_index.interfaces import DocumentIndex
|
|
from danswer.natural_language_processing.search_nlp_models import EmbeddingModel
|
|
from danswer.search.models import ChunkMetric
|
|
from danswer.search.models import IndexFilters
|
|
from danswer.search.models import InferenceChunk
|
|
from danswer.search.models import InferenceSection
|
|
from danswer.search.models import MAX_METRICS_CONTENT
|
|
from danswer.search.models import RetrievalMetricsContainer
|
|
from danswer.search.models import SearchQuery
|
|
from danswer.search.models import SearchType
|
|
from danswer.search.postprocessing.postprocessing import cleanup_chunks
|
|
from danswer.search.utils import inference_section_from_chunks
|
|
from danswer.secondary_llm_flows.query_expansion import multilingual_query_expansion
|
|
from danswer.utils.logger import setup_logger
|
|
from danswer.utils.threadpool_concurrency import run_functions_tuples_in_parallel
|
|
from danswer.utils.timing import log_function_time
|
|
from shared_configs.configs import MODEL_SERVER_HOST
|
|
from shared_configs.configs import MODEL_SERVER_PORT
|
|
from shared_configs.enums import EmbedTextType
|
|
|
|
|
|
logger = setup_logger()
|
|
|
|
|
|
def download_nltk_data() -> None:
|
|
resources = {
|
|
"stopwords": "corpora/stopwords",
|
|
"wordnet": "corpora/wordnet",
|
|
"punkt": "tokenizers/punkt",
|
|
}
|
|
|
|
for resource_name, resource_path in resources.items():
|
|
try:
|
|
nltk.data.find(resource_path)
|
|
logger.info(f"{resource_name} is already downloaded.")
|
|
except LookupError:
|
|
try:
|
|
logger.info(f"Downloading {resource_name}...")
|
|
nltk.download(resource_name, quiet=True)
|
|
logger.info(f"{resource_name} downloaded successfully.")
|
|
except Exception as e:
|
|
logger.error(f"Failed to download {resource_name}. Error: {e}")
|
|
|
|
|
|
def lemmatize_text(text: str) -> list[str]:
|
|
try:
|
|
lemmatizer = WordNetLemmatizer()
|
|
word_tokens = word_tokenize(text)
|
|
return [lemmatizer.lemmatize(word) for word in word_tokens]
|
|
except Exception:
|
|
return text.split(" ")
|
|
|
|
|
|
def remove_stop_words_and_punctuation(text: str) -> list[str]:
|
|
try:
|
|
stop_words = set(stopwords.words("english"))
|
|
word_tokens = word_tokenize(text)
|
|
text_trimmed = [
|
|
word
|
|
for word in word_tokens
|
|
if (word.casefold() not in stop_words and word not in string.punctuation)
|
|
]
|
|
return text_trimmed or word_tokens
|
|
except Exception:
|
|
return text.split(" ")
|
|
|
|
|
|
def query_processing(
|
|
query: str,
|
|
) -> str:
|
|
query = " ".join(remove_stop_words_and_punctuation(query))
|
|
query = " ".join(lemmatize_text(query))
|
|
return query
|
|
|
|
|
|
def combine_retrieval_results(
|
|
chunk_sets: list[list[InferenceChunk]],
|
|
) -> list[InferenceChunk]:
|
|
all_chunks = [chunk for chunk_set in chunk_sets for chunk in chunk_set]
|
|
|
|
unique_chunks: dict[tuple[str, int], InferenceChunk] = {}
|
|
for chunk in all_chunks:
|
|
key = (chunk.document_id, chunk.chunk_id)
|
|
if key not in unique_chunks:
|
|
unique_chunks[key] = chunk
|
|
continue
|
|
|
|
stored_chunk_score = unique_chunks[key].score or 0
|
|
this_chunk_score = chunk.score or 0
|
|
if stored_chunk_score < this_chunk_score:
|
|
unique_chunks[key] = chunk
|
|
|
|
sorted_chunks = sorted(
|
|
unique_chunks.values(), key=lambda x: x.score or 0, reverse=True
|
|
)
|
|
|
|
return sorted_chunks
|
|
|
|
|
|
@log_function_time(print_only=True)
|
|
def doc_index_retrieval(
|
|
query: SearchQuery,
|
|
document_index: DocumentIndex,
|
|
db_session: Session,
|
|
hybrid_alpha: float = HYBRID_ALPHA,
|
|
) -> list[InferenceChunk]:
|
|
if query.search_type == SearchType.KEYWORD:
|
|
top_chunks = document_index.keyword_retrieval(
|
|
query=query.query,
|
|
filters=query.filters,
|
|
time_decay_multiplier=query.recency_bias_multiplier,
|
|
num_to_retrieve=query.num_hits,
|
|
)
|
|
else:
|
|
db_embedding_model = get_current_db_embedding_model(db_session)
|
|
|
|
model = EmbeddingModel(
|
|
model_name=db_embedding_model.model_name,
|
|
query_prefix=db_embedding_model.query_prefix,
|
|
passage_prefix=db_embedding_model.passage_prefix,
|
|
normalize=db_embedding_model.normalize,
|
|
api_key=db_embedding_model.api_key,
|
|
provider_type=db_embedding_model.provider_type,
|
|
# The below are globally set, this flow always uses the indexing one
|
|
server_host=MODEL_SERVER_HOST,
|
|
server_port=MODEL_SERVER_PORT,
|
|
)
|
|
|
|
query_embedding = model.encode([query.query], text_type=EmbedTextType.QUERY)[0]
|
|
|
|
if query.search_type == SearchType.SEMANTIC:
|
|
top_chunks = document_index.semantic_retrieval(
|
|
query=query.query,
|
|
query_embedding=query_embedding,
|
|
filters=query.filters,
|
|
time_decay_multiplier=query.recency_bias_multiplier,
|
|
num_to_retrieve=query.num_hits,
|
|
)
|
|
|
|
elif query.search_type == SearchType.HYBRID:
|
|
top_chunks = document_index.hybrid_retrieval(
|
|
query=query.query,
|
|
query_embedding=query_embedding,
|
|
filters=query.filters,
|
|
time_decay_multiplier=query.recency_bias_multiplier,
|
|
num_to_retrieve=query.num_hits,
|
|
offset=query.offset,
|
|
hybrid_alpha=hybrid_alpha,
|
|
)
|
|
|
|
else:
|
|
raise RuntimeError("Invalid Search Flow")
|
|
|
|
return cleanup_chunks(top_chunks)
|
|
|
|
|
|
def _simplify_text(text: str) -> str:
|
|
return "".join(
|
|
char for char in text if char not in string.punctuation and not char.isspace()
|
|
).lower()
|
|
|
|
|
|
def retrieve_chunks(
|
|
query: SearchQuery,
|
|
document_index: DocumentIndex,
|
|
db_session: Session,
|
|
hybrid_alpha: float = HYBRID_ALPHA, # Only applicable to hybrid search
|
|
multilingual_expansion_str: str | None = MULTILINGUAL_QUERY_EXPANSION,
|
|
retrieval_metrics_callback: Callable[[RetrievalMetricsContainer], None]
|
|
| None = None,
|
|
) -> list[InferenceChunk]:
|
|
"""Returns a list of the best chunks from an initial keyword/semantic/ hybrid search."""
|
|
# Don't do query expansion on complex queries, rephrasings likely would not work well
|
|
if not multilingual_expansion_str or "\n" in query.query or "\r" in query.query:
|
|
top_chunks = doc_index_retrieval(
|
|
query=query,
|
|
document_index=document_index,
|
|
db_session=db_session,
|
|
hybrid_alpha=hybrid_alpha,
|
|
)
|
|
else:
|
|
simplified_queries = set()
|
|
run_queries: list[tuple[Callable, tuple]] = []
|
|
|
|
# Currently only uses query expansion on multilingual use cases
|
|
query_rephrases = multilingual_query_expansion(
|
|
query.query, multilingual_expansion_str
|
|
)
|
|
# Just to be extra sure, add the original query.
|
|
query_rephrases.append(query.query)
|
|
for rephrase in set(query_rephrases):
|
|
# Sometimes the model rephrases the query in the same language with minor changes
|
|
# Avoid doing an extra search with the minor changes as this biases the results
|
|
simplified_rephrase = _simplify_text(rephrase)
|
|
if simplified_rephrase in simplified_queries:
|
|
continue
|
|
simplified_queries.add(simplified_rephrase)
|
|
|
|
q_copy = query.copy(update={"query": rephrase}, deep=True)
|
|
run_queries.append(
|
|
(
|
|
doc_index_retrieval,
|
|
(q_copy, document_index, db_session, hybrid_alpha),
|
|
)
|
|
)
|
|
parallel_search_results = run_functions_tuples_in_parallel(run_queries)
|
|
top_chunks = combine_retrieval_results(parallel_search_results)
|
|
|
|
if not top_chunks:
|
|
logger.info(
|
|
f"{query.search_type.value.capitalize()} search returned no results "
|
|
f"with filters: {query.filters}"
|
|
)
|
|
return []
|
|
|
|
if retrieval_metrics_callback is not None:
|
|
chunk_metrics = [
|
|
ChunkMetric(
|
|
document_id=chunk.document_id,
|
|
chunk_content_start=chunk.content[:MAX_METRICS_CONTENT],
|
|
first_link=chunk.source_links[0] if chunk.source_links else None,
|
|
score=chunk.score if chunk.score is not None else 0,
|
|
)
|
|
for chunk in top_chunks
|
|
]
|
|
retrieval_metrics_callback(
|
|
RetrievalMetricsContainer(
|
|
search_type=query.search_type, metrics=chunk_metrics
|
|
)
|
|
)
|
|
|
|
return top_chunks
|
|
|
|
|
|
def inference_sections_from_ids(
|
|
doc_identifiers: list[tuple[str, int]],
|
|
document_index: DocumentIndex,
|
|
) -> list[InferenceSection]:
|
|
# Currently only fetches whole docs
|
|
doc_ids_set = set(doc_id for doc_id, chunk_id in doc_identifiers)
|
|
|
|
# No need for ACL here because the doc ids were validated beforehand
|
|
filters = IndexFilters(access_control_list=None)
|
|
|
|
functions_with_args: list[tuple[Callable, tuple]] = [
|
|
(document_index.id_based_retrieval, (doc_id, None, None, filters))
|
|
for doc_id in doc_ids_set
|
|
]
|
|
|
|
parallel_results = run_functions_tuples_in_parallel(
|
|
functions_with_args, allow_failures=True
|
|
)
|
|
|
|
# Any failures to retrieve would give a None, drop the Nones and empty lists
|
|
inference_chunks_sets = [res for res in parallel_results if res]
|
|
|
|
return [
|
|
inference_section
|
|
for inference_section in [
|
|
inference_section_from_chunks(
|
|
# The scores will always be 0 because the fetching by id gives back
|
|
# no search scores. This is not needed though if the user is explicitly
|
|
# selecting a document.
|
|
center_chunk=chunk_set[0],
|
|
chunks=chunk_set,
|
|
)
|
|
for chunk_set in inference_chunks_sets
|
|
]
|
|
if inference_section is not None
|
|
]
|