mirror of
https://github.com/danswer-ai/danswer.git
synced 2025-04-18 08:41:04 +02:00
805 lines
27 KiB
Python
805 lines
27 KiB
Python
import re
|
|
from datetime import datetime
|
|
from datetime import timezone
|
|
|
|
import jwt
|
|
from email_validator import EmailNotValidError
|
|
from email_validator import EmailUndeliverableError
|
|
from email_validator import validate_email
|
|
from fastapi import APIRouter
|
|
from fastapi import Body
|
|
from fastapi import Depends
|
|
from fastapi import HTTPException
|
|
from fastapi import Query
|
|
from fastapi import Request
|
|
from psycopg2.errors import UniqueViolation
|
|
from pydantic import BaseModel
|
|
from sqlalchemy import Column
|
|
from sqlalchemy import desc
|
|
from sqlalchemy import select
|
|
from sqlalchemy import update
|
|
from sqlalchemy.exc import IntegrityError
|
|
from sqlalchemy.orm import Session
|
|
|
|
from ee.onyx.configs.app_configs import SUPER_USERS
|
|
from onyx.auth.email_utils import send_user_email_invite
|
|
from onyx.auth.invited_users import get_invited_users
|
|
from onyx.auth.invited_users import write_invited_users
|
|
from onyx.auth.noauth_user import fetch_no_auth_user
|
|
from onyx.auth.noauth_user import set_no_auth_user_preferences
|
|
from onyx.auth.schemas import UserRole
|
|
from onyx.auth.users import anonymous_user_enabled
|
|
from onyx.auth.users import current_admin_user
|
|
from onyx.auth.users import current_curator_or_admin_user
|
|
from onyx.auth.users import current_user
|
|
from onyx.auth.users import optional_minimal_user
|
|
from onyx.auth.users import optional_user
|
|
from onyx.configs.app_configs import AUTH_TYPE
|
|
from onyx.configs.app_configs import DEV_MODE
|
|
from onyx.configs.app_configs import ENABLE_EMAIL_INVITES
|
|
from onyx.configs.app_configs import SESSION_EXPIRE_TIME_SECONDS
|
|
from onyx.configs.app_configs import VALID_EMAIL_DOMAINS
|
|
from onyx.configs.constants import AuthType
|
|
from onyx.configs.constants import FASTAPI_USERS_AUTH_COOKIE_NAME
|
|
from onyx.db.api_key import is_api_key_email_address
|
|
from onyx.db.auth import get_total_users_count
|
|
from onyx.db.engine import get_session
|
|
from onyx.db.models import AccessToken
|
|
from onyx.db.models import MinimalUser
|
|
from onyx.db.models import User
|
|
from onyx.db.users import delete_user_from_db
|
|
from onyx.db.users import get_all_users
|
|
from onyx.db.users import get_page_of_filtered_users
|
|
from onyx.db.users import get_total_filtered_users_count
|
|
from onyx.db.users import get_user_by_email
|
|
from onyx.db.users import validate_user_role_update
|
|
from onyx.key_value_store.factory import get_kv_store
|
|
from onyx.server.documents.models import PaginatedReturn
|
|
from onyx.server.manage.models import AllUsersResponse
|
|
from onyx.server.manage.models import AutoScrollRequest
|
|
from onyx.server.manage.models import MinimalUserInfo
|
|
from onyx.server.manage.models import UserByEmail
|
|
from onyx.server.manage.models import UserInfo
|
|
from onyx.server.manage.models import UserPreferences
|
|
from onyx.server.manage.models import UserRoleResponse
|
|
from onyx.server.manage.models import UserRoleUpdateRequest
|
|
from onyx.server.models import FullUserSnapshot
|
|
from onyx.server.models import InvitedUserSnapshot
|
|
from onyx.server.models import MinimalUserSnapshot
|
|
from onyx.server.utils import BasicAuthenticationError
|
|
from onyx.utils.logger import setup_logger
|
|
from onyx.utils.variable_functionality import fetch_ee_implementation_or_noop
|
|
from shared_configs.configs import MULTI_TENANT
|
|
from shared_configs.contextvars import get_current_tenant_id
|
|
|
|
logger = setup_logger()
|
|
router = APIRouter()
|
|
|
|
USERS_PAGE_SIZE = 10
|
|
|
|
|
|
@router.patch("/manage/set-user-role")
|
|
def set_user_role(
|
|
user_role_update_request: UserRoleUpdateRequest,
|
|
current_user: User = Depends(current_admin_user),
|
|
db_session: Session = Depends(get_session),
|
|
) -> None:
|
|
user_to_update = get_user_by_email(
|
|
email=user_role_update_request.user_email, db_session=db_session
|
|
)
|
|
if not user_to_update:
|
|
raise HTTPException(status_code=404, detail="User not found")
|
|
|
|
current_role = user_to_update.role
|
|
requested_role = user_role_update_request.new_role
|
|
if requested_role == current_role:
|
|
return
|
|
|
|
# This will raise an exception if the role update is invalid
|
|
validate_user_role_update(
|
|
requested_role=requested_role,
|
|
current_role=current_role,
|
|
)
|
|
|
|
if user_to_update.id == current_user.id:
|
|
raise HTTPException(
|
|
status_code=400,
|
|
detail="An admin cannot demote themselves from admin role!",
|
|
)
|
|
|
|
if requested_role == UserRole.CURATOR:
|
|
# Remove all curator db relationships before changing role
|
|
fetch_ee_implementation_or_noop(
|
|
"onyx.db.user_group",
|
|
"remove_curator_status__no_commit",
|
|
)(db_session, user_to_update)
|
|
|
|
user_to_update.role = user_role_update_request.new_role
|
|
|
|
db_session.commit()
|
|
|
|
|
|
@router.get("/manage/users/accepted")
|
|
def list_accepted_users(
|
|
q: str | None = Query(default=None),
|
|
page_num: int = Query(0, ge=0),
|
|
page_size: int = Query(10, ge=1, le=1000),
|
|
roles: list[UserRole] = Query(default=[]),
|
|
is_active: bool | None = Query(default=None),
|
|
_: User | None = Depends(current_admin_user),
|
|
db_session: Session = Depends(get_session),
|
|
) -> PaginatedReturn[FullUserSnapshot]:
|
|
filtered_accepted_users = get_page_of_filtered_users(
|
|
db_session=db_session,
|
|
page_size=page_size,
|
|
page_num=page_num,
|
|
email_filter_string=q,
|
|
is_active_filter=is_active,
|
|
roles_filter=roles,
|
|
)
|
|
|
|
total_accepted_users_count = get_total_filtered_users_count(
|
|
db_session=db_session,
|
|
email_filter_string=q,
|
|
is_active_filter=is_active,
|
|
roles_filter=roles,
|
|
)
|
|
|
|
if not filtered_accepted_users:
|
|
logger.info("No users found")
|
|
return PaginatedReturn(
|
|
items=[],
|
|
total_items=0,
|
|
)
|
|
|
|
return PaginatedReturn(
|
|
items=[
|
|
FullUserSnapshot.from_user_model(user) for user in filtered_accepted_users
|
|
],
|
|
total_items=total_accepted_users_count,
|
|
)
|
|
|
|
|
|
@router.get("/manage/users/invited")
|
|
def list_invited_users(
|
|
_: User | None = Depends(current_admin_user),
|
|
) -> list[InvitedUserSnapshot]:
|
|
invited_emails = get_invited_users()
|
|
|
|
return [InvitedUserSnapshot(email=email) for email in invited_emails]
|
|
|
|
|
|
@router.get("/manage/users")
|
|
def list_all_users(
|
|
q: str | None = None,
|
|
accepted_page: int | None = None,
|
|
slack_users_page: int | None = None,
|
|
invited_page: int | None = None,
|
|
include_api_keys: bool = False,
|
|
_: User | None = Depends(current_curator_or_admin_user),
|
|
db_session: Session = Depends(get_session),
|
|
) -> AllUsersResponse:
|
|
users = [
|
|
user
|
|
for user in get_all_users(db_session, email_filter_string=q)
|
|
if (include_api_keys or not is_api_key_email_address(user.email))
|
|
]
|
|
|
|
slack_users = [user for user in users if user.role == UserRole.SLACK_USER]
|
|
accepted_users = [user for user in users if user.role != UserRole.SLACK_USER]
|
|
|
|
accepted_emails = {user.email for user in accepted_users}
|
|
slack_users_emails = {user.email for user in slack_users}
|
|
invited_emails = get_invited_users()
|
|
if q:
|
|
invited_emails = [
|
|
email for email in invited_emails if re.search(r"{}".format(q), email, re.I)
|
|
]
|
|
|
|
accepted_count = len(accepted_emails)
|
|
slack_users_count = len(slack_users_emails)
|
|
invited_count = len(invited_emails)
|
|
|
|
# If any of q, accepted_page, or invited_page is None, return all users
|
|
if accepted_page is None or invited_page is None or slack_users_page is None:
|
|
return AllUsersResponse(
|
|
accepted=[
|
|
FullUserSnapshot(
|
|
id=user.id,
|
|
email=user.email,
|
|
role=user.role,
|
|
is_active=user.is_active,
|
|
password_configured=user.password_configured,
|
|
)
|
|
for user in accepted_users
|
|
],
|
|
slack_users=[
|
|
FullUserSnapshot(
|
|
id=user.id,
|
|
email=user.email,
|
|
role=user.role,
|
|
is_active=user.is_active,
|
|
password_configured=user.password_configured,
|
|
)
|
|
for user in slack_users
|
|
],
|
|
invited=[InvitedUserSnapshot(email=email) for email in invited_emails],
|
|
accepted_pages=1,
|
|
invited_pages=1,
|
|
slack_users_pages=1,
|
|
)
|
|
|
|
# Otherwise, return paginated results
|
|
return AllUsersResponse(
|
|
accepted=[
|
|
FullUserSnapshot(
|
|
id=user.id,
|
|
email=user.email,
|
|
role=user.role,
|
|
is_active=user.is_active,
|
|
password_configured=user.password_configured,
|
|
)
|
|
for user in accepted_users
|
|
][accepted_page * USERS_PAGE_SIZE : (accepted_page + 1) * USERS_PAGE_SIZE],
|
|
slack_users=[
|
|
FullUserSnapshot(
|
|
id=user.id,
|
|
email=user.email,
|
|
role=user.role,
|
|
is_active=user.is_active,
|
|
password_configured=user.password_configured,
|
|
)
|
|
for user in slack_users
|
|
][
|
|
slack_users_page
|
|
* USERS_PAGE_SIZE : (slack_users_page + 1)
|
|
* USERS_PAGE_SIZE
|
|
],
|
|
invited=[InvitedUserSnapshot(email=email) for email in invited_emails][
|
|
invited_page * USERS_PAGE_SIZE : (invited_page + 1) * USERS_PAGE_SIZE
|
|
],
|
|
accepted_pages=(accepted_count + USERS_PAGE_SIZE - 1) // USERS_PAGE_SIZE,
|
|
invited_pages=(invited_count + USERS_PAGE_SIZE - 1) // USERS_PAGE_SIZE,
|
|
slack_users_pages=(slack_users_count + USERS_PAGE_SIZE - 1) // USERS_PAGE_SIZE,
|
|
)
|
|
|
|
|
|
@router.put("/manage/admin/users")
|
|
def bulk_invite_users(
|
|
emails: list[str] = Body(..., embed=True),
|
|
current_user: User | None = Depends(current_admin_user),
|
|
db_session: Session = Depends(get_session),
|
|
) -> int:
|
|
"""emails are string validated. If any email fails validation, no emails are
|
|
invited and an exception is raised."""
|
|
tenant_id = get_current_tenant_id()
|
|
|
|
if current_user is None:
|
|
raise HTTPException(
|
|
status_code=400, detail="Auth is disabled, cannot invite users"
|
|
)
|
|
|
|
new_invited_emails = []
|
|
email: str
|
|
|
|
try:
|
|
for email in emails:
|
|
email_info = validate_email(email)
|
|
new_invited_emails.append(email_info.normalized)
|
|
|
|
except (EmailUndeliverableError, EmailNotValidError) as e:
|
|
raise HTTPException(
|
|
status_code=400,
|
|
detail=f"Invalid email address: {email} - {str(e)}",
|
|
)
|
|
|
|
if MULTI_TENANT and not DEV_MODE:
|
|
try:
|
|
fetch_ee_implementation_or_noop(
|
|
"onyx.server.tenants.provisioning", "add_users_to_tenant", None
|
|
)(new_invited_emails, tenant_id)
|
|
|
|
except IntegrityError as e:
|
|
if isinstance(e.orig, UniqueViolation):
|
|
raise HTTPException(
|
|
status_code=400,
|
|
detail="User has already been invited to a Onyx organization",
|
|
)
|
|
raise
|
|
except Exception as e:
|
|
logger.error(f"Failed to add users to tenant {tenant_id}: {str(e)}")
|
|
|
|
initial_invited_users = get_invited_users()
|
|
|
|
all_emails = list(set(new_invited_emails) | set(initial_invited_users))
|
|
number_of_invited_users = write_invited_users(all_emails)
|
|
|
|
# send out email invitations if enabled
|
|
if ENABLE_EMAIL_INVITES:
|
|
try:
|
|
for email in new_invited_emails:
|
|
send_user_email_invite(email, current_user, AUTH_TYPE)
|
|
except Exception as e:
|
|
logger.error(f"Error sending email invite to invited users: {e}")
|
|
|
|
if not MULTI_TENANT:
|
|
return number_of_invited_users
|
|
|
|
# for billing purposes, write to the control plane about the number of new users
|
|
try:
|
|
logger.info("Registering tenant users")
|
|
fetch_ee_implementation_or_noop(
|
|
"onyx.server.tenants.billing", "register_tenant_users", None
|
|
)(tenant_id, get_total_users_count(db_session))
|
|
|
|
return number_of_invited_users
|
|
except Exception as e:
|
|
logger.error(f"Failed to register tenant users: {str(e)}")
|
|
logger.info(
|
|
"Reverting changes: removing users from tenant and resetting invited users"
|
|
)
|
|
write_invited_users(initial_invited_users) # Reset to original state
|
|
fetch_ee_implementation_or_noop(
|
|
"onyx.server.tenants.user_mapping", "remove_users_from_tenant", None
|
|
)(new_invited_emails, tenant_id)
|
|
raise e
|
|
|
|
|
|
@router.patch("/manage/admin/remove-invited-user")
|
|
def remove_invited_user(
|
|
user_email: UserByEmail,
|
|
_: User | None = Depends(current_admin_user),
|
|
db_session: Session = Depends(get_session),
|
|
) -> int:
|
|
tenant_id = get_current_tenant_id()
|
|
user_emails = get_invited_users()
|
|
remaining_users = [user for user in user_emails if user != user_email.user_email]
|
|
|
|
fetch_ee_implementation_or_noop(
|
|
"onyx.server.tenants.user_mapping", "remove_users_from_tenant", None
|
|
)([user_email.user_email], tenant_id)
|
|
number_of_invited_users = write_invited_users(remaining_users)
|
|
|
|
try:
|
|
if MULTI_TENANT:
|
|
fetch_ee_implementation_or_noop(
|
|
"onyx.server.tenants.billing", "register_tenant_users", None
|
|
)(tenant_id, get_total_users_count(db_session))
|
|
except Exception:
|
|
logger.error(
|
|
"Request to update number of seats taken in control plane failed. "
|
|
"This may cause synchronization issues/out of date enforcement of seat limits."
|
|
)
|
|
raise
|
|
|
|
return number_of_invited_users
|
|
|
|
|
|
@router.patch("/manage/admin/deactivate-user")
|
|
def deactivate_user(
|
|
user_email: UserByEmail,
|
|
current_user: User | None = Depends(current_admin_user),
|
|
db_session: Session = Depends(get_session),
|
|
) -> None:
|
|
if current_user is None:
|
|
raise HTTPException(
|
|
status_code=400, detail="Auth is disabled, cannot deactivate user"
|
|
)
|
|
|
|
if current_user.email == user_email.user_email:
|
|
raise HTTPException(status_code=400, detail="You cannot deactivate yourself")
|
|
|
|
user_to_deactivate = get_user_by_email(
|
|
email=user_email.user_email, db_session=db_session
|
|
)
|
|
|
|
if not user_to_deactivate:
|
|
raise HTTPException(status_code=404, detail="User not found")
|
|
|
|
if user_to_deactivate.is_active is False:
|
|
logger.warning("{} is already deactivated".format(user_to_deactivate.email))
|
|
|
|
user_to_deactivate.is_active = False
|
|
db_session.add(user_to_deactivate)
|
|
db_session.commit()
|
|
|
|
|
|
@router.delete("/manage/admin/delete-user")
|
|
async def delete_user(
|
|
user_email: UserByEmail,
|
|
_: User | None = Depends(current_admin_user),
|
|
db_session: Session = Depends(get_session),
|
|
) -> None:
|
|
user_to_delete = get_user_by_email(
|
|
email=user_email.user_email, db_session=db_session
|
|
)
|
|
if not user_to_delete:
|
|
raise HTTPException(status_code=404, detail="User not found")
|
|
|
|
if user_to_delete.is_active is True:
|
|
logger.warning(
|
|
"{} must be deactivated before deleting".format(user_to_delete.email)
|
|
)
|
|
raise HTTPException(
|
|
status_code=400, detail="User must be deactivated before deleting"
|
|
)
|
|
|
|
# Detach the user from the current session
|
|
db_session.expunge(user_to_delete)
|
|
|
|
try:
|
|
delete_user_from_db(user_to_delete, db_session)
|
|
logger.info(f"Deleted user {user_to_delete.email}")
|
|
|
|
except Exception as e:
|
|
db_session.rollback()
|
|
logger.error(f"Error deleting user {user_to_delete.email}: {str(e)}")
|
|
raise HTTPException(status_code=500, detail="Error deleting user")
|
|
|
|
|
|
@router.patch("/manage/admin/activate-user")
|
|
def activate_user(
|
|
user_email: UserByEmail,
|
|
_: User | None = Depends(current_admin_user),
|
|
db_session: Session = Depends(get_session),
|
|
) -> None:
|
|
user_to_activate = get_user_by_email(
|
|
email=user_email.user_email, db_session=db_session
|
|
)
|
|
if not user_to_activate:
|
|
raise HTTPException(status_code=404, detail="User not found")
|
|
|
|
if user_to_activate.is_active is True:
|
|
logger.warning("{} is already activated".format(user_to_activate.email))
|
|
|
|
user_to_activate.is_active = True
|
|
db_session.add(user_to_activate)
|
|
db_session.commit()
|
|
|
|
|
|
@router.get("/manage/admin/valid-domains")
|
|
def get_valid_domains(
|
|
_: User | None = Depends(current_admin_user),
|
|
) -> list[str]:
|
|
return VALID_EMAIL_DOMAINS
|
|
|
|
|
|
"""Endpoints for all"""
|
|
|
|
|
|
@router.get("/users")
|
|
def list_all_users_basic_info(
|
|
_: User | None = Depends(current_user),
|
|
db_session: Session = Depends(get_session),
|
|
) -> list[MinimalUserSnapshot]:
|
|
users = get_all_users(db_session)
|
|
return [MinimalUserSnapshot(id=user.id, email=user.email) for user in users]
|
|
|
|
|
|
@router.get("/get-user-role")
|
|
async def get_user_role(user: User = Depends(current_user)) -> UserRoleResponse:
|
|
if user is None:
|
|
raise ValueError("Invalid or missing user.")
|
|
return UserRoleResponse(role=user.role)
|
|
|
|
|
|
def get_current_token_expiration_jwt(
|
|
user: User | None, request: Request
|
|
) -> datetime | None:
|
|
if user is None:
|
|
return None
|
|
|
|
try:
|
|
# Get the JWT from the cookie
|
|
jwt_token = request.cookies.get(FASTAPI_USERS_AUTH_COOKIE_NAME)
|
|
if not jwt_token:
|
|
logger.error("No JWT token found in cookies")
|
|
return None
|
|
|
|
# Decode the JWT
|
|
decoded_token = jwt.decode(jwt_token, options={"verify_signature": False})
|
|
|
|
# Get the 'exp' (expiration) claim from the token
|
|
exp = decoded_token.get("exp")
|
|
if exp:
|
|
return datetime.fromtimestamp(exp)
|
|
else:
|
|
logger.error("No 'exp' claim found in JWT")
|
|
return None
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error decoding JWT: {e}")
|
|
return None
|
|
|
|
|
|
def get_current_token_creation(
|
|
user: User | None, db_session: Session
|
|
) -> datetime | None:
|
|
if user is None:
|
|
return None
|
|
try:
|
|
result = db_session.execute(
|
|
select(AccessToken)
|
|
.where(AccessToken.user_id == user.id) # type: ignore
|
|
.order_by(desc(Column("created_at")))
|
|
.limit(1)
|
|
)
|
|
access_token = result.scalar_one_or_none()
|
|
|
|
if access_token:
|
|
return access_token.created_at
|
|
else:
|
|
logger.error("No AccessToken found for user")
|
|
return None
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error fetching AccessToken: {e}")
|
|
return None
|
|
|
|
|
|
@router.get("/me-info")
|
|
def verify_user_attempting_to_login(
|
|
request: Request,
|
|
user: MinimalUser | None = Depends(optional_minimal_user),
|
|
# db_session: Session = Depends(get_session),
|
|
) -> MinimalUserInfo:
|
|
# Check if the authentication cookie exists
|
|
# Print cookie names for debugging
|
|
cookie_names = list(request.cookies.keys())
|
|
logger.info(f"Available cookies: {cookie_names}")
|
|
|
|
if not request.cookies.get(FASTAPI_USERS_AUTH_COOKIE_NAME):
|
|
raise HTTPException(status_code=401, detail="User not found")
|
|
# print("I AM IN THIS FUNCTION 4")
|
|
# if user is None:
|
|
# print("I AM IN THIS FUNCTION 5")
|
|
# raise HTTPException(status_code=401, detail="User not found")
|
|
# print("I AM IN THIS FUNCTION 6")
|
|
return MinimalUserInfo(id="", email="", is_active=True)
|
|
|
|
|
|
@router.get("/me")
|
|
def verify_user_logged_in(
|
|
user: User | None = Depends(optional_user),
|
|
db_session: Session = Depends(get_session),
|
|
) -> UserInfo:
|
|
tenant_id = get_current_tenant_id()
|
|
|
|
# NOTE: this does not use `current_user` / `current_admin_user` because we don't want
|
|
# to enforce user verification here - the frontend always wants to get the info about
|
|
# the current user regardless of if they are currently verified
|
|
if user is None:
|
|
# if auth type is disabled, return a dummy user with preferences from
|
|
# the key-value store
|
|
if AUTH_TYPE == AuthType.DISABLED:
|
|
store = get_kv_store()
|
|
return fetch_no_auth_user(store)
|
|
if anonymous_user_enabled(tenant_id=tenant_id):
|
|
store = get_kv_store()
|
|
return fetch_no_auth_user(store, anonymous_user_enabled=True)
|
|
|
|
raise BasicAuthenticationError(detail="User Not Authenticated")
|
|
if user.oidc_expiry and user.oidc_expiry < datetime.now(timezone.utc):
|
|
raise BasicAuthenticationError(
|
|
detail="Access denied. User's OIDC token has expired.",
|
|
)
|
|
|
|
token_created_at = (
|
|
None if MULTI_TENANT else get_current_token_creation(user, db_session)
|
|
)
|
|
organization_name = fetch_ee_implementation_or_noop(
|
|
"onyx.server.tenants.user_mapping", "get_tenant_id_for_email", None
|
|
)(user.email)
|
|
|
|
user_info = UserInfo.from_model(
|
|
user,
|
|
current_token_created_at=token_created_at,
|
|
expiry_length=SESSION_EXPIRE_TIME_SECONDS,
|
|
is_cloud_superuser=user.email in SUPER_USERS,
|
|
organization_name=organization_name,
|
|
)
|
|
|
|
return user_info
|
|
|
|
|
|
"""APIs to adjust user preferences"""
|
|
|
|
|
|
@router.patch("/temperature-override-enabled")
|
|
def update_user_temperature_override_enabled(
|
|
temperature_override_enabled: bool,
|
|
user: User | None = Depends(current_user),
|
|
db_session: Session = Depends(get_session),
|
|
) -> None:
|
|
if user is None:
|
|
if AUTH_TYPE == AuthType.DISABLED:
|
|
store = get_kv_store()
|
|
no_auth_user = fetch_no_auth_user(store)
|
|
no_auth_user.preferences.temperature_override_enabled = (
|
|
temperature_override_enabled
|
|
)
|
|
set_no_auth_user_preferences(store, no_auth_user.preferences)
|
|
return
|
|
else:
|
|
raise RuntimeError("This should never happen")
|
|
|
|
db_session.execute(
|
|
update(User)
|
|
.where(User.id == user.id) # type: ignore
|
|
.values(temperature_override_enabled=temperature_override_enabled)
|
|
)
|
|
db_session.commit()
|
|
|
|
|
|
class ChosenDefaultModelRequest(BaseModel):
|
|
default_model: str | None = None
|
|
|
|
|
|
@router.patch("/shortcut-enabled")
|
|
def update_user_shortcut_enabled(
|
|
shortcut_enabled: bool,
|
|
user: User | None = Depends(current_user),
|
|
db_session: Session = Depends(get_session),
|
|
) -> None:
|
|
if user is None:
|
|
if AUTH_TYPE == AuthType.DISABLED:
|
|
store = get_kv_store()
|
|
no_auth_user = fetch_no_auth_user(store)
|
|
no_auth_user.preferences.shortcut_enabled = shortcut_enabled
|
|
set_no_auth_user_preferences(store, no_auth_user.preferences)
|
|
return
|
|
else:
|
|
raise RuntimeError("This should never happen")
|
|
|
|
db_session.execute(
|
|
update(User)
|
|
.where(User.id == user.id) # type: ignore
|
|
.values(shortcut_enabled=shortcut_enabled)
|
|
)
|
|
db_session.commit()
|
|
|
|
|
|
@router.patch("/auto-scroll")
|
|
def update_user_auto_scroll(
|
|
request: AutoScrollRequest,
|
|
user: User | None = Depends(current_user),
|
|
db_session: Session = Depends(get_session),
|
|
) -> None:
|
|
if user is None:
|
|
if AUTH_TYPE == AuthType.DISABLED:
|
|
store = get_kv_store()
|
|
no_auth_user = fetch_no_auth_user(store)
|
|
no_auth_user.preferences.auto_scroll = request.auto_scroll
|
|
set_no_auth_user_preferences(store, no_auth_user.preferences)
|
|
return
|
|
else:
|
|
raise RuntimeError("This should never happen")
|
|
|
|
db_session.execute(
|
|
update(User)
|
|
.where(User.id == user.id) # type: ignore
|
|
.values(auto_scroll=request.auto_scroll)
|
|
)
|
|
db_session.commit()
|
|
|
|
|
|
@router.patch("/user/default-model")
|
|
def update_user_default_model(
|
|
request: ChosenDefaultModelRequest,
|
|
user: User | None = Depends(current_user),
|
|
db_session: Session = Depends(get_session),
|
|
) -> None:
|
|
if user is None:
|
|
if AUTH_TYPE == AuthType.DISABLED:
|
|
store = get_kv_store()
|
|
no_auth_user = fetch_no_auth_user(store)
|
|
no_auth_user.preferences.default_model = request.default_model
|
|
set_no_auth_user_preferences(store, no_auth_user.preferences)
|
|
return
|
|
else:
|
|
raise RuntimeError("This should never happen")
|
|
|
|
db_session.execute(
|
|
update(User)
|
|
.where(User.id == user.id) # type: ignore
|
|
.values(default_model=request.default_model)
|
|
)
|
|
db_session.commit()
|
|
|
|
|
|
class ReorderPinnedAssistantsRequest(BaseModel):
|
|
ordered_assistant_ids: list[int]
|
|
|
|
|
|
@router.patch("/user/pinned-assistants")
|
|
def update_user_pinned_assistants(
|
|
request: ReorderPinnedAssistantsRequest,
|
|
user: User | None = Depends(current_user),
|
|
db_session: Session = Depends(get_session),
|
|
) -> None:
|
|
ordered_assistant_ids = request.ordered_assistant_ids
|
|
|
|
if user is None:
|
|
if AUTH_TYPE == AuthType.DISABLED:
|
|
store = get_kv_store()
|
|
no_auth_user = fetch_no_auth_user(store)
|
|
no_auth_user.preferences.pinned_assistants = ordered_assistant_ids
|
|
set_no_auth_user_preferences(store, no_auth_user.preferences)
|
|
return
|
|
else:
|
|
raise RuntimeError("This should never happen")
|
|
|
|
db_session.execute(
|
|
update(User)
|
|
.where(User.id == user.id) # type: ignore
|
|
.values(pinned_assistants=ordered_assistant_ids)
|
|
)
|
|
db_session.commit()
|
|
|
|
|
|
class ChosenAssistantsRequest(BaseModel):
|
|
chosen_assistants: list[int]
|
|
|
|
|
|
def update_assistant_visibility(
|
|
preferences: UserPreferences, assistant_id: int, show: bool
|
|
) -> UserPreferences:
|
|
visible_assistants = preferences.visible_assistants or []
|
|
hidden_assistants = preferences.hidden_assistants or []
|
|
|
|
if show:
|
|
if assistant_id not in visible_assistants:
|
|
visible_assistants.append(assistant_id)
|
|
if assistant_id in hidden_assistants:
|
|
hidden_assistants.remove(assistant_id)
|
|
else:
|
|
if assistant_id in visible_assistants:
|
|
visible_assistants.remove(assistant_id)
|
|
if assistant_id not in hidden_assistants:
|
|
hidden_assistants.append(assistant_id)
|
|
|
|
preferences.visible_assistants = visible_assistants
|
|
preferences.hidden_assistants = hidden_assistants
|
|
return preferences
|
|
|
|
|
|
@router.patch("/user/assistant-list/update/{assistant_id}")
|
|
def update_user_assistant_visibility(
|
|
assistant_id: int,
|
|
show: bool,
|
|
user: User | None = Depends(current_user),
|
|
db_session: Session = Depends(get_session),
|
|
) -> None:
|
|
if user is None:
|
|
if AUTH_TYPE == AuthType.DISABLED:
|
|
store = get_kv_store()
|
|
no_auth_user = fetch_no_auth_user(store)
|
|
preferences = no_auth_user.preferences
|
|
updated_preferences = update_assistant_visibility(
|
|
preferences, assistant_id, show
|
|
)
|
|
if updated_preferences.chosen_assistants is not None:
|
|
updated_preferences.chosen_assistants.append(assistant_id)
|
|
|
|
set_no_auth_user_preferences(store, updated_preferences)
|
|
return
|
|
else:
|
|
raise RuntimeError("This should never happen")
|
|
|
|
user_preferences = UserInfo.from_model(user).preferences
|
|
updated_preferences = update_assistant_visibility(
|
|
user_preferences, assistant_id, show
|
|
)
|
|
if updated_preferences.chosen_assistants is not None:
|
|
updated_preferences.chosen_assistants.append(assistant_id)
|
|
db_session.execute(
|
|
update(User)
|
|
.where(User.id == user.id) # type: ignore
|
|
.values(
|
|
hidden_assistants=updated_preferences.hidden_assistants,
|
|
visible_assistants=updated_preferences.visible_assistants,
|
|
chosen_assistants=updated_preferences.chosen_assistants,
|
|
)
|
|
)
|
|
db_session.commit()
|