mirror of
https://github.com/danswer-ai/danswer.git
synced 2025-06-25 07:21:00 +02:00
Improve multi tenant anonymous user interaction (#3857)
* cleaner handling * k * k * address nits * fix typing
This commit is contained in:
parent
a1cef389aa
commit
dc18d53133
@ -44,7 +44,7 @@ async def _get_tenant_id_from_request(
|
|||||||
Attempt to extract tenant_id from:
|
Attempt to extract tenant_id from:
|
||||||
1) The API key header
|
1) The API key header
|
||||||
2) The Redis-based token (stored in Cookie: fastapiusersauth)
|
2) The Redis-based token (stored in Cookie: fastapiusersauth)
|
||||||
3) Reset token cookie
|
3) The anonymous user cookie
|
||||||
Fallback: POSTGRES_DEFAULT_SCHEMA
|
Fallback: POSTGRES_DEFAULT_SCHEMA
|
||||||
"""
|
"""
|
||||||
# Check for API key
|
# Check for API key
|
||||||
@ -52,42 +52,56 @@ async def _get_tenant_id_from_request(
|
|||||||
if tenant_id is not None:
|
if tenant_id is not None:
|
||||||
return tenant_id
|
return tenant_id
|
||||||
|
|
||||||
# Check for anonymous user cookie
|
|
||||||
anonymous_user_cookie = request.cookies.get(ANONYMOUS_USER_COOKIE_NAME)
|
|
||||||
if anonymous_user_cookie:
|
|
||||||
try:
|
|
||||||
anonymous_user_data = decode_anonymous_user_jwt_token(anonymous_user_cookie)
|
|
||||||
return anonymous_user_data.get("tenant_id", POSTGRES_DEFAULT_SCHEMA)
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"Error decoding anonymous user cookie: {str(e)}")
|
|
||||||
# Continue and attempt to authenticate
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# Look up token data in Redis
|
# Look up token data in Redis
|
||||||
|
|
||||||
token_data = await retrieve_auth_token_data_from_redis(request)
|
token_data = await retrieve_auth_token_data_from_redis(request)
|
||||||
|
|
||||||
if not token_data:
|
if token_data:
|
||||||
|
tenant_id_from_payload = token_data.get(
|
||||||
|
"tenant_id", POSTGRES_DEFAULT_SCHEMA
|
||||||
|
)
|
||||||
|
|
||||||
|
tenant_id = (
|
||||||
|
str(tenant_id_from_payload)
|
||||||
|
if tenant_id_from_payload is not None
|
||||||
|
else None
|
||||||
|
)
|
||||||
|
|
||||||
|
if tenant_id and not is_valid_schema_name(tenant_id):
|
||||||
|
raise HTTPException(status_code=400, detail="Invalid tenant ID format")
|
||||||
|
|
||||||
|
# Check for anonymous user cookie
|
||||||
|
anonymous_user_cookie = request.cookies.get(ANONYMOUS_USER_COOKIE_NAME)
|
||||||
|
if anonymous_user_cookie:
|
||||||
|
try:
|
||||||
|
anonymous_user_data = decode_anonymous_user_jwt_token(
|
||||||
|
anonymous_user_cookie
|
||||||
|
)
|
||||||
|
tenant_id = anonymous_user_data.get(
|
||||||
|
"tenant_id", POSTGRES_DEFAULT_SCHEMA
|
||||||
|
)
|
||||||
|
|
||||||
|
if not tenant_id or not is_valid_schema_name(tenant_id):
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=400, detail="Invalid tenant ID format"
|
||||||
|
)
|
||||||
|
|
||||||
|
return tenant_id
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error decoding anonymous user cookie: {str(e)}")
|
||||||
|
# Continue and attempt to authenticate
|
||||||
|
|
||||||
logger.debug(
|
logger.debug(
|
||||||
"Token data not found or expired in Redis, defaulting to POSTGRES_DEFAULT_SCHEMA"
|
"Token data not found or expired in Redis, defaulting to POSTGRES_DEFAULT_SCHEMA"
|
||||||
)
|
)
|
||||||
|
|
||||||
# Return POSTGRES_DEFAULT_SCHEMA, so non-authenticated requests are sent to the default schema
|
# Return POSTGRES_DEFAULT_SCHEMA, so non-authenticated requests are sent to the default schema
|
||||||
# The CURRENT_TENANT_ID_CONTEXTVAR is initialized with POSTGRES_DEFAULT_SCHEMA,
|
# The CURRENT_TENANT_ID_CONTEXTVAR is initialized with POSTGRES_DEFAULT_SCHEMA,
|
||||||
# so we maintain consistency by returning it here when no valid tenant is found.
|
# so we maintain consistency by returning it here when no valid tenant is found.
|
||||||
return POSTGRES_DEFAULT_SCHEMA
|
return POSTGRES_DEFAULT_SCHEMA
|
||||||
|
|
||||||
tenant_id_from_payload = token_data.get("tenant_id", POSTGRES_DEFAULT_SCHEMA)
|
|
||||||
|
|
||||||
# Since token_data.get() can return None, ensure we have a string
|
|
||||||
tenant_id = (
|
|
||||||
str(tenant_id_from_payload)
|
|
||||||
if tenant_id_from_payload is not None
|
|
||||||
else POSTGRES_DEFAULT_SCHEMA
|
|
||||||
)
|
|
||||||
|
|
||||||
if not is_valid_schema_name(tenant_id):
|
|
||||||
raise HTTPException(status_code=400, detail="Invalid tenant ID format")
|
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Unexpected error in _get_tenant_id_from_request: {str(e)}")
|
logger.error(f"Unexpected error in _get_tenant_id_from_request: {str(e)}")
|
||||||
raise HTTPException(status_code=500, detail="Internal server error")
|
raise HTTPException(status_code=500, detail="Internal server error")
|
||||||
|
@ -56,6 +56,7 @@ from httpx_oauth.oauth2 import OAuth2Token
|
|||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
from sqlalchemy.ext.asyncio import AsyncSession
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
|
|
||||||
|
from ee.onyx.configs.app_configs import ANONYMOUS_USER_COOKIE_NAME
|
||||||
from onyx.auth.api_key import get_hashed_api_key_from_request
|
from onyx.auth.api_key import get_hashed_api_key_from_request
|
||||||
from onyx.auth.email_utils import send_forgot_password_email
|
from onyx.auth.email_utils import send_forgot_password_email
|
||||||
from onyx.auth.email_utils import send_user_verification_email
|
from onyx.auth.email_utils import send_user_verification_email
|
||||||
@ -363,6 +364,15 @@ class UserManager(UUIDIDMixin, BaseUserManager[User, uuid.UUID]):
|
|||||||
|
|
||||||
return
|
return
|
||||||
|
|
||||||
|
async def on_after_login(
|
||||||
|
self,
|
||||||
|
user: User,
|
||||||
|
request: Optional[Request] = None,
|
||||||
|
response: Optional[Response] = None,
|
||||||
|
) -> None:
|
||||||
|
if response:
|
||||||
|
response.delete_cookie(ANONYMOUS_USER_COOKIE_NAME)
|
||||||
|
|
||||||
async def oauth_callback(
|
async def oauth_callback(
|
||||||
self,
|
self,
|
||||||
oauth_name: str,
|
oauth_name: str,
|
||||||
|
@ -14,7 +14,7 @@ export default function LoginPage({
|
|||||||
authTypeMetadata,
|
authTypeMetadata,
|
||||||
nextUrl,
|
nextUrl,
|
||||||
searchParams,
|
searchParams,
|
||||||
showPageRedirect,
|
hidePageRedirect,
|
||||||
}: {
|
}: {
|
||||||
authUrl: string | null;
|
authUrl: string | null;
|
||||||
authTypeMetadata: AuthTypeMetadata | null;
|
authTypeMetadata: AuthTypeMetadata | null;
|
||||||
@ -24,7 +24,7 @@ export default function LoginPage({
|
|||||||
[key: string]: string | string[] | undefined;
|
[key: string]: string | string[] | undefined;
|
||||||
}
|
}
|
||||||
| undefined;
|
| undefined;
|
||||||
showPageRedirect?: boolean;
|
hidePageRedirect?: boolean;
|
||||||
}) {
|
}) {
|
||||||
useSendAuthRequiredMessage();
|
useSendAuthRequiredMessage();
|
||||||
return (
|
return (
|
||||||
@ -75,7 +75,7 @@ export default function LoginPage({
|
|||||||
<div className="flex flex-col gap-y-2 items-center"></div>
|
<div className="flex flex-col gap-y-2 items-center"></div>
|
||||||
</>
|
</>
|
||||||
)}
|
)}
|
||||||
{showPageRedirect && (
|
{!hidePageRedirect && (
|
||||||
<p className="text-center mt-4">
|
<p className="text-center mt-4">
|
||||||
Don't have an account?{" "}
|
Don't have an account?{" "}
|
||||||
<span
|
<span
|
||||||
|
@ -72,6 +72,7 @@ const Page = async (props: {
|
|||||||
authTypeMetadata={authTypeMetadata}
|
authTypeMetadata={authTypeMetadata}
|
||||||
nextUrl={nextUrl!}
|
nextUrl={nextUrl!}
|
||||||
searchParams={searchParams}
|
searchParams={searchParams}
|
||||||
|
hidePageRedirect={true}
|
||||||
/>
|
/>
|
||||||
</AuthFlowContainer>
|
</AuthFlowContainer>
|
||||||
</div>
|
</div>
|
||||||
|
@ -347,7 +347,6 @@ export default function NRFPage({
|
|||||||
<p className="p-4">Loading login info…</p>
|
<p className="p-4">Loading login info…</p>
|
||||||
) : authType == "basic" ? (
|
) : authType == "basic" ? (
|
||||||
<LoginPage
|
<LoginPage
|
||||||
showPageRedirect
|
|
||||||
authUrl={null}
|
authUrl={null}
|
||||||
authTypeMetadata={{
|
authTypeMetadata={{
|
||||||
authType: authType as AuthType,
|
authType: authType as AuthType,
|
||||||
|
@ -55,7 +55,7 @@ export async function generateMetadata(): Promise<Metadata> {
|
|||||||
}
|
}
|
||||||
|
|
||||||
return {
|
return {
|
||||||
title: enterpriseSettings?.application_name ?? "Onyx",
|
title: enterpriseSettings?.application_name || "Onyx",
|
||||||
description: "Question answering for your documents",
|
description: "Question answering for your documents",
|
||||||
icons: {
|
icons: {
|
||||||
icon: logoLocation,
|
icon: logoLocation,
|
||||||
|
Loading…
x
Reference in New Issue
Block a user