diff --git a/deployment/docker_compose/docker-compose.dev.yml b/deployment/docker_compose/docker-compose.dev.yml index df2f999528..28c7804836 100644 --- a/deployment/docker_compose/docker-compose.dev.yml +++ b/deployment/docker_compose/docker-compose.dev.yml @@ -129,7 +129,6 @@ services: environment: - INTERNAL_URL=http://api_server:8080 - WEB_DOMAIN=${WEB_DOMAIN:-} - - OAUTH_NAME=${OAUTH_NAME:-} relational_db: image: postgres:15.2-alpine restart: always diff --git a/web/src/app/auth/login/SignInButton.tsx b/web/src/app/auth/login/SignInButton.tsx new file mode 100644 index 0000000000..0d4e35fe61 --- /dev/null +++ b/web/src/app/auth/login/SignInButton.tsx @@ -0,0 +1,51 @@ +import { AuthType } from "@/lib/constants"; +import { FaGoogle } from "react-icons/fa"; + +export function SignInButton({ + authorizeUrl, + authType, +}: { + authorizeUrl: string; + authType: AuthType; +}) { + let button; + if (authType === "google_oauth") { + button = ( +
+
+ +
+

Continue with Google

+
+ ); + } else if (authType === "oidc") { + button = ( +
+

+ Continue with OIDC SSO +

+
+ ); + } else if (authType === "saml") { + button = ( +
+

+ Continue with SAML SSO +

+
+ ); + } + + if (!button) { + throw new Error(`Unhandled authType: ${authType}`); + } + + return ( + + {button} + + ); +} diff --git a/web/src/app/auth/login/page.tsx b/web/src/app/auth/login/page.tsx index da9baf018e..b70d49663f 100644 --- a/web/src/app/auth/login/page.tsx +++ b/web/src/app/auth/login/page.tsx @@ -1,25 +1,31 @@ import { HealthCheckBanner } from "@/components/health/healthcheck"; -import { AuthType, OAUTH_NAME } from "@/lib/constants"; import { User } from "@/lib/types"; -import { getCurrentUserSS, getAuthUrlSS, getAuthTypeSS } from "@/lib/userSS"; +import { + getCurrentUserSS, + getAuthUrlSS, + getAuthTypeMetadataSS, + AuthTypeMetadata, +} from "@/lib/userSS"; import { redirect } from "next/navigation"; import { getWebVersion, getBackendVersion } from "@/lib/version"; +import Image from "next/image"; +import { SignInButton } from "./SignInButton"; -const BUTTON_STYLE = - "group relative w-64 flex justify-center " + - "py-2 px-4 border border-transparent text-sm " + - "font-medium rounded-md text-white bg-red-600 " + - " mx-auto"; +const Page = async ({ + searchParams, +}: { + searchParams?: { [key: string]: string | string[] | undefined }; +}) => { + const autoRedirectDisabled = searchParams?.disableAutoRedirect === "true"; -const Page = async () => { // catch cases where the backend is completely unreachable here // without try / catch, will just raise an exception and the page // will not render - let authType: AuthType | null = null; + let authTypeMetadata: AuthTypeMetadata | null = null; let currentUser: User | null = null; try { - [authType, currentUser] = await Promise.all([ - getAuthTypeSS(), + [authTypeMetadata, currentUser] = await Promise.all([ + getAuthTypeMetadataSS(), getCurrentUserSS(), ]); } catch (e) { @@ -38,7 +44,7 @@ const Page = async () => { } // simply take the user to the home page if Auth is disabled - if (authType === "disabled") { + if (authTypeMetadata?.authType === "disabled") { return redirect("/"); } @@ -49,16 +55,15 @@ const Page = async () => { // get where to send the user to authenticate let authUrl: string | null = null; - let autoRedirect: boolean = false; - if (authType) { + if (authTypeMetadata) { try { - [authUrl, autoRedirect] = await getAuthUrlSS(authType); + authUrl = await getAuthUrlSS(authTypeMetadata.authType); } catch (e) { console.log(`Some fetch failed for the login page - ${e}`); } } - if (autoRedirect && authUrl) { + if (authTypeMetadata?.autoRedirect && authUrl && !autoRedirectDisabled) { return redirect(authUrl); } @@ -68,33 +73,23 @@ const Page = async () => {
-
-
-

- danswer 💃 -

-
-
- {authUrl ? ( - - Sign in with {OAUTH_NAME} - - ) : ( - - )} +
+
+ Logo
+

+ Log In to Danswer +

+ {authUrl && authTypeMetadata && ( + + )}
-
- VERSION w{web_version} b{backend_version} -
+
+
+ VERSION w{web_version} b{backend_version}
); diff --git a/web/src/app/auth/logout/route.ts b/web/src/app/auth/logout/route.ts index 3025cb8acc..7de902c7ac 100644 --- a/web/src/app/auth/logout/route.ts +++ b/web/src/app/auth/logout/route.ts @@ -1,14 +1,13 @@ -import { getDomain } from "@/lib/redirectSS"; -import { getAuthTypeSS, logoutSS } from "@/lib/userSS"; -import { NextRequest, NextResponse } from "next/server"; +import { getAuthTypeMetadataSS, logoutSS } from "@/lib/userSS"; +import { NextRequest } from "next/server"; export const POST = async (request: NextRequest) => { // Directs the logout request to the appropriate FastAPI endpoint. // Needed since env variables don't work well on the client-side - const authType = await getAuthTypeSS(); - const response = await logoutSS(authType, request.headers); - if (response && response.ok) { - return NextResponse.redirect(new URL("/auth/login", getDomain(request))); + const authTypeMetadata = await getAuthTypeMetadataSS(); + const response = await logoutSS(authTypeMetadata.authType, request.headers); + if (!response || response.ok) { + return new Response(null, { status: 204 }); } - return new Response(null, { status: 204 }); + return new Response(response.body, { status: response?.status }); }; diff --git a/web/src/components/Header.tsx b/web/src/components/Header.tsx index 2f6225816e..8e07538503 100644 --- a/web/src/components/Header.tsx +++ b/web/src/components/Header.tsx @@ -17,13 +17,14 @@ export const Header: React.FC = ({ user }) => { const [dropdownOpen, setDropdownOpen] = useState(false); const dropdownRef = useRef(null); - const handleLogout = () => { - logout().then((isSuccess) => { - if (!isSuccess) { - alert("Failed to logout"); - } - router.push("/auth/login"); - }); + const handleLogout = async () => { + const response = await logout(); + if (!response.ok) { + alert("Failed to logout"); + } + // disable auto-redirect immediately after logging out so the user + // is not immediately re-logged in + router.push("/auth/login?disableAutoRedirect=true"); }; // When dropdownOpen state changes, it attaches/removes the click listener diff --git a/web/src/lib/constants.ts b/web/src/lib/constants.ts index 97eb3350ee..645d8b795a 100644 --- a/web/src/lib/constants.ts +++ b/web/src/lib/constants.ts @@ -1,7 +1,5 @@ export type AuthType = "disabled" | "google_oauth" | "oidc" | "saml"; -export const OAUTH_NAME = process.env.OAUTH_NAME || "Google"; - export const INTERNAL_URL = process.env.INTERNAL_URL || "http://127.0.0.1:8080"; export const NEXT_PUBLIC_DISABLE_STREAMING = process.env.NEXT_PUBLIC_DISABLE_STREAMING?.toLowerCase() === "true"; diff --git a/web/src/lib/user.ts b/web/src/lib/user.ts index f4d67acff1..b3a9133331 100644 --- a/web/src/lib/user.ts +++ b/web/src/lib/user.ts @@ -12,10 +12,10 @@ export const getCurrentUser = async (): Promise => { return user; }; -export const logout = async (): Promise => { +export const logout = async (): Promise => { const response = await fetch("/auth/logout", { method: "POST", credentials: "include", }); - return response.ok; + return response; }; diff --git a/web/src/lib/userSS.ts b/web/src/lib/userSS.ts index 300e837acf..0402c68a6a 100644 --- a/web/src/lib/userSS.ts +++ b/web/src/lib/userSS.ts @@ -4,18 +4,30 @@ import { buildUrl } from "./utilsSS"; import { ReadonlyRequestCookies } from "next/dist/server/web/spec-extension/adapters/request-cookies"; import { AuthType } from "./constants"; -export const getAuthTypeSS = async (): Promise => { +export interface AuthTypeMetadata { + authType: AuthType; + autoRedirect: boolean; +} + +export const getAuthTypeMetadataSS = async (): Promise => { const res = await fetch(buildUrl("/auth/type")); if (!res.ok) { throw new Error("Failed to fetch data"); } const data: { auth_type: string } = await res.json(); - return data.auth_type as AuthType; + const authType = data.auth_type as AuthType; + + // for SAML / OIDC, we auto-redirect the user to the IdP when the user visits + // Danswer in an un-authenticated state + if (authType === "oidc" || authType === "saml") { + return { authType, autoRedirect: true }; + } + return { authType, autoRedirect: false }; }; export const getAuthDisabledSS = async (): Promise => { - return (await getAuthTypeSS()) === "disabled"; + return (await getAuthTypeMetadataSS()).authType === "disabled"; }; const geOIDCAuthUrlSS = async (): Promise => { @@ -48,21 +60,19 @@ const getSAMLAuthUrlSS = async (): Promise => { return data.authorization_url; }; -export const getAuthUrlSS = async ( - authType: AuthType -): Promise<[string, boolean]> => { - // Returns the auth url and whether or not we should auto-redirect +export const getAuthUrlSS = async (authType: AuthType): Promise => { + // Returns the auth url for the given auth type switch (authType) { case "disabled": - return ["", true]; + return ""; case "google_oauth": { - return [await getGoogleOAuthUrlSS(), false]; + return await getGoogleOAuthUrlSS(); } case "saml": { - return [await getSAMLAuthUrlSS(), true]; + return await getSAMLAuthUrlSS(); } case "oidc": { - return [await geOIDCAuthUrlSS(), true]; + return await geOIDCAuthUrlSS(); } } };