From 6afd27f9c9d2be6d51d3fba6ad209615b8b46b0e Mon Sep 17 00:00:00 2001 From: hagen-danswer Date: Fri, 10 Jan 2025 16:51:33 -0800 Subject: [PATCH] fix group sync name capitalization (#3653) * fix group sync name capitalization * everything is lowercased now * comments * Added test for be2ab2aa50ee migration * polish --- .../be2ab2aa50ee_fix_capitalization.py | 38 ++++++ backend/ee/onyx/db/document.py | 6 +- backend/ee/onyx/db/external_perm.py | 17 ++- backend/onyx/access/utils.py | 9 +- .../tasks/doc_permission_syncing/tasks.py | 4 +- .../connectors/confluence/onyx_confluence.py | 55 ++++---- .../tests/integration/common_utils/reset.py | 117 +++++++++------- .../tests/migrations/test_migrations.py | 125 ++++++++++++++++++ 8 files changed, 285 insertions(+), 86 deletions(-) create mode 100644 backend/alembic/versions/be2ab2aa50ee_fix_capitalization.py create mode 100644 backend/tests/integration/tests/migrations/test_migrations.py diff --git a/backend/alembic/versions/be2ab2aa50ee_fix_capitalization.py b/backend/alembic/versions/be2ab2aa50ee_fix_capitalization.py new file mode 100644 index 000000000..ea6f201cc --- /dev/null +++ b/backend/alembic/versions/be2ab2aa50ee_fix_capitalization.py @@ -0,0 +1,38 @@ +"""fix_capitalization + +Revision ID: be2ab2aa50ee +Revises: 369644546676 +Create Date: 2025-01-10 13:13:26.228960 + +""" +from alembic import op + +# revision identifiers, used by Alembic. +revision = "be2ab2aa50ee" +down_revision = "369644546676" +branch_labels = None +depends_on = None + + +def upgrade() -> None: + op.execute( + """ + UPDATE document + SET + external_user_group_ids = ARRAY( + SELECT LOWER(unnest(external_user_group_ids)) + ), + last_modified = NOW() + WHERE + external_user_group_ids IS NOT NULL + AND external_user_group_ids::text[] <> ARRAY( + SELECT LOWER(unnest(external_user_group_ids)) + )::text[] + """ + ) + + +def downgrade() -> None: + # No way to cleanly persist the bad state through an upgrade/downgrade + # cycle, so we just pass + pass diff --git a/backend/ee/onyx/db/document.py b/backend/ee/onyx/db/document.py index 2ec5a3623..ad61cff4f 100644 --- a/backend/ee/onyx/db/document.py +++ b/backend/ee/onyx/db/document.py @@ -5,7 +5,7 @@ from sqlalchemy import select from sqlalchemy.orm import Session from onyx.access.models import ExternalAccess -from onyx.access.utils import prefix_group_w_source +from onyx.access.utils import build_ext_group_name_for_onyx from onyx.configs.constants import DocumentSource from onyx.db.models import Document as DbDocument @@ -25,7 +25,7 @@ def upsert_document_external_perms__no_commit( ).first() prefixed_external_groups = [ - prefix_group_w_source( + build_ext_group_name_for_onyx( ext_group_name=group_id, source=source_type, ) @@ -66,7 +66,7 @@ def upsert_document_external_perms( ).first() prefixed_external_groups: set[str] = { - prefix_group_w_source( + build_ext_group_name_for_onyx( ext_group_name=group_id, source=source_type, ) diff --git a/backend/ee/onyx/db/external_perm.py b/backend/ee/onyx/db/external_perm.py index 16de8bb41..9992f86df 100644 --- a/backend/ee/onyx/db/external_perm.py +++ b/backend/ee/onyx/db/external_perm.py @@ -6,8 +6,9 @@ from sqlalchemy import delete from sqlalchemy import select from sqlalchemy.orm import Session -from onyx.access.utils import prefix_group_w_source +from onyx.access.utils import build_ext_group_name_for_onyx from onyx.configs.constants import DocumentSource +from onyx.db.models import User from onyx.db.models import User__ExternalUserGroupId from onyx.db.users import batch_add_ext_perm_user_if_not_exists from onyx.db.users import get_user_by_email @@ -61,8 +62,10 @@ def replace_user__ext_group_for_cc_pair( all_group_member_emails.add(user_email) # batch add users if they don't exist and get their ids - all_group_members = batch_add_ext_perm_user_if_not_exists( - db_session=db_session, emails=list(all_group_member_emails) + all_group_members: list[User] = batch_add_ext_perm_user_if_not_exists( + db_session=db_session, + # NOTE: this function handles case sensitivity for emails + emails=list(all_group_member_emails), ) delete_user__ext_group_for_cc_pair__no_commit( @@ -84,12 +87,14 @@ def replace_user__ext_group_for_cc_pair( f" with email {user_email} not found" ) continue + external_group_id = build_ext_group_name_for_onyx( + ext_group_name=external_group.id, + source=source, + ) new_external_permissions.append( User__ExternalUserGroupId( user_id=user_id, - external_user_group_id=prefix_group_w_source( - external_group.id, source - ), + external_user_group_id=external_group_id, cc_pair_id=cc_pair_id, ) ) diff --git a/backend/onyx/access/utils.py b/backend/onyx/access/utils.py index 3ff9c42bc..52d0e3274 100644 --- a/backend/onyx/access/utils.py +++ b/backend/onyx/access/utils.py @@ -19,6 +19,9 @@ def prefix_external_group(ext_group_name: str) -> str: return f"external_group:{ext_group_name}" -def prefix_group_w_source(ext_group_name: str, source: DocumentSource) -> str: - """External groups may collide across sources, every source needs its own prefix.""" - return f"{source.value.upper()}_{ext_group_name}" +def build_ext_group_name_for_onyx(ext_group_name: str, source: DocumentSource) -> str: + """ + External groups may collide across sources, every source needs its own prefix. + NOTE: the name is lowercased to handle case sensitivity for group names + """ + return f"{source.value}_{ext_group_name}".lower() diff --git a/backend/onyx/background/celery/tasks/doc_permission_syncing/tasks.py b/backend/onyx/background/celery/tasks/doc_permission_syncing/tasks.py index 20ad0a075..5e1e3c2c0 100644 --- a/backend/onyx/background/celery/tasks/doc_permission_syncing/tasks.py +++ b/backend/onyx/background/celery/tasks/doc_permission_syncing/tasks.py @@ -391,5 +391,7 @@ def update_external_document_permissions_task( ) return True except Exception: - logger.exception("Error Syncing Document Permissions") + logger.exception( + f"Error Syncing Document Permissions: connector_id={connector_id} doc_id={doc_id}" + ) return False diff --git a/backend/onyx/connectors/confluence/onyx_confluence.py b/backend/onyx/connectors/confluence/onyx_confluence.py index e6a2b957e..96b9a3702 100644 --- a/backend/onyx/connectors/confluence/onyx_confluence.py +++ b/backend/onyx/connectors/confluence/onyx_confluence.py @@ -135,32 +135,6 @@ class OnyxConfluence(Confluence): super(OnyxConfluence, self).__init__(url, *args, **kwargs) self._wrap_methods() - def get_current_user(self, expand: str | None = None) -> Any: - """ - Implements a method that isn't in the third party client. - - Get information about the current user - :param expand: OPTIONAL expand for get status of user. - Possible param is "status". Results are "Active, Deactivated" - :return: Returns the user details - """ - - from atlassian.errors import ApiPermissionError # type:ignore - - url = "rest/api/user/current" - params = {} - if expand: - params["expand"] = expand - try: - response = self.get(url, params=params) - except HTTPError as e: - if e.response.status_code == 403: - raise ApiPermissionError( - "The calling user does not have permission", reason=e - ) - raise - return response - def _wrap_methods(self) -> None: """ For each attribute that is callable (i.e., a method) and doesn't start with an underscore, @@ -363,6 +337,9 @@ class OnyxConfluence(Confluence): fetch the permissions of a space. This is better logging than calling the get_space_permissions method because it returns a jsonrpc response. + TODO: Make this call these endpoints for newer confluence versions: + - /rest/api/space/{spaceKey}/permissions + - /rest/api/space/{spaceKey}/permissions/anonymous """ url = "rpc/json-rpc/confluenceservice-v2" data = { @@ -381,6 +358,32 @@ class OnyxConfluence(Confluence): return response.get("result", []) + def get_current_user(self, expand: str | None = None) -> Any: + """ + Implements a method that isn't in the third party client. + + Get information about the current user + :param expand: OPTIONAL expand for get status of user. + Possible param is "status". Results are "Active, Deactivated" + :return: Returns the user details + """ + + from atlassian.errors import ApiPermissionError # type:ignore + + url = "rest/api/user/current" + params = {} + if expand: + params["expand"] = expand + try: + response = self.get(url, params=params) + except HTTPError as e: + if e.response.status_code == 403: + raise ApiPermissionError( + "The calling user does not have permission", reason=e + ) + raise + return response + def _validate_connector_configuration( credentials: dict[str, Any], diff --git a/backend/tests/integration/common_utils/reset.py b/backend/tests/integration/common_utils/reset.py index 116d91c42..4c6afb103 100644 --- a/backend/tests/integration/common_utils/reset.py +++ b/backend/tests/integration/common_utils/reset.py @@ -63,57 +63,57 @@ def _run_migrations( logging.getLogger("alembic").setLevel(logging.INFO) -def reset_postgres( - database: str = "postgres", config_name: str = "alembic", setup_onyx: bool = True +def downgrade_postgres( + database: str = "postgres", + config_name: str = "alembic", + revision: str = "base", + clear_data: bool = False, ) -> None: - """Reset the Postgres database.""" + """Downgrade Postgres database to base state.""" + if clear_data: + if revision != "base": + logger.warning("Clearing data without rolling back to base state") + # Delete all rows to allow migrations to be rolled back + conn = psycopg2.connect( + dbname=database, + user=POSTGRES_USER, + password=POSTGRES_PASSWORD, + host=POSTGRES_HOST, + port=POSTGRES_PORT, + ) + cur = conn.cursor() - # NOTE: need to delete all rows to allow migrations to be rolled back - # as there are a few downgrades that don't properly handle data in tables - conn = psycopg2.connect( - dbname=database, - user=POSTGRES_USER, - password=POSTGRES_PASSWORD, - host=POSTGRES_HOST, - port=POSTGRES_PORT, - ) - cur = conn.cursor() + # Disable triggers to prevent foreign key constraints from being checked + cur.execute("SET session_replication_role = 'replica';") - # Disable triggers to prevent foreign key constraints from being checked - cur.execute("SET session_replication_role = 'replica';") - - # Fetch all table names in the current database - cur.execute( + # Fetch all table names in the current database + cur.execute( + """ + SELECT tablename + FROM pg_tables + WHERE schemaname = 'public' """ - SELECT tablename - FROM pg_tables - WHERE schemaname = 'public' - """ - ) + ) - tables = cur.fetchall() + tables = cur.fetchall() - for table in tables: - table_name = table[0] + for table in tables: + table_name = table[0] - # Don't touch migration history - if table_name == "alembic_version": - continue + # Don't touch migration history or Kombu + if table_name in ("alembic_version", "kombu_message", "kombu_queue"): + continue - # Don't touch Kombu - if table_name == "kombu_message" or table_name == "kombu_queue": - continue + cur.execute(f'DELETE FROM "{table_name}"') - cur.execute(f'DELETE FROM "{table_name}"') + # Re-enable triggers + cur.execute("SET session_replication_role = 'origin';") - # Re-enable triggers - cur.execute("SET session_replication_role = 'origin';") + conn.commit() + cur.close() + conn.close() - conn.commit() - cur.close() - conn.close() - - # downgrade to base + upgrade back to head + # Downgrade to base conn_str = build_connection_string( db=database, user=POSTGRES_USER, @@ -126,20 +126,43 @@ def reset_postgres( conn_str, config_name, direction="downgrade", - revision="base", + revision=revision, + ) + + +def upgrade_postgres( + database: str = "postgres", config_name: str = "alembic", revision: str = "head" +) -> None: + """Upgrade Postgres database to latest version.""" + conn_str = build_connection_string( + db=database, + user=POSTGRES_USER, + password=POSTGRES_PASSWORD, + host=POSTGRES_HOST, + port=POSTGRES_PORT, + db_api=SYNC_DB_API, ) _run_migrations( conn_str, config_name, direction="upgrade", - revision="head", + revision=revision, ) - if not setup_onyx: - return - # do the same thing as we do on API server startup - with get_session_context_manager() as db_session: - setup_postgres(db_session) + +def reset_postgres( + database: str = "postgres", + config_name: str = "alembic", + setup_onyx: bool = True, +) -> None: + """Reset the Postgres database.""" + downgrade_postgres( + database=database, config_name=config_name, revision="base", clear_data=True + ) + upgrade_postgres(database=database, config_name=config_name, revision="head") + if setup_onyx: + with get_session_context_manager() as db_session: + setup_postgres(db_session) def reset_vespa() -> None: diff --git a/backend/tests/integration/tests/migrations/test_migrations.py b/backend/tests/integration/tests/migrations/test_migrations.py new file mode 100644 index 000000000..19b6fb1fb --- /dev/null +++ b/backend/tests/integration/tests/migrations/test_migrations.py @@ -0,0 +1,125 @@ +import pytest +from sqlalchemy import text + +from onyx.configs.constants import DEFAULT_BOOST +from onyx.db.engine import get_session_context_manager +from tests.integration.common_utils.reset import downgrade_postgres +from tests.integration.common_utils.reset import upgrade_postgres + + +@pytest.mark.skip( + reason="Migration test no longer needed - migration has been applied to production" +) +def test_fix_capitalization_migration() -> None: + """Test that the be2ab2aa50ee migration correctly lowercases external_user_group_ids""" + # Reset the database and run migrations up to the second to last migration + downgrade_postgres( + database="postgres", config_name="alembic", revision="base", clear_data=True + ) + upgrade_postgres( + database="postgres", + config_name="alembic", + # Upgrade it to the migration before the fix + revision="369644546676", + ) + + # Insert test data with mixed case group IDs + test_data = [ + { + "id": "test_doc_1", + "external_user_group_ids": ["Group1", "GROUP2", "group3"], + "semantic_id": "test_doc_1", + "boost": DEFAULT_BOOST, + "hidden": False, + "from_ingestion_api": False, + "last_modified": "NOW()", + }, + { + "id": "test_doc_2", + "external_user_group_ids": ["UPPER1", "upper2", "UPPER3"], + "semantic_id": "test_doc_2", + "boost": DEFAULT_BOOST, + "hidden": False, + "from_ingestion_api": False, + "last_modified": "NOW()", + }, + ] + + # Insert the test data + with get_session_context_manager() as db_session: + for doc in test_data: + db_session.execute( + text( + """ + INSERT INTO document ( + id, + external_user_group_ids, + semantic_id, + boost, + hidden, + from_ingestion_api, + last_modified + ) + VALUES ( + :id, + :group_ids, + :semantic_id, + :boost, + :hidden, + :from_ingestion_api, + :last_modified + ) + """ + ), + { + "id": doc["id"], + "group_ids": doc["external_user_group_ids"], + "semantic_id": doc["semantic_id"], + "boost": doc["boost"], + "hidden": doc["hidden"], + "from_ingestion_api": doc["from_ingestion_api"], + "last_modified": doc["last_modified"], + }, + ) + db_session.commit() + + # Verify the data was inserted correctly + with get_session_context_manager() as db_session: + results = db_session.execute( + text( + """ + SELECT id, external_user_group_ids + FROM document + WHERE id IN ('test_doc_1', 'test_doc_2') + ORDER BY id + """ + ) + ).fetchall() + + # Verify initial state + assert len(results) == 2 + assert results[0].external_user_group_ids == ["Group1", "GROUP2", "group3"] + assert results[1].external_user_group_ids == ["UPPER1", "upper2", "UPPER3"] + + # Run migrations again to apply the fix + upgrade_postgres( + database="postgres", config_name="alembic", revision="be2ab2aa50ee" + ) + + # Verify the fix was applied + with get_session_context_manager() as db_session: + results = db_session.execute( + text( + """ + SELECT id, external_user_group_ids + FROM document + WHERE id IN ('test_doc_1', 'test_doc_2') + ORDER BY id + """ + ) + ).fetchall() + + # Verify all group IDs are lowercase + assert len(results) == 2 + assert results[0].external_user_group_ids == ["group1", "group2", "group3"] + assert results[1].external_user_group_ids == ["upper1", "upper2", "upper3"]