Add basic passthrough auth (#3731)

* Add basic passthrough auth

* Add server-side validation

* Disallow for non-oauth

* Fix npm build
This commit is contained in:
Chris Weaver
2025-01-20 23:39:23 -08:00
committed by GitHub
parent 4ca7325d1a
commit 420476ad92
11 changed files with 251 additions and 44 deletions

View File

@@ -0,0 +1,33 @@
"""add passthrough auth to tool
Revision ID: f1ca58b2f2ec
Revises: c7bf5721733e
Create Date: 2024-03-19
"""
from typing import Sequence, Union
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision: str = "f1ca58b2f2ec"
down_revision: Union[str, None] = "c7bf5721733e"
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None
def upgrade() -> None:
# Add passthrough_auth column to tool table with default value of False
op.add_column(
"tool",
sa.Column(
"passthrough_auth", sa.Boolean(), nullable=False, server_default=sa.false()
),
)
def downgrade() -> None:
# Remove passthrough_auth column from tool table
op.drop_column("tool", "passthrough_auth")

View File

@@ -1430,6 +1430,8 @@ class Tool(Base):
user_id: Mapped[UUID | None] = mapped_column(
ForeignKey("user.id", ondelete="CASCADE"), nullable=True
)
# whether to pass through the user's OAuth token as Authorization header
passthrough_auth: Mapped[bool] = mapped_column(Boolean, default=False)
user: Mapped[User | None] = relationship("User", back_populates="custom_tools")
# Relationship to Persona through the association table

View File

@@ -38,6 +38,7 @@ def create_tool(
custom_headers: list[Header] | None,
user_id: UUID | None,
db_session: Session,
passthrough_auth: bool,
) -> Tool:
new_tool = Tool(
name=name,
@@ -48,6 +49,7 @@ def create_tool(
if custom_headers
else [],
user_id=user_id,
passthrough_auth=passthrough_auth,
)
db_session.add(new_tool)
db_session.commit()
@@ -62,6 +64,7 @@ def update_tool(
custom_headers: list[Header] | None,
user_id: UUID | None,
db_session: Session,
passthrough_auth: bool | None,
) -> Tool:
tool = get_tool_by_id(tool_id, db_session)
if tool is None:
@@ -79,6 +82,8 @@ def update_tool(
tool.custom_headers = [
cast(HeaderItemDict, header.model_dump()) for header in custom_headers
]
if passthrough_auth is not None:
tool.passthrough_auth = passthrough_auth
db_session.commit()
return tool

View File

@@ -41,6 +41,16 @@ def _validate_tool_definition(definition: dict[str, Any]) -> None:
raise HTTPException(status_code=400, detail=str(e))
def _validate_auth_settings(tool_data: CustomToolCreate | CustomToolUpdate) -> None:
if tool_data.passthrough_auth and tool_data.custom_headers:
for header in tool_data.custom_headers:
if header.key.lower() == "authorization":
raise HTTPException(
status_code=400,
detail="Cannot use passthrough auth with custom authorization headers",
)
@admin_router.post("/custom")
def create_custom_tool(
tool_data: CustomToolCreate,
@@ -48,6 +58,7 @@ def create_custom_tool(
user: User | None = Depends(current_admin_user),
) -> ToolSnapshot:
_validate_tool_definition(tool_data.definition)
_validate_auth_settings(tool_data)
tool = create_tool(
name=tool_data.name,
description=tool_data.description,
@@ -55,6 +66,7 @@ def create_custom_tool(
custom_headers=tool_data.custom_headers,
user_id=user.id if user else None,
db_session=db_session,
passthrough_auth=tool_data.passthrough_auth,
)
return ToolSnapshot.from_model(tool)
@@ -68,6 +80,7 @@ def update_custom_tool(
) -> ToolSnapshot:
if tool_data.definition:
_validate_tool_definition(tool_data.definition)
_validate_auth_settings(tool_data)
updated_tool = update_tool(
tool_id=tool_id,
name=tool_data.name,
@@ -76,6 +89,7 @@ def update_custom_tool(
custom_headers=tool_data.custom_headers,
user_id=user.id if user else None,
db_session=db_session,
passthrough_auth=tool_data.passthrough_auth,
)
return ToolSnapshot.from_model(updated_tool)

View File

@@ -13,6 +13,7 @@ class ToolSnapshot(BaseModel):
display_name: str
in_code_tool_id: str | None
custom_headers: list[Any] | None
passthrough_auth: bool
@classmethod
def from_model(cls, tool: Tool) -> "ToolSnapshot":
@@ -24,6 +25,7 @@ class ToolSnapshot(BaseModel):
display_name=tool.display_name or tool.name,
in_code_tool_id=tool.in_code_tool_id,
custom_headers=tool.custom_headers,
passthrough_auth=tool.passthrough_auth,
)
@@ -37,6 +39,7 @@ class CustomToolCreate(BaseModel):
description: str | None = None
definition: dict[str, Any]
custom_headers: list[Header] | None = None
passthrough_auth: bool
class CustomToolUpdate(BaseModel):
@@ -44,3 +47,4 @@ class CustomToolUpdate(BaseModel):
description: str | None = None
definition: dict[str, Any] | None = None
custom_headers: list[Header] | None = None
passthrough_auth: bool | None = None

View File

@@ -146,6 +146,11 @@ def construct_tools(
"""Constructs tools based on persona configuration and available APIs"""
tool_dict: dict[int, list[Tool]] = {}
# Get user's OAuth token if available
user_oauth_token = None
if user and user.oauth_accounts:
user_oauth_token = user.oauth_accounts[0].access_token
for db_tool_model in persona.tools:
if db_tool_model.in_code_tool_id:
tool_cls = get_built_in_tool_by_id(
@@ -236,6 +241,9 @@ def construct_tools(
custom_tool_config.additional_headers or {}
)
),
user_oauth_token=(
user_oauth_token if db_tool_model.passthrough_auth else None
),
),
)

View File

@@ -80,10 +80,12 @@ class CustomTool(BaseTool):
method_spec: MethodSpec,
base_url: str,
custom_headers: list[HeaderItemDict] | None = None,
user_oauth_token: str | None = None,
) -> None:
self._base_url = base_url
self._method_spec = method_spec
self._tool_definition = self._method_spec.to_tool_definition()
self._user_oauth_token = user_oauth_token
self._name = self._method_spec.name
self._description = self._method_spec.summary
@@ -91,6 +93,20 @@ class CustomTool(BaseTool):
header_list_to_header_dict(custom_headers) if custom_headers else {}
)
# Check for both Authorization header and OAuth token
has_auth_header = any(
key.lower() == "authorization" for key in self.headers.keys()
)
if has_auth_header and self._user_oauth_token:
logger.warning(
f"Tool '{self._name}' has both an Authorization "
"header and OAuth token set. This is likely a configuration "
"error as the OAuth token will override the custom header."
)
if self._user_oauth_token:
self.headers["Authorization"] = f"Bearer {self._user_oauth_token}"
@property
def name(self) -> str:
return self._name
@@ -348,6 +364,7 @@ def build_custom_tools_from_openapi_schema_and_headers(
openapi_schema: dict[str, Any],
custom_headers: list[HeaderItemDict] | None = None,
dynamic_schema_info: DynamicSchemaInfo | None = None,
user_oauth_token: str | None = None,
) -> list[CustomTool]:
if dynamic_schema_info:
# Process dynamic schema information
@@ -366,7 +383,13 @@ def build_custom_tools_from_openapi_schema_and_headers(
url = openapi_to_url(openapi_schema)
method_specs = openapi_to_method_specs(openapi_schema)
return [
CustomTool(method_spec, url, custom_headers) for method_spec in method_specs
CustomTool(
method_spec,
url,
custom_headers,
user_oauth_token=user_oauth_token,
)
for method_spec in method_specs
]