mirror of
https://github.com/danswer-ai/danswer.git
synced 2025-05-31 10:10:21 +02:00
Tenant integration tests (#2913)
* check for index swap * initial bones * kk * k * k: * nit * nit * rebase + update * nit * minior update * k * minor integration test fixes * nit * ensure we build test docker image * remove one space * k * ensure we wipe volumes * remove log * typo * nit * k * k
This commit is contained in:
parent
bd63119684
commit
9b147ae437
65
.github/workflows/pr-Integration-tests.yml
vendored
65
.github/workflows/pr-Integration-tests.yml
vendored
@ -72,7 +72,7 @@ jobs:
|
||||
load: true
|
||||
cache-from: type=s3,prefix=cache/${{ github.repository }}/integration-tests/model-server/,region=${{ env.RUNS_ON_AWS_REGION }},bucket=${{ env.RUNS_ON_S3_BUCKET_CACHE }}
|
||||
cache-to: type=s3,prefix=cache/${{ github.repository }}/integration-tests/model-server/,region=${{ env.RUNS_ON_AWS_REGION }},bucket=${{ env.RUNS_ON_S3_BUCKET_CACHE }},mode=max
|
||||
|
||||
|
||||
- name: Build integration test Docker image
|
||||
uses: ./.github/actions/custom-build-and-push
|
||||
with:
|
||||
@ -85,7 +85,58 @@ jobs:
|
||||
cache-from: type=s3,prefix=cache/${{ github.repository }}/integration-tests/integration/,region=${{ env.RUNS_ON_AWS_REGION }},bucket=${{ env.RUNS_ON_S3_BUCKET_CACHE }}
|
||||
cache-to: type=s3,prefix=cache/${{ github.repository }}/integration-tests/integration/,region=${{ env.RUNS_ON_AWS_REGION }},bucket=${{ env.RUNS_ON_S3_BUCKET_CACHE }},mode=max
|
||||
|
||||
- name: Start Docker containers
|
||||
# Start containers for multi-tenant tests
|
||||
- name: Start Docker containers for multi-tenant tests
|
||||
run: |
|
||||
cd deployment/docker_compose
|
||||
ENABLE_PAID_ENTERPRISE_EDITION_FEATURES=true \
|
||||
MULTI_TENANT=true \
|
||||
AUTH_TYPE=basic \
|
||||
REQUIRE_EMAIL_VERIFICATION=false \
|
||||
DISABLE_TELEMETRY=true \
|
||||
IMAGE_TAG=test \
|
||||
docker compose -f docker-compose.dev.yml -p danswer-stack up -d
|
||||
id: start_docker_multi_tenant
|
||||
|
||||
# In practice, `cloud` Auth type would require OAUTH credentials to be set.
|
||||
- name: Run Multi-Tenant Integration Tests
|
||||
run: |
|
||||
echo "Running integration tests..."
|
||||
docker run --rm --network danswer-stack_default \
|
||||
--name test-runner \
|
||||
-e POSTGRES_HOST=relational_db \
|
||||
-e POSTGRES_USER=postgres \
|
||||
-e POSTGRES_PASSWORD=password \
|
||||
-e POSTGRES_DB=postgres \
|
||||
-e VESPA_HOST=index \
|
||||
-e REDIS_HOST=cache \
|
||||
-e API_SERVER_HOST=api_server \
|
||||
-e OPENAI_API_KEY=${OPENAI_API_KEY} \
|
||||
-e SLACK_BOT_TOKEN=${SLACK_BOT_TOKEN} \
|
||||
-e TEST_WEB_HOSTNAME=test-runner \
|
||||
-e AUTH_TYPE=cloud \
|
||||
-e MULTI_TENANT=true \
|
||||
danswer/danswer-integration:test \
|
||||
/app/tests/integration/multitenant_tests
|
||||
continue-on-error: true
|
||||
id: run_multitenant_tests
|
||||
|
||||
- name: Check multi-tenant test results
|
||||
run: |
|
||||
if [ ${{ steps.run_tests.outcome }} == 'failure' ]; then
|
||||
echo "Integration tests failed. Exiting with error."
|
||||
exit 1
|
||||
else
|
||||
echo "All integration tests passed successfully."
|
||||
fi
|
||||
|
||||
- name: Stop multi-tenant Docker containers
|
||||
run: |
|
||||
cd deployment/docker_compose
|
||||
docker compose -f docker-compose.dev.yml -p danswer-stack down -v
|
||||
|
||||
|
||||
- name: Start Docker containers
|
||||
run: |
|
||||
cd deployment/docker_compose
|
||||
ENABLE_PAID_ENTERPRISE_EDITION_FEATURES=true \
|
||||
@ -130,7 +181,7 @@ jobs:
|
||||
done
|
||||
echo "Finished waiting for service."
|
||||
|
||||
- name: Run integration tests
|
||||
- name: Run Standard Integration Tests
|
||||
run: |
|
||||
echo "Running integration tests..."
|
||||
docker run --rm --network danswer-stack_default \
|
||||
@ -145,7 +196,8 @@ jobs:
|
||||
-e OPENAI_API_KEY=${OPENAI_API_KEY} \
|
||||
-e SLACK_BOT_TOKEN=${SLACK_BOT_TOKEN} \
|
||||
-e TEST_WEB_HOSTNAME=test-runner \
|
||||
danswer/danswer-integration:test
|
||||
danswer/danswer-integration:test \
|
||||
/app/tests/integration/tests
|
||||
continue-on-error: true
|
||||
id: run_tests
|
||||
|
||||
@ -158,6 +210,11 @@ jobs:
|
||||
echo "All integration tests passed successfully."
|
||||
fi
|
||||
|
||||
- name: Stop Docker containers
|
||||
run: |
|
||||
cd deployment/docker_compose
|
||||
docker compose -f docker-compose.dev.yml -p danswer-stack down -v
|
||||
|
||||
- name: Save Docker logs
|
||||
if: success() || failure()
|
||||
run: |
|
||||
|
@ -58,6 +58,7 @@ from danswer.auth.schemas import UserRole
|
||||
from danswer.auth.schemas import UserUpdate
|
||||
from danswer.configs.app_configs import AUTH_TYPE
|
||||
from danswer.configs.app_configs import DISABLE_AUTH
|
||||
from danswer.configs.app_configs import DISABLE_VERIFICATION
|
||||
from danswer.configs.app_configs import EMAIL_FROM
|
||||
from danswer.configs.app_configs import MULTI_TENANT
|
||||
from danswer.configs.app_configs import REQUIRE_EMAIL_VERIFICATION
|
||||
@ -133,7 +134,9 @@ def get_display_email(email: str | None, space_less: bool = False) -> str:
|
||||
def user_needs_to_be_verified() -> bool:
|
||||
# all other auth types besides basic should require users to be
|
||||
# verified
|
||||
return AUTH_TYPE != AuthType.BASIC or REQUIRE_EMAIL_VERIFICATION
|
||||
return not DISABLE_VERIFICATION and (
|
||||
AUTH_TYPE != AuthType.BASIC or REQUIRE_EMAIL_VERIFICATION
|
||||
)
|
||||
|
||||
|
||||
def verify_email_is_invited(email: str) -> None:
|
||||
|
@ -8,18 +8,11 @@ from celery.signals import celeryd_init
|
||||
from celery.signals import worker_init
|
||||
from celery.signals import worker_ready
|
||||
from celery.signals import worker_shutdown
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
import danswer.background.celery.apps.app_base as app_base
|
||||
from danswer.configs.constants import POSTGRES_CELERY_WORKER_INDEXING_APP_NAME
|
||||
from danswer.db.engine import SqlEngine
|
||||
from danswer.db.search_settings import get_current_search_settings
|
||||
from danswer.db.swap_index import check_index_swap
|
||||
from danswer.natural_language_processing.search_nlp_models import EmbeddingModel
|
||||
from danswer.natural_language_processing.search_nlp_models import warm_up_bi_encoder
|
||||
from danswer.utils.logger import setup_logger
|
||||
from shared_configs.configs import INDEXING_MODEL_SERVER_HOST
|
||||
from shared_configs.configs import MODEL_SERVER_PORT
|
||||
|
||||
|
||||
logger = setup_logger()
|
||||
@ -67,27 +60,6 @@ def on_worker_init(sender: Any, **kwargs: Any) -> None:
|
||||
SqlEngine.set_app_name(POSTGRES_CELERY_WORKER_INDEXING_APP_NAME)
|
||||
SqlEngine.init_engine(pool_size=8, max_overflow=0)
|
||||
|
||||
# TODO: why is this necessary for the indexer to do?
|
||||
engine = SqlEngine.get_engine()
|
||||
with Session(engine) as db_session:
|
||||
check_index_swap(db_session=db_session)
|
||||
search_settings = get_current_search_settings(db_session)
|
||||
|
||||
# So that the first time users aren't surprised by really slow speed of first
|
||||
# batch of documents indexed
|
||||
if search_settings.provider_type is None:
|
||||
logger.notice("Running a first inference to warm up embedding model")
|
||||
embedding_model = EmbeddingModel.from_db_model(
|
||||
search_settings=search_settings,
|
||||
server_host=INDEXING_MODEL_SERVER_HOST,
|
||||
server_port=MODEL_SERVER_PORT,
|
||||
)
|
||||
|
||||
warm_up_bi_encoder(
|
||||
embedding_model=embedding_model,
|
||||
)
|
||||
logger.notice("First inference complete.")
|
||||
|
||||
app_base.wait_for_redis(sender, **kwargs)
|
||||
app_base.on_secondary_worker_init(sender, **kwargs)
|
||||
|
||||
|
@ -1,494 +0,0 @@
|
||||
# TODO(rkuo): delete after background indexing via celery is fully vetted
|
||||
# import logging
|
||||
# import time
|
||||
# from datetime import datetime
|
||||
# import dask
|
||||
# from dask.distributed import Client
|
||||
# from dask.distributed import Future
|
||||
# from distributed import LocalCluster
|
||||
# from sqlalchemy import text
|
||||
# from sqlalchemy.exc import ProgrammingError
|
||||
# from sqlalchemy.orm import Session
|
||||
# from danswer.background.indexing.dask_utils import ResourceLogger
|
||||
# from danswer.background.indexing.job_client import SimpleJob
|
||||
# from danswer.background.indexing.job_client import SimpleJobClient
|
||||
# from danswer.background.indexing.run_indexing import run_indexing_entrypoint
|
||||
# from danswer.configs.app_configs import CLEANUP_INDEXING_JOBS_TIMEOUT
|
||||
# from danswer.configs.app_configs import DASK_JOB_CLIENT_ENABLED
|
||||
# from danswer.configs.app_configs import DISABLE_INDEX_UPDATE_ON_SWAP
|
||||
# from danswer.configs.app_configs import MULTI_TENANT
|
||||
# from danswer.configs.app_configs import NUM_INDEXING_WORKERS
|
||||
# from danswer.configs.app_configs import NUM_SECONDARY_INDEXING_WORKERS
|
||||
# from danswer.configs.constants import DocumentSource
|
||||
# from danswer.configs.constants import POSTGRES_INDEXER_APP_NAME
|
||||
# from danswer.configs.constants import TENANT_ID_PREFIX
|
||||
# from danswer.db.connector import fetch_connectors
|
||||
# from danswer.db.connector_credential_pair import fetch_connector_credential_pairs
|
||||
# from danswer.db.engine import get_db_current_time
|
||||
# from danswer.db.engine import get_session_with_tenant
|
||||
# from danswer.db.engine import get_sqlalchemy_engine
|
||||
# from danswer.db.engine import SqlEngine
|
||||
# from danswer.db.index_attempt import create_index_attempt
|
||||
# from danswer.db.index_attempt import get_index_attempt
|
||||
# from danswer.db.index_attempt import get_inprogress_index_attempts
|
||||
# from danswer.db.index_attempt import get_last_attempt_for_cc_pair
|
||||
# from danswer.db.index_attempt import get_not_started_index_attempts
|
||||
# from danswer.db.index_attempt import mark_attempt_failed
|
||||
# from danswer.db.models import ConnectorCredentialPair
|
||||
# from danswer.db.models import IndexAttempt
|
||||
# from danswer.db.models import IndexingStatus
|
||||
# from danswer.db.models import IndexModelStatus
|
||||
# from danswer.db.models import SearchSettings
|
||||
# from danswer.db.search_settings import get_current_search_settings
|
||||
# from danswer.db.search_settings import get_secondary_search_settings
|
||||
# from danswer.db.swap_index import check_index_swap
|
||||
# from danswer.document_index.vespa.index import VespaIndex
|
||||
# from danswer.natural_language_processing.search_nlp_models import EmbeddingModel
|
||||
# from danswer.natural_language_processing.search_nlp_models import warm_up_bi_encoder
|
||||
# from danswer.utils.logger import setup_logger
|
||||
# from danswer.utils.variable_functionality import global_version
|
||||
# from danswer.utils.variable_functionality import set_is_ee_based_on_env_variable
|
||||
# from shared_configs.configs import INDEXING_MODEL_SERVER_HOST
|
||||
# from shared_configs.configs import INDEXING_MODEL_SERVER_PORT
|
||||
# from shared_configs.configs import LOG_LEVEL
|
||||
# logger = setup_logger()
|
||||
# # If the indexing dies, it's most likely due to resource constraints,
|
||||
# # restarting just delays the eventual failure, not useful to the user
|
||||
# dask.config.set({"distributed.scheduler.allowed-failures": 0})
|
||||
# _UNEXPECTED_STATE_FAILURE_REASON = (
|
||||
# "Stopped mid run, likely due to the background process being killed"
|
||||
# )
|
||||
# def _should_create_new_indexing(
|
||||
# cc_pair: ConnectorCredentialPair,
|
||||
# last_index: IndexAttempt | None,
|
||||
# search_settings_instance: SearchSettings,
|
||||
# secondary_index_building: bool,
|
||||
# db_session: Session,
|
||||
# ) -> bool:
|
||||
# connector = cc_pair.connector
|
||||
# # don't kick off indexing for `NOT_APPLICABLE` sources
|
||||
# if connector.source == DocumentSource.NOT_APPLICABLE:
|
||||
# return False
|
||||
# # User can still manually create single indexing attempts via the UI for the
|
||||
# # currently in use index
|
||||
# if DISABLE_INDEX_UPDATE_ON_SWAP:
|
||||
# if (
|
||||
# search_settings_instance.status == IndexModelStatus.PRESENT
|
||||
# and secondary_index_building
|
||||
# ):
|
||||
# return False
|
||||
# # When switching over models, always index at least once
|
||||
# if search_settings_instance.status == IndexModelStatus.FUTURE:
|
||||
# if last_index:
|
||||
# # No new index if the last index attempt succeeded
|
||||
# # Once is enough. The model will never be able to swap otherwise.
|
||||
# if last_index.status == IndexingStatus.SUCCESS:
|
||||
# return False
|
||||
# # No new index if the last index attempt is waiting to start
|
||||
# if last_index.status == IndexingStatus.NOT_STARTED:
|
||||
# return False
|
||||
# # No new index if the last index attempt is running
|
||||
# if last_index.status == IndexingStatus.IN_PROGRESS:
|
||||
# return False
|
||||
# else:
|
||||
# if (
|
||||
# connector.id == 0 or connector.source == DocumentSource.INGESTION_API
|
||||
# ): # Ingestion API
|
||||
# return False
|
||||
# return True
|
||||
# # If the connector is paused or is the ingestion API, don't index
|
||||
# # NOTE: during an embedding model switch over, the following logic
|
||||
# # is bypassed by the above check for a future model
|
||||
# if (
|
||||
# not cc_pair.status.is_active()
|
||||
# or connector.id == 0
|
||||
# or connector.source == DocumentSource.INGESTION_API
|
||||
# ):
|
||||
# return False
|
||||
# if not last_index:
|
||||
# return True
|
||||
# if connector.refresh_freq is None:
|
||||
# return False
|
||||
# # Only one scheduled/ongoing job per connector at a time
|
||||
# # this prevents cases where
|
||||
# # (1) the "latest" index_attempt is scheduled so we show
|
||||
# # that in the UI despite another index_attempt being in-progress
|
||||
# # (2) multiple scheduled index_attempts at a time
|
||||
# if (
|
||||
# last_index.status == IndexingStatus.NOT_STARTED
|
||||
# or last_index.status == IndexingStatus.IN_PROGRESS
|
||||
# ):
|
||||
# return False
|
||||
# current_db_time = get_db_current_time(db_session)
|
||||
# time_since_index = current_db_time - last_index.time_updated
|
||||
# return time_since_index.total_seconds() >= connector.refresh_freq
|
||||
# def _mark_run_failed(
|
||||
# db_session: Session, index_attempt: IndexAttempt, failure_reason: str
|
||||
# ) -> None:
|
||||
# """Marks the `index_attempt` row as failed + updates the `
|
||||
# connector_credential_pair` to reflect that the run failed"""
|
||||
# logger.warning(
|
||||
# f"Marking in-progress attempt 'connector: {index_attempt.connector_credential_pair.connector_id}, "
|
||||
# f"credential: {index_attempt.connector_credential_pair.credential_id}' as failed due to {failure_reason}"
|
||||
# )
|
||||
# mark_attempt_failed(
|
||||
# index_attempt=index_attempt,
|
||||
# db_session=db_session,
|
||||
# failure_reason=failure_reason,
|
||||
# )
|
||||
# """Main funcs"""
|
||||
# def create_indexing_jobs(
|
||||
# existing_jobs: dict[int, Future | SimpleJob], tenant_id: str | None
|
||||
# ) -> None:
|
||||
# """Creates new indexing jobs for each connector / credential pair which is:
|
||||
# 1. Enabled
|
||||
# 2. `refresh_frequency` time has passed since the last indexing run for this pair
|
||||
# 3. There is not already an ongoing indexing attempt for this pair
|
||||
# """
|
||||
# with get_session_with_tenant(tenant_id) as db_session:
|
||||
# ongoing: set[tuple[int | None, int]] = set()
|
||||
# for attempt_id in existing_jobs:
|
||||
# attempt = get_index_attempt(
|
||||
# db_session=db_session, index_attempt_id=attempt_id
|
||||
# )
|
||||
# if attempt is None:
|
||||
# logger.error(
|
||||
# f"Unable to find IndexAttempt for ID '{attempt_id}' when creating "
|
||||
# "indexing jobs"
|
||||
# )
|
||||
# continue
|
||||
# ongoing.add(
|
||||
# (
|
||||
# attempt.connector_credential_pair_id,
|
||||
# attempt.search_settings_id,
|
||||
# )
|
||||
# )
|
||||
# # Get the primary search settings
|
||||
# primary_search_settings = get_current_search_settings(db_session)
|
||||
# search_settings = [primary_search_settings]
|
||||
# # Check for secondary search settings
|
||||
# secondary_search_settings = get_secondary_search_settings(db_session)
|
||||
# if secondary_search_settings is not None:
|
||||
# # If secondary settings exist, add them to the list
|
||||
# search_settings.append(secondary_search_settings)
|
||||
# all_connector_credential_pairs = fetch_connector_credential_pairs(db_session)
|
||||
# for cc_pair in all_connector_credential_pairs:
|
||||
# for search_settings_instance in search_settings:
|
||||
# # Check if there is an ongoing indexing attempt for this connector credential pair
|
||||
# if (cc_pair.id, search_settings_instance.id) in ongoing:
|
||||
# continue
|
||||
# last_attempt = get_last_attempt_for_cc_pair(
|
||||
# cc_pair.id, search_settings_instance.id, db_session
|
||||
# )
|
||||
# if not _should_create_new_indexing(
|
||||
# cc_pair=cc_pair,
|
||||
# last_index=last_attempt,
|
||||
# search_settings_instance=search_settings_instance,
|
||||
# secondary_index_building=len(search_settings) > 1,
|
||||
# db_session=db_session,
|
||||
# ):
|
||||
# continue
|
||||
# create_index_attempt(
|
||||
# cc_pair.id, search_settings_instance.id, db_session
|
||||
# )
|
||||
# def cleanup_indexing_jobs(
|
||||
# existing_jobs: dict[int, Future | SimpleJob],
|
||||
# tenant_id: str | None,
|
||||
# timeout_hours: int = CLEANUP_INDEXING_JOBS_TIMEOUT,
|
||||
# ) -> dict[int, Future | SimpleJob]:
|
||||
# existing_jobs_copy = existing_jobs.copy()
|
||||
# # clean up completed jobs
|
||||
# with get_session_with_tenant(tenant_id) as db_session:
|
||||
# for attempt_id, job in existing_jobs.items():
|
||||
# index_attempt = get_index_attempt(
|
||||
# db_session=db_session, index_attempt_id=attempt_id
|
||||
# )
|
||||
# # do nothing for ongoing jobs that haven't been stopped
|
||||
# if not job.done():
|
||||
# if not index_attempt:
|
||||
# continue
|
||||
# if not index_attempt.is_finished():
|
||||
# continue
|
||||
# if job.status == "error":
|
||||
# logger.error(job.exception())
|
||||
# job.release()
|
||||
# del existing_jobs_copy[attempt_id]
|
||||
# if not index_attempt:
|
||||
# logger.error(
|
||||
# f"Unable to find IndexAttempt for ID '{attempt_id}' when cleaning "
|
||||
# "up indexing jobs"
|
||||
# )
|
||||
# continue
|
||||
# if (
|
||||
# index_attempt.status == IndexingStatus.IN_PROGRESS
|
||||
# or job.status == "error"
|
||||
# ):
|
||||
# _mark_run_failed(
|
||||
# db_session=db_session,
|
||||
# index_attempt=index_attempt,
|
||||
# failure_reason=_UNEXPECTED_STATE_FAILURE_REASON,
|
||||
# )
|
||||
# # clean up in-progress jobs that were never completed
|
||||
# try:
|
||||
# connectors = fetch_connectors(db_session)
|
||||
# for connector in connectors:
|
||||
# in_progress_indexing_attempts = get_inprogress_index_attempts(
|
||||
# connector.id, db_session
|
||||
# )
|
||||
# for index_attempt in in_progress_indexing_attempts:
|
||||
# if index_attempt.id in existing_jobs:
|
||||
# # If index attempt is canceled, stop the run
|
||||
# if index_attempt.status == IndexingStatus.FAILED:
|
||||
# existing_jobs[index_attempt.id].cancel()
|
||||
# # check to see if the job has been updated in last `timeout_hours` hours, if not
|
||||
# # assume it to frozen in some bad state and just mark it as failed. Note: this relies
|
||||
# # on the fact that the `time_updated` field is constantly updated every
|
||||
# # batch of documents indexed
|
||||
# current_db_time = get_db_current_time(db_session=db_session)
|
||||
# time_since_update = current_db_time - index_attempt.time_updated
|
||||
# if time_since_update.total_seconds() > 60 * 60 * timeout_hours:
|
||||
# existing_jobs[index_attempt.id].cancel()
|
||||
# _mark_run_failed(
|
||||
# db_session=db_session,
|
||||
# index_attempt=index_attempt,
|
||||
# failure_reason="Indexing run frozen - no updates in the last three hours. "
|
||||
# "The run will be re-attempted at next scheduled indexing time.",
|
||||
# )
|
||||
# else:
|
||||
# # If job isn't known, simply mark it as failed
|
||||
# _mark_run_failed(
|
||||
# db_session=db_session,
|
||||
# index_attempt=index_attempt,
|
||||
# failure_reason=_UNEXPECTED_STATE_FAILURE_REASON,
|
||||
# )
|
||||
# except ProgrammingError:
|
||||
# logger.debug(f"No Connector Table exists for: {tenant_id}")
|
||||
# return existing_jobs_copy
|
||||
# def kickoff_indexing_jobs(
|
||||
# existing_jobs: dict[int, Future | SimpleJob],
|
||||
# client: Client | SimpleJobClient,
|
||||
# secondary_client: Client | SimpleJobClient,
|
||||
# tenant_id: str | None,
|
||||
# ) -> dict[int, Future | SimpleJob]:
|
||||
# existing_jobs_copy = existing_jobs.copy()
|
||||
# current_session = get_session_with_tenant(tenant_id)
|
||||
# # Don't include jobs waiting in the Dask queue that just haven't started running
|
||||
# # Also (rarely) don't include for jobs that started but haven't updated the indexing tables yet
|
||||
# with current_session as db_session:
|
||||
# # get_not_started_index_attempts orders its returned results from oldest to newest
|
||||
# # we must process attempts in a FIFO manner to prevent connector starvation
|
||||
# new_indexing_attempts = [
|
||||
# (attempt, attempt.search_settings)
|
||||
# for attempt in get_not_started_index_attempts(db_session)
|
||||
# if attempt.id not in existing_jobs
|
||||
# ]
|
||||
# logger.debug(f"Found {len(new_indexing_attempts)} new indexing task(s).")
|
||||
# if not new_indexing_attempts:
|
||||
# return existing_jobs
|
||||
# indexing_attempt_count = 0
|
||||
# primary_client_full = False
|
||||
# secondary_client_full = False
|
||||
# for attempt, search_settings in new_indexing_attempts:
|
||||
# if primary_client_full and secondary_client_full:
|
||||
# break
|
||||
# use_secondary_index = (
|
||||
# search_settings.status == IndexModelStatus.FUTURE
|
||||
# if search_settings is not None
|
||||
# else False
|
||||
# )
|
||||
# if attempt.connector_credential_pair.connector is None:
|
||||
# logger.warning(
|
||||
# f"Skipping index attempt as Connector has been deleted: {attempt}"
|
||||
# )
|
||||
# with current_session as db_session:
|
||||
# mark_attempt_failed(
|
||||
# attempt, db_session, failure_reason="Connector is null"
|
||||
# )
|
||||
# continue
|
||||
# if attempt.connector_credential_pair.credential is None:
|
||||
# logger.warning(
|
||||
# f"Skipping index attempt as Credential has been deleted: {attempt}"
|
||||
# )
|
||||
# with current_session as db_session:
|
||||
# mark_attempt_failed(
|
||||
# attempt, db_session, failure_reason="Credential is null"
|
||||
# )
|
||||
# continue
|
||||
# if not use_secondary_index:
|
||||
# if not primary_client_full:
|
||||
# run = client.submit(
|
||||
# run_indexing_entrypoint,
|
||||
# attempt.id,
|
||||
# tenant_id,
|
||||
# attempt.connector_credential_pair_id,
|
||||
# global_version.is_ee_version(),
|
||||
# pure=False,
|
||||
# )
|
||||
# if not run:
|
||||
# primary_client_full = True
|
||||
# else:
|
||||
# if not secondary_client_full:
|
||||
# run = secondary_client.submit(
|
||||
# run_indexing_entrypoint,
|
||||
# attempt.id,
|
||||
# tenant_id,
|
||||
# attempt.connector_credential_pair_id,
|
||||
# global_version.is_ee_version(),
|
||||
# pure=False,
|
||||
# )
|
||||
# if not run:
|
||||
# secondary_client_full = True
|
||||
# if run:
|
||||
# if indexing_attempt_count == 0:
|
||||
# logger.info(
|
||||
# f"Indexing dispatch starts: pending={len(new_indexing_attempts)}"
|
||||
# )
|
||||
# indexing_attempt_count += 1
|
||||
# secondary_str = " (secondary index)" if use_secondary_index else ""
|
||||
# logger.info(
|
||||
# f"Indexing dispatched{secondary_str}: "
|
||||
# f"attempt_id={attempt.id} "
|
||||
# f"connector='{attempt.connector_credential_pair.connector.name}' "
|
||||
# f"config='{attempt.connector_credential_pair.connector.connector_specific_config}' "
|
||||
# f"credentials='{attempt.connector_credential_pair.credential_id}'"
|
||||
# )
|
||||
# existing_jobs_copy[attempt.id] = run
|
||||
# if indexing_attempt_count > 0:
|
||||
# logger.info(
|
||||
# f"Indexing dispatch results: "
|
||||
# f"initial_pending={len(new_indexing_attempts)} "
|
||||
# f"started={indexing_attempt_count} "
|
||||
# f"remaining={len(new_indexing_attempts) - indexing_attempt_count}"
|
||||
# )
|
||||
# return existing_jobs_copy
|
||||
# def get_all_tenant_ids() -> list[str] | list[None]:
|
||||
# if not MULTI_TENANT:
|
||||
# return [None]
|
||||
# with get_session_with_tenant(tenant_id="public") as session:
|
||||
# result = session.execute(
|
||||
# text(
|
||||
# """
|
||||
# SELECT schema_name
|
||||
# FROM information_schema.schemata
|
||||
# WHERE schema_name NOT IN ('pg_catalog', 'information_schema', 'public')"""
|
||||
# )
|
||||
# )
|
||||
# tenant_ids = [row[0] for row in result]
|
||||
# valid_tenants = [
|
||||
# tenant
|
||||
# for tenant in tenant_ids
|
||||
# if tenant is None or tenant.startswith(TENANT_ID_PREFIX)
|
||||
# ]
|
||||
# return valid_tenants
|
||||
# def update_loop(
|
||||
# delay: int = 10,
|
||||
# num_workers: int = NUM_INDEXING_WORKERS,
|
||||
# num_secondary_workers: int = NUM_SECONDARY_INDEXING_WORKERS,
|
||||
# ) -> None:
|
||||
# if not MULTI_TENANT:
|
||||
# # We can use this function as we are certain only the public schema exists
|
||||
# # (explicitly for the non-`MULTI_TENANT` case)
|
||||
# engine = get_sqlalchemy_engine()
|
||||
# with Session(engine) as db_session:
|
||||
# check_index_swap(db_session=db_session)
|
||||
# search_settings = get_current_search_settings(db_session)
|
||||
# # So that the first time users aren't surprised by really slow speed of first
|
||||
# # batch of documents indexed
|
||||
# if search_settings.provider_type is None:
|
||||
# logger.notice("Running a first inference to warm up embedding model")
|
||||
# embedding_model = EmbeddingModel.from_db_model(
|
||||
# search_settings=search_settings,
|
||||
# server_host=INDEXING_MODEL_SERVER_HOST,
|
||||
# server_port=INDEXING_MODEL_SERVER_PORT,
|
||||
# )
|
||||
# warm_up_bi_encoder(
|
||||
# embedding_model=embedding_model,
|
||||
# )
|
||||
# logger.notice("First inference complete.")
|
||||
# client_primary: Client | SimpleJobClient
|
||||
# client_secondary: Client | SimpleJobClient
|
||||
# if DASK_JOB_CLIENT_ENABLED:
|
||||
# cluster_primary = LocalCluster(
|
||||
# n_workers=num_workers,
|
||||
# threads_per_worker=1,
|
||||
# silence_logs=logging.ERROR,
|
||||
# )
|
||||
# cluster_secondary = LocalCluster(
|
||||
# n_workers=num_secondary_workers,
|
||||
# threads_per_worker=1,
|
||||
# silence_logs=logging.ERROR,
|
||||
# )
|
||||
# client_primary = Client(cluster_primary)
|
||||
# client_secondary = Client(cluster_secondary)
|
||||
# if LOG_LEVEL.lower() == "debug":
|
||||
# client_primary.register_worker_plugin(ResourceLogger())
|
||||
# else:
|
||||
# client_primary = SimpleJobClient(n_workers=num_workers)
|
||||
# client_secondary = SimpleJobClient(n_workers=num_secondary_workers)
|
||||
# existing_jobs: dict[str | None, dict[int, Future | SimpleJob]] = {}
|
||||
# logger.notice("Startup complete. Waiting for indexing jobs...")
|
||||
# while True:
|
||||
# start = time.time()
|
||||
# start_time_utc = datetime.utcfromtimestamp(start).strftime("%Y-%m-%d %H:%M:%S")
|
||||
# logger.debug(f"Running update, current UTC time: {start_time_utc}")
|
||||
# if existing_jobs:
|
||||
# logger.debug(
|
||||
# "Found existing indexing jobs: "
|
||||
# f"{[(tenant_id, list(jobs.keys())) for tenant_id, jobs in existing_jobs.items()]}"
|
||||
# )
|
||||
# try:
|
||||
# tenants = get_all_tenant_ids()
|
||||
# for tenant_id in tenants:
|
||||
# try:
|
||||
# logger.debug(
|
||||
# f"Processing {'index attempts' if tenant_id is None else f'tenant {tenant_id}'}"
|
||||
# )
|
||||
# with get_session_with_tenant(tenant_id) as db_session:
|
||||
# index_to_expire = check_index_swap(db_session=db_session)
|
||||
# if index_to_expire and tenant_id and MULTI_TENANT:
|
||||
# VespaIndex.delete_entries_by_tenant_id(
|
||||
# tenant_id=tenant_id,
|
||||
# index_name=index_to_expire.index_name,
|
||||
# )
|
||||
# if not MULTI_TENANT:
|
||||
# search_settings = get_current_search_settings(db_session)
|
||||
# if search_settings.provider_type is None:
|
||||
# logger.notice(
|
||||
# "Running a first inference to warm up embedding model"
|
||||
# )
|
||||
# embedding_model = EmbeddingModel.from_db_model(
|
||||
# search_settings=search_settings,
|
||||
# server_host=INDEXING_MODEL_SERVER_HOST,
|
||||
# server_port=INDEXING_MODEL_SERVER_PORT,
|
||||
# )
|
||||
# warm_up_bi_encoder(embedding_model=embedding_model)
|
||||
# logger.notice("First inference complete.")
|
||||
# tenant_jobs = existing_jobs.get(tenant_id, {})
|
||||
# tenant_jobs = cleanup_indexing_jobs(
|
||||
# existing_jobs=tenant_jobs, tenant_id=tenant_id
|
||||
# )
|
||||
# create_indexing_jobs(existing_jobs=tenant_jobs, tenant_id=tenant_id)
|
||||
# tenant_jobs = kickoff_indexing_jobs(
|
||||
# existing_jobs=tenant_jobs,
|
||||
# client=client_primary,
|
||||
# secondary_client=client_secondary,
|
||||
# tenant_id=tenant_id,
|
||||
# )
|
||||
# existing_jobs[tenant_id] = tenant_jobs
|
||||
# except Exception as e:
|
||||
# logger.exception(
|
||||
# f"Failed to process tenant {tenant_id or 'default'}: {e}"
|
||||
# )
|
||||
# except Exception as e:
|
||||
# logger.exception(f"Failed to run update due to {e}")
|
||||
# sleep_time = delay - (time.time() - start)
|
||||
# if sleep_time > 0:
|
||||
# time.sleep(sleep_time)
|
||||
# def update__main() -> None:
|
||||
# set_is_ee_based_on_env_variable()
|
||||
# # initialize the Postgres connection pool
|
||||
# SqlEngine.set_app_name(POSTGRES_INDEXER_APP_NAME)
|
||||
# logger.notice("Starting indexing service")
|
||||
# update_loop()
|
||||
# if __name__ == "__main__":
|
||||
# update__main()
|
@ -43,6 +43,9 @@ WEB_DOMAIN = os.environ.get("WEB_DOMAIN") or "http://localhost:3000"
|
||||
AUTH_TYPE = AuthType((os.environ.get("AUTH_TYPE") or AuthType.DISABLED.value).lower())
|
||||
DISABLE_AUTH = AUTH_TYPE == AuthType.DISABLED
|
||||
|
||||
# Necessary for cloud integration tests
|
||||
DISABLE_VERIFICATION = os.environ.get("DISABLE_VERIFICATION", "").lower() == "true"
|
||||
|
||||
# Encryption key secret is used to encrypt connector credentials, api keys, and other sensitive
|
||||
# information. This provides an extra layer of security on top of Postgres access controls
|
||||
# and is available in Danswer EE
|
||||
|
@ -42,7 +42,6 @@ def check_index_swap(db_session: Session) -> SearchSettings | None:
|
||||
logger.error("More unique indexings than cc pairs, should not occur")
|
||||
|
||||
if cc_pair_count == 0 or cc_pair_count == unique_cc_indexings:
|
||||
# Swap indices
|
||||
now_old_search_settings = get_current_search_settings(db_session)
|
||||
update_search_settings_status(
|
||||
search_settings=now_old_search_settings,
|
||||
@ -69,4 +68,6 @@ def check_index_swap(db_session: Session) -> SearchSettings | None:
|
||||
|
||||
if MULTI_TENANT:
|
||||
return now_old_search_settings
|
||||
else:
|
||||
logger.warning("No need to swap indices")
|
||||
return None
|
||||
|
@ -9,6 +9,7 @@ from danswer.connectors.models import IndexAttemptMetadata
|
||||
from danswer.db.connector_credential_pair import get_connector_credential_pair_from_id
|
||||
from danswer.db.document import get_documents_by_cc_pair
|
||||
from danswer.db.document import get_ingestion_documents
|
||||
from danswer.db.engine import get_current_tenant_id
|
||||
from danswer.db.engine import get_session
|
||||
from danswer.db.models import User
|
||||
from danswer.db.search_settings import get_current_search_settings
|
||||
@ -67,6 +68,7 @@ def upsert_ingestion_doc(
|
||||
doc_info: IngestionDocument,
|
||||
_: User | None = Depends(api_key_dep),
|
||||
db_session: Session = Depends(get_session),
|
||||
tenant_id: str = Depends(get_current_tenant_id),
|
||||
) -> IngestionResult:
|
||||
doc_info.document.from_ingestion_api = True
|
||||
|
||||
@ -101,6 +103,7 @@ def upsert_ingestion_doc(
|
||||
document_index=curr_doc_index,
|
||||
ignore_time_skip=True,
|
||||
db_session=db_session,
|
||||
tenant_id=tenant_id,
|
||||
)
|
||||
|
||||
new_doc, __chunk_count = indexing_pipeline(
|
||||
@ -134,6 +137,7 @@ def upsert_ingestion_doc(
|
||||
document_index=sec_doc_index,
|
||||
ignore_time_skip=True,
|
||||
db_session=db_session,
|
||||
tenant_id=tenant_id,
|
||||
)
|
||||
|
||||
sec_ind_pipeline(
|
||||
|
@ -102,8 +102,6 @@ def get_settings_notifications(
|
||||
|
||||
try:
|
||||
# Need a transaction in order to prevent under-counting current notifications
|
||||
db_session.begin()
|
||||
|
||||
reindex_notifs = get_notifications(
|
||||
user=user, notif_type=NotificationType.REINDEX, db_session=db_session
|
||||
)
|
||||
|
@ -22,7 +22,6 @@ def add_tenant_id_middleware(app: FastAPI, logger: logging.LoggerAdapter) -> Non
|
||||
) -> Response:
|
||||
try:
|
||||
logger.info(f"Request route: {request.url.path}")
|
||||
|
||||
if not MULTI_TENANT:
|
||||
tenant_id = POSTGRES_DEFAULT_SCHEMA
|
||||
else:
|
||||
|
@ -83,4 +83,5 @@ COPY ./tests/integration /app/tests/integration
|
||||
|
||||
ENV PYTHONPATH=/app
|
||||
|
||||
CMD ["pytest", "-s", "/app/tests/integration"]
|
||||
ENTRYPOINT ["pytest", "-s"]
|
||||
CMD ["/app/tests/integration", "--ignore=/app/tests/integration/multitenant_tests"]
|
@ -23,7 +23,7 @@ from tests.integration.common_utils.test_models import StreamedResponse
|
||||
class ChatSessionManager:
|
||||
@staticmethod
|
||||
def create(
|
||||
persona_id: int = -1,
|
||||
persona_id: int = 0,
|
||||
description: str = "Test chat session",
|
||||
user_performing_action: DATestUser | None = None,
|
||||
) -> DATestChatSession:
|
||||
|
@ -32,6 +32,7 @@ class CredentialManager:
|
||||
"curator_public": curator_public,
|
||||
"groups": groups or [],
|
||||
}
|
||||
|
||||
response = requests.post(
|
||||
url=f"{API_SERVER_URL}/manage/credential",
|
||||
json=credential_request,
|
||||
|
82
backend/tests/integration/common_utils/managers/tenant.py
Normal file
82
backend/tests/integration/common_utils/managers/tenant.py
Normal file
@ -0,0 +1,82 @@
|
||||
from datetime import datetime
|
||||
from datetime import timedelta
|
||||
|
||||
import jwt
|
||||
import requests
|
||||
|
||||
from danswer.server.manage.models import AllUsersResponse
|
||||
from danswer.server.models import FullUserSnapshot
|
||||
from danswer.server.models import InvitedUserSnapshot
|
||||
from tests.integration.common_utils.constants import API_SERVER_URL
|
||||
from tests.integration.common_utils.constants import GENERAL_HEADERS
|
||||
from tests.integration.common_utils.test_models import DATestUser
|
||||
|
||||
|
||||
def generate_auth_token() -> str:
|
||||
payload = {
|
||||
"iss": "control_plane",
|
||||
"exp": datetime.utcnow() + timedelta(minutes=5),
|
||||
"iat": datetime.utcnow(),
|
||||
"scope": "tenant:create",
|
||||
}
|
||||
token = jwt.encode(payload, "", algorithm="HS256")
|
||||
return token
|
||||
|
||||
|
||||
class TenantManager:
|
||||
@staticmethod
|
||||
def create(
|
||||
tenant_id: str | None = None,
|
||||
initial_admin_email: str | None = None,
|
||||
) -> dict[str, str]:
|
||||
body = {
|
||||
"tenant_id": tenant_id,
|
||||
"initial_admin_email": initial_admin_email,
|
||||
}
|
||||
|
||||
token = generate_auth_token()
|
||||
headers = {
|
||||
"Authorization": f"Bearer {token}",
|
||||
"X-API-KEY": "",
|
||||
"Content-Type": "application/json",
|
||||
}
|
||||
|
||||
response = requests.post(
|
||||
url=f"{API_SERVER_URL}/tenants/create",
|
||||
json=body,
|
||||
headers=headers,
|
||||
)
|
||||
|
||||
response.raise_for_status()
|
||||
|
||||
return response.json()
|
||||
|
||||
@staticmethod
|
||||
def get_all_users(
|
||||
user_performing_action: DATestUser | None = None,
|
||||
) -> AllUsersResponse:
|
||||
response = requests.get(
|
||||
url=f"{API_SERVER_URL}/manage/users",
|
||||
headers=user_performing_action.headers
|
||||
if user_performing_action
|
||||
else GENERAL_HEADERS,
|
||||
)
|
||||
response.raise_for_status()
|
||||
|
||||
data = response.json()
|
||||
return AllUsersResponse(
|
||||
accepted=[FullUserSnapshot(**user) for user in data["accepted"]],
|
||||
invited=[InvitedUserSnapshot(**user) for user in data["invited"]],
|
||||
accepted_pages=data["accepted_pages"],
|
||||
invited_pages=data["invited_pages"],
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def verify_user_in_tenant(
|
||||
user: DATestUser, user_performing_action: DATestUser | None = None
|
||||
) -> None:
|
||||
all_users = TenantManager.get_all_users(user_performing_action)
|
||||
for accepted_user in all_users.accepted:
|
||||
if accepted_user.email == user.email and accepted_user.id == user.id:
|
||||
return
|
||||
raise ValueError(f"User {user.email} not found in tenant")
|
@ -65,15 +65,23 @@ class UserManager:
|
||||
data=data,
|
||||
headers=headers,
|
||||
)
|
||||
response.raise_for_status()
|
||||
result_cookie = next(iter(response.cookies), None)
|
||||
|
||||
if not result_cookie:
|
||||
response.raise_for_status()
|
||||
|
||||
cookies = response.cookies.get_dict()
|
||||
session_cookie = cookies.get("fastapiusersauth")
|
||||
tenant_details_cookie = cookies.get("tenant_details")
|
||||
|
||||
if not session_cookie:
|
||||
raise Exception("Failed to login")
|
||||
|
||||
print(f"Logged in as {test_user.email}")
|
||||
cookie = f"{result_cookie.name}={result_cookie.value}"
|
||||
test_user.headers["Cookie"] = cookie
|
||||
|
||||
# Set both cookies in the headers
|
||||
test_user.headers["Cookie"] = (
|
||||
f"fastapiusersauth={session_cookie}; "
|
||||
f"tenant_details={tenant_details_cookie}"
|
||||
)
|
||||
return test_user
|
||||
|
||||
@staticmethod
|
||||
|
@ -1,5 +1,6 @@
|
||||
import logging
|
||||
import time
|
||||
from types import SimpleNamespace
|
||||
|
||||
import psycopg2
|
||||
import requests
|
||||
@ -11,7 +12,9 @@ from danswer.configs.app_configs import POSTGRES_PASSWORD
|
||||
from danswer.configs.app_configs import POSTGRES_PORT
|
||||
from danswer.configs.app_configs import POSTGRES_USER
|
||||
from danswer.db.engine import build_connection_string
|
||||
from danswer.db.engine import get_all_tenant_ids
|
||||
from danswer.db.engine import get_session_context_manager
|
||||
from danswer.db.engine import get_session_with_tenant
|
||||
from danswer.db.engine import SYNC_DB_API
|
||||
from danswer.db.search_settings import get_current_search_settings
|
||||
from danswer.db.swap_index import check_index_swap
|
||||
@ -26,7 +29,11 @@ logger = setup_logger()
|
||||
|
||||
|
||||
def _run_migrations(
|
||||
database_url: str, direction: str = "upgrade", revision: str = "head"
|
||||
database_url: str,
|
||||
config_name: str,
|
||||
direction: str = "upgrade",
|
||||
revision: str = "head",
|
||||
schema: str = "public",
|
||||
) -> None:
|
||||
# hide info logs emitted during migration
|
||||
logging.getLogger("alembic").setLevel(logging.CRITICAL)
|
||||
@ -35,6 +42,10 @@ def _run_migrations(
|
||||
alembic_cfg = Config("alembic.ini")
|
||||
alembic_cfg.set_section_option("logger_alembic", "level", "WARN")
|
||||
alembic_cfg.attributes["configure_logger"] = False
|
||||
alembic_cfg.config_ini_section = config_name
|
||||
|
||||
alembic_cfg.cmd_opts = SimpleNamespace() # type: ignore
|
||||
alembic_cfg.cmd_opts.x = [f"schema={schema}"] # type: ignore
|
||||
|
||||
# Set the SQLAlchemy URL in the Alembic configuration
|
||||
alembic_cfg.set_main_option("sqlalchemy.url", database_url)
|
||||
@ -52,7 +63,9 @@ def _run_migrations(
|
||||
logging.getLogger("alembic").setLevel(logging.INFO)
|
||||
|
||||
|
||||
def reset_postgres(database: str = "postgres") -> None:
|
||||
def reset_postgres(
|
||||
database: str = "postgres", config_name: str = "alembic", setup_danswer: bool = True
|
||||
) -> None:
|
||||
"""Reset the Postgres database."""
|
||||
|
||||
# NOTE: need to delete all rows to allow migrations to be rolled back
|
||||
@ -111,14 +124,18 @@ def reset_postgres(database: str = "postgres") -> None:
|
||||
)
|
||||
_run_migrations(
|
||||
conn_str,
|
||||
config_name,
|
||||
direction="downgrade",
|
||||
revision="base",
|
||||
)
|
||||
_run_migrations(
|
||||
conn_str,
|
||||
config_name,
|
||||
direction="upgrade",
|
||||
revision="head",
|
||||
)
|
||||
if not setup_danswer:
|
||||
return
|
||||
|
||||
# do the same thing as we do on API server startup
|
||||
with get_session_context_manager() as db_session:
|
||||
@ -127,6 +144,7 @@ def reset_postgres(database: str = "postgres") -> None:
|
||||
|
||||
def reset_vespa() -> None:
|
||||
"""Wipe all data from the Vespa index."""
|
||||
|
||||
with get_session_context_manager() as db_session:
|
||||
# swap to the correct default model
|
||||
check_index_swap(db_session)
|
||||
@ -166,10 +184,98 @@ def reset_vespa() -> None:
|
||||
time.sleep(5)
|
||||
|
||||
|
||||
def reset_postgres_multitenant() -> None:
|
||||
"""Reset the Postgres database for all tenants in a multitenant setup."""
|
||||
|
||||
conn = psycopg2.connect(
|
||||
dbname="postgres",
|
||||
user=POSTGRES_USER,
|
||||
password=POSTGRES_PASSWORD,
|
||||
host=POSTGRES_HOST,
|
||||
port=POSTGRES_PORT,
|
||||
)
|
||||
conn.autocommit = True
|
||||
cur = conn.cursor()
|
||||
|
||||
# Get all tenant schemas
|
||||
cur.execute(
|
||||
"""
|
||||
SELECT schema_name
|
||||
FROM information_schema.schemata
|
||||
WHERE schema_name LIKE 'tenant_%'
|
||||
"""
|
||||
)
|
||||
tenant_schemas = cur.fetchall()
|
||||
|
||||
# Drop all tenant schemas
|
||||
for schema in tenant_schemas:
|
||||
schema_name = schema[0]
|
||||
cur.execute(f'DROP SCHEMA "{schema_name}" CASCADE')
|
||||
|
||||
cur.close()
|
||||
conn.close()
|
||||
|
||||
reset_postgres(config_name="schema_private", setup_danswer=False)
|
||||
|
||||
|
||||
def reset_vespa_multitenant() -> None:
|
||||
"""Wipe all data from the Vespa index for all tenants."""
|
||||
|
||||
for tenant_id in get_all_tenant_ids():
|
||||
with get_session_with_tenant(tenant_id=tenant_id) as db_session:
|
||||
# swap to the correct default model for each tenant
|
||||
check_index_swap(db_session)
|
||||
|
||||
search_settings = get_current_search_settings(db_session)
|
||||
index_name = search_settings.index_name
|
||||
|
||||
success = setup_vespa(
|
||||
document_index=VespaIndex(index_name=index_name, secondary_index_name=None),
|
||||
index_setting=IndexingSetting.from_db_model(search_settings),
|
||||
secondary_index_setting=None,
|
||||
)
|
||||
|
||||
if not success:
|
||||
raise RuntimeError(
|
||||
f"Could not connect to Vespa for tenant {tenant_id} within the specified timeout."
|
||||
)
|
||||
|
||||
for _ in range(5):
|
||||
try:
|
||||
continuation = None
|
||||
should_continue = True
|
||||
while should_continue:
|
||||
params = {"selection": "true", "cluster": "danswer_index"}
|
||||
if continuation:
|
||||
params = {**params, "continuation": continuation}
|
||||
response = requests.delete(
|
||||
DOCUMENT_ID_ENDPOINT.format(index_name=index_name),
|
||||
params=params,
|
||||
)
|
||||
response.raise_for_status()
|
||||
|
||||
response_json = response.json()
|
||||
|
||||
continuation = response_json.get("continuation")
|
||||
should_continue = bool(continuation)
|
||||
|
||||
break
|
||||
except Exception as e:
|
||||
print(f"Error deleting documents for tenant {tenant_id}: {e}")
|
||||
time.sleep(5)
|
||||
|
||||
|
||||
def reset_all() -> None:
|
||||
"""Reset both Postgres and Vespa."""
|
||||
logger.info("Resetting Postgres...")
|
||||
reset_postgres()
|
||||
logger.info("Resetting Vespa...")
|
||||
reset_vespa()
|
||||
|
||||
|
||||
def reset_all_multitenant() -> None:
|
||||
"""Reset both Postgres and Vespa for all tenants."""
|
||||
logger.info("Resetting Postgres for all tenants...")
|
||||
reset_postgres_multitenant()
|
||||
logger.info("Resetting Vespa for all tenants...")
|
||||
reset_vespa_multitenant()
|
||||
logger.info("Finished resetting all.")
|
||||
|
@ -8,6 +8,7 @@ from danswer.db.engine import get_session_context_manager
|
||||
from danswer.db.search_settings import get_current_search_settings
|
||||
from tests.integration.common_utils.managers.user import UserManager
|
||||
from tests.integration.common_utils.reset import reset_all
|
||||
from tests.integration.common_utils.reset import reset_all_multitenant
|
||||
from tests.integration.common_utils.test_models import DATestUser
|
||||
from tests.integration.common_utils.vespa import vespa_fixture
|
||||
|
||||
@ -54,3 +55,8 @@ def new_admin_user(reset: None) -> DATestUser | None:
|
||||
return UserManager.create(name="admin_user")
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def reset_multitenant() -> None:
|
||||
reset_all_multitenant()
|
||||
|
0
backend/tests/integration/multitenant_tests/cc_Pair
Normal file
0
backend/tests/integration/multitenant_tests/cc_Pair
Normal file
@ -0,0 +1,150 @@
|
||||
from danswer.db.models import UserRole
|
||||
from tests.integration.common_utils.managers.api_key import APIKeyManager
|
||||
from tests.integration.common_utils.managers.cc_pair import CCPairManager
|
||||
from tests.integration.common_utils.managers.chat import ChatSessionManager
|
||||
from tests.integration.common_utils.managers.document import DocumentManager
|
||||
from tests.integration.common_utils.managers.llm_provider import LLMProviderManager
|
||||
from tests.integration.common_utils.managers.tenant import TenantManager
|
||||
from tests.integration.common_utils.managers.user import UserManager
|
||||
from tests.integration.common_utils.test_models import DATestAPIKey
|
||||
from tests.integration.common_utils.test_models import DATestCCPair
|
||||
from tests.integration.common_utils.test_models import DATestChatSession
|
||||
from tests.integration.common_utils.test_models import DATestUser
|
||||
|
||||
|
||||
def test_multi_tenant_access_control(reset_multitenant: None) -> None:
|
||||
# Create Tenant 1 and its Admin User
|
||||
TenantManager.create("tenant_dev1", "test1@test.com")
|
||||
test_user1: DATestUser = UserManager.create(name="test1", email="test1@test.com")
|
||||
assert UserManager.verify_role(test_user1, UserRole.ADMIN)
|
||||
|
||||
# Create Tenant 2 and its Admin User
|
||||
TenantManager.create("tenant_dev2", "test2@test.com")
|
||||
test_user2: DATestUser = UserManager.create(name="test2", email="test2@test.com")
|
||||
assert UserManager.verify_role(test_user2, UserRole.ADMIN)
|
||||
|
||||
# Create connectors for Tenant 1
|
||||
cc_pair_1: DATestCCPair = CCPairManager.create_from_scratch(
|
||||
user_performing_action=test_user1,
|
||||
)
|
||||
api_key_1: DATestAPIKey = APIKeyManager.create(
|
||||
user_performing_action=test_user1,
|
||||
)
|
||||
api_key_1.headers.update(test_user1.headers)
|
||||
LLMProviderManager.create(user_performing_action=test_user1)
|
||||
|
||||
# Seed documents for Tenant 1
|
||||
cc_pair_1.documents = []
|
||||
doc1_tenant1 = DocumentManager.seed_doc_with_content(
|
||||
cc_pair=cc_pair_1,
|
||||
content="Tenant 1 Document Content",
|
||||
api_key=api_key_1,
|
||||
)
|
||||
doc2_tenant1 = DocumentManager.seed_doc_with_content(
|
||||
cc_pair=cc_pair_1,
|
||||
content="Tenant 1 Document Content",
|
||||
api_key=api_key_1,
|
||||
)
|
||||
cc_pair_1.documents.extend([doc1_tenant1, doc2_tenant1])
|
||||
|
||||
# Create connectors for Tenant 2
|
||||
cc_pair_2: DATestCCPair = CCPairManager.create_from_scratch(
|
||||
user_performing_action=test_user2,
|
||||
)
|
||||
api_key_2: DATestAPIKey = APIKeyManager.create(
|
||||
user_performing_action=test_user2,
|
||||
)
|
||||
api_key_2.headers.update(test_user2.headers)
|
||||
LLMProviderManager.create(user_performing_action=test_user2)
|
||||
|
||||
# Seed documents for Tenant 2
|
||||
cc_pair_2.documents = []
|
||||
doc1_tenant2 = DocumentManager.seed_doc_with_content(
|
||||
cc_pair=cc_pair_2,
|
||||
content="Tenant 2 Document Content",
|
||||
api_key=api_key_2,
|
||||
)
|
||||
doc2_tenant2 = DocumentManager.seed_doc_with_content(
|
||||
cc_pair=cc_pair_2,
|
||||
content="Tenant 2 Document Content",
|
||||
api_key=api_key_2,
|
||||
)
|
||||
cc_pair_2.documents.extend([doc1_tenant2, doc2_tenant2])
|
||||
|
||||
tenant1_doc_ids = {doc1_tenant1.id, doc2_tenant1.id}
|
||||
tenant2_doc_ids = {doc1_tenant2.id, doc2_tenant2.id}
|
||||
|
||||
# Create chat sessions for each user
|
||||
chat_session1: DATestChatSession = ChatSessionManager.create(
|
||||
user_performing_action=test_user1
|
||||
)
|
||||
chat_session2: DATestChatSession = ChatSessionManager.create(
|
||||
user_performing_action=test_user2
|
||||
)
|
||||
|
||||
# User 1 sends a message and gets a response
|
||||
response1 = ChatSessionManager.send_message(
|
||||
chat_session_id=chat_session1.id,
|
||||
message="What is in Tenant 1's documents?",
|
||||
user_performing_action=test_user1,
|
||||
)
|
||||
# Assert that the search tool was used
|
||||
assert response1.tool_name == "run_search"
|
||||
|
||||
response_doc_ids = {doc["document_id"] for doc in response1.tool_result or []}
|
||||
assert tenant1_doc_ids.issubset(
|
||||
response_doc_ids
|
||||
), "Not all Tenant 1 document IDs are in the response"
|
||||
assert not response_doc_ids.intersection(
|
||||
tenant2_doc_ids
|
||||
), "Tenant 2 document IDs should not be in the response"
|
||||
|
||||
# Assert that the contents are correct
|
||||
for doc in response1.tool_result or []:
|
||||
assert doc["content"] == "Tenant 1 Document Content"
|
||||
|
||||
# User 2 sends a message and gets a response
|
||||
response2 = ChatSessionManager.send_message(
|
||||
chat_session_id=chat_session2.id,
|
||||
message="What is in Tenant 2's documents?",
|
||||
user_performing_action=test_user2,
|
||||
)
|
||||
# Assert that the search tool was used
|
||||
assert response2.tool_name == "run_search"
|
||||
# Assert that the tool_result contains Tenant 2's documents
|
||||
response_doc_ids = {doc["document_id"] for doc in response2.tool_result or []}
|
||||
assert tenant2_doc_ids.issubset(
|
||||
response_doc_ids
|
||||
), "Not all Tenant 2 document IDs are in the response"
|
||||
assert not response_doc_ids.intersection(
|
||||
tenant1_doc_ids
|
||||
), "Tenant 1 document IDs should not be in the response"
|
||||
# Assert that the contents are correct
|
||||
for doc in response2.tool_result or []:
|
||||
assert doc["content"] == "Tenant 2 Document Content"
|
||||
|
||||
# User 1 tries to access Tenant 2's documents
|
||||
response_cross = ChatSessionManager.send_message(
|
||||
chat_session_id=chat_session1.id,
|
||||
message="What is in Tenant 2's documents?",
|
||||
user_performing_action=test_user1,
|
||||
)
|
||||
# Assert that the search tool was used
|
||||
assert response_cross.tool_name == "run_search"
|
||||
# Assert that the tool_result is empty or does not contain Tenant 2's documents
|
||||
response_doc_ids = {doc["document_id"] for doc in response_cross.tool_result or []}
|
||||
# Ensure none of Tenant 2's document IDs are in the response
|
||||
assert not response_doc_ids.intersection(tenant2_doc_ids)
|
||||
|
||||
# User 2 tries to access Tenant 1's documents
|
||||
response_cross2 = ChatSessionManager.send_message(
|
||||
chat_session_id=chat_session2.id,
|
||||
message="What is in Tenant 1's documents?",
|
||||
user_performing_action=test_user2,
|
||||
)
|
||||
# Assert that the search tool was used
|
||||
assert response_cross2.tool_name == "run_search"
|
||||
# Assert that the tool_result is empty or does not contain Tenant 1's documents
|
||||
response_doc_ids = {doc["document_id"] for doc in response_cross2.tool_result or []}
|
||||
# Ensure none of Tenant 1's document IDs are in the response
|
||||
assert not response_doc_ids.intersection(tenant1_doc_ids)
|
@ -0,0 +1,41 @@
|
||||
from danswer.configs.constants import DocumentSource
|
||||
from danswer.db.enums import AccessType
|
||||
from danswer.db.models import UserRole
|
||||
from tests.integration.common_utils.managers.cc_pair import CCPairManager
|
||||
from tests.integration.common_utils.managers.connector import ConnectorManager
|
||||
from tests.integration.common_utils.managers.credential import CredentialManager
|
||||
from tests.integration.common_utils.managers.tenant import TenantManager
|
||||
from tests.integration.common_utils.managers.user import UserManager
|
||||
from tests.integration.common_utils.test_models import DATestUser
|
||||
|
||||
|
||||
# Test flow from creating tenant to registering as a user
|
||||
def test_tenant_creation(reset_multitenant: None) -> None:
|
||||
TenantManager.create("tenant_dev", "test@test.com")
|
||||
test_user: DATestUser = UserManager.create(name="test", email="test@test.com")
|
||||
|
||||
assert UserManager.verify_role(test_user, UserRole.ADMIN)
|
||||
|
||||
test_credential = CredentialManager.create(
|
||||
name="admin_test_credential",
|
||||
source=DocumentSource.FILE,
|
||||
curator_public=False,
|
||||
user_performing_action=test_user,
|
||||
)
|
||||
|
||||
test_connector = ConnectorManager.create(
|
||||
name="admin_test_connector",
|
||||
source=DocumentSource.FILE,
|
||||
is_public=False,
|
||||
user_performing_action=test_user,
|
||||
)
|
||||
|
||||
test_cc_pair = CCPairManager.create(
|
||||
connector_id=test_connector.id,
|
||||
credential_id=test_credential.id,
|
||||
name="admin_test_cc_pair",
|
||||
access_type=AccessType.PRIVATE,
|
||||
user_performing_action=test_user,
|
||||
)
|
||||
|
||||
CCPairManager.verify(cc_pair=test_cc_pair, user_performing_action=test_user)
|
@ -119,6 +119,7 @@ def test_using_reference_docs_with_simple_with_history_api_flow(reset: None) ->
|
||||
)
|
||||
assert response.status_code == 200
|
||||
response_json = response.json()
|
||||
|
||||
# get the db_doc_id of the top document to use as a search doc id for second message
|
||||
first_db_doc_id = response_json["top_documents"][0]["db_doc_id"]
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user