DAN-60 Add streaming for chat model (#46)

This commit is contained in:
Yuhong Sun 2023-05-13 23:05:06 -07:00 committed by GitHub
parent 17bc0f89ff
commit dc4fc02ba5
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 68 additions and 25 deletions

View File

@ -20,6 +20,6 @@ BATCH_SIZE_ENCODE_CHUNKS = 8
# QA Model API Configs # QA Model API Configs
# https://platform.openai.com/docs/models/model-endpoint-compatibility # https://platform.openai.com/docs/models/model-endpoint-compatibility
INTERNAL_MODEL_VERSION = os.environ.get("INTERNAL_MODEL", "openai-completion") INTERNAL_MODEL_VERSION = os.environ.get("INTERNAL_MODEL", "openai-chat-completion")
OPENAPI_MODEL_VERSION = os.environ.get("OPENAI_MODEL_VERSION", "text-davinci-003") OPENAPI_MODEL_VERSION = os.environ.get("OPENAI_MODEL_VERSION", "gpt-4")
OPENAI_MAX_OUTPUT_TOKENS = 512 OPENAI_MAX_OUTPUT_TOKENS = 512

View File

@ -6,17 +6,12 @@ ANSWER_PAT = "Answer:"
UNCERTAINTY_PAT = "?" UNCERTAINTY_PAT = "?"
QUOTE_PAT = "Quote:" QUOTE_PAT = "Quote:"
SYSTEM_ROLE = "You are a Question Answering system that answers queries based on provided documents. "
BASE_PROMPT = ( BASE_PROMPT = (
f"Answer the query based on provided documents and quote relevant sections. " f"Answer the query based on provided documents and quote relevant sections. "
f"Respond with a json containing a concise answer and up to three most relevant quotes from the documents. " f"Respond with a json containing a concise answer and up to three most relevant quotes from the documents. "
f"The quotes must be EXACT substrings from the documents.\n" f"The quotes must be EXACT substrings from the documents.\n"
) )
UNABLE_TO_FIND_JSON_MSG = (
"If the query cannot be answered based on the documents, respond with {}. "
)
SAMPLE_QUESTION = "Where is the Eiffel Tower?" SAMPLE_QUESTION = "Where is the Eiffel Tower?"
@ -106,12 +101,21 @@ def freeform_processor(question: str, documents: list[str]) -> str:
def json_chat_processor(question: str, documents: list[str]) -> list[dict[str, str]]: def json_chat_processor(question: str, documents: list[str]) -> list[dict[str, str]]:
role_msg = ( intro_msg = (
SYSTEM_ROLE "You are a Question Answering assistant that answers queries based on provided documents.\n"
+ 'Start by reading the following documents and responding with "Acknowledged"' 'Start by reading the following documents and responding with "Acknowledged".'
) )
messages = [{"role": "system", "content": role_msg}] task_msg = (
"Now answer the next user query based on documents above and quote relevant sections.\n"
"Respond with a JSON containing the answer and up to three most relevant quotes from the documents.\n"
"If the query cannot be answered based on the documents, do not provide an answer.\n"
"All quotes MUST be EXACT substrings from provided documents.\n"
"Your responses should be informative and concise.\n"
"You MUST prioritize information from provided documents over internal knowledge.\n"
f"Sample response:\n{json.dumps(SAMPLE_JSON_RESPONSE)}"
)
messages = [{"role": "system", "content": intro_msg}]
for document in documents: for document in documents:
messages.extend( messages.extend(
@ -123,18 +127,10 @@ def json_chat_processor(question: str, documents: list[str]) -> list[dict[str, s
{"role": "assistant", "content": "Acknowledged"}, {"role": "assistant", "content": "Acknowledged"},
] ]
) )
sample_msg = ( messages.append({"role": "system", "content": task_msg})
f"Now answer the user query based on documents above and quote relevant sections. "
f"Respond with a json containing a concise answer and up to three most relevant quotes from the documents.\n"
f"Sample response: {json.dumps(SAMPLE_JSON_RESPONSE)}"
)
messages.append({"role": "system", "content": sample_msg})
messages.append({"role": "user", "content": f"{QUESTION_PAT}\n{question}\n"}) messages.append({"role": "user", "content": f"{QUESTION_PAT}\n{question}\n"})
# Note that the below will be dropped in reflexion if used
messages.append({"role": "assistant", "content": "Answer Json:\n"})
return messages return messages

View File

@ -324,10 +324,57 @@ class OpenAIChatCompletionQA(QAModel):
answer, quotes_dict = process_answer(model_output, context_docs) answer, quotes_dict = process_answer(model_output, context_docs)
return answer, quotes_dict return answer, quotes_dict
@log_function_time()
def answer_question_stream( def answer_question_stream(
self, query: str, context_docs: list[InferenceChunk] self, query: str, context_docs: list[InferenceChunk]
) -> Any: ) -> Generator[dict[str, Any] | None, None, None]:
raise NotImplementedError( top_contents = [ranked_chunk.content for ranked_chunk in context_docs]
"Danswer with chat completion does not support streaming yet" messages = self.prompt_processor(query, top_contents)
logger.debug(messages)
try:
response = openai.ChatCompletion.create(
messages=messages,
temperature=0,
top_p=1,
frequency_penalty=0,
presence_penalty=0,
model=self.model_version,
max_tokens=self.max_output_tokens,
stream=True,
) )
model_output = ""
found_answer_start = False
found_answer_end = False
for event in response:
event_dict = event["choices"][0]["delta"]
if (
"content" not in event_dict
): # could be a role message or empty termination
continue
event_text = event_dict["content"]
model_previous = model_output
model_output += event_text
if not found_answer_start and '{"answer":"' in model_output.replace(
" ", ""
).replace("\n", ""):
found_answer_start = True
continue
if found_answer_start and not found_answer_end:
if stream_answer_end(model_previous, event_text):
found_answer_end = True
continue
yield {"answer_data": event_text}
except Exception as e:
logger.exception(e)
model_output = "Model Failure"
logger.debug(model_output)
answer, quotes_dict = process_answer(model_output, context_docs)
logger.info(answer)
yield quotes_dict