mirror of
https://github.com/danswer-ai/danswer.git
synced 2025-04-09 04:18:32 +02:00
Add custom tool chat session / message ID dynamic prompting (#2404)
* add custom tool chat session / message ID dynamic prompting * update some formatting * code organization + remove unnecessary card * remove log * update for clarity
This commit is contained in:
parent
fc98c560a4
commit
648c2531f9
@ -88,6 +88,7 @@ from danswer.tools.internet_search.internet_search_tool import (
|
||||
)
|
||||
from danswer.tools.internet_search.internet_search_tool import InternetSearchResponse
|
||||
from danswer.tools.internet_search.internet_search_tool import InternetSearchTool
|
||||
from danswer.tools.models import DynamicSchemaInfo
|
||||
from danswer.tools.search.search_tool import FINAL_CONTEXT_DOCUMENTS_ID
|
||||
from danswer.tools.search.search_tool import SEARCH_RESPONSE_SUMMARY_ID
|
||||
from danswer.tools.search.search_tool import SearchResponseSummary
|
||||
@ -605,7 +606,11 @@ def stream_chat_message_objects(
|
||||
tool_dict[db_tool_model.id] = cast(
|
||||
list[Tool],
|
||||
build_custom_tools_from_openapi_schema(
|
||||
db_tool_model.openapi_schema
|
||||
db_tool_model.openapi_schema,
|
||||
dynamic_schema_info=DynamicSchemaInfo(
|
||||
chat_session_id=chat_session_id,
|
||||
message_id=user_message.id if user_message else None,
|
||||
),
|
||||
),
|
||||
)
|
||||
|
||||
|
@ -24,6 +24,9 @@ from danswer.tools.custom.openapi_parsing import openapi_to_method_specs
|
||||
from danswer.tools.custom.openapi_parsing import openapi_to_url
|
||||
from danswer.tools.custom.openapi_parsing import REQUEST_BODY
|
||||
from danswer.tools.custom.openapi_parsing import validate_openapi_schema
|
||||
from danswer.tools.models import CHAT_SESSION_ID_PLACEHOLDER
|
||||
from danswer.tools.models import DynamicSchemaInfo
|
||||
from danswer.tools.models import MESSAGE_ID_PLACEHOLDER
|
||||
from danswer.tools.tool import Tool
|
||||
from danswer.tools.tool import ToolResponse
|
||||
from danswer.utils.logger import setup_logger
|
||||
@ -39,7 +42,11 @@ class CustomToolCallSummary(BaseModel):
|
||||
|
||||
|
||||
class CustomTool(Tool):
|
||||
def __init__(self, method_spec: MethodSpec, base_url: str) -> None:
|
||||
def __init__(
|
||||
self,
|
||||
method_spec: MethodSpec,
|
||||
base_url: str,
|
||||
) -> None:
|
||||
self._base_url = base_url
|
||||
self._method_spec = method_spec
|
||||
self._tool_definition = self._method_spec.to_tool_definition()
|
||||
@ -141,6 +148,7 @@ class CustomTool(Tool):
|
||||
request_body = kwargs.get(REQUEST_BODY)
|
||||
|
||||
path_params = {}
|
||||
|
||||
for path_param_schema in self._method_spec.get_path_param_schemas():
|
||||
path_params[path_param_schema["name"]] = kwargs[path_param_schema["name"]]
|
||||
|
||||
@ -168,8 +176,23 @@ class CustomTool(Tool):
|
||||
|
||||
|
||||
def build_custom_tools_from_openapi_schema(
|
||||
openapi_schema: dict[str, Any]
|
||||
openapi_schema: dict[str, Any],
|
||||
dynamic_schema_info: DynamicSchemaInfo | None = None,
|
||||
) -> list[CustomTool]:
|
||||
if dynamic_schema_info:
|
||||
# Process dynamic schema information
|
||||
schema_str = json.dumps(openapi_schema)
|
||||
placeholders = {
|
||||
CHAT_SESSION_ID_PLACEHOLDER: dynamic_schema_info.chat_session_id,
|
||||
MESSAGE_ID_PLACEHOLDER: dynamic_schema_info.message_id,
|
||||
}
|
||||
|
||||
for placeholder, value in placeholders.items():
|
||||
if value:
|
||||
schema_str = schema_str.replace(placeholder, str(value))
|
||||
|
||||
openapi_schema = json.loads(schema_str)
|
||||
|
||||
url = openapi_to_url(openapi_schema)
|
||||
method_specs = openapi_to_method_specs(openapi_schema)
|
||||
return [CustomTool(method_spec, url) for method_spec in method_specs]
|
||||
@ -223,7 +246,9 @@ if __name__ == "__main__":
|
||||
}
|
||||
validate_openapi_schema(openapi_schema)
|
||||
|
||||
tools = build_custom_tools_from_openapi_schema(openapi_schema)
|
||||
tools = build_custom_tools_from_openapi_schema(
|
||||
openapi_schema, dynamic_schema_info=None
|
||||
)
|
||||
|
||||
openai_client = openai.OpenAI()
|
||||
response = openai_client.chat.completions.create(
|
||||
|
@ -37,3 +37,12 @@ class ToolCallFinalResult(ToolCallKickoff):
|
||||
tool_result: Any = (
|
||||
None # we would like to use JSON_ro, but can't due to its recursive nature
|
||||
)
|
||||
|
||||
|
||||
class DynamicSchemaInfo(BaseModel):
|
||||
chat_session_id: int | None
|
||||
message_id: int | None
|
||||
|
||||
|
||||
CHAT_SESSION_ID_PLACEHOLDER = "CHAT_SESSION_ID"
|
||||
MESSAGE_ID_PLACEHOLDER = "MESSAGE_ID"
|
||||
|
@ -6,7 +6,7 @@ import { Formik, Form, Field, ErrorMessage } from "formik";
|
||||
import * as Yup from "yup";
|
||||
import { MethodSpec, ToolSnapshot } from "@/lib/tools/interfaces";
|
||||
import { TextFormField } from "@/components/admin/connectors/Field";
|
||||
import { Button, Divider } from "@tremor/react";
|
||||
import { Button, Divider, Text } from "@tremor/react";
|
||||
import {
|
||||
createCustomTool,
|
||||
updateCustomTool,
|
||||
@ -14,6 +14,7 @@ import {
|
||||
} from "@/lib/tools/edit";
|
||||
import { usePopup } from "@/components/admin/connectors/Popup";
|
||||
import debounce from "lodash/debounce";
|
||||
import Link from "next/link";
|
||||
|
||||
function parseJsonWithTrailingCommas(jsonString: string) {
|
||||
// Regular expression to remove trailing commas before } or ]
|
||||
@ -85,8 +86,8 @@ function ToolForm({
|
||||
}, [values.definition, debouncedValidateDefinition]);
|
||||
|
||||
return (
|
||||
<Form>
|
||||
<div className="relative">
|
||||
<Form className="max-w-4xl">
|
||||
<div className="relative w-full">
|
||||
<TextFormField
|
||||
name="definition"
|
||||
label="Definition"
|
||||
@ -138,6 +139,28 @@ function ToolForm({
|
||||
component="div"
|
||||
className="text-error text-sm"
|
||||
/>
|
||||
<div className="mt-4 text-sm bg-blue-50 p-4 rounded-md border border-blue-200">
|
||||
<Link
|
||||
href="https://docs.danswer.dev/tools/custom"
|
||||
className="text-link hover:underline flex items-center"
|
||||
target="_blank"
|
||||
rel="noopener noreferrer"
|
||||
>
|
||||
<svg
|
||||
xmlns="http://www.w3.org/2000/svg"
|
||||
className="h-5 w-5 mr-2"
|
||||
viewBox="0 0 20 20"
|
||||
fill="currentColor"
|
||||
>
|
||||
<path
|
||||
fillRule="evenodd"
|
||||
d="M18 10a8 8 0 11-16 0 8 8 0 0116 0zm-7-4a1 1 0 11-2 0 1 1 0 012 0zM9 9a1 1 0 000 2v3a1 1 0 001 1h1a1 1 0 100-2v-3a1 1 0 00-1-1H9z"
|
||||
clipRule="evenodd"
|
||||
/>
|
||||
</svg>
|
||||
Learn more about tool calling in our documentation
|
||||
</Link>
|
||||
</div>
|
||||
|
||||
{methodSpecs && methodSpecs.length > 0 && (
|
||||
<div className="mt-4">
|
||||
|
Loading…
x
Reference in New Issue
Block a user