DAN-21 Modularize QA model for easy swapping (#9)

This commit is contained in:
Yuhong Sun 2023-05-01 23:12:47 -07:00 committed by GitHub
parent e7b901f292
commit c00d37a7d7
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
10 changed files with 112 additions and 99 deletions

View File

@ -18,6 +18,7 @@ MODEL_CACHE_FOLDER = os.environ.get("TRANSFORMERS_CACHE")
# Purely an optimization, memory limitation consideration
BATCH_SIZE_ENCODE_CHUNKS = 8
# OpenAI Model API Configs
# QA Model API Configs
INTERNAL_MODEL_VERSION = os.environ.get("INTERNAL_MODEL", "openai-completion")
OPENAPI_MODEL_VERSION = "text-davinci-003"
OPENAI_MAX_OUTPUT_TOKENS = 200

View File

@ -7,7 +7,7 @@ from danswer.datastores.qdrant.indexing import index_chunks
from danswer.embedding.biencoder import get_default_model
from danswer.utils.clients import get_qdrant_client
from danswer.utils.logging import setup_logger
from danswer.utils.timing import build_timing_wrapper
from danswer.utils.timing import log_function_time
from qdrant_client.http.exceptions import ResponseHandlingException
from qdrant_client.http.exceptions import UnexpectedResponse
from qdrant_client.http.models import FieldCondition
@ -28,7 +28,7 @@ class QdrantDatastore(Datastore):
chunks=chunks, collection=self.collection, client=self.client
)
@build_timing_wrapper()
@log_function_time()
def semantic_retrieval(
self, query: str, filters: list[DatastoreFilter] | None, num_to_retrieve: int
) -> list[InferenceChunk]:

View File

@ -0,0 +1,6 @@
from danswer.direct_qa.interfaces import QAModel
from danswer.direct_qa.question_answer import OpenAICompletionQA
def get_default_backend_qa_model() -> QAModel:
return OpenAICompletionQA()

View File

@ -0,0 +1,12 @@
import abc
from typing import *
from danswer.chunking.models import InferenceChunk
class QAModel:
@abc.abstractmethod
def answer_question(
self, query: str, context_docs: list[InferenceChunk]
) -> tuple[str | None, dict[str, dict[str, str | int | None]] | None]:
raise NotImplementedError

View File

@ -21,8 +21,3 @@ def generic_prompt_processor(question: str, documents: list[str]) -> str:
prompt += f"{QUESTION_PAT}\n{question}\n"
prompt += f"{ANSWER_PAT}\n"
return prompt
BASIC_QA_PROMPTS = {
"generic-qa": generic_prompt_processor,
}

View File

