Add additional custom tooling configuration (#2426)

* add custom headers

* add tool seeding

* squash

* tmep

* validated

* rm

* update typing

* update alembic

* update import name

* reformat

* alembic
This commit is contained in:
pablodanswer 2024-09-20 16:12:52 -07:00 committed by GitHub
parent 33f555922c
commit 18c62a0c24
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
12 changed files with 242 additions and 25 deletions

View File

@ -0,0 +1,26 @@
"""add custom headers to tools
Revision ID: f32615f71aeb
Revises: bd2921608c3a
Create Date: 2024-09-12 20:26:38.932377
"""
from alembic import op
import sqlalchemy as sa
from sqlalchemy.dialects import postgresql
# revision identifiers, used by Alembic.
revision = "f32615f71aeb"
down_revision = "bd2921608c3a"
branch_labels = None
depends_on = None
def upgrade() -> None:
op.add_column(
"tool", sa.Column("custom_headers", postgresql.JSONB(), nullable=True)
)
def downgrade() -> None:
op.drop_column("tool", "custom_headers")

View File

@ -73,7 +73,9 @@ from danswer.server.query_and_chat.models import ChatMessageDetail
from danswer.server.query_and_chat.models import CreateChatMessageRequest
from danswer.server.utils import get_json_line
from danswer.tools.built_in_tools import get_built_in_tool_by_id
from danswer.tools.custom.custom_tool import build_custom_tools_from_openapi_schema
from danswer.tools.custom.custom_tool import (
build_custom_tools_from_openapi_schema_and_headers,
)
from danswer.tools.custom.custom_tool import CUSTOM_TOOL_RESPONSE_ID
from danswer.tools.custom.custom_tool import CustomToolCallSummary
from danswer.tools.force import ForceUseTool
@ -607,12 +609,13 @@ def stream_chat_message_objects(
if db_tool_model.openapi_schema:
tool_dict[db_tool_model.id] = cast(
list[Tool],
build_custom_tools_from_openapi_schema(
build_custom_tools_from_openapi_schema_and_headers(
db_tool_model.openapi_schema,
dynamic_schema_info=DynamicSchemaInfo(
chat_session_id=chat_session_id,
message_id=user_message.id if user_message else None,
),
custom_headers=db_tool_model.custom_headers,
),
)

View File

@ -1255,7 +1255,9 @@ class Tool(Base):
openapi_schema: Mapped[dict[str, Any] | None] = mapped_column(
postgresql.JSONB(), nullable=True
)
custom_headers: Mapped[list[dict[str, str]] | None] = mapped_column(
postgresql.JSONB(), nullable=True
)
# user who created / owns the tool. Will be None for built-in tools.
user_id: Mapped[UUID | None] = mapped_column(
ForeignKey("user.id", ondelete="CASCADE"), nullable=True

View File

@ -5,6 +5,7 @@ from sqlalchemy import select
from sqlalchemy.orm import Session
from danswer.db.models import Tool
from danswer.server.features.tool.models import Header
from danswer.utils.logger import setup_logger
logger = setup_logger()
@ -25,6 +26,7 @@ def create_tool(
name: str,
description: str | None,
openapi_schema: dict[str, Any] | None,
custom_headers: list[Header] | None,
user_id: UUID | None,
db_session: Session,
) -> Tool:
@ -33,6 +35,9 @@ def create_tool(
description=description,
in_code_tool_id=None,
openapi_schema=openapi_schema,
custom_headers=[header.dict() for header in custom_headers]
if custom_headers
else [],
user_id=user_id,
)
db_session.add(new_tool)
@ -45,6 +50,7 @@ def update_tool(
name: str | None,
description: str | None,
openapi_schema: dict[str, Any] | None,
custom_headers: list[Header] | None,
user_id: UUID | None,
db_session: Session,
) -> Tool:
@ -60,6 +66,8 @@ def update_tool(
tool.openapi_schema = openapi_schema
if user_id is not None:
tool.user_id = user_id
if custom_headers is not None:
tool.custom_headers = [header.dict() for header in custom_headers]
db_session.commit()
return tool

View File

@ -15,6 +15,8 @@ from danswer.db.tools import delete_tool
from danswer.db.tools import get_tool_by_id
from danswer.db.tools import get_tools
from danswer.db.tools import update_tool
from danswer.server.features.tool.models import CustomToolCreate
from danswer.server.features.tool.models import CustomToolUpdate
from danswer.server.features.tool.models import ToolSnapshot
from danswer.tools.custom.openapi_parsing import MethodSpec
from danswer.tools.custom.openapi_parsing import openapi_to_method_specs
@ -24,18 +26,6 @@ router = APIRouter(prefix="/tool")
admin_router = APIRouter(prefix="/admin/tool")
class CustomToolCreate(BaseModel):
name: str
description: str | None = None
definition: dict[str, Any]
class CustomToolUpdate(BaseModel):
name: str | None = None
description: str | None = None
definition: dict[str, Any] | None = None
def _validate_tool_definition(definition: dict[str, Any]) -> None:
try:
validate_openapi_schema(definition)
@ -54,6 +44,7 @@ def create_custom_tool(
name=tool_data.name,
description=tool_data.description,
openapi_schema=tool_data.definition,
custom_headers=tool_data.custom_headers,
user_id=user.id if user else None,
db_session=db_session,
)
@ -74,6 +65,7 @@ def update_custom_tool(
name=tool_data.name,
description=tool_data.description,
openapi_schema=tool_data.definition,
custom_headers=tool_data.custom_headers,
user_id=user.id if user else None,
db_session=db_session,
)

View File

@ -12,6 +12,7 @@ class ToolSnapshot(BaseModel):
definition: dict[str, Any] | None
display_name: str
in_code_tool_id: str | None
custom_headers: list[Any] | None
@classmethod
def from_model(cls, tool: Tool) -> "ToolSnapshot":
@ -22,4 +23,24 @@ class ToolSnapshot(BaseModel):
definition=tool.openapi_schema,
display_name=tool.display_name or tool.name,
in_code_tool_id=tool.in_code_tool_id,
custom_headers=tool.custom_headers,
)
class Header(BaseModel):
key: str
value: str
class CustomToolCreate(BaseModel):
name: str
description: str | None = None
definition: dict[str, Any]
custom_headers: list[Header] | None = None
class CustomToolUpdate(BaseModel):
name: str | None = None
description: str | None = None
definition: dict[str, Any] | None = None
custom_headers: list[Header] | None = None

View File

@ -46,6 +46,7 @@ class CustomTool(Tool):
self,
method_spec: MethodSpec,
base_url: str,
custom_headers: list[dict[str, str]] | None = [],
) -> None:
self._base_url = base_url
self._method_spec = method_spec
@ -53,6 +54,11 @@ class CustomTool(Tool):
self._name = self._method_spec.name
self._description = self._method_spec.summary
self.headers = (
{header["key"]: header["value"] for header in custom_headers}
if custom_headers
else {}
)
@property
def name(self) -> str:
@ -161,8 +167,10 @@ class CustomTool(Tool):
url = self._method_spec.build_url(self._base_url, path_params, query_params)
method = self._method_spec.method
response = requests.request(method, url, json=request_body)
# Log request details
response = requests.request(
method, url, json=request_body, headers=self.headers
)
yield ToolResponse(
id=CUSTOM_TOOL_RESPONSE_ID,
@ -175,8 +183,9 @@ class CustomTool(Tool):
return cast(CustomToolCallSummary, args[0].response).tool_result
def build_custom_tools_from_openapi_schema(
def build_custom_tools_from_openapi_schema_and_headers(
openapi_schema: dict[str, Any],
custom_headers: list[dict[str, str]] | None = [],
dynamic_schema_info: DynamicSchemaInfo | None = None,
) -> list[CustomTool]:
if dynamic_schema_info:
@ -195,7 +204,9 @@ def build_custom_tools_from_openapi_schema(
url = openapi_to_url(openapi_schema)
method_specs = openapi_to_method_specs(openapi_schema)
return [CustomTool(method_spec, url) for method_spec in method_specs]
return [
CustomTool(method_spec, url, custom_headers) for method_spec in method_specs
]
if __name__ == "__main__":
@ -246,7 +257,7 @@ if __name__ == "__main__":
}
validate_openapi_schema(openapi_schema)
tools = build_custom_tools_from_openapi_schema(
tools = build_custom_tools_from_openapi_schema_and_headers(
openapi_schema, dynamic_schema_info=None
)

View File

@ -12,7 +12,9 @@ from danswer.db.models import Tool
from danswer.db.models import User
from danswer.db.persona import get_prompts_by_ids
from danswer.one_shot_answer.models import PersonaConfig
from danswer.tools.custom.custom_tool import build_custom_tools_from_openapi_schema
from danswer.tools.custom.custom_tool import (
build_custom_tools_from_openapi_schema_and_headers,
)
def create_temporary_persona(
@ -58,7 +60,7 @@ def create_temporary_persona(
for schema in persona_config.custom_tools_openapi:
tools = cast(
list[Tool],
build_custom_tools_from_openapi_schema(schema),
build_custom_tools_from_openapi_schema_and_headers(schema),
)
persona.tools.extend(tools)

View File

@ -1,4 +1,7 @@
import json
import os
from typing import List
from typing import Optional
from pydantic import BaseModel
from sqlalchemy.orm import Session
@ -6,6 +9,7 @@ from sqlalchemy.orm import Session
from danswer.db.engine import get_session_context_manager
from danswer.db.llm import update_default_provider
from danswer.db.llm import upsert_llm_provider
from danswer.db.models import Tool
from danswer.db.persona import upsert_persona
from danswer.search.enums import RecencyBiasSetting
from danswer.server.features.persona.models import CreatePersonaRequest
@ -25,6 +29,16 @@ from ee.danswer.server.enterprise_settings.store import (
from ee.danswer.server.enterprise_settings.store import upload_logo
class CustomToolSeed(BaseModel):
name: str
description: str
definition_path: str
custom_headers: Optional[List[dict]] = None
display_name: Optional[str] = None
in_code_tool_id: Optional[str] = None
user_id: Optional[str] = None
logger = setup_logger()
_SEED_CONFIG_ENV_VAR_NAME = "ENV_SEED_CONFIGURATION"
@ -39,6 +53,7 @@ class SeedConfiguration(BaseModel):
enterprise_settings: EnterpriseSettings | None = None
# Use existing `CUSTOM_ANALYTICS_SECRET_KEY` for reference
analytics_script_path: str | None = None
custom_tools: List[CustomToolSeed] | None = None
def _parse_env() -> SeedConfiguration | None:
@ -49,6 +64,43 @@ def _parse_env() -> SeedConfiguration | None:
return seed_config
def _seed_custom_tools(db_session: Session, tools: List[CustomToolSeed]) -> None:
if tools:
logger.notice("Seeding Custom Tools")
for tool in tools:
try:
logger.debug(f"Attempting to seed tool: {tool.name}")
logger.debug(f"Reading definition from: {tool.definition_path}")
with open(tool.definition_path, "r") as file:
file_content = file.read()
if not file_content.strip():
raise ValueError("File is empty")
openapi_schema = json.loads(file_content)
db_tool = Tool(
name=tool.name,
description=tool.description,
openapi_schema=openapi_schema,
custom_headers=tool.custom_headers,
display_name=tool.display_name,
in_code_tool_id=tool.in_code_tool_id,
user_id=tool.user_id,
)
db_session.add(db_tool)
logger.debug(f"Successfully added tool: {tool.name}")
except FileNotFoundError:
logger.error(
f"Definition file not found for tool {tool.name}: {tool.definition_path}"
)
except json.JSONDecodeError as e:
logger.error(
f"Invalid JSON in definition file for tool {tool.name}: {str(e)}"
)
except Exception as e:
logger.error(f"Failed to seed tool {tool.name}: {str(e)}")
db_session.commit()
logger.notice(f"Successfully seeded {len(tools)} Custom Tools")
def _seed_llms(
db_session: Session, llm_upsert_requests: list[LLMProviderUpsertRequest]
) -> None:
@ -147,6 +199,8 @@ def seed_db() -> None:
_seed_personas(db_session, seed_config.personas)
if seed_config.settings is not None:
_seed_settings(seed_config.settings)
if seed_config.custom_tools is not None:
_seed_custom_tools(db_session, seed_config.custom_tools)
_seed_logo(db_session, seed_config.seeded_logo_path)
_seed_enterprise_settings(seed_config)

View File

@ -2,7 +2,14 @@
import { useState, useEffect, useCallback } from "react";
import { useRouter } from "next/navigation";
import { Formik, Form, Field, ErrorMessage } from "formik";
import {
Formik,
Form,
Field,
ErrorMessage,
FieldArray,
ArrayHelpers,
} from "formik";
import * as Yup from "yup";
import { MethodSpec, ToolSnapshot } from "@/lib/tools/interfaces";
import { TextFormField } from "@/components/admin/connectors/Field";
@ -14,6 +21,7 @@ import {
} from "@/lib/tools/edit";
import { usePopup } from "@/components/admin/connectors/Popup";
import debounce from "lodash/debounce";
import { AdvancedOptionsToggle } from "@/components/AdvancedOptionsToggle";
import Link from "next/link";
function parseJsonWithTrailingCommas(jsonString: string) {
@ -55,6 +63,7 @@ function ToolForm({
}) {
const [definitionError, setDefinitionError] = definitionErrorState;
const [methodSpecs, setMethodSpecs] = methodSpecsState;
const [showAdvancedOptions, setShowAdvancedOptions] = useState(false);
const debouncedValidateDefinition = useCallback(
debounce(async (definition: string) => {
@ -137,7 +146,7 @@ function ToolForm({
<ErrorMessage
name="definition"
component="div"
className="text-error text-sm"
className="mb-4 text-error text-sm"
/>
<div className="mt-4 text-sm bg-blue-50 p-4 rounded-md border border-blue-200">
<Link
@ -163,7 +172,7 @@ function ToolForm({
</div>
{methodSpecs && methodSpecs.length > 0 && (
<div className="mt-4">
<div className="my-4">
<h3 className="text-base font-semibold mb-2">Available methods</h3>
<div className="overflow-x-auto">
<table className="min-w-full bg-white border border-gray-200">
@ -192,7 +201,75 @@ function ToolForm({
</div>
)}
<AdvancedOptionsToggle
showAdvancedOptions={showAdvancedOptions}
setShowAdvancedOptions={setShowAdvancedOptions}
/>
{showAdvancedOptions && (
<div>
<h3 className="text-xl font-bold mb-2 text-primary-600">
Custom Headers
</h3>
<p className="text-sm mb-6 text-gray-600 italic">
Specify custom headers for each request to this tool&apos;s API.
</p>
<FieldArray
name="customHeaders"
render={(arrayHelpers: ArrayHelpers) => (
<div className="space-y-4">
{values.customHeaders && values.customHeaders.length > 0 && (
<div className="space-y-3">
{values.customHeaders.map(
(
header: { key: string; value: string },
index: number
) => (
<div
key={index}
className="flex items-center space-x-2 bg-gray-50 p-3 rounded-lg shadow-sm"
>
<Field
name={`customHeaders.${index}.key`}
placeholder="Header Key"
className="flex-1 p-2 border border-gray-300 rounded-md focus:ring-2 focus:ring-primary-500 focus:border-transparent"
/>
<Field
name={`customHeaders.${index}.value`}
placeholder="Header Value"
className="flex-1 p-2 border border-gray-300 rounded-md focus:ring-2 focus:ring-primary-500 focus:border-transparent"
/>
<Button
type="button"
onClick={() => arrayHelpers.remove(index)}
color="red"
size="sm"
className="transition-colors duration-200 hover:bg-red-600"
>
Remove
</Button>
</div>
)
)}
</div>
)}
<Button
type="button"
onClick={() => arrayHelpers.push({ key: "", value: "" })}
color="blue"
size="md"
className="transition-colors duration-200"
>
Add New Header
</Button>
</div>
)}
/>
</div>
)}
<Divider />
<div className="flex">
<Button
className="mx-auto"
@ -210,10 +287,19 @@ function ToolForm({
interface ToolFormValues {
definition: string;
customHeaders: { key: string; value: string }[];
}
const ToolSchema = Yup.object().shape({
definition: Yup.string().required("Tool definition is required"),
customHeaders: Yup.array()
.of(
Yup.object().shape({
key: Yup.string().required("Header key is required"),
value: Yup.string().required("Header value is required"),
})
)
.default([]),
});
export function ToolEditor({ tool }: { tool?: ToolSnapshot }) {
@ -232,6 +318,10 @@ export function ToolEditor({ tool }: { tool?: ToolSnapshot }) {
<Formik
initialValues={{
definition: prettifiedDefinition,
customHeaders: tool?.custom_headers?.map((header) => ({
key: header.key,
value: header.value,
})) ?? [{ key: "test", value: "value" }],
}}
validationSchema={ToolSchema}
onSubmit={async (values: ToolFormValues) => {
@ -249,6 +339,7 @@ export function ToolEditor({ tool }: { tool?: ToolSnapshot }) {
name: name,
description: description || "",
definition: definition,
custom_headers: values.customHeaders,
};
let response;
if (tool) {

View File

@ -1,3 +1,4 @@
import { Header } from "next/dist/lib/load-custom-routes";
import { MethodSpec, ToolSnapshot } from "./interfaces";
interface ApiResponse<T> {
@ -9,6 +10,7 @@ export async function createCustomTool(toolData: {
name: string;
description?: string;
definition: Record<string, any>;
custom_headers: { key: string; value: string }[];
}): Promise<ApiResponse<ToolSnapshot>> {
try {
const response = await fetch("/api/admin/tool/custom", {
@ -38,6 +40,7 @@ export async function updateCustomTool(
name?: string;
description?: string;
definition?: Record<string, any>;
custom_headers: { key: string; value: string }[];
}
): Promise<ApiResponse<ToolSnapshot>> {
try {

View File

@ -8,6 +8,9 @@ export interface ToolSnapshot {
// the tool's API.
definition: Record<string, any> | null;
// only specified for Custom Tools. Custom headers to add to the tool's API requests.
custom_headers: { key: string; value: string }[];
// only specified for Custom Tools. ID of the tool in the codebase.
in_code_tool_id: string | null;
}
@ -20,4 +23,5 @@ export interface MethodSpec {
path: string;
method: string;
spec: Record<string, any>;
custom_headers: { key: string; value: string }[];
}