diff --git a/backend/danswer/search/models.py b/backend/danswer/search/models.py index d895df098b..ada16ef06a 100644 --- a/backend/danswer/search/models.py +++ b/backend/danswer/search/models.py @@ -249,6 +249,13 @@ class SavedSearchDoc(SearchDoc): return self.score < other.score +class SavedSearchDocWithContent(SavedSearchDoc): + """Used for endpoints that need to return the actual contents of the retrieved + section in addition to the match_highlights.""" + + content: str + + class RetrievalDocs(BaseModel): top_documents: list[SavedSearchDoc] diff --git a/backend/danswer/search/utils.py b/backend/danswer/search/utils.py index a0b41f9fbe..8b138d2e9b 100644 --- a/backend/danswer/search/utils.py +++ b/backend/danswer/search/utils.py @@ -5,10 +5,18 @@ from danswer.db.models import SearchDoc as DBSearchDoc from danswer.search.models import InferenceChunk from danswer.search.models import InferenceSection from danswer.search.models import SavedSearchDoc +from danswer.search.models import SavedSearchDocWithContent from danswer.search.models import SearchDoc -T = TypeVar("T", InferenceSection, InferenceChunk, SearchDoc) +T = TypeVar( + "T", + InferenceSection, + InferenceChunk, + SearchDoc, + SavedSearchDoc, + SavedSearchDocWithContent, +) def dedupe_documents(items: list[T]) -> tuple[list[T], list[int]]: diff --git a/backend/ee/danswer/server/query_and_chat/query_backend.py b/backend/ee/danswer/server/query_and_chat/query_backend.py index b0a9732374..f6cf7297a4 100644 --- a/backend/ee/danswer/server/query_and_chat/query_backend.py +++ b/backend/ee/danswer/server/query_and_chat/query_backend.py @@ -1,5 +1,6 @@ from fastapi import APIRouter from fastapi import Depends +from pydantic import BaseModel from sqlalchemy.orm import Session from danswer.auth.users import current_user @@ -17,11 +18,9 @@ from danswer.llm.utils import get_max_input_tokens from danswer.one_shot_answer.answer_question import get_search_answer from danswer.one_shot_answer.models import DirectQARequest from danswer.one_shot_answer.models import OneShotQAResponse -from danswer.search.models import SavedSearchDoc +from danswer.search.models import SavedSearchDocWithContent from danswer.search.models import SearchRequest -from danswer.search.models import SearchResponse from danswer.search.pipeline import SearchPipeline -from danswer.search.utils import chunks_or_sections_to_search_docs from danswer.search.utils import dedupe_documents from danswer.search.utils import drop_llm_indices from danswer.utils.logger import setup_logger @@ -32,12 +31,17 @@ logger = setup_logger() basic_router = APIRouter(prefix="/query") +class DocumentSearchResponse(BaseModel): + top_documents: list[SavedSearchDocWithContent] + llm_indices: list[int] + + @basic_router.post("/document-search") def handle_search_request( search_request: DocumentSearchRequest, user: User | None = Depends(current_user), db_session: Session = Depends(get_session), -) -> SearchResponse: +) -> DocumentSearchResponse: """Simple search endpoint, does not create a new message or records in the DB""" query = search_request.message logger.info(f"Received document search query: {query}") @@ -67,7 +71,30 @@ def handle_search_request( top_sections = search_pipeline.reranked_sections # If using surrounding context or full doc, this will be empty relevant_section_indices = search_pipeline.relevant_section_indices - top_docs = chunks_or_sections_to_search_docs(top_sections) + top_docs = [ + SavedSearchDocWithContent( + document_id=section.center_chunk.document_id, + chunk_ind=section.center_chunk.chunk_id, + content=section.center_chunk.content, + semantic_identifier=section.center_chunk.semantic_identifier or "Unknown", + link=section.center_chunk.source_links.get(0) + if section.center_chunk.source_links + else None, + blurb=section.center_chunk.blurb, + source_type=section.center_chunk.source_type, + boost=section.center_chunk.boost, + hidden=section.center_chunk.hidden, + metadata=section.center_chunk.metadata, + score=section.center_chunk.score or 0.0, + match_highlights=section.center_chunk.match_highlights, + updated_at=section.center_chunk.updated_at, + primary_owners=section.center_chunk.primary_owners, + secondary_owners=section.center_chunk.secondary_owners, + is_internet=False, + db_doc_id=0, + ) + for section in top_sections + ] # Deduping happens at the last step to avoid harming quality by dropping content early on deduped_docs = top_docs @@ -75,18 +102,15 @@ def handle_search_request( if search_request.retrieval_options.dedupe_docs: deduped_docs, dropped_inds = dedupe_documents(top_docs) - # No need to save the docs for this API - fake_saved_docs = [SavedSearchDoc.from_search_doc(doc) for doc in deduped_docs] - if dropped_inds: relevant_section_indices = drop_llm_indices( llm_indices=relevant_section_indices, - search_docs=fake_saved_docs, + search_docs=deduped_docs, dropped_indices=dropped_inds, ) - return SearchResponse( - top_documents=fake_saved_docs, llm_indices=relevant_section_indices + return DocumentSearchResponse( + top_documents=deduped_docs, llm_indices=relevant_section_indices )