mirror of
https://github.com/danswer-ai/danswer.git
synced 2025-04-08 03:48:14 +02:00
functional search and chat once again!
This commit is contained in:
parent
ff6a15b5af
commit
63d10e7482
@ -752,6 +752,8 @@ def stream_chat_message_objects(
|
||||
tool_name=custom_tool_response.tool_name,
|
||||
)
|
||||
elif isinstance(packet, StreamStopInfo):
|
||||
print("PACKET IS ENINDG")
|
||||
print(packet)
|
||||
if packet.stop_reason is not StreamStopReason.NEW_RESPONSE:
|
||||
break
|
||||
|
||||
|
@ -202,7 +202,13 @@ class Answer:
|
||||
) -> Iterator[
|
||||
str | StreamStopInfo | ToolCallKickoff | ToolResponse | ToolCallFinalResult
|
||||
]:
|
||||
for i in range(MAX_TOOL_CALLS):
|
||||
tool_calls = 0
|
||||
initiated = False
|
||||
while tool_calls < MAX_TOOL_CALLS:
|
||||
if initiated:
|
||||
yield StreamStopInfo(stop_reason=StreamStopReason.NEW_RESPONSE)
|
||||
initiated = True
|
||||
|
||||
prompt_builder = AnswerPromptBuilder(self.message_history, self.llm.config)
|
||||
|
||||
tool_call_chunk: AIMessageChunk | None = None
|
||||
@ -223,7 +229,10 @@ class Answer:
|
||||
)
|
||||
prompt_builder.update_user_prompt(
|
||||
default_build_user_message(
|
||||
self.question, self.prompt_config, self.latest_query_files
|
||||
self.question,
|
||||
self.prompt_config,
|
||||
self.latest_query_files,
|
||||
tool_calls,
|
||||
)
|
||||
)
|
||||
prompt = prompt_builder.build()
|
||||
@ -234,6 +243,10 @@ class Answer:
|
||||
self.tools, self.force_use_tool
|
||||
)
|
||||
]
|
||||
print("\n----\nStreaming with initial prompt to get tool calls\n----\n")
|
||||
print(f"\n\n\n----------Initial prompt is {prompt}----------\n\n\n")
|
||||
|
||||
existing_message = ""
|
||||
|
||||
for message in self.llm.stream(
|
||||
prompt=prompt,
|
||||
@ -246,15 +259,19 @@ class Answer:
|
||||
if tool_call_chunk is None:
|
||||
tool_call_chunk = message
|
||||
else:
|
||||
if len(existing_message) > 0:
|
||||
yield StreamStopInfo(
|
||||
stop_reason=StreamStopReason.NEW_RESPONSE
|
||||
)
|
||||
existing_message = ""
|
||||
|
||||
tool_call_chunk += message # type: ignore
|
||||
else:
|
||||
if message.content:
|
||||
print(message)
|
||||
|
||||
if self.is_cancelled:
|
||||
if self.is_cancelled or tool_calls > 0:
|
||||
return
|
||||
else:
|
||||
print("not canncelled")
|
||||
|
||||
existing_message += cast(str, message.content)
|
||||
yield cast(str, message.content)
|
||||
if (
|
||||
message.additional_kwargs.get("usage_metadata", {}).get(
|
||||
@ -271,7 +288,9 @@ class Answer:
|
||||
return
|
||||
|
||||
tool_call_requests = tool_call_chunk.tool_calls
|
||||
|
||||
for tool_call_request in tool_call_requests:
|
||||
tool_calls += 1
|
||||
known_tools_by_name = [
|
||||
tool
|
||||
for tool in self.tools
|
||||
@ -299,10 +318,18 @@ class Answer:
|
||||
)
|
||||
|
||||
tool_runner = ToolRunner(tool, tool_args)
|
||||
yield tool_runner.kickoff()
|
||||
tool_kickoff = tool_runner.kickoff()
|
||||
print("----\n\n\n")
|
||||
print(tool_kickoff)
|
||||
print(type(tool_kickoff))
|
||||
yield tool_kickoff
|
||||
|
||||
tool_responses = list(tool_runner.tool_responses())
|
||||
yield from tool_responses
|
||||
for response in tool_responses:
|
||||
print("----\n\n\n")
|
||||
print(response)
|
||||
yield response
|
||||
# yield from tool_responses
|
||||
|
||||
tool_call_summary = ToolCallSummary(
|
||||
tool_call_request=tool_call_chunk,
|
||||
@ -313,7 +340,9 @@ class Answer:
|
||||
|
||||
if tool.name in {SearchTool._NAME, InternetSearchTool._NAME}:
|
||||
self._update_prompt_builder_for_search_tool(prompt_builder, [])
|
||||
|
||||
elif tool.name == ImageGenerationTool._NAME:
|
||||
print("\n----\nUpdating image prompt user message\n----\n")
|
||||
img_urls = [
|
||||
img_generation_result["url"]
|
||||
for img_generation_result in tool_runner.tool_final_result().tool_result
|
||||
@ -324,14 +353,26 @@ class Answer:
|
||||
)
|
||||
)
|
||||
|
||||
print("now stremign wie fianl results")
|
||||
|
||||
yield tool_runner.tool_final_result()
|
||||
|
||||
# Update message history with tool call and response
|
||||
self.message_history.append(
|
||||
PreviousMessage(
|
||||
message=str(self.question),
|
||||
message_type=MessageType.USER,
|
||||
token_count=10,
|
||||
tool_call=None,
|
||||
files=[],
|
||||
)
|
||||
)
|
||||
|
||||
self.message_history.append(
|
||||
PreviousMessage(
|
||||
message=str(tool_call_request),
|
||||
message_type=MessageType.ASSISTANT,
|
||||
token_count=10, # You may want to implement a token counting method
|
||||
token_count=10,
|
||||
tool_call=None,
|
||||
files=[],
|
||||
)
|
||||
@ -346,16 +387,22 @@ class Answer:
|
||||
)
|
||||
)
|
||||
|
||||
print("\n----\nBuilding final prompt with Tool call summary\n----\n")
|
||||
|
||||
# Generate response based on updated message history
|
||||
prompt = prompt_builder.build(tool_call_summary=tool_call_summary)
|
||||
|
||||
response_content = ""
|
||||
|
||||
yield from self._process_llm_stream(
|
||||
for content in self._process_llm_stream(
|
||||
prompt=prompt,
|
||||
tools=[tool.tool_definition() for tool in self.tools],
|
||||
)
|
||||
tools=None
|
||||
# tools=[tool.tool_definition() for tool in self.tools],
|
||||
):
|
||||
if isinstance(content, str):
|
||||
response_content += content
|
||||
yield content
|
||||
|
||||
print(response_content)
|
||||
# Update message history with LLM response
|
||||
self.message_history.append(
|
||||
PreviousMessage(
|
||||
@ -363,7 +410,7 @@ class Answer:
|
||||
message_type=MessageType.ASSISTANT,
|
||||
token_count=10,
|
||||
tool_call=None,
|
||||
files=[], # You may want to implement a token counting method
|
||||
files=[],
|
||||
)
|
||||
)
|
||||
|
||||
@ -467,13 +514,11 @@ class Answer:
|
||||
self.question, self.prompt_config, self.latest_query_files
|
||||
)
|
||||
)
|
||||
print("I am now yielding from here")
|
||||
prompt = prompt_builder.build()
|
||||
yield from self._process_llm_stream(
|
||||
prompt=prompt,
|
||||
tools=None,
|
||||
)
|
||||
print("yielding complete")
|
||||
return
|
||||
|
||||
tool, tool_args = chosen_tool_and_args
|
||||
@ -558,7 +603,6 @@ class Answer:
|
||||
final_context_docs: list[LlmDoc] | None = None
|
||||
|
||||
for message in stream:
|
||||
print(message)
|
||||
if isinstance(message, ToolCallKickoff) or isinstance(
|
||||
message, ToolCallFinalResult
|
||||
):
|
||||
@ -611,9 +655,6 @@ class Answer:
|
||||
return
|
||||
if isinstance(item, ToolCallKickoff):
|
||||
new_kickoff = item
|
||||
stream_stop_info = StreamStopInfo(
|
||||
stop_reason=StreamStopReason.NEW_RESPONSE
|
||||
)
|
||||
return
|
||||
else:
|
||||
yield cast(str, item)
|
||||
@ -623,17 +664,21 @@ class Answer:
|
||||
if stream_stop_info:
|
||||
yield stream_stop_info
|
||||
|
||||
# if new_kickoff: handle new tool call (continuation of message)
|
||||
# handle new tool call (continuation of message)
|
||||
if new_kickoff:
|
||||
self.current_streamed_output = self.processing_stream
|
||||
self.processing_stream = []
|
||||
|
||||
yield new_kickoff
|
||||
|
||||
for processed_packet in _process_stream(output_generator):
|
||||
if (
|
||||
isinstance(processed_packet, StreamStopInfo)
|
||||
and processed_packet.stop_reason == StreamStopReason.NEW_RESPONSE
|
||||
):
|
||||
self.current_streamed_output = self.processing_stream
|
||||
self.processing_stream = []
|
||||
|
||||
self.processing_stream.append(processed_packet)
|
||||
yield processed_packet
|
||||
|
||||
self.current_streamed_output = self.processing_stream
|
||||
self._processed_stream = self.processing_stream
|
||||
|
||||
@property
|
||||
|
@ -36,7 +36,10 @@ def default_build_system_message(
|
||||
|
||||
|
||||
def default_build_user_message(
|
||||
user_query: str, prompt_config: PromptConfig, files: list[InMemoryChatFile] = []
|
||||
user_query: str,
|
||||
prompt_config: PromptConfig,
|
||||
files: list[InMemoryChatFile] = [],
|
||||
previous_tool_calls: int = 0,
|
||||
) -> HumanMessage:
|
||||
user_prompt = (
|
||||
CHAT_USER_CONTEXT_FREE_PROMPT.format(
|
||||
@ -45,6 +48,9 @@ def default_build_user_message(
|
||||
if prompt_config.task_prompt
|
||||
else user_query
|
||||
)
|
||||
if previous_tool_calls > 0:
|
||||
user_prompt = f"You have already generated the above but remember the query is: `{user_prompt}`"
|
||||
|
||||
user_prompt = user_prompt.strip()
|
||||
user_msg = HumanMessage(
|
||||
content=build_content_with_imgs(user_prompt, files) if files else user_prompt
|
||||
|
@ -279,18 +279,14 @@ async def is_disconnected(request: Request) -> Callable[[], bool]:
|
||||
main_loop = asyncio.get_event_loop()
|
||||
|
||||
def is_disconnected_sync() -> bool:
|
||||
logger.info("Checking if client is disconnected")
|
||||
future = asyncio.run_coroutine_threadsafe(request.is_disconnected(), main_loop)
|
||||
try:
|
||||
result = not future.result(timeout=0.01)
|
||||
if result:
|
||||
logger.info("Client disconnected")
|
||||
return result
|
||||
except asyncio.TimeoutError:
|
||||
logger.error("Asyncio timed out while checking client connection")
|
||||
return True
|
||||
except asyncio.CancelledError:
|
||||
logger.info("Disconnect check was cancelled")
|
||||
return True
|
||||
except Exception as e:
|
||||
error_msg = str(e)
|
||||
|
@ -22,7 +22,7 @@ INTENT_MODEL_VERSION = "danswer/hybrid-intent-token-classifier"
|
||||
INTENT_MODEL_TAG = "v1.0.3"
|
||||
|
||||
# Tool call configs
|
||||
MAX_TOOL_CALLS = 2
|
||||
MAX_TOOL_CALLS = 3
|
||||
|
||||
# Bi-Encoder, other details
|
||||
DOC_EMBEDDING_CONTEXT_SIZE = 512
|
||||
|
Loading…
x
Reference in New Issue
Block a user