mirror of
https://github.com/danswer-ai/danswer.git
synced 2025-05-08 02:40:32 +02:00
DAN-21 Modularize QA model for easy swapping (#9)
This commit is contained in:
parent
e7b901f292
commit
c00d37a7d7
@ -18,6 +18,7 @@ MODEL_CACHE_FOLDER = os.environ.get("TRANSFORMERS_CACHE")
|
|||||||
# Purely an optimization, memory limitation consideration
|
# Purely an optimization, memory limitation consideration
|
||||||
BATCH_SIZE_ENCODE_CHUNKS = 8
|
BATCH_SIZE_ENCODE_CHUNKS = 8
|
||||||
|
|
||||||
# OpenAI Model API Configs
|
# QA Model API Configs
|
||||||
|
INTERNAL_MODEL_VERSION = os.environ.get("INTERNAL_MODEL", "openai-completion")
|
||||||
OPENAPI_MODEL_VERSION = "text-davinci-003"
|
OPENAPI_MODEL_VERSION = "text-davinci-003"
|
||||||
OPENAI_MAX_OUTPUT_TOKENS = 200
|
OPENAI_MAX_OUTPUT_TOKENS = 200
|
||||||
|
@ -7,7 +7,7 @@ from danswer.datastores.qdrant.indexing import index_chunks
|
|||||||
from danswer.embedding.biencoder import get_default_model
|
from danswer.embedding.biencoder import get_default_model
|
||||||
from danswer.utils.clients import get_qdrant_client
|
from danswer.utils.clients import get_qdrant_client
|
||||||
from danswer.utils.logging import setup_logger
|
from danswer.utils.logging import setup_logger
|
||||||
from danswer.utils.timing import build_timing_wrapper
|
from danswer.utils.timing import log_function_time
|
||||||
from qdrant_client.http.exceptions import ResponseHandlingException
|
from qdrant_client.http.exceptions import ResponseHandlingException
|
||||||
from qdrant_client.http.exceptions import UnexpectedResponse
|
from qdrant_client.http.exceptions import UnexpectedResponse
|
||||||
from qdrant_client.http.models import FieldCondition
|
from qdrant_client.http.models import FieldCondition
|
||||||
@ -28,7 +28,7 @@ class QdrantDatastore(Datastore):
|
|||||||
chunks=chunks, collection=self.collection, client=self.client
|
chunks=chunks, collection=self.collection, client=self.client
|
||||||
)
|
)
|
||||||
|
|
||||||
@build_timing_wrapper()
|
@log_function_time()
|
||||||
def semantic_retrieval(
|
def semantic_retrieval(
|
||||||
self, query: str, filters: list[DatastoreFilter] | None, num_to_retrieve: int
|
self, query: str, filters: list[DatastoreFilter] | None, num_to_retrieve: int
|
||||||
) -> list[InferenceChunk]:
|
) -> list[InferenceChunk]:
|
||||||
|
@ -0,0 +1,6 @@
|
|||||||
|
from danswer.direct_qa.interfaces import QAModel
|
||||||
|
from danswer.direct_qa.question_answer import OpenAICompletionQA
|
||||||
|
|
||||||
|
|
||||||
|
def get_default_backend_qa_model() -> QAModel:
|
||||||
|
return OpenAICompletionQA()
|
12
backend/danswer/direct_qa/interfaces.py
Normal file
12
backend/danswer/direct_qa/interfaces.py
Normal file
@ -0,0 +1,12 @@
|
|||||||
|
import abc
|
||||||
|
from typing import *
|
||||||
|
|
||||||
|
from danswer.chunking.models import InferenceChunk
|
||||||
|
|
||||||
|
|
||||||
|
class QAModel:
|
||||||
|
@abc.abstractmethod
|
||||||
|
def answer_question(
|
||||||
|
self, query: str, context_docs: list[InferenceChunk]
|
||||||
|
) -> tuple[str | None, dict[str, dict[str, str | int | None]] | None]:
|
||||||
|
raise NotImplementedError
|
@ -21,8 +21,3 @@ def generic_prompt_processor(question: str, documents: list[str]) -> str:
|
|||||||
prompt += f"{QUESTION_PAT}\n{question}\n"
|
prompt += f"{QUESTION_PAT}\n{question}\n"
|
||||||
prompt += f"{ANSWER_PAT}\n"
|
prompt += f"{ANSWER_PAT}\n"
|
||||||
return prompt
|
return prompt
|
||||||
|
|
||||||
|
|
||||||
BASIC_QA_PROMPTS = {
|
|
||||||
"generic-qa": generic_prompt_processor,
|
|
||||||
}
|
|
||||||
|
@ -16,12 +16,15 @@ from danswer.configs.constants import SOURCE_LINK
|
|||||||
from danswer.configs.constants import SOURCE_TYPE
|
from danswer.configs.constants import SOURCE_TYPE
|
||||||
from danswer.configs.model_configs import OPENAI_MAX_OUTPUT_TOKENS
|
from danswer.configs.model_configs import OPENAI_MAX_OUTPUT_TOKENS
|
||||||
from danswer.configs.model_configs import OPENAPI_MODEL_VERSION
|
from danswer.configs.model_configs import OPENAPI_MODEL_VERSION
|
||||||
|
from danswer.direct_qa.interfaces import QAModel
|
||||||
from danswer.direct_qa.qa_prompts import ANSWER_PAT
|
from danswer.direct_qa.qa_prompts import ANSWER_PAT
|
||||||
|
from danswer.direct_qa.qa_prompts import generic_prompt_processor
|
||||||
from danswer.direct_qa.qa_prompts import QUOTE_PAT
|
from danswer.direct_qa.qa_prompts import QUOTE_PAT
|
||||||
from danswer.direct_qa.qa_prompts import UNCERTAINTY_PAT
|
from danswer.direct_qa.qa_prompts import UNCERTAINTY_PAT
|
||||||
from danswer.utils.logging import setup_logger
|
from danswer.utils.logging import setup_logger
|
||||||
from danswer.utils.text_processing import clean_model_quote
|
from danswer.utils.text_processing import clean_model_quote
|
||||||
from danswer.utils.text_processing import shared_precompare_cleanup
|
from danswer.utils.text_processing import shared_precompare_cleanup
|
||||||
|
from danswer.utils.timing import log_function_time
|
||||||
|
|
||||||
|
|
||||||
logger = setup_logger()
|
logger = setup_logger()
|
||||||
@ -29,39 +32,6 @@ logger = setup_logger()
|
|||||||
openai.api_key = OPENAI_API_KEY
|
openai.api_key = OPENAI_API_KEY
|
||||||
|
|
||||||
|
|
||||||
def ask_openai(
|
|
||||||
complete_qa_prompt: str,
|
|
||||||
model: str = OPENAPI_MODEL_VERSION,
|
|
||||||
max_tokens: int = OPENAI_MAX_OUTPUT_TOKENS,
|
|
||||||
) -> str:
|
|
||||||
try:
|
|
||||||
response = openai.Completion.create(
|
|
||||||
prompt=complete_qa_prompt,
|
|
||||||
temperature=0,
|
|
||||||
top_p=1,
|
|
||||||
frequency_penalty=0,
|
|
||||||
presence_penalty=0,
|
|
||||||
model=model,
|
|
||||||
max_tokens=max_tokens,
|
|
||||||
)
|
|
||||||
model_answer = response["choices"][0]["text"].strip()
|
|
||||||
logger.info("OpenAI Token Usage: " + str(response["usage"]).replace("\n", ""))
|
|
||||||
return model_answer
|
|
||||||
except Exception as e:
|
|
||||||
logger.exception(e)
|
|
||||||
return "Model Failure"
|
|
||||||
|
|
||||||
|
|
||||||
def answer_question(
|
|
||||||
query: str,
|
|
||||||
context_docs: list[str],
|
|
||||||
prompt_processor: Callable[[str, list[str]], str],
|
|
||||||
) -> str:
|
|
||||||
formatted_prompt = prompt_processor(query, context_docs)
|
|
||||||
logger.debug(formatted_prompt)
|
|
||||||
return ask_openai(formatted_prompt)
|
|
||||||
|
|
||||||
|
|
||||||
def separate_answer_quotes(
|
def separate_answer_quotes(
|
||||||
answer_raw: str,
|
answer_raw: str,
|
||||||
) -> Tuple[Optional[str], Optional[list[str]]]:
|
) -> Tuple[Optional[str], Optional[list[str]]]:
|
||||||
@ -158,9 +128,52 @@ def match_quotes_to_docs(
|
|||||||
|
|
||||||
def process_answer(
|
def process_answer(
|
||||||
answer_raw: str, chunks: list[InferenceChunk]
|
answer_raw: str, chunks: list[InferenceChunk]
|
||||||
) -> Tuple[Optional[str], Optional[Dict[str, Dict[str, Union[str, int, None]]]]]:
|
) -> tuple[str | None, dict[str, dict[str, str | int | None]] | None]:
|
||||||
answer, quote_strings = separate_answer_quotes(answer_raw)
|
answer, quote_strings = separate_answer_quotes(answer_raw)
|
||||||
if not answer or not quote_strings:
|
if not answer or not quote_strings:
|
||||||
return None, None
|
return None, None
|
||||||
quotes_dict = match_quotes_to_docs(quote_strings, chunks)
|
quotes_dict = match_quotes_to_docs(quote_strings, chunks)
|
||||||
return answer, quotes_dict
|
return answer, quotes_dict
|
||||||
|
|
||||||
|
|
||||||
|
class OpenAICompletionQA(QAModel):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
prompt_processor: Callable[[str, list[str]], str] = generic_prompt_processor,
|
||||||
|
model_version: str = OPENAPI_MODEL_VERSION,
|
||||||
|
max_output_tokens: int = OPENAI_MAX_OUTPUT_TOKENS,
|
||||||
|
) -> None:
|
||||||
|
self.prompt_processor = prompt_processor
|
||||||
|
self.model_version = model_version
|
||||||
|
self.max_output_tokens = max_output_tokens
|
||||||
|
|
||||||
|
@log_function_time()
|
||||||
|
def answer_question(
|
||||||
|
self, query: str, context_docs: list[InferenceChunk]
|
||||||
|
) -> tuple[str | None, dict[str, dict[str, str | int | None]] | None]:
|
||||||
|
top_contents = [ranked_chunk.content for ranked_chunk in context_docs]
|
||||||
|
filled_prompt = self.prompt_processor(query, top_contents)
|
||||||
|
logger.debug(filled_prompt)
|
||||||
|
|
||||||
|
try:
|
||||||
|
response = openai.Completion.create(
|
||||||
|
prompt=filled_prompt,
|
||||||
|
temperature=0,
|
||||||
|
top_p=1,
|
||||||
|
frequency_penalty=0,
|
||||||
|
presence_penalty=0,
|
||||||
|
model=self.model_version,
|
||||||
|
max_tokens=self.max_output_tokens,
|
||||||
|
)
|
||||||
|
model_output = response["choices"][0]["text"].strip()
|
||||||
|
logger.info(
|
||||||
|
"OpenAI Token Usage: " + str(response["usage"]).replace("\n", "")
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
logger.exception(e)
|
||||||
|
model_output = "Model Failure"
|
||||||
|
|
||||||
|
logger.debug(model_output)
|
||||||
|
|
||||||
|
answer, quotes_dict = process_answer(model_output, context_docs)
|
||||||
|
return answer, quotes_dict
|
||||||
|
@ -14,7 +14,7 @@ from danswer.configs.model_configs import QUERY_EMBEDDING_CONTEXT_SIZE
|
|||||||
from danswer.datastores.interfaces import Datastore
|
from danswer.datastores.interfaces import Datastore
|
||||||
from danswer.datastores.interfaces import DatastoreFilter
|
from danswer.datastores.interfaces import DatastoreFilter
|
||||||
from danswer.utils.logging import setup_logger
|
from danswer.utils.logging import setup_logger
|
||||||
from danswer.utils.timing import build_timing_wrapper
|
from danswer.utils.timing import log_function_time
|
||||||
from sentence_transformers import CrossEncoder # type: ignore
|
from sentence_transformers import CrossEncoder # type: ignore
|
||||||
from sentence_transformers import SentenceTransformer # type: ignore
|
from sentence_transformers import SentenceTransformer # type: ignore
|
||||||
|
|
||||||
@ -32,7 +32,7 @@ cross_encoder = CrossEncoder(CROSS_ENCODER_MODEL)
|
|||||||
cross_encoder.max_length = CROSS_EMBED_CONTEXT_SIZE
|
cross_encoder.max_length = CROSS_EMBED_CONTEXT_SIZE
|
||||||
|
|
||||||
|
|
||||||
@build_timing_wrapper()
|
@log_function_time()
|
||||||
def semantic_reranking(
|
def semantic_reranking(
|
||||||
query: str,
|
query: str,
|
||||||
chunks: List[InferenceChunk],
|
chunks: List[InferenceChunk],
|
||||||
@ -50,6 +50,7 @@ def semantic_reranking(
|
|||||||
return ranked_chunks[:filtered_result_set_size]
|
return ranked_chunks[:filtered_result_set_size]
|
||||||
|
|
||||||
|
|
||||||
|
@log_function_time()
|
||||||
def semantic_search(
|
def semantic_search(
|
||||||
query: str,
|
query: str,
|
||||||
filters: list[DatastoreFilter] | None,
|
filters: list[DatastoreFilter] | None,
|
||||||
@ -65,4 +66,9 @@ def semantic_search(
|
|||||||
)
|
)
|
||||||
return None
|
return None
|
||||||
ranked_chunks = semantic_reranking(query, top_chunks, filtered_result_set_size)
|
ranked_chunks = semantic_reranking(query, top_chunks, filtered_result_set_size)
|
||||||
|
|
||||||
|
top_docs = [ranked_chunk.document_id for ranked_chunk in ranked_chunks]
|
||||||
|
files_log_msg = f"Top links from semantic search: {', '.join(top_docs)}"
|
||||||
|
logger.info(files_log_msg)
|
||||||
|
|
||||||
return ranked_chunks
|
return ranked_chunks
|
||||||
|
21
backend/danswer/server/models.py
Normal file
21
backend/danswer/server/models.py
Normal file
@ -0,0 +1,21 @@
|
|||||||
|
from danswer.datastores.interfaces import DatastoreFilter
|
||||||
|
from pydantic import BaseModel
|
||||||
|
|
||||||
|
|
||||||
|
class ServerStatus(BaseModel):
|
||||||
|
status: str
|
||||||
|
|
||||||
|
|
||||||
|
class QAQuestion(BaseModel):
|
||||||
|
query: str
|
||||||
|
collection: str
|
||||||
|
filters: list[DatastoreFilter] | None
|
||||||
|
|
||||||
|
|
||||||
|
class QAResponse(BaseModel):
|
||||||
|
answer: str | None
|
||||||
|
quotes: dict[str, dict[str, str | int | None]] | None
|
||||||
|
|
||||||
|
|
||||||
|
class KeywordResponse(BaseModel):
|
||||||
|
results: list[str] | None
|
@ -1,20 +1,19 @@
|
|||||||
import time
|
import time
|
||||||
from http import HTTPStatus
|
from http import HTTPStatus
|
||||||
|
|
||||||
from danswer.configs.app_configs import DEFAULT_PROMPT
|
|
||||||
from danswer.configs.app_configs import KEYWORD_MAX_HITS
|
from danswer.configs.app_configs import KEYWORD_MAX_HITS
|
||||||
from danswer.configs.constants import CONTENT
|
from danswer.configs.constants import CONTENT
|
||||||
from danswer.configs.constants import SOURCE_LINKS
|
from danswer.configs.constants import SOURCE_LINKS
|
||||||
from danswer.datastores import create_datastore
|
from danswer.datastores import create_datastore
|
||||||
from danswer.datastores.interfaces import DatastoreFilter
|
from danswer.direct_qa import get_default_backend_qa_model
|
||||||
from danswer.direct_qa.qa_prompts import BASIC_QA_PROMPTS
|
|
||||||
from danswer.direct_qa.question_answer import answer_question
|
|
||||||
from danswer.direct_qa.question_answer import process_answer
|
|
||||||
from danswer.direct_qa.semantic_search import semantic_search
|
from danswer.direct_qa.semantic_search import semantic_search
|
||||||
|
from danswer.server.models import KeywordResponse
|
||||||
|
from danswer.server.models import QAQuestion
|
||||||
|
from danswer.server.models import QAResponse
|
||||||
|
from danswer.server.models import ServerStatus
|
||||||
from danswer.utils.clients import TSClient
|
from danswer.utils.clients import TSClient
|
||||||
from danswer.utils.logging import setup_logger
|
from danswer.utils.logging import setup_logger
|
||||||
from fastapi import APIRouter
|
from fastapi import APIRouter
|
||||||
from pydantic import BaseModel
|
|
||||||
|
|
||||||
|
|
||||||
logger = setup_logger()
|
logger = setup_logger()
|
||||||
@ -22,34 +21,16 @@ logger = setup_logger()
|
|||||||
router = APIRouter()
|
router = APIRouter()
|
||||||
|
|
||||||
|
|
||||||
class ServerStatus(BaseModel):
|
|
||||||
status: str
|
|
||||||
|
|
||||||
|
|
||||||
class QAQuestion(BaseModel):
|
|
||||||
query: str
|
|
||||||
collection: str
|
|
||||||
filters: list[DatastoreFilter] | None
|
|
||||||
|
|
||||||
|
|
||||||
class QAResponse(BaseModel):
|
|
||||||
answer: str | None
|
|
||||||
quotes: dict[str, dict[str, str]] | None
|
|
||||||
|
|
||||||
|
|
||||||
class KeywordResponse(BaseModel):
|
|
||||||
results: list[str] | None
|
|
||||||
|
|
||||||
|
|
||||||
@router.get("/", response_model=ServerStatus)
|
@router.get("/", response_model=ServerStatus)
|
||||||
@router.get("/status", response_model=ServerStatus)
|
@router.get("/status", response_model=ServerStatus)
|
||||||
def read_server_status():
|
def read_server_status():
|
||||||
return {"status": HTTPStatus.OK.value}
|
return ServerStatus(status=HTTPStatus.OK.value)
|
||||||
|
|
||||||
|
|
||||||
@router.post("/direct-qa", response_model=QAResponse)
|
@router.post("/direct-qa", response_model=QAResponse)
|
||||||
def direct_qa(question: QAQuestion):
|
def direct_qa(question: QAQuestion):
|
||||||
prompt_processor = BASIC_QA_PROMPTS[DEFAULT_PROMPT]
|
start_time = time.time()
|
||||||
|
qa_model = get_default_backend_qa_model()
|
||||||
query = question.query
|
query = question.query
|
||||||
collection = question.collection
|
collection = question.collection
|
||||||
filters = question.filters
|
filters = question.filters
|
||||||
@ -58,36 +39,14 @@ def direct_qa(question: QAQuestion):
|
|||||||
|
|
||||||
logger.info(f"Received semantic query: {query}")
|
logger.info(f"Received semantic query: {query}")
|
||||||
|
|
||||||
start_time = time.time()
|
|
||||||
ranked_chunks = semantic_search(query, filters, datastore)
|
ranked_chunks = semantic_search(query, filters, datastore)
|
||||||
sem_search_time = time.time()
|
|
||||||
|
|
||||||
logger.info(f"Semantic search took {sem_search_time - start_time} seconds")
|
|
||||||
|
|
||||||
if not ranked_chunks:
|
if not ranked_chunks:
|
||||||
return {"answer": None, "quotes": None}
|
return {"answer": None, "quotes": None}
|
||||||
|
|
||||||
top_docs = [ranked_chunk.document_id for ranked_chunk in ranked_chunks]
|
answer, quotes = qa_model.answer_question(query, ranked_chunks)
|
||||||
top_contents = [ranked_chunk.content for ranked_chunk in ranked_chunks]
|
logger.info(f"Total QA took {time.time() - start_time} seconds")
|
||||||
|
|
||||||
files_log_msg = f"Top links from semantic search: {', '.join(top_docs)}"
|
return QAResponse(answer=answer, quotes=quotes)
|
||||||
logger.info(files_log_msg)
|
|
||||||
|
|
||||||
qa_answer = answer_question(query, top_contents, prompt_processor)
|
|
||||||
qa_time = time.time()
|
|
||||||
logger.debug(qa_answer)
|
|
||||||
logger.info(f"GPT QA took {qa_time - sem_search_time} seconds")
|
|
||||||
|
|
||||||
# Postprocessing, no more models involved, purely rule based
|
|
||||||
answer, quotes_dict = process_answer(qa_answer, ranked_chunks)
|
|
||||||
postprocess_time = time.time()
|
|
||||||
logger.info(f"Postprocessing took {postprocess_time - qa_time} seconds")
|
|
||||||
|
|
||||||
total_time = time.time() - start_time
|
|
||||||
logger.info(f"Total QA took {total_time} seconds")
|
|
||||||
|
|
||||||
qa_response = {"answer": answer, "quotes": quotes_dict}
|
|
||||||
return qa_response
|
|
||||||
|
|
||||||
|
|
||||||
@router.post("/keyword-search", response_model=KeywordResponse)
|
@router.post("/keyword-search", response_model=KeywordResponse)
|
||||||
@ -114,4 +73,4 @@ def keyword_search(question: QAQuestion):
|
|||||||
total_time = time.time() - start_time
|
total_time = time.time() - start_time
|
||||||
logger.info(f"Total Keyword Search took {total_time} seconds")
|
logger.info(f"Total Keyword Search took {total_time} seconds")
|
||||||
|
|
||||||
return {"results": sources}
|
return KeywordResponse(results=sources)
|
||||||
|
@ -6,13 +6,13 @@ from danswer.utils.logging import setup_logger
|
|||||||
logger = setup_logger()
|
logger = setup_logger()
|
||||||
|
|
||||||
|
|
||||||
def build_timing_wrapper(
|
def log_function_time(
|
||||||
func_name: str | None = None,
|
func_name: str | None = None,
|
||||||
) -> Callable[[Callable], Callable]:
|
) -> Callable[[Callable], Callable]:
|
||||||
"""Build a timing wrapper for a function. Logs how long the function took to run.
|
"""Build a timing wrapper for a function. Logs how long the function took to run.
|
||||||
Use like:
|
Use like:
|
||||||
|
|
||||||
@build_timing_wrapper()
|
@log_function_time()
|
||||||
def my_func():
|
def my_func():
|
||||||
...
|
...
|
||||||
"""
|
"""
|
||||||
|
Loading…
x
Reference in New Issue
Block a user