diff --git a/backend/danswer/auth/users.py b/backend/danswer/auth/users.py index 7851f291f..f12554c94 100644 --- a/backend/danswer/auth/users.py +++ b/backend/danswer/auth/users.py @@ -1,4 +1,5 @@ import contextlib +import os import smtplib import uuid from collections.abc import AsyncGenerator @@ -49,6 +50,27 @@ logger = setup_logger() FAKE_USER_EMAIL = "fakeuser@fakedanswermail.com" FAKE_USER_PASS = "foobar" +USER_WHITELIST_FILE = "/home/danswer_whitelist.txt" +_user_whitelist: list[str] | None = None + + +def get_user_whitelist() -> list[str]: + global _user_whitelist + if _user_whitelist is None: + if os.path.exists(USER_WHITELIST_FILE): + with open(USER_WHITELIST_FILE, "r") as file: + _user_whitelist = [line.strip() for line in file] + else: + _user_whitelist = [] + + return _user_whitelist + + +def verify_email_in_whitelist(email: str) -> None: + whitelist = get_user_whitelist() + if (whitelist and email not in whitelist) or not email: + raise PermissionError("User not on allowed user whitelist") + def send_user_verification_email(user_email: str, token: str) -> None: msg = MIMEMultipart() @@ -79,6 +101,7 @@ class UserManager(UUIDIDMixin, BaseUserManager[User, uuid.UUID]): safe: bool = False, request: Optional[Request] = None, ) -> models.UP: + verify_email_in_whitelist(user_create.email) if hasattr(user_create, "role"): user_count = await get_user_count() if user_count == 0: @@ -87,6 +110,33 @@ class UserManager(UUIDIDMixin, BaseUserManager[User, uuid.UUID]): user_create.role = UserRole.BASIC return await super().create(user_create, safe=safe, request=request) # type: ignore + async def oauth_callback( + self: "BaseUserManager[models.UOAP, models.ID]", + oauth_name: str, + access_token: str, + account_id: str, + account_email: str, + expires_at: Optional[int] = None, + refresh_token: Optional[str] = None, + request: Optional[Request] = None, + *, + associate_by_email: bool = False, + is_verified_by_default: bool = False, + ) -> models.UOAP: + verify_email_in_whitelist(account_email) + + return await super().oauth_callback( # type: ignore + oauth_name=oauth_name, + access_token=access_token, + account_id=account_id, + account_email=account_email, + expires_at=expires_at, + refresh_token=refresh_token, + request=request, + associate_by_email=associate_by_email, + is_verified_by_default=is_verified_by_default, + ) + async def on_after_register( self, user: User, request: Optional[Request] = None ) -> None: