Fix logout redirect

This commit is contained in:
Weves 2023-11-17 15:35:36 -08:00 committed by Chris Weaver
parent ae72cd56f8
commit 37c5f24d91
8 changed files with 125 additions and 72 deletions

View File

@ -129,7 +129,6 @@ services:
environment: environment:
- INTERNAL_URL=http://api_server:8080 - INTERNAL_URL=http://api_server:8080
- WEB_DOMAIN=${WEB_DOMAIN:-} - WEB_DOMAIN=${WEB_DOMAIN:-}
- OAUTH_NAME=${OAUTH_NAME:-}
relational_db: relational_db:
image: postgres:15.2-alpine image: postgres:15.2-alpine
restart: always restart: always

View File

@ -0,0 +1,51 @@
import { AuthType } from "@/lib/constants";
import { FaGoogle } from "react-icons/fa";
export function SignInButton({
authorizeUrl,
authType,
}: {
authorizeUrl: string;
authType: AuthType;
}) {
let button;
if (authType === "google_oauth") {
button = (
<div className="mx-auto flex">
<div className="my-auto mr-2">
<FaGoogle />
</div>
<p className="text-sm font-medium select-none">Continue with Google</p>
</div>
);
} else if (authType === "oidc") {
button = (
<div className="mx-auto flex">
<p className="text-sm font-medium select-none">
Continue with OIDC SSO
</p>
</div>
);
} else if (authType === "saml") {
button = (
<div className="mx-auto flex">
<p className="text-sm font-medium select-none">
Continue with SAML SSO
</p>
</div>
);
}
if (!button) {
throw new Error(`Unhandled authType: ${authType}`);
}
return (
<a
className="mt-6 py-3 w-72 bg-blue-900 flex rounded cursor-pointer hover:bg-blue-950"
href={authorizeUrl}
>
{button}
</a>
);
}

View File

