diff --git a/backend/danswer/auth/users.py b/backend/danswer/auth/users.py index a540cbaa4c..66659d7c6e 100644 --- a/backend/danswer/auth/users.py +++ b/backend/danswer/auth/users.py @@ -62,7 +62,6 @@ from danswer.configs.app_configs import DISABLE_VERIFICATION from danswer.configs.app_configs import EMAIL_FROM from danswer.configs.app_configs import MULTI_TENANT from danswer.configs.app_configs import REQUIRE_EMAIL_VERIFICATION -from danswer.configs.app_configs import SECRET_JWT_KEY from danswer.configs.app_configs import SESSION_EXPIRE_TIME_SECONDS from danswer.configs.app_configs import SMTP_PASS from danswer.configs.app_configs import SMTP_PORT @@ -295,29 +294,6 @@ class UserManager(UUIDIDMixin, BaseUserManager[User, uuid.UUID]): CURRENT_TENANT_ID_CONTEXTVAR.reset(token) return user - async def on_after_login( - self, - user: User, - request: Request | None = None, - response: Response | None = None, - ) -> None: - if response is None or not MULTI_TENANT: - return - - tenant_id = get_tenant_id_for_email(user.email) - - tenant_token = jwt.encode( - {"tenant_id": tenant_id}, SECRET_JWT_KEY, algorithm="HS256" - ) - - response.set_cookie( - key="tenant_details", - value=tenant_token, - httponly=True, - secure=WEB_DOMAIN.startswith("https"), - samesite="lax", - ) - async def oauth_callback( self: "BaseUserManager[models.UOAP, models.ID]", oauth_name: str, @@ -527,8 +503,22 @@ cookie_transport = CookieTransport( ) +# This strategy is used to add tenant_id to the JWT token +class TenantAwareJWTStrategy(JWTStrategy): + async def write_token(self, user: User) -> str: + tenant_id = get_tenant_id_for_email(user.email) + data = { + "sub": str(user.id), + "aud": self.token_audience, + "tenant_id": tenant_id, + } + return generate_jwt( + data, self.encode_key, self.lifetime_seconds, algorithm=self.algorithm + ) + + def get_jwt_strategy() -> JWTStrategy: - return JWTStrategy( + return TenantAwareJWTStrategy( secret=USER_AUTH_SECRET, lifetime_seconds=SESSION_EXPIRE_TIME_SECONDS, ) diff --git a/backend/danswer/configs/app_configs.py b/backend/danswer/configs/app_configs.py index 8710569ea3..453a18cd59 100644 --- a/backend/danswer/configs/app_configs.py +++ b/backend/danswer/configs/app_configs.py @@ -472,9 +472,6 @@ MANAGED_VESPA = os.environ.get("MANAGED_VESPA", "").lower() == "true" ENABLE_EMAIL_INVITES = os.environ.get("ENABLE_EMAIL_INVITES", "").lower() == "true" # Security and authentication -SECRET_JWT_KEY = os.environ.get( - "SECRET_JWT_KEY", "" -) # Used for encryption of the JWT token for user's tenant context DATA_PLANE_SECRET = os.environ.get( "DATA_PLANE_SECRET", "" ) # Used for secure communication between the control and data plane diff --git a/backend/danswer/db/engine.py b/backend/danswer/db/engine.py index b071a41e8c..d133675f91 100644 --- a/backend/danswer/db/engine.py +++ b/backend/danswer/db/engine.py @@ -35,7 +35,7 @@ from danswer.configs.app_configs import POSTGRES_POOL_PRE_PING from danswer.configs.app_configs import POSTGRES_POOL_RECYCLE from danswer.configs.app_configs import POSTGRES_PORT from danswer.configs.app_configs import POSTGRES_USER -from danswer.configs.app_configs import SECRET_JWT_KEY +from danswer.configs.app_configs import USER_AUTH_SECRET from danswer.configs.constants import POSTGRES_UNKNOWN_APP_NAME from danswer.configs.constants import TENANT_ID_PREFIX from danswer.utils.logger import setup_logger @@ -263,17 +263,20 @@ def get_current_tenant_id(request: Request) -> str: CURRENT_TENANT_ID_CONTEXTVAR.set(tenant_id) return tenant_id - token = request.cookies.get("tenant_details") + token = request.cookies.get("fastapiusersauth") if not token: current_value = CURRENT_TENANT_ID_CONTEXTVAR.get() # If no token is present, use the default schema or handle accordingly return current_value try: - payload = jwt.decode(token, SECRET_JWT_KEY, algorithms=["HS256"]) - tenant_id = payload.get("tenant_id") - if not tenant_id: - return CURRENT_TENANT_ID_CONTEXTVAR.get() + payload = jwt.decode( + token, + USER_AUTH_SECRET, + audience=["fastapi-users:auth"], + algorithms=["HS256"], + ) + tenant_id = payload.get("tenant_id", POSTGRES_DEFAULT_SCHEMA) if not is_valid_schema_name(tenant_id): raise HTTPException(status_code=400, detail="Invalid tenant ID format") CURRENT_TENANT_ID_CONTEXTVAR.set(tenant_id) diff --git a/backend/ee/danswer/server/middleware/tenant_tracking.py b/backend/ee/danswer/server/middleware/tenant_tracking.py index f7a4ab0b6a..42291712a4 100644 --- a/backend/ee/danswer/server/middleware/tenant_tracking.py +++ b/backend/ee/danswer/server/middleware/tenant_tracking.py @@ -9,7 +9,7 @@ from fastapi import Request from fastapi import Response from danswer.configs.app_configs import MULTI_TENANT -from danswer.configs.app_configs import SECRET_JWT_KEY +from danswer.configs.app_configs import USER_AUTH_SECRET from danswer.db.engine import is_valid_schema_name from shared_configs.configs import CURRENT_TENANT_ID_CONTEXTVAR from shared_configs.configs import POSTGRES_DEFAULT_SCHEMA @@ -25,11 +25,15 @@ def add_tenant_id_middleware(app: FastAPI, logger: logging.LoggerAdapter) -> Non if not MULTI_TENANT: tenant_id = POSTGRES_DEFAULT_SCHEMA else: - token = request.cookies.get("tenant_details") + token = request.cookies.get("fastapiusersauth") + if token: try: payload = jwt.decode( - token, SECRET_JWT_KEY, algorithms=["HS256"] + token, + USER_AUTH_SECRET, + audience=["fastapi-users:auth"], + algorithms=["HS256"], ) tenant_id = payload.get("tenant_id", POSTGRES_DEFAULT_SCHEMA) if not is_valid_schema_name(tenant_id): diff --git a/backend/tests/integration/common_utils/managers/user.py b/backend/tests/integration/common_utils/managers/user.py index ecc6d3206b..2b9aa6e189 100644 --- a/backend/tests/integration/common_utils/managers/user.py +++ b/backend/tests/integration/common_utils/managers/user.py @@ -70,18 +70,14 @@ class UserManager: cookies = response.cookies.get_dict() session_cookie = cookies.get("fastapiusersauth") - tenant_details_cookie = cookies.get("tenant_details") if not session_cookie: raise Exception("Failed to login") print(f"Logged in as {test_user.email}") - # Set both cookies in the headers - test_user.headers["Cookie"] = ( - f"fastapiusersauth={session_cookie}; " - f"tenant_details={tenant_details_cookie}" - ) + # Set cookies in the headers + test_user.headers["Cookie"] = f"fastapiusersauth={session_cookie}; " return test_user @staticmethod diff --git a/web/src/app/auth/logout/route.ts b/web/src/app/auth/logout/route.ts index cd731810ca..9a7d22ae0d 100644 --- a/web/src/app/auth/logout/route.ts +++ b/web/src/app/auth/logout/route.ts @@ -14,7 +14,7 @@ export const POST = async (request: NextRequest) => { // Delete cookies only if cloud is enabled (jwt auth) if (NEXT_PUBLIC_CLOUD_ENABLED) { - const cookiesToDelete = ["fastapiusersauth", "tenant_details"]; + const cookiesToDelete = ["fastapiusersauth"]; const cookieOptions = { path: "/", secure: process.env.NODE_ENV === "production",