mirror of
https://github.com/danswer-ai/danswer.git
synced 2025-09-27 12:29:41 +02:00
DAN-55 Intent Model (#89)
Includes: - Intent Model - Heuristic Classifications - GPT self error classification - Bugfix on finding end of answer stream
This commit is contained in:
@@ -18,7 +18,7 @@ CROSS_ENCODER_MODEL_ENSEMBLE = [
|
|||||||
"cross-encoder/ms-marco-TinyBERT-L-2-v2",
|
"cross-encoder/ms-marco-TinyBERT-L-2-v2",
|
||||||
]
|
]
|
||||||
|
|
||||||
QUERY_EMBEDDING_CONTEXT_SIZE = 256
|
QUERY_MAX_CONTEXT_SIZE = 256
|
||||||
# The below is correlated with CHUNK_SIZE in app_configs but not strictly calculated
|
# The below is correlated with CHUNK_SIZE in app_configs but not strictly calculated
|
||||||
# To avoid extra overhead of tokenizing for chunking during indexing.
|
# To avoid extra overhead of tokenizing for chunking during indexing.
|
||||||
DOC_EMBEDDING_CONTEXT_SIZE = 512
|
DOC_EMBEDDING_CONTEXT_SIZE = 512
|
||||||
@@ -32,3 +32,6 @@ BATCH_SIZE_ENCODE_CHUNKS = 8
|
|||||||
INTERNAL_MODEL_VERSION = os.environ.get("INTERNAL_MODEL", "openai-chat-completion")
|
INTERNAL_MODEL_VERSION = os.environ.get("INTERNAL_MODEL", "openai-chat-completion")
|
||||||
OPENAI_MODEL_VERSION = os.environ.get("OPENAI_MODEL_VERSION", "gpt-3.5-turbo")
|
OPENAI_MODEL_VERSION = os.environ.get("OPENAI_MODEL_VERSION", "gpt-3.5-turbo")
|
||||||
OPENAI_MAX_OUTPUT_TOKENS = 512
|
OPENAI_MAX_OUTPUT_TOKENS = 512
|
||||||
|
|
||||||
|
# Danswer custom Deep Learning Models
|
||||||
|
INTENT_MODEL_VERSION = "danswer/intent-model"
|
||||||
|
@@ -11,7 +11,7 @@ from danswer.datastores.datastore_utils import get_uuid_from_chunk
|
|||||||
from danswer.datastores.interfaces import IndexFilter
|
from danswer.datastores.interfaces import IndexFilter
|
||||||
from danswer.datastores.interfaces import VectorIndex
|
from danswer.datastores.interfaces import VectorIndex
|
||||||
from danswer.datastores.qdrant.indexing import index_qdrant_chunks
|
from danswer.datastores.qdrant.indexing import index_qdrant_chunks
|
||||||
from danswer.search.semantic_search import get_default_embedding_model
|
from danswer.search.search_utils import get_default_embedding_model
|
||||||
from danswer.utils.clients import get_qdrant_client
|
from danswer.utils.clients import get_qdrant_client
|
||||||
from danswer.utils.logging import setup_logger
|
from danswer.utils.logging import setup_logger
|
||||||
from danswer.utils.timing import log_function_time
|
from danswer.utils.timing import log_function_time
|
||||||
|
@@ -109,10 +109,13 @@ def json_chat_processor(question: str, documents: list[str]) -> list[dict[str, s
|
|||||||
task_msg = (
|
task_msg = (
|
||||||
"Now answer the next user query based on documents above and quote relevant sections.\n"
|
"Now answer the next user query based on documents above and quote relevant sections.\n"
|
||||||
"Respond with a JSON containing the answer and up to three most relevant quotes from the documents.\n"
|
"Respond with a JSON containing the answer and up to three most relevant quotes from the documents.\n"
|
||||||
"If the query cannot be answered based on the documents, do not provide an answer.\n"
|
|
||||||
"All quotes MUST be EXACT substrings from provided documents.\n"
|
"All quotes MUST be EXACT substrings from provided documents.\n"
|
||||||
"Your responses should be informative and concise.\n"
|
"Your responses should be informative and concise.\n"
|
||||||
"You MUST prioritize information from provided documents over internal knowledge.\n"
|
"You MUST prioritize information from provided documents over internal knowledge.\n"
|
||||||
|
"If the query cannot be answered based on the documents, respond with "
|
||||||
|
'{"answer": "Information not found", "quotes": []}\n'
|
||||||
|
"If the query requires aggregating whole documents, respond with "
|
||||||
|
'{"answer": "Aggregations not supported", "quotes": []}\n'
|
||||||
f"Sample response:\n{json.dumps(SAMPLE_JSON_RESPONSE)}"
|
f"Sample response:\n{json.dumps(SAMPLE_JSON_RESPONSE)}"
|
||||||
)
|
)
|
||||||
messages = [{"role": "system", "content": intro_msg}]
|
messages = [{"role": "system", "content": intro_msg}]
|
||||||
|
@@ -185,7 +185,8 @@ def process_answer(
|
|||||||
|
|
||||||
def stream_answer_end(answer_so_far: str, next_token: str) -> bool:
|
def stream_answer_end(answer_so_far: str, next_token: str) -> bool:
|
||||||
next_token = next_token.replace('\\"', "")
|
next_token = next_token.replace('\\"', "")
|
||||||
if answer_so_far and answer_so_far[-1] != "\\":
|
# If the previous character is an escape token, don't consider the first character of next_token
|
||||||
|
if answer_so_far and answer_so_far[-1] == "\\":
|
||||||
next_token = next_token[1:]
|
next_token = next_token[1:]
|
||||||
if '"' in next_token:
|
if '"' in next_token:
|
||||||
return True
|
return True
|
||||||
|
@@ -111,7 +111,7 @@ def get_application() -> FastAPI:
|
|||||||
@application.on_event("startup")
|
@application.on_event("startup")
|
||||||
def startup_event() -> None:
|
def startup_event() -> None:
|
||||||
# To avoid circular imports
|
# To avoid circular imports
|
||||||
from danswer.search.semantic_search import (
|
from danswer.search.search_utils import (
|
||||||
warm_up_models,
|
warm_up_models,
|
||||||
)
|
)
|
||||||
from danswer.datastores.qdrant.indexing import create_qdrant_collection
|
from danswer.datastores.qdrant.indexing import create_qdrant_collection
|
||||||
|
91
backend/danswer/search/danswer_helper.py
Normal file
91
backend/danswer/search/danswer_helper.py
Normal file
@@ -0,0 +1,91 @@
|
|||||||
|
import numpy as np
|
||||||
|
import tensorflow as tf # type:ignore
|
||||||
|
from danswer.search.keyword_search import remove_stop_words
|
||||||
|
from danswer.search.models import QueryFlow
|
||||||
|
from danswer.search.models import SearchType
|
||||||
|
from danswer.search.search_utils import get_default_intent_model
|
||||||
|
from danswer.search.search_utils import get_default_intent_model_tokenizer
|
||||||
|
from danswer.search.search_utils import get_default_tokenizer
|
||||||
|
from danswer.server.models import HelperResponse
|
||||||
|
from transformers import AutoTokenizer # type:ignore
|
||||||
|
|
||||||
|
|
||||||
|
def count_unk_tokens(text: str, tokenizer: AutoTokenizer) -> int:
|
||||||
|
"""Unclear if the wordpiece tokenizer used is actually tokenizing anything as the [UNK] token
|
||||||
|
It splits up even foreign characters and unicode emojis without using UNK"""
|
||||||
|
tokenized_text = tokenizer.tokenize(text)
|
||||||
|
return len([token for token in tokenized_text if token == tokenizer.unk_token])
|
||||||
|
|
||||||
|
|
||||||
|
def query_intent(query: str) -> tuple[SearchType, QueryFlow]:
|
||||||
|
tokenizer = get_default_intent_model_tokenizer()
|
||||||
|
intent_model = get_default_intent_model()
|
||||||
|
model_input = tokenizer(query, return_tensors="tf", truncation=True, padding=True)
|
||||||
|
|
||||||
|
predictions = intent_model(model_input)[0]
|
||||||
|
probabilities = tf.nn.softmax(predictions, axis=-1)
|
||||||
|
class_percentages = np.round(probabilities.numpy() * 100, 2)
|
||||||
|
|
||||||
|
keyword, semantic, qa = class_percentages.tolist()[0]
|
||||||
|
|
||||||
|
# Heavily bias towards QA, from user perspective, answering a statement is not as bad as not answering a question
|
||||||
|
if qa > 20:
|
||||||
|
# If one class is very certain, choose it still
|
||||||
|
if keyword > 70:
|
||||||
|
return SearchType.KEYWORD, QueryFlow.SEARCH
|
||||||
|
if semantic > 70:
|
||||||
|
return SearchType.SEMANTIC, QueryFlow.SEARCH
|
||||||
|
# If it's a QA question, it must be a "Semantic" style statement/question
|
||||||
|
return SearchType.SEMANTIC, QueryFlow.QUESTION_ANSWER
|
||||||
|
# If definitely not a QA question, choose between keyword or semantic search
|
||||||
|
elif keyword > semantic:
|
||||||
|
return SearchType.KEYWORD, QueryFlow.SEARCH
|
||||||
|
else:
|
||||||
|
return SearchType.SEMANTIC, QueryFlow.SEARCH
|
||||||
|
|
||||||
|
|
||||||
|
def recommend_search_flow(
|
||||||
|
query: str,
|
||||||
|
keyword: bool,
|
||||||
|
max_percent_stopwords: float = 0.33, # Every third word max, ie "effects of caffeine" still viable keyword search
|
||||||
|
) -> HelperResponse:
|
||||||
|
heuristic_search_type: SearchType | None = None
|
||||||
|
message: str | None = None
|
||||||
|
|
||||||
|
# Heuristics based decisions
|
||||||
|
words = query.split()
|
||||||
|
non_stopwords = remove_stop_words(query)
|
||||||
|
non_stopword_percent = len(non_stopwords) / len(words)
|
||||||
|
|
||||||
|
# UNK tokens -> suggest Keyword (still may be valid QA)
|
||||||
|
if count_unk_tokens(query, get_default_tokenizer()) > 0:
|
||||||
|
if not keyword:
|
||||||
|
heuristic_search_type = SearchType.KEYWORD
|
||||||
|
message = (
|
||||||
|
"Query contains words that the AI model cannot understand, "
|
||||||
|
"Keyword Search may yield better results."
|
||||||
|
)
|
||||||
|
|
||||||
|
# Too many stop words, most likely a Semantic query (still may be valid QA)
|
||||||
|
if non_stopword_percent < 1 - max_percent_stopwords:
|
||||||
|
if keyword:
|
||||||
|
heuristic_search_type = SearchType.SEMANTIC
|
||||||
|
message = "Query contains stopwords, AI Search is likely more suitable."
|
||||||
|
|
||||||
|
# Model based decisions
|
||||||
|
model_search_type, flow = query_intent(query)
|
||||||
|
if not message:
|
||||||
|
if model_search_type == SearchType.SEMANTIC and keyword:
|
||||||
|
message = "Query may yield better results with Semantic Search"
|
||||||
|
if model_search_type == SearchType.KEYWORD and not keyword:
|
||||||
|
message = "Query may yield better results with Keyword Search."
|
||||||
|
|
||||||
|
return HelperResponse(
|
||||||
|
values={
|
||||||
|
"flow": flow,
|
||||||
|
"search_type": model_search_type
|
||||||
|
if heuristic_search_type is None
|
||||||
|
else heuristic_search_type,
|
||||||
|
},
|
||||||
|
details=[message] if message else [],
|
||||||
|
)
|
@@ -13,23 +13,21 @@ from nltk.tokenize import word_tokenize # type:ignore
|
|||||||
logger = setup_logger()
|
logger = setup_logger()
|
||||||
|
|
||||||
|
|
||||||
def lemmatize_text(text: str) -> str:
|
def lemmatize_text(text: str) -> list[str]:
|
||||||
lemmatizer = WordNetLemmatizer()
|
lemmatizer = WordNetLemmatizer()
|
||||||
word_tokens = word_tokenize(text)
|
word_tokens = word_tokenize(text)
|
||||||
lemmatized_text = [lemmatizer.lemmatize(word) for word in word_tokens]
|
return [lemmatizer.lemmatize(word) for word in word_tokens]
|
||||||
return " ".join(lemmatized_text)
|
|
||||||
|
|
||||||
|
|
||||||
def remove_stop_words(text: str) -> str:
|
def remove_stop_words(text: str) -> list[str]:
|
||||||
stop_words = set(stopwords.words("english"))
|
stop_words = set(stopwords.words("english"))
|
||||||
word_tokens = word_tokenize(text)
|
word_tokens = word_tokenize(text)
|
||||||
filtered_text = [word for word in word_tokens if word.casefold() not in stop_words]
|
return [word for word in word_tokens if word.casefold() not in stop_words]
|
||||||
return " ".join(filtered_text)
|
|
||||||
|
|
||||||
|
|
||||||
def query_processing(query: str) -> str:
|
def query_processing(query: str) -> str:
|
||||||
query = remove_stop_words(query)
|
query = " ".join(remove_stop_words(query))
|
||||||
query = lemmatize_text(query)
|
query = " ".join(lemmatize_text(query))
|
||||||
return query
|
return query
|
||||||
|
|
||||||
|
|
||||||
|
19
backend/danswer/search/models.py
Normal file
19
backend/danswer/search/models.py
Normal file
@@ -0,0 +1,19 @@
|
|||||||
|
from enum import Enum
|
||||||
|
|
||||||
|
from danswer.chunking.models import EmbeddedIndexChunk
|
||||||
|
from danswer.chunking.models import IndexChunk
|
||||||
|
|
||||||
|
|
||||||
|
class SearchType(str, Enum):
|
||||||
|
KEYWORD = "keyword" # May be better to also try keyword search if Semantic (AI Search) is on
|
||||||
|
SEMANTIC = "semantic" # Really should try Semantic (AI Search) if keyword is on
|
||||||
|
|
||||||
|
|
||||||
|
class QueryFlow(str, Enum):
|
||||||
|
SEARCH = "search"
|
||||||
|
QUESTION_ANSWER = "question-answer"
|
||||||
|
|
||||||
|
|
||||||
|
class Embedder:
|
||||||
|
def embed(self, chunks: list[IndexChunk]) -> list[EmbeddedIndexChunk]:
|
||||||
|
raise NotImplementedError
|
76
backend/danswer/search/search_utils.py
Normal file
76
backend/danswer/search/search_utils.py
Normal file
@@ -0,0 +1,76 @@
|
|||||||
|
from danswer.configs.model_configs import CROSS_EMBED_CONTEXT_SIZE
|
||||||
|
from danswer.configs.model_configs import CROSS_ENCODER_MODEL_ENSEMBLE
|
||||||
|
from danswer.configs.model_configs import DOC_EMBEDDING_CONTEXT_SIZE
|
||||||
|
from danswer.configs.model_configs import DOCUMENT_ENCODER_MODEL
|
||||||
|
from danswer.configs.model_configs import INTENT_MODEL_VERSION
|
||||||
|
from danswer.configs.model_configs import QUERY_MAX_CONTEXT_SIZE
|
||||||
|
from sentence_transformers import CrossEncoder # type: ignore
|
||||||
|
from sentence_transformers import SentenceTransformer # type: ignore
|
||||||
|
from transformers import AutoTokenizer # type: ignore
|
||||||
|
from transformers import TFDistilBertForSequenceClassification # type: ignore
|
||||||
|
|
||||||
|
|
||||||
|
_TOKENIZER: None | AutoTokenizer = None
|
||||||
|
_EMBED_MODEL: None | SentenceTransformer = None
|
||||||
|
_RERANK_MODELS: None | list[CrossEncoder] = None
|
||||||
|
_INTENT_TOKENIZER: None | AutoTokenizer = None
|
||||||
|
_INTENT_MODEL: None | TFDistilBertForSequenceClassification = None
|
||||||
|
|
||||||
|
|
||||||
|
def get_default_tokenizer() -> AutoTokenizer:
|
||||||
|
global _TOKENIZER
|
||||||
|
if _TOKENIZER is None:
|
||||||
|
_TOKENIZER = AutoTokenizer.from_pretrained(DOCUMENT_ENCODER_MODEL)
|
||||||
|
return _TOKENIZER
|
||||||
|
|
||||||
|
|
||||||
|
def get_default_embedding_model() -> SentenceTransformer:
|
||||||
|
global _EMBED_MODEL
|
||||||
|
if _EMBED_MODEL is None:
|
||||||
|
_EMBED_MODEL = SentenceTransformer(DOCUMENT_ENCODER_MODEL)
|
||||||
|
_EMBED_MODEL.max_seq_length = DOC_EMBEDDING_CONTEXT_SIZE
|
||||||
|
return _EMBED_MODEL
|
||||||
|
|
||||||
|
|
||||||
|
def get_default_reranking_model_ensemble() -> list[CrossEncoder]:
|
||||||
|
global _RERANK_MODELS
|
||||||
|
if _RERANK_MODELS is None:
|
||||||
|
_RERANK_MODELS = [
|
||||||
|
CrossEncoder(model_name) for model_name in CROSS_ENCODER_MODEL_ENSEMBLE
|
||||||
|
]
|
||||||
|
for model in _RERANK_MODELS:
|
||||||
|
model.max_length = CROSS_EMBED_CONTEXT_SIZE
|
||||||
|
return _RERANK_MODELS
|
||||||
|
|
||||||
|
|
||||||
|
def get_default_intent_model_tokenizer() -> AutoTokenizer:
|
||||||
|
global _INTENT_TOKENIZER
|
||||||
|
if _INTENT_TOKENIZER is None:
|
||||||
|
_INTENT_TOKENIZER = AutoTokenizer.from_pretrained(INTENT_MODEL_VERSION)
|
||||||
|
return _INTENT_TOKENIZER
|
||||||
|
|
||||||
|
|
||||||
|
def get_default_intent_model() -> TFDistilBertForSequenceClassification:
|
||||||
|
global _INTENT_MODEL
|
||||||
|
if _INTENT_MODEL is None:
|
||||||
|
_INTENT_MODEL = TFDistilBertForSequenceClassification.from_pretrained(
|
||||||
|
INTENT_MODEL_VERSION
|
||||||
|
)
|
||||||
|
_INTENT_MODEL.max_seq_length = QUERY_MAX_CONTEXT_SIZE
|
||||||
|
return _INTENT_MODEL
|
||||||
|
|
||||||
|
|
||||||
|
def warm_up_models() -> None:
|
||||||
|
warm_up_str = "Danswer is amazing"
|
||||||
|
get_default_tokenizer()(warm_up_str)
|
||||||
|
get_default_embedding_model().encode(warm_up_str)
|
||||||
|
cross_encoders = get_default_reranking_model_ensemble()
|
||||||
|
[
|
||||||
|
cross_encoder.predict((warm_up_str, warm_up_str))
|
||||||
|
for cross_encoder in cross_encoders
|
||||||
|
]
|
||||||
|
intent_tokenizer = get_default_intent_model_tokenizer()
|
||||||
|
inputs = intent_tokenizer(
|
||||||
|
warm_up_str, return_tensors="tf", truncation=True, padding=True
|
||||||
|
)
|
||||||
|
get_default_intent_model()(inputs)
|
@@ -9,26 +9,19 @@ from danswer.configs.app_configs import MINI_CHUNK_SIZE
|
|||||||
from danswer.configs.app_configs import NUM_RERANKED_RESULTS
|
from danswer.configs.app_configs import NUM_RERANKED_RESULTS
|
||||||
from danswer.configs.app_configs import NUM_RETURNED_HITS
|
from danswer.configs.app_configs import NUM_RETURNED_HITS
|
||||||
from danswer.configs.model_configs import BATCH_SIZE_ENCODE_CHUNKS
|
from danswer.configs.model_configs import BATCH_SIZE_ENCODE_CHUNKS
|
||||||
from danswer.configs.model_configs import CROSS_EMBED_CONTEXT_SIZE
|
|
||||||
from danswer.configs.model_configs import CROSS_ENCODER_MODEL_ENSEMBLE
|
|
||||||
from danswer.configs.model_configs import DOC_EMBEDDING_CONTEXT_SIZE
|
|
||||||
from danswer.configs.model_configs import DOCUMENT_ENCODER_MODEL
|
|
||||||
from danswer.datastores.interfaces import IndexFilter
|
from danswer.datastores.interfaces import IndexFilter
|
||||||
from danswer.datastores.interfaces import VectorIndex
|
from danswer.datastores.interfaces import VectorIndex
|
||||||
from danswer.search.type_aliases import Embedder
|
from danswer.search.models import Embedder
|
||||||
|
from danswer.search.search_utils import get_default_embedding_model
|
||||||
|
from danswer.search.search_utils import get_default_reranking_model_ensemble
|
||||||
from danswer.server.models import SearchDoc
|
from danswer.server.models import SearchDoc
|
||||||
from danswer.utils.logging import setup_logger
|
from danswer.utils.logging import setup_logger
|
||||||
from danswer.utils.timing import log_function_time
|
from danswer.utils.timing import log_function_time
|
||||||
from sentence_transformers import CrossEncoder # type: ignore
|
|
||||||
from sentence_transformers import SentenceTransformer # type: ignore
|
from sentence_transformers import SentenceTransformer # type: ignore
|
||||||
|
|
||||||
logger = setup_logger()
|
logger = setup_logger()
|
||||||
|
|
||||||
|
|
||||||
_EMBED_MODEL: None | SentenceTransformer = None
|
|
||||||
_RERANK_MODELS: None | list[CrossEncoder] = None
|
|
||||||
|
|
||||||
|
|
||||||
def chunks_to_search_docs(chunks: list[InferenceChunk] | None) -> list[SearchDoc]:
|
def chunks_to_search_docs(chunks: list[InferenceChunk] | None) -> list[SearchDoc]:
|
||||||
search_docs = (
|
search_docs = (
|
||||||
[
|
[
|
||||||
@@ -46,33 +39,6 @@ def chunks_to_search_docs(chunks: list[InferenceChunk] | None) -> list[SearchDoc
|
|||||||
return search_docs
|
return search_docs
|
||||||
|
|
||||||
|
|
||||||
def get_default_embedding_model() -> SentenceTransformer:
|
|
||||||
global _EMBED_MODEL
|
|
||||||
if _EMBED_MODEL is None:
|
|
||||||
_EMBED_MODEL = SentenceTransformer(DOCUMENT_ENCODER_MODEL)
|
|
||||||
_EMBED_MODEL.max_seq_length = DOC_EMBEDDING_CONTEXT_SIZE
|
|
||||||
|
|
||||||
return _EMBED_MODEL
|
|
||||||
|
|
||||||
|
|
||||||
def get_default_reranking_model_ensemble() -> list[CrossEncoder]:
|
|
||||||
global _RERANK_MODELS
|
|
||||||
if _RERANK_MODELS is None:
|
|
||||||
_RERANK_MODELS = [
|
|
||||||
CrossEncoder(model_name) for model_name in CROSS_ENCODER_MODEL_ENSEMBLE
|
|
||||||
]
|
|
||||||
for model in _RERANK_MODELS:
|
|
||||||
model.max_length = CROSS_EMBED_CONTEXT_SIZE
|
|
||||||
|
|
||||||
return _RERANK_MODELS
|
|
||||||
|
|
||||||
|
|
||||||
def warm_up_models() -> None:
|
|
||||||
get_default_embedding_model().encode("Danswer is so cool")
|
|
||||||
cross_encoders = get_default_reranking_model_ensemble()
|
|
||||||
[cross_encoder.predict(("What is Danswer", "Enterprise QA")) for cross_encoder in cross_encoders] # type: ignore
|
|
||||||
|
|
||||||
|
|
||||||
@log_function_time()
|
@log_function_time()
|
||||||
def semantic_reranking(
|
def semantic_reranking(
|
||||||
query: str,
|
query: str,
|
||||||
|
@@ -1,7 +0,0 @@
|
|||||||
from danswer.chunking.models import EmbeddedIndexChunk
|
|
||||||
from danswer.chunking.models import IndexChunk
|
|
||||||
|
|
||||||
|
|
||||||
class Embedder:
|
|
||||||
def embed(self, chunks: list[IndexChunk]) -> list[EmbeddedIndexChunk]:
|
|
||||||
raise NotImplementedError
|
|
@@ -27,6 +27,11 @@ class DataRequest(BaseModel):
|
|||||||
data: str
|
data: str
|
||||||
|
|
||||||
|
|
||||||
|
class HelperResponse(BaseModel):
|
||||||
|
values: dict[str, str]
|
||||||
|
details: list[str] | None = None
|
||||||
|
|
||||||
|
|
||||||
class GoogleAppWebCredentials(BaseModel):
|
class GoogleAppWebCredentials(BaseModel):
|
||||||
client_id: str
|
client_id: str
|
||||||
project_id: str
|
project_id: str
|
||||||
@@ -79,6 +84,7 @@ class QuestionRequest(BaseModel):
|
|||||||
collection: str
|
collection: str
|
||||||
use_keyword: bool | None
|
use_keyword: bool | None
|
||||||
filters: list[IndexFilter] | None
|
filters: list[IndexFilter] | None
|
||||||
|
offset: int | None
|
||||||
|
|
||||||
|
|
||||||
class SearchResponse(BaseModel):
|
class SearchResponse(BaseModel):
|
||||||
|
@@ -1,4 +1,3 @@
|
|||||||
import json
|
|
||||||
import time
|
import time
|
||||||
from collections.abc import Generator
|
from collections.abc import Generator
|
||||||
|
|
||||||
@@ -11,9 +10,11 @@ from danswer.datastores.typesense.store import TypesenseIndex
|
|||||||
from danswer.db.models import User
|
from danswer.db.models import User
|
||||||
from danswer.direct_qa import get_default_backend_qa_model
|
from danswer.direct_qa import get_default_backend_qa_model
|
||||||
from danswer.direct_qa.question_answer import get_json_line
|
from danswer.direct_qa.question_answer import get_json_line
|
||||||
|
from danswer.search.danswer_helper import recommend_search_flow
|
||||||
from danswer.search.keyword_search import retrieve_keyword_documents
|
from danswer.search.keyword_search import retrieve_keyword_documents
|
||||||
from danswer.search.semantic_search import chunks_to_search_docs
|
from danswer.search.semantic_search import chunks_to_search_docs
|
||||||
from danswer.search.semantic_search import retrieve_ranked_documents
|
from danswer.search.semantic_search import retrieve_ranked_documents
|
||||||
|
from danswer.server.models import HelperResponse
|
||||||
from danswer.server.models import QAResponse
|
from danswer.server.models import QAResponse
|
||||||
from danswer.server.models import QuestionRequest
|
from danswer.server.models import QuestionRequest
|
||||||
from danswer.server.models import SearchResponse
|
from danswer.server.models import SearchResponse
|
||||||
@@ -27,6 +28,15 @@ logger = setup_logger()
|
|||||||
router = APIRouter()
|
router = APIRouter()
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/search-intent")
|
||||||
|
def get_search_type(
|
||||||
|
question: QuestionRequest = Depends(), _: User = Depends(current_user)
|
||||||
|
) -> HelperResponse:
|
||||||
|
query = question.query
|
||||||
|
use_keyword = question.use_keyword if question.use_keyword is not None else False
|
||||||
|
return recommend_search_flow(query, use_keyword)
|
||||||
|
|
||||||
|
|
||||||
@router.post("/semantic-search")
|
@router.post("/semantic-search")
|
||||||
def semantic_search(
|
def semantic_search(
|
||||||
question: QuestionRequest, user: User = Depends(current_user)
|
question: QuestionRequest, user: User = Depends(current_user)
|
||||||
@@ -79,6 +89,7 @@ def direct_qa(
|
|||||||
collection = question.collection
|
collection = question.collection
|
||||||
filters = question.filters
|
filters = question.filters
|
||||||
use_keyword = question.use_keyword
|
use_keyword = question.use_keyword
|
||||||
|
offset_count = question.offset if question.offset is not None else 0
|
||||||
logger.info(f"Received QA query: {query}")
|
logger.info(f"Received QA query: {query}")
|
||||||
|
|
||||||
user_id = None if user is None else int(user.id)
|
user_id = None if user is None else int(user.id)
|
||||||
@@ -97,9 +108,13 @@ def direct_qa(
|
|||||||
)
|
)
|
||||||
|
|
||||||
qa_model = get_default_backend_qa_model(timeout=QA_TIMEOUT)
|
qa_model = get_default_backend_qa_model(timeout=QA_TIMEOUT)
|
||||||
|
chunk_offset = offset_count * NUM_GENERATIVE_AI_INPUT_DOCS
|
||||||
|
if chunk_offset >= len(ranked_chunks):
|
||||||
|
raise ValueError("Chunks offset too large, should not retry this many times")
|
||||||
try:
|
try:
|
||||||
answer, quotes = qa_model.answer_question(
|
answer, quotes = qa_model.answer_question(
|
||||||
query, ranked_chunks[:NUM_GENERATIVE_AI_INPUT_DOCS]
|
query,
|
||||||
|
ranked_chunks[chunk_offset : chunk_offset + NUM_GENERATIVE_AI_INPUT_DOCS],
|
||||||
)
|
)
|
||||||
except Exception:
|
except Exception:
|
||||||
# exception is logged in the answer_question method, no need to re-log
|
# exception is logged in the answer_question method, no need to re-log
|
||||||
@@ -127,6 +142,7 @@ def stream_direct_qa(
|
|||||||
collection = question.collection
|
collection = question.collection
|
||||||
filters = question.filters
|
filters = question.filters
|
||||||
use_keyword = question.use_keyword
|
use_keyword = question.use_keyword
|
||||||
|
offset_count = question.offset if question.offset is not None else 0
|
||||||
logger.info(f"Received QA query: {query}")
|
logger.info(f"Received QA query: {query}")
|
||||||
|
|
||||||
user_id = None if user is None else int(user.id)
|
user_id = None if user is None else int(user.id)
|
||||||
@@ -152,9 +168,17 @@ def stream_direct_qa(
|
|||||||
yield get_json_line(top_docs_dict)
|
yield get_json_line(top_docs_dict)
|
||||||
|
|
||||||
qa_model = get_default_backend_qa_model(timeout=QA_TIMEOUT)
|
qa_model = get_default_backend_qa_model(timeout=QA_TIMEOUT)
|
||||||
|
chunk_offset = offset_count * NUM_GENERATIVE_AI_INPUT_DOCS
|
||||||
|
if chunk_offset >= len(ranked_chunks):
|
||||||
|
raise ValueError(
|
||||||
|
"Chunks offset too large, should not retry this many times"
|
||||||
|
)
|
||||||
try:
|
try:
|
||||||
for response_dict in qa_model.answer_question_stream(
|
for response_dict in qa_model.answer_question_stream(
|
||||||
query, ranked_chunks[:NUM_GENERATIVE_AI_INPUT_DOCS]
|
query,
|
||||||
|
ranked_chunks[
|
||||||
|
chunk_offset : chunk_offset + NUM_GENERATIVE_AI_INPUT_DOCS
|
||||||
|
],
|
||||||
):
|
):
|
||||||
if response_dict is None:
|
if response_dict is None:
|
||||||
continue
|
continue
|
||||||
|
@@ -10,8 +10,8 @@ from danswer.datastores.interfaces import KeywordIndex
|
|||||||
from danswer.datastores.interfaces import VectorIndex
|
from danswer.datastores.interfaces import VectorIndex
|
||||||
from danswer.datastores.qdrant.store import QdrantIndex
|
from danswer.datastores.qdrant.store import QdrantIndex
|
||||||
from danswer.datastores.typesense.store import TypesenseIndex
|
from danswer.datastores.typesense.store import TypesenseIndex
|
||||||
|
from danswer.search.models import Embedder
|
||||||
from danswer.search.semantic_search import DefaultEmbedder
|
from danswer.search.semantic_search import DefaultEmbedder
|
||||||
from danswer.search.type_aliases import Embedder
|
|
||||||
|
|
||||||
|
|
||||||
class IndexingPipelineProtocol(Protocol):
|
class IndexingPipelineProtocol(Protocol):
|
||||||
|
@@ -28,6 +28,7 @@ rfc3986==1.5.0
|
|||||||
sentence-transformers==2.2.2
|
sentence-transformers==2.2.2
|
||||||
slack-sdk==3.20.2
|
slack-sdk==3.20.2
|
||||||
SQLAlchemy[mypy]==2.0.12
|
SQLAlchemy[mypy]==2.0.12
|
||||||
|
tensorflow==2.12.0
|
||||||
transformers==4.27.3
|
transformers==4.27.3
|
||||||
types-beautifulsoup4==4.12.0.3
|
types-beautifulsoup4==4.12.0.3
|
||||||
types-html5lib==1.1.11.13
|
types-html5lib==1.1.11.13
|
||||||
|
@@ -96,19 +96,15 @@ if __name__ == "__main__":
|
|||||||
"query": query,
|
"query": query,
|
||||||
"collection": QDRANT_DEFAULT_COLLECTION,
|
"collection": QDRANT_DEFAULT_COLLECTION,
|
||||||
"use_keyword": flow_type == "keyword", # Ignore if not QA Endpoints
|
"use_keyword": flow_type == "keyword", # Ignore if not QA Endpoints
|
||||||
"filters": json.dumps([{SOURCE_TYPE: source_types}]),
|
"filters": [{SOURCE_TYPE: source_types}],
|
||||||
}
|
}
|
||||||
|
|
||||||
if args.stream:
|
if args.stream:
|
||||||
with requests.get(
|
with requests.post(endpoint, json=query_json, stream=True) as r:
|
||||||
endpoint, params=urllib.parse.urlencode(query_json), stream=True
|
|
||||||
) as r:
|
|
||||||
for json_response in r.iter_lines():
|
for json_response in r.iter_lines():
|
||||||
pprint(json.loads(json_response.decode()))
|
pprint(json.loads(json_response.decode()))
|
||||||
else:
|
else:
|
||||||
response = requests.get(
|
response = requests.post(endpoint, json=query_json)
|
||||||
endpoint, params=urllib.parse.urlencode(query_json)
|
|
||||||
)
|
|
||||||
contents = json.loads(response.content)
|
contents = json.loads(response.content)
|
||||||
pprint(contents)
|
pprint(contents)
|
||||||
|
|
||||||
|
Reference in New Issue
Block a user