Internet Search Tool (#1666)

---------

Co-authored-by: Weves <chrisweaver101@gmail.com>
This commit is contained in:
rashad-danswer 2024-07-06 18:01:24 -07:00 committed by GitHub
parent e06f8a0a4b
commit 146f85936b
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
36 changed files with 718 additions and 115 deletions

View File

@ -42,6 +42,11 @@ PYTHONPATH=./backend
PYTHONUNBUFFERED=1
# Internet Search
BING_API_KEY=<REPLACE THIS>
# Enable the full set of Danswer Enterprise Edition features
# NOTE: DO NOT ENABLE THIS UNLESS YOU HAVE A PAID ENTERPRISE LICENSE (or if you are using this for local testing/development)
ENABLE_PAID_ENTERPRISE_EDITION_FEATURES=False
ENABLE_PAID_ENTERPRISE_EDITION_FEATURES=False

View File

@ -0,0 +1,23 @@
"""added is_internet to DBDoc
Revision ID: 4505fd7302e1
Revises: c18cdf4b497e
Create Date: 2024-06-18 20:46:09.095034
"""
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision = "4505fd7302e1"
down_revision = "c18cdf4b497e"
def upgrade() -> None:
op.add_column("search_doc", sa.Column("is_internet", sa.Boolean(), nullable=True))
op.add_column("tool", sa.Column("display_name", sa.String(), nullable=True))
def downgrade() -> None:
op.drop_column("tool", "display_name")
op.drop_column("search_doc", "is_internet")

View File

@ -14,6 +14,7 @@ from danswer.chat.models import LlmDoc
from danswer.chat.models import LLMRelevanceFilterResponse
from danswer.chat.models import QADocsResponse
from danswer.chat.models import StreamingError
from danswer.configs.chat_configs import BING_API_KEY
from danswer.configs.chat_configs import CHAT_TARGET_CHUNK_PERCENTAGE
from danswer.configs.chat_configs import DISABLE_LLM_CHOOSE_SEARCH
from danswer.configs.chat_configs import MAX_CHUNKS_FED_TO_CHAT
@ -53,10 +54,13 @@ from danswer.llm.factory import get_main_llm_from_tuple
from danswer.llm.interfaces import LLMConfig
from danswer.llm.utils import get_default_llm_tokenizer
from danswer.search.enums import OptionalSearchSetting
from danswer.search.enums import QueryFlow
from danswer.search.enums import SearchType
from danswer.search.retrieval.search_runner import inference_documents_from_ids
from danswer.search.utils import chunks_or_sections_to_search_docs
from danswer.search.utils import dedupe_documents
from danswer.search.utils import drop_llm_indices
from danswer.search.utils import internet_search_response_to_search_docs
from danswer.server.query_and_chat.models import ChatMessageDetail
from danswer.server.query_and_chat.models import CreateChatMessageRequest
from danswer.server.utils import get_json_line
@ -68,6 +72,11 @@ from danswer.tools.force import ForceUseTool
from danswer.tools.images.image_generation_tool import IMAGE_GENERATION_RESPONSE_ID
from danswer.tools.images.image_generation_tool import ImageGenerationResponse
from danswer.tools.images.image_generation_tool import ImageGenerationTool
from danswer.tools.internet_search.internet_search_tool import (
INTERNET_SEARCH_RESPONSE_ID,
)
from danswer.tools.internet_search.internet_search_tool import InternetSearchResponse
from danswer.tools.internet_search.internet_search_tool import InternetSearchTool
from danswer.tools.search.search_tool import SEARCH_RESPONSE_SUMMARY_ID
from danswer.tools.search.search_tool import SearchResponseSummary
from danswer.tools.search.search_tool import SearchTool
@ -145,6 +154,37 @@ def _handle_search_tool_response_summary(
)
def _handle_internet_search_tool_response_summary(
packet: ToolResponse,
db_session: Session,
) -> tuple[QADocsResponse, list[DbSearchDoc]]:
internet_search_response = cast(InternetSearchResponse, packet.response)
server_search_docs = internet_search_response_to_search_docs(
internet_search_response
)
reference_db_search_docs = [
create_db_search_doc(server_search_doc=doc, db_session=db_session)
for doc in server_search_docs
]
response_docs = [
translate_db_search_doc_to_server_search_doc(db_search_doc)
for db_search_doc in reference_db_search_docs
]
return (
QADocsResponse(
rephrased_query=internet_search_response.revised_query,
top_documents=response_docs,
predicted_flow=QueryFlow.QUESTION_ANSWER,
predicted_search=SearchType.HYBRID,
applied_source_filters=[],
applied_time_cutoff=None,
recency_bias_multiplier=1.0,
),
reference_db_search_docs,
)
def _check_should_force_search(
new_msg_req: CreateChatMessageRequest,
) -> ForceUseTool | None:
@ -172,7 +212,7 @@ def _check_should_force_search(
args = {"query": new_msg_req.message}
return ForceUseTool(
tool_name=SearchTool.NAME,
tool_name=SearchTool._NAME,
args=args,
)
return None
@ -476,6 +516,15 @@ def stream_chat_message_objects(
additional_headers=litellm_additional_headers,
)
]
elif tool_cls.__name__ == InternetSearchTool.__name__:
bing_api_key = BING_API_KEY
if not bing_api_key:
raise ValueError(
"Internet search tool requires a Bing API key, please contact your Danswer admin to get it added!"
)
tool_dict[db_tool_model.id] = [
InternetSearchTool(api_key=bing_api_key)
]
continue
@ -582,6 +631,15 @@ def stream_chat_message_objects(
yield ImageGenerationDisplay(
file_ids=[str(file_id) for file_id in file_ids]
)
elif packet.id == INTERNET_SEARCH_RESPONSE_ID:
(
qa_docs_response,
reference_db_search_docs,
) = _handle_internet_search_tool_response_summary(
packet=packet,
db_session=db_session,
)
yield qa_docs_response
elif packet.id == CUSTOM_TOOL_RESPONSE_ID:
custom_tool_response = cast(CustomToolCallSummary, packet.response)
yield CustomToolResponse(
@ -623,7 +681,7 @@ def stream_chat_message_objects(
tool_name_to_tool_id: dict[str, int] = {}
for tool_id, tool_list in tool_dict.items():
for tool in tool_list:
tool_name_to_tool_id[tool.name()] = tool_id
tool_name_to_tool_id[tool.name] = tool_id
gen_ai_response_message = partial_response(
message=answer.llm_answer,

View File

@ -78,3 +78,6 @@ STOP_STREAM_PAT = os.environ.get("STOP_STREAM_PAT") or None
# The backend logic for this being True isn't fully supported yet
HARD_DELETE_CHATS = False
# Internet Search
BING_API_KEY = os.environ.get("BING_API_KEY") or None

View File

@ -104,6 +104,7 @@ class DocumentSource(str, Enum):
R2 = "r2"
GOOGLE_CLOUD_STORAGE = "google_cloud_storage"
OCI_STORAGE = "oci_storage"
NOT_APPLICABLE = "not_applicable"
class BlobType(str, Enum):
@ -112,6 +113,9 @@ class BlobType(str, Enum):
GOOGLE_CLOUD_STORAGE = "google_cloud_storage"
OCI_STORAGE = "oci_storage"
# Special case, for internet search
NOT_APPLICABLE = "not_applicable"
class DocumentIndexType(str, Enum):
COMBINED = "combined" # Vespa

View File

@ -504,6 +504,7 @@ def create_db_search_doc(
updated_at=server_search_doc.updated_at,
primary_owners=server_search_doc.primary_owners,
secondary_owners=server_search_doc.secondary_owners,
is_internet=server_search_doc.is_internet,
)
db_session.add(db_search_doc)
@ -542,6 +543,7 @@ def translate_db_search_doc_to_server_search_doc(
secondary_owners=(
db_search_doc.secondary_owners if not remove_doc_content else []
),
is_internet=db_search_doc.is_internet,
)

View File

@ -645,6 +645,7 @@ class SearchDoc(Base):
secondary_owners: Mapped[list[str] | None] = mapped_column(
postgresql.ARRAY(String), nullable=True
)
is_internet: Mapped[bool] = mapped_column(Boolean, default=False, nullable=True)
chat_messages = relationship(
"ChatMessage",
@ -990,6 +991,7 @@ class Tool(Base):
# ID of the tool in the codebase, only applies for in-code tools.
# tools defined via the UI will have this as None
in_code_tool_id: Mapped[str | None] = mapped_column(String, nullable=True)
display_name: Mapped[str] = mapped_column(String, nullable=True)
# OpenAPI scheme for the tool. Only applies to tools defined via the UI.
openapi_schema: Mapped[dict[str, Any] | None] = mapped_column(

View File

@ -12,6 +12,7 @@ from sqlalchemy import update
from sqlalchemy.orm import Session
from danswer.auth.schemas import UserRole
from danswer.configs.chat_configs import BING_API_KEY
from danswer.db.constants import SLACK_BOT_PERSONA_PREFIX
from danswer.db.engine import get_sqlalchemy_engine
from danswer.db.models import DocumentSet
@ -360,6 +361,10 @@ def upsert_persona(
if not prompts and prompt_ids:
raise ValueError("prompts not found")
# ensure all specified tools are valid
if tools:
validate_persona_tools(tools)
if persona:
if not default_persona and persona.default_persona:
raise ValueError("Cannot update default persona with non-default.")
@ -457,6 +462,14 @@ def update_persona_visibility(
db_session.commit()
def validate_persona_tools(tools: list[Tool]) -> None:
for tool in tools:
if tool.name == "InternetSearchTool" and not BING_API_KEY:
raise ValueError(
"Bing API key not found, please contact your Danswer admin to get it added!"
)
def check_user_can_edit_persona(user: User | None, persona: Persona) -> None:
# if user is None, assume that no-auth is turned on
if user is None:

View File

@ -43,6 +43,7 @@ from danswer.tools.images.image_generation_tool import IMAGE_GENERATION_RESPONSE
from danswer.tools.images.image_generation_tool import ImageGenerationResponse
from danswer.tools.images.image_generation_tool import ImageGenerationTool
from danswer.tools.images.prompt import build_image_generation_user_prompt
from danswer.tools.internet_search.internet_search_tool import InternetSearchTool
from danswer.tools.message import build_tool_message
from danswer.tools.message import ToolCallSummary
from danswer.tools.search.search_tool import FINAL_CONTEXT_DOCUMENTS
@ -58,7 +59,11 @@ from danswer.tools.tool_runner import (
from danswer.tools.tool_runner import ToolCallFinalResult
from danswer.tools.tool_runner import ToolCallKickoff
from danswer.tools.tool_runner import ToolRunner
from danswer.tools.tool_selection import select_single_tool_for_non_tool_calling_llm
from danswer.tools.utils import explicit_tool_calling_supported
from danswer.utils.logger import setup_logger
logger = setup_logger()
def _get_answer_stream_processor(
@ -228,7 +233,7 @@ class Answer:
tool_call_requests = tool_call_chunk.tool_calls
for tool_call_request in tool_call_requests:
tool = [
tool for tool in self.tools if tool.name() == tool_call_request["name"]
tool for tool in self.tools if tool.name == tool_call_request["name"]
][0]
tool_args = (
self.force_use_tool.args
@ -247,15 +252,14 @@ class Answer:
),
)
if tool.name() == SearchTool.NAME:
if tool.name in {SearchTool._NAME, InternetSearchTool._NAME}:
self._update_prompt_builder_for_search_tool(prompt_builder, [])
elif tool.name() == ImageGenerationTool.NAME:
elif tool.name == ImageGenerationTool._NAME:
prompt_builder.update_user_prompt(
build_image_generation_user_prompt(
query=self.question,
)
)
yield tool_runner.tool_final_result()
prompt = prompt_builder.build(tool_call_summary=tool_call_summary)
@ -281,7 +285,7 @@ class Answer:
[
tool
for tool in self.tools
if tool.name() == self.force_use_tool.tool_name
if tool.name == self.force_use_tool.tool_name
]
),
None,
@ -301,21 +305,39 @@ class Answer:
)
if tool_args is None:
raise RuntimeError(f"Tool '{tool.name()}' did not return args")
raise RuntimeError(f"Tool '{tool.name}' did not return args")
chosen_tool_and_args = (tool, tool_args)
else:
all_tool_args = check_which_tools_should_run_for_non_tool_calling_llm(
tool_options = check_which_tools_should_run_for_non_tool_calling_llm(
tools=self.tools,
query=self.question,
history=self.message_history,
llm=self.llm,
)
for ind, args in enumerate(all_tool_args):
if args is not None:
chosen_tool_and_args = (self.tools[ind], args)
# for now, just pick the first tool selected
break
available_tools_and_args = [
(self.tools[ind], args)
for ind, args in enumerate(tool_options)
if args is not None
]
logger.info(
f"Selecting single tool from tools: {[(tool.name, args) for tool, args in available_tools_and_args]}"
)
chosen_tool_and_args = (
select_single_tool_for_non_tool_calling_llm(
tools_and_args=available_tools_and_args,
history=self.message_history,
query=self.question,
llm=self.llm,
)
if available_tools_and_args
else None
)
logger.info(f"Chosen tool: {chosen_tool_and_args}")
if not chosen_tool_and_args:
prompt_builder.update_system_prompt(
@ -336,7 +358,7 @@ class Answer:
tool_runner = ToolRunner(tool, tool_args)
yield tool_runner.kickoff()
if tool.name() == SearchTool.NAME:
if tool.name in {SearchTool._NAME, InternetSearchTool._NAME}:
final_context_documents = None
for response in tool_runner.tool_responses():
if response.id == FINAL_CONTEXT_DOCUMENTS:
@ -344,12 +366,14 @@ class Answer:
yield response
if final_context_documents is None:
raise RuntimeError("SearchTool did not return final context documents")
raise RuntimeError(
f"{tool.name} did not return final context documents"
)
self._update_prompt_builder_for_search_tool(
prompt_builder, final_context_documents
)
elif tool.name() == ImageGenerationTool.NAME:
elif tool.name == ImageGenerationTool._NAME:
img_urls = []
for response in tool_runner.tool_responses():
if response.id == IMAGE_GENERATION_RESPONSE_ID:
@ -371,7 +395,7 @@ class Answer:
HumanMessage(
content=build_user_message_for_custom_tool_for_non_tool_calling_llm(
self.question,
tool.name(),
tool.name,
*tool_runner.tool_responses(),
)
)

View File

@ -193,7 +193,7 @@ def stream_answer_objects(
single_message_history=history_str,
tools=[search_tool],
force_use_tool=ForceUseTool(
tool_name=search_tool.name(),
tool_name=search_tool.name,
args={"query": rephrased_query},
),
# for now, don't use tool calling for this flow, as we haven't

View File

@ -144,6 +144,23 @@ Follow Up Input: {{question}}
Standalone question (Respond with only the short combined query):
""".strip()
INTERNET_SEARCH_QUERY_REPHRASE = f"""
Given the following conversation and a follow up input, rephrase the follow up into a SHORT, \
standalone query suitable for an internet search engine.
IMPORTANT: If a specific query might limit results, keep it broad. \
If a broad query might yield too many results, make it detailed.
If there is a clear change in topic, ensure the query reflects the new topic accurately.
Strip out any information that is not relevant for the internet search.
{GENERAL_SEP_PAT}
Chat History:
{{chat_history}}
{GENERAL_SEP_PAT}
Follow Up Input: {{question}}
Internet Search Query (Respond with a detailed and specific query):
""".strip()
# The below prompts are retired
NO_SEARCH = "No Search"

View File

@ -199,6 +199,7 @@ class SearchDoc(BaseModel):
updated_at: datetime | None
primary_owners: list[str] | None
secondary_owners: list[str] | None
is_internet: bool = False
def dict(self, *args: list, **kwargs: dict[str, Any]) -> dict[str, Any]: # type: ignore
initial_dict = super().dict(*args, **kwargs) # type: ignore

View File

@ -1,11 +1,13 @@
from collections.abc import Sequence
from typing import TypeVar
from danswer.configs.constants import DocumentSource
from danswer.db.models import SearchDoc as DBSearchDoc
from danswer.search.models import InferenceChunk
from danswer.search.models import InferenceSection
from danswer.search.models import SavedSearchDoc
from danswer.search.models import SearchDoc
from danswer.tools.internet_search.internet_search_tool import InternetSearchResponse
T = TypeVar("T", InferenceSection, InferenceChunk, SearchDoc)
@ -57,6 +59,7 @@ def chunks_or_sections_to_search_docs(
updated_at=chunk.updated_at,
primary_owners=chunk.primary_owners,
secondary_owners=chunk.secondary_owners,
is_internet=False,
)
for chunk in chunks
]
@ -64,3 +67,28 @@ def chunks_or_sections_to_search_docs(
else []
)
return search_docs
def internet_search_response_to_search_docs(
internet_search_response: InternetSearchResponse,
) -> list[SearchDoc]:
return [
SearchDoc(
document_id=doc.link,
chunk_ind=-1,
semantic_identifier=doc.title,
link=doc.link,
blurb=doc.snippet,
source_type=DocumentSource.NOT_APPLICABLE,
boost=0,
hidden=False,
metadata={},
score=None,
match_highlights=[],
updated_at=None,
primary_owners=[],
secondary_owners=[],
is_internet=True,
)
for doc in internet_search_response.internet_results
]

View File

@ -74,11 +74,12 @@ def multilingual_query_expansion(
def get_contextual_rephrase_messages(
question: str,
history_str: str,
prompt_template: str = HISTORY_QUERY_REPHRASE,
) -> list[dict[str, str]]:
messages = [
{
"role": "user",
"content": HISTORY_QUERY_REPHRASE.format(
"content": prompt_template.format(
question=question, chat_history=history_str
),
},
@ -94,6 +95,7 @@ def history_based_query_rephrase(
size_heuristic: int = 200,
punctuation_heuristic: int = 10,
skip_first_rephrase: bool = False,
prompt_template: str = HISTORY_QUERY_REPHRASE,
) -> str:
# Globally disabled, just use the exact user query
if DISABLE_LLM_QUERY_REPHRASE:
@ -119,7 +121,7 @@ def history_based_query_rephrase(
)
prompt_msgs = get_contextual_rephrase_messages(
question=query, history_str=history_str
question=query, history_str=history_str, prompt_template=prompt_template
)
filled_llm_prompt = dict_based_prompt_to_langchain_prompt(prompt_msgs)

View File

@ -10,6 +10,7 @@ class ToolSnapshot(BaseModel):
name: str
description: str
definition: dict[str, Any] | None
display_name: str
in_code_tool_id: str | None
@classmethod
@ -19,5 +20,6 @@ class ToolSnapshot(BaseModel):
name=tool.name,
description=tool.description,
definition=tool.openapi_schema,
display_name=tool.display_name or tool.name,
in_code_tool_id=tool.in_code_tool_id,
)

View File

@ -1,3 +1,4 @@
import os
from typing import Type
from typing import TypedDict
@ -9,6 +10,7 @@ from sqlalchemy.orm import Session
from danswer.db.models import Persona
from danswer.db.models import Tool as ToolDBModel
from danswer.tools.images.image_generation_tool import ImageGenerationTool
from danswer.tools.internet_search.internet_search_tool import InternetSearchTool
from danswer.tools.search.search_tool import SearchTool
from danswer.tools.tool import Tool
from danswer.utils.logger import setup_logger
@ -20,22 +22,41 @@ class InCodeToolInfo(TypedDict):
cls: Type[Tool]
description: str
in_code_tool_id: str
display_name: str
BUILT_IN_TOOLS: list[InCodeToolInfo] = [
{
"cls": SearchTool,
"description": "The Search Tool allows the Assistant to search through connected knowledge to help build an answer.",
"in_code_tool_id": SearchTool.__name__,
},
{
"cls": ImageGenerationTool,
"description": (
InCodeToolInfo(
cls=SearchTool,
description="The Search Tool allows the Assistant to search through connected knowledge to help build an answer.",
in_code_tool_id=SearchTool.__name__,
display_name=SearchTool._DISPLAY_NAME,
),
InCodeToolInfo(
cls=ImageGenerationTool,
description=(
"The Image Generation Tool allows the assistant to use DALL-E 3 to generate images. "
"The tool will be used when the user asks the assistant to generate an image."
),
"in_code_tool_id": ImageGenerationTool.__name__,
},
in_code_tool_id=ImageGenerationTool.__name__,
display_name=ImageGenerationTool._DISPLAY_NAME,
),
# don't show the InternetSearchTool as an option if BING_API_KEY is not available
*(
[
InCodeToolInfo(
cls=InternetSearchTool,
description=(
"The Internet Search Tool allows the assistant "
"to perform internet searches for up-to-date information."
),
in_code_tool_id=InternetSearchTool.__name__,
display_name=InternetSearchTool._DISPLAY_NAME,
)
]
if os.environ.get("BING_API_KEY")
else []
),
]
@ -55,12 +76,14 @@ def load_builtin_tools(db_session: Session) -> None:
# Update existing tool
tool.name = tool_name
tool.description = tool_info["description"]
tool.display_name = tool_info["display_name"]
logger.info(f"Updated tool: {tool_name}")
else:
# Add new tool
new_tool = ToolDBModel(
name=tool_name,
description=tool_info["description"],
display_name=tool_info["display_name"],
in_code_tool_id=tool_info["in_code_tool_id"],
)
db_session.add(new_tool)

View File

@ -44,11 +44,20 @@ class CustomTool(Tool):
self._tool_definition = self._method_spec.to_tool_definition()
self._name = self._method_spec.name
self.description = self._method_spec.summary
self._description = self._method_spec.summary
@property
def name(self) -> str:
return self._name
@property
def description(self) -> str:
return self._description
@property
def display_name(self) -> str:
return self._name
"""For LLMs which support explicit tool calling"""
def tool_definition(self) -> dict:
@ -77,7 +86,7 @@ class CustomTool(Tool):
content=SHOULD_USE_CUSTOM_TOOL_USER_PROMPT.format(
history=history,
query=query,
tool_name=self.name(),
tool_name=self.name,
tool_description=self.description,
)
),
@ -93,7 +102,7 @@ class CustomTool(Tool):
content=TOOL_ARG_USER_PROMPT.format(
history=history,
query=query,
tool_name=self.name(),
tool_name=self.name,
tool_description=self.description,
tool_args=self.tool_definition()["function"]["parameters"],
)
@ -121,7 +130,7 @@ class CustomTool(Tool):
# pretend like nothing happened if not parse-able
logger.error(
f"Failed to parse args for '{self.name()}' tool. Recieved: {args_result_str}"
f"Failed to parse args for '{self.name}' tool. Recieved: {args_result_str}"
)
return None

View File

@ -37,4 +37,4 @@ def filter_tools_for_force_tool_use(
if not force_use_tool:
return tools
return [tool for tool in tools if tool.name() == force_use_tool.tool_name]
return [tool for tool in tools if tool.name == force_use_tool.tool_name]

View File

@ -55,7 +55,9 @@ class ImageGenerationResponse(BaseModel):
class ImageGenerationTool(Tool):
NAME = "run_image_generation"
_NAME = "run_image_generation"
_DESCRIPTION = "Generate an image from a prompt."
_DISPLAY_NAME = "Image Generation Tool"
def __init__(
self,
@ -75,15 +77,24 @@ class ImageGenerationTool(Tool):
self.additional_headers = additional_headers
@property
def name(self) -> str:
return self.NAME
return self._NAME
@property
def description(self) -> str:
return self._DESCRIPTION
@property
def display_name(self) -> str:
return self._DISPLAY_NAME
def tool_definition(self) -> dict:
return {
"type": "function",
"function": {
"name": self.name(),
"description": "Generate an image from a prompt",
"name": self.name,
"description": self.description,
"parameters": {
"type": "object",
"properties": {

View File

@ -0,0 +1,217 @@
import json
from collections.abc import Generator
from datetime import datetime
from typing import Any
from typing import cast
import httpx
from pydantic import BaseModel
from danswer.chat.chat_utils import combine_message_chain
from danswer.chat.models import LlmDoc
from danswer.configs.constants import DocumentSource
from danswer.configs.model_configs import GEN_AI_HISTORY_CUTOFF
from danswer.dynamic_configs.interface import JSON_ro
from danswer.llm.answering.models import PreviousMessage
from danswer.llm.interfaces import LLM
from danswer.llm.utils import message_to_string
from danswer.prompts.chat_prompts import INTERNET_SEARCH_QUERY_REPHRASE
from danswer.prompts.constants import GENERAL_SEP_PAT
from danswer.secondary_llm_flows.query_expansion import history_based_query_rephrase
from danswer.tools.search.search_tool import FINAL_CONTEXT_DOCUMENTS
from danswer.tools.tool import Tool
from danswer.tools.tool import ToolResponse
from danswer.utils.logger import setup_logger
logger = setup_logger()
INTERNET_SEARCH_RESPONSE_ID = "internet_search_response"
YES_INTERNET_SEARCH = "Yes Internet Search"
SKIP_INTERNET_SEARCH = "Skip Internet Search"
INTERNET_SEARCH_TEMPLATE = f"""
Given the conversation history and a follow up query, determine if the system should call \
an external internet search tool to better answer the latest user input.
Your default response is {SKIP_INTERNET_SEARCH}.
Respond "{YES_INTERNET_SEARCH}" if:
- The user is asking for information that requires an internet search.
Conversation History:
{GENERAL_SEP_PAT}
{{chat_history}}
{GENERAL_SEP_PAT}
If you are at all unsure, respond with {SKIP_INTERNET_SEARCH}.
Respond with EXACTLY and ONLY "{YES_INTERNET_SEARCH}" or "{SKIP_INTERNET_SEARCH}"
Follow Up Input:
{{final_query}}
""".strip()
class InternetSearchResult(BaseModel):
title: str
link: str
snippet: str
class InternetSearchResponse(BaseModel):
revised_query: str
internet_results: list[InternetSearchResult]
def llm_doc_from_internet_search_result(result: InternetSearchResult) -> LlmDoc:
return LlmDoc(
document_id=result.link,
content=result.snippet,
blurb=result.snippet,
semantic_identifier=result.link,
source_type=DocumentSource.WEB,
metadata={},
updated_at=datetime.now(),
link=result.link,
source_links={0: result.link},
)
class InternetSearchTool(Tool):
_NAME = "run_internet_search"
_DISPLAY_NAME = "[Beta] Internet Search Tool"
_DESCRIPTION = "Perform an internet search for up-to-date information."
def __init__(self, api_key: str, num_results: int = 10) -> None:
self.api_key = api_key
self.host = "https://api.bing.microsoft.com/v7.0"
self.headers = {
"Ocp-Apim-Subscription-Key": api_key,
"Content-Type": "application/json",
}
self.num_results = num_results
self.client = httpx.Client()
@property
def name(self) -> str:
return self._NAME
@property
def description(self) -> str:
return self._DESCRIPTION
@property
def display_name(self) -> str:
return self._DISPLAY_NAME
def tool_definition(self) -> dict:
return {
"type": "function",
"function": {
"name": self.name,
"description": self.description,
"parameters": {
"type": "object",
"properties": {
"internet_search_query": {
"type": "string",
"description": "Query to search on the internet",
},
},
"required": ["internet_search_query"],
},
},
}
def check_if_needs_internet_search(
self,
query: str,
history: list[PreviousMessage],
llm: LLM,
) -> bool:
history_str = combine_message_chain(
messages=history, token_limit=GEN_AI_HISTORY_CUTOFF
)
prompt = INTERNET_SEARCH_TEMPLATE.format(
chat_history=history_str,
final_query=query,
)
use_internet_search_output = message_to_string(llm.invoke(prompt))
logger.debug(
f"Evaluated if should use internet search: {use_internet_search_output}"
)
return (
YES_INTERNET_SEARCH.split()[0]
).lower() in use_internet_search_output.lower()
def get_args_for_non_tool_calling_llm(
self,
query: str,
history: list[PreviousMessage],
llm: LLM,
force_run: bool = False,
) -> dict[str, Any] | None:
if not force_run and not self.check_if_needs_internet_search(
query, history, llm
):
return None
rephrased_query = history_based_query_rephrase(
query=query,
history=history,
llm=llm,
prompt_template=INTERNET_SEARCH_QUERY_REPHRASE,
)
return {
"internet_search_query": rephrased_query,
}
def build_tool_message_content(
self, *args: ToolResponse
) -> str | list[str | dict[str, Any]]:
search_response = cast(InternetSearchResponse, args[0].response)
return json.dumps(search_response.dict())
def _perform_search(self, query: str) -> InternetSearchResponse:
response = self.client.get(
f"{self.host}/search",
headers=self.headers,
params={"q": query, "count": self.num_results},
)
results = response.json()
return InternetSearchResponse(
revised_query=query,
internet_results=[
InternetSearchResult(
title=result["name"],
link=result["url"],
snippet=result["snippet"],
)
for result in results["webPages"]["value"][: self.num_results]
],
)
def run(self, **kwargs: str) -> Generator[ToolResponse, None, None]:
query = cast(str, kwargs["internet_search_query"])
results = self._perform_search(query)
yield ToolResponse(
id=INTERNET_SEARCH_RESPONSE_ID,
response=results,
)
llm_docs = [
llm_doc_from_internet_search_result(result)
for result in results.internet_results
]
yield ToolResponse(
id=FINAL_CONTEXT_DOCUMENTS,
response=llm_docs,
)
def final_result(self, *args: ToolResponse) -> JSON_ro:
search_response = cast(InternetSearchResponse, args[0].response)
return search_response.dict()

View File

@ -46,7 +46,7 @@ class SearchResponseSummary(BaseModel):
recency_bias_multiplier: float
search_tool_description = """
SEARCH_TOOL_DESCRIPTION = """
Runs a semantic search over the user's knowledge base. The default behavior is to use this tool. \
The only scenario where you should not use this tool is if:
@ -59,7 +59,9 @@ HINT: if you are unfamiliar with the user input OR think the user input is a typ
class SearchTool(Tool):
NAME = "run_search"
_NAME = "run_search"
_DISPLAY_NAME = "Search Tool"
_DESCRIPTION = SEARCH_TOOL_DESCRIPTION
def __init__(
self,
@ -95,8 +97,17 @@ class SearchTool(Tool):
self.bypass_acl = bypass_acl
self.db_session = db_session
@property
def name(self) -> str:
return self.NAME
return self._NAME
@property
def description(self) -> str:
return self._DESCRIPTION
@property
def display_name(self) -> str:
return self._DISPLAY_NAME
"""For explicit tool calling"""
@ -104,8 +115,8 @@ class SearchTool(Tool):
return {
"type": "function",
"function": {
"name": self.name(),
"description": search_tool_description,
"name": self.name,
"description": self.description,
"parameters": {
"type": "object",
"properties": {

View File

@ -9,10 +9,21 @@ from danswer.tools.models import ToolResponse
class Tool(abc.ABC):
@property
@abc.abstractmethod
def name(self) -> str:
raise NotImplementedError
@property
@abc.abstractmethod
def description(self) -> str:
raise NotImplementedError
@property
@abc.abstractmethod
def display_name(self) -> str:
raise NotImplementedError
"""For LLMs which support explicit tool calling"""
@abc.abstractmethod

View File

@ -18,7 +18,7 @@ class ToolRunner:
self._tool_responses: list[ToolResponse] | None = None
def kickoff(self) -> ToolCallKickoff:
return ToolCallKickoff(tool_name=self.tool.name(), tool_args=self.args)
return ToolCallKickoff(tool_name=self.tool.name, tool_args=self.args)
def tool_responses(self) -> Generator[ToolResponse, None, None]:
if self._tool_responses is not None:
@ -37,7 +37,7 @@ class ToolRunner:
def tool_final_result(self) -> ToolCallFinalResult:
return ToolCallFinalResult(
tool_name=self.tool.name(),
tool_name=self.tool.name,
tool_args=self.args,
tool_result=self.tool.final_result(*self.tool_responses()),
)

View File

@ -0,0 +1,78 @@
import re
from typing import Any
from danswer.chat.chat_utils import combine_message_chain
from danswer.configs.model_configs import GEN_AI_HISTORY_CUTOFF
from danswer.llm.answering.models import PreviousMessage
from danswer.llm.interfaces import LLM
from danswer.llm.utils import message_to_string
from danswer.prompts.constants import GENERAL_SEP_PAT
from danswer.tools.tool import Tool
from danswer.utils.logger import setup_logger
logger = setup_logger()
SINGLE_TOOL_SELECTION_PROMPT = f"""
You are an expert at selecting the most useful tool to run for answering the query.
You will be given a numbered list of tools and their arguments, a message history, and a query.
You will select a single tool that will be most useful for answering the query.
Respond with only the number corresponding to the tool you want to use.
Conversation History:
{GENERAL_SEP_PAT}
{{chat_history}}
{GENERAL_SEP_PAT}
Query:
{{query}}
Tools:
{{tool_list}}
Respond with EXACTLY and ONLY the number corresponding to the tool you want to use.
Your selection:
"""
def select_single_tool_for_non_tool_calling_llm(
tools_and_args: list[tuple[Tool, dict[str, Any]]],
history: list[PreviousMessage],
query: str,
llm: LLM,
) -> tuple[Tool, dict[str, Any]] | None:
if len(tools_and_args) == 1:
return tools_and_args[0]
tool_list_str = "\n".join(
f"""```{ind}: {tool.name} ({args}) - {tool.description}```"""
for ind, (tool, args) in enumerate(tools_and_args)
).lstrip()
history_str = combine_message_chain(
messages=history,
token_limit=GEN_AI_HISTORY_CUTOFF,
)
prompt = SINGLE_TOOL_SELECTION_PROMPT.format(
tool_list=tool_list_str, chat_history=history_str, query=query
)
output = message_to_string(llm.invoke(prompt))
try:
# First try to match the number
number_match = re.search(r"\d+", output)
if number_match:
tool_ind = int(number_match.group())
return tools_and_args[tool_ind]
# If that fails, try to match the tool name
for tool, args in tools_and_args:
if tool.name.lower() in output.lower():
return tool, args
# If that fails, return the first tool
return tools_and_args[0]
except Exception:
logger.error(f"Failed to select single tool for non-tool-calling LLM: {output}")
return None

1
backend/throttle.ctrl Normal file
View File

@ -0,0 +1 @@
f1f2 1 1718910083.03085 wikipedia:en

View File

@ -53,6 +53,10 @@ function findImageGenerationTool(tools: ToolSnapshot[]) {
return tools.find((tool) => tool.in_code_tool_id === "ImageGenerationTool");
}
function findInternetSearchTool(tools: ToolSnapshot[]) {
return tools.find((tool) => tool.in_code_tool_id === "InternetSearchTool");
}
function SubLabel({ children }: { children: string | JSX.Element }) {
return <div className="text-sm text-subtle mb-2">{children}</div>;
}
@ -150,16 +154,20 @@ export function AssistantEditor({
const imageGenerationTool = providerSupportingImageGenerationExists
? findImageGenerationTool(tools)
: undefined;
const internetSearchTool = findInternetSearchTool(tools);
const customTools = tools.filter(
(tool) =>
tool.in_code_tool_id !== searchTool?.in_code_tool_id &&
tool.in_code_tool_id !== imageGenerationTool?.in_code_tool_id
tool.in_code_tool_id !== imageGenerationTool?.in_code_tool_id &&
tool.in_code_tool_id !== internetSearchTool?.in_code_tool_id
);
const availableTools = [
...customTools,
...(searchTool ? [searchTool] : []),
...(imageGenerationTool ? [imageGenerationTool] : []),
...(internetSearchTool ? [internetSearchTool] : []),
];
const enabledToolsMap: { [key: number]: boolean } = {};
availableTools.forEach((tool) => {
@ -666,6 +674,17 @@ export function AssistantEditor({
</>
)}
{internetSearchTool && (
<BooleanFormField
noPadding
name={`enabled_tools_map.${internetSearchTool.id}`}
label={internetSearchTool.display_name}
onChange={() => {
toggleToolInValues(internetSearchTool.id);
}}
/>
)}
{customTools.length > 0 && (
<>
{customTools.map((tool) => (

View File

@ -1,6 +1,6 @@
import { Bubble } from "@/components/Bubble";
import { ToolSnapshot } from "@/lib/tools/interfaces";
import { FiImage, FiSearch } from "react-icons/fi";
import { FiImage, FiSearch, FiGlobe } from "react-icons/fi";
export function ToolsDisplay({ tools }: { tools: ToolSnapshot[] }) {
return (
@ -15,6 +15,9 @@ export function ToolsDisplay({ tools }: { tools: ToolSnapshot[] }) {
} else if (tool.name === "ImageGenerationTool") {
toolName = "Image Generation";
toolIcon = <FiImage className="mr-1 my-auto" />;
} else if (tool.name === "InternetSearchTool") {
toolName = "Internet Search";
toolIcon = <FiGlobe className="mr-1 my-auto" />;
}
return (

View File

@ -10,6 +10,7 @@ import {
DocumentMetadataBlock,
buildDocumentSummaryDisplay,
} from "@/components/search/DocumentDisplay";
import { InternetSearchIcon } from "@/components/InternetSearchIcon";
interface DocumentDisplayProps {
document: DanswerDocument;
@ -30,24 +31,25 @@ export function ChatDocumentDisplay({
setPopup,
tokenLimitReached,
}: DocumentDisplayProps) {
// Consider reintroducing null scored docs in the future
if (document.score === null) {
return null;
}
const isInternet = document.is_internet;
return (
<div key={document.semantic_identifier} className="text-sm px-3">
<div className="flex relative w-full overflow-y-visible">
<a
className={
"rounded-lg flex font-bold flex-shrink truncate " +
"rounded-lg flex font-bold flex-shrink truncate items-center " +
(document.link ? "" : "pointer-events-none")
}
href={document.link}
target="_blank"
rel="noopener noreferrer"
>
<SourceIcon sourceType={document.source_type} iconSize={18} />
{isInternet ? (
<InternetSearchIcon url={document.link} />
) : (
<SourceIcon sourceType={document.source_type} iconSize={18} />
)}
<p className="overflow-hidden text-ellipsis mx-2 my-auto text-sm ">
{document.semantic_identifier || document.document_id}
</p>
@ -73,29 +75,16 @@ export function ChatDocumentDisplay({
/>
</div>
)}
<div
className={`
text-xs
text-emphasis
bg-hover
rounded
p-0.5
w-fit
my-auto
select-none
my-auto
mr-2`}
>
{Math.abs(document.score).toFixed(2)}
</div>
</div>
)}
<DocumentSelector
isSelected={isSelected}
handleSelect={() => handleSelect(document.document_id)}
isDisabled={tokenLimitReached && !isSelected}
/>
{!isInternet && (
<DocumentSelector
isSelected={isSelected}
handleSelect={() => handleSelect(document.document_id)}
isDisabled={tokenLimitReached && !isSelected}
/>
)}
</div>
<div>
<div className="mt-1">

View File

@ -530,7 +530,11 @@ export function checkAnyAssistantHasSearch(
}
export function personaIncludesRetrieval(selectedPersona: Persona) {
return selectedPersona.num_chunks !== 0;
return selectedPersona.tools.some(
(tool) =>
tool.in_code_tool_id &&
["SearchTool", "InternetSearchTool"].includes(tool.in_code_tool_id)
);
}
const PARAMS_TO_SKIP = [

View File

@ -10,6 +10,7 @@ import {
FiChevronRight,
FiChevronLeft,
FiTool,
FiGlobe,
} from "react-icons/fi";
import { FeedbackType } from "../types";
import { useEffect, useRef, useState } from "react";
@ -25,6 +26,7 @@ import { ChatFileType, FileDescriptor, ToolCallMetadata } from "../interfaces";
import {
IMAGE_GENERATION_TOOL_NAME,
SEARCH_TOOL_NAME,
INTERNET_SEARCH_TOOL_NAME,
} from "../tools/constants";
import { ToolRunDisplay } from "../tools/ToolRunningAnimation";
import { Hoverable } from "@/components/Hoverable";
@ -39,11 +41,12 @@ import Prism from "prismjs";
import "prismjs/themes/prism-tomorrow.css";
import "./custom-code-styles.css";
import { Persona } from "@/app/admin/assistants/interfaces";
import { Button } from "@tremor/react";
import { AssistantIcon } from "@/components/assistants/AssistantIcon";
import { InternetSearchIcon } from "@/components/InternetSearchIcon";
const TOOLS_WITH_CUSTOM_HANDLING = [
SEARCH_TOOL_NAME,
INTERNET_SEARCH_TOOL_NAME,
IMAGE_GENERATION_TOOL_NAME,
];
@ -149,6 +152,9 @@ export const AIMessage = ({
content = trimIncompleteCodeSection(content);
}
const danswerSearchToolEnabledForPersona = currentPersona.tools.some(
(tool) => tool.in_code_tool_id === SEARCH_TOOL_NAME
);
const shouldShowLoader =
!toolCall || (toolCall.tool_name === SEARCH_TOOL_NAME && !content);
const defaultLoader = shouldShowLoader ? (
@ -200,36 +206,37 @@ export const AIMessage = ({
</div>
<div className="w-message-xs 2xl:w-message-sm 3xl:w-message-default break-words mt-1 ml-8">
{(!toolCall || toolCall.tool_name === SEARCH_TOOL_NAME) && (
<>
{query !== undefined &&
handleShowRetrieved !== undefined &&
isCurrentlyShowingRetrieved !== undefined &&
!retrievalDisabled && (
<div className="my-1">
<SearchSummary
query={query}
hasDocs={hasDocs || false}
messageId={messageId}
isCurrentlyShowingRetrieved={
isCurrentlyShowingRetrieved
}
handleShowRetrieved={handleShowRetrieved}
handleSearchQueryEdit={handleSearchQueryEdit}
/>
</div>
)}
{handleForceSearch &&
content &&
query === undefined &&
!hasDocs &&
!retrievalDisabled && (
<div className="my-1">
<SkippedSearch handleForceSearch={handleForceSearch} />
</div>
)}
</>
)}
{(!toolCall || toolCall.tool_name === SEARCH_TOOL_NAME) &&
danswerSearchToolEnabledForPersona && (
<>
{query !== undefined &&
handleShowRetrieved !== undefined &&
isCurrentlyShowingRetrieved !== undefined &&
!retrievalDisabled && (
<div className="my-1">
<SearchSummary
query={query}
hasDocs={hasDocs || false}
messageId={messageId}
isCurrentlyShowingRetrieved={
isCurrentlyShowingRetrieved
}
handleShowRetrieved={handleShowRetrieved}
handleSearchQueryEdit={handleSearchQueryEdit}
/>
</div>
)}
{handleForceSearch &&
content &&
query === undefined &&
!hasDocs &&
!retrievalDisabled && (
<div className="my-1">
<SkippedSearch handleForceSearch={handleForceSearch} />
</div>
)}
</>
)}
{toolCall &&
!TOOLS_WITH_CUSTOM_HANDLING.includes(toolCall.tool_name) && (
@ -258,6 +265,20 @@ export const AIMessage = ({
</div>
)}
{toolCall && toolCall.tool_name === INTERNET_SEARCH_TOOL_NAME && (
<div className="my-2">
<ToolRunDisplay
toolName={
toolCall.tool_result
? `Searched the internet`
: `Searching the internet`
}
toolLogo={<FiGlobe size={15} className="my-auto mr-1" />}
isRunning={!toolCall.tool_result}
/>
</div>
)}
{content ? (
<>
<FileDisplay files={files || []} />
@ -317,12 +338,16 @@ export const AIMessage = ({
.filter(([_, document]) => document.semantic_identifier)
.map(([citationKey, document], ind) => {
const display = (
<div className="max-w-350 text-ellipsis flex text-sm border border-border py-1 px-2 rounded flex">
<div className="max-w-350 text-ellipsis text-sm border border-border py-1 px-2 rounded flex">
<div className="mr-1 my-auto">
<SourceIcon
sourceType={document.source_type}
iconSize={16}
/>
{document.is_internet ? (
<InternetSearchIcon url={document.link} />
) : (
<SourceIcon
sourceType={document.source_type}
iconSize={16}
/>
)}
</div>
[{citationKey}] {document!.semantic_identifier}
</div>

View File

@ -1,2 +1,3 @@
export const SEARCH_TOOL_NAME = "run_search";
export const INTERNET_SEARCH_TOOL_NAME = "run_internet_search";
export const IMAGE_GENERATION_TOOL_NAME = "run_image_generation";

View File

@ -0,0 +1,9 @@
export function InternetSearchIcon({ url }: { url: string }) {
return (
<img
className="rounded-full w-[18px] h-[18px]"
src={`https://www.google.com/s2/favicons?sz=128&domain=${url}`}
alt="favicon"
/>
);
}

View File

@ -11,6 +11,7 @@ export const SearchType = {
SEMANTIC: "semantic",
KEYWORD: "keyword",
AUTOMATIC: "automatic",
INTERNET: "internet",
};
export type SearchType = (typeof SearchType)[keyof typeof SearchType];
@ -48,6 +49,7 @@ export interface DanswerDocument {
metadata: { [key: string]: string };
updated_at: string | null;
db_doc_id?: number;
is_internet: boolean;
}
export interface DocumentInfoPacket {

View File

@ -39,7 +39,6 @@ import {
import { ValidSources } from "./types";
import { SourceCategory, SourceMetadata } from "./search/interfaces";
import { Persona } from "@/app/admin/assistants/interfaces";
import internal from "stream";
interface PartialSourceMetadata {
icon: React.FC<{ size?: number; className?: string }>;
@ -232,6 +231,11 @@ const SOURCE_METADATA_MAP: SourceMap = {
displayName: "Google Storage",
category: SourceCategory.AppConnection,
},
not_applicable: {
icon: GlobeIcon,
displayName: "Internet",
category: SourceCategory.ImportedKnowledge,
},
};
function fillSourceMetadata(

View File

@ -1,6 +1,7 @@
export interface ToolSnapshot {
id: number;
name: string;
display_name: string;
description: string;
// only specified for Custom Tools. OpenAPI schema which represents

View File

@ -63,7 +63,8 @@ export type ValidSources =
| "s3"
| "r2"
| "google_cloud_storage"
| "oci_storage";
| "oci_storage"
| "not_applicable";
export type ValidInputTypes = "load_state" | "poll" | "event";
export type ValidStatuses =