2024-12-13 09:56:10 -08:00

332 lines
12 KiB
Python

import string
from collections.abc import Callable
import nltk # type:ignore
from nltk.corpus import stopwords # type:ignore
from nltk.tokenize import word_tokenize # type:ignore
from sqlalchemy.orm import Session
from onyx.context.search.models import ChunkMetric
from onyx.context.search.models import IndexFilters
from onyx.context.search.models import InferenceChunk
from onyx.context.search.models import InferenceChunkUncleaned
from onyx.context.search.models import InferenceSection
from onyx.context.search.models import MAX_METRICS_CONTENT
from onyx.context.search.models import RetrievalMetricsContainer
from onyx.context.search.models import SearchQuery
from onyx.context.search.postprocessing.postprocessing import cleanup_chunks
from onyx.context.search.utils import inference_section_from_chunks
from onyx.db.search_settings import get_current_search_settings
from onyx.db.search_settings import get_multilingual_expansion
from onyx.document_index.interfaces import DocumentIndex
from onyx.document_index.interfaces import VespaChunkRequest
from onyx.document_index.vespa.shared_utils.utils import (
replace_invalid_doc_id_characters,
)
from onyx.natural_language_processing.search_nlp_models import EmbeddingModel
from onyx.secondary_llm_flows.query_expansion import multilingual_query_expansion
from onyx.utils.logger import setup_logger
from onyx.utils.threadpool_concurrency import run_functions_tuples_in_parallel
from onyx.utils.timing import log_function_time
from shared_configs.configs import MODEL_SERVER_HOST
from shared_configs.configs import MODEL_SERVER_PORT
from shared_configs.enums import EmbedTextType
logger = setup_logger()
def download_nltk_data() -> None:
resources = {
"stopwords": "corpora/stopwords",
# "wordnet": "corpora/wordnet", # Not in use
"punkt": "tokenizers/punkt",
}
for resource_name, resource_path in resources.items():
try:
nltk.data.find(resource_path)
logger.info(f"{resource_name} is already downloaded.")
except LookupError:
try:
logger.info(f"Downloading {resource_name}...")
nltk.download(resource_name, quiet=True)
logger.info(f"{resource_name} downloaded successfully.")
except Exception as e:
logger.error(f"Failed to download {resource_name}. Error: {e}")
def lemmatize_text(keywords: list[str]) -> list[str]:
raise NotImplementedError("Lemmatization should not be used currently")
# try:
# query = " ".join(keywords)
# lemmatizer = WordNetLemmatizer()
# word_tokens = word_tokenize(query)
# lemmatized_words = [lemmatizer.lemmatize(word) for word in word_tokens]
# combined_keywords = list(set(keywords + lemmatized_words))
# return combined_keywords
# except Exception:
# return keywords
def remove_stop_words_and_punctuation(keywords: list[str]) -> list[str]:
try:
# Re-tokenize using the NLTK tokenizer for better matching
query = " ".join(keywords)
stop_words = set(stopwords.words("english"))
word_tokens = word_tokenize(query)
text_trimmed = [
word
for word in word_tokens
if (word.casefold() not in stop_words and word not in string.punctuation)
]
return text_trimmed or word_tokens
except Exception:
return keywords
def combine_retrieval_results(
chunk_sets: list[list[InferenceChunk]],
) -> list[InferenceChunk]:
all_chunks = [chunk for chunk_set in chunk_sets for chunk in chunk_set]
unique_chunks: dict[tuple[str, int], InferenceChunk] = {}
for chunk in all_chunks:
key = (chunk.document_id, chunk.chunk_id)
if key not in unique_chunks:
unique_chunks[key] = chunk
continue
stored_chunk_score = unique_chunks[key].score or 0
this_chunk_score = chunk.score or 0
if stored_chunk_score < this_chunk_score:
unique_chunks[key] = chunk
sorted_chunks = sorted(
unique_chunks.values(), key=lambda x: x.score or 0, reverse=True
)
return sorted_chunks
@log_function_time(print_only=True)
def doc_index_retrieval(
query: SearchQuery,
document_index: DocumentIndex,
db_session: Session,
) -> list[InferenceChunk]:
"""
This function performs the search to retrieve the chunks,
extracts chunks from the large chunks, persists the scores
from the large chunks to the referenced chunks,
dedupes the chunks, and cleans the chunks.
"""
search_settings = get_current_search_settings(db_session)
model = EmbeddingModel.from_db_model(
search_settings=search_settings,
# The below are globally set, this flow always uses the indexing one
server_host=MODEL_SERVER_HOST,
server_port=MODEL_SERVER_PORT,
)
query_embedding = model.encode([query.query], text_type=EmbedTextType.QUERY)[0]
top_chunks = document_index.hybrid_retrieval(
query=query.query,
query_embedding=query_embedding,
final_keywords=query.processed_keywords,
filters=query.filters,
hybrid_alpha=query.hybrid_alpha,
time_decay_multiplier=query.recency_bias_multiplier,
num_to_retrieve=query.num_hits,
offset=query.offset,
)
retrieval_requests: list[VespaChunkRequest] = []
normal_chunks: list[InferenceChunkUncleaned] = []
referenced_chunk_scores: dict[tuple[str, int], float] = {}
for chunk in top_chunks:
if chunk.large_chunk_reference_ids:
retrieval_requests.append(
VespaChunkRequest(
document_id=replace_invalid_doc_id_characters(chunk.document_id),
min_chunk_ind=chunk.large_chunk_reference_ids[0],
max_chunk_ind=chunk.large_chunk_reference_ids[-1],
)
)
# for each referenced chunk, persist the
# highest score to the referenced chunk
for chunk_id in chunk.large_chunk_reference_ids:
key = (chunk.document_id, chunk_id)
referenced_chunk_scores[key] = max(
referenced_chunk_scores.get(key, 0), chunk.score or 0
)
else:
normal_chunks.append(chunk)
# If there are no large chunks, just return the normal chunks
if not retrieval_requests:
return cleanup_chunks(normal_chunks)
# Retrieve and return the referenced normal chunks from the large chunks
retrieved_inference_chunks = document_index.id_based_retrieval(
chunk_requests=retrieval_requests,
filters=query.filters,
batch_retrieval=True,
)
# Apply the scores from the large chunks to the chunks referenced
# by each large chunk
for chunk in retrieved_inference_chunks:
if (chunk.document_id, chunk.chunk_id) in referenced_chunk_scores:
chunk.score = referenced_chunk_scores[(chunk.document_id, chunk.chunk_id)]
referenced_chunk_scores.pop((chunk.document_id, chunk.chunk_id))
else:
logger.error(
f"Chunk {chunk.document_id} {chunk.chunk_id} not found in referenced chunk scores"
)
# Log any chunks that were not found in the retrieved chunks
for reference in referenced_chunk_scores.keys():
logger.error(f"Chunk {reference} not found in retrieved chunks")
unique_chunks: dict[tuple[str, int], InferenceChunkUncleaned] = {
(chunk.document_id, chunk.chunk_id): chunk for chunk in normal_chunks
}
# persist the highest score of each deduped chunk
for chunk in retrieved_inference_chunks:
key = (chunk.document_id, chunk.chunk_id)
# For duplicates, keep the highest score
if key not in unique_chunks or (chunk.score or 0) > (
unique_chunks[key].score or 0
):
unique_chunks[key] = chunk
# Deduplicate the chunks
deduped_chunks = list(unique_chunks.values())
deduped_chunks.sort(key=lambda chunk: chunk.score or 0, reverse=True)
return cleanup_chunks(deduped_chunks)
def _simplify_text(text: str) -> str:
return "".join(
char for char in text if char not in string.punctuation and not char.isspace()
).lower()
def retrieve_chunks(
query: SearchQuery,
document_index: DocumentIndex,
db_session: Session,
retrieval_metrics_callback: Callable[[RetrievalMetricsContainer], None]
| None = None,
) -> list[InferenceChunk]:
"""Returns a list of the best chunks from an initial keyword/semantic/ hybrid search."""
multilingual_expansion = get_multilingual_expansion(db_session)
# Don't do query expansion on complex queries, rephrasings likely would not work well
if not multilingual_expansion or "\n" in query.query or "\r" in query.query:
top_chunks = doc_index_retrieval(
query=query, document_index=document_index, db_session=db_session
)
else:
simplified_queries = set()
run_queries: list[tuple[Callable, tuple]] = []
# Currently only uses query expansion on multilingual use cases
query_rephrases = multilingual_query_expansion(
query.query, multilingual_expansion
)
# Just to be extra sure, add the original query.
query_rephrases.append(query.query)
for rephrase in set(query_rephrases):
# Sometimes the model rephrases the query in the same language with minor changes
# Avoid doing an extra search with the minor changes as this biases the results
simplified_rephrase = _simplify_text(rephrase)
if simplified_rephrase in simplified_queries:
continue
simplified_queries.add(simplified_rephrase)
q_copy = query.copy(update={"query": rephrase}, deep=True)
run_queries.append(
(
doc_index_retrieval,
(q_copy, document_index, db_session),
)
)
parallel_search_results = run_functions_tuples_in_parallel(run_queries)
top_chunks = combine_retrieval_results(parallel_search_results)
if not top_chunks:
logger.warning(
f"Hybrid ({query.search_type.value.capitalize()}) search returned no results "
f"with filters: {query.filters}"
)
return []
if retrieval_metrics_callback is not None:
chunk_metrics = [
ChunkMetric(
document_id=chunk.document_id,
chunk_content_start=chunk.content[:MAX_METRICS_CONTENT],
first_link=chunk.source_links[0] if chunk.source_links else None,
score=chunk.score if chunk.score is not None else 0,
)
for chunk in top_chunks
]
retrieval_metrics_callback(
RetrievalMetricsContainer(
search_type=query.search_type, metrics=chunk_metrics
)
)
return top_chunks
def inference_sections_from_ids(
doc_identifiers: list[tuple[str, int]],
document_index: DocumentIndex,
) -> list[InferenceSection]:
# Currently only fetches whole docs
doc_ids_set = set(doc_id for doc_id, _ in doc_identifiers)
chunk_requests: list[VespaChunkRequest] = [
VespaChunkRequest(document_id=doc_id) for doc_id in doc_ids_set
]
# No need for ACL here because the doc ids were validated beforehand
filters = IndexFilters(access_control_list=None)
retrieved_chunks = document_index.id_based_retrieval(
chunk_requests=chunk_requests,
filters=filters,
)
cleaned_chunks = cleanup_chunks(retrieved_chunks)
if not cleaned_chunks:
return []
# Group chunks by document ID
chunks_by_doc_id: dict[str, list[InferenceChunk]] = {}
for chunk in cleaned_chunks:
chunks_by_doc_id.setdefault(chunk.document_id, []).append(chunk)
inference_sections = [
section
for chunks in chunks_by_doc_id.values()
if chunks
and (
section := inference_section_from_chunks(
# The scores will always be 0 because the fetching by id gives back
# no search scores. This is not needed though if the user is explicitly
# selecting a document.
center_chunk=chunks[0],
chunks=chunks,
)
)
]
return inference_sections