mirror of
https://github.com/danswer-ai/danswer.git
synced 2025-07-28 13:53:28 +02:00
Fix answer with specified doc ids (#2703)
* Fix Fix Refactor more more fix refactor Fix circular imports Refactor Move tests around * Add quote support * Testing * More testing * Fix image generation slowness * Remove unused exception * Fix UT * fix stop generating * minor typo * minor logging updates for clarity --------- Co-authored-by: pablodanswer <pablo@danswer.ai>
This commit is contained in:
@@ -10,7 +10,7 @@ from danswer.search.enums import QueryFlow
|
||||
from danswer.search.enums import SearchType
|
||||
from danswer.search.models import RetrievalDocs
|
||||
from danswer.search.models import SearchResponse
|
||||
from danswer.tools.custom.base_tool_types import ToolResultType
|
||||
from danswer.tools.tool_implementations.custom.base_tool_types import ToolResultType
|
||||
|
||||
|
||||
class LlmDoc(BaseModel):
|
||||
|
@@ -18,6 +18,7 @@ from danswer.chat.models import MessageResponseIDInfo
|
||||
from danswer.chat.models import MessageSpecificCitations
|
||||
from danswer.chat.models import QADocsResponse
|
||||
from danswer.chat.models import StreamingError
|
||||
from danswer.chat.models import StreamStopInfo
|
||||
from danswer.configs.app_configs import AZURE_DALLE_API_BASE
|
||||
from danswer.configs.app_configs import AZURE_DALLE_API_KEY
|
||||
from danswer.configs.app_configs import AZURE_DALLE_API_VERSION
|
||||
@@ -77,31 +78,49 @@ from danswer.server.query_and_chat.models import ChatMessageDetail
|
||||
from danswer.server.query_and_chat.models import CreateChatMessageRequest
|
||||
from danswer.server.utils import get_json_line
|
||||
from danswer.tools.built_in_tools import get_built_in_tool_by_id
|
||||
from danswer.tools.custom.custom_tool import (
|
||||
from danswer.tools.force import ForceUseTool
|
||||
from danswer.tools.models import DynamicSchemaInfo
|
||||
from danswer.tools.models import ToolResponse
|
||||
from danswer.tools.tool import Tool
|
||||
from danswer.tools.tool_implementations.custom.custom_tool import (
|
||||
build_custom_tools_from_openapi_schema_and_headers,
|
||||
)
|
||||
from danswer.tools.custom.custom_tool import CUSTOM_TOOL_RESPONSE_ID
|
||||
from danswer.tools.custom.custom_tool import CustomToolCallSummary
|
||||
from danswer.tools.force import ForceUseTool
|
||||
from danswer.tools.images.image_generation_tool import IMAGE_GENERATION_RESPONSE_ID
|
||||
from danswer.tools.images.image_generation_tool import ImageGenerationResponse
|
||||
from danswer.tools.images.image_generation_tool import ImageGenerationTool
|
||||
from danswer.tools.internet_search.internet_search_tool import (
|
||||
from danswer.tools.tool_implementations.custom.custom_tool import (
|
||||
CUSTOM_TOOL_RESPONSE_ID,
|
||||
)
|
||||
from danswer.tools.tool_implementations.custom.custom_tool import CustomToolCallSummary
|
||||
from danswer.tools.tool_implementations.images.image_generation_tool import (
|
||||
IMAGE_GENERATION_RESPONSE_ID,
|
||||
)
|
||||
from danswer.tools.tool_implementations.images.image_generation_tool import (
|
||||
ImageGenerationResponse,
|
||||
)
|
||||
from danswer.tools.tool_implementations.images.image_generation_tool import (
|
||||
ImageGenerationTool,
|
||||
)
|
||||
from danswer.tools.tool_implementations.internet_search.internet_search_tool import (
|
||||
INTERNET_SEARCH_RESPONSE_ID,
|
||||
)
|
||||
from danswer.tools.internet_search.internet_search_tool import (
|
||||
from danswer.tools.tool_implementations.internet_search.internet_search_tool import (
|
||||
internet_search_response_to_search_docs,
|
||||
)
|
||||
from danswer.tools.internet_search.internet_search_tool import InternetSearchResponse
|
||||
from danswer.tools.internet_search.internet_search_tool import InternetSearchTool
|
||||
from danswer.tools.models import DynamicSchemaInfo
|
||||
from danswer.tools.search.search_tool import FINAL_CONTEXT_DOCUMENTS_ID
|
||||
from danswer.tools.search.search_tool import SEARCH_RESPONSE_SUMMARY_ID
|
||||
from danswer.tools.search.search_tool import SearchResponseSummary
|
||||
from danswer.tools.search.search_tool import SearchTool
|
||||
from danswer.tools.search.search_tool import SECTION_RELEVANCE_LIST_ID
|
||||
from danswer.tools.tool import Tool
|
||||
from danswer.tools.tool import ToolResponse
|
||||
from danswer.tools.tool_implementations.internet_search.internet_search_tool import (
|
||||
InternetSearchResponse,
|
||||
)
|
||||
from danswer.tools.tool_implementations.internet_search.internet_search_tool import (
|
||||
InternetSearchTool,
|
||||
)
|
||||
from danswer.tools.tool_implementations.search.search_tool import (
|
||||
FINAL_CONTEXT_DOCUMENTS_ID,
|
||||
)
|
||||
from danswer.tools.tool_implementations.search.search_tool import (
|
||||
SEARCH_RESPONSE_SUMMARY_ID,
|
||||
)
|
||||
from danswer.tools.tool_implementations.search.search_tool import SearchResponseSummary
|
||||
from danswer.tools.tool_implementations.search.search_tool import SearchTool
|
||||
from danswer.tools.tool_implementations.search.search_tool import (
|
||||
SECTION_RELEVANCE_LIST_ID,
|
||||
)
|
||||
from danswer.tools.tool_runner import ToolCallFinalResult
|
||||
from danswer.tools.utils import compute_all_tool_tokens
|
||||
from danswer.tools.utils import explicit_tool_calling_supported
|
||||
@@ -260,6 +279,7 @@ ChatPacket = (
|
||||
| CustomToolResponse
|
||||
| MessageSpecificCitations
|
||||
| MessageResponseIDInfo
|
||||
| StreamStopInfo
|
||||
)
|
||||
ChatPacketStream = Iterator[ChatPacket]
|
||||
|
||||
@@ -532,6 +552,13 @@ def stream_chat_message_objects(
|
||||
if not persona
|
||||
else PromptConfig.from_model(persona.prompts[0])
|
||||
)
|
||||
answer_style_config = AnswerStyleConfig(
|
||||
citation_config=CitationConfig(
|
||||
all_docs_useful=selected_db_search_docs is not None
|
||||
),
|
||||
document_pruning_config=document_pruning_config,
|
||||
structured_response_format=new_msg_req.structured_response_format,
|
||||
)
|
||||
|
||||
# find out what tools to use
|
||||
search_tool: SearchTool | None = None
|
||||
@@ -550,13 +577,16 @@ def stream_chat_message_objects(
|
||||
llm=llm,
|
||||
fast_llm=fast_llm,
|
||||
pruning_config=document_pruning_config,
|
||||
answer_style_config=answer_style_config,
|
||||
selected_sections=selected_sections,
|
||||
chunks_above=new_msg_req.chunks_above,
|
||||
chunks_below=new_msg_req.chunks_below,
|
||||
full_doc=new_msg_req.full_doc,
|
||||
evaluation_type=LLMEvaluationType.BASIC
|
||||
if persona.llm_relevance_filter
|
||||
else LLMEvaluationType.SKIP,
|
||||
evaluation_type=(
|
||||
LLMEvaluationType.BASIC
|
||||
if persona.llm_relevance_filter
|
||||
else LLMEvaluationType.SKIP
|
||||
),
|
||||
)
|
||||
tool_dict[db_tool_model.id] = [search_tool]
|
||||
elif tool_cls.__name__ == ImageGenerationTool.__name__:
|
||||
@@ -626,7 +656,11 @@ def stream_chat_message_objects(
|
||||
"Internet search tool requires a Bing API key, please contact your Danswer admin to get it added!"
|
||||
)
|
||||
tool_dict[db_tool_model.id] = [
|
||||
InternetSearchTool(api_key=bing_api_key)
|
||||
InternetSearchTool(
|
||||
api_key=bing_api_key,
|
||||
answer_style_config=answer_style_config,
|
||||
prompt_config=prompt_config,
|
||||
)
|
||||
]
|
||||
|
||||
continue
|
||||
@@ -667,13 +701,7 @@ def stream_chat_message_objects(
|
||||
is_connected=is_connected,
|
||||
question=final_msg.message,
|
||||
latest_query_files=latest_query_files,
|
||||
answer_style_config=AnswerStyleConfig(
|
||||
citation_config=CitationConfig(
|
||||
all_docs_useful=selected_db_search_docs is not None
|
||||
),
|
||||
document_pruning_config=document_pruning_config,
|
||||
structured_response_format=new_msg_req.structured_response_format,
|
||||
),
|
||||
answer_style_config=answer_style_config,
|
||||
prompt_config=prompt_config,
|
||||
llm=(
|
||||
llm
|
||||
@@ -777,7 +805,8 @@ def stream_chat_message_objects(
|
||||
response=custom_tool_response.tool_result,
|
||||
tool_name=custom_tool_response.tool_name,
|
||||
)
|
||||
|
||||
elif isinstance(packet, StreamStopInfo):
|
||||
pass
|
||||
else:
|
||||
if isinstance(packet, ToolCallFinalResult):
|
||||
tool_result = packet
|
||||
@@ -807,6 +836,7 @@ def stream_chat_message_objects(
|
||||
|
||||
# Post-LLM answer processing
|
||||
try:
|
||||
logger.debug("Post-LLM answer processing")
|
||||
message_specific_citations: MessageSpecificCitations | None = None
|
||||
if reference_db_search_docs:
|
||||
message_specific_citations = _translate_citations(
|
||||
|
@@ -1,72 +1,44 @@
|
||||
import itertools
|
||||
from collections.abc import Callable
|
||||
from collections.abc import Iterator
|
||||
from typing import Any
|
||||
from typing import cast
|
||||
from uuid import uuid4
|
||||
|
||||
from langchain.schema.messages import BaseMessage
|
||||
from langchain_core.messages import AIMessageChunk
|
||||
from langchain_core.messages import HumanMessage
|
||||
from langchain_core.messages import ToolCall
|
||||
|
||||
from danswer.chat.chat_utils import llm_doc_from_inference_section
|
||||
from danswer.chat.models import AnswerQuestionPossibleReturn
|
||||
from danswer.chat.models import CitationInfo
|
||||
from danswer.chat.models import DanswerAnswerPiece
|
||||
from danswer.chat.models import LlmDoc
|
||||
from danswer.chat.models import StreamStopInfo
|
||||
from danswer.chat.models import StreamStopReason
|
||||
from danswer.configs.chat_configs import QA_PROMPT_OVERRIDE
|
||||
from danswer.file_store.utils import InMemoryChatFile
|
||||
from danswer.llm.answering.llm_response_handler import LLMCall
|
||||
from danswer.llm.answering.llm_response_handler import LLMResponseHandlerManager
|
||||
from danswer.llm.answering.models import AnswerStyleConfig
|
||||
from danswer.llm.answering.models import PreviousMessage
|
||||
from danswer.llm.answering.models import PromptConfig
|
||||
from danswer.llm.answering.models import StreamProcessor
|
||||
from danswer.llm.answering.prompts.build import AnswerPromptBuilder
|
||||
from danswer.llm.answering.prompts.build import default_build_system_message
|
||||
from danswer.llm.answering.prompts.build import default_build_user_message
|
||||
from danswer.llm.answering.prompts.citations_prompt import (
|
||||
build_citations_system_message,
|
||||
from danswer.llm.answering.stream_processing.answer_response_handler import (
|
||||
AnswerResponseHandler,
|
||||
)
|
||||
from danswer.llm.answering.prompts.citations_prompt import build_citations_user_message
|
||||
from danswer.llm.answering.prompts.quotes_prompt import build_quotes_user_message
|
||||
from danswer.llm.answering.stream_processing.citation_processing import (
|
||||
build_citation_processor,
|
||||
from danswer.llm.answering.stream_processing.answer_response_handler import (
|
||||
CitationResponseHandler,
|
||||
)
|
||||
from danswer.llm.answering.stream_processing.quotes_processing import (
|
||||
build_quotes_processor,
|
||||
from danswer.llm.answering.stream_processing.answer_response_handler import (
|
||||
DummyAnswerResponseHandler,
|
||||
)
|
||||
from danswer.llm.answering.stream_processing.answer_response_handler import (
|
||||
QuotesResponseHandler,
|
||||
)
|
||||
from danswer.llm.answering.stream_processing.utils import DocumentIdOrderMapping
|
||||
from danswer.llm.answering.stream_processing.utils import map_document_id_order
|
||||
from danswer.llm.answering.tool.tool_response_handler import ToolResponseHandler
|
||||
from danswer.llm.interfaces import LLM
|
||||
from danswer.llm.interfaces import ToolChoiceOptions
|
||||
from danswer.natural_language_processing.utils import get_tokenizer
|
||||
from danswer.tools.custom.custom_tool_prompt_builder import (
|
||||
build_user_message_for_custom_tool_for_non_tool_calling_llm,
|
||||
)
|
||||
from danswer.tools.force import filter_tools_for_force_tool_use
|
||||
from danswer.tools.force import ForceUseTool
|
||||
from danswer.tools.images.image_generation_tool import IMAGE_GENERATION_RESPONSE_ID
|
||||
from danswer.tools.images.image_generation_tool import ImageGenerationResponse
|
||||
from danswer.tools.images.image_generation_tool import ImageGenerationTool
|
||||
from danswer.tools.images.prompt import build_image_generation_user_prompt
|
||||
from danswer.tools.internet_search.internet_search_tool import InternetSearchTool
|
||||
from danswer.tools.message import build_tool_message
|
||||
from danswer.tools.message import ToolCallSummary
|
||||
from danswer.tools.search.search_tool import FINAL_CONTEXT_DOCUMENTS_ID
|
||||
from danswer.tools.search.search_tool import SEARCH_DOC_CONTENT_ID
|
||||
from danswer.tools.search.search_tool import SEARCH_RESPONSE_SUMMARY_ID
|
||||
from danswer.tools.search.search_tool import SearchResponseSummary
|
||||
from danswer.tools.search.search_tool import SearchTool
|
||||
from danswer.tools.models import ToolResponse
|
||||
from danswer.tools.tool import Tool
|
||||
from danswer.tools.tool import ToolResponse
|
||||
from danswer.tools.tool_runner import (
|
||||
check_which_tools_should_run_for_non_tool_calling_llm,
|
||||
)
|
||||
from danswer.tools.tool_runner import ToolCallFinalResult
|
||||
from danswer.tools.tool_implementations.search.search_tool import SearchTool
|
||||
from danswer.tools.tool_runner import ToolCallKickoff
|
||||
from danswer.tools.tool_runner import ToolRunner
|
||||
from danswer.tools.tool_selection import select_single_tool_for_non_tool_calling_llm
|
||||
from danswer.tools.utils import explicit_tool_calling_supported
|
||||
from danswer.utils.logger import setup_logger
|
||||
|
||||
@@ -74,29 +46,9 @@ from danswer.utils.logger import setup_logger
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
def _get_answer_stream_processor(
|
||||
context_docs: list[LlmDoc],
|
||||
doc_id_to_rank_map: DocumentIdOrderMapping,
|
||||
answer_style_configs: AnswerStyleConfig,
|
||||
) -> StreamProcessor:
|
||||
if answer_style_configs.citation_config:
|
||||
return build_citation_processor(
|
||||
context_docs=context_docs, doc_id_to_rank_map=doc_id_to_rank_map
|
||||
)
|
||||
if answer_style_configs.quotes_config:
|
||||
return build_quotes_processor(
|
||||
context_docs=context_docs, is_json_prompt=not (QA_PROMPT_OVERRIDE == "weak")
|
||||
)
|
||||
|
||||
raise RuntimeError("Not implemented yet")
|
||||
|
||||
|
||||
AnswerStream = Iterator[AnswerQuestionPossibleReturn | ToolCallKickoff | ToolResponse]
|
||||
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
class Answer:
|
||||
def __init__(
|
||||
self,
|
||||
@@ -136,8 +88,6 @@ class Answer:
|
||||
self.tools = tools or []
|
||||
self.force_use_tool = force_use_tool
|
||||
|
||||
self.skip_explicit_tool_calling = skip_explicit_tool_calling
|
||||
|
||||
self.message_history = message_history or []
|
||||
# used for QA flow where we only want to send a single message
|
||||
self.single_message_history = single_message_history
|
||||
@@ -162,335 +112,141 @@ class Answer:
|
||||
self.skip_gen_ai_answer_generation = skip_gen_ai_answer_generation
|
||||
self._is_cancelled = False
|
||||
|
||||
def _update_prompt_builder_for_search_tool(
|
||||
self, prompt_builder: AnswerPromptBuilder, final_context_documents: list[LlmDoc]
|
||||
) -> None:
|
||||
if self.answer_style_config.citation_config:
|
||||
prompt_builder.update_system_prompt(
|
||||
build_citations_system_message(self.prompt_config)
|
||||
)
|
||||
prompt_builder.update_user_prompt(
|
||||
build_citations_user_message(
|
||||
question=self.question,
|
||||
prompt_config=self.prompt_config,
|
||||
context_docs=final_context_documents,
|
||||
files=self.latest_query_files,
|
||||
all_doc_useful=(
|
||||
self.answer_style_config.citation_config.all_docs_useful
|
||||
if self.answer_style_config.citation_config
|
||||
else False
|
||||
),
|
||||
history_message=self.single_message_history or "",
|
||||
)
|
||||
)
|
||||
elif self.answer_style_config.quotes_config:
|
||||
prompt_builder.update_user_prompt(
|
||||
build_quotes_user_message(
|
||||
question=self.question,
|
||||
context_docs=final_context_documents,
|
||||
history_str=self.single_message_history or "",
|
||||
prompt=self.prompt_config,
|
||||
)
|
||||
self.using_tool_calling_llm = (
|
||||
explicit_tool_calling_supported(
|
||||
self.llm.config.model_provider, self.llm.config.model_name
|
||||
)
|
||||
and not skip_explicit_tool_calling
|
||||
)
|
||||
|
||||
def _raw_output_for_explicit_tool_calling_llms(
|
||||
self,
|
||||
) -> Iterator[
|
||||
str | StreamStopInfo | ToolCallKickoff | ToolResponse | ToolCallFinalResult
|
||||
]:
|
||||
prompt_builder = AnswerPromptBuilder(self.message_history, self.llm.config)
|
||||
def _get_tools_list(self) -> list[Tool]:
|
||||
if not self.force_use_tool.force_use:
|
||||
return self.tools
|
||||
|
||||
tool_call_chunk: AIMessageChunk | None = None
|
||||
if self.force_use_tool.force_use and self.force_use_tool.args is not None:
|
||||
# if we are forcing a tool WITH args specified, we don't need to check which tools to run
|
||||
# / need to generate the args
|
||||
tool_call_chunk = AIMessageChunk(
|
||||
content="",
|
||||
tool = next(
|
||||
(t for t in self.tools if t.name == self.force_use_tool.tool_name), None
|
||||
)
|
||||
if tool is None:
|
||||
raise RuntimeError(f"Tool '{self.force_use_tool.tool_name}' not found")
|
||||
|
||||
logger.info(
|
||||
f"Forcefully using tool='{tool.name}'"
|
||||
+ (
|
||||
f" with args='{self.force_use_tool.args}'"
|
||||
if self.force_use_tool.args is not None
|
||||
else ""
|
||||
)
|
||||
tool_call_chunk.tool_calls = [
|
||||
{
|
||||
"name": self.force_use_tool.tool_name,
|
||||
"args": self.force_use_tool.args,
|
||||
"id": str(uuid4()),
|
||||
}
|
||||
]
|
||||
)
|
||||
return [tool]
|
||||
|
||||
def _handle_specified_tool_call(
|
||||
self, llm_calls: list[LLMCall], tool: Tool, tool_args: dict
|
||||
) -> AnswerStream:
|
||||
current_llm_call = llm_calls[-1]
|
||||
|
||||
# make a dummy tool handler
|
||||
tool_handler = ToolResponseHandler([tool])
|
||||
|
||||
dummy_tool_call_chunk = AIMessageChunk(content="")
|
||||
dummy_tool_call_chunk.tool_calls = [
|
||||
ToolCall(name=tool.name, args=tool_args, id=str(uuid4()))
|
||||
]
|
||||
|
||||
response_handler_manager = LLMResponseHandlerManager(
|
||||
tool_handler, DummyAnswerResponseHandler(), self.is_cancelled
|
||||
)
|
||||
yield from response_handler_manager.handle_llm_response(
|
||||
iter([dummy_tool_call_chunk])
|
||||
)
|
||||
|
||||
new_llm_call = response_handler_manager.next_llm_call(current_llm_call)
|
||||
if new_llm_call:
|
||||
yield from self._get_response(llm_calls + [new_llm_call])
|
||||
else:
|
||||
# if tool calling is supported, first try the raw message
|
||||
# to see if we don't need to use any tools
|
||||
prompt_builder.update_system_prompt(
|
||||
default_build_system_message(self.prompt_config)
|
||||
)
|
||||
prompt_builder.update_user_prompt(
|
||||
default_build_user_message(
|
||||
self.question, self.prompt_config, self.latest_query_files
|
||||
)
|
||||
)
|
||||
prompt = prompt_builder.build()
|
||||
final_tool_definitions = [
|
||||
tool.tool_definition()
|
||||
for tool in filter_tools_for_force_tool_use(
|
||||
self.tools, self.force_use_tool
|
||||
)
|
||||
]
|
||||
raise RuntimeError("Tool call handler did not return a new LLM call")
|
||||
|
||||
for message in self.llm.stream(
|
||||
prompt=prompt,
|
||||
tools=final_tool_definitions if final_tool_definitions else None,
|
||||
tool_choice="required" if self.force_use_tool.force_use else None,
|
||||
structured_response_format=self.answer_style_config.structured_response_format,
|
||||
):
|
||||
if isinstance(message, AIMessageChunk) and (
|
||||
message.tool_call_chunks or message.tool_calls
|
||||
):
|
||||
if tool_call_chunk is None:
|
||||
tool_call_chunk = message
|
||||
else:
|
||||
tool_call_chunk += message # type: ignore
|
||||
else:
|
||||
if message.content:
|
||||
if self.is_cancelled:
|
||||
return
|
||||
yield cast(str, message.content)
|
||||
if (
|
||||
message.additional_kwargs.get("usage_metadata", {}).get("stop")
|
||||
== "length"
|
||||
):
|
||||
yield StreamStopInfo(
|
||||
stop_reason=StreamStopReason.CONTEXT_LENGTH
|
||||
)
|
||||
def _get_response(self, llm_calls: list[LLMCall]) -> AnswerStream:
|
||||
current_llm_call = llm_calls[-1]
|
||||
|
||||
if not tool_call_chunk:
|
||||
return # no tool call needed
|
||||
|
||||
# if we have a tool call, we need to call the tool
|
||||
tool_call_requests = tool_call_chunk.tool_calls
|
||||
for tool_call_request in tool_call_requests:
|
||||
known_tools_by_name = [
|
||||
tool for tool in self.tools if tool.name == tool_call_request["name"]
|
||||
]
|
||||
|
||||
if not known_tools_by_name:
|
||||
logger.error(
|
||||
"Tool call requested with unknown name field. \n"
|
||||
f"self.tools: {self.tools}"
|
||||
f"tool_call_request: {tool_call_request}"
|
||||
)
|
||||
if self.tools:
|
||||
tool = self.tools[0]
|
||||
else:
|
||||
continue
|
||||
else:
|
||||
tool = known_tools_by_name[0]
|
||||
tool_args = (
|
||||
self.force_use_tool.args
|
||||
if self.force_use_tool.tool_name == tool.name
|
||||
and self.force_use_tool.args
|
||||
else tool_call_request["args"]
|
||||
)
|
||||
|
||||
tool_runner = ToolRunner(tool, tool_args)
|
||||
yield tool_runner.kickoff()
|
||||
yield from tool_runner.tool_responses()
|
||||
|
||||
tool_call_summary = ToolCallSummary(
|
||||
tool_call_request=tool_call_chunk,
|
||||
tool_call_result=build_tool_message(
|
||||
tool_call_request, tool_runner.tool_message_content()
|
||||
),
|
||||
)
|
||||
|
||||
if tool.name in {SearchTool._NAME, InternetSearchTool._NAME}:
|
||||
self._update_prompt_builder_for_search_tool(prompt_builder, [])
|
||||
elif tool.name == ImageGenerationTool._NAME:
|
||||
img_urls = [
|
||||
img_generation_result["url"]
|
||||
for img_generation_result in tool_runner.tool_final_result().tool_result
|
||||
]
|
||||
prompt_builder.update_user_prompt(
|
||||
build_image_generation_user_prompt(
|
||||
query=self.question, img_urls=img_urls
|
||||
)
|
||||
)
|
||||
yield tool_runner.tool_final_result()
|
||||
if not self.skip_gen_ai_answer_generation:
|
||||
prompt = prompt_builder.build(tool_call_summary=tool_call_summary)
|
||||
|
||||
yield from self._process_llm_stream(
|
||||
prompt=prompt,
|
||||
# as of now, we don't support multiple tool calls in sequence, which is why
|
||||
# we don't need to pass this in here
|
||||
# tools=[tool.tool_definition() for tool in self.tools],
|
||||
)
|
||||
|
||||
return
|
||||
|
||||
# This method processes the LLM stream and yields the content or stop information
|
||||
def _process_llm_stream(
|
||||
self,
|
||||
prompt: Any,
|
||||
tools: list[dict] | None = None,
|
||||
tool_choice: ToolChoiceOptions | None = None,
|
||||
) -> Iterator[str | StreamStopInfo]:
|
||||
for message in self.llm.stream(
|
||||
prompt=prompt,
|
||||
tools=tools,
|
||||
tool_choice=tool_choice,
|
||||
structured_response_format=self.answer_style_config.structured_response_format,
|
||||
# handle the case where no decision has to be made; we simply run the tool
|
||||
if (
|
||||
current_llm_call.force_use_tool.force_use
|
||||
and current_llm_call.force_use_tool.args is not None
|
||||
):
|
||||
if isinstance(message, AIMessageChunk):
|
||||
if message.content:
|
||||
if self.is_cancelled:
|
||||
return StreamStopInfo(stop_reason=StreamStopReason.CANCELLED)
|
||||
yield cast(str, message.content)
|
||||
|
||||
if (
|
||||
message.additional_kwargs.get("usage_metadata", {}).get("stop")
|
||||
== "length"
|
||||
):
|
||||
yield StreamStopInfo(stop_reason=StreamStopReason.CONTEXT_LENGTH)
|
||||
|
||||
def _raw_output_for_non_explicit_tool_calling_llms(
|
||||
self,
|
||||
) -> Iterator[
|
||||
str | StreamStopInfo | ToolCallKickoff | ToolResponse | ToolCallFinalResult
|
||||
]:
|
||||
prompt_builder = AnswerPromptBuilder(self.message_history, self.llm.config)
|
||||
chosen_tool_and_args: tuple[Tool, dict] | None = None
|
||||
|
||||
if self.force_use_tool.force_use:
|
||||
# if we are forcing a tool, we don't need to check which tools to run
|
||||
tool_name, tool_args = (
|
||||
current_llm_call.force_use_tool.tool_name,
|
||||
current_llm_call.force_use_tool.args,
|
||||
)
|
||||
tool = next(
|
||||
iter(
|
||||
[
|
||||
tool
|
||||
for tool in self.tools
|
||||
if tool.name == self.force_use_tool.tool_name
|
||||
]
|
||||
),
|
||||
None,
|
||||
(t for t in current_llm_call.tools if t.name == tool_name), None
|
||||
)
|
||||
if not tool:
|
||||
raise RuntimeError(f"Tool '{self.force_use_tool.tool_name}' not found")
|
||||
raise RuntimeError(f"Tool '{tool_name}' not found")
|
||||
|
||||
tool_args = (
|
||||
self.force_use_tool.args
|
||||
if self.force_use_tool.args is not None
|
||||
else tool.get_args_for_non_tool_calling_llm(
|
||||
query=self.question,
|
||||
history=self.message_history,
|
||||
llm=self.llm,
|
||||
force_run=True,
|
||||
)
|
||||
)
|
||||
|
||||
if tool_args is None:
|
||||
raise RuntimeError(f"Tool '{tool.name}' did not return args")
|
||||
|
||||
chosen_tool_and_args = (tool, tool_args)
|
||||
else:
|
||||
tool_options = check_which_tools_should_run_for_non_tool_calling_llm(
|
||||
tools=self.tools,
|
||||
query=self.question,
|
||||
history=self.message_history,
|
||||
llm=self.llm,
|
||||
)
|
||||
|
||||
available_tools_and_args = [
|
||||
(self.tools[ind], args)
|
||||
for ind, args in enumerate(tool_options)
|
||||
if args is not None
|
||||
]
|
||||
|
||||
logger.info(
|
||||
f"Selecting single tool from tools: {[(tool.name, args) for tool, args in available_tools_and_args]}"
|
||||
)
|
||||
|
||||
chosen_tool_and_args = (
|
||||
select_single_tool_for_non_tool_calling_llm(
|
||||
tools_and_args=available_tools_and_args,
|
||||
history=self.message_history,
|
||||
query=self.question,
|
||||
llm=self.llm,
|
||||
)
|
||||
if available_tools_and_args
|
||||
else None
|
||||
)
|
||||
|
||||
logger.notice(f"Chosen tool: {chosen_tool_and_args}")
|
||||
|
||||
if not chosen_tool_and_args:
|
||||
if self.skip_gen_ai_answer_generation:
|
||||
raise ValueError(
|
||||
"skip_gen_ai_answer_generation is True, but no tool was chosen; no answer will be generated"
|
||||
)
|
||||
prompt_builder.update_system_prompt(
|
||||
default_build_system_message(self.prompt_config)
|
||||
)
|
||||
prompt_builder.update_user_prompt(
|
||||
default_build_user_message(
|
||||
self.question, self.prompt_config, self.latest_query_files
|
||||
)
|
||||
)
|
||||
prompt = prompt_builder.build()
|
||||
yield from self._process_llm_stream(
|
||||
prompt=prompt,
|
||||
tools=None,
|
||||
)
|
||||
yield from self._handle_specified_tool_call(llm_calls, tool, tool_args)
|
||||
return
|
||||
|
||||
tool, tool_args = chosen_tool_and_args
|
||||
tool_runner = ToolRunner(tool, tool_args)
|
||||
yield tool_runner.kickoff()
|
||||
|
||||
if tool.name in {SearchTool._NAME, InternetSearchTool._NAME}:
|
||||
final_context_documents = None
|
||||
for response in tool_runner.tool_responses():
|
||||
if response.id == FINAL_CONTEXT_DOCUMENTS_ID:
|
||||
final_context_documents = cast(list[LlmDoc], response.response)
|
||||
yield response
|
||||
|
||||
if final_context_documents is None:
|
||||
raise RuntimeError(
|
||||
f"{tool.name} did not return final context documents"
|
||||
# special pre-logic for non-tool calling LLM case
|
||||
if not self.using_tool_calling_llm and current_llm_call.tools:
|
||||
chosen_tool_and_args = (
|
||||
ToolResponseHandler.get_tool_call_for_non_tool_calling_llm(
|
||||
current_llm_call, self.llm
|
||||
)
|
||||
|
||||
self._update_prompt_builder_for_search_tool(
|
||||
prompt_builder, final_context_documents
|
||||
)
|
||||
elif tool.name == ImageGenerationTool._NAME:
|
||||
img_urls = []
|
||||
for response in tool_runner.tool_responses():
|
||||
if response.id == IMAGE_GENERATION_RESPONSE_ID:
|
||||
img_generation_response = cast(
|
||||
list[ImageGenerationResponse], response.response
|
||||
)
|
||||
img_urls = [img.url for img in img_generation_response]
|
||||
if chosen_tool_and_args:
|
||||
tool, tool_args = chosen_tool_and_args
|
||||
yield from self._handle_specified_tool_call(llm_calls, tool, tool_args)
|
||||
return
|
||||
|
||||
yield response
|
||||
# if we're skipping gen ai answer generation, we should break
|
||||
# out unless we're forcing a tool call. If we don't, we might generate an
|
||||
# answer, which is a no-no!
|
||||
if (
|
||||
self.skip_gen_ai_answer_generation
|
||||
and not current_llm_call.force_use_tool.force_use
|
||||
):
|
||||
return
|
||||
|
||||
prompt_builder.update_user_prompt(
|
||||
build_image_generation_user_prompt(
|
||||
query=self.question,
|
||||
img_urls=img_urls,
|
||||
)
|
||||
# set up "handlers" to listen to the LLM response stream and
|
||||
# feed back the processed results + handle tool call requests
|
||||
# + figure out what the next LLM call should be
|
||||
tool_call_handler = ToolResponseHandler(current_llm_call.tools)
|
||||
|
||||
search_result = SearchTool.get_search_result(current_llm_call) or []
|
||||
|
||||
answer_handler: AnswerResponseHandler
|
||||
if self.answer_style_config.citation_config:
|
||||
answer_handler = CitationResponseHandler(
|
||||
context_docs=search_result,
|
||||
doc_id_to_rank_map=map_document_id_order(search_result),
|
||||
)
|
||||
elif self.answer_style_config.quotes_config:
|
||||
answer_handler = QuotesResponseHandler(
|
||||
context_docs=search_result,
|
||||
)
|
||||
else:
|
||||
prompt_builder.update_user_prompt(
|
||||
HumanMessage(
|
||||
content=build_user_message_for_custom_tool_for_non_tool_calling_llm(
|
||||
self.question,
|
||||
tool.name,
|
||||
*tool_runner.tool_responses(),
|
||||
)
|
||||
)
|
||||
)
|
||||
final = tool_runner.tool_final_result()
|
||||
raise ValueError("No answer style config provided")
|
||||
|
||||
yield final
|
||||
if not self.skip_gen_ai_answer_generation:
|
||||
prompt = prompt_builder.build()
|
||||
response_handler_manager = LLMResponseHandlerManager(
|
||||
tool_call_handler, answer_handler, self.is_cancelled
|
||||
)
|
||||
|
||||
yield from self._process_llm_stream(prompt=prompt, tools=None)
|
||||
# DEBUG: good breakpoint
|
||||
stream = self.llm.stream(
|
||||
prompt=current_llm_call.prompt_builder.build(),
|
||||
tools=[tool.tool_definition() for tool in current_llm_call.tools] or None,
|
||||
tool_choice=(
|
||||
"required"
|
||||
if current_llm_call.tools and current_llm_call.force_use_tool.force_use
|
||||
else None
|
||||
),
|
||||
structured_response_format=self.answer_style_config.structured_response_format,
|
||||
)
|
||||
yield from response_handler_manager.handle_llm_response(stream)
|
||||
|
||||
new_llm_call = response_handler_manager.next_llm_call(current_llm_call)
|
||||
if new_llm_call:
|
||||
yield from self._get_response(llm_calls + [new_llm_call])
|
||||
|
||||
@property
|
||||
def processed_streamed_output(self) -> AnswerStream:
|
||||
@@ -498,94 +254,30 @@ class Answer:
|
||||
yield from self._processed_stream
|
||||
return
|
||||
|
||||
output_generator = (
|
||||
self._raw_output_for_explicit_tool_calling_llms()
|
||||
if explicit_tool_calling_supported(
|
||||
self.llm.config.model_provider, self.llm.config.model_name
|
||||
)
|
||||
and not self.skip_explicit_tool_calling
|
||||
else self._raw_output_for_non_explicit_tool_calling_llms()
|
||||
prompt_builder = AnswerPromptBuilder(
|
||||
user_message=default_build_user_message(
|
||||
user_query=self.question,
|
||||
prompt_config=self.prompt_config,
|
||||
files=self.latest_query_files,
|
||||
),
|
||||
message_history=self.message_history,
|
||||
llm_config=self.llm.config,
|
||||
single_message_history=self.single_message_history,
|
||||
)
|
||||
prompt_builder.update_system_prompt(
|
||||
default_build_system_message(self.prompt_config)
|
||||
)
|
||||
llm_call = LLMCall(
|
||||
prompt_builder=prompt_builder,
|
||||
tools=self._get_tools_list(),
|
||||
force_use_tool=self.force_use_tool,
|
||||
files=self.latest_query_files,
|
||||
tool_call_info=[],
|
||||
using_tool_calling_llm=self.using_tool_calling_llm,
|
||||
)
|
||||
|
||||
def _process_stream(
|
||||
stream: Iterator[ToolCallKickoff | ToolResponse | str | StreamStopInfo],
|
||||
) -> AnswerStream:
|
||||
message = None
|
||||
|
||||
# special things we need to keep track of for the SearchTool
|
||||
# raw results that will be displayed to the user
|
||||
search_results: list[LlmDoc] | None = None
|
||||
# processed docs to feed into the LLM
|
||||
final_context_docs: list[LlmDoc] | None = None
|
||||
|
||||
for message in stream:
|
||||
if isinstance(message, ToolCallKickoff) or isinstance(
|
||||
message, ToolCallFinalResult
|
||||
):
|
||||
yield message
|
||||
elif isinstance(message, ToolResponse):
|
||||
if message.id == SEARCH_RESPONSE_SUMMARY_ID:
|
||||
# We don't need to run section merging in this flow, this variable is only used
|
||||
# below to specify the ordering of the documents for the purpose of matching
|
||||
# citations to the right search documents. The deduplication logic is more lightweight
|
||||
# there and we don't need to do it twice
|
||||
search_results = [
|
||||
llm_doc_from_inference_section(section)
|
||||
for section in cast(
|
||||
SearchResponseSummary, message.response
|
||||
).top_sections
|
||||
]
|
||||
elif message.id == FINAL_CONTEXT_DOCUMENTS_ID:
|
||||
final_context_docs = cast(list[LlmDoc], message.response)
|
||||
yield message
|
||||
|
||||
elif (
|
||||
message.id == SEARCH_DOC_CONTENT_ID
|
||||
and not self._return_contexts
|
||||
):
|
||||
continue
|
||||
|
||||
yield message
|
||||
else:
|
||||
# assumes all tool responses will come first, then the final answer
|
||||
break
|
||||
|
||||
if not self.skip_gen_ai_answer_generation:
|
||||
process_answer_stream_fn = _get_answer_stream_processor(
|
||||
context_docs=final_context_docs or [],
|
||||
# if doc selection is enabled, then search_results will be None,
|
||||
# so we need to use the final_context_docs
|
||||
doc_id_to_rank_map=map_document_id_order(
|
||||
search_results or final_context_docs or []
|
||||
),
|
||||
answer_style_configs=self.answer_style_config,
|
||||
)
|
||||
|
||||
stream_stop_info = None
|
||||
|
||||
def _stream() -> Iterator[str]:
|
||||
nonlocal stream_stop_info
|
||||
for item in itertools.chain([message], stream):
|
||||
if isinstance(item, StreamStopInfo):
|
||||
stream_stop_info = item
|
||||
return
|
||||
|
||||
# this should never happen, but we're seeing weird behavior here so handling for now
|
||||
if not isinstance(item, str):
|
||||
logger.error(
|
||||
f"Received non-string item in answer stream: {item}. Skipping."
|
||||
)
|
||||
continue
|
||||
|
||||
yield item
|
||||
|
||||
yield from process_answer_stream_fn(_stream())
|
||||
|
||||
if stream_stop_info:
|
||||
yield stream_stop_info
|
||||
|
||||
processed_stream = []
|
||||
for processed_packet in _process_stream(output_generator):
|
||||
for processed_packet in self._get_response([llm_call]):
|
||||
processed_stream.append(processed_packet)
|
||||
yield processed_packet
|
||||
|
||||
@@ -609,7 +301,6 @@ class Answer:
|
||||
|
||||
return citations
|
||||
|
||||
@property
|
||||
def is_cancelled(self) -> bool:
|
||||
if self._is_cancelled:
|
||||
return True
|
||||
|
84
backend/danswer/llm/answering/llm_response_handler.py
Normal file
84
backend/danswer/llm/answering/llm_response_handler.py
Normal file
@@ -0,0 +1,84 @@
|
||||
from collections.abc import Callable
|
||||
from collections.abc import Generator
|
||||
from collections.abc import Iterator
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from langchain_core.messages import BaseMessage
|
||||
from pydantic.v1 import BaseModel as BaseModel__v1
|
||||
|
||||
from danswer.chat.models import CitationInfo
|
||||
from danswer.chat.models import DanswerAnswerPiece
|
||||
from danswer.chat.models import DanswerQuotes
|
||||
from danswer.chat.models import StreamStopInfo
|
||||
from danswer.chat.models import StreamStopReason
|
||||
from danswer.file_store.models import InMemoryChatFile
|
||||
from danswer.llm.answering.prompts.build import AnswerPromptBuilder
|
||||
from danswer.tools.force import ForceUseTool
|
||||
from danswer.tools.models import ToolCallFinalResult
|
||||
from danswer.tools.models import ToolCallKickoff
|
||||
from danswer.tools.models import ToolResponse
|
||||
from danswer.tools.tool import Tool
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from danswer.llm.answering.stream_processing.answer_response_handler import (
|
||||
AnswerResponseHandler,
|
||||
)
|
||||
from danswer.llm.answering.tool.tool_response_handler import ToolResponseHandler
|
||||
|
||||
|
||||
ResponsePart = (
|
||||
DanswerAnswerPiece
|
||||
| CitationInfo
|
||||
| DanswerQuotes
|
||||
| ToolCallKickoff
|
||||
| ToolResponse
|
||||
| ToolCallFinalResult
|
||||
| StreamStopInfo
|
||||
)
|
||||
|
||||
|
||||
class LLMCall(BaseModel__v1):
|
||||
prompt_builder: AnswerPromptBuilder
|
||||
tools: list[Tool]
|
||||
force_use_tool: ForceUseTool
|
||||
files: list[InMemoryChatFile]
|
||||
tool_call_info: list[ToolCallKickoff | ToolResponse | ToolCallFinalResult]
|
||||
using_tool_calling_llm: bool
|
||||
|
||||
class Config:
|
||||
arbitrary_types_allowed = True
|
||||
|
||||
|
||||
class LLMResponseHandlerManager:
|
||||
def __init__(
|
||||
self,
|
||||
tool_handler: "ToolResponseHandler",
|
||||
answer_handler: "AnswerResponseHandler",
|
||||
is_cancelled: Callable[[], bool],
|
||||
):
|
||||
self.tool_handler = tool_handler
|
||||
self.answer_handler = answer_handler
|
||||
self.is_cancelled = is_cancelled
|
||||
|
||||
def handle_llm_response(
|
||||
self,
|
||||
stream: Iterator[BaseMessage],
|
||||
) -> Generator[ResponsePart, None, None]:
|
||||
all_messages: list[BaseMessage] = []
|
||||
for message in stream:
|
||||
if self.is_cancelled():
|
||||
yield StreamStopInfo(stop_reason=StreamStopReason.CANCELLED)
|
||||
return
|
||||
# tool handler doesn't do anything until the full message is received
|
||||
# NOTE: still need to run list() to get this to run
|
||||
list(self.tool_handler.handle_response_part(message, all_messages))
|
||||
yield from self.answer_handler.handle_response_part(message, all_messages)
|
||||
all_messages.append(message)
|
||||
|
||||
# potentially give back all info on the selected tool call + its result
|
||||
yield from self.tool_handler.handle_response_part(None, all_messages)
|
||||
yield from self.answer_handler.handle_response_part(None, all_messages)
|
||||
|
||||
def next_llm_call(self, llm_call: LLMCall) -> LLMCall | None:
|
||||
return self.tool_handler.next_llm_call(llm_call)
|
@@ -12,12 +12,12 @@ from danswer.llm.answering.prompts.citations_prompt import compute_max_llm_input
|
||||
from danswer.llm.interfaces import LLMConfig
|
||||
from danswer.llm.utils import build_content_with_imgs
|
||||
from danswer.llm.utils import check_message_tokens
|
||||
from danswer.llm.utils import message_to_prompt_and_imgs
|
||||
from danswer.llm.utils import translate_history_to_basemessages
|
||||
from danswer.natural_language_processing.utils import get_tokenizer
|
||||
from danswer.prompts.chat_prompts import CHAT_USER_CONTEXT_FREE_PROMPT
|
||||
from danswer.prompts.prompt_utils import add_date_time_to_prompt
|
||||
from danswer.prompts.prompt_utils import drop_messages_history_overflow
|
||||
from danswer.tools.message import ToolCallSummary
|
||||
|
||||
|
||||
def default_build_system_message(
|
||||
@@ -54,18 +54,14 @@ def default_build_user_message(
|
||||
|
||||
class AnswerPromptBuilder:
|
||||
def __init__(
|
||||
self, message_history: list[PreviousMessage], llm_config: LLMConfig
|
||||
self,
|
||||
user_message: HumanMessage,
|
||||
message_history: list[PreviousMessage],
|
||||
llm_config: LLMConfig,
|
||||
single_message_history: str | None = None,
|
||||
) -> None:
|
||||
self.max_tokens = compute_max_llm_input_tokens(llm_config)
|
||||
|
||||
(
|
||||
self.message_history,
|
||||
self.history_token_cnts,
|
||||
) = translate_history_to_basemessages(message_history)
|
||||
|
||||
self.system_message_and_token_cnt: tuple[SystemMessage, int] | None = None
|
||||
self.user_message_and_token_cnt: tuple[HumanMessage, int] | None = None
|
||||
|
||||
llm_tokenizer = get_tokenizer(
|
||||
provider_type=llm_config.model_provider,
|
||||
model_name=llm_config.model_name,
|
||||
@@ -74,6 +70,24 @@ class AnswerPromptBuilder:
|
||||
Callable[[str], list[int]], llm_tokenizer.encode
|
||||
)
|
||||
|
||||
self.raw_message_history = message_history
|
||||
(
|
||||
self.message_history,
|
||||
self.history_token_cnts,
|
||||
) = translate_history_to_basemessages(message_history)
|
||||
|
||||
# for cases where like the QA flow where we want to condense the chat history
|
||||
# into a single message rather than a sequence of User / Assistant messages
|
||||
self.single_message_history = single_message_history
|
||||
|
||||
self.system_message_and_token_cnt: tuple[SystemMessage, int] | None = None
|
||||
self.user_message_and_token_cnt = (
|
||||
user_message,
|
||||
check_message_tokens(user_message, self.llm_tokenizer_encode_func),
|
||||
)
|
||||
|
||||
self.new_messages_and_token_cnts: list[tuple[BaseMessage, int]] = []
|
||||
|
||||
def update_system_prompt(self, system_message: SystemMessage | None) -> None:
|
||||
if not system_message:
|
||||
self.system_message_and_token_cnt = None
|
||||
@@ -85,18 +99,21 @@ class AnswerPromptBuilder:
|
||||
)
|
||||
|
||||
def update_user_prompt(self, user_message: HumanMessage) -> None:
|
||||
if not user_message:
|
||||
self.user_message_and_token_cnt = None
|
||||
return
|
||||
|
||||
self.user_message_and_token_cnt = (
|
||||
user_message,
|
||||
check_message_tokens(user_message, self.llm_tokenizer_encode_func),
|
||||
)
|
||||
|
||||
def build(
|
||||
self, tool_call_summary: ToolCallSummary | None = None
|
||||
) -> list[BaseMessage]:
|
||||
def append_message(self, message: BaseMessage) -> None:
|
||||
"""Append a new message to the message history."""
|
||||
token_count = check_message_tokens(message, self.llm_tokenizer_encode_func)
|
||||
self.new_messages_and_token_cnts.append((message, token_count))
|
||||
|
||||
def get_user_message_content(self) -> str:
|
||||
query, _ = message_to_prompt_and_imgs(self.user_message_and_token_cnt[0])
|
||||
return query
|
||||
|
||||
def build(self) -> list[BaseMessage]:
|
||||
if not self.user_message_and_token_cnt:
|
||||
raise ValueError("User message must be set before building prompt")
|
||||
|
||||
@@ -113,25 +130,8 @@ class AnswerPromptBuilder:
|
||||
|
||||
final_messages_with_tokens.append(self.user_message_and_token_cnt)
|
||||
|
||||
if tool_call_summary:
|
||||
final_messages_with_tokens.append(
|
||||
(
|
||||
tool_call_summary.tool_call_request,
|
||||
check_message_tokens(
|
||||
tool_call_summary.tool_call_request,
|
||||
self.llm_tokenizer_encode_func,
|
||||
),
|
||||
)
|
||||
)
|
||||
final_messages_with_tokens.append(
|
||||
(
|
||||
tool_call_summary.tool_call_result,
|
||||
check_message_tokens(
|
||||
tool_call_summary.tool_call_result,
|
||||
self.llm_tokenizer_encode_func,
|
||||
),
|
||||
)
|
||||
)
|
||||
if self.new_messages_and_token_cnts:
|
||||
final_messages_with_tokens.extend(self.new_messages_and_token_cnts)
|
||||
|
||||
return drop_messages_history_overflow(
|
||||
final_messages_with_tokens, self.max_tokens
|
||||
|
@@ -6,7 +6,6 @@ from danswer.configs.model_configs import GEN_AI_SINGLE_USER_MESSAGE_EXPECTED_MA
|
||||
from danswer.db.models import Persona
|
||||
from danswer.db.persona import get_default_prompt__read_only
|
||||
from danswer.db.search_settings import get_multilingual_expansion
|
||||
from danswer.file_store.utils import InMemoryChatFile
|
||||
from danswer.llm.answering.models import PromptConfig
|
||||
from danswer.llm.factory import get_llms_for_persona
|
||||
from danswer.llm.factory import get_main_llm_from_tuple
|
||||
@@ -14,6 +13,7 @@ from danswer.llm.interfaces import LLMConfig
|
||||
from danswer.llm.utils import build_content_with_imgs
|
||||
from danswer.llm.utils import check_number_of_tokens
|
||||
from danswer.llm.utils import get_max_input_tokens
|
||||
from danswer.llm.utils import message_to_prompt_and_imgs
|
||||
from danswer.prompts.chat_prompts import REQUIRE_CITATION_STATEMENT
|
||||
from danswer.prompts.constants import DEFAULT_IGNORE_STATEMENT
|
||||
from danswer.prompts.direct_qa_prompts import CITATIONS_PROMPT
|
||||
@@ -132,10 +132,9 @@ def build_citations_system_message(
|
||||
|
||||
|
||||
def build_citations_user_message(
|
||||
question: str,
|
||||
message: HumanMessage,
|
||||
prompt_config: PromptConfig,
|
||||
context_docs: list[LlmDoc] | list[InferenceChunk],
|
||||
files: list[InMemoryChatFile],
|
||||
all_doc_useful: bool,
|
||||
history_message: str = "",
|
||||
) -> HumanMessage:
|
||||
@@ -149,6 +148,7 @@ def build_citations_user_message(
|
||||
if history_message
|
||||
else ""
|
||||
)
|
||||
query, img_urls = message_to_prompt_and_imgs(message)
|
||||
|
||||
if context_docs:
|
||||
context_docs_str = build_complete_context_str(context_docs)
|
||||
@@ -158,20 +158,22 @@ def build_citations_user_message(
|
||||
optional_ignore_statement=optional_ignore,
|
||||
context_docs_str=context_docs_str,
|
||||
task_prompt=task_prompt_with_reminder,
|
||||
user_query=question,
|
||||
user_query=query,
|
||||
history_block=history_block,
|
||||
)
|
||||
else:
|
||||
# if no context docs provided, assume we're in the tool calling flow
|
||||
user_prompt = CITATIONS_PROMPT_FOR_TOOL_CALLING.format(
|
||||
task_prompt=task_prompt_with_reminder,
|
||||
user_query=question,
|
||||
user_query=query,
|
||||
history_block=history_block,
|
||||
)
|
||||
|
||||
user_prompt = user_prompt.strip()
|
||||
user_msg = HumanMessage(
|
||||
content=build_content_with_imgs(user_prompt, files) if files else user_prompt
|
||||
content=build_content_with_imgs(user_prompt, img_urls=img_urls)
|
||||
if img_urls
|
||||
else user_prompt
|
||||
)
|
||||
|
||||
return user_msg
|
||||
|
@@ -5,6 +5,7 @@ from danswer.configs.chat_configs import LANGUAGE_HINT
|
||||
from danswer.configs.chat_configs import QA_PROMPT_OVERRIDE
|
||||
from danswer.db.search_settings import get_multilingual_expansion
|
||||
from danswer.llm.answering.models import PromptConfig
|
||||
from danswer.llm.utils import message_to_prompt_and_imgs
|
||||
from danswer.prompts.direct_qa_prompts import CONTEXT_BLOCK
|
||||
from danswer.prompts.direct_qa_prompts import HISTORY_BLOCK
|
||||
from danswer.prompts.direct_qa_prompts import JSON_PROMPT
|
||||
@@ -75,7 +76,7 @@ def _build_strong_llm_quotes_prompt(
|
||||
|
||||
|
||||
def build_quotes_user_message(
|
||||
question: str,
|
||||
message: HumanMessage,
|
||||
context_docs: list[LlmDoc] | list[InferenceChunk],
|
||||
history_str: str,
|
||||
prompt: PromptConfig,
|
||||
@@ -86,28 +87,10 @@ def build_quotes_user_message(
|
||||
else _build_strong_llm_quotes_prompt
|
||||
)
|
||||
|
||||
query, _ = message_to_prompt_and_imgs(message)
|
||||
|
||||
return prompt_builder(
|
||||
question=question,
|
||||
context_docs=context_docs,
|
||||
history_str=history_str,
|
||||
prompt=prompt,
|
||||
)
|
||||
|
||||
|
||||
def build_quotes_prompt(
|
||||
question: str,
|
||||
context_docs: list[LlmDoc] | list[InferenceChunk],
|
||||
history_str: str,
|
||||
prompt: PromptConfig,
|
||||
) -> HumanMessage:
|
||||
prompt_builder = (
|
||||
_build_weak_llm_quotes_prompt
|
||||
if QA_PROMPT_OVERRIDE == "weak"
|
||||
else _build_strong_llm_quotes_prompt
|
||||
)
|
||||
|
||||
return prompt_builder(
|
||||
question=question,
|
||||
question=query,
|
||||
context_docs=context_docs,
|
||||
history_str=history_str,
|
||||
prompt=prompt,
|
||||
|
@@ -19,7 +19,7 @@ from danswer.natural_language_processing.utils import tokenizer_trim_content
|
||||
from danswer.prompts.prompt_utils import build_doc_context_str
|
||||
from danswer.search.models import InferenceChunk
|
||||
from danswer.search.models import InferenceSection
|
||||
from danswer.tools.search.search_utils import section_to_dict
|
||||
from danswer.tools.tool_implementations.search.search_utils import section_to_dict
|
||||
from danswer.utils.logger import setup_logger
|
||||
|
||||
|
||||
|
@@ -0,0 +1,91 @@
|
||||
import abc
|
||||
from collections.abc import Generator
|
||||
|
||||
from langchain_core.messages import BaseMessage
|
||||
|
||||
from danswer.chat.models import CitationInfo
|
||||
from danswer.chat.models import LlmDoc
|
||||
from danswer.llm.answering.llm_response_handler import ResponsePart
|
||||
from danswer.llm.answering.stream_processing.citation_processing import (
|
||||
CitationProcessor,
|
||||
)
|
||||
from danswer.llm.answering.stream_processing.quotes_processing import (
|
||||
QuotesProcessor,
|
||||
)
|
||||
from danswer.llm.answering.stream_processing.utils import DocumentIdOrderMapping
|
||||
|
||||
|
||||
class AnswerResponseHandler(abc.ABC):
|
||||
@abc.abstractmethod
|
||||
def handle_response_part(
|
||||
self,
|
||||
response_item: BaseMessage | None,
|
||||
previous_response_items: list[BaseMessage],
|
||||
) -> Generator[ResponsePart, None, None]:
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class DummyAnswerResponseHandler(AnswerResponseHandler):
|
||||
def handle_response_part(
|
||||
self,
|
||||
response_item: BaseMessage | None,
|
||||
previous_response_items: list[BaseMessage],
|
||||
) -> Generator[ResponsePart, None, None]:
|
||||
# This is a dummy handler that returns nothing
|
||||
yield from []
|
||||
|
||||
|
||||
class CitationResponseHandler(AnswerResponseHandler):
|
||||
def __init__(
|
||||
self, context_docs: list[LlmDoc], doc_id_to_rank_map: DocumentIdOrderMapping
|
||||
):
|
||||
self.context_docs = context_docs
|
||||
self.doc_id_to_rank_map = doc_id_to_rank_map
|
||||
self.citation_processor = CitationProcessor(
|
||||
context_docs=self.context_docs,
|
||||
doc_id_to_rank_map=self.doc_id_to_rank_map,
|
||||
)
|
||||
self.processed_text = ""
|
||||
self.citations: list[CitationInfo] = []
|
||||
|
||||
def handle_response_part(
|
||||
self,
|
||||
response_item: BaseMessage | None,
|
||||
previous_response_items: list[BaseMessage],
|
||||
) -> Generator[ResponsePart, None, None]:
|
||||
if response_item is None:
|
||||
return
|
||||
|
||||
content = (
|
||||
response_item.content if isinstance(response_item.content, str) else ""
|
||||
)
|
||||
|
||||
# Process the new content through the citation processor
|
||||
yield from self.citation_processor.process_token(content)
|
||||
|
||||
|
||||
class QuotesResponseHandler(AnswerResponseHandler):
|
||||
def __init__(
|
||||
self,
|
||||
context_docs: list[LlmDoc],
|
||||
is_json_prompt: bool = True,
|
||||
):
|
||||
self.quotes_processor = QuotesProcessor(
|
||||
context_docs=context_docs,
|
||||
is_json_prompt=is_json_prompt,
|
||||
)
|
||||
|
||||
def handle_response_part(
|
||||
self,
|
||||
response_item: BaseMessage | None,
|
||||
previous_response_items: list[BaseMessage],
|
||||
) -> Generator[ResponsePart, None, None]:
|
||||
if response_item is None:
|
||||
yield from self.quotes_processor.process_token(None)
|
||||
return
|
||||
|
||||
content = (
|
||||
response_item.content if isinstance(response_item.content, str) else ""
|
||||
)
|
||||
|
||||
yield from self.quotes_processor.process_token(content)
|
@@ -1,12 +1,10 @@
|
||||
import re
|
||||
from collections.abc import Iterator
|
||||
from collections.abc import Generator
|
||||
|
||||
from danswer.chat.models import AnswerQuestionStreamReturn
|
||||
from danswer.chat.models import CitationInfo
|
||||
from danswer.chat.models import DanswerAnswerPiece
|
||||
from danswer.chat.models import LlmDoc
|
||||
from danswer.configs.chat_configs import STOP_STREAM_PAT
|
||||
from danswer.llm.answering.models import StreamProcessor
|
||||
from danswer.llm.answering.stream_processing.utils import DocumentIdOrderMapping
|
||||
from danswer.prompts.constants import TRIPLE_BACKTICK
|
||||
from danswer.utils.logger import setup_logger
|
||||
@@ -19,128 +17,104 @@ def in_code_block(llm_text: str) -> bool:
|
||||
return count % 2 != 0
|
||||
|
||||
|
||||
def extract_citations_from_stream(
|
||||
tokens: Iterator[str],
|
||||
context_docs: list[LlmDoc],
|
||||
doc_id_to_rank_map: DocumentIdOrderMapping,
|
||||
stop_stream: str | None = STOP_STREAM_PAT,
|
||||
) -> Iterator[DanswerAnswerPiece | CitationInfo]:
|
||||
"""
|
||||
Key aspects:
|
||||
class CitationProcessor:
|
||||
def __init__(
|
||||
self,
|
||||
context_docs: list[LlmDoc],
|
||||
doc_id_to_rank_map: DocumentIdOrderMapping,
|
||||
stop_stream: str | None = STOP_STREAM_PAT,
|
||||
):
|
||||
self.context_docs = context_docs
|
||||
self.doc_id_to_rank_map = doc_id_to_rank_map
|
||||
self.stop_stream = stop_stream
|
||||
self.order_mapping = doc_id_to_rank_map.order_mapping
|
||||
self.llm_out = ""
|
||||
self.max_citation_num = len(context_docs)
|
||||
self.citation_order: list[int] = []
|
||||
self.curr_segment = ""
|
||||
self.cited_inds: set[int] = set()
|
||||
self.hold = ""
|
||||
self.current_citations: list[int] = []
|
||||
self.past_cite_count = 0
|
||||
|
||||
1. Stream Processing:
|
||||
- Processes tokens one by one, allowing for real-time handling of large texts.
|
||||
def process_token(
|
||||
self, token: str | None
|
||||
) -> Generator[DanswerAnswerPiece | CitationInfo, None, None]:
|
||||
# None -> end of stream
|
||||
if token is None:
|
||||
yield DanswerAnswerPiece(answer_piece=self.curr_segment)
|
||||
return
|
||||
|
||||
2. Citation Detection:
|
||||
- Uses regex to find citations in the format [number].
|
||||
- Example: [1], [2], etc.
|
||||
|
||||
3. Citation Mapping:
|
||||
- Maps detected citation numbers to actual document ranks using doc_id_to_rank_map.
|
||||
- Example: [1] might become [3] if doc_id_to_rank_map maps it to 3.
|
||||
|
||||
4. Citation Formatting:
|
||||
- Replaces citations with properly formatted versions.
|
||||
- Adds links if available: [[1]](https://example.com)
|
||||
- Handles cases where links are not available: [[1]]()
|
||||
|
||||
5. Duplicate Handling:
|
||||
- Skips consecutive citations of the same document to avoid redundancy.
|
||||
|
||||
6. Output Generation:
|
||||
- Yields DanswerAnswerPiece objects for regular text.
|
||||
- Yields CitationInfo objects for each unique citation encountered.
|
||||
|
||||
7. Context Awareness:
|
||||
- Uses context_docs to access document information for citations.
|
||||
|
||||
This function effectively processes a stream of text, identifies and reformats citations,
|
||||
and provides both the processed text and citation information as output.
|
||||
"""
|
||||
order_mapping = doc_id_to_rank_map.order_mapping
|
||||
llm_out = ""
|
||||
max_citation_num = len(context_docs)
|
||||
citation_order = []
|
||||
curr_segment = ""
|
||||
cited_inds = set()
|
||||
hold = ""
|
||||
|
||||
raw_out = ""
|
||||
current_citations: list[int] = []
|
||||
past_cite_count = 0
|
||||
for raw_token in tokens:
|
||||
raw_out += raw_token
|
||||
if stop_stream:
|
||||
next_hold = hold + raw_token
|
||||
if stop_stream in next_hold:
|
||||
break
|
||||
if next_hold == stop_stream[: len(next_hold)]:
|
||||
hold = next_hold
|
||||
continue
|
||||
if self.stop_stream:
|
||||
next_hold = self.hold + token
|
||||
if self.stop_stream in next_hold:
|
||||
return
|
||||
if next_hold == self.stop_stream[: len(next_hold)]:
|
||||
self.hold = next_hold
|
||||
return
|
||||
token = next_hold
|
||||
hold = ""
|
||||
else:
|
||||
token = raw_token
|
||||
self.hold = ""
|
||||
|
||||
curr_segment += token
|
||||
llm_out += token
|
||||
self.curr_segment += token
|
||||
self.llm_out += token
|
||||
|
||||
# Handle code blocks without language tags
|
||||
if "`" in curr_segment:
|
||||
if curr_segment.endswith("`"):
|
||||
continue
|
||||
elif "```" in curr_segment:
|
||||
piece_that_comes_after = curr_segment.split("```")[1][0]
|
||||
if piece_that_comes_after == "\n" and in_code_block(llm_out):
|
||||
curr_segment = curr_segment.replace("```", "```plaintext")
|
||||
if "`" in self.curr_segment:
|
||||
if self.curr_segment.endswith("`"):
|
||||
return
|
||||
elif "```" in self.curr_segment:
|
||||
piece_that_comes_after = self.curr_segment.split("```")[1][0]
|
||||
if piece_that_comes_after == "\n" and in_code_block(self.llm_out):
|
||||
self.curr_segment = self.curr_segment.replace("```", "```plaintext")
|
||||
|
||||
citation_pattern = r"\[(\d+)\]"
|
||||
|
||||
citations_found = list(re.finditer(citation_pattern, curr_segment))
|
||||
citations_found = list(re.finditer(citation_pattern, self.curr_segment))
|
||||
possible_citation_pattern = r"(\[\d*$)" # [1, [, etc
|
||||
possible_citation_found = re.search(possible_citation_pattern, curr_segment)
|
||||
possible_citation_found = re.search(
|
||||
possible_citation_pattern, self.curr_segment
|
||||
)
|
||||
|
||||
# `past_cite_count`: number of characters since past citation
|
||||
# 5 to ensure a citation hasn't occured
|
||||
if len(citations_found) == 0 and len(llm_out) - past_cite_count > 5:
|
||||
current_citations = []
|
||||
if len(citations_found) == 0 and len(self.llm_out) - self.past_cite_count > 5:
|
||||
self.current_citations = []
|
||||
|
||||
if citations_found and not in_code_block(llm_out):
|
||||
result = "" # Initialize result here
|
||||
if citations_found and not in_code_block(self.llm_out):
|
||||
last_citation_end = 0
|
||||
length_to_add = 0
|
||||
while len(citations_found) > 0:
|
||||
citation = citations_found.pop(0)
|
||||
numerical_value = int(citation.group(1))
|
||||
|
||||
if 1 <= numerical_value <= max_citation_num:
|
||||
context_llm_doc = context_docs[numerical_value - 1]
|
||||
real_citation_num = order_mapping[context_llm_doc.document_id]
|
||||
if 1 <= numerical_value <= self.max_citation_num:
|
||||
context_llm_doc = self.context_docs[numerical_value - 1]
|
||||
real_citation_num = self.order_mapping[context_llm_doc.document_id]
|
||||
|
||||
if real_citation_num not in citation_order:
|
||||
citation_order.append(real_citation_num)
|
||||
if real_citation_num not in self.citation_order:
|
||||
self.citation_order.append(real_citation_num)
|
||||
|
||||
target_citation_num = citation_order.index(real_citation_num) + 1
|
||||
target_citation_num = (
|
||||
self.citation_order.index(real_citation_num) + 1
|
||||
)
|
||||
|
||||
# Skip consecutive citations of the same work
|
||||
if target_citation_num in current_citations:
|
||||
if target_citation_num in self.current_citations:
|
||||
start, end = citation.span()
|
||||
real_start = length_to_add + start
|
||||
diff = end - start
|
||||
curr_segment = (
|
||||
curr_segment[: length_to_add + start]
|
||||
+ curr_segment[real_start + diff :]
|
||||
self.curr_segment = (
|
||||
self.curr_segment[: length_to_add + start]
|
||||
+ self.curr_segment[real_start + diff :]
|
||||
)
|
||||
length_to_add -= diff
|
||||
continue
|
||||
|
||||
# Handle edge case where LLM outputs citation itself
|
||||
# by allowing it to generate citations on its own.
|
||||
if curr_segment.startswith("[["):
|
||||
match = re.match(r"\[\[(\d+)\]\]", curr_segment)
|
||||
if self.curr_segment.startswith("[["):
|
||||
match = re.match(r"\[\[(\d+)\]\]", self.curr_segment)
|
||||
if match:
|
||||
try:
|
||||
doc_id = int(match.group(1))
|
||||
context_llm_doc = context_docs[doc_id - 1]
|
||||
context_llm_doc = self.context_docs[doc_id - 1]
|
||||
yield CitationInfo(
|
||||
citation_num=target_citation_num,
|
||||
document_id=context_llm_doc.document_id,
|
||||
@@ -150,75 +124,57 @@ def extract_citations_from_stream(
|
||||
f"Manual LLM citation didn't properly cite documents {e}"
|
||||
)
|
||||
else:
|
||||
# Will continue attempt on next loops
|
||||
logger.warning(
|
||||
"Manual LLM citation wasn't able to close brackets"
|
||||
)
|
||||
|
||||
continue
|
||||
|
||||
link = context_llm_doc.link
|
||||
|
||||
# Replace the citation in the current segment
|
||||
start, end = citation.span()
|
||||
curr_segment = (
|
||||
curr_segment[: start + length_to_add]
|
||||
self.curr_segment = (
|
||||
self.curr_segment[: start + length_to_add]
|
||||
+ f"[{target_citation_num}]"
|
||||
+ curr_segment[end + length_to_add :]
|
||||
+ self.curr_segment[end + length_to_add :]
|
||||
)
|
||||
|
||||
past_cite_count = len(llm_out)
|
||||
current_citations.append(target_citation_num)
|
||||
self.past_cite_count = len(self.llm_out)
|
||||
self.current_citations.append(target_citation_num)
|
||||
|
||||
if target_citation_num not in cited_inds:
|
||||
cited_inds.add(target_citation_num)
|
||||
if target_citation_num not in self.cited_inds:
|
||||
self.cited_inds.add(target_citation_num)
|
||||
yield CitationInfo(
|
||||
citation_num=target_citation_num,
|
||||
document_id=context_llm_doc.document_id,
|
||||
)
|
||||
|
||||
if link:
|
||||
prev_length = len(curr_segment)
|
||||
curr_segment = (
|
||||
curr_segment[: start + length_to_add]
|
||||
prev_length = len(self.curr_segment)
|
||||
self.curr_segment = (
|
||||
self.curr_segment[: start + length_to_add]
|
||||
+ f"[[{target_citation_num}]]({link})"
|
||||
+ curr_segment[end + length_to_add :]
|
||||
+ self.curr_segment[end + length_to_add :]
|
||||
)
|
||||
length_to_add += len(curr_segment) - prev_length
|
||||
|
||||
length_to_add += len(self.curr_segment) - prev_length
|
||||
else:
|
||||
prev_length = len(curr_segment)
|
||||
curr_segment = (
|
||||
curr_segment[: start + length_to_add]
|
||||
prev_length = len(self.curr_segment)
|
||||
self.curr_segment = (
|
||||
self.curr_segment[: start + length_to_add]
|
||||
+ f"[[{target_citation_num}]]()"
|
||||
+ curr_segment[end + length_to_add :]
|
||||
+ self.curr_segment[end + length_to_add :]
|
||||
)
|
||||
length_to_add += len(curr_segment) - prev_length
|
||||
length_to_add += len(self.curr_segment) - prev_length
|
||||
|
||||
last_citation_end = end + length_to_add
|
||||
|
||||
if last_citation_end > 0:
|
||||
yield DanswerAnswerPiece(answer_piece=curr_segment[:last_citation_end])
|
||||
curr_segment = curr_segment[last_citation_end:]
|
||||
if possible_citation_found:
|
||||
continue
|
||||
yield DanswerAnswerPiece(answer_piece=curr_segment)
|
||||
curr_segment = ""
|
||||
result += self.curr_segment[:last_citation_end]
|
||||
self.curr_segment = self.curr_segment[last_citation_end:]
|
||||
|
||||
if curr_segment:
|
||||
yield DanswerAnswerPiece(answer_piece=curr_segment)
|
||||
if not possible_citation_found:
|
||||
result += self.curr_segment
|
||||
self.curr_segment = ""
|
||||
|
||||
|
||||
def build_citation_processor(
|
||||
context_docs: list[LlmDoc], doc_id_to_rank_map: DocumentIdOrderMapping
|
||||
) -> StreamProcessor:
|
||||
def stream_processor(
|
||||
tokens: Iterator[str],
|
||||
) -> AnswerQuestionStreamReturn:
|
||||
yield from extract_citations_from_stream(
|
||||
tokens=tokens,
|
||||
context_docs=context_docs,
|
||||
doc_id_to_rank_map=doc_id_to_rank_map,
|
||||
)
|
||||
|
||||
return stream_processor
|
||||
if result:
|
||||
yield DanswerAnswerPiece(answer_piece=result)
|
||||
|
@@ -1,14 +1,11 @@
|
||||
import math
|
||||
import re
|
||||
from collections.abc import Callable
|
||||
from collections.abc import Generator
|
||||
from collections.abc import Iterator
|
||||
from json import JSONDecodeError
|
||||
from typing import Optional
|
||||
|
||||
import regex
|
||||
|
||||
from danswer.chat.models import AnswerQuestionStreamReturn
|
||||
from danswer.chat.models import DanswerAnswer
|
||||
from danswer.chat.models import DanswerAnswerPiece
|
||||
from danswer.chat.models import DanswerQuote
|
||||
@@ -157,7 +154,7 @@ def separate_answer_quotes(
|
||||
return _extract_answer_quotes_freeform(clean_up_code_blocks(answer_raw))
|
||||
|
||||
|
||||
def process_answer(
|
||||
def _process_answer(
|
||||
answer_raw: str,
|
||||
docs: list[LlmDoc],
|
||||
is_json_prompt: bool = True,
|
||||
@@ -195,7 +192,7 @@ def _stream_json_answer_end(answer_so_far: str, next_token: str) -> bool:
|
||||
def _extract_quotes_from_completed_token_stream(
|
||||
model_output: str, context_docs: list[LlmDoc], is_json_prompt: bool = True
|
||||
) -> DanswerQuotes:
|
||||
answer, quotes = process_answer(model_output, context_docs, is_json_prompt)
|
||||
answer, quotes = _process_answer(model_output, context_docs, is_json_prompt)
|
||||
if answer:
|
||||
logger.notice(answer)
|
||||
elif model_output:
|
||||
@@ -204,94 +201,101 @@ def _extract_quotes_from_completed_token_stream(
|
||||
return quotes
|
||||
|
||||
|
||||
def process_model_tokens(
|
||||
tokens: Iterator[str],
|
||||
context_docs: list[LlmDoc],
|
||||
is_json_prompt: bool = True,
|
||||
) -> Generator[DanswerAnswerPiece | DanswerQuotes, None, None]:
|
||||
"""Used in the streaming case to process the model output
|
||||
into an Answer and Quotes
|
||||
class QuotesProcessor:
|
||||
def __init__(
|
||||
self,
|
||||
context_docs: list[LlmDoc],
|
||||
is_json_prompt: bool = True,
|
||||
):
|
||||
self.context_docs = context_docs
|
||||
self.is_json_prompt = is_json_prompt
|
||||
|
||||
Yields Answer tokens back out in a dict for streaming to frontend
|
||||
When Answer section ends, yields dict with answer_finished key
|
||||
Collects all the tokens at the end to form the complete model output"""
|
||||
quote_pat = f"\n{QUOTE_PAT}"
|
||||
# Sometimes worse model outputs new line instead of :
|
||||
quote_loose = f"\n{quote_pat[:-1]}\n"
|
||||
# Sometime model outputs two newlines before quote section
|
||||
quote_pat_full = f"\n{quote_pat}"
|
||||
model_output: str = ""
|
||||
found_answer_start = False if is_json_prompt else True
|
||||
found_answer_end = False
|
||||
hold_quote = ""
|
||||
self.found_answer_start = False if is_json_prompt else True
|
||||
self.found_answer_end = False
|
||||
self.hold_quote = ""
|
||||
self.model_output = ""
|
||||
self.hold = ""
|
||||
|
||||
for token in tokens:
|
||||
model_previous = model_output
|
||||
model_output += token
|
||||
def process_token(
|
||||
self, token: str | None
|
||||
) -> Generator[DanswerAnswerPiece | DanswerQuotes, None, None]:
|
||||
# None -> end of stream
|
||||
if token is None:
|
||||
if self.model_output:
|
||||
yield _extract_quotes_from_completed_token_stream(
|
||||
model_output=self.model_output,
|
||||
context_docs=self.context_docs,
|
||||
is_json_prompt=self.is_json_prompt,
|
||||
)
|
||||
return
|
||||
|
||||
if not found_answer_start:
|
||||
m = answer_pattern.search(model_output)
|
||||
model_previous = self.model_output
|
||||
self.model_output += token
|
||||
|
||||
if not self.found_answer_start:
|
||||
m = answer_pattern.search(self.model_output)
|
||||
if m:
|
||||
found_answer_start = True
|
||||
self.found_answer_start = True
|
||||
|
||||
# Prevent heavy cases of hallucinations where model is never providing a JSON
|
||||
# We want to quickly update the user - not stream forever
|
||||
if is_json_prompt and len(model_output) > 70:
|
||||
# Prevent heavy cases of hallucinations
|
||||
if self.is_json_prompt and len(self.model_output) > 70:
|
||||
logger.warning("LLM did not produce json as prompted")
|
||||
found_answer_end = True
|
||||
continue
|
||||
self.found_answer_end = True
|
||||
return
|
||||
|
||||
remaining = model_output[m.end() :]
|
||||
remaining = self.model_output[m.end() :]
|
||||
|
||||
# Look for an unescaped quote, which means the answer is entirely contained
|
||||
# in this token e.g. if the token is `{"answer": "blah", "qu`
|
||||
quote_indices = [i for i, char in enumerate(remaining) if char == '"']
|
||||
for quote_idx in quote_indices:
|
||||
# Check if quote is escaped by counting backslashes before it
|
||||
num_backslashes = 0
|
||||
pos = quote_idx - 1
|
||||
while pos >= 0 and remaining[pos] == "\\":
|
||||
num_backslashes += 1
|
||||
pos -= 1
|
||||
# If even number of backslashes, quote is not escaped
|
||||
if num_backslashes % 2 == 0:
|
||||
yield DanswerAnswerPiece(answer_piece=remaining[:quote_idx])
|
||||
return
|
||||
|
||||
# If no unescaped quote found, yield the remaining string
|
||||
if len(remaining) > 0:
|
||||
yield DanswerAnswerPiece(answer_piece=remaining)
|
||||
continue
|
||||
return
|
||||
|
||||
if found_answer_start and not found_answer_end:
|
||||
if is_json_prompt and _stream_json_answer_end(model_previous, token):
|
||||
found_answer_end = True
|
||||
if self.found_answer_start and not self.found_answer_end:
|
||||
if self.is_json_prompt and _stream_json_answer_end(model_previous, token):
|
||||
self.found_answer_end = True
|
||||
|
||||
# return the remaining part of the answer e.g. token might be 'd.", ' and we should yield 'd.'
|
||||
if token:
|
||||
try:
|
||||
answer_token_section = token.index('"')
|
||||
yield DanswerAnswerPiece(
|
||||
answer_piece=hold_quote + token[:answer_token_section]
|
||||
answer_piece=self.hold_quote + token[:answer_token_section]
|
||||
)
|
||||
except ValueError:
|
||||
logger.error("Quotation mark not found in token")
|
||||
yield DanswerAnswerPiece(answer_piece=hold_quote + token)
|
||||
yield DanswerAnswerPiece(answer_piece=self.hold_quote + token)
|
||||
yield DanswerAnswerPiece(answer_piece=None)
|
||||
continue
|
||||
elif not is_json_prompt:
|
||||
if quote_pat in hold_quote + token or quote_loose in hold_quote + token:
|
||||
found_answer_end = True
|
||||
return
|
||||
|
||||
elif not self.is_json_prompt:
|
||||
quote_pat = f"\n{QUOTE_PAT}"
|
||||
quote_loose = f"\n{quote_pat[:-1]}\n"
|
||||
quote_pat_full = f"\n{quote_pat}"
|
||||
|
||||
if (
|
||||
quote_pat in self.hold_quote + token
|
||||
or quote_loose in self.hold_quote + token
|
||||
):
|
||||
self.found_answer_end = True
|
||||
yield DanswerAnswerPiece(answer_piece=None)
|
||||
continue
|
||||
if hold_quote + token in quote_pat_full:
|
||||
hold_quote += token
|
||||
continue
|
||||
yield DanswerAnswerPiece(answer_piece=hold_quote + token)
|
||||
hold_quote = ""
|
||||
return
|
||||
if self.hold_quote + token in quote_pat_full:
|
||||
self.hold_quote += token
|
||||
return
|
||||
|
||||
logger.debug(f"Raw Model QnA Output: {model_output}")
|
||||
|
||||
yield _extract_quotes_from_completed_token_stream(
|
||||
model_output=model_output,
|
||||
context_docs=context_docs,
|
||||
is_json_prompt=is_json_prompt,
|
||||
)
|
||||
|
||||
|
||||
def build_quotes_processor(
|
||||
context_docs: list[LlmDoc], is_json_prompt: bool
|
||||
) -> Callable[[Iterator[str]], AnswerQuestionStreamReturn]:
|
||||
def stream_processor(
|
||||
tokens: Iterator[str],
|
||||
) -> AnswerQuestionStreamReturn:
|
||||
yield from process_model_tokens(
|
||||
tokens=tokens,
|
||||
context_docs=context_docs,
|
||||
is_json_prompt=is_json_prompt,
|
||||
)
|
||||
|
||||
return stream_processor
|
||||
yield DanswerAnswerPiece(answer_piece=self.hold_quote + token)
|
||||
self.hold_quote = ""
|
||||
|
207
backend/danswer/llm/answering/tool/tool_response_handler.py
Normal file
207
backend/danswer/llm/answering/tool/tool_response_handler.py
Normal file
@@ -0,0 +1,207 @@
|
||||
from collections.abc import Generator
|
||||
|
||||
from langchain_core.messages import AIMessageChunk
|
||||
from langchain_core.messages import BaseMessage
|
||||
from langchain_core.messages import ToolCall
|
||||
|
||||
from danswer.llm.answering.llm_response_handler import LLMCall
|
||||
from danswer.llm.answering.llm_response_handler import ResponsePart
|
||||
from danswer.llm.interfaces import LLM
|
||||
from danswer.tools.force import ForceUseTool
|
||||
from danswer.tools.message import build_tool_message
|
||||
from danswer.tools.message import ToolCallSummary
|
||||
from danswer.tools.models import ToolCallFinalResult
|
||||
from danswer.tools.models import ToolCallKickoff
|
||||
from danswer.tools.models import ToolResponse
|
||||
from danswer.tools.tool import Tool
|
||||
from danswer.tools.tool_runner import (
|
||||
check_which_tools_should_run_for_non_tool_calling_llm,
|
||||
)
|
||||
from danswer.tools.tool_runner import ToolRunner
|
||||
from danswer.tools.tool_selection import select_single_tool_for_non_tool_calling_llm
|
||||
from danswer.utils.logger import setup_logger
|
||||
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
class ToolResponseHandler:
|
||||
def __init__(self, tools: list[Tool]):
|
||||
self.tools = tools
|
||||
|
||||
self.tool_call_chunk: AIMessageChunk | None = None
|
||||
self.tool_call_requests: list[ToolCall] = []
|
||||
|
||||
self.tool_runner: ToolRunner | None = None
|
||||
self.tool_call_summary: ToolCallSummary | None = None
|
||||
|
||||
self.tool_kickoff: ToolCallKickoff | None = None
|
||||
self.tool_responses: list[ToolResponse] = []
|
||||
self.tool_final_result: ToolCallFinalResult | None = None
|
||||
|
||||
@classmethod
|
||||
def get_tool_call_for_non_tool_calling_llm(
|
||||
cls, llm_call: LLMCall, llm: LLM
|
||||
) -> tuple[Tool, dict] | None:
|
||||
if llm_call.force_use_tool.force_use:
|
||||
# if we are forcing a tool, we don't need to check which tools to run
|
||||
tool = next(
|
||||
(
|
||||
t
|
||||
for t in llm_call.tools
|
||||
if t.name == llm_call.force_use_tool.tool_name
|
||||
),
|
||||
None,
|
||||
)
|
||||
if not tool:
|
||||
raise RuntimeError(
|
||||
f"Tool '{llm_call.force_use_tool.tool_name}' not found"
|
||||
)
|
||||
|
||||
tool_args = (
|
||||
llm_call.force_use_tool.args
|
||||
if llm_call.force_use_tool.args is not None
|
||||
else tool.get_args_for_non_tool_calling_llm(
|
||||
query=llm_call.prompt_builder.get_user_message_content(),
|
||||
history=llm_call.prompt_builder.raw_message_history,
|
||||
llm=llm,
|
||||
force_run=True,
|
||||
)
|
||||
)
|
||||
|
||||
if tool_args is None:
|
||||
raise RuntimeError(f"Tool '{tool.name}' did not return args")
|
||||
|
||||
return (tool, tool_args)
|
||||
else:
|
||||
tool_options = check_which_tools_should_run_for_non_tool_calling_llm(
|
||||
tools=llm_call.tools,
|
||||
query=llm_call.prompt_builder.get_user_message_content(),
|
||||
history=llm_call.prompt_builder.raw_message_history,
|
||||
llm=llm,
|
||||
)
|
||||
|
||||
available_tools_and_args = [
|
||||
(llm_call.tools[ind], args)
|
||||
for ind, args in enumerate(tool_options)
|
||||
if args is not None
|
||||
]
|
||||
|
||||
logger.info(
|
||||
f"Selecting single tool from tools: {[(tool.name, args) for tool, args in available_tools_and_args]}"
|
||||
)
|
||||
|
||||
chosen_tool_and_args = (
|
||||
select_single_tool_for_non_tool_calling_llm(
|
||||
tools_and_args=available_tools_and_args,
|
||||
history=llm_call.prompt_builder.raw_message_history,
|
||||
query=llm_call.prompt_builder.get_user_message_content(),
|
||||
llm=llm,
|
||||
)
|
||||
if available_tools_and_args
|
||||
else None
|
||||
)
|
||||
|
||||
logger.notice(f"Chosen tool: {chosen_tool_and_args}")
|
||||
return chosen_tool_and_args
|
||||
|
||||
def _handle_tool_call(self) -> Generator[ResponsePart, None, None]:
|
||||
if not self.tool_call_chunk or not self.tool_call_chunk.tool_calls:
|
||||
return
|
||||
|
||||
self.tool_call_requests = self.tool_call_chunk.tool_calls
|
||||
|
||||
selected_tool: Tool | None = None
|
||||
selected_tool_call_request: ToolCall | None = None
|
||||
for tool_call_request in self.tool_call_requests:
|
||||
known_tools_by_name = [
|
||||
tool for tool in self.tools if tool.name == tool_call_request["name"]
|
||||
]
|
||||
|
||||
if not known_tools_by_name:
|
||||
logger.error(
|
||||
"Tool call requested with unknown name field. \n"
|
||||
f"self.tools: {self.tools}"
|
||||
f"tool_call_request: {tool_call_request}"
|
||||
)
|
||||
continue
|
||||
else:
|
||||
selected_tool = known_tools_by_name[0]
|
||||
selected_tool_call_request = tool_call_request
|
||||
|
||||
if selected_tool and selected_tool_call_request:
|
||||
break
|
||||
|
||||
if not selected_tool or not selected_tool_call_request:
|
||||
return
|
||||
|
||||
logger.info(f"Selected tool: {selected_tool.name}")
|
||||
logger.debug(f"Selected tool call request: {selected_tool_call_request}")
|
||||
self.tool_runner = ToolRunner(selected_tool, selected_tool_call_request["args"])
|
||||
self.tool_kickoff = self.tool_runner.kickoff()
|
||||
yield self.tool_kickoff
|
||||
|
||||
for response in self.tool_runner.tool_responses():
|
||||
self.tool_responses.append(response)
|
||||
yield response
|
||||
|
||||
self.tool_final_result = self.tool_runner.tool_final_result()
|
||||
yield self.tool_final_result
|
||||
|
||||
self.tool_call_summary = ToolCallSummary(
|
||||
tool_call_request=self.tool_call_chunk,
|
||||
tool_call_result=build_tool_message(
|
||||
selected_tool_call_request, self.tool_runner.tool_message_content()
|
||||
),
|
||||
)
|
||||
|
||||
def handle_response_part(
|
||||
self,
|
||||
response_item: BaseMessage | None,
|
||||
previous_response_items: list[BaseMessage],
|
||||
) -> Generator[ResponsePart, None, None]:
|
||||
if response_item is None:
|
||||
yield from self._handle_tool_call()
|
||||
|
||||
if isinstance(response_item, AIMessageChunk) and (
|
||||
response_item.tool_call_chunks or response_item.tool_calls
|
||||
):
|
||||
if self.tool_call_chunk is None:
|
||||
self.tool_call_chunk = response_item
|
||||
else:
|
||||
self.tool_call_chunk += response_item # type: ignore
|
||||
|
||||
return
|
||||
|
||||
def next_llm_call(self, current_llm_call: LLMCall) -> LLMCall | None:
|
||||
if (
|
||||
self.tool_runner is None
|
||||
or self.tool_call_summary is None
|
||||
or self.tool_kickoff is None
|
||||
or self.tool_final_result is None
|
||||
):
|
||||
return None
|
||||
|
||||
tool_runner = self.tool_runner
|
||||
new_prompt_builder = tool_runner.tool.build_next_prompt(
|
||||
prompt_builder=current_llm_call.prompt_builder,
|
||||
tool_call_summary=self.tool_call_summary,
|
||||
tool_responses=self.tool_responses,
|
||||
using_tool_calling_llm=current_llm_call.using_tool_calling_llm,
|
||||
)
|
||||
return LLMCall(
|
||||
prompt_builder=new_prompt_builder,
|
||||
tools=[], # for now, only allow one tool call per response
|
||||
force_use_tool=ForceUseTool(
|
||||
force_use=False,
|
||||
tool_name="",
|
||||
args=None,
|
||||
),
|
||||
files=current_llm_call.files,
|
||||
using_tool_calling_llm=current_llm_call.using_tool_calling_llm,
|
||||
tool_call_info=[
|
||||
self.tool_kickoff,
|
||||
*self.tool_responses,
|
||||
self.tool_final_result,
|
||||
],
|
||||
)
|
@@ -203,6 +203,28 @@ def build_content_with_imgs(
|
||||
)
|
||||
|
||||
|
||||
def message_to_prompt_and_imgs(message: BaseMessage) -> tuple[str, list[str]]:
|
||||
if isinstance(message.content, str):
|
||||
return message.content, []
|
||||
|
||||
imgs = []
|
||||
texts = []
|
||||
for part in message.content:
|
||||
if isinstance(part, dict):
|
||||
if part.get("type") == "image_url":
|
||||
img_url = part.get("image_url", {}).get("url")
|
||||
if img_url:
|
||||
imgs.append(img_url)
|
||||
elif part.get("type") == "text":
|
||||
text = part.get("text")
|
||||
if text:
|
||||
texts.append(text)
|
||||
else:
|
||||
texts.append(part)
|
||||
|
||||
return "".join(texts), imgs
|
||||
|
||||
|
||||
def dict_based_prompt_to_langchain_prompt(
|
||||
messages: list[dict[str, str]]
|
||||
) -> list[BaseMessage]:
|
||||
|
@@ -52,12 +52,16 @@ from danswer.secondary_llm_flows.query_expansion import thread_based_query_rephr
|
||||
from danswer.server.query_and_chat.models import ChatMessageDetail
|
||||
from danswer.server.utils import get_json_line
|
||||
from danswer.tools.force import ForceUseTool
|
||||
from danswer.tools.search.search_tool import SEARCH_DOC_CONTENT_ID
|
||||
from danswer.tools.search.search_tool import SEARCH_RESPONSE_SUMMARY_ID
|
||||
from danswer.tools.search.search_tool import SearchResponseSummary
|
||||
from danswer.tools.search.search_tool import SearchTool
|
||||
from danswer.tools.search.search_tool import SECTION_RELEVANCE_LIST_ID
|
||||
from danswer.tools.tool import ToolResponse
|
||||
from danswer.tools.models import ToolResponse
|
||||
from danswer.tools.tool_implementations.search.search_tool import SEARCH_DOC_CONTENT_ID
|
||||
from danswer.tools.tool_implementations.search.search_tool import (
|
||||
SEARCH_RESPONSE_SUMMARY_ID,
|
||||
)
|
||||
from danswer.tools.tool_implementations.search.search_tool import SearchResponseSummary
|
||||
from danswer.tools.tool_implementations.search.search_tool import SearchTool
|
||||
from danswer.tools.tool_implementations.search.search_tool import (
|
||||
SECTION_RELEVANCE_LIST_ID,
|
||||
)
|
||||
from danswer.tools.tool_runner import ToolCallKickoff
|
||||
from danswer.utils.logger import setup_logger
|
||||
from danswer.utils.timing import log_generator_function_time
|
||||
@@ -202,30 +206,33 @@ def stream_answer_objects(
|
||||
max_tokens=max_document_tokens,
|
||||
)
|
||||
|
||||
answer_config = AnswerStyleConfig(
|
||||
citation_config=CitationConfig() if use_citations else None,
|
||||
quotes_config=QuotesConfig() if not use_citations else None,
|
||||
document_pruning_config=document_pruning_config,
|
||||
)
|
||||
|
||||
search_tool = SearchTool(
|
||||
db_session=db_session,
|
||||
user=user,
|
||||
evaluation_type=LLMEvaluationType.SKIP
|
||||
if DISABLE_LLM_DOC_RELEVANCE
|
||||
else query_req.evaluation_type,
|
||||
evaluation_type=(
|
||||
LLMEvaluationType.SKIP
|
||||
if DISABLE_LLM_DOC_RELEVANCE
|
||||
else query_req.evaluation_type
|
||||
),
|
||||
persona=persona,
|
||||
retrieval_options=query_req.retrieval_options,
|
||||
prompt_config=prompt_config,
|
||||
llm=llm,
|
||||
fast_llm=fast_llm,
|
||||
pruning_config=document_pruning_config,
|
||||
answer_style_config=answer_config,
|
||||
bypass_acl=bypass_acl,
|
||||
chunks_above=query_req.chunks_above,
|
||||
chunks_below=query_req.chunks_below,
|
||||
full_doc=query_req.full_doc,
|
||||
)
|
||||
|
||||
answer_config = AnswerStyleConfig(
|
||||
citation_config=CitationConfig() if use_citations else None,
|
||||
quotes_config=QuotesConfig() if not use_citations else None,
|
||||
document_pruning_config=document_pruning_config,
|
||||
)
|
||||
|
||||
answer = Answer(
|
||||
question=query_msg.message,
|
||||
answer_style_config=answer_config,
|
||||
|
@@ -9,7 +9,7 @@ from danswer.db.models import StarterMessage
|
||||
from danswer.search.enums import RecencyBiasSetting
|
||||
from danswer.server.features.document_set.models import DocumentSet
|
||||
from danswer.server.features.prompt.models import PromptSnapshot
|
||||
from danswer.server.features.tool.api import ToolSnapshot
|
||||
from danswer.server.features.tool.models import ToolSnapshot
|
||||
from danswer.server.models import MinimalUserSnapshot
|
||||
from danswer.utils.logger import setup_logger
|
||||
|
||||
|
@@ -18,10 +18,16 @@ from danswer.db.tools import update_tool
|
||||
from danswer.server.features.tool.models import CustomToolCreate
|
||||
from danswer.server.features.tool.models import CustomToolUpdate
|
||||
from danswer.server.features.tool.models import ToolSnapshot
|
||||
from danswer.tools.custom.openapi_parsing import MethodSpec
|
||||
from danswer.tools.custom.openapi_parsing import openapi_to_method_specs
|
||||
from danswer.tools.custom.openapi_parsing import validate_openapi_schema
|
||||
from danswer.tools.images.image_generation_tool import ImageGenerationTool
|
||||
from danswer.tools.tool_implementations.custom.openapi_parsing import MethodSpec
|
||||
from danswer.tools.tool_implementations.custom.openapi_parsing import (
|
||||
openapi_to_method_specs,
|
||||
)
|
||||
from danswer.tools.tool_implementations.custom.openapi_parsing import (
|
||||
validate_openapi_schema,
|
||||
)
|
||||
from danswer.tools.tool_implementations.images.image_generation_tool import (
|
||||
ImageGenerationTool,
|
||||
)
|
||||
from danswer.tools.utils import is_image_generation_available
|
||||
|
||||
router = APIRouter(prefix="/tool")
|
||||
|
@@ -1,5 +1,6 @@
|
||||
import asyncio
|
||||
import io
|
||||
import json
|
||||
import uuid
|
||||
from collections.abc import Callable
|
||||
from collections.abc import Generator
|
||||
@@ -283,13 +284,14 @@ def delete_chat_session_by_id(
|
||||
raise HTTPException(status_code=400, detail=str(e))
|
||||
|
||||
|
||||
async def is_disconnected(request: Request) -> Callable[[], bool]:
|
||||
async def is_connected(request: Request) -> Callable[[], bool]:
|
||||
main_loop = asyncio.get_event_loop()
|
||||
|
||||
def is_disconnected_sync() -> bool:
|
||||
def is_connected_sync() -> bool:
|
||||
future = asyncio.run_coroutine_threadsafe(request.is_disconnected(), main_loop)
|
||||
try:
|
||||
return not future.result(timeout=0.01)
|
||||
is_connected = not future.result(timeout=0.01)
|
||||
return is_connected
|
||||
except asyncio.TimeoutError:
|
||||
logger.error("Asyncio timed out")
|
||||
return True
|
||||
@@ -300,7 +302,7 @@ async def is_disconnected(request: Request) -> Callable[[], bool]:
|
||||
)
|
||||
return True
|
||||
|
||||
return is_disconnected_sync
|
||||
return is_connected_sync
|
||||
|
||||
|
||||
@router.post("/send-message")
|
||||
@@ -309,7 +311,7 @@ def handle_new_chat_message(
|
||||
request: Request,
|
||||
user: User | None = Depends(current_user),
|
||||
_: None = Depends(check_token_rate_limits),
|
||||
is_disconnected_func: Callable[[], bool] = Depends(is_disconnected),
|
||||
is_connected_func: Callable[[], bool] = Depends(is_connected),
|
||||
) -> StreamingResponse:
|
||||
"""
|
||||
This endpoint is both used for all the following purposes:
|
||||
@@ -325,7 +327,7 @@ def handle_new_chat_message(
|
||||
request (Request): The current HTTP request context.
|
||||
user (User | None): The current user, obtained via dependency injection.
|
||||
_ (None): Rate limit check is run if user/group/global rate limits are enabled.
|
||||
is_disconnected_func (Callable[[], bool]): Function to check client disconnection,
|
||||
is_connected_func (Callable[[], bool]): Function to check client disconnection,
|
||||
used to stop the streaming response if the client disconnects.
|
||||
|
||||
Returns:
|
||||
@@ -340,8 +342,6 @@ def handle_new_chat_message(
|
||||
):
|
||||
raise HTTPException(status_code=400, detail="Empty chat message is invalid")
|
||||
|
||||
import json
|
||||
|
||||
def stream_generator() -> Generator[str, None, None]:
|
||||
try:
|
||||
for packet in stream_chat_message(
|
||||
@@ -354,7 +354,7 @@ def handle_new_chat_message(
|
||||
custom_tool_additional_headers=get_custom_tool_additional_request_headers(
|
||||
request.headers
|
||||
),
|
||||
is_connected=is_disconnected_func,
|
||||
is_connected=is_connected_func,
|
||||
):
|
||||
yield json.dumps(packet) if isinstance(packet, dict) else packet
|
||||
|
||||
@@ -362,6 +362,9 @@ def handle_new_chat_message(
|
||||
logger.exception(f"Error in chat message streaming: {e}")
|
||||
yield json.dumps({"error": str(e)})
|
||||
|
||||
finally:
|
||||
logger.debug("Stream generator finished")
|
||||
|
||||
return StreamingResponse(stream_generator(), media_type="text/event-stream")
|
||||
|
||||
|
||||
|
59
backend/danswer/tools/base_tool.py
Normal file
59
backend/danswer/tools/base_tool.py
Normal file
@@ -0,0 +1,59 @@
|
||||
from typing import cast
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from langchain_core.messages import HumanMessage
|
||||
|
||||
from danswer.llm.utils import message_to_prompt_and_imgs
|
||||
from danswer.tools.tool import Tool
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from danswer.llm.answering.prompts.build import AnswerPromptBuilder
|
||||
from danswer.tools.tool_implementations.custom.custom_tool import (
|
||||
CustomToolCallSummary,
|
||||
)
|
||||
from danswer.tools.message import ToolCallSummary
|
||||
from danswer.tools.models import ToolResponse
|
||||
|
||||
|
||||
def build_user_message_for_non_tool_calling_llm(
|
||||
message: HumanMessage,
|
||||
tool_name: str,
|
||||
*args: "ToolResponse",
|
||||
) -> str:
|
||||
query, _ = message_to_prompt_and_imgs(message)
|
||||
|
||||
tool_run_summary = cast("CustomToolCallSummary", args[0].response).tool_result
|
||||
return f"""
|
||||
Here's the result from the {tool_name} tool:
|
||||
|
||||
{tool_run_summary}
|
||||
|
||||
Now respond to the following:
|
||||
|
||||
{query}
|
||||
""".strip()
|
||||
|
||||
|
||||
class BaseTool(Tool):
|
||||
def build_next_prompt(
|
||||
self,
|
||||
prompt_builder: "AnswerPromptBuilder",
|
||||
tool_call_summary: "ToolCallSummary",
|
||||
tool_responses: list["ToolResponse"],
|
||||
using_tool_calling_llm: bool,
|
||||
) -> "AnswerPromptBuilder":
|
||||
if using_tool_calling_llm:
|
||||
prompt_builder.append_message(tool_call_summary.tool_call_request)
|
||||
prompt_builder.append_message(tool_call_summary.tool_call_result)
|
||||
else:
|
||||
prompt_builder.update_user_prompt(
|
||||
HumanMessage(
|
||||
content=build_user_message_for_non_tool_calling_llm(
|
||||
prompt_builder.user_message_and_token_cnt[0],
|
||||
self.name,
|
||||
*tool_responses,
|
||||
)
|
||||
)
|
||||
)
|
||||
|
||||
return prompt_builder
|
@@ -9,9 +9,13 @@ from sqlalchemy.orm import Session
|
||||
|
||||
from danswer.db.models import Persona
|
||||
from danswer.db.models import Tool as ToolDBModel
|
||||
from danswer.tools.images.image_generation_tool import ImageGenerationTool
|
||||
from danswer.tools.internet_search.internet_search_tool import InternetSearchTool
|
||||
from danswer.tools.search.search_tool import SearchTool
|
||||
from danswer.tools.tool_implementations.images.image_generation_tool import (
|
||||
ImageGenerationTool,
|
||||
)
|
||||
from danswer.tools.tool_implementations.internet_search.internet_search_tool import (
|
||||
InternetSearchTool,
|
||||
)
|
||||
from danswer.tools.tool_implementations.search.search_tool import SearchTool
|
||||
from danswer.tools.tool import Tool
|
||||
from danswer.utils.logger import setup_logger
|
||||
|
||||
|
@@ -1,21 +0,0 @@
|
||||
from typing import cast
|
||||
|
||||
from danswer.tools.custom.custom_tool import CustomToolCallSummary
|
||||
from danswer.tools.models import ToolResponse
|
||||
|
||||
|
||||
def build_user_message_for_custom_tool_for_non_tool_calling_llm(
|
||||
query: str,
|
||||
tool_name: str,
|
||||
*args: ToolResponse,
|
||||
) -> str:
|
||||
tool_run_summary = cast(CustomToolCallSummary, args[0].response).tool_result
|
||||
return f"""
|
||||
Here's the result from the {tool_name} tool:
|
||||
|
||||
{tool_run_summary}
|
||||
|
||||
Now respond to the following:
|
||||
|
||||
{query}
|
||||
""".strip()
|
@@ -1,11 +1,17 @@
|
||||
import abc
|
||||
from collections.abc import Generator
|
||||
from typing import Any
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from danswer.key_value_store.interface import JSON_ro
|
||||
from danswer.llm.answering.models import PreviousMessage
|
||||
from danswer.llm.interfaces import LLM
|
||||
from danswer.tools.models import ToolResponse
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from danswer.llm.answering.prompts.build import AnswerPromptBuilder
|
||||
from danswer.tools.message import ToolCallSummary
|
||||
from danswer.tools.models import ToolResponse
|
||||
|
||||
|
||||
class Tool(abc.ABC):
|
||||
@@ -32,7 +38,7 @@ class Tool(abc.ABC):
|
||||
|
||||
@abc.abstractmethod
|
||||
def build_tool_message_content(
|
||||
self, *args: ToolResponse
|
||||
self, *args: "ToolResponse"
|
||||
) -> str | list[str | dict[str, Any]]:
|
||||
raise NotImplementedError
|
||||
|
||||
@@ -51,13 +57,26 @@ class Tool(abc.ABC):
|
||||
"""Actual execution of the tool"""
|
||||
|
||||
@abc.abstractmethod
|
||||
def run(self, **kwargs: Any) -> Generator[ToolResponse, None, None]:
|
||||
def run(self, **kwargs: Any) -> Generator["ToolResponse", None, None]:
|
||||
raise NotImplementedError
|
||||
|
||||
@abc.abstractmethod
|
||||
def final_result(self, *args: ToolResponse) -> JSON_ro:
|
||||
def final_result(self, *args: "ToolResponse") -> JSON_ro:
|
||||
"""
|
||||
This is the "final summary" result of the tool.
|
||||
It is the result that will be stored in the database.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
"""Some tools may want to modify the prompt based on the tool call summary and tool responses.
|
||||
Default behavior is to continue with just the raw tool call request/result passed to the LLM."""
|
||||
|
||||
@abc.abstractmethod
|
||||
def build_next_prompt(
|
||||
self,
|
||||
prompt_builder: "AnswerPromptBuilder",
|
||||
tool_call_summary: "ToolCallSummary",
|
||||
tool_responses: list["ToolResponse"],
|
||||
using_tool_calling_llm: bool,
|
||||
) -> "AnswerPromptBuilder":
|
||||
raise NotImplementedError
|
||||
|
@@ -11,24 +11,34 @@ from pydantic import BaseModel
|
||||
from danswer.key_value_store.interface import JSON_ro
|
||||
from danswer.llm.answering.models import PreviousMessage
|
||||
from danswer.llm.interfaces import LLM
|
||||
from danswer.tools.custom.base_tool_types import ToolResultType
|
||||
from danswer.tools.custom.custom_tool_prompts import (
|
||||
SHOULD_USE_CUSTOM_TOOL_SYSTEM_PROMPT,
|
||||
)
|
||||
from danswer.tools.custom.custom_tool_prompts import SHOULD_USE_CUSTOM_TOOL_USER_PROMPT
|
||||
from danswer.tools.custom.custom_tool_prompts import TOOL_ARG_SYSTEM_PROMPT
|
||||
from danswer.tools.custom.custom_tool_prompts import TOOL_ARG_USER_PROMPT
|
||||
from danswer.tools.custom.custom_tool_prompts import USE_TOOL
|
||||
from danswer.tools.custom.openapi_parsing import MethodSpec
|
||||
from danswer.tools.custom.openapi_parsing import openapi_to_method_specs
|
||||
from danswer.tools.custom.openapi_parsing import openapi_to_url
|
||||
from danswer.tools.custom.openapi_parsing import REQUEST_BODY
|
||||
from danswer.tools.custom.openapi_parsing import validate_openapi_schema
|
||||
from danswer.tools.base_tool import BaseTool
|
||||
from danswer.tools.models import CHAT_SESSION_ID_PLACEHOLDER
|
||||
from danswer.tools.models import DynamicSchemaInfo
|
||||
from danswer.tools.models import MESSAGE_ID_PLACEHOLDER
|
||||
from danswer.tools.tool import Tool
|
||||
from danswer.tools.tool import ToolResponse
|
||||
from danswer.tools.models import ToolResponse
|
||||
from danswer.tools.tool_implementations.custom.base_tool_types import ToolResultType
|
||||
from danswer.tools.tool_implementations.custom.custom_tool_prompts import (
|
||||
SHOULD_USE_CUSTOM_TOOL_SYSTEM_PROMPT,
|
||||
)
|
||||
from danswer.tools.tool_implementations.custom.custom_tool_prompts import (
|
||||
SHOULD_USE_CUSTOM_TOOL_USER_PROMPT,
|
||||
)
|
||||
from danswer.tools.tool_implementations.custom.custom_tool_prompts import (
|
||||
TOOL_ARG_SYSTEM_PROMPT,
|
||||
)
|
||||
from danswer.tools.tool_implementations.custom.custom_tool_prompts import (
|
||||
TOOL_ARG_USER_PROMPT,
|
||||
)
|
||||
from danswer.tools.tool_implementations.custom.custom_tool_prompts import USE_TOOL
|
||||
from danswer.tools.tool_implementations.custom.openapi_parsing import MethodSpec
|
||||
from danswer.tools.tool_implementations.custom.openapi_parsing import (
|
||||
openapi_to_method_specs,
|
||||
)
|
||||
from danswer.tools.tool_implementations.custom.openapi_parsing import openapi_to_url
|
||||
from danswer.tools.tool_implementations.custom.openapi_parsing import REQUEST_BODY
|
||||
from danswer.tools.tool_implementations.custom.openapi_parsing import (
|
||||
validate_openapi_schema,
|
||||
)
|
||||
from danswer.utils.headers import header_list_to_header_dict
|
||||
from danswer.utils.headers import HeaderItemDict
|
||||
from danswer.utils.logger import setup_logger
|
||||
@@ -43,7 +53,7 @@ class CustomToolCallSummary(BaseModel):
|
||||
tool_result: ToolResultType
|
||||
|
||||
|
||||
class CustomTool(Tool):
|
||||
class CustomTool(BaseTool):
|
||||
def __init__(
|
||||
self,
|
||||
method_spec: MethodSpec,
|
@@ -11,12 +11,17 @@ from danswer.chat.chat_utils import combine_message_chain
|
||||
from danswer.configs.model_configs import GEN_AI_HISTORY_CUTOFF
|
||||
from danswer.key_value_store.interface import JSON_ro
|
||||
from danswer.llm.answering.models import PreviousMessage
|
||||
from danswer.llm.answering.prompts.build import AnswerPromptBuilder
|
||||
from danswer.llm.interfaces import LLM
|
||||
from danswer.llm.utils import build_content_with_imgs
|
||||
from danswer.llm.utils import message_to_string
|
||||
from danswer.prompts.constants import GENERAL_SEP_PAT
|
||||
from danswer.tools.message import ToolCallSummary
|
||||
from danswer.tools.models import ToolResponse
|
||||
from danswer.tools.tool import Tool
|
||||
from danswer.tools.tool import ToolResponse
|
||||
from danswer.tools.tool_implementations.images.prompt import (
|
||||
build_image_generation_user_prompt,
|
||||
)
|
||||
from danswer.utils.headers import build_llm_extra_headers
|
||||
from danswer.utils.logger import setup_logger
|
||||
from danswer.utils.threadpool_concurrency import run_functions_tuples_in_parallel
|
||||
@@ -112,7 +117,10 @@ class ImageGenerationTool(Tool):
|
||||
},
|
||||
"shape": {
|
||||
"type": "string",
|
||||
"description": "Optional. Image shape: 'square', 'portrait', or 'landscape'",
|
||||
"description": (
|
||||
"Optional - only specify if you want a specific shape."
|
||||
" Image shape: 'square', 'portrait', or 'landscape'."
|
||||
),
|
||||
"enum": [shape.value for shape in ImageShape],
|
||||
},
|
||||
},
|
||||
@@ -258,3 +266,34 @@ class ImageGenerationTool(Tool):
|
||||
image_generation_response.model_dump()
|
||||
for image_generation_response in image_generation_responses
|
||||
]
|
||||
|
||||
def build_next_prompt(
|
||||
self,
|
||||
prompt_builder: AnswerPromptBuilder,
|
||||
tool_call_summary: ToolCallSummary,
|
||||
tool_responses: list[ToolResponse],
|
||||
using_tool_calling_llm: bool,
|
||||
) -> AnswerPromptBuilder:
|
||||
img_generation_response = cast(
|
||||
list[ImageGenerationResponse] | None,
|
||||
next(
|
||||
(
|
||||
response.response
|
||||
for response in tool_responses
|
||||
if response.id == IMAGE_GENERATION_RESPONSE_ID
|
||||
),
|
||||
None,
|
||||
),
|
||||
)
|
||||
if img_generation_response is None:
|
||||
raise ValueError("No image generation response found")
|
||||
|
||||
img_urls = [img.url for img in img_generation_response]
|
||||
prompt_builder.update_user_prompt(
|
||||
build_image_generation_user_prompt(
|
||||
query=prompt_builder.get_user_message_content(),
|
||||
img_urls=img_urls,
|
||||
)
|
||||
)
|
||||
|
||||
return prompt_builder
|
@@ -11,18 +11,31 @@ from danswer.chat.models import LlmDoc
|
||||
from danswer.configs.constants import DocumentSource
|
||||
from danswer.configs.model_configs import GEN_AI_HISTORY_CUTOFF
|
||||
from danswer.key_value_store.interface import JSON_ro
|
||||
from danswer.llm.answering.models import AnswerStyleConfig
|
||||
from danswer.llm.answering.models import PreviousMessage
|
||||
from danswer.llm.answering.models import PromptConfig
|
||||
from danswer.llm.answering.prompts.build import AnswerPromptBuilder
|
||||
from danswer.llm.interfaces import LLM
|
||||
from danswer.llm.utils import message_to_string
|
||||
from danswer.prompts.chat_prompts import INTERNET_SEARCH_QUERY_REPHRASE
|
||||
from danswer.prompts.constants import GENERAL_SEP_PAT
|
||||
from danswer.search.models import SearchDoc
|
||||
from danswer.secondary_llm_flows.query_expansion import history_based_query_rephrase
|
||||
from danswer.tools.internet_search.models import InternetSearchResponse
|
||||
from danswer.tools.internet_search.models import InternetSearchResult
|
||||
from danswer.tools.search.search_tool import FINAL_CONTEXT_DOCUMENTS_ID
|
||||
from danswer.tools.message import ToolCallSummary
|
||||
from danswer.tools.models import ToolResponse
|
||||
from danswer.tools.tool import Tool
|
||||
from danswer.tools.tool import ToolResponse
|
||||
from danswer.tools.tool_implementations.internet_search.models import (
|
||||
InternetSearchResponse,
|
||||
)
|
||||
from danswer.tools.tool_implementations.internet_search.models import (
|
||||
InternetSearchResult,
|
||||
)
|
||||
from danswer.tools.tool_implementations.search_like_tool_utils import (
|
||||
build_next_prompt_for_search_like_tool,
|
||||
)
|
||||
from danswer.tools.tool_implementations.search_like_tool_utils import (
|
||||
FINAL_CONTEXT_DOCUMENTS_ID,
|
||||
)
|
||||
from danswer.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
@@ -97,8 +110,17 @@ class InternetSearchTool(Tool):
|
||||
_DISPLAY_NAME = "[Beta] Internet Search Tool"
|
||||
_DESCRIPTION = "Perform an internet search for up-to-date information."
|
||||
|
||||
def __init__(self, api_key: str, num_results: int = 10) -> None:
|
||||
def __init__(
|
||||
self,
|
||||
api_key: str,
|
||||
answer_style_config: AnswerStyleConfig,
|
||||
prompt_config: PromptConfig,
|
||||
num_results: int = 10,
|
||||
) -> None:
|
||||
self.api_key = api_key
|
||||
self.answer_style_config = answer_style_config
|
||||
self.prompt_config = prompt_config
|
||||
|
||||
self.host = "https://api.bing.microsoft.com/v7.0"
|
||||
self.headers = {
|
||||
"Ocp-Apim-Subscription-Key": api_key,
|
||||
@@ -231,3 +253,19 @@ class InternetSearchTool(Tool):
|
||||
def final_result(self, *args: ToolResponse) -> JSON_ro:
|
||||
search_response = cast(InternetSearchResponse, args[0].response)
|
||||
return search_response.model_dump()
|
||||
|
||||
def build_next_prompt(
|
||||
self,
|
||||
prompt_builder: AnswerPromptBuilder,
|
||||
tool_call_summary: ToolCallSummary,
|
||||
tool_responses: list[ToolResponse],
|
||||
using_tool_calling_llm: bool,
|
||||
) -> AnswerPromptBuilder:
|
||||
return build_next_prompt_for_search_like_tool(
|
||||
prompt_builder=prompt_builder,
|
||||
tool_call_summary=tool_call_summary,
|
||||
tool_responses=tool_responses,
|
||||
using_tool_calling_llm=using_tool_calling_llm,
|
||||
answer_style_config=self.answer_style_config,
|
||||
prompt_config=self.prompt_config,
|
||||
)
|
@@ -17,10 +17,13 @@ from danswer.configs.model_configs import GEN_AI_MODEL_FALLBACK_MAX_TOKENS
|
||||
from danswer.db.models import Persona
|
||||
from danswer.db.models import User
|
||||
from danswer.key_value_store.interface import JSON_ro
|
||||
from danswer.llm.answering.llm_response_handler import LLMCall
|
||||
from danswer.llm.answering.models import AnswerStyleConfig
|
||||
from danswer.llm.answering.models import ContextualPruningConfig
|
||||
from danswer.llm.answering.models import DocumentPruningConfig
|
||||
from danswer.llm.answering.models import PreviousMessage
|
||||
from danswer.llm.answering.models import PromptConfig
|
||||
from danswer.llm.answering.prompts.build import AnswerPromptBuilder
|
||||
from danswer.llm.answering.prompts.citations_prompt import compute_max_llm_input_tokens
|
||||
from danswer.llm.answering.prune_and_merge import prune_and_merge_sections
|
||||
from danswer.llm.answering.prune_and_merge import prune_sections
|
||||
@@ -35,9 +38,16 @@ from danswer.search.models import SearchRequest
|
||||
from danswer.search.pipeline import SearchPipeline
|
||||
from danswer.secondary_llm_flows.choose_search import check_if_need_search
|
||||
from danswer.secondary_llm_flows.query_expansion import history_based_query_rephrase
|
||||
from danswer.tools.search.search_utils import llm_doc_to_dict
|
||||
from danswer.tools.message import ToolCallSummary
|
||||
from danswer.tools.models import ToolResponse
|
||||
from danswer.tools.tool import Tool
|
||||
from danswer.tools.tool import ToolResponse
|
||||
from danswer.tools.tool_implementations.search.search_utils import llm_doc_to_dict
|
||||
from danswer.tools.tool_implementations.search_like_tool_utils import (
|
||||
build_next_prompt_for_search_like_tool,
|
||||
)
|
||||
from danswer.tools.tool_implementations.search_like_tool_utils import (
|
||||
FINAL_CONTEXT_DOCUMENTS_ID,
|
||||
)
|
||||
from danswer.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
@@ -45,7 +55,6 @@ logger = setup_logger()
|
||||
SEARCH_RESPONSE_SUMMARY_ID = "search_response_summary"
|
||||
SEARCH_DOC_CONTENT_ID = "search_doc_content"
|
||||
SECTION_RELEVANCE_LIST_ID = "section_relevance_list"
|
||||
FINAL_CONTEXT_DOCUMENTS_ID = "final_context_documents"
|
||||
SEARCH_EVALUATION_ID = "llm_doc_eval"
|
||||
|
||||
|
||||
@@ -85,6 +94,7 @@ class SearchTool(Tool):
|
||||
llm: LLM,
|
||||
fast_llm: LLM,
|
||||
pruning_config: DocumentPruningConfig,
|
||||
answer_style_config: AnswerStyleConfig,
|
||||
evaluation_type: LLMEvaluationType,
|
||||
# if specified, will not actually run a search and will instead return these
|
||||
# sections. Used when the user selects specific docs to talk to
|
||||
@@ -136,6 +146,7 @@ class SearchTool(Tool):
|
||||
|
||||
num_chunk_multiple = self.chunks_above + self.chunks_below + 1
|
||||
|
||||
self.answer_style_config = answer_style_config
|
||||
self.contextual_pruning_config = (
|
||||
ContextualPruningConfig.from_doc_pruning_config(
|
||||
num_chunk_multiple=num_chunk_multiple, doc_pruning_config=pruning_config
|
||||
@@ -353,4 +364,36 @@ class SearchTool(Tool):
|
||||
# NOTE: need to do this json.loads(doc.json()) stuff because there are some
|
||||
# subfields that are not serializable by default (datetime)
|
||||
# this forces pydantic to make them JSON serializable for us
|
||||
return [json.loads(doc.json()) for doc in final_docs]
|
||||
return [json.loads(doc.model_dump_json()) for doc in final_docs]
|
||||
|
||||
def build_next_prompt(
|
||||
self,
|
||||
prompt_builder: AnswerPromptBuilder,
|
||||
tool_call_summary: ToolCallSummary,
|
||||
tool_responses: list[ToolResponse],
|
||||
using_tool_calling_llm: bool,
|
||||
) -> AnswerPromptBuilder:
|
||||
return build_next_prompt_for_search_like_tool(
|
||||
prompt_builder=prompt_builder,
|
||||
tool_call_summary=tool_call_summary,
|
||||
tool_responses=tool_responses,
|
||||
using_tool_calling_llm=using_tool_calling_llm,
|
||||
answer_style_config=self.answer_style_config,
|
||||
prompt_config=self.prompt_config,
|
||||
)
|
||||
|
||||
"""Other utility functions"""
|
||||
|
||||
@classmethod
|
||||
def get_search_result(cls, llm_call: LLMCall) -> list[LlmDoc] | None:
|
||||
if not llm_call.tool_call_info:
|
||||
return None
|
||||
|
||||
for yield_item in llm_call.tool_call_info:
|
||||
if (
|
||||
isinstance(yield_item, ToolResponse)
|
||||
and yield_item.id == FINAL_CONTEXT_DOCUMENTS_ID
|
||||
):
|
||||
return cast(list[LlmDoc], yield_item.response)
|
||||
|
||||
return None
|
@@ -0,0 +1,71 @@
|
||||
from typing import cast
|
||||
|
||||
from danswer.chat.models import LlmDoc
|
||||
from danswer.llm.answering.models import AnswerStyleConfig
|
||||
from danswer.llm.answering.models import PromptConfig
|
||||
from danswer.llm.answering.prompts.build import AnswerPromptBuilder
|
||||
from danswer.llm.answering.prompts.citations_prompt import (
|
||||
build_citations_system_message,
|
||||
)
|
||||
from danswer.llm.answering.prompts.citations_prompt import build_citations_user_message
|
||||
from danswer.llm.answering.prompts.quotes_prompt import build_quotes_user_message
|
||||
from danswer.tools.message import ToolCallSummary
|
||||
from danswer.tools.models import ToolResponse
|
||||
|
||||
|
||||
FINAL_CONTEXT_DOCUMENTS_ID = "final_context_documents"
|
||||
|
||||
|
||||
def build_next_prompt_for_search_like_tool(
|
||||
prompt_builder: AnswerPromptBuilder,
|
||||
tool_call_summary: ToolCallSummary,
|
||||
tool_responses: list[ToolResponse],
|
||||
using_tool_calling_llm: bool,
|
||||
answer_style_config: AnswerStyleConfig,
|
||||
prompt_config: PromptConfig,
|
||||
) -> AnswerPromptBuilder:
|
||||
if not using_tool_calling_llm:
|
||||
final_context_docs_response = next(
|
||||
response
|
||||
for response in tool_responses
|
||||
if response.id == FINAL_CONTEXT_DOCUMENTS_ID
|
||||
)
|
||||
final_context_documents = cast(
|
||||
list[LlmDoc], final_context_docs_response.response
|
||||
)
|
||||
else:
|
||||
# if using tool calling llm, then the final context documents are the tool responses
|
||||
final_context_documents = []
|
||||
|
||||
if answer_style_config.citation_config:
|
||||
prompt_builder.update_system_prompt(
|
||||
build_citations_system_message(prompt_config)
|
||||
)
|
||||
prompt_builder.update_user_prompt(
|
||||
build_citations_user_message(
|
||||
message=prompt_builder.user_message_and_token_cnt[0],
|
||||
prompt_config=prompt_config,
|
||||
context_docs=final_context_documents,
|
||||
all_doc_useful=(
|
||||
answer_style_config.citation_config.all_docs_useful
|
||||
if answer_style_config.citation_config
|
||||
else False
|
||||
),
|
||||
history_message=prompt_builder.single_message_history or "",
|
||||
)
|
||||
)
|
||||
elif answer_style_config.quotes_config:
|
||||
prompt_builder.update_user_prompt(
|
||||
build_quotes_user_message(
|
||||
message=prompt_builder.user_message_and_token_cnt[0],
|
||||
context_docs=final_context_documents,
|
||||
history_str=prompt_builder.single_message_history or "",
|
||||
prompt=prompt_config,
|
||||
)
|
||||
)
|
||||
|
||||
if using_tool_calling_llm:
|
||||
prompt_builder.append_message(tool_call_summary.tool_call_request)
|
||||
prompt_builder.append_message(tool_call_summary.tool_call_result)
|
||||
|
||||
return prompt_builder
|
@@ -6,8 +6,8 @@ from danswer.llm.answering.models import PreviousMessage
|
||||
from danswer.llm.interfaces import LLM
|
||||
from danswer.tools.models import ToolCallFinalResult
|
||||
from danswer.tools.models import ToolCallKickoff
|
||||
from danswer.tools.models import ToolResponse
|
||||
from danswer.tools.tool import Tool
|
||||
from danswer.tools.tool import ToolResponse
|
||||
from danswer.utils.threadpool_concurrency import run_functions_tuples_in_parallel
|
||||
|
||||
|
||||
|
@@ -12,7 +12,7 @@ from danswer.db.models import Tool
|
||||
from danswer.db.models import User
|
||||
from danswer.db.persona import get_prompts_by_ids
|
||||
from danswer.one_shot_answer.models import PersonaConfig
|
||||
from danswer.tools.custom.custom_tool import (
|
||||
from danswer.tools.tool_implementations.custom.custom_tool import (
|
||||
build_custom_tools_from_openapi_schema_and_headers,
|
||||
)
|
||||
|
||||
|
@@ -142,6 +142,9 @@ def test_using_reference_docs_with_simple_with_history_api_flow(reset: None) ->
|
||||
assert response.status_code == 200
|
||||
response_json = response.json()
|
||||
|
||||
# make sure there is an answer
|
||||
assert response_json["answer"]
|
||||
|
||||
# since we only gave it one search doc, all responses should only contain that doc
|
||||
assert response_json["final_context_doc_indices"] == [0]
|
||||
assert response_json["llm_selected_doc_indices"] == [0]
|
||||
|
113
backend/tests/unit/danswer/llm/answering/conftest.py
Normal file
113
backend/tests/unit/danswer/llm/answering/conftest.py
Normal file
@@ -0,0 +1,113 @@
|
||||
import json
|
||||
from datetime import datetime
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
from langchain_core.messages import SystemMessage
|
||||
|
||||
from danswer.chat.models import LlmDoc
|
||||
from danswer.configs.constants import DocumentSource
|
||||
from danswer.llm.answering.models import AnswerStyleConfig
|
||||
from danswer.llm.answering.models import CitationConfig
|
||||
from danswer.llm.answering.models import PromptConfig
|
||||
from danswer.llm.answering.prompts.build import AnswerPromptBuilder
|
||||
from danswer.llm.interfaces import LLMConfig
|
||||
from danswer.tools.models import ToolResponse
|
||||
from danswer.tools.tool_implementations.search.search_tool import SearchTool
|
||||
from danswer.tools.tool_implementations.search_like_tool_utils import (
|
||||
FINAL_CONTEXT_DOCUMENTS_ID,
|
||||
)
|
||||
|
||||
QUERY = "Test question"
|
||||
DEFAULT_SEARCH_ARGS = {"query": "search"}
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def answer_style_config() -> AnswerStyleConfig:
|
||||
return AnswerStyleConfig(citation_config=CitationConfig())
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def prompt_config() -> PromptConfig:
|
||||
return PromptConfig(
|
||||
system_prompt="System prompt",
|
||||
task_prompt="Task prompt",
|
||||
datetime_aware=False,
|
||||
include_citations=True,
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_llm() -> MagicMock:
|
||||
mock_llm_obj = MagicMock()
|
||||
mock_llm_obj.config = LLMConfig(
|
||||
model_provider="openai",
|
||||
model_name="gpt-4o",
|
||||
temperature=0.0,
|
||||
api_key=None,
|
||||
api_base=None,
|
||||
api_version=None,
|
||||
)
|
||||
return mock_llm_obj
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_search_results() -> list[LlmDoc]:
|
||||
return [
|
||||
LlmDoc(
|
||||
content="Search result 1",
|
||||
source_type=DocumentSource.WEB,
|
||||
metadata={"id": "doc1"},
|
||||
document_id="doc1",
|
||||
blurb="Blurb 1",
|
||||
semantic_identifier="Semantic ID 1",
|
||||
updated_at=datetime(2023, 1, 1),
|
||||
link="https://example.com/doc1",
|
||||
source_links={0: "https://example.com/doc1"},
|
||||
),
|
||||
LlmDoc(
|
||||
content="Search result 2",
|
||||
source_type=DocumentSource.WEB,
|
||||
metadata={"id": "doc2"},
|
||||
document_id="doc2",
|
||||
blurb="Blurb 2",
|
||||
semantic_identifier="Semantic ID 2",
|
||||
updated_at=datetime(2023, 1, 2),
|
||||
link="https://example.com/doc2",
|
||||
source_links={0: "https://example.com/doc2"},
|
||||
),
|
||||
]
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_search_tool(mock_search_results: list[LlmDoc]) -> MagicMock:
|
||||
mock_tool = MagicMock(spec=SearchTool)
|
||||
mock_tool.name = "search"
|
||||
mock_tool.build_tool_message_content.return_value = "search_response"
|
||||
mock_tool.get_args_for_non_tool_calling_llm.return_value = DEFAULT_SEARCH_ARGS
|
||||
mock_tool.final_result.return_value = [
|
||||
json.loads(doc.model_dump_json()) for doc in mock_search_results
|
||||
]
|
||||
mock_tool.run.return_value = [
|
||||
ToolResponse(id=FINAL_CONTEXT_DOCUMENTS_ID, response=mock_search_results)
|
||||
]
|
||||
mock_tool.tool_definition.return_value = {
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "search",
|
||||
"description": "Search for information",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"query": {"type": "string", "description": "The search query"},
|
||||
},
|
||||
"required": ["query"],
|
||||
},
|
||||
},
|
||||
}
|
||||
mock_post_search_tool_prompt_builder = MagicMock(spec=AnswerPromptBuilder)
|
||||
mock_post_search_tool_prompt_builder.build.return_value = [
|
||||
SystemMessage(content="Updated system prompt"),
|
||||
]
|
||||
mock_tool.build_next_prompt.return_value = mock_post_search_tool_prompt_builder
|
||||
return mock_tool
|
@@ -7,7 +7,7 @@ from danswer.chat.models import DanswerAnswerPiece
|
||||
from danswer.chat.models import LlmDoc
|
||||
from danswer.configs.constants import DocumentSource
|
||||
from danswer.llm.answering.stream_processing.citation_processing import (
|
||||
extract_citations_from_stream,
|
||||
CitationProcessor,
|
||||
)
|
||||
from danswer.llm.answering.stream_processing.utils import DocumentIdOrderMapping
|
||||
|
||||
@@ -70,14 +70,16 @@ def process_text(
|
||||
) -> tuple[str, list[CitationInfo]]:
|
||||
mock_docs, mock_doc_id_to_rank_map = mock_data
|
||||
mapping = DocumentIdOrderMapping(order_mapping=mock_doc_id_to_rank_map)
|
||||
result = list(
|
||||
extract_citations_from_stream(
|
||||
tokens=iter(tokens),
|
||||
context_docs=mock_docs,
|
||||
doc_id_to_rank_map=mapping,
|
||||
stop_stream=None,
|
||||
)
|
||||
processor = CitationProcessor(
|
||||
context_docs=mock_docs,
|
||||
doc_id_to_rank_map=mapping,
|
||||
stop_stream=None,
|
||||
)
|
||||
result: list[DanswerAnswerPiece | CitationInfo] = []
|
||||
for token in tokens:
|
||||
result.extend(processor.process_token(token))
|
||||
result.extend(processor.process_token(None))
|
||||
|
||||
final_answer_text = ""
|
||||
citations = []
|
||||
for piece in result:
|
||||
|
@@ -6,7 +6,7 @@ from danswer.chat.models import DanswerQuotes
|
||||
from danswer.chat.models import LlmDoc
|
||||
from danswer.configs.constants import DocumentSource
|
||||
from danswer.llm.answering.stream_processing.quotes_processing import (
|
||||
process_model_tokens,
|
||||
QuotesProcessor,
|
||||
)
|
||||
|
||||
mock_docs = [
|
||||
@@ -25,179 +25,202 @@ mock_docs = [
|
||||
]
|
||||
|
||||
|
||||
tokens_with_quotes = [
|
||||
"{",
|
||||
"\n ",
|
||||
'"answer": "Yes',
|
||||
", Danswer allows",
|
||||
" customized prompts. This",
|
||||
" feature",
|
||||
" is currently being",
|
||||
" developed and implemente",
|
||||
"d to",
|
||||
" improve",
|
||||
" the accuracy",
|
||||
" of",
|
||||
" Language",
|
||||
" Models (",
|
||||
"LL",
|
||||
"Ms) for",
|
||||
" different",
|
||||
" companies",
|
||||
".",
|
||||
" The custom",
|
||||
"ized prompts feature",
|
||||
" woul",
|
||||
"d allow users to ad",
|
||||
"d person",
|
||||
"alized prom",
|
||||
"pts through",
|
||||
" an",
|
||||
" interface or",
|
||||
" metho",
|
||||
"d,",
|
||||
" which would then be used to",
|
||||
" train",
|
||||
" the LLM.",
|
||||
" This enhancement",
|
||||
" aims to make",
|
||||
" Danswer more",
|
||||
" adaptable to",
|
||||
" different",
|
||||
" business",
|
||||
" contexts",
|
||||
" by",
|
||||
" tail",
|
||||
"oring it",
|
||||
" to the specific language",
|
||||
" an",
|
||||
"d terminology",
|
||||
" used within",
|
||||
" a",
|
||||
" company.",
|
||||
" Additionally",
|
||||
",",
|
||||
" Danswer already",
|
||||
" supports creating",
|
||||
" custom AI",
|
||||
" Assistants with",
|
||||
" different",
|
||||
" prom",
|
||||
"pts and backing",
|
||||
" knowledge",
|
||||
" sets",
|
||||
",",
|
||||
" which",
|
||||
" is",
|
||||
" a form",
|
||||
" of prompt",
|
||||
" customization. However, it",
|
||||
"'s important to nLogging Details LiteLLM-Success Call: Noneote that some",
|
||||
" aspects",
|
||||
" of prompt",
|
||||
" customization,",
|
||||
" such as for",
|
||||
" Sl",
|
||||
"ack",
|
||||
"b",
|
||||
"ots, may",
|
||||
" still",
|
||||
" be in",
|
||||
" development or have",
|
||||
' limitations.",',
|
||||
'\n "quotes": [',
|
||||
'\n "We',
|
||||
" woul",
|
||||
"d like to ad",
|
||||
"d customized prompts for",
|
||||
" different",
|
||||
" companies to improve the accuracy of",
|
||||
" Language",
|
||||
" Model",
|
||||
" (LLM)",
|
||||
'.",\n "A',
|
||||
" new",
|
||||
" feature that",
|
||||
" allows users to add personalize",
|
||||
"d prompts.",
|
||||
" This would involve",
|
||||
" creating",
|
||||
" an interface or method for",
|
||||
" users to input",
|
||||
" their",
|
||||
" own",
|
||||
" prom",
|
||||
"pts,",
|
||||
" which would then be used to",
|
||||
' train the LLM.",',
|
||||
'\n "Create',
|
||||
" custom AI Assistants with",
|
||||
" different prompts and backing knowledge",
|
||||
' sets.",',
|
||||
'\n "This',
|
||||
" PR",
|
||||
" fixes",
|
||||
" https",
|
||||
"://github.com/dan",
|
||||
"swer-ai/dan",
|
||||
"swer/issues/1",
|
||||
"584",
|
||||
" by",
|
||||
" setting",
|
||||
" the system",
|
||||
" default",
|
||||
" prompt for",
|
||||
" sl",
|
||||
"ackbots const",
|
||||
"rained by",
|
||||
" ",
|
||||
"document sets",
|
||||
".",
|
||||
" It",
|
||||
" probably",
|
||||
" isn",
|
||||
"'t ideal",
|
||||
" -",
|
||||
" it",
|
||||
" might",
|
||||
" be pref",
|
||||
"erable to be",
|
||||
" able to select",
|
||||
" a prompt for",
|
||||
" the",
|
||||
" slackbot from",
|
||||
" the",
|
||||
" admin",
|
||||
" panel",
|
||||
" -",
|
||||
" but it sol",
|
||||
"ves the immediate problem",
|
||||
" of",
|
||||
" the slack",
|
||||
" listener",
|
||||
" cr",
|
||||
"ashing when",
|
||||
" configure",
|
||||
"d this",
|
||||
' way."\n ]',
|
||||
"\n}",
|
||||
"",
|
||||
]
|
||||
def _process_tokens(
|
||||
processor: QuotesProcessor, tokens: list[str]
|
||||
) -> tuple[str, list[str]]:
|
||||
"""Process a list of tokens and return the answer and quotes.
|
||||
|
||||
Args:
|
||||
processor: QuotesProcessor instance
|
||||
tokens: List of tokens to process
|
||||
|
||||
Returns:
|
||||
Tuple of (answer_text, list_of_quotes)
|
||||
"""
|
||||
answer = ""
|
||||
quotes: list[str] = []
|
||||
|
||||
# need to add a None to the end to simulate the end of the stream
|
||||
for token in tokens + [None]:
|
||||
for output in processor.process_token(token):
|
||||
if isinstance(output, DanswerAnswerPiece):
|
||||
if output.answer_piece:
|
||||
answer += output.answer_piece
|
||||
elif isinstance(output, DanswerQuotes):
|
||||
quotes.extend(q.quote for q in output.quotes)
|
||||
|
||||
return answer, quotes
|
||||
|
||||
|
||||
def test_process_model_tokens_answer() -> None:
|
||||
gen = process_model_tokens(tokens=iter(tokens_with_quotes), context_docs=mock_docs)
|
||||
tokens_with_quotes = [
|
||||
"{",
|
||||
"\n ",
|
||||
'"answer": "Yes',
|
||||
", Danswer allows",
|
||||
" customized prompts. This",
|
||||
" feature",
|
||||
" is currently being",
|
||||
" developed and implemente",
|
||||
"d to",
|
||||
" improve",
|
||||
" the accuracy",
|
||||
" of",
|
||||
" Language",
|
||||
" Models (",
|
||||
"LL",
|
||||
"Ms) for",
|
||||
" different",
|
||||
" companies",
|
||||
".",
|
||||
" The custom",
|
||||
"ized prompts feature",
|
||||
" woul",
|
||||
"d allow users to ad",
|
||||
"d person",
|
||||
"alized prom",
|
||||
"pts through",
|
||||
" an",
|
||||
" interface or",
|
||||
" metho",
|
||||
"d,",
|
||||
" which would then be used to",
|
||||
" train",
|
||||
" the LLM.",
|
||||
" This enhancement",
|
||||
" aims to make",
|
||||
" Danswer more",
|
||||
" adaptable to",
|
||||
" different",
|
||||
" business",
|
||||
" contexts",
|
||||
" by",
|
||||
" tail",
|
||||
"oring it",
|
||||
" to the specific language",
|
||||
" an",
|
||||
"d terminology",
|
||||
" used within",
|
||||
" a",
|
||||
" company.",
|
||||
" Additionally",
|
||||
",",
|
||||
" Danswer already",
|
||||
" supports creating",
|
||||
" custom AI",
|
||||
" Assistants with",
|
||||
" different",
|
||||
" prom",
|
||||
"pts and backing",
|
||||
" knowledge",
|
||||
" sets",
|
||||
",",
|
||||
" which",
|
||||
" is",
|
||||
" a form",
|
||||
" of prompt",
|
||||
" customization. However, it",
|
||||
"'s important to nLogging Details LiteLLM-Success Call: Noneote that some",
|
||||
" aspects",
|
||||
" of prompt",
|
||||
" customization,",
|
||||
" such as for",
|
||||
" Sl",
|
||||
"ack",
|
||||
"b",
|
||||
"ots, may",
|
||||
" still",
|
||||
" be in",
|
||||
" development or have",
|
||||
' limitations.",',
|
||||
'\n "quotes": [',
|
||||
'\n "We',
|
||||
" woul",
|
||||
"d like to ad",
|
||||
"d customized prompts for",
|
||||
" different",
|
||||
" companies to improve the accuracy of",
|
||||
" Language",
|
||||
" Model",
|
||||
" (LLM)",
|
||||
'.",\n "A',
|
||||
" new",
|
||||
" feature that",
|
||||
" allows users to add personalize",
|
||||
"d prompts.",
|
||||
" This would involve",
|
||||
" creating",
|
||||
" an interface or method for",
|
||||
" users to input",
|
||||
" their",
|
||||
" own",
|
||||
" prom",
|
||||
"pts,",
|
||||
" which would then be used to",
|
||||
' train the LLM.",',
|
||||
'\n "Create',
|
||||
" custom AI Assistants with",
|
||||
" different prompts and backing knowledge",
|
||||
' sets.",',
|
||||
'\n "This',
|
||||
" PR",
|
||||
" fixes",
|
||||
" https",
|
||||
"://github.com/dan",
|
||||
"swer-ai/dan",
|
||||
"swer/issues/1",
|
||||
"584",
|
||||
" by",
|
||||
" setting",
|
||||
" the system",
|
||||
" default",
|
||||
" prompt for",
|
||||
" sl",
|
||||
"ackbots const",
|
||||
"rained by",
|
||||
" ",
|
||||
"document sets",
|
||||
".",
|
||||
" It",
|
||||
" probably",
|
||||
" isn",
|
||||
"'t ideal",
|
||||
" -",
|
||||
" it",
|
||||
" might",
|
||||
" be pref",
|
||||
"erable to be",
|
||||
" able to select",
|
||||
" a prompt for",
|
||||
" the",
|
||||
" slackbot from",
|
||||
" the",
|
||||
" admin",
|
||||
" panel",
|
||||
" -",
|
||||
" but it sol",
|
||||
"ves the immediate problem",
|
||||
" of",
|
||||
" the slack",
|
||||
" listener",
|
||||
" cr",
|
||||
"ashing when",
|
||||
" configure",
|
||||
"d this",
|
||||
' way."\n ]',
|
||||
"\n}",
|
||||
"",
|
||||
]
|
||||
|
||||
processor = QuotesProcessor(context_docs=mock_docs)
|
||||
answer, quotes = _process_tokens(processor, tokens_with_quotes)
|
||||
|
||||
s_json = "".join(tokens_with_quotes)
|
||||
j = json.loads(s_json)
|
||||
expected_answer = j["answer"]
|
||||
actual = ""
|
||||
for o in gen:
|
||||
if isinstance(o, DanswerAnswerPiece):
|
||||
if o.answer_piece:
|
||||
actual += o.answer_piece
|
||||
|
||||
assert expected_answer == actual
|
||||
assert expected_answer == answer
|
||||
# NOTE: no quotes, since the docs don't match the quotes
|
||||
assert len(quotes) == 0
|
||||
|
||||
|
||||
def test_simple_json_answer() -> None:
|
||||
@@ -214,16 +237,11 @@ def test_simple_json_answer() -> None:
|
||||
"\n",
|
||||
"```",
|
||||
]
|
||||
gen = process_model_tokens(tokens=iter(tokens), context_docs=mock_docs)
|
||||
processor = QuotesProcessor(context_docs=mock_docs)
|
||||
answer, quotes = _process_tokens(processor, tokens)
|
||||
|
||||
expected_answer = "This is a simple answer."
|
||||
actual = "".join(
|
||||
o.answer_piece
|
||||
for o in gen
|
||||
if isinstance(o, DanswerAnswerPiece) and o.answer_piece
|
||||
)
|
||||
|
||||
assert expected_answer == actual
|
||||
assert "This is a simple answer." == answer
|
||||
assert len(quotes) == 0
|
||||
|
||||
|
||||
def test_json_answer_with_quotes() -> None:
|
||||
@@ -242,16 +260,21 @@ def test_json_answer_with_quotes() -> None:
|
||||
"\n",
|
||||
"```",
|
||||
]
|
||||
gen = process_model_tokens(tokens=iter(tokens), context_docs=mock_docs)
|
||||
processor = QuotesProcessor(context_docs=mock_docs)
|
||||
answer, quotes = _process_tokens(processor, tokens)
|
||||
|
||||
expected_answer = "This is a split answer."
|
||||
actual = "".join(
|
||||
o.answer_piece
|
||||
for o in gen
|
||||
if isinstance(o, DanswerAnswerPiece) and o.answer_piece
|
||||
)
|
||||
assert "This is a split answer." == answer
|
||||
assert len(quotes) == 0
|
||||
|
||||
assert expected_answer == actual
|
||||
|
||||
def test_json_answer_with_quotes_one_chunk() -> None:
|
||||
tokens = ['```json\n{"answer": "z",\n"quotes": ["Document"]\n}\n```']
|
||||
processor = QuotesProcessor(context_docs=mock_docs)
|
||||
answer, quotes = _process_tokens(processor, tokens)
|
||||
|
||||
assert "z" == answer
|
||||
assert len(quotes) == 1
|
||||
assert quotes[0] == "Document"
|
||||
|
||||
|
||||
def test_json_answer_split_tokens() -> None:
|
||||
@@ -271,16 +294,11 @@ def test_json_answer_split_tokens() -> None:
|
||||
"\n",
|
||||
"```",
|
||||
]
|
||||
gen = process_model_tokens(tokens=iter(tokens), context_docs=mock_docs)
|
||||
processor = QuotesProcessor(context_docs=mock_docs)
|
||||
answer, quotes = _process_tokens(processor, tokens)
|
||||
|
||||
expected_answer = "This is a split answer."
|
||||
actual = "".join(
|
||||
o.answer_piece
|
||||
for o in gen
|
||||
if isinstance(o, DanswerAnswerPiece) and o.answer_piece
|
||||
)
|
||||
|
||||
assert expected_answer == actual
|
||||
assert "This is a split answer." == answer
|
||||
assert len(quotes) == 0
|
||||
|
||||
|
||||
def test_lengthy_prefixed_json_with_quotes() -> None:
|
||||
@@ -298,23 +316,12 @@ def test_lengthy_prefixed_json_with_quotes() -> None:
|
||||
"\n",
|
||||
"```",
|
||||
]
|
||||
processor = QuotesProcessor(context_docs=mock_docs)
|
||||
answer, quotes = _process_tokens(processor, tokens)
|
||||
|
||||
gen = process_model_tokens(tokens=iter(tokens), context_docs=mock_docs)
|
||||
|
||||
actual_answer = ""
|
||||
actual_count = 0
|
||||
for o in gen:
|
||||
if isinstance(o, DanswerAnswerPiece):
|
||||
if o.answer_piece:
|
||||
actual_answer += o.answer_piece
|
||||
continue
|
||||
|
||||
if isinstance(o, DanswerQuotes):
|
||||
for q in o.quotes:
|
||||
assert q.quote == "Document"
|
||||
actual_count += 1
|
||||
assert "This is a simple answer." == actual_answer
|
||||
assert 1 == actual_count
|
||||
assert "This is a simple answer." == answer
|
||||
assert len(quotes) == 1
|
||||
assert quotes[0] == "Document"
|
||||
|
||||
|
||||
def test_prefixed_json_with_quotes() -> None:
|
||||
@@ -331,21 +338,9 @@ def test_prefixed_json_with_quotes() -> None:
|
||||
"\n",
|
||||
"```",
|
||||
]
|
||||
processor = QuotesProcessor(context_docs=mock_docs)
|
||||
answer, quotes = _process_tokens(processor, tokens)
|
||||
|
||||
gen = process_model_tokens(tokens=iter(tokens), context_docs=mock_docs)
|
||||
|
||||
actual_answer = ""
|
||||
actual_count = 0
|
||||
for o in gen:
|
||||
if isinstance(o, DanswerAnswerPiece):
|
||||
if o.answer_piece:
|
||||
actual_answer += o.answer_piece
|
||||
continue
|
||||
|
||||
if isinstance(o, DanswerQuotes):
|
||||
for q in o.quotes:
|
||||
assert q.quote == "Document"
|
||||
actual_count += 1
|
||||
|
||||
assert "This is a simple answer." == actual_answer
|
||||
assert 1 == actual_count
|
||||
assert "This is a simple answer." == answer
|
||||
assert len(quotes) == 1
|
||||
assert quotes[0] == "Document"
|
||||
|
405
backend/tests/unit/danswer/llm/answering/test_answer.py
Normal file
405
backend/tests/unit/danswer/llm/answering/test_answer.py
Normal file
@@ -0,0 +1,405 @@
|
||||
import json
|
||||
from typing import cast
|
||||
from unittest.mock import MagicMock
|
||||
from unittest.mock import Mock
|
||||
|
||||
import pytest
|
||||
from langchain_core.messages import AIMessageChunk
|
||||
from langchain_core.messages import BaseMessage
|
||||
from langchain_core.messages import HumanMessage
|
||||
from langchain_core.messages import SystemMessage
|
||||
from langchain_core.messages import ToolCall
|
||||
from langchain_core.messages import ToolCallChunk
|
||||
|
||||
from danswer.chat.models import CitationInfo
|
||||
from danswer.chat.models import DanswerAnswerPiece
|
||||
from danswer.chat.models import DanswerQuote
|
||||
from danswer.chat.models import DanswerQuotes
|
||||
from danswer.chat.models import LlmDoc
|
||||
from danswer.chat.models import StreamStopInfo
|
||||
from danswer.chat.models import StreamStopReason
|
||||
from danswer.llm.answering.answer import Answer
|
||||
from danswer.llm.answering.models import AnswerStyleConfig
|
||||
from danswer.llm.answering.models import PromptConfig
|
||||
from danswer.llm.answering.models import QuotesConfig
|
||||
from danswer.llm.interfaces import LLM
|
||||
from danswer.tools.force import ForceUseTool
|
||||
from danswer.tools.models import ToolCallFinalResult
|
||||
from danswer.tools.models import ToolCallKickoff
|
||||
from danswer.tools.models import ToolResponse
|
||||
from tests.unit.danswer.llm.answering.conftest import DEFAULT_SEARCH_ARGS
|
||||
from tests.unit.danswer.llm.answering.conftest import QUERY
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def answer_instance(
|
||||
mock_llm: LLM, answer_style_config: AnswerStyleConfig, prompt_config: PromptConfig
|
||||
) -> Answer:
|
||||
return Answer(
|
||||
question=QUERY,
|
||||
answer_style_config=answer_style_config,
|
||||
llm=mock_llm,
|
||||
prompt_config=prompt_config,
|
||||
force_use_tool=ForceUseTool(force_use=False, tool_name="", args=None),
|
||||
)
|
||||
|
||||
|
||||
def test_basic_answer(answer_instance: Answer) -> None:
|
||||
mock_llm = cast(Mock, answer_instance.llm)
|
||||
mock_llm.stream.return_value = [
|
||||
AIMessageChunk(content="This is a "),
|
||||
AIMessageChunk(content="mock answer."),
|
||||
]
|
||||
|
||||
output = list(answer_instance.processed_streamed_output)
|
||||
assert len(output) == 2
|
||||
assert isinstance(output[0], DanswerAnswerPiece)
|
||||
assert isinstance(output[1], DanswerAnswerPiece)
|
||||
|
||||
full_answer = "".join(
|
||||
piece.answer_piece
|
||||
for piece in output
|
||||
if isinstance(piece, DanswerAnswerPiece) and piece.answer_piece is not None
|
||||
)
|
||||
assert full_answer == "This is a mock answer."
|
||||
|
||||
assert answer_instance.llm_answer == "This is a mock answer."
|
||||
assert answer_instance.citations == []
|
||||
|
||||
assert mock_llm.stream.call_count == 1
|
||||
mock_llm.stream.assert_called_once_with(
|
||||
prompt=[
|
||||
SystemMessage(content="System prompt"),
|
||||
HumanMessage(content="Task prompt\n\nQUERY:\nTest question"),
|
||||
],
|
||||
tools=None,
|
||||
tool_choice=None,
|
||||
structured_response_format=None,
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"force_use_tool, expected_tool_args",
|
||||
[
|
||||
(
|
||||
ForceUseTool(force_use=False, tool_name="", args=None),
|
||||
DEFAULT_SEARCH_ARGS,
|
||||
),
|
||||
(
|
||||
ForceUseTool(
|
||||
force_use=True, tool_name="search", args={"query": "forced search"}
|
||||
),
|
||||
{"query": "forced search"},
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_answer_with_search_call(
|
||||
answer_instance: Answer,
|
||||
mock_search_results: list[LlmDoc],
|
||||
mock_search_tool: MagicMock,
|
||||
force_use_tool: ForceUseTool,
|
||||
expected_tool_args: dict,
|
||||
) -> None:
|
||||
answer_instance.tools = [mock_search_tool]
|
||||
answer_instance.force_use_tool = force_use_tool
|
||||
|
||||
# Set up the LLM mock to return search results and then an answer
|
||||
mock_llm = cast(Mock, answer_instance.llm)
|
||||
|
||||
stream_side_effect: list[list[BaseMessage]] = []
|
||||
|
||||
if not force_use_tool.force_use:
|
||||
tool_call_chunk = AIMessageChunk(content="")
|
||||
tool_call_chunk.tool_calls = [
|
||||
ToolCall(
|
||||
id="search",
|
||||
name="search",
|
||||
args=expected_tool_args,
|
||||
)
|
||||
]
|
||||
tool_call_chunk.tool_call_chunks = [
|
||||
ToolCallChunk(
|
||||
id="search",
|
||||
name="search",
|
||||
args=json.dumps(expected_tool_args),
|
||||
index=0,
|
||||
)
|
||||
]
|
||||
stream_side_effect.append([tool_call_chunk])
|
||||
|
||||
stream_side_effect.append(
|
||||
[
|
||||
AIMessageChunk(content="Based on the search results, "),
|
||||
AIMessageChunk(content="the answer is abc[1]. "),
|
||||
AIMessageChunk(content="This is some other stuff."),
|
||||
],
|
||||
)
|
||||
mock_llm.stream.side_effect = stream_side_effect
|
||||
|
||||
# Process the output
|
||||
output = list(answer_instance.processed_streamed_output)
|
||||
print(output)
|
||||
|
||||
# Updated assertions
|
||||
assert len(output) == 7
|
||||
assert output[0] == ToolCallKickoff(
|
||||
tool_name="search", tool_args=expected_tool_args
|
||||
)
|
||||
assert output[1] == ToolResponse(
|
||||
id="final_context_documents",
|
||||
response=mock_search_results,
|
||||
)
|
||||
assert output[2] == ToolCallFinalResult(
|
||||
tool_name="search",
|
||||
tool_args=expected_tool_args,
|
||||
tool_result=[json.loads(doc.model_dump_json()) for doc in mock_search_results],
|
||||
)
|
||||
assert output[3] == DanswerAnswerPiece(answer_piece="Based on the search results, ")
|
||||
expected_citation = CitationInfo(citation_num=1, document_id="doc1")
|
||||
assert output[4] == expected_citation
|
||||
assert output[5] == DanswerAnswerPiece(
|
||||
answer_piece="the answer is abc[[1]](https://example.com/doc1). "
|
||||
)
|
||||
assert output[6] == DanswerAnswerPiece(answer_piece="This is some other stuff.")
|
||||
|
||||
expected_answer = (
|
||||
"Based on the search results, "
|
||||
"the answer is abc[[1]](https://example.com/doc1). "
|
||||
"This is some other stuff."
|
||||
)
|
||||
full_answer = "".join(
|
||||
piece.answer_piece
|
||||
for piece in output
|
||||
if isinstance(piece, DanswerAnswerPiece) and piece.answer_piece is not None
|
||||
)
|
||||
assert full_answer == expected_answer
|
||||
|
||||
assert answer_instance.llm_answer == expected_answer
|
||||
assert len(answer_instance.citations) == 1
|
||||
assert answer_instance.citations[0] == expected_citation
|
||||
|
||||
# Verify LLM calls
|
||||
if not force_use_tool.force_use:
|
||||
assert mock_llm.stream.call_count == 2
|
||||
first_call, second_call = mock_llm.stream.call_args_list
|
||||
|
||||
# First call should include the search tool definition
|
||||
assert len(first_call.kwargs["tools"]) == 1
|
||||
assert (
|
||||
first_call.kwargs["tools"][0]
|
||||
== mock_search_tool.tool_definition.return_value
|
||||
)
|
||||
|
||||
# Second call should not include tools (as we're just generating the final answer)
|
||||
assert "tools" not in second_call.kwargs or not second_call.kwargs["tools"]
|
||||
# Second call should use the returned prompt from build_next_prompt
|
||||
assert (
|
||||
second_call.kwargs["prompt"]
|
||||
== mock_search_tool.build_next_prompt.return_value.build.return_value
|
||||
)
|
||||
|
||||
# Verify that tool_definition was called on the mock_search_tool
|
||||
mock_search_tool.tool_definition.assert_called_once()
|
||||
else:
|
||||
assert mock_llm.stream.call_count == 1
|
||||
|
||||
call = mock_llm.stream.call_args_list[0]
|
||||
assert (
|
||||
call.kwargs["prompt"]
|
||||
== mock_search_tool.build_next_prompt.return_value.build.return_value
|
||||
)
|
||||
|
||||
|
||||
def test_answer_with_search_no_tool_calling(
|
||||
answer_instance: Answer,
|
||||
mock_search_results: list[LlmDoc],
|
||||
mock_search_tool: MagicMock,
|
||||
) -> None:
|
||||
answer_instance.tools = [mock_search_tool]
|
||||
|
||||
# Set up the LLM mock to return an answer
|
||||
mock_llm = cast(Mock, answer_instance.llm)
|
||||
mock_llm.stream.return_value = [
|
||||
AIMessageChunk(content="Based on the search results, "),
|
||||
AIMessageChunk(content="the answer is abc[1]. "),
|
||||
AIMessageChunk(content="This is some other stuff."),
|
||||
]
|
||||
|
||||
# Force non-tool calling behavior
|
||||
answer_instance.using_tool_calling_llm = False
|
||||
|
||||
# Process the output
|
||||
output = list(answer_instance.processed_streamed_output)
|
||||
|
||||
# Assertions
|
||||
assert len(output) == 7
|
||||
assert output[0] == ToolCallKickoff(
|
||||
tool_name="search", tool_args=DEFAULT_SEARCH_ARGS
|
||||
)
|
||||
assert output[1] == ToolResponse(
|
||||
id="final_context_documents",
|
||||
response=mock_search_results,
|
||||
)
|
||||
assert output[2] == ToolCallFinalResult(
|
||||
tool_name="search",
|
||||
tool_args=DEFAULT_SEARCH_ARGS,
|
||||
tool_result=[json.loads(doc.model_dump_json()) for doc in mock_search_results],
|
||||
)
|
||||
assert output[3] == DanswerAnswerPiece(answer_piece="Based on the search results, ")
|
||||
expected_citation = CitationInfo(citation_num=1, document_id="doc1")
|
||||
assert output[4] == expected_citation
|
||||
assert output[5] == DanswerAnswerPiece(
|
||||
answer_piece="the answer is abc[[1]](https://example.com/doc1). "
|
||||
)
|
||||
assert output[6] == DanswerAnswerPiece(answer_piece="This is some other stuff.")
|
||||
|
||||
expected_answer = (
|
||||
"Based on the search results, "
|
||||
"the answer is abc[[1]](https://example.com/doc1). "
|
||||
"This is some other stuff."
|
||||
)
|
||||
assert answer_instance.llm_answer == expected_answer
|
||||
assert len(answer_instance.citations) == 1
|
||||
assert answer_instance.citations[0] == expected_citation
|
||||
|
||||
# Verify LLM calls
|
||||
assert mock_llm.stream.call_count == 1
|
||||
call_args = mock_llm.stream.call_args
|
||||
|
||||
# Verify that no tools were passed to the LLM
|
||||
assert "tools" not in call_args.kwargs or not call_args.kwargs["tools"]
|
||||
|
||||
# Verify that the prompt was built correctly
|
||||
assert (
|
||||
call_args.kwargs["prompt"]
|
||||
== mock_search_tool.build_next_prompt.return_value.build.return_value
|
||||
)
|
||||
|
||||
# Verify that get_args_for_non_tool_calling_llm was called on the mock_search_tool
|
||||
mock_search_tool.get_args_for_non_tool_calling_llm.assert_called_once_with(
|
||||
f"Task prompt\n\nQUERY:\n{QUERY}", [], answer_instance.llm
|
||||
)
|
||||
|
||||
# Verify that the search tool's run method was called
|
||||
mock_search_tool.run.assert_called_once()
|
||||
|
||||
|
||||
def test_answer_with_search_call_quotes_enabled(
|
||||
answer_instance: Answer,
|
||||
mock_search_results: list[LlmDoc],
|
||||
mock_search_tool: MagicMock,
|
||||
) -> None:
|
||||
answer_instance.tools = [mock_search_tool]
|
||||
answer_instance.force_use_tool = ForceUseTool(
|
||||
force_use=False, tool_name="", args=None
|
||||
)
|
||||
answer_instance.answer_style_config.citation_config = None
|
||||
answer_instance.answer_style_config.quotes_config = QuotesConfig()
|
||||
|
||||
# Set up the LLM mock to return search results and then an answer
|
||||
mock_llm = cast(Mock, answer_instance.llm)
|
||||
|
||||
tool_call_chunk = AIMessageChunk(content="")
|
||||
tool_call_chunk.tool_calls = [
|
||||
ToolCall(
|
||||
id="search",
|
||||
name="search",
|
||||
args=DEFAULT_SEARCH_ARGS,
|
||||
)
|
||||
]
|
||||
tool_call_chunk.tool_call_chunks = [
|
||||
ToolCallChunk(
|
||||
id="search",
|
||||
name="search",
|
||||
args=json.dumps(DEFAULT_SEARCH_ARGS),
|
||||
index=0,
|
||||
)
|
||||
]
|
||||
|
||||
# needs to be short due to the "anti-hallucination" check in QuotesProcessor
|
||||
answer_content = "z"
|
||||
quote_content = mock_search_results[0].content
|
||||
mock_llm.stream.side_effect = [
|
||||
[tool_call_chunk],
|
||||
[
|
||||
AIMessageChunk(
|
||||
content=(
|
||||
'{"answer": "'
|
||||
+ answer_content
|
||||
+ '", "quotes": ["'
|
||||
+ quote_content
|
||||
+ '"]}'
|
||||
)
|
||||
),
|
||||
],
|
||||
]
|
||||
|
||||
# Process the output
|
||||
output = list(answer_instance.processed_streamed_output)
|
||||
|
||||
# Assertions
|
||||
assert len(output) == 5
|
||||
assert output[0] == ToolCallKickoff(
|
||||
tool_name="search", tool_args=DEFAULT_SEARCH_ARGS
|
||||
)
|
||||
assert output[1] == ToolResponse(
|
||||
id="final_context_documents",
|
||||
response=mock_search_results,
|
||||
)
|
||||
assert output[2] == ToolCallFinalResult(
|
||||
tool_name="search",
|
||||
tool_args=DEFAULT_SEARCH_ARGS,
|
||||
tool_result=[json.loads(doc.model_dump_json()) for doc in mock_search_results],
|
||||
)
|
||||
assert output[3] == DanswerAnswerPiece(answer_piece=answer_content)
|
||||
assert output[4] == DanswerQuotes(
|
||||
quotes=[
|
||||
DanswerQuote(
|
||||
quote=quote_content,
|
||||
document_id=mock_search_results[0].document_id,
|
||||
link=mock_search_results[0].link,
|
||||
source_type=mock_search_results[0].source_type,
|
||||
semantic_identifier=mock_search_results[0].semantic_identifier,
|
||||
blurb=mock_search_results[0].blurb,
|
||||
)
|
||||
]
|
||||
)
|
||||
|
||||
assert answer_instance.llm_answer == answer_content
|
||||
|
||||
|
||||
def test_is_cancelled(answer_instance: Answer) -> None:
|
||||
# Set up the LLM mock to return multiple chunks
|
||||
mock_llm = Mock()
|
||||
answer_instance.llm = mock_llm
|
||||
mock_llm.stream.return_value = [
|
||||
AIMessageChunk(content="This is the "),
|
||||
AIMessageChunk(content="first part."),
|
||||
AIMessageChunk(content="This should not be seen."),
|
||||
]
|
||||
|
||||
# Create a mutable object to control is_connected behavior
|
||||
connection_status = {"connected": True}
|
||||
answer_instance.is_connected = lambda: connection_status["connected"]
|
||||
|
||||
# Process the output
|
||||
output = []
|
||||
for i, chunk in enumerate(answer_instance.processed_streamed_output):
|
||||
output.append(chunk)
|
||||
# Simulate disconnection after the second chunk
|
||||
if i == 1:
|
||||
connection_status["connected"] = False
|
||||
|
||||
assert len(output) == 3
|
||||
assert output[0] == DanswerAnswerPiece(answer_piece="This is the ")
|
||||
assert output[1] == DanswerAnswerPiece(answer_piece="first part.")
|
||||
assert output[2] == StreamStopInfo(stop_reason=StreamStopReason.CANCELLED)
|
||||
|
||||
# Verify that the stream was cancelled
|
||||
assert answer_instance.is_cancelled() is True
|
||||
|
||||
# Verify that the final answer only contains the streamed parts
|
||||
assert answer_instance.llm_answer == "This is the first part."
|
||||
|
||||
# Verify LLM calls
|
||||
mock_llm.stream.assert_called_once()
|
@@ -6,8 +6,11 @@ import pytest
|
||||
from pytest_mock import MockerFixture
|
||||
|
||||
from danswer.llm.answering.answer import Answer
|
||||
from danswer.llm.answering.models import AnswerStyleConfig
|
||||
from danswer.llm.answering.models import PromptConfig
|
||||
from danswer.one_shot_answer.answer_question import AnswerObjectIterator
|
||||
from danswer.tools.force import ForceUseTool
|
||||
from danswer.tools.tool_implementations.search.search_tool import SearchTool
|
||||
from tests.regression.answer_quality.run_qa import _process_and_write_query_results
|
||||
|
||||
|
||||
@@ -24,39 +27,43 @@ from tests.regression.answer_quality.run_qa import _process_and_write_query_resu
|
||||
},
|
||||
],
|
||||
)
|
||||
def test_skip_gen_ai_answer_generation_flag(config: dict[str, Any]) -> None:
|
||||
search_tool = Mock()
|
||||
search_tool.name = "search"
|
||||
search_tool.run = Mock()
|
||||
search_tool.run.return_value = [Mock()]
|
||||
def test_skip_gen_ai_answer_generation_flag(
|
||||
config: dict[str, Any],
|
||||
mock_search_tool: SearchTool,
|
||||
answer_style_config: AnswerStyleConfig,
|
||||
prompt_config: PromptConfig,
|
||||
) -> None:
|
||||
question = config["question"]
|
||||
skip_gen_ai_answer_generation = config["skip_gen_ai_answer_generation"]
|
||||
|
||||
mock_llm = Mock()
|
||||
mock_llm.config = Mock()
|
||||
mock_llm.config.model_name = "gpt-4o-mini"
|
||||
mock_llm.stream = Mock()
|
||||
mock_llm.stream.return_value = [Mock()]
|
||||
answer = Answer(
|
||||
question=config["question"],
|
||||
answer_style_config=Mock(),
|
||||
prompt_config=Mock(),
|
||||
question=question,
|
||||
answer_style_config=answer_style_config,
|
||||
prompt_config=prompt_config,
|
||||
llm=mock_llm,
|
||||
single_message_history="history",
|
||||
tools=[search_tool],
|
||||
tools=[mock_search_tool],
|
||||
force_use_tool=(
|
||||
ForceUseTool(
|
||||
tool_name=search_tool.name,
|
||||
args={"query": config["question"]},
|
||||
tool_name=mock_search_tool.name,
|
||||
args={"query": question},
|
||||
force_use=True,
|
||||
)
|
||||
),
|
||||
skip_explicit_tool_calling=True,
|
||||
return_contexts=True,
|
||||
skip_gen_ai_answer_generation=config["skip_gen_ai_answer_generation"],
|
||||
skip_gen_ai_answer_generation=skip_gen_ai_answer_generation,
|
||||
)
|
||||
count = 0
|
||||
for _ in cast(AnswerObjectIterator, answer.processed_streamed_output):
|
||||
count += 1
|
||||
assert count == 2
|
||||
if not config["skip_gen_ai_answer_generation"]:
|
||||
assert count == 3 if skip_gen_ai_answer_generation else 4
|
||||
if not skip_gen_ai_answer_generation:
|
||||
mock_llm.stream.assert_called_once()
|
||||
else:
|
||||
mock_llm.stream.assert_not_called()
|
||||
|
@@ -5,14 +5,18 @@ from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
|
||||
from danswer.tools.custom.custom_tool import (
|
||||
from danswer.tools.models import DynamicSchemaInfo
|
||||
from danswer.tools.models import ToolResponse
|
||||
from danswer.tools.tool_implementations.custom.custom_tool import (
|
||||
build_custom_tools_from_openapi_schema_and_headers,
|
||||
)
|
||||
from danswer.tools.custom.custom_tool import CUSTOM_TOOL_RESPONSE_ID
|
||||
from danswer.tools.custom.custom_tool import CustomToolCallSummary
|
||||
from danswer.tools.custom.custom_tool import validate_openapi_schema
|
||||
from danswer.tools.models import DynamicSchemaInfo
|
||||
from danswer.tools.tool import ToolResponse
|
||||
from danswer.tools.tool_implementations.custom.custom_tool import (
|
||||
CUSTOM_TOOL_RESPONSE_ID,
|
||||
)
|
||||
from danswer.tools.tool_implementations.custom.custom_tool import CustomToolCallSummary
|
||||
from danswer.tools.tool_implementations.custom.custom_tool import (
|
||||
validate_openapi_schema,
|
||||
)
|
||||
from danswer.utils.headers import HeaderItemDict
|
||||
|
||||
|
||||
@@ -78,7 +82,7 @@ class TestCustomTool(unittest.TestCase):
|
||||
chat_session_id=uuid.uuid4(), message_id=20
|
||||
)
|
||||
|
||||
@patch("danswer.tools.custom.custom_tool.requests.request")
|
||||
@patch("danswer.tools.tool_implementations.custom.custom_tool.requests.request")
|
||||
def test_custom_tool_run_get(self, mock_request: unittest.mock.MagicMock) -> None:
|
||||
"""
|
||||
Test the GET method of a custom tool.
|
||||
@@ -106,7 +110,7 @@ class TestCustomTool(unittest.TestCase):
|
||||
"Tool name in response does not match expected value",
|
||||
)
|
||||
|
||||
@patch("danswer.tools.custom.custom_tool.requests.request")
|
||||
@patch("danswer.tools.tool_implementations.custom.custom_tool.requests.request")
|
||||
def test_custom_tool_run_post(self, mock_request: unittest.mock.MagicMock) -> None:
|
||||
"""
|
||||
Test the POST method of a custom tool.
|
||||
@@ -136,7 +140,7 @@ class TestCustomTool(unittest.TestCase):
|
||||
"Tool name in response does not match expected value",
|
||||
)
|
||||
|
||||
@patch("danswer.tools.custom.custom_tool.requests.request")
|
||||
@patch("danswer.tools.tool_implementations.custom.custom_tool.requests.request")
|
||||
def test_custom_tool_with_headers(
|
||||
self, mock_request: unittest.mock.MagicMock
|
||||
) -> None:
|
||||
@@ -164,7 +168,7 @@ class TestCustomTool(unittest.TestCase):
|
||||
"GET", expected_url, json=None, headers=expected_headers
|
||||
)
|
||||
|
||||
@patch("danswer.tools.custom.custom_tool.requests.request")
|
||||
@patch("danswer.tools.tool_implementations.custom.custom_tool.requests.request")
|
||||
def test_custom_tool_with_empty_headers(
|
||||
self, mock_request: unittest.mock.MagicMock
|
||||
) -> None:
|
||||
|
Reference in New Issue
Block a user