Prep for Hybrid Search (#648)

This commit is contained in:
Yuhong Sun
2023-10-29 00:13:21 -07:00
committed by GitHub
parent bfa338e142
commit 26b491fb0c
19 changed files with 199 additions and 442 deletions

View File

@@ -38,9 +38,11 @@ from danswer.llm.llm import LLM
from danswer.llm.utils import get_default_llm_tokenizer
from danswer.llm.utils import translate_danswer_msg_to_langchain
from danswer.search.access_filters import build_access_filters_for_user
from danswer.search.models import IndexFilters
from danswer.search.models import SearchQuery
from danswer.search.models import SearchType
from danswer.search.search_runner import chunks_to_search_docs
from danswer.search.search_runner import retrieve_ranked_documents
from danswer.server.models import IndexFilters
from danswer.search.search_runner import search_chunks
from danswer.server.models import RetrievalDocs
from danswer.utils.logger import setup_logger
from danswer.utils.text_processing import extract_embedded_json
@@ -130,13 +132,18 @@ def danswer_chat_retrieval(
else:
reworded_query = query_message.message
# Good Debug/Breakpoint
ranked_chunks, unranked_chunks = retrieve_ranked_documents(
search_query = SearchQuery(
query=reworded_query,
search_type=SearchType.HYBRID,
filters=filters,
favor_recent=False,
datastore=get_default_document_index(),
)
# Good Debug/Breakpoint
ranked_chunks, unranked_chunks = search_chunks(
query=search_query, document_index=get_default_document_index()
)
if not ranked_chunks:
return []

View File

@@ -26,9 +26,9 @@ from danswer.danswerbot.slack.utils import respond_in_thread
from danswer.db.engine import get_sqlalchemy_engine
from danswer.db.models import SlackBotConfig
from danswer.direct_qa.answer_question import answer_qa_query
from danswer.search.models import BaseFilters
from danswer.server.models import QAResponse
from danswer.server.models import QuestionRequest
from danswer.server.models import RequestFilters
from danswer.utils.logger import setup_logger
logger_base = setup_logger()
@@ -182,7 +182,7 @@ def handle_message(
try:
# By leaving time_cutoff and favor_recent as None, and setting enable_auto_detect_filters
# it allows the slack flow to extract out filters from the user query
filters = RequestFilters(
filters = BaseFilters(
source_type=None,
document_set=document_set_names,
time_cutoff=None,
@@ -193,7 +193,6 @@ def handle_message(
QuestionRequest(
query=msg,
collection=DOCUMENT_INDEX_NAME,
use_keyword=False, # always use semantic search when handling Slack messages
enable_auto_detect_filters=not disable_auto_detect_filters,
filters=filters,
favor_recent=None,

View File

@@ -96,14 +96,14 @@ def update_document_hidden(db_session: Session, document_id: str, hidden: bool)
def create_query_event(
db_session: Session,
query: str,
selected_flow: SearchType | None,
search_type: SearchType | None,
llm_answer: str | None,
user_id: UUID | None,
retrieved_document_ids: list[str] | None = None,
) -> int:
query_event = QueryEvent(
query=query,
selected_search_flow=selected_flow,
selected_search_flow=search_type,
llm_answer=llm_answer,
retrieved_document_ids=retrieved_document_ids,
user_id=user_id,

View File

@@ -15,19 +15,17 @@ from danswer.direct_qa.llm_utils import get_default_qa_model
from danswer.direct_qa.models import LLMMetricsContainer
from danswer.direct_qa.qa_utils import get_usable_chunks
from danswer.document_index import get_default_document_index
from danswer.indexing.models import InferenceChunk
from danswer.search.access_filters import build_access_filters_for_user
from danswer.search.danswer_helper import query_intent
from danswer.search.models import IndexFilters
from danswer.search.models import QueryFlow
from danswer.search.models import RerankMetricsContainer
from danswer.search.models import RetrievalMetricsContainer
from danswer.search.models import SearchType
from danswer.search.models import SearchQuery
from danswer.search.search_runner import chunks_to_search_docs
from danswer.search.search_runner import retrieve_keyword_documents
from danswer.search.search_runner import retrieve_ranked_documents
from danswer.search.search_runner import search_chunks
from danswer.secondary_llm_flows.answer_validation import get_answer_validity
from danswer.secondary_llm_flows.extract_filters import extract_question_time_filters
from danswer.server.models import IndexFilters
from danswer.server.models import QAResponse
from danswer.server.models import QuestionRequest
from danswer.utils.logger import setup_logger
@@ -51,7 +49,6 @@ def answer_qa_query(
llm_metrics_callback: Callable[[LLMMetricsContainer], None] | None = None,
) -> QAResponse:
query = question.query
use_keyword = question.use_keyword
offset_count = question.offset if question.offset is not None else 0
logger.info(f"Received QA query: {query}")
@@ -61,18 +58,12 @@ def answer_qa_query(
query_event_id = create_query_event(
query=query,
selected_flow=SearchType.KEYWORD
if question.use_keyword
else SearchType.SEMANTIC,
search_type=question.search_type,
llm_answer=None,
user_id=user.id if user is not None else None,
db_session=db_session,
)
predicted_search, predicted_flow = query_intent(query)
if use_keyword is None:
use_keyword = predicted_search == SearchType.KEYWORD
user_id = None if user is None else user.id
user_acl_filters = build_access_filters_for_user(user, db_session)
final_filters = IndexFilters(
@@ -81,24 +72,23 @@ def answer_qa_query(
time_cutoff=time_cutoff,
access_control_list=user_acl_filters,
)
if use_keyword:
ranked_chunks: list[InferenceChunk] | None = retrieve_keyword_documents(
query=query,
filters=final_filters,
favor_recent=favor_recent,
datastore=get_default_document_index(),
retrieval_metrics_callback=retrieval_metrics_callback,
)
unranked_chunks: list[InferenceChunk] | None = []
else:
ranked_chunks, unranked_chunks = retrieve_ranked_documents(
query=query,
filters=final_filters,
favor_recent=favor_recent,
datastore=get_default_document_index(),
retrieval_metrics_callback=retrieval_metrics_callback,
rerank_metrics_callback=rerank_metrics_callback,
)
search_query = SearchQuery(
query=query,
search_type=question.search_type,
filters=final_filters,
favor_recent=True if question.favor_recent is None else question.favor_recent,
)
# TODO retire this
predicted_search, predicted_flow = query_intent(query)
ranked_chunks, unranked_chunks = search_chunks(
query=search_query,
document_index=get_default_document_index(),
retrieval_metrics_callback=retrieval_metrics_callback,
rerank_metrics_callback=rerank_metrics_callback,
)
if not ranked_chunks:
return QAResponse(
answer=None,
@@ -114,6 +104,7 @@ def answer_qa_query(
top_docs = chunks_to_search_docs(ranked_chunks)
unranked_top_docs = chunks_to_search_docs(unranked_chunks)
update_query_event_retrieved_documents(
db_session=db_session,
retrieved_document_ids=[doc.document_id for doc in top_docs],

View File

@@ -4,10 +4,9 @@ from datetime import datetime
from typing import Any
from danswer.access.models import DocumentAccess
from danswer.configs.model_configs import SEARCH_DISTANCE_CUTOFF
from danswer.indexing.models import DocMetadataAwareIndexChunk
from danswer.indexing.models import InferenceChunk
from danswer.server.models import IndexFilters
from danswer.search.models import IndexFilters
@dataclass(frozen=True)
@@ -97,7 +96,6 @@ class VectorCapable(abc.ABC):
filters: IndexFilters,
favor_recent: bool,
num_to_retrieve: int,
distance_cutoff: float | None = SEARCH_DISTANCE_CUTOFF,
) -> list[InferenceChunk]:
raise NotImplementedError

View File

@@ -48,12 +48,13 @@ from danswer.configs.model_configs import SEARCH_DISTANCE_CUTOFF
from danswer.document_index.document_index_utils import get_uuid_from_chunk
from danswer.document_index.interfaces import DocumentIndex
from danswer.document_index.interfaces import DocumentInsertionRecord
from danswer.document_index.interfaces import IndexFilters
from danswer.document_index.interfaces import UpdateRequest
from danswer.document_index.vespa.utils import remove_invalid_unicode_chars
from danswer.indexing.models import DocMetadataAwareIndexChunk
from danswer.indexing.models import InferenceChunk
from danswer.search.models import IndexFilters
from danswer.search.search_runner import embed_query
from danswer.search.search_runner import query_processing
from danswer.search.search_runner import remove_stop_words
from danswer.utils.batching import batch_generator
from danswer.utils.logger import setup_logger
@@ -585,6 +586,7 @@ class VespaIndex(DocumentIndex):
filters: IndexFilters,
favor_recent: bool,
num_to_retrieve: int = NUM_RETURNED_HITS,
edit_keyword_query: bool = EDIT_KEYWORD_QUERY,
) -> list[InferenceChunk]:
decay_multiplier = FAVOR_RECENT_DECAY_MULTIPLIER if favor_recent else 1
vespa_where_clauses = _build_vespa_filters(filters)
@@ -599,9 +601,11 @@ class VespaIndex(DocumentIndex):
+ _build_vespa_limit(num_to_retrieve)
)
final_query = query_processing(query) if edit_keyword_query else query
params: dict[str, str | int] = {
"yql": yql,
"query": query,
"query": final_query,
"input.query(decay_factor)": str(DOC_TIME_DECAY * decay_multiplier),
"hits": num_to_retrieve,
"num_to_rerank": 10 * num_to_retrieve,
@@ -617,6 +621,7 @@ class VespaIndex(DocumentIndex):
favor_recent: bool,
num_to_retrieve: int = NUM_RETURNED_HITS,
distance_cutoff: float | None = SEARCH_DISTANCE_CUTOFF,
edit_keyword_query: bool = EDIT_KEYWORD_QUERY,
) -> list[InferenceChunk]:
decay_multiplier = FAVOR_RECENT_DECAY_MULTIPLIER if favor_recent else 1
vespa_where_clauses = _build_vespa_filters(filters)
@@ -634,7 +639,7 @@ class VespaIndex(DocumentIndex):
query_embedding = embed_query(query)
query_keywords = (
" ".join(remove_stop_words(query)) if EDIT_KEYWORD_QUERY else query
" ".join(remove_stop_words(query)) if edit_keyword_query else query
)
params = {
@@ -653,31 +658,11 @@ class VespaIndex(DocumentIndex):
filters: IndexFilters,
favor_recent: bool,
num_to_retrieve: int,
distance_cutoff: float | None = SEARCH_DISTANCE_CUTOFF,
edit_keyword_query: bool = EDIT_KEYWORD_QUERY,
) -> list[InferenceChunk]:
vespa_where_clauses = _build_vespa_filters(filters)
yql = (
VespaIndex.yql_base
+ vespa_where_clauses
+ f"({{targetHits: {10 * num_to_retrieve}}}nearestNeighbor(embeddings, query_embedding)) or "
+ '({grammar: "weakAnd"}userInput(@query) '
# `({defaultIndex: "content_summary"}userInput(@query))` section is
# needed for highlighting while the N-gram highlighting is broken /
# not working as desired
+ f'or ({{defaultIndex: "{CONTENT_SUMMARY}"}}userInput(@query)))'
+ _build_vespa_limit(num_to_retrieve)
)
query_embedding = embed_query(query)
params = {
"yql": yql,
"query": query,
"input.query(query_embedding)": str(query_embedding),
"input.query(decay_factor)": str(DOC_TIME_DECAY),
"ranking.profile": "hybrid_search",
}
return _query_vespa(params)
# TODO introduce the real hybrid search
return self.semantic_retrieval(query, filters, favor_recent, num_to_retrieve)
def admin_retrieval(
self,

View File

@@ -2,7 +2,7 @@ from sqlalchemy.orm import Session
from danswer.access.access import get_acl_for_user
from danswer.db.models import User
from danswer.server.models import IndexFilters
from danswer.search.models import IndexFilters
def build_access_filters_for_user(user: User | None, session: Session) -> list[str]:

View File

@@ -66,7 +66,7 @@ def query_intent(query: str) -> tuple[SearchType, QueryFlow]:
def recommend_search_flow(
query: str,
keyword: bool,
keyword: bool = False,
max_percent_stopwords: float = 0.30, # ~Every third word max, ie "effects of caffeine" still viable keyword search
) -> HelperResponse:
heuristic_search_type: SearchType | None = None

View File

@@ -1,11 +1,14 @@
from datetime import datetime
from enum import Enum
from pydantic import BaseModel
from danswer.configs.app_configs import NUM_RERANKED_RESULTS
from danswer.configs.app_configs import NUM_RETURNED_HITS
from danswer.configs.model_configs import SKIP_RERANKING
from danswer.indexing.models import DocAwareChunk
from danswer.indexing.models import IndexChunk
MAX_METRICS_CONTENT = (
200 # Just need enough characters to identify where in the doc the chunk is
)
@@ -27,6 +30,16 @@ class Embedder:
raise NotImplementedError
class BaseFilters(BaseModel):
source_type: list[str] | None = None
document_set: list[str] | None = None
time_cutoff: datetime | None = None
class IndexFilters(BaseFilters):
access_control_list: list[str]
class ChunkMetric(BaseModel):
document_id: str
chunk_content_start: str
@@ -34,6 +47,17 @@ class ChunkMetric(BaseModel):
score: float
class SearchQuery(BaseModel):
query: str
search_type: SearchType
filters: IndexFilters
favor_recent: bool
num_hits: int = NUM_RETURNED_HITS
skip_rerank: bool = SKIP_RERANKING
# Only used if not skip_rerank
num_rerank: int | None = NUM_RERANKED_RESULTS
class RetrievalMetricsContainer(BaseModel):
keyword_search: bool # False for Vector Search
metrics: list[ChunkMetric] # This contains the scores for retrieval as well

View File

@@ -6,26 +6,23 @@ from nltk.stem import WordNetLemmatizer # type:ignore
from nltk.tokenize import word_tokenize # type:ignore
from sentence_transformers import SentenceTransformer # type: ignore
from danswer.configs.app_configs import EDIT_KEYWORD_QUERY
from danswer.configs.app_configs import NUM_RERANKED_RESULTS
from danswer.configs.app_configs import NUM_RETURNED_HITS
from danswer.configs.model_configs import ASYM_QUERY_PREFIX
from danswer.configs.model_configs import CROSS_ENCODER_RANGE_MAX
from danswer.configs.model_configs import CROSS_ENCODER_RANGE_MIN
from danswer.configs.model_configs import NORMALIZE_EMBEDDINGS
from danswer.configs.model_configs import SIM_SCORE_RANGE_HIGH
from danswer.configs.model_configs import SIM_SCORE_RANGE_LOW
from danswer.configs.model_configs import SKIP_RERANKING
from danswer.document_index.document_index_utils import (
translate_boost_count_to_multiplier,
)
from danswer.document_index.interfaces import DocumentIndex
from danswer.document_index.interfaces import IndexFilters
from danswer.indexing.models import InferenceChunk
from danswer.search.models import ChunkMetric
from danswer.search.models import MAX_METRICS_CONTENT
from danswer.search.models import RerankMetricsContainer
from danswer.search.models import RetrievalMetricsContainer
from danswer.search.models import SearchQuery
from danswer.search.models import SearchType
from danswer.search.search_nlp_models import get_default_embedding_model
from danswer.search.search_nlp_models import get_default_reranking_model_ensemble
from danswer.server.models import SearchDoc
@@ -56,6 +53,24 @@ def query_processing(
return query
def embed_query(
query: str,
embedding_model: SentenceTransformer | None = None,
prefix: str = ASYM_QUERY_PREFIX,
normalize_embeddings: bool = NORMALIZE_EMBEDDINGS,
) -> list[float]:
model = embedding_model or get_default_embedding_model()
prefixed_query = prefix + query
query_embedding = model.encode(
prefixed_query, normalize_embeddings=normalize_embeddings
)
if not isinstance(query_embedding, list):
query_embedding = query_embedding.tolist()
return query_embedding
def chunks_to_search_docs(chunks: list[InferenceChunk] | None) -> list[SearchDoc]:
search_docs = (
[
@@ -233,73 +248,44 @@ def apply_boost(
return final_chunks
@log_function_time()
def retrieve_keyword_documents(
query: str,
filters: IndexFilters,
favor_recent: bool,
datastore: DocumentIndex,
num_hits: int = NUM_RETURNED_HITS,
edit_query: bool = EDIT_KEYWORD_QUERY,
retrieval_metrics_callback: Callable[[RetrievalMetricsContainer], None]
| None = None,
) -> list[InferenceChunk] | None:
edited_query = query_processing(query) if edit_query else query
top_chunks = datastore.keyword_retrieval(
edited_query, filters, favor_recent, num_hits
)
if not top_chunks:
logger.warning(
f"Keyword search returned no results - Filters: {filters}\tEdited Query: {edited_query}"
)
return None
if retrieval_metrics_callback is not None:
chunk_metrics = [
ChunkMetric(
document_id=chunk.document_id,
chunk_content_start=chunk.content[:MAX_METRICS_CONTENT],
first_link=chunk.source_links[0] if chunk.source_links else None,
score=chunk.score if chunk.score is not None else 0,
)
for chunk in top_chunks
]
retrieval_metrics_callback(
RetrievalMetricsContainer(keyword_search=True, metrics=chunk_metrics)
)
return top_chunks
@log_function_time()
def retrieve_ranked_documents(
query: str,
filters: IndexFilters,
favor_recent: bool,
datastore: DocumentIndex,
num_hits: int = NUM_RETURNED_HITS,
num_rerank: int = NUM_RERANKED_RESULTS,
skip_rerank: bool = SKIP_RERANKING,
def search_chunks(
query: SearchQuery,
document_index: DocumentIndex,
retrieval_metrics_callback: Callable[[RetrievalMetricsContainer], None]
| None = None,
rerank_metrics_callback: Callable[[RerankMetricsContainer], None] | None = None,
) -> tuple[list[InferenceChunk] | None, list[InferenceChunk] | None]:
"""Uses vector similarity to fetch the top num_hits document chunks with a distance cutoff.
Reranks the top num_rerank out of those (instead of all due to latency)"""
def _log_top_chunk_links(search_flow: str, chunks: list[InferenceChunk]) -> None:
top_links = [
c.source_links[0] if c.source_links is not None else "No Link"
for c in chunks
]
logger.info(f"Top links from {search_flow} search: {', '.join(top_links)}")
def _log_top_chunk_links(chunks: list[InferenceChunk]) -> None:
doc_links = [c.source_links[0] for c in chunks if c.source_links is not None]
if query.search_type == SearchType.KEYWORD:
top_chunks = document_index.keyword_retrieval(
query.query, query.filters, query.favor_recent, query.num_hits
)
files_log_msg = f"Top links from semantic search: {', '.join(doc_links)}"
logger.info(files_log_msg)
elif query.search_type == SearchType.SEMANTIC:
top_chunks = document_index.semantic_retrieval(
query.query, query.filters, query.favor_recent, query.num_hits
)
elif query.search_type == SearchType.HYBRID:
top_chunks = document_index.hybrid_retrieval(
query.query, query.filters, query.favor_recent, query.num_hits
)
else:
raise RuntimeError("Invalid Search Flow")
top_chunks = datastore.semantic_retrieval(query, filters, favor_recent, num_hits)
if not top_chunks:
logger.info(f"Semantic search returned no results with filters: {filters}")
logger.info(
f"{query.search_type.value.capitalize()} search returned no results "
f"with filters: {query.filters}"
)
return None, None
logger.debug(top_chunks)
if retrieval_metrics_callback is not None:
chunk_metrics = [
@@ -315,40 +301,24 @@ def retrieve_ranked_documents(
RetrievalMetricsContainer(keyword_search=True, metrics=chunk_metrics)
)
if skip_rerank:
# Keyword Search should never do reranking, no transformers involved in this flow
if query.search_type == SearchType.KEYWORD:
_log_top_chunk_links(query.search_type.value, top_chunks)
return top_chunks, None
if query.skip_rerank:
# Need the range of values to not be too spread out for applying boost
boosted_chunks = apply_boost(top_chunks[:num_rerank])
_log_top_chunk_links(boosted_chunks)
return boosted_chunks, top_chunks[num_rerank:]
# Therefore pass in smaller set of chunks to limit the range for norm-ing
boosted_chunks = apply_boost(top_chunks[: query.num_rerank])
_log_top_chunk_links(query.search_type.value, boosted_chunks)
return boosted_chunks, top_chunks[query.num_rerank :]
ranked_chunks = (
semantic_reranking(
query,
top_chunks[:num_rerank],
rerank_metrics_callback=rerank_metrics_callback,
)
if not skip_rerank
else []
ranked_chunks = semantic_reranking(
query.query,
top_chunks[: query.num_rerank],
rerank_metrics_callback=rerank_metrics_callback,
)
_log_top_chunk_links(ranked_chunks)
_log_top_chunk_links(query.search_type.value, ranked_chunks)
return ranked_chunks, top_chunks[num_rerank:]
def embed_query(
query: str,
embedding_model: SentenceTransformer | None = None,
prefix: str = ASYM_QUERY_PREFIX,
normalize_embeddings: bool = NORMALIZE_EMBEDDINGS,
) -> list[float]:
model = embedding_model or get_default_embedding_model()
prefixed_query = prefix + query
query_embedding = model.encode(
prefixed_query, normalize_embeddings=normalize_embeddings
)
if not isinstance(query_embedding, list):
query_embedding = query_embedding.tolist()
return query_embedding
return ranked_chunks, top_chunks[query.num_rerank :]

View File

@@ -27,6 +27,7 @@ from danswer.db.models import IndexAttempt
from danswer.db.models import IndexingStatus
from danswer.db.models import TaskStatus
from danswer.direct_qa.interfaces import DanswerQuote
from danswer.search.models import BaseFilters
from danswer.search.models import QueryFlow
from danswer.search.models import SearchType
from danswer.server.utils import mask_credential_dict
@@ -189,25 +190,14 @@ class CreateChatSessionID(BaseModel):
chat_session_id: int
class RequestFilters(BaseModel):
source_type: list[str] | None
document_set: list[str] | None
time_cutoff: datetime | None = None
class IndexFilters(RequestFilters):
access_control_list: list[str]
class QuestionRequest(BaseModel):
query: str
collection: str
filters: RequestFilters
filters: BaseFilters
offset: int | None
enable_auto_detect_filters: bool
favor_recent: bool | None = None
use_keyword: bool | None # TODO remove this for hybrid search
search_flow: SearchType | None = None # Default hybrid
search_type: SearchType = SearchType.HYBRID
class QAFeedbackRequest(BaseModel):

View File

@@ -27,20 +27,18 @@ from danswer.direct_qa.llm_utils import get_default_qa_model
from danswer.direct_qa.qa_utils import get_usable_chunks
from danswer.document_index import get_default_document_index
from danswer.document_index.vespa.index import VespaIndex
from danswer.indexing.models import InferenceChunk
from danswer.search.access_filters import build_access_filters_for_user
from danswer.search.danswer_helper import query_intent
from danswer.search.danswer_helper import recommend_search_flow
from danswer.search.models import IndexFilters
from danswer.search.models import QueryFlow
from danswer.search.models import SearchType
from danswer.search.models import SearchQuery
from danswer.search.search_runner import chunks_to_search_docs
from danswer.search.search_runner import retrieve_keyword_documents
from danswer.search.search_runner import retrieve_ranked_documents
from danswer.search.search_runner import search_chunks
from danswer.secondary_llm_flows.extract_filters import extract_question_time_filters
from danswer.secondary_llm_flows.query_validation import get_query_answerability
from danswer.secondary_llm_flows.query_validation import stream_query_answerability
from danswer.server.models import HelperResponse
from danswer.server.models import IndexFilters
from danswer.server.models import QAFeedbackRequest
from danswer.server.models import QAResponse
from danswer.server.models import QueryValidationResponse
@@ -114,8 +112,7 @@ def get_search_type(
question: QuestionRequest, _: User = Depends(current_user)
) -> HelperResponse:
query = question.query
use_keyword = question.use_keyword if question.use_keyword is not None else False
return recommend_search_flow(query, use_keyword)
return recommend_search_flow(query)
@router.post("/query-validation")
@@ -137,14 +134,14 @@ def stream_query_validation(
)
@router.post("/keyword-search")
def keyword_search(
@router.post("/document-search")
def handle_search_request(
question: QuestionRequest,
user: User = Depends(current_user),
db_session: Session = Depends(get_session),
) -> SearchResponse:
query = question.query
logger.info(f"Received keyword search query: {query}")
logger.info(f"Received {question.search_type.value} " f"search query: {query}")
time_cutoff, favor_recent = extract_question_time_filters(question)
question.filters.time_cutoff = time_cutoff
@@ -152,7 +149,7 @@ def keyword_search(
query_event_id = create_query_event(
query=query,
selected_flow=SearchType.KEYWORD,
search_type=question.search_type,
llm_answer=None,
user_id=user.id,
db_session=db_session,
@@ -166,12 +163,18 @@ def keyword_search(
time_cutoff=filters.time_cutoff,
access_control_list=user_acl_filters,
)
ranked_chunks = retrieve_keyword_documents(
search_query = SearchQuery(
query=query,
search_type=question.search_type,
filters=final_filters,
favor_recent=favor_recent,
datastore=get_default_document_index(),
)
ranked_chunks, unranked_chunks = search_chunks(
query=search_query, document_index=get_default_document_index()
)
if not ranked_chunks:
return SearchResponse(
top_ranked_docs=None,
@@ -182,6 +185,8 @@ def keyword_search(
)
top_docs = chunks_to_search_docs(ranked_chunks)
lower_top_docs = chunks_to_search_docs(unranked_chunks)
update_query_event_retrieved_documents(
db_session=db_session,
retrieved_document_ids=[doc.document_id for doc in top_docs],
@@ -191,132 +196,7 @@ def keyword_search(
return SearchResponse(
top_ranked_docs=top_docs,
lower_ranked_docs=None,
query_event_id=query_event_id,
time_cutoff=time_cutoff,
favor_recent=favor_recent,
)
@router.post("/semantic-search")
def semantic_search(
question: QuestionRequest,
user: User = Depends(current_user),
db_session: Session = Depends(get_session),
) -> SearchResponse:
query = question.query
logger.info(f"Received semantic search query: {query}")
time_cutoff, favor_recent = extract_question_time_filters(question)
question.filters.time_cutoff = time_cutoff
filters = question.filters
query_event_id = create_query_event(
query=query,
selected_flow=SearchType.SEMANTIC,
llm_answer=None,
user_id=user.id,
db_session=db_session,
)
user_id = None if user is None else user.id
user_acl_filters = build_access_filters_for_user(user, db_session)
final_filters = IndexFilters(
source_type=filters.source_type,
document_set=filters.document_set,
time_cutoff=filters.time_cutoff,
access_control_list=user_acl_filters,
)
ranked_chunks, unranked_chunks = retrieve_ranked_documents(
query=query,
filters=final_filters,
favor_recent=favor_recent,
datastore=get_default_document_index(),
)
if not ranked_chunks:
return SearchResponse(
top_ranked_docs=None,
lower_ranked_docs=None,
query_event_id=query_event_id,
time_cutoff=time_cutoff,
favor_recent=favor_recent,
)
top_docs = chunks_to_search_docs(ranked_chunks)
other_top_docs = chunks_to_search_docs(unranked_chunks)
update_query_event_retrieved_documents(
db_session=db_session,
retrieved_document_ids=[doc.document_id for doc in top_docs],
query_id=query_event_id,
user_id=user_id,
)
return SearchResponse(
top_ranked_docs=top_docs,
lower_ranked_docs=other_top_docs,
query_event_id=query_event_id,
time_cutoff=time_cutoff,
favor_recent=favor_recent,
)
# TODO don't use this, not done yet
@router.post("/hybrid-search")
def hybrid_search(
question: QuestionRequest,
user: User = Depends(current_user),
db_session: Session = Depends(get_session),
) -> SearchResponse:
query = question.query
logger.info(f"Received hybrid search query: {query}")
time_cutoff, favor_recent = extract_question_time_filters(question)
question.filters.time_cutoff = time_cutoff
filters = question.filters
query_event_id = create_query_event(
query=query,
selected_flow=SearchType.HYBRID,
llm_answer=None,
user_id=user.id,
db_session=db_session,
)
user_id = None if user is None else user.id
user_acl_filters = build_access_filters_for_user(user, db_session)
final_filters = IndexFilters(
source_type=filters.source_type,
document_set=filters.document_set,
time_cutoff=filters.time_cutoff,
access_control_list=user_acl_filters,
)
ranked_chunks, unranked_chunks = retrieve_ranked_documents(
query=query,
filters=final_filters,
favor_recent=favor_recent,
datastore=get_default_document_index(),
)
if not ranked_chunks:
return SearchResponse(
top_ranked_docs=None,
lower_ranked_docs=None,
query_event_id=query_event_id,
time_cutoff=time_cutoff,
favor_recent=favor_recent,
)
top_docs = chunks_to_search_docs(ranked_chunks)
other_top_docs = chunks_to_search_docs(unranked_chunks)
update_query_event_retrieved_documents(
db_session=db_session,
retrieved_document_ids=[doc.document_id for doc in top_docs],
query_id=query_event_id,
user_id=user_id,
)
return SearchResponse(
top_ranked_docs=top_docs,
lower_ranked_docs=other_top_docs,
lower_ranked_docs=lower_top_docs or None,
query_event_id=query_event_id,
time_cutoff=time_cutoff,
favor_recent=favor_recent,
@@ -347,10 +227,10 @@ def stream_direct_qa(
predicted_search_key = "predicted_search"
query_event_id_key = "query_event_id"
logger.debug(f"Received QA query: {question.query}")
logger.debug(
f"Received QA query ({question.search_type.value} search): {question.query}"
)
logger.debug(f"Query filters: {question.filters}")
if question.use_keyword:
logger.debug("User selected Keyword Search")
@log_generator_function_time()
def stream_qa_portions(
@@ -358,40 +238,34 @@ def stream_direct_qa(
) -> Generator[str, None, None]:
answer_so_far: str = ""
query = question.query
use_keyword = question.use_keyword
offset_count = question.offset if question.offset is not None else 0
time_cutoff, favor_recent = extract_question_time_filters(question)
question.filters.time_cutoff = time_cutoff
question.filters.time_cutoff = time_cutoff # not used but just in case
filters = question.filters
predicted_search, predicted_flow = query_intent(query)
if use_keyword is None:
use_keyword = predicted_search == SearchType.KEYWORD
user_id = None if user is None else user.id
user_acl_filters = build_access_filters_for_user(user, db_session)
final_filters = IndexFilters(
source_type=filters.source_type,
document_set=filters.document_set,
time_cutoff=filters.time_cutoff,
time_cutoff=time_cutoff,
access_control_list=user_acl_filters,
)
if use_keyword:
ranked_chunks: list[InferenceChunk] | None = retrieve_keyword_documents(
query=query,
filters=final_filters,
favor_recent=favor_recent,
datastore=get_default_document_index(),
)
unranked_chunks: list[InferenceChunk] | None = []
else:
ranked_chunks, unranked_chunks = retrieve_ranked_documents(
query=query,
filters=final_filters,
favor_recent=favor_recent,
datastore=get_default_document_index(),
)
search_query = SearchQuery(
query=query,
search_type=question.search_type,
filters=final_filters,
favor_recent=favor_recent,
)
ranked_chunks, unranked_chunks = search_chunks(
query=search_query, document_index=get_default_document_index()
)
# TODO retire this
predicted_search, predicted_flow = query_intent(query)
if not ranked_chunks:
logger.debug("No Documents Found")
empty_docs_result = {
@@ -473,17 +347,14 @@ def stream_direct_qa(
query_event_id = create_query_event(
query=query,
selected_flow=SearchType.KEYWORD
if question.use_keyword
else SearchType.SEMANTIC,
search_type=question.search_type,
llm_answer=answer_so_far,
retrieved_document_ids=[doc.document_id for doc in top_docs],
user_id=user_id,
user_id=None if user is None else user.id,
db_session=db_session,
)
yield get_json_line({query_event_id_key: query_event_id})
return
return StreamingResponse(stream_qa_portions(), media_type="application/json")

View File

@@ -14,27 +14,19 @@ if __name__ == "__main__":
previous_query = None
parser = argparse.ArgumentParser()
parser.add_argument(
"-f",
"--flow",
type=str,
default="QA",
help='"Search" or "QA", defaults to "QA"',
)
parser.add_argument(
"-t",
"--type",
type=str,
default="Semantic",
help='"Semantic" or "Keyword", defaults to "Semantic"',
default="hybrid",
help='"hybrid" "semantic" or "keyword", defaults to "hybrid"',
)
parser.add_argument(
"-s",
"--stream",
action="store_true",
help='Enable streaming response, only for flow="QA"',
help="Enable streaming response",
)
parser.add_argument(
@@ -50,7 +42,7 @@ if __name__ == "__main__":
try:
user_input = input(
"\n\nAsk any question:\n"
" - Use -f (QA/Search) and -t (Semantic/Keyword) flags to set endpoint.\n"
" - Use -t (hybrid/semantic/keyword) flag to choose search flow.\n"
" - prefix with -s to stream answer, --filters web,slack etc. for filters.\n"
" - input an empty string to rerun last query.\n\t"
)
@@ -66,25 +58,15 @@ if __name__ == "__main__":
args = parser.parse_args(user_input.split())
flow = str(args.flow).lower()
flow_type = str(args.type).lower()
search_type = str(args.type).lower()
stream = args.stream
source_types = args.filters.split(",") if args.filters else None
if source_types and len(source_types) == 1:
source_types = source_types[0]
query = " ".join(args.query)
if flow not in ["qa", "search"]:
raise ValueError("Flow value must be QA or Search")
if flow_type not in ["keyword", "semantic"]:
raise ValueError("Type value must be keyword or semantic")
if flow != "qa" and stream:
raise ValueError("Can only stream results for QA")
if search_type not in ["hybrid", "semantic", "keyword"]:
raise ValueError("Invalid Search")
if (flow, flow_type) == ("search", "keyword"):
path = "keyword-search"
elif (flow, flow_type) == ("search", "semantic"):
path = "semantic-search"
elif stream:
path = "stream-direct-qa"
else:
@@ -95,8 +77,9 @@ if __name__ == "__main__":
query_json = {
"query": query,
"collection": DOCUMENT_INDEX_NAME,
"use_keyword": flow_type == "keyword", # Ignore if not QA Endpoints
"filters": [{SOURCE_TYPE: source_types}],
"filters": {SOURCE_TYPE: source_types},
"enable_auto_detect_filters": True,
"search_type": search_type,
}
if args.stream:

View File

@@ -12,9 +12,9 @@ from danswer.access.access import get_acl_for_user
from danswer.db.engine import get_sqlalchemy_engine
from danswer.direct_qa.answer_question import answer_qa_query
from danswer.direct_qa.models import LLMMetricsContainer
from danswer.search.models import IndexFilters
from danswer.search.models import RerankMetricsContainer
from danswer.search.models import RetrievalMetricsContainer
from danswer.server.models import IndexFilters
from danswer.server.models import QuestionRequest
from danswer.utils.callbacks import MetricsHander
@@ -85,7 +85,6 @@ def get_answer_for_question(
question = QuestionRequest(
query=query,
collection="danswer_index",
use_keyword=False,
filters=filters,
enable_auto_detect_filters=False,
offset=None,