Add referral source to cloud on data plane (#3096)

* cloud auth referral source

* minor clarity

* k

* minor modification to be best practice

* typing

* Update ReferralSourceSelector.tsx

* Update ReferralSourceSelector.tsx

---------

Co-authored-by: hagen-danswer <hagen@danswer.ai>
This commit is contained in:
pablodanswer
2024-11-12 16:42:25 -08:00
committed by GitHub
parent fdc4811fce
commit 22189f02c6
14 changed files with 176 additions and 21 deletions

View File

@@ -228,12 +228,17 @@ class UserManager(UUIDIDMixin, BaseUserManager[User, uuid.UUID]):
safe: bool = False,
request: Optional[Request] = None,
) -> User:
referral_source = None
if request is not None:
referral_source = request.cookies.get("referral_source", None)
tenant_id = await fetch_ee_implementation_or_noop(
"danswer.server.tenants.provisioning",
"get_or_create_tenant_id",
async_return_default_schema,
)(
email=user_create.email,
referral_source=referral_source,
)
async with get_async_session_with_tenant(tenant_id) as db_session:
@@ -294,12 +299,17 @@ class UserManager(UUIDIDMixin, BaseUserManager[User, uuid.UUID]):
associate_by_email: bool = False,
is_verified_by_default: bool = False,
) -> models.UOAP:
referral_source = None
if request:
referral_source = getattr(request.state, "referral_source", None)
tenant_id = await fetch_ee_implementation_or_noop(
"danswer.server.tenants.provisioning",
"get_or_create_tenant_id",
async_return_default_schema,
)(
email=account_email,
referral_source=referral_source,
)
if not tenant_id:
@@ -711,8 +721,6 @@ def generate_state_token(
# refer to https://github.com/fastapi-users/fastapi-users/blob/42ddc241b965475390e2bce887b084152ae1a2cd/fastapi_users/fastapi_users.py#L91
def create_danswer_oauth_router(
oauth_client: BaseOAuth2,
backend: AuthenticationBackend,
@@ -762,15 +770,22 @@ def get_oauth_router(
response_model=OAuth2AuthorizeResponse,
)
async def authorize(
request: Request, scopes: List[str] = Query(None)
request: Request,
scopes: List[str] = Query(None),
) -> OAuth2AuthorizeResponse:
referral_source = request.cookies.get("referral_source", None)
if redirect_url is not None:
authorize_redirect_url = redirect_url
else:
authorize_redirect_url = str(request.url_for(callback_route_name))
next_url = request.query_params.get("next", "/")
state_data: Dict[str, str] = {"next_url": next_url}
state_data: Dict[str, str] = {
"next_url": next_url,
"referral_source": referral_source or "default_referral",
}
state = generate_state_token(state_data, state_secret)
authorization_url = await oauth_client.get_authorization_url(
authorize_redirect_url,
@@ -829,8 +844,11 @@ def get_oauth_router(
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST)
next_url = state_data.get("next_url", "/")
referral_source = state_data.get("referral_source", None)
# Authenticate user
request.state.referral_source = referral_source
# Proceed to authenticate or create the user
try:
user = await user_manager.oauth_callback(
oauth_client.name,
@@ -872,7 +890,6 @@ def get_oauth_router(
redirect_response.status_code = response.status_code
if hasattr(response, "media_type"):
redirect_response.media_type = response.media_type
return redirect_response
return router

View File

@@ -315,7 +315,7 @@ def get_application() -> FastAPI:
tags=["users"],
)
if AUTH_TYPE == AuthType.GOOGLE_OAUTH or AUTH_TYPE == AuthType.CLOUD:
if AUTH_TYPE == AuthType.GOOGLE_OAUTH:
oauth_client = GoogleOAuth2(OAUTH_CLIENT_ID, OAUTH_CLIENT_SECRET)
include_router_with_global_prefix_prepended(
application,

View File

@@ -1,4 +1,5 @@
from fastapi import FastAPI
from httpx_oauth.clients.google import GoogleOAuth2
from httpx_oauth.clients.openid import OpenID
from danswer.auth.users import auth_backend
@@ -59,6 +60,31 @@ def get_application() -> FastAPI:
if MULTI_TENANT:
add_tenant_id_middleware(application, logger)
if AUTH_TYPE == AuthType.CLOUD:
oauth_client = GoogleOAuth2(OAUTH_CLIENT_ID, OAUTH_CLIENT_SECRET)
include_router_with_global_prefix_prepended(
application,
create_danswer_oauth_router(
oauth_client,
auth_backend,
USER_AUTH_SECRET,
associate_by_email=True,
is_verified_by_default=True,
# Points the user back to the login page
redirect_url=f"{WEB_DOMAIN}/auth/oauth/callback",
),
prefix="/auth/oauth",
tags=["auth"],
)
# Need basic auth router for `logout` endpoint
include_router_with_global_prefix_prepended(
application,
fastapi_users.get_logout_router(auth_backend),
prefix="/auth",
tags=["auth"],
)
if AUTH_TYPE == AuthType.OIDC:
include_router_with_global_prefix_prepended(
application,
@@ -73,6 +99,7 @@ def get_application() -> FastAPI:
prefix="/auth/oidc",
tags=["auth"],
)
# need basic auth router for `logout` endpoint
include_router_with_global_prefix_prepended(
application,

View File

@@ -38,3 +38,4 @@ class ImpersonateRequest(BaseModel):
class TenantCreationPayload(BaseModel):
tenant_id: str
email: str
referral_source: str | None = None

View File

@@ -41,7 +41,9 @@ from shared_configs.enums import EmbeddingProvider
logger = logging.getLogger(__name__)
async def get_or_create_tenant_id(email: str) -> str:
async def get_or_create_tenant_id(
email: str, referral_source: str | None = None
) -> str:
"""Get existing tenant ID for an email or create a new tenant if none exists."""
if not MULTI_TENANT:
return POSTGRES_DEFAULT_SCHEMA
@@ -51,7 +53,7 @@ async def get_or_create_tenant_id(email: str) -> str:
except exceptions.UserNotExists:
# If tenant does not exist and in Multi tenant mode, provision a new tenant
try:
tenant_id = await create_tenant(email)
tenant_id = await create_tenant(email, referral_source)
except Exception as e:
logger.error(f"Tenant provisioning failed: {e}")
raise HTTPException(status_code=500, detail="Failed to provision tenant.")
@@ -64,13 +66,13 @@ async def get_or_create_tenant_id(email: str) -> str:
return tenant_id
async def create_tenant(email: str) -> str:
async def create_tenant(email: str, referral_source: str | None = None) -> str:
tenant_id = TENANT_ID_PREFIX + str(uuid.uuid4())
try:
# Provision tenant on data plane
await provision_tenant(tenant_id, email)
# Notify control plane
await notify_control_plane(tenant_id, email)
await notify_control_plane(tenant_id, email, referral_source)
except Exception as e:
logger.error(f"Tenant provisioning failed: {e}")
await rollback_tenant_provisioning(tenant_id)
@@ -117,14 +119,18 @@ async def provision_tenant(tenant_id: str, email: str) -> None:
CURRENT_TENANT_ID_CONTEXTVAR.reset(token)
async def notify_control_plane(tenant_id: str, email: str) -> None:
async def notify_control_plane(
tenant_id: str, email: str, referral_source: str | None = None
) -> None:
logger.info("Fetching billing information")
token = generate_data_plane_token()
headers = {
"Authorization": f"Bearer {token}",
"Content-Type": "application/json",
}
payload = TenantCreationPayload(tenant_id=tenant_id, email=email)
payload = TenantCreationPayload(
tenant_id=tenant_id, email=email, referral_source=referral_source
)
async with aiohttp.ClientSession() as session:
async with session.post(

View File

@@ -28,10 +28,12 @@ class TenantManager:
def create(
tenant_id: str | None = None,
initial_admin_email: str | None = None,
referral_source: str | None = None,
) -> dict[str, str]:
body = {
"tenant_id": tenant_id,
"initial_admin_email": initial_admin_email,
"referral_source": referral_source,
}
token = generate_auth_token()

View File

@@ -14,12 +14,12 @@ from tests.integration.common_utils.test_models import DATestUser
def test_multi_tenant_access_control(reset_multitenant: None) -> None:
# Create Tenant 1 and its Admin User
TenantManager.create("tenant_dev1", "test1@test.com")
TenantManager.create("tenant_dev1", "test1@test.com", "Data Plane Registration")
test_user1: DATestUser = UserManager.create(name="test1", email="test1@test.com")
assert UserManager.verify_role(test_user1, UserRole.ADMIN)
# Create Tenant 2 and its Admin User
TenantManager.create("tenant_dev2", "test2@test.com")
TenantManager.create("tenant_dev2", "test2@test.com", "Data Plane Registration")
test_user2: DATestUser = UserManager.create(name="test2", email="test2@test.com")
assert UserManager.verify_role(test_user2, UserRole.ADMIN)

View File

@@ -11,7 +11,7 @@ from tests.integration.common_utils.test_models import DATestUser
# Test flow from creating tenant to registering as a user
def test_tenant_creation(reset_multitenant: None) -> None:
TenantManager.create("tenant_dev", "test@test.com")
TenantManager.create("tenant_dev", "test@test.com", "Data Plane Registration")
test_user: DATestUser = UserManager.create(name="test", email="test@test.com")
assert UserManager.verify_role(test_user, UserRole.ADMIN)

View File

@@ -14,9 +14,11 @@ import { Spinner } from "@/components/Spinner";
export function EmailPasswordForm({
isSignup = false,
shouldVerify,
referralSource,
}: {
isSignup?: boolean;
shouldVerify?: boolean;
referralSource?: string;
}) {
const router = useRouter();
const { popup, setPopup } = usePopup();
@@ -39,7 +41,11 @@ export function EmailPasswordForm({
if (isSignup) {
// login is fast, no need to show a spinner
setIsWorking(true);
const response = await basicSignup(values.email, values.password);
const response = await basicSignup(
values.email,
values.password,
referralSource
);
if (!response.ok) {
const errorDetail = (await response.json()).detail;

View File

@@ -36,14 +36,18 @@ export function SignInButton({
);
}
const url = new URL(authorizeUrl);
const finalAuthorizeUrl = url.toString();
if (!button) {
throw new Error(`Unhandled authType: ${authType}`);
}
return (
<a
className="mx-auto mt-6 py-3 w-72 text-text-100 bg-accent flex rounded cursor-pointer hover:bg-indigo-800"
href={authorizeUrl}
className="mx-auto mt-6 py-3 w-full text-text-100 bg-accent flex rounded cursor-pointer hover:bg-indigo-800"
href={finalAuthorizeUrl}
>
{button}
</a>

View File

@@ -0,0 +1,74 @@
"use client";
import { useState } from "react";
import {
Select,
SelectContent,
SelectItem,
SelectTrigger,
SelectValue,
} from "@/components/ui/select";
import { Label } from "@/components/admin/connectors/Field";
interface ReferralSourceSelectorProps {
defaultValue?: string;
}
const ReferralSourceSelector: React.FC<ReferralSourceSelectorProps> = ({
defaultValue,
}) => {
const [referralSource, setReferralSource] = useState(defaultValue);
const referralOptions = [
{ value: "search", label: "Search Engine (Google/Bing)" },
{ value: "friend", label: "Friend/Colleague" },
{ value: "linkedin", label: "LinkedIn" },
{ value: "twitter", label: "Twitter" },
{ value: "hackernews", label: "HackerNews" },
{ value: "reddit", label: "Reddit" },
{ value: "youtube", label: "YouTube" },
{ value: "podcast", label: "Podcast" },
{ value: "blog", label: "Article/Blog" },
{ value: "ads", label: "Advertisements" },
{ value: "other", label: "Other" },
];
const handleChange = (value: string) => {
setReferralSource(value);
const cookies = require("js-cookie");
cookies.set("referral_source", value, {
expires: 365,
path: "/",
sameSite: "strict",
});
};
return (
<div className="w-full max-w-sm gap-y-2 flex flex-col mx-auto">
<Label className="text-text-950" small={false}>
How did you hear about us?
</Label>
<Select value={referralSource} onValueChange={handleChange}>
<SelectTrigger
id="referral-source"
className="w-full border-gray-300 rounded-md shadow-sm focus:border-indigo-500 focus:ring-indigo-500"
>
<SelectValue placeholder="Select an option" />
</SelectTrigger>
<SelectContent className="max-h-60 overflow-y-auto">
{referralOptions.map((option) => (
<SelectItem
key={option.value}
value={option.value}
className="py-2 px-3 hover:bg-indigo-100 cursor-pointer"
>
{option.label}
</SelectItem>
))}
</SelectContent>
</Select>
</div>
);
};
export default ReferralSourceSelector;

View File

@@ -12,6 +12,8 @@ import Text from "@/components/ui/text";
import Link from "next/link";
import { SignInButton } from "../login/SignInButton";
import AuthFlowContainer from "@/components/auth/AuthFlowContainer";
import ReferralSourceSelector from "./ReferralSourceSelector";
import { Separator } from "@/components/ui/separator";
const Page = async () => {
// catch cases where the backend is completely unreachable here
@@ -62,6 +64,13 @@ const Page = async () => {
<h2 className="text-center text-xl text-strong font-bold">
{cloud ? "Complete your sign up" : "Sign Up for Danswer"}
</h2>
{cloud && (
<>
<div className="w-full flex flex-col items-center space-y-4 mb-4 mt-4">
<ReferralSourceSelector />
</div>
</>
)}
{cloud && authUrl && (
<div className="w-full justify-center">

View File

@@ -43,7 +43,11 @@ export const basicLogin = async (
return response;
};
export const basicSignup = async (email: string, password: string) => {
export const basicSignup = async (
email: string,
password: string,
referralSource?: string
) => {
const response = await fetch("/api/auth/register", {
method: "POST",
credentials: "include",
@@ -54,6 +58,7 @@ export const basicSignup = async (email: string, password: string) => {
email,
username: email,
password,
referral_source: referralSource,
}),
});
return response;

View File

@@ -63,7 +63,11 @@ const getOIDCAuthUrlSS = async (nextUrl: string | null): Promise<string> => {
};
const getGoogleOAuthUrlSS = async (): Promise<string> => {
const res = await fetch(buildUrl(`/auth/oauth/authorize`));
const res = await fetch(buildUrl(`/auth/oauth/authorize`), {
headers: {
cookie: processCookies(await cookies()),
},
});
if (!res.ok) {
throw new Error("Failed to fetch data");
}