mirror of
https://github.com/danswer-ai/danswer.git
synced 2025-07-28 13:53:28 +02:00
Add custom tool chat session / message ID dynamic prompting (#2404)
* add custom tool chat session / message ID dynamic prompting * update some formatting * code organization + remove unnecessary card * remove log * update for clarity
This commit is contained in:
@@ -88,6 +88,7 @@ from danswer.tools.internet_search.internet_search_tool import (
|
|||||||
)
|
)
|
||||||
from danswer.tools.internet_search.internet_search_tool import InternetSearchResponse
|
from danswer.tools.internet_search.internet_search_tool import InternetSearchResponse
|
||||||
from danswer.tools.internet_search.internet_search_tool import InternetSearchTool
|
from danswer.tools.internet_search.internet_search_tool import InternetSearchTool
|
||||||
|
from danswer.tools.models import DynamicSchemaInfo
|
||||||
from danswer.tools.search.search_tool import FINAL_CONTEXT_DOCUMENTS_ID
|
from danswer.tools.search.search_tool import FINAL_CONTEXT_DOCUMENTS_ID
|
||||||
from danswer.tools.search.search_tool import SEARCH_RESPONSE_SUMMARY_ID
|
from danswer.tools.search.search_tool import SEARCH_RESPONSE_SUMMARY_ID
|
||||||
from danswer.tools.search.search_tool import SearchResponseSummary
|
from danswer.tools.search.search_tool import SearchResponseSummary
|
||||||
@@ -605,7 +606,11 @@ def stream_chat_message_objects(
|
|||||||
tool_dict[db_tool_model.id] = cast(
|
tool_dict[db_tool_model.id] = cast(
|
||||||
list[Tool],
|
list[Tool],
|
||||||
build_custom_tools_from_openapi_schema(
|
build_custom_tools_from_openapi_schema(
|
||||||
db_tool_model.openapi_schema
|
db_tool_model.openapi_schema,
|
||||||
|
dynamic_schema_info=DynamicSchemaInfo(
|
||||||
|
chat_session_id=chat_session_id,
|
||||||
|
message_id=user_message.id if user_message else None,
|
||||||
|
),
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@@ -24,6 +24,9 @@ from danswer.tools.custom.openapi_parsing import openapi_to_method_specs
|
|||||||
from danswer.tools.custom.openapi_parsing import openapi_to_url
|
from danswer.tools.custom.openapi_parsing import openapi_to_url
|
||||||
from danswer.tools.custom.openapi_parsing import REQUEST_BODY
|
from danswer.tools.custom.openapi_parsing import REQUEST_BODY
|
||||||
from danswer.tools.custom.openapi_parsing import validate_openapi_schema
|
from danswer.tools.custom.openapi_parsing import validate_openapi_schema
|
||||||
|
from danswer.tools.models import CHAT_SESSION_ID_PLACEHOLDER
|
||||||
|
from danswer.tools.models import DynamicSchemaInfo
|
||||||
|
from danswer.tools.models import MESSAGE_ID_PLACEHOLDER
|
||||||
from danswer.tools.tool import Tool
|
from danswer.tools.tool import Tool
|
||||||
from danswer.tools.tool import ToolResponse
|
from danswer.tools.tool import ToolResponse
|
||||||
from danswer.utils.logger import setup_logger
|
from danswer.utils.logger import setup_logger
|
||||||
@@ -39,7 +42,11 @@ class CustomToolCallSummary(BaseModel):
|
|||||||
|
|
||||||
|
|
||||||
class CustomTool(Tool):
|
class CustomTool(Tool):
|
||||||
def __init__(self, method_spec: MethodSpec, base_url: str) -> None:
|
def __init__(
|
||||||
|
self,
|
||||||
|
method_spec: MethodSpec,
|
||||||
|
base_url: str,
|
||||||
|
) -> None:
|
||||||
self._base_url = base_url
|
self._base_url = base_url
|
||||||
self._method_spec = method_spec
|
self._method_spec = method_spec
|
||||||
self._tool_definition = self._method_spec.to_tool_definition()
|
self._tool_definition = self._method_spec.to_tool_definition()
|
||||||
@@ -141,6 +148,7 @@ class CustomTool(Tool):
|
|||||||
request_body = kwargs.get(REQUEST_BODY)
|
request_body = kwargs.get(REQUEST_BODY)
|
||||||
|
|
||||||
path_params = {}
|
path_params = {}
|
||||||
|
|
||||||
for path_param_schema in self._method_spec.get_path_param_schemas():
|
for path_param_schema in self._method_spec.get_path_param_schemas():
|
||||||
path_params[path_param_schema["name"]] = kwargs[path_param_schema["name"]]
|
path_params[path_param_schema["name"]] = kwargs[path_param_schema["name"]]
|
||||||
|
|
||||||
@@ -168,8 +176,23 @@ class CustomTool(Tool):
|
|||||||
|
|
||||||
|
|
||||||
def build_custom_tools_from_openapi_schema(
|
def build_custom_tools_from_openapi_schema(
|
||||||
openapi_schema: dict[str, Any]
|
openapi_schema: dict[str, Any],
|
||||||
|
dynamic_schema_info: DynamicSchemaInfo | None = None,
|
||||||
) -> list[CustomTool]:
|
) -> list[CustomTool]:
|
||||||
|
if dynamic_schema_info:
|
||||||
|
# Process dynamic schema information
|
||||||
|
schema_str = json.dumps(openapi_schema)
|
||||||
|
placeholders = {
|
||||||
|
CHAT_SESSION_ID_PLACEHOLDER: dynamic_schema_info.chat_session_id,
|
||||||
|
MESSAGE_ID_PLACEHOLDER: dynamic_schema_info.message_id,
|
||||||
|
}
|
||||||
|
|
||||||
|
for placeholder, value in placeholders.items():
|
||||||
|
if value:
|
||||||
|
schema_str = schema_str.replace(placeholder, str(value))
|
||||||
|
|
||||||
|
openapi_schema = json.loads(schema_str)
|
||||||
|
|
||||||
url = openapi_to_url(openapi_schema)
|
url = openapi_to_url(openapi_schema)
|
||||||
method_specs = openapi_to_method_specs(openapi_schema)
|
method_specs = openapi_to_method_specs(openapi_schema)
|
||||||
return [CustomTool(method_spec, url) for method_spec in method_specs]
|
return [CustomTool(method_spec, url) for method_spec in method_specs]
|
||||||
@@ -223,7 +246,9 @@ if __name__ == "__main__":
|
|||||||
}
|
}
|
||||||
validate_openapi_schema(openapi_schema)
|
validate_openapi_schema(openapi_schema)
|
||||||
|
|
||||||
tools = build_custom_tools_from_openapi_schema(openapi_schema)
|
tools = build_custom_tools_from_openapi_schema(
|
||||||
|
openapi_schema, dynamic_schema_info=None
|
||||||
|
)
|
||||||
|
|
||||||
openai_client = openai.OpenAI()
|
openai_client = openai.OpenAI()
|
||||||
response = openai_client.chat.completions.create(
|
response = openai_client.chat.completions.create(
|
||||||
|
@@ -37,3 +37,12 @@ class ToolCallFinalResult(ToolCallKickoff):
|
|||||||
tool_result: Any = (
|
tool_result: Any = (
|
||||||
None # we would like to use JSON_ro, but can't due to its recursive nature
|
None # we would like to use JSON_ro, but can't due to its recursive nature
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class DynamicSchemaInfo(BaseModel):
|
||||||
|
chat_session_id: int | None
|
||||||
|
message_id: int | None
|
||||||
|
|
||||||
|
|
||||||
|
CHAT_SESSION_ID_PLACEHOLDER = "CHAT_SESSION_ID"
|
||||||
|
MESSAGE_ID_PLACEHOLDER = "MESSAGE_ID"
|
||||||
|
@@ -6,7 +6,7 @@ import { Formik, Form, Field, ErrorMessage } from "formik";
|
|||||||
import * as Yup from "yup";
|
import * as Yup from "yup";
|
||||||
import { MethodSpec, ToolSnapshot } from "@/lib/tools/interfaces";
|
import { MethodSpec, ToolSnapshot } from "@/lib/tools/interfaces";
|
||||||
import { TextFormField } from "@/components/admin/connectors/Field";
|
import { TextFormField } from "@/components/admin/connectors/Field";
|
||||||
import { Button, Divider } from "@tremor/react";
|
import { Button, Divider, Text } from "@tremor/react";
|
||||||
import {
|
import {
|
||||||
createCustomTool,
|
createCustomTool,
|
||||||
updateCustomTool,
|
updateCustomTool,
|
||||||
@@ -14,6 +14,7 @@ import {
|
|||||||
} from "@/lib/tools/edit";
|
} from "@/lib/tools/edit";
|
||||||
import { usePopup } from "@/components/admin/connectors/Popup";
|
import { usePopup } from "@/components/admin/connectors/Popup";
|
||||||
import debounce from "lodash/debounce";
|
import debounce from "lodash/debounce";
|
||||||
|
import Link from "next/link";
|
||||||
|
|
||||||
function parseJsonWithTrailingCommas(jsonString: string) {
|
function parseJsonWithTrailingCommas(jsonString: string) {
|
||||||
// Regular expression to remove trailing commas before } or ]
|
// Regular expression to remove trailing commas before } or ]
|
||||||
@@ -85,8 +86,8 @@ function ToolForm({
|
|||||||
}, [values.definition, debouncedValidateDefinition]);
|
}, [values.definition, debouncedValidateDefinition]);
|
||||||
|
|
||||||
return (
|
return (
|
||||||
<Form>
|
<Form className="max-w-4xl">
|
||||||
<div className="relative">
|
<div className="relative w-full">
|
||||||
<TextFormField
|
<TextFormField
|
||||||
name="definition"
|
name="definition"
|
||||||
label="Definition"
|
label="Definition"
|
||||||
@@ -138,6 +139,28 @@ function ToolForm({
|
|||||||
component="div"
|
component="div"
|
||||||
className="text-error text-sm"
|
className="text-error text-sm"
|
||||||
/>
|
/>
|
||||||
|
<div className="mt-4 text-sm bg-blue-50 p-4 rounded-md border border-blue-200">
|
||||||
|
<Link
|
||||||
|
href="https://docs.danswer.dev/tools/custom"
|
||||||
|
className="text-link hover:underline flex items-center"
|
||||||
|
target="_blank"
|
||||||
|
rel="noopener noreferrer"
|
||||||
|
>
|
||||||
|
<svg
|
||||||
|
xmlns="http://www.w3.org/2000/svg"
|
||||||
|
className="h-5 w-5 mr-2"
|
||||||
|
viewBox="0 0 20 20"
|
||||||
|
fill="currentColor"
|
||||||
|
>
|
||||||
|
<path
|
||||||
|
fillRule="evenodd"
|
||||||
|
d="M18 10a8 8 0 11-16 0 8 8 0 0116 0zm-7-4a1 1 0 11-2 0 1 1 0 012 0zM9 9a1 1 0 000 2v3a1 1 0 001 1h1a1 1 0 100-2v-3a1 1 0 00-1-1H9z"
|
||||||
|
clipRule="evenodd"
|
||||||
|
/>
|
||||||
|
</svg>
|
||||||
|
Learn more about tool calling in our documentation
|
||||||
|
</Link>
|
||||||
|
</div>
|
||||||
|
|
||||||
{methodSpecs && methodSpecs.length > 0 && (
|
{methodSpecs && methodSpecs.length > 0 && (
|
||||||
<div className="mt-4">
|
<div className="mt-4">
|
||||||
|
Reference in New Issue
Block a user