mirror of
https://github.com/danswer-ai/danswer.git
synced 2025-06-28 08:51:00 +02:00
parent
088551a4ef
commit
1261d859ac
@ -62,7 +62,6 @@ from danswer.configs.app_configs import DISABLE_VERIFICATION
|
|||||||
from danswer.configs.app_configs import EMAIL_FROM
|
from danswer.configs.app_configs import EMAIL_FROM
|
||||||
from danswer.configs.app_configs import MULTI_TENANT
|
from danswer.configs.app_configs import MULTI_TENANT
|
||||||
from danswer.configs.app_configs import REQUIRE_EMAIL_VERIFICATION
|
from danswer.configs.app_configs import REQUIRE_EMAIL_VERIFICATION
|
||||||
from danswer.configs.app_configs import SECRET_JWT_KEY
|
|
||||||
from danswer.configs.app_configs import SESSION_EXPIRE_TIME_SECONDS
|
from danswer.configs.app_configs import SESSION_EXPIRE_TIME_SECONDS
|
||||||
from danswer.configs.app_configs import SMTP_PASS
|
from danswer.configs.app_configs import SMTP_PASS
|
||||||
from danswer.configs.app_configs import SMTP_PORT
|
from danswer.configs.app_configs import SMTP_PORT
|
||||||
@ -295,29 +294,6 @@ class UserManager(UUIDIDMixin, BaseUserManager[User, uuid.UUID]):
|
|||||||
CURRENT_TENANT_ID_CONTEXTVAR.reset(token)
|
CURRENT_TENANT_ID_CONTEXTVAR.reset(token)
|
||||||
return user
|
return user
|
||||||
|
|
||||||
async def on_after_login(
|
|
||||||
self,
|
|
||||||
user: User,
|
|
||||||
request: Request | None = None,
|
|
||||||
response: Response | None = None,
|
|
||||||
) -> None:
|
|
||||||
if response is None or not MULTI_TENANT:
|
|
||||||
return
|
|
||||||
|
|
||||||
tenant_id = get_tenant_id_for_email(user.email)
|
|
||||||
|
|
||||||
tenant_token = jwt.encode(
|
|
||||||
{"tenant_id": tenant_id}, SECRET_JWT_KEY, algorithm="HS256"
|
|
||||||
)
|
|
||||||
|
|
||||||
response.set_cookie(
|
|
||||||
key="tenant_details",
|
|
||||||
value=tenant_token,
|
|
||||||
httponly=True,
|
|
||||||
secure=WEB_DOMAIN.startswith("https"),
|
|
||||||
samesite="lax",
|
|
||||||
)
|
|
||||||
|
|
||||||
async def oauth_callback(
|
async def oauth_callback(
|
||||||
self: "BaseUserManager[models.UOAP, models.ID]",
|
self: "BaseUserManager[models.UOAP, models.ID]",
|
||||||
oauth_name: str,
|
oauth_name: str,
|
||||||
@ -527,8 +503,22 @@ cookie_transport = CookieTransport(
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# This strategy is used to add tenant_id to the JWT token
|
||||||
|
class TenantAwareJWTStrategy(JWTStrategy):
|
||||||
|
async def write_token(self, user: User) -> str:
|
||||||
|
tenant_id = get_tenant_id_for_email(user.email)
|
||||||
|
data = {
|
||||||
|
"sub": str(user.id),
|
||||||
|
"aud": self.token_audience,
|
||||||
|
"tenant_id": tenant_id,
|
||||||
|
}
|
||||||
|
return generate_jwt(
|
||||||
|
data, self.encode_key, self.lifetime_seconds, algorithm=self.algorithm
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def get_jwt_strategy() -> JWTStrategy:
|
def get_jwt_strategy() -> JWTStrategy:
|
||||||
return JWTStrategy(
|
return TenantAwareJWTStrategy(
|
||||||
secret=USER_AUTH_SECRET,
|
secret=USER_AUTH_SECRET,
|
||||||
lifetime_seconds=SESSION_EXPIRE_TIME_SECONDS,
|
lifetime_seconds=SESSION_EXPIRE_TIME_SECONDS,
|
||||||
)
|
)
|
||||||
|
@ -472,9 +472,6 @@ MANAGED_VESPA = os.environ.get("MANAGED_VESPA", "").lower() == "true"
|
|||||||
ENABLE_EMAIL_INVITES = os.environ.get("ENABLE_EMAIL_INVITES", "").lower() == "true"
|
ENABLE_EMAIL_INVITES = os.environ.get("ENABLE_EMAIL_INVITES", "").lower() == "true"
|
||||||
|
|
||||||
# Security and authentication
|
# Security and authentication
|
||||||
SECRET_JWT_KEY = os.environ.get(
|
|
||||||
"SECRET_JWT_KEY", ""
|
|
||||||
) # Used for encryption of the JWT token for user's tenant context
|
|
||||||
DATA_PLANE_SECRET = os.environ.get(
|
DATA_PLANE_SECRET = os.environ.get(
|
||||||
"DATA_PLANE_SECRET", ""
|
"DATA_PLANE_SECRET", ""
|
||||||
) # Used for secure communication between the control and data plane
|
) # Used for secure communication between the control and data plane
|
||||||
|
@ -35,7 +35,7 @@ from danswer.configs.app_configs import POSTGRES_POOL_PRE_PING
|
|||||||
from danswer.configs.app_configs import POSTGRES_POOL_RECYCLE
|
from danswer.configs.app_configs import POSTGRES_POOL_RECYCLE
|
||||||
from danswer.configs.app_configs import POSTGRES_PORT
|
from danswer.configs.app_configs import POSTGRES_PORT
|
||||||
from danswer.configs.app_configs import POSTGRES_USER
|
from danswer.configs.app_configs import POSTGRES_USER
|
||||||
from danswer.configs.app_configs import SECRET_JWT_KEY
|
from danswer.configs.app_configs import USER_AUTH_SECRET
|
||||||
from danswer.configs.constants import POSTGRES_UNKNOWN_APP_NAME
|
from danswer.configs.constants import POSTGRES_UNKNOWN_APP_NAME
|
||||||
from danswer.configs.constants import TENANT_ID_PREFIX
|
from danswer.configs.constants import TENANT_ID_PREFIX
|
||||||
from danswer.utils.logger import setup_logger
|
from danswer.utils.logger import setup_logger
|
||||||
@ -263,17 +263,20 @@ def get_current_tenant_id(request: Request) -> str:
|
|||||||
CURRENT_TENANT_ID_CONTEXTVAR.set(tenant_id)
|
CURRENT_TENANT_ID_CONTEXTVAR.set(tenant_id)
|
||||||
return tenant_id
|
return tenant_id
|
||||||
|
|
||||||
token = request.cookies.get("tenant_details")
|
token = request.cookies.get("fastapiusersauth")
|
||||||
if not token:
|
if not token:
|
||||||
current_value = CURRENT_TENANT_ID_CONTEXTVAR.get()
|
current_value = CURRENT_TENANT_ID_CONTEXTVAR.get()
|
||||||
# If no token is present, use the default schema or handle accordingly
|
# If no token is present, use the default schema or handle accordingly
|
||||||
return current_value
|
return current_value
|
||||||
|
|
||||||
try:
|
try:
|
||||||
payload = jwt.decode(token, SECRET_JWT_KEY, algorithms=["HS256"])
|
payload = jwt.decode(
|
||||||
tenant_id = payload.get("tenant_id")
|
token,
|
||||||
if not tenant_id:
|
USER_AUTH_SECRET,
|
||||||
return CURRENT_TENANT_ID_CONTEXTVAR.get()
|
audience=["fastapi-users:auth"],
|
||||||
|
algorithms=["HS256"],
|
||||||
|
)
|
||||||
|
tenant_id = payload.get("tenant_id", POSTGRES_DEFAULT_SCHEMA)
|
||||||
if not is_valid_schema_name(tenant_id):
|
if not is_valid_schema_name(tenant_id):
|
||||||
raise HTTPException(status_code=400, detail="Invalid tenant ID format")
|
raise HTTPException(status_code=400, detail="Invalid tenant ID format")
|
||||||
CURRENT_TENANT_ID_CONTEXTVAR.set(tenant_id)
|
CURRENT_TENANT_ID_CONTEXTVAR.set(tenant_id)
|
||||||
|
@ -9,7 +9,7 @@ from fastapi import Request
|
|||||||
from fastapi import Response
|
from fastapi import Response
|
||||||
|
|
||||||
from danswer.configs.app_configs import MULTI_TENANT
|
from danswer.configs.app_configs import MULTI_TENANT
|
||||||
from danswer.configs.app_configs import SECRET_JWT_KEY
|
from danswer.configs.app_configs import USER_AUTH_SECRET
|
||||||
from danswer.db.engine import is_valid_schema_name
|
from danswer.db.engine import is_valid_schema_name
|
||||||
from shared_configs.configs import CURRENT_TENANT_ID_CONTEXTVAR
|
from shared_configs.configs import CURRENT_TENANT_ID_CONTEXTVAR
|
||||||
from shared_configs.configs import POSTGRES_DEFAULT_SCHEMA
|
from shared_configs.configs import POSTGRES_DEFAULT_SCHEMA
|
||||||
@ -25,11 +25,15 @@ def add_tenant_id_middleware(app: FastAPI, logger: logging.LoggerAdapter) -> Non
|
|||||||
if not MULTI_TENANT:
|
if not MULTI_TENANT:
|
||||||
tenant_id = POSTGRES_DEFAULT_SCHEMA
|
tenant_id = POSTGRES_DEFAULT_SCHEMA
|
||||||
else:
|
else:
|
||||||
token = request.cookies.get("tenant_details")
|
token = request.cookies.get("fastapiusersauth")
|
||||||
|
|
||||||
if token:
|
if token:
|
||||||
try:
|
try:
|
||||||
payload = jwt.decode(
|
payload = jwt.decode(
|
||||||
token, SECRET_JWT_KEY, algorithms=["HS256"]
|
token,
|
||||||
|
USER_AUTH_SECRET,
|
||||||
|
audience=["fastapi-users:auth"],
|
||||||
|
algorithms=["HS256"],
|
||||||
)
|
)
|
||||||
tenant_id = payload.get("tenant_id", POSTGRES_DEFAULT_SCHEMA)
|
tenant_id = payload.get("tenant_id", POSTGRES_DEFAULT_SCHEMA)
|
||||||
if not is_valid_schema_name(tenant_id):
|
if not is_valid_schema_name(tenant_id):
|
||||||
|
@ -70,18 +70,14 @@ class UserManager:
|
|||||||
|
|
||||||
cookies = response.cookies.get_dict()
|
cookies = response.cookies.get_dict()
|
||||||
session_cookie = cookies.get("fastapiusersauth")
|
session_cookie = cookies.get("fastapiusersauth")
|
||||||
tenant_details_cookie = cookies.get("tenant_details")
|
|
||||||
|
|
||||||
if not session_cookie:
|
if not session_cookie:
|
||||||
raise Exception("Failed to login")
|
raise Exception("Failed to login")
|
||||||
|
|
||||||
print(f"Logged in as {test_user.email}")
|
print(f"Logged in as {test_user.email}")
|
||||||
|
|
||||||
# Set both cookies in the headers
|
# Set cookies in the headers
|
||||||
test_user.headers["Cookie"] = (
|
test_user.headers["Cookie"] = f"fastapiusersauth={session_cookie}; "
|
||||||
f"fastapiusersauth={session_cookie}; "
|
|
||||||
f"tenant_details={tenant_details_cookie}"
|
|
||||||
)
|
|
||||||
return test_user
|
return test_user
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
|
@ -14,7 +14,7 @@ export const POST = async (request: NextRequest) => {
|
|||||||
|
|
||||||
// Delete cookies only if cloud is enabled (jwt auth)
|
// Delete cookies only if cloud is enabled (jwt auth)
|
||||||
if (NEXT_PUBLIC_CLOUD_ENABLED) {
|
if (NEXT_PUBLIC_CLOUD_ENABLED) {
|
||||||
const cookiesToDelete = ["fastapiusersauth", "tenant_details"];
|
const cookiesToDelete = ["fastapiusersauth"];
|
||||||
const cookieOptions = {
|
const cookieOptions = {
|
||||||
path: "/",
|
path: "/",
|
||||||
secure: process.env.NODE_ENV === "production",
|
secure: process.env.NODE_ENV === "production",
|
||||||
|
Loading…
x
Reference in New Issue
Block a user