mirror of
https://github.com/danswer-ai/danswer.git
synced 2025-04-08 11:58:34 +02:00
Add support for openid connect (#206)
This allow using Danswer in typical (non-google) enterprise environments. * Access Tokens can be very large. A token without claims is already 1100 bytes for me (larger than allowed in danswer by default). With roles I got a 12kB token. For that reason I changed the field to TEXT in the database. * Danswer used to swallow most errors when OIDC would fail. Nodejs forwards a request to the backend and swallows all errors. Even within the backend we catched all ValueErrors and only returned the last exception with the request. Added full stack trace logging to allow debugging issues with userinfo and other endpoints. * Allow changing name of the login provider on the login button. * Changed variables and URLs to generic OAUTH_XX (without google in the name) but kept compatibility with the existing google integration * Tested again Keycloak with OpenID Connect Next steps: * Claim to role mappings * Auto login/SSO (Login button is just an extra click)
This commit is contained in:
parent
878d4e367f
commit
63780113d3
@ -0,0 +1,32 @@
|
||||
"""Larger Access Tokens for OAUTH
|
||||
|
||||
Revision ID: 465f78d9b7f9
|
||||
Revises: 3c5e35aa9af0
|
||||
Create Date: 2023-07-18 17:33:40.365034
|
||||
|
||||
"""
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = '465f78d9b7f9'
|
||||
down_revision = '3c5e35aa9af0'
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
op.alter_column(
|
||||
"oauth_account",
|
||||
"access_token",
|
||||
type_=sa.Text()
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.alter_column(
|
||||
"oauth_account",
|
||||
"access_token",
|
||||
type_=sa.String(length=1024)
|
||||
)
|
@ -7,11 +7,32 @@ from email.mime.multipart import MIMEMultipart
|
||||
from email.mime.text import MIMEText
|
||||
from typing import Optional
|
||||
|
||||
from fastapi import Depends
|
||||
from fastapi import HTTPException
|
||||
from fastapi import Request
|
||||
from fastapi import status
|
||||
from fastapi_users import BaseUserManager
|
||||
from fastapi_users import FastAPIUsers
|
||||
from fastapi_users import models
|
||||
from fastapi_users import schemas
|
||||
from fastapi_users import UUIDIDMixin
|
||||
from fastapi_users.authentication import AuthenticationBackend
|
||||
from fastapi_users.authentication import CookieTransport
|
||||
from fastapi_users.authentication.strategy.db import AccessTokenDatabase
|
||||
from fastapi_users.authentication.strategy.db import DatabaseStrategy
|
||||
from fastapi_users.db import SQLAlchemyUserDatabase
|
||||
from httpx_oauth.clients.google import GoogleOAuth2
|
||||
from httpx_oauth.clients.openid import OpenID
|
||||
from pydantic import EmailStr
|
||||
|
||||
from danswer.auth.schemas import UserCreate
|
||||
from danswer.auth.schemas import UserRole
|
||||
from danswer.configs.app_configs import DISABLE_AUTH
|
||||
from danswer.configs.app_configs import GOOGLE_OAUTH_CLIENT_ID
|
||||
from danswer.configs.app_configs import GOOGLE_OAUTH_CLIENT_SECRET
|
||||
from danswer.configs.app_configs import ENABLE_OAUTH
|
||||
from danswer.configs.app_configs import OAUTH_CLIENT_ID
|
||||
from danswer.configs.app_configs import OAUTH_CLIENT_SECRET
|
||||
from danswer.configs.app_configs import OAUTH_TYPE
|
||||
from danswer.configs.app_configs import OPENID_CONFIG_URL
|
||||
from danswer.configs.app_configs import REQUIRE_EMAIL_VERIFICATION
|
||||
from danswer.configs.app_configs import SECRET
|
||||
from danswer.configs.app_configs import SESSION_EXPIRE_TIME_SECONDS
|
||||
@ -28,22 +49,6 @@ from danswer.db.engine import get_async_session
|
||||
from danswer.db.models import AccessToken
|
||||
from danswer.db.models import User
|
||||
from danswer.utils.logger import setup_logger
|
||||
from fastapi import Depends
|
||||
from fastapi import HTTPException
|
||||
from fastapi import Request
|
||||
from fastapi import status
|
||||
from fastapi_users import BaseUserManager
|
||||
from fastapi_users import FastAPIUsers
|
||||
from fastapi_users import models
|
||||
from fastapi_users import schemas
|
||||
from fastapi_users import UUIDIDMixin
|
||||
from fastapi_users.authentication import AuthenticationBackend
|
||||
from fastapi_users.authentication import CookieTransport
|
||||
from fastapi_users.authentication.strategy.db import AccessTokenDatabase
|
||||
from fastapi_users.authentication.strategy.db import DatabaseStrategy
|
||||
from fastapi_users.db import SQLAlchemyUserDatabase
|
||||
from httpx_oauth.clients.google import GoogleOAuth2
|
||||
from pydantic import EmailStr
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
@ -193,7 +198,15 @@ auth_backend = AuthenticationBackend(
|
||||
get_strategy=get_database_strategy,
|
||||
)
|
||||
|
||||
google_oauth_client = GoogleOAuth2(GOOGLE_OAUTH_CLIENT_ID, GOOGLE_OAUTH_CLIENT_SECRET)
|
||||
oauth_client = None # type: GoogleOAuth2 | OpenID | None
|
||||
if ENABLE_OAUTH:
|
||||
if OAUTH_TYPE == "google":
|
||||
oauth_client = GoogleOAuth2(OAUTH_CLIENT_ID, OAUTH_CLIENT_SECRET)
|
||||
elif OAUTH_TYPE == "openid":
|
||||
oauth_client = OpenID(OAUTH_CLIENT_ID, OAUTH_CLIENT_SECRET, OPENID_CONFIG_URL)
|
||||
else:
|
||||
raise AssertionError(f"Invalid OAUTH type {OAUTH_TYPE}")
|
||||
|
||||
|
||||
fastapi_users = FastAPIUsers[User, uuid.UUID](get_user_manager, [auth_backend])
|
||||
|
||||
|
@ -46,8 +46,10 @@ SESSION_EXPIRE_TIME_SECONDS = int(
|
||||
VALID_EMAIL_DOMAIN = os.environ.get("VALID_EMAIL_DOMAIN", "")
|
||||
# OAuth Login Flow
|
||||
ENABLE_OAUTH = os.environ.get("ENABLE_OAUTH", "").lower() != "false"
|
||||
GOOGLE_OAUTH_CLIENT_ID = os.environ.get("GOOGLE_OAUTH_CLIENT_ID", "")
|
||||
GOOGLE_OAUTH_CLIENT_SECRET = os.environ.get("GOOGLE_OAUTH_CLIENT_SECRET", "")
|
||||
OAUTH_TYPE = os.environ.get("OAUTH_TYPE", "google").lower()
|
||||
OAUTH_CLIENT_ID = os.environ.get("OAUTH_CLIENT_ID", os.environ.get("GOOGLE_OAUTH_CLIENT_ID", ""))
|
||||
OAUTH_CLIENT_SECRET = os.environ.get("OAUTH_CLIENT_SECRET", os.environ.get("GOOGLE_OAUTH_CLIENT_SECRET", ""))
|
||||
OPENID_CONFIG_URL = os.environ.get("OPENID_CONFIG_URL", "")
|
||||
MASK_CREDENTIAL_PREFIX = (
|
||||
os.environ.get("MASK_CREDENTIAL_PREFIX", "True").lower() != "false"
|
||||
)
|
||||
|
@ -4,9 +4,6 @@ from typing import Any
|
||||
from typing import List
|
||||
from uuid import UUID
|
||||
|
||||
from danswer.auth.schemas import UserRole
|
||||
from danswer.configs.constants import DocumentSource
|
||||
from danswer.connectors.models import InputType
|
||||
from fastapi_users.db import SQLAlchemyBaseOAuthAccountTableUUID
|
||||
from fastapi_users.db import SQLAlchemyBaseUserTableUUID
|
||||
from fastapi_users_db_sqlalchemy.access_token import SQLAlchemyBaseAccessTokenTableUUID
|
||||
@ -17,12 +14,17 @@ from sqlalchemy import ForeignKey
|
||||
from sqlalchemy import func
|
||||
from sqlalchemy import Integer
|
||||
from sqlalchemy import String
|
||||
from sqlalchemy import Text
|
||||
from sqlalchemy.dialects import postgresql
|
||||
from sqlalchemy.orm import DeclarativeBase
|
||||
from sqlalchemy.orm import Mapped
|
||||
from sqlalchemy.orm import mapped_column
|
||||
from sqlalchemy.orm import relationship
|
||||
|
||||
from danswer.auth.schemas import UserRole
|
||||
from danswer.configs.constants import DocumentSource
|
||||
from danswer.connectors.models import InputType
|
||||
|
||||
|
||||
class IndexingStatus(str, PyEnum):
|
||||
NOT_STARTED = "not_started"
|
||||
@ -36,7 +38,8 @@ class Base(DeclarativeBase):
|
||||
|
||||
|
||||
class OAuthAccount(SQLAlchemyBaseOAuthAccountTableUUID, Base):
|
||||
pass
|
||||
# even an almost empty token from keycloak will not fit the default 1024 bytes
|
||||
access_token: Mapped[str] = mapped_column(Text(), nullable=False) # type: ignore
|
||||
|
||||
|
||||
class User(SQLAlchemyBaseUserTableUUID, Base):
|
||||
|
@ -5,14 +5,14 @@ from danswer.auth.schemas import UserRead
|
||||
from danswer.auth.schemas import UserUpdate
|
||||
from danswer.auth.users import auth_backend
|
||||
from danswer.auth.users import fastapi_users
|
||||
from danswer.auth.users import google_oauth_client
|
||||
from danswer.configs.app_configs import APP_HOST
|
||||
from danswer.auth.users import oauth_client
|
||||
from danswer.configs.app_configs import APP_HOST, OAUTH_TYPE, OPENID_CONFIG_URL
|
||||
from danswer.configs.app_configs import APP_PORT
|
||||
from danswer.configs.app_configs import DISABLE_AUTH
|
||||
from danswer.configs.app_configs import DISABLE_GENERATIVE_AI
|
||||
from danswer.configs.app_configs import ENABLE_OAUTH
|
||||
from danswer.configs.app_configs import GOOGLE_OAUTH_CLIENT_ID
|
||||
from danswer.configs.app_configs import GOOGLE_OAUTH_CLIENT_SECRET
|
||||
from danswer.configs.app_configs import OAUTH_CLIENT_ID
|
||||
from danswer.configs.app_configs import OAUTH_CLIENT_SECRET
|
||||
from danswer.configs.app_configs import QDRANT_DEFAULT_COLLECTION
|
||||
from danswer.configs.app_configs import SECRET
|
||||
from danswer.configs.app_configs import TYPESENSE_DEFAULT_COLLECTION
|
||||
@ -49,6 +49,11 @@ def validation_exception_handler(
|
||||
|
||||
|
||||
def value_error_handler(_: Request, exc: ValueError) -> JSONResponse:
|
||||
try:
|
||||
raise(exc)
|
||||
except:
|
||||
# log stacktrace
|
||||
logger.exception("ValueError")
|
||||
return JSONResponse(
|
||||
status_code=400,
|
||||
content={"message": str(exc)},
|
||||
@ -88,25 +93,41 @@ def get_application() -> FastAPI:
|
||||
tags=["users"],
|
||||
)
|
||||
if ENABLE_OAUTH:
|
||||
if OAUTH_TYPE == "google":
|
||||
# special case for google
|
||||
application.include_router(
|
||||
fastapi_users.get_oauth_router(
|
||||
oauth_client,
|
||||
auth_backend,
|
||||
SECRET,
|
||||
associate_by_email=True,
|
||||
is_verified_by_default=True,
|
||||
# points the user back to the login page, where we will call the
|
||||
# /auth/google/callback endpoint + redirect them to the main app
|
||||
redirect_url=f"{WEB_DOMAIN}/auth/google/callback",
|
||||
),
|
||||
prefix="/auth/google",
|
||||
tags=["auth"],
|
||||
)
|
||||
application.include_router(
|
||||
fastapi_users.get_oauth_router(
|
||||
google_oauth_client,
|
||||
oauth_client,
|
||||
auth_backend,
|
||||
SECRET,
|
||||
associate_by_email=True,
|
||||
is_verified_by_default=True,
|
||||
# points the user back to the login page, where we will call the
|
||||
# /auth/google/callback endpoint + redirect them to the main app
|
||||
redirect_url=f"{WEB_DOMAIN}/auth/google/callback",
|
||||
# /auth/oauth/callback endpoint + redirect them to the main app
|
||||
redirect_url=f"{WEB_DOMAIN}/auth/oauth/callback",
|
||||
),
|
||||
prefix="/auth/google",
|
||||
prefix="/auth/oauth",
|
||||
tags=["auth"],
|
||||
)
|
||||
application.include_router(
|
||||
fastapi_users.get_oauth_associate_router(
|
||||
google_oauth_client, UserRead, SECRET
|
||||
oauth_client, UserRead, SECRET
|
||||
),
|
||||
prefix="/auth/associate/google",
|
||||
prefix="/auth/associate/oauth",
|
||||
tags=["auth"],
|
||||
)
|
||||
|
||||
@ -135,10 +156,14 @@ def get_application() -> FastAPI:
|
||||
|
||||
if not DISABLE_AUTH:
|
||||
if not ENABLE_OAUTH:
|
||||
logger.warning("OAuth is turned off")
|
||||
logger.debug("OAuth is turned off")
|
||||
else:
|
||||
if not GOOGLE_OAUTH_CLIENT_ID or not GOOGLE_OAUTH_CLIENT_SECRET:
|
||||
logger.warning("OAuth is turned on but incorrectly configured")
|
||||
if not OAUTH_CLIENT_ID:
|
||||
logger.warning("OAuth is turned on but OAUTH_CLIENT_ID is empty")
|
||||
if not OAUTH_CLIENT_SECRET:
|
||||
logger.warning("OAuth is turned on but OAUTH_CLIENT_SECRET is empty")
|
||||
if OAUTH_TYPE == "openid" and not OPENID_CONFIG_URL:
|
||||
logger.warning("OpenID is turned on but OPENID_CONFIG_URL is emtpy")
|
||||
else:
|
||||
logger.debug("OAuth is turned on")
|
||||
|
||||
|
@ -25,6 +25,8 @@ services:
|
||||
- TYPESENSE_API_KEY=${TYPESENSE_API_KEY:-typesense_api_key}
|
||||
- LOG_LEVEL=${LOG_LEVEL:-info}
|
||||
- DISABLE_AUTH=${DISABLE_AUTH:-True}
|
||||
- OAUTH_TYPE=${OAUTH_TYPE:-}
|
||||
- OPENID_CONFIG_URL=${OPENID_CONFIG_URL:-}
|
||||
- GOOGLE_OAUTH_CLIENT_ID=${GOOGLE_OAUTH_CLIENT_ID:-}
|
||||
- GOOGLE_OAUTH_CLIENT_SECRET=${GOOGLE_OAUTH_CLIENT_SECRET:-}
|
||||
- DISABLE_GENERATIVE_AI=${DISABLE_GENERATIVE_AI:-}
|
||||
@ -72,6 +74,7 @@ services:
|
||||
environment:
|
||||
- INTERNAL_URL=http://api_server:8080
|
||||
- DISABLE_AUTH=${DISABLE_AUTH:-True}
|
||||
- OAUTH_NAME=${OAUTH_NAME:-}
|
||||
relational_db:
|
||||
image: postgres:15.2-alpine
|
||||
restart: always
|
||||
|
@ -1,5 +1,5 @@
|
||||
import { HealthCheckBanner } from "@/components/health/healthcheck";
|
||||
import { DISABLE_AUTH } from "@/lib/constants";
|
||||
import {DISABLE_AUTH, OAUTH_NAME} from "@/lib/constants";
|
||||
import { User } from "@/lib/types";
|
||||
import { getGoogleOAuthUrlSS, getCurrentUserSS } from "@/lib/userSS";
|
||||
import { redirect } from "next/navigation";
|
||||
@ -56,11 +56,11 @@ const Page = async () => {
|
||||
" focus:outline-none focus:ring-2 hover:bg-red-700 focus:ring-offset-2 focus:ring-red-500"
|
||||
}
|
||||
>
|
||||
Sign in with Google
|
||||
Sign in with {OAUTH_NAME}
|
||||
</a>
|
||||
) : (
|
||||
<button className={BUTTON_STYLE + " cursor-default"}>
|
||||
Sign in with Google
|
||||
Sign in with {OAUTH_NAME}
|
||||
</button>
|
||||
)}
|
||||
</div>
|
||||
|
23
web/src/app/auth/oauth/callback/route.ts
Normal file
23
web/src/app/auth/oauth/callback/route.ts
Normal file
@ -0,0 +1,23 @@
|
||||
import { getDomain } from "@/lib/redirectSS";
|
||||
import { buildUrl } from "@/lib/utilsSS";
|
||||
import { NextRequest, NextResponse } from "next/server";
|
||||
|
||||
export const GET = async (request: NextRequest) => {
|
||||
// Wrapper around the FastAPI endpoint /auth/oauth/callback,
|
||||
// which adds back a redirect to the main app.
|
||||
const url = new URL(buildUrl("/auth/oauth/callback"));
|
||||
url.search = request.nextUrl.search;
|
||||
|
||||
const response = await fetch(url.toString());
|
||||
const setCookieHeader = response.headers.get("set-cookie");
|
||||
|
||||
if (!setCookieHeader) {
|
||||
return NextResponse.redirect(new URL("/auth/error", getDomain(request)));
|
||||
}
|
||||
|
||||
const redirectResponse = NextResponse.redirect(
|
||||
new URL("/", getDomain(request))
|
||||
);
|
||||
redirectResponse.headers.set("set-cookie", setCookieHeader);
|
||||
return redirectResponse;
|
||||
};
|
@ -1,4 +1,7 @@
|
||||
export const DISABLE_AUTH = process.env.DISABLE_AUTH?.toLowerCase() === "true";
|
||||
|
||||
export const OAUTH_NAME = process.env.OAUTH_NAME || "Google";
|
||||
|
||||
export const INTERNAL_URL = process.env.INTERNAL_URL || "http://127.0.0.1:8080";
|
||||
export const NEXT_PUBLIC_DISABLE_STREAMING =
|
||||
process.env.NEXT_PUBLIC_DISABLE_STREAMING?.toLowerCase() === "true";
|
||||
|
@ -4,7 +4,7 @@ import { buildUrl } from "./utilsSS";
|
||||
import { ReadonlyRequestCookies } from "next/dist/server/web/spec-extension/adapters/request-cookies";
|
||||
|
||||
export const getGoogleOAuthUrlSS = async (): Promise<string> => {
|
||||
const res = await fetch(buildUrl("/auth/google/authorize"));
|
||||
const res = await fetch(buildUrl("/auth/oauth/authorize"));
|
||||
if (!res.ok) {
|
||||
throw new Error("Failed to fetch data");
|
||||
}
|
||||
|
Loading…
x
Reference in New Issue
Block a user