Auth on main (#2878)

* add cloud auth type

* k

* robustified cloud auth type

* k

* minor typing
This commit is contained in:
pablodanswer 2024-10-23 09:46:30 -07:00 committed by GitHub
parent 9105f95d13
commit 5703ea47d2
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
9 changed files with 146 additions and 69 deletions

View File

@ -233,35 +233,60 @@ class UserManager(UUIDIDMixin, BaseUserManager[User, uuid.UUID]):
safe: bool = False,
request: Optional[Request] = None,
) -> User:
verify_email_is_invited(user_create.email)
verify_email_domain(user_create.email)
if hasattr(user_create, "role"):
user_count = await get_user_count()
if user_count == 0 or user_create.email in get_default_admin_user_emails():
user_create.role = UserRole.ADMIN
else:
user_create.role = UserRole.BASIC
user = None
try:
user = await super().create(user_create, safe=safe, request=request) # type: ignore
except exceptions.UserAlreadyExists:
user = await self.get_by_email(user_create.email)
# Handle case where user has used product outside of web and is now creating an account through web
if (
not user.has_web_login
and hasattr(user_create, "has_web_login")
and user_create.has_web_login
):
user_update = UserUpdate(
password=user_create.password,
has_web_login=True,
role=user_create.role,
is_verified=user_create.is_verified,
)
user = await self.update(user_update, user)
else:
raise exceptions.UserAlreadyExists()
return user
tenant_id = (
get_tenant_id_for_email(user_create.email) if MULTI_TENANT else "public"
)
except exceptions.UserNotExists:
raise HTTPException(status_code=401, detail="User not found")
if not tenant_id:
raise HTTPException(
status_code=401, detail="User does not belong to an organization"
)
async with get_async_session_with_tenant(tenant_id) as db_session:
token = current_tenant_id.set(tenant_id)
verify_email_is_invited(user_create.email)
verify_email_domain(user_create.email)
if MULTI_TENANT:
tenant_user_db = SQLAlchemyUserAdminDB(db_session, User, OAuthAccount)
self.user_db = tenant_user_db
self.database = tenant_user_db
if hasattr(user_create, "role"):
user_count = await get_user_count()
if (
user_count == 0
or user_create.email in get_default_admin_user_emails()
):
user_create.role = UserRole.ADMIN
else:
user_create.role = UserRole.BASIC
user = None
try:
user = await super().create(user_create, safe=safe, request=request) # type: ignore
except exceptions.UserAlreadyExists:
user = await self.get_by_email(user_create.email)
# Handle case where user has used product outside of web and is now creating an account through web
if (
not user.has_web_login
and hasattr(user_create, "has_web_login")
and user_create.has_web_login
):
user_update = UserUpdate(
password=user_create.password,
has_web_login=True,
role=user_create.role,
is_verified=user_create.is_verified,
)
user = await self.update(user_update, user)
else:
raise exceptions.UserAlreadyExists()
current_tenant_id.reset(token)
return user
async def on_after_login(
self,
@ -320,7 +345,7 @@ class UserManager(UUIDIDMixin, BaseUserManager[User, uuid.UUID]):
if MULTI_TENANT:
tenant_user_db = SQLAlchemyUserAdminDB(db_session, User, OAuthAccount)
self.user_db = tenant_user_db
self.database = tenant_user_db
self.database = tenant_user_db # type: ignore
oauth_account_dict = {
"oauth_name": oauth_name,

View File

@ -160,6 +160,9 @@ class AuthType(str, Enum):
OIDC = "oidc"
SAML = "saml"
# google auth and basic
CLOUD = "cloud"
class SessionType(str, Enum):
CHAT = "Chat"

View File

@ -269,7 +269,7 @@ def get_application() -> FastAPI:
# Server logs this during auth setup verification step
pass
elif AUTH_TYPE == AuthType.BASIC:
if AUTH_TYPE == AuthType.BASIC or AUTH_TYPE == AuthType.CLOUD:
include_router_with_global_prefix_prepended(
application,
fastapi_users.get_auth_router(auth_backend),
@ -301,7 +301,7 @@ def get_application() -> FastAPI:
tags=["users"],
)
elif AUTH_TYPE == AuthType.GOOGLE_OAUTH:
if AUTH_TYPE == AuthType.GOOGLE_OAUTH or AUTH_TYPE == AuthType.CLOUD:
oauth_client = GoogleOAuth2(OAUTH_CLIENT_ID, OAUTH_CLIENT_SECRET)
include_router_with_global_prefix_prepended(
application,

View File

@ -9,7 +9,7 @@ export function SignInButton({
authType: AuthType;
}) {
let button;
if (authType === "google_oauth") {
if (authType === "google_oauth" || authType === "cloud") {
button = (
<div className="mx-auto flex">
<div className="my-auto mr-2">
@ -42,7 +42,7 @@ export function SignInButton({
return (
<a
className="mt-6 py-3 w-72 text-text-100 bg-accent flex rounded cursor-pointer hover:bg-indigo-800"
className="mx-auto mt-6 py-3 w-72 text-text-100 bg-accent flex rounded cursor-pointer hover:bg-indigo-800"
href={authorizeUrl}
>
{button}

View File

@ -78,7 +78,7 @@ const Page = async ({
<HealthCheckBanner />
</div>
<div>
<div className="flex flex-col w-full justify-center">
{authUrl && authTypeMetadata && (
<>
<h2 className="text-center text-xl text-strong font-bold">
@ -92,6 +92,26 @@ const Page = async ({
</>
)}
{authTypeMetadata?.authType === "cloud" && (
<div className="mt-4 w-full justify-center">
<div className="flex items-center w-full my-4">
<div className="flex-grow border-t border-gray-300"></div>
<span className="px-4 text-gray-500">or</span>
<div className="flex-grow border-t border-gray-300"></div>
</div>
<EmailPasswordForm shouldVerify={true} />
<div className="flex">
<Text className="mt-4 mx-auto">
Don&apos;t have an account?{" "}
<Link href="/auth/signup" className="text-link font-medium">
Create an account
</Link>
</Text>
</div>
</div>
)}
{authTypeMetadata?.authType === "basic" && (
<Card className="mt-4 w-96">
<div className="flex">

View File

@ -4,6 +4,7 @@ import {
getCurrentUserSS,
getAuthTypeMetadataSS,
AuthTypeMetadata,
getAuthUrlSS,
} from "@/lib/userSS";
import { redirect } from "next/navigation";
import { EmailPasswordForm } from "../login/EmailPasswordForm";
@ -11,6 +12,8 @@ import { Card, Title, Text } from "@tremor/react";
import Link from "next/link";
import { Logo } from "@/components/Logo";
import { CLOUD_ENABLED } from "@/lib/constants";
import { SignInButton } from "../login/SignInButton";
import AuthFlowContainer from "@/components/auth/AuthFlowContainer";
const Page = async () => {
// catch cases where the backend is completely unreachable here
@ -26,9 +29,6 @@ const Page = async () => {
} catch (e) {
console.log(`Some fetch failed for the login page - ${e}`);
}
if (CLOUD_ENABLED) {
return redirect("/auth/login");
}
// simply take the user to the home page if Auth is disabled
if (authTypeMetadata?.authType === "disabled") {
@ -42,44 +42,56 @@ const Page = async () => {
}
return redirect("/auth/waiting-on-verification");
}
const cloud = authTypeMetadata?.authType === "cloud";
// only enable this page if basic login is enabled
if (authTypeMetadata?.authType !== "basic") {
if (authTypeMetadata?.authType !== "basic" && !cloud) {
return redirect("/");
}
let authUrl: string | null = null;
if (cloud && authTypeMetadata) {
authUrl = await getAuthUrlSS(authTypeMetadata.authType, null);
}
return (
<main>
<div className="absolute top-10x w-full">
<HealthCheckBanner />
</div>
<div className="min-h-screen flex items-center justify-center py-12 px-4 sm:px-6 lg:px-8">
<div>
<Logo height={64} width={64} className="mx-auto w-fit" />
<AuthFlowContainer>
<HealthCheckBanner />
<Card className="mt-4 w-96">
<div className="flex">
<Title className="mb-2 mx-auto font-bold">
Sign Up for Danswer
</Title>
</div>
<EmailPasswordForm
isSignup
shouldVerify={authTypeMetadata?.requiresVerification}
/>
<>
<div className="absolute top-10x w-full"></div>
<div className="flex w-full flex-col justify-center">
<h2 className="text-center text-xl text-strong font-bold">
{cloud ? "Complete your sign up" : "Sign Up for Danswer"}
</h2>
<div className="flex">
<Text className="mt-4 mx-auto">
Already have an account?{" "}
<Link href="/auth/login" className="text-link font-medium">
Log In
</Link>
</Text>
{cloud && authUrl && (
<div className="w-full justify-center">
<SignInButton authorizeUrl={authUrl} authType="cloud" />
<div className="flex items-center w-full my-4">
<div className="flex-grow border-t border-background-300"></div>
<span className="px-4 text-gray-500">or</span>
<div className="flex-grow border-t border-background-300"></div>
</div>
</div>
</Card>
)}
<EmailPasswordForm
isSignup
shouldVerify={authTypeMetadata?.requiresVerification}
/>
<div className="flex">
<Text className="mt-4 mx-auto">
Already have an account?{" "}
<Link href="/auth/login" className="text-link font-medium">
Log In
</Link>
</Text>
</div>
</div>
</div>
</main>
</>
</AuthFlowContainer>
);
};

View File

@ -7,7 +7,7 @@ export default function AuthFlowContainer({
}) {
return (
<div className="flex flex-col items-center justify-center min-h-screen bg-background">
<div className="w-full max-w-md p-8 gap-y-4 bg-white flex items-center flex-col rounded-xl shadow-lg border border-bacgkround-100">
<div className="w-full max-w-md bg-black p-8 mx-4 gap-y-4 bg-white flex items-center flex-col rounded-xl shadow-lg border border-bacgkround-100">
<Logo width={70} height={70} />
{children}
</div>

View File

@ -1,4 +1,10 @@
export type AuthType = "disabled" | "basic" | "google_oauth" | "oidc" | "saml";
export type AuthType =
| "disabled"
| "basic"
| "google_oauth"
| "oidc"
| "saml"
| "cloud";
export const HOST_URL = process.env.WEB_DOMAIN || "http://127.0.0.1:3000";
export const HEADER_HEIGHT = "h-16";

View File

@ -2,7 +2,7 @@ import { cookies } from "next/headers";
import { User } from "./types";
import { buildUrl } from "./utilsSS";
import { ReadonlyRequestCookies } from "next/dist/server/web/spec-extension/adapters/request-cookies";
import { AuthType } from "./constants";
import { AuthType, SERVER_SIDE_ONLY__CLOUD_ENABLED } from "./constants";
export interface AuthTypeMetadata {
authType: AuthType;
@ -18,7 +18,15 @@ export const getAuthTypeMetadataSS = async (): Promise<AuthTypeMetadata> => {
const data: { auth_type: string; requires_verification: boolean } =
await res.json();
const authType = data.auth_type as AuthType;
let authType: AuthType;
// Override fasapi users auth so we can use both
if (SERVER_SIDE_ONLY__CLOUD_ENABLED) {
authType = "cloud";
} else {
authType = data.auth_type as AuthType;
}
// for SAML / OIDC, we auto-redirect the user to the IdP when the user visits
// Danswer in an un-authenticated state
@ -87,6 +95,9 @@ export const getAuthUrlSS = async (
case "google_oauth": {
return await getGoogleOAuthUrlSS();
}
case "cloud": {
return await getGoogleOAuthUrlSS();
}
case "saml": {
return await getSAMLAuthUrlSS();
}