mirror of
https://github.com/danswer-ai/danswer.git
synced 2025-04-10 04:49:29 +02:00
Wrap errors in an object instead of plain dict (#623)
This commit is contained in:
parent
890eb7901e
commit
17bd68be4c
@ -31,6 +31,7 @@ from danswer.db.models import Persona
|
||||
from danswer.db.models import User
|
||||
from danswer.direct_qa.interfaces import DanswerAnswerPiece
|
||||
from danswer.direct_qa.interfaces import DanswerChatModelOut
|
||||
from danswer.direct_qa.interfaces import StreamingError
|
||||
from danswer.direct_qa.qa_utils import get_usable_chunks
|
||||
from danswer.llm.build import get_default_llm
|
||||
from danswer.llm.llm import LLM
|
||||
@ -187,41 +188,6 @@ def _drop_messages_history_overflow(
|
||||
return prompt
|
||||
|
||||
|
||||
def llm_contextless_chat_answer(
|
||||
messages: list[ChatMessage],
|
||||
system_text: str | None = None,
|
||||
tokenizer: Callable | None = None,
|
||||
) -> Iterator[str]:
|
||||
try:
|
||||
prompt_msgs = [translate_danswer_msg_to_langchain(msg) for msg in messages]
|
||||
|
||||
if system_text:
|
||||
tokenizer = tokenizer or get_default_llm_tokenizer()
|
||||
system_tokens = len(tokenizer(system_text))
|
||||
system_msg = SystemMessage(content=system_text)
|
||||
|
||||
message_tokens = [msg.token_count for msg in messages] + [system_tokens]
|
||||
else:
|
||||
message_tokens = [msg.token_count for msg in messages]
|
||||
|
||||
last_msg_ind = _find_last_index(message_tokens)
|
||||
|
||||
remaining_user_msgs = prompt_msgs[last_msg_ind:]
|
||||
if not remaining_user_msgs:
|
||||
raise ValueError("Last user message is too long!")
|
||||
|
||||
if system_text:
|
||||
all_msgs = [system_msg] + remaining_user_msgs
|
||||
else:
|
||||
all_msgs = remaining_user_msgs
|
||||
|
||||
return get_default_llm().stream(all_msgs)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"LLM failed to produce valid chat message, error: {e}")
|
||||
return (msg for msg in [LLM_CHAT_FAILURE_MSG]) # needs to be an Iterator
|
||||
|
||||
|
||||
def extract_citations_from_stream(
|
||||
tokens: Iterator[str], links: list[str | None]
|
||||
) -> Iterator[str]:
|
||||
@ -277,6 +243,42 @@ def extract_citations_from_stream(
|
||||
yield curr_segment
|
||||
|
||||
|
||||
def llm_contextless_chat_answer(
|
||||
messages: list[ChatMessage],
|
||||
system_text: str | None = None,
|
||||
tokenizer: Callable | None = None,
|
||||
) -> Iterator[DanswerAnswerPiece | StreamingError]:
|
||||
try:
|
||||
prompt_msgs = [translate_danswer_msg_to_langchain(msg) for msg in messages]
|
||||
|
||||
if system_text:
|
||||
tokenizer = tokenizer or get_default_llm_tokenizer()
|
||||
system_tokens = len(tokenizer(system_text))
|
||||
system_msg = SystemMessage(content=system_text)
|
||||
|
||||
message_tokens = [msg.token_count for msg in messages] + [system_tokens]
|
||||
else:
|
||||
message_tokens = [msg.token_count for msg in messages]
|
||||
|
||||
last_msg_ind = _find_last_index(message_tokens)
|
||||
|
||||
remaining_user_msgs = prompt_msgs[last_msg_ind:]
|
||||
if not remaining_user_msgs:
|
||||
raise ValueError("Last user message is too long!")
|
||||
|
||||
if system_text:
|
||||
all_msgs = [system_msg] + remaining_user_msgs
|
||||
else:
|
||||
all_msgs = remaining_user_msgs
|
||||
|
||||
for token in get_default_llm().stream(all_msgs):
|
||||
yield DanswerAnswerPiece(answer_piece=token)
|
||||
|
||||
except Exception as e:
|
||||
logger.exception(f"LLM failed to produce valid chat message, error: {e}")
|
||||
yield StreamingError(error=str(e))
|
||||
|
||||
|
||||
def llm_contextual_chat_answer(
|
||||
messages: list[ChatMessage],
|
||||
persona: Persona,
|
||||
@ -284,7 +286,7 @@ def llm_contextual_chat_answer(
|
||||
tokenizer: Callable,
|
||||
db_session: Session,
|
||||
run_search_system_text: str = REQUIRE_DANSWER_SYSTEM_MSG,
|
||||
) -> Iterator[str | list[InferenceChunk]]:
|
||||
) -> Iterator[DanswerAnswerPiece | RetrievalDocs | StreamingError]:
|
||||
last_message = messages[-1]
|
||||
final_query_text = last_message.message
|
||||
previous_messages = messages[:-1]
|
||||
@ -337,7 +339,8 @@ def llm_contextual_chat_answer(
|
||||
llm=llm,
|
||||
filters=final_filters,
|
||||
)
|
||||
yield retrieved_chunks
|
||||
|
||||
yield RetrievalDocs(top_documents=chunks_to_search_docs(retrieved_chunks))
|
||||
|
||||
tool_result_str = format_danswer_chunks_for_chat(retrieved_chunks)
|
||||
|
||||
@ -373,11 +376,12 @@ def llm_contextual_chat_answer(
|
||||
for chunk in retrieved_chunks
|
||||
]
|
||||
|
||||
yield from extract_citations_from_stream(tokens, links)
|
||||
for segment in extract_citations_from_stream(tokens, links):
|
||||
yield DanswerAnswerPiece(answer_piece=segment)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"LLM failed to produce valid chat message, error: {e}")
|
||||
yield LLM_CHAT_FAILURE_MSG # needs to be an Iterator
|
||||
logger.exception(f"LLM failed to produce valid chat message, error: {e}")
|
||||
yield StreamingError(error=str(e))
|
||||
|
||||
|
||||
def llm_tools_enabled_chat_answer(
|
||||
@ -386,7 +390,7 @@ def llm_tools_enabled_chat_answer(
|
||||
user: User | None,
|
||||
tokenizer: Callable,
|
||||
db_session: Session,
|
||||
) -> Iterator[str | list[InferenceChunk]]:
|
||||
) -> Iterator[DanswerAnswerPiece | RetrievalDocs | StreamingError]:
|
||||
retrieval_enabled = persona.retrieval_enabled
|
||||
system_text = build_system_text_from_persona(persona)
|
||||
hint_text = persona.hint_text
|
||||
@ -441,7 +445,7 @@ def llm_tools_enabled_chat_answer(
|
||||
|
||||
for result in _parse_embedded_json_streamed_response(tokens):
|
||||
if isinstance(result, DanswerAnswerPiece) and result.answer_piece:
|
||||
yield result.answer_piece
|
||||
yield result
|
||||
final_answer_streamed = True
|
||||
|
||||
if isinstance(result, DanswerChatModelOut):
|
||||
@ -474,7 +478,7 @@ def llm_tools_enabled_chat_answer(
|
||||
llm=llm,
|
||||
filters=final_filters,
|
||||
)
|
||||
yield retrieved_chunks
|
||||
yield RetrievalDocs(top_documents=chunks_to_search_docs(retrieved_chunks))
|
||||
|
||||
tool_result_str = format_danswer_chunks_for_chat(retrieved_chunks)
|
||||
else:
|
||||
@ -512,23 +516,14 @@ def llm_tools_enabled_chat_answer(
|
||||
|
||||
for result in _parse_embedded_json_streamed_response(tokens):
|
||||
if isinstance(result, DanswerAnswerPiece) and result.answer_piece:
|
||||
yield result.answer_piece
|
||||
yield result
|
||||
final_answer_streamed = True
|
||||
|
||||
if final_answer_streamed is False:
|
||||
raise RuntimeError("LLM did not to produce a Final Answer after tool call")
|
||||
except Exception as e:
|
||||
logger.error(f"LLM failed to produce valid chat message, error: {e}")
|
||||
yield LLM_CHAT_FAILURE_MSG
|
||||
|
||||
|
||||
def wrap_chat_package_in_model(
|
||||
package: str | list[InferenceChunk],
|
||||
) -> DanswerAnswerPiece | RetrievalDocs:
|
||||
if isinstance(package, str):
|
||||
return DanswerAnswerPiece(answer_piece=package)
|
||||
elif isinstance(package, list):
|
||||
return RetrievalDocs(top_documents=chunks_to_search_docs(package))
|
||||
logger.exception(f"LLM failed to produce valid chat message, error: {e}")
|
||||
yield StreamingError(error=str(e))
|
||||
|
||||
|
||||
def llm_chat_answer(
|
||||
@ -537,7 +532,7 @@ def llm_chat_answer(
|
||||
tokenizer: Callable,
|
||||
user: User | None,
|
||||
db_session: Session,
|
||||
) -> Iterator[DanswerAnswerPiece | RetrievalDocs]:
|
||||
) -> Iterator[DanswerAnswerPiece | RetrievalDocs | StreamingError]:
|
||||
# Common error cases to keep in mind:
|
||||
# - User asks question about something long ago, due to context limit, the message is dropped
|
||||
# - Tool use gives wrong/irrelevant results, model gets confused by the noise
|
||||
@ -547,36 +542,33 @@ def llm_chat_answer(
|
||||
|
||||
# No setting/persona available therefore no retrieval and no additional tools
|
||||
if persona is None:
|
||||
for token in llm_contextless_chat_answer(messages):
|
||||
yield DanswerAnswerPiece(answer_piece=token)
|
||||
return llm_contextless_chat_answer(messages)
|
||||
|
||||
# Persona is configured but with retrieval off and no tools
|
||||
# therefore cannot retrieve any context so contextless
|
||||
elif persona.retrieval_enabled is False and not persona.tools:
|
||||
for token in llm_contextless_chat_answer(
|
||||
return llm_contextless_chat_answer(
|
||||
messages, system_text=persona.system_text, tokenizer=tokenizer
|
||||
):
|
||||
yield DanswerAnswerPiece(answer_piece=token)
|
||||
)
|
||||
|
||||
# No additional tools outside of Danswer retrieval, can use a more basic prompt
|
||||
# Doesn't require tool calling output format (all LLM outputs are therefore valid)
|
||||
elif persona.retrieval_enabled and not persona.tools and not FORCE_TOOL_PROMPT:
|
||||
for package in llm_contextual_chat_answer(
|
||||
return llm_contextual_chat_answer(
|
||||
messages=messages,
|
||||
persona=persona,
|
||||
tokenizer=tokenizer,
|
||||
user=user,
|
||||
db_session=db_session,
|
||||
):
|
||||
yield wrap_chat_package_in_model(package)
|
||||
)
|
||||
|
||||
# Use most flexible/complex prompt format
|
||||
else:
|
||||
for package in llm_tools_enabled_chat_answer(
|
||||
messages=messages,
|
||||
persona=persona,
|
||||
tokenizer=tokenizer,
|
||||
user=user,
|
||||
db_session=db_session,
|
||||
):
|
||||
yield wrap_chat_package_in_model(package)
|
||||
# Use most flexible/complex prompt format that allows arbitrary tool calls
|
||||
# that are configured in the persona file
|
||||
# WARNING: this flow does not work well with weaker LLMs (anything below GPT-4)
|
||||
return llm_tools_enabled_chat_answer(
|
||||
messages=messages,
|
||||
persona=persona,
|
||||
tokenizer=tokenizer,
|
||||
user=user,
|
||||
db_session=db_session,
|
||||
)
|
||||
|
@ -8,6 +8,10 @@ from danswer.chunking.models import InferenceChunk
|
||||
from danswer.direct_qa.models import LLMMetricsContainer
|
||||
|
||||
|
||||
class StreamingError(BaseModel):
|
||||
error: str
|
||||
|
||||
|
||||
class DanswerAnswer(BaseModel):
|
||||
answer: str | None
|
||||
|
||||
|
@ -4,6 +4,7 @@ from collections.abc import Iterator
|
||||
from danswer.configs.constants import CODE_BLOCK_PAT
|
||||
from danswer.configs.constants import GENERAL_SEP_PAT
|
||||
from danswer.direct_qa.interfaces import DanswerAnswerPiece
|
||||
from danswer.direct_qa.interfaces import StreamingError
|
||||
from danswer.direct_qa.qa_block import dict_based_prompt_to_langchain_prompt
|
||||
from danswer.llm.build import get_default_llm
|
||||
from danswer.server.models import QueryValidationResponse
|
||||
@ -135,6 +136,7 @@ def stream_query_answerability(user_query: str) -> Iterator[str]:
|
||||
)
|
||||
except Exception as e:
|
||||
# exception is logged in the answer_question method, no need to re-log
|
||||
yield get_json_line({"error": str(e)})
|
||||
error = StreamingError(error=str(e))
|
||||
yield get_json_line(error.dict())
|
||||
logger.exception("Failed to validate Query")
|
||||
return
|
||||
|
@ -25,6 +25,7 @@ from danswer.direct_qa.answer_question import answer_qa_query
|
||||
from danswer.direct_qa.exceptions import OpenAIKeyMissing
|
||||
from danswer.direct_qa.exceptions import UnknownModelError
|
||||
from danswer.direct_qa.interfaces import DanswerAnswerPiece
|
||||
from danswer.direct_qa.interfaces import StreamingError
|
||||
from danswer.direct_qa.llm_utils import get_default_qa_model
|
||||
from danswer.direct_qa.qa_utils import get_usable_chunks
|
||||
from danswer.search.access_filters import build_access_filters_for_user
|
||||
@ -366,7 +367,8 @@ def stream_direct_qa(
|
||||
qa_model = get_default_qa_model()
|
||||
except (UnknownModelError, OpenAIKeyMissing) as e:
|
||||
logger.exception("Unable to get QA model")
|
||||
yield get_json_line({"error": str(e)})
|
||||
error = StreamingError(error=str(e))
|
||||
yield get_json_line(error.dict())
|
||||
return
|
||||
|
||||
# remove chunks marked as not applicable for QA (e.g. Google Drive file
|
||||
|
Loading…
x
Reference in New Issue
Block a user