DAN-55 Intent Model (#89)

Includes:
- Intent Model
- Heuristic Classifications
- GPT self error classification
- Bugfix on finding end of answer stream
This commit is contained in:
Yuhong Sun 2023-06-07 15:27:06 -07:00 committed by GitHub
parent 0f1f16880a
commit 7c97cc4626
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
16 changed files with 245 additions and 68 deletions

View File

@ -18,7 +18,7 @@ CROSS_ENCODER_MODEL_ENSEMBLE = [
"cross-encoder/ms-marco-TinyBERT-L-2-v2",
]
QUERY_EMBEDDING_CONTEXT_SIZE = 256
QUERY_MAX_CONTEXT_SIZE = 256
# The below is correlated with CHUNK_SIZE in app_configs but not strictly calculated
# To avoid extra overhead of tokenizing for chunking during indexing.
DOC_EMBEDDING_CONTEXT_SIZE = 512
@ -32,3 +32,6 @@ BATCH_SIZE_ENCODE_CHUNKS = 8
INTERNAL_MODEL_VERSION = os.environ.get("INTERNAL_MODEL", "openai-chat-completion")
OPENAI_MODEL_VERSION = os.environ.get("OPENAI_MODEL_VERSION", "gpt-3.5-turbo")
OPENAI_MAX_OUTPUT_TOKENS = 512
# Danswer custom Deep Learning Models
INTENT_MODEL_VERSION = "danswer/intent-model"

View File

@ -11,7 +11,7 @@ from danswer.datastores.datastore_utils import get_uuid_from_chunk
from danswer.datastores.interfaces import IndexFilter
from danswer.datastores.interfaces import VectorIndex
from danswer.datastores.qdrant.indexing import index_qdrant_chunks
from danswer.search.semantic_search import get_default_embedding_model
from danswer.search.search_utils import get_default_embedding_model
from danswer.utils.clients import get_qdrant_client
from danswer.utils.logging import setup_logger
from danswer.utils.timing import log_function_time

View File

@ -109,10 +109,13 @@ def json_chat_processor(question: str, documents: list[str]) -> list[dict[str, s
task_msg = (
"Now answer the next user query based on documents above and quote relevant sections.\n"
"Respond with a JSON containing the answer and up to three most relevant quotes from the documents.\n"
"If the query cannot be answered based on the documents, do not provide an answer.\n"
"All quotes MUST be EXACT substrings from provided documents.\n"
"Your responses should be informative and concise.\n"
"You MUST prioritize information from provided documents over internal knowledge.\n"
"If the query cannot be answered based on the documents, respond with "
'{"answer": "Information not found", "quotes": []}\n'
"If the query requires aggregating whole documents, respond with "
'{"answer": "Aggregations not supported", "quotes": []}\n'
f"Sample response:\n{json.dumps(SAMPLE_JSON_RESPONSE)}"
)
messages = [{"role": "system", "content": intro_msg}]

View File

@ -185,7 +185,8 @@ def process_answer(
def stream_answer_end(answer_so_far: str, next_token: str) -> bool:
next_token = next_token.replace('\\"', "")
if answer_so_far and answer_so_far[-1] != "\\":
# If the previous character is an escape token, don't consider the first character of next_token
if answer_so_far and answer_so_far[-1] == "\\":
next_token = next_token[1:]
if '"' in next_token:
return True

View File

@ -111,7 +111,7 @@ def get_application() -> FastAPI:
@application.on_event("startup")
def startup_event() -> None:
# To avoid circular imports
from danswer.search.semantic_search import (
from danswer.search.search_utils import (
warm_up_models,
)
from danswer.datastores.qdrant.indexing import create_qdrant_collection

View File

@ -0,0 +1,91 @@
import numpy as np
import tensorflow as tf # type:ignore
from danswer.search.keyword_search import remove_stop_words
from danswer.search.models import QueryFlow
from danswer.search.models import SearchType
from danswer.search.search_utils import get_default_intent_model
from danswer.search.search_utils import get_default_intent_model_tokenizer
from danswer.search.search_utils import get_default_tokenizer
from danswer.server.models import HelperResponse
from transformers import AutoTokenizer # type:ignore
def count_unk_tokens(text: str, tokenizer: AutoTokenizer) -> int:
"""Unclear if the wordpiece tokenizer used is actually tokenizing anything as the [UNK] token
It splits up even foreign characters and unicode emojis without using UNK"""
tokenized_text = tokenizer.tokenize(text)
return len([token for token in tokenized_text if token == tokenizer.unk_token])
def query_intent(query: str) -> tuple[SearchType, QueryFlow]:
tokenizer = get_default_intent_model_tokenizer()
intent_model = get_default_intent_model()
model_input = tokenizer(query, return_tensors="tf", truncation=True, padding=True)
predictions = intent_model(model_input)[0]
probabilities = tf.nn.softmax(predictions, axis=-1)
class_percentages = np.round(probabilities.numpy() * 100, 2)
keyword, semantic, qa = class_percentages.tolist()[0]
# Heavily bias towards QA, from user perspective, answering a statement is not as bad as not answering a question
if qa > 20:
# If one class is very certain, choose it still
if keyword > 70:
return SearchType.KEYWORD, QueryFlow.SEARCH
if semantic > 70:
return SearchType.SEMANTIC, QueryFlow.SEARCH
# If it's a QA question, it must be a "Semantic" style statement/question
return SearchType.SEMANTIC, QueryFlow.QUESTION_ANSWER
# If definitely not a QA question, choose between keyword or semantic search
elif keyword > semantic:
return SearchType.KEYWORD, QueryFlow.SEARCH
else:
return SearchType.SEMANTIC, QueryFlow.SEARCH
def recommend_search_flow(
query: str,
keyword: bool,
max_percent_stopwords: float = 0.33, # Every third word max, ie "effects of caffeine" still viable keyword search
) -> HelperResponse:
heuristic_search_type: SearchType | None = None
message: str | None = None
# Heuristics based decisions
words = query.split()
non_stopwords = remove_stop_words(query)
non_stopword_percent = len(non_stopwords) / len(words)
# UNK tokens -> suggest Keyword (still may be valid QA)
if count_unk_tokens(query, get_default_tokenizer()) > 0:
if not keyword:
heuristic_search_type = SearchType.KEYWORD
message = (
"Query contains words that the AI model cannot understand, "
"Keyword Search may yield better results."
)
# Too many stop words, most likely a Semantic query (still may be valid QA)
if non_stopword_percent < 1 - max_percent_stopwords:
if keyword:
heuristic_search_type = SearchType.SEMANTIC
message = "Query contains stopwords, AI Search is likely more suitable."
# Model based decisions
model_search_type, flow = query_intent(query)
if not message:
if model_search_type == SearchType.SEMANTIC and keyword:
message = "Query may yield better results with Semantic Search"
if model_search_type == SearchType.KEYWORD and not keyword:
message = "Query may yield better results with Keyword Search."
return HelperResponse(
values={
"flow": flow,
"search_type": model_search_type
if heuristic_search_type is None
else heuristic_search_type,
},
details=[message] if message else [],
)

View File

@ -13,23 +13,21 @@ from nltk.tokenize import word_tokenize # type:ignore
logger = setup_logger()
def lemmatize_text(text: str) -> str:
def lemmatize_text(text: str) -> list[str]:
lemmatizer = WordNetLemmatizer()
word_tokens = word_tokenize(text)
lemmatized_text = [lemmatizer.lemmatize(word) for word in word_tokens]
return " ".join(lemmatized_text)
return [lemmatizer.lemmatize(word) for word in word_tokens]
def remove_stop_words(text: str) -> str:
def remove_stop_words(text: str) -> list[str]:
stop_words = set(stopwords.words("english"))
word_tokens = word_tokenize(text)
filtered_text = [word for word in word_tokens if word.casefold() not in stop_words]
return " ".join(filtered_text)
return [word for word in word_tokens if word.casefold() not in stop_words]
def query_processing(query: str) -> str:
query = remove_stop_words(query)
query = lemmatize_text(query)
query = " ".join(remove_stop_words(query))
query = " ".join(lemmatize_text(query))
return query

View File

@ -0,0 +1,19 @@
from enum import Enum
from danswer.chunking.models import EmbeddedIndexChunk
from danswer.chunking.models import IndexChunk
class SearchType(str, Enum):
KEYWORD = "keyword" # May be better to also try keyword search if Semantic (AI Search) is on
SEMANTIC = "semantic" # Really should try Semantic (AI Search) if keyword is on
class QueryFlow(str, Enum):
SEARCH = "search"
QUESTION_ANSWER = "question-answer"
class Embedder:
def embed(self, chunks: list[IndexChunk]) -> list[EmbeddedIndexChunk]:
raise NotImplementedError

View File

@ -0,0 +1,76 @@
from danswer.configs.model_configs import CROSS_EMBED_CONTEXT_SIZE
from danswer.configs.model_configs import CROSS_ENCODER_MODEL_ENSEMBLE
from danswer.configs.model_configs import DOC_EMBEDDING_CONTEXT_SIZE
from danswer.configs.model_configs import DOCUMENT_ENCODER_MODEL
from danswer.configs.model_configs import INTENT_MODEL_VERSION
from danswer.configs.model_configs import QUERY_MAX_CONTEXT_SIZE
from sentence_transformers import CrossEncoder # type: ignore
from sentence_transformers import SentenceTransformer # type: ignore
from transformers import AutoTokenizer # type: ignore
from transformers import TFDistilBertForSequenceClassification # type: ignore
_TOKENIZER: None | AutoTokenizer = None
_EMBED_MODEL: None | SentenceTransformer = None
_RERANK_MODELS: None | list[CrossEncoder] = None
_INTENT_TOKENIZER: None | AutoTokenizer = None
_INTENT_MODEL: None | TFDistilBertForSequenceClassification = None
def get_default_tokenizer() -> AutoTokenizer:
global _TOKENIZER
if _TOKENIZER is None:
_TOKENIZER = AutoTokenizer.from_pretrained(DOCUMENT_ENCODER_MODEL)
return _TOKENIZER
def get_default_embedding_model() -> SentenceTransformer:
global _EMBED_MODEL
if _EMBED_MODEL is None:
_EMBED_MODEL = SentenceTransformer(DOCUMENT_ENCODER_MODEL)
_EMBED_MODEL.max_seq_length = DOC_EMBEDDING_CONTEXT_SIZE
return _EMBED_MODEL
def get_default_reranking_model_ensemble() -> list[CrossEncoder]:
global _RERANK_MODELS
if _RERANK_MODELS is None:
_RERANK_MODELS = [
CrossEncoder(model_name) for model_name in CROSS_ENCODER_MODEL_ENSEMBLE
]
for model in _RERANK_MODELS:
model.max_length = CROSS_EMBED_CONTEXT_SIZE
return _RERANK_MODELS
def get_default_intent_model_tokenizer() -> AutoTokenizer:
global _INTENT_TOKENIZER
if _INTENT_TOKENIZER is None:
_INTENT_TOKENIZER = AutoTokenizer.from_pretrained(INTENT_MODEL_VERSION)
return _INTENT_TOKENIZER
def get_default_intent_model() -> TFDistilBertForSequenceClassification:
global _INTENT_MODEL
if _INTENT_MODEL is None:
_INTENT_MODEL = TFDistilBertForSequenceClassification.from_pretrained(
INTENT_MODEL_VERSION
)
_INTENT_MODEL.max_seq_length = QUERY_MAX_CONTEXT_SIZE
return _INTENT_MODEL
def warm_up_models() -> None:
warm_up_str = "Danswer is amazing"
get_default_tokenizer()(warm_up_str)
get_default_embedding_model().encode(warm_up_str)
cross_encoders = get_default_reranking_model_ensemble()
[
cross_encoder.predict((warm_up_str, warm_up_str))
for cross_encoder in cross_encoders
]
intent_tokenizer = get_default_intent_model_tokenizer()
inputs = intent_tokenizer(
warm_up_str, return_tensors="tf", truncation=True, padding=True
)
get_default_intent_model()(inputs)

View File

@ -9,26 +9,19 @@ from danswer.configs.app_configs import MINI_CHUNK_SIZE
from danswer.configs.app_configs import NUM_RERANKED_RESULTS
from danswer.configs.app_configs import NUM_RETURNED_HITS
from danswer.configs.model_configs import BATCH_SIZE_ENCODE_CHUNKS
from danswer.configs.model_configs import CROSS_EMBED_CONTEXT_SIZE
from danswer.configs.model_configs import CROSS_ENCODER_MODEL_ENSEMBLE
from danswer.configs.model_configs import DOC_EMBEDDING_CONTEXT_SIZE
from danswer.configs.model_configs import DOCUMENT_ENCODER_MODEL
from danswer.datastores.interfaces import IndexFilter
from danswer.datastores.interfaces import VectorIndex
from danswer.search.type_aliases import Embedder
from danswer.search.models import Embedder
from danswer.search.search_utils import get_default_embedding_model
from danswer.search.search_utils import get_default_reranking_model_ensemble
from danswer.server.models import SearchDoc
from danswer.utils.logging import setup_logger
from danswer.utils.timing import log_function_time
from sentence_transformers import CrossEncoder # type: ignore
from sentence_transformers import SentenceTransformer # type: ignore
logger = setup_logger()
_EMBED_MODEL: None | SentenceTransformer = None
_RERANK_MODELS: None | list[CrossEncoder] = None
def chunks_to_search_docs(chunks: list[InferenceChunk] | None) -> list[SearchDoc]:
search_docs = (
[
@ -46,33 +39,6 @@ def chunks_to_search_docs(chunks: list[InferenceChunk] | None) -> list[SearchDoc
return search_docs
def get_default_embedding_model() -> SentenceTransformer:
global _EMBED_MODEL
if _EMBED_MODEL is None:
_EMBED_MODEL = SentenceTransformer(DOCUMENT_ENCODER_MODEL)
_EMBED_MODEL.max_seq_length = DOC_EMBEDDING_CONTEXT_SIZE
return _EMBED_MODEL
def get_default_reranking_model_ensemble() -> list[CrossEncoder]:
global _RERANK_MODELS
if _RERANK_MODELS is None:
_RERANK_MODELS = [
CrossEncoder(model_name) for model_name in CROSS_ENCODER_MODEL_ENSEMBLE
]
for model in _RERANK_MODELS:
model.max_length = CROSS_EMBED_CONTEXT_SIZE
return _RERANK_MODELS
def warm_up_models() -> None:
get_default_embedding_model().encode("Danswer is so cool")
cross_encoders = get_default_reranking_model_ensemble()
[cross_encoder.predict(("What is Danswer", "Enterprise QA")) for cross_encoder in cross_encoders] # type: ignore
@log_function_time()
def semantic_reranking(
query: str,

View File

@ -1,7 +0,0 @@
from danswer.chunking.models import EmbeddedIndexChunk
from danswer.chunking.models import IndexChunk
class Embedder:
def embed(self, chunks: list[IndexChunk]) -> list[EmbeddedIndexChunk]:
raise NotImplementedError

View File

@ -27,6 +27,11 @@ class DataRequest(BaseModel):
data: str
class HelperResponse(BaseModel):
values: dict[str, str]
details: list[str] | None = None
class GoogleAppWebCredentials(BaseModel):
client_id: str
project_id: str
@ -79,6 +84,7 @@ class QuestionRequest(BaseModel):
collection: str
use_keyword: bool | None
filters: list[IndexFilter] | None
offset: int | None
class SearchResponse(BaseModel):

View File

@ -1,4 +1,3 @@
import json
import time
from collections.abc import Generator
@ -11,9 +10,11 @@ from danswer.datastores.typesense.store import TypesenseIndex
from danswer.db.models import User
from danswer.direct_qa import get_default_backend_qa_model
from danswer.direct_qa.question_answer import get_json_line
from danswer.search.danswer_helper import recommend_search_flow
from danswer.search.keyword_search import retrieve_keyword_documents
from danswer.search.semantic_search import chunks_to_search_docs
from danswer.search.semantic_search import retrieve_ranked_documents
from danswer.server.models import HelperResponse
from danswer.server.models import QAResponse
from danswer.server.models import QuestionRequest
from danswer.server.models import SearchResponse
@ -27,6 +28,15 @@ logger = setup_logger()
router = APIRouter()
@router.get("/search-intent")
def get_search_type(
question: QuestionRequest = Depends(), _: User = Depends(current_user)
) -> HelperResponse:
query = question.query
use_keyword = question.use_keyword if question.use_keyword is not None else False
return recommend_search_flow(query, use_keyword)
@router.post("/semantic-search")
def semantic_search(
question: QuestionRequest, user: User = Depends(current_user)
@ -79,6 +89,7 @@ def direct_qa(
collection = question.collection
filters = question.filters
use_keyword = question.use_keyword
offset_count = question.offset if question.offset is not None else 0
logger.info(f"Received QA query: {query}")
user_id = None if user is None else int(user.id)
@ -97,9 +108,13 @@ def direct_qa(
)
qa_model = get_default_backend_qa_model(timeout=QA_TIMEOUT)
chunk_offset = offset_count * NUM_GENERATIVE_AI_INPUT_DOCS
if chunk_offset >= len(ranked_chunks):
raise ValueError("Chunks offset too large, should not retry this many times")
try:
answer, quotes = qa_model.answer_question(
query, ranked_chunks[:NUM_GENERATIVE_AI_INPUT_DOCS]
query,
ranked_chunks[chunk_offset : chunk_offset + NUM_GENERATIVE_AI_INPUT_DOCS],
)
except Exception:
# exception is logged in the answer_question method, no need to re-log
@ -127,6 +142,7 @@ def stream_direct_qa(
collection = question.collection
filters = question.filters
use_keyword = question.use_keyword
offset_count = question.offset if question.offset is not None else 0
logger.info(f"Received QA query: {query}")
user_id = None if user is None else int(user.id)
@ -152,9 +168,17 @@ def stream_direct_qa(
yield get_json_line(top_docs_dict)
qa_model = get_default_backend_qa_model(timeout=QA_TIMEOUT)
chunk_offset = offset_count * NUM_GENERATIVE_AI_INPUT_DOCS
if chunk_offset >= len(ranked_chunks):
raise ValueError(
"Chunks offset too large, should not retry this many times"
)
try:
for response_dict in qa_model.answer_question_stream(
query, ranked_chunks[:NUM_GENERATIVE_AI_INPUT_DOCS]
query,
ranked_chunks[
chunk_offset : chunk_offset + NUM_GENERATIVE_AI_INPUT_DOCS
],
):
if response_dict is None:
continue

View File

@ -10,8 +10,8 @@ from danswer.datastores.interfaces import KeywordIndex
from danswer.datastores.interfaces import VectorIndex
from danswer.datastores.qdrant.store import QdrantIndex
from danswer.datastores.typesense.store import TypesenseIndex
from danswer.search.models import Embedder
from danswer.search.semantic_search import DefaultEmbedder
from danswer.search.type_aliases import Embedder
class IndexingPipelineProtocol(Protocol):

View File

@ -28,6 +28,7 @@ rfc3986==1.5.0
sentence-transformers==2.2.2
slack-sdk==3.20.2
SQLAlchemy[mypy]==2.0.12
tensorflow==2.12.0
transformers==4.27.3
types-beautifulsoup4==4.12.0.3
types-html5lib==1.1.11.13

View File

@ -96,19 +96,15 @@ if __name__ == "__main__":
"query": query,
"collection": QDRANT_DEFAULT_COLLECTION,
"use_keyword": flow_type == "keyword", # Ignore if not QA Endpoints
"filters": json.dumps([{SOURCE_TYPE: source_types}]),
"filters": [{SOURCE_TYPE: source_types}],
}
if args.stream:
with requests.get(
endpoint, params=urllib.parse.urlencode(query_json), stream=True
) as r:
with requests.post(endpoint, json=query_json, stream=True) as r:
for json_response in r.iter_lines():
pprint(json.loads(json_response.decode()))
else:
response = requests.get(
endpoint, params=urllib.parse.urlencode(query_json)
)
response = requests.post(endpoint, json=query_json)
contents = json.loads(response.content)
pprint(contents)