diff --git a/backend/alembic/versions/465f78d9b7f9_larger_access_tokens_for_oauth.py b/backend/alembic/versions/465f78d9b7f9_larger_access_tokens_for_oauth.py
new file mode 100644
index 000000000..61d797e52
--- /dev/null
+++ b/backend/alembic/versions/465f78d9b7f9_larger_access_tokens_for_oauth.py
@@ -0,0 +1,32 @@
+"""Larger Access Tokens for OAUTH
+
+Revision ID: 465f78d9b7f9
+Revises: 3c5e35aa9af0
+Create Date: 2023-07-18 17:33:40.365034
+
+"""
+from alembic import op
+import sqlalchemy as sa
+
+
+# revision identifiers, used by Alembic.
+revision = '465f78d9b7f9'
+down_revision = '3c5e35aa9af0'
+branch_labels = None
+depends_on = None
+
+
+def upgrade() -> None:
+ op.alter_column(
+ "oauth_account",
+ "access_token",
+ type_=sa.Text()
+ )
+
+
+def downgrade() -> None:
+ op.alter_column(
+ "oauth_account",
+ "access_token",
+ type_=sa.String(length=1024)
+ )
diff --git a/backend/danswer/auth/users.py b/backend/danswer/auth/users.py
index c9cf8d58c..8f113979f 100644
--- a/backend/danswer/auth/users.py
+++ b/backend/danswer/auth/users.py
@@ -7,11 +7,32 @@ from email.mime.multipart import MIMEMultipart
from email.mime.text import MIMEText
from typing import Optional
+from fastapi import Depends
+from fastapi import HTTPException
+from fastapi import Request
+from fastapi import status
+from fastapi_users import BaseUserManager
+from fastapi_users import FastAPIUsers
+from fastapi_users import models
+from fastapi_users import schemas
+from fastapi_users import UUIDIDMixin
+from fastapi_users.authentication import AuthenticationBackend
+from fastapi_users.authentication import CookieTransport
+from fastapi_users.authentication.strategy.db import AccessTokenDatabase
+from fastapi_users.authentication.strategy.db import DatabaseStrategy
+from fastapi_users.db import SQLAlchemyUserDatabase
+from httpx_oauth.clients.google import GoogleOAuth2
+from httpx_oauth.clients.openid import OpenID
+from pydantic import EmailStr
+
from danswer.auth.schemas import UserCreate
from danswer.auth.schemas import UserRole
from danswer.configs.app_configs import DISABLE_AUTH
-from danswer.configs.app_configs import GOOGLE_OAUTH_CLIENT_ID
-from danswer.configs.app_configs import GOOGLE_OAUTH_CLIENT_SECRET
+from danswer.configs.app_configs import ENABLE_OAUTH
+from danswer.configs.app_configs import OAUTH_CLIENT_ID
+from danswer.configs.app_configs import OAUTH_CLIENT_SECRET
+from danswer.configs.app_configs import OAUTH_TYPE
+from danswer.configs.app_configs import OPENID_CONFIG_URL
from danswer.configs.app_configs import REQUIRE_EMAIL_VERIFICATION
from danswer.configs.app_configs import SECRET
from danswer.configs.app_configs import SESSION_EXPIRE_TIME_SECONDS
@@ -28,22 +49,6 @@ from danswer.db.engine import get_async_session
from danswer.db.models import AccessToken
from danswer.db.models import User
from danswer.utils.logger import setup_logger
-from fastapi import Depends
-from fastapi import HTTPException
-from fastapi import Request
-from fastapi import status
-from fastapi_users import BaseUserManager
-from fastapi_users import FastAPIUsers
-from fastapi_users import models
-from fastapi_users import schemas
-from fastapi_users import UUIDIDMixin
-from fastapi_users.authentication import AuthenticationBackend
-from fastapi_users.authentication import CookieTransport
-from fastapi_users.authentication.strategy.db import AccessTokenDatabase
-from fastapi_users.authentication.strategy.db import DatabaseStrategy
-from fastapi_users.db import SQLAlchemyUserDatabase
-from httpx_oauth.clients.google import GoogleOAuth2
-from pydantic import EmailStr
logger = setup_logger()
@@ -193,7 +198,15 @@ auth_backend = AuthenticationBackend(
get_strategy=get_database_strategy,
)
-google_oauth_client = GoogleOAuth2(GOOGLE_OAUTH_CLIENT_ID, GOOGLE_OAUTH_CLIENT_SECRET)
+oauth_client = None # type: GoogleOAuth2 | OpenID | None
+if ENABLE_OAUTH:
+ if OAUTH_TYPE == "google":
+ oauth_client = GoogleOAuth2(OAUTH_CLIENT_ID, OAUTH_CLIENT_SECRET)
+ elif OAUTH_TYPE == "openid":
+ oauth_client = OpenID(OAUTH_CLIENT_ID, OAUTH_CLIENT_SECRET, OPENID_CONFIG_URL)
+ else:
+ raise AssertionError(f"Invalid OAUTH type {OAUTH_TYPE}")
+
fastapi_users = FastAPIUsers[User, uuid.UUID](get_user_manager, [auth_backend])
diff --git a/backend/danswer/configs/app_configs.py b/backend/danswer/configs/app_configs.py
index d099aa2cf..8d25858c1 100644
--- a/backend/danswer/configs/app_configs.py
+++ b/backend/danswer/configs/app_configs.py
@@ -46,8 +46,10 @@ SESSION_EXPIRE_TIME_SECONDS = int(
VALID_EMAIL_DOMAIN = os.environ.get("VALID_EMAIL_DOMAIN", "")
# OAuth Login Flow
ENABLE_OAUTH = os.environ.get("ENABLE_OAUTH", "").lower() != "false"
-GOOGLE_OAUTH_CLIENT_ID = os.environ.get("GOOGLE_OAUTH_CLIENT_ID", "")
-GOOGLE_OAUTH_CLIENT_SECRET = os.environ.get("GOOGLE_OAUTH_CLIENT_SECRET", "")
+OAUTH_TYPE = os.environ.get("OAUTH_TYPE", "google").lower()
+OAUTH_CLIENT_ID = os.environ.get("OAUTH_CLIENT_ID", os.environ.get("GOOGLE_OAUTH_CLIENT_ID", ""))
+OAUTH_CLIENT_SECRET = os.environ.get("OAUTH_CLIENT_SECRET", os.environ.get("GOOGLE_OAUTH_CLIENT_SECRET", ""))
+OPENID_CONFIG_URL = os.environ.get("OPENID_CONFIG_URL", "")
MASK_CREDENTIAL_PREFIX = (
os.environ.get("MASK_CREDENTIAL_PREFIX", "True").lower() != "false"
)
diff --git a/backend/danswer/db/models.py b/backend/danswer/db/models.py
index 5a467e2a9..a619cd7a5 100644
--- a/backend/danswer/db/models.py
+++ b/backend/danswer/db/models.py
@@ -4,9 +4,6 @@ from typing import Any
from typing import List
from uuid import UUID
-from danswer.auth.schemas import UserRole
-from danswer.configs.constants import DocumentSource
-from danswer.connectors.models import InputType
from fastapi_users.db import SQLAlchemyBaseOAuthAccountTableUUID
from fastapi_users.db import SQLAlchemyBaseUserTableUUID
from fastapi_users_db_sqlalchemy.access_token import SQLAlchemyBaseAccessTokenTableUUID
@@ -17,12 +14,17 @@ from sqlalchemy import ForeignKey
from sqlalchemy import func
from sqlalchemy import Integer
from sqlalchemy import String
+from sqlalchemy import Text
from sqlalchemy.dialects import postgresql
from sqlalchemy.orm import DeclarativeBase
from sqlalchemy.orm import Mapped
from sqlalchemy.orm import mapped_column
from sqlalchemy.orm import relationship
+from danswer.auth.schemas import UserRole
+from danswer.configs.constants import DocumentSource
+from danswer.connectors.models import InputType
+
class IndexingStatus(str, PyEnum):
NOT_STARTED = "not_started"
@@ -36,7 +38,8 @@ class Base(DeclarativeBase):
class OAuthAccount(SQLAlchemyBaseOAuthAccountTableUUID, Base):
- pass
+ # even an almost empty token from keycloak will not fit the default 1024 bytes
+ access_token: Mapped[str] = mapped_column(Text(), nullable=False) # type: ignore
class User(SQLAlchemyBaseUserTableUUID, Base):
diff --git a/backend/danswer/main.py b/backend/danswer/main.py
index dc003a6b9..32509af79 100644
--- a/backend/danswer/main.py
+++ b/backend/danswer/main.py
@@ -5,14 +5,14 @@ from danswer.auth.schemas import UserRead
from danswer.auth.schemas import UserUpdate
from danswer.auth.users import auth_backend
from danswer.auth.users import fastapi_users
-from danswer.auth.users import google_oauth_client
-from danswer.configs.app_configs import APP_HOST
+from danswer.auth.users import oauth_client
+from danswer.configs.app_configs import APP_HOST, OAUTH_TYPE, OPENID_CONFIG_URL
from danswer.configs.app_configs import APP_PORT
from danswer.configs.app_configs import DISABLE_AUTH
from danswer.configs.app_configs import DISABLE_GENERATIVE_AI
from danswer.configs.app_configs import ENABLE_OAUTH
-from danswer.configs.app_configs import GOOGLE_OAUTH_CLIENT_ID
-from danswer.configs.app_configs import GOOGLE_OAUTH_CLIENT_SECRET
+from danswer.configs.app_configs import OAUTH_CLIENT_ID
+from danswer.configs.app_configs import OAUTH_CLIENT_SECRET
from danswer.configs.app_configs import QDRANT_DEFAULT_COLLECTION
from danswer.configs.app_configs import SECRET
from danswer.configs.app_configs import TYPESENSE_DEFAULT_COLLECTION
@@ -49,6 +49,11 @@ def validation_exception_handler(
def value_error_handler(_: Request, exc: ValueError) -> JSONResponse:
+ try:
+ raise(exc)
+ except:
+ # log stacktrace
+ logger.exception("ValueError")
return JSONResponse(
status_code=400,
content={"message": str(exc)},
@@ -88,25 +93,41 @@ def get_application() -> FastAPI:
tags=["users"],
)
if ENABLE_OAUTH:
+ if OAUTH_TYPE == "google":
+ # special case for google
+ application.include_router(
+ fastapi_users.get_oauth_router(
+ oauth_client,
+ auth_backend,
+ SECRET,
+ associate_by_email=True,
+ is_verified_by_default=True,
+ # points the user back to the login page, where we will call the
+ # /auth/google/callback endpoint + redirect them to the main app
+ redirect_url=f"{WEB_DOMAIN}/auth/google/callback",
+ ),
+ prefix="/auth/google",
+ tags=["auth"],
+ )
application.include_router(
fastapi_users.get_oauth_router(
- google_oauth_client,
+ oauth_client,
auth_backend,
SECRET,
associate_by_email=True,
is_verified_by_default=True,
# points the user back to the login page, where we will call the
- # /auth/google/callback endpoint + redirect them to the main app
- redirect_url=f"{WEB_DOMAIN}/auth/google/callback",
+ # /auth/oauth/callback endpoint + redirect them to the main app
+ redirect_url=f"{WEB_DOMAIN}/auth/oauth/callback",
),
- prefix="/auth/google",
+ prefix="/auth/oauth",
tags=["auth"],
)
application.include_router(
fastapi_users.get_oauth_associate_router(
- google_oauth_client, UserRead, SECRET
+ oauth_client, UserRead, SECRET
),
- prefix="/auth/associate/google",
+ prefix="/auth/associate/oauth",
tags=["auth"],
)
@@ -135,10 +156,14 @@ def get_application() -> FastAPI:
if not DISABLE_AUTH:
if not ENABLE_OAUTH:
- logger.warning("OAuth is turned off")
+ logger.debug("OAuth is turned off")
else:
- if not GOOGLE_OAUTH_CLIENT_ID or not GOOGLE_OAUTH_CLIENT_SECRET:
- logger.warning("OAuth is turned on but incorrectly configured")
+ if not OAUTH_CLIENT_ID:
+ logger.warning("OAuth is turned on but OAUTH_CLIENT_ID is empty")
+ if not OAUTH_CLIENT_SECRET:
+ logger.warning("OAuth is turned on but OAUTH_CLIENT_SECRET is empty")
+ if OAUTH_TYPE == "openid" and not OPENID_CONFIG_URL:
+ logger.warning("OpenID is turned on but OPENID_CONFIG_URL is emtpy")
else:
logger.debug("OAuth is turned on")
diff --git a/deployment/docker_compose/docker-compose.dev.yml b/deployment/docker_compose/docker-compose.dev.yml
index f14ffe1e3..c05fc3e5b 100644
--- a/deployment/docker_compose/docker-compose.dev.yml
+++ b/deployment/docker_compose/docker-compose.dev.yml
@@ -25,6 +25,8 @@ services:
- TYPESENSE_API_KEY=${TYPESENSE_API_KEY:-typesense_api_key}
- LOG_LEVEL=${LOG_LEVEL:-info}
- DISABLE_AUTH=${DISABLE_AUTH:-True}
+ - OAUTH_TYPE=${OAUTH_TYPE:-}
+ - OPENID_CONFIG_URL=${OPENID_CONFIG_URL:-}
- GOOGLE_OAUTH_CLIENT_ID=${GOOGLE_OAUTH_CLIENT_ID:-}
- GOOGLE_OAUTH_CLIENT_SECRET=${GOOGLE_OAUTH_CLIENT_SECRET:-}
- DISABLE_GENERATIVE_AI=${DISABLE_GENERATIVE_AI:-}
@@ -72,6 +74,7 @@ services:
environment:
- INTERNAL_URL=http://api_server:8080
- DISABLE_AUTH=${DISABLE_AUTH:-True}
+ - OAUTH_NAME=${OAUTH_NAME:-}
relational_db:
image: postgres:15.2-alpine
restart: always
diff --git a/web/src/app/auth/login/page.tsx b/web/src/app/auth/login/page.tsx
index 71e8c9780..1fe3e3343 100644
--- a/web/src/app/auth/login/page.tsx
+++ b/web/src/app/auth/login/page.tsx
@@ -1,5 +1,5 @@
import { HealthCheckBanner } from "@/components/health/healthcheck";
-import { DISABLE_AUTH } from "@/lib/constants";
+import {DISABLE_AUTH, OAUTH_NAME} from "@/lib/constants";
import { User } from "@/lib/types";
import { getGoogleOAuthUrlSS, getCurrentUserSS } from "@/lib/userSS";
import { redirect } from "next/navigation";
@@ -56,11 +56,11 @@ const Page = async () => {
" focus:outline-none focus:ring-2 hover:bg-red-700 focus:ring-offset-2 focus:ring-red-500"
}
>
- Sign in with Google
+ Sign in with {OAUTH_NAME}
) : (
)}
diff --git a/web/src/app/auth/oauth/callback/route.ts b/web/src/app/auth/oauth/callback/route.ts
new file mode 100644
index 000000000..0b4157731
--- /dev/null
+++ b/web/src/app/auth/oauth/callback/route.ts
@@ -0,0 +1,23 @@
+import { getDomain } from "@/lib/redirectSS";
+import { buildUrl } from "@/lib/utilsSS";
+import { NextRequest, NextResponse } from "next/server";
+
+export const GET = async (request: NextRequest) => {
+ // Wrapper around the FastAPI endpoint /auth/oauth/callback,
+ // which adds back a redirect to the main app.
+ const url = new URL(buildUrl("/auth/oauth/callback"));
+ url.search = request.nextUrl.search;
+
+ const response = await fetch(url.toString());
+ const setCookieHeader = response.headers.get("set-cookie");
+
+ if (!setCookieHeader) {
+ return NextResponse.redirect(new URL("/auth/error", getDomain(request)));
+ }
+
+ const redirectResponse = NextResponse.redirect(
+ new URL("/", getDomain(request))
+ );
+ redirectResponse.headers.set("set-cookie", setCookieHeader);
+ return redirectResponse;
+};
diff --git a/web/src/lib/constants.ts b/web/src/lib/constants.ts
index d466f4b0e..1343f07ad 100644
--- a/web/src/lib/constants.ts
+++ b/web/src/lib/constants.ts
@@ -1,4 +1,7 @@
export const DISABLE_AUTH = process.env.DISABLE_AUTH?.toLowerCase() === "true";
+
+export const OAUTH_NAME = process.env.OAUTH_NAME || "Google";
+
export const INTERNAL_URL = process.env.INTERNAL_URL || "http://127.0.0.1:8080";
export const NEXT_PUBLIC_DISABLE_STREAMING =
process.env.NEXT_PUBLIC_DISABLE_STREAMING?.toLowerCase() === "true";
diff --git a/web/src/lib/userSS.ts b/web/src/lib/userSS.ts
index fec47551e..bedaee71e 100644
--- a/web/src/lib/userSS.ts
+++ b/web/src/lib/userSS.ts
@@ -4,7 +4,7 @@ import { buildUrl } from "./utilsSS";
import { ReadonlyRequestCookies } from "next/dist/server/web/spec-extension/adapters/request-cookies";
export const getGoogleOAuthUrlSS = async (): Promise => {
- const res = await fetch(buildUrl("/auth/google/authorize"));
+ const res = await fetch(buildUrl("/auth/oauth/authorize"));
if (!res.ok) {
throw new Error("Failed to fetch data");
}