mirror of
https://github.com/danswer-ai/danswer.git
synced 2025-04-04 09:58:32 +02:00
Danswer Helper QA Flow Backend (#90)
This commit is contained in:
parent
1facd58938
commit
f10ece4411
@ -7,6 +7,7 @@ from danswer.search.search_utils import get_default_intent_model
|
||||
from danswer.search.search_utils import get_default_intent_model_tokenizer
|
||||
from danswer.search.search_utils import get_default_tokenizer
|
||||
from danswer.server.models import HelperResponse
|
||||
from danswer.utils.timing import log_function_time
|
||||
from transformers import AutoTokenizer # type:ignore
|
||||
|
||||
|
||||
@ -17,6 +18,7 @@ def count_unk_tokens(text: str, tokenizer: AutoTokenizer) -> int:
|
||||
return len([token for token in tokenized_text if token == tokenizer.unk_token])
|
||||
|
||||
|
||||
@log_function_time()
|
||||
def query_intent(query: str) -> tuple[SearchType, QueryFlow]:
|
||||
tokenizer = get_default_intent_model_tokenizer()
|
||||
intent_model = get_default_intent_model()
|
||||
@ -47,7 +49,7 @@ def query_intent(query: str) -> tuple[SearchType, QueryFlow]:
|
||||
def recommend_search_flow(
|
||||
query: str,
|
||||
keyword: bool,
|
||||
max_percent_stopwords: float = 0.33, # Every third word max, ie "effects of caffeine" still viable keyword search
|
||||
max_percent_stopwords: float = 0.30, # ~Every third word max, ie "effects of caffeine" still viable keyword search
|
||||
) -> HelperResponse:
|
||||
heuristic_search_type: SearchType | None = None
|
||||
message: str | None = None
|
||||
@ -61,24 +63,21 @@ def recommend_search_flow(
|
||||
if count_unk_tokens(query, get_default_tokenizer()) > 0:
|
||||
if not keyword:
|
||||
heuristic_search_type = SearchType.KEYWORD
|
||||
message = (
|
||||
"Query contains words that the AI model cannot understand, "
|
||||
"Keyword Search may yield better results."
|
||||
)
|
||||
message = "Unknown tokens in query."
|
||||
|
||||
# Too many stop words, most likely a Semantic query (still may be valid QA)
|
||||
if non_stopword_percent < 1 - max_percent_stopwords:
|
||||
if keyword:
|
||||
heuristic_search_type = SearchType.SEMANTIC
|
||||
message = "Query contains stopwords, AI Search is likely more suitable."
|
||||
message = "Stopwords in query"
|
||||
|
||||
# Model based decisions
|
||||
model_search_type, flow = query_intent(query)
|
||||
if not message:
|
||||
if model_search_type == SearchType.SEMANTIC and keyword:
|
||||
message = "Query may yield better results with Semantic Search"
|
||||
message = "Intent model classified Semantic Search"
|
||||
if model_search_type == SearchType.KEYWORD and not keyword:
|
||||
message = "Query may yield better results with Keyword Search."
|
||||
message = "Intent model classified Keyword Search."
|
||||
|
||||
return HelperResponse(
|
||||
values={
|
||||
|
@ -10,6 +10,8 @@ from danswer.connectors.models import InputType
|
||||
from danswer.datastores.interfaces import IndexFilter
|
||||
from danswer.db.models import Connector
|
||||
from danswer.db.models import IndexingStatus
|
||||
from danswer.search.models import QueryFlow
|
||||
from danswer.search.models import SearchType
|
||||
from pydantic import BaseModel
|
||||
from pydantic.generics import GenericModel
|
||||
|
||||
@ -90,15 +92,14 @@ class QuestionRequest(BaseModel):
|
||||
class SearchResponse(BaseModel):
|
||||
# For semantic search, top docs are reranked, the remaining are as ordered from retrieval
|
||||
top_ranked_docs: list[SearchDoc] | None
|
||||
semi_ranked_docs: list[SearchDoc] | None
|
||||
lower_ranked_docs: list[SearchDoc] | None
|
||||
|
||||
|
||||
class QAResponse(BaseModel):
|
||||
class QAResponse(SearchResponse):
|
||||
answer: str | None
|
||||
quotes: dict[str, dict[str, str | int | None]] | None
|
||||
ranked_documents: list[SearchDoc] | None
|
||||
# for performance, only a few top documents are cross-encoded for rerank, the rest follow retrieval order
|
||||
unranked_documents: list[SearchDoc] | None
|
||||
predicted_flow: QueryFlow
|
||||
predicted_search: SearchType
|
||||
|
||||
|
||||
class UserByEmail(BaseModel):
|
||||
|
@ -10,8 +10,11 @@ from danswer.datastores.typesense.store import TypesenseIndex
|
||||
from danswer.db.models import User
|
||||
from danswer.direct_qa import get_default_backend_qa_model
|
||||
from danswer.direct_qa.question_answer import get_json_line
|
||||
from danswer.search.danswer_helper import query_intent
|
||||
from danswer.search.danswer_helper import recommend_search_flow
|
||||
from danswer.search.keyword_search import retrieve_keyword_documents
|
||||
from danswer.search.models import QueryFlow
|
||||
from danswer.search.models import SearchType
|
||||
from danswer.search.semantic_search import chunks_to_search_docs
|
||||
from danswer.search.semantic_search import retrieve_ranked_documents
|
||||
from danswer.server.models import HelperResponse
|
||||
@ -51,12 +54,12 @@ def semantic_search(
|
||||
query, user_id, filters, QdrantIndex(collection)
|
||||
)
|
||||
if not ranked_chunks:
|
||||
return SearchResponse(top_ranked_docs=None, semi_ranked_docs=None)
|
||||
return SearchResponse(top_ranked_docs=None, lower_ranked_docs=None)
|
||||
|
||||
top_docs = chunks_to_search_docs(ranked_chunks)
|
||||
other_top_docs = chunks_to_search_docs(unranked_chunks)
|
||||
|
||||
return SearchResponse(top_ranked_docs=top_docs, semi_ranked_docs=other_top_docs)
|
||||
return SearchResponse(top_ranked_docs=top_docs, lower_ranked_docs=other_top_docs)
|
||||
|
||||
|
||||
@router.post("/keyword-search")
|
||||
@ -73,10 +76,10 @@ def keyword_search(
|
||||
query, user_id, filters, TypesenseIndex(collection)
|
||||
)
|
||||
if not ranked_chunks:
|
||||
return SearchResponse(top_ranked_docs=None, semi_ranked_docs=None)
|
||||
return SearchResponse(top_ranked_docs=None, lower_ranked_docs=None)
|
||||
|
||||
top_docs = chunks_to_search_docs(ranked_chunks)
|
||||
return SearchResponse(top_ranked_docs=top_docs, semi_ranked_docs=None)
|
||||
return SearchResponse(top_ranked_docs=top_docs, lower_ranked_docs=None)
|
||||
|
||||
|
||||
@router.post("/direct-qa")
|
||||
@ -92,6 +95,10 @@ def direct_qa(
|
||||
offset_count = question.offset if question.offset is not None else 0
|
||||
logger.info(f"Received QA query: {query}")
|
||||
|
||||
predicted_search, predicted_flow = query_intent(query)
|
||||
if use_keyword is None:
|
||||
use_keyword = predicted_search == SearchType.KEYWORD
|
||||
|
||||
user_id = None if user is None else int(user.id)
|
||||
if use_keyword:
|
||||
ranked_chunks: list[InferenceChunk] | None = retrieve_keyword_documents(
|
||||
@ -104,7 +111,12 @@ def direct_qa(
|
||||
)
|
||||
if not ranked_chunks:
|
||||
return QAResponse(
|
||||
answer=None, quotes=None, ranked_documents=None, unranked_documents=None
|
||||
answer=None,
|
||||
quotes=None,
|
||||
top_ranked_docs=None,
|
||||
lower_ranked_docs=None,
|
||||
predicted_flow=predicted_flow,
|
||||
predicted_search=predicted_search,
|
||||
)
|
||||
|
||||
qa_model = get_default_backend_qa_model(timeout=QA_TIMEOUT)
|
||||
@ -125,8 +137,10 @@ def direct_qa(
|
||||
return QAResponse(
|
||||
answer=answer,
|
||||
quotes=quotes,
|
||||
ranked_documents=chunks_to_search_docs(ranked_chunks),
|
||||
unranked_documents=chunks_to_search_docs(unranked_chunks),
|
||||
top_ranked_docs=chunks_to_search_docs(ranked_chunks),
|
||||
lower_ranked_docs=chunks_to_search_docs(unranked_chunks),
|
||||
predicted_flow=predicted_flow,
|
||||
predicted_search=predicted_search,
|
||||
)
|
||||
|
||||
|
||||
@ -136,6 +150,8 @@ def stream_direct_qa(
|
||||
) -> StreamingResponse:
|
||||
top_documents_key = "top_documents"
|
||||
unranked_top_docs_key = "unranked_top_documents"
|
||||
predicted_flow_key = "predicted_flow"
|
||||
predicted_search_key = "predicted_search"
|
||||
|
||||
def stream_qa_portions() -> Generator[str, None, None]:
|
||||
query = question.query
|
||||
@ -145,6 +161,10 @@ def stream_direct_qa(
|
||||
offset_count = question.offset if question.offset is not None else 0
|
||||
logger.info(f"Received QA query: {query}")
|
||||
|
||||
predicted_search, predicted_flow = query_intent(query)
|
||||
if use_keyword is None:
|
||||
use_keyword = predicted_search == SearchType.KEYWORD
|
||||
|
||||
user_id = None if user is None else int(user.id)
|
||||
if use_keyword:
|
||||
ranked_chunks: list[InferenceChunk] | None = retrieve_keyword_documents(
|
||||
@ -156,16 +176,25 @@ def stream_direct_qa(
|
||||
query, user_id, filters, QdrantIndex(collection)
|
||||
)
|
||||
if not ranked_chunks:
|
||||
yield get_json_line({top_documents_key: None, unranked_top_docs_key: None})
|
||||
yield get_json_line(
|
||||
{
|
||||
top_documents_key: None,
|
||||
unranked_top_docs_key: None,
|
||||
predicted_flow_key: predicted_flow,
|
||||
predicted_search_key: predicted_search,
|
||||
}
|
||||
)
|
||||
return
|
||||
|
||||
top_docs = chunks_to_search_docs(ranked_chunks)
|
||||
unranked_top_docs = chunks_to_search_docs(unranked_chunks)
|
||||
top_docs_dict = {
|
||||
initial_response_dict = {
|
||||
top_documents_key: [top_doc.json() for top_doc in top_docs],
|
||||
unranked_top_docs_key: [doc.json() for doc in unranked_top_docs],
|
||||
predicted_flow_key: predicted_flow,
|
||||
predicted_search_key: predicted_search,
|
||||
}
|
||||
yield get_json_line(top_docs_dict)
|
||||
yield get_json_line(initial_response_dict)
|
||||
|
||||
qa_model = get_default_backend_qa_model(timeout=QA_TIMEOUT)
|
||||
chunk_offset = offset_count * NUM_GENERATIVE_AI_INPUT_DOCS
|
||||
|
@ -2,7 +2,7 @@ import { DanswerDocument, SearchRequestArgs } from "./interfaces";
|
||||
|
||||
interface KeywordResponse {
|
||||
top_ranked_docs: DanswerDocument[];
|
||||
semi_ranked_docs: DanswerDocument[];
|
||||
lower_ranked_docs: DanswerDocument[];
|
||||
}
|
||||
|
||||
export const keywordSearch = async ({
|
||||
@ -37,8 +37,8 @@ export const keywordSearch = async ({
|
||||
const keywordResults = (await response.json()) as KeywordResponse;
|
||||
|
||||
let matchingDocs = keywordResults.top_ranked_docs;
|
||||
if (keywordResults.semi_ranked_docs) {
|
||||
matchingDocs = matchingDocs.concat(keywordResults.semi_ranked_docs);
|
||||
if (keywordResults.lower_ranked_docs) {
|
||||
matchingDocs = matchingDocs.concat(keywordResults.lower_ranked_docs);
|
||||
}
|
||||
|
||||
updateDocs(matchingDocs);
|
||||
|
Loading…
x
Reference in New Issue
Block a user