Don't rephrase first chat query (#1907)

This commit is contained in:
Yuhong Sun 2024-07-23 16:20:11 -07:00 committed by GitHub
parent 866bc803b1
commit 2470c68506
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
8 changed files with 63 additions and 63 deletions

View File

@ -11,8 +11,8 @@ import sqlalchemy as sa
revision = "795b20b85b4b"
down_revision = "05c07bf07c00"
branch_labels = None
depends_on = None
branch_labels: None = None
depends_on: None = None
def upgrade() -> None:

View File

@ -187,37 +187,46 @@ def _handle_internet_search_tool_response_summary(
)
def _check_should_force_search(
new_msg_req: CreateChatMessageRequest,
) -> ForceUseTool | None:
# If files are already provided, don't run the search tool
def _get_force_search_settings(
new_msg_req: CreateChatMessageRequest, tools: list[Tool]
) -> ForceUseTool:
internet_search_available = any(
isinstance(tool, InternetSearchTool) for tool in tools
)
search_tool_available = any(isinstance(tool, SearchTool) for tool in tools)
if not internet_search_available and not search_tool_available:
# Does not matter much which tool is set here as force is false and neither tool is available
return ForceUseTool(force_use=False, tool_name=SearchTool._NAME)
tool_name = SearchTool._NAME if search_tool_available else InternetSearchTool._NAME
# Currently, the internet search tool does not support query override
args = (
{"query": new_msg_req.query_override}
if new_msg_req.query_override and tool_name == SearchTool._NAME
else None
)
if new_msg_req.file_descriptors:
return None
# If user has uploaded files they're using, don't run any of the search tools
return ForceUseTool(force_use=False, tool_name=tool_name)
if (
new_msg_req.query_override
or (
should_force_search = any(
[
new_msg_req.retrieval_options
and new_msg_req.retrieval_options.run_search == OptionalSearchSetting.ALWAYS
)
or new_msg_req.search_doc_ids
or DISABLE_LLM_CHOOSE_SEARCH
):
args = (
{"query": new_msg_req.query_override}
if new_msg_req.query_override
else None
)
# if we are using selected docs, just put something here so the Tool doesn't need
# to build its own args via an LLM call
if new_msg_req.search_doc_ids:
args = {"query": new_msg_req.message}
and new_msg_req.retrieval_options.run_search
== OptionalSearchSetting.ALWAYS,
new_msg_req.search_doc_ids,
DISABLE_LLM_CHOOSE_SEARCH,
]
)
return ForceUseTool(
tool_name=SearchTool._NAME,
args=args,
)
return None
if should_force_search:
# If we are using selected docs, just put something here so the Tool doesn't need to build its own args via an LLM call
args = {"query": new_msg_req.message} if new_msg_req.search_doc_ids else args
return ForceUseTool(force_use=True, tool_name=tool_name, args=args)
return ForceUseTool(force_use=False, tool_name=tool_name, args=args)
ChatPacket = (
@ -360,6 +369,14 @@ def stream_chat_message_objects(
"when the last message is not a user message."
)
# Disable Query Rephrasing for the first message
# This leads to a better first response since the LLM rephrasing the question
# leads to worst search quality
if not history_msgs:
new_msg_req.query_override = (
new_msg_req.query_override or new_msg_req.message
)
# load all files needed for this chat chain in memory
files = load_all_chat_files(
history_msgs, new_msg_req.file_descriptors, db_session
@ -575,11 +592,7 @@ def stream_chat_message_objects(
PreviousMessage.from_chat_message(msg, files) for msg in history_msgs
],
tools=tools,
force_use_tool=(
_check_should_force_search(new_msg_req)
if search_tool and len(tools) == 1
else None
),
force_use_tool=_get_force_search_settings(new_msg_req, tools),
)
reference_db_search_docs = None

View File

@ -99,6 +99,7 @@ class Answer:
answer_style_config: AnswerStyleConfig,
llm: LLM,
prompt_config: PromptConfig,
force_use_tool: ForceUseTool,
# must be the same length as `docs`. If None, all docs are considered "relevant"
message_history: list[PreviousMessage] | None = None,
single_message_history: str | None = None,
@ -107,10 +108,8 @@ class Answer:
latest_query_files: list[InMemoryChatFile] | None = None,
files: list[InMemoryChatFile] | None = None,
tools: list[Tool] | None = None,
# if specified, tells the LLM to always this tool
# NOTE: for native tool-calling, this is only supported by OpenAI atm,
# but we only support them anyways
force_use_tool: ForceUseTool | None = None,
# if set to True, then never use the LLMs provided tool-calling functonality
skip_explicit_tool_calling: bool = False,
# Returns the full document sections text from the search tool
@ -129,6 +128,7 @@ class Answer:
self.tools = tools or []
self.force_use_tool = force_use_tool
self.skip_explicit_tool_calling = skip_explicit_tool_calling
self.message_history = message_history or []
@ -187,7 +187,7 @@ class Answer:
prompt_builder = AnswerPromptBuilder(self.message_history, self.llm.config)
tool_call_chunk: AIMessageChunk | None = None
if self.force_use_tool and self.force_use_tool.args is not None:
if self.force_use_tool.force_use and self.force_use_tool.args is not None:
# if we are forcing a tool WITH args specified, we don't need to check which tools to run
# / need to generate the args
tool_call_chunk = AIMessageChunk(
@ -221,7 +221,7 @@ class Answer:
for message in self.llm.stream(
prompt=prompt,
tools=final_tool_definitions if final_tool_definitions else None,
tool_choice="required" if self.force_use_tool else None,
tool_choice="required" if self.force_use_tool.force_use else None,
):
if isinstance(message, AIMessageChunk) and (
message.tool_call_chunks or message.tool_calls
@ -245,7 +245,8 @@ class Answer:
][0]
tool_args = (
self.force_use_tool.args
if self.force_use_tool and self.force_use_tool.args
if self.force_use_tool.tool_name == tool.name
and self.force_use_tool.args
else tool_call_request["args"]
)
@ -303,7 +304,7 @@ class Answer:
tool_args = (
self.force_use_tool.args
if self.force_use_tool.args
if self.force_use_tool.args is not None
else tool.get_args_for_non_tool_calling_llm(
query=self.question,
history=self.message_history,

View File

@ -118,8 +118,6 @@ class EmbeddingModel:
text_batches = batch_list(texts, batch_size)
embeddings: list[list[float]] = []
for idx, text_batch in enumerate(text_batches, start=1):
logger.debug(f"Embedding Content Texts batch {idx} of {len(text_batches)}")
embed_request = EmbedRequest(
model_name=self.model_name,
texts=text_batch,

View File

@ -206,6 +206,7 @@ def stream_answer_objects(
single_message_history=history_str,
tools=[search_tool],
force_use_tool=ForceUseTool(
force_use=True,
tool_name=search_tool.name,
args={"query": rephrased_query},
),

View File

@ -94,7 +94,7 @@ def history_based_query_rephrase(
llm: LLM,
size_heuristic: int = 200,
punctuation_heuristic: int = 10,
skip_first_rephrase: bool = False,
skip_first_rephrase: bool = True,
prompt_template: str = HISTORY_QUERY_REPHRASE,
) -> str:
# Globally disabled, just use the exact user query

View File

@ -90,7 +90,7 @@ class CreateChatMessageRequest(ChunkContext):
parent_message_id: int | None
# New message contents
message: str
# file's that we should attach to this message
# Files that we should attach to this message
file_descriptors: list[FileDescriptor]
# If no prompt provided, uses the largest prompt of the chat session
# but really this should be explicitly specified, only in the simplified APIs is this inferred

View File

@ -1,13 +1,15 @@
from typing import Any
from langchain_core.messages import AIMessage
from langchain_core.messages import BaseMessage
from pydantic import BaseModel
from danswer.tools.tool import Tool
class ForceUseTool(BaseModel):
# Could be not a forced usage of the tool but still have args, in which case
# if the tool is called, then those args are applied instead of what the LLM
# wanted to call it with
force_use: bool
tool_name: str
args: dict[str, Any] | None = None
@ -16,25 +18,10 @@ class ForceUseTool(BaseModel):
return {"type": "function", "function": {"name": self.tool_name}}
def modify_message_chain_for_force_use_tool(
messages: list[BaseMessage], force_use_tool: ForceUseTool | None = None
) -> list[BaseMessage]:
"""NOTE: modifies `messages` in place."""
if not force_use_tool:
return messages
for message in messages:
if isinstance(message, AIMessage) and message.tool_calls:
for tool_call in message.tool_calls:
tool_call["args"] = force_use_tool.args or {}
return messages
def filter_tools_for_force_tool_use(
tools: list[Tool], force_use_tool: ForceUseTool | None = None
tools: list[Tool], force_use_tool: ForceUseTool
) -> list[Tool]:
if not force_use_tool:
if not force_use_tool.force_use:
return tools
return [tool for tool in tools if tool.name == force_use_tool.tool_name]