mirror of
https://github.com/danswer-ai/danswer.git
synced 2025-05-06 01:40:32 +02:00
DAN-21 Modularize QA model for easy swapping (#9)
This commit is contained in:
parent
e7b901f292
commit
c00d37a7d7
@ -18,6 +18,7 @@ MODEL_CACHE_FOLDER = os.environ.get("TRANSFORMERS_CACHE")
|
||||
# Purely an optimization, memory limitation consideration
|
||||
BATCH_SIZE_ENCODE_CHUNKS = 8
|
||||
|
||||
# OpenAI Model API Configs
|
||||
# QA Model API Configs
|
||||
INTERNAL_MODEL_VERSION = os.environ.get("INTERNAL_MODEL", "openai-completion")
|
||||
OPENAPI_MODEL_VERSION = "text-davinci-003"
|
||||
OPENAI_MAX_OUTPUT_TOKENS = 200
|
||||
|
@ -7,7 +7,7 @@ from danswer.datastores.qdrant.indexing import index_chunks
|
||||
from danswer.embedding.biencoder import get_default_model
|
||||
from danswer.utils.clients import get_qdrant_client
|
||||
from danswer.utils.logging import setup_logger
|
||||
from danswer.utils.timing import build_timing_wrapper
|
||||
from danswer.utils.timing import log_function_time
|
||||
from qdrant_client.http.exceptions import ResponseHandlingException
|
||||
from qdrant_client.http.exceptions import UnexpectedResponse
|
||||
from qdrant_client.http.models import FieldCondition
|
||||
@ -28,7 +28,7 @@ class QdrantDatastore(Datastore):
|
||||
chunks=chunks, collection=self.collection, client=self.client
|
||||
)
|
||||
|
||||
@build_timing_wrapper()
|
||||
@log_function_time()
|
||||
def semantic_retrieval(
|
||||
self, query: str, filters: list[DatastoreFilter] | None, num_to_retrieve: int
|
||||
) -> list[InferenceChunk]:
|
||||
|
@ -0,0 +1,6 @@
|
||||
from danswer.direct_qa.interfaces import QAModel
|
||||
from danswer.direct_qa.question_answer import OpenAICompletionQA
|
||||
|
||||
|
||||
def get_default_backend_qa_model() -> QAModel:
|
||||
return OpenAICompletionQA()
|
12
backend/danswer/direct_qa/interfaces.py
Normal file
12
backend/danswer/direct_qa/interfaces.py
Normal file
@ -0,0 +1,12 @@
|
||||
import abc
|
||||
from typing import *
|
||||
|
||||
from danswer.chunking.models import InferenceChunk
|
||||
|
||||
|
||||
class QAModel:
|
||||
@abc.abstractmethod
|
||||
def answer_question(
|
||||
self, query: str, context_docs: list[InferenceChunk]
|
||||
) -> tuple[str | None, dict[str, dict[str, str | int | None]] | None]:
|
||||
raise NotImplementedError
|
@ -21,8 +21,3 @@ def generic_prompt_processor(question: str, documents: list[str]) -> str:
|
||||
prompt += f"{QUESTION_PAT}\n{question}\n"
|
||||
prompt += f"{ANSWER_PAT}\n"
|
||||
return prompt
|
||||
|
||||
|
||||
BASIC_QA_PROMPTS = {
|
||||
"generic-qa": generic_prompt_processor,
|
||||
}
|
||||
|
@ -16,12 +16,15 @@ from danswer.configs.constants import SOURCE_LINK
|
||||
from danswer.configs.constants import SOURCE_TYPE
|
||||
from danswer.configs.model_configs import OPENAI_MAX_OUTPUT_TOKENS
|
||||
from danswer.configs.model_configs import OPENAPI_MODEL_VERSION
|
||||
from danswer.direct_qa.interfaces import QAModel
|
||||
from danswer.direct_qa.qa_prompts import ANSWER_PAT
|
||||
from danswer.direct_qa.qa_prompts import generic_prompt_processor
|
||||
from danswer.direct_qa.qa_prompts import QUOTE_PAT
|
||||
from danswer.direct_qa.qa_prompts import UNCERTAINTY_PAT
|
||||
from danswer.utils.logging import setup_logger
|
||||
from danswer.utils.text_processing import clean_model_quote
|
||||
from danswer.utils.text_processing import shared_precompare_cleanup
|
||||
from danswer.utils.timing import log_function_time
|
||||
|
||||
|
||||
logger = setup_logger()
|
||||
@ -29,39 +32,6 @@ logger = setup_logger()
|
||||
openai.api_key = OPENAI_API_KEY
|
||||
|
||||
|
||||
def ask_openai(
|
||||
complete_qa_prompt: str,
|
||||
model: str = OPENAPI_MODEL_VERSION,
|
||||
max_tokens: int = OPENAI_MAX_OUTPUT_TOKENS,
|
||||
) -> str:
|
||||
try:
|
||||
response = openai.Completion.create(
|
||||
prompt=complete_qa_prompt,
|
||||
temperature=0,
|
||||
top_p=1,
|
||||
frequency_penalty=0,
|
||||
presence_penalty=0,
|
||||
model=model,
|
||||
max_tokens=max_tokens,
|
||||
)
|
||||
model_answer = response["choices"][0]["text"].strip()
|
||||
logger.info("OpenAI Token Usage: " + str(response["usage"]).replace("\n", ""))
|
||||
return model_answer
|
||||
except Exception as e:
|
||||
logger.exception(e)
|
||||
return "Model Failure"
|
||||
|
||||
|
||||
def answer_question(
|
||||
query: str,
|
||||
context_docs: list[str],
|
||||
prompt_processor: Callable[[str, list[str]], str],
|
||||
) -> str:
|
||||
formatted_prompt = prompt_processor(query, context_docs)
|
||||
logger.debug(formatted_prompt)
|
||||
return ask_openai(formatted_prompt)
|
||||
|
||||
|
||||
def separate_answer_quotes(
|
||||
answer_raw: str,
|
||||
) -> Tuple[Optional[str], Optional[list[str]]]:
|
||||
@ -158,9 +128,52 @@ def match_quotes_to_docs(
|
||||
|
||||
def process_answer(
|
||||
answer_raw: str, chunks: list[InferenceChunk]
|
||||
) -> Tuple[Optional[str], Optional[Dict[str, Dict[str, Union[str, int, None]]]]]:
|
||||
) -> tuple[str | None, dict[str, dict[str, str | int | None]] | None]:
|
||||
answer, quote_strings = separate_answer_quotes(answer_raw)
|
||||
if not answer or not quote_strings:
|
||||
return None, None
|
||||
quotes_dict = match_quotes_to_docs(quote_strings, chunks)
|
||||
return answer, quotes_dict
|
||||
|
||||
|
||||
class OpenAICompletionQA(QAModel):
|
||||
def __init__(
|
||||
self,
|
||||
prompt_processor: Callable[[str, list[str]], str] = generic_prompt_processor,
|
||||
model_version: str = OPENAPI_MODEL_VERSION,
|
||||
max_output_tokens: int = OPENAI_MAX_OUTPUT_TOKENS,
|
||||
) -> None:
|
||||
self.prompt_processor = prompt_processor
|
||||
self.model_version = model_version
|
||||
self.max_output_tokens = max_output_tokens
|
||||
|
||||
@log_function_time()
|
||||
def answer_question(
|
||||
self, query: str, context_docs: list[InferenceChunk]
|
||||
) -> tuple[str | None, dict[str, dict[str, str | int | None]] | None]:
|
||||
top_contents = [ranked_chunk.content for ranked_chunk in context_docs]
|
||||
filled_prompt = self.prompt_processor(query, top_contents)
|
||||
logger.debug(filled_prompt)
|
||||
|
||||
try:
|
||||
response = openai.Completion.create(
|
||||
prompt=filled_prompt,
|
||||
temperature=0,
|
||||
top_p=1,
|
||||
frequency_penalty=0,
|
||||
presence_penalty=0,
|
||||
model=self.model_version,
|
||||
max_tokens=self.max_output_tokens,
|
||||
)
|
||||
model_output = response["choices"][0]["text"].strip()
|
||||
logger.info(
|
||||
"OpenAI Token Usage: " + str(response["usage"]).replace("\n", "")
|
||||
)
|
||||
except Exception as e:
|
||||
logger.exception(e)
|
||||
model_output = "Model Failure"
|
||||
|
||||
logger.debug(model_output)
|
||||
|
||||
answer, quotes_dict = process_answer(model_output, context_docs)
|
||||
return answer, quotes_dict
|
||||
|
@ -14,7 +14,7 @@ from danswer.configs.model_configs import QUERY_EMBEDDING_CONTEXT_SIZE
|
||||
from danswer.datastores.interfaces import Datastore
|
||||
from danswer.datastores.interfaces import DatastoreFilter
|
||||
from danswer.utils.logging import setup_logger
|
||||
from danswer.utils.timing import build_timing_wrapper
|
||||
from danswer.utils.timing import log_function_time
|
||||
from sentence_transformers import CrossEncoder # type: ignore
|
||||
from sentence_transformers import SentenceTransformer # type: ignore
|
||||
|
||||
@ -32,7 +32,7 @@ cross_encoder = CrossEncoder(CROSS_ENCODER_MODEL)
|
||||
cross_encoder.max_length = CROSS_EMBED_CONTEXT_SIZE
|
||||
|
||||
|
||||
@build_timing_wrapper()
|
||||
@log_function_time()
|
||||
def semantic_reranking(
|
||||
query: str,
|
||||
chunks: List[InferenceChunk],
|
||||
@ -50,6 +50,7 @@ def semantic_reranking(
|
||||
return ranked_chunks[:filtered_result_set_size]
|
||||
|
||||
|
||||
@log_function_time()
|
||||
def semantic_search(
|
||||
query: str,
|
||||
filters: list[DatastoreFilter] | None,
|
||||
@ -65,4 +66,9 @@ def semantic_search(
|
||||
)
|
||||
return None
|
||||
ranked_chunks = semantic_reranking(query, top_chunks, filtered_result_set_size)
|
||||
|
||||
top_docs = [ranked_chunk.document_id for ranked_chunk in ranked_chunks]
|
||||
files_log_msg = f"Top links from semantic search: {', '.join(top_docs)}"
|
||||
logger.info(files_log_msg)
|
||||
|
||||
return ranked_chunks
|
||||
|
21
backend/danswer/server/models.py
Normal file
21
backend/danswer/server/models.py
Normal file
@ -0,0 +1,21 @@
|
||||
from danswer.datastores.interfaces import DatastoreFilter
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
class ServerStatus(BaseModel):
|
||||
status: str
|
||||
|
||||
|
||||
class QAQuestion(BaseModel):
|
||||
query: str
|
||||
collection: str
|
||||
filters: list[DatastoreFilter] | None
|
||||
|
||||
|
||||
class QAResponse(BaseModel):
|
||||
answer: str | None
|
||||
quotes: dict[str, dict[str, str | int | None]] | None
|
||||
|
||||
|
||||
class KeywordResponse(BaseModel):
|
||||
results: list[str] | None
|
@ -1,20 +1,19 @@
|
||||
import time
|
||||
from http import HTTPStatus
|
||||
|
||||
from danswer.configs.app_configs import DEFAULT_PROMPT
|
||||
from danswer.configs.app_configs import KEYWORD_MAX_HITS
|
||||
from danswer.configs.constants import CONTENT
|
||||
from danswer.configs.constants import SOURCE_LINKS
|
||||
from danswer.datastores import create_datastore
|
||||
from danswer.datastores.interfaces import DatastoreFilter
|
||||
from danswer.direct_qa.qa_prompts import BASIC_QA_PROMPTS
|
||||
from danswer.direct_qa.question_answer import answer_question
|
||||
from danswer.direct_qa.question_answer import process_answer
|
||||
from danswer.direct_qa import get_default_backend_qa_model
|
||||
from danswer.direct_qa.semantic_search import semantic_search
|
||||
from danswer.server.models import KeywordResponse
|
||||
from danswer.server.models import QAQuestion
|
||||
from danswer.server.models import QAResponse
|
||||
from danswer.server.models import ServerStatus
|
||||
from danswer.utils.clients import TSClient
|
||||
from danswer.utils.logging import setup_logger
|
||||
from fastapi import APIRouter
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
logger = setup_logger()
|
||||
@ -22,34 +21,16 @@ logger = setup_logger()
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
class ServerStatus(BaseModel):
|
||||
status: str
|
||||
|
||||
|
||||
class QAQuestion(BaseModel):
|
||||
query: str
|
||||
collection: str
|
||||
filters: list[DatastoreFilter] | None
|
||||
|
||||
|
||||
class QAResponse(BaseModel):
|
||||
answer: str | None
|
||||
quotes: dict[str, dict[str, str]] | None
|
||||
|
||||
|
||||
class KeywordResponse(BaseModel):
|
||||
results: list[str] | None
|
||||
|
||||
|
||||
@router.get("/", response_model=ServerStatus)
|
||||
@router.get("/status", response_model=ServerStatus)
|
||||
def read_server_status():
|
||||
return {"status": HTTPStatus.OK.value}
|
||||
return ServerStatus(status=HTTPStatus.OK.value)
|
||||
|
||||
|
||||
@router.post("/direct-qa", response_model=QAResponse)
|
||||
def direct_qa(question: QAQuestion):
|
||||
prompt_processor = BASIC_QA_PROMPTS[DEFAULT_PROMPT]
|
||||
start_time = time.time()
|
||||
qa_model = get_default_backend_qa_model()
|
||||
query = question.query
|
||||
collection = question.collection
|
||||
filters = question.filters
|
||||
@ -58,36 +39,14 @@ def direct_qa(question: QAQuestion):
|
||||
|
||||
logger.info(f"Received semantic query: {query}")
|
||||
|
||||
start_time = time.time()
|
||||
ranked_chunks = semantic_search(query, filters, datastore)
|
||||
sem_search_time = time.time()
|
||||
|
||||
logger.info(f"Semantic search took {sem_search_time - start_time} seconds")
|
||||
|
||||
if not ranked_chunks:
|
||||
return {"answer": None, "quotes": None}
|
||||
|
||||
top_docs = [ranked_chunk.document_id for ranked_chunk in ranked_chunks]
|
||||
top_contents = [ranked_chunk.content for ranked_chunk in ranked_chunks]
|
||||
answer, quotes = qa_model.answer_question(query, ranked_chunks)
|
||||
logger.info(f"Total QA took {time.time() - start_time} seconds")
|
||||
|
||||
files_log_msg = f"Top links from semantic search: {', '.join(top_docs)}"
|
||||
logger.info(files_log_msg)
|
||||
|
||||
qa_answer = answer_question(query, top_contents, prompt_processor)
|
||||
qa_time = time.time()
|
||||
logger.debug(qa_answer)
|
||||
logger.info(f"GPT QA took {qa_time - sem_search_time} seconds")
|
||||
|
||||
# Postprocessing, no more models involved, purely rule based
|
||||
answer, quotes_dict = process_answer(qa_answer, ranked_chunks)
|
||||
postprocess_time = time.time()
|
||||
logger.info(f"Postprocessing took {postprocess_time - qa_time} seconds")
|
||||
|
||||
total_time = time.time() - start_time
|
||||
logger.info(f"Total QA took {total_time} seconds")
|
||||
|
||||
qa_response = {"answer": answer, "quotes": quotes_dict}
|
||||
return qa_response
|
||||
return QAResponse(answer=answer, quotes=quotes)
|
||||
|
||||
|
||||
@router.post("/keyword-search", response_model=KeywordResponse)
|
||||
@ -114,4 +73,4 @@ def keyword_search(question: QAQuestion):
|
||||
total_time = time.time() - start_time
|
||||
logger.info(f"Total Keyword Search took {total_time} seconds")
|
||||
|
||||
return {"results": sources}
|
||||
return KeywordResponse(results=sources)
|
||||
|
@ -6,13 +6,13 @@ from danswer.utils.logging import setup_logger
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
def build_timing_wrapper(
|
||||
def log_function_time(
|
||||
func_name: str | None = None,
|
||||
) -> Callable[[Callable], Callable]:
|
||||
"""Build a timing wrapper for a function. Logs how long the function took to run.
|
||||
Use like:
|
||||
|
||||
@build_timing_wrapper()
|
||||
@log_function_time()
|
||||
def my_func():
|
||||
...
|
||||
"""
|
||||
|
Loading…
x
Reference in New Issue
Block a user