diff --git a/backend/ee/danswer/db/user_group.py b/backend/ee/danswer/db/user_group.py index 529112004b0c..00ad159c1467 100644 --- a/backend/ee/danswer/db/user_group.py +++ b/backend/ee/danswer/db/user_group.py @@ -18,6 +18,7 @@ from danswer.db.models import Document from danswer.db.models import DocumentByConnectorCredentialPair from danswer.db.models import DocumentSet__UserGroup from danswer.db.models import LLMProvider__UserGroup +from danswer.db.models import Persona__UserGroup from danswer.db.models import TokenRateLimit__UserGroup from danswer.db.models import User from danswer.db.models import User__UserGroup @@ -33,6 +34,93 @@ from ee.danswer.server.user_group.models import UserGroupUpdate logger = setup_logger() +def _cleanup_user__user_group_relationships__no_commit( + db_session: Session, + user_group_id: int, + user_ids: list[UUID] | None = None, +) -> None: + """NOTE: does not commit the transaction.""" + where_clause = User__UserGroup.user_group_id == user_group_id + if user_ids: + where_clause &= User__UserGroup.user_id.in_(user_ids) + + user__user_group_relationships = db_session.scalars( + select(User__UserGroup).where(where_clause) + ).all() + for user__user_group_relationship in user__user_group_relationships: + db_session.delete(user__user_group_relationship) + + +def _cleanup_credential__user_group_relationships__no_commit( + db_session: Session, + user_group_id: int, +) -> None: + """NOTE: does not commit the transaction.""" + db_session.query(Credential__UserGroup).filter( + Credential__UserGroup.user_group_id == user_group_id + ).delete(synchronize_session=False) + + +def _cleanup_llm_provider__user_group_relationships__no_commit( + db_session: Session, user_group_id: int +) -> None: + """NOTE: does not commit the transaction.""" + db_session.query(LLMProvider__UserGroup).filter( + LLMProvider__UserGroup.user_group_id == user_group_id + ).delete(synchronize_session=False) + + +def _cleanup_persona__user_group_relationships__no_commit( + db_session: Session, user_group_id: int +) -> None: + """NOTE: does not commit the transaction.""" + db_session.query(Persona__UserGroup).filter( + Persona__UserGroup.user_group_id == user_group_id + ).delete(synchronize_session=False) + + +def _cleanup_token_rate_limit__user_group_relationships__no_commit( + db_session: Session, user_group_id: int +) -> None: + """NOTE: does not commit the transaction.""" + token_rate_limit__user_group_relationships = db_session.scalars( + select(TokenRateLimit__UserGroup).where( + TokenRateLimit__UserGroup.user_group_id == user_group_id + ) + ).all() + for ( + token_rate_limit__user_group_relationship + ) in token_rate_limit__user_group_relationships: + db_session.delete(token_rate_limit__user_group_relationship) + + +def _cleanup_user_group__cc_pair_relationships__no_commit( + db_session: Session, user_group_id: int, outdated_only: bool +) -> None: + """NOTE: does not commit the transaction.""" + stmt = select(UserGroup__ConnectorCredentialPair).where( + UserGroup__ConnectorCredentialPair.user_group_id == user_group_id + ) + if outdated_only: + stmt = stmt.where( + UserGroup__ConnectorCredentialPair.is_current == False # noqa: E712 + ) + user_group__cc_pair_relationships = db_session.scalars(stmt) + for user_group__cc_pair_relationship in user_group__cc_pair_relationships: + db_session.delete(user_group__cc_pair_relationship) + + +def _cleanup_document_set__user_group_relationships__no_commit( + db_session: Session, user_group_id: int +) -> None: + """NOTE: does not commit the transaction.""" + db_session.execute( + delete(DocumentSet__UserGroup).where( + DocumentSet__UserGroup.user_group_id == user_group_id + ) + ) + + def validate_user_creation_permissions( db_session: Session, user: User | None, @@ -286,42 +374,6 @@ def insert_user_group(db_session: Session, user_group: UserGroupCreate) -> UserG return db_user_group -def _cleanup_user__user_group_relationships__no_commit( - db_session: Session, - user_group_id: int, - user_ids: list[UUID] | None = None, -) -> None: - """NOTE: does not commit the transaction.""" - where_clause = User__UserGroup.user_group_id == user_group_id - if user_ids: - where_clause &= User__UserGroup.user_id.in_(user_ids) - - user__user_group_relationships = db_session.scalars( - select(User__UserGroup).where(where_clause) - ).all() - for user__user_group_relationship in user__user_group_relationships: - db_session.delete(user__user_group_relationship) - - -def _cleanup_credential__user_group_relationships__no_commit( - db_session: Session, - user_group_id: int, -) -> None: - """NOTE: does not commit the transaction.""" - db_session.query(Credential__UserGroup).filter( - Credential__UserGroup.user_group_id == user_group_id - ).delete(synchronize_session=False) - - -def _cleanup_llm_provider__user_group_relationships__no_commit( - db_session: Session, user_group_id: int -) -> None: - """NOTE: does not commit the transaction.""" - db_session.query(LLMProvider__UserGroup).filter( - LLMProvider__UserGroup.user_group_id == user_group_id - ).delete(synchronize_session=False) - - def _mark_user_group__cc_pair_relationships_outdated__no_commit( db_session: Session, user_group_id: int ) -> None: @@ -476,21 +528,6 @@ def update_user_group( return db_user_group -def _cleanup_token_rate_limit__user_group_relationships__no_commit( - db_session: Session, user_group_id: int -) -> None: - """NOTE: does not commit the transaction.""" - token_rate_limit__user_group_relationships = db_session.scalars( - select(TokenRateLimit__UserGroup).where( - TokenRateLimit__UserGroup.user_group_id == user_group_id - ) - ).all() - for ( - token_rate_limit__user_group_relationship - ) in token_rate_limit__user_group_relationships: - db_session.delete(token_rate_limit__user_group_relationship) - - def prepare_user_group_for_deletion(db_session: Session, user_group_id: int) -> None: stmt = select(UserGroup).where(UserGroup.id == user_group_id) db_user_group = db_session.scalar(stmt) @@ -499,16 +536,31 @@ def prepare_user_group_for_deletion(db_session: Session, user_group_id: int) -> _check_user_group_is_modifiable(db_user_group) + _mark_user_group__cc_pair_relationships_outdated__no_commit( + db_session=db_session, user_group_id=user_group_id + ) + _cleanup_credential__user_group_relationships__no_commit( db_session=db_session, user_group_id=user_group_id ) _cleanup_user__user_group_relationships__no_commit( db_session=db_session, user_group_id=user_group_id ) - _mark_user_group__cc_pair_relationships_outdated__no_commit( + _cleanup_token_rate_limit__user_group_relationships__no_commit( db_session=db_session, user_group_id=user_group_id ) - _cleanup_token_rate_limit__user_group_relationships__no_commit( + _cleanup_document_set__user_group_relationships__no_commit( + db_session=db_session, user_group_id=user_group_id + ) + _cleanup_persona__user_group_relationships__no_commit( + db_session=db_session, user_group_id=user_group_id + ) + _cleanup_user_group__cc_pair_relationships__no_commit( + db_session=db_session, + user_group_id=user_group_id, + outdated_only=False, + ) + _cleanup_llm_provider__user_group_relationships__no_commit( db_session=db_session, user_group_id=user_group_id ) @@ -517,31 +569,12 @@ def prepare_user_group_for_deletion(db_session: Session, user_group_id: int) -> db_session.commit() -def _cleanup_user_group__cc_pair_relationships__no_commit( - db_session: Session, user_group_id: int, outdated_only: bool -) -> None: - """NOTE: does not commit the transaction.""" - stmt = select(UserGroup__ConnectorCredentialPair).where( - UserGroup__ConnectorCredentialPair.user_group_id == user_group_id - ) - if outdated_only: - stmt = stmt.where( - UserGroup__ConnectorCredentialPair.is_current == False # noqa: E712 - ) - user_group__cc_pair_relationships = db_session.scalars(stmt) - for user_group__cc_pair_relationship in user_group__cc_pair_relationships: - db_session.delete(user_group__cc_pair_relationship) - - -def _cleanup_document_set__user_group_relationships__no_commit( - db_session: Session, user_group_id: int -) -> None: - """NOTE: does not commit the transaction.""" - db_session.execute( - delete(DocumentSet__UserGroup).where( - DocumentSet__UserGroup.user_group_id == user_group_id - ) - ) +def delete_user_group(db_session: Session, user_group: UserGroup) -> None: + """ + This assumes that all the fk cleanup has already been done. + """ + db_session.delete(user_group) + db_session.commit() def mark_user_group_as_synced(db_session: Session, user_group: UserGroup) -> None: @@ -553,29 +586,6 @@ def mark_user_group_as_synced(db_session: Session, user_group: UserGroup) -> Non db_session.commit() -def delete_user_group(db_session: Session, user_group: UserGroup) -> None: - _cleanup_llm_provider__user_group_relationships__no_commit( - db_session=db_session, user_group_id=user_group.id - ) - _cleanup_user__user_group_relationships__no_commit( - db_session=db_session, user_group_id=user_group.id - ) - _cleanup_user_group__cc_pair_relationships__no_commit( - db_session=db_session, - user_group_id=user_group.id, - outdated_only=False, - ) - _cleanup_document_set__user_group_relationships__no_commit( - db_session=db_session, user_group_id=user_group.id - ) - - # need to flush so that we don't get a foreign key error when deleting the user group row - db_session.flush() - - db_session.delete(user_group) - db_session.commit() - - def delete_user_group_cc_pair_relationship__no_commit( cc_pair_id: int, db_session: Session ) -> None: diff --git a/backend/tests/integration/common_utils/llm.py b/backend/tests/integration/common_utils/managers/llm_provider.py similarity index 62% rename from backend/tests/integration/common_utils/llm.py rename to backend/tests/integration/common_utils/managers/llm_provider.py index a3614c9bc242..cde75284ca8f 100644 --- a/backend/tests/integration/common_utils/llm.py +++ b/backend/tests/integration/common_utils/managers/llm_provider.py @@ -3,6 +3,7 @@ from uuid import uuid4 import requests +from danswer.server.manage.llm.models import FullLLMProvider from danswer.server.manage.llm.models import LLMProviderUpsertRequest from tests.integration.common_utils.constants import API_SERVER_URL from tests.integration.common_utils.constants import GENERAL_HEADERS @@ -49,6 +50,9 @@ class LLMProviderManager: ) llm_response.raise_for_status() response_data = llm_response.json() + import json + + print(json.dumps(response_data, indent=4)) result_llm = DATestLLMProvider( id=response_data["id"], name=response_data["name"], @@ -76,8 +80,6 @@ class LLMProviderManager: llm_provider: DATestLLMProvider, user_performing_action: DATestUser | None = None, ) -> bool: - if not llm_provider.id: - raise ValueError("LLM Provider ID is required to delete a provider") response = requests.delete( f"{API_SERVER_URL}/admin/llm/provider/{llm_provider.id}", headers=user_performing_action.headers @@ -86,3 +88,43 @@ class LLMProviderManager: ) response.raise_for_status() return True + + @staticmethod + def get_all( + user_performing_action: DATestUser | None = None, + ) -> list[FullLLMProvider]: + response = requests.get( + f"{API_SERVER_URL}/admin/llm/provider", + headers=user_performing_action.headers + if user_performing_action + else GENERAL_HEADERS, + ) + response.raise_for_status() + return [FullLLMProvider(**ug) for ug in response.json()] + + @staticmethod + def verify( + llm_provider: DATestLLMProvider, + verify_deleted: bool = False, + user_performing_action: DATestUser | None = None, + ) -> None: + all_llm_providers = LLMProviderManager.get_all(user_performing_action) + for fetched_llm_provider in all_llm_providers: + if llm_provider.id == fetched_llm_provider.id: + if verify_deleted: + raise ValueError( + f"User group {llm_provider.id} found but should be deleted" + ) + fetched_llm_groups = set(fetched_llm_provider.groups) + llm_provider_groups = set(llm_provider.groups) + if ( + fetched_llm_groups == llm_provider_groups + and llm_provider.provider == fetched_llm_provider.provider + and llm_provider.api_key == fetched_llm_provider.api_key + and llm_provider.default_model_name + == fetched_llm_provider.default_model_name + and llm_provider.is_public == fetched_llm_provider.is_public + ): + return + if not verify_deleted: + raise ValueError(f"User group {llm_provider.id} not found") diff --git a/backend/tests/integration/common_utils/managers/persona.py b/backend/tests/integration/common_utils/managers/persona.py index 086dfc373e3a..4e8e58224fb8 100644 --- a/backend/tests/integration/common_utils/managers/persona.py +++ b/backend/tests/integration/common_utils/managers/persona.py @@ -164,31 +164,39 @@ class PersonaManager: @staticmethod def verify( - test_persona: DATestPersona, + persona: DATestPersona, user_performing_action: DATestUser | None = None, ) -> bool: all_personas = PersonaManager.get_all(user_performing_action) - for persona in all_personas: - if persona.id == test_persona.id: + for fetched_persona in all_personas: + if fetched_persona.id == persona.id: return ( - persona.name == test_persona.name - and persona.description == test_persona.description - and persona.num_chunks == test_persona.num_chunks - and persona.llm_relevance_filter - == test_persona.llm_relevance_filter - and persona.is_public == test_persona.is_public - and persona.llm_filter_extraction - == test_persona.llm_filter_extraction - and persona.llm_model_provider_override - == test_persona.llm_model_provider_override - and persona.llm_model_version_override - == test_persona.llm_model_version_override - and set(persona.prompts) == set(test_persona.prompt_ids) - and set(persona.document_sets) == set(test_persona.document_set_ids) - and set(persona.tools) == set(test_persona.tool_ids) - and set(user.email for user in persona.users) - == set(test_persona.users) - and set(persona.groups) == set(test_persona.groups) + fetched_persona.name == persona.name + and fetched_persona.description == persona.description + and fetched_persona.num_chunks == persona.num_chunks + and fetched_persona.llm_relevance_filter + == persona.llm_relevance_filter + and fetched_persona.is_public == persona.is_public + and fetched_persona.llm_filter_extraction + == persona.llm_filter_extraction + and fetched_persona.llm_model_provider_override + == persona.llm_model_provider_override + and fetched_persona.llm_model_version_override + == persona.llm_model_version_override + and set([prompt.id for prompt in fetched_persona.prompts]) + == set(persona.prompt_ids) + and set( + [ + document_set.id + for document_set in fetched_persona.document_sets + ] + ) + == set(persona.document_set_ids) + and set([tool.id for tool in fetched_persona.tools]) + == set(persona.tool_ids) + and set(user.email for user in fetched_persona.users) + == set(persona.users) + and set(fetched_persona.groups) == set(persona.groups) ) return False diff --git a/backend/tests/integration/common_utils/managers/user_group.py b/backend/tests/integration/common_utils/managers/user_group.py index 486dcb64c684..baf2008b9659 100644 --- a/backend/tests/integration/common_utils/managers/user_group.py +++ b/backend/tests/integration/common_utils/managers/user_group.py @@ -47,8 +47,6 @@ class UserGroupManager: user_group: DATestUserGroup, user_performing_action: DATestUser | None = None, ) -> None: - if not user_group.id: - raise ValueError("User group has no ID") response = requests.patch( f"{API_SERVER_URL}/manage/admin/user-group/{user_group.id}", json=user_group.model_dump(), @@ -58,6 +56,19 @@ class UserGroupManager: ) response.raise_for_status() + @staticmethod + def delete( + user_group: DATestUserGroup, + user_performing_action: DATestUser | None = None, + ) -> None: + response = requests.delete( + f"{API_SERVER_URL}/manage/admin/user-group/{user_group.id}", + headers=user_performing_action.headers + if user_performing_action + else GENERAL_HEADERS, + ) + response.raise_for_status() + @staticmethod def set_curator_status( test_user_group: DATestUserGroup, @@ -65,8 +76,6 @@ class UserGroupManager: is_curator: bool = True, user_performing_action: DATestUser | None = None, ) -> None: - if not user_to_set_as_curator.id: - raise ValueError("User has no ID") set_curator_request = { "user_id": user_to_set_as_curator.id, "is_curator": is_curator, @@ -130,7 +139,7 @@ class UserGroupManager: check_ids = {user_group.id for user_group in user_groups_to_check} user_group_ids = {user_group.id for user_group in user_groups} if not check_ids.issubset(user_group_ids): - raise RuntimeError("Document set not found") + raise RuntimeError("User group not found") user_groups = [ user_group for user_group in user_groups @@ -146,3 +155,26 @@ class UserGroupManager: else: print("User groups were not synced yet, waiting...") time.sleep(2) + + @staticmethod + def wait_for_deletion_completion( + user_groups_to_check: list[DATestUserGroup], + user_performing_action: DATestUser | None = None, + ) -> None: + start = time.time() + user_group_ids_to_check = {user_group.id for user_group in user_groups_to_check} + while True: + fetched_user_groups = UserGroupManager.get_all(user_performing_action) + fetched_user_group_ids = { + user_group.id for user_group in fetched_user_groups + } + if not user_group_ids_to_check.intersection(fetched_user_group_ids): + return + + if time.time() - start > MAX_DELAY: + raise TimeoutError( + f"User groups deletion was not completed within the {MAX_DELAY} seconds" + ) + else: + print("Some user groups are still being deleted, waiting...") + time.sleep(2) diff --git a/backend/tests/integration/common_utils/test_models.py b/backend/tests/integration/common_utils/test_models.py index 6e1afbba5f27..ca573663e722 100644 --- a/backend/tests/integration/common_utils/test_models.py +++ b/backend/tests/integration/common_utils/test_models.py @@ -87,7 +87,7 @@ class DATestLLMProvider(BaseModel): api_key: str default_model_name: str is_public: bool - groups: list[DATestUserGroup] + groups: list[int] api_base: str | None = None api_version: str | None = None diff --git a/backend/tests/integration/tests/connector/test_connector_deletion.py b/backend/tests/integration/tests/connector/test_connector_deletion.py index 5d160bb431f8..46a65f768a98 100644 --- a/backend/tests/integration/tests/connector/test_connector_deletion.py +++ b/backend/tests/integration/tests/connector/test_connector_deletion.py @@ -31,7 +31,7 @@ from tests.integration.common_utils.vespa import vespa_fixture def test_connector_deletion(reset: None, vespa_client: vespa_fixture) -> None: # Creating an admin user (first user created is automatically an admin) admin_user: DATestUser = UserManager.create(name="admin_user") - # add api key to user + # create api key api_key: DATestAPIKey = APIKeyManager.create( user_performing_action=admin_user, ) @@ -181,7 +181,7 @@ def test_connector_deletion_for_overlapping_connectors( """ # Creating an admin user (first user created is automatically an admin) admin_user: DATestUser = UserManager.create(name="admin_user") - # add api key to user + # create api key api_key: DATestAPIKey = APIKeyManager.create( user_performing_action=admin_user, ) diff --git a/backend/tests/integration/tests/dev_apis/test_knowledge_chat.py b/backend/tests/integration/tests/dev_apis/test_knowledge_chat.py index 866e51efd1c9..2cf6fd399eaf 100644 --- a/backend/tests/integration/tests/dev_apis/test_knowledge_chat.py +++ b/backend/tests/integration/tests/dev_apis/test_knowledge_chat.py @@ -2,10 +2,10 @@ import requests from danswer.configs.constants import MessageType from tests.integration.common_utils.constants import API_SERVER_URL -from tests.integration.common_utils.llm import LLMProviderManager from tests.integration.common_utils.managers.api_key import APIKeyManager from tests.integration.common_utils.managers.cc_pair import CCPairManager from tests.integration.common_utils.managers.document import DocumentManager +from tests.integration.common_utils.managers.llm_provider import LLMProviderManager from tests.integration.common_utils.managers.user import UserManager from tests.integration.common_utils.test_models import DATestAPIKey from tests.integration.common_utils.test_models import DATestCCPair diff --git a/backend/tests/integration/tests/dev_apis/test_simple_chat_api.py b/backend/tests/integration/tests/dev_apis/test_simple_chat_api.py index 6456b7dc9365..0a4e7b40b570 100644 --- a/backend/tests/integration/tests/dev_apis/test_simple_chat_api.py +++ b/backend/tests/integration/tests/dev_apis/test_simple_chat_api.py @@ -3,10 +3,10 @@ import requests from danswer.configs.constants import MessageType from tests.integration.common_utils.constants import API_SERVER_URL from tests.integration.common_utils.constants import NUM_DOCS -from tests.integration.common_utils.llm import LLMProviderManager from tests.integration.common_utils.managers.api_key import APIKeyManager from tests.integration.common_utils.managers.cc_pair import CCPairManager from tests.integration.common_utils.managers.document import DocumentManager +from tests.integration.common_utils.managers.llm_provider import LLMProviderManager from tests.integration.common_utils.managers.user import UserManager from tests.integration.common_utils.test_models import DATestAPIKey from tests.integration.common_utils.test_models import DATestCCPair diff --git a/backend/tests/integration/tests/document_set/test_syncing.py b/backend/tests/integration/tests/document_set/test_syncing.py index 37fa3397627b..ed00870663a5 100644 --- a/backend/tests/integration/tests/document_set/test_syncing.py +++ b/backend/tests/integration/tests/document_set/test_syncing.py @@ -16,7 +16,7 @@ def test_multiple_document_sets_syncing_same_connnector( # Creating an admin user (first user created is automatically an admin) admin_user: DATestUser = UserManager.create(name="admin_user") - # add api key to user + # create api key api_key: DATestAPIKey = APIKeyManager.create( user_performing_action=admin_user, ) @@ -70,7 +70,7 @@ def test_removing_connector(reset: None, vespa_client: vespa_fixture) -> None: # Creating an admin user (first user created is automatically an admin) admin_user: DATestUser = UserManager.create(name="admin_user") - # add api key to user + # create api key api_key: DATestAPIKey = APIKeyManager.create( user_performing_action=admin_user, ) diff --git a/backend/tests/integration/tests/streaming_endpoints/test_answer_stream.py b/backend/tests/integration/tests/streaming_endpoints/test_answer_stream.py index 6abaaae8b75b..3eb982ef228a 100644 --- a/backend/tests/integration/tests/streaming_endpoints/test_answer_stream.py +++ b/backend/tests/integration/tests/streaming_endpoints/test_answer_stream.py @@ -1,5 +1,5 @@ -from tests.integration.common_utils.llm import LLMProviderManager from tests.integration.common_utils.managers.chat import ChatSessionManager +from tests.integration.common_utils.managers.llm_provider import LLMProviderManager from tests.integration.common_utils.managers.user import UserManager from tests.integration.common_utils.test_models import DATestUser diff --git a/backend/tests/integration/tests/streaming_endpoints/test_chat_stream.py b/backend/tests/integration/tests/streaming_endpoints/test_chat_stream.py index a52584cb2e8b..dc4361301dc7 100644 --- a/backend/tests/integration/tests/streaming_endpoints/test_chat_stream.py +++ b/backend/tests/integration/tests/streaming_endpoints/test_chat_stream.py @@ -1,5 +1,5 @@ -from tests.integration.common_utils.llm import LLMProviderManager from tests.integration.common_utils.managers.chat import ChatSessionManager +from tests.integration.common_utils.managers.llm_provider import LLMProviderManager from tests.integration.common_utils.managers.user import UserManager from tests.integration.common_utils.test_models import DATestUser diff --git a/backend/tests/integration/tests/usergroup/test_user_group_deletion.py b/backend/tests/integration/tests/usergroup/test_user_group_deletion.py new file mode 100644 index 000000000000..eed0d9d8c170 --- /dev/null +++ b/backend/tests/integration/tests/usergroup/test_user_group_deletion.py @@ -0,0 +1,115 @@ +""" +This tests the deletion of a user group with the following foreign key constraints: +- connector_credential_pair +- user +- credential +- llm_provider +- document_set +- token_rate_limit (Not Implemented) +- persona +""" +from danswer.server.documents.models import DocumentSource +from tests.integration.common_utils.managers.cc_pair import CCPairManager +from tests.integration.common_utils.managers.credential import CredentialManager +from tests.integration.common_utils.managers.document_set import DocumentSetManager +from tests.integration.common_utils.managers.llm_provider import LLMProviderManager +from tests.integration.common_utils.managers.persona import PersonaManager +from tests.integration.common_utils.managers.user import UserManager +from tests.integration.common_utils.managers.user_group import UserGroupManager +from tests.integration.common_utils.test_models import DATestCredential +from tests.integration.common_utils.test_models import DATestDocumentSet +from tests.integration.common_utils.test_models import DATestLLMProvider +from tests.integration.common_utils.test_models import DATestPersona +from tests.integration.common_utils.test_models import DATestUser +from tests.integration.common_utils.test_models import DATestUserGroup +from tests.integration.common_utils.vespa import vespa_fixture + + +def test_user_group_deletion(reset: None, vespa_client: vespa_fixture) -> None: + # Creating an admin user (first user created is automatically an admin) + admin_user: DATestUser = UserManager.create(name="admin_user") + + # create connectors + cc_pair = CCPairManager.create_from_scratch( + source=DocumentSource.INGESTION_API, + user_performing_action=admin_user, + ) + + # Create user group with a cc_pair and a user + user_group: DATestUserGroup = UserGroupManager.create( + user_ids=[admin_user.id], + cc_pair_ids=[cc_pair.id], + user_performing_action=admin_user, + ) + cc_pair.groups = [user_group.id] + + UserGroupManager.wait_for_sync( + user_groups_to_check=[user_group], user_performing_action=admin_user + ) + UserGroupManager.verify( + user_group=user_group, + user_performing_action=admin_user, + ) + CCPairManager.verify( + cc_pair=cc_pair, + user_performing_action=admin_user, + ) + + credential: DATestCredential = CredentialManager.create( + groups=[user_group.id], + user_performing_action=admin_user, + ) + document_set: DATestDocumentSet = DocumentSetManager.create( + cc_pair_ids=[cc_pair.id], + groups=[user_group.id], + user_performing_action=admin_user, + ) + llm_provider: DATestLLMProvider = LLMProviderManager.create( + groups=[user_group.id], + user_performing_action=admin_user, + ) + persona: DATestPersona = PersonaManager.create( + groups=[user_group.id], + user_performing_action=admin_user, + ) + + UserGroupManager.wait_for_sync( + user_groups_to_check=[user_group], user_performing_action=admin_user + ) + UserGroupManager.verify( + user_group=user_group, + user_performing_action=admin_user, + ) + + UserGroupManager.delete( + user_group=user_group, + user_performing_action=admin_user, + ) + + UserGroupManager.wait_for_deletion_completion( + user_groups_to_check=[user_group], user_performing_action=admin_user + ) + + credential.groups = [] + document_set.groups = [] + llm_provider.groups = [] + persona.groups = [] + CredentialManager.verify( + credential=credential, + user_performing_action=admin_user, + ) + + DocumentSetManager.verify( + document_set=document_set, + user_performing_action=admin_user, + ) + + LLMProviderManager.verify( + llm_provider=llm_provider, + user_performing_action=admin_user, + ) + + PersonaManager.verify( + persona=persona, + user_performing_action=admin_user, + ) diff --git a/backend/tests/integration/tests/usergroup/test_usergroup_syncing.py b/backend/tests/integration/tests/usergroup/test_usergroup_syncing.py index 99620c6279c0..5d1ee3b10214 100644 --- a/backend/tests/integration/tests/usergroup/test_usergroup_syncing.py +++ b/backend/tests/integration/tests/usergroup/test_usergroup_syncing.py @@ -15,7 +15,7 @@ def test_removing_connector(reset: None, vespa_client: vespa_fixture) -> None: # Creating an admin user (first user created is automatically an admin) admin_user: DATestUser = UserManager.create(name="admin_user") - # add api key to user + # create api key api_key: DATestAPIKey = APIKeyManager.create( user_performing_action=admin_user, )