Add basic passthrough auth (#3731)

* Add basic passthrough auth

* Add server-side validation

* Disallow for non-oauth

* Fix npm build
This commit is contained in:
Chris Weaver
2025-01-20 23:39:23 -08:00
committed by GitHub
parent 4ca7325d1a
commit 420476ad92
11 changed files with 251 additions and 44 deletions

View File

@@ -0,0 +1,33 @@
"""add passthrough auth to tool
Revision ID: f1ca58b2f2ec
Revises: c7bf5721733e
Create Date: 2024-03-19
"""
from typing import Sequence, Union
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision: str = "f1ca58b2f2ec"
down_revision: Union[str, None] = "c7bf5721733e"
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None
def upgrade() -> None:
# Add passthrough_auth column to tool table with default value of False
op.add_column(
"tool",
sa.Column(
"passthrough_auth", sa.Boolean(), nullable=False, server_default=sa.false()
),
)
def downgrade() -> None:
# Remove passthrough_auth column from tool table
op.drop_column("tool", "passthrough_auth")

View File

@@ -1430,6 +1430,8 @@ class Tool(Base):
user_id: Mapped[UUID | None] = mapped_column(
ForeignKey("user.id", ondelete="CASCADE"), nullable=True
)
# whether to pass through the user's OAuth token as Authorization header
passthrough_auth: Mapped[bool] = mapped_column(Boolean, default=False)
user: Mapped[User | None] = relationship("User", back_populates="custom_tools")
# Relationship to Persona through the association table

View File

@@ -38,6 +38,7 @@ def create_tool(
custom_headers: list[Header] | None,
user_id: UUID | None,
db_session: Session,
passthrough_auth: bool,
) -> Tool:
new_tool = Tool(
name=name,
@@ -48,6 +49,7 @@ def create_tool(
if custom_headers
else [],
user_id=user_id,
passthrough_auth=passthrough_auth,
)
db_session.add(new_tool)
db_session.commit()
@@ -62,6 +64,7 @@ def update_tool(
custom_headers: list[Header] | None,
user_id: UUID | None,
db_session: Session,
passthrough_auth: bool | None,
) -> Tool:
tool = get_tool_by_id(tool_id, db_session)
if tool is None:
@@ -79,6 +82,8 @@ def update_tool(
tool.custom_headers = [
cast(HeaderItemDict, header.model_dump()) for header in custom_headers
]
if passthrough_auth is not None:
tool.passthrough_auth = passthrough_auth
db_session.commit()
return tool

View File

@@ -41,6 +41,16 @@ def _validate_tool_definition(definition: dict[str, Any]) -> None:
raise HTTPException(status_code=400, detail=str(e))
def _validate_auth_settings(tool_data: CustomToolCreate | CustomToolUpdate) -> None:
if tool_data.passthrough_auth and tool_data.custom_headers:
for header in tool_data.custom_headers:
if header.key.lower() == "authorization":
raise HTTPException(
status_code=400,
detail="Cannot use passthrough auth with custom authorization headers",
)
@admin_router.post("/custom")
def create_custom_tool(
tool_data: CustomToolCreate,
@@ -48,6 +58,7 @@ def create_custom_tool(
user: User | None = Depends(current_admin_user),
) -> ToolSnapshot:
_validate_tool_definition(tool_data.definition)
_validate_auth_settings(tool_data)
tool = create_tool(
name=tool_data.name,
description=tool_data.description,
@@ -55,6 +66,7 @@ def create_custom_tool(
custom_headers=tool_data.custom_headers,
user_id=user.id if user else None,
db_session=db_session,
passthrough_auth=tool_data.passthrough_auth,
)
return ToolSnapshot.from_model(tool)
@@ -68,6 +80,7 @@ def update_custom_tool(
) -> ToolSnapshot:
if tool_data.definition:
_validate_tool_definition(tool_data.definition)
_validate_auth_settings(tool_data)
updated_tool = update_tool(
tool_id=tool_id,
name=tool_data.name,
@@ -76,6 +89,7 @@ def update_custom_tool(
custom_headers=tool_data.custom_headers,
user_id=user.id if user else None,
db_session=db_session,
passthrough_auth=tool_data.passthrough_auth,
)
return ToolSnapshot.from_model(updated_tool)

View File

@@ -13,6 +13,7 @@ class ToolSnapshot(BaseModel):
display_name: str
in_code_tool_id: str | None
custom_headers: list[Any] | None
passthrough_auth: bool
@classmethod
def from_model(cls, tool: Tool) -> "ToolSnapshot":
@@ -24,6 +25,7 @@ class ToolSnapshot(BaseModel):
display_name=tool.display_name or tool.name,
in_code_tool_id=tool.in_code_tool_id,
custom_headers=tool.custom_headers,
passthrough_auth=tool.passthrough_auth,
)
@@ -37,6 +39,7 @@ class CustomToolCreate(BaseModel):
description: str | None = None
definition: dict[str, Any]
custom_headers: list[Header] | None = None
passthrough_auth: bool
class CustomToolUpdate(BaseModel):
@@ -44,3 +47,4 @@ class CustomToolUpdate(BaseModel):
description: str | None = None
definition: dict[str, Any] | None = None
custom_headers: list[Header] | None = None
passthrough_auth: bool | None = None

View File

@@ -146,6 +146,11 @@ def construct_tools(
"""Constructs tools based on persona configuration and available APIs"""
tool_dict: dict[int, list[Tool]] = {}
# Get user's OAuth token if available
user_oauth_token = None
if user and user.oauth_accounts:
user_oauth_token = user.oauth_accounts[0].access_token
for db_tool_model in persona.tools:
if db_tool_model.in_code_tool_id:
tool_cls = get_built_in_tool_by_id(
@@ -236,6 +241,9 @@ def construct_tools(
custom_tool_config.additional_headers or {}
)
),
user_oauth_token=(
user_oauth_token if db_tool_model.passthrough_auth else None
),
),
)

View File

@@ -80,10 +80,12 @@ class CustomTool(BaseTool):
method_spec: MethodSpec,
base_url: str,
custom_headers: list[HeaderItemDict] | None = None,
user_oauth_token: str | None = None,
) -> None:
self._base_url = base_url
self._method_spec = method_spec
self._tool_definition = self._method_spec.to_tool_definition()
self._user_oauth_token = user_oauth_token
self._name = self._method_spec.name
self._description = self._method_spec.summary
@@ -91,6 +93,20 @@ class CustomTool(BaseTool):
header_list_to_header_dict(custom_headers) if custom_headers else {}
)
# Check for both Authorization header and OAuth token
has_auth_header = any(
key.lower() == "authorization" for key in self.headers.keys()
)
if has_auth_header and self._user_oauth_token:
logger.warning(
f"Tool '{self._name}' has both an Authorization "
"header and OAuth token set. This is likely a configuration "
"error as the OAuth token will override the custom header."
)
if self._user_oauth_token:
self.headers["Authorization"] = f"Bearer {self._user_oauth_token}"
@property
def name(self) -> str:
return self._name
@@ -348,6 +364,7 @@ def build_custom_tools_from_openapi_schema_and_headers(
openapi_schema: dict[str, Any],
custom_headers: list[HeaderItemDict] | None = None,
dynamic_schema_info: DynamicSchemaInfo | None = None,
user_oauth_token: str | None = None,
) -> list[CustomTool]:
if dynamic_schema_info:
# Process dynamic schema information
@@ -366,7 +383,13 @@ def build_custom_tools_from_openapi_schema_and_headers(
url = openapi_to_url(openapi_schema)
method_specs = openapi_to_method_specs(openapi_schema)
return [
CustomTool(method_spec, url, custom_headers) for method_spec in method_specs
CustomTool(
method_spec,
url,
custom_headers,
user_oauth_token=user_oauth_token,
)
for method_spec in method_specs
]

View File

@@ -24,6 +24,14 @@ import debounce from "lodash/debounce";
import { AdvancedOptionsToggle } from "@/components/AdvancedOptionsToggle";
import Link from "next/link";
import { Separator } from "@/components/ui/separator";
import { Checkbox } from "@/components/ui/checkbox";
import {
Tooltip,
TooltipContent,
TooltipProvider,
TooltipTrigger,
} from "@/components/ui/tooltip";
import { useAuthType } from "@/lib/hooks";
function parseJsonWithTrailingCommas(jsonString: string) {
// Regular expression to remove trailing commas before } or ]
@@ -51,7 +59,11 @@ function ToolForm({
}: {
existingTool?: ToolSnapshot;
values: ToolFormValues;
setFieldValue: (field: string, value: string) => void;
setFieldValue: <T = any>(
field: string,
value: T,
shouldValidate?: boolean
) => void;
isSubmitting: boolean;
definitionErrorState: [
string | null,
@@ -65,6 +77,9 @@ function ToolForm({
const [definitionError, setDefinitionError] = definitionErrorState;
const [methodSpecs, setMethodSpecs] = methodSpecsState;
const [showAdvancedOptions, setShowAdvancedOptions] = useState(false);
const authType = useAuthType();
const isOAuthEnabled = authType === "oidc" || authType === "google_oauth";
const debouncedValidateDefinition = useCallback(
(definition: string) => {
const validateDefinition = async () => {
@@ -218,43 +233,38 @@ function ToolForm({
</p>
<FieldArray
name="customHeaders"
render={(arrayHelpers: ArrayHelpers) => (
<div className="space-y-4">
{values.customHeaders && values.customHeaders.length > 0 && (
<div className="space-y-3">
{values.customHeaders.map(
(
header: { key: string; value: string },
index: number
) => (
<div
key={index}
className="flex items-center space-x-2 bg-gray-50 p-3 rounded-lg shadow-sm"
render={(arrayHelpers) => (
<div>
<div className="space-y-2">
{values.customHeaders.map(
(header: { key: string; value: string }, index: number) => (
<div
key={index}
className="flex items-center space-x-2 bg-gray-50 p-3 rounded-lg shadow-sm"
>
<Field
name={`customHeaders.${index}.key`}
placeholder="Header Key"
className="flex-1 p-2 border border-gray-300 rounded-md focus:ring-2 focus:ring-primary-500 focus:border-transparent"
/>
<Field
name={`customHeaders.${index}.value`}
placeholder="Header Value"
className="flex-1 p-2 border border-gray-300 rounded-md focus:ring-2 focus:ring-primary-500 focus:border-transparent"
/>
<Button
type="button"
onClick={() => arrayHelpers.remove(index)}
variant="destructive"
size="sm"
className="transition-colors duration-200 hover:bg-red-600"
>
<Field
name={`customHeaders.${index}.key`}
placeholder="Header Key"
className="flex-1 p-2 border border-gray-300 rounded-md focus:ring-2 focus:ring-primary-500 focus:border-transparent"
/>
<Field
name={`customHeaders.${index}.value`}
placeholder="Header Value"
className="flex-1 p-2 border border-gray-300 rounded-md focus:ring-2 focus:ring-primary-500 focus:border-transparent"
/>
<Button
type="button"
onClick={() => arrayHelpers.remove(index)}
variant="destructive"
size="sm"
className="transition-colors duration-200 hover:bg-red-600"
>
Remove
</Button>
</div>
)
)}
</div>
)}
Remove
</Button>
</div>
)
)}
</div>
<Button
type="button"
@@ -268,6 +278,75 @@ function ToolForm({
</div>
)}
/>
<div className="mt-6">
<h3 className="text-xl font-bold mb-2 text-primary-600">
Authentication
</h3>
{isOAuthEnabled ? (
<div className="flex flex-col gap-y-2">
<div className="flex items-center space-x-2">
<TooltipProvider>
<Tooltip>
<TooltipTrigger>
<div
className={
values.customHeaders.some(
(header) =>
header.key.toLowerCase() === "authorization"
)
? "opacity-50"
: ""
}
>
<Checkbox
id="passthrough_auth"
size="sm"
checked={values.passthrough_auth}
disabled={values.customHeaders.some(
(header) =>
header.key.toLowerCase() === "authorization" &&
!values.passthrough_auth
)}
onCheckedChange={(checked) => {
setFieldValue("passthrough_auth", checked, true);
}}
/>
</div>
</TooltipTrigger>
{values.customHeaders.some(
(header) => header.key.toLowerCase() === "authorization"
) && (
<TooltipContent side="top" align="center">
<p className="bg-background-900 max-w-[200px] mb-1 text-sm rounded-lg p-1.5 text-white">
Cannot enable OAuth passthrough when an
Authorization header is already set
</p>
</TooltipContent>
)}
</Tooltip>
</TooltipProvider>
<div className="flex flex-col">
<label
htmlFor="passthrough_auth"
className="text-sm font-medium leading-none peer-disabled:cursor-not-allowed peer-disabled:opacity-70"
>
Pass through user&apos;s OAuth token
</label>
<p className="text-xs text-subtle mt-1">
When enabled, the user&apos;s OAuth token will be passed
as the Authorization header for all API calls
</p>
</div>
</div>
</div>
) : (
<p className="text-sm text-subtle">
OAuth passthrough is only available when OIDC or OAuth
authentication is enabled
</p>
)}
</div>
</div>
)}
@@ -291,6 +370,7 @@ function ToolForm({
interface ToolFormValues {
definition: string;
customHeaders: { key: string; value: string }[];
passthrough_auth: boolean;
}
const ToolSchema = Yup.object().shape({
@@ -303,6 +383,7 @@ const ToolSchema = Yup.object().shape({
})
)
.default([]),
passthrough_auth: Yup.boolean().default(false),
});
export function ToolEditor({ tool }: { tool?: ToolSnapshot }) {
@@ -326,9 +407,27 @@ export function ToolEditor({ tool }: { tool?: ToolSnapshot }) {
key: header.key,
value: header.value,
})) ?? [],
passthrough_auth: tool?.passthrough_auth ?? false,
}}
validationSchema={ToolSchema}
onSubmit={async (values: ToolFormValues) => {
const hasAuthHeader = values.customHeaders?.some(
(header) => header.key.toLowerCase() === "authorization"
);
if (hasAuthHeader && values.passthrough_auth) {
setPopup({
message:
"Cannot enable passthrough auth when Authorization " +
"headers are present. Please remove any Authorization " +
"headers first.",
type: "error",
});
console.log(
"Cannot enable passthrough auth when Authorization headers are present. Please remove any Authorization headers first."
);
return;
}
let definition: any;
try {
definition = parseJsonWithTrailingCommas(values.definition);
@@ -344,6 +443,7 @@ export function ToolEditor({ tool }: { tool?: ToolSnapshot }) {
description: description || "",
definition: definition,
custom_headers: values.customHeaders,
passthrough_auth: values.passthrough_auth,
};
let response;
if (tool) {

View File

@@ -1,7 +1,6 @@
"use client";
import {
ConnectorIndexingStatus,
OAuthSlackCallbackResponse,
DocumentBoostStatus,
Tag,
UserGroup,
@@ -20,13 +19,10 @@ import { AllUsersResponse } from "./types";
import { Credential } from "./connectors/credentials";
import { SettingsContext } from "@/components/settings/SettingsProvider";
import { PersonaLabel } from "@/app/admin/assistants/interfaces";
import {
LLMProvider,
LLMProviderDescriptor,
} from "@/app/admin/configuration/llm/interfaces";
import { LLMProviderDescriptor } from "@/app/admin/configuration/llm/interfaces";
import { isAnthropic } from "@/app/admin/configuration/llm/interfaces";
import { getSourceMetadata } from "./sources";
import { buildFilters } from "./search/utils";
import { AuthType, NEXT_PUBLIC_CLOUD_ENABLED } from "./constants";
const CREDENTIAL_URL = "/api/manage/admin/credential";
@@ -454,6 +450,23 @@ export function useLlmOverride(
};
}
export function useAuthType(): AuthType | null {
const { data, error } = useSWR<{ auth_type: AuthType }>(
"/api/auth/type",
errorHandlingFetcher
);
if (NEXT_PUBLIC_CLOUD_ENABLED) {
return "cloud";
}
if (error || !data) {
return null;
}
return data.auth_type;
}
/*
EE Only APIs
*/

View File

@@ -11,6 +11,7 @@ export async function createCustomTool(toolData: {
description?: string;
definition: Record<string, any>;
custom_headers: { key: string; value: string }[];
passthrough_auth: boolean;
}): Promise<ApiResponse<ToolSnapshot>> {
try {
const response = await fetch("/api/admin/tool/custom", {
@@ -41,6 +42,7 @@ export async function updateCustomTool(
description?: string;
definition?: Record<string, any>;
custom_headers: { key: string; value: string }[];
passthrough_auth: boolean;
}
): Promise<ApiResponse<ToolSnapshot>> {
try {

View File

@@ -13,6 +13,9 @@ export interface ToolSnapshot {
// only specified for Custom Tools. ID of the tool in the codebase.
in_code_tool_id: string | null;
// whether to pass through the user's OAuth token as Authorization header
passthrough_auth: boolean;
}
export interface MethodSpec {