mirror of
https://github.com/danswer-ai/danswer.git
synced 2025-09-29 05:15:12 +02:00
Danswer Helper QA Flow Backend (#90)
This commit is contained in:
@@ -7,6 +7,7 @@ from danswer.search.search_utils import get_default_intent_model
|
|||||||
from danswer.search.search_utils import get_default_intent_model_tokenizer
|
from danswer.search.search_utils import get_default_intent_model_tokenizer
|
||||||
from danswer.search.search_utils import get_default_tokenizer
|
from danswer.search.search_utils import get_default_tokenizer
|
||||||
from danswer.server.models import HelperResponse
|
from danswer.server.models import HelperResponse
|
||||||
|
from danswer.utils.timing import log_function_time
|
||||||
from transformers import AutoTokenizer # type:ignore
|
from transformers import AutoTokenizer # type:ignore
|
||||||
|
|
||||||
|
|
||||||
@@ -17,6 +18,7 @@ def count_unk_tokens(text: str, tokenizer: AutoTokenizer) -> int:
|
|||||||
return len([token for token in tokenized_text if token == tokenizer.unk_token])
|
return len([token for token in tokenized_text if token == tokenizer.unk_token])
|
||||||
|
|
||||||
|
|
||||||
|
@log_function_time()
|
||||||
def query_intent(query: str) -> tuple[SearchType, QueryFlow]:
|
def query_intent(query: str) -> tuple[SearchType, QueryFlow]:
|
||||||
tokenizer = get_default_intent_model_tokenizer()
|
tokenizer = get_default_intent_model_tokenizer()
|
||||||
intent_model = get_default_intent_model()
|
intent_model = get_default_intent_model()
|
||||||
@@ -47,7 +49,7 @@ def query_intent(query: str) -> tuple[SearchType, QueryFlow]:
|
|||||||
def recommend_search_flow(
|
def recommend_search_flow(
|
||||||
query: str,
|
query: str,
|
||||||
keyword: bool,
|
keyword: bool,
|
||||||
max_percent_stopwords: float = 0.33, # Every third word max, ie "effects of caffeine" still viable keyword search
|
max_percent_stopwords: float = 0.30, # ~Every third word max, ie "effects of caffeine" still viable keyword search
|
||||||
) -> HelperResponse:
|
) -> HelperResponse:
|
||||||
heuristic_search_type: SearchType | None = None
|
heuristic_search_type: SearchType | None = None
|
||||||
message: str | None = None
|
message: str | None = None
|
||||||
@@ -61,24 +63,21 @@ def recommend_search_flow(
|
|||||||
if count_unk_tokens(query, get_default_tokenizer()) > 0:
|
if count_unk_tokens(query, get_default_tokenizer()) > 0:
|
||||||
if not keyword:
|
if not keyword:
|
||||||
heuristic_search_type = SearchType.KEYWORD
|
heuristic_search_type = SearchType.KEYWORD
|
||||||
message = (
|
message = "Unknown tokens in query."
|
||||||
"Query contains words that the AI model cannot understand, "
|
|
||||||
"Keyword Search may yield better results."
|
|
||||||
)
|
|
||||||
|
|
||||||
# Too many stop words, most likely a Semantic query (still may be valid QA)
|
# Too many stop words, most likely a Semantic query (still may be valid QA)
|
||||||
if non_stopword_percent < 1 - max_percent_stopwords:
|
if non_stopword_percent < 1 - max_percent_stopwords:
|
||||||
if keyword:
|
if keyword:
|
||||||
heuristic_search_type = SearchType.SEMANTIC
|
heuristic_search_type = SearchType.SEMANTIC
|
||||||
message = "Query contains stopwords, AI Search is likely more suitable."
|
message = "Stopwords in query"
|
||||||
|
|
||||||
# Model based decisions
|
# Model based decisions
|
||||||
model_search_type, flow = query_intent(query)
|
model_search_type, flow = query_intent(query)
|
||||||
if not message:
|
if not message:
|
||||||
if model_search_type == SearchType.SEMANTIC and keyword:
|
if model_search_type == SearchType.SEMANTIC and keyword:
|
||||||
message = "Query may yield better results with Semantic Search"
|
message = "Intent model classified Semantic Search"
|
||||||
if model_search_type == SearchType.KEYWORD and not keyword:
|
if model_search_type == SearchType.KEYWORD and not keyword:
|
||||||
message = "Query may yield better results with Keyword Search."
|
message = "Intent model classified Keyword Search."
|
||||||
|
|
||||||
return HelperResponse(
|
return HelperResponse(
|
||||||
values={
|
values={
|
||||||
|
@@ -10,6 +10,8 @@ from danswer.connectors.models import InputType
|
|||||||
from danswer.datastores.interfaces import IndexFilter
|
from danswer.datastores.interfaces import IndexFilter
|
||||||
from danswer.db.models import Connector
|
from danswer.db.models import Connector
|
||||||
from danswer.db.models import IndexingStatus
|
from danswer.db.models import IndexingStatus
|
||||||
|
from danswer.search.models import QueryFlow
|
||||||
|
from danswer.search.models import SearchType
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
from pydantic.generics import GenericModel
|
from pydantic.generics import GenericModel
|
||||||
|
|
||||||
@@ -90,15 +92,14 @@ class QuestionRequest(BaseModel):
|
|||||||
class SearchResponse(BaseModel):
|
class SearchResponse(BaseModel):
|
||||||
# For semantic search, top docs are reranked, the remaining are as ordered from retrieval
|
# For semantic search, top docs are reranked, the remaining are as ordered from retrieval
|
||||||
top_ranked_docs: list[SearchDoc] | None
|
top_ranked_docs: list[SearchDoc] | None
|
||||||
semi_ranked_docs: list[SearchDoc] | None
|
lower_ranked_docs: list[SearchDoc] | None
|
||||||
|
|
||||||
|
|
||||||
class QAResponse(BaseModel):
|
class QAResponse(SearchResponse):
|
||||||
answer: str | None
|
answer: str | None
|
||||||
quotes: dict[str, dict[str, str | int | None]] | None
|
quotes: dict[str, dict[str, str | int | None]] | None
|
||||||
ranked_documents: list[SearchDoc] | None
|
predicted_flow: QueryFlow
|
||||||
# for performance, only a few top documents are cross-encoded for rerank, the rest follow retrieval order
|
predicted_search: SearchType
|
||||||
unranked_documents: list[SearchDoc] | None
|
|
||||||
|
|
||||||
|
|
||||||
class UserByEmail(BaseModel):
|
class UserByEmail(BaseModel):
|
||||||
|
@@ -10,8 +10,11 @@ from danswer.datastores.typesense.store import TypesenseIndex
|
|||||||
from danswer.db.models import User
|
from danswer.db.models import User
|
||||||
from danswer.direct_qa import get_default_backend_qa_model
|
from danswer.direct_qa import get_default_backend_qa_model
|
||||||
from danswer.direct_qa.question_answer import get_json_line
|
from danswer.direct_qa.question_answer import get_json_line
|
||||||
|
from danswer.search.danswer_helper import query_intent
|
||||||
from danswer.search.danswer_helper import recommend_search_flow
|
from danswer.search.danswer_helper import recommend_search_flow
|
||||||
from danswer.search.keyword_search import retrieve_keyword_documents
|
from danswer.search.keyword_search import retrieve_keyword_documents
|
||||||
|
from danswer.search.models import QueryFlow
|
||||||
|
from danswer.search.models import SearchType
|
||||||
from danswer.search.semantic_search import chunks_to_search_docs
|
from danswer.search.semantic_search import chunks_to_search_docs
|
||||||
from danswer.search.semantic_search import retrieve_ranked_documents
|
from danswer.search.semantic_search import retrieve_ranked_documents
|
||||||
from danswer.server.models import HelperResponse
|
from danswer.server.models import HelperResponse
|
||||||
@@ -51,12 +54,12 @@ def semantic_search(
|
|||||||
query, user_id, filters, QdrantIndex(collection)
|
query, user_id, filters, QdrantIndex(collection)
|
||||||
)
|
)
|
||||||
if not ranked_chunks:
|
if not ranked_chunks:
|
||||||
return SearchResponse(top_ranked_docs=None, semi_ranked_docs=None)
|
return SearchResponse(top_ranked_docs=None, lower_ranked_docs=None)
|
||||||
|
|
||||||
top_docs = chunks_to_search_docs(ranked_chunks)
|
top_docs = chunks_to_search_docs(ranked_chunks)
|
||||||
other_top_docs = chunks_to_search_docs(unranked_chunks)
|
other_top_docs = chunks_to_search_docs(unranked_chunks)
|
||||||
|
|
||||||
return SearchResponse(top_ranked_docs=top_docs, semi_ranked_docs=other_top_docs)
|
return SearchResponse(top_ranked_docs=top_docs, lower_ranked_docs=other_top_docs)
|
||||||
|
|
||||||
|
|
||||||
@router.post("/keyword-search")
|
@router.post("/keyword-search")
|
||||||
@@ -73,10 +76,10 @@ def keyword_search(
|
|||||||
query, user_id, filters, TypesenseIndex(collection)
|
query, user_id, filters, TypesenseIndex(collection)
|
||||||
)
|
)
|
||||||
if not ranked_chunks:
|
if not ranked_chunks:
|
||||||
return SearchResponse(top_ranked_docs=None, semi_ranked_docs=None)
|
return SearchResponse(top_ranked_docs=None, lower_ranked_docs=None)
|
||||||
|
|
||||||
top_docs = chunks_to_search_docs(ranked_chunks)
|
top_docs = chunks_to_search_docs(ranked_chunks)
|
||||||
return SearchResponse(top_ranked_docs=top_docs, semi_ranked_docs=None)
|
return SearchResponse(top_ranked_docs=top_docs, lower_ranked_docs=None)
|
||||||
|
|
||||||
|
|
||||||
@router.post("/direct-qa")
|
@router.post("/direct-qa")
|
||||||
@@ -92,6 +95,10 @@ def direct_qa(
|
|||||||
offset_count = question.offset if question.offset is not None else 0
|
offset_count = question.offset if question.offset is not None else 0
|
||||||
logger.info(f"Received QA query: {query}")
|
logger.info(f"Received QA query: {query}")
|
||||||
|
|
||||||
|
predicted_search, predicted_flow = query_intent(query)
|
||||||
|
if use_keyword is None:
|
||||||
|
use_keyword = predicted_search == SearchType.KEYWORD
|
||||||
|
|
||||||
user_id = None if user is None else int(user.id)
|
user_id = None if user is None else int(user.id)
|
||||||
if use_keyword:
|
if use_keyword:
|
||||||
ranked_chunks: list[InferenceChunk] | None = retrieve_keyword_documents(
|
ranked_chunks: list[InferenceChunk] | None = retrieve_keyword_documents(
|
||||||
@@ -104,7 +111,12 @@ def direct_qa(
|
|||||||
)
|
)
|
||||||
if not ranked_chunks:
|
if not ranked_chunks:
|
||||||
return QAResponse(
|
return QAResponse(
|
||||||
answer=None, quotes=None, ranked_documents=None, unranked_documents=None
|
answer=None,
|
||||||
|
quotes=None,
|
||||||
|
top_ranked_docs=None,
|
||||||
|
lower_ranked_docs=None,
|
||||||
|
predicted_flow=predicted_flow,
|
||||||
|
predicted_search=predicted_search,
|
||||||
)
|
)
|
||||||
|
|
||||||
qa_model = get_default_backend_qa_model(timeout=QA_TIMEOUT)
|
qa_model = get_default_backend_qa_model(timeout=QA_TIMEOUT)
|
||||||
@@ -125,8 +137,10 @@ def direct_qa(
|
|||||||
return QAResponse(
|
return QAResponse(
|
||||||
answer=answer,
|
answer=answer,
|
||||||
quotes=quotes,
|
quotes=quotes,
|
||||||
ranked_documents=chunks_to_search_docs(ranked_chunks),
|
top_ranked_docs=chunks_to_search_docs(ranked_chunks),
|
||||||
unranked_documents=chunks_to_search_docs(unranked_chunks),
|
lower_ranked_docs=chunks_to_search_docs(unranked_chunks),
|
||||||
|
predicted_flow=predicted_flow,
|
||||||
|
predicted_search=predicted_search,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@@ -136,6 +150,8 @@ def stream_direct_qa(
|
|||||||
) -> StreamingResponse:
|
) -> StreamingResponse:
|
||||||
top_documents_key = "top_documents"
|
top_documents_key = "top_documents"
|
||||||
unranked_top_docs_key = "unranked_top_documents"
|
unranked_top_docs_key = "unranked_top_documents"
|
||||||
|
predicted_flow_key = "predicted_flow"
|
||||||
|
predicted_search_key = "predicted_search"
|
||||||
|
|
||||||
def stream_qa_portions() -> Generator[str, None, None]:
|
def stream_qa_portions() -> Generator[str, None, None]:
|
||||||
query = question.query
|
query = question.query
|
||||||
@@ -145,6 +161,10 @@ def stream_direct_qa(
|
|||||||
offset_count = question.offset if question.offset is not None else 0
|
offset_count = question.offset if question.offset is not None else 0
|
||||||
logger.info(f"Received QA query: {query}")
|
logger.info(f"Received QA query: {query}")
|
||||||
|
|
||||||
|
predicted_search, predicted_flow = query_intent(query)
|
||||||
|
if use_keyword is None:
|
||||||
|
use_keyword = predicted_search == SearchType.KEYWORD
|
||||||
|
|
||||||
user_id = None if user is None else int(user.id)
|
user_id = None if user is None else int(user.id)
|
||||||
if use_keyword:
|
if use_keyword:
|
||||||
ranked_chunks: list[InferenceChunk] | None = retrieve_keyword_documents(
|
ranked_chunks: list[InferenceChunk] | None = retrieve_keyword_documents(
|
||||||
@@ -156,16 +176,25 @@ def stream_direct_qa(
|
|||||||
query, user_id, filters, QdrantIndex(collection)
|
query, user_id, filters, QdrantIndex(collection)
|
||||||
)
|
)
|
||||||
if not ranked_chunks:
|
if not ranked_chunks:
|
||||||
yield get_json_line({top_documents_key: None, unranked_top_docs_key: None})
|
yield get_json_line(
|
||||||
|
{
|
||||||
|
top_documents_key: None,
|
||||||
|
unranked_top_docs_key: None,
|
||||||
|
predicted_flow_key: predicted_flow,
|
||||||
|
predicted_search_key: predicted_search,
|
||||||
|
}
|
||||||
|
)
|
||||||
return
|
return
|
||||||
|
|
||||||
top_docs = chunks_to_search_docs(ranked_chunks)
|
top_docs = chunks_to_search_docs(ranked_chunks)
|
||||||
unranked_top_docs = chunks_to_search_docs(unranked_chunks)
|
unranked_top_docs = chunks_to_search_docs(unranked_chunks)
|
||||||
top_docs_dict = {
|
initial_response_dict = {
|
||||||
top_documents_key: [top_doc.json() for top_doc in top_docs],
|
top_documents_key: [top_doc.json() for top_doc in top_docs],
|
||||||
unranked_top_docs_key: [doc.json() for doc in unranked_top_docs],
|
unranked_top_docs_key: [doc.json() for doc in unranked_top_docs],
|
||||||
|
predicted_flow_key: predicted_flow,
|
||||||
|
predicted_search_key: predicted_search,
|
||||||
}
|
}
|
||||||
yield get_json_line(top_docs_dict)
|
yield get_json_line(initial_response_dict)
|
||||||
|
|
||||||
qa_model = get_default_backend_qa_model(timeout=QA_TIMEOUT)
|
qa_model = get_default_backend_qa_model(timeout=QA_TIMEOUT)
|
||||||
chunk_offset = offset_count * NUM_GENERATIVE_AI_INPUT_DOCS
|
chunk_offset = offset_count * NUM_GENERATIVE_AI_INPUT_DOCS
|
||||||
|
@@ -2,7 +2,7 @@ import { DanswerDocument, SearchRequestArgs } from "./interfaces";
|
|||||||
|
|
||||||
interface KeywordResponse {
|
interface KeywordResponse {
|
||||||
top_ranked_docs: DanswerDocument[];
|
top_ranked_docs: DanswerDocument[];
|
||||||
semi_ranked_docs: DanswerDocument[];
|
lower_ranked_docs: DanswerDocument[];
|
||||||
}
|
}
|
||||||
|
|
||||||
export const keywordSearch = async ({
|
export const keywordSearch = async ({
|
||||||
@@ -37,8 +37,8 @@ export const keywordSearch = async ({
|
|||||||
const keywordResults = (await response.json()) as KeywordResponse;
|
const keywordResults = (await response.json()) as KeywordResponse;
|
||||||
|
|
||||||
let matchingDocs = keywordResults.top_ranked_docs;
|
let matchingDocs = keywordResults.top_ranked_docs;
|
||||||
if (keywordResults.semi_ranked_docs) {
|
if (keywordResults.lower_ranked_docs) {
|
||||||
matchingDocs = matchingDocs.concat(keywordResults.semi_ranked_docs);
|
matchingDocs = matchingDocs.concat(keywordResults.lower_ranked_docs);
|
||||||
}
|
}
|
||||||
|
|
||||||
updateDocs(matchingDocs);
|
updateDocs(matchingDocs);
|
||||||
|
Reference in New Issue
Block a user