mirror of
https://github.com/danswer-ai/danswer.git
synced 2025-08-02 21:22:51 +02:00
Add user when they interact outside of UI (e.g. Slack bot) (#2369)
* Add user when they interact outside of UI (e.g. Slack bot) * fix mypy errors * don't use user manager to avoid async messiness * fix email is none scenario * fix mypy * make code slightly clearer * PR comments * get slack email in generate button as well * fix alembic migration * update name to be more descriptive --------- Co-authored-by: Hyeong Joon Suh <hyeongjoonsuh@Hyeongs-MacBook-Pro.local>
This commit is contained in:
@@ -0,0 +1,26 @@
|
|||||||
|
"""add has_web_login column to user
|
||||||
|
|
||||||
|
Revision ID: f7e58d357687
|
||||||
|
Revises: bceb1e139447
|
||||||
|
Create Date: 2024-09-07 20:20:54.522620
|
||||||
|
|
||||||
|
"""
|
||||||
|
from alembic import op
|
||||||
|
import sqlalchemy as sa
|
||||||
|
|
||||||
|
# revision identifiers, used by Alembic.
|
||||||
|
revision = "f7e58d357687"
|
||||||
|
down_revision = "ba98eba0f66a"
|
||||||
|
branch_labels = None
|
||||||
|
depends_on = None
|
||||||
|
|
||||||
|
|
||||||
|
def upgrade() -> None:
|
||||||
|
op.add_column(
|
||||||
|
"user",
|
||||||
|
sa.Column("has_web_login", sa.Boolean(), nullable=False, server_default="true"),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def downgrade() -> None:
|
||||||
|
op.drop_column("user", "has_web_login")
|
@@ -33,7 +33,9 @@ class UserRead(schemas.BaseUser[uuid.UUID]):
|
|||||||
|
|
||||||
class UserCreate(schemas.BaseUserCreate):
|
class UserCreate(schemas.BaseUserCreate):
|
||||||
role: UserRole = UserRole.BASIC
|
role: UserRole = UserRole.BASIC
|
||||||
|
has_web_login: bool | None = True
|
||||||
|
|
||||||
|
|
||||||
class UserUpdate(schemas.BaseUserUpdate):
|
class UserUpdate(schemas.BaseUserUpdate):
|
||||||
role: UserRole
|
role: UserRole
|
||||||
|
has_web_login: bool | None = True
|
||||||
|
@@ -16,7 +16,9 @@ from fastapi import HTTPException
|
|||||||
from fastapi import Request
|
from fastapi import Request
|
||||||
from fastapi import Response
|
from fastapi import Response
|
||||||
from fastapi import status
|
from fastapi import status
|
||||||
|
from fastapi.security import OAuth2PasswordRequestForm
|
||||||
from fastapi_users import BaseUserManager
|
from fastapi_users import BaseUserManager
|
||||||
|
from fastapi_users import exceptions
|
||||||
from fastapi_users import FastAPIUsers
|
from fastapi_users import FastAPIUsers
|
||||||
from fastapi_users import models
|
from fastapi_users import models
|
||||||
from fastapi_users import schemas
|
from fastapi_users import schemas
|
||||||
@@ -33,6 +35,7 @@ from sqlalchemy.orm import Session
|
|||||||
from danswer.auth.invited_users import get_invited_users
|
from danswer.auth.invited_users import get_invited_users
|
||||||
from danswer.auth.schemas import UserCreate
|
from danswer.auth.schemas import UserCreate
|
||||||
from danswer.auth.schemas import UserRole
|
from danswer.auth.schemas import UserRole
|
||||||
|
from danswer.auth.schemas import UserUpdate
|
||||||
from danswer.configs.app_configs import AUTH_TYPE
|
from danswer.configs.app_configs import AUTH_TYPE
|
||||||
from danswer.configs.app_configs import DISABLE_AUTH
|
from danswer.configs.app_configs import DISABLE_AUTH
|
||||||
from danswer.configs.app_configs import EMAIL_FROM
|
from danswer.configs.app_configs import EMAIL_FROM
|
||||||
@@ -184,7 +187,7 @@ class UserManager(UUIDIDMixin, BaseUserManager[User, uuid.UUID]):
|
|||||||
user_create: schemas.UC | UserCreate,
|
user_create: schemas.UC | UserCreate,
|
||||||
safe: bool = False,
|
safe: bool = False,
|
||||||
request: Optional[Request] = None,
|
request: Optional[Request] = None,
|
||||||
) -> models.UP:
|
) -> User:
|
||||||
verify_email_is_invited(user_create.email)
|
verify_email_is_invited(user_create.email)
|
||||||
verify_email_domain(user_create.email)
|
verify_email_domain(user_create.email)
|
||||||
if hasattr(user_create, "role"):
|
if hasattr(user_create, "role"):
|
||||||
@@ -193,7 +196,27 @@ class UserManager(UUIDIDMixin, BaseUserManager[User, uuid.UUID]):
|
|||||||
user_create.role = UserRole.ADMIN
|
user_create.role = UserRole.ADMIN
|
||||||
else:
|
else:
|
||||||
user_create.role = UserRole.BASIC
|
user_create.role = UserRole.BASIC
|
||||||
return await super().create(user_create, safe=safe, request=request) # type: ignore
|
user = None
|
||||||
|
try:
|
||||||
|
user = await super().create(user_create, safe=safe, request=request) # type: ignore
|
||||||
|
except exceptions.UserAlreadyExists:
|
||||||
|
user = await self.get_by_email(user_create.email)
|
||||||
|
# Handle case where user has used product outside of web and is now creating an account through web
|
||||||
|
if (
|
||||||
|
not user.has_web_login
|
||||||
|
and hasattr(user_create, "has_web_login")
|
||||||
|
and user_create.has_web_login
|
||||||
|
):
|
||||||
|
user_update = UserUpdate(
|
||||||
|
password=user_create.password,
|
||||||
|
has_web_login=True,
|
||||||
|
role=user_create.role,
|
||||||
|
is_verified=user_create.is_verified,
|
||||||
|
)
|
||||||
|
user = await self.update(user_update, user)
|
||||||
|
else:
|
||||||
|
raise exceptions.UserAlreadyExists()
|
||||||
|
return user
|
||||||
|
|
||||||
async def oauth_callback(
|
async def oauth_callback(
|
||||||
self: "BaseUserManager[models.UOAP, models.ID]",
|
self: "BaseUserManager[models.UOAP, models.ID]",
|
||||||
@@ -234,6 +257,17 @@ class UserManager(UUIDIDMixin, BaseUserManager[User, uuid.UUID]):
|
|||||||
if user.oidc_expiry and not TRACK_EXTERNAL_IDP_EXPIRY:
|
if user.oidc_expiry and not TRACK_EXTERNAL_IDP_EXPIRY:
|
||||||
await self.user_db.update(user, update_dict={"oidc_expiry": None})
|
await self.user_db.update(user, update_dict={"oidc_expiry": None})
|
||||||
|
|
||||||
|
# Handle case where user has used product outside of web and is now creating an account through web
|
||||||
|
if not user.has_web_login:
|
||||||
|
await self.user_db.update(
|
||||||
|
user,
|
||||||
|
update_dict={
|
||||||
|
"is_verified": is_verified_by_default,
|
||||||
|
"has_web_login": True,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
user.is_verified = is_verified_by_default
|
||||||
|
user.has_web_login = True
|
||||||
return user
|
return user
|
||||||
|
|
||||||
async def on_after_register(
|
async def on_after_register(
|
||||||
@@ -262,6 +296,22 @@ class UserManager(UUIDIDMixin, BaseUserManager[User, uuid.UUID]):
|
|||||||
|
|
||||||
send_user_verification_email(user.email, token)
|
send_user_verification_email(user.email, token)
|
||||||
|
|
||||||
|
async def authenticate(
|
||||||
|
self, credentials: OAuth2PasswordRequestForm
|
||||||
|
) -> Optional[User]:
|
||||||
|
user = await super().authenticate(credentials)
|
||||||
|
if user is None:
|
||||||
|
try:
|
||||||
|
user = await self.get_by_email(credentials.username)
|
||||||
|
if not user.has_web_login:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_403_FORBIDDEN,
|
||||||
|
detail="NO_WEB_LOGIN_AND_HAS_NO_PASSWORD",
|
||||||
|
)
|
||||||
|
except exceptions.UserNotExists:
|
||||||
|
pass
|
||||||
|
return user
|
||||||
|
|
||||||
|
|
||||||
async def get_user_manager(
|
async def get_user_manager(
|
||||||
user_db: SQLAlchemyUserDatabase = Depends(get_user_db),
|
user_db: SQLAlchemyUserDatabase = Depends(get_user_db),
|
||||||
|
@@ -11,6 +11,7 @@ from sqlalchemy.orm import Session
|
|||||||
from danswer.configs.constants import MessageType
|
from danswer.configs.constants import MessageType
|
||||||
from danswer.configs.constants import SearchFeedbackType
|
from danswer.configs.constants import SearchFeedbackType
|
||||||
from danswer.configs.danswerbot_configs import DANSWER_FOLLOWUP_EMOJI
|
from danswer.configs.danswerbot_configs import DANSWER_FOLLOWUP_EMOJI
|
||||||
|
from danswer.connectors.slack.utils import expert_info_from_slack_id
|
||||||
from danswer.connectors.slack.utils import make_slack_api_rate_limited
|
from danswer.connectors.slack.utils import make_slack_api_rate_limited
|
||||||
from danswer.danswerbot.slack.blocks import build_follow_up_resolved_blocks
|
from danswer.danswerbot.slack.blocks import build_follow_up_resolved_blocks
|
||||||
from danswer.danswerbot.slack.blocks import get_document_feedback_blocks
|
from danswer.danswerbot.slack.blocks import get_document_feedback_blocks
|
||||||
@@ -87,6 +88,8 @@ def handle_generate_answer_button(
|
|||||||
message_ts = req.payload["message"]["ts"]
|
message_ts = req.payload["message"]["ts"]
|
||||||
thread_ts = req.payload["container"]["thread_ts"]
|
thread_ts = req.payload["container"]["thread_ts"]
|
||||||
user_id = req.payload["user"]["id"]
|
user_id = req.payload["user"]["id"]
|
||||||
|
expert_info = expert_info_from_slack_id(user_id, client.web_client, user_cache={})
|
||||||
|
email = expert_info.email if expert_info else None
|
||||||
|
|
||||||
if not thread_ts:
|
if not thread_ts:
|
||||||
raise ValueError("Missing thread_ts in the payload")
|
raise ValueError("Missing thread_ts in the payload")
|
||||||
@@ -125,6 +128,7 @@ def handle_generate_answer_button(
|
|||||||
msg_to_respond=cast(str, message_ts or thread_ts),
|
msg_to_respond=cast(str, message_ts or thread_ts),
|
||||||
thread_to_respond=cast(str, thread_ts or message_ts),
|
thread_to_respond=cast(str, thread_ts or message_ts),
|
||||||
sender=user_id or None,
|
sender=user_id or None,
|
||||||
|
email=email or None,
|
||||||
bypass_filters=True,
|
bypass_filters=True,
|
||||||
is_bot_msg=False,
|
is_bot_msg=False,
|
||||||
is_bot_dm=False,
|
is_bot_dm=False,
|
||||||
|
@@ -21,6 +21,7 @@ from danswer.danswerbot.slack.utils import slack_usage_report
|
|||||||
from danswer.danswerbot.slack.utils import update_emote_react
|
from danswer.danswerbot.slack.utils import update_emote_react
|
||||||
from danswer.db.engine import get_sqlalchemy_engine
|
from danswer.db.engine import get_sqlalchemy_engine
|
||||||
from danswer.db.models import SlackBotConfig
|
from danswer.db.models import SlackBotConfig
|
||||||
|
from danswer.db.users import add_non_web_user_if_not_exists
|
||||||
from danswer.utils.logger import setup_logger
|
from danswer.utils.logger import setup_logger
|
||||||
from shared_configs.configs import SLACK_CHANNEL_ID
|
from shared_configs.configs import SLACK_CHANNEL_ID
|
||||||
|
|
||||||
@@ -209,6 +210,9 @@ def handle_message(
|
|||||||
logger.error(f"Was not able to react to user message due to: {e}")
|
logger.error(f"Was not able to react to user message due to: {e}")
|
||||||
|
|
||||||
with Session(get_sqlalchemy_engine()) as db_session:
|
with Session(get_sqlalchemy_engine()) as db_session:
|
||||||
|
if message_info.email:
|
||||||
|
add_non_web_user_if_not_exists(message_info.email, db_session)
|
||||||
|
|
||||||
# first check if we need to respond with a standard answer
|
# first check if we need to respond with a standard answer
|
||||||
used_standard_answer = handle_standard_answers(
|
used_standard_answer = handle_standard_answers(
|
||||||
message_info=message_info,
|
message_info=message_info,
|
||||||
|
@@ -22,7 +22,6 @@ from danswer.configs.danswerbot_configs import DANSWER_BOT_USE_QUOTES
|
|||||||
from danswer.configs.danswerbot_configs import DANSWER_FOLLOWUP_EMOJI
|
from danswer.configs.danswerbot_configs import DANSWER_FOLLOWUP_EMOJI
|
||||||
from danswer.configs.danswerbot_configs import DANSWER_REACT_EMOJI
|
from danswer.configs.danswerbot_configs import DANSWER_REACT_EMOJI
|
||||||
from danswer.configs.danswerbot_configs import ENABLE_DANSWERBOT_REFLEXION
|
from danswer.configs.danswerbot_configs import ENABLE_DANSWERBOT_REFLEXION
|
||||||
from danswer.connectors.slack.utils import expert_info_from_slack_id
|
|
||||||
from danswer.danswerbot.slack.blocks import build_documents_blocks
|
from danswer.danswerbot.slack.blocks import build_documents_blocks
|
||||||
from danswer.danswerbot.slack.blocks import build_follow_up_block
|
from danswer.danswerbot.slack.blocks import build_follow_up_block
|
||||||
from danswer.danswerbot.slack.blocks import build_qa_response_blocks
|
from danswer.danswerbot.slack.blocks import build_qa_response_blocks
|
||||||
@@ -103,13 +102,10 @@ def handle_regular_answer(
|
|||||||
is_bot_msg = message_info.is_bot_msg
|
is_bot_msg = message_info.is_bot_msg
|
||||||
user = None
|
user = None
|
||||||
if message_info.is_bot_dm:
|
if message_info.is_bot_dm:
|
||||||
slack_user_info = expert_info_from_slack_id(
|
if message_info.email:
|
||||||
message_info.sender, client, user_cache={}
|
|
||||||
)
|
|
||||||
if slack_user_info and slack_user_info.email:
|
|
||||||
engine = get_sqlalchemy_engine()
|
engine = get_sqlalchemy_engine()
|
||||||
with Session(engine) as db_session:
|
with Session(engine) as db_session:
|
||||||
user = get_user_by_email(slack_user_info.email, db_session)
|
user = get_user_by_email(message_info.email, db_session)
|
||||||
|
|
||||||
document_set_names: list[str] | None = None
|
document_set_names: list[str] | None = None
|
||||||
persona = slack_bot_config.persona if slack_bot_config else None
|
persona = slack_bot_config.persona if slack_bot_config else None
|
||||||
|
@@ -13,6 +13,7 @@ from danswer.configs.constants import MessageType
|
|||||||
from danswer.configs.danswerbot_configs import DANSWER_BOT_REPHRASE_MESSAGE
|
from danswer.configs.danswerbot_configs import DANSWER_BOT_REPHRASE_MESSAGE
|
||||||
from danswer.configs.danswerbot_configs import DANSWER_BOT_RESPOND_EVERY_CHANNEL
|
from danswer.configs.danswerbot_configs import DANSWER_BOT_RESPOND_EVERY_CHANNEL
|
||||||
from danswer.configs.danswerbot_configs import NOTIFY_SLACKBOT_NO_ANSWER
|
from danswer.configs.danswerbot_configs import NOTIFY_SLACKBOT_NO_ANSWER
|
||||||
|
from danswer.connectors.slack.utils import expert_info_from_slack_id
|
||||||
from danswer.danswerbot.slack.config import get_slack_bot_config_for_channel
|
from danswer.danswerbot.slack.config import get_slack_bot_config_for_channel
|
||||||
from danswer.danswerbot.slack.constants import DISLIKE_BLOCK_ACTION_ID
|
from danswer.danswerbot.slack.constants import DISLIKE_BLOCK_ACTION_ID
|
||||||
from danswer.danswerbot.slack.constants import FEEDBACK_DOC_BUTTON_BLOCK_ACTION_ID
|
from danswer.danswerbot.slack.constants import FEEDBACK_DOC_BUTTON_BLOCK_ACTION_ID
|
||||||
@@ -256,6 +257,11 @@ def build_request_details(
|
|||||||
tagged = event.get("type") == "app_mention"
|
tagged = event.get("type") == "app_mention"
|
||||||
message_ts = event.get("ts")
|
message_ts = event.get("ts")
|
||||||
thread_ts = event.get("thread_ts")
|
thread_ts = event.get("thread_ts")
|
||||||
|
sender = event.get("user") or None
|
||||||
|
expert_info = expert_info_from_slack_id(
|
||||||
|
sender, client.web_client, user_cache={}
|
||||||
|
)
|
||||||
|
email = expert_info.email if expert_info else None
|
||||||
|
|
||||||
msg = remove_danswer_bot_tag(msg, client=client.web_client)
|
msg = remove_danswer_bot_tag(msg, client=client.web_client)
|
||||||
|
|
||||||
@@ -286,7 +292,8 @@ def build_request_details(
|
|||||||
channel_to_respond=channel,
|
channel_to_respond=channel,
|
||||||
msg_to_respond=cast(str, message_ts or thread_ts),
|
msg_to_respond=cast(str, message_ts or thread_ts),
|
||||||
thread_to_respond=cast(str, thread_ts or message_ts),
|
thread_to_respond=cast(str, thread_ts or message_ts),
|
||||||
sender=event.get("user") or None,
|
sender=sender,
|
||||||
|
email=email,
|
||||||
bypass_filters=tagged,
|
bypass_filters=tagged,
|
||||||
is_bot_msg=False,
|
is_bot_msg=False,
|
||||||
is_bot_dm=event.get("channel_type") == "im",
|
is_bot_dm=event.get("channel_type") == "im",
|
||||||
@@ -296,6 +303,10 @@ def build_request_details(
|
|||||||
channel = req.payload["channel_id"]
|
channel = req.payload["channel_id"]
|
||||||
msg = req.payload["text"]
|
msg = req.payload["text"]
|
||||||
sender = req.payload["user_id"]
|
sender = req.payload["user_id"]
|
||||||
|
expert_info = expert_info_from_slack_id(
|
||||||
|
sender, client.web_client, user_cache={}
|
||||||
|
)
|
||||||
|
email = expert_info.email if expert_info else None
|
||||||
|
|
||||||
single_msg = ThreadMessage(message=msg, sender=None, role=MessageType.USER)
|
single_msg = ThreadMessage(message=msg, sender=None, role=MessageType.USER)
|
||||||
|
|
||||||
@@ -305,6 +316,7 @@ def build_request_details(
|
|||||||
msg_to_respond=None,
|
msg_to_respond=None,
|
||||||
thread_to_respond=None,
|
thread_to_respond=None,
|
||||||
sender=sender,
|
sender=sender,
|
||||||
|
email=email,
|
||||||
bypass_filters=True,
|
bypass_filters=True,
|
||||||
is_bot_msg=True,
|
is_bot_msg=True,
|
||||||
is_bot_dm=False,
|
is_bot_dm=False,
|
||||||
|
@@ -9,6 +9,7 @@ class SlackMessageInfo(BaseModel):
|
|||||||
msg_to_respond: str | None
|
msg_to_respond: str | None
|
||||||
thread_to_respond: str | None
|
thread_to_respond: str | None
|
||||||
sender: str | None
|
sender: str | None
|
||||||
|
email: str | None
|
||||||
bypass_filters: bool # User has tagged @DanswerBot
|
bypass_filters: bool # User has tagged @DanswerBot
|
||||||
is_bot_msg: bool # User is using /DanswerBot
|
is_bot_msg: bool # User is using /DanswerBot
|
||||||
is_bot_dm: bool # User is direct messaging to DanswerBot
|
is_bot_dm: bool # User is direct messaging to DanswerBot
|
||||||
|
@@ -157,6 +157,8 @@ class User(SQLAlchemyBaseUserTableUUID, Base):
|
|||||||
notifications: Mapped[list["Notification"]] = relationship(
|
notifications: Mapped[list["Notification"]] = relationship(
|
||||||
"Notification", back_populates="user"
|
"Notification", back_populates="user"
|
||||||
)
|
)
|
||||||
|
# Whether the user has logged in via web. False if user has only used Danswer through Slack bot
|
||||||
|
has_web_login: Mapped[bool] = mapped_column(Boolean, default=True)
|
||||||
|
|
||||||
|
|
||||||
class InputPrompt(Base):
|
class InputPrompt(Base):
|
||||||
|
@@ -1,9 +1,11 @@
|
|||||||
from collections.abc import Sequence
|
from collections.abc import Sequence
|
||||||
from uuid import UUID
|
from uuid import UUID
|
||||||
|
|
||||||
|
from fastapi_users.password import PasswordHelper
|
||||||
from sqlalchemy import select
|
from sqlalchemy import select
|
||||||
from sqlalchemy.orm import Session
|
from sqlalchemy.orm import Session
|
||||||
|
|
||||||
|
from danswer.auth.schemas import UserRole
|
||||||
from danswer.db.models import User
|
from danswer.db.models import User
|
||||||
|
|
||||||
|
|
||||||
@@ -30,3 +32,22 @@ def fetch_user_by_id(db_session: Session, user_id: UUID) -> User | None:
|
|||||||
user = db_session.query(User).filter(User.id == user_id).first() # type: ignore
|
user = db_session.query(User).filter(User.id == user_id).first() # type: ignore
|
||||||
|
|
||||||
return user
|
return user
|
||||||
|
|
||||||
|
|
||||||
|
def add_non_web_user_if_not_exists(email: str, db_session: Session) -> User:
|
||||||
|
user = get_user_by_email(email, db_session)
|
||||||
|
if user is not None:
|
||||||
|
return user
|
||||||
|
|
||||||
|
fastapi_users_pw_helper = PasswordHelper()
|
||||||
|
password = fastapi_users_pw_helper.generate()
|
||||||
|
hashed_pass = fastapi_users_pw_helper.hash(password)
|
||||||
|
user = User(
|
||||||
|
email=email,
|
||||||
|
hashed_password=hashed_pass,
|
||||||
|
has_web_login=False,
|
||||||
|
role=UserRole.BASIC,
|
||||||
|
)
|
||||||
|
db_session.add(user)
|
||||||
|
db_session.commit()
|
||||||
|
return user
|
||||||
|
@@ -65,6 +65,7 @@ async def upsert_saml_user(email: str) -> User:
|
|||||||
password=hashed_pass,
|
password=hashed_pass,
|
||||||
is_verified=True,
|
is_verified=True,
|
||||||
role=role,
|
role=role,
|
||||||
|
has_web_login=True,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@@ -72,6 +72,8 @@ export function EmailPasswordForm({
|
|||||||
let errorMsg = "Unknown error";
|
let errorMsg = "Unknown error";
|
||||||
if (errorDetail === "LOGIN_BAD_CREDENTIALS") {
|
if (errorDetail === "LOGIN_BAD_CREDENTIALS") {
|
||||||
errorMsg = "Invalid email or password";
|
errorMsg = "Invalid email or password";
|
||||||
|
} else if (errorDetail === "NO_WEB_LOGIN_AND_HAS_NO_PASSWORD") {
|
||||||
|
errorMsg = "Create an account to set a password";
|
||||||
}
|
}
|
||||||
setPopup({
|
setPopup({
|
||||||
type: "error",
|
type: "error",
|
||||||
|
Reference in New Issue
Block a user