From bb7e1d6e55e823fea4fadfd1cef1006920d39f3b Mon Sep 17 00:00:00 2001 From: Chris Weaver <25087905+Weves@users.noreply.github.com> Date: Tue, 6 Aug 2024 18:00:19 -0700 Subject: [PATCH] Add integration tests for document set syncing (#1904) --- ...5_remove_feedback_foreignkey_constraint.py | 2 +- .../27c6ecc08586_permission_framework.py | 28 ++- .../76b60d407dfb_cc_pair_name_not_unique.py | 8 +- ...45c915d0e_remove_deletion_attempt_table.py | 3 + .../versions/dbaa756c2ccf_embedding_models.py | 2 +- backend/danswer/db/document.py | 2 +- backend/danswer/main.py | 82 +++++---- .../danswer/server/danswer_api/ingestion.py | 7 +- backend/danswer/server/gpts/api.py | 3 +- backend/ee/danswer/auth/users.py | 7 +- backend/tests/integration/common/constants.py | 1 + backend/tests/integration/common/reset.py | 164 ++++++++++++++++++ .../integration/common/seed_documents.py | 83 +++++++++ backend/tests/integration/common/vespa.py | 27 +++ backend/tests/integration/conftest.py | 26 +++ .../integration/document_set/test_syncing.py | 79 +++++++++ .../tests/integration/document_set/utils.py | 26 +++ 17 files changed, 501 insertions(+), 49 deletions(-) create mode 100644 backend/tests/integration/common/constants.py create mode 100644 backend/tests/integration/common/reset.py create mode 100644 backend/tests/integration/common/seed_documents.py create mode 100644 backend/tests/integration/common/vespa.py create mode 100644 backend/tests/integration/conftest.py create mode 100644 backend/tests/integration/document_set/test_syncing.py create mode 100644 backend/tests/integration/document_set/utils.py diff --git a/backend/alembic/versions/23957775e5f5_remove_feedback_foreignkey_constraint.py b/backend/alembic/versions/23957775e5f5_remove_feedback_foreignkey_constraint.py index 07b5601013..10d094e0da 100644 --- a/backend/alembic/versions/23957775e5f5_remove_feedback_foreignkey_constraint.py +++ b/backend/alembic/versions/23957775e5f5_remove_feedback_foreignkey_constraint.py @@ -79,7 +79,7 @@ def downgrade() -> None: ) op.create_foreign_key( "document_retrieval_feedback__chat_message_fk", - "document_retrieval", + "document_retrieval_feedback", "chat_message", ["chat_message_id"], ["id"], diff --git a/backend/alembic/versions/27c6ecc08586_permission_framework.py b/backend/alembic/versions/27c6ecc08586_permission_framework.py index cd869e2ba6..ff41d2f5cf 100644 --- a/backend/alembic/versions/27c6ecc08586_permission_framework.py +++ b/backend/alembic/versions/27c6ecc08586_permission_framework.py @@ -160,12 +160,28 @@ def downgrade() -> None: nullable=False, ), ) - op.drop_constraint( - "fk_index_attempt_credential_id", "index_attempt", type_="foreignkey" - ) - op.drop_constraint( - "fk_index_attempt_connector_id", "index_attempt", type_="foreignkey" - ) + + # Check if the constraint exists before dropping + conn = op.get_bind() + inspector = sa.inspect(conn) + constraints = inspector.get_foreign_keys("index_attempt") + + if any( + constraint["name"] == "fk_index_attempt_credential_id" + for constraint in constraints + ): + op.drop_constraint( + "fk_index_attempt_credential_id", "index_attempt", type_="foreignkey" + ) + + if any( + constraint["name"] == "fk_index_attempt_connector_id" + for constraint in constraints + ): + op.drop_constraint( + "fk_index_attempt_connector_id", "index_attempt", type_="foreignkey" + ) + op.drop_column("index_attempt", "credential_id") op.drop_column("index_attempt", "connector_id") op.drop_table("connector_credential_pair") diff --git a/backend/alembic/versions/76b60d407dfb_cc_pair_name_not_unique.py b/backend/alembic/versions/76b60d407dfb_cc_pair_name_not_unique.py index c609ca4ae0..1dfbb9365d 100644 --- a/backend/alembic/versions/76b60d407dfb_cc_pair_name_not_unique.py +++ b/backend/alembic/versions/76b60d407dfb_cc_pair_name_not_unique.py @@ -28,5 +28,9 @@ def upgrade() -> None: def downgrade() -> None: - # This wasn't really required by the code either, no good reason to make it unique again - pass + op.create_unique_constraint( + "connector_credential_pair__name__key", "connector_credential_pair", ["name"] + ) + op.alter_column( + "connector_credential_pair", "name", existing_type=sa.String(), nullable=True + ) diff --git a/backend/alembic/versions/d5645c915d0e_remove_deletion_attempt_table.py b/backend/alembic/versions/d5645c915d0e_remove_deletion_attempt_table.py index aa4e7c71c1..5ef63ed331 100644 --- a/backend/alembic/versions/d5645c915d0e_remove_deletion_attempt_table.py +++ b/backend/alembic/versions/d5645c915d0e_remove_deletion_attempt_table.py @@ -19,6 +19,9 @@ depends_on: None = None def upgrade() -> None: op.drop_table("deletion_attempt") + # Remove the DeletionStatus enum + op.execute("DROP TYPE IF EXISTS deletionstatus;") + def downgrade() -> None: op.create_table( diff --git a/backend/alembic/versions/dbaa756c2ccf_embedding_models.py b/backend/alembic/versions/dbaa756c2ccf_embedding_models.py index a7c9b8f5ae..6b7302d327 100644 --- a/backend/alembic/versions/dbaa756c2ccf_embedding_models.py +++ b/backend/alembic/versions/dbaa756c2ccf_embedding_models.py @@ -136,4 +136,4 @@ def downgrade() -> None: ) op.drop_column("index_attempt", "embedding_model_id") op.drop_table("embedding_model") - op.execute("DROP TYPE indexmodelstatus;") + op.execute("DROP TYPE IF EXISTS indexmodelstatus;") diff --git a/backend/danswer/db/document.py b/backend/danswer/db/document.py index befb867574..80281c38b0 100644 --- a/backend/danswer/db/document.py +++ b/backend/danswer/db/document.py @@ -311,7 +311,7 @@ def acquire_document_locks(db_session: Session, document_ids: list[str]) -> bool _NUM_LOCK_ATTEMPTS = 10 -_LOCK_RETRY_DELAY = 30 +_LOCK_RETRY_DELAY = 10 @contextlib.contextmanager diff --git a/backend/danswer/main.py b/backend/danswer/main.py index 5cfe1c1c14..52e6d92f13 100644 --- a/backend/danswer/main.py +++ b/backend/danswer/main.py @@ -47,10 +47,12 @@ from danswer.db.engine import init_sqlalchemy_engine from danswer.db.engine import warm_up_connections from danswer.db.index_attempt import cancel_indexing_attempts_past_model from danswer.db.index_attempt import expire_index_attempts +from danswer.db.models import EmbeddingModel from danswer.db.persona import delete_old_default_personas from danswer.db.standard_answer import create_initial_default_standard_answer_category from danswer.db.swap_index import check_index_swap from danswer.document_index.factory import get_default_document_index +from danswer.document_index.interfaces import DocumentIndex from danswer.llm.llm_initialization import load_llm_providers from danswer.natural_language_processing.search_nlp_models import warm_up_encoders from danswer.search.retrieval.search_runner import download_nltk_data @@ -158,6 +160,49 @@ def include_router_with_global_prefix_prepended( application.include_router(router, **final_kwargs) +def setup_postgres(db_session: Session) -> None: + logger.info("Verifying default connector/credential exist.") + create_initial_public_credential(db_session) + create_initial_default_connector(db_session) + associate_default_cc_pair(db_session) + + logger.info("Verifying default standard answer category exists.") + create_initial_default_standard_answer_category(db_session) + + logger.info("Loading LLM providers from env variables") + load_llm_providers(db_session) + + logger.info("Loading default Prompts and Personas") + delete_old_default_personas(db_session) + load_chat_yamls() + + logger.info("Loading built-in tools") + load_builtin_tools(db_session) + refresh_built_in_tools_cache(db_session) + auto_add_search_tool_to_personas(db_session) + + +def setup_vespa( + document_index: DocumentIndex, + db_embedding_model: EmbeddingModel, + secondary_db_embedding_model: EmbeddingModel | None, +) -> None: + # Vespa startup is a bit slow, so give it a few seconds + wait_time = 5 + for _ in range(5): + try: + document_index.ensure_indices_exist( + index_embedding_dim=db_embedding_model.model_dim, + secondary_index_embedding_dim=secondary_db_embedding_model.model_dim + if secondary_db_embedding_model + else None, + ) + break + except Exception: + logger.info(f"Waiting on Vespa, retrying in {wait_time} seconds...") + time.sleep(wait_time) + + @asynccontextmanager async def lifespan(app: FastAPI) -> AsyncGenerator: init_sqlalchemy_engine(POSTGRES_WEB_APP_NAME) @@ -213,26 +258,10 @@ async def lifespan(app: FastAPI) -> AsyncGenerator: logger.info("Verifying query preprocessing (NLTK) data is downloaded") download_nltk_data() - logger.info("Verifying default connector/credential exist.") - create_initial_public_credential(db_session) - create_initial_default_connector(db_session) - associate_default_cc_pair(db_session) - - logger.info("Verifying default standard answer category exists.") - create_initial_default_standard_answer_category(db_session) - - logger.info("Loading LLM providers from env variables") - load_llm_providers(db_session) - - logger.info("Loading default Prompts and Personas") - delete_old_default_personas(db_session) - load_chat_yamls() - - logger.info("Loading built-in tools") - load_builtin_tools(db_session) - refresh_built_in_tools_cache(db_session) - auto_add_search_tool_to_personas(db_session) + # setup Postgres with default credential, llm providers, etc. + setup_postgres(db_session) + # ensure Vespa is setup correctly logger.info("Verifying Document Index(s) is/are available.") document_index = get_default_document_index( primary_index_name=db_embedding_model.index_name, @@ -240,20 +269,7 @@ async def lifespan(app: FastAPI) -> AsyncGenerator: if secondary_db_embedding_model else None, ) - # Vespa startup is a bit slow, so give it a few seconds - wait_time = 5 - for attempt in range(5): - try: - document_index.ensure_indices_exist( - index_embedding_dim=db_embedding_model.model_dim, - secondary_index_embedding_dim=secondary_db_embedding_model.model_dim - if secondary_db_embedding_model - else None, - ) - break - except Exception: - logger.info(f"Waiting on Vespa, retrying in {wait_time} seconds...") - time.sleep(wait_time) + setup_vespa(document_index, db_embedding_model, secondary_db_embedding_model) logger.info(f"Model Server: http://{MODEL_SERVER_HOST}:{MODEL_SERVER_PORT}") if db_embedding_model.cloud_provider_id is None: diff --git a/backend/danswer/server/danswer_api/ingestion.py b/backend/danswer/server/danswer_api/ingestion.py index 9127b260d6..45240abf8f 100644 --- a/backend/danswer/server/danswer_api/ingestion.py +++ b/backend/danswer/server/danswer_api/ingestion.py @@ -12,6 +12,7 @@ from danswer.db.document import get_ingestion_documents from danswer.db.embedding_model import get_current_db_embedding_model from danswer.db.embedding_model import get_secondary_db_embedding_model from danswer.db.engine import get_session +from danswer.db.models import User from danswer.document_index.document_index_utils import get_both_index_names from danswer.document_index.factory import get_default_document_index from danswer.indexing.embedder import DefaultIndexingEmbedder @@ -31,7 +32,7 @@ router = APIRouter(prefix="/danswer-api") @router.get("/connector-docs/{cc_pair_id}") def get_docs_by_connector_credential_pair( cc_pair_id: int, - _: str = Depends(api_key_dep), + _: User | None = Depends(api_key_dep), db_session: Session = Depends(get_session), ) -> list[DocMinimalInfo]: db_docs = get_documents_by_cc_pair(cc_pair_id=cc_pair_id, db_session=db_session) @@ -47,7 +48,7 @@ def get_docs_by_connector_credential_pair( @router.get("/ingestion") def get_ingestion_docs( - _: str = Depends(api_key_dep), + _: User | None = Depends(api_key_dep), db_session: Session = Depends(get_session), ) -> list[DocMinimalInfo]: db_docs = get_ingestion_documents(db_session) @@ -64,7 +65,7 @@ def get_ingestion_docs( @router.post("/ingestion") def upsert_ingestion_doc( doc_info: IngestionDocument, - _: str = Depends(api_key_dep), + _: User | None = Depends(api_key_dep), db_session: Session = Depends(get_session), ) -> IngestionResult: doc_info.document.from_ingestion_api = True diff --git a/backend/danswer/server/gpts/api.py b/backend/danswer/server/gpts/api.py index a3ce59edc3..1bebc3bfc1 100644 --- a/backend/danswer/server/gpts/api.py +++ b/backend/danswer/server/gpts/api.py @@ -7,6 +7,7 @@ from pydantic import BaseModel from sqlalchemy.orm import Session from danswer.db.engine import get_session +from danswer.db.models import User from danswer.llm.factory import get_default_llms from danswer.search.models import SearchRequest from danswer.search.pipeline import SearchPipeline @@ -64,7 +65,7 @@ class GptSearchResponse(BaseModel): @router.post("/gpt-document-search") def gpt_search( search_request: GptSearchRequest, - _: str | None = Depends(api_key_dep), + _: User | None = Depends(api_key_dep), db_session: Session = Depends(get_session), ) -> GptSearchResponse: llm, fast_llm = get_default_llms() diff --git a/backend/ee/danswer/auth/users.py b/backend/ee/danswer/auth/users.py index f5f5dbd58f..e66953fbd0 100644 --- a/backend/ee/danswer/auth/users.py +++ b/backend/ee/danswer/auth/users.py @@ -44,7 +44,12 @@ async def optional_user_( return user -def api_key_dep(request: Request, db_session: Session = Depends(get_session)) -> User: +def api_key_dep( + request: Request, db_session: Session = Depends(get_session) +) -> User | None: + if AUTH_TYPE == AuthType.DISABLED: + return None + hashed_api_key = get_hashed_api_key_from_request(request) if not hashed_api_key: raise HTTPException(status_code=401, detail="Missing API key") diff --git a/backend/tests/integration/common/constants.py b/backend/tests/integration/common/constants.py new file mode 100644 index 0000000000..304a31b626 --- /dev/null +++ b/backend/tests/integration/common/constants.py @@ -0,0 +1 @@ +API_SERVER_URL = "http://localhost:8080" diff --git a/backend/tests/integration/common/reset.py b/backend/tests/integration/common/reset.py new file mode 100644 index 0000000000..56760a74c5 --- /dev/null +++ b/backend/tests/integration/common/reset.py @@ -0,0 +1,164 @@ +import logging +import time + +import psycopg2 +import requests + +from alembic import command +from alembic.config import Config +from danswer.configs.app_configs import POSTGRES_HOST +from danswer.configs.app_configs import POSTGRES_PASSWORD +from danswer.configs.app_configs import POSTGRES_PORT +from danswer.configs.app_configs import POSTGRES_USER +from danswer.db.embedding_model import get_current_db_embedding_model +from danswer.db.engine import build_connection_string +from danswer.db.engine import get_session_context_manager +from danswer.db.engine import SYNC_DB_API +from danswer.db.swap_index import check_index_swap +from danswer.document_index.vespa.index import DOCUMENT_ID_ENDPOINT +from danswer.document_index.vespa.index import VespaIndex +from danswer.main import setup_postgres +from danswer.main import setup_vespa + + +def _run_migrations( + database_url: str, direction: str = "upgrade", revision: str = "head" +) -> None: + # hide info logs emitted during migration + logging.getLogger("alembic").setLevel(logging.CRITICAL) + + # Create an Alembic configuration object + alembic_cfg = Config("alembic.ini") + alembic_cfg.set_section_option("logger_alembic", "level", "WARN") + + # Set the SQLAlchemy URL in the Alembic configuration + alembic_cfg.set_main_option("sqlalchemy.url", database_url) + + # Run the migration + if direction == "upgrade": + command.upgrade(alembic_cfg, revision) + elif direction == "downgrade": + command.downgrade(alembic_cfg, revision) + else: + raise ValueError( + f"Invalid direction: {direction}. Must be 'upgrade' or 'downgrade'." + ) + + logging.getLogger("alembic").setLevel(logging.INFO) + + +def reset_postgres(database: str = "postgres") -> None: + """Reset the Postgres database.""" + + # NOTE: need to delete all rows to allow migrations to be rolled back + # as there are a few downgrades that don't properly handle data in tables + conn = psycopg2.connect( + dbname=database, + user=POSTGRES_USER, + password=POSTGRES_PASSWORD, + host=POSTGRES_HOST, + port=POSTGRES_PORT, + ) + cur = conn.cursor() + + # Disable triggers to prevent foreign key constraints from being checked + cur.execute("SET session_replication_role = 'replica';") + + # Fetch all table names in the current database + cur.execute( + """ + SELECT tablename + FROM pg_tables + WHERE schemaname = 'public' + """ + ) + + tables = cur.fetchall() + + for table in tables: + table_name = table[0] + + # Don't touch migration history + if table_name == "alembic_version": + continue + + cur.execute(f'DELETE FROM "{table_name}"') + + # Re-enable triggers + cur.execute("SET session_replication_role = 'origin';") + + conn.commit() + cur.close() + conn.close() + + # downgrade to base + upgrade back to head + conn_str = build_connection_string( + db=database, + user=POSTGRES_USER, + password=POSTGRES_PASSWORD, + host=POSTGRES_HOST, + port=POSTGRES_PORT, + db_api=SYNC_DB_API, + ) + _run_migrations( + conn_str, + direction="downgrade", + revision="base", + ) + _run_migrations( + conn_str, + direction="upgrade", + revision="head", + ) + + # do the same thing as we do on API server startup + with get_session_context_manager() as db_session: + setup_postgres(db_session) + + +def reset_vespa() -> None: + """Wipe all data from the Vespa index.""" + with get_session_context_manager() as db_session: + # swap to the correct default model + check_index_swap(db_session) + + current_model = get_current_db_embedding_model(db_session) + index_name = current_model.index_name + + setup_vespa( + document_index=VespaIndex(index_name=index_name, secondary_index_name=None), + db_embedding_model=current_model, + secondary_db_embedding_model=None, + ) + + for _ in range(5): + try: + continuation = None + should_continue = True + while should_continue: + params = {"selection": "true", "cluster": "danswer_index"} + if continuation: + params = {**params, "continuation": continuation} + response = requests.delete( + DOCUMENT_ID_ENDPOINT.format(index_name=index_name), params=params + ) + response.raise_for_status() + + response_json = response.json() + + continuation = response_json.get("continuation") + should_continue = bool(continuation) + + break + except Exception as e: + print(f"Error deleting documents: {e}") + time.sleep(5) + + +def reset_all() -> None: + """Reset both Postgres and Vespa.""" + print("Resetting Postgres...") + reset_postgres() + print("Resetting Vespa...") + reset_vespa() + print("Finished resetting all.") diff --git a/backend/tests/integration/common/seed_documents.py b/backend/tests/integration/common/seed_documents.py new file mode 100644 index 0000000000..3993aadf57 --- /dev/null +++ b/backend/tests/integration/common/seed_documents.py @@ -0,0 +1,83 @@ +import uuid +from typing import cast + +import requests +from pydantic import BaseModel + +from danswer.configs.constants import DocumentSource +from tests.integration.common.constants import API_SERVER_URL + + +class SeedDocumentResponse(BaseModel): + cc_pair_id: int + document_ids: list[str] + + +class TestDocumentClient: + @staticmethod + def seed_documents(num_docs: int = 5) -> SeedDocumentResponse: + unique_id = uuid.uuid4() + + # Create a connector + connector_name = f"test_connector_{unique_id}" + connector_data = { + "name": connector_name, + "source": DocumentSource.NOT_APPLICABLE, + "input_type": "load_state", + "connector_specific_config": {}, + "refresh_freq": 60, + "disabled": True, + } + response = requests.post( + f"{API_SERVER_URL}/manage/admin/connector", + json=connector_data, + ) + response.raise_for_status() + connector_id = response.json()["id"] + + # Associate the credential with the connector + cc_pair_metadata = {"name": f"test_cc_pair_{unique_id}", "is_public": True} + response = requests.put( + f"{API_SERVER_URL}/manage/connector/{connector_id}/credential/0", + json=cc_pair_metadata, + ) + response.raise_for_status() + cc_pair_id = cast(int, response.json()["data"]) + + # Create and ingest some documents + document_ids: list[str] = [] + for _ in range(num_docs): + document_id = f"test-doc-{uuid.uuid4()}" + document_ids.append(document_id) + + document = { + "document": { + "id": document_id, + "sections": [ + { + "text": f"This is test document {document_id}", + "link": f"{document_id}", + } + ], + "source": DocumentSource.NOT_APPLICABLE, + "metadata": {}, + "semantic_identifier": f"Test Document {document_id}", + "from_ingestion_api": True, + }, + "cc_pair_id": cc_pair_id, + } + response = requests.post( + f"{API_SERVER_URL}/danswer-api/ingestion", + json=document, + ) + response.raise_for_status() + + print("Seeding completed successfully.") + return SeedDocumentResponse( + cc_pair_id=cc_pair_id, + document_ids=document_ids, + ) + + +if __name__ == "__main__": + seed_documents_resp = TestDocumentClient.seed_documents() diff --git a/backend/tests/integration/common/vespa.py b/backend/tests/integration/common/vespa.py new file mode 100644 index 0000000000..aff7ef5eca --- /dev/null +++ b/backend/tests/integration/common/vespa.py @@ -0,0 +1,27 @@ +import requests + +from danswer.document_index.vespa.index import DOCUMENT_ID_ENDPOINT + + +class TestVespaClient: + def __init__(self, index_name: str): + self.index_name = index_name + self.vespa_document_url = DOCUMENT_ID_ENDPOINT.format(index_name=index_name) + + def get_documents_by_id( + self, document_ids: list[str], wanted_doc_count: int = 1_000 + ) -> dict: + selection = " or ".join( + f"{self.index_name}.document_id=='{document_id}'" + for document_id in document_ids + ) + params = { + "selection": selection, + "wantedDocumentCount": wanted_doc_count, + } + response = requests.get( + self.vespa_document_url, + params=params, # type: ignore + ) + response.raise_for_status() + return response.json() diff --git a/backend/tests/integration/conftest.py b/backend/tests/integration/conftest.py new file mode 100644 index 0000000000..d56bf99142 --- /dev/null +++ b/backend/tests/integration/conftest.py @@ -0,0 +1,26 @@ +from collections.abc import Generator + +import pytest +from sqlalchemy.orm import Session + +from danswer.db.embedding_model import get_current_db_embedding_model +from danswer.db.engine import get_session_context_manager +from tests.integration.common.reset import reset_all +from tests.integration.common.vespa import TestVespaClient + + +@pytest.fixture +def db_session() -> Generator[Session, None, None]: + with get_session_context_manager() as session: + yield session + + +@pytest.fixture +def vespa_client(db_session: Session) -> TestVespaClient: + current_model = get_current_db_embedding_model(db_session) + return TestVespaClient(index_name=current_model.index_name) + + +@pytest.fixture +def reset() -> None: + reset_all() diff --git a/backend/tests/integration/document_set/test_syncing.py b/backend/tests/integration/document_set/test_syncing.py new file mode 100644 index 0000000000..797f5a49d5 --- /dev/null +++ b/backend/tests/integration/document_set/test_syncing.py @@ -0,0 +1,79 @@ +import time + +from danswer.server.features.document_set.models import DocumentSetCreationRequest +from tests.integration.common.seed_documents import TestDocumentClient +from tests.integration.common.vespa import TestVespaClient +from tests.integration.document_set.utils import create_document_set +from tests.integration.document_set.utils import fetch_document_sets + + +def test_multiple_document_sets_syncing_same_connnector( + reset: None, vespa_client: TestVespaClient +) -> None: + # Seed documents + seed_result = TestDocumentClient.seed_documents(num_docs=5) + cc_pair_id = seed_result.cc_pair_id + + # Create first document set + doc_set_1_id = create_document_set( + DocumentSetCreationRequest( + name="Test Document Set 1", + description="First test document set", + cc_pair_ids=[cc_pair_id], + is_public=True, + users=[], + groups=[], + ) + ) + + doc_set_2_id = create_document_set( + DocumentSetCreationRequest( + name="Test Document Set 2", + description="Second test document set", + cc_pair_ids=[cc_pair_id], + is_public=True, + users=[], + groups=[], + ) + ) + + # wait for syncing to be complete + max_delay = 45 + start = time.time() + while True: + doc_sets = fetch_document_sets() + doc_set_1 = next( + (doc_set for doc_set in doc_sets if doc_set.id == doc_set_1_id), None + ) + doc_set_2 = next( + (doc_set for doc_set in doc_sets if doc_set.id == doc_set_2_id), None + ) + + if not doc_set_1 or not doc_set_2: + raise RuntimeError("Document set not found") + + if doc_set_1.is_up_to_date and doc_set_2.is_up_to_date: + assert [ccp.id for ccp in doc_set_1.cc_pair_descriptors] == [ + ccp.id for ccp in doc_set_2.cc_pair_descriptors + ] + break + + if time.time() - start > max_delay: + raise TimeoutError("Document sets were not synced within the max delay") + + time.sleep(2) + + # get names so we can compare to what is in vespa + doc_sets = fetch_document_sets() + doc_set_names = {doc_set.name for doc_set in doc_sets} + + # make sure documents are as expected + result = vespa_client.get_documents_by_id(seed_result.document_ids) + documents = result["documents"] + assert len(documents) == len(seed_result.document_ids) + assert all( + doc["fields"]["document_id"] in seed_result.document_ids for doc in documents + ) + assert all( + set(doc["fields"]["document_sets"].keys()) == doc_set_names for doc in documents + ) diff --git a/backend/tests/integration/document_set/utils.py b/backend/tests/integration/document_set/utils.py new file mode 100644 index 0000000000..c28a0f02fd --- /dev/null +++ b/backend/tests/integration/document_set/utils.py @@ -0,0 +1,26 @@ +from typing import cast + +import requests + +from danswer.server.features.document_set.models import DocumentSet +from danswer.server.features.document_set.models import DocumentSetCreationRequest +from tests.integration.common.constants import API_SERVER_URL + + +def create_document_set(doc_set_creation_request: DocumentSetCreationRequest) -> int: + response = requests.post( + f"{API_SERVER_URL}/manage/admin/document-set", + json=doc_set_creation_request.dict(), + ) + response.raise_for_status() + return cast(int, response.json()) + + +def fetch_document_sets() -> list[DocumentSet]: + response = requests.get(f"{API_SERVER_URL}/manage/admin/document-set") + response.raise_for_status() + + document_sets = [ + DocumentSet.parse_obj(doc_set_data) for doc_set_data in response.json() + ] + return document_sets