mirror of
https://github.com/danswer-ai/danswer.git
synced 2025-09-27 12:29:41 +02:00
Add basic passthrough auth (#3731)
* Add basic passthrough auth * Add server-side validation * Disallow for non-oauth * Fix npm build
This commit is contained in:
@@ -0,0 +1,33 @@
|
||||
"""add passthrough auth to tool
|
||||
|
||||
Revision ID: f1ca58b2f2ec
|
||||
Revises: c7bf5721733e
|
||||
Create Date: 2024-03-19
|
||||
|
||||
"""
|
||||
from typing import Sequence, Union
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision: str = "f1ca58b2f2ec"
|
||||
down_revision: Union[str, None] = "c7bf5721733e"
|
||||
branch_labels: Union[str, Sequence[str], None] = None
|
||||
depends_on: Union[str, Sequence[str], None] = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# Add passthrough_auth column to tool table with default value of False
|
||||
op.add_column(
|
||||
"tool",
|
||||
sa.Column(
|
||||
"passthrough_auth", sa.Boolean(), nullable=False, server_default=sa.false()
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
# Remove passthrough_auth column from tool table
|
||||
op.drop_column("tool", "passthrough_auth")
|
@@ -1430,6 +1430,8 @@ class Tool(Base):
|
||||
user_id: Mapped[UUID | None] = mapped_column(
|
||||
ForeignKey("user.id", ondelete="CASCADE"), nullable=True
|
||||
)
|
||||
# whether to pass through the user's OAuth token as Authorization header
|
||||
passthrough_auth: Mapped[bool] = mapped_column(Boolean, default=False)
|
||||
|
||||
user: Mapped[User | None] = relationship("User", back_populates="custom_tools")
|
||||
# Relationship to Persona through the association table
|
||||
|
@@ -38,6 +38,7 @@ def create_tool(
|
||||
custom_headers: list[Header] | None,
|
||||
user_id: UUID | None,
|
||||
db_session: Session,
|
||||
passthrough_auth: bool,
|
||||
) -> Tool:
|
||||
new_tool = Tool(
|
||||
name=name,
|
||||
@@ -48,6 +49,7 @@ def create_tool(
|
||||
if custom_headers
|
||||
else [],
|
||||
user_id=user_id,
|
||||
passthrough_auth=passthrough_auth,
|
||||
)
|
||||
db_session.add(new_tool)
|
||||
db_session.commit()
|
||||
@@ -62,6 +64,7 @@ def update_tool(
|
||||
custom_headers: list[Header] | None,
|
||||
user_id: UUID | None,
|
||||
db_session: Session,
|
||||
passthrough_auth: bool | None,
|
||||
) -> Tool:
|
||||
tool = get_tool_by_id(tool_id, db_session)
|
||||
if tool is None:
|
||||
@@ -79,6 +82,8 @@ def update_tool(
|
||||
tool.custom_headers = [
|
||||
cast(HeaderItemDict, header.model_dump()) for header in custom_headers
|
||||
]
|
||||
if passthrough_auth is not None:
|
||||
tool.passthrough_auth = passthrough_auth
|
||||
db_session.commit()
|
||||
|
||||
return tool
|
||||
|
@@ -41,6 +41,16 @@ def _validate_tool_definition(definition: dict[str, Any]) -> None:
|
||||
raise HTTPException(status_code=400, detail=str(e))
|
||||
|
||||
|
||||
def _validate_auth_settings(tool_data: CustomToolCreate | CustomToolUpdate) -> None:
|
||||
if tool_data.passthrough_auth and tool_data.custom_headers:
|
||||
for header in tool_data.custom_headers:
|
||||
if header.key.lower() == "authorization":
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail="Cannot use passthrough auth with custom authorization headers",
|
||||
)
|
||||
|
||||
|
||||
@admin_router.post("/custom")
|
||||
def create_custom_tool(
|
||||
tool_data: CustomToolCreate,
|
||||
@@ -48,6 +58,7 @@ def create_custom_tool(
|
||||
user: User | None = Depends(current_admin_user),
|
||||
) -> ToolSnapshot:
|
||||
_validate_tool_definition(tool_data.definition)
|
||||
_validate_auth_settings(tool_data)
|
||||
tool = create_tool(
|
||||
name=tool_data.name,
|
||||
description=tool_data.description,
|
||||
@@ -55,6 +66,7 @@ def create_custom_tool(
|
||||
custom_headers=tool_data.custom_headers,
|
||||
user_id=user.id if user else None,
|
||||
db_session=db_session,
|
||||
passthrough_auth=tool_data.passthrough_auth,
|
||||
)
|
||||
return ToolSnapshot.from_model(tool)
|
||||
|
||||
@@ -68,6 +80,7 @@ def update_custom_tool(
|
||||
) -> ToolSnapshot:
|
||||
if tool_data.definition:
|
||||
_validate_tool_definition(tool_data.definition)
|
||||
_validate_auth_settings(tool_data)
|
||||
updated_tool = update_tool(
|
||||
tool_id=tool_id,
|
||||
name=tool_data.name,
|
||||
@@ -76,6 +89,7 @@ def update_custom_tool(
|
||||
custom_headers=tool_data.custom_headers,
|
||||
user_id=user.id if user else None,
|
||||
db_session=db_session,
|
||||
passthrough_auth=tool_data.passthrough_auth,
|
||||
)
|
||||
return ToolSnapshot.from_model(updated_tool)
|
||||
|
||||
|
@@ -13,6 +13,7 @@ class ToolSnapshot(BaseModel):
|
||||
display_name: str
|
||||
in_code_tool_id: str | None
|
||||
custom_headers: list[Any] | None
|
||||
passthrough_auth: bool
|
||||
|
||||
@classmethod
|
||||
def from_model(cls, tool: Tool) -> "ToolSnapshot":
|
||||
@@ -24,6 +25,7 @@ class ToolSnapshot(BaseModel):
|
||||
display_name=tool.display_name or tool.name,
|
||||
in_code_tool_id=tool.in_code_tool_id,
|
||||
custom_headers=tool.custom_headers,
|
||||
passthrough_auth=tool.passthrough_auth,
|
||||
)
|
||||
|
||||
|
||||
@@ -37,6 +39,7 @@ class CustomToolCreate(BaseModel):
|
||||
description: str | None = None
|
||||
definition: dict[str, Any]
|
||||
custom_headers: list[Header] | None = None
|
||||
passthrough_auth: bool
|
||||
|
||||
|
||||
class CustomToolUpdate(BaseModel):
|
||||
@@ -44,3 +47,4 @@ class CustomToolUpdate(BaseModel):
|
||||
description: str | None = None
|
||||
definition: dict[str, Any] | None = None
|
||||
custom_headers: list[Header] | None = None
|
||||
passthrough_auth: bool | None = None
|
||||
|
@@ -146,6 +146,11 @@ def construct_tools(
|
||||
"""Constructs tools based on persona configuration and available APIs"""
|
||||
tool_dict: dict[int, list[Tool]] = {}
|
||||
|
||||
# Get user's OAuth token if available
|
||||
user_oauth_token = None
|
||||
if user and user.oauth_accounts:
|
||||
user_oauth_token = user.oauth_accounts[0].access_token
|
||||
|
||||
for db_tool_model in persona.tools:
|
||||
if db_tool_model.in_code_tool_id:
|
||||
tool_cls = get_built_in_tool_by_id(
|
||||
@@ -236,6 +241,9 @@ def construct_tools(
|
||||
custom_tool_config.additional_headers or {}
|
||||
)
|
||||
),
|
||||
user_oauth_token=(
|
||||
user_oauth_token if db_tool_model.passthrough_auth else None
|
||||
),
|
||||
),
|
||||
)
|
||||
|
||||
|
@@ -80,10 +80,12 @@ class CustomTool(BaseTool):
|
||||
method_spec: MethodSpec,
|
||||
base_url: str,
|
||||
custom_headers: list[HeaderItemDict] | None = None,
|
||||
user_oauth_token: str | None = None,
|
||||
) -> None:
|
||||
self._base_url = base_url
|
||||
self._method_spec = method_spec
|
||||
self._tool_definition = self._method_spec.to_tool_definition()
|
||||
self._user_oauth_token = user_oauth_token
|
||||
|
||||
self._name = self._method_spec.name
|
||||
self._description = self._method_spec.summary
|
||||
@@ -91,6 +93,20 @@ class CustomTool(BaseTool):
|
||||
header_list_to_header_dict(custom_headers) if custom_headers else {}
|
||||
)
|
||||
|
||||
# Check for both Authorization header and OAuth token
|
||||
has_auth_header = any(
|
||||
key.lower() == "authorization" for key in self.headers.keys()
|
||||
)
|
||||
if has_auth_header and self._user_oauth_token:
|
||||
logger.warning(
|
||||
f"Tool '{self._name}' has both an Authorization "
|
||||
"header and OAuth token set. This is likely a configuration "
|
||||
"error as the OAuth token will override the custom header."
|
||||
)
|
||||
|
||||
if self._user_oauth_token:
|
||||
self.headers["Authorization"] = f"Bearer {self._user_oauth_token}"
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
return self._name
|
||||
@@ -348,6 +364,7 @@ def build_custom_tools_from_openapi_schema_and_headers(
|
||||
openapi_schema: dict[str, Any],
|
||||
custom_headers: list[HeaderItemDict] | None = None,
|
||||
dynamic_schema_info: DynamicSchemaInfo | None = None,
|
||||
user_oauth_token: str | None = None,
|
||||
) -> list[CustomTool]:
|
||||
if dynamic_schema_info:
|
||||
# Process dynamic schema information
|
||||
@@ -366,7 +383,13 @@ def build_custom_tools_from_openapi_schema_and_headers(
|
||||
url = openapi_to_url(openapi_schema)
|
||||
method_specs = openapi_to_method_specs(openapi_schema)
|
||||
return [
|
||||
CustomTool(method_spec, url, custom_headers) for method_spec in method_specs
|
||||
CustomTool(
|
||||
method_spec,
|
||||
url,
|
||||
custom_headers,
|
||||
user_oauth_token=user_oauth_token,
|
||||
)
|
||||
for method_spec in method_specs
|
||||
]
|
||||
|
||||
|
||||
|
Reference in New Issue
Block a user