@ -16,12 +16,15 @@ from danswer.configs.constants import SOURCE_LINK
from danswer.configs.constants import SOURCE_TYPE
from danswer.configs.model_configs import OPENAI_MAX_OUTPUT_TOKENS
from danswer.configs.model_configs import OPENAPI_MODEL_VERSION
from danswer.direct_qa.interfaces import QAModel
from danswer.direct_qa.qa_prompts import ANSWER_PAT
from danswer.direct_qa.qa_prompts import generic_prompt_processor
from danswer.direct_qa.qa_prompts import QUOTE_PAT
from danswer.direct_qa.qa_prompts import UNCERTAINTY_PAT
from danswer.utils.logging import setup_logger
from danswer.utils.text_processing import clean_model_quote
from danswer.utils.text_processing import shared_precompare_cleanup
from danswer.utils.timing import log_function_time
logger = setup_logger()
@ -29,39 +32,6 @@ logger = setup_logger()
openai.api_key = OPENAI_API_KEY
def ask_openai(
complete_qa_prompt: str,
model: str = OPENAPI_MODEL_VERSION,
max_tokens: int = OPENAI_MAX_OUTPUT_TOKENS,
) -> str:
try:
response = openai.Completion.create(
prompt=complete_qa_prompt,
temperature=0,
top_p=1,
frequency_penalty=0,
presence_penalty=0,
model=model,
max_tokens=max_tokens,
)
model_answer = response["choices"][0]["text"].strip()
logger.info("OpenAI Token Usage: " + str(response["usage"]).replace("\n", ""))
return model_answer
except Exception as e:
logger.exception(e)
return "Model Failure"
def answer_question(
query: str,
context_docs: list[str],
prompt_processor: Callable[[str, list[str]], str],
) -> str:
formatted_prompt = prompt_processor(query, context_docs)
logger.debug(formatted_prompt)
return ask_openai(formatted_prompt)
def separate_answer_quotes(
answer_raw: str,
) -> Tuple[Optional[str], Optional[list[str]]]:
@ -158,9 +128,52 @@ def match_quotes_to_docs(
def process_answer(
answer_raw: str, chunks: list[InferenceChunk]
) -> Tuple[Optional[str], Optional[Dict[str, Dict[str, Union[str, int, None]]]]]:
) -> tuple[str | None, dict[str, dict[str, str | int | None]] | None]:
answer, quote_strings = separate_answer_quotes(answer_raw)
if not answer or not quote_strings:
return None, None
quotes_dict = match_quotes_to_docs(quote_strings, chunks)
return answer, quotes_dict
class OpenAICompletionQA(QAModel):
def __init__(
self,
prompt_processor: Callable[[str, list[str]], str] = generic_prompt_processor,
model_version: str = OPENAPI_MODEL_VERSION,
max_output_tokens: int = OPENAI_MAX_OUTPUT_TOKENS,
) -> None:
self.prompt_processor = prompt_processor
self.model_version = model_version
self.max_output_tokens = max_output_tokens
@log_function_time()
def answer_question(
self, query: str, context_docs: list[InferenceChunk]
) -> tuple[str | None, dict[str, dict[str, str | int | None]] | None]:
top_contents = [ranked_chunk.content for ranked_chunk in context_docs]
filled_prompt = self.prompt_processor(query, top_contents)
logger.debug(filled_prompt)
try:
response = openai.Completion.create(
prompt=filled_prompt,
temperature=0,
top_p=1,
frequency_penalty=0,
presence_penalty=0,
model=self.model_version,
max_tokens=self.max_output_tokens,
)
model_output = response["choices"][0]["text"].strip()
logger.info(
"OpenAI Token Usage: " + str(response["usage"]).replace("\n", "")
)
except Exception as e:
logger.exception(e)
model_output = "Model Failure"
logger.debug(model_output)
answer, quotes_dict = process_answer(model_output, context_docs)
return answer, quotes_dict

View File

