Don't rephrase first chat query (#1907)

This commit is contained in:
Yuhong Sun
2024-07-23 16:20:11 -07:00
committed by GitHub
parent 866bc803b1
commit 2470c68506
8 changed files with 63 additions and 63 deletions

View File

@ -11,8 +11,8 @@ import sqlalchemy as sa
revision = "795b20b85b4b" revision = "795b20b85b4b"
down_revision = "05c07bf07c00" down_revision = "05c07bf07c00"
branch_labels = None branch_labels: None = None
depends_on = None depends_on: None = None
def upgrade() -> None: def upgrade() -> None:

View File

@ -187,37 +187,46 @@ def _handle_internet_search_tool_response_summary(
) )
def _check_should_force_search( def _get_force_search_settings(
new_msg_req: CreateChatMessageRequest, new_msg_req: CreateChatMessageRequest, tools: list[Tool]
) -> ForceUseTool | None: ) -> ForceUseTool:
# If files are already provided, don't run the search tool internet_search_available = any(
isinstance(tool, InternetSearchTool) for tool in tools
)
search_tool_available = any(isinstance(tool, SearchTool) for tool in tools)
if not internet_search_available and not search_tool_available:
# Does not matter much which tool is set here as force is false and neither tool is available
return ForceUseTool(force_use=False, tool_name=SearchTool._NAME)
tool_name = SearchTool._NAME if search_tool_available else InternetSearchTool._NAME
# Currently, the internet search tool does not support query override
args = (
{"query": new_msg_req.query_override}
if new_msg_req.query_override and tool_name == SearchTool._NAME
else None
)
if new_msg_req.file_descriptors: if new_msg_req.file_descriptors:
return None # If user has uploaded files they're using, don't run any of the search tools
return ForceUseTool(force_use=False, tool_name=tool_name)
if ( should_force_search = any(
new_msg_req.query_override [
or (
new_msg_req.retrieval_options new_msg_req.retrieval_options
and new_msg_req.retrieval_options.run_search == OptionalSearchSetting.ALWAYS and new_msg_req.retrieval_options.run_search
) == OptionalSearchSetting.ALWAYS,
or new_msg_req.search_doc_ids new_msg_req.search_doc_ids,
or DISABLE_LLM_CHOOSE_SEARCH DISABLE_LLM_CHOOSE_SEARCH,
): ]
args = ( )
{"query": new_msg_req.query_override}
if new_msg_req.query_override
else None
)
# if we are using selected docs, just put something here so the Tool doesn't need
# to build its own args via an LLM call
if new_msg_req.search_doc_ids:
args = {"query": new_msg_req.message}
return ForceUseTool( if should_force_search:
tool_name=SearchTool._NAME, # If we are using selected docs, just put something here so the Tool doesn't need to build its own args via an LLM call
args=args, args = {"query": new_msg_req.message} if new_msg_req.search_doc_ids else args
) return ForceUseTool(force_use=True, tool_name=tool_name, args=args)
return None
return ForceUseTool(force_use=False, tool_name=tool_name, args=args)
ChatPacket = ( ChatPacket = (
@ -360,6 +369,14 @@ def stream_chat_message_objects(
"when the last message is not a user message." "when the last message is not a user message."
) )
# Disable Query Rephrasing for the first message
# This leads to a better first response since the LLM rephrasing the question
# leads to worst search quality
if not history_msgs:
new_msg_req.query_override = (
new_msg_req.query_override or new_msg_req.message
)
# load all files needed for this chat chain in memory # load all files needed for this chat chain in memory
files = load_all_chat_files( files = load_all_chat_files(
history_msgs, new_msg_req.file_descriptors, db_session history_msgs, new_msg_req.file_descriptors, db_session
@ -575,11 +592,7 @@ def stream_chat_message_objects(
PreviousMessage.from_chat_message(msg, files) for msg in history_msgs PreviousMessage.from_chat_message(msg, files) for msg in history_msgs
], ],
tools=tools, tools=tools,
force_use_tool=( force_use_tool=_get_force_search_settings(new_msg_req, tools),
_check_should_force_search(new_msg_req)
if search_tool and len(tools) == 1
else None
),
) )
reference_db_search_docs = None reference_db_search_docs = None

