mirror of
https://github.com/danswer-ai/danswer.git
synced 2025-09-08 16:05:32 +02:00
Fix/remove ee from mit (#4682)
* Remove some ee imports * more * Remove all ee imports * Fix * Autodiscover * fix * Fix typing * More celery task stuff * Fix import
This commit is contained in:
@@ -12,12 +12,6 @@ from sqlalchemy.orm import Session
|
||||
from onyx.auth.users import current_curator_or_admin_user
|
||||
from onyx.auth.users import current_user
|
||||
from onyx.background.celery.celery_utils import get_deletion_attempt_snapshot
|
||||
from onyx.background.celery.tasks.doc_permission_syncing.tasks import (
|
||||
try_creating_permissions_sync_task,
|
||||
)
|
||||
from onyx.background.celery.tasks.external_group_syncing.tasks import (
|
||||
try_creating_external_group_sync_task,
|
||||
)
|
||||
from onyx.background.celery.tasks.pruning.tasks import (
|
||||
try_creating_prune_generator_task,
|
||||
)
|
||||
@@ -392,154 +386,6 @@ def prune_cc_pair(
|
||||
)
|
||||
|
||||
|
||||
@router.get("/admin/cc-pair/{cc_pair_id}/sync-permissions")
|
||||
def get_cc_pair_latest_sync(
|
||||
cc_pair_id: int,
|
||||
user: User = Depends(current_curator_or_admin_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> datetime | None:
|
||||
cc_pair = get_connector_credential_pair_from_id_for_user(
|
||||
cc_pair_id=cc_pair_id,
|
||||
db_session=db_session,
|
||||
user=user,
|
||||
get_editable=False,
|
||||
)
|
||||
if not cc_pair:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail="cc_pair not found for current user's permissions",
|
||||
)
|
||||
|
||||
return cc_pair.last_time_perm_sync
|
||||
|
||||
|
||||
@router.post("/admin/cc-pair/{cc_pair_id}/sync-permissions")
|
||||
def sync_cc_pair(
|
||||
cc_pair_id: int,
|
||||
user: User = Depends(current_curator_or_admin_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> StatusResponse[list[int]]:
|
||||
"""Triggers permissions sync on a particular cc_pair immediately"""
|
||||
tenant_id = get_current_tenant_id()
|
||||
|
||||
cc_pair = get_connector_credential_pair_from_id_for_user(
|
||||
cc_pair_id=cc_pair_id,
|
||||
db_session=db_session,
|
||||
user=user,
|
||||
get_editable=False,
|
||||
)
|
||||
if not cc_pair:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail="Connection not found for current user's permissions",
|
||||
)
|
||||
|
||||
r = get_redis_client()
|
||||
|
||||
redis_connector = RedisConnector(tenant_id, cc_pair_id)
|
||||
if redis_connector.permissions.fenced:
|
||||
raise HTTPException(
|
||||
status_code=HTTPStatus.CONFLICT,
|
||||
detail="Permissions sync task already in progress.",
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f"Permissions sync cc_pair={cc_pair_id} "
|
||||
f"connector_id={cc_pair.connector_id} "
|
||||
f"credential_id={cc_pair.credential_id} "
|
||||
f"{cc_pair.connector.name} connector."
|
||||
)
|
||||
payload_id = try_creating_permissions_sync_task(
|
||||
client_app, cc_pair_id, r, tenant_id
|
||||
)
|
||||
if not payload_id:
|
||||
raise HTTPException(
|
||||
status_code=HTTPStatus.INTERNAL_SERVER_ERROR,
|
||||
detail="Permissions sync task creation failed.",
|
||||
)
|
||||
|
||||
logger.info(f"Permissions sync queued: cc_pair={cc_pair_id} id={payload_id}")
|
||||
|
||||
return StatusResponse(
|
||||
success=True,
|
||||
message="Successfully created the permissions sync task.",
|
||||
)
|
||||
|
||||
|
||||
@router.get("/admin/cc-pair/{cc_pair_id}/sync-groups")
|
||||
def get_cc_pair_latest_group_sync(
|
||||
cc_pair_id: int,
|
||||
user: User = Depends(current_curator_or_admin_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> datetime | None:
|
||||
cc_pair = get_connector_credential_pair_from_id_for_user(
|
||||
cc_pair_id=cc_pair_id,
|
||||
db_session=db_session,
|
||||
user=user,
|
||||
get_editable=False,
|
||||
)
|
||||
if not cc_pair:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail="cc_pair not found for current user's permissions",
|
||||
)
|
||||
|
||||
return cc_pair.last_time_external_group_sync
|
||||
|
||||
|
||||
@router.post("/admin/cc-pair/{cc_pair_id}/sync-groups")
|
||||
def sync_cc_pair_groups(
|
||||
cc_pair_id: int,
|
||||
user: User = Depends(current_curator_or_admin_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> StatusResponse[list[int]]:
|
||||
"""Triggers group sync on a particular cc_pair immediately"""
|
||||
tenant_id = get_current_tenant_id()
|
||||
|
||||
cc_pair = get_connector_credential_pair_from_id_for_user(
|
||||
cc_pair_id=cc_pair_id,
|
||||
db_session=db_session,
|
||||
user=user,
|
||||
get_editable=False,
|
||||
)
|
||||
if not cc_pair:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail="Connection not found for current user's permissions",
|
||||
)
|
||||
|
||||
r = get_redis_client()
|
||||
|
||||
redis_connector = RedisConnector(tenant_id, cc_pair_id)
|
||||
if redis_connector.external_group_sync.fenced:
|
||||
raise HTTPException(
|
||||
status_code=HTTPStatus.CONFLICT,
|
||||
detail="External group sync task already in progress.",
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f"External group sync cc_pair={cc_pair_id} "
|
||||
f"connector_id={cc_pair.connector_id} "
|
||||
f"credential_id={cc_pair.credential_id} "
|
||||
f"{cc_pair.connector.name} connector."
|
||||
)
|
||||
payload_id = try_creating_external_group_sync_task(
|
||||
client_app, cc_pair_id, r, tenant_id
|
||||
)
|
||||
if not payload_id:
|
||||
raise HTTPException(
|
||||
status_code=HTTPStatus.INTERNAL_SERVER_ERROR,
|
||||
detail="External group sync task creation failed.",
|
||||
)
|
||||
|
||||
logger.info(f"External group sync queued: cc_pair={cc_pair_id} id={payload_id}")
|
||||
|
||||
return StatusResponse(
|
||||
success=True,
|
||||
message="Successfully created the external group sync task.",
|
||||
)
|
||||
|
||||
|
||||
@router.get("/admin/cc-pair/{cc_pair_id}/get-docs-sync-status")
|
||||
def get_docs_sync_status(
|
||||
cc_pair_id: int,
|
||||
|
@@ -9,8 +9,6 @@ from uuid import UUID
|
||||
from pydantic import BaseModel
|
||||
from pydantic import Field
|
||||
|
||||
from ee.onyx.server.query_history.models import ChatSessionMinimal
|
||||
from onyx.background.indexing.models import IndexAttemptErrorPydantic
|
||||
from onyx.configs.app_configs import MASK_CREDENTIAL_PREFIX
|
||||
from onyx.configs.constants import DocumentSource
|
||||
from onyx.connectors.models import InputType
|
||||
@@ -23,8 +21,6 @@ from onyx.db.models import Document as DbDocument
|
||||
from onyx.db.models import IndexAttempt
|
||||
from onyx.db.models import IndexingStatus
|
||||
from onyx.db.models import TaskStatus
|
||||
from onyx.server.models import FullUserSnapshot
|
||||
from onyx.server.models import InvitedUserSnapshot
|
||||
from onyx.server.utils import mask_credential_dict
|
||||
from onyx.utils.variable_functionality import fetch_ee_implementation_or_noop
|
||||
|
||||
@@ -193,14 +189,7 @@ class IndexAttemptSnapshot(BaseModel):
|
||||
|
||||
# These are the types currently supported by the pagination hook
|
||||
# More api endpoints can be refactored and be added here for use with the pagination hook
|
||||
PaginatedType = TypeVar(
|
||||
"PaginatedType",
|
||||
IndexAttemptSnapshot,
|
||||
FullUserSnapshot,
|
||||
InvitedUserSnapshot,
|
||||
ChatSessionMinimal,
|
||||
IndexAttemptErrorPydantic,
|
||||
)
|
||||
PaginatedType = TypeVar("PaginatedType", bound=BaseModel)
|
||||
|
||||
|
||||
class PaginatedReturn(BaseModel, Generic[PaginatedType]):
|
||||
|
@@ -1,5 +1,7 @@
|
||||
import re
|
||||
from datetime import datetime
|
||||
from enum import Enum
|
||||
from typing import Any
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from pydantic import BaseModel
|
||||
@@ -8,7 +10,6 @@ from pydantic import Field
|
||||
from pydantic import field_validator
|
||||
from pydantic import model_validator
|
||||
|
||||
from ee.onyx.server.manage.models import StandardAnswerCategory
|
||||
from onyx.auth.schemas import UserRole
|
||||
from onyx.configs.app_configs import TRACK_EXTERNAL_IDP_EXPIRY
|
||||
from onyx.configs.constants import AuthType
|
||||
@@ -17,6 +18,8 @@ from onyx.db.models import AllowedAnswerFilters
|
||||
from onyx.db.models import ChannelConfig
|
||||
from onyx.db.models import SlackBot as SlackAppModel
|
||||
from onyx.db.models import SlackChannelConfig as SlackChannelConfigModel
|
||||
from onyx.db.models import StandardAnswer as StandardAnswerModel
|
||||
from onyx.db.models import StandardAnswerCategory as StandardAnswerCategoryModel
|
||||
from onyx.db.models import User
|
||||
from onyx.onyxbot.slack.config import VALID_SLACK_FILTERS
|
||||
from onyx.server.features.persona.models import FullPersonaSnapshot
|
||||
@@ -234,7 +237,7 @@ class SlackChannelConfig(BaseModel):
|
||||
persona: PersonaSnapshot | None
|
||||
channel_config: ChannelConfig
|
||||
# XXX this is going away soon
|
||||
standard_answer_categories: list[StandardAnswerCategory]
|
||||
standard_answer_categories: list["StandardAnswerCategory"]
|
||||
enable_auto_filters: bool
|
||||
is_default: bool
|
||||
|
||||
@@ -307,3 +310,99 @@ class AllUsersResponse(BaseModel):
|
||||
class SlackChannel(BaseModel):
|
||||
id: str
|
||||
name: str
|
||||
|
||||
|
||||
"""
|
||||
Standard Answer Models
|
||||
|
||||
ee only, but needs to be here since it's imported by non-ee models.
|
||||
"""
|
||||
|
||||
|
||||
class StandardAnswerCategoryCreationRequest(BaseModel):
|
||||
name: str
|
||||
|
||||
|
||||
class StandardAnswerCategory(BaseModel):
|
||||
id: int
|
||||
name: str
|
||||
|
||||
@classmethod
|
||||
def from_model(
|
||||
cls, standard_answer_category: StandardAnswerCategoryModel
|
||||
) -> "StandardAnswerCategory":
|
||||
return cls(
|
||||
id=standard_answer_category.id,
|
||||
name=standard_answer_category.name,
|
||||
)
|
||||
|
||||
|
||||
class StandardAnswer(BaseModel):
|
||||
id: int
|
||||
keyword: str
|
||||
answer: str
|
||||
categories: list[StandardAnswerCategory]
|
||||
match_regex: bool
|
||||
match_any_keywords: bool
|
||||
|
||||
@classmethod
|
||||
def from_model(cls, standard_answer_model: StandardAnswerModel) -> "StandardAnswer":
|
||||
return cls(
|
||||
id=standard_answer_model.id,
|
||||
keyword=standard_answer_model.keyword,
|
||||
answer=standard_answer_model.answer,
|
||||
match_regex=standard_answer_model.match_regex,
|
||||
match_any_keywords=standard_answer_model.match_any_keywords,
|
||||
categories=[
|
||||
StandardAnswerCategory.from_model(standard_answer_category_model)
|
||||
for standard_answer_category_model in standard_answer_model.categories
|
||||
],
|
||||
)
|
||||
|
||||
|
||||
class StandardAnswerCreationRequest(BaseModel):
|
||||
keyword: str
|
||||
answer: str
|
||||
categories: list[int]
|
||||
match_regex: bool
|
||||
match_any_keywords: bool
|
||||
|
||||
@field_validator("categories", mode="before")
|
||||
@classmethod
|
||||
def validate_categories(cls, value: list[int]) -> list[int]:
|
||||
if len(value) < 1:
|
||||
raise ValueError(
|
||||
"At least one category must be attached to a standard answer"
|
||||
)
|
||||
return value
|
||||
|
||||
@model_validator(mode="after")
|
||||
def validate_only_match_any_if_not_regex(self) -> Any:
|
||||
if self.match_regex and self.match_any_keywords:
|
||||
raise ValueError(
|
||||
"Can only match any keywords in keyword mode, not regex mode"
|
||||
)
|
||||
|
||||
return self
|
||||
|
||||
@model_validator(mode="after")
|
||||
def validate_keyword_if_regex(self) -> Any:
|
||||
if not self.match_regex:
|
||||
# no validation for keywords
|
||||
return self
|
||||
|
||||
try:
|
||||
re.compile(self.keyword)
|
||||
return self
|
||||
except re.error as err:
|
||||
if isinstance(err.pattern, bytes):
|
||||
raise ValueError(
|
||||
f'invalid regex pattern r"{err.pattern.decode()}" in `keyword`: {err.msg}'
|
||||
)
|
||||
else:
|
||||
pattern = f'r"{err.pattern}"' if err.pattern is not None else ""
|
||||
raise ValueError(
|
||||
" ".join(
|
||||
["invalid regex pattern", pattern, f"in `keyword`: {err.msg}"]
|
||||
)
|
||||
)
|
||||
|
@@ -21,7 +21,6 @@ from sqlalchemy import select
|
||||
from sqlalchemy import update
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from ee.onyx.configs.app_configs import SUPER_USERS
|
||||
from onyx.auth.email_utils import send_user_email_invite
|
||||
from onyx.auth.invited_users import get_invited_users
|
||||
from onyx.auth.invited_users import write_invited_users
|
||||
@@ -72,6 +71,9 @@ from onyx.server.models import MinimalUserSnapshot
|
||||
from onyx.server.utils import BasicAuthenticationError
|
||||
from onyx.utils.logger import setup_logger
|
||||
from onyx.utils.variable_functionality import fetch_ee_implementation_or_noop
|
||||
from onyx.utils.variable_functionality import (
|
||||
fetch_versioned_implementation_with_fallback,
|
||||
)
|
||||
from shared_configs.configs import MULTI_TENANT
|
||||
from shared_configs.contextvars import get_current_tenant_id
|
||||
|
||||
@@ -649,11 +651,19 @@ def verify_user_logged_in(
|
||||
"onyx.server.tenants.user_mapping", "get_tenant_invitation", None
|
||||
)(user.email)
|
||||
|
||||
super_users_list = cast(
|
||||
list[str],
|
||||
fetch_versioned_implementation_with_fallback(
|
||||
"onyx.configs.app_configs",
|
||||
"SUPER_USERS",
|
||||
[],
|
||||
),
|
||||
)
|
||||
user_info = UserInfo.from_model(
|
||||
user,
|
||||
current_token_created_at=token_created_at,
|
||||
expiry_length=SESSION_EXPIRE_TIME_SECONDS,
|
||||
is_cloud_superuser=user.email in SUPER_USERS,
|
||||
is_cloud_superuser=user.email in super_users_list,
|
||||
team_name=team_name,
|
||||
tenant_info=TenantInfo(
|
||||
new_tenant=new_tenant,
|
||||
|
Reference in New Issue
Block a user