mirror of
https://github.com/danswer-ai/danswer.git
synced 2025-07-12 22:23:01 +02:00
Add integration tests for document set syncing (#1904)
This commit is contained in:
@ -79,7 +79,7 @@ def downgrade() -> None:
|
|||||||
)
|
)
|
||||||
op.create_foreign_key(
|
op.create_foreign_key(
|
||||||
"document_retrieval_feedback__chat_message_fk",
|
"document_retrieval_feedback__chat_message_fk",
|
||||||
"document_retrieval",
|
"document_retrieval_feedback",
|
||||||
"chat_message",
|
"chat_message",
|
||||||
["chat_message_id"],
|
["chat_message_id"],
|
||||||
["id"],
|
["id"],
|
||||||
|
@ -160,12 +160,28 @@ def downgrade() -> None:
|
|||||||
nullable=False,
|
nullable=False,
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
op.drop_constraint(
|
|
||||||
"fk_index_attempt_credential_id", "index_attempt", type_="foreignkey"
|
# Check if the constraint exists before dropping
|
||||||
)
|
conn = op.get_bind()
|
||||||
op.drop_constraint(
|
inspector = sa.inspect(conn)
|
||||||
"fk_index_attempt_connector_id", "index_attempt", type_="foreignkey"
|
constraints = inspector.get_foreign_keys("index_attempt")
|
||||||
)
|
|
||||||
|
if any(
|
||||||
|
constraint["name"] == "fk_index_attempt_credential_id"
|
||||||
|
for constraint in constraints
|
||||||
|
):
|
||||||
|
op.drop_constraint(
|
||||||
|
"fk_index_attempt_credential_id", "index_attempt", type_="foreignkey"
|
||||||
|
)
|
||||||
|
|
||||||
|
if any(
|
||||||
|
constraint["name"] == "fk_index_attempt_connector_id"
|
||||||
|
for constraint in constraints
|
||||||
|
):
|
||||||
|
op.drop_constraint(
|
||||||
|
"fk_index_attempt_connector_id", "index_attempt", type_="foreignkey"
|
||||||
|
)
|
||||||
|
|
||||||
op.drop_column("index_attempt", "credential_id")
|
op.drop_column("index_attempt", "credential_id")
|
||||||
op.drop_column("index_attempt", "connector_id")
|
op.drop_column("index_attempt", "connector_id")
|
||||||
op.drop_table("connector_credential_pair")
|
op.drop_table("connector_credential_pair")
|
||||||
|
@ -28,5 +28,9 @@ def upgrade() -> None:
|
|||||||
|
|
||||||
|
|
||||||
def downgrade() -> None:
|
def downgrade() -> None:
|
||||||
# This wasn't really required by the code either, no good reason to make it unique again
|
op.create_unique_constraint(
|
||||||
pass
|
"connector_credential_pair__name__key", "connector_credential_pair", ["name"]
|
||||||
|
)
|
||||||
|
op.alter_column(
|
||||||
|
"connector_credential_pair", "name", existing_type=sa.String(), nullable=True
|
||||||
|
)
|
||||||
|
@ -19,6 +19,9 @@ depends_on: None = None
|
|||||||
def upgrade() -> None:
|
def upgrade() -> None:
|
||||||
op.drop_table("deletion_attempt")
|
op.drop_table("deletion_attempt")
|
||||||
|
|
||||||
|
# Remove the DeletionStatus enum
|
||||||
|
op.execute("DROP TYPE IF EXISTS deletionstatus;")
|
||||||
|
|
||||||
|
|
||||||
def downgrade() -> None:
|
def downgrade() -> None:
|
||||||
op.create_table(
|
op.create_table(
|
||||||
|
@ -136,4 +136,4 @@ def downgrade() -> None:
|
|||||||
)
|
)
|
||||||
op.drop_column("index_attempt", "embedding_model_id")
|
op.drop_column("index_attempt", "embedding_model_id")
|
||||||
op.drop_table("embedding_model")
|
op.drop_table("embedding_model")
|
||||||
op.execute("DROP TYPE indexmodelstatus;")
|
op.execute("DROP TYPE IF EXISTS indexmodelstatus;")
|
||||||
|
@ -311,7 +311,7 @@ def acquire_document_locks(db_session: Session, document_ids: list[str]) -> bool
|
|||||||
|
|
||||||
|
|
||||||
_NUM_LOCK_ATTEMPTS = 10
|
_NUM_LOCK_ATTEMPTS = 10
|
||||||
_LOCK_RETRY_DELAY = 30
|
_LOCK_RETRY_DELAY = 10
|
||||||
|
|
||||||
|
|
||||||
@contextlib.contextmanager
|
@contextlib.contextmanager
|
||||||
|
@ -47,10 +47,12 @@ from danswer.db.engine import init_sqlalchemy_engine
|
|||||||
from danswer.db.engine import warm_up_connections
|
from danswer.db.engine import warm_up_connections
|
||||||
from danswer.db.index_attempt import cancel_indexing_attempts_past_model
|
from danswer.db.index_attempt import cancel_indexing_attempts_past_model
|
||||||
from danswer.db.index_attempt import expire_index_attempts
|
from danswer.db.index_attempt import expire_index_attempts
|
||||||
|
from danswer.db.models import EmbeddingModel
|
||||||
from danswer.db.persona import delete_old_default_personas
|
from danswer.db.persona import delete_old_default_personas
|
||||||
from danswer.db.standard_answer import create_initial_default_standard_answer_category
|
from danswer.db.standard_answer import create_initial_default_standard_answer_category
|
||||||
from danswer.db.swap_index import check_index_swap
|
from danswer.db.swap_index import check_index_swap
|
||||||
from danswer.document_index.factory import get_default_document_index
|
from danswer.document_index.factory import get_default_document_index
|
||||||
|
from danswer.document_index.interfaces import DocumentIndex
|
||||||
from danswer.llm.llm_initialization import load_llm_providers
|
from danswer.llm.llm_initialization import load_llm_providers
|
||||||
from danswer.natural_language_processing.search_nlp_models import warm_up_encoders
|
from danswer.natural_language_processing.search_nlp_models import warm_up_encoders
|
||||||
from danswer.search.retrieval.search_runner import download_nltk_data
|
from danswer.search.retrieval.search_runner import download_nltk_data
|
||||||
@ -158,6 +160,49 @@ def include_router_with_global_prefix_prepended(
|
|||||||
application.include_router(router, **final_kwargs)
|
application.include_router(router, **final_kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
def setup_postgres(db_session: Session) -> None:
|
||||||
|
logger.info("Verifying default connector/credential exist.")
|
||||||
|
create_initial_public_credential(db_session)
|
||||||
|
create_initial_default_connector(db_session)
|
||||||
|
associate_default_cc_pair(db_session)
|
||||||
|
|
||||||
|
logger.info("Verifying default standard answer category exists.")
|
||||||
|
create_initial_default_standard_answer_category(db_session)
|
||||||
|
|
||||||
|
logger.info("Loading LLM providers from env variables")
|
||||||
|
load_llm_providers(db_session)
|
||||||
|
|
||||||
|
logger.info("Loading default Prompts and Personas")
|
||||||
|
delete_old_default_personas(db_session)
|
||||||
|
load_chat_yamls()
|
||||||
|
|
||||||
|
logger.info("Loading built-in tools")
|
||||||
|
load_builtin_tools(db_session)
|
||||||
|
refresh_built_in_tools_cache(db_session)
|
||||||
|
auto_add_search_tool_to_personas(db_session)
|
||||||
|
|
||||||
|
|
||||||
|
def setup_vespa(
|
||||||
|
document_index: DocumentIndex,
|
||||||
|
db_embedding_model: EmbeddingModel,
|
||||||
|
secondary_db_embedding_model: EmbeddingModel | None,
|
||||||
|
) -> None:
|
||||||
|
# Vespa startup is a bit slow, so give it a few seconds
|
||||||
|
wait_time = 5
|
||||||
|
for _ in range(5):
|
||||||
|
try:
|
||||||
|
document_index.ensure_indices_exist(
|
||||||
|
index_embedding_dim=db_embedding_model.model_dim,
|
||||||
|
secondary_index_embedding_dim=secondary_db_embedding_model.model_dim
|
||||||
|
if secondary_db_embedding_model
|
||||||
|
else None,
|
||||||
|
)
|
||||||
|
break
|
||||||
|
except Exception:
|
||||||
|
logger.info(f"Waiting on Vespa, retrying in {wait_time} seconds...")
|
||||||
|
time.sleep(wait_time)
|
||||||
|
|
||||||
|
|
||||||
@asynccontextmanager
|
@asynccontextmanager
|
||||||
async def lifespan(app: FastAPI) -> AsyncGenerator:
|
async def lifespan(app: FastAPI) -> AsyncGenerator:
|
||||||
init_sqlalchemy_engine(POSTGRES_WEB_APP_NAME)
|
init_sqlalchemy_engine(POSTGRES_WEB_APP_NAME)
|
||||||
@ -213,26 +258,10 @@ async def lifespan(app: FastAPI) -> AsyncGenerator:
|
|||||||
logger.info("Verifying query preprocessing (NLTK) data is downloaded")
|
logger.info("Verifying query preprocessing (NLTK) data is downloaded")
|
||||||
download_nltk_data()
|
download_nltk_data()
|
||||||
|
|
||||||
logger.info("Verifying default connector/credential exist.")
|
# setup Postgres with default credential, llm providers, etc.
|
||||||
create_initial_public_credential(db_session)
|
setup_postgres(db_session)
|
||||||
create_initial_default_connector(db_session)
|
|
||||||
associate_default_cc_pair(db_session)
|
|
||||||
|
|
||||||
logger.info("Verifying default standard answer category exists.")
|
|
||||||
create_initial_default_standard_answer_category(db_session)
|
|
||||||
|
|
||||||
logger.info("Loading LLM providers from env variables")
|
|
||||||
load_llm_providers(db_session)
|
|
||||||
|
|
||||||
logger.info("Loading default Prompts and Personas")
|
|
||||||
delete_old_default_personas(db_session)
|
|
||||||
load_chat_yamls()
|
|
||||||
|
|
||||||
logger.info("Loading built-in tools")
|
|
||||||
load_builtin_tools(db_session)
|
|
||||||
refresh_built_in_tools_cache(db_session)
|
|
||||||
auto_add_search_tool_to_personas(db_session)
|
|
||||||
|
|
||||||
|
# ensure Vespa is setup correctly
|
||||||
logger.info("Verifying Document Index(s) is/are available.")
|
logger.info("Verifying Document Index(s) is/are available.")
|
||||||
document_index = get_default_document_index(
|
document_index = get_default_document_index(
|
||||||
primary_index_name=db_embedding_model.index_name,
|
primary_index_name=db_embedding_model.index_name,
|
||||||
@ -240,20 +269,7 @@ async def lifespan(app: FastAPI) -> AsyncGenerator:
|
|||||||
if secondary_db_embedding_model
|
if secondary_db_embedding_model
|
||||||
else None,
|
else None,
|
||||||
)
|
)
|
||||||
# Vespa startup is a bit slow, so give it a few seconds
|
setup_vespa(document_index, db_embedding_model, secondary_db_embedding_model)
|
||||||
wait_time = 5
|
|
||||||
for attempt in range(5):
|
|
||||||
try:
|
|
||||||
document_index.ensure_indices_exist(
|
|
||||||
index_embedding_dim=db_embedding_model.model_dim,
|
|
||||||
secondary_index_embedding_dim=secondary_db_embedding_model.model_dim
|
|
||||||
if secondary_db_embedding_model
|
|
||||||
else None,
|
|
||||||
)
|
|
||||||
break
|
|
||||||
except Exception:
|
|
||||||
logger.info(f"Waiting on Vespa, retrying in {wait_time} seconds...")
|
|
||||||
time.sleep(wait_time)
|
|
||||||
|
|
||||||
logger.info(f"Model Server: http://{MODEL_SERVER_HOST}:{MODEL_SERVER_PORT}")
|
logger.info(f"Model Server: http://{MODEL_SERVER_HOST}:{MODEL_SERVER_PORT}")
|
||||||
if db_embedding_model.cloud_provider_id is None:
|
if db_embedding_model.cloud_provider_id is None:
|
||||||
|
@ -12,6 +12,7 @@ from danswer.db.document import get_ingestion_documents
|
|||||||
from danswer.db.embedding_model import get_current_db_embedding_model
|
from danswer.db.embedding_model import get_current_db_embedding_model
|
||||||
from danswer.db.embedding_model import get_secondary_db_embedding_model
|
from danswer.db.embedding_model import get_secondary_db_embedding_model
|
||||||
from danswer.db.engine import get_session
|
from danswer.db.engine import get_session
|
||||||
|
from danswer.db.models import User
|
||||||
from danswer.document_index.document_index_utils import get_both_index_names
|
from danswer.document_index.document_index_utils import get_both_index_names
|
||||||
from danswer.document_index.factory import get_default_document_index
|
from danswer.document_index.factory import get_default_document_index
|
||||||
from danswer.indexing.embedder import DefaultIndexingEmbedder
|
from danswer.indexing.embedder import DefaultIndexingEmbedder
|
||||||
@ -31,7 +32,7 @@ router = APIRouter(prefix="/danswer-api")
|
|||||||
@router.get("/connector-docs/{cc_pair_id}")
|
@router.get("/connector-docs/{cc_pair_id}")
|
||||||
def get_docs_by_connector_credential_pair(
|
def get_docs_by_connector_credential_pair(
|
||||||
cc_pair_id: int,
|
cc_pair_id: int,
|
||||||
_: str = Depends(api_key_dep),
|
_: User | None = Depends(api_key_dep),
|
||||||
db_session: Session = Depends(get_session),
|
db_session: Session = Depends(get_session),
|
||||||
) -> list[DocMinimalInfo]:
|
) -> list[DocMinimalInfo]:
|
||||||
db_docs = get_documents_by_cc_pair(cc_pair_id=cc_pair_id, db_session=db_session)
|
db_docs = get_documents_by_cc_pair(cc_pair_id=cc_pair_id, db_session=db_session)
|
||||||
@ -47,7 +48,7 @@ def get_docs_by_connector_credential_pair(
|
|||||||
|
|
||||||
@router.get("/ingestion")
|
@router.get("/ingestion")
|
||||||
def get_ingestion_docs(
|
def get_ingestion_docs(
|
||||||
_: str = Depends(api_key_dep),
|
_: User | None = Depends(api_key_dep),
|
||||||
db_session: Session = Depends(get_session),
|
db_session: Session = Depends(get_session),
|
||||||
) -> list[DocMinimalInfo]:
|
) -> list[DocMinimalInfo]:
|
||||||
db_docs = get_ingestion_documents(db_session)
|
db_docs = get_ingestion_documents(db_session)
|
||||||
@ -64,7 +65,7 @@ def get_ingestion_docs(
|
|||||||
@router.post("/ingestion")
|
@router.post("/ingestion")
|
||||||
def upsert_ingestion_doc(
|
def upsert_ingestion_doc(
|
||||||
doc_info: IngestionDocument,
|
doc_info: IngestionDocument,
|
||||||
_: str = Depends(api_key_dep),
|
_: User | None = Depends(api_key_dep),
|
||||||
db_session: Session = Depends(get_session),
|
db_session: Session = Depends(get_session),
|
||||||
) -> IngestionResult:
|
) -> IngestionResult:
|
||||||
doc_info.document.from_ingestion_api = True
|
doc_info.document.from_ingestion_api = True
|
||||||
|
@ -7,6 +7,7 @@ from pydantic import BaseModel
|
|||||||
from sqlalchemy.orm import Session
|
from sqlalchemy.orm import Session
|
||||||
|
|
||||||
from danswer.db.engine import get_session
|
from danswer.db.engine import get_session
|
||||||
|
from danswer.db.models import User
|
||||||
from danswer.llm.factory import get_default_llms
|
from danswer.llm.factory import get_default_llms
|
||||||
from danswer.search.models import SearchRequest
|
from danswer.search.models import SearchRequest
|
||||||
from danswer.search.pipeline import SearchPipeline
|
from danswer.search.pipeline import SearchPipeline
|
||||||
@ -64,7 +65,7 @@ class GptSearchResponse(BaseModel):
|
|||||||
@router.post("/gpt-document-search")
|
@router.post("/gpt-document-search")
|
||||||
def gpt_search(
|
def gpt_search(
|
||||||
search_request: GptSearchRequest,
|
search_request: GptSearchRequest,
|
||||||
_: str | None = Depends(api_key_dep),
|
_: User | None = Depends(api_key_dep),
|
||||||
db_session: Session = Depends(get_session),
|
db_session: Session = Depends(get_session),
|
||||||
) -> GptSearchResponse:
|
) -> GptSearchResponse:
|
||||||
llm, fast_llm = get_default_llms()
|
llm, fast_llm = get_default_llms()
|
||||||
|
@ -44,7 +44,12 @@ async def optional_user_(
|
|||||||
return user
|
return user
|
||||||
|
|
||||||
|
|
||||||
def api_key_dep(request: Request, db_session: Session = Depends(get_session)) -> User:
|
def api_key_dep(
|
||||||
|
request: Request, db_session: Session = Depends(get_session)
|
||||||
|
) -> User | None:
|
||||||
|
if AUTH_TYPE == AuthType.DISABLED:
|
||||||
|
return None
|
||||||
|
|
||||||
hashed_api_key = get_hashed_api_key_from_request(request)
|
hashed_api_key = get_hashed_api_key_from_request(request)
|
||||||
if not hashed_api_key:
|
if not hashed_api_key:
|
||||||
raise HTTPException(status_code=401, detail="Missing API key")
|
raise HTTPException(status_code=401, detail="Missing API key")
|
||||||
|
1
backend/tests/integration/common/constants.py
Normal file
1
backend/tests/integration/common/constants.py
Normal file
@ -0,0 +1 @@
|
|||||||
|
API_SERVER_URL = "http://localhost:8080"
|
164
backend/tests/integration/common/reset.py
Normal file
164
backend/tests/integration/common/reset.py
Normal file
@ -0,0 +1,164 @@
|
|||||||
|
import logging
|
||||||
|
import time
|
||||||
|
|
||||||
|
import psycopg2
|
||||||
|
import requests
|
||||||
|
|
||||||
|
from alembic import command
|
||||||
|
from alembic.config import Config
|
||||||
|
from danswer.configs.app_configs import POSTGRES_HOST
|
||||||
|
from danswer.configs.app_configs import POSTGRES_PASSWORD
|
||||||
|
from danswer.configs.app_configs import POSTGRES_PORT
|
||||||
|
from danswer.configs.app_configs import POSTGRES_USER
|
||||||
|
from danswer.db.embedding_model import get_current_db_embedding_model
|
||||||
|
from danswer.db.engine import build_connection_string
|
||||||
|
from danswer.db.engine import get_session_context_manager
|
||||||
|
from danswer.db.engine import SYNC_DB_API
|
||||||
|
from danswer.db.swap_index import check_index_swap
|
||||||
|
from danswer.document_index.vespa.index import DOCUMENT_ID_ENDPOINT
|
||||||
|
from danswer.document_index.vespa.index import VespaIndex
|
||||||
|
from danswer.main import setup_postgres
|
||||||
|
from danswer.main import setup_vespa
|
||||||
|
|
||||||
|
|
||||||
|
def _run_migrations(
|
||||||
|
database_url: str, direction: str = "upgrade", revision: str = "head"
|
||||||
|
) -> None:
|
||||||
|
# hide info logs emitted during migration
|
||||||
|
logging.getLogger("alembic").setLevel(logging.CRITICAL)
|
||||||
|
|
||||||
|
# Create an Alembic configuration object
|
||||||
|
alembic_cfg = Config("alembic.ini")
|
||||||
|
alembic_cfg.set_section_option("logger_alembic", "level", "WARN")
|
||||||
|
|
||||||
|
# Set the SQLAlchemy URL in the Alembic configuration
|
||||||
|
alembic_cfg.set_main_option("sqlalchemy.url", database_url)
|
||||||
|
|
||||||
|
# Run the migration
|
||||||
|
if direction == "upgrade":
|
||||||
|
command.upgrade(alembic_cfg, revision)
|
||||||
|
elif direction == "downgrade":
|
||||||
|
command.downgrade(alembic_cfg, revision)
|
||||||
|
else:
|
||||||
|
raise ValueError(
|
||||||
|
f"Invalid direction: {direction}. Must be 'upgrade' or 'downgrade'."
|
||||||
|
)
|
||||||
|
|
||||||
|
logging.getLogger("alembic").setLevel(logging.INFO)
|
||||||
|
|
||||||
|
|
||||||
|
def reset_postgres(database: str = "postgres") -> None:
|
||||||
|
"""Reset the Postgres database."""
|
||||||
|
|
||||||
|
# NOTE: need to delete all rows to allow migrations to be rolled back
|
||||||
|
# as there are a few downgrades that don't properly handle data in tables
|
||||||
|
conn = psycopg2.connect(
|
||||||
|
dbname=database,
|
||||||
|
user=POSTGRES_USER,
|
||||||
|
password=POSTGRES_PASSWORD,
|
||||||
|
host=POSTGRES_HOST,
|
||||||
|
port=POSTGRES_PORT,
|
||||||
|
)
|
||||||
|
cur = conn.cursor()
|
||||||
|
|
||||||
|
# Disable triggers to prevent foreign key constraints from being checked
|
||||||
|
cur.execute("SET session_replication_role = 'replica';")
|
||||||
|
|
||||||
|
# Fetch all table names in the current database
|
||||||
|
cur.execute(
|
||||||
|
"""
|
||||||
|
SELECT tablename
|
||||||
|
FROM pg_tables
|
||||||
|
WHERE schemaname = 'public'
|
||||||
|
"""
|
||||||
|
)
|
||||||
|
|
||||||
|
tables = cur.fetchall()
|
||||||
|
|
||||||
|
for table in tables:
|
||||||
|
table_name = table[0]
|
||||||
|
|
||||||
|
# Don't touch migration history
|
||||||
|
if table_name == "alembic_version":
|
||||||
|
continue
|
||||||
|
|
||||||
|
cur.execute(f'DELETE FROM "{table_name}"')
|
||||||
|
|
||||||
|
# Re-enable triggers
|
||||||
|
cur.execute("SET session_replication_role = 'origin';")
|
||||||
|
|
||||||
|
conn.commit()
|
||||||
|
cur.close()
|
||||||
|
conn.close()
|
||||||
|
|
||||||
|
# downgrade to base + upgrade back to head
|
||||||
|
conn_str = build_connection_string(
|
||||||
|
db=database,
|
||||||
|
user=POSTGRES_USER,
|
||||||
|
password=POSTGRES_PASSWORD,
|
||||||
|
host=POSTGRES_HOST,
|
||||||
|
port=POSTGRES_PORT,
|
||||||
|
db_api=SYNC_DB_API,
|
||||||
|
)
|
||||||
|
_run_migrations(
|
||||||
|
conn_str,
|
||||||
|
direction="downgrade",
|
||||||
|
revision="base",
|
||||||
|
)
|
||||||
|
_run_migrations(
|
||||||
|
conn_str,
|
||||||
|
direction="upgrade",
|
||||||
|
revision="head",
|
||||||
|
)
|
||||||
|
|
||||||
|
# do the same thing as we do on API server startup
|
||||||
|
with get_session_context_manager() as db_session:
|
||||||
|
setup_postgres(db_session)
|
||||||
|
|
||||||
|
|
||||||
|
def reset_vespa() -> None:
|
||||||
|
"""Wipe all data from the Vespa index."""
|
||||||
|
with get_session_context_manager() as db_session:
|
||||||
|
# swap to the correct default model
|
||||||
|
check_index_swap(db_session)
|
||||||
|
|
||||||
|
current_model = get_current_db_embedding_model(db_session)
|
||||||
|
index_name = current_model.index_name
|
||||||
|
|
||||||
|
setup_vespa(
|
||||||
|
document_index=VespaIndex(index_name=index_name, secondary_index_name=None),
|
||||||
|
db_embedding_model=current_model,
|
||||||
|
secondary_db_embedding_model=None,
|
||||||
|
)
|
||||||
|
|
||||||
|
for _ in range(5):
|
||||||
|
try:
|
||||||
|
continuation = None
|
||||||
|
should_continue = True
|
||||||
|
while should_continue:
|
||||||
|
params = {"selection": "true", "cluster": "danswer_index"}
|
||||||
|
if continuation:
|
||||||
|
params = {**params, "continuation": continuation}
|
||||||
|
response = requests.delete(
|
||||||
|
DOCUMENT_ID_ENDPOINT.format(index_name=index_name), params=params
|
||||||
|
)
|
||||||
|
response.raise_for_status()
|
||||||
|
|
||||||
|
response_json = response.json()
|
||||||
|
|
||||||
|
continuation = response_json.get("continuation")
|
||||||
|
should_continue = bool(continuation)
|
||||||
|
|
||||||
|
break
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Error deleting documents: {e}")
|
||||||
|
time.sleep(5)
|
||||||
|
|
||||||
|
|
||||||
|
def reset_all() -> None:
|
||||||
|
"""Reset both Postgres and Vespa."""
|
||||||
|
print("Resetting Postgres...")
|
||||||
|
reset_postgres()
|
||||||
|
print("Resetting Vespa...")
|
||||||
|
reset_vespa()
|
||||||
|
print("Finished resetting all.")
|
83
backend/tests/integration/common/seed_documents.py
Normal file
83
backend/tests/integration/common/seed_documents.py
Normal file
@ -0,0 +1,83 @@
|
|||||||
|
import uuid
|
||||||
|
from typing import cast
|
||||||
|
|
||||||
|
import requests
|
||||||
|
from pydantic import BaseModel
|
||||||
|
|
||||||
|
from danswer.configs.constants import DocumentSource
|
||||||
|
from tests.integration.common.constants import API_SERVER_URL
|
||||||
|
|
||||||
|
|
||||||
|
class SeedDocumentResponse(BaseModel):
|
||||||
|
cc_pair_id: int
|
||||||
|
document_ids: list[str]
|
||||||
|
|
||||||
|
|
||||||
|
class TestDocumentClient:
|
||||||
|
@staticmethod
|
||||||
|
def seed_documents(num_docs: int = 5) -> SeedDocumentResponse:
|
||||||
|
unique_id = uuid.uuid4()
|
||||||
|
|
||||||
|
# Create a connector
|
||||||
|
connector_name = f"test_connector_{unique_id}"
|
||||||
|
connector_data = {
|
||||||
|
"name": connector_name,
|
||||||
|
"source": DocumentSource.NOT_APPLICABLE,
|
||||||
|
"input_type": "load_state",
|
||||||
|
"connector_specific_config": {},
|
||||||
|
"refresh_freq": 60,
|
||||||
|
"disabled": True,
|
||||||
|
}
|
||||||
|
response = requests.post(
|
||||||
|
f"{API_SERVER_URL}/manage/admin/connector",
|
||||||
|
json=connector_data,
|
||||||
|
)
|
||||||
|
response.raise_for_status()
|
||||||
|
connector_id = response.json()["id"]
|
||||||
|
|
||||||
|
# Associate the credential with the connector
|
||||||
|
cc_pair_metadata = {"name": f"test_cc_pair_{unique_id}", "is_public": True}
|
||||||
|
response = requests.put(
|
||||||
|
f"{API_SERVER_URL}/manage/connector/{connector_id}/credential/0",
|
||||||
|
json=cc_pair_metadata,
|
||||||
|
)
|
||||||
|
response.raise_for_status()
|
||||||
|
cc_pair_id = cast(int, response.json()["data"])
|
||||||
|
|
||||||
|
# Create and ingest some documents
|
||||||
|
document_ids: list[str] = []
|
||||||
|
for _ in range(num_docs):
|
||||||
|
document_id = f"test-doc-{uuid.uuid4()}"
|
||||||
|
document_ids.append(document_id)
|
||||||
|
|
||||||
|
document = {
|
||||||
|
"document": {
|
||||||
|
"id": document_id,
|
||||||
|
"sections": [
|
||||||
|
{
|
||||||
|
"text": f"This is test document {document_id}",
|
||||||
|
"link": f"{document_id}",
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"source": DocumentSource.NOT_APPLICABLE,
|
||||||
|
"metadata": {},
|
||||||
|
"semantic_identifier": f"Test Document {document_id}",
|
||||||
|
"from_ingestion_api": True,
|
||||||
|
},
|
||||||
|
"cc_pair_id": cc_pair_id,
|
||||||
|
}
|
||||||
|
response = requests.post(
|
||||||
|
f"{API_SERVER_URL}/danswer-api/ingestion",
|
||||||
|
json=document,
|
||||||
|
)
|
||||||
|
response.raise_for_status()
|
||||||
|
|
||||||
|
print("Seeding completed successfully.")
|
||||||
|
return SeedDocumentResponse(
|
||||||
|
cc_pair_id=cc_pair_id,
|
||||||
|
document_ids=document_ids,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
seed_documents_resp = TestDocumentClient.seed_documents()
|
27
backend/tests/integration/common/vespa.py
Normal file
27
backend/tests/integration/common/vespa.py
Normal file
@ -0,0 +1,27 @@
|
|||||||
|
import requests
|
||||||
|
|
||||||
|
from danswer.document_index.vespa.index import DOCUMENT_ID_ENDPOINT
|
||||||
|
|
||||||
|
|
||||||
|
class TestVespaClient:
|
||||||
|
def __init__(self, index_name: str):
|
||||||
|
self.index_name = index_name
|
||||||
|
self.vespa_document_url = DOCUMENT_ID_ENDPOINT.format(index_name=index_name)
|
||||||
|
|
||||||
|
def get_documents_by_id(
|
||||||
|
self, document_ids: list[str], wanted_doc_count: int = 1_000
|
||||||
|
) -> dict:
|
||||||
|
selection = " or ".join(
|
||||||
|
f"{self.index_name}.document_id=='{document_id}'"
|
||||||
|
for document_id in document_ids
|
||||||
|
)
|
||||||
|
params = {
|
||||||
|
"selection": selection,
|
||||||
|
"wantedDocumentCount": wanted_doc_count,
|
||||||
|
}
|
||||||
|
response = requests.get(
|
||||||
|
self.vespa_document_url,
|
||||||
|
params=params, # type: ignore
|
||||||
|
)
|
||||||
|
response.raise_for_status()
|
||||||
|
return response.json()
|
26
backend/tests/integration/conftest.py
Normal file
26
backend/tests/integration/conftest.py
Normal file
@ -0,0 +1,26 @@
|
|||||||
|
from collections.abc import Generator
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
from sqlalchemy.orm import Session
|
||||||
|
|
||||||
|
from danswer.db.embedding_model import get_current_db_embedding_model
|
||||||
|
from danswer.db.engine import get_session_context_manager
|
||||||
|
from tests.integration.common.reset import reset_all
|
||||||
|
from tests.integration.common.vespa import TestVespaClient
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def db_session() -> Generator[Session, None, None]:
|
||||||
|
with get_session_context_manager() as session:
|
||||||
|
yield session
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def vespa_client(db_session: Session) -> TestVespaClient:
|
||||||
|
current_model = get_current_db_embedding_model(db_session)
|
||||||
|
return TestVespaClient(index_name=current_model.index_name)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def reset() -> None:
|
||||||
|
reset_all()
|
79
backend/tests/integration/document_set/test_syncing.py
Normal file
79
backend/tests/integration/document_set/test_syncing.py
Normal file
@ -0,0 +1,79 @@
|
|||||||
|
import time
|
||||||
|
|
||||||
|
from danswer.server.features.document_set.models import DocumentSetCreationRequest
|
||||||
|
from tests.integration.common.seed_documents import TestDocumentClient
|
||||||
|
from tests.integration.common.vespa import TestVespaClient
|
||||||
|
from tests.integration.document_set.utils import create_document_set
|
||||||
|
from tests.integration.document_set.utils import fetch_document_sets
|
||||||
|
|
||||||
|
|
||||||
|
def test_multiple_document_sets_syncing_same_connnector(
|
||||||
|
reset: None, vespa_client: TestVespaClient
|
||||||
|
) -> None:
|
||||||
|
# Seed documents
|
||||||
|
seed_result = TestDocumentClient.seed_documents(num_docs=5)
|
||||||
|
cc_pair_id = seed_result.cc_pair_id
|
||||||
|
|
||||||
|
# Create first document set
|
||||||
|
doc_set_1_id = create_document_set(
|
||||||
|
DocumentSetCreationRequest(
|
||||||
|
name="Test Document Set 1",
|
||||||
|
description="First test document set",
|
||||||
|
cc_pair_ids=[cc_pair_id],
|
||||||
|
is_public=True,
|
||||||
|
users=[],
|
||||||
|
groups=[],
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
doc_set_2_id = create_document_set(
|
||||||
|
DocumentSetCreationRequest(
|
||||||
|
name="Test Document Set 2",
|
||||||
|
description="Second test document set",
|
||||||
|
cc_pair_ids=[cc_pair_id],
|
||||||
|
is_public=True,
|
||||||
|
users=[],
|
||||||
|
groups=[],
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
# wait for syncing to be complete
|
||||||
|
max_delay = 45
|
||||||
|
start = time.time()
|
||||||
|
while True:
|
||||||
|
doc_sets = fetch_document_sets()
|
||||||
|
doc_set_1 = next(
|
||||||
|
(doc_set for doc_set in doc_sets if doc_set.id == doc_set_1_id), None
|
||||||
|
)
|
||||||
|
doc_set_2 = next(
|
||||||
|
(doc_set for doc_set in doc_sets if doc_set.id == doc_set_2_id), None
|
||||||
|
)
|
||||||
|
|
||||||
|
if not doc_set_1 or not doc_set_2:
|
||||||
|
raise RuntimeError("Document set not found")
|
||||||
|
|
||||||
|
if doc_set_1.is_up_to_date and doc_set_2.is_up_to_date:
|
||||||
|
assert [ccp.id for ccp in doc_set_1.cc_pair_descriptors] == [
|
||||||
|
ccp.id for ccp in doc_set_2.cc_pair_descriptors
|
||||||
|
]
|
||||||
|
break
|
||||||
|
|
||||||
|
if time.time() - start > max_delay:
|
||||||
|
raise TimeoutError("Document sets were not synced within the max delay")
|
||||||
|
|
||||||
|
time.sleep(2)
|
||||||
|
|
||||||
|
# get names so we can compare to what is in vespa
|
||||||
|
doc_sets = fetch_document_sets()
|
||||||
|
doc_set_names = {doc_set.name for doc_set in doc_sets}
|
||||||
|
|
||||||
|
# make sure documents are as expected
|
||||||
|
result = vespa_client.get_documents_by_id(seed_result.document_ids)
|
||||||
|
documents = result["documents"]
|
||||||
|
assert len(documents) == len(seed_result.document_ids)
|
||||||
|
assert all(
|
||||||
|
doc["fields"]["document_id"] in seed_result.document_ids for doc in documents
|
||||||
|
)
|
||||||
|
assert all(
|
||||||
|
set(doc["fields"]["document_sets"].keys()) == doc_set_names for doc in documents
|
||||||
|
)
|
26
backend/tests/integration/document_set/utils.py
Normal file
26
backend/tests/integration/document_set/utils.py
Normal file
@ -0,0 +1,26 @@
|
|||||||
|
from typing import cast
|
||||||
|
|
||||||
|
import requests
|
||||||
|
|
||||||
|
from danswer.server.features.document_set.models import DocumentSet
|
||||||
|
from danswer.server.features.document_set.models import DocumentSetCreationRequest
|
||||||
|
from tests.integration.common.constants import API_SERVER_URL
|
||||||
|
|
||||||
|
|
||||||
|
def create_document_set(doc_set_creation_request: DocumentSetCreationRequest) -> int:
|
||||||
|
response = requests.post(
|
||||||
|
f"{API_SERVER_URL}/manage/admin/document-set",
|
||||||
|
json=doc_set_creation_request.dict(),
|
||||||
|
)
|
||||||
|
response.raise_for_status()
|
||||||
|
return cast(int, response.json())
|
||||||
|
|
||||||
|
|
||||||
|
def fetch_document_sets() -> list[DocumentSet]:
|
||||||
|
response = requests.get(f"{API_SERVER_URL}/manage/admin/document-set")
|
||||||
|
response.raise_for_status()
|
||||||
|
|
||||||
|
document_sets = [
|
||||||
|
DocumentSet.parse_obj(doc_set_data) for doc_set_data in response.json()
|
||||||
|
]
|
||||||
|
return document_sets
|
Reference in New Issue
Block a user