View File

@ -99,6 +99,7 @@ class Answer:
answer_style_config: AnswerStyleConfig, answer_style_config: AnswerStyleConfig,
llm: LLM, llm: LLM,
prompt_config: PromptConfig, prompt_config: PromptConfig,
force_use_tool: ForceUseTool,
# must be the same length as `docs`. If None, all docs are considered "relevant" # must be the same length as `docs`. If None, all docs are considered "relevant"
message_history: list[PreviousMessage] | None = None, message_history: list[PreviousMessage] | None = None,
single_message_history: str | None = None, single_message_history: str | None = None,
@ -107,10 +108,8 @@ class Answer:
latest_query_files: list[InMemoryChatFile] | None = None, latest_query_files: list[InMemoryChatFile] | None = None,
files: list[InMemoryChatFile] | None = None, files: list[InMemoryChatFile] | None = None,
tools: list[Tool] | None = None, tools: list[Tool] | None = None,
# if specified, tells the LLM to always this tool
# NOTE: for native tool-calling, this is only supported by OpenAI atm, # NOTE: for native tool-calling, this is only supported by OpenAI atm,
# but we only support them anyways # but we only support them anyways
force_use_tool: ForceUseTool | None = None,
# if set to True, then never use the LLMs provided tool-calling functonality # if set to True, then never use the LLMs provided tool-calling functonality
skip_explicit_tool_calling: bool = False, skip_explicit_tool_calling: bool = False,
# Returns the full document sections text from the search tool # Returns the full document sections text from the search tool
@ -129,6 +128,7 @@ class Answer:
self.tools = tools or [] self.tools = tools or []
self.force_use_tool = force_use_tool self.force_use_tool = force_use_tool
self.skip_explicit_tool_calling = skip_explicit_tool_calling self.skip_explicit_tool_calling = skip_explicit_tool_calling
self.message_history = message_history or [] self.message_history = message_history or []
@ -187,7 +187,7 @@ class Answer:
prompt_builder = AnswerPromptBuilder(self.message_history, self.llm.config) prompt_builder = AnswerPromptBuilder(self.message_history, self.llm.config)
tool_call_chunk: AIMessageChunk | None = None tool_call_chunk: AIMessageChunk | None = None
if self.force_use_tool and self.force_use_tool.args is not None: if self.force_use_tool.force_use and self.force_use_tool.args is not None:
# if we are forcing a tool WITH args specified, we don't need to check which tools to run # if we are forcing a tool WITH args specified, we don't need to check which tools to run
# / need to generate the args # / need to generate the args
tool_call_chunk = AIMessageChunk( tool_call_chunk = AIMessageChunk(
@ -221,7 +221,7 @@ class Answer:
for message in self.llm.stream( for message in self.llm.stream(
prompt=prompt, prompt=prompt,
tools=final_tool_definitions if final_tool_definitions else None, tools=final_tool_definitions if final_tool_definitions else None,
tool_choice="required" if self.force_use_tool else None, tool_choice="required" if self.force_use_tool.force_use else None,
): ):
if isinstance(message, AIMessageChunk) and ( if isinstance(message, AIMessageChunk) and (
message.tool_call_chunks or message.tool_calls message.tool_call_chunks or message.tool_calls
@ -245,7 +245,8 @@ class Answer:
][0] ][0]
tool_args = ( tool_args = (
self.force_use_tool.args self.force_use_tool.args
if self.force_use_tool and self.force_use_tool.args if self.force_use_tool.tool_name == tool.name
and self.force_use_tool.args
else tool_call_request["args"] else tool_call_request["args"]
) )
@ -303,7 +304,7 @@ class Answer:
tool_args = ( tool_args = (
self.force_use_tool.args self.force_use_tool.args
if self.force_use_tool.args if self.force_use_tool.args is not None
else tool.get_args_for_non_tool_calling_llm( else tool.get_args_for_non_tool_calling_llm(
query=self.question, query=self.question,
history=self.message_history, history=self.message_history,

View File

@ -118,8 +118,6 @@ class EmbeddingModel:
text_batches = batch_list(texts, batch_size) text_batches = batch_list(texts, batch_size)
embeddings: list[list[float]] = [] embeddings: list[list[float]] = []
for idx, text_batch in enumerate(text_batches, start=1): for idx, text_batch in enumerate(text_batches, start=1):
logger.debug(f"Embedding Content Texts batch {idx} of {len(text_batches)}")
embed_request = EmbedRequest( embed_request = EmbedRequest(
model_name=self.model_name, model_name=self.model_name,
texts=text_batch, texts=text_batch,

View File

@ -206,6 +206,7 @@ def stream_answer_objects(
single_message_history=history_str, single_message_history=history_str,
tools=[search_tool], tools=[search_tool],
force_use_tool=ForceUseTool( force_use_tool=ForceUseTool(
force_use=True,
tool_name=search_tool.name, tool_name=search_tool.name,
args={"query": rephrased_query}, args={"query": rephrased_query},
), ),

View File

@ -94,7 +94,7 @@ def history_based_query_rephrase(
llm: LLM, llm: LLM,
size_heuristic: int = 200, size_heuristic: int = 200,
punctuation_heuristic: int = 10, punctuation_heuristic: int = 10,
skip_first_rephrase: bool = False, skip_first_rephrase: bool = True,
prompt_template: str = HISTORY_QUERY_REPHRASE, prompt_template: str = HISTORY_QUERY_REPHRASE,
) -> str: ) -> str:
# Globally disabled, just use the exact user query # Globally disabled, just use the exact user query

View File

@ -90,7 +90,7 @@ class CreateChatMessageRequest(ChunkContext):
parent_message_id: int | None parent_message_id: int | None
# New message contents # New message contents
message: str message: str
# file's that we should attach to this message # Files that we should attach to this message
file_descriptors: list[FileDescriptor] file_descriptors: list[FileDescriptor]
# If no prompt provided, uses the largest prompt of the chat session # If no prompt provided, uses the largest prompt of the chat session
# but really this should be explicitly specified, only in the simplified APIs is this inferred # but really this should be explicitly specified, only in the simplified APIs is this inferred

View File

@ -1,13 +1,15 @@
from typing import Any from typing import Any
from langchain_core.messages import AIMessage
from langchain_core.messages import BaseMessage
from pydantic import BaseModel from pydantic import BaseModel
from danswer.tools.tool import Tool from danswer.tools.tool import Tool
class ForceUseTool(BaseModel): class ForceUseTool(BaseModel):
# Could be not a forced usage of the tool but still have args, in which case
# if the tool is called, then those args are applied instead of what the LLM
# wanted to call it with
force_use: bool
tool_name: str tool_name: str
args: dict[str, Any] | None = None args: dict[str, Any] | None = None
@ -16,25 +18,10 @@ class ForceUseTool(BaseModel):
return {"type": "function", "function": {"name": self.tool_name}} return {"type": "function", "function": {"name": self.tool_name}}
def modify_message_chain_for_force_use_tool(
messages: list[BaseMessage], force_use_tool: ForceUseTool | None = None
) -> list[BaseMessage]:
"""NOTE: modifies `messages` in place."""
if not force_use_tool:
return messages
for message in messages:
if isinstance(message, AIMessage) and message.tool_calls:
for tool_call in message.tool_calls:
tool_call["args"] = force_use_tool.args or {}
return messages
def filter_tools_for_force_tool_use( def filter_tools_for_force_tool_use(
tools: list[Tool], force_use_tool: ForceUseTool | None = None tools: list[Tool], force_use_tool: ForceUseTool
) -> list[Tool]: ) -> list[Tool]:
if not force_use_tool: if not force_use_tool.force_use:
return tools return tools
return [tool for tool in tools if tool.name == force_use_tool.tool_name] return [tool for tool in tools if tool.name == force_use_tool.tool_name]