mirror of
https://github.com/danswer-ai/danswer.git
synced 2025-03-26 01:31:51 +01:00
* Add basic passthrough auth * Add server-side validation * Disallow for non-oauth * Fix npm build
99 lines
2.7 KiB
Python
99 lines
2.7 KiB
Python
from typing import Any
|
|
from typing import cast
|
|
from uuid import UUID
|
|
|
|
from sqlalchemy import select
|
|
from sqlalchemy.orm import Session
|
|
|
|
from onyx.db.models import Tool
|
|
from onyx.server.features.tool.models import Header
|
|
from onyx.utils.headers import HeaderItemDict
|
|
from onyx.utils.logger import setup_logger
|
|
|
|
logger = setup_logger()
|
|
|
|
|
|
def get_tools(db_session: Session) -> list[Tool]:
|
|
return list(db_session.scalars(select(Tool)).all())
|
|
|
|
|
|
def get_tool_by_id(tool_id: int, db_session: Session) -> Tool:
|
|
tool = db_session.scalar(select(Tool).where(Tool.id == tool_id))
|
|
if not tool:
|
|
raise ValueError("Tool by specified id does not exist")
|
|
return tool
|
|
|
|
|
|
def get_tool_by_name(tool_name: str, db_session: Session) -> Tool:
|
|
tool = db_session.scalar(select(Tool).where(Tool.name == tool_name))
|
|
if not tool:
|
|
raise ValueError("Tool by specified name does not exist")
|
|
return tool
|
|
|
|
|
|
def create_tool(
|
|
name: str,
|
|
description: str | None,
|
|
openapi_schema: dict[str, Any] | None,
|
|
custom_headers: list[Header] | None,
|
|
user_id: UUID | None,
|
|
db_session: Session,
|
|
passthrough_auth: bool,
|
|
) -> Tool:
|
|
new_tool = Tool(
|
|
name=name,
|
|
description=description,
|
|
in_code_tool_id=None,
|
|
openapi_schema=openapi_schema,
|
|
custom_headers=[header.model_dump() for header in custom_headers]
|
|
if custom_headers
|
|
else [],
|
|
user_id=user_id,
|
|
passthrough_auth=passthrough_auth,
|
|
)
|
|
db_session.add(new_tool)
|
|
db_session.commit()
|
|
return new_tool
|
|
|
|
|
|
def update_tool(
|
|
tool_id: int,
|
|
name: str | None,
|
|
description: str | None,
|
|
openapi_schema: dict[str, Any] | None,
|
|
custom_headers: list[Header] | None,
|
|
user_id: UUID | None,
|
|
db_session: Session,
|
|
passthrough_auth: bool | None,
|
|
) -> Tool:
|
|
tool = get_tool_by_id(tool_id, db_session)
|
|
if tool is None:
|
|
raise ValueError(f"Tool with ID {tool_id} does not exist")
|
|
|
|
if name is not None:
|
|
tool.name = name
|
|
if description is not None:
|
|
tool.description = description
|
|
if openapi_schema is not None:
|
|
tool.openapi_schema = openapi_schema
|
|
if user_id is not None:
|
|
tool.user_id = user_id
|
|
if custom_headers is not None:
|
|
tool.custom_headers = [
|
|
cast(HeaderItemDict, header.model_dump()) for header in custom_headers
|
|
]
|
|
if passthrough_auth is not None:
|
|
tool.passthrough_auth = passthrough_auth
|
|
db_session.commit()
|
|
|
|
return tool
|
|
|
|
|
|
def delete_tool(tool_id: int, db_session: Session) -> None:
|
|
tool = get_tool_by_id(tool_id, db_session)
|
|
if tool is None:
|
|
raise ValueError(f"Tool with ID {tool_id} does not exist")
|
|
|
|
db_session.delete(tool)
|
|
db_session.commit()
|