mirror of
https://github.com/danswer-ai/danswer.git
synced 2025-09-20 04:37:09 +02:00
Add basic passthrough auth (#3731)
* Add basic passthrough auth * Add server-side validation * Disallow for non-oauth * Fix npm build
This commit is contained in:
@@ -0,0 +1,33 @@
|
||||
"""add passthrough auth to tool
|
||||
|
||||
Revision ID: f1ca58b2f2ec
|
||||
Revises: c7bf5721733e
|
||||
Create Date: 2024-03-19
|
||||
|
||||
"""
|
||||
from typing import Sequence, Union
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision: str = "f1ca58b2f2ec"
|
||||
down_revision: Union[str, None] = "c7bf5721733e"
|
||||
branch_labels: Union[str, Sequence[str], None] = None
|
||||
depends_on: Union[str, Sequence[str], None] = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# Add passthrough_auth column to tool table with default value of False
|
||||
op.add_column(
|
||||
"tool",
|
||||
sa.Column(
|
||||
"passthrough_auth", sa.Boolean(), nullable=False, server_default=sa.false()
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
# Remove passthrough_auth column from tool table
|
||||
op.drop_column("tool", "passthrough_auth")
|
@@ -1430,6 +1430,8 @@ class Tool(Base):
|
||||
user_id: Mapped[UUID | None] = mapped_column(
|
||||
ForeignKey("user.id", ondelete="CASCADE"), nullable=True
|
||||
)
|
||||
# whether to pass through the user's OAuth token as Authorization header
|
||||
passthrough_auth: Mapped[bool] = mapped_column(Boolean, default=False)
|
||||
|
||||
user: Mapped[User | None] = relationship("User", back_populates="custom_tools")
|
||||
# Relationship to Persona through the association table
|
||||
|
@@ -38,6 +38,7 @@ def create_tool(
|
||||
custom_headers: list[Header] | None,
|
||||
user_id: UUID | None,
|
||||
db_session: Session,
|
||||
passthrough_auth: bool,
|
||||
) -> Tool:
|
||||
new_tool = Tool(
|
||||
name=name,
|
||||
@@ -48,6 +49,7 @@ def create_tool(
|
||||
if custom_headers
|
||||
else [],
|
||||
user_id=user_id,
|
||||
passthrough_auth=passthrough_auth,
|
||||
)
|
||||
db_session.add(new_tool)
|
||||
db_session.commit()
|
||||
@@ -62,6 +64,7 @@ def update_tool(
|
||||
custom_headers: list[Header] | None,
|
||||
user_id: UUID | None,
|
||||
db_session: Session,
|
||||
passthrough_auth: bool | None,
|
||||
) -> Tool:
|
||||
tool = get_tool_by_id(tool_id, db_session)
|
||||
if tool is None:
|
||||
@@ -79,6 +82,8 @@ def update_tool(
|
||||
tool.custom_headers = [
|
||||
cast(HeaderItemDict, header.model_dump()) for header in custom_headers
|
||||
]
|
||||
if passthrough_auth is not None:
|
||||
tool.passthrough_auth = passthrough_auth
|
||||
db_session.commit()
|
||||
|
||||
return tool
|
||||
|
@@ -41,6 +41,16 @@ def _validate_tool_definition(definition: dict[str, Any]) -> None:
|
||||
raise HTTPException(status_code=400, detail=str(e))
|
||||
|
||||
|
||||
def _validate_auth_settings(tool_data: CustomToolCreate | CustomToolUpdate) -> None:
|
||||
if tool_data.passthrough_auth and tool_data.custom_headers:
|
||||
for header in tool_data.custom_headers:
|
||||
if header.key.lower() == "authorization":
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail="Cannot use passthrough auth with custom authorization headers",
|
||||
)
|
||||
|
||||
|
||||
@admin_router.post("/custom")
|
||||
def create_custom_tool(
|
||||
tool_data: CustomToolCreate,
|
||||
@@ -48,6 +58,7 @@ def create_custom_tool(
|
||||
user: User | None = Depends(current_admin_user),
|
||||
) -> ToolSnapshot:
|
||||
_validate_tool_definition(tool_data.definition)
|
||||
_validate_auth_settings(tool_data)
|
||||
tool = create_tool(
|
||||
name=tool_data.name,
|
||||
description=tool_data.description,
|
||||
@@ -55,6 +66,7 @@ def create_custom_tool(
|
||||
custom_headers=tool_data.custom_headers,
|
||||
user_id=user.id if user else None,
|
||||
db_session=db_session,
|
||||
passthrough_auth=tool_data.passthrough_auth,
|
||||
)
|
||||
return ToolSnapshot.from_model(tool)
|
||||
|
||||
@@ -68,6 +80,7 @@ def update_custom_tool(
|
||||
) -> ToolSnapshot:
|
||||
if tool_data.definition:
|
||||
_validate_tool_definition(tool_data.definition)
|
||||
_validate_auth_settings(tool_data)
|
||||
updated_tool = update_tool(
|
||||
tool_id=tool_id,
|
||||
name=tool_data.name,
|
||||
@@ -76,6 +89,7 @@ def update_custom_tool(
|
||||
custom_headers=tool_data.custom_headers,
|
||||
user_id=user.id if user else None,
|
||||
db_session=db_session,
|
||||
passthrough_auth=tool_data.passthrough_auth,
|
||||
)
|
||||
return ToolSnapshot.from_model(updated_tool)
|
||||
|
||||
|
@@ -13,6 +13,7 @@ class ToolSnapshot(BaseModel):
|
||||
display_name: str
|
||||
in_code_tool_id: str | None
|
||||
custom_headers: list[Any] | None
|
||||
passthrough_auth: bool
|
||||
|
||||
@classmethod
|
||||
def from_model(cls, tool: Tool) -> "ToolSnapshot":
|
||||
@@ -24,6 +25,7 @@ class ToolSnapshot(BaseModel):
|
||||
display_name=tool.display_name or tool.name,
|
||||
in_code_tool_id=tool.in_code_tool_id,
|
||||
custom_headers=tool.custom_headers,
|
||||
passthrough_auth=tool.passthrough_auth,
|
||||
)
|
||||
|
||||
|
||||
@@ -37,6 +39,7 @@ class CustomToolCreate(BaseModel):
|
||||
description: str | None = None
|
||||
definition: dict[str, Any]
|
||||
custom_headers: list[Header] | None = None
|
||||
passthrough_auth: bool
|
||||
|
||||
|
||||
class CustomToolUpdate(BaseModel):
|
||||
@@ -44,3 +47,4 @@ class CustomToolUpdate(BaseModel):
|
||||
description: str | None = None
|
||||
definition: dict[str, Any] | None = None
|
||||
custom_headers: list[Header] | None = None
|
||||
passthrough_auth: bool | None = None
|
||||
|
@@ -146,6 +146,11 @@ def construct_tools(
|
||||
"""Constructs tools based on persona configuration and available APIs"""
|
||||
tool_dict: dict[int, list[Tool]] = {}
|
||||
|
||||
# Get user's OAuth token if available
|
||||
user_oauth_token = None
|
||||
if user and user.oauth_accounts:
|
||||
user_oauth_token = user.oauth_accounts[0].access_token
|
||||
|
||||
for db_tool_model in persona.tools:
|
||||
if db_tool_model.in_code_tool_id:
|
||||
tool_cls = get_built_in_tool_by_id(
|
||||
@@ -236,6 +241,9 @@ def construct_tools(
|
||||
custom_tool_config.additional_headers or {}
|
||||
)
|
||||
),
|
||||
user_oauth_token=(
|
||||
user_oauth_token if db_tool_model.passthrough_auth else None
|
||||
),
|
||||
),
|
||||
)
|
||||
|
||||
|
@@ -80,10 +80,12 @@ class CustomTool(BaseTool):
|
||||
method_spec: MethodSpec,
|
||||
base_url: str,
|
||||
custom_headers: list[HeaderItemDict] | None = None,
|
||||
user_oauth_token: str | None = None,
|
||||
) -> None:
|
||||
self._base_url = base_url
|
||||
self._method_spec = method_spec
|
||||
self._tool_definition = self._method_spec.to_tool_definition()
|
||||
self._user_oauth_token = user_oauth_token
|
||||
|
||||
self._name = self._method_spec.name
|
||||
self._description = self._method_spec.summary
|
||||
@@ -91,6 +93,20 @@ class CustomTool(BaseTool):
|
||||
header_list_to_header_dict(custom_headers) if custom_headers else {}
|
||||
)
|
||||
|
||||
# Check for both Authorization header and OAuth token
|
||||
has_auth_header = any(
|
||||
key.lower() == "authorization" for key in self.headers.keys()
|
||||
)
|
||||
if has_auth_header and self._user_oauth_token:
|
||||
logger.warning(
|
||||
f"Tool '{self._name}' has both an Authorization "
|
||||
"header and OAuth token set. This is likely a configuration "
|
||||
"error as the OAuth token will override the custom header."
|
||||
)
|
||||
|
||||
if self._user_oauth_token:
|
||||
self.headers["Authorization"] = f"Bearer {self._user_oauth_token}"
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
return self._name
|
||||
@@ -348,6 +364,7 @@ def build_custom_tools_from_openapi_schema_and_headers(
|
||||
openapi_schema: dict[str, Any],
|
||||
custom_headers: list[HeaderItemDict] | None = None,
|
||||
dynamic_schema_info: DynamicSchemaInfo | None = None,
|
||||
user_oauth_token: str | None = None,
|
||||
) -> list[CustomTool]:
|
||||
if dynamic_schema_info:
|
||||
# Process dynamic schema information
|
||||
@@ -366,7 +383,13 @@ def build_custom_tools_from_openapi_schema_and_headers(
|
||||
url = openapi_to_url(openapi_schema)
|
||||
method_specs = openapi_to_method_specs(openapi_schema)
|
||||
return [
|
||||
CustomTool(method_spec, url, custom_headers) for method_spec in method_specs
|
||||
CustomTool(
|
||||
method_spec,
|
||||
url,
|
||||
custom_headers,
|
||||
user_oauth_token=user_oauth_token,
|
||||
)
|
||||
for method_spec in method_specs
|
||||
]
|
||||
|
||||
|
||||
|
@@ -24,6 +24,14 @@ import debounce from "lodash/debounce";
|
||||
import { AdvancedOptionsToggle } from "@/components/AdvancedOptionsToggle";
|
||||
import Link from "next/link";
|
||||
import { Separator } from "@/components/ui/separator";
|
||||
import { Checkbox } from "@/components/ui/checkbox";
|
||||
import {
|
||||
Tooltip,
|
||||
TooltipContent,
|
||||
TooltipProvider,
|
||||
TooltipTrigger,
|
||||
} from "@/components/ui/tooltip";
|
||||
import { useAuthType } from "@/lib/hooks";
|
||||
|
||||
function parseJsonWithTrailingCommas(jsonString: string) {
|
||||
// Regular expression to remove trailing commas before } or ]
|
||||
@@ -51,7 +59,11 @@ function ToolForm({
|
||||
}: {
|
||||
existingTool?: ToolSnapshot;
|
||||
values: ToolFormValues;
|
||||
setFieldValue: (field: string, value: string) => void;
|
||||
setFieldValue: <T = any>(
|
||||
field: string,
|
||||
value: T,
|
||||
shouldValidate?: boolean
|
||||
) => void;
|
||||
isSubmitting: boolean;
|
||||
definitionErrorState: [
|
||||
string | null,
|
||||
@@ -65,6 +77,9 @@ function ToolForm({
|
||||
const [definitionError, setDefinitionError] = definitionErrorState;
|
||||
const [methodSpecs, setMethodSpecs] = methodSpecsState;
|
||||
const [showAdvancedOptions, setShowAdvancedOptions] = useState(false);
|
||||
const authType = useAuthType();
|
||||
const isOAuthEnabled = authType === "oidc" || authType === "google_oauth";
|
||||
|
||||
const debouncedValidateDefinition = useCallback(
|
||||
(definition: string) => {
|
||||
const validateDefinition = async () => {
|
||||
@@ -218,43 +233,38 @@ function ToolForm({
|
||||
</p>
|
||||
<FieldArray
|
||||
name="customHeaders"
|
||||
render={(arrayHelpers: ArrayHelpers) => (
|
||||
<div className="space-y-4">
|
||||
{values.customHeaders && values.customHeaders.length > 0 && (
|
||||
<div className="space-y-3">
|
||||
{values.customHeaders.map(
|
||||
(
|
||||
header: { key: string; value: string },
|
||||
index: number
|
||||
) => (
|
||||
<div
|
||||
key={index}
|
||||
className="flex items-center space-x-2 bg-gray-50 p-3 rounded-lg shadow-sm"
|
||||
render={(arrayHelpers) => (
|
||||
<div>
|
||||
<div className="space-y-2">
|
||||
{values.customHeaders.map(
|
||||
(header: { key: string; value: string }, index: number) => (
|
||||
<div
|
||||
key={index}
|
||||
className="flex items-center space-x-2 bg-gray-50 p-3 rounded-lg shadow-sm"
|
||||
>
|
||||
<Field
|
||||
name={`customHeaders.${index}.key`}
|
||||
placeholder="Header Key"
|
||||
className="flex-1 p-2 border border-gray-300 rounded-md focus:ring-2 focus:ring-primary-500 focus:border-transparent"
|
||||
/>
|
||||
<Field
|
||||
name={`customHeaders.${index}.value`}
|
||||
placeholder="Header Value"
|
||||
className="flex-1 p-2 border border-gray-300 rounded-md focus:ring-2 focus:ring-primary-500 focus:border-transparent"
|
||||
/>
|
||||
<Button
|
||||
type="button"
|
||||
onClick={() => arrayHelpers.remove(index)}
|
||||
variant="destructive"
|
||||
size="sm"
|
||||
className="transition-colors duration-200 hover:bg-red-600"
|
||||
>
|
||||
<Field
|
||||
name={`customHeaders.${index}.key`}
|
||||
placeholder="Header Key"
|
||||
className="flex-1 p-2 border border-gray-300 rounded-md focus:ring-2 focus:ring-primary-500 focus:border-transparent"
|
||||
/>
|
||||
<Field
|
||||
name={`customHeaders.${index}.value`}
|
||||
placeholder="Header Value"
|
||||
className="flex-1 p-2 border border-gray-300 rounded-md focus:ring-2 focus:ring-primary-500 focus:border-transparent"
|
||||
/>
|
||||
<Button
|
||||
type="button"
|
||||
onClick={() => arrayHelpers.remove(index)}
|
||||
variant="destructive"
|
||||
size="sm"
|
||||
className="transition-colors duration-200 hover:bg-red-600"
|
||||
>
|
||||
Remove
|
||||
</Button>
|
||||
</div>
|
||||
)
|
||||
)}
|
||||
</div>
|
||||
)}
|
||||
Remove
|
||||
</Button>
|
||||
</div>
|
||||
)
|
||||
)}
|
||||
</div>
|
||||
|
||||
<Button
|
||||
type="button"
|
||||
@@ -268,6 +278,75 @@ function ToolForm({
|
||||
</div>
|
||||
)}
|
||||
/>
|
||||
|
||||
<div className="mt-6">
|
||||
<h3 className="text-xl font-bold mb-2 text-primary-600">
|
||||
Authentication
|
||||
</h3>
|
||||
{isOAuthEnabled ? (
|
||||
<div className="flex flex-col gap-y-2">
|
||||
<div className="flex items-center space-x-2">
|
||||
<TooltipProvider>
|
||||
<Tooltip>
|
||||
<TooltipTrigger>
|
||||
<div
|
||||
className={
|
||||
values.customHeaders.some(
|
||||
(header) =>
|
||||
header.key.toLowerCase() === "authorization"
|
||||
)
|
||||
? "opacity-50"
|
||||
: ""
|
||||
}
|
||||
>
|
||||
<Checkbox
|
||||
id="passthrough_auth"
|
||||
size="sm"
|
||||
checked={values.passthrough_auth}
|
||||
disabled={values.customHeaders.some(
|
||||
(header) =>
|
||||
header.key.toLowerCase() === "authorization" &&
|
||||
!values.passthrough_auth
|
||||
)}
|
||||
onCheckedChange={(checked) => {
|
||||
setFieldValue("passthrough_auth", checked, true);
|
||||
}}
|
||||
/>
|
||||
</div>
|
||||
</TooltipTrigger>
|
||||
{values.customHeaders.some(
|
||||
(header) => header.key.toLowerCase() === "authorization"
|
||||
) && (
|
||||
<TooltipContent side="top" align="center">
|
||||
<p className="bg-background-900 max-w-[200px] mb-1 text-sm rounded-lg p-1.5 text-white">
|
||||
Cannot enable OAuth passthrough when an
|
||||
Authorization header is already set
|
||||
</p>
|
||||
</TooltipContent>
|
||||
)}
|
||||
</Tooltip>
|
||||
</TooltipProvider>
|
||||
<div className="flex flex-col">
|
||||
<label
|
||||
htmlFor="passthrough_auth"
|
||||
className="text-sm font-medium leading-none peer-disabled:cursor-not-allowed peer-disabled:opacity-70"
|
||||
>
|
||||
Pass through user's OAuth token
|
||||
</label>
|
||||
<p className="text-xs text-subtle mt-1">
|
||||
When enabled, the user's OAuth token will be passed
|
||||
as the Authorization header for all API calls
|
||||
</p>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
) : (
|
||||
<p className="text-sm text-subtle">
|
||||
OAuth passthrough is only available when OIDC or OAuth
|
||||
authentication is enabled
|
||||
</p>
|
||||
)}
|
||||
</div>
|
||||
</div>
|
||||
)}
|
||||
|
||||
@@ -291,6 +370,7 @@ function ToolForm({
|
||||
interface ToolFormValues {
|
||||
definition: string;
|
||||
customHeaders: { key: string; value: string }[];
|
||||
passthrough_auth: boolean;
|
||||
}
|
||||
|
||||
const ToolSchema = Yup.object().shape({
|
||||
@@ -303,6 +383,7 @@ const ToolSchema = Yup.object().shape({
|
||||
})
|
||||
)
|
||||
.default([]),
|
||||
passthrough_auth: Yup.boolean().default(false),
|
||||
});
|
||||
|
||||
export function ToolEditor({ tool }: { tool?: ToolSnapshot }) {
|
||||
@@ -326,9 +407,27 @@ export function ToolEditor({ tool }: { tool?: ToolSnapshot }) {
|
||||
key: header.key,
|
||||
value: header.value,
|
||||
})) ?? [],
|
||||
passthrough_auth: tool?.passthrough_auth ?? false,
|
||||
}}
|
||||
validationSchema={ToolSchema}
|
||||
onSubmit={async (values: ToolFormValues) => {
|
||||
const hasAuthHeader = values.customHeaders?.some(
|
||||
(header) => header.key.toLowerCase() === "authorization"
|
||||
);
|
||||
if (hasAuthHeader && values.passthrough_auth) {
|
||||
setPopup({
|
||||
message:
|
||||
"Cannot enable passthrough auth when Authorization " +
|
||||
"headers are present. Please remove any Authorization " +
|
||||
"headers first.",
|
||||
type: "error",
|
||||
});
|
||||
console.log(
|
||||
"Cannot enable passthrough auth when Authorization headers are present. Please remove any Authorization headers first."
|
||||
);
|
||||
return;
|
||||
}
|
||||
|
||||
let definition: any;
|
||||
try {
|
||||
definition = parseJsonWithTrailingCommas(values.definition);
|
||||
@@ -344,6 +443,7 @@ export function ToolEditor({ tool }: { tool?: ToolSnapshot }) {
|
||||
description: description || "",
|
||||
definition: definition,
|
||||
custom_headers: values.customHeaders,
|
||||
passthrough_auth: values.passthrough_auth,
|
||||
};
|
||||
let response;
|
||||
if (tool) {
|
||||
|
@@ -1,7 +1,6 @@
|
||||
"use client";
|
||||
import {
|
||||
ConnectorIndexingStatus,
|
||||
OAuthSlackCallbackResponse,
|
||||
DocumentBoostStatus,
|
||||
Tag,
|
||||
UserGroup,
|
||||
@@ -20,13 +19,10 @@ import { AllUsersResponse } from "./types";
|
||||
import { Credential } from "./connectors/credentials";
|
||||
import { SettingsContext } from "@/components/settings/SettingsProvider";
|
||||
import { PersonaLabel } from "@/app/admin/assistants/interfaces";
|
||||
import {
|
||||
LLMProvider,
|
||||
LLMProviderDescriptor,
|
||||
} from "@/app/admin/configuration/llm/interfaces";
|
||||
import { LLMProviderDescriptor } from "@/app/admin/configuration/llm/interfaces";
|
||||
import { isAnthropic } from "@/app/admin/configuration/llm/interfaces";
|
||||
import { getSourceMetadata } from "./sources";
|
||||
import { buildFilters } from "./search/utils";
|
||||
import { AuthType, NEXT_PUBLIC_CLOUD_ENABLED } from "./constants";
|
||||
|
||||
const CREDENTIAL_URL = "/api/manage/admin/credential";
|
||||
|
||||
@@ -454,6 +450,23 @@ export function useLlmOverride(
|
||||
};
|
||||
}
|
||||
|
||||
export function useAuthType(): AuthType | null {
|
||||
const { data, error } = useSWR<{ auth_type: AuthType }>(
|
||||
"/api/auth/type",
|
||||
errorHandlingFetcher
|
||||
);
|
||||
|
||||
if (NEXT_PUBLIC_CLOUD_ENABLED) {
|
||||
return "cloud";
|
||||
}
|
||||
|
||||
if (error || !data) {
|
||||
return null;
|
||||
}
|
||||
|
||||
return data.auth_type;
|
||||
}
|
||||
|
||||
/*
|
||||
EE Only APIs
|
||||
*/
|
||||
|
@@ -11,6 +11,7 @@ export async function createCustomTool(toolData: {
|
||||
description?: string;
|
||||
definition: Record<string, any>;
|
||||
custom_headers: { key: string; value: string }[];
|
||||
passthrough_auth: boolean;
|
||||
}): Promise<ApiResponse<ToolSnapshot>> {
|
||||
try {
|
||||
const response = await fetch("/api/admin/tool/custom", {
|
||||
@@ -41,6 +42,7 @@ export async function updateCustomTool(
|
||||
description?: string;
|
||||
definition?: Record<string, any>;
|
||||
custom_headers: { key: string; value: string }[];
|
||||
passthrough_auth: boolean;
|
||||
}
|
||||
): Promise<ApiResponse<ToolSnapshot>> {
|
||||
try {
|
||||
|
@@ -13,6 +13,9 @@ export interface ToolSnapshot {
|
||||
|
||||
// only specified for Custom Tools. ID of the tool in the codebase.
|
||||
in_code_tool_id: string | null;
|
||||
|
||||
// whether to pass through the user's OAuth token as Authorization header
|
||||
passthrough_auth: boolean;
|
||||
}
|
||||
|
||||
export interface MethodSpec {
|
||||
|
Reference in New Issue
Block a user