@ -1,25 +1,31 @@
import { HealthCheckBanner } from "@/components/health/healthcheck"; import { HealthCheckBanner } from "@/components/health/healthcheck";
import { AuthType, OAUTH_NAME } from "@/lib/constants";
import { User } from "@/lib/types"; import { User } from "@/lib/types";
import { getCurrentUserSS, getAuthUrlSS, getAuthTypeSS } from "@/lib/userSS"; import {
getCurrentUserSS,
getAuthUrlSS,
getAuthTypeMetadataSS,
AuthTypeMetadata,
} from "@/lib/userSS";
import { redirect } from "next/navigation"; import { redirect } from "next/navigation";
import { getWebVersion, getBackendVersion } from "@/lib/version"; import { getWebVersion, getBackendVersion } from "@/lib/version";
import Image from "next/image";
import { SignInButton } from "./SignInButton";
const BUTTON_STYLE = const Page = async ({
"group relative w-64 flex justify-center " + searchParams,
"py-2 px-4 border border-transparent text-sm " + }: {
"font-medium rounded-md text-white bg-red-600 " + searchParams?: { [key: string]: string | string[] | undefined };
" mx-auto"; }) => {
const autoRedirectDisabled = searchParams?.disableAutoRedirect === "true";
const Page = async () => {
// catch cases where the backend is completely unreachable here // catch cases where the backend is completely unreachable here
// without try / catch, will just raise an exception and the page // without try / catch, will just raise an exception and the page
// will not render // will not render
let authType: AuthType | null = null; let authTypeMetadata: AuthTypeMetadata | null = null;
let currentUser: User | null = null; let currentUser: User | null = null;
try { try {
[authType, currentUser] = await Promise.all([ [authTypeMetadata, currentUser] = await Promise.all([
getAuthTypeSS(), getAuthTypeMetadataSS(),
getCurrentUserSS(), getCurrentUserSS(),
]); ]);
} catch (e) { } catch (e) {
@ -38,7 +44,7 @@ const Page = async () => {
} }
// simply take the user to the home page if Auth is disabled // simply take the user to the home page if Auth is disabled
if (authType === "disabled") { if (authTypeMetadata?.authType === "disabled") {
return redirect("/"); return redirect("/");
} }
@ -49,16 +55,15 @@ const Page = async () => {
// get where to send the user to authenticate // get where to send the user to authenticate
let authUrl: string | null = null; let authUrl: string | null = null;
let autoRedirect: boolean = false; if (authTypeMetadata) {
if (authType) {
try { try {
[authUrl, autoRedirect] = await getAuthUrlSS(authType); authUrl = await getAuthUrlSS(authTypeMetadata.authType);
} catch (e) { } catch (e) {
console.log(`Some fetch failed for the login page - ${e}`); console.log(`Some fetch failed for the login page - ${e}`);
} }
} }
if (autoRedirect && authUrl) { if (authTypeMetadata?.autoRedirect && authUrl && !autoRedirectDisabled) {
return redirect(authUrl); return redirect(authUrl);
} }
@ -68,33 +73,23 @@ const Page = async () => {
<HealthCheckBanner /> <HealthCheckBanner />
</div> </div>
<div className="min-h-screen flex items-center justify-center py-12 px-4 sm:px-6 lg:px-8"> <div className="min-h-screen flex items-center justify-center py-12 px-4 sm:px-6 lg:px-8">
<div className="max-w-md w-full space-y-8"> <div>
<div> <div className="h-16 w-16 mx-auto">
<h2 className="mt-6 text-center text-3xl font-extrabold text-gray-200"> <Image src="/logo.png" alt="Logo" width="1419" height="1520" />
danswer 💃
</h2>
</div>
<div className="flex">
{authUrl ? (
<a
href={authUrl || ""}
className={
BUTTON_STYLE +
" focus:outline-none focus:ring-2 hover:bg-red-700 focus:ring-offset-2 focus:ring-red-500"
}
>
Sign in with {OAUTH_NAME}
</a>
) : (
<button className={BUTTON_STYLE + " cursor-default"}>
Sign in with {OAUTH_NAME}
</button>
)}
</div> </div>
<h2 className="text-center text-xl font-bold mt-4">
Log In to Danswer
</h2>
{authUrl && authTypeMetadata && (
<SignInButton
authorizeUrl={authUrl}
authType={authTypeMetadata?.authType}
/>
)}
</div> </div>
<div className="fixed bottom-4 right-4 z-50 text-slate-400 p-2"> </div>
VERSION w{web_version} b{backend_version} <div className="fixed bottom-4 right-4 z-50 text-slate-400 p-2">
</div> VERSION w{web_version} b{backend_version}
</div> </div>
</main> </main>
); );

View File

@ -1,14 +1,13 @@
import { getDomain } from "@/lib/redirectSS"; import { getAuthTypeMetadataSS, logoutSS } from "@/lib/userSS";
import { getAuthTypeSS, logoutSS } from "@/lib/userSS"; import { NextRequest } from "next/server";
import { NextRequest, NextResponse } from "next/server";
export const POST = async (request: NextRequest) => { export const POST = async (request: NextRequest) => {
// Directs the logout request to the appropriate FastAPI endpoint. // Directs the logout request to the appropriate FastAPI endpoint.
// Needed since env variables don't work well on the client-side // Needed since env variables don't work well on the client-side
const authType = await getAuthTypeSS(); const authTypeMetadata = await getAuthTypeMetadataSS();
const response = await logoutSS(authType, request.headers); const response = await logoutSS(authTypeMetadata.authType, request.headers);
if (response && response.ok) { if (!response || response.ok) {
return NextResponse.redirect(new URL("/auth/login", getDomain(request))); return new Response(null, { status: 204 });
} }
return new Response(null, { status: 204 }); return new Response(response.body, { status: response?.status });
}; };

View File

@ -17,13 +17,14 @@ export const Header: React.FC<HeaderProps> = ({ user }) => {
const [dropdownOpen, setDropdownOpen] = useState(false); const [dropdownOpen, setDropdownOpen] = useState(false);
const dropdownRef = useRef<HTMLDivElement>(null); const dropdownRef = useRef<HTMLDivElement>(null);
const handleLogout = () => { const handleLogout = async () => {
logout().then((isSuccess) => { const response = await logout();
if (!isSuccess) { if (!response.ok) {
alert("Failed to logout"); alert("Failed to logout");
} }
router.push("/auth/login"); // disable auto-redirect immediately after logging out so the user
}); // is not immediately re-logged in
router.push("/auth/login?disableAutoRedirect=true");
}; };
// When dropdownOpen state changes, it attaches/removes the click listener // When dropdownOpen state changes, it attaches/removes the click listener

View File

@ -1,7 +1,5 @@
export type AuthType = "disabled" | "google_oauth" | "oidc" | "saml"; export type AuthType = "disabled" | "google_oauth" | "oidc" | "saml";
export const OAUTH_NAME = process.env.OAUTH_NAME || "Google";
export const INTERNAL_URL = process.env.INTERNAL_URL || "http://127.0.0.1:8080"; export const INTERNAL_URL = process.env.INTERNAL_URL || "http://127.0.0.1:8080";
export const NEXT_PUBLIC_DISABLE_STREAMING = export const NEXT_PUBLIC_DISABLE_STREAMING =
process.env.NEXT_PUBLIC_DISABLE_STREAMING?.toLowerCase() === "true"; process.env.NEXT_PUBLIC_DISABLE_STREAMING?.toLowerCase() === "true";

View File

@ -12,10 +12,10 @@ export const getCurrentUser = async (): Promise<User | null> => {
return user; return user;
}; };
export const logout = async (): Promise<boolean> => { export const logout = async (): Promise<Response> => {
const response = await fetch("/auth/logout", { const response = await fetch("/auth/logout", {
method: "POST", method: "POST",
credentials: "include", credentials: "include",
}); });
return response.ok; return response;
}; };

View File

@ -4,18 +4,30 @@ import { buildUrl } from "./utilsSS";
import { ReadonlyRequestCookies } from "next/dist/server/web/spec-extension/adapters/request-cookies"; import { ReadonlyRequestCookies } from "next/dist/server/web/spec-extension/adapters/request-cookies";
import { AuthType } from "./constants"; import { AuthType } from "./constants";
export const getAuthTypeSS = async (): Promise<AuthType> => { export interface AuthTypeMetadata {
authType: AuthType;
autoRedirect: boolean;
}
export const getAuthTypeMetadataSS = async (): Promise<AuthTypeMetadata> => {
const res = await fetch(buildUrl("/auth/type")); const res = await fetch(buildUrl("/auth/type"));
if (!res.ok) { if (!res.ok) {
throw new Error("Failed to fetch data"); throw new Error("Failed to fetch data");
} }
const data: { auth_type: string } = await res.json(); const data: { auth_type: string } = await res.json();
return data.auth_type as AuthType; const authType = data.auth_type as AuthType;
// for SAML / OIDC, we auto-redirect the user to the IdP when the user visits
// Danswer in an un-authenticated state
if (authType === "oidc" || authType === "saml") {
return { authType, autoRedirect: true };
}
return { authType, autoRedirect: false };
}; };
export const getAuthDisabledSS = async (): Promise<boolean> => { export const getAuthDisabledSS = async (): Promise<boolean> => {
return (await getAuthTypeSS()) === "disabled"; return (await getAuthTypeMetadataSS()).authType === "disabled";
}; };
const geOIDCAuthUrlSS = async (): Promise<string> => { const geOIDCAuthUrlSS = async (): Promise<string> => {
@ -48,21 +60,19 @@ const getSAMLAuthUrlSS = async (): Promise<string> => {
return data.authorization_url; return data.authorization_url;
}; };
export const getAuthUrlSS = async ( export const getAuthUrlSS = async (authType: AuthType): Promise<string> => {
authType: AuthType // Returns the auth url for the given auth type
): Promise<[string, boolean]> => {
// Returns the auth url and whether or not we should auto-redirect
switch (authType) { switch (authType) {
case "disabled": case "disabled":
return ["", true]; return "";
case "google_oauth": { case "google_oauth": {
return [await getGoogleOAuthUrlSS(), false]; return await getGoogleOAuthUrlSS();
} }
case "saml": { case "saml": {
return [await getSAMLAuthUrlSS(), true]; return await getSAMLAuthUrlSS();
} }
case "oidc": { case "oidc": {
return [await geOIDCAuthUrlSS(), true]; return await geOIDCAuthUrlSS();
} }
} }
}; };