mirror of
https://github.com/danswer-ai/danswer.git
synced 2025-07-04 03:31:23 +02:00
Auth on main (#2878)
* add cloud auth type * k * robustified cloud auth type * k * minor typing
This commit is contained in:
@ -233,35 +233,60 @@ class UserManager(UUIDIDMixin, BaseUserManager[User, uuid.UUID]):
|
|||||||
safe: bool = False,
|
safe: bool = False,
|
||||||
request: Optional[Request] = None,
|
request: Optional[Request] = None,
|
||||||
) -> User:
|
) -> User:
|
||||||
verify_email_is_invited(user_create.email)
|
|
||||||
verify_email_domain(user_create.email)
|
|
||||||
if hasattr(user_create, "role"):
|
|
||||||
user_count = await get_user_count()
|
|
||||||
if user_count == 0 or user_create.email in get_default_admin_user_emails():
|
|
||||||
user_create.role = UserRole.ADMIN
|
|
||||||
else:
|
|
||||||
user_create.role = UserRole.BASIC
|
|
||||||
user = None
|
|
||||||
try:
|
try:
|
||||||
user = await super().create(user_create, safe=safe, request=request) # type: ignore
|
tenant_id = (
|
||||||
except exceptions.UserAlreadyExists:
|
get_tenant_id_for_email(user_create.email) if MULTI_TENANT else "public"
|
||||||
user = await self.get_by_email(user_create.email)
|
)
|
||||||
# Handle case where user has used product outside of web and is now creating an account through web
|
except exceptions.UserNotExists:
|
||||||
if (
|
raise HTTPException(status_code=401, detail="User not found")
|
||||||
not user.has_web_login
|
|
||||||
and hasattr(user_create, "has_web_login")
|
if not tenant_id:
|
||||||
and user_create.has_web_login
|
raise HTTPException(
|
||||||
):
|
status_code=401, detail="User does not belong to an organization"
|
||||||
user_update = UserUpdate(
|
)
|
||||||
password=user_create.password,
|
|
||||||
has_web_login=True,
|
async with get_async_session_with_tenant(tenant_id) as db_session:
|
||||||
role=user_create.role,
|
token = current_tenant_id.set(tenant_id)
|
||||||
is_verified=user_create.is_verified,
|
|
||||||
)
|
verify_email_is_invited(user_create.email)
|
||||||
user = await self.update(user_update, user)
|
verify_email_domain(user_create.email)
|
||||||
else:
|
if MULTI_TENANT:
|
||||||
raise exceptions.UserAlreadyExists()
|
tenant_user_db = SQLAlchemyUserAdminDB(db_session, User, OAuthAccount)
|
||||||
return user
|
self.user_db = tenant_user_db
|
||||||
|
self.database = tenant_user_db
|
||||||
|
|
||||||
|
if hasattr(user_create, "role"):
|
||||||
|
user_count = await get_user_count()
|
||||||
|
if (
|
||||||
|
user_count == 0
|
||||||
|
or user_create.email in get_default_admin_user_emails()
|
||||||
|
):
|
||||||
|
user_create.role = UserRole.ADMIN
|
||||||
|
else:
|
||||||
|
user_create.role = UserRole.BASIC
|
||||||
|
user = None
|
||||||
|
try:
|
||||||
|
user = await super().create(user_create, safe=safe, request=request) # type: ignore
|
||||||
|
except exceptions.UserAlreadyExists:
|
||||||
|
user = await self.get_by_email(user_create.email)
|
||||||
|
# Handle case where user has used product outside of web and is now creating an account through web
|
||||||
|
if (
|
||||||
|
not user.has_web_login
|
||||||
|
and hasattr(user_create, "has_web_login")
|
||||||
|
and user_create.has_web_login
|
||||||
|
):
|
||||||
|
user_update = UserUpdate(
|
||||||
|
password=user_create.password,
|
||||||
|
has_web_login=True,
|
||||||
|
role=user_create.role,
|
||||||
|
is_verified=user_create.is_verified,
|
||||||
|
)
|
||||||
|
user = await self.update(user_update, user)
|
||||||
|
else:
|
||||||
|
raise exceptions.UserAlreadyExists()
|
||||||
|
|
||||||
|
current_tenant_id.reset(token)
|
||||||
|
return user
|
||||||
|
|
||||||
async def on_after_login(
|
async def on_after_login(
|
||||||
self,
|
self,
|
||||||
@ -320,7 +345,7 @@ class UserManager(UUIDIDMixin, BaseUserManager[User, uuid.UUID]):
|
|||||||
if MULTI_TENANT:
|
if MULTI_TENANT:
|
||||||
tenant_user_db = SQLAlchemyUserAdminDB(db_session, User, OAuthAccount)
|
tenant_user_db = SQLAlchemyUserAdminDB(db_session, User, OAuthAccount)
|
||||||
self.user_db = tenant_user_db
|
self.user_db = tenant_user_db
|
||||||
self.database = tenant_user_db
|
self.database = tenant_user_db # type: ignore
|
||||||
|
|
||||||
oauth_account_dict = {
|
oauth_account_dict = {
|
||||||
"oauth_name": oauth_name,
|
"oauth_name": oauth_name,
|
||||||
|
@ -160,6 +160,9 @@ class AuthType(str, Enum):
|
|||||||
OIDC = "oidc"
|
OIDC = "oidc"
|
||||||
SAML = "saml"
|
SAML = "saml"
|
||||||
|
|
||||||
|
# google auth and basic
|
||||||
|
CLOUD = "cloud"
|
||||||
|
|
||||||
|
|
||||||
class SessionType(str, Enum):
|
class SessionType(str, Enum):
|
||||||
CHAT = "Chat"
|
CHAT = "Chat"
|
||||||
|
@ -269,7 +269,7 @@ def get_application() -> FastAPI:
|
|||||||
# Server logs this during auth setup verification step
|
# Server logs this during auth setup verification step
|
||||||
pass
|
pass
|
||||||
|
|
||||||
elif AUTH_TYPE == AuthType.BASIC:
|
if AUTH_TYPE == AuthType.BASIC or AUTH_TYPE == AuthType.CLOUD:
|
||||||
include_router_with_global_prefix_prepended(
|
include_router_with_global_prefix_prepended(
|
||||||
application,
|
application,
|
||||||
fastapi_users.get_auth_router(auth_backend),
|
fastapi_users.get_auth_router(auth_backend),
|
||||||
@ -301,7 +301,7 @@ def get_application() -> FastAPI:
|
|||||||
tags=["users"],
|
tags=["users"],
|
||||||
)
|
)
|
||||||
|
|
||||||
elif AUTH_TYPE == AuthType.GOOGLE_OAUTH:
|
if AUTH_TYPE == AuthType.GOOGLE_OAUTH or AUTH_TYPE == AuthType.CLOUD:
|
||||||
oauth_client = GoogleOAuth2(OAUTH_CLIENT_ID, OAUTH_CLIENT_SECRET)
|
oauth_client = GoogleOAuth2(OAUTH_CLIENT_ID, OAUTH_CLIENT_SECRET)
|
||||||
include_router_with_global_prefix_prepended(
|
include_router_with_global_prefix_prepended(
|
||||||
application,
|
application,
|
||||||
|
@ -9,7 +9,7 @@ export function SignInButton({
|
|||||||
authType: AuthType;
|
authType: AuthType;
|
||||||
}) {
|
}) {
|
||||||
let button;
|
let button;
|
||||||
if (authType === "google_oauth") {
|
if (authType === "google_oauth" || authType === "cloud") {
|
||||||
button = (
|
button = (
|
||||||
<div className="mx-auto flex">
|
<div className="mx-auto flex">
|
||||||
<div className="my-auto mr-2">
|
<div className="my-auto mr-2">
|
||||||
@ -42,7 +42,7 @@ export function SignInButton({
|
|||||||
|
|
||||||
return (
|
return (
|
||||||
<a
|
<a
|
||||||
className="mt-6 py-3 w-72 text-text-100 bg-accent flex rounded cursor-pointer hover:bg-indigo-800"
|
className="mx-auto mt-6 py-3 w-72 text-text-100 bg-accent flex rounded cursor-pointer hover:bg-indigo-800"
|
||||||
href={authorizeUrl}
|
href={authorizeUrl}
|
||||||
>
|
>
|
||||||
{button}
|
{button}
|
||||||
|
@ -78,7 +78,7 @@ const Page = async ({
|
|||||||
<HealthCheckBanner />
|
<HealthCheckBanner />
|
||||||
</div>
|
</div>
|
||||||
|
|
||||||
<div>
|
<div className="flex flex-col w-full justify-center">
|
||||||
{authUrl && authTypeMetadata && (
|
{authUrl && authTypeMetadata && (
|
||||||
<>
|
<>
|
||||||
<h2 className="text-center text-xl text-strong font-bold">
|
<h2 className="text-center text-xl text-strong font-bold">
|
||||||
@ -92,6 +92,26 @@ const Page = async ({
|
|||||||
</>
|
</>
|
||||||
)}
|
)}
|
||||||
|
|
||||||
|
{authTypeMetadata?.authType === "cloud" && (
|
||||||
|
<div className="mt-4 w-full justify-center">
|
||||||
|
<div className="flex items-center w-full my-4">
|
||||||
|
<div className="flex-grow border-t border-gray-300"></div>
|
||||||
|
<span className="px-4 text-gray-500">or</span>
|
||||||
|
<div className="flex-grow border-t border-gray-300"></div>
|
||||||
|
</div>
|
||||||
|
<EmailPasswordForm shouldVerify={true} />
|
||||||
|
|
||||||
|
<div className="flex">
|
||||||
|
<Text className="mt-4 mx-auto">
|
||||||
|
Don't have an account?{" "}
|
||||||
|
<Link href="/auth/signup" className="text-link font-medium">
|
||||||
|
Create an account
|
||||||
|
</Link>
|
||||||
|
</Text>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
)}
|
||||||
|
|
||||||
{authTypeMetadata?.authType === "basic" && (
|
{authTypeMetadata?.authType === "basic" && (
|
||||||
<Card className="mt-4 w-96">
|
<Card className="mt-4 w-96">
|
||||||
<div className="flex">
|
<div className="flex">
|
||||||
|
@ -4,6 +4,7 @@ import {
|
|||||||
getCurrentUserSS,
|
getCurrentUserSS,
|
||||||
getAuthTypeMetadataSS,
|
getAuthTypeMetadataSS,
|
||||||
AuthTypeMetadata,
|
AuthTypeMetadata,
|
||||||
|
getAuthUrlSS,
|
||||||
} from "@/lib/userSS";
|
} from "@/lib/userSS";
|
||||||
import { redirect } from "next/navigation";
|
import { redirect } from "next/navigation";
|
||||||
import { EmailPasswordForm } from "../login/EmailPasswordForm";
|
import { EmailPasswordForm } from "../login/EmailPasswordForm";
|
||||||
@ -11,6 +12,8 @@ import { Card, Title, Text } from "@tremor/react";
|
|||||||
import Link from "next/link";
|
import Link from "next/link";
|
||||||
import { Logo } from "@/components/Logo";
|
import { Logo } from "@/components/Logo";
|
||||||
import { CLOUD_ENABLED } from "@/lib/constants";
|
import { CLOUD_ENABLED } from "@/lib/constants";
|
||||||
|
import { SignInButton } from "../login/SignInButton";
|
||||||
|
import AuthFlowContainer from "@/components/auth/AuthFlowContainer";
|
||||||
|
|
||||||
const Page = async () => {
|
const Page = async () => {
|
||||||
// catch cases where the backend is completely unreachable here
|
// catch cases where the backend is completely unreachable here
|
||||||
@ -26,9 +29,6 @@ const Page = async () => {
|
|||||||
} catch (e) {
|
} catch (e) {
|
||||||
console.log(`Some fetch failed for the login page - ${e}`);
|
console.log(`Some fetch failed for the login page - ${e}`);
|
||||||
}
|
}
|
||||||
if (CLOUD_ENABLED) {
|
|
||||||
return redirect("/auth/login");
|
|
||||||
}
|
|
||||||
|
|
||||||
// simply take the user to the home page if Auth is disabled
|
// simply take the user to the home page if Auth is disabled
|
||||||
if (authTypeMetadata?.authType === "disabled") {
|
if (authTypeMetadata?.authType === "disabled") {
|
||||||
@ -42,44 +42,56 @@ const Page = async () => {
|
|||||||
}
|
}
|
||||||
return redirect("/auth/waiting-on-verification");
|
return redirect("/auth/waiting-on-verification");
|
||||||
}
|
}
|
||||||
|
const cloud = authTypeMetadata?.authType === "cloud";
|
||||||
|
|
||||||
// only enable this page if basic login is enabled
|
// only enable this page if basic login is enabled
|
||||||
if (authTypeMetadata?.authType !== "basic") {
|
if (authTypeMetadata?.authType !== "basic" && !cloud) {
|
||||||
return redirect("/");
|
return redirect("/");
|
||||||
}
|
}
|
||||||
|
|
||||||
|
let authUrl: string | null = null;
|
||||||
|
if (cloud && authTypeMetadata) {
|
||||||
|
authUrl = await getAuthUrlSS(authTypeMetadata.authType, null);
|
||||||
|
}
|
||||||
|
|
||||||
return (
|
return (
|
||||||
<main>
|
<AuthFlowContainer>
|
||||||
<div className="absolute top-10x w-full">
|
<HealthCheckBanner />
|
||||||
<HealthCheckBanner />
|
|
||||||
</div>
|
|
||||||
<div className="min-h-screen flex items-center justify-center py-12 px-4 sm:px-6 lg:px-8">
|
|
||||||
<div>
|
|
||||||
<Logo height={64} width={64} className="mx-auto w-fit" />
|
|
||||||
|
|
||||||
<Card className="mt-4 w-96">
|
<>
|
||||||
<div className="flex">
|
<div className="absolute top-10x w-full"></div>
|
||||||
<Title className="mb-2 mx-auto font-bold">
|
<div className="flex w-full flex-col justify-center">
|
||||||
Sign Up for Danswer
|
<h2 className="text-center text-xl text-strong font-bold">
|
||||||
</Title>
|
{cloud ? "Complete your sign up" : "Sign Up for Danswer"}
|
||||||
</div>
|
</h2>
|
||||||
<EmailPasswordForm
|
|
||||||
isSignup
|
|
||||||
shouldVerify={authTypeMetadata?.requiresVerification}
|
|
||||||
/>
|
|
||||||
|
|
||||||
<div className="flex">
|
{cloud && authUrl && (
|
||||||
<Text className="mt-4 mx-auto">
|
<div className="w-full justify-center">
|
||||||
Already have an account?{" "}
|
<SignInButton authorizeUrl={authUrl} authType="cloud" />
|
||||||
<Link href="/auth/login" className="text-link font-medium">
|
<div className="flex items-center w-full my-4">
|
||||||
Log In
|
<div className="flex-grow border-t border-background-300"></div>
|
||||||
</Link>
|
<span className="px-4 text-gray-500">or</span>
|
||||||
</Text>
|
<div className="flex-grow border-t border-background-300"></div>
|
||||||
|
</div>
|
||||||
</div>
|
</div>
|
||||||
</Card>
|
)}
|
||||||
|
|
||||||
|
<EmailPasswordForm
|
||||||
|
isSignup
|
||||||
|
shouldVerify={authTypeMetadata?.requiresVerification}
|
||||||
|
/>
|
||||||
|
|
||||||
|
<div className="flex">
|
||||||
|
<Text className="mt-4 mx-auto">
|
||||||
|
Already have an account?{" "}
|
||||||
|
<Link href="/auth/login" className="text-link font-medium">
|
||||||
|
Log In
|
||||||
|
</Link>
|
||||||
|
</Text>
|
||||||
|
</div>
|
||||||
</div>
|
</div>
|
||||||
</div>
|
</>
|
||||||
</main>
|
</AuthFlowContainer>
|
||||||
);
|
);
|
||||||
};
|
};
|
||||||
|
|
||||||
|
@ -7,7 +7,7 @@ export default function AuthFlowContainer({
|
|||||||
}) {
|
}) {
|
||||||
return (
|
return (
|
||||||
<div className="flex flex-col items-center justify-center min-h-screen bg-background">
|
<div className="flex flex-col items-center justify-center min-h-screen bg-background">
|
||||||
<div className="w-full max-w-md p-8 gap-y-4 bg-white flex items-center flex-col rounded-xl shadow-lg border border-bacgkround-100">
|
<div className="w-full max-w-md bg-black p-8 mx-4 gap-y-4 bg-white flex items-center flex-col rounded-xl shadow-lg border border-bacgkround-100">
|
||||||
<Logo width={70} height={70} />
|
<Logo width={70} height={70} />
|
||||||
{children}
|
{children}
|
||||||
</div>
|
</div>
|
||||||
|
@ -1,4 +1,10 @@
|
|||||||
export type AuthType = "disabled" | "basic" | "google_oauth" | "oidc" | "saml";
|
export type AuthType =
|
||||||
|
| "disabled"
|
||||||
|
| "basic"
|
||||||
|
| "google_oauth"
|
||||||
|
| "oidc"
|
||||||
|
| "saml"
|
||||||
|
| "cloud";
|
||||||
|
|
||||||
export const HOST_URL = process.env.WEB_DOMAIN || "http://127.0.0.1:3000";
|
export const HOST_URL = process.env.WEB_DOMAIN || "http://127.0.0.1:3000";
|
||||||
export const HEADER_HEIGHT = "h-16";
|
export const HEADER_HEIGHT = "h-16";
|
||||||
|
@ -2,7 +2,7 @@ import { cookies } from "next/headers";
|
|||||||
import { User } from "./types";
|
import { User } from "./types";
|
||||||
import { buildUrl } from "./utilsSS";
|
import { buildUrl } from "./utilsSS";
|
||||||
import { ReadonlyRequestCookies } from "next/dist/server/web/spec-extension/adapters/request-cookies";
|
import { ReadonlyRequestCookies } from "next/dist/server/web/spec-extension/adapters/request-cookies";
|
||||||
import { AuthType } from "./constants";
|
import { AuthType, SERVER_SIDE_ONLY__CLOUD_ENABLED } from "./constants";
|
||||||
|
|
||||||
export interface AuthTypeMetadata {
|
export interface AuthTypeMetadata {
|
||||||
authType: AuthType;
|
authType: AuthType;
|
||||||
@ -18,7 +18,15 @@ export const getAuthTypeMetadataSS = async (): Promise<AuthTypeMetadata> => {
|
|||||||
|
|
||||||
const data: { auth_type: string; requires_verification: boolean } =
|
const data: { auth_type: string; requires_verification: boolean } =
|
||||||
await res.json();
|
await res.json();
|
||||||
const authType = data.auth_type as AuthType;
|
|
||||||
|
let authType: AuthType;
|
||||||
|
|
||||||
|
// Override fasapi users auth so we can use both
|
||||||
|
if (SERVER_SIDE_ONLY__CLOUD_ENABLED) {
|
||||||
|
authType = "cloud";
|
||||||
|
} else {
|
||||||
|
authType = data.auth_type as AuthType;
|
||||||
|
}
|
||||||
|
|
||||||
// for SAML / OIDC, we auto-redirect the user to the IdP when the user visits
|
// for SAML / OIDC, we auto-redirect the user to the IdP when the user visits
|
||||||
// Danswer in an un-authenticated state
|
// Danswer in an un-authenticated state
|
||||||
@ -87,6 +95,9 @@ export const getAuthUrlSS = async (
|
|||||||
case "google_oauth": {
|
case "google_oauth": {
|
||||||
return await getGoogleOAuthUrlSS();
|
return await getGoogleOAuthUrlSS();
|
||||||
}
|
}
|
||||||
|
case "cloud": {
|
||||||
|
return await getGoogleOAuthUrlSS();
|
||||||
|
}
|
||||||
case "saml": {
|
case "saml": {
|
||||||
return await getSAMLAuthUrlSS();
|
return await getSAMLAuthUrlSS();
|
||||||
}
|
}
|
||||||
|
Reference in New Issue
Block a user