Auth on main (#2878)

* add cloud auth type

* k

* robustified cloud auth type

* k

* minor typing
This commit is contained in:
pablodanswer
2024-10-23 09:46:30 -07:00
committed by GitHub
parent 9105f95d13
commit 5703ea47d2
9 changed files with 146 additions and 69 deletions

View File

@ -233,35 +233,60 @@ class UserManager(UUIDIDMixin, BaseUserManager[User, uuid.UUID]):
safe: bool = False, safe: bool = False,
request: Optional[Request] = None, request: Optional[Request] = None,
) -> User: ) -> User:
verify_email_is_invited(user_create.email)
verify_email_domain(user_create.email)
if hasattr(user_create, "role"):
user_count = await get_user_count()
if user_count == 0 or user_create.email in get_default_admin_user_emails():
user_create.role = UserRole.ADMIN
else:
user_create.role = UserRole.BASIC
user = None
try: try:
user = await super().create(user_create, safe=safe, request=request) # type: ignore tenant_id = (
except exceptions.UserAlreadyExists: get_tenant_id_for_email(user_create.email) if MULTI_TENANT else "public"
user = await self.get_by_email(user_create.email) )
# Handle case where user has used product outside of web and is now creating an account through web except exceptions.UserNotExists:
if ( raise HTTPException(status_code=401, detail="User not found")
not user.has_web_login
and hasattr(user_create, "has_web_login") if not tenant_id:
and user_create.has_web_login raise HTTPException(
): status_code=401, detail="User does not belong to an organization"
user_update = UserUpdate( )
password=user_create.password,
has_web_login=True, async with get_async_session_with_tenant(tenant_id) as db_session:
role=user_create.role, token = current_tenant_id.set(tenant_id)
is_verified=user_create.is_verified,
) verify_email_is_invited(user_create.email)
user = await self.update(user_update, user) verify_email_domain(user_create.email)
else: if MULTI_TENANT:
raise exceptions.UserAlreadyExists() tenant_user_db = SQLAlchemyUserAdminDB(db_session, User, OAuthAccount)
return user self.user_db = tenant_user_db
self.database = tenant_user_db
if hasattr(user_create, "role"):
user_count = await get_user_count()
if (
user_count == 0
or user_create.email in get_default_admin_user_emails()
):
user_create.role = UserRole.ADMIN
else:
user_create.role = UserRole.BASIC
user = None
try:
user = await super().create(user_create, safe=safe, request=request) # type: ignore
except exceptions.UserAlreadyExists:
user = await self.get_by_email(user_create.email)
# Handle case where user has used product outside of web and is now creating an account through web
if (
not user.has_web_login
and hasattr(user_create, "has_web_login")
and user_create.has_web_login
):
user_update = UserUpdate(
password=user_create.password,
has_web_login=True,
role=user_create.role,
is_verified=user_create.is_verified,
)
user = await self.update(user_update, user)
else:
raise exceptions.UserAlreadyExists()
current_tenant_id.reset(token)
return user
async def on_after_login( async def on_after_login(
self, self,
@ -320,7 +345,7 @@ class UserManager(UUIDIDMixin, BaseUserManager[User, uuid.UUID]):
if MULTI_TENANT: if MULTI_TENANT:
tenant_user_db = SQLAlchemyUserAdminDB(db_session, User, OAuthAccount) tenant_user_db = SQLAlchemyUserAdminDB(db_session, User, OAuthAccount)
self.user_db = tenant_user_db self.user_db = tenant_user_db
self.database = tenant_user_db self.database = tenant_user_db # type: ignore
oauth_account_dict = { oauth_account_dict = {
"oauth_name": oauth_name, "oauth_name": oauth_name,

View File

@ -160,6 +160,9 @@ class AuthType(str, Enum):
OIDC = "oidc" OIDC = "oidc"
SAML = "saml" SAML = "saml"
# google auth and basic
CLOUD = "cloud"
class SessionType(str, Enum): class SessionType(str, Enum):
CHAT = "Chat" CHAT = "Chat"

View File

@ -269,7 +269,7 @@ def get_application() -> FastAPI:
# Server logs this during auth setup verification step # Server logs this during auth setup verification step
pass pass
elif AUTH_TYPE == AuthType.BASIC: if AUTH_TYPE == AuthType.BASIC or AUTH_TYPE == AuthType.CLOUD:
include_router_with_global_prefix_prepended( include_router_with_global_prefix_prepended(
application, application,
fastapi_users.get_auth_router(auth_backend), fastapi_users.get_auth_router(auth_backend),
@ -301,7 +301,7 @@ def get_application() -> FastAPI:
tags=["users"], tags=["users"],
) )
elif AUTH_TYPE == AuthType.GOOGLE_OAUTH: if AUTH_TYPE == AuthType.GOOGLE_OAUTH or AUTH_TYPE == AuthType.CLOUD:
oauth_client = GoogleOAuth2(OAUTH_CLIENT_ID, OAUTH_CLIENT_SECRET) oauth_client = GoogleOAuth2(OAUTH_CLIENT_ID, OAUTH_CLIENT_SECRET)
include_router_with_global_prefix_prepended( include_router_with_global_prefix_prepended(
application, application,

View File

@ -9,7 +9,7 @@ export function SignInButton({
authType: AuthType; authType: AuthType;
}) { }) {
let button; let button;
if (authType === "google_oauth") { if (authType === "google_oauth" || authType === "cloud") {
button = ( button = (
<div className="mx-auto flex"> <div className="mx-auto flex">
<div className="my-auto mr-2"> <div className="my-auto mr-2">
@ -42,7 +42,7 @@ export function SignInButton({
return ( return (
<a <a
className="mt-6 py-3 w-72 text-text-100 bg-accent flex rounded cursor-pointer hover:bg-indigo-800" className="mx-auto mt-6 py-3 w-72 text-text-100 bg-accent flex rounded cursor-pointer hover:bg-indigo-800"
href={authorizeUrl} href={authorizeUrl}
> >
{button} {button}

View File

@ -78,7 +78,7 @@ const Page = async ({
<HealthCheckBanner /> <HealthCheckBanner />
</div> </div>
<div> <div className="flex flex-col w-full justify-center">
{authUrl && authTypeMetadata && ( {authUrl && authTypeMetadata && (
<> <>
<h2 className="text-center text-xl text-strong font-bold"> <h2 className="text-center text-xl text-strong font-bold">
@ -92,6 +92,26 @@ const Page = async ({
</> </>
)} )}
{authTypeMetadata?.authType === "cloud" && (
<div className="mt-4 w-full justify-center">
<div className="flex items-center w-full my-4">
<div className="flex-grow border-t border-gray-300"></div>
<span className="px-4 text-gray-500">or</span>
<div className="flex-grow border-t border-gray-300"></div>
</div>
<EmailPasswordForm shouldVerify={true} />
<div className="flex">
<Text className="mt-4 mx-auto">
Don&apos;t have an account?{" "}
<Link href="/auth/signup" className="text-link font-medium">
Create an account
</Link>
</Text>
</div>
</div>
)}
{authTypeMetadata?.authType === "basic" && ( {authTypeMetadata?.authType === "basic" && (
<Card className="mt-4 w-96"> <Card className="mt-4 w-96">
<div className="flex"> <div className="flex">

View File

@ -4,6 +4,7 @@ import {
getCurrentUserSS, getCurrentUserSS,
getAuthTypeMetadataSS, getAuthTypeMetadataSS,
AuthTypeMetadata, AuthTypeMetadata,
getAuthUrlSS,
} from "@/lib/userSS"; } from "@/lib/userSS";
import { redirect } from "next/navigation"; import { redirect } from "next/navigation";
import { EmailPasswordForm } from "../login/EmailPasswordForm"; import { EmailPasswordForm } from "../login/EmailPasswordForm";
@ -11,6 +12,8 @@ import { Card, Title, Text } from "@tremor/react";
import Link from "next/link"; import Link from "next/link";
import { Logo } from "@/components/Logo"; import { Logo } from "@/components/Logo";
import { CLOUD_ENABLED } from "@/lib/constants"; import { CLOUD_ENABLED } from "@/lib/constants";
import { SignInButton } from "../login/SignInButton";
import AuthFlowContainer from "@/components/auth/AuthFlowContainer";
const Page = async () => { const Page = async () => {
// catch cases where the backend is completely unreachable here // catch cases where the backend is completely unreachable here
@ -26,9 +29,6 @@ const Page = async () => {
} catch (e) { } catch (e) {
console.log(`Some fetch failed for the login page - ${e}`); console.log(`Some fetch failed for the login page - ${e}`);
} }
if (CLOUD_ENABLED) {
return redirect("/auth/login");
}
// simply take the user to the home page if Auth is disabled // simply take the user to the home page if Auth is disabled
if (authTypeMetadata?.authType === "disabled") { if (authTypeMetadata?.authType === "disabled") {
@ -42,44 +42,56 @@ const Page = async () => {
} }
return redirect("/auth/waiting-on-verification"); return redirect("/auth/waiting-on-verification");
} }
const cloud = authTypeMetadata?.authType === "cloud";
// only enable this page if basic login is enabled // only enable this page if basic login is enabled
if (authTypeMetadata?.authType !== "basic") { if (authTypeMetadata?.authType !== "basic" && !cloud) {
return redirect("/"); return redirect("/");
} }
let authUrl: string | null = null;
if (cloud && authTypeMetadata) {
authUrl = await getAuthUrlSS(authTypeMetadata.authType, null);
}
return ( return (
<main> <AuthFlowContainer>
<div className="absolute top-10x w-full"> <HealthCheckBanner />
<HealthCheckBanner />
</div>
<div className="min-h-screen flex items-center justify-center py-12 px-4 sm:px-6 lg:px-8">
<div>
<Logo height={64} width={64} className="mx-auto w-fit" />
<Card className="mt-4 w-96"> <>
<div className="flex"> <div className="absolute top-10x w-full"></div>
<Title className="mb-2 mx-auto font-bold"> <div className="flex w-full flex-col justify-center">
Sign Up for Danswer <h2 className="text-center text-xl text-strong font-bold">
</Title> {cloud ? "Complete your sign up" : "Sign Up for Danswer"}
</div> </h2>
<EmailPasswordForm
isSignup
shouldVerify={authTypeMetadata?.requiresVerification}
/>
<div className="flex"> {cloud && authUrl && (
<Text className="mt-4 mx-auto"> <div className="w-full justify-center">
Already have an account?{" "} <SignInButton authorizeUrl={authUrl} authType="cloud" />
<Link href="/auth/login" className="text-link font-medium"> <div className="flex items-center w-full my-4">
Log In <div className="flex-grow border-t border-background-300"></div>
</Link> <span className="px-4 text-gray-500">or</span>
</Text> <div className="flex-grow border-t border-background-300"></div>
</div>
</div> </div>
</Card> )}
<EmailPasswordForm
isSignup
shouldVerify={authTypeMetadata?.requiresVerification}
/>
<div className="flex">
<Text className="mt-4 mx-auto">
Already have an account?{" "}
<Link href="/auth/login" className="text-link font-medium">
Log In
</Link>
</Text>
</div>
</div> </div>
</div> </>
</main> </AuthFlowContainer>
); );
}; };

View File

@ -7,7 +7,7 @@ export default function AuthFlowContainer({
}) { }) {
return ( return (
<div className="flex flex-col items-center justify-center min-h-screen bg-background"> <div className="flex flex-col items-center justify-center min-h-screen bg-background">
<div className="w-full max-w-md p-8 gap-y-4 bg-white flex items-center flex-col rounded-xl shadow-lg border border-bacgkround-100"> <div className="w-full max-w-md bg-black p-8 mx-4 gap-y-4 bg-white flex items-center flex-col rounded-xl shadow-lg border border-bacgkround-100">
<Logo width={70} height={70} /> <Logo width={70} height={70} />
{children} {children}
</div> </div>

View File

@ -1,4 +1,10 @@
export type AuthType = "disabled" | "basic" | "google_oauth" | "oidc" | "saml"; export type AuthType =
| "disabled"
| "basic"
| "google_oauth"
| "oidc"
| "saml"
| "cloud";
export const HOST_URL = process.env.WEB_DOMAIN || "http://127.0.0.1:3000"; export const HOST_URL = process.env.WEB_DOMAIN || "http://127.0.0.1:3000";
export const HEADER_HEIGHT = "h-16"; export const HEADER_HEIGHT = "h-16";

View File

@ -2,7 +2,7 @@ import { cookies } from "next/headers";
import { User } from "./types"; import { User } from "./types";
import { buildUrl } from "./utilsSS"; import { buildUrl } from "./utilsSS";
import { ReadonlyRequestCookies } from "next/dist/server/web/spec-extension/adapters/request-cookies"; import { ReadonlyRequestCookies } from "next/dist/server/web/spec-extension/adapters/request-cookies";
import { AuthType } from "./constants"; import { AuthType, SERVER_SIDE_ONLY__CLOUD_ENABLED } from "./constants";
export interface AuthTypeMetadata { export interface AuthTypeMetadata {
authType: AuthType; authType: AuthType;
@ -18,7 +18,15 @@ export const getAuthTypeMetadataSS = async (): Promise<AuthTypeMetadata> => {
const data: { auth_type: string; requires_verification: boolean } = const data: { auth_type: string; requires_verification: boolean } =
await res.json(); await res.json();
const authType = data.auth_type as AuthType;
let authType: AuthType;
// Override fasapi users auth so we can use both
if (SERVER_SIDE_ONLY__CLOUD_ENABLED) {
authType = "cloud";
} else {
authType = data.auth_type as AuthType;
}
// for SAML / OIDC, we auto-redirect the user to the IdP when the user visits // for SAML / OIDC, we auto-redirect the user to the IdP when the user visits
// Danswer in an un-authenticated state // Danswer in an un-authenticated state
@ -87,6 +95,9 @@ export const getAuthUrlSS = async (
case "google_oauth": { case "google_oauth": {
return await getGoogleOAuthUrlSS(); return await getGoogleOAuthUrlSS();
} }
case "cloud": {
return await getGoogleOAuthUrlSS();
}
case "saml": { case "saml": {
return await getSAMLAuthUrlSS(); return await getSAMLAuthUrlSS();
} }