Wrap errors in an object instead of plain dict (#623)

This commit is contained in:
Yuhong Sun 2023-10-24 16:07:45 -07:00 committed by GitHub
parent 890eb7901e
commit 17bd68be4c
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 76 additions and 76 deletions

View File

@ -31,6 +31,7 @@ from danswer.db.models import Persona
from danswer.db.models import User
from danswer.direct_qa.interfaces import DanswerAnswerPiece
from danswer.direct_qa.interfaces import DanswerChatModelOut
from danswer.direct_qa.interfaces import StreamingError
from danswer.direct_qa.qa_utils import get_usable_chunks
from danswer.llm.build import get_default_llm
from danswer.llm.llm import LLM
@ -187,41 +188,6 @@ def _drop_messages_history_overflow(
return prompt
def llm_contextless_chat_answer(
messages: list[ChatMessage],
system_text: str | None = None,
tokenizer: Callable | None = None,
) -> Iterator[str]:
try:
prompt_msgs = [translate_danswer_msg_to_langchain(msg) for msg in messages]
if system_text:
tokenizer = tokenizer or get_default_llm_tokenizer()
system_tokens = len(tokenizer(system_text))
system_msg = SystemMessage(content=system_text)
message_tokens = [msg.token_count for msg in messages] + [system_tokens]
else:
message_tokens = [msg.token_count for msg in messages]
last_msg_ind = _find_last_index(message_tokens)
remaining_user_msgs = prompt_msgs[last_msg_ind:]
if not remaining_user_msgs:
raise ValueError("Last user message is too long!")
if system_text:
all_msgs = [system_msg] + remaining_user_msgs
else:
all_msgs = remaining_user_msgs
return get_default_llm().stream(all_msgs)
except Exception as e:
logger.error(f"LLM failed to produce valid chat message, error: {e}")
return (msg for msg in [LLM_CHAT_FAILURE_MSG]) # needs to be an Iterator
def extract_citations_from_stream(
tokens: Iterator[str], links: list[str | None]
) -> Iterator[str]:
@ -277,6 +243,42 @@ def extract_citations_from_stream(
yield curr_segment
def llm_contextless_chat_answer(
messages: list[ChatMessage],
system_text: str | None = None,
tokenizer: Callable | None = None,
) -> Iterator[DanswerAnswerPiece | StreamingError]:
try:
prompt_msgs = [translate_danswer_msg_to_langchain(msg) for msg in messages]
if system_text:
tokenizer = tokenizer or get_default_llm_tokenizer()
system_tokens = len(tokenizer(system_text))
system_msg = SystemMessage(content=system_text)
message_tokens = [msg.token_count for msg in messages] + [system_tokens]
else:
message_tokens = [msg.token_count for msg in messages]
last_msg_ind = _find_last_index(message_tokens)
remaining_user_msgs = prompt_msgs[last_msg_ind:]
if not remaining_user_msgs:
raise ValueError("Last user message is too long!")
if system_text:
all_msgs = [system_msg] + remaining_user_msgs
else:
all_msgs = remaining_user_msgs
for token in get_default_llm().stream(all_msgs):
yield DanswerAnswerPiece(answer_piece=token)
except Exception as e:
logger.exception(f"LLM failed to produce valid chat message, error: {e}")
yield StreamingError(error=str(e))
def llm_contextual_chat_answer(
messages: list[ChatMessage],
persona: Persona,
@ -284,7 +286,7 @@ def llm_contextual_chat_answer(
tokenizer: Callable,
db_session: Session,
run_search_system_text: str = REQUIRE_DANSWER_SYSTEM_MSG,
) -> Iterator[str | list[InferenceChunk]]:
) -> Iterator[DanswerAnswerPiece | RetrievalDocs | StreamingError]:
last_message = messages[-1]
final_query_text = last_message.message
previous_messages = messages[:-1]
@ -337,7 +339,8 @@ def llm_contextual_chat_answer(
llm=llm,
filters=final_filters,
)
yield retrieved_chunks
yield RetrievalDocs(top_documents=chunks_to_search_docs(retrieved_chunks))
tool_result_str = format_danswer_chunks_for_chat(retrieved_chunks)
@ -373,11 +376,12 @@ def llm_contextual_chat_answer(
for chunk in retrieved_chunks
]
yield from extract_citations_from_stream(tokens, links)
for segment in extract_citations_from_stream(tokens, links):
yield DanswerAnswerPiece(answer_piece=segment)
except Exception as e:
logger.error(f"LLM failed to produce valid chat message, error: {e}")
yield LLM_CHAT_FAILURE_MSG # needs to be an Iterator
logger.exception(f"LLM failed to produce valid chat message, error: {e}")
yield StreamingError(error=str(e))
def llm_tools_enabled_chat_answer(
@ -386,7 +390,7 @@ def llm_tools_enabled_chat_answer(
user: User | None,
tokenizer: Callable,
db_session: Session,
) -> Iterator[str | list[InferenceChunk]]:
) -> Iterator[DanswerAnswerPiece | RetrievalDocs | StreamingError]:
retrieval_enabled = persona.retrieval_enabled
system_text = build_system_text_from_persona(persona)
hint_text = persona.hint_text
@ -441,7 +445,7 @@ def llm_tools_enabled_chat_answer(
for result in _parse_embedded_json_streamed_response(tokens):
if isinstance(result, DanswerAnswerPiece) and result.answer_piece:
yield result.answer_piece
yield result
final_answer_streamed = True
if isinstance(result, DanswerChatModelOut):
@ -474,7 +478,7 @@ def llm_tools_enabled_chat_answer(
llm=llm,
filters=final_filters,
)
yield retrieved_chunks
yield RetrievalDocs(top_documents=chunks_to_search_docs(retrieved_chunks))
tool_result_str = format_danswer_chunks_for_chat(retrieved_chunks)
else:
@ -512,23 +516,14 @@ def llm_tools_enabled_chat_answer(
for result in _parse_embedded_json_streamed_response(tokens):
if isinstance(result, DanswerAnswerPiece) and result.answer_piece:
yield result.answer_piece
yield result
final_answer_streamed = True
if final_answer_streamed is False:
raise RuntimeError("LLM did not to produce a Final Answer after tool call")
except Exception as e:
logger.error(f"LLM failed to produce valid chat message, error: {e}")
yield LLM_CHAT_FAILURE_MSG
def wrap_chat_package_in_model(
package: str | list[InferenceChunk],
) -> DanswerAnswerPiece | RetrievalDocs:
if isinstance(package, str):
return DanswerAnswerPiece(answer_piece=package)
elif isinstance(package, list):
return RetrievalDocs(top_documents=chunks_to_search_docs(package))
logger.exception(f"LLM failed to produce valid chat message, error: {e}")
yield StreamingError(error=str(e))
def llm_chat_answer(
@ -537,7 +532,7 @@ def llm_chat_answer(
tokenizer: Callable,
user: User | None,
db_session: Session,
) -> Iterator[DanswerAnswerPiece | RetrievalDocs]:
) -> Iterator[DanswerAnswerPiece | RetrievalDocs | StreamingError]:
# Common error cases to keep in mind:
# - User asks question about something long ago, due to context limit, the message is dropped
# - Tool use gives wrong/irrelevant results, model gets confused by the noise
@ -547,36 +542,33 @@ def llm_chat_answer(
# No setting/persona available therefore no retrieval and no additional tools
if persona is None:
for token in llm_contextless_chat_answer(messages):
yield DanswerAnswerPiece(answer_piece=token)
return llm_contextless_chat_answer(messages)
# Persona is configured but with retrieval off and no tools
# therefore cannot retrieve any context so contextless
elif persona.retrieval_enabled is False and not persona.tools:
for token in llm_contextless_chat_answer(
return llm_contextless_chat_answer(
messages, system_text=persona.system_text, tokenizer=tokenizer
):
yield DanswerAnswerPiece(answer_piece=token)
)
# No additional tools outside of Danswer retrieval, can use a more basic prompt
# Doesn't require tool calling output format (all LLM outputs are therefore valid)
elif persona.retrieval_enabled and not persona.tools and not FORCE_TOOL_PROMPT:
for package in llm_contextual_chat_answer(
return llm_contextual_chat_answer(
messages=messages,
persona=persona,
tokenizer=tokenizer,
user=user,
db_session=db_session,
):
yield wrap_chat_package_in_model(package)
)
# Use most flexible/complex prompt format
else:
for package in llm_tools_enabled_chat_answer(
messages=messages,
persona=persona,
tokenizer=tokenizer,
user=user,
db_session=db_session,
):
yield wrap_chat_package_in_model(package)
# Use most flexible/complex prompt format that allows arbitrary tool calls
# that are configured in the persona file
# WARNING: this flow does not work well with weaker LLMs (anything below GPT-4)
return llm_tools_enabled_chat_answer(
messages=messages,
persona=persona,
tokenizer=tokenizer,
user=user,
db_session=db_session,
)

View File

@ -8,6 +8,10 @@ from danswer.chunking.models import InferenceChunk
from danswer.direct_qa.models import LLMMetricsContainer
class StreamingError(BaseModel):
error: str
class DanswerAnswer(BaseModel):
answer: str | None

View File

@ -4,6 +4,7 @@ from collections.abc import Iterator
from danswer.configs.constants import CODE_BLOCK_PAT
from danswer.configs.constants import GENERAL_SEP_PAT
from danswer.direct_qa.interfaces import DanswerAnswerPiece
from danswer.direct_qa.interfaces import StreamingError
from danswer.direct_qa.qa_block import dict_based_prompt_to_langchain_prompt
from danswer.llm.build import get_default_llm
from danswer.server.models import QueryValidationResponse
@ -135,6 +136,7 @@ def stream_query_answerability(user_query: str) -> Iterator[str]:
)
except Exception as e:
# exception is logged in the answer_question method, no need to re-log
yield get_json_line({"error": str(e)})
error = StreamingError(error=str(e))
yield get_json_line(error.dict())
logger.exception("Failed to validate Query")
return

View File

@ -25,6 +25,7 @@ from danswer.direct_qa.answer_question import answer_qa_query
from danswer.direct_qa.exceptions import OpenAIKeyMissing
from danswer.direct_qa.exceptions import UnknownModelError
from danswer.direct_qa.interfaces import DanswerAnswerPiece
from danswer.direct_qa.interfaces import StreamingError
from danswer.direct_qa.llm_utils import get_default_qa_model
from danswer.direct_qa.qa_utils import get_usable_chunks
from danswer.search.access_filters import build_access_filters_for_user
@ -366,7 +367,8 @@ def stream_direct_qa(
qa_model = get_default_qa_model()
except (UnknownModelError, OpenAIKeyMissing) as e:
logger.exception("Unable to get QA model")
yield get_json_line({"error": str(e)})
error = StreamingError(error=str(e))
yield get_json_line(error.dict())
return
# remove chunks marked as not applicable for QA (e.g. Google Drive file