Tenant aware JWT strategy (#2943)

* add tenantJWTSrategy

* nit
This commit is contained in:
pablodanswer 2024-10-26 16:27:40 -07:00 committed by GitHub
parent 088551a4ef
commit 1261d859ac
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
6 changed files with 34 additions and 44 deletions

View File

@ -62,7 +62,6 @@ from danswer.configs.app_configs import DISABLE_VERIFICATION
from danswer.configs.app_configs import EMAIL_FROM
from danswer.configs.app_configs import MULTI_TENANT
from danswer.configs.app_configs import REQUIRE_EMAIL_VERIFICATION
from danswer.configs.app_configs import SECRET_JWT_KEY
from danswer.configs.app_configs import SESSION_EXPIRE_TIME_SECONDS
from danswer.configs.app_configs import SMTP_PASS
from danswer.configs.app_configs import SMTP_PORT
@ -295,29 +294,6 @@ class UserManager(UUIDIDMixin, BaseUserManager[User, uuid.UUID]):
CURRENT_TENANT_ID_CONTEXTVAR.reset(token)
return user
async def on_after_login(
self,
user: User,
request: Request | None = None,
response: Response | None = None,
) -> None:
if response is None or not MULTI_TENANT:
return
tenant_id = get_tenant_id_for_email(user.email)
tenant_token = jwt.encode(
{"tenant_id": tenant_id}, SECRET_JWT_KEY, algorithm="HS256"
)
response.set_cookie(
key="tenant_details",
value=tenant_token,
httponly=True,
secure=WEB_DOMAIN.startswith("https"),
samesite="lax",
)
async def oauth_callback(
self: "BaseUserManager[models.UOAP, models.ID]",
oauth_name: str,
@ -527,8 +503,22 @@ cookie_transport = CookieTransport(
)
# This strategy is used to add tenant_id to the JWT token
class TenantAwareJWTStrategy(JWTStrategy):
async def write_token(self, user: User) -> str:
tenant_id = get_tenant_id_for_email(user.email)
data = {
"sub": str(user.id),
"aud": self.token_audience,
"tenant_id": tenant_id,
}
return generate_jwt(
data, self.encode_key, self.lifetime_seconds, algorithm=self.algorithm
)
def get_jwt_strategy() -> JWTStrategy:
return JWTStrategy(
return TenantAwareJWTStrategy(
secret=USER_AUTH_SECRET,
lifetime_seconds=SESSION_EXPIRE_TIME_SECONDS,
)

View File

@ -472,9 +472,6 @@ MANAGED_VESPA = os.environ.get("MANAGED_VESPA", "").lower() == "true"
ENABLE_EMAIL_INVITES = os.environ.get("ENABLE_EMAIL_INVITES", "").lower() == "true"
# Security and authentication
SECRET_JWT_KEY = os.environ.get(
"SECRET_JWT_KEY", ""
) # Used for encryption of the JWT token for user's tenant context
DATA_PLANE_SECRET = os.environ.get(
"DATA_PLANE_SECRET", ""
) # Used for secure communication between the control and data plane

View File

@ -35,7 +35,7 @@ from danswer.configs.app_configs import POSTGRES_POOL_PRE_PING
from danswer.configs.app_configs import POSTGRES_POOL_RECYCLE
from danswer.configs.app_configs import POSTGRES_PORT
from danswer.configs.app_configs import POSTGRES_USER
from danswer.configs.app_configs import SECRET_JWT_KEY
from danswer.configs.app_configs import USER_AUTH_SECRET
from danswer.configs.constants import POSTGRES_UNKNOWN_APP_NAME
from danswer.configs.constants import TENANT_ID_PREFIX
from danswer.utils.logger import setup_logger
@ -263,17 +263,20 @@ def get_current_tenant_id(request: Request) -> str:
CURRENT_TENANT_ID_CONTEXTVAR.set(tenant_id)
return tenant_id
token = request.cookies.get("tenant_details")
token = request.cookies.get("fastapiusersauth")
if not token:
current_value = CURRENT_TENANT_ID_CONTEXTVAR.get()
# If no token is present, use the default schema or handle accordingly
return current_value
try:
payload = jwt.decode(token, SECRET_JWT_KEY, algorithms=["HS256"])
tenant_id = payload.get("tenant_id")
if not tenant_id:
return CURRENT_TENANT_ID_CONTEXTVAR.get()
payload = jwt.decode(
token,
USER_AUTH_SECRET,
audience=["fastapi-users:auth"],
algorithms=["HS256"],
)
tenant_id = payload.get("tenant_id", POSTGRES_DEFAULT_SCHEMA)
if not is_valid_schema_name(tenant_id):
raise HTTPException(status_code=400, detail="Invalid tenant ID format")
CURRENT_TENANT_ID_CONTEXTVAR.set(tenant_id)

View File

@ -9,7 +9,7 @@ from fastapi import Request
from fastapi import Response
from danswer.configs.app_configs import MULTI_TENANT
from danswer.configs.app_configs import SECRET_JWT_KEY
from danswer.configs.app_configs import USER_AUTH_SECRET
from danswer.db.engine import is_valid_schema_name
from shared_configs.configs import CURRENT_TENANT_ID_CONTEXTVAR
from shared_configs.configs import POSTGRES_DEFAULT_SCHEMA
@ -25,11 +25,15 @@ def add_tenant_id_middleware(app: FastAPI, logger: logging.LoggerAdapter) -> Non
if not MULTI_TENANT:
tenant_id = POSTGRES_DEFAULT_SCHEMA
else:
token = request.cookies.get("tenant_details")
token = request.cookies.get("fastapiusersauth")
if token:
try:
payload = jwt.decode(
token, SECRET_JWT_KEY, algorithms=["HS256"]
token,
USER_AUTH_SECRET,
audience=["fastapi-users:auth"],
algorithms=["HS256"],
)
tenant_id = payload.get("tenant_id", POSTGRES_DEFAULT_SCHEMA)
if not is_valid_schema_name(tenant_id):

View File

@ -70,18 +70,14 @@ class UserManager:
cookies = response.cookies.get_dict()
session_cookie = cookies.get("fastapiusersauth")
tenant_details_cookie = cookies.get("tenant_details")
if not session_cookie:
raise Exception("Failed to login")
print(f"Logged in as {test_user.email}")
# Set both cookies in the headers
test_user.headers["Cookie"] = (
f"fastapiusersauth={session_cookie}; "
f"tenant_details={tenant_details_cookie}"
)
# Set cookies in the headers
test_user.headers["Cookie"] = f"fastapiusersauth={session_cookie}; "
return test_user
@staticmethod

View File

@ -14,7 +14,7 @@ export const POST = async (request: NextRequest) => {
// Delete cookies only if cloud is enabled (jwt auth)
if (NEXT_PUBLIC_CLOUD_ENABLED) {
const cookiesToDelete = ["fastapiusersauth", "tenant_details"];
const cookiesToDelete = ["fastapiusersauth"];
const cookieOptions = {
path: "/",
secure: process.env.NODE_ENV === "production",