mirror of
https://github.com/danswer-ai/danswer.git
synced 2025-09-27 20:38:32 +02:00
Prep for Hybrid Search (#648)
This commit is contained in:
@@ -38,9 +38,11 @@ from danswer.llm.llm import LLM
|
|||||||
from danswer.llm.utils import get_default_llm_tokenizer
|
from danswer.llm.utils import get_default_llm_tokenizer
|
||||||
from danswer.llm.utils import translate_danswer_msg_to_langchain
|
from danswer.llm.utils import translate_danswer_msg_to_langchain
|
||||||
from danswer.search.access_filters import build_access_filters_for_user
|
from danswer.search.access_filters import build_access_filters_for_user
|
||||||
|
from danswer.search.models import IndexFilters
|
||||||
|
from danswer.search.models import SearchQuery
|
||||||
|
from danswer.search.models import SearchType
|
||||||
from danswer.search.search_runner import chunks_to_search_docs
|
from danswer.search.search_runner import chunks_to_search_docs
|
||||||
from danswer.search.search_runner import retrieve_ranked_documents
|
from danswer.search.search_runner import search_chunks
|
||||||
from danswer.server.models import IndexFilters
|
|
||||||
from danswer.server.models import RetrievalDocs
|
from danswer.server.models import RetrievalDocs
|
||||||
from danswer.utils.logger import setup_logger
|
from danswer.utils.logger import setup_logger
|
||||||
from danswer.utils.text_processing import extract_embedded_json
|
from danswer.utils.text_processing import extract_embedded_json
|
||||||
@@ -130,13 +132,18 @@ def danswer_chat_retrieval(
|
|||||||
else:
|
else:
|
||||||
reworded_query = query_message.message
|
reworded_query = query_message.message
|
||||||
|
|
||||||
# Good Debug/Breakpoint
|
search_query = SearchQuery(
|
||||||
ranked_chunks, unranked_chunks = retrieve_ranked_documents(
|
|
||||||
query=reworded_query,
|
query=reworded_query,
|
||||||
|
search_type=SearchType.HYBRID,
|
||||||
filters=filters,
|
filters=filters,
|
||||||
favor_recent=False,
|
favor_recent=False,
|
||||||
datastore=get_default_document_index(),
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Good Debug/Breakpoint
|
||||||
|
ranked_chunks, unranked_chunks = search_chunks(
|
||||||
|
query=search_query, document_index=get_default_document_index()
|
||||||
|
)
|
||||||
|
|
||||||
if not ranked_chunks:
|
if not ranked_chunks:
|
||||||
return []
|
return []
|
||||||
|
|
||||||
|
@@ -26,9 +26,9 @@ from danswer.danswerbot.slack.utils import respond_in_thread
|
|||||||
from danswer.db.engine import get_sqlalchemy_engine
|
from danswer.db.engine import get_sqlalchemy_engine
|
||||||
from danswer.db.models import SlackBotConfig
|
from danswer.db.models import SlackBotConfig
|
||||||
from danswer.direct_qa.answer_question import answer_qa_query
|
from danswer.direct_qa.answer_question import answer_qa_query
|
||||||
|
from danswer.search.models import BaseFilters
|
||||||
from danswer.server.models import QAResponse
|
from danswer.server.models import QAResponse
|
||||||
from danswer.server.models import QuestionRequest
|
from danswer.server.models import QuestionRequest
|
||||||
from danswer.server.models import RequestFilters
|
|
||||||
from danswer.utils.logger import setup_logger
|
from danswer.utils.logger import setup_logger
|
||||||
|
|
||||||
logger_base = setup_logger()
|
logger_base = setup_logger()
|
||||||
@@ -182,7 +182,7 @@ def handle_message(
|
|||||||
try:
|
try:
|
||||||
# By leaving time_cutoff and favor_recent as None, and setting enable_auto_detect_filters
|
# By leaving time_cutoff and favor_recent as None, and setting enable_auto_detect_filters
|
||||||
# it allows the slack flow to extract out filters from the user query
|
# it allows the slack flow to extract out filters from the user query
|
||||||
filters = RequestFilters(
|
filters = BaseFilters(
|
||||||
source_type=None,
|
source_type=None,
|
||||||
document_set=document_set_names,
|
document_set=document_set_names,
|
||||||
time_cutoff=None,
|
time_cutoff=None,
|
||||||
@@ -193,7 +193,6 @@ def handle_message(
|
|||||||
QuestionRequest(
|
QuestionRequest(
|
||||||
query=msg,
|
query=msg,
|
||||||
collection=DOCUMENT_INDEX_NAME,
|
collection=DOCUMENT_INDEX_NAME,
|
||||||
use_keyword=False, # always use semantic search when handling Slack messages
|
|
||||||
enable_auto_detect_filters=not disable_auto_detect_filters,
|
enable_auto_detect_filters=not disable_auto_detect_filters,
|
||||||
filters=filters,
|
filters=filters,
|
||||||
favor_recent=None,
|
favor_recent=None,
|
||||||
|
@@ -96,14 +96,14 @@ def update_document_hidden(db_session: Session, document_id: str, hidden: bool)
|
|||||||
def create_query_event(
|
def create_query_event(
|
||||||
db_session: Session,
|
db_session: Session,
|
||||||
query: str,
|
query: str,
|
||||||
selected_flow: SearchType | None,
|
search_type: SearchType | None,
|
||||||
llm_answer: str | None,
|
llm_answer: str | None,
|
||||||
user_id: UUID | None,
|
user_id: UUID | None,
|
||||||
retrieved_document_ids: list[str] | None = None,
|
retrieved_document_ids: list[str] | None = None,
|
||||||
) -> int:
|
) -> int:
|
||||||
query_event = QueryEvent(
|
query_event = QueryEvent(
|
||||||
query=query,
|
query=query,
|
||||||
selected_search_flow=selected_flow,
|
selected_search_flow=search_type,
|
||||||
llm_answer=llm_answer,
|
llm_answer=llm_answer,
|
||||||
retrieved_document_ids=retrieved_document_ids,
|
retrieved_document_ids=retrieved_document_ids,
|
||||||
user_id=user_id,
|
user_id=user_id,
|
||||||
|
@@ -15,19 +15,17 @@ from danswer.direct_qa.llm_utils import get_default_qa_model
|
|||||||
from danswer.direct_qa.models import LLMMetricsContainer
|
from danswer.direct_qa.models import LLMMetricsContainer
|
||||||
from danswer.direct_qa.qa_utils import get_usable_chunks
|
from danswer.direct_qa.qa_utils import get_usable_chunks
|
||||||
from danswer.document_index import get_default_document_index
|
from danswer.document_index import get_default_document_index
|
||||||
from danswer.indexing.models import InferenceChunk
|
|
||||||
from danswer.search.access_filters import build_access_filters_for_user
|
from danswer.search.access_filters import build_access_filters_for_user
|
||||||
from danswer.search.danswer_helper import query_intent
|
from danswer.search.danswer_helper import query_intent
|
||||||
|
from danswer.search.models import IndexFilters
|
||||||
from danswer.search.models import QueryFlow
|
from danswer.search.models import QueryFlow
|
||||||
from danswer.search.models import RerankMetricsContainer
|
from danswer.search.models import RerankMetricsContainer
|
||||||
from danswer.search.models import RetrievalMetricsContainer
|
from danswer.search.models import RetrievalMetricsContainer
|
||||||
from danswer.search.models import SearchType
|
from danswer.search.models import SearchQuery
|
||||||
from danswer.search.search_runner import chunks_to_search_docs
|
from danswer.search.search_runner import chunks_to_search_docs
|
||||||
from danswer.search.search_runner import retrieve_keyword_documents
|
from danswer.search.search_runner import search_chunks
|
||||||
from danswer.search.search_runner import retrieve_ranked_documents
|
|
||||||
from danswer.secondary_llm_flows.answer_validation import get_answer_validity
|
from danswer.secondary_llm_flows.answer_validation import get_answer_validity
|
||||||
from danswer.secondary_llm_flows.extract_filters import extract_question_time_filters
|
from danswer.secondary_llm_flows.extract_filters import extract_question_time_filters
|
||||||
from danswer.server.models import IndexFilters
|
|
||||||
from danswer.server.models import QAResponse
|
from danswer.server.models import QAResponse
|
||||||
from danswer.server.models import QuestionRequest
|
from danswer.server.models import QuestionRequest
|
||||||
from danswer.utils.logger import setup_logger
|
from danswer.utils.logger import setup_logger
|
||||||
@@ -51,7 +49,6 @@ def answer_qa_query(
|
|||||||
llm_metrics_callback: Callable[[LLMMetricsContainer], None] | None = None,
|
llm_metrics_callback: Callable[[LLMMetricsContainer], None] | None = None,
|
||||||
) -> QAResponse:
|
) -> QAResponse:
|
||||||
query = question.query
|
query = question.query
|
||||||
use_keyword = question.use_keyword
|
|
||||||
offset_count = question.offset if question.offset is not None else 0
|
offset_count = question.offset if question.offset is not None else 0
|
||||||
logger.info(f"Received QA query: {query}")
|
logger.info(f"Received QA query: {query}")
|
||||||
|
|
||||||
@@ -61,18 +58,12 @@ def answer_qa_query(
|
|||||||
|
|
||||||
query_event_id = create_query_event(
|
query_event_id = create_query_event(
|
||||||
query=query,
|
query=query,
|
||||||
selected_flow=SearchType.KEYWORD
|
search_type=question.search_type,
|
||||||
if question.use_keyword
|
|
||||||
else SearchType.SEMANTIC,
|
|
||||||
llm_answer=None,
|
llm_answer=None,
|
||||||
user_id=user.id if user is not None else None,
|
user_id=user.id if user is not None else None,
|
||||||
db_session=db_session,
|
db_session=db_session,
|
||||||
)
|
)
|
||||||
|
|
||||||
predicted_search, predicted_flow = query_intent(query)
|
|
||||||
if use_keyword is None:
|
|
||||||
use_keyword = predicted_search == SearchType.KEYWORD
|
|
||||||
|
|
||||||
user_id = None if user is None else user.id
|
user_id = None if user is None else user.id
|
||||||
user_acl_filters = build_access_filters_for_user(user, db_session)
|
user_acl_filters = build_access_filters_for_user(user, db_session)
|
||||||
final_filters = IndexFilters(
|
final_filters = IndexFilters(
|
||||||
@@ -81,24 +72,23 @@ def answer_qa_query(
|
|||||||
time_cutoff=time_cutoff,
|
time_cutoff=time_cutoff,
|
||||||
access_control_list=user_acl_filters,
|
access_control_list=user_acl_filters,
|
||||||
)
|
)
|
||||||
if use_keyword:
|
search_query = SearchQuery(
|
||||||
ranked_chunks: list[InferenceChunk] | None = retrieve_keyword_documents(
|
query=query,
|
||||||
query=query,
|
search_type=question.search_type,
|
||||||
filters=final_filters,
|
filters=final_filters,
|
||||||
favor_recent=favor_recent,
|
favor_recent=True if question.favor_recent is None else question.favor_recent,
|
||||||
datastore=get_default_document_index(),
|
)
|
||||||
retrieval_metrics_callback=retrieval_metrics_callback,
|
|
||||||
)
|
# TODO retire this
|
||||||
unranked_chunks: list[InferenceChunk] | None = []
|
predicted_search, predicted_flow = query_intent(query)
|
||||||
else:
|
|
||||||
ranked_chunks, unranked_chunks = retrieve_ranked_documents(
|
ranked_chunks, unranked_chunks = search_chunks(
|
||||||
query=query,
|
query=search_query,
|
||||||
filters=final_filters,
|
document_index=get_default_document_index(),
|
||||||
favor_recent=favor_recent,
|
retrieval_metrics_callback=retrieval_metrics_callback,
|
||||||
datastore=get_default_document_index(),
|
rerank_metrics_callback=rerank_metrics_callback,
|
||||||
retrieval_metrics_callback=retrieval_metrics_callback,
|
)
|
||||||
rerank_metrics_callback=rerank_metrics_callback,
|
|
||||||
)
|
|
||||||
if not ranked_chunks:
|
if not ranked_chunks:
|
||||||
return QAResponse(
|
return QAResponse(
|
||||||
answer=None,
|
answer=None,
|
||||||
@@ -114,6 +104,7 @@ def answer_qa_query(
|
|||||||
|
|
||||||
top_docs = chunks_to_search_docs(ranked_chunks)
|
top_docs = chunks_to_search_docs(ranked_chunks)
|
||||||
unranked_top_docs = chunks_to_search_docs(unranked_chunks)
|
unranked_top_docs = chunks_to_search_docs(unranked_chunks)
|
||||||
|
|
||||||
update_query_event_retrieved_documents(
|
update_query_event_retrieved_documents(
|
||||||
db_session=db_session,
|
db_session=db_session,
|
||||||
retrieved_document_ids=[doc.document_id for doc in top_docs],
|
retrieved_document_ids=[doc.document_id for doc in top_docs],
|
||||||
|
@@ -4,10 +4,9 @@ from datetime import datetime
|
|||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
from danswer.access.models import DocumentAccess
|
from danswer.access.models import DocumentAccess
|
||||||
from danswer.configs.model_configs import SEARCH_DISTANCE_CUTOFF
|
|
||||||
from danswer.indexing.models import DocMetadataAwareIndexChunk
|
from danswer.indexing.models import DocMetadataAwareIndexChunk
|
||||||
from danswer.indexing.models import InferenceChunk
|
from danswer.indexing.models import InferenceChunk
|
||||||
from danswer.server.models import IndexFilters
|
from danswer.search.models import IndexFilters
|
||||||
|
|
||||||
|
|
||||||
@dataclass(frozen=True)
|
@dataclass(frozen=True)
|
||||||
@@ -97,7 +96,6 @@ class VectorCapable(abc.ABC):
|
|||||||
filters: IndexFilters,
|
filters: IndexFilters,
|
||||||
favor_recent: bool,
|
favor_recent: bool,
|
||||||
num_to_retrieve: int,
|
num_to_retrieve: int,
|
||||||
distance_cutoff: float | None = SEARCH_DISTANCE_CUTOFF,
|
|
||||||
) -> list[InferenceChunk]:
|
) -> list[InferenceChunk]:
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
|
@@ -48,12 +48,13 @@ from danswer.configs.model_configs import SEARCH_DISTANCE_CUTOFF
|
|||||||
from danswer.document_index.document_index_utils import get_uuid_from_chunk
|
from danswer.document_index.document_index_utils import get_uuid_from_chunk
|
||||||
from danswer.document_index.interfaces import DocumentIndex
|
from danswer.document_index.interfaces import DocumentIndex
|
||||||
from danswer.document_index.interfaces import DocumentInsertionRecord
|
from danswer.document_index.interfaces import DocumentInsertionRecord
|
||||||
from danswer.document_index.interfaces import IndexFilters
|
|
||||||
from danswer.document_index.interfaces import UpdateRequest
|
from danswer.document_index.interfaces import UpdateRequest
|
||||||
from danswer.document_index.vespa.utils import remove_invalid_unicode_chars
|
from danswer.document_index.vespa.utils import remove_invalid_unicode_chars
|
||||||
from danswer.indexing.models import DocMetadataAwareIndexChunk
|
from danswer.indexing.models import DocMetadataAwareIndexChunk
|
||||||
from danswer.indexing.models import InferenceChunk
|
from danswer.indexing.models import InferenceChunk
|
||||||
|
from danswer.search.models import IndexFilters
|
||||||
from danswer.search.search_runner import embed_query
|
from danswer.search.search_runner import embed_query
|
||||||
|
from danswer.search.search_runner import query_processing
|
||||||
from danswer.search.search_runner import remove_stop_words
|
from danswer.search.search_runner import remove_stop_words
|
||||||
from danswer.utils.batching import batch_generator
|
from danswer.utils.batching import batch_generator
|
||||||
from danswer.utils.logger import setup_logger
|
from danswer.utils.logger import setup_logger
|
||||||
@@ -585,6 +586,7 @@ class VespaIndex(DocumentIndex):
|
|||||||
filters: IndexFilters,
|
filters: IndexFilters,
|
||||||
favor_recent: bool,
|
favor_recent: bool,
|
||||||
num_to_retrieve: int = NUM_RETURNED_HITS,
|
num_to_retrieve: int = NUM_RETURNED_HITS,
|
||||||
|
edit_keyword_query: bool = EDIT_KEYWORD_QUERY,
|
||||||
) -> list[InferenceChunk]:
|
) -> list[InferenceChunk]:
|
||||||
decay_multiplier = FAVOR_RECENT_DECAY_MULTIPLIER if favor_recent else 1
|
decay_multiplier = FAVOR_RECENT_DECAY_MULTIPLIER if favor_recent else 1
|
||||||
vespa_where_clauses = _build_vespa_filters(filters)
|
vespa_where_clauses = _build_vespa_filters(filters)
|
||||||
@@ -599,9 +601,11 @@ class VespaIndex(DocumentIndex):
|
|||||||
+ _build_vespa_limit(num_to_retrieve)
|
+ _build_vespa_limit(num_to_retrieve)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
final_query = query_processing(query) if edit_keyword_query else query
|
||||||
|
|
||||||
params: dict[str, str | int] = {
|
params: dict[str, str | int] = {
|
||||||
"yql": yql,
|
"yql": yql,
|
||||||
"query": query,
|
"query": final_query,
|
||||||
"input.query(decay_factor)": str(DOC_TIME_DECAY * decay_multiplier),
|
"input.query(decay_factor)": str(DOC_TIME_DECAY * decay_multiplier),
|
||||||
"hits": num_to_retrieve,
|
"hits": num_to_retrieve,
|
||||||
"num_to_rerank": 10 * num_to_retrieve,
|
"num_to_rerank": 10 * num_to_retrieve,
|
||||||
@@ -617,6 +621,7 @@ class VespaIndex(DocumentIndex):
|
|||||||
favor_recent: bool,
|
favor_recent: bool,
|
||||||
num_to_retrieve: int = NUM_RETURNED_HITS,
|
num_to_retrieve: int = NUM_RETURNED_HITS,
|
||||||
distance_cutoff: float | None = SEARCH_DISTANCE_CUTOFF,
|
distance_cutoff: float | None = SEARCH_DISTANCE_CUTOFF,
|
||||||
|
edit_keyword_query: bool = EDIT_KEYWORD_QUERY,
|
||||||
) -> list[InferenceChunk]:
|
) -> list[InferenceChunk]:
|
||||||
decay_multiplier = FAVOR_RECENT_DECAY_MULTIPLIER if favor_recent else 1
|
decay_multiplier = FAVOR_RECENT_DECAY_MULTIPLIER if favor_recent else 1
|
||||||
vespa_where_clauses = _build_vespa_filters(filters)
|
vespa_where_clauses = _build_vespa_filters(filters)
|
||||||
@@ -634,7 +639,7 @@ class VespaIndex(DocumentIndex):
|
|||||||
query_embedding = embed_query(query)
|
query_embedding = embed_query(query)
|
||||||
|
|
||||||
query_keywords = (
|
query_keywords = (
|
||||||
" ".join(remove_stop_words(query)) if EDIT_KEYWORD_QUERY else query
|
" ".join(remove_stop_words(query)) if edit_keyword_query else query
|
||||||
)
|
)
|
||||||
|
|
||||||
params = {
|
params = {
|
||||||
@@ -653,31 +658,11 @@ class VespaIndex(DocumentIndex):
|
|||||||
filters: IndexFilters,
|
filters: IndexFilters,
|
||||||
favor_recent: bool,
|
favor_recent: bool,
|
||||||
num_to_retrieve: int,
|
num_to_retrieve: int,
|
||||||
|
distance_cutoff: float | None = SEARCH_DISTANCE_CUTOFF,
|
||||||
|
edit_keyword_query: bool = EDIT_KEYWORD_QUERY,
|
||||||
) -> list[InferenceChunk]:
|
) -> list[InferenceChunk]:
|
||||||
vespa_where_clauses = _build_vespa_filters(filters)
|
# TODO introduce the real hybrid search
|
||||||
yql = (
|
return self.semantic_retrieval(query, filters, favor_recent, num_to_retrieve)
|
||||||
VespaIndex.yql_base
|
|
||||||
+ vespa_where_clauses
|
|
||||||
+ f"({{targetHits: {10 * num_to_retrieve}}}nearestNeighbor(embeddings, query_embedding)) or "
|
|
||||||
+ '({grammar: "weakAnd"}userInput(@query) '
|
|
||||||
# `({defaultIndex: "content_summary"}userInput(@query))` section is
|
|
||||||
# needed for highlighting while the N-gram highlighting is broken /
|
|
||||||
# not working as desired
|
|
||||||
+ f'or ({{defaultIndex: "{CONTENT_SUMMARY}"}}userInput(@query)))'
|
|
||||||
+ _build_vespa_limit(num_to_retrieve)
|
|
||||||
)
|
|
||||||
|
|
||||||
query_embedding = embed_query(query)
|
|
||||||
|
|
||||||
params = {
|
|
||||||
"yql": yql,
|
|
||||||
"query": query,
|
|
||||||
"input.query(query_embedding)": str(query_embedding),
|
|
||||||
"input.query(decay_factor)": str(DOC_TIME_DECAY),
|
|
||||||
"ranking.profile": "hybrid_search",
|
|
||||||
}
|
|
||||||
|
|
||||||
return _query_vespa(params)
|
|
||||||
|
|
||||||
def admin_retrieval(
|
def admin_retrieval(
|
||||||
self,
|
self,
|
||||||
|
@@ -2,7 +2,7 @@ from sqlalchemy.orm import Session
|
|||||||
|
|
||||||
from danswer.access.access import get_acl_for_user
|
from danswer.access.access import get_acl_for_user
|
||||||
from danswer.db.models import User
|
from danswer.db.models import User
|
||||||
from danswer.server.models import IndexFilters
|
from danswer.search.models import IndexFilters
|
||||||
|
|
||||||
|
|
||||||
def build_access_filters_for_user(user: User | None, session: Session) -> list[str]:
|
def build_access_filters_for_user(user: User | None, session: Session) -> list[str]:
|
||||||
|
@@ -66,7 +66,7 @@ def query_intent(query: str) -> tuple[SearchType, QueryFlow]:
|
|||||||
|
|
||||||
def recommend_search_flow(
|
def recommend_search_flow(
|
||||||
query: str,
|
query: str,
|
||||||
keyword: bool,
|
keyword: bool = False,
|
||||||
max_percent_stopwords: float = 0.30, # ~Every third word max, ie "effects of caffeine" still viable keyword search
|
max_percent_stopwords: float = 0.30, # ~Every third word max, ie "effects of caffeine" still viable keyword search
|
||||||
) -> HelperResponse:
|
) -> HelperResponse:
|
||||||
heuristic_search_type: SearchType | None = None
|
heuristic_search_type: SearchType | None = None
|
||||||
|
@@ -1,11 +1,14 @@
|
|||||||
|
from datetime import datetime
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
|
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
|
|
||||||
|
from danswer.configs.app_configs import NUM_RERANKED_RESULTS
|
||||||
|
from danswer.configs.app_configs import NUM_RETURNED_HITS
|
||||||
|
from danswer.configs.model_configs import SKIP_RERANKING
|
||||||
from danswer.indexing.models import DocAwareChunk
|
from danswer.indexing.models import DocAwareChunk
|
||||||
from danswer.indexing.models import IndexChunk
|
from danswer.indexing.models import IndexChunk
|
||||||
|
|
||||||
|
|
||||||
MAX_METRICS_CONTENT = (
|
MAX_METRICS_CONTENT = (
|
||||||
200 # Just need enough characters to identify where in the doc the chunk is
|
200 # Just need enough characters to identify where in the doc the chunk is
|
||||||
)
|
)
|
||||||
@@ -27,6 +30,16 @@ class Embedder:
|
|||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
|
|
||||||
|
class BaseFilters(BaseModel):
|
||||||
|
source_type: list[str] | None = None
|
||||||
|
document_set: list[str] | None = None
|
||||||
|
time_cutoff: datetime | None = None
|
||||||
|
|
||||||
|
|
||||||
|
class IndexFilters(BaseFilters):
|
||||||
|
access_control_list: list[str]
|
||||||
|
|
||||||
|
|
||||||
class ChunkMetric(BaseModel):
|
class ChunkMetric(BaseModel):
|
||||||
document_id: str
|
document_id: str
|
||||||
chunk_content_start: str
|
chunk_content_start: str
|
||||||
@@ -34,6 +47,17 @@ class ChunkMetric(BaseModel):
|
|||||||
score: float
|
score: float
|
||||||
|
|
||||||
|
|
||||||
|
class SearchQuery(BaseModel):
|
||||||
|
query: str
|
||||||
|
search_type: SearchType
|
||||||
|
filters: IndexFilters
|
||||||
|
favor_recent: bool
|
||||||
|
num_hits: int = NUM_RETURNED_HITS
|
||||||
|
skip_rerank: bool = SKIP_RERANKING
|
||||||
|
# Only used if not skip_rerank
|
||||||
|
num_rerank: int | None = NUM_RERANKED_RESULTS
|
||||||
|
|
||||||
|
|
||||||
class RetrievalMetricsContainer(BaseModel):
|
class RetrievalMetricsContainer(BaseModel):
|
||||||
keyword_search: bool # False for Vector Search
|
keyword_search: bool # False for Vector Search
|
||||||
metrics: list[ChunkMetric] # This contains the scores for retrieval as well
|
metrics: list[ChunkMetric] # This contains the scores for retrieval as well
|
||||||
|
@@ -6,26 +6,23 @@ from nltk.stem import WordNetLemmatizer # type:ignore
|
|||||||
from nltk.tokenize import word_tokenize # type:ignore
|
from nltk.tokenize import word_tokenize # type:ignore
|
||||||
from sentence_transformers import SentenceTransformer # type: ignore
|
from sentence_transformers import SentenceTransformer # type: ignore
|
||||||
|
|
||||||
from danswer.configs.app_configs import EDIT_KEYWORD_QUERY
|
|
||||||
from danswer.configs.app_configs import NUM_RERANKED_RESULTS
|
|
||||||
from danswer.configs.app_configs import NUM_RETURNED_HITS
|
|
||||||
from danswer.configs.model_configs import ASYM_QUERY_PREFIX
|
from danswer.configs.model_configs import ASYM_QUERY_PREFIX
|
||||||
from danswer.configs.model_configs import CROSS_ENCODER_RANGE_MAX
|
from danswer.configs.model_configs import CROSS_ENCODER_RANGE_MAX
|
||||||
from danswer.configs.model_configs import CROSS_ENCODER_RANGE_MIN
|
from danswer.configs.model_configs import CROSS_ENCODER_RANGE_MIN
|
||||||
from danswer.configs.model_configs import NORMALIZE_EMBEDDINGS
|
from danswer.configs.model_configs import NORMALIZE_EMBEDDINGS
|
||||||
from danswer.configs.model_configs import SIM_SCORE_RANGE_HIGH
|
from danswer.configs.model_configs import SIM_SCORE_RANGE_HIGH
|
||||||
from danswer.configs.model_configs import SIM_SCORE_RANGE_LOW
|
from danswer.configs.model_configs import SIM_SCORE_RANGE_LOW
|
||||||
from danswer.configs.model_configs import SKIP_RERANKING
|
|
||||||
from danswer.document_index.document_index_utils import (
|
from danswer.document_index.document_index_utils import (
|
||||||
translate_boost_count_to_multiplier,
|
translate_boost_count_to_multiplier,
|
||||||
)
|
)
|
||||||
from danswer.document_index.interfaces import DocumentIndex
|
from danswer.document_index.interfaces import DocumentIndex
|
||||||
from danswer.document_index.interfaces import IndexFilters
|
|
||||||
from danswer.indexing.models import InferenceChunk
|
from danswer.indexing.models import InferenceChunk
|
||||||
from danswer.search.models import ChunkMetric
|
from danswer.search.models import ChunkMetric
|
||||||
from danswer.search.models import MAX_METRICS_CONTENT
|
from danswer.search.models import MAX_METRICS_CONTENT
|
||||||
from danswer.search.models import RerankMetricsContainer
|
from danswer.search.models import RerankMetricsContainer
|
||||||
from danswer.search.models import RetrievalMetricsContainer
|
from danswer.search.models import RetrievalMetricsContainer
|
||||||
|
from danswer.search.models import SearchQuery
|
||||||
|
from danswer.search.models import SearchType
|
||||||
from danswer.search.search_nlp_models import get_default_embedding_model
|
from danswer.search.search_nlp_models import get_default_embedding_model
|
||||||
from danswer.search.search_nlp_models import get_default_reranking_model_ensemble
|
from danswer.search.search_nlp_models import get_default_reranking_model_ensemble
|
||||||
from danswer.server.models import SearchDoc
|
from danswer.server.models import SearchDoc
|
||||||
@@ -56,6 +53,24 @@ def query_processing(
|
|||||||
return query
|
return query
|
||||||
|
|
||||||
|
|
||||||
|
def embed_query(
|
||||||
|
query: str,
|
||||||
|
embedding_model: SentenceTransformer | None = None,
|
||||||
|
prefix: str = ASYM_QUERY_PREFIX,
|
||||||
|
normalize_embeddings: bool = NORMALIZE_EMBEDDINGS,
|
||||||
|
) -> list[float]:
|
||||||
|
model = embedding_model or get_default_embedding_model()
|
||||||
|
prefixed_query = prefix + query
|
||||||
|
query_embedding = model.encode(
|
||||||
|
prefixed_query, normalize_embeddings=normalize_embeddings
|
||||||
|
)
|
||||||
|
|
||||||
|
if not isinstance(query_embedding, list):
|
||||||
|
query_embedding = query_embedding.tolist()
|
||||||
|
|
||||||
|
return query_embedding
|
||||||
|
|
||||||
|
|
||||||
def chunks_to_search_docs(chunks: list[InferenceChunk] | None) -> list[SearchDoc]:
|
def chunks_to_search_docs(chunks: list[InferenceChunk] | None) -> list[SearchDoc]:
|
||||||
search_docs = (
|
search_docs = (
|
||||||
[
|
[
|
||||||
@@ -233,73 +248,44 @@ def apply_boost(
|
|||||||
return final_chunks
|
return final_chunks
|
||||||
|
|
||||||
|
|
||||||
@log_function_time()
|
def search_chunks(
|
||||||
def retrieve_keyword_documents(
|
query: SearchQuery,
|
||||||
query: str,
|
document_index: DocumentIndex,
|
||||||
filters: IndexFilters,
|
|
||||||
favor_recent: bool,
|
|
||||||
datastore: DocumentIndex,
|
|
||||||
num_hits: int = NUM_RETURNED_HITS,
|
|
||||||
edit_query: bool = EDIT_KEYWORD_QUERY,
|
|
||||||
retrieval_metrics_callback: Callable[[RetrievalMetricsContainer], None]
|
|
||||||
| None = None,
|
|
||||||
) -> list[InferenceChunk] | None:
|
|
||||||
edited_query = query_processing(query) if edit_query else query
|
|
||||||
|
|
||||||
top_chunks = datastore.keyword_retrieval(
|
|
||||||
edited_query, filters, favor_recent, num_hits
|
|
||||||
)
|
|
||||||
|
|
||||||
if not top_chunks:
|
|
||||||
logger.warning(
|
|
||||||
f"Keyword search returned no results - Filters: {filters}\tEdited Query: {edited_query}"
|
|
||||||
)
|
|
||||||
return None
|
|
||||||
|
|
||||||
if retrieval_metrics_callback is not None:
|
|
||||||
chunk_metrics = [
|
|
||||||
ChunkMetric(
|
|
||||||
document_id=chunk.document_id,
|
|
||||||
chunk_content_start=chunk.content[:MAX_METRICS_CONTENT],
|
|
||||||
first_link=chunk.source_links[0] if chunk.source_links else None,
|
|
||||||
score=chunk.score if chunk.score is not None else 0,
|
|
||||||
)
|
|
||||||
for chunk in top_chunks
|
|
||||||
]
|
|
||||||
retrieval_metrics_callback(
|
|
||||||
RetrievalMetricsContainer(keyword_search=True, metrics=chunk_metrics)
|
|
||||||
)
|
|
||||||
|
|
||||||
return top_chunks
|
|
||||||
|
|
||||||
|
|
||||||
@log_function_time()
|
|
||||||
def retrieve_ranked_documents(
|
|
||||||
query: str,
|
|
||||||
filters: IndexFilters,
|
|
||||||
favor_recent: bool,
|
|
||||||
datastore: DocumentIndex,
|
|
||||||
num_hits: int = NUM_RETURNED_HITS,
|
|
||||||
num_rerank: int = NUM_RERANKED_RESULTS,
|
|
||||||
skip_rerank: bool = SKIP_RERANKING,
|
|
||||||
retrieval_metrics_callback: Callable[[RetrievalMetricsContainer], None]
|
retrieval_metrics_callback: Callable[[RetrievalMetricsContainer], None]
|
||||||
| None = None,
|
| None = None,
|
||||||
rerank_metrics_callback: Callable[[RerankMetricsContainer], None] | None = None,
|
rerank_metrics_callback: Callable[[RerankMetricsContainer], None] | None = None,
|
||||||
) -> tuple[list[InferenceChunk] | None, list[InferenceChunk] | None]:
|
) -> tuple[list[InferenceChunk] | None, list[InferenceChunk] | None]:
|
||||||
"""Uses vector similarity to fetch the top num_hits document chunks with a distance cutoff.
|
def _log_top_chunk_links(search_flow: str, chunks: list[InferenceChunk]) -> None:
|
||||||
Reranks the top num_rerank out of those (instead of all due to latency)"""
|
top_links = [
|
||||||
|
c.source_links[0] if c.source_links is not None else "No Link"
|
||||||
|
for c in chunks
|
||||||
|
]
|
||||||
|
logger.info(f"Top links from {search_flow} search: {', '.join(top_links)}")
|
||||||
|
|
||||||
def _log_top_chunk_links(chunks: list[InferenceChunk]) -> None:
|
if query.search_type == SearchType.KEYWORD:
|
||||||
doc_links = [c.source_links[0] for c in chunks if c.source_links is not None]
|
top_chunks = document_index.keyword_retrieval(
|
||||||
|
query.query, query.filters, query.favor_recent, query.num_hits
|
||||||
|
)
|
||||||
|
|
||||||
files_log_msg = f"Top links from semantic search: {', '.join(doc_links)}"
|
elif query.search_type == SearchType.SEMANTIC:
|
||||||
logger.info(files_log_msg)
|
top_chunks = document_index.semantic_retrieval(
|
||||||
|
query.query, query.filters, query.favor_recent, query.num_hits
|
||||||
|
)
|
||||||
|
|
||||||
|
elif query.search_type == SearchType.HYBRID:
|
||||||
|
top_chunks = document_index.hybrid_retrieval(
|
||||||
|
query.query, query.filters, query.favor_recent, query.num_hits
|
||||||
|
)
|
||||||
|
|
||||||
|
else:
|
||||||
|
raise RuntimeError("Invalid Search Flow")
|
||||||
|
|
||||||
top_chunks = datastore.semantic_retrieval(query, filters, favor_recent, num_hits)
|
|
||||||
if not top_chunks:
|
if not top_chunks:
|
||||||
logger.info(f"Semantic search returned no results with filters: {filters}")
|
logger.info(
|
||||||
|
f"{query.search_type.value.capitalize()} search returned no results "
|
||||||
|
f"with filters: {query.filters}"
|
||||||
|
)
|
||||||
return None, None
|
return None, None
|
||||||
logger.debug(top_chunks)
|
|
||||||
|
|
||||||
if retrieval_metrics_callback is not None:
|
if retrieval_metrics_callback is not None:
|
||||||
chunk_metrics = [
|
chunk_metrics = [
|
||||||
@@ -315,40 +301,24 @@ def retrieve_ranked_documents(
|
|||||||
RetrievalMetricsContainer(keyword_search=True, metrics=chunk_metrics)
|
RetrievalMetricsContainer(keyword_search=True, metrics=chunk_metrics)
|
||||||
)
|
)
|
||||||
|
|
||||||
if skip_rerank:
|
# Keyword Search should never do reranking, no transformers involved in this flow
|
||||||
|
if query.search_type == SearchType.KEYWORD:
|
||||||
|
_log_top_chunk_links(query.search_type.value, top_chunks)
|
||||||
|
return top_chunks, None
|
||||||
|
|
||||||
|
if query.skip_rerank:
|
||||||
# Need the range of values to not be too spread out for applying boost
|
# Need the range of values to not be too spread out for applying boost
|
||||||
boosted_chunks = apply_boost(top_chunks[:num_rerank])
|
# Therefore pass in smaller set of chunks to limit the range for norm-ing
|
||||||
_log_top_chunk_links(boosted_chunks)
|
boosted_chunks = apply_boost(top_chunks[: query.num_rerank])
|
||||||
return boosted_chunks, top_chunks[num_rerank:]
|
_log_top_chunk_links(query.search_type.value, boosted_chunks)
|
||||||
|
return boosted_chunks, top_chunks[query.num_rerank :]
|
||||||
|
|
||||||
ranked_chunks = (
|
ranked_chunks = semantic_reranking(
|
||||||
semantic_reranking(
|
query.query,
|
||||||
query,
|
top_chunks[: query.num_rerank],
|
||||||
top_chunks[:num_rerank],
|
rerank_metrics_callback=rerank_metrics_callback,
|
||||||
rerank_metrics_callback=rerank_metrics_callback,
|
|
||||||
)
|
|
||||||
if not skip_rerank
|
|
||||||
else []
|
|
||||||
)
|
)
|
||||||
|
|
||||||
_log_top_chunk_links(ranked_chunks)
|
_log_top_chunk_links(query.search_type.value, ranked_chunks)
|
||||||
|
|
||||||
return ranked_chunks, top_chunks[num_rerank:]
|
return ranked_chunks, top_chunks[query.num_rerank :]
|
||||||
|
|
||||||
|
|
||||||
def embed_query(
|
|
||||||
query: str,
|
|
||||||
embedding_model: SentenceTransformer | None = None,
|
|
||||||
prefix: str = ASYM_QUERY_PREFIX,
|
|
||||||
normalize_embeddings: bool = NORMALIZE_EMBEDDINGS,
|
|
||||||
) -> list[float]:
|
|
||||||
model = embedding_model or get_default_embedding_model()
|
|
||||||
prefixed_query = prefix + query
|
|
||||||
query_embedding = model.encode(
|
|
||||||
prefixed_query, normalize_embeddings=normalize_embeddings
|
|
||||||
)
|
|
||||||
|
|
||||||
if not isinstance(query_embedding, list):
|
|
||||||
query_embedding = query_embedding.tolist()
|
|
||||||
|
|
||||||
return query_embedding
|
|
||||||
|
@@ -27,6 +27,7 @@ from danswer.db.models import IndexAttempt
|
|||||||
from danswer.db.models import IndexingStatus
|
from danswer.db.models import IndexingStatus
|
||||||
from danswer.db.models import TaskStatus
|
from danswer.db.models import TaskStatus
|
||||||
from danswer.direct_qa.interfaces import DanswerQuote
|
from danswer.direct_qa.interfaces import DanswerQuote
|
||||||
|
from danswer.search.models import BaseFilters
|
||||||
from danswer.search.models import QueryFlow
|
from danswer.search.models import QueryFlow
|
||||||
from danswer.search.models import SearchType
|
from danswer.search.models import SearchType
|
||||||
from danswer.server.utils import mask_credential_dict
|
from danswer.server.utils import mask_credential_dict
|
||||||
@@ -189,25 +190,14 @@ class CreateChatSessionID(BaseModel):
|
|||||||
chat_session_id: int
|
chat_session_id: int
|
||||||
|
|
||||||
|
|
||||||
class RequestFilters(BaseModel):
|
|
||||||
source_type: list[str] | None
|
|
||||||
document_set: list[str] | None
|
|
||||||
time_cutoff: datetime | None = None
|
|
||||||
|
|
||||||
|
|
||||||
class IndexFilters(RequestFilters):
|
|
||||||
access_control_list: list[str]
|
|
||||||
|
|
||||||
|
|
||||||
class QuestionRequest(BaseModel):
|
class QuestionRequest(BaseModel):
|
||||||
query: str
|
query: str
|
||||||
collection: str
|
collection: str
|
||||||
filters: RequestFilters
|
filters: BaseFilters
|
||||||
offset: int | None
|
offset: int | None
|
||||||
enable_auto_detect_filters: bool
|
enable_auto_detect_filters: bool
|
||||||
favor_recent: bool | None = None
|
favor_recent: bool | None = None
|
||||||
use_keyword: bool | None # TODO remove this for hybrid search
|
search_type: SearchType = SearchType.HYBRID
|
||||||
search_flow: SearchType | None = None # Default hybrid
|
|
||||||
|
|
||||||
|
|
||||||
class QAFeedbackRequest(BaseModel):
|
class QAFeedbackRequest(BaseModel):
|
||||||
|
@@ -27,20 +27,18 @@ from danswer.direct_qa.llm_utils import get_default_qa_model
|
|||||||
from danswer.direct_qa.qa_utils import get_usable_chunks
|
from danswer.direct_qa.qa_utils import get_usable_chunks
|
||||||
from danswer.document_index import get_default_document_index
|
from danswer.document_index import get_default_document_index
|
||||||
from danswer.document_index.vespa.index import VespaIndex
|
from danswer.document_index.vespa.index import VespaIndex
|
||||||
from danswer.indexing.models import InferenceChunk
|
|
||||||
from danswer.search.access_filters import build_access_filters_for_user
|
from danswer.search.access_filters import build_access_filters_for_user
|
||||||
from danswer.search.danswer_helper import query_intent
|
from danswer.search.danswer_helper import query_intent
|
||||||
from danswer.search.danswer_helper import recommend_search_flow
|
from danswer.search.danswer_helper import recommend_search_flow
|
||||||
|
from danswer.search.models import IndexFilters
|
||||||
from danswer.search.models import QueryFlow
|
from danswer.search.models import QueryFlow
|
||||||
from danswer.search.models import SearchType
|
from danswer.search.models import SearchQuery
|
||||||
from danswer.search.search_runner import chunks_to_search_docs
|
from danswer.search.search_runner import chunks_to_search_docs
|
||||||
from danswer.search.search_runner import retrieve_keyword_documents
|
from danswer.search.search_runner import search_chunks
|
||||||
from danswer.search.search_runner import retrieve_ranked_documents
|
|
||||||
from danswer.secondary_llm_flows.extract_filters import extract_question_time_filters
|
from danswer.secondary_llm_flows.extract_filters import extract_question_time_filters
|
||||||
from danswer.secondary_llm_flows.query_validation import get_query_answerability
|
from danswer.secondary_llm_flows.query_validation import get_query_answerability
|
||||||
from danswer.secondary_llm_flows.query_validation import stream_query_answerability
|
from danswer.secondary_llm_flows.query_validation import stream_query_answerability
|
||||||
from danswer.server.models import HelperResponse
|
from danswer.server.models import HelperResponse
|
||||||
from danswer.server.models import IndexFilters
|
|
||||||
from danswer.server.models import QAFeedbackRequest
|
from danswer.server.models import QAFeedbackRequest
|
||||||
from danswer.server.models import QAResponse
|
from danswer.server.models import QAResponse
|
||||||
from danswer.server.models import QueryValidationResponse
|
from danswer.server.models import QueryValidationResponse
|
||||||
@@ -114,8 +112,7 @@ def get_search_type(
|
|||||||
question: QuestionRequest, _: User = Depends(current_user)
|
question: QuestionRequest, _: User = Depends(current_user)
|
||||||
) -> HelperResponse:
|
) -> HelperResponse:
|
||||||
query = question.query
|
query = question.query
|
||||||
use_keyword = question.use_keyword if question.use_keyword is not None else False
|
return recommend_search_flow(query)
|
||||||
return recommend_search_flow(query, use_keyword)
|
|
||||||
|
|
||||||
|
|
||||||
@router.post("/query-validation")
|
@router.post("/query-validation")
|
||||||
@@ -137,14 +134,14 @@ def stream_query_validation(
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@router.post("/keyword-search")
|
@router.post("/document-search")
|
||||||
def keyword_search(
|
def handle_search_request(
|
||||||
question: QuestionRequest,
|
question: QuestionRequest,
|
||||||
user: User = Depends(current_user),
|
user: User = Depends(current_user),
|
||||||
db_session: Session = Depends(get_session),
|
db_session: Session = Depends(get_session),
|
||||||
) -> SearchResponse:
|
) -> SearchResponse:
|
||||||
query = question.query
|
query = question.query
|
||||||
logger.info(f"Received keyword search query: {query}")
|
logger.info(f"Received {question.search_type.value} " f"search query: {query}")
|
||||||
|
|
||||||
time_cutoff, favor_recent = extract_question_time_filters(question)
|
time_cutoff, favor_recent = extract_question_time_filters(question)
|
||||||
question.filters.time_cutoff = time_cutoff
|
question.filters.time_cutoff = time_cutoff
|
||||||
@@ -152,7 +149,7 @@ def keyword_search(
|
|||||||
|
|
||||||
query_event_id = create_query_event(
|
query_event_id = create_query_event(
|
||||||
query=query,
|
query=query,
|
||||||
selected_flow=SearchType.KEYWORD,
|
search_type=question.search_type,
|
||||||
llm_answer=None,
|
llm_answer=None,
|
||||||
user_id=user.id,
|
user_id=user.id,
|
||||||
db_session=db_session,
|
db_session=db_session,
|
||||||
@@ -166,12 +163,18 @@ def keyword_search(
|
|||||||
time_cutoff=filters.time_cutoff,
|
time_cutoff=filters.time_cutoff,
|
||||||
access_control_list=user_acl_filters,
|
access_control_list=user_acl_filters,
|
||||||
)
|
)
|
||||||
ranked_chunks = retrieve_keyword_documents(
|
|
||||||
|
search_query = SearchQuery(
|
||||||
query=query,
|
query=query,
|
||||||
|
search_type=question.search_type,
|
||||||
filters=final_filters,
|
filters=final_filters,
|
||||||
favor_recent=favor_recent,
|
favor_recent=favor_recent,
|
||||||
datastore=get_default_document_index(),
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
ranked_chunks, unranked_chunks = search_chunks(
|
||||||
|
query=search_query, document_index=get_default_document_index()
|
||||||
|
)
|
||||||
|
|
||||||
if not ranked_chunks:
|
if not ranked_chunks:
|
||||||
return SearchResponse(
|
return SearchResponse(
|
||||||
top_ranked_docs=None,
|
top_ranked_docs=None,
|
||||||
@@ -182,6 +185,8 @@ def keyword_search(
|
|||||||
)
|
)
|
||||||
|
|
||||||
top_docs = chunks_to_search_docs(ranked_chunks)
|
top_docs = chunks_to_search_docs(ranked_chunks)
|
||||||
|
lower_top_docs = chunks_to_search_docs(unranked_chunks)
|
||||||
|
|
||||||
update_query_event_retrieved_documents(
|
update_query_event_retrieved_documents(
|
||||||
db_session=db_session,
|
db_session=db_session,
|
||||||
retrieved_document_ids=[doc.document_id for doc in top_docs],
|
retrieved_document_ids=[doc.document_id for doc in top_docs],
|
||||||
@@ -191,132 +196,7 @@ def keyword_search(
|
|||||||
|
|
||||||
return SearchResponse(
|
return SearchResponse(
|
||||||
top_ranked_docs=top_docs,
|
top_ranked_docs=top_docs,
|
||||||
lower_ranked_docs=None,
|
lower_ranked_docs=lower_top_docs or None,
|
||||||
query_event_id=query_event_id,
|
|
||||||
time_cutoff=time_cutoff,
|
|
||||||
favor_recent=favor_recent,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
@router.post("/semantic-search")
|
|
||||||
def semantic_search(
|
|
||||||
question: QuestionRequest,
|
|
||||||
user: User = Depends(current_user),
|
|
||||||
db_session: Session = Depends(get_session),
|
|
||||||
) -> SearchResponse:
|
|
||||||
query = question.query
|
|
||||||
logger.info(f"Received semantic search query: {query}")
|
|
||||||
|
|
||||||
time_cutoff, favor_recent = extract_question_time_filters(question)
|
|
||||||
question.filters.time_cutoff = time_cutoff
|
|
||||||
filters = question.filters
|
|
||||||
|
|
||||||
query_event_id = create_query_event(
|
|
||||||
query=query,
|
|
||||||
selected_flow=SearchType.SEMANTIC,
|
|
||||||
llm_answer=None,
|
|
||||||
user_id=user.id,
|
|
||||||
db_session=db_session,
|
|
||||||
)
|
|
||||||
|
|
||||||
user_id = None if user is None else user.id
|
|
||||||
user_acl_filters = build_access_filters_for_user(user, db_session)
|
|
||||||
final_filters = IndexFilters(
|
|
||||||
source_type=filters.source_type,
|
|
||||||
document_set=filters.document_set,
|
|
||||||
time_cutoff=filters.time_cutoff,
|
|
||||||
access_control_list=user_acl_filters,
|
|
||||||
)
|
|
||||||
ranked_chunks, unranked_chunks = retrieve_ranked_documents(
|
|
||||||
query=query,
|
|
||||||
filters=final_filters,
|
|
||||||
favor_recent=favor_recent,
|
|
||||||
datastore=get_default_document_index(),
|
|
||||||
)
|
|
||||||
if not ranked_chunks:
|
|
||||||
return SearchResponse(
|
|
||||||
top_ranked_docs=None,
|
|
||||||
lower_ranked_docs=None,
|
|
||||||
query_event_id=query_event_id,
|
|
||||||
time_cutoff=time_cutoff,
|
|
||||||
favor_recent=favor_recent,
|
|
||||||
)
|
|
||||||
|
|
||||||
top_docs = chunks_to_search_docs(ranked_chunks)
|
|
||||||
other_top_docs = chunks_to_search_docs(unranked_chunks)
|
|
||||||
update_query_event_retrieved_documents(
|
|
||||||
db_session=db_session,
|
|
||||||
retrieved_document_ids=[doc.document_id for doc in top_docs],
|
|
||||||
query_id=query_event_id,
|
|
||||||
user_id=user_id,
|
|
||||||
)
|
|
||||||
|
|
||||||
return SearchResponse(
|
|
||||||
top_ranked_docs=top_docs,
|
|
||||||
lower_ranked_docs=other_top_docs,
|
|
||||||
query_event_id=query_event_id,
|
|
||||||
time_cutoff=time_cutoff,
|
|
||||||
favor_recent=favor_recent,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
# TODO don't use this, not done yet
|
|
||||||
@router.post("/hybrid-search")
|
|
||||||
def hybrid_search(
|
|
||||||
question: QuestionRequest,
|
|
||||||
user: User = Depends(current_user),
|
|
||||||
db_session: Session = Depends(get_session),
|
|
||||||
) -> SearchResponse:
|
|
||||||
query = question.query
|
|
||||||
logger.info(f"Received hybrid search query: {query}")
|
|
||||||
|
|
||||||
time_cutoff, favor_recent = extract_question_time_filters(question)
|
|
||||||
question.filters.time_cutoff = time_cutoff
|
|
||||||
filters = question.filters
|
|
||||||
|
|
||||||
query_event_id = create_query_event(
|
|
||||||
query=query,
|
|
||||||
selected_flow=SearchType.HYBRID,
|
|
||||||
llm_answer=None,
|
|
||||||
user_id=user.id,
|
|
||||||
db_session=db_session,
|
|
||||||
)
|
|
||||||
|
|
||||||
user_id = None if user is None else user.id
|
|
||||||
user_acl_filters = build_access_filters_for_user(user, db_session)
|
|
||||||
final_filters = IndexFilters(
|
|
||||||
source_type=filters.source_type,
|
|
||||||
document_set=filters.document_set,
|
|
||||||
time_cutoff=filters.time_cutoff,
|
|
||||||
access_control_list=user_acl_filters,
|
|
||||||
)
|
|
||||||
ranked_chunks, unranked_chunks = retrieve_ranked_documents(
|
|
||||||
query=query,
|
|
||||||
filters=final_filters,
|
|
||||||
favor_recent=favor_recent,
|
|
||||||
datastore=get_default_document_index(),
|
|
||||||
)
|
|
||||||
if not ranked_chunks:
|
|
||||||
return SearchResponse(
|
|
||||||
top_ranked_docs=None,
|
|
||||||
lower_ranked_docs=None,
|
|
||||||
query_event_id=query_event_id,
|
|
||||||
time_cutoff=time_cutoff,
|
|
||||||
favor_recent=favor_recent,
|
|
||||||
)
|
|
||||||
|
|
||||||
top_docs = chunks_to_search_docs(ranked_chunks)
|
|
||||||
other_top_docs = chunks_to_search_docs(unranked_chunks)
|
|
||||||
update_query_event_retrieved_documents(
|
|
||||||
db_session=db_session,
|
|
||||||
retrieved_document_ids=[doc.document_id for doc in top_docs],
|
|
||||||
query_id=query_event_id,
|
|
||||||
user_id=user_id,
|
|
||||||
)
|
|
||||||
|
|
||||||
return SearchResponse(
|
|
||||||
top_ranked_docs=top_docs,
|
|
||||||
lower_ranked_docs=other_top_docs,
|
|
||||||
query_event_id=query_event_id,
|
query_event_id=query_event_id,
|
||||||
time_cutoff=time_cutoff,
|
time_cutoff=time_cutoff,
|
||||||
favor_recent=favor_recent,
|
favor_recent=favor_recent,
|
||||||
@@ -347,10 +227,10 @@ def stream_direct_qa(
|
|||||||
predicted_search_key = "predicted_search"
|
predicted_search_key = "predicted_search"
|
||||||
query_event_id_key = "query_event_id"
|
query_event_id_key = "query_event_id"
|
||||||
|
|
||||||
logger.debug(f"Received QA query: {question.query}")
|
logger.debug(
|
||||||
|
f"Received QA query ({question.search_type.value} search): {question.query}"
|
||||||
|
)
|
||||||
logger.debug(f"Query filters: {question.filters}")
|
logger.debug(f"Query filters: {question.filters}")
|
||||||
if question.use_keyword:
|
|
||||||
logger.debug("User selected Keyword Search")
|
|
||||||
|
|
||||||
@log_generator_function_time()
|
@log_generator_function_time()
|
||||||
def stream_qa_portions(
|
def stream_qa_portions(
|
||||||
@@ -358,40 +238,34 @@ def stream_direct_qa(
|
|||||||
) -> Generator[str, None, None]:
|
) -> Generator[str, None, None]:
|
||||||
answer_so_far: str = ""
|
answer_so_far: str = ""
|
||||||
query = question.query
|
query = question.query
|
||||||
use_keyword = question.use_keyword
|
|
||||||
offset_count = question.offset if question.offset is not None else 0
|
offset_count = question.offset if question.offset is not None else 0
|
||||||
|
|
||||||
time_cutoff, favor_recent = extract_question_time_filters(question)
|
time_cutoff, favor_recent = extract_question_time_filters(question)
|
||||||
question.filters.time_cutoff = time_cutoff
|
question.filters.time_cutoff = time_cutoff # not used but just in case
|
||||||
filters = question.filters
|
filters = question.filters
|
||||||
|
|
||||||
predicted_search, predicted_flow = query_intent(query)
|
|
||||||
if use_keyword is None:
|
|
||||||
use_keyword = predicted_search == SearchType.KEYWORD
|
|
||||||
|
|
||||||
user_id = None if user is None else user.id
|
|
||||||
user_acl_filters = build_access_filters_for_user(user, db_session)
|
user_acl_filters = build_access_filters_for_user(user, db_session)
|
||||||
final_filters = IndexFilters(
|
final_filters = IndexFilters(
|
||||||
source_type=filters.source_type,
|
source_type=filters.source_type,
|
||||||
document_set=filters.document_set,
|
document_set=filters.document_set,
|
||||||
time_cutoff=filters.time_cutoff,
|
time_cutoff=time_cutoff,
|
||||||
access_control_list=user_acl_filters,
|
access_control_list=user_acl_filters,
|
||||||
)
|
)
|
||||||
if use_keyword:
|
|
||||||
ranked_chunks: list[InferenceChunk] | None = retrieve_keyword_documents(
|
search_query = SearchQuery(
|
||||||
query=query,
|
query=query,
|
||||||
filters=final_filters,
|
search_type=question.search_type,
|
||||||
favor_recent=favor_recent,
|
filters=final_filters,
|
||||||
datastore=get_default_document_index(),
|
favor_recent=favor_recent,
|
||||||
)
|
)
|
||||||
unranked_chunks: list[InferenceChunk] | None = []
|
|
||||||
else:
|
ranked_chunks, unranked_chunks = search_chunks(
|
||||||
ranked_chunks, unranked_chunks = retrieve_ranked_documents(
|
query=search_query, document_index=get_default_document_index()
|
||||||
query=query,
|
)
|
||||||
filters=final_filters,
|
|
||||||
favor_recent=favor_recent,
|
# TODO retire this
|
||||||
datastore=get_default_document_index(),
|
predicted_search, predicted_flow = query_intent(query)
|
||||||
)
|
|
||||||
if not ranked_chunks:
|
if not ranked_chunks:
|
||||||
logger.debug("No Documents Found")
|
logger.debug("No Documents Found")
|
||||||
empty_docs_result = {
|
empty_docs_result = {
|
||||||
@@ -473,17 +347,14 @@ def stream_direct_qa(
|
|||||||
|
|
||||||
query_event_id = create_query_event(
|
query_event_id = create_query_event(
|
||||||
query=query,
|
query=query,
|
||||||
selected_flow=SearchType.KEYWORD
|
search_type=question.search_type,
|
||||||
if question.use_keyword
|
|
||||||
else SearchType.SEMANTIC,
|
|
||||||
llm_answer=answer_so_far,
|
llm_answer=answer_so_far,
|
||||||
retrieved_document_ids=[doc.document_id for doc in top_docs],
|
retrieved_document_ids=[doc.document_id for doc in top_docs],
|
||||||
user_id=user_id,
|
user_id=None if user is None else user.id,
|
||||||
db_session=db_session,
|
db_session=db_session,
|
||||||
)
|
)
|
||||||
|
|
||||||
yield get_json_line({query_event_id_key: query_event_id})
|
yield get_json_line({query_event_id_key: query_event_id})
|
||||||
|
|
||||||
return
|
return
|
||||||
|
|
||||||
return StreamingResponse(stream_qa_portions(), media_type="application/json")
|
return StreamingResponse(stream_qa_portions(), media_type="application/json")
|
||||||
|
@@ -14,27 +14,19 @@ if __name__ == "__main__":
|
|||||||
previous_query = None
|
previous_query = None
|
||||||
parser = argparse.ArgumentParser()
|
parser = argparse.ArgumentParser()
|
||||||
|
|
||||||
parser.add_argument(
|
|
||||||
"-f",
|
|
||||||
"--flow",
|
|
||||||
type=str,
|
|
||||||
default="QA",
|
|
||||||
help='"Search" or "QA", defaults to "QA"',
|
|
||||||
)
|
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"-t",
|
"-t",
|
||||||
"--type",
|
"--type",
|
||||||
type=str,
|
type=str,
|
||||||
default="Semantic",
|
default="hybrid",
|
||||||
help='"Semantic" or "Keyword", defaults to "Semantic"',
|
help='"hybrid" "semantic" or "keyword", defaults to "hybrid"',
|
||||||
)
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"-s",
|
"-s",
|
||||||
"--stream",
|
"--stream",
|
||||||
action="store_true",
|
action="store_true",
|
||||||
help='Enable streaming response, only for flow="QA"',
|
help="Enable streaming response",
|
||||||
)
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
@@ -50,7 +42,7 @@ if __name__ == "__main__":
|
|||||||
try:
|
try:
|
||||||
user_input = input(
|
user_input = input(
|
||||||
"\n\nAsk any question:\n"
|
"\n\nAsk any question:\n"
|
||||||
" - Use -f (QA/Search) and -t (Semantic/Keyword) flags to set endpoint.\n"
|
" - Use -t (hybrid/semantic/keyword) flag to choose search flow.\n"
|
||||||
" - prefix with -s to stream answer, --filters web,slack etc. for filters.\n"
|
" - prefix with -s to stream answer, --filters web,slack etc. for filters.\n"
|
||||||
" - input an empty string to rerun last query.\n\t"
|
" - input an empty string to rerun last query.\n\t"
|
||||||
)
|
)
|
||||||
@@ -66,25 +58,15 @@ if __name__ == "__main__":
|
|||||||
|
|
||||||
args = parser.parse_args(user_input.split())
|
args = parser.parse_args(user_input.split())
|
||||||
|
|
||||||
flow = str(args.flow).lower()
|
search_type = str(args.type).lower()
|
||||||
flow_type = str(args.type).lower()
|
|
||||||
stream = args.stream
|
stream = args.stream
|
||||||
source_types = args.filters.split(",") if args.filters else None
|
source_types = args.filters.split(",") if args.filters else None
|
||||||
if source_types and len(source_types) == 1:
|
|
||||||
source_types = source_types[0]
|
|
||||||
query = " ".join(args.query)
|
query = " ".join(args.query)
|
||||||
|
|
||||||
if flow not in ["qa", "search"]:
|
if search_type not in ["hybrid", "semantic", "keyword"]:
|
||||||
raise ValueError("Flow value must be QA or Search")
|
raise ValueError("Invalid Search")
|
||||||
if flow_type not in ["keyword", "semantic"]:
|
|
||||||
raise ValueError("Type value must be keyword or semantic")
|
|
||||||
if flow != "qa" and stream:
|
|
||||||
raise ValueError("Can only stream results for QA")
|
|
||||||
|
|
||||||
if (flow, flow_type) == ("search", "keyword"):
|
|
||||||
path = "keyword-search"
|
|
||||||
elif (flow, flow_type) == ("search", "semantic"):
|
|
||||||
path = "semantic-search"
|
|
||||||
elif stream:
|
elif stream:
|
||||||
path = "stream-direct-qa"
|
path = "stream-direct-qa"
|
||||||
else:
|
else:
|
||||||
@@ -95,8 +77,9 @@ if __name__ == "__main__":
|
|||||||
query_json = {
|
query_json = {
|
||||||
"query": query,
|
"query": query,
|
||||||
"collection": DOCUMENT_INDEX_NAME,
|
"collection": DOCUMENT_INDEX_NAME,
|
||||||
"use_keyword": flow_type == "keyword", # Ignore if not QA Endpoints
|
"filters": {SOURCE_TYPE: source_types},
|
||||||
"filters": [{SOURCE_TYPE: source_types}],
|
"enable_auto_detect_filters": True,
|
||||||
|
"search_type": search_type,
|
||||||
}
|
}
|
||||||
|
|
||||||
if args.stream:
|
if args.stream:
|
||||||
|
@@ -12,9 +12,9 @@ from danswer.access.access import get_acl_for_user
|
|||||||
from danswer.db.engine import get_sqlalchemy_engine
|
from danswer.db.engine import get_sqlalchemy_engine
|
||||||
from danswer.direct_qa.answer_question import answer_qa_query
|
from danswer.direct_qa.answer_question import answer_qa_query
|
||||||
from danswer.direct_qa.models import LLMMetricsContainer
|
from danswer.direct_qa.models import LLMMetricsContainer
|
||||||
|
from danswer.search.models import IndexFilters
|
||||||
from danswer.search.models import RerankMetricsContainer
|
from danswer.search.models import RerankMetricsContainer
|
||||||
from danswer.search.models import RetrievalMetricsContainer
|
from danswer.search.models import RetrievalMetricsContainer
|
||||||
from danswer.server.models import IndexFilters
|
|
||||||
from danswer.server.models import QuestionRequest
|
from danswer.server.models import QuestionRequest
|
||||||
from danswer.utils.callbacks import MetricsHander
|
from danswer.utils.callbacks import MetricsHander
|
||||||
|
|
||||||
@@ -85,7 +85,6 @@ def get_answer_for_question(
|
|||||||
question = QuestionRequest(
|
question = QuestionRequest(
|
||||||
query=query,
|
query=query,
|
||||||
collection="danswer_index",
|
collection="danswer_index",
|
||||||
use_keyword=False,
|
|
||||||
filters=filters,
|
filters=filters,
|
||||||
enable_auto_detect_filters=False,
|
enable_auto_detect_filters=False,
|
||||||
offset=None,
|
offset=None,
|
||||||
|
@@ -52,44 +52,6 @@ const getAssistantMessage = ({
|
|||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
if (
|
|
||||||
selectedSearchType !== SearchType.AUTOMATIC &&
|
|
||||||
searchResponse.suggestedSearchType !== selectedSearchType
|
|
||||||
) {
|
|
||||||
if (searchResponse.suggestedSearchType === SearchType.SEMANTIC) {
|
|
||||||
return (
|
|
||||||
<div>
|
|
||||||
Your query looks more like natural language, Semantic Search may yield
|
|
||||||
better results. Would you like to{" "}
|
|
||||||
<span
|
|
||||||
className={CLICKABLE_CLASS_NAME}
|
|
||||||
onClick={() => {
|
|
||||||
setSelectedSearchType(SearchType.SEMANTIC);
|
|
||||||
restartSearch({ searchType: SearchType.SEMANTIC });
|
|
||||||
}}
|
|
||||||
>
|
|
||||||
try AI search?
|
|
||||||
</span>
|
|
||||||
</div>
|
|
||||||
);
|
|
||||||
}
|
|
||||||
return (
|
|
||||||
<div>
|
|
||||||
Your query seems to be a better fit for keyword search. Would you like
|
|
||||||
to{" "}
|
|
||||||
<span
|
|
||||||
className={CLICKABLE_CLASS_NAME}
|
|
||||||
onClick={() => {
|
|
||||||
setSelectedSearchType(SearchType.KEYWORD);
|
|
||||||
restartSearch({ searchType: SearchType.KEYWORD });
|
|
||||||
}}
|
|
||||||
>
|
|
||||||
try Keyword search?
|
|
||||||
</span>
|
|
||||||
</div>
|
|
||||||
);
|
|
||||||
}
|
|
||||||
|
|
||||||
if (
|
if (
|
||||||
(searchResponse.suggestedFlowType === FlowType.QUESTION_ANSWER ||
|
(searchResponse.suggestedFlowType === FlowType.QUESTION_ANSWER ||
|
||||||
defaultOverrides.forceDisplayQA) &&
|
defaultOverrides.forceDisplayQA) &&
|
||||||
|
@@ -228,14 +228,6 @@ export const SearchSection: React.FC<SearchSectionProps> = ({
|
|||||||
</div>
|
</div>
|
||||||
</div>
|
</div>
|
||||||
<div className="w-[800px] mx-auto">
|
<div className="w-[800px] mx-auto">
|
||||||
<SearchTypeSelector
|
|
||||||
selectedSearchType={selectedSearchType}
|
|
||||||
setSelectedSearchType={(searchType) => {
|
|
||||||
Cookies.set("searchType", searchType);
|
|
||||||
setSelectedSearchType(searchType);
|
|
||||||
}}
|
|
||||||
/>
|
|
||||||
|
|
||||||
<SearchBar
|
<SearchBar
|
||||||
query={query}
|
query={query}
|
||||||
setQuery={setQuery}
|
setQuery={setQuery}
|
||||||
|
@@ -21,11 +21,6 @@ export const searchRequest = async ({
|
|||||||
selectedSearchType,
|
selectedSearchType,
|
||||||
offset,
|
offset,
|
||||||
}: SearchRequestArgs) => {
|
}: SearchRequestArgs) => {
|
||||||
let useKeyword = null;
|
|
||||||
if (selectedSearchType !== SearchType.AUTOMATIC) {
|
|
||||||
useKeyword = selectedSearchType === SearchType.KEYWORD ? true : false;
|
|
||||||
}
|
|
||||||
|
|
||||||
let answer = "";
|
let answer = "";
|
||||||
let quotes: Quote[] | null = null;
|
let quotes: Quote[] | null = null;
|
||||||
let relevantDocuments: DanswerDocument[] | null = null;
|
let relevantDocuments: DanswerDocument[] | null = null;
|
||||||
@@ -36,7 +31,6 @@ export const searchRequest = async ({
|
|||||||
body: JSON.stringify({
|
body: JSON.stringify({
|
||||||
query,
|
query,
|
||||||
collection: "danswer_index",
|
collection: "danswer_index",
|
||||||
use_keyword: useKeyword,
|
|
||||||
filters,
|
filters,
|
||||||
enable_auto_detect_filters: false,
|
enable_auto_detect_filters: false,
|
||||||
offset: offset,
|
offset: offset,
|
||||||
|
@@ -64,14 +64,8 @@ export const searchRequestStreamed = async ({
|
|||||||
updateSuggestedFlowType,
|
updateSuggestedFlowType,
|
||||||
updateError,
|
updateError,
|
||||||
updateQueryEventId,
|
updateQueryEventId,
|
||||||
selectedSearchType,
|
|
||||||
offset,
|
offset,
|
||||||
}: SearchRequestArgs) => {
|
}: SearchRequestArgs) => {
|
||||||
let useKeyword = null;
|
|
||||||
if (selectedSearchType !== SearchType.AUTOMATIC) {
|
|
||||||
useKeyword = selectedSearchType === SearchType.KEYWORD ? true : false;
|
|
||||||
}
|
|
||||||
|
|
||||||
let answer = "";
|
let answer = "";
|
||||||
let quotes: Quote[] | null = null;
|
let quotes: Quote[] | null = null;
|
||||||
let relevantDocuments: DanswerDocument[] | null = null;
|
let relevantDocuments: DanswerDocument[] | null = null;
|
||||||
@@ -82,7 +76,6 @@ export const searchRequestStreamed = async ({
|
|||||||
body: JSON.stringify({
|
body: JSON.stringify({
|
||||||
query,
|
query,
|
||||||
collection: "danswer_index",
|
collection: "danswer_index",
|
||||||
use_keyword: useKeyword,
|
|
||||||
filters,
|
filters,
|
||||||
enable_auto_detect_filters: false,
|
enable_auto_detect_filters: false,
|
||||||
offset: offset,
|
offset: offset,
|
||||||
|
@@ -21,7 +21,6 @@ export const questionValidationStreamed = async <T>({
|
|||||||
body: JSON.stringify({
|
body: JSON.stringify({
|
||||||
query,
|
query,
|
||||||
collection: "danswer_index",
|
collection: "danswer_index",
|
||||||
use_keyword: null,
|
|
||||||
filters: emptyFilters,
|
filters: emptyFilters,
|
||||||
enable_auto_detect_filters: false,
|
enable_auto_detect_filters: false,
|
||||||
offset: null,
|
offset: null,
|
||||||
|
Reference in New Issue
Block a user