mirror of
https://github.com/danswer-ai/danswer.git
synced 2025-09-20 04:37:09 +02:00
Prep for Hybrid Search (#648)
This commit is contained in:
@@ -38,9 +38,11 @@ from danswer.llm.llm import LLM
|
||||
from danswer.llm.utils import get_default_llm_tokenizer
|
||||
from danswer.llm.utils import translate_danswer_msg_to_langchain
|
||||
from danswer.search.access_filters import build_access_filters_for_user
|
||||
from danswer.search.models import IndexFilters
|
||||
from danswer.search.models import SearchQuery
|
||||
from danswer.search.models import SearchType
|
||||
from danswer.search.search_runner import chunks_to_search_docs
|
||||
from danswer.search.search_runner import retrieve_ranked_documents
|
||||
from danswer.server.models import IndexFilters
|
||||
from danswer.search.search_runner import search_chunks
|
||||
from danswer.server.models import RetrievalDocs
|
||||
from danswer.utils.logger import setup_logger
|
||||
from danswer.utils.text_processing import extract_embedded_json
|
||||
@@ -130,13 +132,18 @@ def danswer_chat_retrieval(
|
||||
else:
|
||||
reworded_query = query_message.message
|
||||
|
||||
# Good Debug/Breakpoint
|
||||
ranked_chunks, unranked_chunks = retrieve_ranked_documents(
|
||||
search_query = SearchQuery(
|
||||
query=reworded_query,
|
||||
search_type=SearchType.HYBRID,
|
||||
filters=filters,
|
||||
favor_recent=False,
|
||||
datastore=get_default_document_index(),
|
||||
)
|
||||
|
||||
# Good Debug/Breakpoint
|
||||
ranked_chunks, unranked_chunks = search_chunks(
|
||||
query=search_query, document_index=get_default_document_index()
|
||||
)
|
||||
|
||||
if not ranked_chunks:
|
||||
return []
|
||||
|
||||
|
@@ -26,9 +26,9 @@ from danswer.danswerbot.slack.utils import respond_in_thread
|
||||
from danswer.db.engine import get_sqlalchemy_engine
|
||||
from danswer.db.models import SlackBotConfig
|
||||
from danswer.direct_qa.answer_question import answer_qa_query
|
||||
from danswer.search.models import BaseFilters
|
||||
from danswer.server.models import QAResponse
|
||||
from danswer.server.models import QuestionRequest
|
||||
from danswer.server.models import RequestFilters
|
||||
from danswer.utils.logger import setup_logger
|
||||
|
||||
logger_base = setup_logger()
|
||||
@@ -182,7 +182,7 @@ def handle_message(
|
||||
try:
|
||||
# By leaving time_cutoff and favor_recent as None, and setting enable_auto_detect_filters
|
||||
# it allows the slack flow to extract out filters from the user query
|
||||
filters = RequestFilters(
|
||||
filters = BaseFilters(
|
||||
source_type=None,
|
||||
document_set=document_set_names,
|
||||
time_cutoff=None,
|
||||
@@ -193,7 +193,6 @@ def handle_message(
|
||||
QuestionRequest(
|
||||
query=msg,
|
||||
collection=DOCUMENT_INDEX_NAME,
|
||||
use_keyword=False, # always use semantic search when handling Slack messages
|
||||
enable_auto_detect_filters=not disable_auto_detect_filters,
|
||||
filters=filters,
|
||||
favor_recent=None,
|
||||
|
@@ -96,14 +96,14 @@ def update_document_hidden(db_session: Session, document_id: str, hidden: bool)
|
||||
def create_query_event(
|
||||
db_session: Session,
|
||||
query: str,
|
||||
selected_flow: SearchType | None,
|
||||
search_type: SearchType | None,
|
||||
llm_answer: str | None,
|
||||
user_id: UUID | None,
|
||||
retrieved_document_ids: list[str] | None = None,
|
||||
) -> int:
|
||||
query_event = QueryEvent(
|
||||
query=query,
|
||||
selected_search_flow=selected_flow,
|
||||
selected_search_flow=search_type,
|
||||
llm_answer=llm_answer,
|
||||
retrieved_document_ids=retrieved_document_ids,
|
||||
user_id=user_id,
|
||||
|
@@ -15,19 +15,17 @@ from danswer.direct_qa.llm_utils import get_default_qa_model
|
||||
from danswer.direct_qa.models import LLMMetricsContainer
|
||||
from danswer.direct_qa.qa_utils import get_usable_chunks
|
||||
from danswer.document_index import get_default_document_index
|
||||
from danswer.indexing.models import InferenceChunk
|
||||
from danswer.search.access_filters import build_access_filters_for_user
|
||||
from danswer.search.danswer_helper import query_intent
|
||||
from danswer.search.models import IndexFilters
|
||||
from danswer.search.models import QueryFlow
|
||||
from danswer.search.models import RerankMetricsContainer
|
||||
from danswer.search.models import RetrievalMetricsContainer
|
||||
from danswer.search.models import SearchType
|
||||
from danswer.search.models import SearchQuery
|
||||
from danswer.search.search_runner import chunks_to_search_docs
|
||||
from danswer.search.search_runner import retrieve_keyword_documents
|
||||
from danswer.search.search_runner import retrieve_ranked_documents
|
||||
from danswer.search.search_runner import search_chunks
|
||||
from danswer.secondary_llm_flows.answer_validation import get_answer_validity
|
||||
from danswer.secondary_llm_flows.extract_filters import extract_question_time_filters
|
||||
from danswer.server.models import IndexFilters
|
||||
from danswer.server.models import QAResponse
|
||||
from danswer.server.models import QuestionRequest
|
||||
from danswer.utils.logger import setup_logger
|
||||
@@ -51,7 +49,6 @@ def answer_qa_query(
|
||||
llm_metrics_callback: Callable[[LLMMetricsContainer], None] | None = None,
|
||||
) -> QAResponse:
|
||||
query = question.query
|
||||
use_keyword = question.use_keyword
|
||||
offset_count = question.offset if question.offset is not None else 0
|
||||
logger.info(f"Received QA query: {query}")
|
||||
|
||||
@@ -61,18 +58,12 @@ def answer_qa_query(
|
||||
|
||||
query_event_id = create_query_event(
|
||||
query=query,
|
||||
selected_flow=SearchType.KEYWORD
|
||||
if question.use_keyword
|
||||
else SearchType.SEMANTIC,
|
||||
search_type=question.search_type,
|
||||
llm_answer=None,
|
||||
user_id=user.id if user is not None else None,
|
||||
db_session=db_session,
|
||||
)
|
||||
|
||||
predicted_search, predicted_flow = query_intent(query)
|
||||
if use_keyword is None:
|
||||
use_keyword = predicted_search == SearchType.KEYWORD
|
||||
|
||||
user_id = None if user is None else user.id
|
||||
user_acl_filters = build_access_filters_for_user(user, db_session)
|
||||
final_filters = IndexFilters(
|
||||
@@ -81,24 +72,23 @@ def answer_qa_query(
|
||||
time_cutoff=time_cutoff,
|
||||
access_control_list=user_acl_filters,
|
||||
)
|
||||
if use_keyword:
|
||||
ranked_chunks: list[InferenceChunk] | None = retrieve_keyword_documents(
|
||||
query=query,
|
||||
filters=final_filters,
|
||||
favor_recent=favor_recent,
|
||||
datastore=get_default_document_index(),
|
||||
retrieval_metrics_callback=retrieval_metrics_callback,
|
||||
)
|
||||
unranked_chunks: list[InferenceChunk] | None = []
|
||||
else:
|
||||
ranked_chunks, unranked_chunks = retrieve_ranked_documents(
|
||||
query=query,
|
||||
filters=final_filters,
|
||||
favor_recent=favor_recent,
|
||||
datastore=get_default_document_index(),
|
||||
retrieval_metrics_callback=retrieval_metrics_callback,
|
||||
rerank_metrics_callback=rerank_metrics_callback,
|
||||
)
|
||||
search_query = SearchQuery(
|
||||
query=query,
|
||||
search_type=question.search_type,
|
||||
filters=final_filters,
|
||||
favor_recent=True if question.favor_recent is None else question.favor_recent,
|
||||
)
|
||||
|
||||
# TODO retire this
|
||||
predicted_search, predicted_flow = query_intent(query)
|
||||
|
||||
ranked_chunks, unranked_chunks = search_chunks(
|
||||
query=search_query,
|
||||
document_index=get_default_document_index(),
|
||||
retrieval_metrics_callback=retrieval_metrics_callback,
|
||||
rerank_metrics_callback=rerank_metrics_callback,
|
||||
)
|
||||
|
||||
if not ranked_chunks:
|
||||
return QAResponse(
|
||||
answer=None,
|
||||
@@ -114,6 +104,7 @@ def answer_qa_query(
|
||||
|
||||
top_docs = chunks_to_search_docs(ranked_chunks)
|
||||
unranked_top_docs = chunks_to_search_docs(unranked_chunks)
|
||||
|
||||
update_query_event_retrieved_documents(
|
||||
db_session=db_session,
|
||||
retrieved_document_ids=[doc.document_id for doc in top_docs],
|
||||
|
@@ -4,10 +4,9 @@ from datetime import datetime
|
||||
from typing import Any
|
||||
|
||||
from danswer.access.models import DocumentAccess
|
||||
from danswer.configs.model_configs import SEARCH_DISTANCE_CUTOFF
|
||||
from danswer.indexing.models import DocMetadataAwareIndexChunk
|
||||
from danswer.indexing.models import InferenceChunk
|
||||
from danswer.server.models import IndexFilters
|
||||
from danswer.search.models import IndexFilters
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
@@ -97,7 +96,6 @@ class VectorCapable(abc.ABC):
|
||||
filters: IndexFilters,
|
||||
favor_recent: bool,
|
||||
num_to_retrieve: int,
|
||||
distance_cutoff: float | None = SEARCH_DISTANCE_CUTOFF,
|
||||
) -> list[InferenceChunk]:
|
||||
raise NotImplementedError
|
||||
|
||||
|
@@ -48,12 +48,13 @@ from danswer.configs.model_configs import SEARCH_DISTANCE_CUTOFF
|
||||
from danswer.document_index.document_index_utils import get_uuid_from_chunk
|
||||
from danswer.document_index.interfaces import DocumentIndex
|
||||
from danswer.document_index.interfaces import DocumentInsertionRecord
|
||||
from danswer.document_index.interfaces import IndexFilters
|
||||
from danswer.document_index.interfaces import UpdateRequest
|
||||
from danswer.document_index.vespa.utils import remove_invalid_unicode_chars
|
||||
from danswer.indexing.models import DocMetadataAwareIndexChunk
|
||||
from danswer.indexing.models import InferenceChunk
|
||||
from danswer.search.models import IndexFilters
|
||||
from danswer.search.search_runner import embed_query
|
||||
from danswer.search.search_runner import query_processing
|
||||
from danswer.search.search_runner import remove_stop_words
|
||||
from danswer.utils.batching import batch_generator
|
||||
from danswer.utils.logger import setup_logger
|
||||
@@ -585,6 +586,7 @@ class VespaIndex(DocumentIndex):
|
||||
filters: IndexFilters,
|
||||
favor_recent: bool,
|
||||
num_to_retrieve: int = NUM_RETURNED_HITS,
|
||||
edit_keyword_query: bool = EDIT_KEYWORD_QUERY,
|
||||
) -> list[InferenceChunk]:
|
||||
decay_multiplier = FAVOR_RECENT_DECAY_MULTIPLIER if favor_recent else 1
|
||||
vespa_where_clauses = _build_vespa_filters(filters)
|
||||
@@ -599,9 +601,11 @@ class VespaIndex(DocumentIndex):
|
||||
+ _build_vespa_limit(num_to_retrieve)
|
||||
)
|
||||
|
||||
final_query = query_processing(query) if edit_keyword_query else query
|
||||
|
||||
params: dict[str, str | int] = {
|
||||
"yql": yql,
|
||||
"query": query,
|
||||
"query": final_query,
|
||||
"input.query(decay_factor)": str(DOC_TIME_DECAY * decay_multiplier),
|
||||
"hits": num_to_retrieve,
|
||||
"num_to_rerank": 10 * num_to_retrieve,
|
||||
@@ -617,6 +621,7 @@ class VespaIndex(DocumentIndex):
|
||||
favor_recent: bool,
|
||||
num_to_retrieve: int = NUM_RETURNED_HITS,
|
||||
distance_cutoff: float | None = SEARCH_DISTANCE_CUTOFF,
|
||||
edit_keyword_query: bool = EDIT_KEYWORD_QUERY,
|
||||
) -> list[InferenceChunk]:
|
||||
decay_multiplier = FAVOR_RECENT_DECAY_MULTIPLIER if favor_recent else 1
|
||||
vespa_where_clauses = _build_vespa_filters(filters)
|
||||
@@ -634,7 +639,7 @@ class VespaIndex(DocumentIndex):
|
||||
query_embedding = embed_query(query)
|
||||
|
||||
query_keywords = (
|
||||
" ".join(remove_stop_words(query)) if EDIT_KEYWORD_QUERY else query
|
||||
" ".join(remove_stop_words(query)) if edit_keyword_query else query
|
||||
)
|
||||
|
||||
params = {
|
||||
@@ -653,31 +658,11 @@ class VespaIndex(DocumentIndex):
|
||||
filters: IndexFilters,
|
||||
favor_recent: bool,
|
||||
num_to_retrieve: int,
|
||||
distance_cutoff: float | None = SEARCH_DISTANCE_CUTOFF,
|
||||
edit_keyword_query: bool = EDIT_KEYWORD_QUERY,
|
||||
) -> list[InferenceChunk]:
|
||||
vespa_where_clauses = _build_vespa_filters(filters)
|
||||
yql = (
|
||||
VespaIndex.yql_base
|
||||
+ vespa_where_clauses
|
||||
+ f"({{targetHits: {10 * num_to_retrieve}}}nearestNeighbor(embeddings, query_embedding)) or "
|
||||
+ '({grammar: "weakAnd"}userInput(@query) '
|
||||
# `({defaultIndex: "content_summary"}userInput(@query))` section is
|
||||
# needed for highlighting while the N-gram highlighting is broken /
|
||||
# not working as desired
|
||||
+ f'or ({{defaultIndex: "{CONTENT_SUMMARY}"}}userInput(@query)))'
|
||||
+ _build_vespa_limit(num_to_retrieve)
|
||||
)
|
||||
|
||||
query_embedding = embed_query(query)
|
||||
|
||||
params = {
|
||||
"yql": yql,
|
||||
"query": query,
|
||||
"input.query(query_embedding)": str(query_embedding),
|
||||
"input.query(decay_factor)": str(DOC_TIME_DECAY),
|
||||
"ranking.profile": "hybrid_search",
|
||||
}
|
||||
|
||||
return _query_vespa(params)
|
||||
# TODO introduce the real hybrid search
|
||||
return self.semantic_retrieval(query, filters, favor_recent, num_to_retrieve)
|
||||
|
||||
def admin_retrieval(
|
||||
self,
|
||||
|
@@ -2,7 +2,7 @@ from sqlalchemy.orm import Session
|
||||
|
||||
from danswer.access.access import get_acl_for_user
|
||||
from danswer.db.models import User
|
||||
from danswer.server.models import IndexFilters
|
||||
from danswer.search.models import IndexFilters
|
||||
|
||||
|
||||
def build_access_filters_for_user(user: User | None, session: Session) -> list[str]:
|
||||
|
@@ -66,7 +66,7 @@ def query_intent(query: str) -> tuple[SearchType, QueryFlow]:
|
||||
|
||||
def recommend_search_flow(
|
||||
query: str,
|
||||
keyword: bool,
|
||||
keyword: bool = False,
|
||||
max_percent_stopwords: float = 0.30, # ~Every third word max, ie "effects of caffeine" still viable keyword search
|
||||
) -> HelperResponse:
|
||||
heuristic_search_type: SearchType | None = None
|
||||
|
@@ -1,11 +1,14 @@
|
||||
from datetime import datetime
|
||||
from enum import Enum
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from danswer.configs.app_configs import NUM_RERANKED_RESULTS
|
||||
from danswer.configs.app_configs import NUM_RETURNED_HITS
|
||||
from danswer.configs.model_configs import SKIP_RERANKING
|
||||
from danswer.indexing.models import DocAwareChunk
|
||||
from danswer.indexing.models import IndexChunk
|
||||
|
||||
|
||||
MAX_METRICS_CONTENT = (
|
||||
200 # Just need enough characters to identify where in the doc the chunk is
|
||||
)
|
||||
@@ -27,6 +30,16 @@ class Embedder:
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class BaseFilters(BaseModel):
|
||||
source_type: list[str] | None = None
|
||||
document_set: list[str] | None = None
|
||||
time_cutoff: datetime | None = None
|
||||
|
||||
|
||||
class IndexFilters(BaseFilters):
|
||||
access_control_list: list[str]
|
||||
|
||||
|
||||
class ChunkMetric(BaseModel):
|
||||
document_id: str
|
||||
chunk_content_start: str
|
||||
@@ -34,6 +47,17 @@ class ChunkMetric(BaseModel):
|
||||
score: float
|
||||
|
||||
|
||||
class SearchQuery(BaseModel):
|
||||
query: str
|
||||
search_type: SearchType
|
||||
filters: IndexFilters
|
||||
favor_recent: bool
|
||||
num_hits: int = NUM_RETURNED_HITS
|
||||
skip_rerank: bool = SKIP_RERANKING
|
||||
# Only used if not skip_rerank
|
||||
num_rerank: int | None = NUM_RERANKED_RESULTS
|
||||
|
||||
|
||||
class RetrievalMetricsContainer(BaseModel):
|
||||
keyword_search: bool # False for Vector Search
|
||||
metrics: list[ChunkMetric] # This contains the scores for retrieval as well
|
||||
|
@@ -6,26 +6,23 @@ from nltk.stem import WordNetLemmatizer # type:ignore
|
||||
from nltk.tokenize import word_tokenize # type:ignore
|
||||
from sentence_transformers import SentenceTransformer # type: ignore
|
||||
|
||||
from danswer.configs.app_configs import EDIT_KEYWORD_QUERY
|
||||
from danswer.configs.app_configs import NUM_RERANKED_RESULTS
|
||||
from danswer.configs.app_configs import NUM_RETURNED_HITS
|
||||
from danswer.configs.model_configs import ASYM_QUERY_PREFIX
|
||||
from danswer.configs.model_configs import CROSS_ENCODER_RANGE_MAX
|
||||
from danswer.configs.model_configs import CROSS_ENCODER_RANGE_MIN
|
||||
from danswer.configs.model_configs import NORMALIZE_EMBEDDINGS
|
||||
from danswer.configs.model_configs import SIM_SCORE_RANGE_HIGH
|
||||
from danswer.configs.model_configs import SIM_SCORE_RANGE_LOW
|
||||
from danswer.configs.model_configs import SKIP_RERANKING
|
||||
from danswer.document_index.document_index_utils import (
|
||||
translate_boost_count_to_multiplier,
|
||||
)
|
||||
from danswer.document_index.interfaces import DocumentIndex
|
||||
from danswer.document_index.interfaces import IndexFilters
|
||||
from danswer.indexing.models import InferenceChunk
|
||||
from danswer.search.models import ChunkMetric
|
||||
from danswer.search.models import MAX_METRICS_CONTENT
|
||||
from danswer.search.models import RerankMetricsContainer
|
||||
from danswer.search.models import RetrievalMetricsContainer
|
||||
from danswer.search.models import SearchQuery
|
||||
from danswer.search.models import SearchType
|
||||
from danswer.search.search_nlp_models import get_default_embedding_model
|
||||
from danswer.search.search_nlp_models import get_default_reranking_model_ensemble
|
||||
from danswer.server.models import SearchDoc
|
||||
@@ -56,6 +53,24 @@ def query_processing(
|
||||
return query
|
||||
|
||||
|
||||
def embed_query(
|
||||
query: str,
|
||||
embedding_model: SentenceTransformer | None = None,
|
||||
prefix: str = ASYM_QUERY_PREFIX,
|
||||
normalize_embeddings: bool = NORMALIZE_EMBEDDINGS,
|
||||
) -> list[float]:
|
||||
model = embedding_model or get_default_embedding_model()
|
||||
prefixed_query = prefix + query
|
||||
query_embedding = model.encode(
|
||||
prefixed_query, normalize_embeddings=normalize_embeddings
|
||||
)
|
||||
|
||||
if not isinstance(query_embedding, list):
|
||||
query_embedding = query_embedding.tolist()
|
||||
|
||||
return query_embedding
|
||||
|
||||
|
||||
def chunks_to_search_docs(chunks: list[InferenceChunk] | None) -> list[SearchDoc]:
|
||||
search_docs = (
|
||||
[
|
||||
@@ -233,73 +248,44 @@ def apply_boost(
|
||||
return final_chunks
|
||||
|
||||
|
||||
@log_function_time()
|
||||
def retrieve_keyword_documents(
|
||||
query: str,
|
||||
filters: IndexFilters,
|
||||
favor_recent: bool,
|
||||
datastore: DocumentIndex,
|
||||
num_hits: int = NUM_RETURNED_HITS,
|
||||
edit_query: bool = EDIT_KEYWORD_QUERY,
|
||||
retrieval_metrics_callback: Callable[[RetrievalMetricsContainer], None]
|
||||
| None = None,
|
||||
) -> list[InferenceChunk] | None:
|
||||
edited_query = query_processing(query) if edit_query else query
|
||||
|
||||
top_chunks = datastore.keyword_retrieval(
|
||||
edited_query, filters, favor_recent, num_hits
|
||||
)
|
||||
|
||||
if not top_chunks:
|
||||
logger.warning(
|
||||
f"Keyword search returned no results - Filters: {filters}\tEdited Query: {edited_query}"
|
||||
)
|
||||
return None
|
||||
|
||||
if retrieval_metrics_callback is not None:
|
||||
chunk_metrics = [
|
||||
ChunkMetric(
|
||||
document_id=chunk.document_id,
|
||||
chunk_content_start=chunk.content[:MAX_METRICS_CONTENT],
|
||||
first_link=chunk.source_links[0] if chunk.source_links else None,
|
||||
score=chunk.score if chunk.score is not None else 0,
|
||||
)
|
||||
for chunk in top_chunks
|
||||
]
|
||||
retrieval_metrics_callback(
|
||||
RetrievalMetricsContainer(keyword_search=True, metrics=chunk_metrics)
|
||||
)
|
||||
|
||||
return top_chunks
|
||||
|
||||
|
||||
@log_function_time()
|
||||
def retrieve_ranked_documents(
|
||||
query: str,
|
||||
filters: IndexFilters,
|
||||
favor_recent: bool,
|
||||
datastore: DocumentIndex,
|
||||
num_hits: int = NUM_RETURNED_HITS,
|
||||
num_rerank: int = NUM_RERANKED_RESULTS,
|
||||
skip_rerank: bool = SKIP_RERANKING,
|
||||
def search_chunks(
|
||||
query: SearchQuery,
|
||||
document_index: DocumentIndex,
|
||||
retrieval_metrics_callback: Callable[[RetrievalMetricsContainer], None]
|
||||
| None = None,
|
||||
rerank_metrics_callback: Callable[[RerankMetricsContainer], None] | None = None,
|
||||
) -> tuple[list[InferenceChunk] | None, list[InferenceChunk] | None]:
|
||||
"""Uses vector similarity to fetch the top num_hits document chunks with a distance cutoff.
|
||||
Reranks the top num_rerank out of those (instead of all due to latency)"""
|
||||
def _log_top_chunk_links(search_flow: str, chunks: list[InferenceChunk]) -> None:
|
||||
top_links = [
|
||||
c.source_links[0] if c.source_links is not None else "No Link"
|
||||
for c in chunks
|
||||
]
|
||||
logger.info(f"Top links from {search_flow} search: {', '.join(top_links)}")
|
||||
|
||||
def _log_top_chunk_links(chunks: list[InferenceChunk]) -> None:
|
||||
doc_links = [c.source_links[0] for c in chunks if c.source_links is not None]
|
||||
if query.search_type == SearchType.KEYWORD:
|
||||
top_chunks = document_index.keyword_retrieval(
|
||||
query.query, query.filters, query.favor_recent, query.num_hits
|
||||
)
|
||||
|
||||
files_log_msg = f"Top links from semantic search: {', '.join(doc_links)}"
|
||||
logger.info(files_log_msg)
|
||||
elif query.search_type == SearchType.SEMANTIC:
|
||||
top_chunks = document_index.semantic_retrieval(
|
||||
query.query, query.filters, query.favor_recent, query.num_hits
|
||||
)
|
||||
|
||||
elif query.search_type == SearchType.HYBRID:
|
||||
top_chunks = document_index.hybrid_retrieval(
|
||||
query.query, query.filters, query.favor_recent, query.num_hits
|
||||
)
|
||||
|
||||
else:
|
||||
raise RuntimeError("Invalid Search Flow")
|
||||
|
||||
top_chunks = datastore.semantic_retrieval(query, filters, favor_recent, num_hits)
|
||||
if not top_chunks:
|
||||
logger.info(f"Semantic search returned no results with filters: {filters}")
|
||||
logger.info(
|
||||
f"{query.search_type.value.capitalize()} search returned no results "
|
||||
f"with filters: {query.filters}"
|
||||
)
|
||||
return None, None
|
||||
logger.debug(top_chunks)
|
||||
|
||||
if retrieval_metrics_callback is not None:
|
||||
chunk_metrics = [
|
||||
@@ -315,40 +301,24 @@ def retrieve_ranked_documents(
|
||||
RetrievalMetricsContainer(keyword_search=True, metrics=chunk_metrics)
|
||||
)
|
||||
|
||||
if skip_rerank:
|
||||
# Keyword Search should never do reranking, no transformers involved in this flow
|
||||
if query.search_type == SearchType.KEYWORD:
|
||||
_log_top_chunk_links(query.search_type.value, top_chunks)
|
||||
return top_chunks, None
|
||||
|
||||
if query.skip_rerank:
|
||||
# Need the range of values to not be too spread out for applying boost
|
||||
boosted_chunks = apply_boost(top_chunks[:num_rerank])
|
||||
_log_top_chunk_links(boosted_chunks)
|
||||
return boosted_chunks, top_chunks[num_rerank:]
|
||||
# Therefore pass in smaller set of chunks to limit the range for norm-ing
|
||||
boosted_chunks = apply_boost(top_chunks[: query.num_rerank])
|
||||
_log_top_chunk_links(query.search_type.value, boosted_chunks)
|
||||
return boosted_chunks, top_chunks[query.num_rerank :]
|
||||
|
||||
ranked_chunks = (
|
||||
semantic_reranking(
|
||||
query,
|
||||
top_chunks[:num_rerank],
|
||||
rerank_metrics_callback=rerank_metrics_callback,
|
||||
)
|
||||
if not skip_rerank
|
||||
else []
|
||||
ranked_chunks = semantic_reranking(
|
||||
query.query,
|
||||
top_chunks[: query.num_rerank],
|
||||
rerank_metrics_callback=rerank_metrics_callback,
|
||||
)
|
||||
|
||||
_log_top_chunk_links(ranked_chunks)
|
||||
_log_top_chunk_links(query.search_type.value, ranked_chunks)
|
||||
|
||||
return ranked_chunks, top_chunks[num_rerank:]
|
||||
|
||||
|
||||
def embed_query(
|
||||
query: str,
|
||||
embedding_model: SentenceTransformer | None = None,
|
||||
prefix: str = ASYM_QUERY_PREFIX,
|
||||
normalize_embeddings: bool = NORMALIZE_EMBEDDINGS,
|
||||
) -> list[float]:
|
||||
model = embedding_model or get_default_embedding_model()
|
||||
prefixed_query = prefix + query
|
||||
query_embedding = model.encode(
|
||||
prefixed_query, normalize_embeddings=normalize_embeddings
|
||||
)
|
||||
|
||||
if not isinstance(query_embedding, list):
|
||||
query_embedding = query_embedding.tolist()
|
||||
|
||||
return query_embedding
|
||||
return ranked_chunks, top_chunks[query.num_rerank :]
|
||||
|
@@ -27,6 +27,7 @@ from danswer.db.models import IndexAttempt
|
||||
from danswer.db.models import IndexingStatus
|
||||
from danswer.db.models import TaskStatus
|
||||
from danswer.direct_qa.interfaces import DanswerQuote
|
||||
from danswer.search.models import BaseFilters
|
||||
from danswer.search.models import QueryFlow
|
||||
from danswer.search.models import SearchType
|
||||
from danswer.server.utils import mask_credential_dict
|
||||
@@ -189,25 +190,14 @@ class CreateChatSessionID(BaseModel):
|
||||
chat_session_id: int
|
||||
|
||||
|
||||
class RequestFilters(BaseModel):
|
||||
source_type: list[str] | None
|
||||
document_set: list[str] | None
|
||||
time_cutoff: datetime | None = None
|
||||
|
||||
|
||||
class IndexFilters(RequestFilters):
|
||||
access_control_list: list[str]
|
||||
|
||||
|
||||
class QuestionRequest(BaseModel):
|
||||
query: str
|
||||
collection: str
|
||||
filters: RequestFilters
|
||||
filters: BaseFilters
|
||||
offset: int | None
|
||||
enable_auto_detect_filters: bool
|
||||
favor_recent: bool | None = None
|
||||
use_keyword: bool | None # TODO remove this for hybrid search
|
||||
search_flow: SearchType | None = None # Default hybrid
|
||||
search_type: SearchType = SearchType.HYBRID
|
||||
|
||||
|
||||
class QAFeedbackRequest(BaseModel):
|
||||
|
@@ -27,20 +27,18 @@ from danswer.direct_qa.llm_utils import get_default_qa_model
|
||||
from danswer.direct_qa.qa_utils import get_usable_chunks
|
||||
from danswer.document_index import get_default_document_index
|
||||
from danswer.document_index.vespa.index import VespaIndex
|
||||
from danswer.indexing.models import InferenceChunk
|
||||
from danswer.search.access_filters import build_access_filters_for_user
|
||||
from danswer.search.danswer_helper import query_intent
|
||||
from danswer.search.danswer_helper import recommend_search_flow
|
||||
from danswer.search.models import IndexFilters
|
||||
from danswer.search.models import QueryFlow
|
||||
from danswer.search.models import SearchType
|
||||
from danswer.search.models import SearchQuery
|
||||
from danswer.search.search_runner import chunks_to_search_docs
|
||||
from danswer.search.search_runner import retrieve_keyword_documents
|
||||
from danswer.search.search_runner import retrieve_ranked_documents
|
||||
from danswer.search.search_runner import search_chunks
|
||||
from danswer.secondary_llm_flows.extract_filters import extract_question_time_filters
|
||||
from danswer.secondary_llm_flows.query_validation import get_query_answerability
|
||||
from danswer.secondary_llm_flows.query_validation import stream_query_answerability
|
||||
from danswer.server.models import HelperResponse
|
||||
from danswer.server.models import IndexFilters
|
||||
from danswer.server.models import QAFeedbackRequest
|
||||
from danswer.server.models import QAResponse
|
||||
from danswer.server.models import QueryValidationResponse
|
||||
@@ -114,8 +112,7 @@ def get_search_type(
|
||||
question: QuestionRequest, _: User = Depends(current_user)
|
||||
) -> HelperResponse:
|
||||
query = question.query
|
||||
use_keyword = question.use_keyword if question.use_keyword is not None else False
|
||||
return recommend_search_flow(query, use_keyword)
|
||||
return recommend_search_flow(query)
|
||||
|
||||
|
||||
@router.post("/query-validation")
|
||||
@@ -137,14 +134,14 @@ def stream_query_validation(
|
||||
)
|
||||
|
||||
|
||||
@router.post("/keyword-search")
|
||||
def keyword_search(
|
||||
@router.post("/document-search")
|
||||
def handle_search_request(
|
||||
question: QuestionRequest,
|
||||
user: User = Depends(current_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> SearchResponse:
|
||||
query = question.query
|
||||
logger.info(f"Received keyword search query: {query}")
|
||||
logger.info(f"Received {question.search_type.value} " f"search query: {query}")
|
||||
|
||||
time_cutoff, favor_recent = extract_question_time_filters(question)
|
||||
question.filters.time_cutoff = time_cutoff
|
||||
@@ -152,7 +149,7 @@ def keyword_search(
|
||||
|
||||
query_event_id = create_query_event(
|
||||
query=query,
|
||||
selected_flow=SearchType.KEYWORD,
|
||||
search_type=question.search_type,
|
||||
llm_answer=None,
|
||||
user_id=user.id,
|
||||
db_session=db_session,
|
||||
@@ -166,12 +163,18 @@ def keyword_search(
|
||||
time_cutoff=filters.time_cutoff,
|
||||
access_control_list=user_acl_filters,
|
||||
)
|
||||
ranked_chunks = retrieve_keyword_documents(
|
||||
|
||||
search_query = SearchQuery(
|
||||
query=query,
|
||||
search_type=question.search_type,
|
||||
filters=final_filters,
|
||||
favor_recent=favor_recent,
|
||||
datastore=get_default_document_index(),
|
||||
)
|
||||
|
||||
ranked_chunks, unranked_chunks = search_chunks(
|
||||
query=search_query, document_index=get_default_document_index()
|
||||
)
|
||||
|
||||
if not ranked_chunks:
|
||||
return SearchResponse(
|
||||
top_ranked_docs=None,
|
||||
@@ -182,6 +185,8 @@ def keyword_search(
|
||||
)
|
||||
|
||||
top_docs = chunks_to_search_docs(ranked_chunks)
|
||||
lower_top_docs = chunks_to_search_docs(unranked_chunks)
|
||||
|
||||
update_query_event_retrieved_documents(
|
||||
db_session=db_session,
|
||||
retrieved_document_ids=[doc.document_id for doc in top_docs],
|
||||
@@ -191,132 +196,7 @@ def keyword_search(
|
||||
|
||||
return SearchResponse(
|
||||
top_ranked_docs=top_docs,
|
||||
lower_ranked_docs=None,
|
||||
query_event_id=query_event_id,
|
||||
time_cutoff=time_cutoff,
|
||||
favor_recent=favor_recent,
|
||||
)
|
||||
|
||||
|
||||
@router.post("/semantic-search")
|
||||
def semantic_search(
|
||||
question: QuestionRequest,
|
||||
user: User = Depends(current_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> SearchResponse:
|
||||
query = question.query
|
||||
logger.info(f"Received semantic search query: {query}")
|
||||
|
||||
time_cutoff, favor_recent = extract_question_time_filters(question)
|
||||
question.filters.time_cutoff = time_cutoff
|
||||
filters = question.filters
|
||||
|
||||
query_event_id = create_query_event(
|
||||
query=query,
|
||||
selected_flow=SearchType.SEMANTIC,
|
||||
llm_answer=None,
|
||||
user_id=user.id,
|
||||
db_session=db_session,
|
||||
)
|
||||
|
||||
user_id = None if user is None else user.id
|
||||
user_acl_filters = build_access_filters_for_user(user, db_session)
|
||||
final_filters = IndexFilters(
|
||||
source_type=filters.source_type,
|
||||
document_set=filters.document_set,
|
||||
time_cutoff=filters.time_cutoff,
|
||||
access_control_list=user_acl_filters,
|
||||
)
|
||||
ranked_chunks, unranked_chunks = retrieve_ranked_documents(
|
||||
query=query,
|
||||
filters=final_filters,
|
||||
favor_recent=favor_recent,
|
||||
datastore=get_default_document_index(),
|
||||
)
|
||||
if not ranked_chunks:
|
||||
return SearchResponse(
|
||||
top_ranked_docs=None,
|
||||
lower_ranked_docs=None,
|
||||
query_event_id=query_event_id,
|
||||
time_cutoff=time_cutoff,
|
||||
favor_recent=favor_recent,
|
||||
)
|
||||
|
||||
top_docs = chunks_to_search_docs(ranked_chunks)
|
||||
other_top_docs = chunks_to_search_docs(unranked_chunks)
|
||||
update_query_event_retrieved_documents(
|
||||
db_session=db_session,
|
||||
retrieved_document_ids=[doc.document_id for doc in top_docs],
|
||||
query_id=query_event_id,
|
||||
user_id=user_id,
|
||||
)
|
||||
|
||||
return SearchResponse(
|
||||
top_ranked_docs=top_docs,
|
||||
lower_ranked_docs=other_top_docs,
|
||||
query_event_id=query_event_id,
|
||||
time_cutoff=time_cutoff,
|
||||
favor_recent=favor_recent,
|
||||
)
|
||||
|
||||
|
||||
# TODO don't use this, not done yet
|
||||
@router.post("/hybrid-search")
|
||||
def hybrid_search(
|
||||
question: QuestionRequest,
|
||||
user: User = Depends(current_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> SearchResponse:
|
||||
query = question.query
|
||||
logger.info(f"Received hybrid search query: {query}")
|
||||
|
||||
time_cutoff, favor_recent = extract_question_time_filters(question)
|
||||
question.filters.time_cutoff = time_cutoff
|
||||
filters = question.filters
|
||||
|
||||
query_event_id = create_query_event(
|
||||
query=query,
|
||||
selected_flow=SearchType.HYBRID,
|
||||
llm_answer=None,
|
||||
user_id=user.id,
|
||||
db_session=db_session,
|
||||
)
|
||||
|
||||
user_id = None if user is None else user.id
|
||||
user_acl_filters = build_access_filters_for_user(user, db_session)
|
||||
final_filters = IndexFilters(
|
||||
source_type=filters.source_type,
|
||||
document_set=filters.document_set,
|
||||
time_cutoff=filters.time_cutoff,
|
||||
access_control_list=user_acl_filters,
|
||||
)
|
||||
ranked_chunks, unranked_chunks = retrieve_ranked_documents(
|
||||
query=query,
|
||||
filters=final_filters,
|
||||
favor_recent=favor_recent,
|
||||
datastore=get_default_document_index(),
|
||||
)
|
||||
if not ranked_chunks:
|
||||
return SearchResponse(
|
||||
top_ranked_docs=None,
|
||||
lower_ranked_docs=None,
|
||||
query_event_id=query_event_id,
|
||||
time_cutoff=time_cutoff,
|
||||
favor_recent=favor_recent,
|
||||
)
|
||||
|
||||
top_docs = chunks_to_search_docs(ranked_chunks)
|
||||
other_top_docs = chunks_to_search_docs(unranked_chunks)
|
||||
update_query_event_retrieved_documents(
|
||||
db_session=db_session,
|
||||
retrieved_document_ids=[doc.document_id for doc in top_docs],
|
||||
query_id=query_event_id,
|
||||
user_id=user_id,
|
||||
)
|
||||
|
||||
return SearchResponse(
|
||||
top_ranked_docs=top_docs,
|
||||
lower_ranked_docs=other_top_docs,
|
||||
lower_ranked_docs=lower_top_docs or None,
|
||||
query_event_id=query_event_id,
|
||||
time_cutoff=time_cutoff,
|
||||
favor_recent=favor_recent,
|
||||
@@ -347,10 +227,10 @@ def stream_direct_qa(
|
||||
predicted_search_key = "predicted_search"
|
||||
query_event_id_key = "query_event_id"
|
||||
|
||||
logger.debug(f"Received QA query: {question.query}")
|
||||
logger.debug(
|
||||
f"Received QA query ({question.search_type.value} search): {question.query}"
|
||||
)
|
||||
logger.debug(f"Query filters: {question.filters}")
|
||||
if question.use_keyword:
|
||||
logger.debug("User selected Keyword Search")
|
||||
|
||||
@log_generator_function_time()
|
||||
def stream_qa_portions(
|
||||
@@ -358,40 +238,34 @@ def stream_direct_qa(
|
||||
) -> Generator[str, None, None]:
|
||||
answer_so_far: str = ""
|
||||
query = question.query
|
||||
use_keyword = question.use_keyword
|
||||
offset_count = question.offset if question.offset is not None else 0
|
||||
|
||||
time_cutoff, favor_recent = extract_question_time_filters(question)
|
||||
question.filters.time_cutoff = time_cutoff
|
||||
question.filters.time_cutoff = time_cutoff # not used but just in case
|
||||
filters = question.filters
|
||||
|
||||
predicted_search, predicted_flow = query_intent(query)
|
||||
if use_keyword is None:
|
||||
use_keyword = predicted_search == SearchType.KEYWORD
|
||||
|
||||
user_id = None if user is None else user.id
|
||||
user_acl_filters = build_access_filters_for_user(user, db_session)
|
||||
final_filters = IndexFilters(
|
||||
source_type=filters.source_type,
|
||||
document_set=filters.document_set,
|
||||
time_cutoff=filters.time_cutoff,
|
||||
time_cutoff=time_cutoff,
|
||||
access_control_list=user_acl_filters,
|
||||
)
|
||||
if use_keyword:
|
||||
ranked_chunks: list[InferenceChunk] | None = retrieve_keyword_documents(
|
||||
query=query,
|
||||
filters=final_filters,
|
||||
favor_recent=favor_recent,
|
||||
datastore=get_default_document_index(),
|
||||
)
|
||||
unranked_chunks: list[InferenceChunk] | None = []
|
||||
else:
|
||||
ranked_chunks, unranked_chunks = retrieve_ranked_documents(
|
||||
query=query,
|
||||
filters=final_filters,
|
||||
favor_recent=favor_recent,
|
||||
datastore=get_default_document_index(),
|
||||
)
|
||||
|
||||
search_query = SearchQuery(
|
||||
query=query,
|
||||
search_type=question.search_type,
|
||||
filters=final_filters,
|
||||
favor_recent=favor_recent,
|
||||
)
|
||||
|
||||
ranked_chunks, unranked_chunks = search_chunks(
|
||||
query=search_query, document_index=get_default_document_index()
|
||||
)
|
||||
|
||||
# TODO retire this
|
||||
predicted_search, predicted_flow = query_intent(query)
|
||||
|
||||
if not ranked_chunks:
|
||||
logger.debug("No Documents Found")
|
||||
empty_docs_result = {
|
||||
@@ -473,17 +347,14 @@ def stream_direct_qa(
|
||||
|
||||
query_event_id = create_query_event(
|
||||
query=query,
|
||||
selected_flow=SearchType.KEYWORD
|
||||
if question.use_keyword
|
||||
else SearchType.SEMANTIC,
|
||||
search_type=question.search_type,
|
||||
llm_answer=answer_so_far,
|
||||
retrieved_document_ids=[doc.document_id for doc in top_docs],
|
||||
user_id=user_id,
|
||||
user_id=None if user is None else user.id,
|
||||
db_session=db_session,
|
||||
)
|
||||
|
||||
yield get_json_line({query_event_id_key: query_event_id})
|
||||
|
||||
return
|
||||
|
||||
return StreamingResponse(stream_qa_portions(), media_type="application/json")
|
||||
|
@@ -14,27 +14,19 @@ if __name__ == "__main__":
|
||||
previous_query = None
|
||||
parser = argparse.ArgumentParser()
|
||||
|
||||
parser.add_argument(
|
||||
"-f",
|
||||
"--flow",
|
||||
type=str,
|
||||
default="QA",
|
||||
help='"Search" or "QA", defaults to "QA"',
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"-t",
|
||||
"--type",
|
||||
type=str,
|
||||
default="Semantic",
|
||||
help='"Semantic" or "Keyword", defaults to "Semantic"',
|
||||
default="hybrid",
|
||||
help='"hybrid" "semantic" or "keyword", defaults to "hybrid"',
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"-s",
|
||||
"--stream",
|
||||
action="store_true",
|
||||
help='Enable streaming response, only for flow="QA"',
|
||||
help="Enable streaming response",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
@@ -50,7 +42,7 @@ if __name__ == "__main__":
|
||||
try:
|
||||
user_input = input(
|
||||
"\n\nAsk any question:\n"
|
||||
" - Use -f (QA/Search) and -t (Semantic/Keyword) flags to set endpoint.\n"
|
||||
" - Use -t (hybrid/semantic/keyword) flag to choose search flow.\n"
|
||||
" - prefix with -s to stream answer, --filters web,slack etc. for filters.\n"
|
||||
" - input an empty string to rerun last query.\n\t"
|
||||
)
|
||||
@@ -66,25 +58,15 @@ if __name__ == "__main__":
|
||||
|
||||
args = parser.parse_args(user_input.split())
|
||||
|
||||
flow = str(args.flow).lower()
|
||||
flow_type = str(args.type).lower()
|
||||
search_type = str(args.type).lower()
|
||||
stream = args.stream
|
||||
source_types = args.filters.split(",") if args.filters else None
|
||||
if source_types and len(source_types) == 1:
|
||||
source_types = source_types[0]
|
||||
|
||||
query = " ".join(args.query)
|
||||
|
||||
if flow not in ["qa", "search"]:
|
||||
raise ValueError("Flow value must be QA or Search")
|
||||
if flow_type not in ["keyword", "semantic"]:
|
||||
raise ValueError("Type value must be keyword or semantic")
|
||||
if flow != "qa" and stream:
|
||||
raise ValueError("Can only stream results for QA")
|
||||
if search_type not in ["hybrid", "semantic", "keyword"]:
|
||||
raise ValueError("Invalid Search")
|
||||
|
||||
if (flow, flow_type) == ("search", "keyword"):
|
||||
path = "keyword-search"
|
||||
elif (flow, flow_type) == ("search", "semantic"):
|
||||
path = "semantic-search"
|
||||
elif stream:
|
||||
path = "stream-direct-qa"
|
||||
else:
|
||||
@@ -95,8 +77,9 @@ if __name__ == "__main__":
|
||||
query_json = {
|
||||
"query": query,
|
||||
"collection": DOCUMENT_INDEX_NAME,
|
||||
"use_keyword": flow_type == "keyword", # Ignore if not QA Endpoints
|
||||
"filters": [{SOURCE_TYPE: source_types}],
|
||||
"filters": {SOURCE_TYPE: source_types},
|
||||
"enable_auto_detect_filters": True,
|
||||
"search_type": search_type,
|
||||
}
|
||||
|
||||
if args.stream:
|
||||
|
@@ -12,9 +12,9 @@ from danswer.access.access import get_acl_for_user
|
||||
from danswer.db.engine import get_sqlalchemy_engine
|
||||
from danswer.direct_qa.answer_question import answer_qa_query
|
||||
from danswer.direct_qa.models import LLMMetricsContainer
|
||||
from danswer.search.models import IndexFilters
|
||||
from danswer.search.models import RerankMetricsContainer
|
||||
from danswer.search.models import RetrievalMetricsContainer
|
||||
from danswer.server.models import IndexFilters
|
||||
from danswer.server.models import QuestionRequest
|
||||
from danswer.utils.callbacks import MetricsHander
|
||||
|
||||
@@ -85,7 +85,6 @@ def get_answer_for_question(
|
||||
question = QuestionRequest(
|
||||
query=query,
|
||||
collection="danswer_index",
|
||||
use_keyword=False,
|
||||
filters=filters,
|
||||
enable_auto_detect_filters=False,
|
||||
offset=None,
|
||||
|
Reference in New Issue
Block a user