Fix/remove ee from mit (#4682)

* Remove some ee imports

* more

* Remove all ee imports

* Fix

* Autodiscover

* fix

* Fix typing

* More celery task stuff

* Fix import
This commit is contained in:
Chris Weaver
2025-05-11 15:09:50 -07:00
committed by GitHub
parent 84566debab
commit 913f7cc7d4
43 changed files with 499 additions and 434 deletions

View File

@@ -12,12 +12,6 @@ from sqlalchemy.orm import Session
from onyx.auth.users import current_curator_or_admin_user
from onyx.auth.users import current_user
from onyx.background.celery.celery_utils import get_deletion_attempt_snapshot
from onyx.background.celery.tasks.doc_permission_syncing.tasks import (
try_creating_permissions_sync_task,
)
from onyx.background.celery.tasks.external_group_syncing.tasks import (
try_creating_external_group_sync_task,
)
from onyx.background.celery.tasks.pruning.tasks import (
try_creating_prune_generator_task,
)
@@ -392,154 +386,6 @@ def prune_cc_pair(
)
@router.get("/admin/cc-pair/{cc_pair_id}/sync-permissions")
def get_cc_pair_latest_sync(
cc_pair_id: int,
user: User = Depends(current_curator_or_admin_user),
db_session: Session = Depends(get_session),
) -> datetime | None:
cc_pair = get_connector_credential_pair_from_id_for_user(
cc_pair_id=cc_pair_id,
db_session=db_session,
user=user,
get_editable=False,
)
if not cc_pair:
raise HTTPException(
status_code=400,
detail="cc_pair not found for current user's permissions",
)
return cc_pair.last_time_perm_sync
@router.post("/admin/cc-pair/{cc_pair_id}/sync-permissions")
def sync_cc_pair(
cc_pair_id: int,
user: User = Depends(current_curator_or_admin_user),
db_session: Session = Depends(get_session),
) -> StatusResponse[list[int]]:
"""Triggers permissions sync on a particular cc_pair immediately"""
tenant_id = get_current_tenant_id()
cc_pair = get_connector_credential_pair_from_id_for_user(
cc_pair_id=cc_pair_id,
db_session=db_session,
user=user,
get_editable=False,
)
if not cc_pair:
raise HTTPException(
status_code=400,
detail="Connection not found for current user's permissions",
)
r = get_redis_client()
redis_connector = RedisConnector(tenant_id, cc_pair_id)
if redis_connector.permissions.fenced:
raise HTTPException(
status_code=HTTPStatus.CONFLICT,
detail="Permissions sync task already in progress.",
)
logger.info(
f"Permissions sync cc_pair={cc_pair_id} "
f"connector_id={cc_pair.connector_id} "
f"credential_id={cc_pair.credential_id} "
f"{cc_pair.connector.name} connector."
)
payload_id = try_creating_permissions_sync_task(
client_app, cc_pair_id, r, tenant_id
)
if not payload_id:
raise HTTPException(
status_code=HTTPStatus.INTERNAL_SERVER_ERROR,
detail="Permissions sync task creation failed.",
)
logger.info(f"Permissions sync queued: cc_pair={cc_pair_id} id={payload_id}")
return StatusResponse(
success=True,
message="Successfully created the permissions sync task.",
)
@router.get("/admin/cc-pair/{cc_pair_id}/sync-groups")
def get_cc_pair_latest_group_sync(
cc_pair_id: int,
user: User = Depends(current_curator_or_admin_user),
db_session: Session = Depends(get_session),
) -> datetime | None:
cc_pair = get_connector_credential_pair_from_id_for_user(
cc_pair_id=cc_pair_id,
db_session=db_session,
user=user,
get_editable=False,
)
if not cc_pair:
raise HTTPException(
status_code=400,
detail="cc_pair not found for current user's permissions",
)
return cc_pair.last_time_external_group_sync
@router.post("/admin/cc-pair/{cc_pair_id}/sync-groups")
def sync_cc_pair_groups(
cc_pair_id: int,
user: User = Depends(current_curator_or_admin_user),
db_session: Session = Depends(get_session),
) -> StatusResponse[list[int]]:
"""Triggers group sync on a particular cc_pair immediately"""
tenant_id = get_current_tenant_id()
cc_pair = get_connector_credential_pair_from_id_for_user(
cc_pair_id=cc_pair_id,
db_session=db_session,
user=user,
get_editable=False,
)
if not cc_pair:
raise HTTPException(
status_code=400,
detail="Connection not found for current user's permissions",
)
r = get_redis_client()
redis_connector = RedisConnector(tenant_id, cc_pair_id)
if redis_connector.external_group_sync.fenced:
raise HTTPException(
status_code=HTTPStatus.CONFLICT,
detail="External group sync task already in progress.",
)
logger.info(
f"External group sync cc_pair={cc_pair_id} "
f"connector_id={cc_pair.connector_id} "
f"credential_id={cc_pair.credential_id} "
f"{cc_pair.connector.name} connector."
)
payload_id = try_creating_external_group_sync_task(
client_app, cc_pair_id, r, tenant_id
)
if not payload_id:
raise HTTPException(
status_code=HTTPStatus.INTERNAL_SERVER_ERROR,
detail="External group sync task creation failed.",
)
logger.info(f"External group sync queued: cc_pair={cc_pair_id} id={payload_id}")
return StatusResponse(
success=True,
message="Successfully created the external group sync task.",
)
@router.get("/admin/cc-pair/{cc_pair_id}/get-docs-sync-status")
def get_docs_sync_status(
cc_pair_id: int,

View File

@@ -9,8 +9,6 @@ from uuid import UUID
from pydantic import BaseModel
from pydantic import Field
from ee.onyx.server.query_history.models import ChatSessionMinimal
from onyx.background.indexing.models import IndexAttemptErrorPydantic
from onyx.configs.app_configs import MASK_CREDENTIAL_PREFIX
from onyx.configs.constants import DocumentSource
from onyx.connectors.models import InputType
@@ -23,8 +21,6 @@ from onyx.db.models import Document as DbDocument
from onyx.db.models import IndexAttempt
from onyx.db.models import IndexingStatus
from onyx.db.models import TaskStatus
from onyx.server.models import FullUserSnapshot
from onyx.server.models import InvitedUserSnapshot
from onyx.server.utils import mask_credential_dict
from onyx.utils.variable_functionality import fetch_ee_implementation_or_noop
@@ -193,14 +189,7 @@ class IndexAttemptSnapshot(BaseModel):
# These are the types currently supported by the pagination hook
# More api endpoints can be refactored and be added here for use with the pagination hook
PaginatedType = TypeVar(
"PaginatedType",
IndexAttemptSnapshot,
FullUserSnapshot,
InvitedUserSnapshot,
ChatSessionMinimal,
IndexAttemptErrorPydantic,
)
PaginatedType = TypeVar("PaginatedType", bound=BaseModel)
class PaginatedReturn(BaseModel, Generic[PaginatedType]):

View File

@@ -1,5 +1,7 @@
import re
from datetime import datetime
from enum import Enum
from typing import Any
from typing import TYPE_CHECKING
from pydantic import BaseModel
@@ -8,7 +10,6 @@ from pydantic import Field
from pydantic import field_validator
from pydantic import model_validator
from ee.onyx.server.manage.models import StandardAnswerCategory
from onyx.auth.schemas import UserRole
from onyx.configs.app_configs import TRACK_EXTERNAL_IDP_EXPIRY
from onyx.configs.constants import AuthType
@@ -17,6 +18,8 @@ from onyx.db.models import AllowedAnswerFilters
from onyx.db.models import ChannelConfig
from onyx.db.models import SlackBot as SlackAppModel
from onyx.db.models import SlackChannelConfig as SlackChannelConfigModel
from onyx.db.models import StandardAnswer as StandardAnswerModel
from onyx.db.models import StandardAnswerCategory as StandardAnswerCategoryModel
from onyx.db.models import User
from onyx.onyxbot.slack.config import VALID_SLACK_FILTERS
from onyx.server.features.persona.models import FullPersonaSnapshot
@@ -234,7 +237,7 @@ class SlackChannelConfig(BaseModel):
persona: PersonaSnapshot | None
channel_config: ChannelConfig
# XXX this is going away soon
standard_answer_categories: list[StandardAnswerCategory]
standard_answer_categories: list["StandardAnswerCategory"]
enable_auto_filters: bool
is_default: bool
@@ -307,3 +310,99 @@ class AllUsersResponse(BaseModel):
class SlackChannel(BaseModel):
id: str
name: str
"""
Standard Answer Models
ee only, but needs to be here since it's imported by non-ee models.
"""
class StandardAnswerCategoryCreationRequest(BaseModel):
name: str
class StandardAnswerCategory(BaseModel):
id: int
name: str
@classmethod
def from_model(
cls, standard_answer_category: StandardAnswerCategoryModel
) -> "StandardAnswerCategory":
return cls(
id=standard_answer_category.id,
name=standard_answer_category.name,
)
class StandardAnswer(BaseModel):
id: int
keyword: str
answer: str
categories: list[StandardAnswerCategory]
match_regex: bool
match_any_keywords: bool
@classmethod
def from_model(cls, standard_answer_model: StandardAnswerModel) -> "StandardAnswer":
return cls(
id=standard_answer_model.id,
keyword=standard_answer_model.keyword,
answer=standard_answer_model.answer,
match_regex=standard_answer_model.match_regex,
match_any_keywords=standard_answer_model.match_any_keywords,
categories=[
StandardAnswerCategory.from_model(standard_answer_category_model)
for standard_answer_category_model in standard_answer_model.categories
],
)
class StandardAnswerCreationRequest(BaseModel):
keyword: str
answer: str
categories: list[int]
match_regex: bool
match_any_keywords: bool
@field_validator("categories", mode="before")
@classmethod
def validate_categories(cls, value: list[int]) -> list[int]:
if len(value) < 1:
raise ValueError(
"At least one category must be attached to a standard answer"
)
return value
@model_validator(mode="after")
def validate_only_match_any_if_not_regex(self) -> Any:
if self.match_regex and self.match_any_keywords:
raise ValueError(
"Can only match any keywords in keyword mode, not regex mode"
)
return self
@model_validator(mode="after")
def validate_keyword_if_regex(self) -> Any:
if not self.match_regex:
# no validation for keywords
return self
try:
re.compile(self.keyword)
return self
except re.error as err:
if isinstance(err.pattern, bytes):
raise ValueError(
f'invalid regex pattern r"{err.pattern.decode()}" in `keyword`: {err.msg}'
)
else:
pattern = f'r"{err.pattern}"' if err.pattern is not None else ""
raise ValueError(
" ".join(
["invalid regex pattern", pattern, f"in `keyword`: {err.msg}"]
)
)

View File

@@ -21,7 +21,6 @@ from sqlalchemy import select
from sqlalchemy import update
from sqlalchemy.orm import Session
from ee.onyx.configs.app_configs import SUPER_USERS
from onyx.auth.email_utils import send_user_email_invite
from onyx.auth.invited_users import get_invited_users
from onyx.auth.invited_users import write_invited_users
@@ -72,6 +71,9 @@ from onyx.server.models import MinimalUserSnapshot
from onyx.server.utils import BasicAuthenticationError
from onyx.utils.logger import setup_logger
from onyx.utils.variable_functionality import fetch_ee_implementation_or_noop
from onyx.utils.variable_functionality import (
fetch_versioned_implementation_with_fallback,
)
from shared_configs.configs import MULTI_TENANT
from shared_configs.contextvars import get_current_tenant_id
@@ -649,11 +651,19 @@ def verify_user_logged_in(
"onyx.server.tenants.user_mapping", "get_tenant_invitation", None
)(user.email)
super_users_list = cast(
list[str],
fetch_versioned_implementation_with_fallback(
"onyx.configs.app_configs",
"SUPER_USERS",
[],
),
)
user_info = UserInfo.from_model(
user,
current_token_created_at=token_created_at,
expiry_length=SESSION_EXPIRE_TIME_SECONDS,
is_cloud_superuser=user.email in SUPER_USERS,
is_cloud_superuser=user.email in super_users_list,
team_name=team_name,
tenant_info=TenantInfo(
new_tenant=new_tenant,