functional search and chat once again!

This commit is contained in:
pablodanswer 2024-09-15 12:20:35 -07:00
parent ff6a15b5af
commit 63d10e7482
5 changed files with 81 additions and 32 deletions

View File

@ -752,6 +752,8 @@ def stream_chat_message_objects(
tool_name=custom_tool_response.tool_name,
)
elif isinstance(packet, StreamStopInfo):
print("PACKET IS ENINDG")
print(packet)
if packet.stop_reason is not StreamStopReason.NEW_RESPONSE:
break

View File

@ -202,7 +202,13 @@ class Answer:
) -> Iterator[
str | StreamStopInfo | ToolCallKickoff | ToolResponse | ToolCallFinalResult
]:
for i in range(MAX_TOOL_CALLS):
tool_calls = 0
initiated = False
while tool_calls < MAX_TOOL_CALLS:
if initiated:
yield StreamStopInfo(stop_reason=StreamStopReason.NEW_RESPONSE)
initiated = True
prompt_builder = AnswerPromptBuilder(self.message_history, self.llm.config)
tool_call_chunk: AIMessageChunk | None = None
@ -223,7 +229,10 @@ class Answer:
)
prompt_builder.update_user_prompt(
default_build_user_message(
self.question, self.prompt_config, self.latest_query_files
self.question,
self.prompt_config,
self.latest_query_files,
tool_calls,
)
)
prompt = prompt_builder.build()
@ -234,6 +243,10 @@ class Answer:
self.tools, self.force_use_tool
)
]
print("\n----\nStreaming with initial prompt to get tool calls\n----\n")
print(f"\n\n\n----------Initial prompt is {prompt}----------\n\n\n")
existing_message = ""
for message in self.llm.stream(
prompt=prompt,
@ -246,15 +259,19 @@ class Answer:
if tool_call_chunk is None:
tool_call_chunk = message
else:
if len(existing_message) > 0:
yield StreamStopInfo(
stop_reason=StreamStopReason.NEW_RESPONSE
)
existing_message = ""
tool_call_chunk += message # type: ignore
else:
if message.content:
print(message)
if self.is_cancelled:
if self.is_cancelled or tool_calls > 0:
return
else:
print("not canncelled")
existing_message += cast(str, message.content)
yield cast(str, message.content)
if (
message.additional_kwargs.get("usage_metadata", {}).get(
@ -271,7 +288,9 @@ class Answer:
return
tool_call_requests = tool_call_chunk.tool_calls
for tool_call_request in tool_call_requests:
tool_calls += 1
known_tools_by_name = [
tool
for tool in self.tools
@ -299,10 +318,18 @@ class Answer:
)
tool_runner = ToolRunner(tool, tool_args)
yield tool_runner.kickoff()
tool_kickoff = tool_runner.kickoff()
print("----\n\n\n")
print(tool_kickoff)
print(type(tool_kickoff))
yield tool_kickoff
tool_responses = list(tool_runner.tool_responses())
yield from tool_responses
for response in tool_responses:
print("----\n\n\n")
print(response)
yield response
# yield from tool_responses
tool_call_summary = ToolCallSummary(
tool_call_request=tool_call_chunk,
@ -313,7 +340,9 @@ class Answer:
if tool.name in {SearchTool._NAME, InternetSearchTool._NAME}:
self._update_prompt_builder_for_search_tool(prompt_builder, [])
elif tool.name == ImageGenerationTool._NAME:
print("\n----\nUpdating image prompt user message\n----\n")
img_urls = [
img_generation_result["url"]
for img_generation_result in tool_runner.tool_final_result().tool_result
@ -324,14 +353,26 @@ class Answer:
)
)
print("now stremign wie fianl results")
yield tool_runner.tool_final_result()
# Update message history with tool call and response
self.message_history.append(
PreviousMessage(
message=str(self.question),
message_type=MessageType.USER,
token_count=10,
tool_call=None,
files=[],
)
)
self.message_history.append(
PreviousMessage(
message=str(tool_call_request),
message_type=MessageType.ASSISTANT,
token_count=10, # You may want to implement a token counting method
token_count=10,
tool_call=None,
files=[],
)
@ -346,16 +387,22 @@ class Answer:
)
)
print("\n----\nBuilding final prompt with Tool call summary\n----\n")
# Generate response based on updated message history
prompt = prompt_builder.build(tool_call_summary=tool_call_summary)
response_content = ""
yield from self._process_llm_stream(
for content in self._process_llm_stream(
prompt=prompt,
tools=[tool.tool_definition() for tool in self.tools],
)
tools=None
# tools=[tool.tool_definition() for tool in self.tools],
):
if isinstance(content, str):
response_content += content
yield content
print(response_content)
# Update message history with LLM response
self.message_history.append(
PreviousMessage(
@ -363,7 +410,7 @@ class Answer:
message_type=MessageType.ASSISTANT,
token_count=10,
tool_call=None,
files=[], # You may want to implement a token counting method
files=[],
)
)
@ -467,13 +514,11 @@ class Answer:
self.question, self.prompt_config, self.latest_query_files
)
)
print("I am now yielding from here")
prompt = prompt_builder.build()
yield from self._process_llm_stream(
prompt=prompt,
tools=None,
)
print("yielding complete")
return
tool, tool_args = chosen_tool_and_args
@ -558,7 +603,6 @@ class Answer:
final_context_docs: list[LlmDoc] | None = None
for message in stream:
print(message)
if isinstance(message, ToolCallKickoff) or isinstance(
message, ToolCallFinalResult
):
@ -611,9 +655,6 @@ class Answer:
return
if isinstance(item, ToolCallKickoff):
new_kickoff = item
stream_stop_info = StreamStopInfo(
stop_reason=StreamStopReason.NEW_RESPONSE
)
return
else:
yield cast(str, item)
@ -623,17 +664,21 @@ class Answer:
if stream_stop_info:
yield stream_stop_info
# if new_kickoff: handle new tool call (continuation of message)
# handle new tool call (continuation of message)
if new_kickoff:
self.current_streamed_output = self.processing_stream
self.processing_stream = []
yield new_kickoff
for processed_packet in _process_stream(output_generator):
if (
isinstance(processed_packet, StreamStopInfo)
and processed_packet.stop_reason == StreamStopReason.NEW_RESPONSE
):
self.current_streamed_output = self.processing_stream
self.processing_stream = []
self.processing_stream.append(processed_packet)
yield processed_packet
self.current_streamed_output = self.processing_stream
self._processed_stream = self.processing_stream
@property

