add cloud auth type

This commit is contained in:
pablodanswer 2024-10-20 12:34:58 -07:00
parent b4e975013c
commit 1cad9c7b3d
10 changed files with 47 additions and 12 deletions

View File

@ -140,7 +140,7 @@ POSTGRES_PASSWORD = urllib.parse.quote_plus(
os.environ.get("POSTGRES_PASSWORD") or "password" os.environ.get("POSTGRES_PASSWORD") or "password"
) )
POSTGRES_HOST = os.environ.get("POSTGRES_HOST") or "localhost" POSTGRES_HOST = os.environ.get("POSTGRES_HOST") or "localhost"
POSTGRES_PORT = os.environ.get("POSTGRES_PORT") or "5432" POSTGRES_PORT = os.environ.get("POSTGRES_PORT") or "5433"
POSTGRES_DB = os.environ.get("POSTGRES_DB") or "postgres" POSTGRES_DB = os.environ.get("POSTGRES_DB") or "postgres"
POSTGRES_API_SERVER_POOL_SIZE = int( POSTGRES_API_SERVER_POOL_SIZE = int(

View File

@ -160,6 +160,9 @@ class AuthType(str, Enum):
OIDC = "oidc" OIDC = "oidc"
SAML = "saml" SAML = "saml"
# google auth and basic
CLOUD = "cloud"
class SessionType(str, Enum): class SessionType(str, Enum):
CHAT = "Chat" CHAT = "Chat"

View File

@ -269,7 +269,7 @@ def get_application() -> FastAPI:
# Server logs this during auth setup verification step # Server logs this during auth setup verification step
pass pass
elif AUTH_TYPE == AuthType.BASIC: if AUTH_TYPE == AuthType.BASIC or AUTH_TYPE == AuthType.CLOUD:
include_router_with_global_prefix_prepended( include_router_with_global_prefix_prepended(
application, application,
fastapi_users.get_auth_router(auth_backend), fastapi_users.get_auth_router(auth_backend),
@ -301,7 +301,7 @@ def get_application() -> FastAPI:
tags=["users"], tags=["users"],
) )
elif AUTH_TYPE == AuthType.GOOGLE_OAUTH: if AUTH_TYPE == AuthType.GOOGLE_OAUTH or AUTH_TYPE == AuthType.CLOUD:
oauth_client = GoogleOAuth2(OAUTH_CLIENT_ID, OAUTH_CLIENT_SECRET) oauth_client = GoogleOAuth2(OAUTH_CLIENT_ID, OAUTH_CLIENT_SECRET)
include_router_with_global_prefix_prepended( include_router_with_global_prefix_prepended(
application, application,

View File

@ -313,7 +313,7 @@ services:
- POSTGRES_USER=${POSTGRES_USER:-postgres} - POSTGRES_USER=${POSTGRES_USER:-postgres}
- POSTGRES_PASSWORD=${POSTGRES_PASSWORD:-password} - POSTGRES_PASSWORD=${POSTGRES_PASSWORD:-password}
ports: ports:
- "5432:5432" - "5433:5432"
volumes: volumes:
- db_volume:/var/lib/postgresql/data - db_volume:/var/lib/postgresql/data

View File

@ -313,7 +313,7 @@ services:
- POSTGRES_USER=${POSTGRES_USER:-postgres} - POSTGRES_USER=${POSTGRES_USER:-postgres}
- POSTGRES_PASSWORD=${POSTGRES_PASSWORD:-password} - POSTGRES_PASSWORD=${POSTGRES_PASSWORD:-password}
ports: ports:
- "5432:5432" - "5433:5432"
volumes: volumes:
- db_volume:/var/lib/postgresql/data - db_volume:/var/lib/postgresql/data

View File

@ -157,7 +157,7 @@ services:
- POSTGRES_USER=${POSTGRES_USER:-postgres} - POSTGRES_USER=${POSTGRES_USER:-postgres}
- POSTGRES_PASSWORD=${POSTGRES_PASSWORD:-password} - POSTGRES_PASSWORD=${POSTGRES_PASSWORD:-password}
ports: ports:
- "5432" - "5433"
volumes: volumes:
- db_volume:/var/lib/postgresql/data - db_volume:/var/lib/postgresql/data

View File

@ -9,7 +9,7 @@ export function SignInButton({
authType: AuthType; authType: AuthType;
}) { }) {
let button; let button;
if (authType === "google_oauth") { if (authType === "google_oauth" || authType === "cloud") {
button = ( button = (
<div className="mx-auto flex"> <div className="mx-auto flex">
<div className="my-auto mr-2"> <div className="my-auto mr-2">
@ -42,7 +42,7 @@ export function SignInButton({
return ( return (
<a <a
className="mt-6 py-3 w-72 text-text-100 bg-accent flex rounded cursor-pointer hover:bg-indigo-800" className="mx-auto mt-6 py-3 w-72 text-text-100 bg-accent flex rounded cursor-pointer hover:bg-indigo-800"
href={authorizeUrl} href={authorizeUrl}
> >
{button} {button}

View File

@ -71,6 +71,8 @@ const Page = async ({
if (authTypeMetadata?.autoRedirect && authUrl && !autoRedirectDisabled) { if (authTypeMetadata?.autoRedirect && authUrl && !autoRedirectDisabled) {
return redirect(authUrl); return redirect(authUrl);
} }
console.log("authTypeMetadata");
console.log(authTypeMetadata);
return ( return (
<AuthFlowContainer> <AuthFlowContainer>
@ -78,7 +80,7 @@ const Page = async ({
<HealthCheckBanner /> <HealthCheckBanner />
</div> </div>
<div> <div className="flex flex-col justify-center">
{authUrl && authTypeMetadata && ( {authUrl && authTypeMetadata && (
<> <>
<h2 className="text-center text-xl text-strong font-bold"> <h2 className="text-center text-xl text-strong font-bold">
@ -92,6 +94,17 @@ const Page = async ({
</> </>
)} )}
{authTypeMetadata?.authType === "cloud" && (
<div className="mt-4 w-full justify-center">
<div className="flex items-center w-full my-4">
<div className="flex-grow border-t border-gray-300"></div>
<span className="px-4 text-gray-500">or</span>
<div className="flex-grow border-t border-gray-300"></div>
</div>
<EmailPasswordForm shouldVerify={true} />
</div>
)}
{authTypeMetadata?.authType === "basic" && ( {authTypeMetadata?.authType === "basic" && (
<Card className="mt-4 w-96"> <Card className="mt-4 w-96">
<div className="flex"> <div className="flex">

View File

@ -1,4 +1,10 @@
export type AuthType = "disabled" | "basic" | "google_oauth" | "oidc" | "saml"; export type AuthType =
| "disabled"
| "basic"
| "google_oauth"
| "oidc"
| "saml"
| "cloud";
export const HOST_URL = process.env.WEB_DOMAIN || "http://127.0.0.1:3000"; export const HOST_URL = process.env.WEB_DOMAIN || "http://127.0.0.1:3000";
export const HEADER_HEIGHT = "h-16"; export const HEADER_HEIGHT = "h-16";

View File

@ -2,7 +2,7 @@ import { cookies } from "next/headers";
import { User } from "./types"; import { User } from "./types";
import { buildUrl } from "./utilsSS"; import { buildUrl } from "./utilsSS";
import { ReadonlyRequestCookies } from "next/dist/server/web/spec-extension/adapters/request-cookies"; import { ReadonlyRequestCookies } from "next/dist/server/web/spec-extension/adapters/request-cookies";
import { AuthType } from "./constants"; import { AuthType, SERVER_SIDE_ONLY__CLOUD_ENABLED } from "./constants";
export interface AuthTypeMetadata { export interface AuthTypeMetadata {
authType: AuthType; authType: AuthType;
@ -18,7 +18,13 @@ export const getAuthTypeMetadataSS = async (): Promise<AuthTypeMetadata> => {
const data: { auth_type: string; requires_verification: boolean } = const data: { auth_type: string; requires_verification: boolean } =
await res.json(); await res.json();
const authType = data.auth_type as AuthType;
let authType: AuthType;
if (SERVER_SIDE_ONLY__CLOUD_ENABLED) {
authType = "cloud";
} else {
authType = data.auth_type as AuthType;
}
// for SAML / OIDC, we auto-redirect the user to the IdP when the user visits // for SAML / OIDC, we auto-redirect the user to the IdP when the user visits
// Danswer in an un-authenticated state // Danswer in an un-authenticated state
@ -78,6 +84,7 @@ export const getAuthUrlSS = async (
authType: AuthType, authType: AuthType,
nextUrl: string | null nextUrl: string | null
): Promise<string> => { ): Promise<string> => {
console.log(authType);
// Returns the auth url for the given auth type // Returns the auth url for the given auth type
switch (authType) { switch (authType) {
case "disabled": case "disabled":
@ -87,6 +94,12 @@ export const getAuthUrlSS = async (
case "google_oauth": { case "google_oauth": {
return await getGoogleOAuthUrlSS(); return await getGoogleOAuthUrlSS();
} }
case "cloud": {
console.log("returning cloud auth url");
const value = await getGoogleOAuthUrlSS();
console.log(value);
return value;
}
case "saml": { case "saml": {
return await getSAMLAuthUrlSS(); return await getSAMLAuthUrlSS();
} }