danswer/backend/onyx/chat/llm_response_handler.py
2025-02-03 20:07:57 -08:00

59 lines
2.3 KiB
Python

from collections.abc import Callable
from collections.abc import Generator
from collections.abc import Iterator
from langchain_core.messages import BaseMessage
from onyx.chat.models import ResponsePart
from onyx.chat.models import StreamStopInfo
from onyx.chat.models import StreamStopReason
from onyx.chat.prompt_builder.answer_prompt_builder import LLMCall
from onyx.chat.stream_processing.answer_response_handler import AnswerResponseHandler
from onyx.chat.stream_processing.answer_response_handler import (
DummyAnswerResponseHandler,
)
from onyx.chat.tool_handling.tool_response_handler import ToolResponseHandler
class LLMResponseHandlerManager:
"""
This class is responsible for postprocessing the LLM response stream.
In particular, we:
1. handle the tool call requests
2. handle citations
3. pass through answers generated by the LLM
4. Stop yielding if the client disconnects
"""
def __init__(
self,
tool_handler: ToolResponseHandler | None,
answer_handler: AnswerResponseHandler | None,
is_cancelled: Callable[[], bool],
):
self.tool_handler = tool_handler or ToolResponseHandler([])
self.answer_handler = answer_handler or DummyAnswerResponseHandler()
self.is_cancelled = is_cancelled
def handle_llm_response(
self,
stream: Iterator[BaseMessage],
) -> Generator[ResponsePart, None, None]:
all_messages: list[BaseMessage | str] = []
for message in stream:
if self.is_cancelled():
yield StreamStopInfo(stop_reason=StreamStopReason.CANCELLED)
return
# tool handler doesn't do anything until the full message is received
# NOTE: still need to run list() to get this to run
list(self.tool_handler.handle_response_part(message, all_messages))
yield from self.answer_handler.handle_response_part(message, all_messages)
all_messages.append(message)
# potentially give back all info on the selected tool call + its result
yield from self.tool_handler.handle_response_part(None, all_messages)
yield from self.answer_handler.handle_response_part(None, all_messages)
def next_llm_call(self, llm_call: LLMCall) -> LLMCall | None:
return self.tool_handler.next_llm_call(llm_call)