diff --git a/backend/danswer/auth/users.py b/backend/danswer/auth/users.py index d674f560f..74de6e4fe 100644 --- a/backend/danswer/auth/users.py +++ b/backend/danswer/auth/users.py @@ -54,6 +54,7 @@ from danswer.db.auth import get_user_db from danswer.db.engine import get_session from danswer.db.models import AccessToken from danswer.db.models import User +from danswer.db.users import get_user_by_email from danswer.utils.logger import setup_logger from danswer.utils.telemetry import optional_telemetry from danswer.utils.telemetry import RecordType @@ -94,12 +95,20 @@ def user_needs_to_be_verified() -> bool: return AUTH_TYPE != AuthType.BASIC or REQUIRE_EMAIL_VERIFICATION -def verify_email_in_whitelist(email: str) -> None: +def verify_email_is_invited(email: str) -> None: whitelist = get_invited_users() if (whitelist and email not in whitelist) or not email: raise PermissionError("User not on allowed user whitelist") +def verify_email_in_whitelist( + email: str, + db_session: Session = Depends(get_session), +) -> None: + if not get_user_by_email(email, db_session): + verify_email_is_invited(email) + + def verify_email_domain(email: str) -> None: if VALID_EMAIL_DOMAINS: if email.count("@") != 1: @@ -149,7 +158,7 @@ class UserManager(UUIDIDMixin, BaseUserManager[User, uuid.UUID]): safe: bool = False, request: Optional[Request] = None, ) -> models.UP: - verify_email_in_whitelist(user_create.email) + verify_email_is_invited(user_create.email) verify_email_domain(user_create.email) if hasattr(user_create, "role"): user_count = await get_user_count()