Initial EE features (#3)

This commit is contained in:
Yuhong Sun
2023-10-02 19:39:50 -07:00
committed by Chris Weaver
parent 65d5808ea7
commit 92de6acc6f
30 changed files with 525 additions and 26 deletions

36
backend/ee/LICENSE Normal file
View File

@@ -0,0 +1,36 @@
The DanswerAI Enterprise license (the “Enterprise License”)
Copyright (c) 2023 DanswerAI, Inc.
With regard to the Danswer Software:
This software and associated documentation files (the "Software") may only be
used in production, if you (and any entity that you represent) have agreed to,
and are in compliance with, the DanswerAI Subscription Terms of Service, available
at https://danswer.ai/terms (the “Enterprise Terms”), or other
agreement governing the use of the Software, as agreed by you and DanswerAI,
and otherwise have a valid Danswer Enterprise license for the
correct number of user seats. Subject to the foregoing sentence, you are free to
modify this Software and publish patches to the Software. You agree that DanswerAI
and/or its licensors (as applicable) retain all right, title and interest in and
to all such modifications and/or patches, and all such modifications and/or
patches may only be used, copied, modified, displayed, distributed, or otherwise
exploited with a valid Danswer Enterprise license for the correct
number of user seats. Notwithstanding the foregoing, you may copy and modify
the Software for development and testing purposes, without requiring a
subscription. You agree that DanswerAI and/or its licensors (as applicable) retain
all right, title and interest in and to all such modifications. You are not
granted any other rights beyond what is expressly stated herein. Subject to the
foregoing, it is forbidden to copy, merge, publish, distribute, sublicense,
and/or sell the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.
For all third party components incorporated into the Danswer Software, those
components are licensed under the original license provided by the owner of the
applicable component.

0
backend/ee/__init__.py Normal file
View File

View File

View File

View File

@@ -0,0 +1,45 @@
from fastapi import HTTPException
from fastapi import Request
from fastapi import status
from sqlalchemy.orm import Session
from danswer.configs.app_configs import AUTH_TYPE
from danswer.configs.app_configs import DISABLE_AUTH
from danswer.configs.constants import AuthType
from danswer.db.models import User
from danswer.utils.logger import setup_logger
from ee.danswer.db.saml import get_saml_account
from ee.danswer.utils.secrets import extract_hashed_cookie
logger = setup_logger()
def verify_auth_setting() -> None:
# All the Auth flows are valid for EE version
logger.info(f"Using Auth Type: {AUTH_TYPE.value}")
async def double_check_user(
request: Request,
user: User | None,
db_session: Session,
optional: bool = DISABLE_AUTH,
) -> User | None:
if optional:
return None
# Check if the user has a session cookie from SAML
if AUTH_TYPE == AuthType.SAML:
saved_cookie = extract_hashed_cookie(request)
if saved_cookie:
saml_account = get_saml_account(cookie=saved_cookie, db_session=db_session)
user = saml_account.user if saml_account else None
if user is None:
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail="Access denied. User is not authenticated.",
)
return user

View File

View File

@@ -0,0 +1,10 @@
import os
# Applicable for OIDC Auth
OPENID_CONFIG_URL = os.environ.get("OPENID_CONFIG_URL", "")
# Applicable for SAML Auth
SAML_CONF_DIR = (
os.environ.get("SAML_CONF_DIR")
or "/app/danswer/backend/ee/danswer/configs/saml_config"
)

View File

@@ -0,0 +1,20 @@
{
"strict": true,
"debug": false,
"idp": {
"entityId": "<Provide This from IDP>",
"singleSignOnService": {
"url": "<Replace this with your IDP URL> https://trial-1234567.okta.com/home/trial-1234567_danswer/somevalues/somevalues",
"binding": "urn:oasis:names:tc:SAML:2.0:bindings:HTTP-Redirect"
},
"x509cert": "<Provide this>"
},
"sp": {
"entityId": "<Provide This from IDP>",
"assertionConsumerService": {
"url": "http://127.0.0.1:3000/auth/saml/callback",
"binding": "urn:oasis:names:tc:SAML:2.0:bindings:HTTP-POST"
},
"x509cert": "<Provide this>"
}
}

View File

View File

@@ -0,0 +1,26 @@
import datetime
from sqlalchemy import DateTime
from sqlalchemy import ForeignKey
from sqlalchemy import func
from sqlalchemy import Text
from sqlalchemy.orm import Mapped
from sqlalchemy.orm import mapped_column
from sqlalchemy.orm import relationship
from danswer.db.models import Base
from danswer.db.models import User
class SamlAccount(Base):
__tablename__ = "saml"
id: Mapped[int] = mapped_column(primary_key=True)
user_id: Mapped[int] = mapped_column(ForeignKey("user.id"), unique=True)
encrypted_cookie: Mapped[str] = mapped_column(Text, unique=True)
expires_at: Mapped[datetime.datetime] = mapped_column(DateTime(timezone=True))
updated_at: Mapped[datetime.datetime] = mapped_column(
DateTime(timezone=True), server_default=func.now(), onupdate=func.now()
)
user: Mapped[User] = relationship("User")

