mirror of
https://github.com/danswer-ai/danswer.git
synced 2025-09-20 13:05:49 +02:00
Consolidate versions for easier extension (#495)
This commit is contained in:
@@ -115,11 +115,11 @@ mkdir dynamic_config_storage
|
||||
|
||||
To start the frontend, navigate to `danswer/web` and run:
|
||||
```bash
|
||||
DISABLE_AUTH=true npm run dev
|
||||
AUTH_TYPE=disabled npm run dev
|
||||
```
|
||||
_for Windows, run:_
|
||||
```bash
|
||||
(SET "DISABLE_AUTH=true" && npm run dev)
|
||||
(SET "AUTH_TYPE=disabled" && npm run dev)
|
||||
```
|
||||
|
||||
|
||||
@@ -138,7 +138,7 @@ zip -r ../vespa-app.zip .
|
||||
|
||||
To run the backend API server, navigate back to `danswer/backend` and run:
|
||||
```bash
|
||||
DISABLE_AUTH=True \
|
||||
AUTH_TYPE=disabled \
|
||||
DYNAMIC_CONFIG_DIR_PATH=./dynamic_config_storage \
|
||||
VESPA_DEPLOYMENT_ZIP=./danswer/datastores/vespa/vespa-app.zip \
|
||||
uvicorn danswer.main:app --reload --port 8080
|
||||
@@ -146,7 +146,7 @@ uvicorn danswer.main:app --reload --port 8080
|
||||
_For Windows (for compatibility with both PowerShell and Command Prompt):_
|
||||
```bash
|
||||
powershell -Command "
|
||||
$env:DISABLE_AUTH='True'
|
||||
$env:AUTH_TYPE='disabled'
|
||||
$env:DYNAMIC_CONFIG_DIR_PATH='./dynamic_config_storage'
|
||||
$env:VESPA_DEPLOYMENT_ZIP='./danswer/datastores/vespa/vespa-app.zip'
|
||||
uvicorn danswer.main:app --reload --port 8080
|
||||
|
47
backend/alembic/versions/ae62505e3acc_add_saml_accounts.py
Normal file
47
backend/alembic/versions/ae62505e3acc_add_saml_accounts.py
Normal file
@@ -0,0 +1,47 @@
|
||||
"""Add SAML Accounts
|
||||
|
||||
Revision ID: ae62505e3acc
|
||||
Revises: 7da543f5672f
|
||||
Create Date: 2023-09-26 16:19:30.933183
|
||||
|
||||
"""
|
||||
import fastapi_users_db_sqlalchemy
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "ae62505e3acc"
|
||||
down_revision = "7da543f5672f"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
op.create_table(
|
||||
"saml",
|
||||
sa.Column("id", sa.Integer(), nullable=False),
|
||||
sa.Column(
|
||||
"user_id",
|
||||
fastapi_users_db_sqlalchemy.generics.GUID(),
|
||||
nullable=False,
|
||||
),
|
||||
sa.Column("encrypted_cookie", sa.Text(), nullable=False),
|
||||
sa.Column("expires_at", sa.DateTime(timezone=True), nullable=True),
|
||||
sa.Column(
|
||||
"updated_at",
|
||||
sa.DateTime(timezone=True),
|
||||
server_default=sa.text("now()"),
|
||||
nullable=False,
|
||||
),
|
||||
sa.ForeignKeyConstraint(
|
||||
["user_id"],
|
||||
["user.id"],
|
||||
),
|
||||
sa.PrimaryKeyConstraint("id"),
|
||||
sa.UniqueConstraint("encrypted_cookie"),
|
||||
sa.UniqueConstraint("user_id"),
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.drop_table("saml")
|
@@ -1,4 +1,3 @@
|
||||
import contextlib
|
||||
import os
|
||||
import smtplib
|
||||
import uuid
|
||||
@@ -6,10 +5,13 @@ from collections.abc import AsyncGenerator
|
||||
from email.mime.multipart import MIMEMultipart
|
||||
from email.mime.text import MIMEText
|
||||
from typing import Optional
|
||||
from typing import Tuple
|
||||
|
||||
from fastapi import APIRouter
|
||||
from fastapi import Depends
|
||||
from fastapi import HTTPException
|
||||
from fastapi import Request
|
||||
from fastapi import Response
|
||||
from fastapi import status
|
||||
from fastapi_users import BaseUserManager
|
||||
from fastapi_users import FastAPIUsers
|
||||
@@ -18,21 +20,17 @@ from fastapi_users import schemas
|
||||
from fastapi_users import UUIDIDMixin
|
||||
from fastapi_users.authentication import AuthenticationBackend
|
||||
from fastapi_users.authentication import CookieTransport
|
||||
from fastapi_users.authentication import Strategy
|
||||
from fastapi_users.authentication.strategy.db import AccessTokenDatabase
|
||||
from fastapi_users.authentication.strategy.db import DatabaseStrategy
|
||||
from fastapi_users.db import SQLAlchemyUserDatabase
|
||||
from httpx_oauth.clients.google import GoogleOAuth2
|
||||
from httpx_oauth.clients.openid import OpenID
|
||||
from pydantic import EmailStr
|
||||
from fastapi_users.openapi import OpenAPIResponseType
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from danswer.auth.schemas import UserCreate
|
||||
from danswer.auth.schemas import UserRole
|
||||
from danswer.configs.app_configs import AUTH_TYPE
|
||||
from danswer.configs.app_configs import DISABLE_AUTH
|
||||
from danswer.configs.app_configs import ENABLE_OAUTH
|
||||
from danswer.configs.app_configs import OAUTH_CLIENT_ID
|
||||
from danswer.configs.app_configs import OAUTH_CLIENT_SECRET
|
||||
from danswer.configs.app_configs import OAUTH_TYPE
|
||||
from danswer.configs.app_configs import OPENID_CONFIG_URL
|
||||
from danswer.configs.app_configs import REQUIRE_EMAIL_VERIFICATION
|
||||
from danswer.configs.app_configs import SECRET
|
||||
from danswer.configs.app_configs import SESSION_EXPIRE_TIME_SECONDS
|
||||
@@ -42,23 +40,32 @@ from danswer.configs.app_configs import SMTP_SERVER
|
||||
from danswer.configs.app_configs import SMTP_USER
|
||||
from danswer.configs.app_configs import VALID_EMAIL_DOMAINS
|
||||
from danswer.configs.app_configs import WEB_DOMAIN
|
||||
from danswer.configs.constants import AuthType
|
||||
from danswer.db.auth import get_access_token_db
|
||||
from danswer.db.auth import get_user_count
|
||||
from danswer.db.auth import get_user_db
|
||||
from danswer.db.engine import get_async_session
|
||||
from danswer.db.engine import get_session
|
||||
from danswer.db.models import AccessToken
|
||||
from danswer.db.models import User
|
||||
from danswer.utils.logger import setup_logger
|
||||
from danswer.utils.variable_functionality import fetch_versioned_implementation
|
||||
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
FAKE_USER_EMAIL = "fakeuser@fakedanswermail.com"
|
||||
FAKE_USER_PASS = "foobar"
|
||||
|
||||
USER_WHITELIST_FILE = "/home/danswer_whitelist.txt"
|
||||
_user_whitelist: list[str] | None = None
|
||||
|
||||
|
||||
def verify_auth_setting() -> None:
|
||||
if AUTH_TYPE not in [AuthType.DISABLED, AuthType.BASIC, AuthType.GOOGLE_OAUTH]:
|
||||
raise ValueError(
|
||||
"User must choose a valid user authentication method: "
|
||||
"disabled, basic, or google_oauth"
|
||||
)
|
||||
logger.info(f"Using Auth Type: {AUTH_TYPE.value}")
|
||||
|
||||
|
||||
def get_user_whitelist() -> list[str]:
|
||||
global _user_whitelist
|
||||
if _user_whitelist is None:
|
||||
@@ -204,53 +211,84 @@ auth_backend = AuthenticationBackend(
|
||||
get_strategy=get_database_strategy,
|
||||
)
|
||||
|
||||
oauth_client = None # type: GoogleOAuth2 | OpenID | None
|
||||
if ENABLE_OAUTH:
|
||||
if OAUTH_TYPE == "google":
|
||||
oauth_client = GoogleOAuth2(OAUTH_CLIENT_ID, OAUTH_CLIENT_SECRET)
|
||||
elif OAUTH_TYPE == "openid":
|
||||
oauth_client = OpenID(OAUTH_CLIENT_ID, OAUTH_CLIENT_SECRET, OPENID_CONFIG_URL)
|
||||
else:
|
||||
raise AssertionError(f"Invalid OAUTH type {OAUTH_TYPE}")
|
||||
|
||||
class FastAPIUserWithLogoutRouter(FastAPIUsers[models.UP, models.ID]):
|
||||
def get_logout_router(
|
||||
self,
|
||||
backend: AuthenticationBackend,
|
||||
requires_verification: bool = REQUIRE_EMAIL_VERIFICATION,
|
||||
) -> APIRouter:
|
||||
"""
|
||||
Provide a router for logout only for OAuth/OIDC Flows.
|
||||
This way the login router does not need to be included
|
||||
"""
|
||||
router = APIRouter()
|
||||
get_current_user_token = self.authenticator.current_user_token(
|
||||
active=True, verified=requires_verification
|
||||
)
|
||||
logout_responses: OpenAPIResponseType = {
|
||||
**{
|
||||
status.HTTP_401_UNAUTHORIZED: {
|
||||
"description": "Missing token or inactive user."
|
||||
}
|
||||
},
|
||||
**backend.transport.get_openapi_logout_responses_success(),
|
||||
}
|
||||
|
||||
@router.post(
|
||||
"/logout", name=f"auth:{backend.name}.logout", responses=logout_responses
|
||||
)
|
||||
async def logout(
|
||||
user_token: Tuple[models.UP, str] = Depends(get_current_user_token),
|
||||
strategy: Strategy[models.UP, models.ID] = Depends(backend.get_strategy),
|
||||
) -> Response:
|
||||
user, token = user_token
|
||||
return await backend.logout(strategy, user, token)
|
||||
|
||||
return router
|
||||
|
||||
|
||||
fastapi_users = FastAPIUsers[User, uuid.UUID](get_user_manager, [auth_backend])
|
||||
|
||||
|
||||
# Currently unused, maybe useful later
|
||||
async def create_get_fake_user() -> User:
|
||||
get_async_session_context = contextlib.asynccontextmanager(
|
||||
get_async_session
|
||||
) # type:ignore
|
||||
get_user_db_context = contextlib.asynccontextmanager(get_user_db)
|
||||
get_user_manager_context = contextlib.asynccontextmanager(get_user_manager)
|
||||
|
||||
logger.info("Creating fake user due to Auth being turned off")
|
||||
async with get_async_session_context() as session:
|
||||
async with get_user_db_context(session) as user_db:
|
||||
async with get_user_manager_context(user_db) as user_manager:
|
||||
user = await user_manager.get_by_email(FAKE_USER_EMAIL)
|
||||
if user:
|
||||
return user
|
||||
user = await user_manager.create(
|
||||
UserCreate(email=EmailStr(FAKE_USER_EMAIL), password=FAKE_USER_PASS)
|
||||
)
|
||||
logger.info("Created fake user.")
|
||||
return user
|
||||
|
||||
|
||||
current_active_user = fastapi_users.current_user(
|
||||
active=True, verified=REQUIRE_EMAIL_VERIFICATION, optional=DISABLE_AUTH
|
||||
fastapi_users = FastAPIUserWithLogoutRouter[User, uuid.UUID](
|
||||
get_user_manager, [auth_backend]
|
||||
)
|
||||
|
||||
|
||||
async def current_user(user: User = Depends(current_active_user)) -> User | None:
|
||||
if DISABLE_AUTH:
|
||||
optional_valid_user = fastapi_users.current_user(
|
||||
active=True, verified=REQUIRE_EMAIL_VERIFICATION, optional=True
|
||||
)
|
||||
|
||||
|
||||
async def double_check_user(
|
||||
request: Request,
|
||||
user: User | None,
|
||||
db_session: Session,
|
||||
optional: bool = DISABLE_AUTH,
|
||||
) -> User | None:
|
||||
if optional:
|
||||
return None
|
||||
|
||||
if user is None:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail="Access denied. User is not authenticated.",
|
||||
)
|
||||
|
||||
return user
|
||||
|
||||
|
||||
async def current_admin_user(user: User = Depends(current_user)) -> User | None:
|
||||
async def current_user(
|
||||
request: Request,
|
||||
user: User | None = Depends(optional_valid_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> User | None:
|
||||
double_check_user = fetch_versioned_implementation(
|
||||
"danswer.auth.users", "double_check_user"
|
||||
)
|
||||
user = await double_check_user(request, user, db_session)
|
||||
return user
|
||||
|
||||
|
||||
async def current_admin_user(user: User | None = Depends(current_user)) -> User | None:
|
||||
if DISABLE_AUTH:
|
||||
return None
|
||||
|
||||
|
@@ -1,5 +1,6 @@
|
||||
import os
|
||||
|
||||
from danswer.configs.constants import AuthType
|
||||
from danswer.configs.constants import DocumentIndexType
|
||||
|
||||
#####
|
||||
@@ -14,32 +15,27 @@ APP_PORT = 8080
|
||||
#####
|
||||
BLURB_SIZE = 128 # Number Encoder Tokens included in the chunk blurb
|
||||
GENERATIVE_MODEL_ACCESS_CHECK_FREQ = 86400 # 1 day
|
||||
# DISABLE_GENERATIVE_AI will turn of the question answering part of Danswer. Use this
|
||||
# if you want to use Danswer as a search engine only and/or you are not comfortable sending
|
||||
# anything to OpenAI. TODO: update this message once we support Azure / open source generative models.
|
||||
# DISABLE_GENERATIVE_AI will turn of the question answering part of Danswer.
|
||||
# Use this if you want to use Danswer as a search engine only without the LLM capabilities
|
||||
DISABLE_GENERATIVE_AI = os.environ.get("DISABLE_GENERATIVE_AI", "").lower() == "true"
|
||||
|
||||
#####
|
||||
# Web Configs
|
||||
#####
|
||||
# WEB_DOMAIN is used to set the redirect_uri when doing OAuth with Google
|
||||
# TODO: investigate if this can be done cleaner by overwriting the redirect_uri
|
||||
# on the frontend and just sending a dummy value (or completely generating the URL)
|
||||
# on the frontend
|
||||
WEB_DOMAIN = os.environ.get("WEB_DOMAIN", "http://localhost:3000")
|
||||
# WEB_DOMAIN is used to set the redirect_uri after login flows
|
||||
WEB_DOMAIN = os.environ.get("WEB_DOMAIN") or "http://localhost:3000"
|
||||
|
||||
|
||||
#####
|
||||
# Auth Configs
|
||||
#####
|
||||
DISABLE_AUTH = os.environ.get("DISABLE_AUTH", "").lower() == "true"
|
||||
REQUIRE_EMAIL_VERIFICATION = (
|
||||
os.environ.get("REQUIRE_EMAIL_VERIFICATION", "").lower() == "true"
|
||||
AUTH_TYPE = AuthType((os.environ.get("AUTH_TYPE") or AuthType.DISABLED.value).lower())
|
||||
DISABLE_AUTH = AUTH_TYPE == AuthType.DISABLED
|
||||
|
||||
# Turn off mask if admin users should see full credentials for data connectors.
|
||||
MASK_CREDENTIAL_PREFIX = (
|
||||
os.environ.get("MASK_CREDENTIAL_PREFIX", "True").lower() != "false"
|
||||
)
|
||||
SMTP_SERVER = os.environ.get("SMTP_SERVER", "smtp.gmail.com")
|
||||
SMTP_PORT = int(os.environ.get("SMTP_PORT", "587"))
|
||||
SMTP_USER = os.environ.get("SMTP_USER", "your-email@gmail.com")
|
||||
SMTP_PASS = os.environ.get("SMTP_PASS", "your-gmail-password")
|
||||
|
||||
SECRET = os.environ.get("SECRET", "")
|
||||
SESSION_EXPIRE_TIME_SECONDS = int(
|
||||
@@ -62,18 +58,17 @@ VALID_EMAIL_DOMAINS = (
|
||||
)
|
||||
|
||||
# OAuth Login Flow
|
||||
ENABLE_OAUTH = os.environ.get("ENABLE_OAUTH", "").lower() != "false"
|
||||
OAUTH_TYPE = os.environ.get("OAUTH_TYPE", "google").lower()
|
||||
OAUTH_CLIENT_ID = os.environ.get(
|
||||
"OAUTH_CLIENT_ID", os.environ.get("GOOGLE_OAUTH_CLIENT_ID", "")
|
||||
)
|
||||
OAUTH_CLIENT_SECRET = os.environ.get(
|
||||
"OAUTH_CLIENT_SECRET", os.environ.get("GOOGLE_OAUTH_CLIENT_SECRET", "")
|
||||
)
|
||||
OPENID_CONFIG_URL = os.environ.get("OPENID_CONFIG_URL", "")
|
||||
MASK_CREDENTIAL_PREFIX = (
|
||||
os.environ.get("MASK_CREDENTIAL_PREFIX", "True").lower() != "false"
|
||||
GOOGLE_OAUTH_CLIENT_ID = os.environ.get("GOOGLE_OAUTH_CLIENT_ID") or ""
|
||||
GOOGLE_OAUTH_CLIENT_SECRET = os.environ.get("GOOGLE_OAUTH_CLIENT_SECRET") or ""
|
||||
|
||||
# The following Basic Auth configs are not supported by the frontend UI
|
||||
REQUIRE_EMAIL_VERIFICATION = (
|
||||
os.environ.get("REQUIRE_EMAIL_VERIFICATION", "").lower() == "true"
|
||||
)
|
||||
SMTP_SERVER = os.environ.get("SMTP_SERVER", "smtp.gmail.com")
|
||||
SMTP_PORT = int(os.environ.get("SMTP_PORT", "587"))
|
||||
SMTP_USER = os.environ.get("SMTP_USER", "your-email@gmail.com")
|
||||
SMTP_PASS = os.environ.get("SMTP_PASS", "your-gmail-password")
|
||||
|
||||
|
||||
#####
|
||||
@@ -105,7 +100,7 @@ TYPESENSE_API_KEY = os.environ.get("TYPESENSE_API_KEY", "")
|
||||
# Number of documents in a batch during indexing (further batching done by chunks before passing to bi-encoder)
|
||||
INDEX_BATCH_SIZE = 16
|
||||
|
||||
# below are intended to match the env variables names used by the official postgres docker image
|
||||
# Below are intended to match the env variables names used by the official postgres docker image
|
||||
# https://hub.docker.com/_/postgres
|
||||
POSTGRES_USER = os.environ.get("POSTGRES_USER") or "postgres"
|
||||
POSTGRES_PASSWORD = os.environ.get("POSTGRES_PASSWORD") or "password"
|
||||
|
@@ -27,6 +27,7 @@ BOOST = "boost"
|
||||
SCORE = "score"
|
||||
ID_SEPARATOR = ":;:"
|
||||
DEFAULT_BOOST = 0
|
||||
SESSION_KEY = "session"
|
||||
|
||||
# Prompt building constants:
|
||||
GENERAL_SEP_PAT = "\n-----\n"
|
||||
@@ -80,6 +81,14 @@ class DanswerGenAIModel(str, Enum):
|
||||
TRANSFORMERS = "transformers"
|
||||
|
||||
|
||||
class AuthType(str, Enum):
|
||||
DISABLED = "disabled"
|
||||
BASIC = "basic"
|
||||
GOOGLE_OAUTH = "google_oauth"
|
||||
OIDC = "oidc"
|
||||
SAML = "saml"
|
||||
|
||||
|
||||
class ModelHostType(str, Enum):
|
||||
"""For GenAI models interfaced via requests, different services have different
|
||||
expectations for what fields are included in the request"""
|
||||
|
@@ -20,9 +20,9 @@
|
||||
</nodes>
|
||||
<tuning>
|
||||
<resource-limits>
|
||||
<!-- Default is 75% but this should be increased for Dockerized deployments -->
|
||||
<!-- Default is 75% but this can be increased for Dockerized deployments -->
|
||||
<!-- https://docs.vespa.ai/en/operations/feed-block.html -->
|
||||
<disk>0.98</disk>
|
||||
<disk>0.75</disk>
|
||||
</resource-limits>
|
||||
</tuning>
|
||||
<config name="vespa.config.search.summary.juniperrc">
|
||||
|
@@ -5,25 +5,24 @@ from fastapi import Request
|
||||
from fastapi.exceptions import RequestValidationError
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
from fastapi.responses import JSONResponse
|
||||
from httpx_oauth.clients.google import GoogleOAuth2
|
||||
|
||||
from danswer.auth.schemas import UserCreate
|
||||
from danswer.auth.schemas import UserRead
|
||||
from danswer.auth.schemas import UserUpdate
|
||||
from danswer.auth.users import auth_backend
|
||||
from danswer.auth.users import fastapi_users
|
||||
from danswer.auth.users import oauth_client
|
||||
from danswer.chat.personas import load_personas_from_yaml
|
||||
from danswer.configs.app_configs import APP_HOST
|
||||
from danswer.configs.app_configs import APP_PORT
|
||||
from danswer.configs.app_configs import AUTH_TYPE
|
||||
from danswer.configs.app_configs import DISABLE_AUTH
|
||||
from danswer.configs.app_configs import DISABLE_GENERATIVE_AI
|
||||
from danswer.configs.app_configs import ENABLE_OAUTH
|
||||
from danswer.configs.app_configs import OAUTH_CLIENT_ID
|
||||
from danswer.configs.app_configs import OAUTH_CLIENT_SECRET
|
||||
from danswer.configs.app_configs import OAUTH_TYPE
|
||||
from danswer.configs.app_configs import OPENID_CONFIG_URL
|
||||
from danswer.configs.app_configs import GOOGLE_OAUTH_CLIENT_ID
|
||||
from danswer.configs.app_configs import GOOGLE_OAUTH_CLIENT_SECRET
|
||||
from danswer.configs.app_configs import SECRET
|
||||
from danswer.configs.app_configs import WEB_DOMAIN
|
||||
from danswer.configs.constants import AuthType
|
||||
from danswer.configs.model_configs import API_BASE_OPENAI
|
||||
from danswer.configs.model_configs import API_TYPE_OPENAI
|
||||
from danswer.configs.model_configs import ASYM_PASSAGE_PREFIX
|
||||
@@ -39,13 +38,14 @@ from danswer.server.chat_backend import router as chat_router
|
||||
from danswer.server.credential import router as credential_router
|
||||
from danswer.server.document_set import router as document_set_router
|
||||
from danswer.server.event_loading import router as event_processing_router
|
||||
from danswer.server.health import router as health_router
|
||||
from danswer.server.manage import router as admin_router
|
||||
from danswer.server.search_backend import router as backend_router
|
||||
from danswer.server.slack_bot_management import router as slack_bot_management_router
|
||||
from danswer.server.state import router as state_router
|
||||
from danswer.server.users import router as user_router
|
||||
from danswer.utils.acl import set_acl_for_vespa
|
||||
from danswer.utils.logger import setup_logger
|
||||
from danswer.utils.variable_functionality import fetch_versioned_implementation
|
||||
|
||||
|
||||
logger = setup_logger()
|
||||
@@ -82,53 +82,41 @@ def get_application() -> FastAPI:
|
||||
application.include_router(credential_router)
|
||||
application.include_router(document_set_router)
|
||||
application.include_router(slack_bot_management_router)
|
||||
application.include_router(health_router)
|
||||
application.include_router(state_router)
|
||||
|
||||
application.include_router(
|
||||
fastapi_users.get_auth_router(auth_backend),
|
||||
prefix="/auth/database",
|
||||
tags=["auth"],
|
||||
)
|
||||
application.include_router(
|
||||
fastapi_users.get_register_router(UserRead, UserCreate),
|
||||
prefix="/auth",
|
||||
tags=["auth"],
|
||||
)
|
||||
application.include_router(
|
||||
fastapi_users.get_reset_password_router(),
|
||||
prefix="/auth",
|
||||
tags=["auth"],
|
||||
)
|
||||
application.include_router(
|
||||
fastapi_users.get_verify_router(UserRead),
|
||||
prefix="/auth",
|
||||
tags=["auth"],
|
||||
)
|
||||
application.include_router(
|
||||
fastapi_users.get_users_router(UserRead, UserUpdate),
|
||||
prefix="/users",
|
||||
tags=["users"],
|
||||
)
|
||||
if ENABLE_OAUTH:
|
||||
if oauth_client is None:
|
||||
raise RuntimeError("OAuth is enabled but no OAuth client is configured")
|
||||
if AUTH_TYPE == AuthType.DISABLED:
|
||||
# Server logs this during auth setup verification step
|
||||
pass
|
||||
|
||||
if OAUTH_TYPE == "google":
|
||||
# special case for google
|
||||
application.include_router(
|
||||
fastapi_users.get_oauth_router(
|
||||
oauth_client,
|
||||
auth_backend,
|
||||
SECRET,
|
||||
associate_by_email=True,
|
||||
is_verified_by_default=True,
|
||||
# points the user back to the login page, where we will call the
|
||||
# /auth/google/callback endpoint + redirect them to the main app
|
||||
redirect_url=f"{WEB_DOMAIN}/auth/google/callback",
|
||||
),
|
||||
prefix="/auth/google",
|
||||
tags=["auth"],
|
||||
)
|
||||
elif AUTH_TYPE == AuthType.BASIC:
|
||||
application.include_router(
|
||||
fastapi_users.get_auth_router(auth_backend),
|
||||
prefix="/auth",
|
||||
tags=["auth"],
|
||||
)
|
||||
application.include_router(
|
||||
fastapi_users.get_register_router(UserRead, UserCreate),
|
||||
prefix="/auth",
|
||||
tags=["auth"],
|
||||
)
|
||||
application.include_router(
|
||||
fastapi_users.get_reset_password_router(),
|
||||
prefix="/auth",
|
||||
tags=["auth"],
|
||||
)
|
||||
application.include_router(
|
||||
fastapi_users.get_verify_router(UserRead),
|
||||
prefix="/auth",
|
||||
tags=["auth"],
|
||||
)
|
||||
application.include_router(
|
||||
fastapi_users.get_users_router(UserRead, UserUpdate),
|
||||
prefix="/users",
|
||||
tags=["users"],
|
||||
)
|
||||
|
||||
elif AUTH_TYPE == AuthType.GOOGLE_OAUTH:
|
||||
oauth_client = GoogleOAuth2(GOOGLE_OAUTH_CLIENT_ID, GOOGLE_OAUTH_CLIENT_SECRET)
|
||||
application.include_router(
|
||||
fastapi_users.get_oauth_router(
|
||||
oauth_client,
|
||||
@@ -136,16 +124,16 @@ def get_application() -> FastAPI:
|
||||
SECRET,
|
||||
associate_by_email=True,
|
||||
is_verified_by_default=True,
|
||||
# points the user back to the login page, where we will call the
|
||||
# /auth/oauth/callback endpoint + redirect them to the main app
|
||||
# points the user back to the login page
|
||||
redirect_url=f"{WEB_DOMAIN}/auth/oauth/callback",
|
||||
),
|
||||
prefix="/auth/oauth",
|
||||
tags=["auth"],
|
||||
)
|
||||
# need basic auth router for `logout` endpoint
|
||||
application.include_router(
|
||||
fastapi_users.get_oauth_associate_router(oauth_client, UserRead, SECRET),
|
||||
prefix="/auth/associate/oauth",
|
||||
fastapi_users.get_logout_router(auth_backend),
|
||||
prefix="/auth",
|
||||
tags=["auth"],
|
||||
)
|
||||
|
||||
@@ -170,30 +158,25 @@ def get_application() -> FastAPI:
|
||||
if API_TYPE_OPENAI == "azure":
|
||||
logger.info(f"Using Azure OpenAI with Endpoint: {API_BASE_OPENAI}")
|
||||
|
||||
auth_status = "off" if DISABLE_AUTH else "on"
|
||||
logger.info(f"User Authentication is turned {auth_status}")
|
||||
verify_auth = fetch_versioned_implementation(
|
||||
"danswer.auth.users", "verify_auth_setting"
|
||||
)
|
||||
# Will throw exception if an issue is found
|
||||
verify_auth()
|
||||
|
||||
if not DISABLE_AUTH:
|
||||
if not ENABLE_OAUTH:
|
||||
logger.debug("OAuth is turned off")
|
||||
else:
|
||||
if not OAUTH_CLIENT_ID:
|
||||
logger.warning("OAuth is turned on but OAUTH_CLIENT_ID is empty")
|
||||
if not OAUTH_CLIENT_SECRET:
|
||||
logger.warning(
|
||||
"OAuth is turned on but OAUTH_CLIENT_SECRET is empty"
|
||||
)
|
||||
if OAUTH_TYPE == "openid" and not OPENID_CONFIG_URL:
|
||||
logger.warning("OpenID is turned on but OPENID_CONFIG_URL is emtpy")
|
||||
else:
|
||||
logger.debug("OAuth is turned on")
|
||||
if DISABLE_AUTH:
|
||||
logger.info("User Authentication is turned off.")
|
||||
|
||||
if GOOGLE_OAUTH_CLIENT_ID and GOOGLE_OAUTH_CLIENT_SECRET:
|
||||
logger.info("Found both OAuth Client ID and secret configured.")
|
||||
|
||||
if SKIP_RERANKING:
|
||||
logger.info("Reranking step of search flow is disabled")
|
||||
|
||||
logger.info(f'Using Embedding model: "{DOCUMENT_ENCODER_MODEL}"')
|
||||
logger.info(f'Query embedding prefix: "{ASYM_QUERY_PREFIX}"')
|
||||
logger.info(f'Passage embedding prefix: "{ASYM_PASSAGE_PREFIX}"')
|
||||
if ASYM_QUERY_PREFIX or ASYM_PASSAGE_PREFIX:
|
||||
logger.info(f'Query embedding prefix: "{ASYM_QUERY_PREFIX}"')
|
||||
logger.info(f'Passage embedding prefix: "{ASYM_PASSAGE_PREFIX}"')
|
||||
|
||||
logger.info("Warming up local NLP models.")
|
||||
warm_up_models()
|
||||
@@ -201,9 +184,9 @@ def get_application() -> FastAPI:
|
||||
qa_model.warm_up_model()
|
||||
|
||||
logger.info("Verifying query preprocessing (NLTK) data is downloaded")
|
||||
nltk.download("stopwords")
|
||||
nltk.download("wordnet")
|
||||
nltk.download("punkt")
|
||||
nltk.download("stopwords", quiet=True)
|
||||
nltk.download("wordnet", quiet=True)
|
||||
nltk.download("punkt", quiet=True)
|
||||
|
||||
logger.info("Verifying public credential exists.")
|
||||
create_initial_public_credential()
|
||||
@@ -219,17 +202,18 @@ def get_application() -> FastAPI:
|
||||
# does nothing if this has been successfully run before
|
||||
set_acl_for_vespa(should_check_if_already_done=True)
|
||||
|
||||
application.add_middleware(
|
||||
CORSMiddleware,
|
||||
allow_origins=["*"], # Change this to the list of allowed origins if needed
|
||||
allow_credentials=True,
|
||||
allow_methods=["*"],
|
||||
allow_headers=["*"],
|
||||
)
|
||||
|
||||
return application
|
||||
|
||||
|
||||
app = get_application()
|
||||
app.add_middleware(
|
||||
CORSMiddleware,
|
||||
allow_origins=["*"], # Change this to the list of allowed origins if needed
|
||||
allow_credentials=True,
|
||||
allow_methods=["*"],
|
||||
allow_headers=["*"],
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
@@ -9,6 +9,10 @@ from danswer.direct_qa.qa_block import dict_based_prompt_to_langchain_prompt
|
||||
from danswer.llm.build import get_default_llm
|
||||
from danswer.server.models import QueryValidationResponse
|
||||
from danswer.server.utils import get_json_line
|
||||
from danswer.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
QUERY_PAT = "QUERY: "
|
||||
REASONING_PAT = "THOUGHT: "
|
||||
@@ -94,37 +98,44 @@ def get_query_answerability(user_query: str) -> tuple[str, bool]:
|
||||
def stream_query_answerability(user_query: str) -> Iterator[str]:
|
||||
messages = get_query_validation_messages(user_query)
|
||||
filled_llm_prompt = dict_based_prompt_to_langchain_prompt(messages)
|
||||
tokens = get_default_llm().stream(filled_llm_prompt)
|
||||
reasoning_pat_found = False
|
||||
model_output = ""
|
||||
hold_answerable = ""
|
||||
for token in tokens:
|
||||
model_output = model_output + token
|
||||
try:
|
||||
tokens = get_default_llm().stream(filled_llm_prompt)
|
||||
reasoning_pat_found = False
|
||||
model_output = ""
|
||||
hold_answerable = ""
|
||||
for token in tokens:
|
||||
model_output = model_output + token
|
||||
|
||||
if ANSWERABLE_PAT in model_output:
|
||||
continue
|
||||
|
||||
if not reasoning_pat_found and REASONING_PAT in model_output:
|
||||
reasoning_pat_found = True
|
||||
reason_ind = model_output.find(REASONING_PAT)
|
||||
remaining = model_output[reason_ind + len(REASONING_PAT) :]
|
||||
if remaining:
|
||||
yield get_json_line(asdict(DanswerAnswerPiece(answer_piece=remaining)))
|
||||
continue
|
||||
|
||||
if reasoning_pat_found:
|
||||
hold_answerable = hold_answerable + token
|
||||
if hold_answerable == ANSWERABLE_PAT[: len(hold_answerable)]:
|
||||
if ANSWERABLE_PAT in model_output:
|
||||
continue
|
||||
yield get_json_line(
|
||||
asdict(DanswerAnswerPiece(answer_piece=hold_answerable))
|
||||
)
|
||||
hold_answerable = ""
|
||||
|
||||
reasoning = extract_answerability_reasoning(model_output)
|
||||
answerable = extract_answerability_bool(model_output)
|
||||
if not reasoning_pat_found and REASONING_PAT in model_output:
|
||||
reasoning_pat_found = True
|
||||
reason_ind = model_output.find(REASONING_PAT)
|
||||
remaining = model_output[reason_ind + len(REASONING_PAT) :]
|
||||
if remaining:
|
||||
yield get_json_line(
|
||||
asdict(DanswerAnswerPiece(answer_piece=remaining))
|
||||
)
|
||||
continue
|
||||
|
||||
yield get_json_line(
|
||||
QueryValidationResponse(reasoning=reasoning, answerable=answerable).dict()
|
||||
)
|
||||
if reasoning_pat_found:
|
||||
hold_answerable = hold_answerable + token
|
||||
if hold_answerable == ANSWERABLE_PAT[: len(hold_answerable)]:
|
||||
continue
|
||||
yield get_json_line(
|
||||
asdict(DanswerAnswerPiece(answer_piece=hold_answerable))
|
||||
)
|
||||
hold_answerable = ""
|
||||
|
||||
reasoning = extract_answerability_reasoning(model_output)
|
||||
answerable = extract_answerability_bool(model_output)
|
||||
|
||||
yield get_json_line(
|
||||
QueryValidationResponse(reasoning=reasoning, answerable=answerable).dict()
|
||||
)
|
||||
except Exception as e:
|
||||
# exception is logged in the answer_question method, no need to re-log
|
||||
yield get_json_line({"error": str(e)})
|
||||
logger.exception("Failed to validate Query")
|
||||
return
|
||||
|
@@ -1,11 +0,0 @@
|
||||
from fastapi import APIRouter
|
||||
|
||||
from danswer.server.models import StatusResponse
|
||||
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
@router.get("/health", response_model=StatusResponse)
|
||||
def healthcheck() -> StatusResponse:
|
||||
return StatusResponse(success=True, message="ok")
|
@@ -1,7 +1,6 @@
|
||||
from datetime import datetime
|
||||
from typing import Any
|
||||
from typing import Generic
|
||||
from typing import Literal
|
||||
from typing import Optional
|
||||
from typing import TypeVar
|
||||
from uuid import UUID
|
||||
@@ -9,7 +8,9 @@ from uuid import UUID
|
||||
from pydantic import BaseModel
|
||||
from pydantic.generics import GenericModel
|
||||
|
||||
from danswer.auth.schemas import UserRole
|
||||
from danswer.configs.app_configs import MASK_CREDENTIAL_PREFIX
|
||||
from danswer.configs.constants import AuthType
|
||||
from danswer.configs.constants import DocumentSource
|
||||
from danswer.configs.constants import MessageType
|
||||
from danswer.configs.constants import QAFeedbackType
|
||||
@@ -38,6 +39,10 @@ class StatusResponse(GenericModel, Generic[DataT]):
|
||||
data: Optional[DataT] = None
|
||||
|
||||
|
||||
class AuthTypeResponse(BaseModel):
|
||||
auth_type: AuthType
|
||||
|
||||
|
||||
class DataRequest(BaseModel):
|
||||
data: str
|
||||
|
||||
@@ -47,6 +52,15 @@ class HelperResponse(BaseModel):
|
||||
details: list[str] | None = None
|
||||
|
||||
|
||||
class UserInfo(BaseModel):
|
||||
id: str
|
||||
email: str
|
||||
is_active: bool
|
||||
is_superuser: bool
|
||||
is_verified: bool
|
||||
role: UserRole
|
||||
|
||||
|
||||
class GoogleAppWebCredentials(BaseModel):
|
||||
client_id: str
|
||||
project_id: str
|
||||
@@ -84,10 +98,6 @@ class FileUploadResponse(BaseModel):
|
||||
file_paths: list[str]
|
||||
|
||||
|
||||
class HealthCheckResponse(BaseModel):
|
||||
status: Literal["ok"]
|
||||
|
||||
|
||||
class ObjectCreationIdResponse(BaseModel):
|
||||
id: int | str
|
||||
|
||||
|
18
backend/danswer/server/state.py
Normal file
18
backend/danswer/server/state.py
Normal file
@@ -0,0 +1,18 @@
|
||||
from fastapi import APIRouter
|
||||
|
||||
from danswer.configs.app_configs import AUTH_TYPE
|
||||
from danswer.server.models import AuthTypeResponse
|
||||
from danswer.server.models import StatusResponse
|
||||
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
@router.get("/health")
|
||||
def healthcheck() -> StatusResponse:
|
||||
return StatusResponse(success=True, message="ok")
|
||||
|
||||
|
||||
@router.get("/auth/type")
|
||||
def get_auth_type() -> AuthTypeResponse:
|
||||
return AuthTypeResponse(auth_type=AUTH_TYPE)
|
@@ -9,12 +9,13 @@ from sqlalchemy.orm import Session
|
||||
from danswer.auth.schemas import UserRead
|
||||
from danswer.auth.schemas import UserRole
|
||||
from danswer.auth.users import current_admin_user
|
||||
from danswer.auth.users import current_user
|
||||
from danswer.db.engine import get_session
|
||||
from danswer.db.engine import get_sqlalchemy_async_engine
|
||||
from danswer.db.models import User
|
||||
from danswer.db.users import list_users
|
||||
from danswer.server.models import UserByEmail
|
||||
|
||||
from danswer.server.models import UserInfo
|
||||
|
||||
router = APIRouter(prefix="/manage")
|
||||
|
||||
@@ -43,3 +44,18 @@ def list_all_users(
|
||||
) -> list[UserRead]:
|
||||
users = list_users(db_session)
|
||||
return [UserRead.from_orm(user) for user in users]
|
||||
|
||||
|
||||
@router.get("/me")
|
||||
def verify_user_logged_in(user: User | None = Depends(current_user)) -> UserInfo:
|
||||
if user is None:
|
||||
raise HTTPException(status_code=401, detail="User Not Authenticated")
|
||||
|
||||
return UserInfo(
|
||||
id=str(user.id),
|
||||
email=user.email,
|
||||
is_active=user.is_active,
|
||||
is_superuser=user.is_superuser,
|
||||
is_verified=user.is_verified,
|
||||
role=user.role,
|
||||
)
|
||||
|
21
backend/danswer/utils/variable_functionality.py
Normal file
21
backend/danswer/utils/variable_functionality.py
Normal file
@@ -0,0 +1,21 @@
|
||||
import importlib
|
||||
from typing import Any
|
||||
|
||||
|
||||
class DanswerVersion:
|
||||
def __init__(self) -> None:
|
||||
self._is_ee = False
|
||||
|
||||
def set_ee(self) -> None:
|
||||
self._is_ee = True
|
||||
|
||||
def get_is_ee_version(self) -> bool:
|
||||
return self._is_ee
|
||||
|
||||
|
||||
global_version = DanswerVersion()
|
||||
|
||||
|
||||
def fetch_versioned_implementation(module: str, attribute: str) -> Any:
|
||||
module_full = f"ee.{module}" if global_version.get_is_ee_version() else module
|
||||
return getattr(importlib.import_module(module_full), attribute)
|
@@ -46,7 +46,6 @@ Additional steps for user auth and https if you do want to use Docker Compose fo
|
||||
1. Set up a `.env` file in this directory with relevant environment variables.
|
||||
- Refer to `env.prod.template`
|
||||
- To turn on user auth, set:
|
||||
- ENABLE_OAUTH=True
|
||||
- GOOGLE_OAUTH_CLIENT_ID=\<your GCP API client ID\>
|
||||
- GOOGLE_OAUTH_CLIENT_SECRET=\<associated client secret\>
|
||||
- Refer to https://developers.google.com/identity/gsi/web/guides/get-google-api-clientid
|
||||
|
@@ -30,10 +30,8 @@ services:
|
||||
- TYPESENSE_HOST=search_engine
|
||||
- TYPESENSE_API_KEY=${TYPESENSE_API_KEY:-typesense_api_key}
|
||||
- LOG_LEVEL=${LOG_LEVEL:-info}
|
||||
- DISABLE_AUTH=${DISABLE_AUTH:-True}
|
||||
- AUTH_TYPE=${AUTH_TYPE:-disabled}
|
||||
- QA_TIMEOUT=${QA_TIMEOUT:-}
|
||||
- OAUTH_TYPE=${OAUTH_TYPE:-google}
|
||||
- OPENID_CONFIG_URL=${OPENID_CONFIG_URL:-}
|
||||
- GOOGLE_OAUTH_CLIENT_ID=${GOOGLE_OAUTH_CLIENT_ID:-}
|
||||
- GOOGLE_OAUTH_CLIENT_SECRET=${GOOGLE_OAUTH_CLIENT_SECRET:-}
|
||||
- DISABLE_GENERATIVE_AI=${DISABLE_GENERATIVE_AI:-}
|
||||
@@ -91,13 +89,14 @@ services:
|
||||
dockerfile: Dockerfile
|
||||
args:
|
||||
- NEXT_PUBLIC_DISABLE_STREAMING=${NEXT_PUBLIC_DISABLE_STREAMING:-false}
|
||||
- NEXT_PUBLIC_AUTH_TYPE=${AUTH_TYPE:-disabled}
|
||||
depends_on:
|
||||
- api_server
|
||||
restart: always
|
||||
environment:
|
||||
- INTERNAL_URL=http://api_server:8080
|
||||
- WEB_DOMAIN=${WEB_DOMAIN:-}
|
||||
- DISABLE_AUTH=${DISABLE_AUTH:-True}
|
||||
- AUTH_TYPE=${AUTH_TYPE:-disabled}
|
||||
- OAUTH_NAME=${OAUTH_NAME:-}
|
||||
relational_db:
|
||||
image: postgres:15.2-alpine
|
||||
|
@@ -24,11 +24,9 @@ services:
|
||||
- NUM_DOCUMENT_TOKENS_FED_TO_GENERATIVE_MODEL=${NUM_DOCUMENT_TOKENS_FED_TO_GENERATIVE_MODEL:-}
|
||||
- POSTGRES_HOST=relational_db
|
||||
- VESPA_HOST=index
|
||||
- DISABLE_AUTH=${DISABLE_AUTH:-True}
|
||||
- AUTH_TYPE=${AUTH_TYPE:-disabled}
|
||||
- QA_TIMEOUT=${QA_TIMEOUT:-}
|
||||
- VALID_EMAIL_DOMAINS=${VALID_EMAIL_DOMAINS:-}
|
||||
- OAUTH_TYPE=${OAUTH_TYPE:-google}
|
||||
- OPENID_CONFIG_URL=${OPENID_CONFIG_URL:-}
|
||||
- GOOGLE_OAUTH_CLIENT_ID=${GOOGLE_OAUTH_CLIENT_ID:-}
|
||||
- GOOGLE_OAUTH_CLIENT_SECRET=${GOOGLE_OAUTH_CLIENT_SECRET:-}
|
||||
- DISABLE_GENERATIVE_AI=${DISABLE_GENERATIVE_AI:-}
|
||||
@@ -77,6 +75,7 @@ services:
|
||||
- AZURE_DEPLOYMENT_ID=${AZURE_DEPLOYMENT_ID:-}
|
||||
- CONTINUE_ON_CONNECTOR_FAILURE=${CONTINUE_ON_CONNECTOR_FAILURE:-}
|
||||
- NUM_INDEXING_WORKERS=${NUM_INDEXING_WORKERS:-}
|
||||
# Danswer SlackBot Configs
|
||||
- DANSWER_BOT_SLACK_APP_TOKEN=${DANSWER_BOT_SLACK_APP_TOKEN:-}
|
||||
- DANSWER_BOT_SLACK_BOT_TOKEN=${DANSWER_BOT_SLACK_BOT_TOKEN:-}
|
||||
- DANSWER_BOT_DISABLE_DOCS_ONLY_ANSWER=${DANSWER_BOT_DISABLE_DOCS_ONLY_ANSWER:-}
|
||||
@@ -104,13 +103,14 @@ services:
|
||||
dockerfile: Dockerfile
|
||||
args:
|
||||
- NEXT_PUBLIC_DISABLE_STREAMING=${NEXT_PUBLIC_DISABLE_STREAMING:-false}
|
||||
- NEXT_PUBLIC_AUTH_TYPE=${AUTH_TYPE:-disabled}
|
||||
depends_on:
|
||||
- api_server
|
||||
restart: always
|
||||
environment:
|
||||
- INTERNAL_URL=http://api_server:8080
|
||||
- WEB_DOMAIN=${WEB_DOMAIN:-}
|
||||
- DISABLE_AUTH=${DISABLE_AUTH:-True}
|
||||
- AUTH_TYPE=${AUTH_TYPE:-disabled}
|
||||
- OAUTH_NAME=${OAUTH_NAME:-}
|
||||
relational_db:
|
||||
image: postgres:15.2-alpine
|
||||
|
@@ -17,6 +17,7 @@ services:
|
||||
env_file:
|
||||
- .env
|
||||
environment:
|
||||
- AUTH_TYPE=${AUTH_TYPE:-google_oauth}
|
||||
- DOCUMENT_INDEX_TYPE=split
|
||||
- POSTGRES_HOST=relational_db
|
||||
- QDRANT_HOST=vector_db
|
||||
@@ -40,6 +41,7 @@ services:
|
||||
env_file:
|
||||
- .env
|
||||
environment:
|
||||
- AUTH_TYPE=${AUTH_TYPE:-google_oauth}
|
||||
- DOCUMENT_INDEX_TYPE=split
|
||||
- POSTGRES_HOST=relational_db
|
||||
- QDRANT_HOST=vector_db
|
||||
@@ -55,12 +57,16 @@ services:
|
||||
build:
|
||||
context: ../../web
|
||||
dockerfile: Dockerfile
|
||||
args:
|
||||
- NEXT_PUBLIC_DISABLE_STREAMING=${NEXT_PUBLIC_DISABLE_STREAMING:-false}
|
||||
- NEXT_PUBLIC_AUTH_TYPE=${AUTH_TYPE:-google_oauth}
|
||||
depends_on:
|
||||
- api_server
|
||||
restart: always
|
||||
env_file:
|
||||
- .env
|
||||
environment:
|
||||
- AUTH_TYPE=${AUTH_TYPE:-google_oauth}
|
||||
- INTERNAL_URL=http://api_server:8080
|
||||
relational_db:
|
||||
image: postgres:15.2-alpine
|
||||
|
@@ -16,6 +16,7 @@ services:
|
||||
env_file:
|
||||
- .env
|
||||
environment:
|
||||
- AUTH_TYPE=${AUTH_TYPE:-google_oauth}
|
||||
- POSTGRES_HOST=relational_db
|
||||
- VESPA_HOST=index
|
||||
volumes:
|
||||
@@ -37,6 +38,7 @@ services:
|
||||
env_file:
|
||||
- .env
|
||||
environment:
|
||||
- AUTH_TYPE=${AUTH_TYPE:-google_oauth}
|
||||
- POSTGRES_HOST=relational_db
|
||||
- VESPA_HOST=index
|
||||
volumes:
|
||||
@@ -50,12 +52,16 @@ services:
|
||||
build:
|
||||
context: ../../web
|
||||
dockerfile: Dockerfile
|
||||
args:
|
||||
- NEXT_PUBLIC_DISABLE_STREAMING=${NEXT_PUBLIC_DISABLE_STREAMING:-false}
|
||||
- NEXT_PUBLIC_AUTH_TYPE=${AUTH_TYPE:-google_oauth}
|
||||
depends_on:
|
||||
- api_server
|
||||
restart: always
|
||||
env_file:
|
||||
- .env
|
||||
environment:
|
||||
- AUTH_TYPE=${AUTH_TYPE:-google_oauth}
|
||||
- INTERNAL_URL=http://api_server:8080
|
||||
relational_db:
|
||||
image: postgres:15.2-alpine
|
||||
|
@@ -1,10 +1,10 @@
|
||||
# Fill in the values and copy the contents of this file to .env in the deployment directory.
|
||||
# Some valid default values are provided where applicable, delete the variables which you don't set values for.
|
||||
# Only applicable when using the docker-compose.prod.yml compose file.
|
||||
# This is only necessary when using the docker-compose.prod.yml compose file.
|
||||
|
||||
|
||||
# Insert your OpenAI API key here, currently the only Generative AI endpoint for QA that we support is OpenAI
|
||||
# If not provided here, UI will prompt on setup
|
||||
# Insert your OpenAI API key here If not provided here, UI will prompt on setup.
|
||||
# This env variable takes precedence over UI settings.
|
||||
GEN_AI_API_KEY=
|
||||
# Choose between "openai-chat-completion" and "openai-completion"
|
||||
INTERNAL_MODEL_VERSION=openai-chat-completion
|
||||
@@ -17,21 +17,12 @@ API_TYPE_OPENAI=
|
||||
API_VERSION_OPENAI=
|
||||
AZURE_DEPLOYMENT_ID=
|
||||
|
||||
# Could be something like danswer.companyname.com. Requires additional setup if not localhost
|
||||
# Could be something like danswer.companyname.com
|
||||
WEB_DOMAIN=http://localhost:3000
|
||||
|
||||
# If you want to make the postgres / typesense instances a little more secure, modify the below
|
||||
# Note that the postgres / typesense / qdrant containers do not expose any ports to the outside world,
|
||||
# so they are already unaccessible unless someone has ssh access to the machine that Danswer is running on
|
||||
# Default values here are what Postgres uses by default, feel free to change.
|
||||
POSTGRES_USER=postgres
|
||||
POSTGRES_PASSWORD=password
|
||||
TYPESENSE_API_KEY=typesense_api_key
|
||||
|
||||
# Currently frontend page doesn't have basic auth, use OAuth if user auth is enabled.
|
||||
ENABLE_OAUTH=True
|
||||
# The two settings below are only required if ENABLE_OAUTH is true
|
||||
GOOGLE_OAUTH_CLIENT_ID=
|
||||
GOOGLE_OAUTH_CLIENT_SECRET=
|
||||
|
||||
# If you want to setup a slack bot to answer questions automatically in Slack
|
||||
# channels it is added to, you must specify the below.
|
||||
@@ -45,15 +36,25 @@ SECRET=
|
||||
# How long before user needs to reauthenticate, default to 1 day. (cookie expiration time)
|
||||
SESSION_EXPIRE_TIME_SECONDS=86400
|
||||
|
||||
# used to specify a list of allowed user domains
|
||||
# The following are for configuring User Authentication, supported flows are:
|
||||
# disabled
|
||||
# simple (email/password + user account creation in Danswer)
|
||||
# google_oauth (login with google/gmail account)
|
||||
# oidc (only in Danswer enterprise edition)
|
||||
# saml (only in Danswer enterprise edition)
|
||||
AUTH_TYPE=
|
||||
|
||||
# Set the two below to use with Google OAuth
|
||||
GOOGLE_OAUTH_CLIENT_ID=
|
||||
GOOGLE_OAUTH_CLIENT_SECRET=
|
||||
|
||||
# OpenID Connect (OIDC)
|
||||
OPENID_CONFIG_URL=
|
||||
|
||||
# SAML config directory for OneLogin compatible setups
|
||||
SAML_CONF_DIR=
|
||||
|
||||
# used to specify a list of allowed user domains, only checked if user Auth is turned on
|
||||
# e.g. `VALID_EMAIL_DOMAINS=example.com,example.org` will only allow users
|
||||
# with an @example.com or an @example.org email
|
||||
VALID_EMAIL_DOMAINS=
|
||||
|
||||
# Only relevant if using basic auth (not supported on frontend yet)
|
||||
REQUIRE_EMAIL_VERIFICATION=True
|
||||
# The four settings below are only required if REQUIRE_EMAIL_VERIFICATION is True
|
||||
SMTP_SERVER=
|
||||
SMTP_PORT=
|
||||
SMTP_USER=
|
||||
SMTP_PASS=
|
||||
|
@@ -40,6 +40,8 @@ spec:
|
||||
ports:
|
||||
- containerPort: 8080
|
||||
env:
|
||||
- name: AUTH_TYPE
|
||||
value: google_oauth
|
||||
- name: POSTGRES_HOST
|
||||
value: relational-db-service
|
||||
- name: VESPA_HOST
|
||||
|
@@ -31,6 +31,10 @@ spec:
|
||||
imagePullPolicy: IfNotPresent
|
||||
ports:
|
||||
- containerPort: 3000
|
||||
args:
|
||||
- "NEXT_PUBLIC_AUTH_TYPE=google_oauth"
|
||||
env:
|
||||
- name: AUTH_TYPE
|
||||
value: google_oauth
|
||||
- name: INTERNAL_URL
|
||||
value: "http://api-server-service:80"
|
||||
|
@@ -30,6 +30,9 @@ ENV NEXT_TELEMETRY_DISABLED 1
|
||||
ARG NEXT_PUBLIC_DISABLE_STREAMING
|
||||
ENV NEXT_PUBLIC_DISABLE_STREAMING=${NEXT_PUBLIC_DISABLE_STREAMING}
|
||||
|
||||
ARG NEXT_PUBLIC_AUTH_TYPE
|
||||
ENV NEXT_PUBLIC_AUTH_TYPE=${NEXT_PUBLIC_AUTH_TYPE}
|
||||
|
||||
RUN npm run build
|
||||
|
||||
# Step 3. Production image, copy all the files and run next
|
||||
@@ -65,6 +68,9 @@ COPY --from=builder --chown=nextjs:nodejs /app/.next/static ./.next/static
|
||||
ARG NEXT_PUBLIC_DISABLE_STREAMING
|
||||
ENV NEXT_PUBLIC_DISABLE_STREAMING=${NEXT_PUBLIC_DISABLE_STREAMING}
|
||||
|
||||
ARG NEXT_PUBLIC_AUTH_TYPE
|
||||
ENV NEXT_PUBLIC_AUTH_TYPE=${NEXT_PUBLIC_AUTH_TYPE}
|
||||
|
||||
# Note: Don't expose ports here, Compose will handle that for us if necessary.
|
||||
# If you want to run this without compose, specify the ports to
|
||||
# expose via cli
|
||||
|
@@ -1,23 +0,0 @@
|
||||
import { getDomain } from "@/lib/redirectSS";
|
||||
import { buildUrl } from "@/lib/utilsSS";
|
||||
import { NextRequest, NextResponse } from "next/server";
|
||||
|
||||
export const GET = async (request: NextRequest) => {
|
||||
// Wrapper around the FastAPI endpoint /auth/google/callback,
|
||||
// which adds back a redirect to the main app.
|
||||
const url = new URL(buildUrl("/auth/google/callback"));
|
||||
url.search = request.nextUrl.search;
|
||||
|
||||
const response = await fetch(url.toString());
|
||||
const setCookieHeader = response.headers.get("set-cookie");
|
||||
|
||||
if (!setCookieHeader) {
|
||||
return NextResponse.redirect(new URL("/auth/error", getDomain(request)));
|
||||
}
|
||||
|
||||
const redirectResponse = NextResponse.redirect(
|
||||
new URL("/", getDomain(request))
|
||||
);
|
||||
redirectResponse.headers.set("set-cookie", setCookieHeader);
|
||||
return redirectResponse;
|
||||
};
|
@@ -1,7 +1,7 @@
|
||||
import { HealthCheckBanner } from "@/components/health/healthcheck";
|
||||
import { DISABLE_AUTH, OAUTH_NAME } from "@/lib/constants";
|
||||
import { User } from "@/lib/types";
|
||||
import { getGoogleOAuthUrlSS, getCurrentUserSS } from "@/lib/userSS";
|
||||
import { getCurrentUserSS, getAuthUrlSS } from "@/lib/userSS";
|
||||
import { redirect } from "next/navigation";
|
||||
|
||||
const BUTTON_STYLE =
|
||||
@@ -21,10 +21,11 @@ const Page = async () => {
|
||||
// will not render
|
||||
let currentUser: User | null = null;
|
||||
let authorizationUrl: string | null = null;
|
||||
let autoRedirect: boolean = false;
|
||||
try {
|
||||
[currentUser, authorizationUrl] = await Promise.all([
|
||||
[currentUser, [authorizationUrl, autoRedirect]] = await Promise.all([
|
||||
getCurrentUserSS(),
|
||||
getGoogleOAuthUrlSS(),
|
||||
getAuthUrlSS(),
|
||||
]);
|
||||
} catch (e) {
|
||||
console.log(`Some fetch failed for the login page - ${e}`);
|
||||
@@ -35,6 +36,10 @@ const Page = async () => {
|
||||
return redirect("/");
|
||||
}
|
||||
|
||||
if (autoRedirect && authorizationUrl) {
|
||||
return redirect(authorizationUrl);
|
||||
}
|
||||
|
||||
return (
|
||||
<main>
|
||||
<div className="absolute top-10x w-full">
|
||||
|
8
web/src/app/auth/logout/route.ts
Normal file
8
web/src/app/auth/logout/route.ts
Normal file
@@ -0,0 +1,8 @@
|
||||
import { logoutSS } from "@/lib/userSS";
|
||||
import { NextRequest } from "next/server";
|
||||
|
||||
export const POST = async (request: NextRequest) => {
|
||||
// Directs the logout request to the appropriate FastAPI endpoint.
|
||||
// Needed since env variables don't work well on the client-side
|
||||
return await logoutSS(request.headers);
|
||||
};
|
33
web/src/app/auth/saml/callback/route.ts
Normal file
33
web/src/app/auth/saml/callback/route.ts
Normal file
@@ -0,0 +1,33 @@
|
||||
import { getDomain } from "@/lib/redirectSS";
|
||||
import { buildUrl } from "@/lib/utilsSS";
|
||||
import { NextRequest, NextResponse } from "next/server";
|
||||
|
||||
// have to use this so we don't hit the redirect URL with a `POST` request
|
||||
const SEE_OTHER_REDIRECT_STATUS = 303;
|
||||
|
||||
export const POST = async (request: NextRequest) => {
|
||||
// Wrapper around the FastAPI endpoint /auth/saml/callback,
|
||||
// which adds back a redirect to the main app.
|
||||
const url = new URL(buildUrl("/auth/saml/callback"));
|
||||
url.search = request.nextUrl.search;
|
||||
|
||||
const response = await fetch(url.toString(), {
|
||||
method: "POST",
|
||||
body: await request.formData(),
|
||||
});
|
||||
const setCookieHeader = response.headers.get("set-cookie");
|
||||
|
||||
if (!setCookieHeader) {
|
||||
return NextResponse.redirect(
|
||||
new URL("/auth/error", getDomain(request)),
|
||||
SEE_OTHER_REDIRECT_STATUS
|
||||
);
|
||||
}
|
||||
|
||||
const redirectResponse = NextResponse.redirect(
|
||||
new URL("/", getDomain(request)),
|
||||
SEE_OTHER_REDIRECT_STATUS
|
||||
);
|
||||
redirectResponse.headers.set("set-cookie", setCookieHeader);
|
||||
return redirectResponse;
|
||||
};
|
@@ -1,5 +1,9 @@
|
||||
export const DISABLE_AUTH = process.env.DISABLE_AUTH?.toLowerCase() === "true";
|
||||
export type AuthType = "disabled" | "google_oauth" | "oidc" | "saml";
|
||||
|
||||
export const AUTH_TYPE = (process.env.AUTH_TYPE ||
|
||||
process.env.NEXT_PUBLIC_AUTH_TYPE ||
|
||||
"disabled") as AuthType;
|
||||
export const DISABLE_AUTH = AUTH_TYPE === "disabled";
|
||||
export const OAUTH_NAME = process.env.OAUTH_NAME || "Google";
|
||||
|
||||
export const INTERNAL_URL = process.env.INTERNAL_URL || "http://127.0.0.1:8080";
|
||||
|
@@ -2,7 +2,7 @@ import { User } from "./types";
|
||||
|
||||
// should be used client-side only
|
||||
export const getCurrentUser = async (): Promise<User | null> => {
|
||||
const response = await fetch("/api/users/me", {
|
||||
const response = await fetch("/api/manage/me", {
|
||||
credentials: "include",
|
||||
});
|
||||
if (!response.ok) {
|
||||
@@ -13,7 +13,7 @@ export const getCurrentUser = async (): Promise<User | null> => {
|
||||
};
|
||||
|
||||
export const logout = async (): Promise<boolean> => {
|
||||
const response = await fetch("/api/auth/database/logout", {
|
||||
const response = await fetch("/auth/logout", {
|
||||
method: "POST",
|
||||
credentials: "include",
|
||||
});
|
||||
|
@@ -2,8 +2,19 @@ import { cookies } from "next/headers";
|
||||
import { User } from "./types";
|
||||
import { buildUrl } from "./utilsSS";
|
||||
import { ReadonlyRequestCookies } from "next/dist/server/web/spec-extension/adapters/request-cookies";
|
||||
import { AUTH_TYPE } from "./constants";
|
||||
|
||||
export const getGoogleOAuthUrlSS = async (): Promise<string> => {
|
||||
const geOIDCAuthUrlSS = async (): Promise<string> => {
|
||||
const res = await fetch(buildUrl("/auth/oidc/authorize"));
|
||||
if (!res.ok) {
|
||||
throw new Error("Failed to fetch data");
|
||||
}
|
||||
|
||||
const data: { authorization_url: string } = await res.json();
|
||||
return data.authorization_url;
|
||||
};
|
||||
|
||||
const getGoogleOAuthUrlSS = async (): Promise<string> => {
|
||||
const res = await fetch(buildUrl("/auth/oauth/authorize"));
|
||||
if (!res.ok) {
|
||||
throw new Error("Failed to fetch data");
|
||||
@@ -13,10 +24,63 @@ export const getGoogleOAuthUrlSS = async (): Promise<string> => {
|
||||
return data.authorization_url;
|
||||
};
|
||||
|
||||
// should be used server-side only
|
||||
const getSAMLAuthUrlSS = async (): Promise<string> => {
|
||||
const res = await fetch(buildUrl("/auth/saml/authorize"));
|
||||
if (!res.ok) {
|
||||
throw new Error("Failed to fetch data");
|
||||
}
|
||||
|
||||
const data: { authorization_url: string } = await res.json();
|
||||
return data.authorization_url;
|
||||
};
|
||||
|
||||
export const getAuthUrlSS = async (): Promise<[string, boolean]> => {
|
||||
// Returns the auth url and whether or not we should auto-redirect
|
||||
switch (AUTH_TYPE) {
|
||||
case "disabled":
|
||||
return ["", true];
|
||||
case "google_oauth": {
|
||||
return [await getGoogleOAuthUrlSS(), false];
|
||||
}
|
||||
case "saml": {
|
||||
return [await getSAMLAuthUrlSS(), true];
|
||||
}
|
||||
case "oidc": {
|
||||
return [await geOIDCAuthUrlSS(), true];
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
const logoutStandardSS = async (headers: Headers): Promise<Response> => {
|
||||
return await fetch(buildUrl("/auth/logout"), {
|
||||
method: "POST",
|
||||
headers: headers,
|
||||
});
|
||||
};
|
||||
|
||||
const logoutSAMLSS = async (headers: Headers): Promise<Response> => {
|
||||
return await fetch(buildUrl("/auth/saml/logout"), {
|
||||
method: "POST",
|
||||
headers: headers,
|
||||
});
|
||||
};
|
||||
|
||||
export const logoutSS = async (headers: Headers): Promise<Response | null> => {
|
||||
switch (AUTH_TYPE) {
|
||||
case "disabled":
|
||||
return null;
|
||||
case "saml": {
|
||||
return await logoutSAMLSS(headers);
|
||||
}
|
||||
default: {
|
||||
return await logoutStandardSS(headers);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
export const getCurrentUserSS = async (): Promise<User | null> => {
|
||||
try {
|
||||
const response = await fetch(buildUrl("/users/me"), {
|
||||
const response = await fetch(buildUrl("/manage/me"), {
|
||||
credentials: "include",
|
||||
next: { revalidate: 0 },
|
||||
headers: {
|
||||
|
Reference in New Issue
Block a user