diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index 975bb030582..dcc4d54886e 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -115,11 +115,11 @@ mkdir dynamic_config_storage To start the frontend, navigate to `danswer/web` and run: ```bash -DISABLE_AUTH=true npm run dev +AUTH_TYPE=disabled npm run dev ``` _for Windows, run:_ ```bash -(SET "DISABLE_AUTH=true" && npm run dev) +(SET "AUTH_TYPE=disabled" && npm run dev) ``` @@ -138,7 +138,7 @@ zip -r ../vespa-app.zip . To run the backend API server, navigate back to `danswer/backend` and run: ```bash -DISABLE_AUTH=True \ +AUTH_TYPE=disabled \ DYNAMIC_CONFIG_DIR_PATH=./dynamic_config_storage \ VESPA_DEPLOYMENT_ZIP=./danswer/datastores/vespa/vespa-app.zip \ uvicorn danswer.main:app --reload --port 8080 @@ -146,7 +146,7 @@ uvicorn danswer.main:app --reload --port 8080 _For Windows (for compatibility with both PowerShell and Command Prompt):_ ```bash powershell -Command " - $env:DISABLE_AUTH='True' + $env:AUTH_TYPE='disabled' $env:DYNAMIC_CONFIG_DIR_PATH='./dynamic_config_storage' $env:VESPA_DEPLOYMENT_ZIP='./danswer/datastores/vespa/vespa-app.zip' uvicorn danswer.main:app --reload --port 8080 diff --git a/backend/alembic/versions/ae62505e3acc_add_saml_accounts.py b/backend/alembic/versions/ae62505e3acc_add_saml_accounts.py new file mode 100644 index 00000000000..db67d0274c1 --- /dev/null +++ b/backend/alembic/versions/ae62505e3acc_add_saml_accounts.py @@ -0,0 +1,47 @@ +"""Add SAML Accounts + +Revision ID: ae62505e3acc +Revises: 7da543f5672f +Create Date: 2023-09-26 16:19:30.933183 + +""" +import fastapi_users_db_sqlalchemy +from alembic import op +import sqlalchemy as sa + +# revision identifiers, used by Alembic. +revision = "ae62505e3acc" +down_revision = "7da543f5672f" +branch_labels = None +depends_on = None + + +def upgrade() -> None: + op.create_table( + "saml", + sa.Column("id", sa.Integer(), nullable=False), + sa.Column( + "user_id", + fastapi_users_db_sqlalchemy.generics.GUID(), + nullable=False, + ), + sa.Column("encrypted_cookie", sa.Text(), nullable=False), + sa.Column("expires_at", sa.DateTime(timezone=True), nullable=True), + sa.Column( + "updated_at", + sa.DateTime(timezone=True), + server_default=sa.text("now()"), + nullable=False, + ), + sa.ForeignKeyConstraint( + ["user_id"], + ["user.id"], + ), + sa.PrimaryKeyConstraint("id"), + sa.UniqueConstraint("encrypted_cookie"), + sa.UniqueConstraint("user_id"), + ) + + +def downgrade() -> None: + op.drop_table("saml") diff --git a/backend/danswer/auth/users.py b/backend/danswer/auth/users.py index c6be29175db..c02ee7a46a7 100644 --- a/backend/danswer/auth/users.py +++ b/backend/danswer/auth/users.py @@ -1,4 +1,3 @@ -import contextlib import os import smtplib import uuid @@ -6,10 +5,13 @@ from collections.abc import AsyncGenerator from email.mime.multipart import MIMEMultipart from email.mime.text import MIMEText from typing import Optional +from typing import Tuple +from fastapi import APIRouter from fastapi import Depends from fastapi import HTTPException from fastapi import Request +from fastapi import Response from fastapi import status from fastapi_users import BaseUserManager from fastapi_users import FastAPIUsers @@ -18,21 +20,17 @@ from fastapi_users import schemas from fastapi_users import UUIDIDMixin from fastapi_users.authentication import AuthenticationBackend from fastapi_users.authentication import CookieTransport +from fastapi_users.authentication import Strategy from fastapi_users.authentication.strategy.db import AccessTokenDatabase from fastapi_users.authentication.strategy.db import DatabaseStrategy from fastapi_users.db import SQLAlchemyUserDatabase -from httpx_oauth.clients.google import GoogleOAuth2 -from httpx_oauth.clients.openid import OpenID -from pydantic import EmailStr +from fastapi_users.openapi import OpenAPIResponseType +from sqlalchemy.orm import Session from danswer.auth.schemas import UserCreate from danswer.auth.schemas import UserRole +from danswer.configs.app_configs import AUTH_TYPE from danswer.configs.app_configs import DISABLE_AUTH -from danswer.configs.app_configs import ENABLE_OAUTH -from danswer.configs.app_configs import OAUTH_CLIENT_ID -from danswer.configs.app_configs import OAUTH_CLIENT_SECRET -from danswer.configs.app_configs import OAUTH_TYPE -from danswer.configs.app_configs import OPENID_CONFIG_URL from danswer.configs.app_configs import REQUIRE_EMAIL_VERIFICATION from danswer.configs.app_configs import SECRET from danswer.configs.app_configs import SESSION_EXPIRE_TIME_SECONDS @@ -42,23 +40,32 @@ from danswer.configs.app_configs import SMTP_SERVER from danswer.configs.app_configs import SMTP_USER from danswer.configs.app_configs import VALID_EMAIL_DOMAINS from danswer.configs.app_configs import WEB_DOMAIN +from danswer.configs.constants import AuthType from danswer.db.auth import get_access_token_db from danswer.db.auth import get_user_count from danswer.db.auth import get_user_db -from danswer.db.engine import get_async_session +from danswer.db.engine import get_session from danswer.db.models import AccessToken from danswer.db.models import User from danswer.utils.logger import setup_logger +from danswer.utils.variable_functionality import fetch_versioned_implementation + logger = setup_logger() -FAKE_USER_EMAIL = "fakeuser@fakedanswermail.com" -FAKE_USER_PASS = "foobar" - USER_WHITELIST_FILE = "/home/danswer_whitelist.txt" _user_whitelist: list[str] | None = None +def verify_auth_setting() -> None: + if AUTH_TYPE not in [AuthType.DISABLED, AuthType.BASIC, AuthType.GOOGLE_OAUTH]: + raise ValueError( + "User must choose a valid user authentication method: " + "disabled, basic, or google_oauth" + ) + logger.info(f"Using Auth Type: {AUTH_TYPE.value}") + + def get_user_whitelist() -> list[str]: global _user_whitelist if _user_whitelist is None: @@ -204,53 +211,84 @@ auth_backend = AuthenticationBackend( get_strategy=get_database_strategy, ) -oauth_client = None # type: GoogleOAuth2 | OpenID | None -if ENABLE_OAUTH: - if OAUTH_TYPE == "google": - oauth_client = GoogleOAuth2(OAUTH_CLIENT_ID, OAUTH_CLIENT_SECRET) - elif OAUTH_TYPE == "openid": - oauth_client = OpenID(OAUTH_CLIENT_ID, OAUTH_CLIENT_SECRET, OPENID_CONFIG_URL) - else: - raise AssertionError(f"Invalid OAUTH type {OAUTH_TYPE}") + +class FastAPIUserWithLogoutRouter(FastAPIUsers[models.UP, models.ID]): + def get_logout_router( + self, + backend: AuthenticationBackend, + requires_verification: bool = REQUIRE_EMAIL_VERIFICATION, + ) -> APIRouter: + """ + Provide a router for logout only for OAuth/OIDC Flows. + This way the login router does not need to be included + """ + router = APIRouter() + get_current_user_token = self.authenticator.current_user_token( + active=True, verified=requires_verification + ) + logout_responses: OpenAPIResponseType = { + **{ + status.HTTP_401_UNAUTHORIZED: { + "description": "Missing token or inactive user." + } + }, + **backend.transport.get_openapi_logout_responses_success(), + } + + @router.post( + "/logout", name=f"auth:{backend.name}.logout", responses=logout_responses + ) + async def logout( + user_token: Tuple[models.UP, str] = Depends(get_current_user_token), + strategy: Strategy[models.UP, models.ID] = Depends(backend.get_strategy), + ) -> Response: + user, token = user_token + return await backend.logout(strategy, user, token) + + return router -fastapi_users = FastAPIUsers[User, uuid.UUID](get_user_manager, [auth_backend]) - - -# Currently unused, maybe useful later -async def create_get_fake_user() -> User: - get_async_session_context = contextlib.asynccontextmanager( - get_async_session - ) # type:ignore - get_user_db_context = contextlib.asynccontextmanager(get_user_db) - get_user_manager_context = contextlib.asynccontextmanager(get_user_manager) - - logger.info("Creating fake user due to Auth being turned off") - async with get_async_session_context() as session: - async with get_user_db_context(session) as user_db: - async with get_user_manager_context(user_db) as user_manager: - user = await user_manager.get_by_email(FAKE_USER_EMAIL) - if user: - return user - user = await user_manager.create( - UserCreate(email=EmailStr(FAKE_USER_EMAIL), password=FAKE_USER_PASS) - ) - logger.info("Created fake user.") - return user - - -current_active_user = fastapi_users.current_user( - active=True, verified=REQUIRE_EMAIL_VERIFICATION, optional=DISABLE_AUTH +fastapi_users = FastAPIUserWithLogoutRouter[User, uuid.UUID]( + get_user_manager, [auth_backend] ) -async def current_user(user: User = Depends(current_active_user)) -> User | None: - if DISABLE_AUTH: +optional_valid_user = fastapi_users.current_user( + active=True, verified=REQUIRE_EMAIL_VERIFICATION, optional=True +) + + +async def double_check_user( + request: Request, + user: User | None, + db_session: Session, + optional: bool = DISABLE_AUTH, +) -> User | None: + if optional: return None + + if user is None: + raise HTTPException( + status_code=status.HTTP_403_FORBIDDEN, + detail="Access denied. User is not authenticated.", + ) + return user -async def current_admin_user(user: User = Depends(current_user)) -> User | None: +async def current_user( + request: Request, + user: User | None = Depends(optional_valid_user), + db_session: Session = Depends(get_session), +) -> User | None: + double_check_user = fetch_versioned_implementation( + "danswer.auth.users", "double_check_user" + ) + user = await double_check_user(request, user, db_session) + return user + + +async def current_admin_user(user: User | None = Depends(current_user)) -> User | None: if DISABLE_AUTH: return None diff --git a/backend/danswer/configs/app_configs.py b/backend/danswer/configs/app_configs.py index 9dae38c60f0..86a38e0bddb 100644 --- a/backend/danswer/configs/app_configs.py +++ b/backend/danswer/configs/app_configs.py @@ -1,5 +1,6 @@ import os +from danswer.configs.constants import AuthType from danswer.configs.constants import DocumentIndexType ##### @@ -14,32 +15,27 @@ APP_PORT = 8080 ##### BLURB_SIZE = 128 # Number Encoder Tokens included in the chunk blurb GENERATIVE_MODEL_ACCESS_CHECK_FREQ = 86400 # 1 day -# DISABLE_GENERATIVE_AI will turn of the question answering part of Danswer. Use this -# if you want to use Danswer as a search engine only and/or you are not comfortable sending -# anything to OpenAI. TODO: update this message once we support Azure / open source generative models. +# DISABLE_GENERATIVE_AI will turn of the question answering part of Danswer. +# Use this if you want to use Danswer as a search engine only without the LLM capabilities DISABLE_GENERATIVE_AI = os.environ.get("DISABLE_GENERATIVE_AI", "").lower() == "true" ##### # Web Configs ##### -# WEB_DOMAIN is used to set the redirect_uri when doing OAuth with Google -# TODO: investigate if this can be done cleaner by overwriting the redirect_uri -# on the frontend and just sending a dummy value (or completely generating the URL) -# on the frontend -WEB_DOMAIN = os.environ.get("WEB_DOMAIN", "http://localhost:3000") +# WEB_DOMAIN is used to set the redirect_uri after login flows +WEB_DOMAIN = os.environ.get("WEB_DOMAIN") or "http://localhost:3000" ##### # Auth Configs ##### -DISABLE_AUTH = os.environ.get("DISABLE_AUTH", "").lower() == "true" -REQUIRE_EMAIL_VERIFICATION = ( - os.environ.get("REQUIRE_EMAIL_VERIFICATION", "").lower() == "true" +AUTH_TYPE = AuthType((os.environ.get("AUTH_TYPE") or AuthType.DISABLED.value).lower()) +DISABLE_AUTH = AUTH_TYPE == AuthType.DISABLED + +# Turn off mask if admin users should see full credentials for data connectors. +MASK_CREDENTIAL_PREFIX = ( + os.environ.get("MASK_CREDENTIAL_PREFIX", "True").lower() != "false" ) -SMTP_SERVER = os.environ.get("SMTP_SERVER", "smtp.gmail.com") -SMTP_PORT = int(os.environ.get("SMTP_PORT", "587")) -SMTP_USER = os.environ.get("SMTP_USER", "your-email@gmail.com") -SMTP_PASS = os.environ.get("SMTP_PASS", "your-gmail-password") SECRET = os.environ.get("SECRET", "") SESSION_EXPIRE_TIME_SECONDS = int( @@ -62,18 +58,17 @@ VALID_EMAIL_DOMAINS = ( ) # OAuth Login Flow -ENABLE_OAUTH = os.environ.get("ENABLE_OAUTH", "").lower() != "false" -OAUTH_TYPE = os.environ.get("OAUTH_TYPE", "google").lower() -OAUTH_CLIENT_ID = os.environ.get( - "OAUTH_CLIENT_ID", os.environ.get("GOOGLE_OAUTH_CLIENT_ID", "") -) -OAUTH_CLIENT_SECRET = os.environ.get( - "OAUTH_CLIENT_SECRET", os.environ.get("GOOGLE_OAUTH_CLIENT_SECRET", "") -) -OPENID_CONFIG_URL = os.environ.get("OPENID_CONFIG_URL", "") -MASK_CREDENTIAL_PREFIX = ( - os.environ.get("MASK_CREDENTIAL_PREFIX", "True").lower() != "false" +GOOGLE_OAUTH_CLIENT_ID = os.environ.get("GOOGLE_OAUTH_CLIENT_ID") or "" +GOOGLE_OAUTH_CLIENT_SECRET = os.environ.get("GOOGLE_OAUTH_CLIENT_SECRET") or "" + +# The following Basic Auth configs are not supported by the frontend UI +REQUIRE_EMAIL_VERIFICATION = ( + os.environ.get("REQUIRE_EMAIL_VERIFICATION", "").lower() == "true" ) +SMTP_SERVER = os.environ.get("SMTP_SERVER", "smtp.gmail.com") +SMTP_PORT = int(os.environ.get("SMTP_PORT", "587")) +SMTP_USER = os.environ.get("SMTP_USER", "your-email@gmail.com") +SMTP_PASS = os.environ.get("SMTP_PASS", "your-gmail-password") ##### @@ -105,7 +100,7 @@ TYPESENSE_API_KEY = os.environ.get("TYPESENSE_API_KEY", "") # Number of documents in a batch during indexing (further batching done by chunks before passing to bi-encoder) INDEX_BATCH_SIZE = 16 -# below are intended to match the env variables names used by the official postgres docker image +# Below are intended to match the env variables names used by the official postgres docker image # https://hub.docker.com/_/postgres POSTGRES_USER = os.environ.get("POSTGRES_USER") or "postgres" POSTGRES_PASSWORD = os.environ.get("POSTGRES_PASSWORD") or "password" diff --git a/backend/danswer/configs/constants.py b/backend/danswer/configs/constants.py index 4ad6fbec832..058b62fbe08 100644 --- a/backend/danswer/configs/constants.py +++ b/backend/danswer/configs/constants.py @@ -27,6 +27,7 @@ BOOST = "boost" SCORE = "score" ID_SEPARATOR = ":;:" DEFAULT_BOOST = 0 +SESSION_KEY = "session" # Prompt building constants: GENERAL_SEP_PAT = "\n-----\n" @@ -80,6 +81,14 @@ class DanswerGenAIModel(str, Enum): TRANSFORMERS = "transformers" +class AuthType(str, Enum): + DISABLED = "disabled" + BASIC = "basic" + GOOGLE_OAUTH = "google_oauth" + OIDC = "oidc" + SAML = "saml" + + class ModelHostType(str, Enum): """For GenAI models interfaced via requests, different services have different expectations for what fields are included in the request""" diff --git a/backend/danswer/datastores/vespa/app_config/services.xml b/backend/danswer/datastores/vespa/app_config/services.xml index 189b81f204d..492eb225ea5 100644 --- a/backend/danswer/datastores/vespa/app_config/services.xml +++ b/backend/danswer/datastores/vespa/app_config/services.xml @@ -20,9 +20,9 @@ - + - 0.98 + 0.75 diff --git a/backend/danswer/main.py b/backend/danswer/main.py index f15b75b5c03..5b2d1a06ebc 100644 --- a/backend/danswer/main.py +++ b/backend/danswer/main.py @@ -5,25 +5,24 @@ from fastapi import Request from fastapi.exceptions import RequestValidationError from fastapi.middleware.cors import CORSMiddleware from fastapi.responses import JSONResponse +from httpx_oauth.clients.google import GoogleOAuth2 from danswer.auth.schemas import UserCreate from danswer.auth.schemas import UserRead from danswer.auth.schemas import UserUpdate from danswer.auth.users import auth_backend from danswer.auth.users import fastapi_users -from danswer.auth.users import oauth_client from danswer.chat.personas import load_personas_from_yaml from danswer.configs.app_configs import APP_HOST from danswer.configs.app_configs import APP_PORT +from danswer.configs.app_configs import AUTH_TYPE from danswer.configs.app_configs import DISABLE_AUTH from danswer.configs.app_configs import DISABLE_GENERATIVE_AI -from danswer.configs.app_configs import ENABLE_OAUTH -from danswer.configs.app_configs import OAUTH_CLIENT_ID -from danswer.configs.app_configs import OAUTH_CLIENT_SECRET -from danswer.configs.app_configs import OAUTH_TYPE -from danswer.configs.app_configs import OPENID_CONFIG_URL +from danswer.configs.app_configs import GOOGLE_OAUTH_CLIENT_ID +from danswer.configs.app_configs import GOOGLE_OAUTH_CLIENT_SECRET from danswer.configs.app_configs import SECRET from danswer.configs.app_configs import WEB_DOMAIN +from danswer.configs.constants import AuthType from danswer.configs.model_configs import API_BASE_OPENAI from danswer.configs.model_configs import API_TYPE_OPENAI from danswer.configs.model_configs import ASYM_PASSAGE_PREFIX @@ -39,13 +38,14 @@ from danswer.server.chat_backend import router as chat_router from danswer.server.credential import router as credential_router from danswer.server.document_set import router as document_set_router from danswer.server.event_loading import router as event_processing_router -from danswer.server.health import router as health_router from danswer.server.manage import router as admin_router from danswer.server.search_backend import router as backend_router from danswer.server.slack_bot_management import router as slack_bot_management_router +from danswer.server.state import router as state_router from danswer.server.users import router as user_router from danswer.utils.acl import set_acl_for_vespa from danswer.utils.logger import setup_logger +from danswer.utils.variable_functionality import fetch_versioned_implementation logger = setup_logger() @@ -82,53 +82,41 @@ def get_application() -> FastAPI: application.include_router(credential_router) application.include_router(document_set_router) application.include_router(slack_bot_management_router) - application.include_router(health_router) + application.include_router(state_router) - application.include_router( - fastapi_users.get_auth_router(auth_backend), - prefix="/auth/database", - tags=["auth"], - ) - application.include_router( - fastapi_users.get_register_router(UserRead, UserCreate), - prefix="/auth", - tags=["auth"], - ) - application.include_router( - fastapi_users.get_reset_password_router(), - prefix="/auth", - tags=["auth"], - ) - application.include_router( - fastapi_users.get_verify_router(UserRead), - prefix="/auth", - tags=["auth"], - ) - application.include_router( - fastapi_users.get_users_router(UserRead, UserUpdate), - prefix="/users", - tags=["users"], - ) - if ENABLE_OAUTH: - if oauth_client is None: - raise RuntimeError("OAuth is enabled but no OAuth client is configured") + if AUTH_TYPE == AuthType.DISABLED: + # Server logs this during auth setup verification step + pass - if OAUTH_TYPE == "google": - # special case for google - application.include_router( - fastapi_users.get_oauth_router( - oauth_client, - auth_backend, - SECRET, - associate_by_email=True, - is_verified_by_default=True, - # points the user back to the login page, where we will call the - # /auth/google/callback endpoint + redirect them to the main app - redirect_url=f"{WEB_DOMAIN}/auth/google/callback", - ), - prefix="/auth/google", - tags=["auth"], - ) + elif AUTH_TYPE == AuthType.BASIC: + application.include_router( + fastapi_users.get_auth_router(auth_backend), + prefix="/auth", + tags=["auth"], + ) + application.include_router( + fastapi_users.get_register_router(UserRead, UserCreate), + prefix="/auth", + tags=["auth"], + ) + application.include_router( + fastapi_users.get_reset_password_router(), + prefix="/auth", + tags=["auth"], + ) + application.include_router( + fastapi_users.get_verify_router(UserRead), + prefix="/auth", + tags=["auth"], + ) + application.include_router( + fastapi_users.get_users_router(UserRead, UserUpdate), + prefix="/users", + tags=["users"], + ) + + elif AUTH_TYPE == AuthType.GOOGLE_OAUTH: + oauth_client = GoogleOAuth2(GOOGLE_OAUTH_CLIENT_ID, GOOGLE_OAUTH_CLIENT_SECRET) application.include_router( fastapi_users.get_oauth_router( oauth_client, @@ -136,16 +124,16 @@ def get_application() -> FastAPI: SECRET, associate_by_email=True, is_verified_by_default=True, - # points the user back to the login page, where we will call the - # /auth/oauth/callback endpoint + redirect them to the main app + # points the user back to the login page redirect_url=f"{WEB_DOMAIN}/auth/oauth/callback", ), prefix="/auth/oauth", tags=["auth"], ) + # need basic auth router for `logout` endpoint application.include_router( - fastapi_users.get_oauth_associate_router(oauth_client, UserRead, SECRET), - prefix="/auth/associate/oauth", + fastapi_users.get_logout_router(auth_backend), + prefix="/auth", tags=["auth"], ) @@ -170,30 +158,25 @@ def get_application() -> FastAPI: if API_TYPE_OPENAI == "azure": logger.info(f"Using Azure OpenAI with Endpoint: {API_BASE_OPENAI}") - auth_status = "off" if DISABLE_AUTH else "on" - logger.info(f"User Authentication is turned {auth_status}") + verify_auth = fetch_versioned_implementation( + "danswer.auth.users", "verify_auth_setting" + ) + # Will throw exception if an issue is found + verify_auth() - if not DISABLE_AUTH: - if not ENABLE_OAUTH: - logger.debug("OAuth is turned off") - else: - if not OAUTH_CLIENT_ID: - logger.warning("OAuth is turned on but OAUTH_CLIENT_ID is empty") - if not OAUTH_CLIENT_SECRET: - logger.warning( - "OAuth is turned on but OAUTH_CLIENT_SECRET is empty" - ) - if OAUTH_TYPE == "openid" and not OPENID_CONFIG_URL: - logger.warning("OpenID is turned on but OPENID_CONFIG_URL is emtpy") - else: - logger.debug("OAuth is turned on") + if DISABLE_AUTH: + logger.info("User Authentication is turned off.") + + if GOOGLE_OAUTH_CLIENT_ID and GOOGLE_OAUTH_CLIENT_SECRET: + logger.info("Found both OAuth Client ID and secret configured.") if SKIP_RERANKING: logger.info("Reranking step of search flow is disabled") logger.info(f'Using Embedding model: "{DOCUMENT_ENCODER_MODEL}"') - logger.info(f'Query embedding prefix: "{ASYM_QUERY_PREFIX}"') - logger.info(f'Passage embedding prefix: "{ASYM_PASSAGE_PREFIX}"') + if ASYM_QUERY_PREFIX or ASYM_PASSAGE_PREFIX: + logger.info(f'Query embedding prefix: "{ASYM_QUERY_PREFIX}"') + logger.info(f'Passage embedding prefix: "{ASYM_PASSAGE_PREFIX}"') logger.info("Warming up local NLP models.") warm_up_models() @@ -201,9 +184,9 @@ def get_application() -> FastAPI: qa_model.warm_up_model() logger.info("Verifying query preprocessing (NLTK) data is downloaded") - nltk.download("stopwords") - nltk.download("wordnet") - nltk.download("punkt") + nltk.download("stopwords", quiet=True) + nltk.download("wordnet", quiet=True) + nltk.download("punkt", quiet=True) logger.info("Verifying public credential exists.") create_initial_public_credential() @@ -219,17 +202,18 @@ def get_application() -> FastAPI: # does nothing if this has been successfully run before set_acl_for_vespa(should_check_if_already_done=True) + application.add_middleware( + CORSMiddleware, + allow_origins=["*"], # Change this to the list of allowed origins if needed + allow_credentials=True, + allow_methods=["*"], + allow_headers=["*"], + ) + return application app = get_application() -app.add_middleware( - CORSMiddleware, - allow_origins=["*"], # Change this to the list of allowed origins if needed - allow_credentials=True, - allow_methods=["*"], - allow_headers=["*"], -) if __name__ == "__main__": diff --git a/backend/danswer/secondary_llm_flows/query_validation.py b/backend/danswer/secondary_llm_flows/query_validation.py index 8c85cf512e1..d6ee7f9ab6a 100644 --- a/backend/danswer/secondary_llm_flows/query_validation.py +++ b/backend/danswer/secondary_llm_flows/query_validation.py @@ -9,6 +9,10 @@ from danswer.direct_qa.qa_block import dict_based_prompt_to_langchain_prompt from danswer.llm.build import get_default_llm from danswer.server.models import QueryValidationResponse from danswer.server.utils import get_json_line +from danswer.utils.logger import setup_logger + +logger = setup_logger() + QUERY_PAT = "QUERY: " REASONING_PAT = "THOUGHT: " @@ -94,37 +98,44 @@ def get_query_answerability(user_query: str) -> tuple[str, bool]: def stream_query_answerability(user_query: str) -> Iterator[str]: messages = get_query_validation_messages(user_query) filled_llm_prompt = dict_based_prompt_to_langchain_prompt(messages) - tokens = get_default_llm().stream(filled_llm_prompt) - reasoning_pat_found = False - model_output = "" - hold_answerable = "" - for token in tokens: - model_output = model_output + token + try: + tokens = get_default_llm().stream(filled_llm_prompt) + reasoning_pat_found = False + model_output = "" + hold_answerable = "" + for token in tokens: + model_output = model_output + token - if ANSWERABLE_PAT in model_output: - continue - - if not reasoning_pat_found and REASONING_PAT in model_output: - reasoning_pat_found = True - reason_ind = model_output.find(REASONING_PAT) - remaining = model_output[reason_ind + len(REASONING_PAT) :] - if remaining: - yield get_json_line(asdict(DanswerAnswerPiece(answer_piece=remaining))) - continue - - if reasoning_pat_found: - hold_answerable = hold_answerable + token - if hold_answerable == ANSWERABLE_PAT[: len(hold_answerable)]: + if ANSWERABLE_PAT in model_output: continue - yield get_json_line( - asdict(DanswerAnswerPiece(answer_piece=hold_answerable)) - ) - hold_answerable = "" - reasoning = extract_answerability_reasoning(model_output) - answerable = extract_answerability_bool(model_output) + if not reasoning_pat_found and REASONING_PAT in model_output: + reasoning_pat_found = True + reason_ind = model_output.find(REASONING_PAT) + remaining = model_output[reason_ind + len(REASONING_PAT) :] + if remaining: + yield get_json_line( + asdict(DanswerAnswerPiece(answer_piece=remaining)) + ) + continue - yield get_json_line( - QueryValidationResponse(reasoning=reasoning, answerable=answerable).dict() - ) + if reasoning_pat_found: + hold_answerable = hold_answerable + token + if hold_answerable == ANSWERABLE_PAT[: len(hold_answerable)]: + continue + yield get_json_line( + asdict(DanswerAnswerPiece(answer_piece=hold_answerable)) + ) + hold_answerable = "" + + reasoning = extract_answerability_reasoning(model_output) + answerable = extract_answerability_bool(model_output) + + yield get_json_line( + QueryValidationResponse(reasoning=reasoning, answerable=answerable).dict() + ) + except Exception as e: + # exception is logged in the answer_question method, no need to re-log + yield get_json_line({"error": str(e)}) + logger.exception("Failed to validate Query") return diff --git a/backend/danswer/server/health.py b/backend/danswer/server/health.py deleted file mode 100644 index a287f7e7134..00000000000 --- a/backend/danswer/server/health.py +++ /dev/null @@ -1,11 +0,0 @@ -from fastapi import APIRouter - -from danswer.server.models import StatusResponse - - -router = APIRouter() - - -@router.get("/health", response_model=StatusResponse) -def healthcheck() -> StatusResponse: - return StatusResponse(success=True, message="ok") diff --git a/backend/danswer/server/models.py b/backend/danswer/server/models.py index 3f78b17d51a..8c45d55f8e4 100644 --- a/backend/danswer/server/models.py +++ b/backend/danswer/server/models.py @@ -1,7 +1,6 @@ from datetime import datetime from typing import Any from typing import Generic -from typing import Literal from typing import Optional from typing import TypeVar from uuid import UUID @@ -9,7 +8,9 @@ from uuid import UUID from pydantic import BaseModel from pydantic.generics import GenericModel +from danswer.auth.schemas import UserRole from danswer.configs.app_configs import MASK_CREDENTIAL_PREFIX +from danswer.configs.constants import AuthType from danswer.configs.constants import DocumentSource from danswer.configs.constants import MessageType from danswer.configs.constants import QAFeedbackType @@ -38,6 +39,10 @@ class StatusResponse(GenericModel, Generic[DataT]): data: Optional[DataT] = None +class AuthTypeResponse(BaseModel): + auth_type: AuthType + + class DataRequest(BaseModel): data: str @@ -47,6 +52,15 @@ class HelperResponse(BaseModel): details: list[str] | None = None +class UserInfo(BaseModel): + id: str + email: str + is_active: bool + is_superuser: bool + is_verified: bool + role: UserRole + + class GoogleAppWebCredentials(BaseModel): client_id: str project_id: str @@ -84,10 +98,6 @@ class FileUploadResponse(BaseModel): file_paths: list[str] -class HealthCheckResponse(BaseModel): - status: Literal["ok"] - - class ObjectCreationIdResponse(BaseModel): id: int | str diff --git a/backend/danswer/server/state.py b/backend/danswer/server/state.py new file mode 100644 index 00000000000..cbc9500e085 --- /dev/null +++ b/backend/danswer/server/state.py @@ -0,0 +1,18 @@ +from fastapi import APIRouter + +from danswer.configs.app_configs import AUTH_TYPE +from danswer.server.models import AuthTypeResponse +from danswer.server.models import StatusResponse + + +router = APIRouter() + + +@router.get("/health") +def healthcheck() -> StatusResponse: + return StatusResponse(success=True, message="ok") + + +@router.get("/auth/type") +def get_auth_type() -> AuthTypeResponse: + return AuthTypeResponse(auth_type=AUTH_TYPE) diff --git a/backend/danswer/server/users.py b/backend/danswer/server/users.py index b9e3f7e6dde..777bcc4c988 100644 --- a/backend/danswer/server/users.py +++ b/backend/danswer/server/users.py @@ -9,12 +9,13 @@ from sqlalchemy.orm import Session from danswer.auth.schemas import UserRead from danswer.auth.schemas import UserRole from danswer.auth.users import current_admin_user +from danswer.auth.users import current_user from danswer.db.engine import get_session from danswer.db.engine import get_sqlalchemy_async_engine from danswer.db.models import User from danswer.db.users import list_users from danswer.server.models import UserByEmail - +from danswer.server.models import UserInfo router = APIRouter(prefix="/manage") @@ -43,3 +44,18 @@ def list_all_users( ) -> list[UserRead]: users = list_users(db_session) return [UserRead.from_orm(user) for user in users] + + +@router.get("/me") +def verify_user_logged_in(user: User | None = Depends(current_user)) -> UserInfo: + if user is None: + raise HTTPException(status_code=401, detail="User Not Authenticated") + + return UserInfo( + id=str(user.id), + email=user.email, + is_active=user.is_active, + is_superuser=user.is_superuser, + is_verified=user.is_verified, + role=user.role, + ) diff --git a/backend/danswer/utils/variable_functionality.py b/backend/danswer/utils/variable_functionality.py new file mode 100644 index 00000000000..934b526a52b --- /dev/null +++ b/backend/danswer/utils/variable_functionality.py @@ -0,0 +1,21 @@ +import importlib +from typing import Any + + +class DanswerVersion: + def __init__(self) -> None: + self._is_ee = False + + def set_ee(self) -> None: + self._is_ee = True + + def get_is_ee_version(self) -> bool: + return self._is_ee + + +global_version = DanswerVersion() + + +def fetch_versioned_implementation(module: str, attribute: str) -> Any: + module_full = f"ee.{module}" if global_version.get_is_ee_version() else module + return getattr(importlib.import_module(module_full), attribute) diff --git a/deployment/README.md b/deployment/README.md index bbd1d3bc48a..297bd12b266 100644 --- a/deployment/README.md +++ b/deployment/README.md @@ -46,7 +46,6 @@ Additional steps for user auth and https if you do want to use Docker Compose fo 1. Set up a `.env` file in this directory with relevant environment variables. - Refer to `env.prod.template` - To turn on user auth, set: - - ENABLE_OAUTH=True - GOOGLE_OAUTH_CLIENT_ID=\ - GOOGLE_OAUTH_CLIENT_SECRET=\ - Refer to https://developers.google.com/identity/gsi/web/guides/get-google-api-clientid diff --git a/deployment/docker_compose/docker-compose.dev.legacy.yml b/deployment/docker_compose/docker-compose.dev.legacy.yml index 989f9ad5955..006ea0b9327 100644 --- a/deployment/docker_compose/docker-compose.dev.legacy.yml +++ b/deployment/docker_compose/docker-compose.dev.legacy.yml @@ -30,10 +30,8 @@ services: - TYPESENSE_HOST=search_engine - TYPESENSE_API_KEY=${TYPESENSE_API_KEY:-typesense_api_key} - LOG_LEVEL=${LOG_LEVEL:-info} - - DISABLE_AUTH=${DISABLE_AUTH:-True} + - AUTH_TYPE=${AUTH_TYPE:-disabled} - QA_TIMEOUT=${QA_TIMEOUT:-} - - OAUTH_TYPE=${OAUTH_TYPE:-google} - - OPENID_CONFIG_URL=${OPENID_CONFIG_URL:-} - GOOGLE_OAUTH_CLIENT_ID=${GOOGLE_OAUTH_CLIENT_ID:-} - GOOGLE_OAUTH_CLIENT_SECRET=${GOOGLE_OAUTH_CLIENT_SECRET:-} - DISABLE_GENERATIVE_AI=${DISABLE_GENERATIVE_AI:-} @@ -91,13 +89,14 @@ services: dockerfile: Dockerfile args: - NEXT_PUBLIC_DISABLE_STREAMING=${NEXT_PUBLIC_DISABLE_STREAMING:-false} + - NEXT_PUBLIC_AUTH_TYPE=${AUTH_TYPE:-disabled} depends_on: - api_server restart: always environment: - INTERNAL_URL=http://api_server:8080 - WEB_DOMAIN=${WEB_DOMAIN:-} - - DISABLE_AUTH=${DISABLE_AUTH:-True} + - AUTH_TYPE=${AUTH_TYPE:-disabled} - OAUTH_NAME=${OAUTH_NAME:-} relational_db: image: postgres:15.2-alpine diff --git a/deployment/docker_compose/docker-compose.dev.yml b/deployment/docker_compose/docker-compose.dev.yml index 635c880faa9..183e79f6dee 100644 --- a/deployment/docker_compose/docker-compose.dev.yml +++ b/deployment/docker_compose/docker-compose.dev.yml @@ -24,11 +24,9 @@ services: - NUM_DOCUMENT_TOKENS_FED_TO_GENERATIVE_MODEL=${NUM_DOCUMENT_TOKENS_FED_TO_GENERATIVE_MODEL:-} - POSTGRES_HOST=relational_db - VESPA_HOST=index - - DISABLE_AUTH=${DISABLE_AUTH:-True} + - AUTH_TYPE=${AUTH_TYPE:-disabled} - QA_TIMEOUT=${QA_TIMEOUT:-} - VALID_EMAIL_DOMAINS=${VALID_EMAIL_DOMAINS:-} - - OAUTH_TYPE=${OAUTH_TYPE:-google} - - OPENID_CONFIG_URL=${OPENID_CONFIG_URL:-} - GOOGLE_OAUTH_CLIENT_ID=${GOOGLE_OAUTH_CLIENT_ID:-} - GOOGLE_OAUTH_CLIENT_SECRET=${GOOGLE_OAUTH_CLIENT_SECRET:-} - DISABLE_GENERATIVE_AI=${DISABLE_GENERATIVE_AI:-} @@ -77,6 +75,7 @@ services: - AZURE_DEPLOYMENT_ID=${AZURE_DEPLOYMENT_ID:-} - CONTINUE_ON_CONNECTOR_FAILURE=${CONTINUE_ON_CONNECTOR_FAILURE:-} - NUM_INDEXING_WORKERS=${NUM_INDEXING_WORKERS:-} + # Danswer SlackBot Configs - DANSWER_BOT_SLACK_APP_TOKEN=${DANSWER_BOT_SLACK_APP_TOKEN:-} - DANSWER_BOT_SLACK_BOT_TOKEN=${DANSWER_BOT_SLACK_BOT_TOKEN:-} - DANSWER_BOT_DISABLE_DOCS_ONLY_ANSWER=${DANSWER_BOT_DISABLE_DOCS_ONLY_ANSWER:-} @@ -104,13 +103,14 @@ services: dockerfile: Dockerfile args: - NEXT_PUBLIC_DISABLE_STREAMING=${NEXT_PUBLIC_DISABLE_STREAMING:-false} + - NEXT_PUBLIC_AUTH_TYPE=${AUTH_TYPE:-disabled} depends_on: - api_server restart: always environment: - INTERNAL_URL=http://api_server:8080 - WEB_DOMAIN=${WEB_DOMAIN:-} - - DISABLE_AUTH=${DISABLE_AUTH:-True} + - AUTH_TYPE=${AUTH_TYPE:-disabled} - OAUTH_NAME=${OAUTH_NAME:-} relational_db: image: postgres:15.2-alpine diff --git a/deployment/docker_compose/docker-compose.prod.legacy.yml b/deployment/docker_compose/docker-compose.prod.legacy.yml index 38a528df625..929f56ac9d5 100644 --- a/deployment/docker_compose/docker-compose.prod.legacy.yml +++ b/deployment/docker_compose/docker-compose.prod.legacy.yml @@ -17,6 +17,7 @@ services: env_file: - .env environment: + - AUTH_TYPE=${AUTH_TYPE:-google_oauth} - DOCUMENT_INDEX_TYPE=split - POSTGRES_HOST=relational_db - QDRANT_HOST=vector_db @@ -40,6 +41,7 @@ services: env_file: - .env environment: + - AUTH_TYPE=${AUTH_TYPE:-google_oauth} - DOCUMENT_INDEX_TYPE=split - POSTGRES_HOST=relational_db - QDRANT_HOST=vector_db @@ -55,12 +57,16 @@ services: build: context: ../../web dockerfile: Dockerfile + args: + - NEXT_PUBLIC_DISABLE_STREAMING=${NEXT_PUBLIC_DISABLE_STREAMING:-false} + - NEXT_PUBLIC_AUTH_TYPE=${AUTH_TYPE:-google_oauth} depends_on: - api_server restart: always env_file: - .env environment: + - AUTH_TYPE=${AUTH_TYPE:-google_oauth} - INTERNAL_URL=http://api_server:8080 relational_db: image: postgres:15.2-alpine diff --git a/deployment/docker_compose/docker-compose.prod.yml b/deployment/docker_compose/docker-compose.prod.yml index 8833375920c..8850c18bc77 100644 --- a/deployment/docker_compose/docker-compose.prod.yml +++ b/deployment/docker_compose/docker-compose.prod.yml @@ -16,6 +16,7 @@ services: env_file: - .env environment: + - AUTH_TYPE=${AUTH_TYPE:-google_oauth} - POSTGRES_HOST=relational_db - VESPA_HOST=index volumes: @@ -37,6 +38,7 @@ services: env_file: - .env environment: + - AUTH_TYPE=${AUTH_TYPE:-google_oauth} - POSTGRES_HOST=relational_db - VESPA_HOST=index volumes: @@ -50,12 +52,16 @@ services: build: context: ../../web dockerfile: Dockerfile + args: + - NEXT_PUBLIC_DISABLE_STREAMING=${NEXT_PUBLIC_DISABLE_STREAMING:-false} + - NEXT_PUBLIC_AUTH_TYPE=${AUTH_TYPE:-google_oauth} depends_on: - api_server restart: always env_file: - .env environment: + - AUTH_TYPE=${AUTH_TYPE:-google_oauth} - INTERNAL_URL=http://api_server:8080 relational_db: image: postgres:15.2-alpine diff --git a/deployment/docker_compose/env.prod.template b/deployment/docker_compose/env.prod.template index 27648c17aeb..0631aa5f414 100644 --- a/deployment/docker_compose/env.prod.template +++ b/deployment/docker_compose/env.prod.template @@ -1,10 +1,10 @@ # Fill in the values and copy the contents of this file to .env in the deployment directory. # Some valid default values are provided where applicable, delete the variables which you don't set values for. -# Only applicable when using the docker-compose.prod.yml compose file. +# This is only necessary when using the docker-compose.prod.yml compose file. -# Insert your OpenAI API key here, currently the only Generative AI endpoint for QA that we support is OpenAI -# If not provided here, UI will prompt on setup +# Insert your OpenAI API key here If not provided here, UI will prompt on setup. +# This env variable takes precedence over UI settings. GEN_AI_API_KEY= # Choose between "openai-chat-completion" and "openai-completion" INTERNAL_MODEL_VERSION=openai-chat-completion @@ -17,21 +17,12 @@ API_TYPE_OPENAI= API_VERSION_OPENAI= AZURE_DEPLOYMENT_ID= -# Could be something like danswer.companyname.com. Requires additional setup if not localhost +# Could be something like danswer.companyname.com WEB_DOMAIN=http://localhost:3000 -# If you want to make the postgres / typesense instances a little more secure, modify the below -# Note that the postgres / typesense / qdrant containers do not expose any ports to the outside world, -# so they are already unaccessible unless someone has ssh access to the machine that Danswer is running on +# Default values here are what Postgres uses by default, feel free to change. POSTGRES_USER=postgres POSTGRES_PASSWORD=password -TYPESENSE_API_KEY=typesense_api_key - -# Currently frontend page doesn't have basic auth, use OAuth if user auth is enabled. -ENABLE_OAUTH=True -# The two settings below are only required if ENABLE_OAUTH is true -GOOGLE_OAUTH_CLIENT_ID= -GOOGLE_OAUTH_CLIENT_SECRET= # If you want to setup a slack bot to answer questions automatically in Slack # channels it is added to, you must specify the below. @@ -45,15 +36,25 @@ SECRET= # How long before user needs to reauthenticate, default to 1 day. (cookie expiration time) SESSION_EXPIRE_TIME_SECONDS=86400 -# used to specify a list of allowed user domains +# The following are for configuring User Authentication, supported flows are: +# disabled +# simple (email/password + user account creation in Danswer) +# google_oauth (login with google/gmail account) +# oidc (only in Danswer enterprise edition) +# saml (only in Danswer enterprise edition) +AUTH_TYPE= + +# Set the two below to use with Google OAuth +GOOGLE_OAUTH_CLIENT_ID= +GOOGLE_OAUTH_CLIENT_SECRET= + +# OpenID Connect (OIDC) +OPENID_CONFIG_URL= + +# SAML config directory for OneLogin compatible setups +SAML_CONF_DIR= + +# used to specify a list of allowed user domains, only checked if user Auth is turned on # e.g. `VALID_EMAIL_DOMAINS=example.com,example.org` will only allow users # with an @example.com or an @example.org email VALID_EMAIL_DOMAINS= - -# Only relevant if using basic auth (not supported on frontend yet) -REQUIRE_EMAIL_VERIFICATION=True -# The four settings below are only required if REQUIRE_EMAIL_VERIFICATION is True -SMTP_SERVER= -SMTP_PORT= -SMTP_USER= -SMTP_PASS= diff --git a/deployment/kubernetes/api_server-service-deployment.yaml b/deployment/kubernetes/api_server-service-deployment.yaml index 8c7ec43373b..be45bc3808c 100644 --- a/deployment/kubernetes/api_server-service-deployment.yaml +++ b/deployment/kubernetes/api_server-service-deployment.yaml @@ -40,6 +40,8 @@ spec: ports: - containerPort: 8080 env: + - name: AUTH_TYPE + value: google_oauth - name: POSTGRES_HOST value: relational-db-service - name: VESPA_HOST diff --git a/deployment/kubernetes/web_server-service-deployment.yaml b/deployment/kubernetes/web_server-service-deployment.yaml index f1329c638f2..e1f04fb06e1 100644 --- a/deployment/kubernetes/web_server-service-deployment.yaml +++ b/deployment/kubernetes/web_server-service-deployment.yaml @@ -31,6 +31,10 @@ spec: imagePullPolicy: IfNotPresent ports: - containerPort: 3000 + args: + - "NEXT_PUBLIC_AUTH_TYPE=google_oauth" env: + - name: AUTH_TYPE + value: google_oauth - name: INTERNAL_URL value: "http://api-server-service:80" diff --git a/web/Dockerfile b/web/Dockerfile index 0f571fc87c4..f2e0fbefada 100644 --- a/web/Dockerfile +++ b/web/Dockerfile @@ -30,6 +30,9 @@ ENV NEXT_TELEMETRY_DISABLED 1 ARG NEXT_PUBLIC_DISABLE_STREAMING ENV NEXT_PUBLIC_DISABLE_STREAMING=${NEXT_PUBLIC_DISABLE_STREAMING} +ARG NEXT_PUBLIC_AUTH_TYPE +ENV NEXT_PUBLIC_AUTH_TYPE=${NEXT_PUBLIC_AUTH_TYPE} + RUN npm run build # Step 3. Production image, copy all the files and run next @@ -65,6 +68,9 @@ COPY --from=builder --chown=nextjs:nodejs /app/.next/static ./.next/static ARG NEXT_PUBLIC_DISABLE_STREAMING ENV NEXT_PUBLIC_DISABLE_STREAMING=${NEXT_PUBLIC_DISABLE_STREAMING} +ARG NEXT_PUBLIC_AUTH_TYPE +ENV NEXT_PUBLIC_AUTH_TYPE=${NEXT_PUBLIC_AUTH_TYPE} + # Note: Don't expose ports here, Compose will handle that for us if necessary. # If you want to run this without compose, specify the ports to # expose via cli diff --git a/web/src/app/auth/google/callback/route.ts b/web/src/app/auth/google/callback/route.ts deleted file mode 100644 index a554d89e62c..00000000000 --- a/web/src/app/auth/google/callback/route.ts +++ /dev/null @@ -1,23 +0,0 @@ -import { getDomain } from "@/lib/redirectSS"; -import { buildUrl } from "@/lib/utilsSS"; -import { NextRequest, NextResponse } from "next/server"; - -export const GET = async (request: NextRequest) => { - // Wrapper around the FastAPI endpoint /auth/google/callback, - // which adds back a redirect to the main app. - const url = new URL(buildUrl("/auth/google/callback")); - url.search = request.nextUrl.search; - - const response = await fetch(url.toString()); - const setCookieHeader = response.headers.get("set-cookie"); - - if (!setCookieHeader) { - return NextResponse.redirect(new URL("/auth/error", getDomain(request))); - } - - const redirectResponse = NextResponse.redirect( - new URL("/", getDomain(request)) - ); - redirectResponse.headers.set("set-cookie", setCookieHeader); - return redirectResponse; -}; diff --git a/web/src/app/auth/login/page.tsx b/web/src/app/auth/login/page.tsx index 52b2e9408c2..87eeaaa8e5e 100644 --- a/web/src/app/auth/login/page.tsx +++ b/web/src/app/auth/login/page.tsx @@ -1,7 +1,7 @@ import { HealthCheckBanner } from "@/components/health/healthcheck"; import { DISABLE_AUTH, OAUTH_NAME } from "@/lib/constants"; import { User } from "@/lib/types"; -import { getGoogleOAuthUrlSS, getCurrentUserSS } from "@/lib/userSS"; +import { getCurrentUserSS, getAuthUrlSS } from "@/lib/userSS"; import { redirect } from "next/navigation"; const BUTTON_STYLE = @@ -21,10 +21,11 @@ const Page = async () => { // will not render let currentUser: User | null = null; let authorizationUrl: string | null = null; + let autoRedirect: boolean = false; try { - [currentUser, authorizationUrl] = await Promise.all([ + [currentUser, [authorizationUrl, autoRedirect]] = await Promise.all([ getCurrentUserSS(), - getGoogleOAuthUrlSS(), + getAuthUrlSS(), ]); } catch (e) { console.log(`Some fetch failed for the login page - ${e}`); @@ -35,6 +36,10 @@ const Page = async () => { return redirect("/"); } + if (autoRedirect && authorizationUrl) { + return redirect(authorizationUrl); + } + return (
diff --git a/web/src/app/auth/logout/route.ts b/web/src/app/auth/logout/route.ts new file mode 100644 index 00000000000..297094a9ec6 --- /dev/null +++ b/web/src/app/auth/logout/route.ts @@ -0,0 +1,8 @@ +import { logoutSS } from "@/lib/userSS"; +import { NextRequest } from "next/server"; + +export const POST = async (request: NextRequest) => { + // Directs the logout request to the appropriate FastAPI endpoint. + // Needed since env variables don't work well on the client-side + return await logoutSS(request.headers); +}; diff --git a/web/src/app/auth/saml/callback/route.ts b/web/src/app/auth/saml/callback/route.ts new file mode 100644 index 00000000000..fe9db7be126 --- /dev/null +++ b/web/src/app/auth/saml/callback/route.ts @@ -0,0 +1,33 @@ +import { getDomain } from "@/lib/redirectSS"; +import { buildUrl } from "@/lib/utilsSS"; +import { NextRequest, NextResponse } from "next/server"; + +// have to use this so we don't hit the redirect URL with a `POST` request +const SEE_OTHER_REDIRECT_STATUS = 303; + +export const POST = async (request: NextRequest) => { + // Wrapper around the FastAPI endpoint /auth/saml/callback, + // which adds back a redirect to the main app. + const url = new URL(buildUrl("/auth/saml/callback")); + url.search = request.nextUrl.search; + + const response = await fetch(url.toString(), { + method: "POST", + body: await request.formData(), + }); + const setCookieHeader = response.headers.get("set-cookie"); + + if (!setCookieHeader) { + return NextResponse.redirect( + new URL("/auth/error", getDomain(request)), + SEE_OTHER_REDIRECT_STATUS + ); + } + + const redirectResponse = NextResponse.redirect( + new URL("/", getDomain(request)), + SEE_OTHER_REDIRECT_STATUS + ); + redirectResponse.headers.set("set-cookie", setCookieHeader); + return redirectResponse; +}; diff --git a/web/src/lib/constants.ts b/web/src/lib/constants.ts index 1343f07ad21..ef487a8010f 100644 --- a/web/src/lib/constants.ts +++ b/web/src/lib/constants.ts @@ -1,5 +1,9 @@ -export const DISABLE_AUTH = process.env.DISABLE_AUTH?.toLowerCase() === "true"; +export type AuthType = "disabled" | "google_oauth" | "oidc" | "saml"; +export const AUTH_TYPE = (process.env.AUTH_TYPE || + process.env.NEXT_PUBLIC_AUTH_TYPE || + "disabled") as AuthType; +export const DISABLE_AUTH = AUTH_TYPE === "disabled"; export const OAUTH_NAME = process.env.OAUTH_NAME || "Google"; export const INTERNAL_URL = process.env.INTERNAL_URL || "http://127.0.0.1:8080"; diff --git a/web/src/lib/user.ts b/web/src/lib/user.ts index cccd7c72ec4..f4d67acff11 100644 --- a/web/src/lib/user.ts +++ b/web/src/lib/user.ts @@ -2,7 +2,7 @@ import { User } from "./types"; // should be used client-side only export const getCurrentUser = async (): Promise => { - const response = await fetch("/api/users/me", { + const response = await fetch("/api/manage/me", { credentials: "include", }); if (!response.ok) { @@ -13,7 +13,7 @@ export const getCurrentUser = async (): Promise => { }; export const logout = async (): Promise => { - const response = await fetch("/api/auth/database/logout", { + const response = await fetch("/auth/logout", { method: "POST", credentials: "include", }); diff --git a/web/src/lib/userSS.ts b/web/src/lib/userSS.ts index bedaee71ece..306177effc7 100644 --- a/web/src/lib/userSS.ts +++ b/web/src/lib/userSS.ts @@ -2,8 +2,19 @@ import { cookies } from "next/headers"; import { User } from "./types"; import { buildUrl } from "./utilsSS"; import { ReadonlyRequestCookies } from "next/dist/server/web/spec-extension/adapters/request-cookies"; +import { AUTH_TYPE } from "./constants"; -export const getGoogleOAuthUrlSS = async (): Promise => { +const geOIDCAuthUrlSS = async (): Promise => { + const res = await fetch(buildUrl("/auth/oidc/authorize")); + if (!res.ok) { + throw new Error("Failed to fetch data"); + } + + const data: { authorization_url: string } = await res.json(); + return data.authorization_url; +}; + +const getGoogleOAuthUrlSS = async (): Promise => { const res = await fetch(buildUrl("/auth/oauth/authorize")); if (!res.ok) { throw new Error("Failed to fetch data"); @@ -13,10 +24,63 @@ export const getGoogleOAuthUrlSS = async (): Promise => { return data.authorization_url; }; -// should be used server-side only +const getSAMLAuthUrlSS = async (): Promise => { + const res = await fetch(buildUrl("/auth/saml/authorize")); + if (!res.ok) { + throw new Error("Failed to fetch data"); + } + + const data: { authorization_url: string } = await res.json(); + return data.authorization_url; +}; + +export const getAuthUrlSS = async (): Promise<[string, boolean]> => { + // Returns the auth url and whether or not we should auto-redirect + switch (AUTH_TYPE) { + case "disabled": + return ["", true]; + case "google_oauth": { + return [await getGoogleOAuthUrlSS(), false]; + } + case "saml": { + return [await getSAMLAuthUrlSS(), true]; + } + case "oidc": { + return [await geOIDCAuthUrlSS(), true]; + } + } +}; + +const logoutStandardSS = async (headers: Headers): Promise => { + return await fetch(buildUrl("/auth/logout"), { + method: "POST", + headers: headers, + }); +}; + +const logoutSAMLSS = async (headers: Headers): Promise => { + return await fetch(buildUrl("/auth/saml/logout"), { + method: "POST", + headers: headers, + }); +}; + +export const logoutSS = async (headers: Headers): Promise => { + switch (AUTH_TYPE) { + case "disabled": + return null; + case "saml": { + return await logoutSAMLSS(headers); + } + default: { + return await logoutStandardSS(headers); + } + } +}; + export const getCurrentUserSS = async (): Promise => { try { - const response = await fetch(buildUrl("/users/me"), { + const response = await fetch(buildUrl("/manage/me"), { credentials: "include", next: { revalidate: 0 }, headers: {