View File

@@ -0,0 +1,65 @@
import datetime
from typing import cast
from uuid import UUID
from sqlalchemy import and_
from sqlalchemy import func
from sqlalchemy import select
from sqlalchemy.orm import Session
from danswer.configs.app_configs import SESSION_EXPIRE_TIME_SECONDS
from danswer.db.models import User
from ee.danswer.db.models import SamlAccount
def upsert_saml_account(
user_id: UUID,
cookie: str,
db_session: Session,
expiration_offset: int = SESSION_EXPIRE_TIME_SECONDS,
) -> datetime.datetime:
expires_at = func.now() + datetime.timedelta(seconds=expiration_offset)
existing_saml_acc = (
db_session.query(SamlAccount)
.filter(SamlAccount.user_id == user_id)
.one_or_none()
)
if existing_saml_acc:
existing_saml_acc.encrypted_cookie = cookie
existing_saml_acc.expires_at = cast(datetime.datetime, expires_at)
existing_saml_acc.updated_at = func.now()
saml_acc = existing_saml_acc
else:
saml_acc = SamlAccount(
user_id=user_id,
encrypted_cookie=cookie,
expires_at=expires_at,
)
db_session.add(saml_acc)
db_session.commit()
return saml_acc.expires_at
def get_saml_account(cookie: str, db_session: Session) -> SamlAccount | None:
stmt = (
select(SamlAccount)
.join(User, User.id == SamlAccount.user_id) # type: ignore
.where(
and_(
SamlAccount.encrypted_cookie == cookie,
SamlAccount.expires_at > func.now(),
)
)
)
result = db_session.execute(stmt)
return result.scalar_one_or_none()
def expire_saml_account(saml_account: SamlAccount, db_session: Session) -> None:
saml_account.expires_at = func.now()
db_session.commit()

View File

@@ -0,0 +1,64 @@
import uvicorn
from fastapi import FastAPI
from httpx_oauth.clients.openid import OpenID
from danswer.auth.users import auth_backend
from danswer.auth.users import fastapi_users
from danswer.configs.app_configs import APP_HOST
from danswer.configs.app_configs import APP_PORT
from danswer.configs.app_configs import AUTH_TYPE
from danswer.configs.app_configs import OAUTH_CLIENT_ID
from danswer.configs.app_configs import OAUTH_CLIENT_SECRET
from danswer.configs.app_configs import SECRET
from danswer.configs.app_configs import WEB_DOMAIN
from danswer.configs.constants import AuthType
from danswer.main import get_application
from danswer.utils.logger import setup_logger
from danswer.utils.variable_functionality import global_version
from ee.danswer.configs.app_configs import OPENID_CONFIG_URL
from ee.danswer.server.saml import router as saml_router
logger = setup_logger()
def get_ee_application() -> FastAPI:
# Anything that happens at import time is not guaranteed to be running ee-version
# Anything after the server startup will be running ee version
global_version.set_ee()
application = get_application()
if AUTH_TYPE == AuthType.OIDC:
application.include_router(
fastapi_users.get_oauth_router(
OpenID(OAUTH_CLIENT_ID, OAUTH_CLIENT_SECRET, OPENID_CONFIG_URL),
auth_backend,
SECRET,
associate_by_email=True,
is_verified_by_default=True,
redirect_url=f"{WEB_DOMAIN}/auth/oidc/callback",
),
prefix="/auth/oidc",
tags=["auth"],
)
# need basic auth router for `logout` endpoint
application.include_router(
fastapi_users.get_auth_router(auth_backend),
prefix="/auth",
tags=["auth"],
)
elif AUTH_TYPE == AuthType.SAML:
application.include_router(saml_router)
return application
app = get_ee_application()
if __name__ == "__main__":
logger.info(
f"Running Enterprise Danswer API Service on http://{APP_HOST}:{str(APP_PORT)}/"
)
uvicorn.run(app, host=APP_HOST, port=APP_PORT)

View File

View File

