mirror of
https://github.com/danswer-ai/danswer.git
synced 2025-07-12 14:12:53 +02:00
robustified cloud auth type
This commit is contained in:
@ -227,41 +227,70 @@ class UserManager(UUIDIDMixin, BaseUserManager[User, uuid.UUID]):
|
||||
reset_password_token_secret = USER_AUTH_SECRET
|
||||
verification_token_secret = USER_AUTH_SECRET
|
||||
|
||||
# async def register(
|
||||
# self,
|
||||
# user_create: schemas.UC | UserCreate,
|
||||
# safe: bool = False,
|
||||
# request: Optional[Request] = None,
|
||||
# ) -> User:
|
||||
# return await super().register(user_create, safe, request)
|
||||
|
||||
async def create(
|
||||
self,
|
||||
user_create: schemas.UC | UserCreate,
|
||||
safe: bool = False,
|
||||
request: Optional[Request] = None,
|
||||
) -> User:
|
||||
verify_email_is_invited(user_create.email)
|
||||
verify_email_domain(user_create.email)
|
||||
if hasattr(user_create, "role"):
|
||||
user_count = await get_user_count()
|
||||
if user_count == 0 or user_create.email in get_default_admin_user_emails():
|
||||
user_create.role = UserRole.ADMIN
|
||||
else:
|
||||
user_create.role = UserRole.BASIC
|
||||
user = None
|
||||
try:
|
||||
user = await super().create(user_create, safe=safe, request=request) # type: ignore
|
||||
except exceptions.UserAlreadyExists:
|
||||
user = await self.get_by_email(user_create.email)
|
||||
# Handle case where user has used product outside of web and is now creating an account through web
|
||||
if (
|
||||
not user.has_web_login
|
||||
and hasattr(user_create, "has_web_login")
|
||||
and user_create.has_web_login
|
||||
):
|
||||
user_update = UserUpdate(
|
||||
password=user_create.password,
|
||||
has_web_login=True,
|
||||
role=user_create.role,
|
||||
is_verified=user_create.is_verified,
|
||||
)
|
||||
user = await self.update(user_update, user)
|
||||
else:
|
||||
raise exceptions.UserAlreadyExists()
|
||||
return user
|
||||
tenant_id = (
|
||||
get_tenant_id_for_email(user_create.email) if MULTI_TENANT else "public"
|
||||
)
|
||||
except exceptions.UserNotExists:
|
||||
raise HTTPException(status_code=401, detail="User not found")
|
||||
|
||||
if not tenant_id:
|
||||
raise HTTPException(status_code=401, detail="User not found")
|
||||
|
||||
async with get_async_session_with_tenant(tenant_id) as db_session:
|
||||
current_tenant_id.set(tenant_id)
|
||||
|
||||
verify_email_is_invited(user_create.email)
|
||||
verify_email_domain(user_create.email)
|
||||
if MULTI_TENANT:
|
||||
tenant_user_db = SQLAlchemyUserAdminDB(db_session, User, OAuthAccount)
|
||||
self.user_db = tenant_user_db
|
||||
self.database = tenant_user_db
|
||||
|
||||
if hasattr(user_create, "role"):
|
||||
user_count = await get_user_count()
|
||||
if (
|
||||
user_count == 0
|
||||
or user_create.email in get_default_admin_user_emails()
|
||||
):
|
||||
user_create.role = UserRole.ADMIN
|
||||
else:
|
||||
user_create.role = UserRole.BASIC
|
||||
user = None
|
||||
try:
|
||||
user = await super().create(user_create, safe=safe, request=request) # type: ignore
|
||||
except exceptions.UserAlreadyExists:
|
||||
user = await self.get_by_email(user_create.email)
|
||||
# Handle case where user has used product outside of web and is now creating an account through web
|
||||
if (
|
||||
not user.has_web_login
|
||||
and hasattr(user_create, "has_web_login")
|
||||
and user_create.has_web_login
|
||||
):
|
||||
user_update = UserUpdate(
|
||||
password=user_create.password,
|
||||
has_web_login=True,
|
||||
role=user_create.role,
|
||||
is_verified=user_create.is_verified,
|
||||
)
|
||||
user = await self.update(user_update, user)
|
||||
else:
|
||||
raise exceptions.UserAlreadyExists()
|
||||
return user
|
||||
|
||||
async def on_after_login(
|
||||
self,
|
||||
|
@ -103,20 +103,13 @@ const Page = async ({
|
||||
</div>
|
||||
<EmailPasswordForm shouldVerify={true} />
|
||||
|
||||
<div className="mt-6 w-full text-center">
|
||||
<div className="flex items-center justify-center space-x-4">
|
||||
<div className="flex-grow border-t border-gray-300"></div>
|
||||
<span className="px-4 text-sm text-gray-500 font-medium">
|
||||
First time?
|
||||
</span>
|
||||
<div className="flex-grow border-t border-gray-300"></div>
|
||||
</div>
|
||||
<Link
|
||||
href="/auth/signup"
|
||||
className="inline-block mt-3 text-blue-600 hover:text-blue-800 transition-colors duration-200 font-semibold"
|
||||
>
|
||||
Verify your email
|
||||
</Link>
|
||||
<div className="flex">
|
||||
<Text className="mt-4 mx-auto">
|
||||
Don't have an account?{" "}
|
||||
<Link href="/auth/signup" className="text-link font-medium">
|
||||
Create an account
|
||||
</Link>
|
||||
</Text>
|
||||
</div>
|
||||
</div>
|
||||
)}
|
||||
|
@ -4,6 +4,7 @@ import {
|
||||
getCurrentUserSS,
|
||||
getAuthTypeMetadataSS,
|
||||
AuthTypeMetadata,
|
||||
getAuthUrlSS,
|
||||
} from "@/lib/userSS";
|
||||
import { redirect } from "next/navigation";
|
||||
import { EmailPasswordForm } from "../login/EmailPasswordForm";
|
||||
@ -11,6 +12,8 @@ import { Card, Title, Text } from "@tremor/react";
|
||||
import Link from "next/link";
|
||||
import { Logo } from "@/components/Logo";
|
||||
import { CLOUD_ENABLED } from "@/lib/constants";
|
||||
import { SignInButton } from "../login/SignInButton";
|
||||
import AuthFlowContainer from "@/components/auth/AuthFlowContainer";
|
||||
|
||||
const Page = async () => {
|
||||
// catch cases where the backend is completely unreachable here
|
||||
@ -26,9 +29,6 @@ const Page = async () => {
|
||||
} catch (e) {
|
||||
console.log(`Some fetch failed for the login page - ${e}`);
|
||||
}
|
||||
if (CLOUD_ENABLED) {
|
||||
return redirect("/auth/login");
|
||||
}
|
||||
|
||||
// simply take the user to the home page if Auth is disabled
|
||||
if (authTypeMetadata?.authType === "disabled") {
|
||||
@ -44,42 +44,60 @@ const Page = async () => {
|
||||
}
|
||||
|
||||
// only enable this page if basic login is enabled
|
||||
if (authTypeMetadata?.authType !== "basic") {
|
||||
if (
|
||||
authTypeMetadata?.authType !== "basic" &&
|
||||
authTypeMetadata?.authType !== "cloud"
|
||||
) {
|
||||
return redirect("/");
|
||||
}
|
||||
const cloud = authTypeMetadata?.authType === "cloud";
|
||||
|
||||
let authUrl: string | null = null;
|
||||
if (cloud) {
|
||||
authUrl = await getAuthUrlSS(authTypeMetadata.authType, null);
|
||||
}
|
||||
|
||||
return (
|
||||
<main>
|
||||
<div className="absolute top-10x w-full">
|
||||
<HealthCheckBanner />
|
||||
</div>
|
||||
<div className="min-h-screen flex items-center justify-center py-12 px-4 sm:px-6 lg:px-8">
|
||||
<div>
|
||||
<Logo height={64} width={64} className="mx-auto w-fit" />
|
||||
<AuthFlowContainer>
|
||||
<HealthCheckBanner />
|
||||
|
||||
<Card className="mt-4 w-96">
|
||||
<div className="flex">
|
||||
<Title className="mb-2 mx-auto font-bold">
|
||||
Sign Up for Danswer
|
||||
</Title>
|
||||
</div>
|
||||
<EmailPasswordForm
|
||||
isSignup
|
||||
shouldVerify={authTypeMetadata?.requiresVerification}
|
||||
/>
|
||||
<>
|
||||
<div className="absolute top-10x w-full"></div>
|
||||
<div className="flex flex-col justify-center">
|
||||
<h2 className="text-center text-xl text-strong font-bold">
|
||||
{cloud ? "Complete your sign up" : "Sign Up for Danswer"}
|
||||
</h2>
|
||||
|
||||
<div className="flex">
|
||||
<Text className="mt-4 mx-auto">
|
||||
Already have an account?{" "}
|
||||
<Link href="/auth/login" className="text-link font-medium">
|
||||
Log In
|
||||
</Link>
|
||||
</Text>
|
||||
{cloud && authUrl && (
|
||||
<div className="w-full justify-center">
|
||||
<SignInButton
|
||||
authorizeUrl={authUrl}
|
||||
authType={authTypeMetadata?.authType}
|
||||
/>
|
||||
<div className="flex items-center w-full my-4">
|
||||
<div className="flex-grow border-t border-gray-300"></div>
|
||||
<span className="px-4 text-gray-500">or</span>
|
||||
<div className="flex-grow border-t border-gray-300"></div>
|
||||
</div>
|
||||
</div>
|
||||
</Card>
|
||||
)}
|
||||
|
||||
<EmailPasswordForm
|
||||
isSignup
|
||||
shouldVerify={authTypeMetadata?.requiresVerification}
|
||||
/>
|
||||
|
||||
<div className="flex">
|
||||
<Text className="mt-4 mx-auto">
|
||||
Already have an account?{" "}
|
||||
<Link href="/auth/login" className="text-link font-medium">
|
||||
Log In
|
||||
</Link>
|
||||
</Text>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
</main>
|
||||
</>
|
||||
</AuthFlowContainer>
|
||||
);
|
||||
};
|
||||
|
||||
|
Reference in New Issue
Block a user