From c00d37a7d718584a290eff275147e240e907806f Mon Sep 17 00:00:00 2001 From: Yuhong Sun Date: Mon, 1 May 2023 23:12:47 -0700 Subject: [PATCH] DAN-21 Modularize QA model for easy swapping (#9) --- backend/danswer/configs/model_configs.py | 3 +- backend/danswer/datastores/qdrant/store.py | 4 +- backend/danswer/direct_qa/__init__.py | 6 ++ backend/danswer/direct_qa/interfaces.py | 12 +++ backend/danswer/direct_qa/qa_prompts.py | 5 -- backend/danswer/direct_qa/question_answer.py | 81 ++++++++++++-------- backend/danswer/direct_qa/semantic_search.py | 10 ++- backend/danswer/server/models.py | 21 +++++ backend/danswer/server/search_backend.py | 65 +++------------- backend/danswer/utils/timing.py | 4 +- 10 files changed, 112 insertions(+), 99 deletions(-) create mode 100644 backend/danswer/direct_qa/interfaces.py create mode 100644 backend/danswer/server/models.py diff --git a/backend/danswer/configs/model_configs.py b/backend/danswer/configs/model_configs.py index 403b2986d..d6f050624 100644 --- a/backend/danswer/configs/model_configs.py +++ b/backend/danswer/configs/model_configs.py @@ -18,6 +18,7 @@ MODEL_CACHE_FOLDER = os.environ.get("TRANSFORMERS_CACHE") # Purely an optimization, memory limitation consideration BATCH_SIZE_ENCODE_CHUNKS = 8 -# OpenAI Model API Configs +# QA Model API Configs +INTERNAL_MODEL_VERSION = os.environ.get("INTERNAL_MODEL", "openai-completion") OPENAPI_MODEL_VERSION = "text-davinci-003" OPENAI_MAX_OUTPUT_TOKENS = 200 diff --git a/backend/danswer/datastores/qdrant/store.py b/backend/danswer/datastores/qdrant/store.py index d6dc92d51..a8b02c99a 100644 --- a/backend/danswer/datastores/qdrant/store.py +++ b/backend/danswer/datastores/qdrant/store.py @@ -7,7 +7,7 @@ from danswer.datastores.qdrant.indexing import index_chunks from danswer.embedding.biencoder import get_default_model from danswer.utils.clients import get_qdrant_client from danswer.utils.logging import setup_logger -from danswer.utils.timing import build_timing_wrapper +from danswer.utils.timing import log_function_time from qdrant_client.http.exceptions import ResponseHandlingException from qdrant_client.http.exceptions import UnexpectedResponse from qdrant_client.http.models import FieldCondition @@ -28,7 +28,7 @@ class QdrantDatastore(Datastore): chunks=chunks, collection=self.collection, client=self.client ) - @build_timing_wrapper() + @log_function_time() def semantic_retrieval( self, query: str, filters: list[DatastoreFilter] | None, num_to_retrieve: int ) -> list[InferenceChunk]: diff --git a/backend/danswer/direct_qa/__init__.py b/backend/danswer/direct_qa/__init__.py index e69de29bb..0f82721a2 100644 --- a/backend/danswer/direct_qa/__init__.py +++ b/backend/danswer/direct_qa/__init__.py @@ -0,0 +1,6 @@ +from danswer.direct_qa.interfaces import QAModel +from danswer.direct_qa.question_answer import OpenAICompletionQA + + +def get_default_backend_qa_model() -> QAModel: + return OpenAICompletionQA() diff --git a/backend/danswer/direct_qa/interfaces.py b/backend/danswer/direct_qa/interfaces.py new file mode 100644 index 000000000..d1d9b939b --- /dev/null +++ b/backend/danswer/direct_qa/interfaces.py @@ -0,0 +1,12 @@ +import abc +from typing import * + +from danswer.chunking.models import InferenceChunk + + +class QAModel: + @abc.abstractmethod + def answer_question( + self, query: str, context_docs: list[InferenceChunk] + ) -> tuple[str | None, dict[str, dict[str, str | int | None]] | None]: + raise NotImplementedError diff --git a/backend/danswer/direct_qa/qa_prompts.py b/backend/danswer/direct_qa/qa_prompts.py index 2a36624b2..199f5f205 100644 --- a/backend/danswer/direct_qa/qa_prompts.py +++ b/backend/danswer/direct_qa/qa_prompts.py @@ -21,8 +21,3 @@ def generic_prompt_processor(question: str, documents: list[str]) -> str: prompt += f"{QUESTION_PAT}\n{question}\n" prompt += f"{ANSWER_PAT}\n" return prompt - - -BASIC_QA_PROMPTS = { - "generic-qa": generic_prompt_processor, -} diff --git a/backend/danswer/direct_qa/question_answer.py b/backend/danswer/direct_qa/question_answer.py index afb3a29ef..ff05c0858 100644 --- a/backend/danswer/direct_qa/question_answer.py +++ b/backend/danswer/direct_qa/question_answer.py @@ -16,12 +16,15 @@ from danswer.configs.constants import SOURCE_LINK from danswer.configs.constants import SOURCE_TYPE from danswer.configs.model_configs import OPENAI_MAX_OUTPUT_TOKENS from danswer.configs.model_configs import OPENAPI_MODEL_VERSION +from danswer.direct_qa.interfaces import QAModel from danswer.direct_qa.qa_prompts import ANSWER_PAT +from danswer.direct_qa.qa_prompts import generic_prompt_processor from danswer.direct_qa.qa_prompts import QUOTE_PAT from danswer.direct_qa.qa_prompts import UNCERTAINTY_PAT from danswer.utils.logging import setup_logger from danswer.utils.text_processing import clean_model_quote from danswer.utils.text_processing import shared_precompare_cleanup +from danswer.utils.timing import log_function_time logger = setup_logger() @@ -29,39 +32,6 @@ logger = setup_logger() openai.api_key = OPENAI_API_KEY -def ask_openai( - complete_qa_prompt: str, - model: str = OPENAPI_MODEL_VERSION, - max_tokens: int = OPENAI_MAX_OUTPUT_TOKENS, -) -> str: - try: - response = openai.Completion.create( - prompt=complete_qa_prompt, - temperature=0, - top_p=1, - frequency_penalty=0, - presence_penalty=0, - model=model, - max_tokens=max_tokens, - ) - model_answer = response["choices"][0]["text"].strip() - logger.info("OpenAI Token Usage: " + str(response["usage"]).replace("\n", "")) - return model_answer - except Exception as e: - logger.exception(e) - return "Model Failure" - - -def answer_question( - query: str, - context_docs: list[str], - prompt_processor: Callable[[str, list[str]], str], -) -> str: - formatted_prompt = prompt_processor(query, context_docs) - logger.debug(formatted_prompt) - return ask_openai(formatted_prompt) - - def separate_answer_quotes( answer_raw: str, ) -> Tuple[Optional[str], Optional[list[str]]]: @@ -158,9 +128,52 @@ def match_quotes_to_docs( def process_answer( answer_raw: str, chunks: list[InferenceChunk] -) -> Tuple[Optional[str], Optional[Dict[str, Dict[str, Union[str, int, None]]]]]: +) -> tuple[str | None, dict[str, dict[str, str | int | None]] | None]: answer, quote_strings = separate_answer_quotes(answer_raw) if not answer or not quote_strings: return None, None quotes_dict = match_quotes_to_docs(quote_strings, chunks) return answer, quotes_dict + + +class OpenAICompletionQA(QAModel): + def __init__( + self, + prompt_processor: Callable[[str, list[str]], str] = generic_prompt_processor, + model_version: str = OPENAPI_MODEL_VERSION, + max_output_tokens: int = OPENAI_MAX_OUTPUT_TOKENS, + ) -> None: + self.prompt_processor = prompt_processor + self.model_version = model_version + self.max_output_tokens = max_output_tokens + + @log_function_time() + def answer_question( + self, query: str, context_docs: list[InferenceChunk] + ) -> tuple[str | None, dict[str, dict[str, str | int | None]] | None]: + top_contents = [ranked_chunk.content for ranked_chunk in context_docs] + filled_prompt = self.prompt_processor(query, top_contents) + logger.debug(filled_prompt) + + try: + response = openai.Completion.create( + prompt=filled_prompt, + temperature=0, + top_p=1, + frequency_penalty=0, + presence_penalty=0, + model=self.model_version, + max_tokens=self.max_output_tokens, + ) + model_output = response["choices"][0]["text"].strip() + logger.info( + "OpenAI Token Usage: " + str(response["usage"]).replace("\n", "") + ) + except Exception as e: + logger.exception(e) + model_output = "Model Failure" + + logger.debug(model_output) + + answer, quotes_dict = process_answer(model_output, context_docs) + return answer, quotes_dict diff --git a/backend/danswer/direct_qa/semantic_search.py b/backend/danswer/direct_qa/semantic_search.py index 8c8c31d94..3b033c5cb 100644 --- a/backend/danswer/direct_qa/semantic_search.py +++ b/backend/danswer/direct_qa/semantic_search.py @@ -14,7 +14,7 @@ from danswer.configs.model_configs import QUERY_EMBEDDING_CONTEXT_SIZE from danswer.datastores.interfaces import Datastore from danswer.datastores.interfaces import DatastoreFilter from danswer.utils.logging import setup_logger -from danswer.utils.timing import build_timing_wrapper +from danswer.utils.timing import log_function_time from sentence_transformers import CrossEncoder # type: ignore from sentence_transformers import SentenceTransformer # type: ignore @@ -32,7 +32,7 @@ cross_encoder = CrossEncoder(CROSS_ENCODER_MODEL) cross_encoder.max_length = CROSS_EMBED_CONTEXT_SIZE -@build_timing_wrapper() +@log_function_time() def semantic_reranking( query: str, chunks: List[InferenceChunk], @@ -50,6 +50,7 @@ def semantic_reranking( return ranked_chunks[:filtered_result_set_size] +@log_function_time() def semantic_search( query: str, filters: list[DatastoreFilter] | None, @@ -65,4 +66,9 @@ def semantic_search( ) return None ranked_chunks = semantic_reranking(query, top_chunks, filtered_result_set_size) + + top_docs = [ranked_chunk.document_id for ranked_chunk in ranked_chunks] + files_log_msg = f"Top links from semantic search: {', '.join(top_docs)}" + logger.info(files_log_msg) + return ranked_chunks diff --git a/backend/danswer/server/models.py b/backend/danswer/server/models.py new file mode 100644 index 000000000..32cffe090 --- /dev/null +++ b/backend/danswer/server/models.py @@ -0,0 +1,21 @@ +from danswer.datastores.interfaces import DatastoreFilter +from pydantic import BaseModel + + +class ServerStatus(BaseModel): + status: str + + +class QAQuestion(BaseModel): + query: str + collection: str + filters: list[DatastoreFilter] | None + + +class QAResponse(BaseModel): + answer: str | None + quotes: dict[str, dict[str, str | int | None]] | None + + +class KeywordResponse(BaseModel): + results: list[str] | None diff --git a/backend/danswer/server/search_backend.py b/backend/danswer/server/search_backend.py index 1068d421c..f5e777791 100644 --- a/backend/danswer/server/search_backend.py +++ b/backend/danswer/server/search_backend.py @@ -1,20 +1,19 @@ import time from http import HTTPStatus -from danswer.configs.app_configs import DEFAULT_PROMPT from danswer.configs.app_configs import KEYWORD_MAX_HITS from danswer.configs.constants import CONTENT from danswer.configs.constants import SOURCE_LINKS from danswer.datastores import create_datastore -from danswer.datastores.interfaces import DatastoreFilter -from danswer.direct_qa.qa_prompts import BASIC_QA_PROMPTS -from danswer.direct_qa.question_answer import answer_question -from danswer.direct_qa.question_answer import process_answer +from danswer.direct_qa import get_default_backend_qa_model from danswer.direct_qa.semantic_search import semantic_search +from danswer.server.models import KeywordResponse +from danswer.server.models import QAQuestion +from danswer.server.models import QAResponse +from danswer.server.models import ServerStatus from danswer.utils.clients import TSClient from danswer.utils.logging import setup_logger from fastapi import APIRouter -from pydantic import BaseModel logger = setup_logger() @@ -22,34 +21,16 @@ logger = setup_logger() router = APIRouter() -class ServerStatus(BaseModel): - status: str - - -class QAQuestion(BaseModel): - query: str - collection: str - filters: list[DatastoreFilter] | None - - -class QAResponse(BaseModel): - answer: str | None - quotes: dict[str, dict[str, str]] | None - - -class KeywordResponse(BaseModel): - results: list[str] | None - - @router.get("/", response_model=ServerStatus) @router.get("/status", response_model=ServerStatus) def read_server_status(): - return {"status": HTTPStatus.OK.value} + return ServerStatus(status=HTTPStatus.OK.value) @router.post("/direct-qa", response_model=QAResponse) def direct_qa(question: QAQuestion): - prompt_processor = BASIC_QA_PROMPTS[DEFAULT_PROMPT] + start_time = time.time() + qa_model = get_default_backend_qa_model() query = question.query collection = question.collection filters = question.filters @@ -58,36 +39,14 @@ def direct_qa(question: QAQuestion): logger.info(f"Received semantic query: {query}") - start_time = time.time() ranked_chunks = semantic_search(query, filters, datastore) - sem_search_time = time.time() - - logger.info(f"Semantic search took {sem_search_time - start_time} seconds") - if not ranked_chunks: return {"answer": None, "quotes": None} - top_docs = [ranked_chunk.document_id for ranked_chunk in ranked_chunks] - top_contents = [ranked_chunk.content for ranked_chunk in ranked_chunks] + answer, quotes = qa_model.answer_question(query, ranked_chunks) + logger.info(f"Total QA took {time.time() - start_time} seconds") - files_log_msg = f"Top links from semantic search: {', '.join(top_docs)}" - logger.info(files_log_msg) - - qa_answer = answer_question(query, top_contents, prompt_processor) - qa_time = time.time() - logger.debug(qa_answer) - logger.info(f"GPT QA took {qa_time - sem_search_time} seconds") - - # Postprocessing, no more models involved, purely rule based - answer, quotes_dict = process_answer(qa_answer, ranked_chunks) - postprocess_time = time.time() - logger.info(f"Postprocessing took {postprocess_time - qa_time} seconds") - - total_time = time.time() - start_time - logger.info(f"Total QA took {total_time} seconds") - - qa_response = {"answer": answer, "quotes": quotes_dict} - return qa_response + return QAResponse(answer=answer, quotes=quotes) @router.post("/keyword-search", response_model=KeywordResponse) @@ -114,4 +73,4 @@ def keyword_search(question: QAQuestion): total_time = time.time() - start_time logger.info(f"Total Keyword Search took {total_time} seconds") - return {"results": sources} + return KeywordResponse(results=sources) diff --git a/backend/danswer/utils/timing.py b/backend/danswer/utils/timing.py index c6d1ab702..c7d82a8ad 100644 --- a/backend/danswer/utils/timing.py +++ b/backend/danswer/utils/timing.py @@ -6,13 +6,13 @@ from danswer.utils.logging import setup_logger logger = setup_logger() -def build_timing_wrapper( +def log_function_time( func_name: str | None = None, ) -> Callable[[Callable], Callable]: """Build a timing wrapper for a function. Logs how long the function took to run. Use like: - @build_timing_wrapper() + @log_function_time() def my_func(): ... """