Add custom tool chat session / message ID dynamic prompting (#2404)

* add custom tool chat session / message ID dynamic prompting

* update some formatting

* code organization + remove unnecessary card

* remove log

* update for clarity
This commit is contained in:
pablodanswer 2024-09-13 11:42:21 -07:00 committed by GitHub
parent fc98c560a4
commit 648c2531f9
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 69 additions and 7 deletions

View File

@ -88,6 +88,7 @@ from danswer.tools.internet_search.internet_search_tool import (
)
from danswer.tools.internet_search.internet_search_tool import InternetSearchResponse
from danswer.tools.internet_search.internet_search_tool import InternetSearchTool
from danswer.tools.models import DynamicSchemaInfo
from danswer.tools.search.search_tool import FINAL_CONTEXT_DOCUMENTS_ID
from danswer.tools.search.search_tool import SEARCH_RESPONSE_SUMMARY_ID
from danswer.tools.search.search_tool import SearchResponseSummary
@ -605,7 +606,11 @@ def stream_chat_message_objects(
tool_dict[db_tool_model.id] = cast(
list[Tool],
build_custom_tools_from_openapi_schema(
db_tool_model.openapi_schema
db_tool_model.openapi_schema,
dynamic_schema_info=DynamicSchemaInfo(
chat_session_id=chat_session_id,
message_id=user_message.id if user_message else None,
),
),
)

View File

@ -24,6 +24,9 @@ from danswer.tools.custom.openapi_parsing import openapi_to_method_specs
from danswer.tools.custom.openapi_parsing import openapi_to_url
from danswer.tools.custom.openapi_parsing import REQUEST_BODY
from danswer.tools.custom.openapi_parsing import validate_openapi_schema
from danswer.tools.models import CHAT_SESSION_ID_PLACEHOLDER
from danswer.tools.models import DynamicSchemaInfo
from danswer.tools.models import MESSAGE_ID_PLACEHOLDER
from danswer.tools.tool import Tool
from danswer.tools.tool import ToolResponse
from danswer.utils.logger import setup_logger
@ -39,7 +42,11 @@ class CustomToolCallSummary(BaseModel):
class CustomTool(Tool):
def __init__(self, method_spec: MethodSpec, base_url: str) -> None:
def __init__(
self,
method_spec: MethodSpec,
base_url: str,
) -> None:
self._base_url = base_url
self._method_spec = method_spec
self._tool_definition = self._method_spec.to_tool_definition()
@ -141,6 +148,7 @@ class CustomTool(Tool):
request_body = kwargs.get(REQUEST_BODY)
path_params = {}
for path_param_schema in self._method_spec.get_path_param_schemas():
path_params[path_param_schema["name"]] = kwargs[path_param_schema["name"]]
@ -168,8 +176,23 @@ class CustomTool(Tool):
def build_custom_tools_from_openapi_schema(
openapi_schema: dict[str, Any]
openapi_schema: dict[str, Any],
dynamic_schema_info: DynamicSchemaInfo | None = None,
) -> list[CustomTool]:
if dynamic_schema_info:
# Process dynamic schema information
schema_str = json.dumps(openapi_schema)
placeholders = {
CHAT_SESSION_ID_PLACEHOLDER: dynamic_schema_info.chat_session_id,
MESSAGE_ID_PLACEHOLDER: dynamic_schema_info.message_id,
}
for placeholder, value in placeholders.items():
if value:
schema_str = schema_str.replace(placeholder, str(value))
openapi_schema = json.loads(schema_str)
url = openapi_to_url(openapi_schema)
method_specs = openapi_to_method_specs(openapi_schema)
return [CustomTool(method_spec, url) for method_spec in method_specs]
@ -223,7 +246,9 @@ if __name__ == "__main__":
}
validate_openapi_schema(openapi_schema)
tools = build_custom_tools_from_openapi_schema(openapi_schema)
tools = build_custom_tools_from_openapi_schema(
openapi_schema, dynamic_schema_info=None
)
openai_client = openai.OpenAI()
response = openai_client.chat.completions.create(

View File

@ -37,3 +37,12 @@ class ToolCallFinalResult(ToolCallKickoff):
tool_result: Any = (
None # we would like to use JSON_ro, but can't due to its recursive nature
)
class DynamicSchemaInfo(BaseModel):
chat_session_id: int | None
message_id: int | None
CHAT_SESSION_ID_PLACEHOLDER = "CHAT_SESSION_ID"
MESSAGE_ID_PLACEHOLDER = "MESSAGE_ID"

View File

@ -6,7 +6,7 @@ import { Formik, Form, Field, ErrorMessage } from "formik";
import * as Yup from "yup";
import { MethodSpec, ToolSnapshot } from "@/lib/tools/interfaces";
import { TextFormField } from "@/components/admin/connectors/Field";
import { Button, Divider } from "@tremor/react";
import { Button, Divider, Text } from "@tremor/react";
import {
createCustomTool,
updateCustomTool,
@ -14,6 +14,7 @@ import {
} from "@/lib/tools/edit";
import { usePopup } from "@/components/admin/connectors/Popup";
import debounce from "lodash/debounce";
import Link from "next/link";
function parseJsonWithTrailingCommas(jsonString: string) {
// Regular expression to remove trailing commas before } or ]
@ -85,8 +86,8 @@ function ToolForm({
}, [values.definition, debouncedValidateDefinition]);
return (
<Form>
<div className="relative">
<Form className="max-w-4xl">
<div className="relative w-full">
<TextFormField
name="definition"
label="Definition"
@ -138,6 +139,28 @@ function ToolForm({
component="div"
className="text-error text-sm"
/>
<div className="mt-4 text-sm bg-blue-50 p-4 rounded-md border border-blue-200">
<Link
href="https://docs.danswer.dev/tools/custom"
className="text-link hover:underline flex items-center"
target="_blank"
rel="noopener noreferrer"
>
<svg
xmlns="http://www.w3.org/2000/svg"
className="h-5 w-5 mr-2"
viewBox="0 0 20 20"
fill="currentColor"
>
<path
fillRule="evenodd"
d="M18 10a8 8 0 11-16 0 8 8 0 0116 0zm-7-4a1 1 0 11-2 0 1 1 0 012 0zM9 9a1 1 0 000 2v3a1 1 0 001 1h1a1 1 0 100-2v-3a1 1 0 00-1-1H9z"
clipRule="evenodd"
/>
</svg>
Learn more about tool calling in our documentation
</Link>
</div>
{methodSpecs && methodSpecs.length > 0 && (
<div className="mt-4">