Cleaned up foreign key cleanup for user group deletion (#2559)

* cleaned up fk cleanup for user group deletion

* added test for user group deletion
This commit is contained in:
hagen-danswer
2024-09-25 20:38:01 -07:00
committed by GitHub
parent c5a61f4820
commit b73d66c84a
13 changed files with 346 additions and 139 deletions

View File

@@ -18,6 +18,7 @@ from danswer.db.models import Document
from danswer.db.models import DocumentByConnectorCredentialPair
from danswer.db.models import DocumentSet__UserGroup
from danswer.db.models import LLMProvider__UserGroup
from danswer.db.models import Persona__UserGroup
from danswer.db.models import TokenRateLimit__UserGroup
from danswer.db.models import User
from danswer.db.models import User__UserGroup
@@ -33,6 +34,93 @@ from ee.danswer.server.user_group.models import UserGroupUpdate
logger = setup_logger()
def _cleanup_user__user_group_relationships__no_commit(
db_session: Session,
user_group_id: int,
user_ids: list[UUID] | None = None,
) -> None:
"""NOTE: does not commit the transaction."""
where_clause = User__UserGroup.user_group_id == user_group_id
if user_ids:
where_clause &= User__UserGroup.user_id.in_(user_ids)
user__user_group_relationships = db_session.scalars(
select(User__UserGroup).where(where_clause)
).all()
for user__user_group_relationship in user__user_group_relationships:
db_session.delete(user__user_group_relationship)
def _cleanup_credential__user_group_relationships__no_commit(
db_session: Session,
user_group_id: int,
) -> None:
"""NOTE: does not commit the transaction."""
db_session.query(Credential__UserGroup).filter(
Credential__UserGroup.user_group_id == user_group_id
).delete(synchronize_session=False)
def _cleanup_llm_provider__user_group_relationships__no_commit(
db_session: Session, user_group_id: int
) -> None:
"""NOTE: does not commit the transaction."""
db_session.query(LLMProvider__UserGroup).filter(
LLMProvider__UserGroup.user_group_id == user_group_id
).delete(synchronize_session=False)
def _cleanup_persona__user_group_relationships__no_commit(
db_session: Session, user_group_id: int
) -> None:
"""NOTE: does not commit the transaction."""
db_session.query(Persona__UserGroup).filter(
Persona__UserGroup.user_group_id == user_group_id
).delete(synchronize_session=False)
def _cleanup_token_rate_limit__user_group_relationships__no_commit(
db_session: Session, user_group_id: int
) -> None:
"""NOTE: does not commit the transaction."""
token_rate_limit__user_group_relationships = db_session.scalars(
select(TokenRateLimit__UserGroup).where(
TokenRateLimit__UserGroup.user_group_id == user_group_id
)
).all()
for (
token_rate_limit__user_group_relationship
) in token_rate_limit__user_group_relationships:
db_session.delete(token_rate_limit__user_group_relationship)
def _cleanup_user_group__cc_pair_relationships__no_commit(
db_session: Session, user_group_id: int, outdated_only: bool
) -> None:
"""NOTE: does not commit the transaction."""
stmt = select(UserGroup__ConnectorCredentialPair).where(
UserGroup__ConnectorCredentialPair.user_group_id == user_group_id
)
if outdated_only:
stmt = stmt.where(
UserGroup__ConnectorCredentialPair.is_current == False # noqa: E712
)
user_group__cc_pair_relationships = db_session.scalars(stmt)
for user_group__cc_pair_relationship in user_group__cc_pair_relationships:
db_session.delete(user_group__cc_pair_relationship)
def _cleanup_document_set__user_group_relationships__no_commit(
db_session: Session, user_group_id: int
) -> None:
"""NOTE: does not commit the transaction."""
db_session.execute(
delete(DocumentSet__UserGroup).where(
DocumentSet__UserGroup.user_group_id == user_group_id
)
)
def validate_user_creation_permissions(
db_session: Session,
user: User | None,
@@ -286,42 +374,6 @@ def insert_user_group(db_session: Session, user_group: UserGroupCreate) -> UserG
return db_user_group
def _cleanup_user__user_group_relationships__no_commit(
db_session: Session,
user_group_id: int,
user_ids: list[UUID] | None = None,
) -> None:
"""NOTE: does not commit the transaction."""
where_clause = User__UserGroup.user_group_id == user_group_id
if user_ids:
where_clause &= User__UserGroup.user_id.in_(user_ids)
user__user_group_relationships = db_session.scalars(
select(User__UserGroup).where(where_clause)
).all()
for user__user_group_relationship in user__user_group_relationships:
db_session.delete(user__user_group_relationship)
def _cleanup_credential__user_group_relationships__no_commit(
db_session: Session,
user_group_id: int,
) -> None:
"""NOTE: does not commit the transaction."""
db_session.query(Credential__UserGroup).filter(
Credential__UserGroup.user_group_id == user_group_id
).delete(synchronize_session=False)
def _cleanup_llm_provider__user_group_relationships__no_commit(
db_session: Session, user_group_id: int
) -> None:
"""NOTE: does not commit the transaction."""
db_session.query(LLMProvider__UserGroup).filter(
LLMProvider__UserGroup.user_group_id == user_group_id
).delete(synchronize_session=False)
def _mark_user_group__cc_pair_relationships_outdated__no_commit(
db_session: Session, user_group_id: int
) -> None:
@@ -476,21 +528,6 @@ def update_user_group(
return db_user_group
def _cleanup_token_rate_limit__user_group_relationships__no_commit(
db_session: Session, user_group_id: int
) -> None:
"""NOTE: does not commit the transaction."""
token_rate_limit__user_group_relationships = db_session.scalars(
select(TokenRateLimit__UserGroup).where(
TokenRateLimit__UserGroup.user_group_id == user_group_id
)
).all()
for (
token_rate_limit__user_group_relationship
) in token_rate_limit__user_group_relationships:
db_session.delete(token_rate_limit__user_group_relationship)
def prepare_user_group_for_deletion(db_session: Session, user_group_id: int) -> None:
stmt = select(UserGroup).where(UserGroup.id == user_group_id)
db_user_group = db_session.scalar(stmt)
@@ -499,16 +536,31 @@ def prepare_user_group_for_deletion(db_session: Session, user_group_id: int) ->
_check_user_group_is_modifiable(db_user_group)
_mark_user_group__cc_pair_relationships_outdated__no_commit(
db_session=db_session, user_group_id=user_group_id
)
_cleanup_credential__user_group_relationships__no_commit(
db_session=db_session, user_group_id=user_group_id
)
_cleanup_user__user_group_relationships__no_commit(
db_session=db_session, user_group_id=user_group_id
)
_mark_user_group__cc_pair_relationships_outdated__no_commit(
_cleanup_token_rate_limit__user_group_relationships__no_commit(
db_session=db_session, user_group_id=user_group_id
)
_cleanup_token_rate_limit__user_group_relationships__no_commit(
_cleanup_document_set__user_group_relationships__no_commit(
db_session=db_session, user_group_id=user_group_id
)
_cleanup_persona__user_group_relationships__no_commit(
db_session=db_session, user_group_id=user_group_id
)
_cleanup_user_group__cc_pair_relationships__no_commit(
db_session=db_session,
user_group_id=user_group_id,
outdated_only=False,
)
_cleanup_llm_provider__user_group_relationships__no_commit(
db_session=db_session, user_group_id=user_group_id
)
@@ -517,31 +569,12 @@ def prepare_user_group_for_deletion(db_session: Session, user_group_id: int) ->
db_session.commit()
def _cleanup_user_group__cc_pair_relationships__no_commit(
db_session: Session, user_group_id: int, outdated_only: bool
) -> None:
"""NOTE: does not commit the transaction."""
stmt = select(UserGroup__ConnectorCredentialPair).where(
UserGroup__ConnectorCredentialPair.user_group_id == user_group_id
)
if outdated_only:
stmt = stmt.where(
UserGroup__ConnectorCredentialPair.is_current == False # noqa: E712
)
user_group__cc_pair_relationships = db_session.scalars(stmt)
for user_group__cc_pair_relationship in user_group__cc_pair_relationships:
db_session.delete(user_group__cc_pair_relationship)
def _cleanup_document_set__user_group_relationships__no_commit(
db_session: Session, user_group_id: int
) -> None:
"""NOTE: does not commit the transaction."""
db_session.execute(
delete(DocumentSet__UserGroup).where(
DocumentSet__UserGroup.user_group_id == user_group_id
)
)
def delete_user_group(db_session: Session, user_group: UserGroup) -> None:
"""
This assumes that all the fk cleanup has already been done.
"""
db_session.delete(user_group)
db_session.commit()
def mark_user_group_as_synced(db_session: Session, user_group: UserGroup) -> None:
@@ -553,29 +586,6 @@ def mark_user_group_as_synced(db_session: Session, user_group: UserGroup) -> Non
db_session.commit()
def delete_user_group(db_session: Session, user_group: UserGroup) -> None:
_cleanup_llm_provider__user_group_relationships__no_commit(
db_session=db_session, user_group_id=user_group.id
)
_cleanup_user__user_group_relationships__no_commit(
db_session=db_session, user_group_id=user_group.id
)
_cleanup_user_group__cc_pair_relationships__no_commit(
db_session=db_session,
user_group_id=user_group.id,
outdated_only=False,
)
_cleanup_document_set__user_group_relationships__no_commit(
db_session=db_session, user_group_id=user_group.id
)
# need to flush so that we don't get a foreign key error when deleting the user group row
db_session.flush()
db_session.delete(user_group)
db_session.commit()
def delete_user_group_cc_pair_relationship__no_commit(
cc_pair_id: int, db_session: Session
) -> None:

View File

@@ -3,6 +3,7 @@ from uuid import uuid4
import requests
from danswer.server.manage.llm.models import FullLLMProvider
from danswer.server.manage.llm.models import LLMProviderUpsertRequest
from tests.integration.common_utils.constants import API_SERVER_URL
from tests.integration.common_utils.constants import GENERAL_HEADERS
@@ -49,6 +50,9 @@ class LLMProviderManager:
)
llm_response.raise_for_status()
response_data = llm_response.json()
import json
print(json.dumps(response_data, indent=4))
result_llm = DATestLLMProvider(
id=response_data["id"],
name=response_data["name"],
@@ -76,8 +80,6 @@ class LLMProviderManager:
llm_provider: DATestLLMProvider,
user_performing_action: DATestUser | None = None,
) -> bool:
if not llm_provider.id:
raise ValueError("LLM Provider ID is required to delete a provider")
response = requests.delete(
f"{API_SERVER_URL}/admin/llm/provider/{llm_provider.id}",
headers=user_performing_action.headers
@@ -86,3 +88,43 @@ class LLMProviderManager:
)
response.raise_for_status()
return True
@staticmethod
def get_all(
user_performing_action: DATestUser | None = None,
) -> list[FullLLMProvider]:
response = requests.get(
f"{API_SERVER_URL}/admin/llm/provider",
headers=user_performing_action.headers
if user_performing_action
else GENERAL_HEADERS,
)
response.raise_for_status()
return [FullLLMProvider(**ug) for ug in response.json()]
@staticmethod
def verify(
llm_provider: DATestLLMProvider,
verify_deleted: bool = False,
user_performing_action: DATestUser | None = None,
) -> None:
all_llm_providers = LLMProviderManager.get_all(user_performing_action)
for fetched_llm_provider in all_llm_providers:
if llm_provider.id == fetched_llm_provider.id:
if verify_deleted:
raise ValueError(
f"User group {llm_provider.id} found but should be deleted"
)
fetched_llm_groups = set(fetched_llm_provider.groups)
llm_provider_groups = set(llm_provider.groups)
if (
fetched_llm_groups == llm_provider_groups
and llm_provider.provider == fetched_llm_provider.provider
and llm_provider.api_key == fetched_llm_provider.api_key
and llm_provider.default_model_name
== fetched_llm_provider.default_model_name
and llm_provider.is_public == fetched_llm_provider.is_public
):
return
if not verify_deleted:
raise ValueError(f"User group {llm_provider.id} not found")

View File

@@ -164,31 +164,39 @@ class PersonaManager:
@staticmethod
def verify(
test_persona: DATestPersona,
persona: DATestPersona,
user_performing_action: DATestUser | None = None,
) -> bool:
all_personas = PersonaManager.get_all(user_performing_action)
for persona in all_personas:
if persona.id == test_persona.id:
for fetched_persona in all_personas:
if fetched_persona.id == persona.id:
return (
persona.name == test_persona.name
and persona.description == test_persona.description
and persona.num_chunks == test_persona.num_chunks
and persona.llm_relevance_filter
== test_persona.llm_relevance_filter
and persona.is_public == test_persona.is_public
and persona.llm_filter_extraction
== test_persona.llm_filter_extraction
and persona.llm_model_provider_override
== test_persona.llm_model_provider_override
and persona.llm_model_version_override
== test_persona.llm_model_version_override
and set(persona.prompts) == set(test_persona.prompt_ids)
and set(persona.document_sets) == set(test_persona.document_set_ids)
and set(persona.tools) == set(test_persona.tool_ids)
and set(user.email for user in persona.users)
== set(test_persona.users)
and set(persona.groups) == set(test_persona.groups)
fetched_persona.name == persona.name
and fetched_persona.description == persona.description
and fetched_persona.num_chunks == persona.num_chunks
and fetched_persona.llm_relevance_filter
== persona.llm_relevance_filter
and fetched_persona.is_public == persona.is_public
and fetched_persona.llm_filter_extraction
== persona.llm_filter_extraction
and fetched_persona.llm_model_provider_override
== persona.llm_model_provider_override
and fetched_persona.llm_model_version_override
== persona.llm_model_version_override
and set([prompt.id for prompt in fetched_persona.prompts])
== set(persona.prompt_ids)
and set(
[
document_set.id
for document_set in fetched_persona.document_sets
]
)
== set(persona.document_set_ids)
and set([tool.id for tool in fetched_persona.tools])
== set(persona.tool_ids)
and set(user.email for user in fetched_persona.users)
== set(persona.users)
and set(fetched_persona.groups) == set(persona.groups)
)
return False

View File

@@ -47,8 +47,6 @@ class UserGroupManager:
user_group: DATestUserGroup,
user_performing_action: DATestUser | None = None,
) -> None:
if not user_group.id:
raise ValueError("User group has no ID")
response = requests.patch(
f"{API_SERVER_URL}/manage/admin/user-group/{user_group.id}",
json=user_group.model_dump(),
@@ -58,6 +56,19 @@ class UserGroupManager:
)
response.raise_for_status()
@staticmethod
def delete(
user_group: DATestUserGroup,
user_performing_action: DATestUser | None = None,
) -> None:
response = requests.delete(
f"{API_SERVER_URL}/manage/admin/user-group/{user_group.id}",
headers=user_performing_action.headers
if user_performing_action
else GENERAL_HEADERS,
)
response.raise_for_status()
@staticmethod
def set_curator_status(
test_user_group: DATestUserGroup,
@@ -65,8 +76,6 @@ class UserGroupManager:
is_curator: bool = True,
user_performing_action: DATestUser | None = None,
) -> None:
if not user_to_set_as_curator.id:
raise ValueError("User has no ID")
set_curator_request = {
"user_id": user_to_set_as_curator.id,
"is_curator": is_curator,
@@ -130,7 +139,7 @@ class UserGroupManager:
check_ids = {user_group.id for user_group in user_groups_to_check}
user_group_ids = {user_group.id for user_group in user_groups}
if not check_ids.issubset(user_group_ids):
raise RuntimeError("Document set not found")
raise RuntimeError("User group not found")
user_groups = [
user_group
for user_group in user_groups
@@ -146,3 +155,26 @@ class UserGroupManager:
else:
print("User groups were not synced yet, waiting...")
time.sleep(2)
@staticmethod
def wait_for_deletion_completion(
user_groups_to_check: list[DATestUserGroup],
user_performing_action: DATestUser | None = None,
) -> None:
start = time.time()
user_group_ids_to_check = {user_group.id for user_group in user_groups_to_check}
while True:
fetched_user_groups = UserGroupManager.get_all(user_performing_action)
fetched_user_group_ids = {
user_group.id for user_group in fetched_user_groups
}
if not user_group_ids_to_check.intersection(fetched_user_group_ids):
return
if time.time() - start > MAX_DELAY:
raise TimeoutError(
f"User groups deletion was not completed within the {MAX_DELAY} seconds"
)
else:
print("Some user groups are still being deleted, waiting...")
time.sleep(2)

View File

@@ -87,7 +87,7 @@ class DATestLLMProvider(BaseModel):
api_key: str
default_model_name: str
is_public: bool
groups: list[DATestUserGroup]
groups: list[int]
api_base: str | None = None
api_version: str | None = None

View File

@@ -31,7 +31,7 @@ from tests.integration.common_utils.vespa import vespa_fixture
def test_connector_deletion(reset: None, vespa_client: vespa_fixture) -> None:
# Creating an admin user (first user created is automatically an admin)
admin_user: DATestUser = UserManager.create(name="admin_user")
# add api key to user
# create api key
api_key: DATestAPIKey = APIKeyManager.create(
user_performing_action=admin_user,
)
@@ -181,7 +181,7 @@ def test_connector_deletion_for_overlapping_connectors(
"""
# Creating an admin user (first user created is automatically an admin)
admin_user: DATestUser = UserManager.create(name="admin_user")
# add api key to user
# create api key
api_key: DATestAPIKey = APIKeyManager.create(
user_performing_action=admin_user,
)

View File

@@ -2,10 +2,10 @@ import requests
from danswer.configs.constants import MessageType
from tests.integration.common_utils.constants import API_SERVER_URL
from tests.integration.common_utils.llm import LLMProviderManager
from tests.integration.common_utils.managers.api_key import APIKeyManager
from tests.integration.common_utils.managers.cc_pair import CCPairManager
from tests.integration.common_utils.managers.document import DocumentManager
from tests.integration.common_utils.managers.llm_provider import LLMProviderManager
from tests.integration.common_utils.managers.user import UserManager
from tests.integration.common_utils.test_models import DATestAPIKey
from tests.integration.common_utils.test_models import DATestCCPair

View File

@@ -3,10 +3,10 @@ import requests
from danswer.configs.constants import MessageType
from tests.integration.common_utils.constants import API_SERVER_URL
from tests.integration.common_utils.constants import NUM_DOCS
from tests.integration.common_utils.llm import LLMProviderManager
from tests.integration.common_utils.managers.api_key import APIKeyManager
from tests.integration.common_utils.managers.cc_pair import CCPairManager
from tests.integration.common_utils.managers.document import DocumentManager
from tests.integration.common_utils.managers.llm_provider import LLMProviderManager
from tests.integration.common_utils.managers.user import UserManager
from tests.integration.common_utils.test_models import DATestAPIKey
from tests.integration.common_utils.test_models import DATestCCPair

View File

@@ -16,7 +16,7 @@ def test_multiple_document_sets_syncing_same_connnector(
# Creating an admin user (first user created is automatically an admin)
admin_user: DATestUser = UserManager.create(name="admin_user")
# add api key to user
# create api key
api_key: DATestAPIKey = APIKeyManager.create(
user_performing_action=admin_user,
)
@@ -70,7 +70,7 @@ def test_removing_connector(reset: None, vespa_client: vespa_fixture) -> None:
# Creating an admin user (first user created is automatically an admin)
admin_user: DATestUser = UserManager.create(name="admin_user")
# add api key to user
# create api key
api_key: DATestAPIKey = APIKeyManager.create(
user_performing_action=admin_user,
)

View File

@@ -1,5 +1,5 @@
from tests.integration.common_utils.llm import LLMProviderManager
from tests.integration.common_utils.managers.chat import ChatSessionManager
from tests.integration.common_utils.managers.llm_provider import LLMProviderManager
from tests.integration.common_utils.managers.user import UserManager
from tests.integration.common_utils.test_models import DATestUser

View File

@@ -1,5 +1,5 @@
from tests.integration.common_utils.llm import LLMProviderManager
from tests.integration.common_utils.managers.chat import ChatSessionManager
from tests.integration.common_utils.managers.llm_provider import LLMProviderManager
from tests.integration.common_utils.managers.user import UserManager
from tests.integration.common_utils.test_models import DATestUser

View File

@@ -0,0 +1,115 @@
"""
This tests the deletion of a user group with the following foreign key constraints:
- connector_credential_pair
- user
- credential
- llm_provider
- document_set
- token_rate_limit (Not Implemented)
- persona
"""
from danswer.server.documents.models import DocumentSource
from tests.integration.common_utils.managers.cc_pair import CCPairManager
from tests.integration.common_utils.managers.credential import CredentialManager
from tests.integration.common_utils.managers.document_set import DocumentSetManager
from tests.integration.common_utils.managers.llm_provider import LLMProviderManager
from tests.integration.common_utils.managers.persona import PersonaManager
from tests.integration.common_utils.managers.user import UserManager
from tests.integration.common_utils.managers.user_group import UserGroupManager
from tests.integration.common_utils.test_models import DATestCredential
from tests.integration.common_utils.test_models import DATestDocumentSet
from tests.integration.common_utils.test_models import DATestLLMProvider
from tests.integration.common_utils.test_models import DATestPersona
from tests.integration.common_utils.test_models import DATestUser
from tests.integration.common_utils.test_models import DATestUserGroup
from tests.integration.common_utils.vespa import vespa_fixture
def test_user_group_deletion(reset: None, vespa_client: vespa_fixture) -> None:
# Creating an admin user (first user created is automatically an admin)
admin_user: DATestUser = UserManager.create(name="admin_user")
# create connectors
cc_pair = CCPairManager.create_from_scratch(
source=DocumentSource.INGESTION_API,
user_performing_action=admin_user,
)
# Create user group with a cc_pair and a user
user_group: DATestUserGroup = UserGroupManager.create(
user_ids=[admin_user.id],
cc_pair_ids=[cc_pair.id],
user_performing_action=admin_user,
)
cc_pair.groups = [user_group.id]
UserGroupManager.wait_for_sync(
user_groups_to_check=[user_group], user_performing_action=admin_user
)
UserGroupManager.verify(
user_group=user_group,
user_performing_action=admin_user,
)
CCPairManager.verify(
cc_pair=cc_pair,
user_performing_action=admin_user,
)
credential: DATestCredential = CredentialManager.create(
groups=[user_group.id],
user_performing_action=admin_user,
)
document_set: DATestDocumentSet = DocumentSetManager.create(
cc_pair_ids=[cc_pair.id],
groups=[user_group.id],
user_performing_action=admin_user,
)
llm_provider: DATestLLMProvider = LLMProviderManager.create(
groups=[user_group.id],
user_performing_action=admin_user,
)
persona: DATestPersona = PersonaManager.create(
groups=[user_group.id],
user_performing_action=admin_user,
)
UserGroupManager.wait_for_sync(
user_groups_to_check=[user_group], user_performing_action=admin_user
)
UserGroupManager.verify(
user_group=user_group,
user_performing_action=admin_user,
)
UserGroupManager.delete(
user_group=user_group,
user_performing_action=admin_user,
)
UserGroupManager.wait_for_deletion_completion(
user_groups_to_check=[user_group], user_performing_action=admin_user
)
credential.groups = []
document_set.groups = []
llm_provider.groups = []
persona.groups = []
CredentialManager.verify(
credential=credential,
user_performing_action=admin_user,
)
DocumentSetManager.verify(
document_set=document_set,
user_performing_action=admin_user,
)
LLMProviderManager.verify(
llm_provider=llm_provider,
user_performing_action=admin_user,
)
PersonaManager.verify(
persona=persona,
user_performing_action=admin_user,
)

View File

@@ -15,7 +15,7 @@ def test_removing_connector(reset: None, vespa_client: vespa_fixture) -> None:
# Creating an admin user (first user created is automatically an admin)
admin_user: DATestUser = UserManager.create(name="admin_user")
# add api key to user
# create api key
api_key: DATestAPIKey = APIKeyManager.create(
user_performing_action=admin_user,
)