mirror of
https://github.com/danswer-ai/danswer.git
synced 2025-09-19 12:03:54 +02:00
Initial EE features (#3)
This commit is contained in:
36
backend/ee/LICENSE
Normal file
36
backend/ee/LICENSE
Normal file
@@ -0,0 +1,36 @@
|
||||
The DanswerAI Enterprise license (the “Enterprise License”)
|
||||
Copyright (c) 2023 DanswerAI, Inc.
|
||||
|
||||
With regard to the Danswer Software:
|
||||
|
||||
This software and associated documentation files (the "Software") may only be
|
||||
used in production, if you (and any entity that you represent) have agreed to,
|
||||
and are in compliance with, the DanswerAI Subscription Terms of Service, available
|
||||
at https://danswer.ai/terms (the “Enterprise Terms”), or other
|
||||
agreement governing the use of the Software, as agreed by you and DanswerAI,
|
||||
and otherwise have a valid Danswer Enterprise license for the
|
||||
correct number of user seats. Subject to the foregoing sentence, you are free to
|
||||
modify this Software and publish patches to the Software. You agree that DanswerAI
|
||||
and/or its licensors (as applicable) retain all right, title and interest in and
|
||||
to all such modifications and/or patches, and all such modifications and/or
|
||||
patches may only be used, copied, modified, displayed, distributed, or otherwise
|
||||
exploited with a valid Danswer Enterprise license for the correct
|
||||
number of user seats. Notwithstanding the foregoing, you may copy and modify
|
||||
the Software for development and testing purposes, without requiring a
|
||||
subscription. You agree that DanswerAI and/or its licensors (as applicable) retain
|
||||
all right, title and interest in and to all such modifications. You are not
|
||||
granted any other rights beyond what is expressly stated herein. Subject to the
|
||||
foregoing, it is forbidden to copy, merge, publish, distribute, sublicense,
|
||||
and/or sell the Software.
|
||||
|
||||
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
||||
SOFTWARE.
|
||||
|
||||
For all third party components incorporated into the Danswer Software, those
|
||||
components are licensed under the original license provided by the owner of the
|
||||
applicable component.
|
0
backend/ee/__init__.py
Normal file
0
backend/ee/__init__.py
Normal file
0
backend/ee/danswer/__init__.py
Normal file
0
backend/ee/danswer/__init__.py
Normal file
0
backend/ee/danswer/auth/__init__.py
Normal file
0
backend/ee/danswer/auth/__init__.py
Normal file
45
backend/ee/danswer/auth/users.py
Normal file
45
backend/ee/danswer/auth/users.py
Normal file
@@ -0,0 +1,45 @@
|
||||
from fastapi import HTTPException
|
||||
from fastapi import Request
|
||||
from fastapi import status
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from danswer.configs.app_configs import AUTH_TYPE
|
||||
from danswer.configs.app_configs import DISABLE_AUTH
|
||||
from danswer.configs.constants import AuthType
|
||||
from danswer.db.models import User
|
||||
from danswer.utils.logger import setup_logger
|
||||
from ee.danswer.db.saml import get_saml_account
|
||||
from ee.danswer.utils.secrets import extract_hashed_cookie
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
def verify_auth_setting() -> None:
|
||||
# All the Auth flows are valid for EE version
|
||||
logger.info(f"Using Auth Type: {AUTH_TYPE.value}")
|
||||
|
||||
|
||||
async def double_check_user(
|
||||
request: Request,
|
||||
user: User | None,
|
||||
db_session: Session,
|
||||
optional: bool = DISABLE_AUTH,
|
||||
) -> User | None:
|
||||
if optional:
|
||||
return None
|
||||
|
||||
# Check if the user has a session cookie from SAML
|
||||
if AUTH_TYPE == AuthType.SAML:
|
||||
saved_cookie = extract_hashed_cookie(request)
|
||||
|
||||
if saved_cookie:
|
||||
saml_account = get_saml_account(cookie=saved_cookie, db_session=db_session)
|
||||
user = saml_account.user if saml_account else None
|
||||
|
||||
if user is None:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail="Access denied. User is not authenticated.",
|
||||
)
|
||||
|
||||
return user
|
0
backend/ee/danswer/configs/__init__.py
Normal file
0
backend/ee/danswer/configs/__init__.py
Normal file
10
backend/ee/danswer/configs/app_configs.py
Normal file
10
backend/ee/danswer/configs/app_configs.py
Normal file
@@ -0,0 +1,10 @@
|
||||
import os
|
||||
|
||||
# Applicable for OIDC Auth
|
||||
OPENID_CONFIG_URL = os.environ.get("OPENID_CONFIG_URL", "")
|
||||
|
||||
# Applicable for SAML Auth
|
||||
SAML_CONF_DIR = (
|
||||
os.environ.get("SAML_CONF_DIR")
|
||||
or "/app/danswer/backend/ee/danswer/configs/saml_config"
|
||||
)
|
@@ -0,0 +1,20 @@
|
||||
{
|
||||
"strict": true,
|
||||
"debug": false,
|
||||
"idp": {
|
||||
"entityId": "<Provide This from IDP>",
|
||||
"singleSignOnService": {
|
||||
"url": "<Replace this with your IDP URL> https://trial-1234567.okta.com/home/trial-1234567_danswer/somevalues/somevalues",
|
||||
"binding": "urn:oasis:names:tc:SAML:2.0:bindings:HTTP-Redirect"
|
||||
},
|
||||
"x509cert": "<Provide this>"
|
||||
},
|
||||
"sp": {
|
||||
"entityId": "<Provide This from IDP>",
|
||||
"assertionConsumerService": {
|
||||
"url": "http://127.0.0.1:3000/auth/saml/callback",
|
||||
"binding": "urn:oasis:names:tc:SAML:2.0:bindings:HTTP-POST"
|
||||
},
|
||||
"x509cert": "<Provide this>"
|
||||
}
|
||||
}
|
0
backend/ee/danswer/db/__init__.py
Normal file
0
backend/ee/danswer/db/__init__.py
Normal file
26
backend/ee/danswer/db/models.py
Normal file
26
backend/ee/danswer/db/models.py
Normal file
@@ -0,0 +1,26 @@
|
||||
import datetime
|
||||
|
||||
from sqlalchemy import DateTime
|
||||
from sqlalchemy import ForeignKey
|
||||
from sqlalchemy import func
|
||||
from sqlalchemy import Text
|
||||
from sqlalchemy.orm import Mapped
|
||||
from sqlalchemy.orm import mapped_column
|
||||
from sqlalchemy.orm import relationship
|
||||
|
||||
from danswer.db.models import Base
|
||||
from danswer.db.models import User
|
||||
|
||||
|
||||
class SamlAccount(Base):
|
||||
__tablename__ = "saml"
|
||||
|
||||
id: Mapped[int] = mapped_column(primary_key=True)
|
||||
user_id: Mapped[int] = mapped_column(ForeignKey("user.id"), unique=True)
|
||||
encrypted_cookie: Mapped[str] = mapped_column(Text, unique=True)
|
||||
expires_at: Mapped[datetime.datetime] = mapped_column(DateTime(timezone=True))
|
||||
updated_at: Mapped[datetime.datetime] = mapped_column(
|
||||
DateTime(timezone=True), server_default=func.now(), onupdate=func.now()
|
||||
)
|
||||
|
||||
user: Mapped[User] = relationship("User")
|
65
backend/ee/danswer/db/saml.py
Normal file
65
backend/ee/danswer/db/saml.py
Normal file
@@ -0,0 +1,65 @@
|
||||
import datetime
|
||||
from typing import cast
|
||||
from uuid import UUID
|
||||
|
||||
from sqlalchemy import and_
|
||||
from sqlalchemy import func
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from danswer.configs.app_configs import SESSION_EXPIRE_TIME_SECONDS
|
||||
from danswer.db.models import User
|
||||
from ee.danswer.db.models import SamlAccount
|
||||
|
||||
|
||||
def upsert_saml_account(
|
||||
user_id: UUID,
|
||||
cookie: str,
|
||||
db_session: Session,
|
||||
expiration_offset: int = SESSION_EXPIRE_TIME_SECONDS,
|
||||
) -> datetime.datetime:
|
||||
expires_at = func.now() + datetime.timedelta(seconds=expiration_offset)
|
||||
|
||||
existing_saml_acc = (
|
||||
db_session.query(SamlAccount)
|
||||
.filter(SamlAccount.user_id == user_id)
|
||||
.one_or_none()
|
||||
)
|
||||
|
||||
if existing_saml_acc:
|
||||
existing_saml_acc.encrypted_cookie = cookie
|
||||
existing_saml_acc.expires_at = cast(datetime.datetime, expires_at)
|
||||
existing_saml_acc.updated_at = func.now()
|
||||
saml_acc = existing_saml_acc
|
||||
else:
|
||||
saml_acc = SamlAccount(
|
||||
user_id=user_id,
|
||||
encrypted_cookie=cookie,
|
||||
expires_at=expires_at,
|
||||
)
|
||||
db_session.add(saml_acc)
|
||||
|
||||
db_session.commit()
|
||||
|
||||
return saml_acc.expires_at
|
||||
|
||||
|
||||
def get_saml_account(cookie: str, db_session: Session) -> SamlAccount | None:
|
||||
stmt = (
|
||||
select(SamlAccount)
|
||||
.join(User, User.id == SamlAccount.user_id) # type: ignore
|
||||
.where(
|
||||
and_(
|
||||
SamlAccount.encrypted_cookie == cookie,
|
||||
SamlAccount.expires_at > func.now(),
|
||||
)
|
||||
)
|
||||
)
|
||||
|
||||
result = db_session.execute(stmt)
|
||||
return result.scalar_one_or_none()
|
||||
|
||||
|
||||
def expire_saml_account(saml_account: SamlAccount, db_session: Session) -> None:
|
||||
saml_account.expires_at = func.now()
|
||||
db_session.commit()
|
64
backend/ee/danswer/main.py
Normal file
64
backend/ee/danswer/main.py
Normal file
@@ -0,0 +1,64 @@
|
||||
import uvicorn
|
||||
from fastapi import FastAPI
|
||||
from httpx_oauth.clients.openid import OpenID
|
||||
|
||||
from danswer.auth.users import auth_backend
|
||||
from danswer.auth.users import fastapi_users
|
||||
from danswer.configs.app_configs import APP_HOST
|
||||
from danswer.configs.app_configs import APP_PORT
|
||||
from danswer.configs.app_configs import AUTH_TYPE
|
||||
from danswer.configs.app_configs import OAUTH_CLIENT_ID
|
||||
from danswer.configs.app_configs import OAUTH_CLIENT_SECRET
|
||||
from danswer.configs.app_configs import SECRET
|
||||
from danswer.configs.app_configs import WEB_DOMAIN
|
||||
from danswer.configs.constants import AuthType
|
||||
from danswer.main import get_application
|
||||
from danswer.utils.logger import setup_logger
|
||||
from danswer.utils.variable_functionality import global_version
|
||||
from ee.danswer.configs.app_configs import OPENID_CONFIG_URL
|
||||
from ee.danswer.server.saml import router as saml_router
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
def get_ee_application() -> FastAPI:
|
||||
# Anything that happens at import time is not guaranteed to be running ee-version
|
||||
# Anything after the server startup will be running ee version
|
||||
global_version.set_ee()
|
||||
|
||||
application = get_application()
|
||||
|
||||
if AUTH_TYPE == AuthType.OIDC:
|
||||
application.include_router(
|
||||
fastapi_users.get_oauth_router(
|
||||
OpenID(OAUTH_CLIENT_ID, OAUTH_CLIENT_SECRET, OPENID_CONFIG_URL),
|
||||
auth_backend,
|
||||
SECRET,
|
||||
associate_by_email=True,
|
||||
is_verified_by_default=True,
|
||||
redirect_url=f"{WEB_DOMAIN}/auth/oidc/callback",
|
||||
),
|
||||
prefix="/auth/oidc",
|
||||
tags=["auth"],
|
||||
)
|
||||
# need basic auth router for `logout` endpoint
|
||||
application.include_router(
|
||||
fastapi_users.get_auth_router(auth_backend),
|
||||
prefix="/auth",
|
||||
tags=["auth"],
|
||||
)
|
||||
|
||||
elif AUTH_TYPE == AuthType.SAML:
|
||||
application.include_router(saml_router)
|
||||
|
||||
return application
|
||||
|
||||
|
||||
app = get_ee_application()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
logger.info(
|
||||
f"Running Enterprise Danswer API Service on http://{APP_HOST}:{str(APP_PORT)}/"
|
||||
)
|
||||
uvicorn.run(app, host=APP_HOST, port=APP_PORT)
|
0
backend/ee/danswer/server/__init__.py
Normal file
0
backend/ee/danswer/server/__init__.py
Normal file
177
backend/ee/danswer/server/saml.py
Normal file
177
backend/ee/danswer/server/saml.py
Normal file
@@ -0,0 +1,177 @@
|
||||
import contextlib
|
||||
import secrets
|
||||
from typing import Any
|
||||
|
||||
from fastapi import APIRouter
|
||||
from fastapi import Depends
|
||||
from fastapi import HTTPException
|
||||
from fastapi import Request
|
||||
from fastapi import Response
|
||||
from fastapi import status
|
||||
from fastapi_users import exceptions
|
||||
from fastapi_users.password import PasswordHelper
|
||||
from onelogin.saml2.auth import OneLogin_Saml2_Auth # type: ignore
|
||||
from pydantic import BaseModel
|
||||
from pydantic import EmailStr
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from danswer.auth.schemas import UserCreate
|
||||
from danswer.auth.schemas import UserRole
|
||||
from danswer.auth.users import get_user_manager
|
||||
from danswer.configs.app_configs import SESSION_EXPIRE_TIME_SECONDS
|
||||
from danswer.db.auth import get_user_count
|
||||
from danswer.db.auth import get_user_db
|
||||
from danswer.db.engine import get_async_session
|
||||
from danswer.db.engine import get_session
|
||||
from danswer.db.models import User
|
||||
from danswer.utils.logger import setup_logger
|
||||
from ee.danswer.configs.app_configs import SAML_CONF_DIR
|
||||
from ee.danswer.db.saml import expire_saml_account
|
||||
from ee.danswer.db.saml import get_saml_account
|
||||
from ee.danswer.db.saml import upsert_saml_account
|
||||
from ee.danswer.utils.secrets import encrypt_string
|
||||
from ee.danswer.utils.secrets import extract_hashed_cookie
|
||||
|
||||
|
||||
logger = setup_logger()
|
||||
router = APIRouter(prefix="/auth/saml")
|
||||
|
||||
|
||||
async def upsert_saml_user(email: str) -> User:
|
||||
get_async_session_context = contextlib.asynccontextmanager(
|
||||
get_async_session
|
||||
) # type:ignore
|
||||
get_user_db_context = contextlib.asynccontextmanager(get_user_db)
|
||||
get_user_manager_context = contextlib.asynccontextmanager(get_user_manager)
|
||||
|
||||
async with get_async_session_context() as session:
|
||||
async with get_user_db_context(session) as user_db:
|
||||
async with get_user_manager_context(user_db) as user_manager:
|
||||
try:
|
||||
return await user_manager.get_by_email(email)
|
||||
except exceptions.UserNotExists:
|
||||
logger.info("Creating user from SAML login")
|
||||
|
||||
user_count = await get_user_count()
|
||||
role = UserRole.ADMIN if user_count == 0 else UserRole.BASIC
|
||||
|
||||
fastapi_users_pw_helper = PasswordHelper()
|
||||
password = fastapi_users_pw_helper.generate()
|
||||
hashed_pass = fastapi_users_pw_helper.hash(password)
|
||||
|
||||
user: User = await user_manager.create(
|
||||
UserCreate(
|
||||
email=EmailStr(email),
|
||||
password=hashed_pass,
|
||||
is_verified=True,
|
||||
role=role,
|
||||
)
|
||||
)
|
||||
|
||||
return user
|
||||
|
||||
|
||||
async def prepare_from_fastapi_request(request: Request) -> dict[str, Any]:
|
||||
form_data = await request.form()
|
||||
if request.client is None:
|
||||
raise ValueError("Invalid request for SAML")
|
||||
|
||||
rv: dict[str, Any] = {
|
||||
"http_host": request.client.host,
|
||||
"server_port": request.url.port,
|
||||
"script_name": request.url.path,
|
||||
"post_data": {},
|
||||
"get_data": {},
|
||||
}
|
||||
if request.query_params:
|
||||
rv["get_data"] = (request.query_params,)
|
||||
if "SAMLResponse" in form_data:
|
||||
SAMLResponse = form_data["SAMLResponse"]
|
||||
rv["post_data"]["SAMLResponse"] = SAMLResponse
|
||||
if "RelayState" in form_data:
|
||||
RelayState = form_data["RelayState"]
|
||||
rv["post_data"]["RelayState"] = RelayState
|
||||
return rv
|
||||
|
||||
|
||||
class SAMLAuthorizeResponse(BaseModel):
|
||||
authorization_url: str
|
||||
|
||||
|
||||
@router.get("/authorize")
|
||||
async def saml_login(request: Request) -> SAMLAuthorizeResponse:
|
||||
req = await prepare_from_fastapi_request(request)
|
||||
auth = OneLogin_Saml2_Auth(req, custom_base_path=SAML_CONF_DIR)
|
||||
callback_url = auth.login()
|
||||
return SAMLAuthorizeResponse(authorization_url=callback_url)
|
||||
|
||||
|
||||
@router.post("/callback")
|
||||
async def saml_login_callback(
|
||||
request: Request,
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> Response:
|
||||
req = await prepare_from_fastapi_request(request)
|
||||
auth = OneLogin_Saml2_Auth(req, custom_base_path=SAML_CONF_DIR)
|
||||
auth.process_response()
|
||||
errors = auth.get_errors()
|
||||
if len(errors) != 0:
|
||||
logger.error(
|
||||
"Error when processing SAML Response: %s %s"
|
||||
% (", ".join(errors), auth.get_last_error_reason())
|
||||
)
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail="Access denied. Failed to parse SAML Response.",
|
||||
)
|
||||
|
||||
if not auth.is_authenticated():
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail="Access denied. User was not Authenticated.",
|
||||
)
|
||||
|
||||
user_email = auth.get_attribute("email")
|
||||
if not user_email:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail="SAML is not set up correctly, email attribute must be provided.",
|
||||
)
|
||||
|
||||
user_email = user_email[0]
|
||||
|
||||
user = await upsert_saml_user(email=user_email)
|
||||
|
||||
# Generate a random session cookie and Sha256 encrypt before saving
|
||||
session_cookie = secrets.token_hex(16)
|
||||
saved_cookie = encrypt_string(session_cookie)
|
||||
|
||||
upsert_saml_account(user_id=user.id, cookie=saved_cookie, db_session=db_session)
|
||||
|
||||
# Redirect to main Danswer search page
|
||||
response = Response(status_code=status.HTTP_204_NO_CONTENT)
|
||||
|
||||
response.set_cookie(
|
||||
key="session",
|
||||
value=session_cookie,
|
||||
httponly=True,
|
||||
secure=True,
|
||||
max_age=SESSION_EXPIRE_TIME_SECONDS,
|
||||
)
|
||||
|
||||
return response
|
||||
|
||||
|
||||
@router.post("/logout")
|
||||
def saml_logout(
|
||||
request: Request,
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> None:
|
||||
saved_cookie = extract_hashed_cookie(request)
|
||||
|
||||
if saved_cookie:
|
||||
saml_account = get_saml_account(cookie=saved_cookie, db_session=db_session)
|
||||
if saml_account:
|
||||
expire_saml_account(saml_account, db_session)
|
||||
|
||||
return
|
0
backend/ee/danswer/utils/__init__.py
Normal file
0
backend/ee/danswer/utils/__init__.py
Normal file
14
backend/ee/danswer/utils/secrets.py
Normal file
14
backend/ee/danswer/utils/secrets.py
Normal file
@@ -0,0 +1,14 @@
|
||||
import hashlib
|
||||
|
||||
from fastapi import Request
|
||||
|
||||
from danswer.configs.constants import SESSION_KEY
|
||||
|
||||
|
||||
def encrypt_string(s: str) -> str:
|
||||
return hashlib.sha256(s.encode()).hexdigest()
|
||||
|
||||
|
||||
def extract_hashed_cookie(request: Request) -> str | None:
|
||||
session_cookie = request.cookies.get(SESSION_KEY)
|
||||
return encrypt_string(session_cookie) if session_cookie else None
|
Reference in New Issue
Block a user