mirror of
https://github.com/danswer-ai/danswer.git
synced 2025-06-05 12:39:33 +02:00
DAN-60 Add streaming for chat model (#46)
This commit is contained in:
parent
17bc0f89ff
commit
dc4fc02ba5
@ -20,6 +20,6 @@ BATCH_SIZE_ENCODE_CHUNKS = 8
|
|||||||
|
|
||||||
# QA Model API Configs
|
# QA Model API Configs
|
||||||
# https://platform.openai.com/docs/models/model-endpoint-compatibility
|
# https://platform.openai.com/docs/models/model-endpoint-compatibility
|
||||||
INTERNAL_MODEL_VERSION = os.environ.get("INTERNAL_MODEL", "openai-completion")
|
INTERNAL_MODEL_VERSION = os.environ.get("INTERNAL_MODEL", "openai-chat-completion")
|
||||||
OPENAPI_MODEL_VERSION = os.environ.get("OPENAI_MODEL_VERSION", "text-davinci-003")
|
OPENAPI_MODEL_VERSION = os.environ.get("OPENAI_MODEL_VERSION", "gpt-4")
|
||||||
OPENAI_MAX_OUTPUT_TOKENS = 512
|
OPENAI_MAX_OUTPUT_TOKENS = 512
|
||||||
|
@ -6,17 +6,12 @@ ANSWER_PAT = "Answer:"
|
|||||||
UNCERTAINTY_PAT = "?"
|
UNCERTAINTY_PAT = "?"
|
||||||
QUOTE_PAT = "Quote:"
|
QUOTE_PAT = "Quote:"
|
||||||
|
|
||||||
SYSTEM_ROLE = "You are a Question Answering system that answers queries based on provided documents. "
|
|
||||||
|
|
||||||
BASE_PROMPT = (
|
BASE_PROMPT = (
|
||||||
f"Answer the query based on provided documents and quote relevant sections. "
|
f"Answer the query based on provided documents and quote relevant sections. "
|
||||||
f"Respond with a json containing a concise answer and up to three most relevant quotes from the documents. "
|
f"Respond with a json containing a concise answer and up to three most relevant quotes from the documents. "
|
||||||
f"The quotes must be EXACT substrings from the documents.\n"
|
f"The quotes must be EXACT substrings from the documents.\n"
|
||||||
)
|
)
|
||||||
|
|
||||||
UNABLE_TO_FIND_JSON_MSG = (
|
|
||||||
"If the query cannot be answered based on the documents, respond with {}. "
|
|
||||||
)
|
|
||||||
|
|
||||||
SAMPLE_QUESTION = "Where is the Eiffel Tower?"
|
SAMPLE_QUESTION = "Where is the Eiffel Tower?"
|
||||||
|
|
||||||
@ -106,12 +101,21 @@ def freeform_processor(question: str, documents: list[str]) -> str:
|
|||||||
|
|
||||||
|
|
||||||
def json_chat_processor(question: str, documents: list[str]) -> list[dict[str, str]]:
|
def json_chat_processor(question: str, documents: list[str]) -> list[dict[str, str]]:
|
||||||
role_msg = (
|
intro_msg = (
|
||||||
SYSTEM_ROLE
|
"You are a Question Answering assistant that answers queries based on provided documents.\n"
|
||||||
+ 'Start by reading the following documents and responding with "Acknowledged"'
|
'Start by reading the following documents and responding with "Acknowledged".'
|
||||||
)
|
)
|
||||||
|
|
||||||
messages = [{"role": "system", "content": role_msg}]
|
task_msg = (
|
||||||
|
"Now answer the next user query based on documents above and quote relevant sections.\n"
|
||||||
|
"Respond with a JSON containing the answer and up to three most relevant quotes from the documents.\n"
|
||||||
|
"If the query cannot be answered based on the documents, do not provide an answer.\n"
|
||||||
|
"All quotes MUST be EXACT substrings from provided documents.\n"
|
||||||
|
"Your responses should be informative and concise.\n"
|
||||||
|
"You MUST prioritize information from provided documents over internal knowledge.\n"
|
||||||
|
f"Sample response:\n{json.dumps(SAMPLE_JSON_RESPONSE)}"
|
||||||
|
)
|
||||||
|
messages = [{"role": "system", "content": intro_msg}]
|
||||||
|
|
||||||
for document in documents:
|
for document in documents:
|
||||||
messages.extend(
|
messages.extend(
|
||||||
@ -123,18 +127,10 @@ def json_chat_processor(question: str, documents: list[str]) -> list[dict[str, s
|
|||||||
{"role": "assistant", "content": "Acknowledged"},
|
{"role": "assistant", "content": "Acknowledged"},
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
sample_msg = (
|
messages.append({"role": "system", "content": task_msg})
|
||||||
f"Now answer the user query based on documents above and quote relevant sections. "
|
|
||||||
f"Respond with a json containing a concise answer and up to three most relevant quotes from the documents.\n"
|
|
||||||
f"Sample response: {json.dumps(SAMPLE_JSON_RESPONSE)}"
|
|
||||||
)
|
|
||||||
messages.append({"role": "system", "content": sample_msg})
|
|
||||||
|
|
||||||
messages.append({"role": "user", "content": f"{QUESTION_PAT}\n{question}\n"})
|
messages.append({"role": "user", "content": f"{QUESTION_PAT}\n{question}\n"})
|
||||||
|
|
||||||
# Note that the below will be dropped in reflexion if used
|
|
||||||
messages.append({"role": "assistant", "content": "Answer Json:\n"})
|
|
||||||
|
|
||||||
return messages
|
return messages
|
||||||
|
|
||||||
|
|
||||||
|
@ -324,10 +324,57 @@ class OpenAIChatCompletionQA(QAModel):
|
|||||||
answer, quotes_dict = process_answer(model_output, context_docs)
|
answer, quotes_dict = process_answer(model_output, context_docs)
|
||||||
return answer, quotes_dict
|
return answer, quotes_dict
|
||||||
|
|
||||||
@log_function_time()
|
|
||||||
def answer_question_stream(
|
def answer_question_stream(
|
||||||
self, query: str, context_docs: list[InferenceChunk]
|
self, query: str, context_docs: list[InferenceChunk]
|
||||||
) -> Any:
|
) -> Generator[dict[str, Any] | None, None, None]:
|
||||||
raise NotImplementedError(
|
top_contents = [ranked_chunk.content for ranked_chunk in context_docs]
|
||||||
"Danswer with chat completion does not support streaming yet"
|
messages = self.prompt_processor(query, top_contents)
|
||||||
|
logger.debug(messages)
|
||||||
|
|
||||||
|
try:
|
||||||
|
response = openai.ChatCompletion.create(
|
||||||
|
messages=messages,
|
||||||
|
temperature=0,
|
||||||
|
top_p=1,
|
||||||
|
frequency_penalty=0,
|
||||||
|
presence_penalty=0,
|
||||||
|
model=self.model_version,
|
||||||
|
max_tokens=self.max_output_tokens,
|
||||||
|
stream=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
model_output = ""
|
||||||
|
found_answer_start = False
|
||||||
|
found_answer_end = False
|
||||||
|
for event in response:
|
||||||
|
event_dict = event["choices"][0]["delta"]
|
||||||
|
if (
|
||||||
|
"content" not in event_dict
|
||||||
|
): # could be a role message or empty termination
|
||||||
|
continue
|
||||||
|
event_text = event_dict["content"]
|
||||||
|
model_previous = model_output
|
||||||
|
model_output += event_text
|
||||||
|
|
||||||
|
if not found_answer_start and '{"answer":"' in model_output.replace(
|
||||||
|
" ", ""
|
||||||
|
).replace("\n", ""):
|
||||||
|
found_answer_start = True
|
||||||
|
continue
|
||||||
|
|
||||||
|
if found_answer_start and not found_answer_end:
|
||||||
|
if stream_answer_end(model_previous, event_text):
|
||||||
|
found_answer_end = True
|
||||||
|
continue
|
||||||
|
yield {"answer_data": event_text}
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.exception(e)
|
||||||
|
model_output = "Model Failure"
|
||||||
|
|
||||||
|
logger.debug(model_output)
|
||||||
|
|
||||||
|
answer, quotes_dict = process_answer(model_output, context_docs)
|
||||||
|
logger.info(answer)
|
||||||
|
|
||||||
|
yield quotes_dict
|
||||||
|
Loading…
x
Reference in New Issue
Block a user