@@ -0,0 +1,177 @@
import contextlib
import secrets
from typing import Any
from fastapi import APIRouter
from fastapi import Depends
from fastapi import HTTPException
from fastapi import Request
from fastapi import Response
from fastapi import status
from fastapi_users import exceptions
from fastapi_users.password import PasswordHelper
from onelogin.saml2.auth import OneLogin_Saml2_Auth # type: ignore
from pydantic import BaseModel
from pydantic import EmailStr
from sqlalchemy.orm import Session
from danswer.auth.schemas import UserCreate
from danswer.auth.schemas import UserRole
from danswer.auth.users import get_user_manager
from danswer.configs.app_configs import SESSION_EXPIRE_TIME_SECONDS
from danswer.db.auth import get_user_count
from danswer.db.auth import get_user_db
from danswer.db.engine import get_async_session
from danswer.db.engine import get_session
from danswer.db.models import User
from danswer.utils.logger import setup_logger
from ee.danswer.configs.app_configs import SAML_CONF_DIR
from ee.danswer.db.saml import expire_saml_account
from ee.danswer.db.saml import get_saml_account
from ee.danswer.db.saml import upsert_saml_account
from ee.danswer.utils.secrets import encrypt_string
from ee.danswer.utils.secrets import extract_hashed_cookie
logger = setup_logger()
router = APIRouter(prefix="/auth/saml")
async def upsert_saml_user(email: str) -> User:
get_async_session_context = contextlib.asynccontextmanager(
get_async_session
) # type:ignore
get_user_db_context = contextlib.asynccontextmanager(get_user_db)
get_user_manager_context = contextlib.asynccontextmanager(get_user_manager)
async with get_async_session_context() as session:
async with get_user_db_context(session) as user_db:
async with get_user_manager_context(user_db) as user_manager:
try:
return await user_manager.get_by_email(email)
except exceptions.UserNotExists:
logger.info("Creating user from SAML login")
user_count = await get_user_count()
role = UserRole.ADMIN if user_count == 0 else UserRole.BASIC
fastapi_users_pw_helper = PasswordHelper()
password = fastapi_users_pw_helper.generate()
hashed_pass = fastapi_users_pw_helper.hash(password)
user: User = await user_manager.create(
UserCreate(
email=EmailStr(email),
password=hashed_pass,
is_verified=True,
role=role,
)
)
return user
async def prepare_from_fastapi_request(request: Request) -> dict[str, Any]:
form_data = await request.form()
if request.client is None:
raise ValueError("Invalid request for SAML")
rv: dict[str, Any] = {
"http_host": request.client.host,
"server_port": request.url.port,
"script_name": request.url.path,
"post_data": {},
"get_data": {},
}
if request.query_params:
rv["get_data"] = (request.query_params,)
if "SAMLResponse" in form_data:
SAMLResponse = form_data["SAMLResponse"]
rv["post_data"]["SAMLResponse"] = SAMLResponse
if "RelayState" in form_data:
RelayState = form_data["RelayState"]
rv["post_data"]["RelayState"] = RelayState
return rv
class SAMLAuthorizeResponse(BaseModel):
authorization_url: str
@router.get("/authorize")
async def saml_login(request: Request) -> SAMLAuthorizeResponse:
req = await prepare_from_fastapi_request(request)
auth = OneLogin_Saml2_Auth(req, custom_base_path=SAML_CONF_DIR)
callback_url = auth.login()
return SAMLAuthorizeResponse(authorization_url=callback_url)
@router.post("/callback")
async def saml_login_callback(
request: Request,
db_session: Session = Depends(get_session),
) -> Response:
req = await prepare_from_fastapi_request(request)
auth = OneLogin_Saml2_Auth(req, custom_base_path=SAML_CONF_DIR)
auth.process_response()
errors = auth.get_errors()
if len(errors) != 0:
logger.error(
"Error when processing SAML Response: %s %s"
% (", ".join(errors), auth.get_last_error_reason())
)
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail="Access denied. Failed to parse SAML Response.",
)
if not auth.is_authenticated():
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail="Access denied. User was not Authenticated.",
)
user_email = auth.get_attribute("email")
if not user_email:
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail="SAML is not set up correctly, email attribute must be provided.",
)
user_email = user_email[0]
user = await upsert_saml_user(email=user_email)
# Generate a random session cookie and Sha256 encrypt before saving
session_cookie = secrets.token_hex(16)
saved_cookie = encrypt_string(session_cookie)
upsert_saml_account(user_id=user.id, cookie=saved_cookie, db_session=db_session)
# Redirect to main Danswer search page
response = Response(status_code=status.HTTP_204_NO_CONTENT)
response.set_cookie(
key="session",
value=session_cookie,
httponly=True,
secure=True,
max_age=SESSION_EXPIRE_TIME_SECONDS,
)
return response
@router.post("/logout")
def saml_logout(
request: Request,
db_session: Session = Depends(get_session),
) -> None:
saved_cookie = extract_hashed_cookie(request)
if saved_cookie:
saml_account = get_saml_account(cookie=saved_cookie, db_session=db_session)
if saml_account:
expire_saml_account(saml_account, db_session)
return

View File

View File

@@ -0,0 +1,14 @@
import hashlib
from fastapi import Request
from danswer.configs.constants import SESSION_KEY
def encrypt_string(s: str) -> str:
return hashlib.sha256(s.encode()).hexdigest()
def extract_hashed_cookie(request: Request) -> str | None:
session_cookie = request.cookies.get(SESSION_KEY)
return encrypt_string(session_cookie) if session_cookie else None