Danswer Helper QA Flow Backend (#90)

This commit is contained in:
Yuhong Sun 2023-06-09 17:48:17 -07:00 committed by GitHub
parent 1facd58938
commit f10ece4411
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 55 additions and 26 deletions

View File

@ -7,6 +7,7 @@ from danswer.search.search_utils import get_default_intent_model
from danswer.search.search_utils import get_default_intent_model_tokenizer
from danswer.search.search_utils import get_default_tokenizer
from danswer.server.models import HelperResponse
from danswer.utils.timing import log_function_time
from transformers import AutoTokenizer # type:ignore
@ -17,6 +18,7 @@ def count_unk_tokens(text: str, tokenizer: AutoTokenizer) -> int:
return len([token for token in tokenized_text if token == tokenizer.unk_token])
@log_function_time()
def query_intent(query: str) -> tuple[SearchType, QueryFlow]:
tokenizer = get_default_intent_model_tokenizer()
intent_model = get_default_intent_model()
@ -47,7 +49,7 @@ def query_intent(query: str) -> tuple[SearchType, QueryFlow]:
def recommend_search_flow(
query: str,
keyword: bool,
max_percent_stopwords: float = 0.33, # Every third word max, ie "effects of caffeine" still viable keyword search
max_percent_stopwords: float = 0.30, # ~Every third word max, ie "effects of caffeine" still viable keyword search
) -> HelperResponse:
heuristic_search_type: SearchType | None = None
message: str | None = None
@ -61,24 +63,21 @@ def recommend_search_flow(
if count_unk_tokens(query, get_default_tokenizer()) > 0:
if not keyword:
heuristic_search_type = SearchType.KEYWORD
message = (
"Query contains words that the AI model cannot understand, "
"Keyword Search may yield better results."
)
message = "Unknown tokens in query."
# Too many stop words, most likely a Semantic query (still may be valid QA)
if non_stopword_percent < 1 - max_percent_stopwords:
if keyword:
heuristic_search_type = SearchType.SEMANTIC
message = "Query contains stopwords, AI Search is likely more suitable."
message = "Stopwords in query"
# Model based decisions
model_search_type, flow = query_intent(query)
if not message:
if model_search_type == SearchType.SEMANTIC and keyword:
message = "Query may yield better results with Semantic Search"
message = "Intent model classified Semantic Search"
if model_search_type == SearchType.KEYWORD and not keyword:
message = "Query may yield better results with Keyword Search."
message = "Intent model classified Keyword Search."
return HelperResponse(
values={

View File

@ -10,6 +10,8 @@ from danswer.connectors.models import InputType
from danswer.datastores.interfaces import IndexFilter
from danswer.db.models import Connector
from danswer.db.models import IndexingStatus
from danswer.search.models import QueryFlow
from danswer.search.models import SearchType
from pydantic import BaseModel
from pydantic.generics import GenericModel
@ -90,15 +92,14 @@ class QuestionRequest(BaseModel):
class SearchResponse(BaseModel):
# For semantic search, top docs are reranked, the remaining are as ordered from retrieval
top_ranked_docs: list[SearchDoc] | None
semi_ranked_docs: list[SearchDoc] | None
lower_ranked_docs: list[SearchDoc] | None
class QAResponse(BaseModel):
class QAResponse(SearchResponse):
answer: str | None
quotes: dict[str, dict[str, str | int | None]] | None
ranked_documents: list[SearchDoc] | None
# for performance, only a few top documents are cross-encoded for rerank, the rest follow retrieval order
unranked_documents: list[SearchDoc] | None
predicted_flow: QueryFlow
predicted_search: SearchType
class UserByEmail(BaseModel):

View File

@ -10,8 +10,11 @@ from danswer.datastores.typesense.store import TypesenseIndex
from danswer.db.models import User
from danswer.direct_qa import get_default_backend_qa_model
from danswer.direct_qa.question_answer import get_json_line
from danswer.search.danswer_helper import query_intent
from danswer.search.danswer_helper import recommend_search_flow
from danswer.search.keyword_search import retrieve_keyword_documents
from danswer.search.models import QueryFlow
from danswer.search.models import SearchType
from danswer.search.semantic_search import chunks_to_search_docs
from danswer.search.semantic_search import retrieve_ranked_documents
from danswer.server.models import HelperResponse
@ -51,12 +54,12 @@ def semantic_search(
query, user_id, filters, QdrantIndex(collection)
)
if not ranked_chunks:
return SearchResponse(top_ranked_docs=None, semi_ranked_docs=None)
return SearchResponse(top_ranked_docs=None, lower_ranked_docs=None)
top_docs = chunks_to_search_docs(ranked_chunks)
other_top_docs = chunks_to_search_docs(unranked_chunks)
return SearchResponse(top_ranked_docs=top_docs, semi_ranked_docs=other_top_docs)
return SearchResponse(top_ranked_docs=top_docs, lower_ranked_docs=other_top_docs)
@router.post("/keyword-search")
@ -73,10 +76,10 @@ def keyword_search(
query, user_id, filters, TypesenseIndex(collection)
)
if not ranked_chunks:
return SearchResponse(top_ranked_docs=None, semi_ranked_docs=None)
return SearchResponse(top_ranked_docs=None, lower_ranked_docs=None)
top_docs = chunks_to_search_docs(ranked_chunks)
return SearchResponse(top_ranked_docs=top_docs, semi_ranked_docs=None)
return SearchResponse(top_ranked_docs=top_docs, lower_ranked_docs=None)
@router.post("/direct-qa")
@ -92,6 +95,10 @@ def direct_qa(
offset_count = question.offset if question.offset is not None else 0
logger.info(f"Received QA query: {query}")
predicted_search, predicted_flow = query_intent(query)
if use_keyword is None:
use_keyword = predicted_search == SearchType.KEYWORD
user_id = None if user is None else int(user.id)
if use_keyword:
ranked_chunks: list[InferenceChunk] | None = retrieve_keyword_documents(
@ -104,7 +111,12 @@ def direct_qa(
)
if not ranked_chunks:
return QAResponse(
answer=None, quotes=None, ranked_documents=None, unranked_documents=None
answer=None,
quotes=None,
top_ranked_docs=None,
lower_ranked_docs=None,
predicted_flow=predicted_flow,
predicted_search=predicted_search,
)
qa_model = get_default_backend_qa_model(timeout=QA_TIMEOUT)
@ -125,8 +137,10 @@ def direct_qa(
return QAResponse(
answer=answer,
quotes=quotes,
ranked_documents=chunks_to_search_docs(ranked_chunks),
unranked_documents=chunks_to_search_docs(unranked_chunks),
top_ranked_docs=chunks_to_search_docs(ranked_chunks),
lower_ranked_docs=chunks_to_search_docs(unranked_chunks),
predicted_flow=predicted_flow,
predicted_search=predicted_search,
)
@ -136,6 +150,8 @@ def stream_direct_qa(
) -> StreamingResponse:
top_documents_key = "top_documents"
unranked_top_docs_key = "unranked_top_documents"
predicted_flow_key = "predicted_flow"
predicted_search_key = "predicted_search"
def stream_qa_portions() -> Generator[str, None, None]:
query = question.query
@ -145,6 +161,10 @@ def stream_direct_qa(
offset_count = question.offset if question.offset is not None else 0
logger.info(f"Received QA query: {query}")
predicted_search, predicted_flow = query_intent(query)
if use_keyword is None:
use_keyword = predicted_search == SearchType.KEYWORD
user_id = None if user is None else int(user.id)
if use_keyword:
ranked_chunks: list[InferenceChunk] | None = retrieve_keyword_documents(
@ -156,16 +176,25 @@ def stream_direct_qa(
query, user_id, filters, QdrantIndex(collection)
)
if not ranked_chunks:
yield get_json_line({top_documents_key: None, unranked_top_docs_key: None})
yield get_json_line(
{
top_documents_key: None,
unranked_top_docs_key: None,
predicted_flow_key: predicted_flow,
predicted_search_key: predicted_search,
}
)
return
top_docs = chunks_to_search_docs(ranked_chunks)
unranked_top_docs = chunks_to_search_docs(unranked_chunks)
top_docs_dict = {
initial_response_dict = {
top_documents_key: [top_doc.json() for top_doc in top_docs],
unranked_top_docs_key: [doc.json() for doc in unranked_top_docs],
predicted_flow_key: predicted_flow,
predicted_search_key: predicted_search,
}
yield get_json_line(top_docs_dict)
yield get_json_line(initial_response_dict)
qa_model = get_default_backend_qa_model(timeout=QA_TIMEOUT)
chunk_offset = offset_count * NUM_GENERATIVE_AI_INPUT_DOCS

View File

@ -2,7 +2,7 @@ import { DanswerDocument, SearchRequestArgs } from "./interfaces";
interface KeywordResponse {
top_ranked_docs: DanswerDocument[];
semi_ranked_docs: DanswerDocument[];
lower_ranked_docs: DanswerDocument[];
}
export const keywordSearch = async ({
@ -37,8 +37,8 @@ export const keywordSearch = async ({
const keywordResults = (await response.json()) as KeywordResponse;
let matchingDocs = keywordResults.top_ranked_docs;
if (keywordResults.semi_ranked_docs) {
matchingDocs = matchingDocs.concat(keywordResults.semi_ranked_docs);
if (keywordResults.lower_ranked_docs) {
matchingDocs = matchingDocs.concat(keywordResults.lower_ranked_docs);
}
updateDocs(matchingDocs);