mirror of
https://github.com/danswer-ai/danswer.git
synced 2025-07-21 18:43:30 +02:00
robustified cloud auth type
This commit is contained in:
@ -227,17 +227,46 @@ class UserManager(UUIDIDMixin, BaseUserManager[User, uuid.UUID]):
|
|||||||
reset_password_token_secret = USER_AUTH_SECRET
|
reset_password_token_secret = USER_AUTH_SECRET
|
||||||
verification_token_secret = USER_AUTH_SECRET
|
verification_token_secret = USER_AUTH_SECRET
|
||||||
|
|
||||||
|
# async def register(
|
||||||
|
# self,
|
||||||
|
# user_create: schemas.UC | UserCreate,
|
||||||
|
# safe: bool = False,
|
||||||
|
# request: Optional[Request] = None,
|
||||||
|
# ) -> User:
|
||||||
|
# return await super().register(user_create, safe, request)
|
||||||
|
|
||||||
async def create(
|
async def create(
|
||||||
self,
|
self,
|
||||||
user_create: schemas.UC | UserCreate,
|
user_create: schemas.UC | UserCreate,
|
||||||
safe: bool = False,
|
safe: bool = False,
|
||||||
request: Optional[Request] = None,
|
request: Optional[Request] = None,
|
||||||
) -> User:
|
) -> User:
|
||||||
|
try:
|
||||||
|
tenant_id = (
|
||||||
|
get_tenant_id_for_email(user_create.email) if MULTI_TENANT else "public"
|
||||||
|
)
|
||||||
|
except exceptions.UserNotExists:
|
||||||
|
raise HTTPException(status_code=401, detail="User not found")
|
||||||
|
|
||||||
|
if not tenant_id:
|
||||||
|
raise HTTPException(status_code=401, detail="User not found")
|
||||||
|
|
||||||
|
async with get_async_session_with_tenant(tenant_id) as db_session:
|
||||||
|
current_tenant_id.set(tenant_id)
|
||||||
|
|
||||||
verify_email_is_invited(user_create.email)
|
verify_email_is_invited(user_create.email)
|
||||||
verify_email_domain(user_create.email)
|
verify_email_domain(user_create.email)
|
||||||
|
if MULTI_TENANT:
|
||||||
|
tenant_user_db = SQLAlchemyUserAdminDB(db_session, User, OAuthAccount)
|
||||||
|
self.user_db = tenant_user_db
|
||||||
|
self.database = tenant_user_db
|
||||||
|
|
||||||
if hasattr(user_create, "role"):
|
if hasattr(user_create, "role"):
|
||||||
user_count = await get_user_count()
|
user_count = await get_user_count()
|
||||||
if user_count == 0 or user_create.email in get_default_admin_user_emails():
|
if (
|
||||||
|
user_count == 0
|
||||||
|
or user_create.email in get_default_admin_user_emails()
|
||||||
|
):
|
||||||
user_create.role = UserRole.ADMIN
|
user_create.role = UserRole.ADMIN
|
||||||
else:
|
else:
|
||||||
user_create.role = UserRole.BASIC
|
user_create.role = UserRole.BASIC
|
||||||
|
@ -103,20 +103,13 @@ const Page = async ({
|
|||||||
</div>
|
</div>
|
||||||
<EmailPasswordForm shouldVerify={true} />
|
<EmailPasswordForm shouldVerify={true} />
|
||||||
|
|
||||||
<div className="mt-6 w-full text-center">
|
<div className="flex">
|
||||||
<div className="flex items-center justify-center space-x-4">
|
<Text className="mt-4 mx-auto">
|
||||||
<div className="flex-grow border-t border-gray-300"></div>
|
Don't have an account?{" "}
|
||||||
<span className="px-4 text-sm text-gray-500 font-medium">
|
<Link href="/auth/signup" className="text-link font-medium">
|
||||||
First time?
|
Create an account
|
||||||
</span>
|
|
||||||
<div className="flex-grow border-t border-gray-300"></div>
|
|
||||||
</div>
|
|
||||||
<Link
|
|
||||||
href="/auth/signup"
|
|
||||||
className="inline-block mt-3 text-blue-600 hover:text-blue-800 transition-colors duration-200 font-semibold"
|
|
||||||
>
|
|
||||||
Verify your email
|
|
||||||
</Link>
|
</Link>
|
||||||
|
</Text>
|
||||||
</div>
|
</div>
|
||||||
</div>
|
</div>
|
||||||
)}
|
)}
|
||||||
|
@ -4,6 +4,7 @@ import {
|
|||||||
getCurrentUserSS,
|
getCurrentUserSS,
|
||||||
getAuthTypeMetadataSS,
|
getAuthTypeMetadataSS,
|
||||||
AuthTypeMetadata,
|
AuthTypeMetadata,
|
||||||
|
getAuthUrlSS,
|
||||||
} from "@/lib/userSS";
|
} from "@/lib/userSS";
|
||||||
import { redirect } from "next/navigation";
|
import { redirect } from "next/navigation";
|
||||||
import { EmailPasswordForm } from "../login/EmailPasswordForm";
|
import { EmailPasswordForm } from "../login/EmailPasswordForm";
|
||||||
@ -11,6 +12,8 @@ import { Card, Title, Text } from "@tremor/react";
|
|||||||
import Link from "next/link";
|
import Link from "next/link";
|
||||||
import { Logo } from "@/components/Logo";
|
import { Logo } from "@/components/Logo";
|
||||||
import { CLOUD_ENABLED } from "@/lib/constants";
|
import { CLOUD_ENABLED } from "@/lib/constants";
|
||||||
|
import { SignInButton } from "../login/SignInButton";
|
||||||
|
import AuthFlowContainer from "@/components/auth/AuthFlowContainer";
|
||||||
|
|
||||||
const Page = async () => {
|
const Page = async () => {
|
||||||
// catch cases where the backend is completely unreachable here
|
// catch cases where the backend is completely unreachable here
|
||||||
@ -26,9 +29,6 @@ const Page = async () => {
|
|||||||
} catch (e) {
|
} catch (e) {
|
||||||
console.log(`Some fetch failed for the login page - ${e}`);
|
console.log(`Some fetch failed for the login page - ${e}`);
|
||||||
}
|
}
|
||||||
if (CLOUD_ENABLED) {
|
|
||||||
return redirect("/auth/login");
|
|
||||||
}
|
|
||||||
|
|
||||||
// simply take the user to the home page if Auth is disabled
|
// simply take the user to the home page if Auth is disabled
|
||||||
if (authTypeMetadata?.authType === "disabled") {
|
if (authTypeMetadata?.authType === "disabled") {
|
||||||
@ -44,25 +44,44 @@ const Page = async () => {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// only enable this page if basic login is enabled
|
// only enable this page if basic login is enabled
|
||||||
if (authTypeMetadata?.authType !== "basic") {
|
if (
|
||||||
|
authTypeMetadata?.authType !== "basic" &&
|
||||||
|
authTypeMetadata?.authType !== "cloud"
|
||||||
|
) {
|
||||||
return redirect("/");
|
return redirect("/");
|
||||||
}
|
}
|
||||||
|
const cloud = authTypeMetadata?.authType === "cloud";
|
||||||
|
|
||||||
|
let authUrl: string | null = null;
|
||||||
|
if (cloud) {
|
||||||
|
authUrl = await getAuthUrlSS(authTypeMetadata.authType, null);
|
||||||
|
}
|
||||||
|
|
||||||
return (
|
return (
|
||||||
<main>
|
<AuthFlowContainer>
|
||||||
<div className="absolute top-10x w-full">
|
|
||||||
<HealthCheckBanner />
|
<HealthCheckBanner />
|
||||||
</div>
|
|
||||||
<div className="min-h-screen flex items-center justify-center py-12 px-4 sm:px-6 lg:px-8">
|
|
||||||
<div>
|
|
||||||
<Logo height={64} width={64} className="mx-auto w-fit" />
|
|
||||||
|
|
||||||
<Card className="mt-4 w-96">
|
<>
|
||||||
<div className="flex">
|
<div className="absolute top-10x w-full"></div>
|
||||||
<Title className="mb-2 mx-auto font-bold">
|
<div className="flex flex-col justify-center">
|
||||||
Sign Up for Danswer
|
<h2 className="text-center text-xl text-strong font-bold">
|
||||||
</Title>
|
{cloud ? "Complete your sign up" : "Sign Up for Danswer"}
|
||||||
|
</h2>
|
||||||
|
|
||||||
|
{cloud && authUrl && (
|
||||||
|
<div className="w-full justify-center">
|
||||||
|
<SignInButton
|
||||||
|
authorizeUrl={authUrl}
|
||||||
|
authType={authTypeMetadata?.authType}
|
||||||
|
/>
|
||||||
|
<div className="flex items-center w-full my-4">
|
||||||
|
<div className="flex-grow border-t border-gray-300"></div>
|
||||||
|
<span className="px-4 text-gray-500">or</span>
|
||||||
|
<div className="flex-grow border-t border-gray-300"></div>
|
||||||
</div>
|
</div>
|
||||||
|
</div>
|
||||||
|
)}
|
||||||
|
|
||||||
<EmailPasswordForm
|
<EmailPasswordForm
|
||||||
isSignup
|
isSignup
|
||||||
shouldVerify={authTypeMetadata?.requiresVerification}
|
shouldVerify={authTypeMetadata?.requiresVerification}
|
||||||
@ -76,10 +95,9 @@ const Page = async () => {
|
|||||||
</Link>
|
</Link>
|
||||||
</Text>
|
</Text>
|
||||||
</div>
|
</div>
|
||||||
</Card>
|
|
||||||
</div>
|
</div>
|
||||||
</div>
|
</>
|
||||||
</main>
|
</AuthFlowContainer>
|
||||||
);
|
);
|
||||||
};
|
};
|
||||||
|
|
||||||
|
Reference in New Issue
Block a user