mirror of
https://github.com/danswer-ai/danswer.git
synced 2025-05-19 08:10:13 +02:00
* Add basic passthrough auth * Add server-side validation * Disallow for non-oauth * Fix npm build
158 lines
4.9 KiB
Python
158 lines
4.9 KiB
Python
from typing import Any
|
|
|
|
from fastapi import APIRouter
|
|
from fastapi import Depends
|
|
from fastapi import HTTPException
|
|
from pydantic import BaseModel
|
|
from sqlalchemy.orm import Session
|
|
|
|
from onyx.auth.users import current_admin_user
|
|
from onyx.auth.users import current_user
|
|
from onyx.db.engine import get_session
|
|
from onyx.db.models import User
|
|
from onyx.db.tools import create_tool
|
|
from onyx.db.tools import delete_tool
|
|
from onyx.db.tools import get_tool_by_id
|
|
from onyx.db.tools import get_tools
|
|
from onyx.db.tools import update_tool
|
|
from onyx.server.features.tool.models import CustomToolCreate
|
|
from onyx.server.features.tool.models import CustomToolUpdate
|
|
from onyx.server.features.tool.models import ToolSnapshot
|
|
from onyx.tools.tool_implementations.custom.openapi_parsing import MethodSpec
|
|
from onyx.tools.tool_implementations.custom.openapi_parsing import (
|
|
openapi_to_method_specs,
|
|
)
|
|
from onyx.tools.tool_implementations.custom.openapi_parsing import (
|
|
validate_openapi_schema,
|
|
)
|
|
from onyx.tools.tool_implementations.images.image_generation_tool import (
|
|
ImageGenerationTool,
|
|
)
|
|
from onyx.tools.utils import is_image_generation_available
|
|
|
|
router = APIRouter(prefix="/tool")
|
|
admin_router = APIRouter(prefix="/admin/tool")
|
|
|
|
|
|
def _validate_tool_definition(definition: dict[str, Any]) -> None:
|
|
try:
|
|
validate_openapi_schema(definition)
|
|
except Exception as e:
|
|
raise HTTPException(status_code=400, detail=str(e))
|
|
|
|
|
|
def _validate_auth_settings(tool_data: CustomToolCreate | CustomToolUpdate) -> None:
|
|
if tool_data.passthrough_auth and tool_data.custom_headers:
|
|
for header in tool_data.custom_headers:
|
|
if header.key.lower() == "authorization":
|
|
raise HTTPException(
|
|
status_code=400,
|
|
detail="Cannot use passthrough auth with custom authorization headers",
|
|
)
|
|
|
|
|
|
@admin_router.post("/custom")
|
|
def create_custom_tool(
|
|
tool_data: CustomToolCreate,
|
|
db_session: Session = Depends(get_session),
|
|
user: User | None = Depends(current_admin_user),
|
|
) -> ToolSnapshot:
|
|
_validate_tool_definition(tool_data.definition)
|
|
_validate_auth_settings(tool_data)
|
|
tool = create_tool(
|
|
name=tool_data.name,
|
|
description=tool_data.description,
|
|
openapi_schema=tool_data.definition,
|
|
custom_headers=tool_data.custom_headers,
|
|
user_id=user.id if user else None,
|
|
db_session=db_session,
|
|
passthrough_auth=tool_data.passthrough_auth,
|
|
)
|
|
return ToolSnapshot.from_model(tool)
|
|
|
|
|
|
@admin_router.put("/custom/{tool_id}")
|
|
def update_custom_tool(
|
|
tool_id: int,
|
|
tool_data: CustomToolUpdate,
|
|
db_session: Session = Depends(get_session),
|
|
user: User | None = Depends(current_admin_user),
|
|
) -> ToolSnapshot:
|
|
if tool_data.definition:
|
|
_validate_tool_definition(tool_data.definition)
|
|
_validate_auth_settings(tool_data)
|
|
updated_tool = update_tool(
|
|
tool_id=tool_id,
|
|
name=tool_data.name,
|
|
description=tool_data.description,
|
|
openapi_schema=tool_data.definition,
|
|
custom_headers=tool_data.custom_headers,
|
|
user_id=user.id if user else None,
|
|
db_session=db_session,
|
|
passthrough_auth=tool_data.passthrough_auth,
|
|
)
|
|
return ToolSnapshot.from_model(updated_tool)
|
|
|
|
|
|
@admin_router.delete("/custom/{tool_id}")
|
|
def delete_custom_tool(
|
|
tool_id: int,
|
|
db_session: Session = Depends(get_session),
|
|
_: User | None = Depends(current_admin_user),
|
|
) -> None:
|
|
try:
|
|
delete_tool(tool_id, db_session)
|
|
except ValueError as e:
|
|
raise HTTPException(status_code=404, detail=str(e))
|
|
except Exception as e:
|
|
# handles case where tool is still used by an Assistant
|
|
raise HTTPException(status_code=400, detail=str(e))
|
|
|
|
|
|
class ValidateToolRequest(BaseModel):
|
|
definition: dict[str, Any]
|
|
|
|
|
|
class ValidateToolResponse(BaseModel):
|
|
methods: list[MethodSpec]
|
|
|
|
|
|
@admin_router.post("/custom/validate")
|
|
def validate_tool(
|
|
tool_data: ValidateToolRequest,
|
|
_: User | None = Depends(current_admin_user),
|
|
) -> ValidateToolResponse:
|
|
_validate_tool_definition(tool_data.definition)
|
|
method_specs = openapi_to_method_specs(tool_data.definition)
|
|
return ValidateToolResponse(methods=method_specs)
|
|
|
|
|
|
"""Endpoints for all"""
|
|
|
|
|
|
@router.get("/{tool_id}")
|
|
def get_custom_tool(
|
|
tool_id: int,
|
|
db_session: Session = Depends(get_session),
|
|
_: User | None = Depends(current_user),
|
|
) -> ToolSnapshot:
|
|
try:
|
|
tool = get_tool_by_id(tool_id, db_session)
|
|
except ValueError as e:
|
|
raise HTTPException(status_code=404, detail=str(e))
|
|
return ToolSnapshot.from_model(tool)
|
|
|
|
|
|
@router.get("")
|
|
def list_tools(
|
|
db_session: Session = Depends(get_session),
|
|
_: User | None = Depends(current_user),
|
|
) -> list[ToolSnapshot]:
|
|
tools = get_tools(db_session)
|
|
return [
|
|
ToolSnapshot.from_model(tool)
|
|
for tool in tools
|
|
if tool.in_code_tool_id != ImageGenerationTool.name
|
|
or is_image_generation_available(db_session=db_session)
|
|
]
|