mirror of
https://github.com/danswer-ai/danswer.git
synced 2025-09-27 12:29:41 +02:00
Add referral source to cloud on data plane (#3096)
* cloud auth referral source * minor clarity * k * minor modification to be best practice * typing * Update ReferralSourceSelector.tsx * Update ReferralSourceSelector.tsx --------- Co-authored-by: hagen-danswer <hagen@danswer.ai>
This commit is contained in:
@@ -228,12 +228,17 @@ class UserManager(UUIDIDMixin, BaseUserManager[User, uuid.UUID]):
|
||||
safe: bool = False,
|
||||
request: Optional[Request] = None,
|
||||
) -> User:
|
||||
referral_source = None
|
||||
if request is not None:
|
||||
referral_source = request.cookies.get("referral_source", None)
|
||||
|
||||
tenant_id = await fetch_ee_implementation_or_noop(
|
||||
"danswer.server.tenants.provisioning",
|
||||
"get_or_create_tenant_id",
|
||||
async_return_default_schema,
|
||||
)(
|
||||
email=user_create.email,
|
||||
referral_source=referral_source,
|
||||
)
|
||||
|
||||
async with get_async_session_with_tenant(tenant_id) as db_session:
|
||||
@@ -294,12 +299,17 @@ class UserManager(UUIDIDMixin, BaseUserManager[User, uuid.UUID]):
|
||||
associate_by_email: bool = False,
|
||||
is_verified_by_default: bool = False,
|
||||
) -> models.UOAP:
|
||||
referral_source = None
|
||||
if request:
|
||||
referral_source = getattr(request.state, "referral_source", None)
|
||||
|
||||
tenant_id = await fetch_ee_implementation_or_noop(
|
||||
"danswer.server.tenants.provisioning",
|
||||
"get_or_create_tenant_id",
|
||||
async_return_default_schema,
|
||||
)(
|
||||
email=account_email,
|
||||
referral_source=referral_source,
|
||||
)
|
||||
|
||||
if not tenant_id:
|
||||
@@ -711,8 +721,6 @@ def generate_state_token(
|
||||
|
||||
|
||||
# refer to https://github.com/fastapi-users/fastapi-users/blob/42ddc241b965475390e2bce887b084152ae1a2cd/fastapi_users/fastapi_users.py#L91
|
||||
|
||||
|
||||
def create_danswer_oauth_router(
|
||||
oauth_client: BaseOAuth2,
|
||||
backend: AuthenticationBackend,
|
||||
@@ -762,15 +770,22 @@ def get_oauth_router(
|
||||
response_model=OAuth2AuthorizeResponse,
|
||||
)
|
||||
async def authorize(
|
||||
request: Request, scopes: List[str] = Query(None)
|
||||
request: Request,
|
||||
scopes: List[str] = Query(None),
|
||||
) -> OAuth2AuthorizeResponse:
|
||||
referral_source = request.cookies.get("referral_source", None)
|
||||
|
||||
if redirect_url is not None:
|
||||
authorize_redirect_url = redirect_url
|
||||
else:
|
||||
authorize_redirect_url = str(request.url_for(callback_route_name))
|
||||
|
||||
next_url = request.query_params.get("next", "/")
|
||||
state_data: Dict[str, str] = {"next_url": next_url}
|
||||
|
||||
state_data: Dict[str, str] = {
|
||||
"next_url": next_url,
|
||||
"referral_source": referral_source or "default_referral",
|
||||
}
|
||||
state = generate_state_token(state_data, state_secret)
|
||||
authorization_url = await oauth_client.get_authorization_url(
|
||||
authorize_redirect_url,
|
||||
@@ -829,8 +844,11 @@ def get_oauth_router(
|
||||
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST)
|
||||
|
||||
next_url = state_data.get("next_url", "/")
|
||||
referral_source = state_data.get("referral_source", None)
|
||||
|
||||
# Authenticate user
|
||||
request.state.referral_source = referral_source
|
||||
|
||||
# Proceed to authenticate or create the user
|
||||
try:
|
||||
user = await user_manager.oauth_callback(
|
||||
oauth_client.name,
|
||||
@@ -872,7 +890,6 @@ def get_oauth_router(
|
||||
redirect_response.status_code = response.status_code
|
||||
if hasattr(response, "media_type"):
|
||||
redirect_response.media_type = response.media_type
|
||||
|
||||
return redirect_response
|
||||
|
||||
return router
|
||||
|
@@ -315,7 +315,7 @@ def get_application() -> FastAPI:
|
||||
tags=["users"],
|
||||
)
|
||||
|
||||
if AUTH_TYPE == AuthType.GOOGLE_OAUTH or AUTH_TYPE == AuthType.CLOUD:
|
||||
if AUTH_TYPE == AuthType.GOOGLE_OAUTH:
|
||||
oauth_client = GoogleOAuth2(OAUTH_CLIENT_ID, OAUTH_CLIENT_SECRET)
|
||||
include_router_with_global_prefix_prepended(
|
||||
application,
|
||||
|
@@ -1,4 +1,5 @@
|
||||
from fastapi import FastAPI
|
||||
from httpx_oauth.clients.google import GoogleOAuth2
|
||||
from httpx_oauth.clients.openid import OpenID
|
||||
|
||||
from danswer.auth.users import auth_backend
|
||||
@@ -59,6 +60,31 @@ def get_application() -> FastAPI:
|
||||
if MULTI_TENANT:
|
||||
add_tenant_id_middleware(application, logger)
|
||||
|
||||
if AUTH_TYPE == AuthType.CLOUD:
|
||||
oauth_client = GoogleOAuth2(OAUTH_CLIENT_ID, OAUTH_CLIENT_SECRET)
|
||||
include_router_with_global_prefix_prepended(
|
||||
application,
|
||||
create_danswer_oauth_router(
|
||||
oauth_client,
|
||||
auth_backend,
|
||||
USER_AUTH_SECRET,
|
||||
associate_by_email=True,
|
||||
is_verified_by_default=True,
|
||||
# Points the user back to the login page
|
||||
redirect_url=f"{WEB_DOMAIN}/auth/oauth/callback",
|
||||
),
|
||||
prefix="/auth/oauth",
|
||||
tags=["auth"],
|
||||
)
|
||||
|
||||
# Need basic auth router for `logout` endpoint
|
||||
include_router_with_global_prefix_prepended(
|
||||
application,
|
||||
fastapi_users.get_logout_router(auth_backend),
|
||||
prefix="/auth",
|
||||
tags=["auth"],
|
||||
)
|
||||
|
||||
if AUTH_TYPE == AuthType.OIDC:
|
||||
include_router_with_global_prefix_prepended(
|
||||
application,
|
||||
@@ -73,6 +99,7 @@ def get_application() -> FastAPI:
|
||||
prefix="/auth/oidc",
|
||||
tags=["auth"],
|
||||
)
|
||||
|
||||
# need basic auth router for `logout` endpoint
|
||||
include_router_with_global_prefix_prepended(
|
||||
application,
|
||||
|
@@ -38,3 +38,4 @@ class ImpersonateRequest(BaseModel):
|
||||
class TenantCreationPayload(BaseModel):
|
||||
tenant_id: str
|
||||
email: str
|
||||
referral_source: str | None = None
|
||||
|
@@ -41,7 +41,9 @@ from shared_configs.enums import EmbeddingProvider
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
async def get_or_create_tenant_id(email: str) -> str:
|
||||
async def get_or_create_tenant_id(
|
||||
email: str, referral_source: str | None = None
|
||||
) -> str:
|
||||
"""Get existing tenant ID for an email or create a new tenant if none exists."""
|
||||
if not MULTI_TENANT:
|
||||
return POSTGRES_DEFAULT_SCHEMA
|
||||
@@ -51,7 +53,7 @@ async def get_or_create_tenant_id(email: str) -> str:
|
||||
except exceptions.UserNotExists:
|
||||
# If tenant does not exist and in Multi tenant mode, provision a new tenant
|
||||
try:
|
||||
tenant_id = await create_tenant(email)
|
||||
tenant_id = await create_tenant(email, referral_source)
|
||||
except Exception as e:
|
||||
logger.error(f"Tenant provisioning failed: {e}")
|
||||
raise HTTPException(status_code=500, detail="Failed to provision tenant.")
|
||||
@@ -64,13 +66,13 @@ async def get_or_create_tenant_id(email: str) -> str:
|
||||
return tenant_id
|
||||
|
||||
|
||||
async def create_tenant(email: str) -> str:
|
||||
async def create_tenant(email: str, referral_source: str | None = None) -> str:
|
||||
tenant_id = TENANT_ID_PREFIX + str(uuid.uuid4())
|
||||
try:
|
||||
# Provision tenant on data plane
|
||||
await provision_tenant(tenant_id, email)
|
||||
# Notify control plane
|
||||
await notify_control_plane(tenant_id, email)
|
||||
await notify_control_plane(tenant_id, email, referral_source)
|
||||
except Exception as e:
|
||||
logger.error(f"Tenant provisioning failed: {e}")
|
||||
await rollback_tenant_provisioning(tenant_id)
|
||||
@@ -117,14 +119,18 @@ async def provision_tenant(tenant_id: str, email: str) -> None:
|
||||
CURRENT_TENANT_ID_CONTEXTVAR.reset(token)
|
||||
|
||||
|
||||
async def notify_control_plane(tenant_id: str, email: str) -> None:
|
||||
async def notify_control_plane(
|
||||
tenant_id: str, email: str, referral_source: str | None = None
|
||||
) -> None:
|
||||
logger.info("Fetching billing information")
|
||||
token = generate_data_plane_token()
|
||||
headers = {
|
||||
"Authorization": f"Bearer {token}",
|
||||
"Content-Type": "application/json",
|
||||
}
|
||||
payload = TenantCreationPayload(tenant_id=tenant_id, email=email)
|
||||
payload = TenantCreationPayload(
|
||||
tenant_id=tenant_id, email=email, referral_source=referral_source
|
||||
)
|
||||
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.post(
|
||||
|
@@ -28,10 +28,12 @@ class TenantManager:
|
||||
def create(
|
||||
tenant_id: str | None = None,
|
||||
initial_admin_email: str | None = None,
|
||||
referral_source: str | None = None,
|
||||
) -> dict[str, str]:
|
||||
body = {
|
||||
"tenant_id": tenant_id,
|
||||
"initial_admin_email": initial_admin_email,
|
||||
"referral_source": referral_source,
|
||||
}
|
||||
|
||||
token = generate_auth_token()
|
||||
|
@@ -14,12 +14,12 @@ from tests.integration.common_utils.test_models import DATestUser
|
||||
|
||||
def test_multi_tenant_access_control(reset_multitenant: None) -> None:
|
||||
# Create Tenant 1 and its Admin User
|
||||
TenantManager.create("tenant_dev1", "test1@test.com")
|
||||
TenantManager.create("tenant_dev1", "test1@test.com", "Data Plane Registration")
|
||||
test_user1: DATestUser = UserManager.create(name="test1", email="test1@test.com")
|
||||
assert UserManager.verify_role(test_user1, UserRole.ADMIN)
|
||||
|
||||
# Create Tenant 2 and its Admin User
|
||||
TenantManager.create("tenant_dev2", "test2@test.com")
|
||||
TenantManager.create("tenant_dev2", "test2@test.com", "Data Plane Registration")
|
||||
test_user2: DATestUser = UserManager.create(name="test2", email="test2@test.com")
|
||||
assert UserManager.verify_role(test_user2, UserRole.ADMIN)
|
||||
|
||||
|
@@ -11,7 +11,7 @@ from tests.integration.common_utils.test_models import DATestUser
|
||||
|
||||
# Test flow from creating tenant to registering as a user
|
||||
def test_tenant_creation(reset_multitenant: None) -> None:
|
||||
TenantManager.create("tenant_dev", "test@test.com")
|
||||
TenantManager.create("tenant_dev", "test@test.com", "Data Plane Registration")
|
||||
test_user: DATestUser = UserManager.create(name="test", email="test@test.com")
|
||||
|
||||
assert UserManager.verify_role(test_user, UserRole.ADMIN)
|
||||
|
@@ -14,9 +14,11 @@ import { Spinner } from "@/components/Spinner";
|
||||
export function EmailPasswordForm({
|
||||
isSignup = false,
|
||||
shouldVerify,
|
||||
referralSource,
|
||||
}: {
|
||||
isSignup?: boolean;
|
||||
shouldVerify?: boolean;
|
||||
referralSource?: string;
|
||||
}) {
|
||||
const router = useRouter();
|
||||
const { popup, setPopup } = usePopup();
|
||||
@@ -39,7 +41,11 @@ export function EmailPasswordForm({
|
||||
if (isSignup) {
|
||||
// login is fast, no need to show a spinner
|
||||
setIsWorking(true);
|
||||
const response = await basicSignup(values.email, values.password);
|
||||
const response = await basicSignup(
|
||||
values.email,
|
||||
values.password,
|
||||
referralSource
|
||||
);
|
||||
|
||||
if (!response.ok) {
|
||||
const errorDetail = (await response.json()).detail;
|
||||
|
@@ -36,14 +36,18 @@ export function SignInButton({
|
||||
);
|
||||
}
|
||||
|
||||
const url = new URL(authorizeUrl);
|
||||
|
||||
const finalAuthorizeUrl = url.toString();
|
||||
|
||||
if (!button) {
|
||||
throw new Error(`Unhandled authType: ${authType}`);
|
||||
}
|
||||
|
||||
return (
|
||||
<a
|
||||
className="mx-auto mt-6 py-3 w-72 text-text-100 bg-accent flex rounded cursor-pointer hover:bg-indigo-800"
|
||||
href={authorizeUrl}
|
||||
className="mx-auto mt-6 py-3 w-full text-text-100 bg-accent flex rounded cursor-pointer hover:bg-indigo-800"
|
||||
href={finalAuthorizeUrl}
|
||||
>
|
||||
{button}
|
||||
</a>
|
||||
|
74
web/src/app/auth/signup/ReferralSourceSelector.tsx
Normal file
74
web/src/app/auth/signup/ReferralSourceSelector.tsx
Normal file
@@ -0,0 +1,74 @@
|
||||
"use client";
|
||||
|
||||
import { useState } from "react";
|
||||
import {
|
||||
Select,
|
||||
SelectContent,
|
||||
SelectItem,
|
||||
SelectTrigger,
|
||||
SelectValue,
|
||||
} from "@/components/ui/select";
|
||||
import { Label } from "@/components/admin/connectors/Field";
|
||||
|
||||
interface ReferralSourceSelectorProps {
|
||||
defaultValue?: string;
|
||||
}
|
||||
|
||||
const ReferralSourceSelector: React.FC<ReferralSourceSelectorProps> = ({
|
||||
defaultValue,
|
||||
}) => {
|
||||
const [referralSource, setReferralSource] = useState(defaultValue);
|
||||
|
||||
const referralOptions = [
|
||||
{ value: "search", label: "Search Engine (Google/Bing)" },
|
||||
{ value: "friend", label: "Friend/Colleague" },
|
||||
{ value: "linkedin", label: "LinkedIn" },
|
||||
{ value: "twitter", label: "Twitter" },
|
||||
{ value: "hackernews", label: "HackerNews" },
|
||||
{ value: "reddit", label: "Reddit" },
|
||||
{ value: "youtube", label: "YouTube" },
|
||||
{ value: "podcast", label: "Podcast" },
|
||||
{ value: "blog", label: "Article/Blog" },
|
||||
{ value: "ads", label: "Advertisements" },
|
||||
{ value: "other", label: "Other" },
|
||||
];
|
||||
|
||||
const handleChange = (value: string) => {
|
||||
setReferralSource(value);
|
||||
const cookies = require("js-cookie");
|
||||
cookies.set("referral_source", value, {
|
||||
expires: 365,
|
||||
path: "/",
|
||||
sameSite: "strict",
|
||||
});
|
||||
};
|
||||
|
||||
return (
|
||||
<div className="w-full max-w-sm gap-y-2 flex flex-col mx-auto">
|
||||
<Label className="text-text-950" small={false}>
|
||||
How did you hear about us?
|
||||
</Label>
|
||||
<Select value={referralSource} onValueChange={handleChange}>
|
||||
<SelectTrigger
|
||||
id="referral-source"
|
||||
className="w-full border-gray-300 rounded-md shadow-sm focus:border-indigo-500 focus:ring-indigo-500"
|
||||
>
|
||||
<SelectValue placeholder="Select an option" />
|
||||
</SelectTrigger>
|
||||
<SelectContent className="max-h-60 overflow-y-auto">
|
||||
{referralOptions.map((option) => (
|
||||
<SelectItem
|
||||
key={option.value}
|
||||
value={option.value}
|
||||
className="py-2 px-3 hover:bg-indigo-100 cursor-pointer"
|
||||
>
|
||||
{option.label}
|
||||
</SelectItem>
|
||||
))}
|
||||
</SelectContent>
|
||||
</Select>
|
||||
</div>
|
||||
);
|
||||
};
|
||||
|
||||
export default ReferralSourceSelector;
|
@@ -12,6 +12,8 @@ import Text from "@/components/ui/text";
|
||||
import Link from "next/link";
|
||||
import { SignInButton } from "../login/SignInButton";
|
||||
import AuthFlowContainer from "@/components/auth/AuthFlowContainer";
|
||||
import ReferralSourceSelector from "./ReferralSourceSelector";
|
||||
import { Separator } from "@/components/ui/separator";
|
||||
|
||||
const Page = async () => {
|
||||
// catch cases where the backend is completely unreachable here
|
||||
@@ -62,6 +64,13 @@ const Page = async () => {
|
||||
<h2 className="text-center text-xl text-strong font-bold">
|
||||
{cloud ? "Complete your sign up" : "Sign Up for Danswer"}
|
||||
</h2>
|
||||
{cloud && (
|
||||
<>
|
||||
<div className="w-full flex flex-col items-center space-y-4 mb-4 mt-4">
|
||||
<ReferralSourceSelector />
|
||||
</div>
|
||||
</>
|
||||
)}
|
||||
|
||||
{cloud && authUrl && (
|
||||
<div className="w-full justify-center">
|
||||
|
@@ -43,7 +43,11 @@ export const basicLogin = async (
|
||||
return response;
|
||||
};
|
||||
|
||||
export const basicSignup = async (email: string, password: string) => {
|
||||
export const basicSignup = async (
|
||||
email: string,
|
||||
password: string,
|
||||
referralSource?: string
|
||||
) => {
|
||||
const response = await fetch("/api/auth/register", {
|
||||
method: "POST",
|
||||
credentials: "include",
|
||||
@@ -54,6 +58,7 @@ export const basicSignup = async (email: string, password: string) => {
|
||||
email,
|
||||
username: email,
|
||||
password,
|
||||
referral_source: referralSource,
|
||||
}),
|
||||
});
|
||||
return response;
|
||||
|
@@ -63,7 +63,11 @@ const getOIDCAuthUrlSS = async (nextUrl: string | null): Promise<string> => {
|
||||
};
|
||||
|
||||
const getGoogleOAuthUrlSS = async (): Promise<string> => {
|
||||
const res = await fetch(buildUrl(`/auth/oauth/authorize`));
|
||||
const res = await fetch(buildUrl(`/auth/oauth/authorize`), {
|
||||
headers: {
|
||||
cookie: processCookies(await cookies()),
|
||||
},
|
||||
});
|
||||
if (!res.ok) {
|
||||
throw new Error("Failed to fetch data");
|
||||
}
|
||||
|
Reference in New Issue
Block a user