DAN-169 Users whitelist (#153)

This commit is contained in:
Yuhong Sun 2023-07-11 21:23:35 -07:00 committed by GitHub
parent c2fa3d5074
commit d53ec8a905
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -1,4 +1,5 @@
import contextlib
import os
import smtplib
import uuid
from collections.abc import AsyncGenerator
@ -49,6 +50,27 @@ logger = setup_logger()
FAKE_USER_EMAIL = "fakeuser@fakedanswermail.com"
FAKE_USER_PASS = "foobar"
USER_WHITELIST_FILE = "/home/danswer_whitelist.txt"
_user_whitelist: list[str] | None = None
def get_user_whitelist() -> list[str]:
global _user_whitelist
if _user_whitelist is None:
if os.path.exists(USER_WHITELIST_FILE):
with open(USER_WHITELIST_FILE, "r") as file:
_user_whitelist = [line.strip() for line in file]
else:
_user_whitelist = []
return _user_whitelist
def verify_email_in_whitelist(email: str) -> None:
whitelist = get_user_whitelist()
if (whitelist and email not in whitelist) or not email:
raise PermissionError("User not on allowed user whitelist")
def send_user_verification_email(user_email: str, token: str) -> None:
msg = MIMEMultipart()
@ -79,6 +101,7 @@ class UserManager(UUIDIDMixin, BaseUserManager[User, uuid.UUID]):
safe: bool = False,
request: Optional[Request] = None,
) -> models.UP:
verify_email_in_whitelist(user_create.email)
if hasattr(user_create, "role"):
user_count = await get_user_count()
if user_count == 0:
@ -87,6 +110,33 @@ class UserManager(UUIDIDMixin, BaseUserManager[User, uuid.UUID]):
user_create.role = UserRole.BASIC
return await super().create(user_create, safe=safe, request=request) # type: ignore
async def oauth_callback(
self: "BaseUserManager[models.UOAP, models.ID]",
oauth_name: str,
access_token: str,
account_id: str,
account_email: str,
expires_at: Optional[int] = None,
refresh_token: Optional[str] = None,
request: Optional[Request] = None,
*,
associate_by_email: bool = False,
is_verified_by_default: bool = False,
) -> models.UOAP:
verify_email_in_whitelist(account_email)
return await super().oauth_callback( # type: ignore
oauth_name=oauth_name,
access_token=access_token,
account_id=account_id,
account_email=account_email,
expires_at=expires_at,
refresh_token=refresh_token,
request=request,
associate_by_email=associate_by_email,
is_verified_by_default=is_verified_by_default,
)
async def on_after_register(
self, user: User, request: Optional[Request] = None
) -> None: