Add integration tests for document set syncing (#1904)

This commit is contained in:
Chris Weaver
2024-08-06 18:00:19 -07:00
committed by GitHub
parent fcc4c30ead
commit bb7e1d6e55
17 changed files with 501 additions and 49 deletions

View File

@ -79,7 +79,7 @@ def downgrade() -> None:
)
op.create_foreign_key(
"document_retrieval_feedback__chat_message_fk",
"document_retrieval",
"document_retrieval_feedback",
"chat_message",
["chat_message_id"],
["id"],

View File

@ -160,12 +160,28 @@ def downgrade() -> None:
nullable=False,
),
)
op.drop_constraint(
"fk_index_attempt_credential_id", "index_attempt", type_="foreignkey"
)
op.drop_constraint(
"fk_index_attempt_connector_id", "index_attempt", type_="foreignkey"
)
# Check if the constraint exists before dropping
conn = op.get_bind()
inspector = sa.inspect(conn)
constraints = inspector.get_foreign_keys("index_attempt")
if any(
constraint["name"] == "fk_index_attempt_credential_id"
for constraint in constraints
):
op.drop_constraint(
"fk_index_attempt_credential_id", "index_attempt", type_="foreignkey"
)
if any(
constraint["name"] == "fk_index_attempt_connector_id"
for constraint in constraints
):
op.drop_constraint(
"fk_index_attempt_connector_id", "index_attempt", type_="foreignkey"
)
op.drop_column("index_attempt", "credential_id")
op.drop_column("index_attempt", "connector_id")
op.drop_table("connector_credential_pair")

View File

@ -28,5 +28,9 @@ def upgrade() -> None:
def downgrade() -> None:
# This wasn't really required by the code either, no good reason to make it unique again
pass
op.create_unique_constraint(
"connector_credential_pair__name__key", "connector_credential_pair", ["name"]
)
op.alter_column(
"connector_credential_pair", "name", existing_type=sa.String(), nullable=True
)

View File

@ -19,6 +19,9 @@ depends_on: None = None
def upgrade() -> None:
op.drop_table("deletion_attempt")
# Remove the DeletionStatus enum
op.execute("DROP TYPE IF EXISTS deletionstatus;")
def downgrade() -> None:
op.create_table(

View File

@ -136,4 +136,4 @@ def downgrade() -> None:
)
op.drop_column("index_attempt", "embedding_model_id")
op.drop_table("embedding_model")
op.execute("DROP TYPE indexmodelstatus;")
op.execute("DROP TYPE IF EXISTS indexmodelstatus;")

View File

@ -311,7 +311,7 @@ def acquire_document_locks(db_session: Session, document_ids: list[str]) -> bool
_NUM_LOCK_ATTEMPTS = 10
_LOCK_RETRY_DELAY = 30
_LOCK_RETRY_DELAY = 10
@contextlib.contextmanager

View File

@ -47,10 +47,12 @@ from danswer.db.engine import init_sqlalchemy_engine
from danswer.db.engine import warm_up_connections
from danswer.db.index_attempt import cancel_indexing_attempts_past_model
from danswer.db.index_attempt import expire_index_attempts
from danswer.db.models import EmbeddingModel
from danswer.db.persona import delete_old_default_personas
from danswer.db.standard_answer import create_initial_default_standard_answer_category
from danswer.db.swap_index import check_index_swap
from danswer.document_index.factory import get_default_document_index
from danswer.document_index.interfaces import DocumentIndex
from danswer.llm.llm_initialization import load_llm_providers
from danswer.natural_language_processing.search_nlp_models import warm_up_encoders
from danswer.search.retrieval.search_runner import download_nltk_data
@ -158,6 +160,49 @@ def include_router_with_global_prefix_prepended(
application.include_router(router, **final_kwargs)
def setup_postgres(db_session: Session) -> None:
logger.info("Verifying default connector/credential exist.")
create_initial_public_credential(db_session)
create_initial_default_connector(db_session)
associate_default_cc_pair(db_session)
logger.info("Verifying default standard answer category exists.")
create_initial_default_standard_answer_category(db_session)
logger.info("Loading LLM providers from env variables")
load_llm_providers(db_session)
logger.info("Loading default Prompts and Personas")
delete_old_default_personas(db_session)
load_chat_yamls()
logger.info("Loading built-in tools")
load_builtin_tools(db_session)
refresh_built_in_tools_cache(db_session)
auto_add_search_tool_to_personas(db_session)
def setup_vespa(
document_index: DocumentIndex,
db_embedding_model: EmbeddingModel,
secondary_db_embedding_model: EmbeddingModel | None,
) -> None:
# Vespa startup is a bit slow, so give it a few seconds
wait_time = 5
for _ in range(5):
try:
document_index.ensure_indices_exist(
index_embedding_dim=db_embedding_model.model_dim,
secondary_index_embedding_dim=secondary_db_embedding_model.model_dim
if secondary_db_embedding_model
else None,
)
break
except Exception:
logger.info(f"Waiting on Vespa, retrying in {wait_time} seconds...")
time.sleep(wait_time)
@asynccontextmanager
async def lifespan(app: FastAPI) -> AsyncGenerator:
init_sqlalchemy_engine(POSTGRES_WEB_APP_NAME)
@ -213,26 +258,10 @@ async def lifespan(app: FastAPI) -> AsyncGenerator:
logger.info("Verifying query preprocessing (NLTK) data is downloaded")
download_nltk_data()
logger.info("Verifying default connector/credential exist.")
create_initial_public_credential(db_session)
create_initial_default_connector(db_session)
associate_default_cc_pair(db_session)
logger.info("Verifying default standard answer category exists.")
create_initial_default_standard_answer_category(db_session)
logger.info("Loading LLM providers from env variables")
load_llm_providers(db_session)
logger.info("Loading default Prompts and Personas")
delete_old_default_personas(db_session)
load_chat_yamls()
logger.info("Loading built-in tools")
load_builtin_tools(db_session)
refresh_built_in_tools_cache(db_session)
auto_add_search_tool_to_personas(db_session)
# setup Postgres with default credential, llm providers, etc.
setup_postgres(db_session)
# ensure Vespa is setup correctly
logger.info("Verifying Document Index(s) is/are available.")
document_index = get_default_document_index(
primary_index_name=db_embedding_model.index_name,
@ -240,20 +269,7 @@ async def lifespan(app: FastAPI) -> AsyncGenerator:
if secondary_db_embedding_model
else None,
)
# Vespa startup is a bit slow, so give it a few seconds
wait_time = 5
for attempt in range(5):
try:
document_index.ensure_indices_exist(
index_embedding_dim=db_embedding_model.model_dim,
secondary_index_embedding_dim=secondary_db_embedding_model.model_dim
if secondary_db_embedding_model
else None,
)
break
except Exception:
logger.info(f"Waiting on Vespa, retrying in {wait_time} seconds...")
time.sleep(wait_time)
setup_vespa(document_index, db_embedding_model, secondary_db_embedding_model)
logger.info(f"Model Server: http://{MODEL_SERVER_HOST}:{MODEL_SERVER_PORT}")
if db_embedding_model.cloud_provider_id is None:

View File

@ -12,6 +12,7 @@ from danswer.db.document import get_ingestion_documents
from danswer.db.embedding_model import get_current_db_embedding_model
from danswer.db.embedding_model import get_secondary_db_embedding_model
from danswer.db.engine import get_session
from danswer.db.models import User
from danswer.document_index.document_index_utils import get_both_index_names
from danswer.document_index.factory import get_default_document_index
from danswer.indexing.embedder import DefaultIndexingEmbedder
@ -31,7 +32,7 @@ router = APIRouter(prefix="/danswer-api")
@router.get("/connector-docs/{cc_pair_id}")
def get_docs_by_connector_credential_pair(
cc_pair_id: int,
_: str = Depends(api_key_dep),
_: User | None = Depends(api_key_dep),
db_session: Session = Depends(get_session),
) -> list[DocMinimalInfo]:
db_docs = get_documents_by_cc_pair(cc_pair_id=cc_pair_id, db_session=db_session)
@ -47,7 +48,7 @@ def get_docs_by_connector_credential_pair(
@router.get("/ingestion")
def get_ingestion_docs(
_: str = Depends(api_key_dep),
_: User | None = Depends(api_key_dep),
db_session: Session = Depends(get_session),
) -> list[DocMinimalInfo]:
db_docs = get_ingestion_documents(db_session)
@ -64,7 +65,7 @@ def get_ingestion_docs(
@router.post("/ingestion")
def upsert_ingestion_doc(
doc_info: IngestionDocument,
_: str = Depends(api_key_dep),
_: User | None = Depends(api_key_dep),
db_session: Session = Depends(get_session),
) -> IngestionResult:
doc_info.document.from_ingestion_api = True

View File

@ -7,6 +7,7 @@ from pydantic import BaseModel
from sqlalchemy.orm import Session
from danswer.db.engine import get_session
from danswer.db.models import User
from danswer.llm.factory import get_default_llms
from danswer.search.models import SearchRequest
from danswer.search.pipeline import SearchPipeline
@ -64,7 +65,7 @@ class GptSearchResponse(BaseModel):
@router.post("/gpt-document-search")
def gpt_search(
search_request: GptSearchRequest,
_: str | None = Depends(api_key_dep),
_: User | None = Depends(api_key_dep),
db_session: Session = Depends(get_session),
) -> GptSearchResponse:
llm, fast_llm = get_default_llms()

View File

@ -44,7 +44,12 @@ async def optional_user_(
return user
def api_key_dep(request: Request, db_session: Session = Depends(get_session)) -> User:
def api_key_dep(
request: Request, db_session: Session = Depends(get_session)
) -> User | None:
if AUTH_TYPE == AuthType.DISABLED:
return None
hashed_api_key = get_hashed_api_key_from_request(request)
if not hashed_api_key:
raise HTTPException(status_code=401, detail="Missing API key")

View File

@ -0,0 +1 @@
API_SERVER_URL = "http://localhost:8080"

View File

@ -0,0 +1,164 @@
import logging
import time
import psycopg2
import requests
from alembic import command
from alembic.config import Config
from danswer.configs.app_configs import POSTGRES_HOST
from danswer.configs.app_configs import POSTGRES_PASSWORD
from danswer.configs.app_configs import POSTGRES_PORT
from danswer.configs.app_configs import POSTGRES_USER
from danswer.db.embedding_model import get_current_db_embedding_model
from danswer.db.engine import build_connection_string
from danswer.db.engine import get_session_context_manager
from danswer.db.engine import SYNC_DB_API
from danswer.db.swap_index import check_index_swap
from danswer.document_index.vespa.index import DOCUMENT_ID_ENDPOINT
from danswer.document_index.vespa.index import VespaIndex
from danswer.main import setup_postgres
from danswer.main import setup_vespa
def _run_migrations(
database_url: str, direction: str = "upgrade", revision: str = "head"
) -> None:
# hide info logs emitted during migration
logging.getLogger("alembic").setLevel(logging.CRITICAL)
# Create an Alembic configuration object
alembic_cfg = Config("alembic.ini")
alembic_cfg.set_section_option("logger_alembic", "level", "WARN")
# Set the SQLAlchemy URL in the Alembic configuration
alembic_cfg.set_main_option("sqlalchemy.url", database_url)
# Run the migration
if direction == "upgrade":
command.upgrade(alembic_cfg, revision)
elif direction == "downgrade":
command.downgrade(alembic_cfg, revision)
else:
raise ValueError(
f"Invalid direction: {direction}. Must be 'upgrade' or 'downgrade'."
)
logging.getLogger("alembic").setLevel(logging.INFO)
def reset_postgres(database: str = "postgres") -> None:
"""Reset the Postgres database."""
# NOTE: need to delete all rows to allow migrations to be rolled back
# as there are a few downgrades that don't properly handle data in tables
conn = psycopg2.connect(
dbname=database,
user=POSTGRES_USER,
password=POSTGRES_PASSWORD,
host=POSTGRES_HOST,
port=POSTGRES_PORT,
)
cur = conn.cursor()
# Disable triggers to prevent foreign key constraints from being checked
cur.execute("SET session_replication_role = 'replica';")
# Fetch all table names in the current database
cur.execute(
"""
SELECT tablename
FROM pg_tables
WHERE schemaname = 'public'
"""
)
tables = cur.fetchall()
for table in tables:
table_name = table[0]
# Don't touch migration history
if table_name == "alembic_version":
continue
cur.execute(f'DELETE FROM "{table_name}"')
# Re-enable triggers
cur.execute("SET session_replication_role = 'origin';")
conn.commit()
cur.close()
conn.close()
# downgrade to base + upgrade back to head
conn_str = build_connection_string(
db=database,
user=POSTGRES_USER,
password=POSTGRES_PASSWORD,
host=POSTGRES_HOST,
port=POSTGRES_PORT,
db_api=SYNC_DB_API,
)
_run_migrations(
conn_str,
direction="downgrade",
revision="base",
)
_run_migrations(
conn_str,
direction="upgrade",
revision="head",
)
# do the same thing as we do on API server startup
with get_session_context_manager() as db_session:
setup_postgres(db_session)
def reset_vespa() -> None:
"""Wipe all data from the Vespa index."""
with get_session_context_manager() as db_session:
# swap to the correct default model
check_index_swap(db_session)
current_model = get_current_db_embedding_model(db_session)
index_name = current_model.index_name
setup_vespa(
document_index=VespaIndex(index_name=index_name, secondary_index_name=None),
db_embedding_model=current_model,
secondary_db_embedding_model=None,
)
for _ in range(5):
try:
continuation = None
should_continue = True
while should_continue:
params = {"selection": "true", "cluster": "danswer_index"}
if continuation:
params = {**params, "continuation": continuation}
response = requests.delete(
DOCUMENT_ID_ENDPOINT.format(index_name=index_name), params=params
)
response.raise_for_status()
response_json = response.json()
continuation = response_json.get("continuation")
should_continue = bool(continuation)
break
except Exception as e:
print(f"Error deleting documents: {e}")
time.sleep(5)
def reset_all() -> None:
"""Reset both Postgres and Vespa."""
print("Resetting Postgres...")
reset_postgres()
print("Resetting Vespa...")
reset_vespa()
print("Finished resetting all.")

View File

@ -0,0 +1,83 @@
import uuid
from typing import cast
import requests
from pydantic import BaseModel
from danswer.configs.constants import DocumentSource
from tests.integration.common.constants import API_SERVER_URL
class SeedDocumentResponse(BaseModel):
cc_pair_id: int
document_ids: list[str]
class TestDocumentClient:
@staticmethod
def seed_documents(num_docs: int = 5) -> SeedDocumentResponse:
unique_id = uuid.uuid4()
# Create a connector
connector_name = f"test_connector_{unique_id}"
connector_data = {
"name": connector_name,
"source": DocumentSource.NOT_APPLICABLE,
"input_type": "load_state",
"connector_specific_config": {},
"refresh_freq": 60,
"disabled": True,
}
response = requests.post(
f"{API_SERVER_URL}/manage/admin/connector",
json=connector_data,
)
response.raise_for_status()
connector_id = response.json()["id"]
# Associate the credential with the connector
cc_pair_metadata = {"name": f"test_cc_pair_{unique_id}", "is_public": True}
response = requests.put(
f"{API_SERVER_URL}/manage/connector/{connector_id}/credential/0",
json=cc_pair_metadata,
)
response.raise_for_status()
cc_pair_id = cast(int, response.json()["data"])
# Create and ingest some documents
document_ids: list[str] = []
for _ in range(num_docs):
document_id = f"test-doc-{uuid.uuid4()}"
document_ids.append(document_id)
document = {
"document": {
"id": document_id,
"sections": [
{
"text": f"This is test document {document_id}",
"link": f"{document_id}",
}
],
"source": DocumentSource.NOT_APPLICABLE,
"metadata": {},
"semantic_identifier": f"Test Document {document_id}",
"from_ingestion_api": True,
},
"cc_pair_id": cc_pair_id,
}
response = requests.post(
f"{API_SERVER_URL}/danswer-api/ingestion",
json=document,
)
response.raise_for_status()
print("Seeding completed successfully.")
return SeedDocumentResponse(
cc_pair_id=cc_pair_id,
document_ids=document_ids,
)
if __name__ == "__main__":
seed_documents_resp = TestDocumentClient.seed_documents()

View File

@ -0,0 +1,27 @@
import requests
from danswer.document_index.vespa.index import DOCUMENT_ID_ENDPOINT
class TestVespaClient:
def __init__(self, index_name: str):
self.index_name = index_name
self.vespa_document_url = DOCUMENT_ID_ENDPOINT.format(index_name=index_name)
def get_documents_by_id(
self, document_ids: list[str], wanted_doc_count: int = 1_000
) -> dict:
selection = " or ".join(
f"{self.index_name}.document_id=='{document_id}'"
for document_id in document_ids
)
params = {
"selection": selection,
"wantedDocumentCount": wanted_doc_count,
}
response = requests.get(
self.vespa_document_url,
params=params, # type: ignore
)
response.raise_for_status()
return response.json()

View File

@ -0,0 +1,26 @@
from collections.abc import Generator
import pytest
from sqlalchemy.orm import Session
from danswer.db.embedding_model import get_current_db_embedding_model
from danswer.db.engine import get_session_context_manager
from tests.integration.common.reset import reset_all
from tests.integration.common.vespa import TestVespaClient
@pytest.fixture
def db_session() -> Generator[Session, None, None]:
with get_session_context_manager() as session:
yield session
@pytest.fixture
def vespa_client(db_session: Session) -> TestVespaClient:
current_model = get_current_db_embedding_model(db_session)
return TestVespaClient(index_name=current_model.index_name)
@pytest.fixture
def reset() -> None:
reset_all()

View File

@ -0,0 +1,79 @@
import time
from danswer.server.features.document_set.models import DocumentSetCreationRequest
from tests.integration.common.seed_documents import TestDocumentClient
from tests.integration.common.vespa import TestVespaClient
from tests.integration.document_set.utils import create_document_set
from tests.integration.document_set.utils import fetch_document_sets
def test_multiple_document_sets_syncing_same_connnector(
reset: None, vespa_client: TestVespaClient
) -> None:
# Seed documents
seed_result = TestDocumentClient.seed_documents(num_docs=5)
cc_pair_id = seed_result.cc_pair_id
# Create first document set
doc_set_1_id = create_document_set(
DocumentSetCreationRequest(
name="Test Document Set 1",
description="First test document set",
cc_pair_ids=[cc_pair_id],
is_public=True,
users=[],
groups=[],
)
)
doc_set_2_id = create_document_set(
DocumentSetCreationRequest(
name="Test Document Set 2",
description="Second test document set",
cc_pair_ids=[cc_pair_id],
is_public=True,
users=[],
groups=[],
)
)
# wait for syncing to be complete
max_delay = 45
start = time.time()
while True:
doc_sets = fetch_document_sets()
doc_set_1 = next(
(doc_set for doc_set in doc_sets if doc_set.id == doc_set_1_id), None
)
doc_set_2 = next(
(doc_set for doc_set in doc_sets if doc_set.id == doc_set_2_id), None
)
if not doc_set_1 or not doc_set_2:
raise RuntimeError("Document set not found")
if doc_set_1.is_up_to_date and doc_set_2.is_up_to_date:
assert [ccp.id for ccp in doc_set_1.cc_pair_descriptors] == [
ccp.id for ccp in doc_set_2.cc_pair_descriptors
]
break
if time.time() - start > max_delay:
raise TimeoutError("Document sets were not synced within the max delay")
time.sleep(2)
# get names so we can compare to what is in vespa
doc_sets = fetch_document_sets()
doc_set_names = {doc_set.name for doc_set in doc_sets}
# make sure documents are as expected
result = vespa_client.get_documents_by_id(seed_result.document_ids)
documents = result["documents"]
assert len(documents) == len(seed_result.document_ids)
assert all(
doc["fields"]["document_id"] in seed_result.document_ids for doc in documents
)
assert all(
set(doc["fields"]["document_sets"].keys()) == doc_set_names for doc in documents
)

View File

@ -0,0 +1,26 @@
from typing import cast
import requests
from danswer.server.features.document_set.models import DocumentSet
from danswer.server.features.document_set.models import DocumentSetCreationRequest
from tests.integration.common.constants import API_SERVER_URL
def create_document_set(doc_set_creation_request: DocumentSetCreationRequest) -> int:
response = requests.post(
f"{API_SERVER_URL}/manage/admin/document-set",
json=doc_set_creation_request.dict(),
)
response.raise_for_status()
return cast(int, response.json())
def fetch_document_sets() -> list[DocumentSet]:
response = requests.get(f"{API_SERVER_URL}/manage/admin/document-set")
response.raise_for_status()
document_sets = [
DocumentSet.parse_obj(doc_set_data) for doc_set_data in response.json()
]
return document_sets