mirror of
https://github.com/danswer-ai/danswer.git
synced 2025-07-28 13:53:28 +02:00
DAN-26 Enable GPT 4 through chat completion endpoint (#10)
Also touched up front page README which had a typo
This commit is contained in:
@@ -19,6 +19,7 @@ MODEL_CACHE_FOLDER = os.environ.get("TRANSFORMERS_CACHE")
|
||||
BATCH_SIZE_ENCODE_CHUNKS = 8
|
||||
|
||||
# QA Model API Configs
|
||||
# https://platform.openai.com/docs/models/model-endpoint-compatibility
|
||||
INTERNAL_MODEL_VERSION = os.environ.get("INTERNAL_MODEL", "openai-completion")
|
||||
OPENAPI_MODEL_VERSION = "text-davinci-003"
|
||||
OPENAI_MAX_OUTPUT_TOKENS = 200
|
||||
OPENAPI_MODEL_VERSION = os.environ.get("OPENAI_MODEL_VERSION", "text-davinci-003")
|
||||
OPENAI_MAX_OUTPUT_TOKENS = 400
|
||||
|
@@ -1,6 +1,15 @@
|
||||
from danswer.configs.model_configs import INTERNAL_MODEL_VERSION
|
||||
from danswer.direct_qa.interfaces import QAModel
|
||||
from danswer.direct_qa.question_answer import OpenAIChatCompletionQA
|
||||
from danswer.direct_qa.question_answer import OpenAICompletionQA
|
||||
|
||||
|
||||
def get_default_backend_qa_model() -> QAModel:
|
||||
return OpenAICompletionQA()
|
||||
def get_default_backend_qa_model(
|
||||
internal_model: str = INTERNAL_MODEL_VERSION,
|
||||
) -> QAModel:
|
||||
if internal_model == "openai-completion":
|
||||
return OpenAICompletionQA()
|
||||
elif internal_model == "openai-chat-completion":
|
||||
return OpenAIChatCompletionQA()
|
||||
else:
|
||||
raise ValueError("Wrong internal QA model set.")
|
||||
|
@@ -21,3 +21,42 @@ def generic_prompt_processor(question: str, documents: list[str]) -> str:
|
||||
prompt += f"{QUESTION_PAT}\n{question}\n"
|
||||
prompt += f"{ANSWER_PAT}\n"
|
||||
return prompt
|
||||
|
||||
|
||||
def openai_chat_completion_processor(
|
||||
question: str, documents: list[str]
|
||||
) -> list[dict[str, str]]:
|
||||
sample_quote = "Quote:\nThe hotdogs are freshly cooked.\n\nQuote:\nThey are very cheap at only a dollar each."
|
||||
role_msg = (
|
||||
f"You are a Question Answering assistant that answers queries based on provided documents. "
|
||||
f'You will be asked to acknowledge a set of documents and then provide one "{ANSWER_PAT}" and '
|
||||
f'as many "{QUOTE_PAT}" sections as is relevant to back up your answer. '
|
||||
f"Answer the question directly and concisely. "
|
||||
f"Each quote should be a single continuous segment from a document. "
|
||||
f'If the query cannot be answered based on the documents, say "{UNCERTAINTY_PAT}". '
|
||||
f"An example of quote sections may look like:\n{sample_quote}"
|
||||
)
|
||||
|
||||
messages = [
|
||||
{"role": "system", "content": role_msg},
|
||||
]
|
||||
for document in documents:
|
||||
messages.extend(
|
||||
[
|
||||
{
|
||||
"role": "user",
|
||||
"content": f"Acknowledge the following document:\n{document}",
|
||||
},
|
||||
{"role": "assistant", "content": "Acknowledged"},
|
||||
]
|
||||
)
|
||||
|
||||
messages.append(
|
||||
{
|
||||
"role": "user",
|
||||
"content": f"Please now answer the following query based on the previously provided "
|
||||
f"documents and quote the relevant sections of the documents\n{question}",
|
||||
},
|
||||
)
|
||||
|
||||
return messages
|
||||
|
@@ -19,6 +19,7 @@ from danswer.configs.model_configs import OPENAPI_MODEL_VERSION
|
||||
from danswer.direct_qa.interfaces import QAModel
|
||||
from danswer.direct_qa.qa_prompts import ANSWER_PAT
|
||||
from danswer.direct_qa.qa_prompts import generic_prompt_processor
|
||||
from danswer.direct_qa.qa_prompts import openai_chat_completion_processor
|
||||
from danswer.direct_qa.qa_prompts import QUOTE_PAT
|
||||
from danswer.direct_qa.qa_prompts import UNCERTAINTY_PAT
|
||||
from danswer.utils.logging import setup_logger
|
||||
@@ -177,3 +178,48 @@ class OpenAICompletionQA(QAModel):
|
||||
|
||||
answer, quotes_dict = process_answer(model_output, context_docs)
|
||||
return answer, quotes_dict
|
||||
|
||||
|
||||
class OpenAIChatCompletionQA(QAModel):
|
||||
def __init__(
|
||||
self,
|
||||
prompt_processor: Callable[
|
||||
[str, list[str]], list[dict[str, str]]
|
||||
] = openai_chat_completion_processor,
|
||||
model_version: str = OPENAPI_MODEL_VERSION,
|
||||
max_output_tokens: int = OPENAI_MAX_OUTPUT_TOKENS,
|
||||
) -> None:
|
||||
self.prompt_processor = prompt_processor
|
||||
self.model_version = model_version
|
||||
self.max_output_tokens = max_output_tokens
|
||||
|
||||
@log_function_time()
|
||||
def answer_question(
|
||||
self, query: str, context_docs: list[InferenceChunk]
|
||||
) -> tuple[str | None, dict[str, dict[str, str | int | None]] | None]:
|
||||
top_contents = [ranked_chunk.content for ranked_chunk in context_docs]
|
||||
messages = self.prompt_processor(query, top_contents)
|
||||
logger.debug(messages)
|
||||
|
||||
try:
|
||||
response = openai.ChatCompletion.create(
|
||||
messages=messages,
|
||||
temperature=0,
|
||||
top_p=1,
|
||||
frequency_penalty=0,
|
||||
presence_penalty=0,
|
||||
model=self.model_version,
|
||||
max_tokens=self.max_output_tokens,
|
||||
)
|
||||
model_output = response["choices"][0]["message"]["content"].strip()
|
||||
logger.info(
|
||||
"OpenAI Token Usage: " + str(response["usage"]).replace("\n", "")
|
||||
)
|
||||
except Exception as e:
|
||||
logger.exception(e)
|
||||
model_output = "Model Failure"
|
||||
|
||||
logger.debug(model_output)
|
||||
|
||||
answer, quotes_dict = process_answer(model_output, context_docs)
|
||||
return answer, quotes_dict
|
||||
|
@@ -4,7 +4,7 @@ filelock==3.12.0
|
||||
google-api-python-client==2.86.0
|
||||
google-auth-httplib2==0.1.0
|
||||
google-auth-oauthlib==1.0.0
|
||||
openai==0.27.2
|
||||
openai==0.27.6
|
||||
playwright==1.32.1
|
||||
pydantic==1.10.7
|
||||
PyPDF2==3.0.1
|
||||
|
Reference in New Issue
Block a user