From cb0a1e4fdca4d72a5403dd44c52d684c91b58f96 Mon Sep 17 00:00:00 2001 From: pablodanswer Date: Thu, 29 Aug 2024 17:25:43 -0700 Subject: [PATCH] functional porting over user --- .../alembic/versions/5be3aa848ce8_testing.py | 317 ------------------ ...25c363470f3_add_tenant_id_to_user_model.py | 24 ++ backend/danswer/auth/users.py | 3 +- .../connectors/confluence/connector.py | 1 - backend/danswer/db/models.py | 3 +- backend/danswer/server/settings/api.py | 33 +- web/src/app/auth/sso-callback/page.tsx | 9 +- web/src/app/ee/admin/plan/BillingSettings.tsx | 73 +++- web/src/app/ee/admin/plan/page.tsx | 6 +- 9 files changed, 96 insertions(+), 373 deletions(-) delete mode 100644 backend/alembic/versions/5be3aa848ce8_testing.py create mode 100644 backend/alembic/versions/b25c363470f3_add_tenant_id_to_user_model.py diff --git a/backend/alembic/versions/5be3aa848ce8_testing.py b/backend/alembic/versions/5be3aa848ce8_testing.py deleted file mode 100644 index 4c8c74bc6..000000000 --- a/backend/alembic/versions/5be3aa848ce8_testing.py +++ /dev/null @@ -1,317 +0,0 @@ -"""testing - -Revision ID: 5be3aa848ce8 -Revises: bceb1e139447 -Create Date: 2024-08-28 17:15:06.247199 - -""" -from alembic import op -import sqlalchemy as sa -from sqlalchemy.dialects import postgresql - -# revision identifiers, used by Alembic. -revision = "5be3aa848ce8" -down_revision = "bceb1e139447" -branch_labels = None -depends_on = None - - -def upgrade() -> None: - # ### commands auto generated by Alembic - please adjust! ### - op.create_table( - "chat_message__standard_answer", - sa.Column("chat_message_id", sa.Integer(), nullable=False), - sa.Column("standard_answer_id", sa.Integer(), nullable=False), - sa.ForeignKeyConstraint( - ["chat_message_id"], - ["chat_message.id"], - ), - sa.ForeignKeyConstraint( - ["standard_answer_id"], - ["standard_answer.id"], - ), - sa.PrimaryKeyConstraint("chat_message_id", "standard_answer_id"), - ) - op.drop_table("kombu_queue") - op.drop_index("ix_kombu_message_timestamp", table_name="kombu_message") - op.drop_index("ix_kombu_message_timestamp_id", table_name="kombu_message") - op.drop_index("ix_kombu_message_visible", table_name="kombu_message") - op.drop_table("kombu_message") - op.create_foreign_key(None, "api_key", "user", ["user_id"], ["id"]) - op.create_foreign_key(None, "api_key", "user", ["owner_id"], ["id"]) - op.alter_column( - "chat_folder", - "display_priority", - existing_type=sa.INTEGER(), - nullable=True, - ) - op.drop_constraint("chat_message_id_key", "chat_message", type_="unique") - op.alter_column( - "credential", - "source", - existing_type=sa.VARCHAR(length=100), - nullable=False, - ) - op.alter_column( - "credential", - "credential_json", - existing_type=postgresql.BYTEA(), - nullable=False, - ) - op.drop_index( - "ix_document_by_connector_credential_pair_pkey__connecto_27dc", - table_name="document_by_connector_credential_pair", - ) - op.alter_column( - "document_set__user", "user_id", existing_type=sa.UUID(), nullable=True - ) - op.add_column( - "email_to_external_user_cache", - sa.Column( - "source_type", - sa.Enum( - "INGESTION_API", - "SLACK", - "WEB", - "GOOGLE_DRIVE", - "GMAIL", - "REQUESTTRACKER", - "GITHUB", - "GITLAB", - "GURU", - "BOOKSTACK", - "CONFLUENCE", - "SLAB", - "JIRA", - "PRODUCTBOARD", - "FILE", - "NOTION", - "ZULIP", - "LINEAR", - "HUBSPOT", - "DOCUMENT360", - "GONG", - "GOOGLE_SITES", - "ZENDESK", - "LOOPIO", - "DROPBOX", - "SHAREPOINT", - "TEAMS", - "SALESFORCE", - "DISCOURSE", - "AXERO", - "CLICKUP", - "MEDIAWIKI", - "WIKIPEDIA", - "S3", - "R2", - "GOOGLE_CLOUD_STORAGE", - "OCI_STORAGE", - "NOT_APPLICABLE", - name="documentsource", - native_enum=False, - ), - nullable=False, - ), - ) - op.alter_column( - "inputprompt__user", - "user_id", - existing_type=sa.INTEGER(), - nullable=True, - ) - op.alter_column( - "llm_provider", "provider", existing_type=sa.VARCHAR(), nullable=False - ) - op.alter_column("persona__user", "user_id", existing_type=sa.UUID(), nullable=True) - op.alter_column( - "saml", - "expires_at", - existing_type=postgresql.TIMESTAMP(timezone=True), - nullable=False, - ) - op.alter_column( - "search_settings", - "query_prefix", - existing_type=sa.VARCHAR(), - nullable=True, - ) - op.alter_column( - "search_settings", - "passage_prefix", - existing_type=sa.VARCHAR(), - nullable=True, - ) - op.alter_column( - "search_settings", "status", existing_type=sa.VARCHAR(), nullable=False - ) - op.create_index( - "ix_embedding_model_future_unique", - "search_settings", - ["status"], - unique=True, - postgresql_where=sa.text("status = 'FUTURE'"), - ) - op.create_index( - "ix_embedding_model_present_unique", - "search_settings", - ["status"], - unique=True, - postgresql_where=sa.text("status = 'PRESENT'"), - ) - op.drop_constraint("standard_answer_keyword_key", "standard_answer", type_="unique") - op.create_index( - "unique_keyword_active", - "standard_answer", - ["keyword", "active"], - unique=True, - postgresql_where=sa.text("active = true"), - ) - op.alter_column( - "tool_call", - "tool_result", - existing_type=postgresql.JSONB(astext_type=sa.Text()), - nullable=True, - ) - op.alter_column( - "user__user_group", "user_id", existing_type=sa.UUID(), nullable=True - ) - # ### end Alembic commands ### - - -def downgrade() -> None: - # ### commands auto generated by Alembic - please adjust! ### - op.alter_column( - "user__user_group", "user_id", existing_type=sa.UUID(), nullable=False - ) - op.alter_column( - "tool_call", - "tool_result", - existing_type=postgresql.JSONB(astext_type=sa.Text()), - nullable=False, - ) - op.drop_index( - "unique_keyword_active", - table_name="standard_answer", - postgresql_where=sa.text("active = true"), - ) - op.create_unique_constraint( - "standard_answer_keyword_key", "standard_answer", ["keyword"] - ) - op.drop_index( - "ix_embedding_model_present_unique", - table_name="search_settings", - postgresql_where=sa.text("status = 'PRESENT'"), - ) - op.drop_index( - "ix_embedding_model_future_unique", - table_name="search_settings", - postgresql_where=sa.text("status = 'FUTURE'"), - ) - op.alter_column( - "search_settings", "status", existing_type=sa.VARCHAR(), nullable=True - ) - op.alter_column( - "search_settings", - "passage_prefix", - existing_type=sa.VARCHAR(), - nullable=False, - ) - op.alter_column( - "search_settings", - "query_prefix", - existing_type=sa.VARCHAR(), - nullable=False, - ) - op.alter_column( - "saml", - "expires_at", - existing_type=postgresql.TIMESTAMP(timezone=True), - nullable=True, - ) - op.alter_column("persona__user", "user_id", existing_type=sa.UUID(), nullable=False) - op.alter_column( - "llm_provider", "provider", existing_type=sa.VARCHAR(), nullable=True - ) - op.alter_column( - "inputprompt__user", - "user_id", - existing_type=sa.INTEGER(), - nullable=False, - ) - op.drop_column("email_to_external_user_cache", "source_type") - op.alter_column( - "document_set__user", - "user_id", - existing_type=sa.UUID(), - nullable=False, - ) - op.create_index( - "ix_document_by_connector_credential_pair_pkey__connecto_27dc", - "document_by_connector_credential_pair", - ["connector_id", "credential_id"], - unique=False, - ) - op.alter_column( - "credential", - "credential_json", - existing_type=postgresql.BYTEA(), - nullable=True, - ) - op.alter_column( - "credential", - "source", - existing_type=sa.VARCHAR(length=100), - nullable=True, - ) - op.create_unique_constraint("chat_message_id_key", "chat_message", ["id"]) - op.alter_column( - "chat_folder", - "display_priority", - existing_type=sa.INTEGER(), - nullable=False, - ) - op.drop_constraint(None, "api_key", type_="foreignkey") - op.drop_constraint(None, "api_key", type_="foreignkey") - op.create_table( - "kombu_message", - sa.Column("id", sa.INTEGER(), autoincrement=False, nullable=False), - sa.Column("visible", sa.BOOLEAN(), autoincrement=False, nullable=True), - sa.Column( - "timestamp", - postgresql.TIMESTAMP(), - autoincrement=False, - nullable=True, - ), - sa.Column("payload", sa.TEXT(), autoincrement=False, nullable=False), - sa.Column("version", sa.SMALLINT(), autoincrement=False, nullable=False), - sa.Column("queue_id", sa.INTEGER(), autoincrement=False, nullable=True), - sa.ForeignKeyConstraint( - ["queue_id"], ["kombu_queue.id"], name="FK_kombu_message_queue" - ), - sa.PrimaryKeyConstraint("id", name="kombu_message_pkey"), - ) - op.create_index( - "ix_kombu_message_visible", "kombu_message", ["visible"], unique=False - ) - op.create_index( - "ix_kombu_message_timestamp_id", - "kombu_message", - ["timestamp", "id"], - unique=False, - ) - op.create_index( - "ix_kombu_message_timestamp", - "kombu_message", - ["timestamp"], - unique=False, - ) - op.create_table( - "kombu_queue", - sa.Column("id", sa.INTEGER(), autoincrement=False, nullable=False), - sa.Column("name", sa.VARCHAR(length=200), autoincrement=False, nullable=True), - sa.PrimaryKeyConstraint("id", name="kombu_queue_pkey"), - sa.UniqueConstraint("name", name="kombu_queue_name_key"), - ) - op.drop_table("chat_message__standard_answer") - # ### end Alembic commands ### diff --git a/backend/alembic/versions/b25c363470f3_add_tenant_id_to_user_model.py b/backend/alembic/versions/b25c363470f3_add_tenant_id_to_user_model.py new file mode 100644 index 000000000..ce495a0e2 --- /dev/null +++ b/backend/alembic/versions/b25c363470f3_add_tenant_id_to_user_model.py @@ -0,0 +1,24 @@ +"""add tenant id to user model + +Revision ID: b25c363470f3 +Revises: 1f60f60c3401 +Create Date: 2024-08-29 17:03:20.794120 + +""" +from alembic import op +import sqlalchemy as sa + + +# revision identifiers, used by Alembic. +revision = "b25c363470f3" +down_revision = "1f60f60c3401" +branch_labels = None +depends_on = None + + +def upgrade() -> None: + op.add_column("user", sa.Column("tenant_id", sa.Text(), nullable=True)) + + +def downgrade() -> None: + op.drop_column("user", "tenant_id") diff --git a/backend/danswer/auth/users.py b/backend/danswer/auth/users.py index 1ef3198d2..dfd51c805 100644 --- a/backend/danswer/auth/users.py +++ b/backend/danswer/auth/users.py @@ -276,6 +276,7 @@ class UserManager(UUIDIDMixin, BaseUserManager[User, uuid.UUID]): ) -> models.UP: verify_email_is_invited(user_create.email) verify_email_domain(user_create.email) + print("CReATING") if hasattr(user_create, "role"): user_count = await get_user_count() if user_count == 0 or user_create.email in get_default_admin_user_emails(): @@ -360,7 +361,7 @@ async def sso_authenticate( ) -> models.UP: user = await self.get_by_email(email) if not user: - user_create = UserCreate(UserRole.BASIC) + user_create = UserCreate(role=UserRole.BASIC) user = await self.create(user_create) # Update user with tenant information if needed diff --git a/backend/danswer/connectors/confluence/connector.py b/backend/danswer/connectors/confluence/connector.py index b8dc967a3..aad812638 100644 --- a/backend/danswer/connectors/confluence/connector.py +++ b/backend/danswer/connectors/confluence/connector.py @@ -239,7 +239,6 @@ def _datetime_from_string(datetime_string: str) -> datetime: else: # If not in UTC, translate it datetime_object = datetime_object.astimezone(timezone.utc) - return datetime_object diff --git a/backend/danswer/db/models.py b/backend/danswer/db/models.py index 3cdec3239..9cd5c6f94 100644 --- a/backend/danswer/db/models.py +++ b/backend/danswer/db/models.py @@ -128,10 +128,11 @@ class User(SQLAlchemyBaseUserTableUUID, Base): oidc_expiry: Mapped[datetime.datetime] = mapped_column( TIMESTAMPAware(timezone=True), nullable=True ) + tenant_id: Mapped[str] = mapped_column(Text, nullable=True) - default_model: Mapped[str] = mapped_column(Text, nullable=True) # organized in typical structured fashion # formatted as `displayName__provider__modelName` + default_model: Mapped[str] = mapped_column(Text, nullable=True) # relationships credentials: Mapped[list["Credential"]] = relationship( diff --git a/backend/danswer/server/settings/api.py b/backend/danswer/server/settings/api.py index 95ece20e0..1f82ed8e6 100644 --- a/backend/danswer/server/settings/api.py +++ b/backend/danswer/server/settings/api.py @@ -59,16 +59,13 @@ async def sso_callback( print("SSO callback reached") payload = verify_sso_token(sso_token) - print("hi") + user = await get_or_create_user( payload["email"], payload["user_id"], payload["tenant_id"] ) - print(user) + session_token = await create_user_session(user, strategy) - print("Session creation attempt completed") - logger.info( - f"Session token created: {session_token[:10]}..." - ) # Log first 10 chars for security + logger.info(f"Session token created: {session_token[:10]}...") logger.info(f"User email: {user.email}") logger.info(f"User ID: {user.id}") logger.info(f"User role: {user.role}") @@ -86,30 +83,6 @@ async def sso_callback( return response -# @basic_router.post("/auth/sso-callback") -# async def sso_callback( -# user = Depends(current_user), -# token: str = Depends(oauth2_scheme), -# strategy: Strategy = Depends(get_database_strategy), -# user_manager: UserManager = Depends(get_user_manager), -# ): -# print('SSO callback reached') - -# payload = verify_sso_token(token) -# user = await get_or_create_user(payload["email"], payload["user_id"], payload["tenant_id"]) -# session_token = await create_user_session(user, strategy) - -# response = RedirectResponse(url="/") -# response.set_cookie( -# key="session", -# value=session_token, -# httponly=True, -# max_age=SESSION_EXPIRE_TIME_SECONDS, -# secure=WEB_DOMAIN.startswith("https"), -# ) -# return response - - @admin_router.put("") def put_settings( settings: Settings, _: User | None = Depends(current_admin_user) diff --git a/web/src/app/auth/sso-callback/page.tsx b/web/src/app/auth/sso-callback/page.tsx index 96ede2ce5..a1c42fffe 100644 --- a/web/src/app/auth/sso-callback/page.tsx +++ b/web/src/app/auth/sso-callback/page.tsx @@ -2,11 +2,7 @@ import { useEffect, useState } from "react"; import { useRouter, useSearchParams } from "next/navigation"; import { Card, Text } from "@tremor/react"; -import { Spinner } from "@/components/Spinner"; -import { SpinnerBall } from "@phosphor-icons/react/dist/ssr"; -import LogoType from "@/components/header/LogoType"; import { Logo } from "@/components/Logo"; -import { HeaderTitle } from "@/components/header/HeaderTitle"; export default function SSOCallback() { const router = useRouter(); @@ -39,7 +35,7 @@ export default function SSOCallback() { setTimeout(() => { setAuthStatus("Redirecting to dashboard..."); setTimeout(() => { - router.push("/admin/plan"); + router.replace("/admin/plan"); }, 1000); }, 1000); } else { @@ -78,9 +74,8 @@ export default function SSOCallback() { ) : (
- +
- {authStatus} diff --git a/web/src/app/ee/admin/plan/BillingSettings.tsx b/web/src/app/ee/admin/plan/BillingSettings.tsx index a6d529d0d..b9f7c5139 100644 --- a/web/src/app/ee/admin/plan/BillingSettings.tsx +++ b/web/src/app/ee/admin/plan/BillingSettings.tsx @@ -1,14 +1,18 @@ "use client"; import { BillingPlanType } from "@/app/admin/settings/interfaces"; -import { useContext, useState } from "react"; +import { useContext, useEffect, useState } from "react"; import { SettingsContext } from "@/components/settings/SettingsProvider"; import { Button, Divider, Text, Card } from "@tremor/react"; import { StripeCheckoutButton } from "./StripeCheckoutButton"; import { CheckmarkIcon, XIcon } from "@/components/icons/icons"; -import { FiAward, FiDollarSign, FiStar } from "react-icons/fi"; +import { FiAward, FiDollarSign, FiHelpCircle, FiStar } from "react-icons/fi"; +import Cookies from "js-cookie"; +import { Modal } from "@/components/Modal"; +import { Logo } from "@/components/Logo"; +import Link from "next/link"; -export function BillingSettings() { +export function BillingSettings({ newUser }: { newUser: boolean }) { const settings = useContext(SettingsContext); const cloudSettings = settings?.cloudSettings; @@ -53,6 +57,7 @@ export function BillingSettings() { const [newSeats, setNewSeats] = useState(seats); const [newPlan, setNewPlan] = useState(currentPlan); const [isOpen, setIsOpen] = useState(false); + const [isNewUserOpen, setIsNewUserOpen] = useState(true); function getBillingPlanIcon(planType: BillingPlanType) { switch (planType) { @@ -67,11 +72,49 @@ export function BillingSettings() { } } + const handleCloseModal = () => { + setIsNewUserOpen(false); + Cookies.set("new_auth_user", "false"); + }; + return (
+ {newUser && isNewUserOpen && ( + + <> +

+ Welcome to Danswer! +

+
+ +

+ We're thrilled to have you on board! Here, you can manage your + billing settings and explore your plan details. +

+
+
+ + +
+ +
+ )}
-

+

Your Plan
-

+

-

+

-