DAN-26 Enable GPT 4 through chat completion endpoint (#10)

Also touched up front page README which had a typo
This commit is contained in:
Yuhong Sun
2023-05-02 12:08:44 -07:00
committed by GitHub
parent c00d37a7d7
commit 22b7f7e89f
6 changed files with 102 additions and 7 deletions

View File

@@ -19,6 +19,7 @@ MODEL_CACHE_FOLDER = os.environ.get("TRANSFORMERS_CACHE")
BATCH_SIZE_ENCODE_CHUNKS = 8
# QA Model API Configs
# https://platform.openai.com/docs/models/model-endpoint-compatibility
INTERNAL_MODEL_VERSION = os.environ.get("INTERNAL_MODEL", "openai-completion")
OPENAPI_MODEL_VERSION = "text-davinci-003"
OPENAI_MAX_OUTPUT_TOKENS = 200
OPENAPI_MODEL_VERSION = os.environ.get("OPENAI_MODEL_VERSION", "text-davinci-003")
OPENAI_MAX_OUTPUT_TOKENS = 400

View File

@@ -1,6 +1,15 @@
from danswer.configs.model_configs import INTERNAL_MODEL_VERSION
from danswer.direct_qa.interfaces import QAModel
from danswer.direct_qa.question_answer import OpenAIChatCompletionQA
from danswer.direct_qa.question_answer import OpenAICompletionQA
def get_default_backend_qa_model() -> QAModel:
return OpenAICompletionQA()
def get_default_backend_qa_model(
internal_model: str = INTERNAL_MODEL_VERSION,
) -> QAModel:
if internal_model == "openai-completion":
return OpenAICompletionQA()
elif internal_model == "openai-chat-completion":
return OpenAIChatCompletionQA()
else:
raise ValueError("Wrong internal QA model set.")

View File

@@ -21,3 +21,42 @@ def generic_prompt_processor(question: str, documents: list[str]) -> str:
prompt += f"{QUESTION_PAT}\n{question}\n"
prompt += f"{ANSWER_PAT}\n"
return prompt
def openai_chat_completion_processor(
question: str, documents: list[str]
) -> list[dict[str, str]]:
sample_quote = "Quote:\nThe hotdogs are freshly cooked.\n\nQuote:\nThey are very cheap at only a dollar each."
role_msg = (
f"You are a Question Answering assistant that answers queries based on provided documents. "
f'You will be asked to acknowledge a set of documents and then provide one "{ANSWER_PAT}" and '
f'as many "{QUOTE_PAT}" sections as is relevant to back up your answer. '
f"Answer the question directly and concisely. "
f"Each quote should be a single continuous segment from a document. "
f'If the query cannot be answered based on the documents, say "{UNCERTAINTY_PAT}". '
f"An example of quote sections may look like:\n{sample_quote}"
)
messages = [
{"role": "system", "content": role_msg},
]
for document in documents:
messages.extend(
[
{
"role": "user",
"content": f"Acknowledge the following document:\n{document}",
},
{"role": "assistant", "content": "Acknowledged"},
]
)
messages.append(
{
"role": "user",
"content": f"Please now answer the following query based on the previously provided "
f"documents and quote the relevant sections of the documents\n{question}",
},
)
return messages

View File

@@ -19,6 +19,7 @@ from danswer.configs.model_configs import OPENAPI_MODEL_VERSION
from danswer.direct_qa.interfaces import QAModel
from danswer.direct_qa.qa_prompts import ANSWER_PAT
from danswer.direct_qa.qa_prompts import generic_prompt_processor
from danswer.direct_qa.qa_prompts import openai_chat_completion_processor
from danswer.direct_qa.qa_prompts import QUOTE_PAT
from danswer.direct_qa.qa_prompts import UNCERTAINTY_PAT
from danswer.utils.logging import setup_logger
@@ -177,3 +178,48 @@ class OpenAICompletionQA(QAModel):
answer, quotes_dict = process_answer(model_output, context_docs)
return answer, quotes_dict
class OpenAIChatCompletionQA(QAModel):
def __init__(
self,
prompt_processor: Callable[
[str, list[str]], list[dict[str, str]]
] = openai_chat_completion_processor,
model_version: str = OPENAPI_MODEL_VERSION,
max_output_tokens: int = OPENAI_MAX_OUTPUT_TOKENS,
) -> None:
self.prompt_processor = prompt_processor
self.model_version = model_version
self.max_output_tokens = max_output_tokens
@log_function_time()
def answer_question(
self, query: str, context_docs: list[InferenceChunk]
) -> tuple[str | None, dict[str, dict[str, str | int | None]] | None]:
top_contents = [ranked_chunk.content for ranked_chunk in context_docs]
messages = self.prompt_processor(query, top_contents)
logger.debug(messages)
try:
response = openai.ChatCompletion.create(
messages=messages,
temperature=0,
top_p=1,
frequency_penalty=0,
presence_penalty=0,
model=self.model_version,
max_tokens=self.max_output_tokens,
)
model_output = response["choices"][0]["message"]["content"].strip()
logger.info(
"OpenAI Token Usage: " + str(response["usage"]).replace("\n", "")
)
except Exception as e:
logger.exception(e)
model_output = "Model Failure"
logger.debug(model_output)
answer, quotes_dict = process_answer(model_output, context_docs)
return answer, quotes_dict

View File

@@ -4,7 +4,7 @@ filelock==3.12.0
google-api-python-client==2.86.0
google-auth-httplib2==0.1.0
google-auth-oauthlib==1.0.0
openai==0.27.2
openai==0.27.6
playwright==1.32.1
pydantic==1.10.7
PyPDF2==3.0.1