DAN-60 Add streaming for chat model (#46)

This commit is contained in:
Yuhong Sun 2023-05-13 23:05:06 -07:00 committed by GitHub
parent 17bc0f89ff
commit dc4fc02ba5
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 68 additions and 25 deletions

View File

@ -20,6 +20,6 @@ BATCH_SIZE_ENCODE_CHUNKS = 8
# QA Model API Configs
# https://platform.openai.com/docs/models/model-endpoint-compatibility
INTERNAL_MODEL_VERSION = os.environ.get("INTERNAL_MODEL", "openai-completion")
OPENAPI_MODEL_VERSION = os.environ.get("OPENAI_MODEL_VERSION", "text-davinci-003")
INTERNAL_MODEL_VERSION = os.environ.get("INTERNAL_MODEL", "openai-chat-completion")
OPENAPI_MODEL_VERSION = os.environ.get("OPENAI_MODEL_VERSION", "gpt-4")
OPENAI_MAX_OUTPUT_TOKENS = 512

View File

@ -6,17 +6,12 @@ ANSWER_PAT = "Answer:"
UNCERTAINTY_PAT = "?"
QUOTE_PAT = "Quote:"
SYSTEM_ROLE = "You are a Question Answering system that answers queries based on provided documents. "
BASE_PROMPT = (
f"Answer the query based on provided documents and quote relevant sections. "
f"Respond with a json containing a concise answer and up to three most relevant quotes from the documents. "
f"The quotes must be EXACT substrings from the documents.\n"
)
UNABLE_TO_FIND_JSON_MSG = (
"If the query cannot be answered based on the documents, respond with {}. "
)
SAMPLE_QUESTION = "Where is the Eiffel Tower?"
@ -106,12 +101,21 @@ def freeform_processor(question: str, documents: list[str]) -> str:
def json_chat_processor(question: str, documents: list[str]) -> list[dict[str, str]]:
role_msg = (
SYSTEM_ROLE
+ 'Start by reading the following documents and responding with "Acknowledged"'
intro_msg = (
"You are a Question Answering assistant that answers queries based on provided documents.\n"
'Start by reading the following documents and responding with "Acknowledged".'
)
messages = [{"role": "system", "content": role_msg}]
task_msg = (
"Now answer the next user query based on documents above and quote relevant sections.\n"
"Respond with a JSON containing the answer and up to three most relevant quotes from the documents.\n"
"If the query cannot be answered based on the documents, do not provide an answer.\n"
"All quotes MUST be EXACT substrings from provided documents.\n"
"Your responses should be informative and concise.\n"
"You MUST prioritize information from provided documents over internal knowledge.\n"
f"Sample response:\n{json.dumps(SAMPLE_JSON_RESPONSE)}"
)
messages = [{"role": "system", "content": intro_msg}]
for document in documents:
messages.extend(
@ -123,18 +127,10 @@ def json_chat_processor(question: str, documents: list[str]) -> list[dict[str, s
{"role": "assistant", "content": "Acknowledged"},
]
)
sample_msg = (
f"Now answer the user query based on documents above and quote relevant sections. "
f"Respond with a json containing a concise answer and up to three most relevant quotes from the documents.\n"
f"Sample response: {json.dumps(SAMPLE_JSON_RESPONSE)}"
)
messages.append({"role": "system", "content": sample_msg})
messages.append({"role": "system", "content": task_msg})
messages.append({"role": "user", "content": f"{QUESTION_PAT}\n{question}\n"})
# Note that the below will be dropped in reflexion if used
messages.append({"role": "assistant", "content": "Answer Json:\n"})
return messages

View File

@ -324,10 +324,57 @@ class OpenAIChatCompletionQA(QAModel):
answer, quotes_dict = process_answer(model_output, context_docs)
return answer, quotes_dict
@log_function_time()
def answer_question_stream(
self, query: str, context_docs: list[InferenceChunk]
) -> Any:
raise NotImplementedError(
"Danswer with chat completion does not support streaming yet"
)
) -> Generator[dict[str, Any] | None, None, None]:
top_contents = [ranked_chunk.content for ranked_chunk in context_docs]
messages = self.prompt_processor(query, top_contents)
logger.debug(messages)
try:
response = openai.ChatCompletion.create(
messages=messages,
temperature=0,
top_p=1,
frequency_penalty=0,
presence_penalty=0,
model=self.model_version,
max_tokens=self.max_output_tokens,
stream=True,
)
model_output = ""
found_answer_start = False
found_answer_end = False
for event in response:
event_dict = event["choices"][0]["delta"]
if (
"content" not in event_dict
): # could be a role message or empty termination
continue
event_text = event_dict["content"]
model_previous = model_output
model_output += event_text
if not found_answer_start and '{"answer":"' in model_output.replace(
" ", ""
).replace("\n", ""):
found_answer_start = True
continue
if found_answer_start and not found_answer_end:
if stream_answer_end(model_previous, event_text):
found_answer_end = True
continue
yield {"answer_data": event_text}
except Exception as e:
logger.exception(e)
model_output = "Model Failure"
logger.debug(model_output)
answer, quotes_dict = process_answer(model_output, context_docs)
logger.info(answer)
yield quotes_dict