mirror of
https://github.com/danswer-ai/danswer.git
synced 2025-05-07 10:20:32 +02:00
Auth fix + Registration Clarity (#3590)
* clarify auth flow * k * nit * k * fix typing
This commit is contained in:
parent
e100a5e965
commit
c8090ab75b
@ -46,6 +46,7 @@ def register_tenant_users(tenant_id: str, number_of_users: int) -> stripe.Subscr
|
|||||||
"""
|
"""
|
||||||
Send a request to the control service to register the number of users for a tenant.
|
Send a request to the control service to register the number of users for a tenant.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
if not STRIPE_PRICE_ID:
|
if not STRIPE_PRICE_ID:
|
||||||
raise Exception("STRIPE_PRICE_ID is not set")
|
raise Exception("STRIPE_PRICE_ID is not set")
|
||||||
|
|
||||||
|
@ -40,21 +40,24 @@ def send_email(
|
|||||||
|
|
||||||
|
|
||||||
def send_user_email_invite(user_email: str, current_user: User) -> None:
|
def send_user_email_invite(user_email: str, current_user: User) -> None:
|
||||||
subject = "Invitation to Join Onyx Workspace"
|
subject = "Invitation to Join Onyx Organization"
|
||||||
body = dedent(
|
body = dedent(
|
||||||
f"""\
|
f"""\
|
||||||
Hello,
|
Hello,
|
||||||
|
|
||||||
You have been invited to join a workspace on Onyx.
|
You have been invited to join an organization on Onyx.
|
||||||
|
|
||||||
To join the workspace, please visit the following link:
|
To join the organization, please visit the following link:
|
||||||
|
|
||||||
{WEB_DOMAIN}/auth/login
|
{WEB_DOMAIN}/auth/signup?email={user_email}
|
||||||
|
|
||||||
|
You'll be asked to set a password or login with Google to complete your registration.
|
||||||
|
|
||||||
Best regards,
|
Best regards,
|
||||||
The Onyx Team
|
The Onyx Team
|
||||||
"""
|
"""
|
||||||
)
|
)
|
||||||
|
|
||||||
send_email(user_email, subject, body, current_user.email)
|
send_email(user_email, subject, body, current_user.email)
|
||||||
|
|
||||||
|
|
||||||
|
@ -46,7 +46,6 @@ from httpx_oauth.integrations.fastapi import OAuth2AuthorizeCallback
|
|||||||
from httpx_oauth.oauth2 import BaseOAuth2
|
from httpx_oauth.oauth2 import BaseOAuth2
|
||||||
from httpx_oauth.oauth2 import OAuth2Token
|
from httpx_oauth.oauth2 import OAuth2Token
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
from sqlalchemy import text
|
|
||||||
from sqlalchemy.ext.asyncio import AsyncSession
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
|
|
||||||
from onyx.auth.api_key import get_hashed_api_key_from_request
|
from onyx.auth.api_key import get_hashed_api_key_from_request
|
||||||
@ -396,11 +395,9 @@ class UserManager(UUIDIDMixin, BaseUserManager[User, uuid.UUID]):
|
|||||||
|
|
||||||
# Explicitly set the Postgres schema for this session to ensure
|
# Explicitly set the Postgres schema for this session to ensure
|
||||||
# OAuth account creation happens in the correct tenant schema
|
# OAuth account creation happens in the correct tenant schema
|
||||||
await db_session.execute(text(f'SET search_path = "{tenant_id}"'))
|
|
||||||
|
|
||||||
# Add OAuth account
|
# Add OAuth account
|
||||||
await self.user_db.add_oauth_account(user, oauth_account_dict)
|
await self.user_db.add_oauth_account(user, oauth_account_dict)
|
||||||
|
|
||||||
await self.on_after_register(user, request)
|
await self.on_after_register(user, request)
|
||||||
|
|
||||||
else:
|
else:
|
||||||
@ -419,7 +416,6 @@ class UserManager(UUIDIDMixin, BaseUserManager[User, uuid.UUID]):
|
|||||||
|
|
||||||
# NOTE: Most IdPs have very short expiry times, and we don't want to force the user to
|
# NOTE: Most IdPs have very short expiry times, and we don't want to force the user to
|
||||||
# re-authenticate that frequently, so by default this is disabled
|
# re-authenticate that frequently, so by default this is disabled
|
||||||
|
|
||||||
if expires_at and TRACK_EXTERNAL_IDP_EXPIRY:
|
if expires_at and TRACK_EXTERNAL_IDP_EXPIRY:
|
||||||
oidc_expiry = datetime.fromtimestamp(expires_at, tz=timezone.utc)
|
oidc_expiry = datetime.fromtimestamp(expires_at, tz=timezone.utc)
|
||||||
await self.user_db.update(
|
await self.user_db.update(
|
||||||
|
@ -370,9 +370,23 @@ async def get_async_session_with_tenant(
|
|||||||
bind=engine, expire_on_commit=False, class_=AsyncSession
|
bind=engine, expire_on_commit=False, class_=AsyncSession
|
||||||
) # type: ignore
|
) # type: ignore
|
||||||
|
|
||||||
|
async def _set_search_path(session: AsyncSession, tenant_id: str) -> None:
|
||||||
|
await session.execute(text(f'SET search_path = "{tenant_id}"'))
|
||||||
|
|
||||||
async with async_session_factory() as session:
|
async with async_session_factory() as session:
|
||||||
|
# Register an event listener that is called whenever a new transaction starts
|
||||||
|
@event.listens_for(session.sync_session, "after_begin")
|
||||||
|
def after_begin(session_: Any, transaction: Any, connection: Any) -> None:
|
||||||
|
# Because the event is sync, we can't directly await here.
|
||||||
|
# Instead we queue up an asyncio task to ensures
|
||||||
|
# the next statement sets the search_path
|
||||||
|
session_.do_orm_execute = lambda state: connection.exec_driver_sql(
|
||||||
|
f'SET search_path = "{tenant_id}"'
|
||||||
|
)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
await session.execute(text(f'SET search_path = "{tenant_id}"'))
|
await _set_search_path(session, tenant_id)
|
||||||
|
|
||||||
if POSTGRES_IDLE_SESSIONS_TIMEOUT:
|
if POSTGRES_IDLE_SESSIONS_TIMEOUT:
|
||||||
await session.execute(
|
await session.execute(
|
||||||
text(
|
text(
|
||||||
|
@ -19,11 +19,13 @@ export function EmailPasswordForm({
|
|||||||
shouldVerify,
|
shouldVerify,
|
||||||
referralSource,
|
referralSource,
|
||||||
nextUrl,
|
nextUrl,
|
||||||
|
defaultEmail,
|
||||||
}: {
|
}: {
|
||||||
isSignup?: boolean;
|
isSignup?: boolean;
|
||||||
shouldVerify?: boolean;
|
shouldVerify?: boolean;
|
||||||
referralSource?: string;
|
referralSource?: string;
|
||||||
nextUrl?: string | null;
|
nextUrl?: string | null;
|
||||||
|
defaultEmail?: string | null;
|
||||||
}) {
|
}) {
|
||||||
const { user } = useUser();
|
const { user } = useUser();
|
||||||
const { popup, setPopup } = usePopup();
|
const { popup, setPopup } = usePopup();
|
||||||
@ -34,7 +36,7 @@ export function EmailPasswordForm({
|
|||||||
{popup}
|
{popup}
|
||||||
<Formik
|
<Formik
|
||||||
initialValues={{
|
initialValues={{
|
||||||
email: "",
|
email: defaultEmail || "",
|
||||||
password: "",
|
password: "",
|
||||||
}}
|
}}
|
||||||
validationSchema={Yup.object().shape({
|
validationSchema={Yup.object().shape({
|
||||||
|
@ -22,6 +22,10 @@ const Page = async (props: {
|
|||||||
? searchParams?.next[0]
|
? searchParams?.next[0]
|
||||||
: searchParams?.next || null;
|
: searchParams?.next || null;
|
||||||
|
|
||||||
|
const defaultEmail = Array.isArray(searchParams?.email)
|
||||||
|
? searchParams?.email[0]
|
||||||
|
: searchParams?.email || null;
|
||||||
|
|
||||||
// catch cases where the backend is completely unreachable here
|
// catch cases where the backend is completely unreachable here
|
||||||
// without try / catch, will just raise an exception and the page
|
// without try / catch, will just raise an exception and the page
|
||||||
// will not render
|
// will not render
|
||||||
@ -93,6 +97,7 @@ const Page = async (props: {
|
|||||||
isSignup
|
isSignup
|
||||||
shouldVerify={authTypeMetadata?.requiresVerification}
|
shouldVerify={authTypeMetadata?.requiresVerification}
|
||||||
nextUrl={nextUrl}
|
nextUrl={nextUrl}
|
||||||
|
defaultEmail={defaultEmail}
|
||||||
/>
|
/>
|
||||||
</div>
|
</div>
|
||||||
</>
|
</>
|
||||||
|
Loading…
x
Reference in New Issue
Block a user