removed print statements, fixed pass through handling

This commit is contained in:
Evan Lohn 2025-01-23 10:38:09 -08:00
parent ca1f176c61
commit 2032fb10da
4 changed files with 24 additions and 27 deletions

View File

@ -42,13 +42,11 @@ def tool_call(state: BasicState, config: RunnableConfig) -> ToolCallUpdate:
tool_runner = ToolRunner(tool, tool_args)
tool_kickoff = tool_runner.kickoff()
print("tool_kickoff", tool_kickoff)
# TODO: custom events for yields
emit_packet(tool_kickoff)
tool_responses = []
for response in tool_runner.tool_responses():
print("response", response.id)
tool_responses.append(response)
emit_packet(response)

View File

@ -9,7 +9,7 @@ from onyx.chat.models import LlmDoc
from onyx.chat.stream_processing.answer_response_handler import AnswerResponseHandler
from onyx.chat.stream_processing.answer_response_handler import CitationResponseHandler
from onyx.chat.stream_processing.answer_response_handler import (
DummyAnswerResponseHandler,
PassThroughAnswerResponseHandler,
)
from onyx.chat.stream_processing.utils import map_document_id_order
from onyx.utils.logger import setup_logger
@ -34,8 +34,6 @@ def process_llm_stream(
tool_call_chunk = AIMessageChunk(content="")
# for response in response_handler_manager.handle_llm_response(stream):
print("final_search_results", final_search_results)
print("displayed_search_results", displayed_search_results)
if final_search_results and displayed_search_results:
answer_handler: AnswerResponseHandler = CitationResponseHandler(
context_docs=final_search_results,
@ -43,9 +41,8 @@ def process_llm_stream(
display_doc_id_to_rank_map=map_document_id_order(displayed_search_results),
)
else:
answer_handler = DummyAnswerResponseHandler()
answer_handler = PassThroughAnswerResponseHandler()
print("entering stream")
# This stream will be the llm answer if no tool is chosen. When a tool is chosen,
# the stream will contain AIMessageChunks with tool call information.
for response in stream:
@ -62,7 +59,6 @@ def process_llm_stream(
elif should_stream_answer:
# TODO: handle emitting of CitationInfo
for response_part in answer_handler.handle_response_part(response, []):
print("resp part", response_part)
dispatch_custom_event(
"basic_response",
response_part,

View File

@ -8,6 +8,7 @@ from langchain_core.messages import BaseMessage
from onyx.chat.llm_response_handler import ResponsePart
from onyx.chat.models import CitationInfo
from onyx.chat.models import LlmDoc
from onyx.chat.models import OnyxAnswerPiece
from onyx.chat.stream_processing.citation_processing import CitationProcessor
from onyx.chat.stream_processing.utils import DocumentIdOrderMapping
from onyx.utils.logger import setup_logger
@ -15,6 +16,7 @@ from onyx.utils.logger import setup_logger
logger = setup_logger()
# TODO: remove update() once it is no longer needed
class AnswerResponseHandler(abc.ABC):
@abc.abstractmethod
def handle_response_part(
@ -29,6 +31,19 @@ class AnswerResponseHandler(abc.ABC):
raise NotImplementedError
class PassThroughAnswerResponseHandler(AnswerResponseHandler):
def handle_response_part(
self,
response_item: BaseMessage | str | None,
previous_response_items: list[BaseMessage | str],
) -> Generator[ResponsePart, None, None]:
content = _message_to_str(response_item)
yield OnyxAnswerPiece(answer_piece=content)
def update(self, state_update: Any) -> None:
pass
class DummyAnswerResponseHandler(AnswerResponseHandler):
def handle_response_part(
self,
@ -71,16 +86,7 @@ class CitationResponseHandler(AnswerResponseHandler):
if response_item is None:
return
content = (
response_item.content
if isinstance(response_item, BaseMessage)
else response_item
)
# Ensure content is a string
if not isinstance(content, str):
logger.warning(f"Received non-string content: {type(content)}")
content = str(content) if content is not None else ""
content = _message_to_str(response_item)
# Process the new content through the citation processor
yield from self.citation_processor.process_token(content)
@ -100,7 +106,11 @@ class CitationResponseHandler(AnswerResponseHandler):
)
def BaseMessage_to_str(message: BaseMessage) -> str:
def _message_to_str(message: BaseMessage | str | None) -> str:
if message is None:
return ""
if isinstance(message, str):
return message
content = message.content if isinstance(message, BaseMessage) else message
if not isinstance(content, str):
logger.warning(f"Received non-string content: {type(content)}")

View File

@ -145,7 +145,6 @@ def test_answer_with_search_call(
# Process the output
output = list(answer_instance.processed_streamed_output)
print(output)
# Updated assertions
assert len(output) == 7
@ -238,14 +237,9 @@ def test_answer_with_search_no_tool_calling(
# Process the output
output = list(answer_instance.processed_streamed_output)
print("-" * 50)
for v in output:
print(v)
print()
print("-" * 50)
# Assertions
assert len(output) == 7
assert len(output) == 8
assert output[0] == ToolCallKickoff(
tool_name="search", tool_args=DEFAULT_SEARCH_ARGS
)
@ -324,7 +318,6 @@ def test_is_cancelled(answer_instance: Answer) -> None:
if i == 1:
connection_status["connected"] = False
print(output)
assert len(output) == 3
assert output[0] == OnyxAnswerPiece(answer_piece="This is the ")
assert output[1] == OnyxAnswerPiece(answer_piece="first part.")