mirror of
https://github.com/danswer-ai/danswer.git
synced 2025-06-26 07:50:56 +02:00
Internet Search Tool (#1666)
--------- Co-authored-by: Weves <chrisweaver101@gmail.com>
This commit is contained in:
parent
e06f8a0a4b
commit
146f85936b
7
.vscode/env_template.txt
vendored
7
.vscode/env_template.txt
vendored
@ -42,6 +42,11 @@ PYTHONPATH=./backend
|
||||
PYTHONUNBUFFERED=1
|
||||
|
||||
|
||||
# Internet Search
|
||||
BING_API_KEY=<REPLACE THIS>
|
||||
|
||||
|
||||
# Enable the full set of Danswer Enterprise Edition features
|
||||
# NOTE: DO NOT ENABLE THIS UNLESS YOU HAVE A PAID ENTERPRISE LICENSE (or if you are using this for local testing/development)
|
||||
ENABLE_PAID_ENTERPRISE_EDITION_FEATURES=False
|
||||
ENABLE_PAID_ENTERPRISE_EDITION_FEATURES=False
|
||||
|
||||
|
@ -0,0 +1,23 @@
|
||||
"""added is_internet to DBDoc
|
||||
|
||||
Revision ID: 4505fd7302e1
|
||||
Revises: c18cdf4b497e
|
||||
Create Date: 2024-06-18 20:46:09.095034
|
||||
|
||||
"""
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "4505fd7302e1"
|
||||
down_revision = "c18cdf4b497e"
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
op.add_column("search_doc", sa.Column("is_internet", sa.Boolean(), nullable=True))
|
||||
op.add_column("tool", sa.Column("display_name", sa.String(), nullable=True))
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.drop_column("tool", "display_name")
|
||||
op.drop_column("search_doc", "is_internet")
|
@ -14,6 +14,7 @@ from danswer.chat.models import LlmDoc
|
||||
from danswer.chat.models import LLMRelevanceFilterResponse
|
||||
from danswer.chat.models import QADocsResponse
|
||||
from danswer.chat.models import StreamingError
|
||||
from danswer.configs.chat_configs import BING_API_KEY
|
||||
from danswer.configs.chat_configs import CHAT_TARGET_CHUNK_PERCENTAGE
|
||||
from danswer.configs.chat_configs import DISABLE_LLM_CHOOSE_SEARCH
|
||||
from danswer.configs.chat_configs import MAX_CHUNKS_FED_TO_CHAT
|
||||
@ -53,10 +54,13 @@ from danswer.llm.factory import get_main_llm_from_tuple
|
||||
from danswer.llm.interfaces import LLMConfig
|
||||
from danswer.llm.utils import get_default_llm_tokenizer
|
||||
from danswer.search.enums import OptionalSearchSetting
|
||||
from danswer.search.enums import QueryFlow
|
||||
from danswer.search.enums import SearchType
|
||||
from danswer.search.retrieval.search_runner import inference_documents_from_ids
|
||||
from danswer.search.utils import chunks_or_sections_to_search_docs
|
||||
from danswer.search.utils import dedupe_documents
|
||||
from danswer.search.utils import drop_llm_indices
|
||||
from danswer.search.utils import internet_search_response_to_search_docs
|
||||
from danswer.server.query_and_chat.models import ChatMessageDetail
|
||||
from danswer.server.query_and_chat.models import CreateChatMessageRequest
|
||||
from danswer.server.utils import get_json_line
|
||||
@ -68,6 +72,11 @@ from danswer.tools.force import ForceUseTool
|
||||
from danswer.tools.images.image_generation_tool import IMAGE_GENERATION_RESPONSE_ID
|
||||
from danswer.tools.images.image_generation_tool import ImageGenerationResponse
|
||||
from danswer.tools.images.image_generation_tool import ImageGenerationTool
|
||||
from danswer.tools.internet_search.internet_search_tool import (
|
||||
INTERNET_SEARCH_RESPONSE_ID,
|
||||
)
|
||||
from danswer.tools.internet_search.internet_search_tool import InternetSearchResponse
|
||||
from danswer.tools.internet_search.internet_search_tool import InternetSearchTool
|
||||
from danswer.tools.search.search_tool import SEARCH_RESPONSE_SUMMARY_ID
|
||||
from danswer.tools.search.search_tool import SearchResponseSummary
|
||||
from danswer.tools.search.search_tool import SearchTool
|
||||
@ -145,6 +154,37 @@ def _handle_search_tool_response_summary(
|
||||
)
|
||||
|
||||
|
||||
def _handle_internet_search_tool_response_summary(
|
||||
packet: ToolResponse,
|
||||
db_session: Session,
|
||||
) -> tuple[QADocsResponse, list[DbSearchDoc]]:
|
||||
internet_search_response = cast(InternetSearchResponse, packet.response)
|
||||
server_search_docs = internet_search_response_to_search_docs(
|
||||
internet_search_response
|
||||
)
|
||||
|
||||
reference_db_search_docs = [
|
||||
create_db_search_doc(server_search_doc=doc, db_session=db_session)
|
||||
for doc in server_search_docs
|
||||
]
|
||||
response_docs = [
|
||||
translate_db_search_doc_to_server_search_doc(db_search_doc)
|
||||
for db_search_doc in reference_db_search_docs
|
||||
]
|
||||
return (
|
||||
QADocsResponse(
|
||||
rephrased_query=internet_search_response.revised_query,
|
||||
top_documents=response_docs,
|
||||
predicted_flow=QueryFlow.QUESTION_ANSWER,
|
||||
predicted_search=SearchType.HYBRID,
|
||||
applied_source_filters=[],
|
||||
applied_time_cutoff=None,
|
||||
recency_bias_multiplier=1.0,
|
||||
),
|
||||
reference_db_search_docs,
|
||||
)
|
||||
|
||||
|
||||
def _check_should_force_search(
|
||||
new_msg_req: CreateChatMessageRequest,
|
||||
) -> ForceUseTool | None:
|
||||
@ -172,7 +212,7 @@ def _check_should_force_search(
|
||||
args = {"query": new_msg_req.message}
|
||||
|
||||
return ForceUseTool(
|
||||
tool_name=SearchTool.NAME,
|
||||
tool_name=SearchTool._NAME,
|
||||
args=args,
|
||||
)
|
||||
return None
|
||||
@ -476,6 +516,15 @@ def stream_chat_message_objects(
|
||||
additional_headers=litellm_additional_headers,
|
||||
)
|
||||
]
|
||||
elif tool_cls.__name__ == InternetSearchTool.__name__:
|
||||
bing_api_key = BING_API_KEY
|
||||
if not bing_api_key:
|
||||
raise ValueError(
|
||||
"Internet search tool requires a Bing API key, please contact your Danswer admin to get it added!"
|
||||
)
|
||||
tool_dict[db_tool_model.id] = [
|
||||
InternetSearchTool(api_key=bing_api_key)
|
||||
]
|
||||
|
||||
continue
|
||||
|
||||
@ -582,6 +631,15 @@ def stream_chat_message_objects(
|
||||
yield ImageGenerationDisplay(
|
||||
file_ids=[str(file_id) for file_id in file_ids]
|
||||
)
|
||||
elif packet.id == INTERNET_SEARCH_RESPONSE_ID:
|
||||
(
|
||||
qa_docs_response,
|
||||
reference_db_search_docs,
|
||||
) = _handle_internet_search_tool_response_summary(
|
||||
packet=packet,
|
||||
db_session=db_session,
|
||||
)
|
||||
yield qa_docs_response
|
||||
elif packet.id == CUSTOM_TOOL_RESPONSE_ID:
|
||||
custom_tool_response = cast(CustomToolCallSummary, packet.response)
|
||||
yield CustomToolResponse(
|
||||
@ -623,7 +681,7 @@ def stream_chat_message_objects(
|
||||
tool_name_to_tool_id: dict[str, int] = {}
|
||||
for tool_id, tool_list in tool_dict.items():
|
||||
for tool in tool_list:
|
||||
tool_name_to_tool_id[tool.name()] = tool_id
|
||||
tool_name_to_tool_id[tool.name] = tool_id
|
||||
|
||||
gen_ai_response_message = partial_response(
|
||||
message=answer.llm_answer,
|
||||
|
@ -78,3 +78,6 @@ STOP_STREAM_PAT = os.environ.get("STOP_STREAM_PAT") or None
|
||||
|
||||
# The backend logic for this being True isn't fully supported yet
|
||||
HARD_DELETE_CHATS = False
|
||||
|
||||
# Internet Search
|
||||
BING_API_KEY = os.environ.get("BING_API_KEY") or None
|
||||
|
@ -104,6 +104,7 @@ class DocumentSource(str, Enum):
|
||||
R2 = "r2"
|
||||
GOOGLE_CLOUD_STORAGE = "google_cloud_storage"
|
||||
OCI_STORAGE = "oci_storage"
|
||||
NOT_APPLICABLE = "not_applicable"
|
||||
|
||||
|
||||
class BlobType(str, Enum):
|
||||
@ -112,6 +113,9 @@ class BlobType(str, Enum):
|
||||
GOOGLE_CLOUD_STORAGE = "google_cloud_storage"
|
||||
OCI_STORAGE = "oci_storage"
|
||||
|
||||
# Special case, for internet search
|
||||
NOT_APPLICABLE = "not_applicable"
|
||||
|
||||
|
||||
class DocumentIndexType(str, Enum):
|
||||
COMBINED = "combined" # Vespa
|
||||
|
@ -504,6 +504,7 @@ def create_db_search_doc(
|
||||
updated_at=server_search_doc.updated_at,
|
||||
primary_owners=server_search_doc.primary_owners,
|
||||
secondary_owners=server_search_doc.secondary_owners,
|
||||
is_internet=server_search_doc.is_internet,
|
||||
)
|
||||
|
||||
db_session.add(db_search_doc)
|
||||
@ -542,6 +543,7 @@ def translate_db_search_doc_to_server_search_doc(
|
||||
secondary_owners=(
|
||||
db_search_doc.secondary_owners if not remove_doc_content else []
|
||||
),
|
||||
is_internet=db_search_doc.is_internet,
|
||||
)
|
||||
|
||||
|
||||
|
@ -645,6 +645,7 @@ class SearchDoc(Base):
|
||||
secondary_owners: Mapped[list[str] | None] = mapped_column(
|
||||
postgresql.ARRAY(String), nullable=True
|
||||
)
|
||||
is_internet: Mapped[bool] = mapped_column(Boolean, default=False, nullable=True)
|
||||
|
||||
chat_messages = relationship(
|
||||
"ChatMessage",
|
||||
@ -990,6 +991,7 @@ class Tool(Base):
|
||||
# ID of the tool in the codebase, only applies for in-code tools.
|
||||
# tools defined via the UI will have this as None
|
||||
in_code_tool_id: Mapped[str | None] = mapped_column(String, nullable=True)
|
||||
display_name: Mapped[str] = mapped_column(String, nullable=True)
|
||||
|
||||
# OpenAPI scheme for the tool. Only applies to tools defined via the UI.
|
||||
openapi_schema: Mapped[dict[str, Any] | None] = mapped_column(
|
||||
|
@ -12,6 +12,7 @@ from sqlalchemy import update
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from danswer.auth.schemas import UserRole
|
||||
from danswer.configs.chat_configs import BING_API_KEY
|
||||
from danswer.db.constants import SLACK_BOT_PERSONA_PREFIX
|
||||
from danswer.db.engine import get_sqlalchemy_engine
|
||||
from danswer.db.models import DocumentSet
|
||||
@ -360,6 +361,10 @@ def upsert_persona(
|
||||
if not prompts and prompt_ids:
|
||||
raise ValueError("prompts not found")
|
||||
|
||||
# ensure all specified tools are valid
|
||||
if tools:
|
||||
validate_persona_tools(tools)
|
||||
|
||||
if persona:
|
||||
if not default_persona and persona.default_persona:
|
||||
raise ValueError("Cannot update default persona with non-default.")
|
||||
@ -457,6 +462,14 @@ def update_persona_visibility(
|
||||
db_session.commit()
|
||||
|
||||
|
||||
def validate_persona_tools(tools: list[Tool]) -> None:
|
||||
for tool in tools:
|
||||
if tool.name == "InternetSearchTool" and not BING_API_KEY:
|
||||
raise ValueError(
|
||||
"Bing API key not found, please contact your Danswer admin to get it added!"
|
||||
)
|
||||
|
||||
|
||||
def check_user_can_edit_persona(user: User | None, persona: Persona) -> None:
|
||||
# if user is None, assume that no-auth is turned on
|
||||
if user is None:
|
||||
|
@ -43,6 +43,7 @@ from danswer.tools.images.image_generation_tool import IMAGE_GENERATION_RESPONSE
|
||||
from danswer.tools.images.image_generation_tool import ImageGenerationResponse
|
||||
from danswer.tools.images.image_generation_tool import ImageGenerationTool
|
||||
from danswer.tools.images.prompt import build_image_generation_user_prompt
|
||||
from danswer.tools.internet_search.internet_search_tool import InternetSearchTool
|
||||
from danswer.tools.message import build_tool_message
|
||||
from danswer.tools.message import ToolCallSummary
|
||||
from danswer.tools.search.search_tool import FINAL_CONTEXT_DOCUMENTS
|
||||
@ -58,7 +59,11 @@ from danswer.tools.tool_runner import (
|
||||
from danswer.tools.tool_runner import ToolCallFinalResult
|
||||
from danswer.tools.tool_runner import ToolCallKickoff
|
||||
from danswer.tools.tool_runner import ToolRunner
|
||||
from danswer.tools.tool_selection import select_single_tool_for_non_tool_calling_llm
|
||||
from danswer.tools.utils import explicit_tool_calling_supported
|
||||
from danswer.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
def _get_answer_stream_processor(
|
||||
@ -228,7 +233,7 @@ class Answer:
|
||||
tool_call_requests = tool_call_chunk.tool_calls
|
||||
for tool_call_request in tool_call_requests:
|
||||
tool = [
|
||||
tool for tool in self.tools if tool.name() == tool_call_request["name"]
|
||||
tool for tool in self.tools if tool.name == tool_call_request["name"]
|
||||
][0]
|
||||
tool_args = (
|
||||
self.force_use_tool.args
|
||||
@ -247,15 +252,14 @@ class Answer:
|
||||
),
|
||||
)
|
||||
|
||||
if tool.name() == SearchTool.NAME:
|
||||
if tool.name in {SearchTool._NAME, InternetSearchTool._NAME}:
|
||||
self._update_prompt_builder_for_search_tool(prompt_builder, [])
|
||||
elif tool.name() == ImageGenerationTool.NAME:
|
||||
elif tool.name == ImageGenerationTool._NAME:
|
||||
prompt_builder.update_user_prompt(
|
||||
build_image_generation_user_prompt(
|
||||
query=self.question,
|
||||
)
|
||||
)
|
||||
|
||||
yield tool_runner.tool_final_result()
|
||||
|
||||
prompt = prompt_builder.build(tool_call_summary=tool_call_summary)
|
||||
@ -281,7 +285,7 @@ class Answer:
|
||||
[
|
||||
tool
|
||||
for tool in self.tools
|
||||
if tool.name() == self.force_use_tool.tool_name
|
||||
if tool.name == self.force_use_tool.tool_name
|
||||
]
|
||||
),
|
||||
None,
|
||||
@ -301,21 +305,39 @@ class Answer:
|
||||
)
|
||||
|
||||
if tool_args is None:
|
||||
raise RuntimeError(f"Tool '{tool.name()}' did not return args")
|
||||
raise RuntimeError(f"Tool '{tool.name}' did not return args")
|
||||
|
||||
chosen_tool_and_args = (tool, tool_args)
|
||||
else:
|
||||
all_tool_args = check_which_tools_should_run_for_non_tool_calling_llm(
|
||||
tool_options = check_which_tools_should_run_for_non_tool_calling_llm(
|
||||
tools=self.tools,
|
||||
query=self.question,
|
||||
history=self.message_history,
|
||||
llm=self.llm,
|
||||
)
|
||||
for ind, args in enumerate(all_tool_args):
|
||||
if args is not None:
|
||||
chosen_tool_and_args = (self.tools[ind], args)
|
||||
# for now, just pick the first tool selected
|
||||
break
|
||||
|
||||
available_tools_and_args = [
|
||||
(self.tools[ind], args)
|
||||
for ind, args in enumerate(tool_options)
|
||||
if args is not None
|
||||
]
|
||||
|
||||
logger.info(
|
||||
f"Selecting single tool from tools: {[(tool.name, args) for tool, args in available_tools_and_args]}"
|
||||
)
|
||||
|
||||
chosen_tool_and_args = (
|
||||
select_single_tool_for_non_tool_calling_llm(
|
||||
tools_and_args=available_tools_and_args,
|
||||
history=self.message_history,
|
||||
query=self.question,
|
||||
llm=self.llm,
|
||||
)
|
||||
if available_tools_and_args
|
||||
else None
|
||||
)
|
||||
|
||||
logger.info(f"Chosen tool: {chosen_tool_and_args}")
|
||||
|
||||
if not chosen_tool_and_args:
|
||||
prompt_builder.update_system_prompt(
|
||||
@ -336,7 +358,7 @@ class Answer:
|
||||
tool_runner = ToolRunner(tool, tool_args)
|
||||
yield tool_runner.kickoff()
|
||||
|
||||
if tool.name() == SearchTool.NAME:
|
||||
if tool.name in {SearchTool._NAME, InternetSearchTool._NAME}:
|
||||
final_context_documents = None
|
||||
for response in tool_runner.tool_responses():
|
||||
if response.id == FINAL_CONTEXT_DOCUMENTS:
|
||||
@ -344,12 +366,14 @@ class Answer:
|
||||
yield response
|
||||
|
||||
if final_context_documents is None:
|
||||
raise RuntimeError("SearchTool did not return final context documents")
|
||||
raise RuntimeError(
|
||||
f"{tool.name} did not return final context documents"
|
||||
)
|
||||
|
||||
self._update_prompt_builder_for_search_tool(
|
||||
prompt_builder, final_context_documents
|
||||
)
|
||||
elif tool.name() == ImageGenerationTool.NAME:
|
||||
elif tool.name == ImageGenerationTool._NAME:
|
||||
img_urls = []
|
||||
for response in tool_runner.tool_responses():
|
||||
if response.id == IMAGE_GENERATION_RESPONSE_ID:
|
||||
@ -371,7 +395,7 @@ class Answer:
|
||||
HumanMessage(
|
||||
content=build_user_message_for_custom_tool_for_non_tool_calling_llm(
|
||||
self.question,
|
||||
tool.name(),
|
||||
tool.name,
|
||||
*tool_runner.tool_responses(),
|
||||
)
|
||||
)
|
||||
|
@ -193,7 +193,7 @@ def stream_answer_objects(
|
||||
single_message_history=history_str,
|
||||
tools=[search_tool],
|
||||
force_use_tool=ForceUseTool(
|
||||
tool_name=search_tool.name(),
|
||||
tool_name=search_tool.name,
|
||||
args={"query": rephrased_query},
|
||||
),
|
||||
# for now, don't use tool calling for this flow, as we haven't
|
||||
|
@ -144,6 +144,23 @@ Follow Up Input: {{question}}
|
||||
Standalone question (Respond with only the short combined query):
|
||||
""".strip()
|
||||
|
||||
INTERNET_SEARCH_QUERY_REPHRASE = f"""
|
||||
Given the following conversation and a follow up input, rephrase the follow up into a SHORT, \
|
||||
standalone query suitable for an internet search engine.
|
||||
IMPORTANT: If a specific query might limit results, keep it broad. \
|
||||
If a broad query might yield too many results, make it detailed.
|
||||
If there is a clear change in topic, ensure the query reflects the new topic accurately.
|
||||
Strip out any information that is not relevant for the internet search.
|
||||
|
||||
{GENERAL_SEP_PAT}
|
||||
Chat History:
|
||||
{{chat_history}}
|
||||
{GENERAL_SEP_PAT}
|
||||
|
||||
Follow Up Input: {{question}}
|
||||
Internet Search Query (Respond with a detailed and specific query):
|
||||
""".strip()
|
||||
|
||||
|
||||
# The below prompts are retired
|
||||
NO_SEARCH = "No Search"
|
||||
|
@ -199,6 +199,7 @@ class SearchDoc(BaseModel):
|
||||
updated_at: datetime | None
|
||||
primary_owners: list[str] | None
|
||||
secondary_owners: list[str] | None
|
||||
is_internet: bool = False
|
||||
|
||||
def dict(self, *args: list, **kwargs: dict[str, Any]) -> dict[str, Any]: # type: ignore
|
||||
initial_dict = super().dict(*args, **kwargs) # type: ignore
|
||||
|
@ -1,11 +1,13 @@
|
||||
from collections.abc import Sequence
|
||||
from typing import TypeVar
|
||||
|
||||
from danswer.configs.constants import DocumentSource
|
||||
from danswer.db.models import SearchDoc as DBSearchDoc
|
||||
from danswer.search.models import InferenceChunk
|
||||
from danswer.search.models import InferenceSection
|
||||
from danswer.search.models import SavedSearchDoc
|
||||
from danswer.search.models import SearchDoc
|
||||
from danswer.tools.internet_search.internet_search_tool import InternetSearchResponse
|
||||
|
||||
|
||||
T = TypeVar("T", InferenceSection, InferenceChunk, SearchDoc)
|
||||
@ -57,6 +59,7 @@ def chunks_or_sections_to_search_docs(
|
||||
updated_at=chunk.updated_at,
|
||||
primary_owners=chunk.primary_owners,
|
||||
secondary_owners=chunk.secondary_owners,
|
||||
is_internet=False,
|
||||
)
|
||||
for chunk in chunks
|
||||
]
|
||||
@ -64,3 +67,28 @@ def chunks_or_sections_to_search_docs(
|
||||
else []
|
||||
)
|
||||
return search_docs
|
||||
|
||||
|
||||
def internet_search_response_to_search_docs(
|
||||
internet_search_response: InternetSearchResponse,
|
||||
) -> list[SearchDoc]:
|
||||
return [
|
||||
SearchDoc(
|
||||
document_id=doc.link,
|
||||
chunk_ind=-1,
|
||||
semantic_identifier=doc.title,
|
||||
link=doc.link,
|
||||
blurb=doc.snippet,
|
||||
source_type=DocumentSource.NOT_APPLICABLE,
|
||||
boost=0,
|
||||
hidden=False,
|
||||
metadata={},
|
||||
score=None,
|
||||
match_highlights=[],
|
||||
updated_at=None,
|
||||
primary_owners=[],
|
||||
secondary_owners=[],
|
||||
is_internet=True,
|
||||
)
|
||||
for doc in internet_search_response.internet_results
|
||||
]
|
||||
|
@ -74,11 +74,12 @@ def multilingual_query_expansion(
|
||||
def get_contextual_rephrase_messages(
|
||||
question: str,
|
||||
history_str: str,
|
||||
prompt_template: str = HISTORY_QUERY_REPHRASE,
|
||||
) -> list[dict[str, str]]:
|
||||
messages = [
|
||||
{
|
||||
"role": "user",
|
||||
"content": HISTORY_QUERY_REPHRASE.format(
|
||||
"content": prompt_template.format(
|
||||
question=question, chat_history=history_str
|
||||
),
|
||||
},
|
||||
@ -94,6 +95,7 @@ def history_based_query_rephrase(
|
||||
size_heuristic: int = 200,
|
||||
punctuation_heuristic: int = 10,
|
||||
skip_first_rephrase: bool = False,
|
||||
prompt_template: str = HISTORY_QUERY_REPHRASE,
|
||||
) -> str:
|
||||
# Globally disabled, just use the exact user query
|
||||
if DISABLE_LLM_QUERY_REPHRASE:
|
||||
@ -119,7 +121,7 @@ def history_based_query_rephrase(
|
||||
)
|
||||
|
||||
prompt_msgs = get_contextual_rephrase_messages(
|
||||
question=query, history_str=history_str
|
||||
question=query, history_str=history_str, prompt_template=prompt_template
|
||||
)
|
||||
|
||||
filled_llm_prompt = dict_based_prompt_to_langchain_prompt(prompt_msgs)
|
||||
|
@ -10,6 +10,7 @@ class ToolSnapshot(BaseModel):
|
||||
name: str
|
||||
description: str
|
||||
definition: dict[str, Any] | None
|
||||
display_name: str
|
||||
in_code_tool_id: str | None
|
||||
|
||||
@classmethod
|
||||
@ -19,5 +20,6 @@ class ToolSnapshot(BaseModel):
|
||||
name=tool.name,
|
||||
description=tool.description,
|
||||
definition=tool.openapi_schema,
|
||||
display_name=tool.display_name or tool.name,
|
||||
in_code_tool_id=tool.in_code_tool_id,
|
||||
)
|
||||
|
@ -1,3 +1,4 @@
|
||||
import os
|
||||
from typing import Type
|
||||
from typing import TypedDict
|
||||
|
||||
@ -9,6 +10,7 @@ from sqlalchemy.orm import Session
|
||||
from danswer.db.models import Persona
|
||||
from danswer.db.models import Tool as ToolDBModel
|
||||
from danswer.tools.images.image_generation_tool import ImageGenerationTool
|
||||
from danswer.tools.internet_search.internet_search_tool import InternetSearchTool
|
||||
from danswer.tools.search.search_tool import SearchTool
|
||||
from danswer.tools.tool import Tool
|
||||
from danswer.utils.logger import setup_logger
|
||||
@ -20,22 +22,41 @@ class InCodeToolInfo(TypedDict):
|
||||
cls: Type[Tool]
|
||||
description: str
|
||||
in_code_tool_id: str
|
||||
display_name: str
|
||||
|
||||
|
||||
BUILT_IN_TOOLS: list[InCodeToolInfo] = [
|
||||
{
|
||||
"cls": SearchTool,
|
||||
"description": "The Search Tool allows the Assistant to search through connected knowledge to help build an answer.",
|
||||
"in_code_tool_id": SearchTool.__name__,
|
||||
},
|
||||
{
|
||||
"cls": ImageGenerationTool,
|
||||
"description": (
|
||||
InCodeToolInfo(
|
||||
cls=SearchTool,
|
||||
description="The Search Tool allows the Assistant to search through connected knowledge to help build an answer.",
|
||||
in_code_tool_id=SearchTool.__name__,
|
||||
display_name=SearchTool._DISPLAY_NAME,
|
||||
),
|
||||
InCodeToolInfo(
|
||||
cls=ImageGenerationTool,
|
||||
description=(
|
||||
"The Image Generation Tool allows the assistant to use DALL-E 3 to generate images. "
|
||||
"The tool will be used when the user asks the assistant to generate an image."
|
||||
),
|
||||
"in_code_tool_id": ImageGenerationTool.__name__,
|
||||
},
|
||||
in_code_tool_id=ImageGenerationTool.__name__,
|
||||
display_name=ImageGenerationTool._DISPLAY_NAME,
|
||||
),
|
||||
# don't show the InternetSearchTool as an option if BING_API_KEY is not available
|
||||
*(
|
||||
[
|
||||
InCodeToolInfo(
|
||||
cls=InternetSearchTool,
|
||||
description=(
|
||||
"The Internet Search Tool allows the assistant "
|
||||
"to perform internet searches for up-to-date information."
|
||||
),
|
||||
in_code_tool_id=InternetSearchTool.__name__,
|
||||
display_name=InternetSearchTool._DISPLAY_NAME,
|
||||
)
|
||||
]
|
||||
if os.environ.get("BING_API_KEY")
|
||||
else []
|
||||
),
|
||||
]
|
||||
|
||||
|
||||
@ -55,12 +76,14 @@ def load_builtin_tools(db_session: Session) -> None:
|
||||
# Update existing tool
|
||||
tool.name = tool_name
|
||||
tool.description = tool_info["description"]
|
||||
tool.display_name = tool_info["display_name"]
|
||||
logger.info(f"Updated tool: {tool_name}")
|
||||
else:
|
||||
# Add new tool
|
||||
new_tool = ToolDBModel(
|
||||
name=tool_name,
|
||||
description=tool_info["description"],
|
||||
display_name=tool_info["display_name"],
|
||||
in_code_tool_id=tool_info["in_code_tool_id"],
|
||||
)
|
||||
db_session.add(new_tool)
|
||||
|
@ -44,11 +44,20 @@ class CustomTool(Tool):
|
||||
self._tool_definition = self._method_spec.to_tool_definition()
|
||||
|
||||
self._name = self._method_spec.name
|
||||
self.description = self._method_spec.summary
|
||||
self._description = self._method_spec.summary
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
return self._name
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return self._description
|
||||
|
||||
@property
|
||||
def display_name(self) -> str:
|
||||
return self._name
|
||||
|
||||
"""For LLMs which support explicit tool calling"""
|
||||
|
||||
def tool_definition(self) -> dict:
|
||||
@ -77,7 +86,7 @@ class CustomTool(Tool):
|
||||
content=SHOULD_USE_CUSTOM_TOOL_USER_PROMPT.format(
|
||||
history=history,
|
||||
query=query,
|
||||
tool_name=self.name(),
|
||||
tool_name=self.name,
|
||||
tool_description=self.description,
|
||||
)
|
||||
),
|
||||
@ -93,7 +102,7 @@ class CustomTool(Tool):
|
||||
content=TOOL_ARG_USER_PROMPT.format(
|
||||
history=history,
|
||||
query=query,
|
||||
tool_name=self.name(),
|
||||
tool_name=self.name,
|
||||
tool_description=self.description,
|
||||
tool_args=self.tool_definition()["function"]["parameters"],
|
||||
)
|
||||
@ -121,7 +130,7 @@ class CustomTool(Tool):
|
||||
|
||||
# pretend like nothing happened if not parse-able
|
||||
logger.error(
|
||||
f"Failed to parse args for '{self.name()}' tool. Recieved: {args_result_str}"
|
||||
f"Failed to parse args for '{self.name}' tool. Recieved: {args_result_str}"
|
||||
)
|
||||
return None
|
||||
|
||||
|
@ -37,4 +37,4 @@ def filter_tools_for_force_tool_use(
|
||||
if not force_use_tool:
|
||||
return tools
|
||||
|
||||
return [tool for tool in tools if tool.name() == force_use_tool.tool_name]
|
||||
return [tool for tool in tools if tool.name == force_use_tool.tool_name]
|
||||
|
@ -55,7 +55,9 @@ class ImageGenerationResponse(BaseModel):
|
||||
|
||||
|
||||
class ImageGenerationTool(Tool):
|
||||
NAME = "run_image_generation"
|
||||
_NAME = "run_image_generation"
|
||||
_DESCRIPTION = "Generate an image from a prompt."
|
||||
_DISPLAY_NAME = "Image Generation Tool"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
@ -75,15 +77,24 @@ class ImageGenerationTool(Tool):
|
||||
|
||||
self.additional_headers = additional_headers
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
return self.NAME
|
||||
return self._NAME
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return self._DESCRIPTION
|
||||
|
||||
@property
|
||||
def display_name(self) -> str:
|
||||
return self._DISPLAY_NAME
|
||||
|
||||
def tool_definition(self) -> dict:
|
||||
return {
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": self.name(),
|
||||
"description": "Generate an image from a prompt",
|
||||
"name": self.name,
|
||||
"description": self.description,
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
|
217
backend/danswer/tools/internet_search/internet_search_tool.py
Normal file
217
backend/danswer/tools/internet_search/internet_search_tool.py
Normal file
@ -0,0 +1,217 @@
|
||||
import json
|
||||
from collections.abc import Generator
|
||||
from datetime import datetime
|
||||
from typing import Any
|
||||
from typing import cast
|
||||
|
||||
import httpx
|
||||
from pydantic import BaseModel
|
||||
|
||||
from danswer.chat.chat_utils import combine_message_chain
|
||||
from danswer.chat.models import LlmDoc
|
||||
from danswer.configs.constants import DocumentSource
|
||||
from danswer.configs.model_configs import GEN_AI_HISTORY_CUTOFF
|
||||
from danswer.dynamic_configs.interface import JSON_ro
|
||||
from danswer.llm.answering.models import PreviousMessage
|
||||
from danswer.llm.interfaces import LLM
|
||||
from danswer.llm.utils import message_to_string
|
||||
from danswer.prompts.chat_prompts import INTERNET_SEARCH_QUERY_REPHRASE
|
||||
from danswer.prompts.constants import GENERAL_SEP_PAT
|
||||
from danswer.secondary_llm_flows.query_expansion import history_based_query_rephrase
|
||||
from danswer.tools.search.search_tool import FINAL_CONTEXT_DOCUMENTS
|
||||
from danswer.tools.tool import Tool
|
||||
from danswer.tools.tool import ToolResponse
|
||||
from danswer.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
INTERNET_SEARCH_RESPONSE_ID = "internet_search_response"
|
||||
|
||||
YES_INTERNET_SEARCH = "Yes Internet Search"
|
||||
SKIP_INTERNET_SEARCH = "Skip Internet Search"
|
||||
|
||||
INTERNET_SEARCH_TEMPLATE = f"""
|
||||
Given the conversation history and a follow up query, determine if the system should call \
|
||||
an external internet search tool to better answer the latest user input.
|
||||
Your default response is {SKIP_INTERNET_SEARCH}.
|
||||
|
||||
Respond "{YES_INTERNET_SEARCH}" if:
|
||||
- The user is asking for information that requires an internet search.
|
||||
|
||||
Conversation History:
|
||||
{GENERAL_SEP_PAT}
|
||||
{{chat_history}}
|
||||
{GENERAL_SEP_PAT}
|
||||
|
||||
If you are at all unsure, respond with {SKIP_INTERNET_SEARCH}.
|
||||
Respond with EXACTLY and ONLY "{YES_INTERNET_SEARCH}" or "{SKIP_INTERNET_SEARCH}"
|
||||
|
||||
Follow Up Input:
|
||||
{{final_query}}
|
||||
""".strip()
|
||||
|
||||
|
||||
class InternetSearchResult(BaseModel):
|
||||
title: str
|
||||
link: str
|
||||
snippet: str
|
||||
|
||||
|
||||
class InternetSearchResponse(BaseModel):
|
||||
revised_query: str
|
||||
internet_results: list[InternetSearchResult]
|
||||
|
||||
|
||||
def llm_doc_from_internet_search_result(result: InternetSearchResult) -> LlmDoc:
|
||||
return LlmDoc(
|
||||
document_id=result.link,
|
||||
content=result.snippet,
|
||||
blurb=result.snippet,
|
||||
semantic_identifier=result.link,
|
||||
source_type=DocumentSource.WEB,
|
||||
metadata={},
|
||||
updated_at=datetime.now(),
|
||||
link=result.link,
|
||||
source_links={0: result.link},
|
||||
)
|
||||
|
||||
|
||||
class InternetSearchTool(Tool):
|
||||
_NAME = "run_internet_search"
|
||||
_DISPLAY_NAME = "[Beta] Internet Search Tool"
|
||||
_DESCRIPTION = "Perform an internet search for up-to-date information."
|
||||
|
||||
def __init__(self, api_key: str, num_results: int = 10) -> None:
|
||||
self.api_key = api_key
|
||||
self.host = "https://api.bing.microsoft.com/v7.0"
|
||||
self.headers = {
|
||||
"Ocp-Apim-Subscription-Key": api_key,
|
||||
"Content-Type": "application/json",
|
||||
}
|
||||
self.num_results = num_results
|
||||
self.client = httpx.Client()
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
return self._NAME
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return self._DESCRIPTION
|
||||
|
||||
@property
|
||||
def display_name(self) -> str:
|
||||
return self._DISPLAY_NAME
|
||||
|
||||
def tool_definition(self) -> dict:
|
||||
return {
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": self.name,
|
||||
"description": self.description,
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"internet_search_query": {
|
||||
"type": "string",
|
||||
"description": "Query to search on the internet",
|
||||
},
|
||||
},
|
||||
"required": ["internet_search_query"],
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
def check_if_needs_internet_search(
|
||||
self,
|
||||
query: str,
|
||||
history: list[PreviousMessage],
|
||||
llm: LLM,
|
||||
) -> bool:
|
||||
history_str = combine_message_chain(
|
||||
messages=history, token_limit=GEN_AI_HISTORY_CUTOFF
|
||||
)
|
||||
prompt = INTERNET_SEARCH_TEMPLATE.format(
|
||||
chat_history=history_str,
|
||||
final_query=query,
|
||||
)
|
||||
use_internet_search_output = message_to_string(llm.invoke(prompt))
|
||||
|
||||
logger.debug(
|
||||
f"Evaluated if should use internet search: {use_internet_search_output}"
|
||||
)
|
||||
|
||||
return (
|
||||
YES_INTERNET_SEARCH.split()[0]
|
||||
).lower() in use_internet_search_output.lower()
|
||||
|
||||
def get_args_for_non_tool_calling_llm(
|
||||
self,
|
||||
query: str,
|
||||
history: list[PreviousMessage],
|
||||
llm: LLM,
|
||||
force_run: bool = False,
|
||||
) -> dict[str, Any] | None:
|
||||
if not force_run and not self.check_if_needs_internet_search(
|
||||
query, history, llm
|
||||
):
|
||||
return None
|
||||
|
||||
rephrased_query = history_based_query_rephrase(
|
||||
query=query,
|
||||
history=history,
|
||||
llm=llm,
|
||||
prompt_template=INTERNET_SEARCH_QUERY_REPHRASE,
|
||||
)
|
||||
return {
|
||||
"internet_search_query": rephrased_query,
|
||||
}
|
||||
|
||||
def build_tool_message_content(
|
||||
self, *args: ToolResponse
|
||||
) -> str | list[str | dict[str, Any]]:
|
||||
search_response = cast(InternetSearchResponse, args[0].response)
|
||||
return json.dumps(search_response.dict())
|
||||
|
||||
def _perform_search(self, query: str) -> InternetSearchResponse:
|
||||
response = self.client.get(
|
||||
f"{self.host}/search",
|
||||
headers=self.headers,
|
||||
params={"q": query, "count": self.num_results},
|
||||
)
|
||||
results = response.json()
|
||||
|
||||
return InternetSearchResponse(
|
||||
revised_query=query,
|
||||
internet_results=[
|
||||
InternetSearchResult(
|
||||
title=result["name"],
|
||||
link=result["url"],
|
||||
snippet=result["snippet"],
|
||||
)
|
||||
for result in results["webPages"]["value"][: self.num_results]
|
||||
],
|
||||
)
|
||||
|
||||
def run(self, **kwargs: str) -> Generator[ToolResponse, None, None]:
|
||||
query = cast(str, kwargs["internet_search_query"])
|
||||
|
||||
results = self._perform_search(query)
|
||||
yield ToolResponse(
|
||||
id=INTERNET_SEARCH_RESPONSE_ID,
|
||||
response=results,
|
||||
)
|
||||
|
||||
llm_docs = [
|
||||
llm_doc_from_internet_search_result(result)
|
||||
for result in results.internet_results
|
||||
]
|
||||
|
||||
yield ToolResponse(
|
||||
id=FINAL_CONTEXT_DOCUMENTS,
|
||||
response=llm_docs,
|
||||
)
|
||||
|
||||
def final_result(self, *args: ToolResponse) -> JSON_ro:
|
||||
search_response = cast(InternetSearchResponse, args[0].response)
|
||||
return search_response.dict()
|
@ -46,7 +46,7 @@ class SearchResponseSummary(BaseModel):
|
||||
recency_bias_multiplier: float
|
||||
|
||||
|
||||
search_tool_description = """
|
||||
SEARCH_TOOL_DESCRIPTION = """
|
||||
Runs a semantic search over the user's knowledge base. The default behavior is to use this tool. \
|
||||
The only scenario where you should not use this tool is if:
|
||||
|
||||
@ -59,7 +59,9 @@ HINT: if you are unfamiliar with the user input OR think the user input is a typ
|
||||
|
||||
|
||||
class SearchTool(Tool):
|
||||
NAME = "run_search"
|
||||
_NAME = "run_search"
|
||||
_DISPLAY_NAME = "Search Tool"
|
||||
_DESCRIPTION = SEARCH_TOOL_DESCRIPTION
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
@ -95,8 +97,17 @@ class SearchTool(Tool):
|
||||
self.bypass_acl = bypass_acl
|
||||
self.db_session = db_session
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
return self.NAME
|
||||
return self._NAME
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return self._DESCRIPTION
|
||||
|
||||
@property
|
||||
def display_name(self) -> str:
|
||||
return self._DISPLAY_NAME
|
||||
|
||||
"""For explicit tool calling"""
|
||||
|
||||
@ -104,8 +115,8 @@ class SearchTool(Tool):
|
||||
return {
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": self.name(),
|
||||
"description": search_tool_description,
|
||||
"name": self.name,
|
||||
"description": self.description,
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
|
@ -9,10 +9,21 @@ from danswer.tools.models import ToolResponse
|
||||
|
||||
|
||||
class Tool(abc.ABC):
|
||||
@property
|
||||
@abc.abstractmethod
|
||||
def name(self) -> str:
|
||||
raise NotImplementedError
|
||||
|
||||
@property
|
||||
@abc.abstractmethod
|
||||
def description(self) -> str:
|
||||
raise NotImplementedError
|
||||
|
||||
@property
|
||||
@abc.abstractmethod
|
||||
def display_name(self) -> str:
|
||||
raise NotImplementedError
|
||||
|
||||
"""For LLMs which support explicit tool calling"""
|
||||
|
||||
@abc.abstractmethod
|
||||
|
@ -18,7 +18,7 @@ class ToolRunner:
|
||||
self._tool_responses: list[ToolResponse] | None = None
|
||||
|
||||
def kickoff(self) -> ToolCallKickoff:
|
||||
return ToolCallKickoff(tool_name=self.tool.name(), tool_args=self.args)
|
||||
return ToolCallKickoff(tool_name=self.tool.name, tool_args=self.args)
|
||||
|
||||
def tool_responses(self) -> Generator[ToolResponse, None, None]:
|
||||
if self._tool_responses is not None:
|
||||
@ -37,7 +37,7 @@ class ToolRunner:
|
||||
|
||||
def tool_final_result(self) -> ToolCallFinalResult:
|
||||
return ToolCallFinalResult(
|
||||
tool_name=self.tool.name(),
|
||||
tool_name=self.tool.name,
|
||||
tool_args=self.args,
|
||||
tool_result=self.tool.final_result(*self.tool_responses()),
|
||||
)
|
||||
|
78
backend/danswer/tools/tool_selection.py
Normal file
78
backend/danswer/tools/tool_selection.py
Normal file
@ -0,0 +1,78 @@
|
||||
import re
|
||||
from typing import Any
|
||||
|
||||
from danswer.chat.chat_utils import combine_message_chain
|
||||
from danswer.configs.model_configs import GEN_AI_HISTORY_CUTOFF
|
||||
from danswer.llm.answering.models import PreviousMessage
|
||||
from danswer.llm.interfaces import LLM
|
||||
from danswer.llm.utils import message_to_string
|
||||
from danswer.prompts.constants import GENERAL_SEP_PAT
|
||||
from danswer.tools.tool import Tool
|
||||
from danswer.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
SINGLE_TOOL_SELECTION_PROMPT = f"""
|
||||
You are an expert at selecting the most useful tool to run for answering the query.
|
||||
You will be given a numbered list of tools and their arguments, a message history, and a query.
|
||||
You will select a single tool that will be most useful for answering the query.
|
||||
Respond with only the number corresponding to the tool you want to use.
|
||||
|
||||
Conversation History:
|
||||
{GENERAL_SEP_PAT}
|
||||
{{chat_history}}
|
||||
{GENERAL_SEP_PAT}
|
||||
|
||||
Query:
|
||||
{{query}}
|
||||
|
||||
Tools:
|
||||
{{tool_list}}
|
||||
|
||||
Respond with EXACTLY and ONLY the number corresponding to the tool you want to use.
|
||||
|
||||
Your selection:
|
||||
"""
|
||||
|
||||
|
||||
def select_single_tool_for_non_tool_calling_llm(
|
||||
tools_and_args: list[tuple[Tool, dict[str, Any]]],
|
||||
history: list[PreviousMessage],
|
||||
query: str,
|
||||
llm: LLM,
|
||||
) -> tuple[Tool, dict[str, Any]] | None:
|
||||
if len(tools_and_args) == 1:
|
||||
return tools_and_args[0]
|
||||
|
||||
tool_list_str = "\n".join(
|
||||
f"""```{ind}: {tool.name} ({args}) - {tool.description}```"""
|
||||
for ind, (tool, args) in enumerate(tools_and_args)
|
||||
).lstrip()
|
||||
|
||||
history_str = combine_message_chain(
|
||||
messages=history,
|
||||
token_limit=GEN_AI_HISTORY_CUTOFF,
|
||||
)
|
||||
prompt = SINGLE_TOOL_SELECTION_PROMPT.format(
|
||||
tool_list=tool_list_str, chat_history=history_str, query=query
|
||||
)
|
||||
output = message_to_string(llm.invoke(prompt))
|
||||
try:
|
||||
# First try to match the number
|
||||
number_match = re.search(r"\d+", output)
|
||||
if number_match:
|
||||
tool_ind = int(number_match.group())
|
||||
return tools_and_args[tool_ind]
|
||||
|
||||
# If that fails, try to match the tool name
|
||||
for tool, args in tools_and_args:
|
||||
if tool.name.lower() in output.lower():
|
||||
return tool, args
|
||||
|
||||
# If that fails, return the first tool
|
||||
return tools_and_args[0]
|
||||
|
||||
except Exception:
|
||||
logger.error(f"Failed to select single tool for non-tool-calling LLM: {output}")
|
||||
return None
|
1
backend/throttle.ctrl
Normal file
1
backend/throttle.ctrl
Normal file
@ -0,0 +1 @@
|
||||
f1f2 1 1718910083.03085 wikipedia:en
|
@ -53,6 +53,10 @@ function findImageGenerationTool(tools: ToolSnapshot[]) {
|
||||
return tools.find((tool) => tool.in_code_tool_id === "ImageGenerationTool");
|
||||
}
|
||||
|
||||
function findInternetSearchTool(tools: ToolSnapshot[]) {
|
||||
return tools.find((tool) => tool.in_code_tool_id === "InternetSearchTool");
|
||||
}
|
||||
|
||||
function SubLabel({ children }: { children: string | JSX.Element }) {
|
||||
return <div className="text-sm text-subtle mb-2">{children}</div>;
|
||||
}
|
||||
@ -150,16 +154,20 @@ export function AssistantEditor({
|
||||
const imageGenerationTool = providerSupportingImageGenerationExists
|
||||
? findImageGenerationTool(tools)
|
||||
: undefined;
|
||||
const internetSearchTool = findInternetSearchTool(tools);
|
||||
|
||||
const customTools = tools.filter(
|
||||
(tool) =>
|
||||
tool.in_code_tool_id !== searchTool?.in_code_tool_id &&
|
||||
tool.in_code_tool_id !== imageGenerationTool?.in_code_tool_id
|
||||
tool.in_code_tool_id !== imageGenerationTool?.in_code_tool_id &&
|
||||
tool.in_code_tool_id !== internetSearchTool?.in_code_tool_id
|
||||
);
|
||||
|
||||
const availableTools = [
|
||||
...customTools,
|
||||
...(searchTool ? [searchTool] : []),
|
||||
...(imageGenerationTool ? [imageGenerationTool] : []),
|
||||
...(internetSearchTool ? [internetSearchTool] : []),
|
||||
];
|
||||
const enabledToolsMap: { [key: number]: boolean } = {};
|
||||
availableTools.forEach((tool) => {
|
||||
@ -666,6 +674,17 @@ export function AssistantEditor({
|
||||
</>
|
||||
)}
|
||||
|
||||
{internetSearchTool && (
|
||||
<BooleanFormField
|
||||
noPadding
|
||||
name={`enabled_tools_map.${internetSearchTool.id}`}
|
||||
label={internetSearchTool.display_name}
|
||||
onChange={() => {
|
||||
toggleToolInValues(internetSearchTool.id);
|
||||
}}
|
||||
/>
|
||||
)}
|
||||
|
||||
{customTools.length > 0 && (
|
||||
<>
|
||||
{customTools.map((tool) => (
|
||||
|
@ -1,6 +1,6 @@
|
||||
import { Bubble } from "@/components/Bubble";
|
||||
import { ToolSnapshot } from "@/lib/tools/interfaces";
|
||||
import { FiImage, FiSearch } from "react-icons/fi";
|
||||
import { FiImage, FiSearch, FiGlobe } from "react-icons/fi";
|
||||
|
||||
export function ToolsDisplay({ tools }: { tools: ToolSnapshot[] }) {
|
||||
return (
|
||||
@ -15,6 +15,9 @@ export function ToolsDisplay({ tools }: { tools: ToolSnapshot[] }) {
|
||||
} else if (tool.name === "ImageGenerationTool") {
|
||||
toolName = "Image Generation";
|
||||
toolIcon = <FiImage className="mr-1 my-auto" />;
|
||||
} else if (tool.name === "InternetSearchTool") {
|
||||
toolName = "Internet Search";
|
||||
toolIcon = <FiGlobe className="mr-1 my-auto" />;
|
||||
}
|
||||
|
||||
return (
|
||||
|
@ -10,6 +10,7 @@ import {
|
||||
DocumentMetadataBlock,
|
||||
buildDocumentSummaryDisplay,
|
||||
} from "@/components/search/DocumentDisplay";
|
||||
import { InternetSearchIcon } from "@/components/InternetSearchIcon";
|
||||
|
||||
interface DocumentDisplayProps {
|
||||
document: DanswerDocument;
|
||||
@ -30,24 +31,25 @@ export function ChatDocumentDisplay({
|
||||
setPopup,
|
||||
tokenLimitReached,
|
||||
}: DocumentDisplayProps) {
|
||||
// Consider reintroducing null scored docs in the future
|
||||
if (document.score === null) {
|
||||
return null;
|
||||
}
|
||||
const isInternet = document.is_internet;
|
||||
|
||||
return (
|
||||
<div key={document.semantic_identifier} className="text-sm px-3">
|
||||
<div className="flex relative w-full overflow-y-visible">
|
||||
<a
|
||||
className={
|
||||
"rounded-lg flex font-bold flex-shrink truncate " +
|
||||
"rounded-lg flex font-bold flex-shrink truncate items-center " +
|
||||
(document.link ? "" : "pointer-events-none")
|
||||
}
|
||||
href={document.link}
|
||||
target="_blank"
|
||||
rel="noopener noreferrer"
|
||||
>
|
||||
<SourceIcon sourceType={document.source_type} iconSize={18} />
|
||||
{isInternet ? (
|
||||
<InternetSearchIcon url={document.link} />
|
||||
) : (
|
||||
<SourceIcon sourceType={document.source_type} iconSize={18} />
|
||||
)}
|
||||
<p className="overflow-hidden text-ellipsis mx-2 my-auto text-sm ">
|
||||
{document.semantic_identifier || document.document_id}
|
||||
</p>
|
||||
@ -73,29 +75,16 @@ export function ChatDocumentDisplay({
|
||||
/>
|
||||
</div>
|
||||
)}
|
||||
<div
|
||||
className={`
|
||||
text-xs
|
||||
text-emphasis
|
||||
bg-hover
|
||||
rounded
|
||||
p-0.5
|
||||
w-fit
|
||||
my-auto
|
||||
select-none
|
||||
my-auto
|
||||
mr-2`}
|
||||
>
|
||||
{Math.abs(document.score).toFixed(2)}
|
||||
</div>
|
||||
</div>
|
||||
)}
|
||||
|
||||
<DocumentSelector
|
||||
isSelected={isSelected}
|
||||
handleSelect={() => handleSelect(document.document_id)}
|
||||
isDisabled={tokenLimitReached && !isSelected}
|
||||
/>
|
||||
{!isInternet && (
|
||||
<DocumentSelector
|
||||
isSelected={isSelected}
|
||||
handleSelect={() => handleSelect(document.document_id)}
|
||||
isDisabled={tokenLimitReached && !isSelected}
|
||||
/>
|
||||
)}
|
||||
</div>
|
||||
<div>
|
||||
<div className="mt-1">
|
||||
|
@ -530,7 +530,11 @@ export function checkAnyAssistantHasSearch(
|
||||
}
|
||||
|
||||
export function personaIncludesRetrieval(selectedPersona: Persona) {
|
||||
return selectedPersona.num_chunks !== 0;
|
||||
return selectedPersona.tools.some(
|
||||
(tool) =>
|
||||
tool.in_code_tool_id &&
|
||||
["SearchTool", "InternetSearchTool"].includes(tool.in_code_tool_id)
|
||||
);
|
||||
}
|
||||
|
||||
const PARAMS_TO_SKIP = [
|
||||
|
@ -10,6 +10,7 @@ import {
|
||||
FiChevronRight,
|
||||
FiChevronLeft,
|
||||
FiTool,
|
||||
FiGlobe,
|
||||
} from "react-icons/fi";
|
||||
import { FeedbackType } from "../types";
|
||||
import { useEffect, useRef, useState } from "react";
|
||||
@ -25,6 +26,7 @@ import { ChatFileType, FileDescriptor, ToolCallMetadata } from "../interfaces";
|
||||
import {
|
||||
IMAGE_GENERATION_TOOL_NAME,
|
||||
SEARCH_TOOL_NAME,
|
||||
INTERNET_SEARCH_TOOL_NAME,
|
||||
} from "../tools/constants";
|
||||
import { ToolRunDisplay } from "../tools/ToolRunningAnimation";
|
||||
import { Hoverable } from "@/components/Hoverable";
|
||||
@ -39,11 +41,12 @@ import Prism from "prismjs";
|
||||
import "prismjs/themes/prism-tomorrow.css";
|
||||
import "./custom-code-styles.css";
|
||||
import { Persona } from "@/app/admin/assistants/interfaces";
|
||||
import { Button } from "@tremor/react";
|
||||
import { AssistantIcon } from "@/components/assistants/AssistantIcon";
|
||||
import { InternetSearchIcon } from "@/components/InternetSearchIcon";
|
||||
|
||||
const TOOLS_WITH_CUSTOM_HANDLING = [
|
||||
SEARCH_TOOL_NAME,
|
||||
INTERNET_SEARCH_TOOL_NAME,
|
||||
IMAGE_GENERATION_TOOL_NAME,
|
||||
];
|
||||
|
||||
@ -149,6 +152,9 @@ export const AIMessage = ({
|
||||
content = trimIncompleteCodeSection(content);
|
||||
}
|
||||
|
||||
const danswerSearchToolEnabledForPersona = currentPersona.tools.some(
|
||||
(tool) => tool.in_code_tool_id === SEARCH_TOOL_NAME
|
||||
);
|
||||
const shouldShowLoader =
|
||||
!toolCall || (toolCall.tool_name === SEARCH_TOOL_NAME && !content);
|
||||
const defaultLoader = shouldShowLoader ? (
|
||||
@ -200,36 +206,37 @@ export const AIMessage = ({
|
||||
</div>
|
||||
|
||||
<div className="w-message-xs 2xl:w-message-sm 3xl:w-message-default break-words mt-1 ml-8">
|
||||
{(!toolCall || toolCall.tool_name === SEARCH_TOOL_NAME) && (
|
||||
<>
|
||||
{query !== undefined &&
|
||||
handleShowRetrieved !== undefined &&
|
||||
isCurrentlyShowingRetrieved !== undefined &&
|
||||
!retrievalDisabled && (
|
||||
<div className="my-1">
|
||||
<SearchSummary
|
||||
query={query}
|
||||
hasDocs={hasDocs || false}
|
||||
messageId={messageId}
|
||||
isCurrentlyShowingRetrieved={
|
||||
isCurrentlyShowingRetrieved
|
||||
}
|
||||
handleShowRetrieved={handleShowRetrieved}
|
||||
handleSearchQueryEdit={handleSearchQueryEdit}
|
||||
/>
|
||||
</div>
|
||||
)}
|
||||
{handleForceSearch &&
|
||||
content &&
|
||||
query === undefined &&
|
||||
!hasDocs &&
|
||||
!retrievalDisabled && (
|
||||
<div className="my-1">
|
||||
<SkippedSearch handleForceSearch={handleForceSearch} />
|
||||
</div>
|
||||
)}
|
||||
</>
|
||||
)}
|
||||
{(!toolCall || toolCall.tool_name === SEARCH_TOOL_NAME) &&
|
||||
danswerSearchToolEnabledForPersona && (
|
||||
<>
|
||||
{query !== undefined &&
|
||||
handleShowRetrieved !== undefined &&
|
||||
isCurrentlyShowingRetrieved !== undefined &&
|
||||
!retrievalDisabled && (
|
||||
<div className="my-1">
|
||||
<SearchSummary
|
||||
query={query}
|
||||
hasDocs={hasDocs || false}
|
||||
messageId={messageId}
|
||||
isCurrentlyShowingRetrieved={
|
||||
isCurrentlyShowingRetrieved
|
||||
}
|
||||
handleShowRetrieved={handleShowRetrieved}
|
||||
handleSearchQueryEdit={handleSearchQueryEdit}
|
||||
/>
|
||||
</div>
|
||||
)}
|
||||
{handleForceSearch &&
|
||||
content &&
|
||||
query === undefined &&
|
||||
!hasDocs &&
|
||||
!retrievalDisabled && (
|
||||
<div className="my-1">
|
||||
<SkippedSearch handleForceSearch={handleForceSearch} />
|
||||
</div>
|
||||
)}
|
||||
</>
|
||||
)}
|
||||
|
||||
{toolCall &&
|
||||
!TOOLS_WITH_CUSTOM_HANDLING.includes(toolCall.tool_name) && (
|
||||
@ -258,6 +265,20 @@ export const AIMessage = ({
|
||||
</div>
|
||||
)}
|
||||
|
||||
{toolCall && toolCall.tool_name === INTERNET_SEARCH_TOOL_NAME && (
|
||||
<div className="my-2">
|
||||
<ToolRunDisplay
|
||||
toolName={
|
||||
toolCall.tool_result
|
||||
? `Searched the internet`
|
||||
: `Searching the internet`
|
||||
}
|
||||
toolLogo={<FiGlobe size={15} className="my-auto mr-1" />}
|
||||
isRunning={!toolCall.tool_result}
|
||||
/>
|
||||
</div>
|
||||
)}
|
||||
|
||||
{content ? (
|
||||
<>
|
||||
<FileDisplay files={files || []} />
|
||||
@ -317,12 +338,16 @@ export const AIMessage = ({
|
||||
.filter(([_, document]) => document.semantic_identifier)
|
||||
.map(([citationKey, document], ind) => {
|
||||
const display = (
|
||||
<div className="max-w-350 text-ellipsis flex text-sm border border-border py-1 px-2 rounded flex">
|
||||
<div className="max-w-350 text-ellipsis text-sm border border-border py-1 px-2 rounded flex">
|
||||
<div className="mr-1 my-auto">
|
||||
<SourceIcon
|
||||
sourceType={document.source_type}
|
||||
iconSize={16}
|
||||
/>
|
||||
{document.is_internet ? (
|
||||
<InternetSearchIcon url={document.link} />
|
||||
) : (
|
||||
<SourceIcon
|
||||
sourceType={document.source_type}
|
||||
iconSize={16}
|
||||
/>
|
||||
)}
|
||||
</div>
|
||||
[{citationKey}] {document!.semantic_identifier}
|
||||
</div>
|
||||
|
@ -1,2 +1,3 @@
|
||||
export const SEARCH_TOOL_NAME = "run_search";
|
||||
export const INTERNET_SEARCH_TOOL_NAME = "run_internet_search";
|
||||
export const IMAGE_GENERATION_TOOL_NAME = "run_image_generation";
|
||||
|
9
web/src/components/InternetSearchIcon.tsx
Normal file
9
web/src/components/InternetSearchIcon.tsx
Normal file
@ -0,0 +1,9 @@
|
||||
export function InternetSearchIcon({ url }: { url: string }) {
|
||||
return (
|
||||
<img
|
||||
className="rounded-full w-[18px] h-[18px]"
|
||||
src={`https://www.google.com/s2/favicons?sz=128&domain=${url}`}
|
||||
alt="favicon"
|
||||
/>
|
||||
);
|
||||
}
|
@ -11,6 +11,7 @@ export const SearchType = {
|
||||
SEMANTIC: "semantic",
|
||||
KEYWORD: "keyword",
|
||||
AUTOMATIC: "automatic",
|
||||
INTERNET: "internet",
|
||||
};
|
||||
export type SearchType = (typeof SearchType)[keyof typeof SearchType];
|
||||
|
||||
@ -48,6 +49,7 @@ export interface DanswerDocument {
|
||||
metadata: { [key: string]: string };
|
||||
updated_at: string | null;
|
||||
db_doc_id?: number;
|
||||
is_internet: boolean;
|
||||
}
|
||||
|
||||
export interface DocumentInfoPacket {
|
||||
|
@ -39,7 +39,6 @@ import {
|
||||
import { ValidSources } from "./types";
|
||||
import { SourceCategory, SourceMetadata } from "./search/interfaces";
|
||||
import { Persona } from "@/app/admin/assistants/interfaces";
|
||||
import internal from "stream";
|
||||
|
||||
interface PartialSourceMetadata {
|
||||
icon: React.FC<{ size?: number; className?: string }>;
|
||||
@ -232,6 +231,11 @@ const SOURCE_METADATA_MAP: SourceMap = {
|
||||
displayName: "Google Storage",
|
||||
category: SourceCategory.AppConnection,
|
||||
},
|
||||
not_applicable: {
|
||||
icon: GlobeIcon,
|
||||
displayName: "Internet",
|
||||
category: SourceCategory.ImportedKnowledge,
|
||||
},
|
||||
};
|
||||
|
||||
function fillSourceMetadata(
|
||||
|
@ -1,6 +1,7 @@
|
||||
export interface ToolSnapshot {
|
||||
id: number;
|
||||
name: string;
|
||||
display_name: string;
|
||||
description: string;
|
||||
|
||||
// only specified for Custom Tools. OpenAPI schema which represents
|
||||
|
@ -63,7 +63,8 @@ export type ValidSources =
|
||||
| "s3"
|
||||
| "r2"
|
||||
| "google_cloud_storage"
|
||||
| "oci_storage";
|
||||
| "oci_storage"
|
||||
| "not_applicable";
|
||||
|
||||
export type ValidInputTypes = "load_state" | "poll" | "event";
|
||||
export type ValidStatuses =
|
||||
|
Loading…
x
Reference in New Issue
Block a user