danswer/backend/onyx/db/tools.py
rkuo-danswer 24184024bb
Bugfix/dependency updates (#4482)
* bump fastapi and starlette

* bumping llama index and nltk and associated deps

* bump to fix python-multipart

* bump aiohttp

* update package lock for examples/widget

* bump black

* sentencesplitter has changed namespaces

* fix reorder import check, fix missing passlib

* update package-lock.json

* black formatter updated

* reformatted again

* change to black compatible reorder

* change to black compatible reorder-python-imports fork

* fix pytest dependency

* black format again

* we don't need cdk.txt. update packages to be consistent across all packages

---------

Co-authored-by: Richard Kuo (Onyx) <rkuo@onyx.app>
Co-authored-by: Richard Kuo <rkuo@rkuo.com>
2025-04-10 08:23:02 +00:00

99 lines
2.7 KiB
Python

from typing import Any
from typing import cast
from uuid import UUID
from sqlalchemy import select
from sqlalchemy.orm import Session
from onyx.db.models import Tool
from onyx.server.features.tool.models import Header
from onyx.utils.headers import HeaderItemDict
from onyx.utils.logger import setup_logger
logger = setup_logger()
def get_tools(db_session: Session) -> list[Tool]:
return list(db_session.scalars(select(Tool)).all())
def get_tool_by_id(tool_id: int, db_session: Session) -> Tool:
tool = db_session.scalar(select(Tool).where(Tool.id == tool_id))
if not tool:
raise ValueError("Tool by specified id does not exist")
return tool
def get_tool_by_name(tool_name: str, db_session: Session) -> Tool:
tool = db_session.scalar(select(Tool).where(Tool.name == tool_name))
if not tool:
raise ValueError("Tool by specified name does not exist")
return tool
def create_tool(
name: str,
description: str | None,
openapi_schema: dict[str, Any] | None,
custom_headers: list[Header] | None,
user_id: UUID | None,
db_session: Session,
passthrough_auth: bool,
) -> Tool:
new_tool = Tool(
name=name,
description=description,
in_code_tool_id=None,
openapi_schema=openapi_schema,
custom_headers=(
[header.model_dump() for header in custom_headers] if custom_headers else []
),
user_id=user_id,
passthrough_auth=passthrough_auth,
)
db_session.add(new_tool)
db_session.commit()
return new_tool
def update_tool(
tool_id: int,
name: str | None,
description: str | None,
openapi_schema: dict[str, Any] | None,
custom_headers: list[Header] | None,
user_id: UUID | None,
db_session: Session,
passthrough_auth: bool | None,
) -> Tool:
tool = get_tool_by_id(tool_id, db_session)
if tool is None:
raise ValueError(f"Tool with ID {tool_id} does not exist")
if name is not None:
tool.name = name
if description is not None:
tool.description = description
if openapi_schema is not None:
tool.openapi_schema = openapi_schema
if user_id is not None:
tool.user_id = user_id
if custom_headers is not None:
tool.custom_headers = [
cast(HeaderItemDict, header.model_dump()) for header in custom_headers
]
if passthrough_auth is not None:
tool.passthrough_auth = passthrough_auth
db_session.commit()
return tool
def delete_tool(tool_id: int, db_session: Session) -> None:
tool = get_tool_by_id(tool_id, db_session)
if tool is None:
raise ValueError(f"Tool with ID {tool_id} does not exist")
db_session.delete(tool)
db_session.commit()