@ -14,7 +14,7 @@ from danswer.configs.model_configs import QUERY_EMBEDDING_CONTEXT_SIZE
from danswer.datastores.interfaces import Datastore
from danswer.datastores.interfaces import DatastoreFilter
from danswer.utils.logging import setup_logger
from danswer.utils.timing import build_timing_wrapper
from danswer.utils.timing import log_function_time
from sentence_transformers import CrossEncoder # type: ignore
from sentence_transformers import SentenceTransformer # type: ignore
@ -32,7 +32,7 @@ cross_encoder = CrossEncoder(CROSS_ENCODER_MODEL)
cross_encoder.max_length = CROSS_EMBED_CONTEXT_SIZE
@build_timing_wrapper()
@log_function_time()
def semantic_reranking(
query: str,
chunks: List[InferenceChunk],
@ -50,6 +50,7 @@ def semantic_reranking(
return ranked_chunks[:filtered_result_set_size]
@log_function_time()
def semantic_search(
query: str,
filters: list[DatastoreFilter] | None,
@ -65,4 +66,9 @@ def semantic_search(
)
return None
ranked_chunks = semantic_reranking(query, top_chunks, filtered_result_set_size)
top_docs = [ranked_chunk.document_id for ranked_chunk in ranked_chunks]
files_log_msg = f"Top links from semantic search: {', '.join(top_docs)}"
logger.info(files_log_msg)
return ranked_chunks

View File

@ -0,0 +1,21 @@
from danswer.datastores.interfaces import DatastoreFilter
from pydantic import BaseModel
class ServerStatus(BaseModel):
status: str
class QAQuestion(BaseModel):
query: str
collection: str
filters: list[DatastoreFilter] | None
class QAResponse(BaseModel):
answer: str | None
quotes: dict[str, dict[str, str | int | None]] | None
class KeywordResponse(BaseModel):
results: list[str] | None

View File

@ -1,20 +1,19 @@
import time
from http import HTTPStatus
from danswer.configs.app_configs import DEFAULT_PROMPT
from danswer.configs.app_configs import KEYWORD_MAX_HITS
from danswer.configs.constants import CONTENT
from danswer.configs.constants import SOURCE_LINKS
from danswer.datastores import create_datastore
from danswer.datastores.interfaces import DatastoreFilter
from danswer.direct_qa.qa_prompts import BASIC_QA_PROMPTS
from danswer.direct_qa.question_answer import answer_question
from danswer.direct_qa.question_answer import process_answer
from danswer.direct_qa import get_default_backend_qa_model
from danswer.direct_qa.semantic_search import semantic_search
from danswer.server.models import KeywordResponse
from danswer.server.models import QAQuestion
from danswer.server.models import QAResponse
from danswer.server.models import ServerStatus
from danswer.utils.clients import TSClient
from danswer.utils.logging import setup_logger
from fastapi import APIRouter
from pydantic import BaseModel
logger = setup_logger()
@ -22,34 +21,16 @@ logger = setup_logger()
router = APIRouter()
class ServerStatus(BaseModel):
status: str
class QAQuestion(BaseModel):
query: str
collection: str
filters: list[DatastoreFilter] | None
class QAResponse(BaseModel):
answer: str | None
quotes: dict[str, dict[str, str]] | None
class KeywordResponse(BaseModel):
results: list[str] | None
@router.get("/", response_model=ServerStatus)
@router.get("/status", response_model=ServerStatus)
def read_server_status():
return {"status": HTTPStatus.OK.value}
return ServerStatus(status=HTTPStatus.OK.value)
@router.post("/direct-qa", response_model=QAResponse)
def direct_qa(question: QAQuestion):
prompt_processor = BASIC_QA_PROMPTS[DEFAULT_PROMPT]
start_time = time.time()
qa_model = get_default_backend_qa_model()
query = question.query
collection = question.collection
filters = question.filters
@ -58,36 +39,14 @@ def direct_qa(question: QAQuestion):
logger.info(f"Received semantic query: {query}")
start_time = time.time()
ranked_chunks = semantic_search(query, filters, datastore)
sem_search_time = time.time()
logger.info(f"Semantic search took {sem_search_time - start_time} seconds")
if not ranked_chunks:
return {"answer": None, "quotes": None}
top_docs = [ranked_chunk.document_id for ranked_chunk in ranked_chunks]
top_contents = [ranked_chunk.content for ranked_chunk in ranked_chunks]
answer, quotes = qa_model.answer_question(query, ranked_chunks)
logger.info(f"Total QA took {time.time() - start_time} seconds")
files_log_msg = f"Top links from semantic search: {', '.join(top_docs)}"
logger.info(files_log_msg)
qa_answer = answer_question(query, top_contents, prompt_processor)
qa_time = time.time()
logger.debug(qa_answer)
logger.info(f"GPT QA took {qa_time - sem_search_time} seconds")
# Postprocessing, no more models involved, purely rule based
answer, quotes_dict = process_answer(qa_answer, ranked_chunks)
postprocess_time = time.time()
logger.info(f"Postprocessing took {postprocess_time - qa_time} seconds")
total_time = time.time() - start_time
logger.info(f"Total QA took {total_time} seconds")
qa_response = {"answer": answer, "quotes": quotes_dict}
return qa_response
return QAResponse(answer=answer, quotes=quotes)
@router.post("/keyword-search", response_model=KeywordResponse)
@ -114,4 +73,4 @@ def keyword_search(question: QAQuestion):
total_time = time.time() - start_time
logger.info(f"Total Keyword Search took {total_time} seconds")
return {"results": sources}
return KeywordResponse(results=sources)

View File

@ -6,13 +6,13 @@ from danswer.utils.logging import setup_logger
logger = setup_logger()
def build_timing_wrapper(
def log_function_time(
func_name: str | None = None,
) -> Callable[[Callable], Callable]:
"""Build a timing wrapper for a function. Logs how long the function took to run.
Use like:
@build_timing_wrapper()
@log_function_time()
def my_func():
...
"""