View File

@ -36,7 +36,10 @@ def default_build_system_message(
def default_build_user_message(
user_query: str, prompt_config: PromptConfig, files: list[InMemoryChatFile] = []
user_query: str,
prompt_config: PromptConfig,
files: list[InMemoryChatFile] = [],
previous_tool_calls: int = 0,
) -> HumanMessage:
user_prompt = (
CHAT_USER_CONTEXT_FREE_PROMPT.format(
@ -45,6 +48,9 @@ def default_build_user_message(
if prompt_config.task_prompt
else user_query
)
if previous_tool_calls > 0:
user_prompt = f"You have already generated the above but remember the query is: `{user_prompt}`"
user_prompt = user_prompt.strip()
user_msg = HumanMessage(
content=build_content_with_imgs(user_prompt, files) if files else user_prompt

View File

@ -279,18 +279,14 @@ async def is_disconnected(request: Request) -> Callable[[], bool]:
main_loop = asyncio.get_event_loop()
def is_disconnected_sync() -> bool:
logger.info("Checking if client is disconnected")
future = asyncio.run_coroutine_threadsafe(request.is_disconnected(), main_loop)
try:
result = not future.result(timeout=0.01)
if result:
logger.info("Client disconnected")
return result
except asyncio.TimeoutError:
logger.error("Asyncio timed out while checking client connection")
return True
except asyncio.CancelledError:
logger.info("Disconnect check was cancelled")
return True
except Exception as e:
error_msg = str(e)

View File

@ -22,7 +22,7 @@ INTENT_MODEL_VERSION = "danswer/hybrid-intent-token-classifier"
INTENT_MODEL_TAG = "v1.0.3"
# Tool call configs
MAX_TOOL_CALLS = 2
MAX_TOOL_CALLS = 3
# Bi-Encoder, other details
DOC_EMBEDDING_CONTEXT_SIZE = 512