diff --git a/backend/alembic/versions/465f78d9b7f9_larger_access_tokens_for_oauth.py b/backend/alembic/versions/465f78d9b7f9_larger_access_tokens_for_oauth.py new file mode 100644 index 000000000..61d797e52 --- /dev/null +++ b/backend/alembic/versions/465f78d9b7f9_larger_access_tokens_for_oauth.py @@ -0,0 +1,32 @@ +"""Larger Access Tokens for OAUTH + +Revision ID: 465f78d9b7f9 +Revises: 3c5e35aa9af0 +Create Date: 2023-07-18 17:33:40.365034 + +""" +from alembic import op +import sqlalchemy as sa + + +# revision identifiers, used by Alembic. +revision = '465f78d9b7f9' +down_revision = '3c5e35aa9af0' +branch_labels = None +depends_on = None + + +def upgrade() -> None: + op.alter_column( + "oauth_account", + "access_token", + type_=sa.Text() + ) + + +def downgrade() -> None: + op.alter_column( + "oauth_account", + "access_token", + type_=sa.String(length=1024) + ) diff --git a/backend/danswer/auth/users.py b/backend/danswer/auth/users.py index c9cf8d58c..8f113979f 100644 --- a/backend/danswer/auth/users.py +++ b/backend/danswer/auth/users.py @@ -7,11 +7,32 @@ from email.mime.multipart import MIMEMultipart from email.mime.text import MIMEText from typing import Optional +from fastapi import Depends +from fastapi import HTTPException +from fastapi import Request +from fastapi import status +from fastapi_users import BaseUserManager +from fastapi_users import FastAPIUsers +from fastapi_users import models +from fastapi_users import schemas +from fastapi_users import UUIDIDMixin +from fastapi_users.authentication import AuthenticationBackend +from fastapi_users.authentication import CookieTransport +from fastapi_users.authentication.strategy.db import AccessTokenDatabase +from fastapi_users.authentication.strategy.db import DatabaseStrategy +from fastapi_users.db import SQLAlchemyUserDatabase +from httpx_oauth.clients.google import GoogleOAuth2 +from httpx_oauth.clients.openid import OpenID +from pydantic import EmailStr + from danswer.auth.schemas import UserCreate from danswer.auth.schemas import UserRole from danswer.configs.app_configs import DISABLE_AUTH -from danswer.configs.app_configs import GOOGLE_OAUTH_CLIENT_ID -from danswer.configs.app_configs import GOOGLE_OAUTH_CLIENT_SECRET +from danswer.configs.app_configs import ENABLE_OAUTH +from danswer.configs.app_configs import OAUTH_CLIENT_ID +from danswer.configs.app_configs import OAUTH_CLIENT_SECRET +from danswer.configs.app_configs import OAUTH_TYPE +from danswer.configs.app_configs import OPENID_CONFIG_URL from danswer.configs.app_configs import REQUIRE_EMAIL_VERIFICATION from danswer.configs.app_configs import SECRET from danswer.configs.app_configs import SESSION_EXPIRE_TIME_SECONDS @@ -28,22 +49,6 @@ from danswer.db.engine import get_async_session from danswer.db.models import AccessToken from danswer.db.models import User from danswer.utils.logger import setup_logger -from fastapi import Depends -from fastapi import HTTPException -from fastapi import Request -from fastapi import status -from fastapi_users import BaseUserManager -from fastapi_users import FastAPIUsers -from fastapi_users import models -from fastapi_users import schemas -from fastapi_users import UUIDIDMixin -from fastapi_users.authentication import AuthenticationBackend -from fastapi_users.authentication import CookieTransport -from fastapi_users.authentication.strategy.db import AccessTokenDatabase -from fastapi_users.authentication.strategy.db import DatabaseStrategy -from fastapi_users.db import SQLAlchemyUserDatabase -from httpx_oauth.clients.google import GoogleOAuth2 -from pydantic import EmailStr logger = setup_logger() @@ -193,7 +198,15 @@ auth_backend = AuthenticationBackend( get_strategy=get_database_strategy, ) -google_oauth_client = GoogleOAuth2(GOOGLE_OAUTH_CLIENT_ID, GOOGLE_OAUTH_CLIENT_SECRET) +oauth_client = None # type: GoogleOAuth2 | OpenID | None +if ENABLE_OAUTH: + if OAUTH_TYPE == "google": + oauth_client = GoogleOAuth2(OAUTH_CLIENT_ID, OAUTH_CLIENT_SECRET) + elif OAUTH_TYPE == "openid": + oauth_client = OpenID(OAUTH_CLIENT_ID, OAUTH_CLIENT_SECRET, OPENID_CONFIG_URL) + else: + raise AssertionError(f"Invalid OAUTH type {OAUTH_TYPE}") + fastapi_users = FastAPIUsers[User, uuid.UUID](get_user_manager, [auth_backend]) diff --git a/backend/danswer/configs/app_configs.py b/backend/danswer/configs/app_configs.py index d099aa2cf..8d25858c1 100644 --- a/backend/danswer/configs/app_configs.py +++ b/backend/danswer/configs/app_configs.py @@ -46,8 +46,10 @@ SESSION_EXPIRE_TIME_SECONDS = int( VALID_EMAIL_DOMAIN = os.environ.get("VALID_EMAIL_DOMAIN", "") # OAuth Login Flow ENABLE_OAUTH = os.environ.get("ENABLE_OAUTH", "").lower() != "false" -GOOGLE_OAUTH_CLIENT_ID = os.environ.get("GOOGLE_OAUTH_CLIENT_ID", "") -GOOGLE_OAUTH_CLIENT_SECRET = os.environ.get("GOOGLE_OAUTH_CLIENT_SECRET", "") +OAUTH_TYPE = os.environ.get("OAUTH_TYPE", "google").lower() +OAUTH_CLIENT_ID = os.environ.get("OAUTH_CLIENT_ID", os.environ.get("GOOGLE_OAUTH_CLIENT_ID", "")) +OAUTH_CLIENT_SECRET = os.environ.get("OAUTH_CLIENT_SECRET", os.environ.get("GOOGLE_OAUTH_CLIENT_SECRET", "")) +OPENID_CONFIG_URL = os.environ.get("OPENID_CONFIG_URL", "") MASK_CREDENTIAL_PREFIX = ( os.environ.get("MASK_CREDENTIAL_PREFIX", "True").lower() != "false" ) diff --git a/backend/danswer/db/models.py b/backend/danswer/db/models.py index 5a467e2a9..a619cd7a5 100644 --- a/backend/danswer/db/models.py +++ b/backend/danswer/db/models.py @@ -4,9 +4,6 @@ from typing import Any from typing import List from uuid import UUID -from danswer.auth.schemas import UserRole -from danswer.configs.constants import DocumentSource -from danswer.connectors.models import InputType from fastapi_users.db import SQLAlchemyBaseOAuthAccountTableUUID from fastapi_users.db import SQLAlchemyBaseUserTableUUID from fastapi_users_db_sqlalchemy.access_token import SQLAlchemyBaseAccessTokenTableUUID @@ -17,12 +14,17 @@ from sqlalchemy import ForeignKey from sqlalchemy import func from sqlalchemy import Integer from sqlalchemy import String +from sqlalchemy import Text from sqlalchemy.dialects import postgresql from sqlalchemy.orm import DeclarativeBase from sqlalchemy.orm import Mapped from sqlalchemy.orm import mapped_column from sqlalchemy.orm import relationship +from danswer.auth.schemas import UserRole +from danswer.configs.constants import DocumentSource +from danswer.connectors.models import InputType + class IndexingStatus(str, PyEnum): NOT_STARTED = "not_started" @@ -36,7 +38,8 @@ class Base(DeclarativeBase): class OAuthAccount(SQLAlchemyBaseOAuthAccountTableUUID, Base): - pass + # even an almost empty token from keycloak will not fit the default 1024 bytes + access_token: Mapped[str] = mapped_column(Text(), nullable=False) # type: ignore class User(SQLAlchemyBaseUserTableUUID, Base): diff --git a/backend/danswer/main.py b/backend/danswer/main.py index dc003a6b9..32509af79 100644 --- a/backend/danswer/main.py +++ b/backend/danswer/main.py @@ -5,14 +5,14 @@ from danswer.auth.schemas import UserRead from danswer.auth.schemas import UserUpdate from danswer.auth.users import auth_backend from danswer.auth.users import fastapi_users -from danswer.auth.users import google_oauth_client -from danswer.configs.app_configs import APP_HOST +from danswer.auth.users import oauth_client +from danswer.configs.app_configs import APP_HOST, OAUTH_TYPE, OPENID_CONFIG_URL from danswer.configs.app_configs import APP_PORT from danswer.configs.app_configs import DISABLE_AUTH from danswer.configs.app_configs import DISABLE_GENERATIVE_AI from danswer.configs.app_configs import ENABLE_OAUTH -from danswer.configs.app_configs import GOOGLE_OAUTH_CLIENT_ID -from danswer.configs.app_configs import GOOGLE_OAUTH_CLIENT_SECRET +from danswer.configs.app_configs import OAUTH_CLIENT_ID +from danswer.configs.app_configs import OAUTH_CLIENT_SECRET from danswer.configs.app_configs import QDRANT_DEFAULT_COLLECTION from danswer.configs.app_configs import SECRET from danswer.configs.app_configs import TYPESENSE_DEFAULT_COLLECTION @@ -49,6 +49,11 @@ def validation_exception_handler( def value_error_handler(_: Request, exc: ValueError) -> JSONResponse: + try: + raise(exc) + except: + # log stacktrace + logger.exception("ValueError") return JSONResponse( status_code=400, content={"message": str(exc)}, @@ -88,25 +93,41 @@ def get_application() -> FastAPI: tags=["users"], ) if ENABLE_OAUTH: + if OAUTH_TYPE == "google": + # special case for google + application.include_router( + fastapi_users.get_oauth_router( + oauth_client, + auth_backend, + SECRET, + associate_by_email=True, + is_verified_by_default=True, + # points the user back to the login page, where we will call the + # /auth/google/callback endpoint + redirect them to the main app + redirect_url=f"{WEB_DOMAIN}/auth/google/callback", + ), + prefix="/auth/google", + tags=["auth"], + ) application.include_router( fastapi_users.get_oauth_router( - google_oauth_client, + oauth_client, auth_backend, SECRET, associate_by_email=True, is_verified_by_default=True, # points the user back to the login page, where we will call the - # /auth/google/callback endpoint + redirect them to the main app - redirect_url=f"{WEB_DOMAIN}/auth/google/callback", + # /auth/oauth/callback endpoint + redirect them to the main app + redirect_url=f"{WEB_DOMAIN}/auth/oauth/callback", ), - prefix="/auth/google", + prefix="/auth/oauth", tags=["auth"], ) application.include_router( fastapi_users.get_oauth_associate_router( - google_oauth_client, UserRead, SECRET + oauth_client, UserRead, SECRET ), - prefix="/auth/associate/google", + prefix="/auth/associate/oauth", tags=["auth"], ) @@ -135,10 +156,14 @@ def get_application() -> FastAPI: if not DISABLE_AUTH: if not ENABLE_OAUTH: - logger.warning("OAuth is turned off") + logger.debug("OAuth is turned off") else: - if not GOOGLE_OAUTH_CLIENT_ID or not GOOGLE_OAUTH_CLIENT_SECRET: - logger.warning("OAuth is turned on but incorrectly configured") + if not OAUTH_CLIENT_ID: + logger.warning("OAuth is turned on but OAUTH_CLIENT_ID is empty") + if not OAUTH_CLIENT_SECRET: + logger.warning("OAuth is turned on but OAUTH_CLIENT_SECRET is empty") + if OAUTH_TYPE == "openid" and not OPENID_CONFIG_URL: + logger.warning("OpenID is turned on but OPENID_CONFIG_URL is emtpy") else: logger.debug("OAuth is turned on") diff --git a/deployment/docker_compose/docker-compose.dev.yml b/deployment/docker_compose/docker-compose.dev.yml index f14ffe1e3..c05fc3e5b 100644 --- a/deployment/docker_compose/docker-compose.dev.yml +++ b/deployment/docker_compose/docker-compose.dev.yml @@ -25,6 +25,8 @@ services: - TYPESENSE_API_KEY=${TYPESENSE_API_KEY:-typesense_api_key} - LOG_LEVEL=${LOG_LEVEL:-info} - DISABLE_AUTH=${DISABLE_AUTH:-True} + - OAUTH_TYPE=${OAUTH_TYPE:-} + - OPENID_CONFIG_URL=${OPENID_CONFIG_URL:-} - GOOGLE_OAUTH_CLIENT_ID=${GOOGLE_OAUTH_CLIENT_ID:-} - GOOGLE_OAUTH_CLIENT_SECRET=${GOOGLE_OAUTH_CLIENT_SECRET:-} - DISABLE_GENERATIVE_AI=${DISABLE_GENERATIVE_AI:-} @@ -72,6 +74,7 @@ services: environment: - INTERNAL_URL=http://api_server:8080 - DISABLE_AUTH=${DISABLE_AUTH:-True} + - OAUTH_NAME=${OAUTH_NAME:-} relational_db: image: postgres:15.2-alpine restart: always diff --git a/web/src/app/auth/login/page.tsx b/web/src/app/auth/login/page.tsx index 71e8c9780..1fe3e3343 100644 --- a/web/src/app/auth/login/page.tsx +++ b/web/src/app/auth/login/page.tsx @@ -1,5 +1,5 @@ import { HealthCheckBanner } from "@/components/health/healthcheck"; -import { DISABLE_AUTH } from "@/lib/constants"; +import {DISABLE_AUTH, OAUTH_NAME} from "@/lib/constants"; import { User } from "@/lib/types"; import { getGoogleOAuthUrlSS, getCurrentUserSS } from "@/lib/userSS"; import { redirect } from "next/navigation"; @@ -56,11 +56,11 @@ const Page = async () => { " focus:outline-none focus:ring-2 hover:bg-red-700 focus:ring-offset-2 focus:ring-red-500" } > - Sign in with Google + Sign in with {OAUTH_NAME} ) : ( )} diff --git a/web/src/app/auth/oauth/callback/route.ts b/web/src/app/auth/oauth/callback/route.ts new file mode 100644 index 000000000..0b4157731 --- /dev/null +++ b/web/src/app/auth/oauth/callback/route.ts @@ -0,0 +1,23 @@ +import { getDomain } from "@/lib/redirectSS"; +import { buildUrl } from "@/lib/utilsSS"; +import { NextRequest, NextResponse } from "next/server"; + +export const GET = async (request: NextRequest) => { + // Wrapper around the FastAPI endpoint /auth/oauth/callback, + // which adds back a redirect to the main app. + const url = new URL(buildUrl("/auth/oauth/callback")); + url.search = request.nextUrl.search; + + const response = await fetch(url.toString()); + const setCookieHeader = response.headers.get("set-cookie"); + + if (!setCookieHeader) { + return NextResponse.redirect(new URL("/auth/error", getDomain(request))); + } + + const redirectResponse = NextResponse.redirect( + new URL("/", getDomain(request)) + ); + redirectResponse.headers.set("set-cookie", setCookieHeader); + return redirectResponse; +}; diff --git a/web/src/lib/constants.ts b/web/src/lib/constants.ts index d466f4b0e..1343f07ad 100644 --- a/web/src/lib/constants.ts +++ b/web/src/lib/constants.ts @@ -1,4 +1,7 @@ export const DISABLE_AUTH = process.env.DISABLE_AUTH?.toLowerCase() === "true"; + +export const OAUTH_NAME = process.env.OAUTH_NAME || "Google"; + export const INTERNAL_URL = process.env.INTERNAL_URL || "http://127.0.0.1:8080"; export const NEXT_PUBLIC_DISABLE_STREAMING = process.env.NEXT_PUBLIC_DISABLE_STREAMING?.toLowerCase() === "true"; diff --git a/web/src/lib/userSS.ts b/web/src/lib/userSS.ts index fec47551e..bedaee71e 100644 --- a/web/src/lib/userSS.ts +++ b/web/src/lib/userSS.ts @@ -4,7 +4,7 @@ import { buildUrl } from "./utilsSS"; import { ReadonlyRequestCookies } from "next/dist/server/web/spec-extension/adapters/request-cookies"; export const getGoogleOAuthUrlSS = async (): Promise => { - const res = await fetch(buildUrl("/auth/google/authorize")); + const res = await fetch(buildUrl("/auth/oauth/authorize")); if (!res.ok) { throw new Error("Failed to fetch data"); }