Fix answer with specified doc ids (#2703)

* Fix

Fix

Refactor

more

more

fix

refactor

Fix circular imports

Refactor

Move tests around

* Add quote support

* Testing

* More testing

* Fix image generation slowness

* Remove unused exception

* Fix UT

* fix stop generating

* minor typo

* minor logging updates for clarity

---------

Co-authored-by: pablodanswer <pablo@danswer.ai>
This commit is contained in:
Chris Weaver
2024-11-01 12:50:20 -07:00
committed by GitHub
parent d66b81a902
commit ecf4923a3a
41 changed files with 1986 additions and 1109 deletions

View File

@@ -10,7 +10,7 @@ from danswer.search.enums import QueryFlow
from danswer.search.enums import SearchType
from danswer.search.models import RetrievalDocs
from danswer.search.models import SearchResponse
from danswer.tools.custom.base_tool_types import ToolResultType
from danswer.tools.tool_implementations.custom.base_tool_types import ToolResultType
class LlmDoc(BaseModel):

View File

@@ -18,6 +18,7 @@ from danswer.chat.models import MessageResponseIDInfo
from danswer.chat.models import MessageSpecificCitations
from danswer.chat.models import QADocsResponse
from danswer.chat.models import StreamingError
from danswer.chat.models import StreamStopInfo
from danswer.configs.app_configs import AZURE_DALLE_API_BASE
from danswer.configs.app_configs import AZURE_DALLE_API_KEY
from danswer.configs.app_configs import AZURE_DALLE_API_VERSION
@@ -77,31 +78,49 @@ from danswer.server.query_and_chat.models import ChatMessageDetail
from danswer.server.query_and_chat.models import CreateChatMessageRequest
from danswer.server.utils import get_json_line
from danswer.tools.built_in_tools import get_built_in_tool_by_id
from danswer.tools.custom.custom_tool import (
from danswer.tools.force import ForceUseTool
from danswer.tools.models import DynamicSchemaInfo
from danswer.tools.models import ToolResponse
from danswer.tools.tool import Tool
from danswer.tools.tool_implementations.custom.custom_tool import (
build_custom_tools_from_openapi_schema_and_headers,
)
from danswer.tools.custom.custom_tool import CUSTOM_TOOL_RESPONSE_ID
from danswer.tools.custom.custom_tool import CustomToolCallSummary
from danswer.tools.force import ForceUseTool
from danswer.tools.images.image_generation_tool import IMAGE_GENERATION_RESPONSE_ID
from danswer.tools.images.image_generation_tool import ImageGenerationResponse
from danswer.tools.images.image_generation_tool import ImageGenerationTool
from danswer.tools.internet_search.internet_search_tool import (
from danswer.tools.tool_implementations.custom.custom_tool import (
CUSTOM_TOOL_RESPONSE_ID,
)
from danswer.tools.tool_implementations.custom.custom_tool import CustomToolCallSummary
from danswer.tools.tool_implementations.images.image_generation_tool import (
IMAGE_GENERATION_RESPONSE_ID,
)
from danswer.tools.tool_implementations.images.image_generation_tool import (
ImageGenerationResponse,
)
from danswer.tools.tool_implementations.images.image_generation_tool import (
ImageGenerationTool,
)
from danswer.tools.tool_implementations.internet_search.internet_search_tool import (
INTERNET_SEARCH_RESPONSE_ID,
)
from danswer.tools.internet_search.internet_search_tool import (
from danswer.tools.tool_implementations.internet_search.internet_search_tool import (
internet_search_response_to_search_docs,
)
from danswer.tools.internet_search.internet_search_tool import InternetSearchResponse
from danswer.tools.internet_search.internet_search_tool import InternetSearchTool
from danswer.tools.models import DynamicSchemaInfo
from danswer.tools.search.search_tool import FINAL_CONTEXT_DOCUMENTS_ID
from danswer.tools.search.search_tool import SEARCH_RESPONSE_SUMMARY_ID
from danswer.tools.search.search_tool import SearchResponseSummary
from danswer.tools.search.search_tool import SearchTool
from danswer.tools.search.search_tool import SECTION_RELEVANCE_LIST_ID
from danswer.tools.tool import Tool
from danswer.tools.tool import ToolResponse
from danswer.tools.tool_implementations.internet_search.internet_search_tool import (
InternetSearchResponse,
)
from danswer.tools.tool_implementations.internet_search.internet_search_tool import (
InternetSearchTool,
)
from danswer.tools.tool_implementations.search.search_tool import (
FINAL_CONTEXT_DOCUMENTS_ID,
)
from danswer.tools.tool_implementations.search.search_tool import (
SEARCH_RESPONSE_SUMMARY_ID,
)
from danswer.tools.tool_implementations.search.search_tool import SearchResponseSummary
from danswer.tools.tool_implementations.search.search_tool import SearchTool
from danswer.tools.tool_implementations.search.search_tool import (
SECTION_RELEVANCE_LIST_ID,
)
from danswer.tools.tool_runner import ToolCallFinalResult
from danswer.tools.utils import compute_all_tool_tokens
from danswer.tools.utils import explicit_tool_calling_supported
@@ -260,6 +279,7 @@ ChatPacket = (
| CustomToolResponse
| MessageSpecificCitations
| MessageResponseIDInfo
| StreamStopInfo
)
ChatPacketStream = Iterator[ChatPacket]
@@ -532,6 +552,13 @@ def stream_chat_message_objects(
if not persona
else PromptConfig.from_model(persona.prompts[0])
)
answer_style_config = AnswerStyleConfig(
citation_config=CitationConfig(
all_docs_useful=selected_db_search_docs is not None
),
document_pruning_config=document_pruning_config,
structured_response_format=new_msg_req.structured_response_format,
)
# find out what tools to use
search_tool: SearchTool | None = None
@@ -550,13 +577,16 @@ def stream_chat_message_objects(
llm=llm,
fast_llm=fast_llm,
pruning_config=document_pruning_config,
answer_style_config=answer_style_config,
selected_sections=selected_sections,
chunks_above=new_msg_req.chunks_above,
chunks_below=new_msg_req.chunks_below,
full_doc=new_msg_req.full_doc,
evaluation_type=LLMEvaluationType.BASIC
if persona.llm_relevance_filter
else LLMEvaluationType.SKIP,
evaluation_type=(
LLMEvaluationType.BASIC
if persona.llm_relevance_filter
else LLMEvaluationType.SKIP
),
)
tool_dict[db_tool_model.id] = [search_tool]
elif tool_cls.__name__ == ImageGenerationTool.__name__:
@@ -626,7 +656,11 @@ def stream_chat_message_objects(
"Internet search tool requires a Bing API key, please contact your Danswer admin to get it added!"
)
tool_dict[db_tool_model.id] = [
InternetSearchTool(api_key=bing_api_key)
InternetSearchTool(
api_key=bing_api_key,
answer_style_config=answer_style_config,
prompt_config=prompt_config,
)
]
continue
@@ -667,13 +701,7 @@ def stream_chat_message_objects(
is_connected=is_connected,
question=final_msg.message,
latest_query_files=latest_query_files,
answer_style_config=AnswerStyleConfig(
citation_config=CitationConfig(
all_docs_useful=selected_db_search_docs is not None
),
document_pruning_config=document_pruning_config,
structured_response_format=new_msg_req.structured_response_format,
),
answer_style_config=answer_style_config,
prompt_config=prompt_config,
llm=(
llm
@@ -777,7 +805,8 @@ def stream_chat_message_objects(
response=custom_tool_response.tool_result,
tool_name=custom_tool_response.tool_name,
)
elif isinstance(packet, StreamStopInfo):
pass
else:
if isinstance(packet, ToolCallFinalResult):
tool_result = packet
@@ -807,6 +836,7 @@ def stream_chat_message_objects(
# Post-LLM answer processing
try:
logger.debug("Post-LLM answer processing")
message_specific_citations: MessageSpecificCitations | None = None
if reference_db_search_docs:
message_specific_citations = _translate_citations(

View File

@@ -1,72 +1,44 @@
import itertools
from collections.abc import Callable
from collections.abc import Iterator
from typing import Any
from typing import cast
from uuid import uuid4
from langchain.schema.messages import BaseMessage
from langchain_core.messages import AIMessageChunk
from langchain_core.messages import HumanMessage
from langchain_core.messages import ToolCall
from danswer.chat.chat_utils import llm_doc_from_inference_section
from danswer.chat.models import AnswerQuestionPossibleReturn
from danswer.chat.models import CitationInfo
from danswer.chat.models import DanswerAnswerPiece
from danswer.chat.models import LlmDoc
from danswer.chat.models import StreamStopInfo
from danswer.chat.models import StreamStopReason
from danswer.configs.chat_configs import QA_PROMPT_OVERRIDE
from danswer.file_store.utils import InMemoryChatFile
from danswer.llm.answering.llm_response_handler import LLMCall
from danswer.llm.answering.llm_response_handler import LLMResponseHandlerManager
from danswer.llm.answering.models import AnswerStyleConfig
from danswer.llm.answering.models import PreviousMessage
from danswer.llm.answering.models import PromptConfig
from danswer.llm.answering.models import StreamProcessor
from danswer.llm.answering.prompts.build import AnswerPromptBuilder
from danswer.llm.answering.prompts.build import default_build_system_message
from danswer.llm.answering.prompts.build import default_build_user_message
from danswer.llm.answering.prompts.citations_prompt import (
build_citations_system_message,
from danswer.llm.answering.stream_processing.answer_response_handler import (
AnswerResponseHandler,
)
from danswer.llm.answering.prompts.citations_prompt import build_citations_user_message
from danswer.llm.answering.prompts.quotes_prompt import build_quotes_user_message
from danswer.llm.answering.stream_processing.citation_processing import (
build_citation_processor,
from danswer.llm.answering.stream_processing.answer_response_handler import (
CitationResponseHandler,
)
from danswer.llm.answering.stream_processing.quotes_processing import (
build_quotes_processor,
from danswer.llm.answering.stream_processing.answer_response_handler import (
DummyAnswerResponseHandler,
)
from danswer.llm.answering.stream_processing.answer_response_handler import (
QuotesResponseHandler,
)
from danswer.llm.answering.stream_processing.utils import DocumentIdOrderMapping
from danswer.llm.answering.stream_processing.utils import map_document_id_order
from danswer.llm.answering.tool.tool_response_handler import ToolResponseHandler
from danswer.llm.interfaces import LLM
from danswer.llm.interfaces import ToolChoiceOptions
from danswer.natural_language_processing.utils import get_tokenizer
from danswer.tools.custom.custom_tool_prompt_builder import (
build_user_message_for_custom_tool_for_non_tool_calling_llm,
)
from danswer.tools.force import filter_tools_for_force_tool_use
from danswer.tools.force import ForceUseTool
from danswer.tools.images.image_generation_tool import IMAGE_GENERATION_RESPONSE_ID
from danswer.tools.images.image_generation_tool import ImageGenerationResponse
from danswer.tools.images.image_generation_tool import ImageGenerationTool
from danswer.tools.images.prompt import build_image_generation_user_prompt
from danswer.tools.internet_search.internet_search_tool import InternetSearchTool
from danswer.tools.message import build_tool_message
from danswer.tools.message import ToolCallSummary
from danswer.tools.search.search_tool import FINAL_CONTEXT_DOCUMENTS_ID
from danswer.tools.search.search_tool import SEARCH_DOC_CONTENT_ID
from danswer.tools.search.search_tool import SEARCH_RESPONSE_SUMMARY_ID
from danswer.tools.search.search_tool import SearchResponseSummary
from danswer.tools.search.search_tool import SearchTool
from danswer.tools.models import ToolResponse
from danswer.tools.tool import Tool
from danswer.tools.tool import ToolResponse
from danswer.tools.tool_runner import (
check_which_tools_should_run_for_non_tool_calling_llm,
)
from danswer.tools.tool_runner import ToolCallFinalResult
from danswer.tools.tool_implementations.search.search_tool import SearchTool
from danswer.tools.tool_runner import ToolCallKickoff
from danswer.tools.tool_runner import ToolRunner
from danswer.tools.tool_selection import select_single_tool_for_non_tool_calling_llm
from danswer.tools.utils import explicit_tool_calling_supported
from danswer.utils.logger import setup_logger
@@ -74,29 +46,9 @@ from danswer.utils.logger import setup_logger
logger = setup_logger()
def _get_answer_stream_processor(
context_docs: list[LlmDoc],
doc_id_to_rank_map: DocumentIdOrderMapping,
answer_style_configs: AnswerStyleConfig,
) -> StreamProcessor:
if answer_style_configs.citation_config:
return build_citation_processor(
context_docs=context_docs, doc_id_to_rank_map=doc_id_to_rank_map
)
if answer_style_configs.quotes_config:
return build_quotes_processor(
context_docs=context_docs, is_json_prompt=not (QA_PROMPT_OVERRIDE == "weak")
)
raise RuntimeError("Not implemented yet")
AnswerStream = Iterator[AnswerQuestionPossibleReturn | ToolCallKickoff | ToolResponse]
logger = setup_logger()
class Answer:
def __init__(
self,
@@ -136,8 +88,6 @@ class Answer:
self.tools = tools or []
self.force_use_tool = force_use_tool
self.skip_explicit_tool_calling = skip_explicit_tool_calling
self.message_history = message_history or []
# used for QA flow where we only want to send a single message
self.single_message_history = single_message_history
@@ -162,335 +112,141 @@ class Answer:
self.skip_gen_ai_answer_generation = skip_gen_ai_answer_generation
self._is_cancelled = False
def _update_prompt_builder_for_search_tool(
self, prompt_builder: AnswerPromptBuilder, final_context_documents: list[LlmDoc]
) -> None:
if self.answer_style_config.citation_config:
prompt_builder.update_system_prompt(
build_citations_system_message(self.prompt_config)
)
prompt_builder.update_user_prompt(
build_citations_user_message(
question=self.question,
prompt_config=self.prompt_config,
context_docs=final_context_documents,
files=self.latest_query_files,
all_doc_useful=(
self.answer_style_config.citation_config.all_docs_useful
if self.answer_style_config.citation_config
else False
),
history_message=self.single_message_history or "",
)
)
elif self.answer_style_config.quotes_config:
prompt_builder.update_user_prompt(
build_quotes_user_message(
question=self.question,
context_docs=final_context_documents,
history_str=self.single_message_history or "",
prompt=self.prompt_config,
)
self.using_tool_calling_llm = (
explicit_tool_calling_supported(
self.llm.config.model_provider, self.llm.config.model_name
)
and not skip_explicit_tool_calling
)
def _raw_output_for_explicit_tool_calling_llms(
self,
) -> Iterator[
str | StreamStopInfo | ToolCallKickoff | ToolResponse | ToolCallFinalResult
]:
prompt_builder = AnswerPromptBuilder(self.message_history, self.llm.config)
def _get_tools_list(self) -> list[Tool]:
if not self.force_use_tool.force_use:
return self.tools
tool_call_chunk: AIMessageChunk | None = None
if self.force_use_tool.force_use and self.force_use_tool.args is not None:
# if we are forcing a tool WITH args specified, we don't need to check which tools to run
# / need to generate the args
tool_call_chunk = AIMessageChunk(
content="",
tool = next(
(t for t in self.tools if t.name == self.force_use_tool.tool_name), None
)
if tool is None:
raise RuntimeError(f"Tool '{self.force_use_tool.tool_name}' not found")
logger.info(
f"Forcefully using tool='{tool.name}'"
+ (
f" with args='{self.force_use_tool.args}'"
if self.force_use_tool.args is not None
else ""
)
tool_call_chunk.tool_calls = [
{
"name": self.force_use_tool.tool_name,
"args": self.force_use_tool.args,
"id": str(uuid4()),
}
]
)
return [tool]
def _handle_specified_tool_call(
self, llm_calls: list[LLMCall], tool: Tool, tool_args: dict
) -> AnswerStream:
current_llm_call = llm_calls[-1]
# make a dummy tool handler
tool_handler = ToolResponseHandler([tool])
dummy_tool_call_chunk = AIMessageChunk(content="")
dummy_tool_call_chunk.tool_calls = [
ToolCall(name=tool.name, args=tool_args, id=str(uuid4()))
]
response_handler_manager = LLMResponseHandlerManager(
tool_handler, DummyAnswerResponseHandler(), self.is_cancelled
)
yield from response_handler_manager.handle_llm_response(
iter([dummy_tool_call_chunk])
)
new_llm_call = response_handler_manager.next_llm_call(current_llm_call)
if new_llm_call:
yield from self._get_response(llm_calls + [new_llm_call])
else:
# if tool calling is supported, first try the raw message
# to see if we don't need to use any tools
prompt_builder.update_system_prompt(
default_build_system_message(self.prompt_config)
)
prompt_builder.update_user_prompt(
default_build_user_message(
self.question, self.prompt_config, self.latest_query_files
)
)
prompt = prompt_builder.build()
final_tool_definitions = [
tool.tool_definition()
for tool in filter_tools_for_force_tool_use(
self.tools, self.force_use_tool
)
]
raise RuntimeError("Tool call handler did not return a new LLM call")
for message in self.llm.stream(
prompt=prompt,
tools=final_tool_definitions if final_tool_definitions else None,
tool_choice="required" if self.force_use_tool.force_use else None,
structured_response_format=self.answer_style_config.structured_response_format,
):
if isinstance(message, AIMessageChunk) and (
message.tool_call_chunks or message.tool_calls
):
if tool_call_chunk is None:
tool_call_chunk = message
else:
tool_call_chunk += message # type: ignore
else:
if message.content:
if self.is_cancelled:
return
yield cast(str, message.content)
if (
message.additional_kwargs.get("usage_metadata", {}).get("stop")
== "length"
):
yield StreamStopInfo(
stop_reason=StreamStopReason.CONTEXT_LENGTH
)
def _get_response(self, llm_calls: list[LLMCall]) -> AnswerStream:
current_llm_call = llm_calls[-1]
if not tool_call_chunk:
return # no tool call needed
# if we have a tool call, we need to call the tool
tool_call_requests = tool_call_chunk.tool_calls
for tool_call_request in tool_call_requests:
known_tools_by_name = [
tool for tool in self.tools if tool.name == tool_call_request["name"]
]
if not known_tools_by_name:
logger.error(
"Tool call requested with unknown name field. \n"
f"self.tools: {self.tools}"
f"tool_call_request: {tool_call_request}"
)
if self.tools:
tool = self.tools[0]
else:
continue
else:
tool = known_tools_by_name[0]
tool_args = (
self.force_use_tool.args
if self.force_use_tool.tool_name == tool.name
and self.force_use_tool.args
else tool_call_request["args"]
)
tool_runner = ToolRunner(tool, tool_args)
yield tool_runner.kickoff()
yield from tool_runner.tool_responses()
tool_call_summary = ToolCallSummary(
tool_call_request=tool_call_chunk,
tool_call_result=build_tool_message(
tool_call_request, tool_runner.tool_message_content()
),
)
if tool.name in {SearchTool._NAME, InternetSearchTool._NAME}:
self._update_prompt_builder_for_search_tool(prompt_builder, [])
elif tool.name == ImageGenerationTool._NAME:
img_urls = [
img_generation_result["url"]
for img_generation_result in tool_runner.tool_final_result().tool_result
]
prompt_builder.update_user_prompt(
build_image_generation_user_prompt(
query=self.question, img_urls=img_urls
)
)
yield tool_runner.tool_final_result()
if not self.skip_gen_ai_answer_generation:
prompt = prompt_builder.build(tool_call_summary=tool_call_summary)
yield from self._process_llm_stream(
prompt=prompt,
# as of now, we don't support multiple tool calls in sequence, which is why
# we don't need to pass this in here
# tools=[tool.tool_definition() for tool in self.tools],
)
return
# This method processes the LLM stream and yields the content or stop information
def _process_llm_stream(
self,
prompt: Any,
tools: list[dict] | None = None,
tool_choice: ToolChoiceOptions | None = None,
) -> Iterator[str | StreamStopInfo]:
for message in self.llm.stream(
prompt=prompt,
tools=tools,
tool_choice=tool_choice,
structured_response_format=self.answer_style_config.structured_response_format,
# handle the case where no decision has to be made; we simply run the tool
if (
current_llm_call.force_use_tool.force_use
and current_llm_call.force_use_tool.args is not None
):
if isinstance(message, AIMessageChunk):
if message.content:
if self.is_cancelled:
return StreamStopInfo(stop_reason=StreamStopReason.CANCELLED)
yield cast(str, message.content)
if (
message.additional_kwargs.get("usage_metadata", {}).get("stop")
== "length"
):
yield StreamStopInfo(stop_reason=StreamStopReason.CONTEXT_LENGTH)
def _raw_output_for_non_explicit_tool_calling_llms(
self,
) -> Iterator[
str | StreamStopInfo | ToolCallKickoff | ToolResponse | ToolCallFinalResult
]:
prompt_builder = AnswerPromptBuilder(self.message_history, self.llm.config)
chosen_tool_and_args: tuple[Tool, dict] | None = None
if self.force_use_tool.force_use:
# if we are forcing a tool, we don't need to check which tools to run
tool_name, tool_args = (
current_llm_call.force_use_tool.tool_name,
current_llm_call.force_use_tool.args,
)
tool = next(
iter(
[
tool
for tool in self.tools
if tool.name == self.force_use_tool.tool_name
]
),
None,
(t for t in current_llm_call.tools if t.name == tool_name), None
)
if not tool:
raise RuntimeError(f"Tool '{self.force_use_tool.tool_name}' not found")
raise RuntimeError(f"Tool '{tool_name}' not found")
tool_args = (
self.force_use_tool.args
if self.force_use_tool.args is not None
else tool.get_args_for_non_tool_calling_llm(
query=self.question,
history=self.message_history,
llm=self.llm,
force_run=True,
)
)
if tool_args is None:
raise RuntimeError(f"Tool '{tool.name}' did not return args")
chosen_tool_and_args = (tool, tool_args)
else:
tool_options = check_which_tools_should_run_for_non_tool_calling_llm(
tools=self.tools,
query=self.question,
history=self.message_history,
llm=self.llm,
)
available_tools_and_args = [
(self.tools[ind], args)
for ind, args in enumerate(tool_options)
if args is not None
]
logger.info(
f"Selecting single tool from tools: {[(tool.name, args) for tool, args in available_tools_and_args]}"
)
chosen_tool_and_args = (
select_single_tool_for_non_tool_calling_llm(
tools_and_args=available_tools_and_args,
history=self.message_history,
query=self.question,
llm=self.llm,
)
if available_tools_and_args
else None
)
logger.notice(f"Chosen tool: {chosen_tool_and_args}")
if not chosen_tool_and_args:
if self.skip_gen_ai_answer_generation:
raise ValueError(
"skip_gen_ai_answer_generation is True, but no tool was chosen; no answer will be generated"
)
prompt_builder.update_system_prompt(
default_build_system_message(self.prompt_config)
)
prompt_builder.update_user_prompt(
default_build_user_message(
self.question, self.prompt_config, self.latest_query_files
)
)
prompt = prompt_builder.build()
yield from self._process_llm_stream(
prompt=prompt,
tools=None,
)
yield from self._handle_specified_tool_call(llm_calls, tool, tool_args)
return
tool, tool_args = chosen_tool_and_args
tool_runner = ToolRunner(tool, tool_args)
yield tool_runner.kickoff()
if tool.name in {SearchTool._NAME, InternetSearchTool._NAME}:
final_context_documents = None
for response in tool_runner.tool_responses():
if response.id == FINAL_CONTEXT_DOCUMENTS_ID:
final_context_documents = cast(list[LlmDoc], response.response)
yield response
if final_context_documents is None:
raise RuntimeError(
f"{tool.name} did not return final context documents"
# special pre-logic for non-tool calling LLM case
if not self.using_tool_calling_llm and current_llm_call.tools:
chosen_tool_and_args = (
ToolResponseHandler.get_tool_call_for_non_tool_calling_llm(
current_llm_call, self.llm
)
self._update_prompt_builder_for_search_tool(
prompt_builder, final_context_documents
)
elif tool.name == ImageGenerationTool._NAME:
img_urls = []
for response in tool_runner.tool_responses():
if response.id == IMAGE_GENERATION_RESPONSE_ID:
img_generation_response = cast(
list[ImageGenerationResponse], response.response
)
img_urls = [img.url for img in img_generation_response]
if chosen_tool_and_args:
tool, tool_args = chosen_tool_and_args
yield from self._handle_specified_tool_call(llm_calls, tool, tool_args)
return
yield response
# if we're skipping gen ai answer generation, we should break
# out unless we're forcing a tool call. If we don't, we might generate an
# answer, which is a no-no!
if (
self.skip_gen_ai_answer_generation
and not current_llm_call.force_use_tool.force_use
):
return
prompt_builder.update_user_prompt(
build_image_generation_user_prompt(
query=self.question,
img_urls=img_urls,
)
# set up "handlers" to listen to the LLM response stream and
# feed back the processed results + handle tool call requests
# + figure out what the next LLM call should be
tool_call_handler = ToolResponseHandler(current_llm_call.tools)
search_result = SearchTool.get_search_result(current_llm_call) or []
answer_handler: AnswerResponseHandler
if self.answer_style_config.citation_config:
answer_handler = CitationResponseHandler(
context_docs=search_result,
doc_id_to_rank_map=map_document_id_order(search_result),
)
elif self.answer_style_config.quotes_config:
answer_handler = QuotesResponseHandler(
context_docs=search_result,
)
else:
prompt_builder.update_user_prompt(
HumanMessage(
content=build_user_message_for_custom_tool_for_non_tool_calling_llm(
self.question,
tool.name,
*tool_runner.tool_responses(),
)
)
)
final = tool_runner.tool_final_result()
raise ValueError("No answer style config provided")
yield final
if not self.skip_gen_ai_answer_generation:
prompt = prompt_builder.build()
response_handler_manager = LLMResponseHandlerManager(
tool_call_handler, answer_handler, self.is_cancelled
)
yield from self._process_llm_stream(prompt=prompt, tools=None)
# DEBUG: good breakpoint
stream = self.llm.stream(
prompt=current_llm_call.prompt_builder.build(),
tools=[tool.tool_definition() for tool in current_llm_call.tools] or None,
tool_choice=(
"required"
if current_llm_call.tools and current_llm_call.force_use_tool.force_use
else None
),
structured_response_format=self.answer_style_config.structured_response_format,
)
yield from response_handler_manager.handle_llm_response(stream)
new_llm_call = response_handler_manager.next_llm_call(current_llm_call)
if new_llm_call:
yield from self._get_response(llm_calls + [new_llm_call])
@property
def processed_streamed_output(self) -> AnswerStream:
@@ -498,94 +254,30 @@ class Answer:
yield from self._processed_stream
return
output_generator = (
self._raw_output_for_explicit_tool_calling_llms()
if explicit_tool_calling_supported(
self.llm.config.model_provider, self.llm.config.model_name
)
and not self.skip_explicit_tool_calling
else self._raw_output_for_non_explicit_tool_calling_llms()
prompt_builder = AnswerPromptBuilder(
user_message=default_build_user_message(
user_query=self.question,
prompt_config=self.prompt_config,
files=self.latest_query_files,
),
message_history=self.message_history,
llm_config=self.llm.config,
single_message_history=self.single_message_history,
)
prompt_builder.update_system_prompt(
default_build_system_message(self.prompt_config)
)
llm_call = LLMCall(
prompt_builder=prompt_builder,
tools=self._get_tools_list(),
force_use_tool=self.force_use_tool,
files=self.latest_query_files,
tool_call_info=[],
using_tool_calling_llm=self.using_tool_calling_llm,
)
def _process_stream(
stream: Iterator[ToolCallKickoff | ToolResponse | str | StreamStopInfo],
) -> AnswerStream:
message = None
# special things we need to keep track of for the SearchTool
# raw results that will be displayed to the user
search_results: list[LlmDoc] | None = None
# processed docs to feed into the LLM
final_context_docs: list[LlmDoc] | None = None
for message in stream:
if isinstance(message, ToolCallKickoff) or isinstance(
message, ToolCallFinalResult
):
yield message
elif isinstance(message, ToolResponse):
if message.id == SEARCH_RESPONSE_SUMMARY_ID:
# We don't need to run section merging in this flow, this variable is only used
# below to specify the ordering of the documents for the purpose of matching
# citations to the right search documents. The deduplication logic is more lightweight
# there and we don't need to do it twice
search_results = [
llm_doc_from_inference_section(section)
for section in cast(
SearchResponseSummary, message.response
).top_sections
]
elif message.id == FINAL_CONTEXT_DOCUMENTS_ID:
final_context_docs = cast(list[LlmDoc], message.response)
yield message
elif (
message.id == SEARCH_DOC_CONTENT_ID
and not self._return_contexts
):
continue
yield message
else:
# assumes all tool responses will come first, then the final answer
break
if not self.skip_gen_ai_answer_generation:
process_answer_stream_fn = _get_answer_stream_processor(
context_docs=final_context_docs or [],
# if doc selection is enabled, then search_results will be None,
# so we need to use the final_context_docs
doc_id_to_rank_map=map_document_id_order(
search_results or final_context_docs or []
),
answer_style_configs=self.answer_style_config,
)
stream_stop_info = None
def _stream() -> Iterator[str]:
nonlocal stream_stop_info
for item in itertools.chain([message], stream):
if isinstance(item, StreamStopInfo):
stream_stop_info = item
return
# this should never happen, but we're seeing weird behavior here so handling for now
if not isinstance(item, str):
logger.error(
f"Received non-string item in answer stream: {item}. Skipping."
)
continue
yield item
yield from process_answer_stream_fn(_stream())
if stream_stop_info:
yield stream_stop_info
processed_stream = []
for processed_packet in _process_stream(output_generator):
for processed_packet in self._get_response([llm_call]):
processed_stream.append(processed_packet)
yield processed_packet
@@ -609,7 +301,6 @@ class Answer:
return citations
@property
def is_cancelled(self) -> bool:
if self._is_cancelled:
return True

View File

@@ -0,0 +1,84 @@
from collections.abc import Callable
from collections.abc import Generator
from collections.abc import Iterator
from typing import TYPE_CHECKING
from langchain_core.messages import BaseMessage
from pydantic.v1 import BaseModel as BaseModel__v1
from danswer.chat.models import CitationInfo
from danswer.chat.models import DanswerAnswerPiece
from danswer.chat.models import DanswerQuotes
from danswer.chat.models import StreamStopInfo
from danswer.chat.models import StreamStopReason
from danswer.file_store.models import InMemoryChatFile
from danswer.llm.answering.prompts.build import AnswerPromptBuilder
from danswer.tools.force import ForceUseTool
from danswer.tools.models import ToolCallFinalResult
from danswer.tools.models import ToolCallKickoff
from danswer.tools.models import ToolResponse
from danswer.tools.tool import Tool
if TYPE_CHECKING:
from danswer.llm.answering.stream_processing.answer_response_handler import (
AnswerResponseHandler,
)
from danswer.llm.answering.tool.tool_response_handler import ToolResponseHandler
ResponsePart = (
DanswerAnswerPiece
| CitationInfo
| DanswerQuotes
| ToolCallKickoff
| ToolResponse
| ToolCallFinalResult
| StreamStopInfo
)
class LLMCall(BaseModel__v1):
prompt_builder: AnswerPromptBuilder
tools: list[Tool]
force_use_tool: ForceUseTool
files: list[InMemoryChatFile]
tool_call_info: list[ToolCallKickoff | ToolResponse | ToolCallFinalResult]
using_tool_calling_llm: bool
class Config:
arbitrary_types_allowed = True
class LLMResponseHandlerManager:
def __init__(
self,
tool_handler: "ToolResponseHandler",
answer_handler: "AnswerResponseHandler",
is_cancelled: Callable[[], bool],
):
self.tool_handler = tool_handler
self.answer_handler = answer_handler
self.is_cancelled = is_cancelled
def handle_llm_response(
self,
stream: Iterator[BaseMessage],
) -> Generator[ResponsePart, None, None]:
all_messages: list[BaseMessage] = []
for message in stream:
if self.is_cancelled():
yield StreamStopInfo(stop_reason=StreamStopReason.CANCELLED)
return
# tool handler doesn't do anything until the full message is received
# NOTE: still need to run list() to get this to run
list(self.tool_handler.handle_response_part(message, all_messages))
yield from self.answer_handler.handle_response_part(message, all_messages)
all_messages.append(message)
# potentially give back all info on the selected tool call + its result
yield from self.tool_handler.handle_response_part(None, all_messages)
yield from self.answer_handler.handle_response_part(None, all_messages)
def next_llm_call(self, llm_call: LLMCall) -> LLMCall | None:
return self.tool_handler.next_llm_call(llm_call)

View File

@@ -12,12 +12,12 @@ from danswer.llm.answering.prompts.citations_prompt import compute_max_llm_input
from danswer.llm.interfaces import LLMConfig
from danswer.llm.utils import build_content_with_imgs
from danswer.llm.utils import check_message_tokens
from danswer.llm.utils import message_to_prompt_and_imgs
from danswer.llm.utils import translate_history_to_basemessages
from danswer.natural_language_processing.utils import get_tokenizer
from danswer.prompts.chat_prompts import CHAT_USER_CONTEXT_FREE_PROMPT
from danswer.prompts.prompt_utils import add_date_time_to_prompt
from danswer.prompts.prompt_utils import drop_messages_history_overflow
from danswer.tools.message import ToolCallSummary
def default_build_system_message(
@@ -54,18 +54,14 @@ def default_build_user_message(
class AnswerPromptBuilder:
def __init__(
self, message_history: list[PreviousMessage], llm_config: LLMConfig
self,
user_message: HumanMessage,
message_history: list[PreviousMessage],
llm_config: LLMConfig,
single_message_history: str | None = None,
) -> None:
self.max_tokens = compute_max_llm_input_tokens(llm_config)
(
self.message_history,
self.history_token_cnts,
) = translate_history_to_basemessages(message_history)
self.system_message_and_token_cnt: tuple[SystemMessage, int] | None = None
self.user_message_and_token_cnt: tuple[HumanMessage, int] | None = None
llm_tokenizer = get_tokenizer(
provider_type=llm_config.model_provider,
model_name=llm_config.model_name,
@@ -74,6 +70,24 @@ class AnswerPromptBuilder:
Callable[[str], list[int]], llm_tokenizer.encode
)
self.raw_message_history = message_history
(
self.message_history,
self.history_token_cnts,
) = translate_history_to_basemessages(message_history)
# for cases where like the QA flow where we want to condense the chat history
# into a single message rather than a sequence of User / Assistant messages
self.single_message_history = single_message_history
self.system_message_and_token_cnt: tuple[SystemMessage, int] | None = None
self.user_message_and_token_cnt = (
user_message,
check_message_tokens(user_message, self.llm_tokenizer_encode_func),
)
self.new_messages_and_token_cnts: list[tuple[BaseMessage, int]] = []
def update_system_prompt(self, system_message: SystemMessage | None) -> None:
if not system_message:
self.system_message_and_token_cnt = None
@@ -85,18 +99,21 @@ class AnswerPromptBuilder:
)
def update_user_prompt(self, user_message: HumanMessage) -> None:
if not user_message:
self.user_message_and_token_cnt = None
return
self.user_message_and_token_cnt = (
user_message,
check_message_tokens(user_message, self.llm_tokenizer_encode_func),
)
def build(
self, tool_call_summary: ToolCallSummary | None = None
) -> list[BaseMessage]:
def append_message(self, message: BaseMessage) -> None:
"""Append a new message to the message history."""
token_count = check_message_tokens(message, self.llm_tokenizer_encode_func)
self.new_messages_and_token_cnts.append((message, token_count))
def get_user_message_content(self) -> str:
query, _ = message_to_prompt_and_imgs(self.user_message_and_token_cnt[0])
return query
def build(self) -> list[BaseMessage]:
if not self.user_message_and_token_cnt:
raise ValueError("User message must be set before building prompt")
@@ -113,25 +130,8 @@ class AnswerPromptBuilder:
final_messages_with_tokens.append(self.user_message_and_token_cnt)
if tool_call_summary:
final_messages_with_tokens.append(
(
tool_call_summary.tool_call_request,
check_message_tokens(
tool_call_summary.tool_call_request,
self.llm_tokenizer_encode_func,
),
)
)
final_messages_with_tokens.append(
(
tool_call_summary.tool_call_result,
check_message_tokens(
tool_call_summary.tool_call_result,
self.llm_tokenizer_encode_func,
),
)
)
if self.new_messages_and_token_cnts:
final_messages_with_tokens.extend(self.new_messages_and_token_cnts)
return drop_messages_history_overflow(
final_messages_with_tokens, self.max_tokens

View File

@@ -6,7 +6,6 @@ from danswer.configs.model_configs import GEN_AI_SINGLE_USER_MESSAGE_EXPECTED_MA
from danswer.db.models import Persona
from danswer.db.persona import get_default_prompt__read_only
from danswer.db.search_settings import get_multilingual_expansion
from danswer.file_store.utils import InMemoryChatFile
from danswer.llm.answering.models import PromptConfig
from danswer.llm.factory import get_llms_for_persona
from danswer.llm.factory import get_main_llm_from_tuple
@@ -14,6 +13,7 @@ from danswer.llm.interfaces import LLMConfig
from danswer.llm.utils import build_content_with_imgs
from danswer.llm.utils import check_number_of_tokens
from danswer.llm.utils import get_max_input_tokens
from danswer.llm.utils import message_to_prompt_and_imgs
from danswer.prompts.chat_prompts import REQUIRE_CITATION_STATEMENT
from danswer.prompts.constants import DEFAULT_IGNORE_STATEMENT
from danswer.prompts.direct_qa_prompts import CITATIONS_PROMPT
@@ -132,10 +132,9 @@ def build_citations_system_message(
def build_citations_user_message(
question: str,
message: HumanMessage,
prompt_config: PromptConfig,
context_docs: list[LlmDoc] | list[InferenceChunk],
files: list[InMemoryChatFile],
all_doc_useful: bool,
history_message: str = "",
) -> HumanMessage:
@@ -149,6 +148,7 @@ def build_citations_user_message(
if history_message
else ""
)
query, img_urls = message_to_prompt_and_imgs(message)
if context_docs:
context_docs_str = build_complete_context_str(context_docs)
@@ -158,20 +158,22 @@ def build_citations_user_message(
optional_ignore_statement=optional_ignore,
context_docs_str=context_docs_str,
task_prompt=task_prompt_with_reminder,
user_query=question,
user_query=query,
history_block=history_block,
)
else:
# if no context docs provided, assume we're in the tool calling flow
user_prompt = CITATIONS_PROMPT_FOR_TOOL_CALLING.format(
task_prompt=task_prompt_with_reminder,
user_query=question,
user_query=query,
history_block=history_block,
)
user_prompt = user_prompt.strip()
user_msg = HumanMessage(
content=build_content_with_imgs(user_prompt, files) if files else user_prompt
content=build_content_with_imgs(user_prompt, img_urls=img_urls)
if img_urls
else user_prompt
)
return user_msg

View File

@@ -5,6 +5,7 @@ from danswer.configs.chat_configs import LANGUAGE_HINT
from danswer.configs.chat_configs import QA_PROMPT_OVERRIDE
from danswer.db.search_settings import get_multilingual_expansion
from danswer.llm.answering.models import PromptConfig
from danswer.llm.utils import message_to_prompt_and_imgs
from danswer.prompts.direct_qa_prompts import CONTEXT_BLOCK
from danswer.prompts.direct_qa_prompts import HISTORY_BLOCK
from danswer.prompts.direct_qa_prompts import JSON_PROMPT
@@ -75,7 +76,7 @@ def _build_strong_llm_quotes_prompt(
def build_quotes_user_message(
question: str,
message: HumanMessage,
context_docs: list[LlmDoc] | list[InferenceChunk],
history_str: str,
prompt: PromptConfig,
@@ -86,28 +87,10 @@ def build_quotes_user_message(
else _build_strong_llm_quotes_prompt
)
query, _ = message_to_prompt_and_imgs(message)
return prompt_builder(
question=question,
context_docs=context_docs,
history_str=history_str,
prompt=prompt,
)
def build_quotes_prompt(
question: str,
context_docs: list[LlmDoc] | list[InferenceChunk],
history_str: str,
prompt: PromptConfig,
) -> HumanMessage:
prompt_builder = (
_build_weak_llm_quotes_prompt
if QA_PROMPT_OVERRIDE == "weak"
else _build_strong_llm_quotes_prompt
)
return prompt_builder(
question=question,
question=query,
context_docs=context_docs,
history_str=history_str,
prompt=prompt,

View File

@@ -19,7 +19,7 @@ from danswer.natural_language_processing.utils import tokenizer_trim_content
from danswer.prompts.prompt_utils import build_doc_context_str
from danswer.search.models import InferenceChunk
from danswer.search.models import InferenceSection
from danswer.tools.search.search_utils import section_to_dict
from danswer.tools.tool_implementations.search.search_utils import section_to_dict
from danswer.utils.logger import setup_logger

View File

@@ -0,0 +1,91 @@
import abc
from collections.abc import Generator
from langchain_core.messages import BaseMessage
from danswer.chat.models import CitationInfo
from danswer.chat.models import LlmDoc
from danswer.llm.answering.llm_response_handler import ResponsePart
from danswer.llm.answering.stream_processing.citation_processing import (
CitationProcessor,
)
from danswer.llm.answering.stream_processing.quotes_processing import (
QuotesProcessor,
)
from danswer.llm.answering.stream_processing.utils import DocumentIdOrderMapping
class AnswerResponseHandler(abc.ABC):
@abc.abstractmethod
def handle_response_part(
self,
response_item: BaseMessage | None,
previous_response_items: list[BaseMessage],
) -> Generator[ResponsePart, None, None]:
raise NotImplementedError
class DummyAnswerResponseHandler(AnswerResponseHandler):
def handle_response_part(
self,
response_item: BaseMessage | None,
previous_response_items: list[BaseMessage],
) -> Generator[ResponsePart, None, None]:
# This is a dummy handler that returns nothing
yield from []
class CitationResponseHandler(AnswerResponseHandler):
def __init__(
self, context_docs: list[LlmDoc], doc_id_to_rank_map: DocumentIdOrderMapping
):
self.context_docs = context_docs
self.doc_id_to_rank_map = doc_id_to_rank_map
self.citation_processor = CitationProcessor(
context_docs=self.context_docs,
doc_id_to_rank_map=self.doc_id_to_rank_map,
)
self.processed_text = ""
self.citations: list[CitationInfo] = []
def handle_response_part(
self,
response_item: BaseMessage | None,
previous_response_items: list[BaseMessage],
) -> Generator[ResponsePart, None, None]:
if response_item is None:
return
content = (
response_item.content if isinstance(response_item.content, str) else ""
)
# Process the new content through the citation processor
yield from self.citation_processor.process_token(content)
class QuotesResponseHandler(AnswerResponseHandler):
def __init__(
self,
context_docs: list[LlmDoc],
is_json_prompt: bool = True,
):
self.quotes_processor = QuotesProcessor(
context_docs=context_docs,
is_json_prompt=is_json_prompt,
)
def handle_response_part(
self,
response_item: BaseMessage | None,
previous_response_items: list[BaseMessage],
) -> Generator[ResponsePart, None, None]:
if response_item is None:
yield from self.quotes_processor.process_token(None)
return
content = (
response_item.content if isinstance(response_item.content, str) else ""
)
yield from self.quotes_processor.process_token(content)

View File

@@ -1,12 +1,10 @@
import re
from collections.abc import Iterator
from collections.abc import Generator
from danswer.chat.models import AnswerQuestionStreamReturn
from danswer.chat.models import CitationInfo
from danswer.chat.models import DanswerAnswerPiece
from danswer.chat.models import LlmDoc
from danswer.configs.chat_configs import STOP_STREAM_PAT
from danswer.llm.answering.models import StreamProcessor
from danswer.llm.answering.stream_processing.utils import DocumentIdOrderMapping
from danswer.prompts.constants import TRIPLE_BACKTICK
from danswer.utils.logger import setup_logger
@@ -19,128 +17,104 @@ def in_code_block(llm_text: str) -> bool:
return count % 2 != 0
def extract_citations_from_stream(
tokens: Iterator[str],
context_docs: list[LlmDoc],
doc_id_to_rank_map: DocumentIdOrderMapping,
stop_stream: str | None = STOP_STREAM_PAT,
) -> Iterator[DanswerAnswerPiece | CitationInfo]:
"""
Key aspects:
class CitationProcessor:
def __init__(
self,
context_docs: list[LlmDoc],
doc_id_to_rank_map: DocumentIdOrderMapping,
stop_stream: str | None = STOP_STREAM_PAT,
):
self.context_docs = context_docs
self.doc_id_to_rank_map = doc_id_to_rank_map
self.stop_stream = stop_stream
self.order_mapping = doc_id_to_rank_map.order_mapping
self.llm_out = ""
self.max_citation_num = len(context_docs)
self.citation_order: list[int] = []
self.curr_segment = ""
self.cited_inds: set[int] = set()
self.hold = ""
self.current_citations: list[int] = []
self.past_cite_count = 0
1. Stream Processing:
- Processes tokens one by one, allowing for real-time handling of large texts.
def process_token(
self, token: str | None
) -> Generator[DanswerAnswerPiece | CitationInfo, None, None]:
# None -> end of stream
if token is None:
yield DanswerAnswerPiece(answer_piece=self.curr_segment)
return
2. Citation Detection:
- Uses regex to find citations in the format [number].
- Example: [1], [2], etc.
3. Citation Mapping:
- Maps detected citation numbers to actual document ranks using doc_id_to_rank_map.
- Example: [1] might become [3] if doc_id_to_rank_map maps it to 3.
4. Citation Formatting:
- Replaces citations with properly formatted versions.
- Adds links if available: [[1]](https://example.com)
- Handles cases where links are not available: [[1]]()
5. Duplicate Handling:
- Skips consecutive citations of the same document to avoid redundancy.
6. Output Generation:
- Yields DanswerAnswerPiece objects for regular text.
- Yields CitationInfo objects for each unique citation encountered.
7. Context Awareness:
- Uses context_docs to access document information for citations.
This function effectively processes a stream of text, identifies and reformats citations,
and provides both the processed text and citation information as output.
"""
order_mapping = doc_id_to_rank_map.order_mapping
llm_out = ""
max_citation_num = len(context_docs)
citation_order = []
curr_segment = ""
cited_inds = set()
hold = ""
raw_out = ""
current_citations: list[int] = []
past_cite_count = 0
for raw_token in tokens:
raw_out += raw_token
if stop_stream:
next_hold = hold + raw_token
if stop_stream in next_hold:
break
if next_hold == stop_stream[: len(next_hold)]:
hold = next_hold
continue
if self.stop_stream:
next_hold = self.hold + token
if self.stop_stream in next_hold:
return
if next_hold == self.stop_stream[: len(next_hold)]:
self.hold = next_hold
return
token = next_hold
hold = ""
else:
token = raw_token
self.hold = ""
curr_segment += token
llm_out += token
self.curr_segment += token
self.llm_out += token
# Handle code blocks without language tags
if "`" in curr_segment:
if curr_segment.endswith("`"):
continue
elif "```" in curr_segment:
piece_that_comes_after = curr_segment.split("```")[1][0]
if piece_that_comes_after == "\n" and in_code_block(llm_out):
curr_segment = curr_segment.replace("```", "```plaintext")
if "`" in self.curr_segment:
if self.curr_segment.endswith("`"):
return
elif "```" in self.curr_segment:
piece_that_comes_after = self.curr_segment.split("```")[1][0]
if piece_that_comes_after == "\n" and in_code_block(self.llm_out):
self.curr_segment = self.curr_segment.replace("```", "```plaintext")
citation_pattern = r"\[(\d+)\]"
citations_found = list(re.finditer(citation_pattern, curr_segment))
citations_found = list(re.finditer(citation_pattern, self.curr_segment))
possible_citation_pattern = r"(\[\d*$)" # [1, [, etc
possible_citation_found = re.search(possible_citation_pattern, curr_segment)
possible_citation_found = re.search(
possible_citation_pattern, self.curr_segment
)
# `past_cite_count`: number of characters since past citation
# 5 to ensure a citation hasn't occured
if len(citations_found) == 0 and len(llm_out) - past_cite_count > 5:
current_citations = []
if len(citations_found) == 0 and len(self.llm_out) - self.past_cite_count > 5:
self.current_citations = []
if citations_found and not in_code_block(llm_out):
result = "" # Initialize result here
if citations_found and not in_code_block(self.llm_out):
last_citation_end = 0
length_to_add = 0
while len(citations_found) > 0:
citation = citations_found.pop(0)
numerical_value = int(citation.group(1))
if 1 <= numerical_value <= max_citation_num:
context_llm_doc = context_docs[numerical_value - 1]
real_citation_num = order_mapping[context_llm_doc.document_id]
if 1 <= numerical_value <= self.max_citation_num:
context_llm_doc = self.context_docs[numerical_value - 1]
real_citation_num = self.order_mapping[context_llm_doc.document_id]
if real_citation_num not in citation_order:
citation_order.append(real_citation_num)
if real_citation_num not in self.citation_order:
self.citation_order.append(real_citation_num)
target_citation_num = citation_order.index(real_citation_num) + 1
target_citation_num = (
self.citation_order.index(real_citation_num) + 1
)
# Skip consecutive citations of the same work
if target_citation_num in current_citations:
if target_citation_num in self.current_citations:
start, end = citation.span()
real_start = length_to_add + start
diff = end - start
curr_segment = (
curr_segment[: length_to_add + start]
+ curr_segment[real_start + diff :]
self.curr_segment = (
self.curr_segment[: length_to_add + start]
+ self.curr_segment[real_start + diff :]
)
length_to_add -= diff
continue
# Handle edge case where LLM outputs citation itself
# by allowing it to generate citations on its own.
if curr_segment.startswith("[["):
match = re.match(r"\[\[(\d+)\]\]", curr_segment)
if self.curr_segment.startswith("[["):
match = re.match(r"\[\[(\d+)\]\]", self.curr_segment)
if match:
try:
doc_id = int(match.group(1))
context_llm_doc = context_docs[doc_id - 1]
context_llm_doc = self.context_docs[doc_id - 1]
yield CitationInfo(
citation_num=target_citation_num,
document_id=context_llm_doc.document_id,
@@ -150,75 +124,57 @@ def extract_citations_from_stream(
f"Manual LLM citation didn't properly cite documents {e}"
)
else:
# Will continue attempt on next loops
logger.warning(
"Manual LLM citation wasn't able to close brackets"
)
continue
link = context_llm_doc.link
# Replace the citation in the current segment
start, end = citation.span()
curr_segment = (
curr_segment[: start + length_to_add]
self.curr_segment = (
self.curr_segment[: start + length_to_add]
+ f"[{target_citation_num}]"
+ curr_segment[end + length_to_add :]
+ self.curr_segment[end + length_to_add :]
)
past_cite_count = len(llm_out)
current_citations.append(target_citation_num)
self.past_cite_count = len(self.llm_out)
self.current_citations.append(target_citation_num)
if target_citation_num not in cited_inds:
cited_inds.add(target_citation_num)
if target_citation_num not in self.cited_inds:
self.cited_inds.add(target_citation_num)
yield CitationInfo(
citation_num=target_citation_num,
document_id=context_llm_doc.document_id,
)
if link:
prev_length = len(curr_segment)
curr_segment = (
curr_segment[: start + length_to_add]
prev_length = len(self.curr_segment)
self.curr_segment = (
self.curr_segment[: start + length_to_add]
+ f"[[{target_citation_num}]]({link})"
+ curr_segment[end + length_to_add :]
+ self.curr_segment[end + length_to_add :]
)
length_to_add += len(curr_segment) - prev_length
length_to_add += len(self.curr_segment) - prev_length
else:
prev_length = len(curr_segment)
curr_segment = (
curr_segment[: start + length_to_add]
prev_length = len(self.curr_segment)
self.curr_segment = (
self.curr_segment[: start + length_to_add]
+ f"[[{target_citation_num}]]()"
+ curr_segment[end + length_to_add :]
+ self.curr_segment[end + length_to_add :]
)
length_to_add += len(curr_segment) - prev_length
length_to_add += len(self.curr_segment) - prev_length
last_citation_end = end + length_to_add
if last_citation_end > 0:
yield DanswerAnswerPiece(answer_piece=curr_segment[:last_citation_end])
curr_segment = curr_segment[last_citation_end:]
if possible_citation_found:
continue
yield DanswerAnswerPiece(answer_piece=curr_segment)
curr_segment = ""
result += self.curr_segment[:last_citation_end]
self.curr_segment = self.curr_segment[last_citation_end:]
if curr_segment:
yield DanswerAnswerPiece(answer_piece=curr_segment)
if not possible_citation_found:
result += self.curr_segment
self.curr_segment = ""
def build_citation_processor(
context_docs: list[LlmDoc], doc_id_to_rank_map: DocumentIdOrderMapping
) -> StreamProcessor:
def stream_processor(
tokens: Iterator[str],
) -> AnswerQuestionStreamReturn:
yield from extract_citations_from_stream(
tokens=tokens,
context_docs=context_docs,
doc_id_to_rank_map=doc_id_to_rank_map,
)
return stream_processor
if result:
yield DanswerAnswerPiece(answer_piece=result)

View File

@@ -1,14 +1,11 @@
import math
import re
from collections.abc import Callable
from collections.abc import Generator
from collections.abc import Iterator
from json import JSONDecodeError
from typing import Optional
import regex
from danswer.chat.models import AnswerQuestionStreamReturn
from danswer.chat.models import DanswerAnswer
from danswer.chat.models import DanswerAnswerPiece
from danswer.chat.models import DanswerQuote
@@ -157,7 +154,7 @@ def separate_answer_quotes(
return _extract_answer_quotes_freeform(clean_up_code_blocks(answer_raw))
def process_answer(
def _process_answer(
answer_raw: str,
docs: list[LlmDoc],
is_json_prompt: bool = True,
@@ -195,7 +192,7 @@ def _stream_json_answer_end(answer_so_far: str, next_token: str) -> bool:
def _extract_quotes_from_completed_token_stream(
model_output: str, context_docs: list[LlmDoc], is_json_prompt: bool = True
) -> DanswerQuotes:
answer, quotes = process_answer(model_output, context_docs, is_json_prompt)
answer, quotes = _process_answer(model_output, context_docs, is_json_prompt)
if answer:
logger.notice(answer)
elif model_output:
@@ -204,94 +201,101 @@ def _extract_quotes_from_completed_token_stream(
return quotes
def process_model_tokens(
tokens: Iterator[str],
context_docs: list[LlmDoc],
is_json_prompt: bool = True,
) -> Generator[DanswerAnswerPiece | DanswerQuotes, None, None]:
"""Used in the streaming case to process the model output
into an Answer and Quotes
class QuotesProcessor:
def __init__(
self,
context_docs: list[LlmDoc],
is_json_prompt: bool = True,
):
self.context_docs = context_docs
self.is_json_prompt = is_json_prompt
Yields Answer tokens back out in a dict for streaming to frontend
When Answer section ends, yields dict with answer_finished key
Collects all the tokens at the end to form the complete model output"""
quote_pat = f"\n{QUOTE_PAT}"
# Sometimes worse model outputs new line instead of :
quote_loose = f"\n{quote_pat[:-1]}\n"
# Sometime model outputs two newlines before quote section
quote_pat_full = f"\n{quote_pat}"
model_output: str = ""
found_answer_start = False if is_json_prompt else True
found_answer_end = False
hold_quote = ""
self.found_answer_start = False if is_json_prompt else True
self.found_answer_end = False
self.hold_quote = ""
self.model_output = ""
self.hold = ""
for token in tokens:
model_previous = model_output
model_output += token
def process_token(
self, token: str | None
) -> Generator[DanswerAnswerPiece | DanswerQuotes, None, None]:
# None -> end of stream
if token is None:
if self.model_output:
yield _extract_quotes_from_completed_token_stream(
model_output=self.model_output,
context_docs=self.context_docs,
is_json_prompt=self.is_json_prompt,
)
return
if not found_answer_start:
m = answer_pattern.search(model_output)
model_previous = self.model_output
self.model_output += token
if not self.found_answer_start:
m = answer_pattern.search(self.model_output)
if m:
found_answer_start = True
self.found_answer_start = True
# Prevent heavy cases of hallucinations where model is never providing a JSON
# We want to quickly update the user - not stream forever
if is_json_prompt and len(model_output) > 70:
# Prevent heavy cases of hallucinations
if self.is_json_prompt and len(self.model_output) > 70:
logger.warning("LLM did not produce json as prompted")
found_answer_end = True
continue
self.found_answer_end = True
return
remaining = model_output[m.end() :]
remaining = self.model_output[m.end() :]
# Look for an unescaped quote, which means the answer is entirely contained
# in this token e.g. if the token is `{"answer": "blah", "qu`
quote_indices = [i for i, char in enumerate(remaining) if char == '"']
for quote_idx in quote_indices:
# Check if quote is escaped by counting backslashes before it
num_backslashes = 0
pos = quote_idx - 1
while pos >= 0 and remaining[pos] == "\\":
num_backslashes += 1
pos -= 1
# If even number of backslashes, quote is not escaped
if num_backslashes % 2 == 0:
yield DanswerAnswerPiece(answer_piece=remaining[:quote_idx])
return
# If no unescaped quote found, yield the remaining string
if len(remaining) > 0:
yield DanswerAnswerPiece(answer_piece=remaining)
continue
return
if found_answer_start and not found_answer_end:
if is_json_prompt and _stream_json_answer_end(model_previous, token):
found_answer_end = True
if self.found_answer_start and not self.found_answer_end:
if self.is_json_prompt and _stream_json_answer_end(model_previous, token):
self.found_answer_end = True
# return the remaining part of the answer e.g. token might be 'd.", ' and we should yield 'd.'
if token:
try:
answer_token_section = token.index('"')
yield DanswerAnswerPiece(
answer_piece=hold_quote + token[:answer_token_section]
answer_piece=self.hold_quote + token[:answer_token_section]
)
except ValueError:
logger.error("Quotation mark not found in token")
yield DanswerAnswerPiece(answer_piece=hold_quote + token)
yield DanswerAnswerPiece(answer_piece=self.hold_quote + token)
yield DanswerAnswerPiece(answer_piece=None)
continue
elif not is_json_prompt:
if quote_pat in hold_quote + token or quote_loose in hold_quote + token:
found_answer_end = True
return
elif not self.is_json_prompt:
quote_pat = f"\n{QUOTE_PAT}"
quote_loose = f"\n{quote_pat[:-1]}\n"
quote_pat_full = f"\n{quote_pat}"
if (
quote_pat in self.hold_quote + token
or quote_loose in self.hold_quote + token
):
self.found_answer_end = True
yield DanswerAnswerPiece(answer_piece=None)
continue
if hold_quote + token in quote_pat_full:
hold_quote += token
continue
yield DanswerAnswerPiece(answer_piece=hold_quote + token)
hold_quote = ""
return
if self.hold_quote + token in quote_pat_full:
self.hold_quote += token
return
logger.debug(f"Raw Model QnA Output: {model_output}")
yield _extract_quotes_from_completed_token_stream(
model_output=model_output,
context_docs=context_docs,
is_json_prompt=is_json_prompt,
)
def build_quotes_processor(
context_docs: list[LlmDoc], is_json_prompt: bool
) -> Callable[[Iterator[str]], AnswerQuestionStreamReturn]:
def stream_processor(
tokens: Iterator[str],
) -> AnswerQuestionStreamReturn:
yield from process_model_tokens(
tokens=tokens,
context_docs=context_docs,
is_json_prompt=is_json_prompt,
)
return stream_processor
yield DanswerAnswerPiece(answer_piece=self.hold_quote + token)
self.hold_quote = ""

View File

@@ -0,0 +1,207 @@
from collections.abc import Generator
from langchain_core.messages import AIMessageChunk
from langchain_core.messages import BaseMessage
from langchain_core.messages import ToolCall
from danswer.llm.answering.llm_response_handler import LLMCall
from danswer.llm.answering.llm_response_handler import ResponsePart
from danswer.llm.interfaces import LLM
from danswer.tools.force import ForceUseTool
from danswer.tools.message import build_tool_message
from danswer.tools.message import ToolCallSummary
from danswer.tools.models import ToolCallFinalResult
from danswer.tools.models import ToolCallKickoff
from danswer.tools.models import ToolResponse
from danswer.tools.tool import Tool
from danswer.tools.tool_runner import (
check_which_tools_should_run_for_non_tool_calling_llm,
)
from danswer.tools.tool_runner import ToolRunner
from danswer.tools.tool_selection import select_single_tool_for_non_tool_calling_llm
from danswer.utils.logger import setup_logger
logger = setup_logger()
class ToolResponseHandler:
def __init__(self, tools: list[Tool]):
self.tools = tools
self.tool_call_chunk: AIMessageChunk | None = None
self.tool_call_requests: list[ToolCall] = []
self.tool_runner: ToolRunner | None = None
self.tool_call_summary: ToolCallSummary | None = None
self.tool_kickoff: ToolCallKickoff | None = None
self.tool_responses: list[ToolResponse] = []
self.tool_final_result: ToolCallFinalResult | None = None
@classmethod
def get_tool_call_for_non_tool_calling_llm(
cls, llm_call: LLMCall, llm: LLM
) -> tuple[Tool, dict] | None:
if llm_call.force_use_tool.force_use:
# if we are forcing a tool, we don't need to check which tools to run
tool = next(
(
t
for t in llm_call.tools
if t.name == llm_call.force_use_tool.tool_name
),
None,
)
if not tool:
raise RuntimeError(
f"Tool '{llm_call.force_use_tool.tool_name}' not found"
)
tool_args = (
llm_call.force_use_tool.args
if llm_call.force_use_tool.args is not None
else tool.get_args_for_non_tool_calling_llm(
query=llm_call.prompt_builder.get_user_message_content(),
history=llm_call.prompt_builder.raw_message_history,
llm=llm,
force_run=True,
)
)
if tool_args is None:
raise RuntimeError(f"Tool '{tool.name}' did not return args")
return (tool, tool_args)
else:
tool_options = check_which_tools_should_run_for_non_tool_calling_llm(
tools=llm_call.tools,
query=llm_call.prompt_builder.get_user_message_content(),
history=llm_call.prompt_builder.raw_message_history,
llm=llm,
)
available_tools_and_args = [
(llm_call.tools[ind], args)
for ind, args in enumerate(tool_options)
if args is not None
]
logger.info(
f"Selecting single tool from tools: {[(tool.name, args) for tool, args in available_tools_and_args]}"
)
chosen_tool_and_args = (
select_single_tool_for_non_tool_calling_llm(
tools_and_args=available_tools_and_args,
history=llm_call.prompt_builder.raw_message_history,
query=llm_call.prompt_builder.get_user_message_content(),
llm=llm,
)
if available_tools_and_args
else None
)
logger.notice(f"Chosen tool: {chosen_tool_and_args}")
return chosen_tool_and_args
def _handle_tool_call(self) -> Generator[ResponsePart, None, None]:
if not self.tool_call_chunk or not self.tool_call_chunk.tool_calls:
return
self.tool_call_requests = self.tool_call_chunk.tool_calls
selected_tool: Tool | None = None
selected_tool_call_request: ToolCall | None = None
for tool_call_request in self.tool_call_requests:
known_tools_by_name = [
tool for tool in self.tools if tool.name == tool_call_request["name"]
]
if not known_tools_by_name:
logger.error(
"Tool call requested with unknown name field. \n"
f"self.tools: {self.tools}"
f"tool_call_request: {tool_call_request}"
)
continue
else:
selected_tool = known_tools_by_name[0]
selected_tool_call_request = tool_call_request
if selected_tool and selected_tool_call_request:
break
if not selected_tool or not selected_tool_call_request:
return
logger.info(f"Selected tool: {selected_tool.name}")
logger.debug(f"Selected tool call request: {selected_tool_call_request}")
self.tool_runner = ToolRunner(selected_tool, selected_tool_call_request["args"])
self.tool_kickoff = self.tool_runner.kickoff()
yield self.tool_kickoff
for response in self.tool_runner.tool_responses():
self.tool_responses.append(response)
yield response
self.tool_final_result = self.tool_runner.tool_final_result()
yield self.tool_final_result
self.tool_call_summary = ToolCallSummary(
tool_call_request=self.tool_call_chunk,
tool_call_result=build_tool_message(
selected_tool_call_request, self.tool_runner.tool_message_content()
),
)
def handle_response_part(
self,
response_item: BaseMessage | None,
previous_response_items: list[BaseMessage],
) -> Generator[ResponsePart, None, None]:
if response_item is None:
yield from self._handle_tool_call()
if isinstance(response_item, AIMessageChunk) and (
response_item.tool_call_chunks or response_item.tool_calls
):
if self.tool_call_chunk is None:
self.tool_call_chunk = response_item
else:
self.tool_call_chunk += response_item # type: ignore
return
def next_llm_call(self, current_llm_call: LLMCall) -> LLMCall | None:
if (
self.tool_runner is None
or self.tool_call_summary is None
or self.tool_kickoff is None
or self.tool_final_result is None
):
return None
tool_runner = self.tool_runner
new_prompt_builder = tool_runner.tool.build_next_prompt(
prompt_builder=current_llm_call.prompt_builder,
tool_call_summary=self.tool_call_summary,
tool_responses=self.tool_responses,
using_tool_calling_llm=current_llm_call.using_tool_calling_llm,
)
return LLMCall(
prompt_builder=new_prompt_builder,
tools=[], # for now, only allow one tool call per response
force_use_tool=ForceUseTool(
force_use=False,
tool_name="",
args=None,
),
files=current_llm_call.files,
using_tool_calling_llm=current_llm_call.using_tool_calling_llm,
tool_call_info=[
self.tool_kickoff,
*self.tool_responses,
self.tool_final_result,
],
)

View File

@@ -203,6 +203,28 @@ def build_content_with_imgs(
)
def message_to_prompt_and_imgs(message: BaseMessage) -> tuple[str, list[str]]:
if isinstance(message.content, str):
return message.content, []
imgs = []
texts = []
for part in message.content:
if isinstance(part, dict):
if part.get("type") == "image_url":
img_url = part.get("image_url", {}).get("url")
if img_url:
imgs.append(img_url)
elif part.get("type") == "text":
text = part.get("text")
if text:
texts.append(text)
else:
texts.append(part)
return "".join(texts), imgs
def dict_based_prompt_to_langchain_prompt(
messages: list[dict[str, str]]
) -> list[BaseMessage]:

View File

@@ -52,12 +52,16 @@ from danswer.secondary_llm_flows.query_expansion import thread_based_query_rephr
from danswer.server.query_and_chat.models import ChatMessageDetail
from danswer.server.utils import get_json_line
from danswer.tools.force import ForceUseTool
from danswer.tools.search.search_tool import SEARCH_DOC_CONTENT_ID
from danswer.tools.search.search_tool import SEARCH_RESPONSE_SUMMARY_ID
from danswer.tools.search.search_tool import SearchResponseSummary
from danswer.tools.search.search_tool import SearchTool
from danswer.tools.search.search_tool import SECTION_RELEVANCE_LIST_ID
from danswer.tools.tool import ToolResponse
from danswer.tools.models import ToolResponse
from danswer.tools.tool_implementations.search.search_tool import SEARCH_DOC_CONTENT_ID
from danswer.tools.tool_implementations.search.search_tool import (
SEARCH_RESPONSE_SUMMARY_ID,
)
from danswer.tools.tool_implementations.search.search_tool import SearchResponseSummary
from danswer.tools.tool_implementations.search.search_tool import SearchTool
from danswer.tools.tool_implementations.search.search_tool import (
SECTION_RELEVANCE_LIST_ID,
)
from danswer.tools.tool_runner import ToolCallKickoff
from danswer.utils.logger import setup_logger
from danswer.utils.timing import log_generator_function_time
@@ -202,30 +206,33 @@ def stream_answer_objects(
max_tokens=max_document_tokens,
)
answer_config = AnswerStyleConfig(
citation_config=CitationConfig() if use_citations else None,
quotes_config=QuotesConfig() if not use_citations else None,
document_pruning_config=document_pruning_config,
)
search_tool = SearchTool(
db_session=db_session,
user=user,
evaluation_type=LLMEvaluationType.SKIP
if DISABLE_LLM_DOC_RELEVANCE
else query_req.evaluation_type,
evaluation_type=(
LLMEvaluationType.SKIP
if DISABLE_LLM_DOC_RELEVANCE
else query_req.evaluation_type
),
persona=persona,
retrieval_options=query_req.retrieval_options,
prompt_config=prompt_config,
llm=llm,
fast_llm=fast_llm,
pruning_config=document_pruning_config,
answer_style_config=answer_config,
bypass_acl=bypass_acl,
chunks_above=query_req.chunks_above,
chunks_below=query_req.chunks_below,
full_doc=query_req.full_doc,
)
answer_config = AnswerStyleConfig(
citation_config=CitationConfig() if use_citations else None,
quotes_config=QuotesConfig() if not use_citations else None,
document_pruning_config=document_pruning_config,
)
answer = Answer(
question=query_msg.message,
answer_style_config=answer_config,

View File

@@ -9,7 +9,7 @@ from danswer.db.models import StarterMessage
from danswer.search.enums import RecencyBiasSetting
from danswer.server.features.document_set.models import DocumentSet
from danswer.server.features.prompt.models import PromptSnapshot
from danswer.server.features.tool.api import ToolSnapshot
from danswer.server.features.tool.models import ToolSnapshot
from danswer.server.models import MinimalUserSnapshot
from danswer.utils.logger import setup_logger

View File

@@ -18,10 +18,16 @@ from danswer.db.tools import update_tool
from danswer.server.features.tool.models import CustomToolCreate
from danswer.server.features.tool.models import CustomToolUpdate
from danswer.server.features.tool.models import ToolSnapshot
from danswer.tools.custom.openapi_parsing import MethodSpec
from danswer.tools.custom.openapi_parsing import openapi_to_method_specs
from danswer.tools.custom.openapi_parsing import validate_openapi_schema
from danswer.tools.images.image_generation_tool import ImageGenerationTool
from danswer.tools.tool_implementations.custom.openapi_parsing import MethodSpec
from danswer.tools.tool_implementations.custom.openapi_parsing import (
openapi_to_method_specs,
)
from danswer.tools.tool_implementations.custom.openapi_parsing import (
validate_openapi_schema,
)
from danswer.tools.tool_implementations.images.image_generation_tool import (
ImageGenerationTool,
)
from danswer.tools.utils import is_image_generation_available
router = APIRouter(prefix="/tool")

View File

@@ -1,5 +1,6 @@
import asyncio
import io
import json
import uuid
from collections.abc import Callable
from collections.abc import Generator
@@ -283,13 +284,14 @@ def delete_chat_session_by_id(
raise HTTPException(status_code=400, detail=str(e))
async def is_disconnected(request: Request) -> Callable[[], bool]:
async def is_connected(request: Request) -> Callable[[], bool]:
main_loop = asyncio.get_event_loop()
def is_disconnected_sync() -> bool:
def is_connected_sync() -> bool:
future = asyncio.run_coroutine_threadsafe(request.is_disconnected(), main_loop)
try:
return not future.result(timeout=0.01)
is_connected = not future.result(timeout=0.01)
return is_connected
except asyncio.TimeoutError:
logger.error("Asyncio timed out")
return True
@@ -300,7 +302,7 @@ async def is_disconnected(request: Request) -> Callable[[], bool]:
)
return True
return is_disconnected_sync
return is_connected_sync
@router.post("/send-message")
@@ -309,7 +311,7 @@ def handle_new_chat_message(
request: Request,
user: User | None = Depends(current_user),
_: None = Depends(check_token_rate_limits),
is_disconnected_func: Callable[[], bool] = Depends(is_disconnected),
is_connected_func: Callable[[], bool] = Depends(is_connected),
) -> StreamingResponse:
"""
This endpoint is both used for all the following purposes:
@@ -325,7 +327,7 @@ def handle_new_chat_message(
request (Request): The current HTTP request context.
user (User | None): The current user, obtained via dependency injection.
_ (None): Rate limit check is run if user/group/global rate limits are enabled.
is_disconnected_func (Callable[[], bool]): Function to check client disconnection,
is_connected_func (Callable[[], bool]): Function to check client disconnection,
used to stop the streaming response if the client disconnects.
Returns:
@@ -340,8 +342,6 @@ def handle_new_chat_message(
):
raise HTTPException(status_code=400, detail="Empty chat message is invalid")
import json
def stream_generator() -> Generator[str, None, None]:
try:
for packet in stream_chat_message(
@@ -354,7 +354,7 @@ def handle_new_chat_message(
custom_tool_additional_headers=get_custom_tool_additional_request_headers(
request.headers
),
is_connected=is_disconnected_func,
is_connected=is_connected_func,
):
yield json.dumps(packet) if isinstance(packet, dict) else packet
@@ -362,6 +362,9 @@ def handle_new_chat_message(
logger.exception(f"Error in chat message streaming: {e}")
yield json.dumps({"error": str(e)})
finally:
logger.debug("Stream generator finished")
return StreamingResponse(stream_generator(), media_type="text/event-stream")

View File

@@ -0,0 +1,59 @@
from typing import cast
from typing import TYPE_CHECKING
from langchain_core.messages import HumanMessage
from danswer.llm.utils import message_to_prompt_and_imgs
from danswer.tools.tool import Tool
if TYPE_CHECKING:
from danswer.llm.answering.prompts.build import AnswerPromptBuilder
from danswer.tools.tool_implementations.custom.custom_tool import (
CustomToolCallSummary,
)
from danswer.tools.message import ToolCallSummary
from danswer.tools.models import ToolResponse
def build_user_message_for_non_tool_calling_llm(
message: HumanMessage,
tool_name: str,
*args: "ToolResponse",
) -> str:
query, _ = message_to_prompt_and_imgs(message)
tool_run_summary = cast("CustomToolCallSummary", args[0].response).tool_result
return f"""
Here's the result from the {tool_name} tool:
{tool_run_summary}
Now respond to the following:
{query}
""".strip()
class BaseTool(Tool):
def build_next_prompt(
self,
prompt_builder: "AnswerPromptBuilder",
tool_call_summary: "ToolCallSummary",
tool_responses: list["ToolResponse"],
using_tool_calling_llm: bool,
) -> "AnswerPromptBuilder":
if using_tool_calling_llm:
prompt_builder.append_message(tool_call_summary.tool_call_request)
prompt_builder.append_message(tool_call_summary.tool_call_result)
else:
prompt_builder.update_user_prompt(
HumanMessage(
content=build_user_message_for_non_tool_calling_llm(
prompt_builder.user_message_and_token_cnt[0],
self.name,
*tool_responses,
)
)
)
return prompt_builder

View File

@@ -9,9 +9,13 @@ from sqlalchemy.orm import Session
from danswer.db.models import Persona
from danswer.db.models import Tool as ToolDBModel
from danswer.tools.images.image_generation_tool import ImageGenerationTool
from danswer.tools.internet_search.internet_search_tool import InternetSearchTool
from danswer.tools.search.search_tool import SearchTool
from danswer.tools.tool_implementations.images.image_generation_tool import (
ImageGenerationTool,
)
from danswer.tools.tool_implementations.internet_search.internet_search_tool import (
InternetSearchTool,
)
from danswer.tools.tool_implementations.search.search_tool import SearchTool
from danswer.tools.tool import Tool
from danswer.utils.logger import setup_logger

View File

@@ -1,21 +0,0 @@
from typing import cast
from danswer.tools.custom.custom_tool import CustomToolCallSummary
from danswer.tools.models import ToolResponse
def build_user_message_for_custom_tool_for_non_tool_calling_llm(
query: str,
tool_name: str,
*args: ToolResponse,
) -> str:
tool_run_summary = cast(CustomToolCallSummary, args[0].response).tool_result
return f"""
Here's the result from the {tool_name} tool:
{tool_run_summary}
Now respond to the following:
{query}
""".strip()

View File

@@ -1,11 +1,17 @@
import abc
from collections.abc import Generator
from typing import Any
from typing import TYPE_CHECKING
from danswer.key_value_store.interface import JSON_ro
from danswer.llm.answering.models import PreviousMessage
from danswer.llm.interfaces import LLM
from danswer.tools.models import ToolResponse
if TYPE_CHECKING:
from danswer.llm.answering.prompts.build import AnswerPromptBuilder
from danswer.tools.message import ToolCallSummary
from danswer.tools.models import ToolResponse
class Tool(abc.ABC):
@@ -32,7 +38,7 @@ class Tool(abc.ABC):
@abc.abstractmethod
def build_tool_message_content(
self, *args: ToolResponse
self, *args: "ToolResponse"
) -> str | list[str | dict[str, Any]]:
raise NotImplementedError
@@ -51,13 +57,26 @@ class Tool(abc.ABC):
"""Actual execution of the tool"""
@abc.abstractmethod
def run(self, **kwargs: Any) -> Generator[ToolResponse, None, None]:
def run(self, **kwargs: Any) -> Generator["ToolResponse", None, None]:
raise NotImplementedError
@abc.abstractmethod
def final_result(self, *args: ToolResponse) -> JSON_ro:
def final_result(self, *args: "ToolResponse") -> JSON_ro:
"""
This is the "final summary" result of the tool.
It is the result that will be stored in the database.
"""
raise NotImplementedError
"""Some tools may want to modify the prompt based on the tool call summary and tool responses.
Default behavior is to continue with just the raw tool call request/result passed to the LLM."""
@abc.abstractmethod
def build_next_prompt(
self,
prompt_builder: "AnswerPromptBuilder",
tool_call_summary: "ToolCallSummary",
tool_responses: list["ToolResponse"],
using_tool_calling_llm: bool,
) -> "AnswerPromptBuilder":
raise NotImplementedError

View File

@@ -11,24 +11,34 @@ from pydantic import BaseModel
from danswer.key_value_store.interface import JSON_ro
from danswer.llm.answering.models import PreviousMessage
from danswer.llm.interfaces import LLM
from danswer.tools.custom.base_tool_types import ToolResultType
from danswer.tools.custom.custom_tool_prompts import (
SHOULD_USE_CUSTOM_TOOL_SYSTEM_PROMPT,
)
from danswer.tools.custom.custom_tool_prompts import SHOULD_USE_CUSTOM_TOOL_USER_PROMPT
from danswer.tools.custom.custom_tool_prompts import TOOL_ARG_SYSTEM_PROMPT
from danswer.tools.custom.custom_tool_prompts import TOOL_ARG_USER_PROMPT
from danswer.tools.custom.custom_tool_prompts import USE_TOOL
from danswer.tools.custom.openapi_parsing import MethodSpec
from danswer.tools.custom.openapi_parsing import openapi_to_method_specs
from danswer.tools.custom.openapi_parsing import openapi_to_url
from danswer.tools.custom.openapi_parsing import REQUEST_BODY
from danswer.tools.custom.openapi_parsing import validate_openapi_schema
from danswer.tools.base_tool import BaseTool
from danswer.tools.models import CHAT_SESSION_ID_PLACEHOLDER
from danswer.tools.models import DynamicSchemaInfo
from danswer.tools.models import MESSAGE_ID_PLACEHOLDER
from danswer.tools.tool import Tool
from danswer.tools.tool import ToolResponse
from danswer.tools.models import ToolResponse
from danswer.tools.tool_implementations.custom.base_tool_types import ToolResultType
from danswer.tools.tool_implementations.custom.custom_tool_prompts import (
SHOULD_USE_CUSTOM_TOOL_SYSTEM_PROMPT,
)
from danswer.tools.tool_implementations.custom.custom_tool_prompts import (
SHOULD_USE_CUSTOM_TOOL_USER_PROMPT,
)
from danswer.tools.tool_implementations.custom.custom_tool_prompts import (
TOOL_ARG_SYSTEM_PROMPT,
)
from danswer.tools.tool_implementations.custom.custom_tool_prompts import (
TOOL_ARG_USER_PROMPT,
)
from danswer.tools.tool_implementations.custom.custom_tool_prompts import USE_TOOL
from danswer.tools.tool_implementations.custom.openapi_parsing import MethodSpec
from danswer.tools.tool_implementations.custom.openapi_parsing import (
openapi_to_method_specs,
)
from danswer.tools.tool_implementations.custom.openapi_parsing import openapi_to_url
from danswer.tools.tool_implementations.custom.openapi_parsing import REQUEST_BODY
from danswer.tools.tool_implementations.custom.openapi_parsing import (
validate_openapi_schema,
)
from danswer.utils.headers import header_list_to_header_dict
from danswer.utils.headers import HeaderItemDict
from danswer.utils.logger import setup_logger
@@ -43,7 +53,7 @@ class CustomToolCallSummary(BaseModel):
tool_result: ToolResultType
class CustomTool(Tool):
class CustomTool(BaseTool):
def __init__(
self,
method_spec: MethodSpec,

View File

@@ -11,12 +11,17 @@ from danswer.chat.chat_utils import combine_message_chain
from danswer.configs.model_configs import GEN_AI_HISTORY_CUTOFF
from danswer.key_value_store.interface import JSON_ro
from danswer.llm.answering.models import PreviousMessage
from danswer.llm.answering.prompts.build import AnswerPromptBuilder
from danswer.llm.interfaces import LLM
from danswer.llm.utils import build_content_with_imgs
from danswer.llm.utils import message_to_string
from danswer.prompts.constants import GENERAL_SEP_PAT
from danswer.tools.message import ToolCallSummary
from danswer.tools.models import ToolResponse
from danswer.tools.tool import Tool
from danswer.tools.tool import ToolResponse
from danswer.tools.tool_implementations.images.prompt import (
build_image_generation_user_prompt,
)
from danswer.utils.headers import build_llm_extra_headers
from danswer.utils.logger import setup_logger
from danswer.utils.threadpool_concurrency import run_functions_tuples_in_parallel
@@ -112,7 +117,10 @@ class ImageGenerationTool(Tool):
},
"shape": {
"type": "string",
"description": "Optional. Image shape: 'square', 'portrait', or 'landscape'",
"description": (
"Optional - only specify if you want a specific shape."
" Image shape: 'square', 'portrait', or 'landscape'."
),
"enum": [shape.value for shape in ImageShape],
},
},
@@ -258,3 +266,34 @@ class ImageGenerationTool(Tool):
image_generation_response.model_dump()
for image_generation_response in image_generation_responses
]
def build_next_prompt(
self,
prompt_builder: AnswerPromptBuilder,
tool_call_summary: ToolCallSummary,
tool_responses: list[ToolResponse],
using_tool_calling_llm: bool,
) -> AnswerPromptBuilder:
img_generation_response = cast(
list[ImageGenerationResponse] | None,
next(
(
response.response
for response in tool_responses
if response.id == IMAGE_GENERATION_RESPONSE_ID
),
None,
),
)
if img_generation_response is None:
raise ValueError("No image generation response found")
img_urls = [img.url for img in img_generation_response]
prompt_builder.update_user_prompt(
build_image_generation_user_prompt(
query=prompt_builder.get_user_message_content(),
img_urls=img_urls,
)
)
return prompt_builder

View File

@@ -11,18 +11,31 @@ from danswer.chat.models import LlmDoc
from danswer.configs.constants import DocumentSource
from danswer.configs.model_configs import GEN_AI_HISTORY_CUTOFF
from danswer.key_value_store.interface import JSON_ro
from danswer.llm.answering.models import AnswerStyleConfig
from danswer.llm.answering.models import PreviousMessage
from danswer.llm.answering.models import PromptConfig
from danswer.llm.answering.prompts.build import AnswerPromptBuilder
from danswer.llm.interfaces import LLM
from danswer.llm.utils import message_to_string
from danswer.prompts.chat_prompts import INTERNET_SEARCH_QUERY_REPHRASE
from danswer.prompts.constants import GENERAL_SEP_PAT
from danswer.search.models import SearchDoc
from danswer.secondary_llm_flows.query_expansion import history_based_query_rephrase
from danswer.tools.internet_search.models import InternetSearchResponse
from danswer.tools.internet_search.models import InternetSearchResult
from danswer.tools.search.search_tool import FINAL_CONTEXT_DOCUMENTS_ID
from danswer.tools.message import ToolCallSummary
from danswer.tools.models import ToolResponse
from danswer.tools.tool import Tool
from danswer.tools.tool import ToolResponse
from danswer.tools.tool_implementations.internet_search.models import (
InternetSearchResponse,
)
from danswer.tools.tool_implementations.internet_search.models import (
InternetSearchResult,
)
from danswer.tools.tool_implementations.search_like_tool_utils import (
build_next_prompt_for_search_like_tool,
)
from danswer.tools.tool_implementations.search_like_tool_utils import (
FINAL_CONTEXT_DOCUMENTS_ID,
)
from danswer.utils.logger import setup_logger
logger = setup_logger()
@@ -97,8 +110,17 @@ class InternetSearchTool(Tool):
_DISPLAY_NAME = "[Beta] Internet Search Tool"
_DESCRIPTION = "Perform an internet search for up-to-date information."
def __init__(self, api_key: str, num_results: int = 10) -> None:
def __init__(
self,
api_key: str,
answer_style_config: AnswerStyleConfig,
prompt_config: PromptConfig,
num_results: int = 10,
) -> None:
self.api_key = api_key
self.answer_style_config = answer_style_config
self.prompt_config = prompt_config
self.host = "https://api.bing.microsoft.com/v7.0"
self.headers = {
"Ocp-Apim-Subscription-Key": api_key,
@@ -231,3 +253,19 @@ class InternetSearchTool(Tool):
def final_result(self, *args: ToolResponse) -> JSON_ro:
search_response = cast(InternetSearchResponse, args[0].response)
return search_response.model_dump()
def build_next_prompt(
self,
prompt_builder: AnswerPromptBuilder,
tool_call_summary: ToolCallSummary,
tool_responses: list[ToolResponse],
using_tool_calling_llm: bool,
) -> AnswerPromptBuilder:
return build_next_prompt_for_search_like_tool(
prompt_builder=prompt_builder,
tool_call_summary=tool_call_summary,
tool_responses=tool_responses,
using_tool_calling_llm=using_tool_calling_llm,
answer_style_config=self.answer_style_config,
prompt_config=self.prompt_config,
)

View File

@@ -17,10 +17,13 @@ from danswer.configs.model_configs import GEN_AI_MODEL_FALLBACK_MAX_TOKENS
from danswer.db.models import Persona
from danswer.db.models import User
from danswer.key_value_store.interface import JSON_ro
from danswer.llm.answering.llm_response_handler import LLMCall
from danswer.llm.answering.models import AnswerStyleConfig
from danswer.llm.answering.models import ContextualPruningConfig
from danswer.llm.answering.models import DocumentPruningConfig
from danswer.llm.answering.models import PreviousMessage
from danswer.llm.answering.models import PromptConfig
from danswer.llm.answering.prompts.build import AnswerPromptBuilder
from danswer.llm.answering.prompts.citations_prompt import compute_max_llm_input_tokens
from danswer.llm.answering.prune_and_merge import prune_and_merge_sections
from danswer.llm.answering.prune_and_merge import prune_sections
@@ -35,9 +38,16 @@ from danswer.search.models import SearchRequest
from danswer.search.pipeline import SearchPipeline
from danswer.secondary_llm_flows.choose_search import check_if_need_search
from danswer.secondary_llm_flows.query_expansion import history_based_query_rephrase
from danswer.tools.search.search_utils import llm_doc_to_dict
from danswer.tools.message import ToolCallSummary
from danswer.tools.models import ToolResponse
from danswer.tools.tool import Tool
from danswer.tools.tool import ToolResponse
from danswer.tools.tool_implementations.search.search_utils import llm_doc_to_dict
from danswer.tools.tool_implementations.search_like_tool_utils import (
build_next_prompt_for_search_like_tool,
)
from danswer.tools.tool_implementations.search_like_tool_utils import (
FINAL_CONTEXT_DOCUMENTS_ID,
)
from danswer.utils.logger import setup_logger
logger = setup_logger()
@@ -45,7 +55,6 @@ logger = setup_logger()
SEARCH_RESPONSE_SUMMARY_ID = "search_response_summary"
SEARCH_DOC_CONTENT_ID = "search_doc_content"
SECTION_RELEVANCE_LIST_ID = "section_relevance_list"
FINAL_CONTEXT_DOCUMENTS_ID = "final_context_documents"
SEARCH_EVALUATION_ID = "llm_doc_eval"
@@ -85,6 +94,7 @@ class SearchTool(Tool):
llm: LLM,
fast_llm: LLM,
pruning_config: DocumentPruningConfig,
answer_style_config: AnswerStyleConfig,
evaluation_type: LLMEvaluationType,
# if specified, will not actually run a search and will instead return these
# sections. Used when the user selects specific docs to talk to
@@ -136,6 +146,7 @@ class SearchTool(Tool):
num_chunk_multiple = self.chunks_above + self.chunks_below + 1
self.answer_style_config = answer_style_config
self.contextual_pruning_config = (
ContextualPruningConfig.from_doc_pruning_config(
num_chunk_multiple=num_chunk_multiple, doc_pruning_config=pruning_config
@@ -353,4 +364,36 @@ class SearchTool(Tool):
# NOTE: need to do this json.loads(doc.json()) stuff because there are some
# subfields that are not serializable by default (datetime)
# this forces pydantic to make them JSON serializable for us
return [json.loads(doc.json()) for doc in final_docs]
return [json.loads(doc.model_dump_json()) for doc in final_docs]
def build_next_prompt(
self,
prompt_builder: AnswerPromptBuilder,
tool_call_summary: ToolCallSummary,
tool_responses: list[ToolResponse],
using_tool_calling_llm: bool,
) -> AnswerPromptBuilder:
return build_next_prompt_for_search_like_tool(
prompt_builder=prompt_builder,
tool_call_summary=tool_call_summary,
tool_responses=tool_responses,
using_tool_calling_llm=using_tool_calling_llm,
answer_style_config=self.answer_style_config,
prompt_config=self.prompt_config,
)
"""Other utility functions"""
@classmethod
def get_search_result(cls, llm_call: LLMCall) -> list[LlmDoc] | None:
if not llm_call.tool_call_info:
return None
for yield_item in llm_call.tool_call_info:
if (
isinstance(yield_item, ToolResponse)
and yield_item.id == FINAL_CONTEXT_DOCUMENTS_ID
):
return cast(list[LlmDoc], yield_item.response)
return None

View File

@@ -0,0 +1,71 @@
from typing import cast
from danswer.chat.models import LlmDoc
from danswer.llm.answering.models import AnswerStyleConfig
from danswer.llm.answering.models import PromptConfig
from danswer.llm.answering.prompts.build import AnswerPromptBuilder
from danswer.llm.answering.prompts.citations_prompt import (
build_citations_system_message,
)
from danswer.llm.answering.prompts.citations_prompt import build_citations_user_message
from danswer.llm.answering.prompts.quotes_prompt import build_quotes_user_message
from danswer.tools.message import ToolCallSummary
from danswer.tools.models import ToolResponse
FINAL_CONTEXT_DOCUMENTS_ID = "final_context_documents"
def build_next_prompt_for_search_like_tool(
prompt_builder: AnswerPromptBuilder,
tool_call_summary: ToolCallSummary,
tool_responses: list[ToolResponse],
using_tool_calling_llm: bool,
answer_style_config: AnswerStyleConfig,
prompt_config: PromptConfig,
) -> AnswerPromptBuilder:
if not using_tool_calling_llm:
final_context_docs_response = next(
response
for response in tool_responses
if response.id == FINAL_CONTEXT_DOCUMENTS_ID
)
final_context_documents = cast(
list[LlmDoc], final_context_docs_response.response
)
else:
# if using tool calling llm, then the final context documents are the tool responses
final_context_documents = []
if answer_style_config.citation_config:
prompt_builder.update_system_prompt(
build_citations_system_message(prompt_config)
)
prompt_builder.update_user_prompt(
build_citations_user_message(
message=prompt_builder.user_message_and_token_cnt[0],
prompt_config=prompt_config,
context_docs=final_context_documents,
all_doc_useful=(
answer_style_config.citation_config.all_docs_useful
if answer_style_config.citation_config
else False
),
history_message=prompt_builder.single_message_history or "",
)
)
elif answer_style_config.quotes_config:
prompt_builder.update_user_prompt(
build_quotes_user_message(
message=prompt_builder.user_message_and_token_cnt[0],
context_docs=final_context_documents,
history_str=prompt_builder.single_message_history or "",
prompt=prompt_config,
)
)
if using_tool_calling_llm:
prompt_builder.append_message(tool_call_summary.tool_call_request)
prompt_builder.append_message(tool_call_summary.tool_call_result)
return prompt_builder

View File

@@ -6,8 +6,8 @@ from danswer.llm.answering.models import PreviousMessage
from danswer.llm.interfaces import LLM
from danswer.tools.models import ToolCallFinalResult
from danswer.tools.models import ToolCallKickoff
from danswer.tools.models import ToolResponse
from danswer.tools.tool import Tool
from danswer.tools.tool import ToolResponse
from danswer.utils.threadpool_concurrency import run_functions_tuples_in_parallel

View File

@@ -12,7 +12,7 @@ from danswer.db.models import Tool
from danswer.db.models import User
from danswer.db.persona import get_prompts_by_ids
from danswer.one_shot_answer.models import PersonaConfig
from danswer.tools.custom.custom_tool import (
from danswer.tools.tool_implementations.custom.custom_tool import (
build_custom_tools_from_openapi_schema_and_headers,
)

View File

@@ -142,6 +142,9 @@ def test_using_reference_docs_with_simple_with_history_api_flow(reset: None) ->
assert response.status_code == 200
response_json = response.json()
# make sure there is an answer
assert response_json["answer"]
# since we only gave it one search doc, all responses should only contain that doc
assert response_json["final_context_doc_indices"] == [0]
assert response_json["llm_selected_doc_indices"] == [0]

View File

@@ -0,0 +1,113 @@
import json
from datetime import datetime
from unittest.mock import MagicMock
import pytest
from langchain_core.messages import SystemMessage
from danswer.chat.models import LlmDoc
from danswer.configs.constants import DocumentSource
from danswer.llm.answering.models import AnswerStyleConfig
from danswer.llm.answering.models import CitationConfig
from danswer.llm.answering.models import PromptConfig
from danswer.llm.answering.prompts.build import AnswerPromptBuilder
from danswer.llm.interfaces import LLMConfig
from danswer.tools.models import ToolResponse
from danswer.tools.tool_implementations.search.search_tool import SearchTool
from danswer.tools.tool_implementations.search_like_tool_utils import (
FINAL_CONTEXT_DOCUMENTS_ID,
)
QUERY = "Test question"
DEFAULT_SEARCH_ARGS = {"query": "search"}
@pytest.fixture
def answer_style_config() -> AnswerStyleConfig:
return AnswerStyleConfig(citation_config=CitationConfig())
@pytest.fixture
def prompt_config() -> PromptConfig:
return PromptConfig(
system_prompt="System prompt",
task_prompt="Task prompt",
datetime_aware=False,
include_citations=True,
)
@pytest.fixture
def mock_llm() -> MagicMock:
mock_llm_obj = MagicMock()
mock_llm_obj.config = LLMConfig(
model_provider="openai",
model_name="gpt-4o",
temperature=0.0,
api_key=None,
api_base=None,
api_version=None,
)
return mock_llm_obj
@pytest.fixture
def mock_search_results() -> list[LlmDoc]:
return [
LlmDoc(
content="Search result 1",
source_type=DocumentSource.WEB,
metadata={"id": "doc1"},
document_id="doc1",
blurb="Blurb 1",
semantic_identifier="Semantic ID 1",
updated_at=datetime(2023, 1, 1),
link="https://example.com/doc1",
source_links={0: "https://example.com/doc1"},
),
LlmDoc(
content="Search result 2",
source_type=DocumentSource.WEB,
metadata={"id": "doc2"},
document_id="doc2",
blurb="Blurb 2",
semantic_identifier="Semantic ID 2",
updated_at=datetime(2023, 1, 2),
link="https://example.com/doc2",
source_links={0: "https://example.com/doc2"},
),
]
@pytest.fixture
def mock_search_tool(mock_search_results: list[LlmDoc]) -> MagicMock:
mock_tool = MagicMock(spec=SearchTool)
mock_tool.name = "search"
mock_tool.build_tool_message_content.return_value = "search_response"
mock_tool.get_args_for_non_tool_calling_llm.return_value = DEFAULT_SEARCH_ARGS
mock_tool.final_result.return_value = [
json.loads(doc.model_dump_json()) for doc in mock_search_results
]
mock_tool.run.return_value = [
ToolResponse(id=FINAL_CONTEXT_DOCUMENTS_ID, response=mock_search_results)
]
mock_tool.tool_definition.return_value = {
"type": "function",
"function": {
"name": "search",
"description": "Search for information",
"parameters": {
"type": "object",
"properties": {
"query": {"type": "string", "description": "The search query"},
},
"required": ["query"],
},
},
}
mock_post_search_tool_prompt_builder = MagicMock(spec=AnswerPromptBuilder)
mock_post_search_tool_prompt_builder.build.return_value = [
SystemMessage(content="Updated system prompt"),
]
mock_tool.build_next_prompt.return_value = mock_post_search_tool_prompt_builder
return mock_tool

View File

@@ -7,7 +7,7 @@ from danswer.chat.models import DanswerAnswerPiece
from danswer.chat.models import LlmDoc
from danswer.configs.constants import DocumentSource
from danswer.llm.answering.stream_processing.citation_processing import (
extract_citations_from_stream,
CitationProcessor,
)
from danswer.llm.answering.stream_processing.utils import DocumentIdOrderMapping
@@ -70,14 +70,16 @@ def process_text(
) -> tuple[str, list[CitationInfo]]:
mock_docs, mock_doc_id_to_rank_map = mock_data
mapping = DocumentIdOrderMapping(order_mapping=mock_doc_id_to_rank_map)
result = list(
extract_citations_from_stream(
tokens=iter(tokens),
context_docs=mock_docs,
doc_id_to_rank_map=mapping,
stop_stream=None,
)
processor = CitationProcessor(
context_docs=mock_docs,
doc_id_to_rank_map=mapping,
stop_stream=None,
)
result: list[DanswerAnswerPiece | CitationInfo] = []
for token in tokens:
result.extend(processor.process_token(token))
result.extend(processor.process_token(None))
final_answer_text = ""
citations = []
for piece in result:

View File

@@ -6,7 +6,7 @@ from danswer.chat.models import DanswerQuotes
from danswer.chat.models import LlmDoc
from danswer.configs.constants import DocumentSource
from danswer.llm.answering.stream_processing.quotes_processing import (
process_model_tokens,
QuotesProcessor,
)
mock_docs = [
@@ -25,179 +25,202 @@ mock_docs = [
]
tokens_with_quotes = [
"{",
"\n ",
'"answer": "Yes',
", Danswer allows",
" customized prompts. This",
" feature",
" is currently being",
" developed and implemente",
"d to",
" improve",
" the accuracy",
" of",
" Language",
" Models (",
"LL",
"Ms) for",
" different",
" companies",
".",
" The custom",
"ized prompts feature",
" woul",
"d allow users to ad",
"d person",
"alized prom",
"pts through",
" an",
" interface or",
" metho",
"d,",
" which would then be used to",
" train",
" the LLM.",
" This enhancement",
" aims to make",
" Danswer more",
" adaptable to",
" different",
" business",
" contexts",
" by",
" tail",
"oring it",
" to the specific language",
" an",
"d terminology",
" used within",
" a",
" company.",
" Additionally",
",",
" Danswer already",
" supports creating",
" custom AI",
" Assistants with",
" different",
" prom",
"pts and backing",
" knowledge",
" sets",
",",
" which",
" is",
" a form",
" of prompt",
" customization. However, it",
"'s important to nLogging Details LiteLLM-Success Call: Noneote that some",
" aspects",
" of prompt",
" customization,",
" such as for",
" Sl",
"ack",
"b",
"ots, may",
" still",
" be in",
" development or have",
' limitations.",',
'\n "quotes": [',
'\n "We',
" woul",
"d like to ad",
"d customized prompts for",
" different",
" companies to improve the accuracy of",
" Language",
" Model",
" (LLM)",
'.",\n "A',
" new",
" feature that",
" allows users to add personalize",
"d prompts.",
" This would involve",
" creating",
" an interface or method for",
" users to input",
" their",
" own",
" prom",
"pts,",
" which would then be used to",
' train the LLM.",',
'\n "Create',
" custom AI Assistants with",
" different prompts and backing knowledge",
' sets.",',
'\n "This',
" PR",
" fixes",
" https",
"://github.com/dan",
"swer-ai/dan",
"swer/issues/1",
"584",
" by",
" setting",
" the system",
" default",
" prompt for",
" sl",
"ackbots const",
"rained by",
" ",
"document sets",
".",
" It",
" probably",
" isn",
"'t ideal",
" -",
" it",
" might",
" be pref",
"erable to be",
" able to select",
" a prompt for",
" the",
" slackbot from",
" the",
" admin",
" panel",
" -",
" but it sol",
"ves the immediate problem",
" of",
" the slack",
" listener",
" cr",
"ashing when",
" configure",
"d this",
' way."\n ]',
"\n}",
"",
]
def _process_tokens(
processor: QuotesProcessor, tokens: list[str]
) -> tuple[str, list[str]]:
"""Process a list of tokens and return the answer and quotes.
Args:
processor: QuotesProcessor instance
tokens: List of tokens to process
Returns:
Tuple of (answer_text, list_of_quotes)
"""
answer = ""
quotes: list[str] = []
# need to add a None to the end to simulate the end of the stream
for token in tokens + [None]:
for output in processor.process_token(token):
if isinstance(output, DanswerAnswerPiece):
if output.answer_piece:
answer += output.answer_piece
elif isinstance(output, DanswerQuotes):
quotes.extend(q.quote for q in output.quotes)
return answer, quotes
def test_process_model_tokens_answer() -> None:
gen = process_model_tokens(tokens=iter(tokens_with_quotes), context_docs=mock_docs)
tokens_with_quotes = [
"{",
"\n ",
'"answer": "Yes',
", Danswer allows",
" customized prompts. This",
" feature",
" is currently being",
" developed and implemente",
"d to",
" improve",
" the accuracy",
" of",
" Language",
" Models (",
"LL",
"Ms) for",
" different",
" companies",
".",
" The custom",
"ized prompts feature",
" woul",
"d allow users to ad",
"d person",
"alized prom",
"pts through",
" an",
" interface or",
" metho",
"d,",
" which would then be used to",
" train",
" the LLM.",
" This enhancement",
" aims to make",
" Danswer more",
" adaptable to",
" different",
" business",
" contexts",
" by",
" tail",
"oring it",
" to the specific language",
" an",
"d terminology",
" used within",
" a",
" company.",
" Additionally",
",",
" Danswer already",
" supports creating",
" custom AI",
" Assistants with",
" different",
" prom",
"pts and backing",
" knowledge",
" sets",
",",
" which",
" is",
" a form",
" of prompt",
" customization. However, it",
"'s important to nLogging Details LiteLLM-Success Call: Noneote that some",
" aspects",
" of prompt",
" customization,",
" such as for",
" Sl",
"ack",
"b",
"ots, may",
" still",
" be in",
" development or have",
' limitations.",',
'\n "quotes": [',
'\n "We',
" woul",
"d like to ad",
"d customized prompts for",
" different",
" companies to improve the accuracy of",
" Language",
" Model",
" (LLM)",
'.",\n "A',
" new",
" feature that",
" allows users to add personalize",
"d prompts.",
" This would involve",
" creating",
" an interface or method for",
" users to input",
" their",
" own",
" prom",
"pts,",
" which would then be used to",
' train the LLM.",',
'\n "Create',
" custom AI Assistants with",
" different prompts and backing knowledge",
' sets.",',
'\n "This',
" PR",
" fixes",
" https",
"://github.com/dan",
"swer-ai/dan",
"swer/issues/1",
"584",
" by",
" setting",
" the system",
" default",
" prompt for",
" sl",
"ackbots const",
"rained by",
" ",
"document sets",
".",
" It",
" probably",
" isn",
"'t ideal",
" -",
" it",
" might",
" be pref",
"erable to be",
" able to select",
" a prompt for",
" the",
" slackbot from",
" the",
" admin",
" panel",
" -",
" but it sol",
"ves the immediate problem",
" of",
" the slack",
" listener",
" cr",
"ashing when",
" configure",
"d this",
' way."\n ]',
"\n}",
"",
]
processor = QuotesProcessor(context_docs=mock_docs)
answer, quotes = _process_tokens(processor, tokens_with_quotes)
s_json = "".join(tokens_with_quotes)
j = json.loads(s_json)
expected_answer = j["answer"]
actual = ""
for o in gen:
if isinstance(o, DanswerAnswerPiece):
if o.answer_piece:
actual += o.answer_piece
assert expected_answer == actual
assert expected_answer == answer
# NOTE: no quotes, since the docs don't match the quotes
assert len(quotes) == 0
def test_simple_json_answer() -> None:
@@ -214,16 +237,11 @@ def test_simple_json_answer() -> None:
"\n",
"```",
]
gen = process_model_tokens(tokens=iter(tokens), context_docs=mock_docs)
processor = QuotesProcessor(context_docs=mock_docs)
answer, quotes = _process_tokens(processor, tokens)
expected_answer = "This is a simple answer."
actual = "".join(
o.answer_piece
for o in gen
if isinstance(o, DanswerAnswerPiece) and o.answer_piece
)
assert expected_answer == actual
assert "This is a simple answer." == answer
assert len(quotes) == 0
def test_json_answer_with_quotes() -> None:
@@ -242,16 +260,21 @@ def test_json_answer_with_quotes() -> None:
"\n",
"```",
]
gen = process_model_tokens(tokens=iter(tokens), context_docs=mock_docs)
processor = QuotesProcessor(context_docs=mock_docs)
answer, quotes = _process_tokens(processor, tokens)
expected_answer = "This is a split answer."
actual = "".join(
o.answer_piece
for o in gen
if isinstance(o, DanswerAnswerPiece) and o.answer_piece
)
assert "This is a split answer." == answer
assert len(quotes) == 0
assert expected_answer == actual
def test_json_answer_with_quotes_one_chunk() -> None:
tokens = ['```json\n{"answer": "z",\n"quotes": ["Document"]\n}\n```']
processor = QuotesProcessor(context_docs=mock_docs)
answer, quotes = _process_tokens(processor, tokens)
assert "z" == answer
assert len(quotes) == 1
assert quotes[0] == "Document"
def test_json_answer_split_tokens() -> None:
@@ -271,16 +294,11 @@ def test_json_answer_split_tokens() -> None:
"\n",
"```",
]
gen = process_model_tokens(tokens=iter(tokens), context_docs=mock_docs)
processor = QuotesProcessor(context_docs=mock_docs)
answer, quotes = _process_tokens(processor, tokens)
expected_answer = "This is a split answer."
actual = "".join(
o.answer_piece
for o in gen
if isinstance(o, DanswerAnswerPiece) and o.answer_piece
)
assert expected_answer == actual
assert "This is a split answer." == answer
assert len(quotes) == 0
def test_lengthy_prefixed_json_with_quotes() -> None:
@@ -298,23 +316,12 @@ def test_lengthy_prefixed_json_with_quotes() -> None:
"\n",
"```",
]
processor = QuotesProcessor(context_docs=mock_docs)
answer, quotes = _process_tokens(processor, tokens)
gen = process_model_tokens(tokens=iter(tokens), context_docs=mock_docs)
actual_answer = ""
actual_count = 0
for o in gen:
if isinstance(o, DanswerAnswerPiece):
if o.answer_piece:
actual_answer += o.answer_piece
continue
if isinstance(o, DanswerQuotes):
for q in o.quotes:
assert q.quote == "Document"
actual_count += 1
assert "This is a simple answer." == actual_answer
assert 1 == actual_count
assert "This is a simple answer." == answer
assert len(quotes) == 1
assert quotes[0] == "Document"
def test_prefixed_json_with_quotes() -> None:
@@ -331,21 +338,9 @@ def test_prefixed_json_with_quotes() -> None:
"\n",
"```",
]
processor = QuotesProcessor(context_docs=mock_docs)
answer, quotes = _process_tokens(processor, tokens)
gen = process_model_tokens(tokens=iter(tokens), context_docs=mock_docs)
actual_answer = ""
actual_count = 0
for o in gen:
if isinstance(o, DanswerAnswerPiece):
if o.answer_piece:
actual_answer += o.answer_piece
continue
if isinstance(o, DanswerQuotes):
for q in o.quotes:
assert q.quote == "Document"
actual_count += 1
assert "This is a simple answer." == actual_answer
assert 1 == actual_count
assert "This is a simple answer." == answer
assert len(quotes) == 1
assert quotes[0] == "Document"

View File

@@ -0,0 +1,405 @@
import json
from typing import cast
from unittest.mock import MagicMock
from unittest.mock import Mock
import pytest
from langchain_core.messages import AIMessageChunk
from langchain_core.messages import BaseMessage
from langchain_core.messages import HumanMessage
from langchain_core.messages import SystemMessage
from langchain_core.messages import ToolCall
from langchain_core.messages import ToolCallChunk
from danswer.chat.models import CitationInfo
from danswer.chat.models import DanswerAnswerPiece
from danswer.chat.models import DanswerQuote
from danswer.chat.models import DanswerQuotes
from danswer.chat.models import LlmDoc
from danswer.chat.models import StreamStopInfo
from danswer.chat.models import StreamStopReason
from danswer.llm.answering.answer import Answer
from danswer.llm.answering.models import AnswerStyleConfig
from danswer.llm.answering.models import PromptConfig
from danswer.llm.answering.models import QuotesConfig
from danswer.llm.interfaces import LLM
from danswer.tools.force import ForceUseTool
from danswer.tools.models import ToolCallFinalResult
from danswer.tools.models import ToolCallKickoff
from danswer.tools.models import ToolResponse
from tests.unit.danswer.llm.answering.conftest import DEFAULT_SEARCH_ARGS
from tests.unit.danswer.llm.answering.conftest import QUERY
@pytest.fixture
def answer_instance(
mock_llm: LLM, answer_style_config: AnswerStyleConfig, prompt_config: PromptConfig
) -> Answer:
return Answer(
question=QUERY,
answer_style_config=answer_style_config,
llm=mock_llm,
prompt_config=prompt_config,
force_use_tool=ForceUseTool(force_use=False, tool_name="", args=None),
)
def test_basic_answer(answer_instance: Answer) -> None:
mock_llm = cast(Mock, answer_instance.llm)
mock_llm.stream.return_value = [
AIMessageChunk(content="This is a "),
AIMessageChunk(content="mock answer."),
]
output = list(answer_instance.processed_streamed_output)
assert len(output) == 2
assert isinstance(output[0], DanswerAnswerPiece)
assert isinstance(output[1], DanswerAnswerPiece)
full_answer = "".join(
piece.answer_piece
for piece in output
if isinstance(piece, DanswerAnswerPiece) and piece.answer_piece is not None
)
assert full_answer == "This is a mock answer."
assert answer_instance.llm_answer == "This is a mock answer."
assert answer_instance.citations == []
assert mock_llm.stream.call_count == 1
mock_llm.stream.assert_called_once_with(
prompt=[
SystemMessage(content="System prompt"),
HumanMessage(content="Task prompt\n\nQUERY:\nTest question"),
],
tools=None,
tool_choice=None,
structured_response_format=None,
)
@pytest.mark.parametrize(
"force_use_tool, expected_tool_args",
[
(
ForceUseTool(force_use=False, tool_name="", args=None),
DEFAULT_SEARCH_ARGS,
),
(
ForceUseTool(
force_use=True, tool_name="search", args={"query": "forced search"}
),
{"query": "forced search"},
),
],
)
def test_answer_with_search_call(
answer_instance: Answer,
mock_search_results: list[LlmDoc],
mock_search_tool: MagicMock,
force_use_tool: ForceUseTool,
expected_tool_args: dict,
) -> None:
answer_instance.tools = [mock_search_tool]
answer_instance.force_use_tool = force_use_tool
# Set up the LLM mock to return search results and then an answer
mock_llm = cast(Mock, answer_instance.llm)
stream_side_effect: list[list[BaseMessage]] = []
if not force_use_tool.force_use:
tool_call_chunk = AIMessageChunk(content="")
tool_call_chunk.tool_calls = [
ToolCall(
id="search",
name="search",
args=expected_tool_args,
)
]
tool_call_chunk.tool_call_chunks = [
ToolCallChunk(
id="search",
name="search",
args=json.dumps(expected_tool_args),
index=0,
)
]
stream_side_effect.append([tool_call_chunk])
stream_side_effect.append(
[
AIMessageChunk(content="Based on the search results, "),
AIMessageChunk(content="the answer is abc[1]. "),
AIMessageChunk(content="This is some other stuff."),
],
)
mock_llm.stream.side_effect = stream_side_effect
# Process the output
output = list(answer_instance.processed_streamed_output)
print(output)
# Updated assertions
assert len(output) == 7
assert output[0] == ToolCallKickoff(
tool_name="search", tool_args=expected_tool_args
)
assert output[1] == ToolResponse(
id="final_context_documents",
response=mock_search_results,
)
assert output[2] == ToolCallFinalResult(
tool_name="search",
tool_args=expected_tool_args,
tool_result=[json.loads(doc.model_dump_json()) for doc in mock_search_results],
)
assert output[3] == DanswerAnswerPiece(answer_piece="Based on the search results, ")
expected_citation = CitationInfo(citation_num=1, document_id="doc1")
assert output[4] == expected_citation
assert output[5] == DanswerAnswerPiece(
answer_piece="the answer is abc[[1]](https://example.com/doc1). "
)
assert output[6] == DanswerAnswerPiece(answer_piece="This is some other stuff.")
expected_answer = (
"Based on the search results, "
"the answer is abc[[1]](https://example.com/doc1). "
"This is some other stuff."
)
full_answer = "".join(
piece.answer_piece
for piece in output
if isinstance(piece, DanswerAnswerPiece) and piece.answer_piece is not None
)
assert full_answer == expected_answer
assert answer_instance.llm_answer == expected_answer
assert len(answer_instance.citations) == 1
assert answer_instance.citations[0] == expected_citation
# Verify LLM calls
if not force_use_tool.force_use:
assert mock_llm.stream.call_count == 2
first_call, second_call = mock_llm.stream.call_args_list
# First call should include the search tool definition
assert len(first_call.kwargs["tools"]) == 1
assert (
first_call.kwargs["tools"][0]
== mock_search_tool.tool_definition.return_value
)
# Second call should not include tools (as we're just generating the final answer)
assert "tools" not in second_call.kwargs or not second_call.kwargs["tools"]
# Second call should use the returned prompt from build_next_prompt
assert (
second_call.kwargs["prompt"]
== mock_search_tool.build_next_prompt.return_value.build.return_value
)
# Verify that tool_definition was called on the mock_search_tool
mock_search_tool.tool_definition.assert_called_once()
else:
assert mock_llm.stream.call_count == 1
call = mock_llm.stream.call_args_list[0]
assert (
call.kwargs["prompt"]
== mock_search_tool.build_next_prompt.return_value.build.return_value
)
def test_answer_with_search_no_tool_calling(
answer_instance: Answer,
mock_search_results: list[LlmDoc],
mock_search_tool: MagicMock,
) -> None:
answer_instance.tools = [mock_search_tool]
# Set up the LLM mock to return an answer
mock_llm = cast(Mock, answer_instance.llm)
mock_llm.stream.return_value = [
AIMessageChunk(content="Based on the search results, "),
AIMessageChunk(content="the answer is abc[1]. "),
AIMessageChunk(content="This is some other stuff."),
]
# Force non-tool calling behavior
answer_instance.using_tool_calling_llm = False
# Process the output
output = list(answer_instance.processed_streamed_output)
# Assertions
assert len(output) == 7
assert output[0] == ToolCallKickoff(
tool_name="search", tool_args=DEFAULT_SEARCH_ARGS
)
assert output[1] == ToolResponse(
id="final_context_documents",
response=mock_search_results,
)
assert output[2] == ToolCallFinalResult(
tool_name="search",
tool_args=DEFAULT_SEARCH_ARGS,
tool_result=[json.loads(doc.model_dump_json()) for doc in mock_search_results],
)
assert output[3] == DanswerAnswerPiece(answer_piece="Based on the search results, ")
expected_citation = CitationInfo(citation_num=1, document_id="doc1")
assert output[4] == expected_citation
assert output[5] == DanswerAnswerPiece(
answer_piece="the answer is abc[[1]](https://example.com/doc1). "
)
assert output[6] == DanswerAnswerPiece(answer_piece="This is some other stuff.")
expected_answer = (
"Based on the search results, "
"the answer is abc[[1]](https://example.com/doc1). "
"This is some other stuff."
)
assert answer_instance.llm_answer == expected_answer
assert len(answer_instance.citations) == 1
assert answer_instance.citations[0] == expected_citation
# Verify LLM calls
assert mock_llm.stream.call_count == 1
call_args = mock_llm.stream.call_args
# Verify that no tools were passed to the LLM
assert "tools" not in call_args.kwargs or not call_args.kwargs["tools"]
# Verify that the prompt was built correctly
assert (
call_args.kwargs["prompt"]
== mock_search_tool.build_next_prompt.return_value.build.return_value
)
# Verify that get_args_for_non_tool_calling_llm was called on the mock_search_tool
mock_search_tool.get_args_for_non_tool_calling_llm.assert_called_once_with(
f"Task prompt\n\nQUERY:\n{QUERY}", [], answer_instance.llm
)
# Verify that the search tool's run method was called
mock_search_tool.run.assert_called_once()
def test_answer_with_search_call_quotes_enabled(
answer_instance: Answer,
mock_search_results: list[LlmDoc],
mock_search_tool: MagicMock,
) -> None:
answer_instance.tools = [mock_search_tool]
answer_instance.force_use_tool = ForceUseTool(
force_use=False, tool_name="", args=None
)
answer_instance.answer_style_config.citation_config = None
answer_instance.answer_style_config.quotes_config = QuotesConfig()
# Set up the LLM mock to return search results and then an answer
mock_llm = cast(Mock, answer_instance.llm)
tool_call_chunk = AIMessageChunk(content="")
tool_call_chunk.tool_calls = [
ToolCall(
id="search",
name="search",
args=DEFAULT_SEARCH_ARGS,
)
]
tool_call_chunk.tool_call_chunks = [
ToolCallChunk(
id="search",
name="search",
args=json.dumps(DEFAULT_SEARCH_ARGS),
index=0,
)
]
# needs to be short due to the "anti-hallucination" check in QuotesProcessor
answer_content = "z"
quote_content = mock_search_results[0].content
mock_llm.stream.side_effect = [
[tool_call_chunk],
[
AIMessageChunk(
content=(
'{"answer": "'
+ answer_content
+ '", "quotes": ["'
+ quote_content
+ '"]}'
)
),
],
]
# Process the output
output = list(answer_instance.processed_streamed_output)
# Assertions
assert len(output) == 5
assert output[0] == ToolCallKickoff(
tool_name="search", tool_args=DEFAULT_SEARCH_ARGS
)
assert output[1] == ToolResponse(
id="final_context_documents",
response=mock_search_results,
)
assert output[2] == ToolCallFinalResult(
tool_name="search",
tool_args=DEFAULT_SEARCH_ARGS,
tool_result=[json.loads(doc.model_dump_json()) for doc in mock_search_results],
)
assert output[3] == DanswerAnswerPiece(answer_piece=answer_content)
assert output[4] == DanswerQuotes(
quotes=[
DanswerQuote(
quote=quote_content,
document_id=mock_search_results[0].document_id,
link=mock_search_results[0].link,
source_type=mock_search_results[0].source_type,
semantic_identifier=mock_search_results[0].semantic_identifier,
blurb=mock_search_results[0].blurb,
)
]
)
assert answer_instance.llm_answer == answer_content
def test_is_cancelled(answer_instance: Answer) -> None:
# Set up the LLM mock to return multiple chunks
mock_llm = Mock()
answer_instance.llm = mock_llm
mock_llm.stream.return_value = [
AIMessageChunk(content="This is the "),
AIMessageChunk(content="first part."),
AIMessageChunk(content="This should not be seen."),
]
# Create a mutable object to control is_connected behavior
connection_status = {"connected": True}
answer_instance.is_connected = lambda: connection_status["connected"]
# Process the output
output = []
for i, chunk in enumerate(answer_instance.processed_streamed_output):
output.append(chunk)
# Simulate disconnection after the second chunk
if i == 1:
connection_status["connected"] = False
assert len(output) == 3
assert output[0] == DanswerAnswerPiece(answer_piece="This is the ")
assert output[1] == DanswerAnswerPiece(answer_piece="first part.")
assert output[2] == StreamStopInfo(stop_reason=StreamStopReason.CANCELLED)
# Verify that the stream was cancelled
assert answer_instance.is_cancelled() is True
# Verify that the final answer only contains the streamed parts
assert answer_instance.llm_answer == "This is the first part."
# Verify LLM calls
mock_llm.stream.assert_called_once()

View File

@@ -6,8 +6,11 @@ import pytest
from pytest_mock import MockerFixture
from danswer.llm.answering.answer import Answer
from danswer.llm.answering.models import AnswerStyleConfig
from danswer.llm.answering.models import PromptConfig
from danswer.one_shot_answer.answer_question import AnswerObjectIterator
from danswer.tools.force import ForceUseTool
from danswer.tools.tool_implementations.search.search_tool import SearchTool
from tests.regression.answer_quality.run_qa import _process_and_write_query_results
@@ -24,39 +27,43 @@ from tests.regression.answer_quality.run_qa import _process_and_write_query_resu
},
],
)
def test_skip_gen_ai_answer_generation_flag(config: dict[str, Any]) -> None:
search_tool = Mock()
search_tool.name = "search"
search_tool.run = Mock()
search_tool.run.return_value = [Mock()]
def test_skip_gen_ai_answer_generation_flag(
config: dict[str, Any],
mock_search_tool: SearchTool,
answer_style_config: AnswerStyleConfig,
prompt_config: PromptConfig,
) -> None:
question = config["question"]
skip_gen_ai_answer_generation = config["skip_gen_ai_answer_generation"]
mock_llm = Mock()
mock_llm.config = Mock()
mock_llm.config.model_name = "gpt-4o-mini"
mock_llm.stream = Mock()
mock_llm.stream.return_value = [Mock()]
answer = Answer(
question=config["question"],
answer_style_config=Mock(),
prompt_config=Mock(),
question=question,
answer_style_config=answer_style_config,
prompt_config=prompt_config,
llm=mock_llm,
single_message_history="history",
tools=[search_tool],
tools=[mock_search_tool],
force_use_tool=(
ForceUseTool(
tool_name=search_tool.name,
args={"query": config["question"]},
tool_name=mock_search_tool.name,
args={"query": question},
force_use=True,
)
),
skip_explicit_tool_calling=True,
return_contexts=True,
skip_gen_ai_answer_generation=config["skip_gen_ai_answer_generation"],
skip_gen_ai_answer_generation=skip_gen_ai_answer_generation,
)
count = 0
for _ in cast(AnswerObjectIterator, answer.processed_streamed_output):
count += 1
assert count == 2
if not config["skip_gen_ai_answer_generation"]:
assert count == 3 if skip_gen_ai_answer_generation else 4
if not skip_gen_ai_answer_generation:
mock_llm.stream.assert_called_once()
else:
mock_llm.stream.assert_not_called()

View File

@@ -5,14 +5,18 @@ from unittest.mock import patch
import pytest
from danswer.tools.custom.custom_tool import (
from danswer.tools.models import DynamicSchemaInfo
from danswer.tools.models import ToolResponse
from danswer.tools.tool_implementations.custom.custom_tool import (
build_custom_tools_from_openapi_schema_and_headers,
)
from danswer.tools.custom.custom_tool import CUSTOM_TOOL_RESPONSE_ID
from danswer.tools.custom.custom_tool import CustomToolCallSummary
from danswer.tools.custom.custom_tool import validate_openapi_schema
from danswer.tools.models import DynamicSchemaInfo
from danswer.tools.tool import ToolResponse
from danswer.tools.tool_implementations.custom.custom_tool import (
CUSTOM_TOOL_RESPONSE_ID,
)
from danswer.tools.tool_implementations.custom.custom_tool import CustomToolCallSummary
from danswer.tools.tool_implementations.custom.custom_tool import (
validate_openapi_schema,
)
from danswer.utils.headers import HeaderItemDict
@@ -78,7 +82,7 @@ class TestCustomTool(unittest.TestCase):
chat_session_id=uuid.uuid4(), message_id=20
)
@patch("danswer.tools.custom.custom_tool.requests.request")
@patch("danswer.tools.tool_implementations.custom.custom_tool.requests.request")
def test_custom_tool_run_get(self, mock_request: unittest.mock.MagicMock) -> None:
"""
Test the GET method of a custom tool.
@@ -106,7 +110,7 @@ class TestCustomTool(unittest.TestCase):
"Tool name in response does not match expected value",
)
@patch("danswer.tools.custom.custom_tool.requests.request")
@patch("danswer.tools.tool_implementations.custom.custom_tool.requests.request")
def test_custom_tool_run_post(self, mock_request: unittest.mock.MagicMock) -> None:
"""
Test the POST method of a custom tool.
@@ -136,7 +140,7 @@ class TestCustomTool(unittest.TestCase):
"Tool name in response does not match expected value",
)
@patch("danswer.tools.custom.custom_tool.requests.request")
@patch("danswer.tools.tool_implementations.custom.custom_tool.requests.request")
def test_custom_tool_with_headers(
self, mock_request: unittest.mock.MagicMock
) -> None:
@@ -164,7 +168,7 @@ class TestCustomTool(unittest.TestCase):
"GET", expected_url, json=None, headers=expected_headers
)
@patch("danswer.tools.custom.custom_tool.requests.request")
@patch("danswer.tools.tool_implementations.custom.custom_tool.requests.request")
def test_custom_tool_with_empty_headers(
self, mock_request: unittest.mock.MagicMock
) -> None: