This commit is contained in:
pablodanswer
2024-10-20 14:41:12 -07:00
parent 3c9ccd32d3
commit d2346dc24e
9 changed files with 25 additions and 39 deletions

View File

@ -227,14 +227,6 @@ class UserManager(UUIDIDMixin, BaseUserManager[User, uuid.UUID]):
reset_password_token_secret = USER_AUTH_SECRET reset_password_token_secret = USER_AUTH_SECRET
verification_token_secret = USER_AUTH_SECRET verification_token_secret = USER_AUTH_SECRET
# async def register(
# self,
# user_create: schemas.UC | UserCreate,
# safe: bool = False,
# request: Optional[Request] = None,
# ) -> User:
# return await super().register(user_create, safe, request)
async def create( async def create(
self, self,
user_create: schemas.UC | UserCreate, user_create: schemas.UC | UserCreate,
@ -249,10 +241,12 @@ class UserManager(UUIDIDMixin, BaseUserManager[User, uuid.UUID]):
raise HTTPException(status_code=401, detail="User not found") raise HTTPException(status_code=401, detail="User not found")
if not tenant_id: if not tenant_id:
raise HTTPException(status_code=401, detail="User not found") raise HTTPException(
status_code=401, detail="User does not belong to an organization"
)
async with get_async_session_with_tenant(tenant_id) as db_session: async with get_async_session_with_tenant(tenant_id) as db_session:
current_tenant_id.set(tenant_id) token = current_tenant_id.set(tenant_id)
verify_email_is_invited(user_create.email) verify_email_is_invited(user_create.email)
verify_email_domain(user_create.email) verify_email_domain(user_create.email)
@ -290,6 +284,8 @@ class UserManager(UUIDIDMixin, BaseUserManager[User, uuid.UUID]):
user = await self.update(user_update, user) user = await self.update(user_update, user)
else: else:
raise exceptions.UserAlreadyExists() raise exceptions.UserAlreadyExists()
current_tenant_id.reset(token)
return user return user
async def on_after_login( async def on_after_login(

View File

@ -140,7 +140,7 @@ POSTGRES_PASSWORD = urllib.parse.quote_plus(
os.environ.get("POSTGRES_PASSWORD") or "password" os.environ.get("POSTGRES_PASSWORD") or "password"
) )
POSTGRES_HOST = os.environ.get("POSTGRES_HOST") or "localhost" POSTGRES_HOST = os.environ.get("POSTGRES_HOST") or "localhost"
POSTGRES_PORT = os.environ.get("POSTGRES_PORT") or "5433" POSTGRES_PORT = os.environ.get("POSTGRES_PORT") or "5432"
POSTGRES_DB = os.environ.get("POSTGRES_DB") or "postgres" POSTGRES_DB = os.environ.get("POSTGRES_DB") or "postgres"
POSTGRES_API_SERVER_POOL_SIZE = int( POSTGRES_API_SERVER_POOL_SIZE = int(

View File

@ -313,7 +313,7 @@ services:
- POSTGRES_USER=${POSTGRES_USER:-postgres} - POSTGRES_USER=${POSTGRES_USER:-postgres}
- POSTGRES_PASSWORD=${POSTGRES_PASSWORD:-password} - POSTGRES_PASSWORD=${POSTGRES_PASSWORD:-password}
ports: ports:
- "5433:5432" - "5432:5432"
volumes: volumes:
- db_volume:/var/lib/postgresql/data - db_volume:/var/lib/postgresql/data

View File

@ -313,7 +313,7 @@ services:
- POSTGRES_USER=${POSTGRES_USER:-postgres} - POSTGRES_USER=${POSTGRES_USER:-postgres}
- POSTGRES_PASSWORD=${POSTGRES_PASSWORD:-password} - POSTGRES_PASSWORD=${POSTGRES_PASSWORD:-password}
ports: ports:
- "5433:5432" - "5432:5432"
volumes: volumes:
- db_volume:/var/lib/postgresql/data - db_volume:/var/lib/postgresql/data

View File

@ -157,7 +157,7 @@ services:
- POSTGRES_USER=${POSTGRES_USER:-postgres} - POSTGRES_USER=${POSTGRES_USER:-postgres}
- POSTGRES_PASSWORD=${POSTGRES_PASSWORD:-password} - POSTGRES_PASSWORD=${POSTGRES_PASSWORD:-password}
ports: ports:
- "5433" - "5432"
volumes: volumes:
- db_volume:/var/lib/postgresql/data - db_volume:/var/lib/postgresql/data

View File

@ -71,8 +71,6 @@ const Page = async ({
if (authTypeMetadata?.autoRedirect && authUrl && !autoRedirectDisabled) { if (authTypeMetadata?.autoRedirect && authUrl && !autoRedirectDisabled) {
return redirect(authUrl); return redirect(authUrl);
} }
console.log("authTypeMetadata");
console.log(authTypeMetadata);
return ( return (
<AuthFlowContainer> <AuthFlowContainer>
@ -80,7 +78,7 @@ const Page = async ({
<HealthCheckBanner /> <HealthCheckBanner />
</div> </div>
<div className="flex flex-col justify-center"> <div className="flex flex-col w-full justify-center">
{authUrl && authTypeMetadata && ( {authUrl && authTypeMetadata && (
<> <>
<h2 className="text-center text-xl text-strong font-bold"> <h2 className="text-center text-xl text-strong font-bold">

View File

@ -42,18 +42,15 @@ const Page = async () => {
} }
return redirect("/auth/waiting-on-verification"); return redirect("/auth/waiting-on-verification");
} }
// only enable this page if basic login is enabled
if (
authTypeMetadata?.authType !== "basic" &&
authTypeMetadata?.authType !== "cloud"
) {
return redirect("/");
}
const cloud = authTypeMetadata?.authType === "cloud"; const cloud = authTypeMetadata?.authType === "cloud";
// only enable this page if basic login is enabled
if (authTypeMetadata?.authType !== "basic" && !cloud) {
return redirect("/");
}
let authUrl: string | null = null; let authUrl: string | null = null;
if (cloud) { if (cloud && authTypeMetadata) {
authUrl = await getAuthUrlSS(authTypeMetadata.authType, null); authUrl = await getAuthUrlSS(authTypeMetadata.authType, null);
} }
@ -63,21 +60,18 @@ const Page = async () => {
<> <>
<div className="absolute top-10x w-full"></div> <div className="absolute top-10x w-full"></div>
<div className="flex flex-col justify-center"> <div className="flex w-full flex-col justify-center">
<h2 className="text-center text-xl text-strong font-bold"> <h2 className="text-center text-xl text-strong font-bold">
{cloud ? "Complete your sign up" : "Sign Up for Danswer"} {cloud ? "Complete your sign up" : "Sign Up for Danswer"}
</h2> </h2>
{cloud && authUrl && ( {cloud && authUrl && (
<div className="w-full justify-center"> <div className="w-full justify-center">
<SignInButton <SignInButton authorizeUrl={authUrl} authType="cloud" />
authorizeUrl={authUrl}
authType={authTypeMetadata?.authType}
/>
<div className="flex items-center w-full my-4"> <div className="flex items-center w-full my-4">
<div className="flex-grow border-t border-gray-300"></div> <div className="flex-grow border-t border-background-300"></div>
<span className="px-4 text-gray-500">or</span> <span className="px-4 text-gray-500">or</span>
<div className="flex-grow border-t border-gray-300"></div> <div className="flex-grow border-t border-background-300"></div>
</div> </div>
</div> </div>
)} )}

View File

@ -7,7 +7,7 @@ export default function AuthFlowContainer({
}) { }) {
return ( return (
<div className="flex flex-col items-center justify-center min-h-screen bg-background"> <div className="flex flex-col items-center justify-center min-h-screen bg-background">
<div className="w-full max-w-md p-8 gap-y-4 bg-white flex items-center flex-col rounded-xl shadow-lg border border-bacgkround-100"> <div className="w-full max-w-md bg-black p-8 mx-4 gap-y-4 bg-white flex items-center flex-col rounded-xl shadow-lg border border-bacgkround-100">
<Logo width={70} height={70} /> <Logo width={70} height={70} />
{children} {children}
</div> </div>

View File

@ -20,6 +20,8 @@ export const getAuthTypeMetadataSS = async (): Promise<AuthTypeMetadata> => {
await res.json(); await res.json();
let authType: AuthType; let authType: AuthType;
// Override fasapi users auth so we can use both
if (SERVER_SIDE_ONLY__CLOUD_ENABLED) { if (SERVER_SIDE_ONLY__CLOUD_ENABLED) {
authType = "cloud"; authType = "cloud";
} else { } else {
@ -84,7 +86,6 @@ export const getAuthUrlSS = async (
authType: AuthType, authType: AuthType,
nextUrl: string | null nextUrl: string | null
): Promise<string> => { ): Promise<string> => {
console.log(authType);
// Returns the auth url for the given auth type // Returns the auth url for the given auth type
switch (authType) { switch (authType) {
case "disabled": case "disabled":
@ -95,10 +96,7 @@ export const getAuthUrlSS = async (
return await getGoogleOAuthUrlSS(); return await getGoogleOAuthUrlSS();
} }
case "cloud": { case "cloud": {
console.log("returning cloud auth url"); return await getGoogleOAuthUrlSS();
const value = await getGoogleOAuthUrlSS();
console.log(value);
return value;
} }
case "saml": { case "saml": {
return await getSAMLAuthUrlSS(); return await getSAMLAuthUrlSS();