Add Chunk Context options for EE APIs (#68)

This commit is contained in:
Yuhong Sun
2024-04-14 19:06:22 -07:00
committed by Chris Weaver
parent 680aca68e5
commit c00bd44bcc
3 changed files with 25 additions and 9 deletions

View File

@@ -90,6 +90,9 @@ def handle_simplified_chat_message(
search_doc_ids=chat_message_req.search_doc_ids, search_doc_ids=chat_message_req.search_doc_ids,
retrieval_options=retrieval_options, retrieval_options=retrieval_options,
query_override=chat_message_req.query_override, query_override=chat_message_req.query_override,
chunks_above=chat_message_req.chunks_above,
chunks_below=chat_message_req.chunks_below,
full_doc=chat_message_req.full_doc,
) )
packets = stream_chat_message_objects( packets = stream_chat_message_objects(

View File

@@ -1,10 +1,22 @@
from pydantic import BaseModel from pydantic import BaseModel
from danswer.configs.constants import DocumentSource from danswer.configs.constants import DocumentSource
from danswer.search.enums import SearchType
from danswer.search.models import ChunkContext
from danswer.search.models import RetrievalDetails from danswer.search.models import RetrievalDetails
class BasicCreateChatMessageRequest(BaseModel): class DocumentSearchRequest(ChunkContext):
message: str
search_type: SearchType
retrieval_options: RetrievalDetails
recency_bias_multiplier: float = 1.0
# This is to forcibly skip (or run) the step, if None it uses the system defaults
skip_rerank: bool | None = None
skip_llm_chunk_filter: bool | None = None
class BasicCreateChatMessageRequest(ChunkContext):
"""Before creating messages, be sure to create a chat_session and get an id """Before creating messages, be sure to create a chat_session and get an id
Note, for simplicity this option only allows for a single linear chain of messages Note, for simplicity this option only allows for a single linear chain of messages
""" """

View File

@@ -3,7 +3,6 @@ from fastapi import Depends
from sqlalchemy.orm import Session from sqlalchemy.orm import Session
from danswer.auth.users import current_user from danswer.auth.users import current_user
from danswer.configs.chat_configs import DISABLE_LLM_CHUNK_FILTER
from danswer.configs.danswerbot_configs import DANSWER_BOT_TARGET_CHUNK_PERCENTAGE from danswer.configs.danswerbot_configs import DANSWER_BOT_TARGET_CHUNK_PERCENTAGE
from danswer.db.chat import get_persona_by_id from danswer.db.chat import get_persona_by_id
from danswer.db.engine import get_session from danswer.db.engine import get_session
@@ -20,9 +19,9 @@ from danswer.search.models import SavedSearchDoc
from danswer.search.models import SearchRequest from danswer.search.models import SearchRequest
from danswer.search.models import SearchResponse from danswer.search.models import SearchResponse
from danswer.search.pipeline import SearchPipeline from danswer.search.pipeline import SearchPipeline
from danswer.search.utils import chunks_to_search_docs from danswer.search.utils import chunks_or_sections_to_search_docs
from danswer.server.query_and_chat.models import DocumentSearchRequest
from danswer.utils.logger import setup_logger from danswer.utils.logger import setup_logger
from ee.danswer.server.query_and_chat.models import DocumentSearchRequest
logger = setup_logger() logger = setup_logger()
@@ -34,8 +33,6 @@ def handle_search_request(
search_request: DocumentSearchRequest, search_request: DocumentSearchRequest,
user: User | None = Depends(current_user), user: User | None = Depends(current_user),
db_session: Session = Depends(get_session), db_session: Session = Depends(get_session),
# Default to running LLM filter unless globally disabled
disable_llm_chunk_filter: bool = DISABLE_LLM_CHUNK_FILTER,
) -> SearchResponse: ) -> SearchResponse:
"""Simple search endpoint, does not create a new message or records in the DB""" """Simple search endpoint, does not create a new message or records in the DB"""
query = search_request.message query = search_request.message
@@ -52,14 +49,18 @@ def handle_search_request(
limit=search_request.retrieval_options.limit, limit=search_request.retrieval_options.limit,
skip_rerank=search_request.skip_rerank, skip_rerank=search_request.skip_rerank,
skip_llm_chunk_filter=search_request.skip_llm_chunk_filter, skip_llm_chunk_filter=search_request.skip_llm_chunk_filter,
chunks_above=search_request.chunks_above,
chunks_below=search_request.chunks_below,
full_doc=search_request.full_doc,
), ),
user=user, user=user,
db_session=db_session, db_session=db_session,
bypass_acl=False, bypass_acl=False,
) )
top_chunks = search_pipeline.reranked_docs top_sections = search_pipeline.reranked_sections
relevant_chunk_indices = search_pipeline.relevant_chunk_indicies # If using surrounding context or full doc, this will be empty
top_docs = chunks_to_search_docs(top_chunks) relevant_chunk_indices = search_pipeline.relevant_chunk_indices
top_docs = chunks_or_sections_to_search_docs(top_sections)
# No need to save the docs for this API # No need to save the docs for this API
fake_saved_docs = [SavedSearchDoc.from_search_doc(doc) for doc in top_docs] fake_saved_docs = [SavedSearchDoc.from_search_doc(doc) for doc in top_docs]