fix group sync name capitalization (#3653)

* fix group sync name capitalization

* everything is lowercased now

* comments

* Added test for be2ab2aa50ee migration

* polish
This commit is contained in:
hagen-danswer 2025-01-10 16:51:33 -08:00 committed by GitHub
parent cab7e60542
commit 6afd27f9c9
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
8 changed files with 285 additions and 86 deletions

View File

@ -0,0 +1,38 @@
"""fix_capitalization
Revision ID: be2ab2aa50ee
Revises: 369644546676
Create Date: 2025-01-10 13:13:26.228960
"""
from alembic import op
# revision identifiers, used by Alembic.
revision = "be2ab2aa50ee"
down_revision = "369644546676"
branch_labels = None
depends_on = None
def upgrade() -> None:
op.execute(
"""
UPDATE document
SET
external_user_group_ids = ARRAY(
SELECT LOWER(unnest(external_user_group_ids))
),
last_modified = NOW()
WHERE
external_user_group_ids IS NOT NULL
AND external_user_group_ids::text[] <> ARRAY(
SELECT LOWER(unnest(external_user_group_ids))
)::text[]
"""
)
def downgrade() -> None:
# No way to cleanly persist the bad state through an upgrade/downgrade
# cycle, so we just pass
pass

View File

@ -5,7 +5,7 @@ from sqlalchemy import select
from sqlalchemy.orm import Session
from onyx.access.models import ExternalAccess
from onyx.access.utils import prefix_group_w_source
from onyx.access.utils import build_ext_group_name_for_onyx
from onyx.configs.constants import DocumentSource
from onyx.db.models import Document as DbDocument
@ -25,7 +25,7 @@ def upsert_document_external_perms__no_commit(
).first()
prefixed_external_groups = [
prefix_group_w_source(
build_ext_group_name_for_onyx(
ext_group_name=group_id,
source=source_type,
)
@ -66,7 +66,7 @@ def upsert_document_external_perms(
).first()
prefixed_external_groups: set[str] = {
prefix_group_w_source(
build_ext_group_name_for_onyx(
ext_group_name=group_id,
source=source_type,
)

View File

@ -6,8 +6,9 @@ from sqlalchemy import delete
from sqlalchemy import select
from sqlalchemy.orm import Session
from onyx.access.utils import prefix_group_w_source
from onyx.access.utils import build_ext_group_name_for_onyx
from onyx.configs.constants import DocumentSource
from onyx.db.models import User
from onyx.db.models import User__ExternalUserGroupId
from onyx.db.users import batch_add_ext_perm_user_if_not_exists
from onyx.db.users import get_user_by_email
@ -61,8 +62,10 @@ def replace_user__ext_group_for_cc_pair(
all_group_member_emails.add(user_email)
# batch add users if they don't exist and get their ids
all_group_members = batch_add_ext_perm_user_if_not_exists(
db_session=db_session, emails=list(all_group_member_emails)
all_group_members: list[User] = batch_add_ext_perm_user_if_not_exists(
db_session=db_session,
# NOTE: this function handles case sensitivity for emails
emails=list(all_group_member_emails),
)
delete_user__ext_group_for_cc_pair__no_commit(
@ -84,12 +87,14 @@ def replace_user__ext_group_for_cc_pair(
f" with email {user_email} not found"
)
continue
external_group_id = build_ext_group_name_for_onyx(
ext_group_name=external_group.id,
source=source,
)
new_external_permissions.append(
User__ExternalUserGroupId(
user_id=user_id,
external_user_group_id=prefix_group_w_source(
external_group.id, source
),
external_user_group_id=external_group_id,
cc_pair_id=cc_pair_id,
)
)

View File

@ -19,6 +19,9 @@ def prefix_external_group(ext_group_name: str) -> str:
return f"external_group:{ext_group_name}"
def prefix_group_w_source(ext_group_name: str, source: DocumentSource) -> str:
"""External groups may collide across sources, every source needs its own prefix."""
return f"{source.value.upper()}_{ext_group_name}"
def build_ext_group_name_for_onyx(ext_group_name: str, source: DocumentSource) -> str:
"""
External groups may collide across sources, every source needs its own prefix.
NOTE: the name is lowercased to handle case sensitivity for group names
"""
return f"{source.value}_{ext_group_name}".lower()

View File

@ -391,5 +391,7 @@ def update_external_document_permissions_task(
)
return True
except Exception:
logger.exception("Error Syncing Document Permissions")
logger.exception(
f"Error Syncing Document Permissions: connector_id={connector_id} doc_id={doc_id}"
)
return False

View File

@ -135,32 +135,6 @@ class OnyxConfluence(Confluence):
super(OnyxConfluence, self).__init__(url, *args, **kwargs)
self._wrap_methods()
def get_current_user(self, expand: str | None = None) -> Any:
"""
Implements a method that isn't in the third party client.
Get information about the current user
:param expand: OPTIONAL expand for get status of user.
Possible param is "status". Results are "Active, Deactivated"
:return: Returns the user details
"""
from atlassian.errors import ApiPermissionError # type:ignore
url = "rest/api/user/current"
params = {}
if expand:
params["expand"] = expand
try:
response = self.get(url, params=params)
except HTTPError as e:
if e.response.status_code == 403:
raise ApiPermissionError(
"The calling user does not have permission", reason=e
)
raise
return response
def _wrap_methods(self) -> None:
"""
For each attribute that is callable (i.e., a method) and doesn't start with an underscore,
@ -363,6 +337,9 @@ class OnyxConfluence(Confluence):
fetch the permissions of a space.
This is better logging than calling the get_space_permissions method
because it returns a jsonrpc response.
TODO: Make this call these endpoints for newer confluence versions:
- /rest/api/space/{spaceKey}/permissions
- /rest/api/space/{spaceKey}/permissions/anonymous
"""
url = "rpc/json-rpc/confluenceservice-v2"
data = {
@ -381,6 +358,32 @@ class OnyxConfluence(Confluence):
return response.get("result", [])
def get_current_user(self, expand: str | None = None) -> Any:
"""
Implements a method that isn't in the third party client.
Get information about the current user
:param expand: OPTIONAL expand for get status of user.
Possible param is "status". Results are "Active, Deactivated"
:return: Returns the user details
"""
from atlassian.errors import ApiPermissionError # type:ignore
url = "rest/api/user/current"
params = {}
if expand:
params["expand"] = expand
try:
response = self.get(url, params=params)
except HTTPError as e:
if e.response.status_code == 403:
raise ApiPermissionError(
"The calling user does not have permission", reason=e
)
raise
return response
def _validate_connector_configuration(
credentials: dict[str, Any],

View File

@ -63,57 +63,57 @@ def _run_migrations(
logging.getLogger("alembic").setLevel(logging.INFO)
def reset_postgres(
database: str = "postgres", config_name: str = "alembic", setup_onyx: bool = True
def downgrade_postgres(
database: str = "postgres",
config_name: str = "alembic",
revision: str = "base",
clear_data: bool = False,
) -> None:
"""Reset the Postgres database."""
"""Downgrade Postgres database to base state."""
if clear_data:
if revision != "base":
logger.warning("Clearing data without rolling back to base state")
# Delete all rows to allow migrations to be rolled back
conn = psycopg2.connect(
dbname=database,
user=POSTGRES_USER,
password=POSTGRES_PASSWORD,
host=POSTGRES_HOST,
port=POSTGRES_PORT,
)
cur = conn.cursor()
# NOTE: need to delete all rows to allow migrations to be rolled back
# as there are a few downgrades that don't properly handle data in tables
conn = psycopg2.connect(
dbname=database,
user=POSTGRES_USER,
password=POSTGRES_PASSWORD,
host=POSTGRES_HOST,
port=POSTGRES_PORT,
)
cur = conn.cursor()
# Disable triggers to prevent foreign key constraints from being checked
cur.execute("SET session_replication_role = 'replica';")
# Disable triggers to prevent foreign key constraints from being checked
cur.execute("SET session_replication_role = 'replica';")
# Fetch all table names in the current database
cur.execute(
# Fetch all table names in the current database
cur.execute(
"""
SELECT tablename
FROM pg_tables
WHERE schemaname = 'public'
"""
SELECT tablename
FROM pg_tables
WHERE schemaname = 'public'
"""
)
)
tables = cur.fetchall()
tables = cur.fetchall()
for table in tables:
table_name = table[0]
for table in tables:
table_name = table[0]
# Don't touch migration history
if table_name == "alembic_version":
continue
# Don't touch migration history or Kombu
if table_name in ("alembic_version", "kombu_message", "kombu_queue"):
continue
# Don't touch Kombu
if table_name == "kombu_message" or table_name == "kombu_queue":
continue
cur.execute(f'DELETE FROM "{table_name}"')
cur.execute(f'DELETE FROM "{table_name}"')
# Re-enable triggers
cur.execute("SET session_replication_role = 'origin';")
# Re-enable triggers
cur.execute("SET session_replication_role = 'origin';")
conn.commit()
cur.close()
conn.close()
conn.commit()
cur.close()
conn.close()
# downgrade to base + upgrade back to head
# Downgrade to base
conn_str = build_connection_string(
db=database,
user=POSTGRES_USER,
@ -126,20 +126,43 @@ def reset_postgres(
conn_str,
config_name,
direction="downgrade",
revision="base",
revision=revision,
)
def upgrade_postgres(
database: str = "postgres", config_name: str = "alembic", revision: str = "head"
) -> None:
"""Upgrade Postgres database to latest version."""
conn_str = build_connection_string(
db=database,
user=POSTGRES_USER,
password=POSTGRES_PASSWORD,
host=POSTGRES_HOST,
port=POSTGRES_PORT,
db_api=SYNC_DB_API,
)
_run_migrations(
conn_str,
config_name,
direction="upgrade",
revision="head",
revision=revision,
)
if not setup_onyx:
return
# do the same thing as we do on API server startup
with get_session_context_manager() as db_session:
setup_postgres(db_session)
def reset_postgres(
database: str = "postgres",
config_name: str = "alembic",
setup_onyx: bool = True,
) -> None:
"""Reset the Postgres database."""
downgrade_postgres(
database=database, config_name=config_name, revision="base", clear_data=True
)
upgrade_postgres(database=database, config_name=config_name, revision="head")
if setup_onyx:
with get_session_context_manager() as db_session:
setup_postgres(db_session)
def reset_vespa() -> None:

View File

@ -0,0 +1,125 @@
import pytest
from sqlalchemy import text
from onyx.configs.constants import DEFAULT_BOOST
from onyx.db.engine import get_session_context_manager
from tests.integration.common_utils.reset import downgrade_postgres
from tests.integration.common_utils.reset import upgrade_postgres
@pytest.mark.skip(
reason="Migration test no longer needed - migration has been applied to production"
)
def test_fix_capitalization_migration() -> None:
"""Test that the be2ab2aa50ee migration correctly lowercases external_user_group_ids"""
# Reset the database and run migrations up to the second to last migration
downgrade_postgres(
database="postgres", config_name="alembic", revision="base", clear_data=True
)
upgrade_postgres(
database="postgres",
config_name="alembic",
# Upgrade it to the migration before the fix
revision="369644546676",
)
# Insert test data with mixed case group IDs
test_data = [
{
"id": "test_doc_1",
"external_user_group_ids": ["Group1", "GROUP2", "group3"],
"semantic_id": "test_doc_1",
"boost": DEFAULT_BOOST,
"hidden": False,
"from_ingestion_api": False,
"last_modified": "NOW()",
},
{
"id": "test_doc_2",
"external_user_group_ids": ["UPPER1", "upper2", "UPPER3"],
"semantic_id": "test_doc_2",
"boost": DEFAULT_BOOST,
"hidden": False,
"from_ingestion_api": False,
"last_modified": "NOW()",
},
]
# Insert the test data
with get_session_context_manager() as db_session:
for doc in test_data:
db_session.execute(
text(
"""
INSERT INTO document (
id,
external_user_group_ids,
semantic_id,
boost,
hidden,
from_ingestion_api,
last_modified
)
VALUES (
:id,
:group_ids,
:semantic_id,
:boost,
:hidden,
:from_ingestion_api,
:last_modified
)
"""
),
{
"id": doc["id"],
"group_ids": doc["external_user_group_ids"],
"semantic_id": doc["semantic_id"],
"boost": doc["boost"],
"hidden": doc["hidden"],
"from_ingestion_api": doc["from_ingestion_api"],
"last_modified": doc["last_modified"],
},
)
db_session.commit()
# Verify the data was inserted correctly
with get_session_context_manager() as db_session:
results = db_session.execute(
text(
"""
SELECT id, external_user_group_ids
FROM document
WHERE id IN ('test_doc_1', 'test_doc_2')
ORDER BY id
"""
)
).fetchall()
# Verify initial state
assert len(results) == 2
assert results[0].external_user_group_ids == ["Group1", "GROUP2", "group3"]
assert results[1].external_user_group_ids == ["UPPER1", "upper2", "UPPER3"]
# Run migrations again to apply the fix
upgrade_postgres(
database="postgres", config_name="alembic", revision="be2ab2aa50ee"
)
# Verify the fix was applied
with get_session_context_manager() as db_session:
results = db_session.execute(
text(
"""
SELECT id, external_user_group_ids
FROM document
WHERE id IN ('test_doc_1', 'test_doc_2')
ORDER BY id
"""
)
).fetchall()
# Verify all group IDs are lowercase
assert len(results) == 2
assert results[0].external_user_group_ids == ["group1", "group2", "group3"]
assert results[1].external_user_group_ids == ["upper1", "upper2", "upper3"]