Add support for openid connect (#206)

This allow using Danswer in typical (non-google) enterprise environments.

* Access Tokens can be very large. A token without claims is already 1100 bytes for me (larger than allowed in danswer by default). With roles I got a 12kB token. For that reason I changed the field to TEXT in the database.
* Danswer used to swallow most errors when OIDC would fail. Nodejs forwards a request to the backend and swallows all errors. Even within the backend we catched all ValueErrors and only returned the last exception with the request. Added full stack trace logging to allow debugging issues with userinfo and other endpoints.
* Allow changing name of the login provider on the login button.
* Changed variables and URLs to generic OAUTH_XX (without google in the name) but kept compatibility with the existing google integration
* Tested again Keycloak with OpenID Connect

Next steps:
* Claim to role mappings
* Auto login/SSO (Login button is just an extra click)
This commit is contained in:
jabdoa2 2023-07-29 23:04:32 +02:00 committed by GitHub
parent 878d4e367f
commit 63780113d3
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
10 changed files with 146 additions and 42 deletions

View File

@ -0,0 +1,32 @@
"""Larger Access Tokens for OAUTH
Revision ID: 465f78d9b7f9
Revises: 3c5e35aa9af0
Create Date: 2023-07-18 17:33:40.365034
"""
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision = '465f78d9b7f9'
down_revision = '3c5e35aa9af0'
branch_labels = None
depends_on = None
def upgrade() -> None:
op.alter_column(
"oauth_account",
"access_token",
type_=sa.Text()
)
def downgrade() -> None:
op.alter_column(
"oauth_account",
"access_token",
type_=sa.String(length=1024)
)

View File

@ -7,11 +7,32 @@ from email.mime.multipart import MIMEMultipart
from email.mime.text import MIMEText
from typing import Optional
from fastapi import Depends
from fastapi import HTTPException
from fastapi import Request
from fastapi import status
from fastapi_users import BaseUserManager
from fastapi_users import FastAPIUsers
from fastapi_users import models
from fastapi_users import schemas
from fastapi_users import UUIDIDMixin
from fastapi_users.authentication import AuthenticationBackend
from fastapi_users.authentication import CookieTransport
from fastapi_users.authentication.strategy.db import AccessTokenDatabase
from fastapi_users.authentication.strategy.db import DatabaseStrategy
from fastapi_users.db import SQLAlchemyUserDatabase
from httpx_oauth.clients.google import GoogleOAuth2
from httpx_oauth.clients.openid import OpenID
from pydantic import EmailStr
from danswer.auth.schemas import UserCreate
from danswer.auth.schemas import UserRole
from danswer.configs.app_configs import DISABLE_AUTH
from danswer.configs.app_configs import GOOGLE_OAUTH_CLIENT_ID
from danswer.configs.app_configs import GOOGLE_OAUTH_CLIENT_SECRET
from danswer.configs.app_configs import ENABLE_OAUTH
from danswer.configs.app_configs import OAUTH_CLIENT_ID
from danswer.configs.app_configs import OAUTH_CLIENT_SECRET
from danswer.configs.app_configs import OAUTH_TYPE
from danswer.configs.app_configs import OPENID_CONFIG_URL
from danswer.configs.app_configs import REQUIRE_EMAIL_VERIFICATION
from danswer.configs.app_configs import SECRET
from danswer.configs.app_configs import SESSION_EXPIRE_TIME_SECONDS
@ -28,22 +49,6 @@ from danswer.db.engine import get_async_session
from danswer.db.models import AccessToken
from danswer.db.models import User
from danswer.utils.logger import setup_logger
from fastapi import Depends
from fastapi import HTTPException
from fastapi import Request
from fastapi import status
from fastapi_users import BaseUserManager
from fastapi_users import FastAPIUsers
from fastapi_users import models
from fastapi_users import schemas
from fastapi_users import UUIDIDMixin
from fastapi_users.authentication import AuthenticationBackend
from fastapi_users.authentication import CookieTransport
from fastapi_users.authentication.strategy.db import AccessTokenDatabase
from fastapi_users.authentication.strategy.db import DatabaseStrategy
from fastapi_users.db import SQLAlchemyUserDatabase
from httpx_oauth.clients.google import GoogleOAuth2
from pydantic import EmailStr
logger = setup_logger()
@ -193,7 +198,15 @@ auth_backend = AuthenticationBackend(
get_strategy=get_database_strategy,
)
google_oauth_client = GoogleOAuth2(GOOGLE_OAUTH_CLIENT_ID, GOOGLE_OAUTH_CLIENT_SECRET)
oauth_client = None # type: GoogleOAuth2 | OpenID | None
if ENABLE_OAUTH:
if OAUTH_TYPE == "google":
oauth_client = GoogleOAuth2(OAUTH_CLIENT_ID, OAUTH_CLIENT_SECRET)
elif OAUTH_TYPE == "openid":
oauth_client = OpenID(OAUTH_CLIENT_ID, OAUTH_CLIENT_SECRET, OPENID_CONFIG_URL)
else:
raise AssertionError(f"Invalid OAUTH type {OAUTH_TYPE}")
fastapi_users = FastAPIUsers[User, uuid.UUID](get_user_manager, [auth_backend])

View File

@ -46,8 +46,10 @@ SESSION_EXPIRE_TIME_SECONDS = int(
VALID_EMAIL_DOMAIN = os.environ.get("VALID_EMAIL_DOMAIN", "")
# OAuth Login Flow
ENABLE_OAUTH = os.environ.get("ENABLE_OAUTH", "").lower() != "false"
GOOGLE_OAUTH_CLIENT_ID = os.environ.get("GOOGLE_OAUTH_CLIENT_ID", "")
GOOGLE_OAUTH_CLIENT_SECRET = os.environ.get("GOOGLE_OAUTH_CLIENT_SECRET", "")
OAUTH_TYPE = os.environ.get("OAUTH_TYPE", "google").lower()
OAUTH_CLIENT_ID = os.environ.get("OAUTH_CLIENT_ID", os.environ.get("GOOGLE_OAUTH_CLIENT_ID", ""))
OAUTH_CLIENT_SECRET = os.environ.get("OAUTH_CLIENT_SECRET", os.environ.get("GOOGLE_OAUTH_CLIENT_SECRET", ""))
OPENID_CONFIG_URL = os.environ.get("OPENID_CONFIG_URL", "")
MASK_CREDENTIAL_PREFIX = (
os.environ.get("MASK_CREDENTIAL_PREFIX", "True").lower() != "false"
)

View File

@ -4,9 +4,6 @@ from typing import Any
from typing import List
from uuid import UUID
from danswer.auth.schemas import UserRole
from danswer.configs.constants import DocumentSource
from danswer.connectors.models import InputType
from fastapi_users.db import SQLAlchemyBaseOAuthAccountTableUUID
from fastapi_users.db import SQLAlchemyBaseUserTableUUID
from fastapi_users_db_sqlalchemy.access_token import SQLAlchemyBaseAccessTokenTableUUID
@ -17,12 +14,17 @@ from sqlalchemy import ForeignKey
from sqlalchemy import func
from sqlalchemy import Integer
from sqlalchemy import String
from sqlalchemy import Text
from sqlalchemy.dialects import postgresql
from sqlalchemy.orm import DeclarativeBase
from sqlalchemy.orm import Mapped
from sqlalchemy.orm import mapped_column
from sqlalchemy.orm import relationship
from danswer.auth.schemas import UserRole
from danswer.configs.constants import DocumentSource
from danswer.connectors.models import InputType
class IndexingStatus(str, PyEnum):
NOT_STARTED = "not_started"
@ -36,7 +38,8 @@ class Base(DeclarativeBase):
class OAuthAccount(SQLAlchemyBaseOAuthAccountTableUUID, Base):
pass
# even an almost empty token from keycloak will not fit the default 1024 bytes
access_token: Mapped[str] = mapped_column(Text(), nullable=False) # type: ignore
class User(SQLAlchemyBaseUserTableUUID, Base):

View File

@ -5,14 +5,14 @@ from danswer.auth.schemas import UserRead
from danswer.auth.schemas import UserUpdate
from danswer.auth.users import auth_backend
from danswer.auth.users import fastapi_users
from danswer.auth.users import google_oauth_client
from danswer.configs.app_configs import APP_HOST
from danswer.auth.users import oauth_client
from danswer.configs.app_configs import APP_HOST, OAUTH_TYPE, OPENID_CONFIG_URL
from danswer.configs.app_configs import APP_PORT
from danswer.configs.app_configs import DISABLE_AUTH
from danswer.configs.app_configs import DISABLE_GENERATIVE_AI
from danswer.configs.app_configs import ENABLE_OAUTH
from danswer.configs.app_configs import GOOGLE_OAUTH_CLIENT_ID
from danswer.configs.app_configs import GOOGLE_OAUTH_CLIENT_SECRET
from danswer.configs.app_configs import OAUTH_CLIENT_ID
from danswer.configs.app_configs import OAUTH_CLIENT_SECRET
from danswer.configs.app_configs import QDRANT_DEFAULT_COLLECTION
from danswer.configs.app_configs import SECRET
from danswer.configs.app_configs import TYPESENSE_DEFAULT_COLLECTION
@ -49,6 +49,11 @@ def validation_exception_handler(
def value_error_handler(_: Request, exc: ValueError) -> JSONResponse:
try:
raise(exc)
except:
# log stacktrace
logger.exception("ValueError")
return JSONResponse(
status_code=400,
content={"message": str(exc)},
@ -88,25 +93,41 @@ def get_application() -> FastAPI:
tags=["users"],
)
if ENABLE_OAUTH:
if OAUTH_TYPE == "google":
# special case for google
application.include_router(
fastapi_users.get_oauth_router(
oauth_client,
auth_backend,
SECRET,
associate_by_email=True,
is_verified_by_default=True,
# points the user back to the login page, where we will call the
# /auth/google/callback endpoint + redirect them to the main app
redirect_url=f"{WEB_DOMAIN}/auth/google/callback",
),
prefix="/auth/google",
tags=["auth"],
)
application.include_router(
fastapi_users.get_oauth_router(
google_oauth_client,
oauth_client,
auth_backend,
SECRET,
associate_by_email=True,
is_verified_by_default=True,
# points the user back to the login page, where we will call the
# /auth/google/callback endpoint + redirect them to the main app
redirect_url=f"{WEB_DOMAIN}/auth/google/callback",
# /auth/oauth/callback endpoint + redirect them to the main app
redirect_url=f"{WEB_DOMAIN}/auth/oauth/callback",
),
prefix="/auth/google",
prefix="/auth/oauth",
tags=["auth"],
)
application.include_router(
fastapi_users.get_oauth_associate_router(
google_oauth_client, UserRead, SECRET
oauth_client, UserRead, SECRET
),
prefix="/auth/associate/google",
prefix="/auth/associate/oauth",
tags=["auth"],
)
@ -135,10 +156,14 @@ def get_application() -> FastAPI:
if not DISABLE_AUTH:
if not ENABLE_OAUTH:
logger.warning("OAuth is turned off")
logger.debug("OAuth is turned off")
else:
if not GOOGLE_OAUTH_CLIENT_ID or not GOOGLE_OAUTH_CLIENT_SECRET:
logger.warning("OAuth is turned on but incorrectly configured")
if not OAUTH_CLIENT_ID:
logger.warning("OAuth is turned on but OAUTH_CLIENT_ID is empty")
if not OAUTH_CLIENT_SECRET:
logger.warning("OAuth is turned on but OAUTH_CLIENT_SECRET is empty")
if OAUTH_TYPE == "openid" and not OPENID_CONFIG_URL:
logger.warning("OpenID is turned on but OPENID_CONFIG_URL is emtpy")
else:
logger.debug("OAuth is turned on")

View File

@ -25,6 +25,8 @@ services:
- TYPESENSE_API_KEY=${TYPESENSE_API_KEY:-typesense_api_key}
- LOG_LEVEL=${LOG_LEVEL:-info}
- DISABLE_AUTH=${DISABLE_AUTH:-True}
- OAUTH_TYPE=${OAUTH_TYPE:-}
- OPENID_CONFIG_URL=${OPENID_CONFIG_URL:-}
- GOOGLE_OAUTH_CLIENT_ID=${GOOGLE_OAUTH_CLIENT_ID:-}
- GOOGLE_OAUTH_CLIENT_SECRET=${GOOGLE_OAUTH_CLIENT_SECRET:-}
- DISABLE_GENERATIVE_AI=${DISABLE_GENERATIVE_AI:-}
@ -72,6 +74,7 @@ services:
environment:
- INTERNAL_URL=http://api_server:8080
- DISABLE_AUTH=${DISABLE_AUTH:-True}
- OAUTH_NAME=${OAUTH_NAME:-}
relational_db:
image: postgres:15.2-alpine
restart: always

View File

@ -1,5 +1,5 @@
import { HealthCheckBanner } from "@/components/health/healthcheck";
import { DISABLE_AUTH } from "@/lib/constants";
import {DISABLE_AUTH, OAUTH_NAME} from "@/lib/constants";
import { User } from "@/lib/types";
import { getGoogleOAuthUrlSS, getCurrentUserSS } from "@/lib/userSS";
import { redirect } from "next/navigation";
@ -56,11 +56,11 @@ const Page = async () => {
" focus:outline-none focus:ring-2 hover:bg-red-700 focus:ring-offset-2 focus:ring-red-500"
}
>
Sign in with Google
Sign in with {OAUTH_NAME}
</a>
) : (
<button className={BUTTON_STYLE + " cursor-default"}>
Sign in with Google
Sign in with {OAUTH_NAME}
</button>
)}
</div>

View File

@ -0,0 +1,23 @@
import { getDomain } from "@/lib/redirectSS";
import { buildUrl } from "@/lib/utilsSS";
import { NextRequest, NextResponse } from "next/server";
export const GET = async (request: NextRequest) => {
// Wrapper around the FastAPI endpoint /auth/oauth/callback,
// which adds back a redirect to the main app.
const url = new URL(buildUrl("/auth/oauth/callback"));
url.search = request.nextUrl.search;
const response = await fetch(url.toString());
const setCookieHeader = response.headers.get("set-cookie");
if (!setCookieHeader) {
return NextResponse.redirect(new URL("/auth/error", getDomain(request)));
}
const redirectResponse = NextResponse.redirect(
new URL("/", getDomain(request))
);
redirectResponse.headers.set("set-cookie", setCookieHeader);
return redirectResponse;
};

View File

@ -1,4 +1,7 @@
export const DISABLE_AUTH = process.env.DISABLE_AUTH?.toLowerCase() === "true";
export const OAUTH_NAME = process.env.OAUTH_NAME || "Google";
export const INTERNAL_URL = process.env.INTERNAL_URL || "http://127.0.0.1:8080";
export const NEXT_PUBLIC_DISABLE_STREAMING =
process.env.NEXT_PUBLIC_DISABLE_STREAMING?.toLowerCase() === "true";

View File

@ -4,7 +4,7 @@ import { buildUrl } from "./utilsSS";
import { ReadonlyRequestCookies } from "next/dist/server/web/spec-extension/adapters/request-cookies";
export const getGoogleOAuthUrlSS = async (): Promise<string> => {
const res = await fetch(buildUrl("/auth/google/authorize"));
const res = await fetch(buildUrl("/auth/oauth/authorize"));
if (!res.ok) {
throw new Error("Failed to fetch data");
}