mirror of
https://github.com/danswer-ai/danswer.git
synced 2025-06-03 11:40:01 +02:00
Tenant integration tests (#2913)
* check for index swap * initial bones * kk * k * k: * nit * nit * rebase + update * nit * minior update * k * minor integration test fixes * nit * ensure we build test docker image * remove one space * k * ensure we wipe volumes * remove log * typo * nit * k * k
This commit is contained in:
parent
bd63119684
commit
9b147ae437
61
.github/workflows/pr-Integration-tests.yml
vendored
61
.github/workflows/pr-Integration-tests.yml
vendored
@ -85,6 +85,57 @@ jobs:
|
|||||||
cache-from: type=s3,prefix=cache/${{ github.repository }}/integration-tests/integration/,region=${{ env.RUNS_ON_AWS_REGION }},bucket=${{ env.RUNS_ON_S3_BUCKET_CACHE }}
|
cache-from: type=s3,prefix=cache/${{ github.repository }}/integration-tests/integration/,region=${{ env.RUNS_ON_AWS_REGION }},bucket=${{ env.RUNS_ON_S3_BUCKET_CACHE }}
|
||||||
cache-to: type=s3,prefix=cache/${{ github.repository }}/integration-tests/integration/,region=${{ env.RUNS_ON_AWS_REGION }},bucket=${{ env.RUNS_ON_S3_BUCKET_CACHE }},mode=max
|
cache-to: type=s3,prefix=cache/${{ github.repository }}/integration-tests/integration/,region=${{ env.RUNS_ON_AWS_REGION }},bucket=${{ env.RUNS_ON_S3_BUCKET_CACHE }},mode=max
|
||||||
|
|
||||||
|
# Start containers for multi-tenant tests
|
||||||
|
- name: Start Docker containers for multi-tenant tests
|
||||||
|
run: |
|
||||||
|
cd deployment/docker_compose
|
||||||
|
ENABLE_PAID_ENTERPRISE_EDITION_FEATURES=true \
|
||||||
|
MULTI_TENANT=true \
|
||||||
|
AUTH_TYPE=basic \
|
||||||
|
REQUIRE_EMAIL_VERIFICATION=false \
|
||||||
|
DISABLE_TELEMETRY=true \
|
||||||
|
IMAGE_TAG=test \
|
||||||
|
docker compose -f docker-compose.dev.yml -p danswer-stack up -d
|
||||||
|
id: start_docker_multi_tenant
|
||||||
|
|
||||||
|
# In practice, `cloud` Auth type would require OAUTH credentials to be set.
|
||||||
|
- name: Run Multi-Tenant Integration Tests
|
||||||
|
run: |
|
||||||
|
echo "Running integration tests..."
|
||||||
|
docker run --rm --network danswer-stack_default \
|
||||||
|
--name test-runner \
|
||||||
|
-e POSTGRES_HOST=relational_db \
|
||||||
|
-e POSTGRES_USER=postgres \
|
||||||
|
-e POSTGRES_PASSWORD=password \
|
||||||
|
-e POSTGRES_DB=postgres \
|
||||||
|
-e VESPA_HOST=index \
|
||||||
|
-e REDIS_HOST=cache \
|
||||||
|
-e API_SERVER_HOST=api_server \
|
||||||
|
-e OPENAI_API_KEY=${OPENAI_API_KEY} \
|
||||||
|
-e SLACK_BOT_TOKEN=${SLACK_BOT_TOKEN} \
|
||||||
|
-e TEST_WEB_HOSTNAME=test-runner \
|
||||||
|
-e AUTH_TYPE=cloud \
|
||||||
|
-e MULTI_TENANT=true \
|
||||||
|
danswer/danswer-integration:test \
|
||||||
|
/app/tests/integration/multitenant_tests
|
||||||
|
continue-on-error: true
|
||||||
|
id: run_multitenant_tests
|
||||||
|
|
||||||
|
- name: Check multi-tenant test results
|
||||||
|
run: |
|
||||||
|
if [ ${{ steps.run_tests.outcome }} == 'failure' ]; then
|
||||||
|
echo "Integration tests failed. Exiting with error."
|
||||||
|
exit 1
|
||||||
|
else
|
||||||
|
echo "All integration tests passed successfully."
|
||||||
|
fi
|
||||||
|
|
||||||
|
- name: Stop multi-tenant Docker containers
|
||||||
|
run: |
|
||||||
|
cd deployment/docker_compose
|
||||||
|
docker compose -f docker-compose.dev.yml -p danswer-stack down -v
|
||||||
|
|
||||||
|
|
||||||
- name: Start Docker containers
|
- name: Start Docker containers
|
||||||
run: |
|
run: |
|
||||||
cd deployment/docker_compose
|
cd deployment/docker_compose
|
||||||
@ -130,7 +181,7 @@ jobs:
|
|||||||
done
|
done
|
||||||
echo "Finished waiting for service."
|
echo "Finished waiting for service."
|
||||||
|
|
||||||
- name: Run integration tests
|
- name: Run Standard Integration Tests
|
||||||
run: |
|
run: |
|
||||||
echo "Running integration tests..."
|
echo "Running integration tests..."
|
||||||
docker run --rm --network danswer-stack_default \
|
docker run --rm --network danswer-stack_default \
|
||||||
@ -145,7 +196,8 @@ jobs:
|
|||||||
-e OPENAI_API_KEY=${OPENAI_API_KEY} \
|
-e OPENAI_API_KEY=${OPENAI_API_KEY} \
|
||||||
-e SLACK_BOT_TOKEN=${SLACK_BOT_TOKEN} \
|
-e SLACK_BOT_TOKEN=${SLACK_BOT_TOKEN} \
|
||||||
-e TEST_WEB_HOSTNAME=test-runner \
|
-e TEST_WEB_HOSTNAME=test-runner \
|
||||||
danswer/danswer-integration:test
|
danswer/danswer-integration:test \
|
||||||
|
/app/tests/integration/tests
|
||||||
continue-on-error: true
|
continue-on-error: true
|
||||||
id: run_tests
|
id: run_tests
|
||||||
|
|
||||||
@ -158,6 +210,11 @@ jobs:
|
|||||||
echo "All integration tests passed successfully."
|
echo "All integration tests passed successfully."
|
||||||
fi
|
fi
|
||||||
|
|
||||||
|
- name: Stop Docker containers
|
||||||
|
run: |
|
||||||
|
cd deployment/docker_compose
|
||||||
|
docker compose -f docker-compose.dev.yml -p danswer-stack down -v
|
||||||
|
|
||||||
- name: Save Docker logs
|
- name: Save Docker logs
|
||||||
if: success() || failure()
|
if: success() || failure()
|
||||||
run: |
|
run: |
|
||||||
|
@ -58,6 +58,7 @@ from danswer.auth.schemas import UserRole
|
|||||||
from danswer.auth.schemas import UserUpdate
|
from danswer.auth.schemas import UserUpdate
|
||||||
from danswer.configs.app_configs import AUTH_TYPE
|
from danswer.configs.app_configs import AUTH_TYPE
|
||||||
from danswer.configs.app_configs import DISABLE_AUTH
|
from danswer.configs.app_configs import DISABLE_AUTH
|
||||||
|
from danswer.configs.app_configs import DISABLE_VERIFICATION
|
||||||
from danswer.configs.app_configs import EMAIL_FROM
|
from danswer.configs.app_configs import EMAIL_FROM
|
||||||
from danswer.configs.app_configs import MULTI_TENANT
|
from danswer.configs.app_configs import MULTI_TENANT
|
||||||
from danswer.configs.app_configs import REQUIRE_EMAIL_VERIFICATION
|
from danswer.configs.app_configs import REQUIRE_EMAIL_VERIFICATION
|
||||||
@ -133,7 +134,9 @@ def get_display_email(email: str | None, space_less: bool = False) -> str:
|
|||||||
def user_needs_to_be_verified() -> bool:
|
def user_needs_to_be_verified() -> bool:
|
||||||
# all other auth types besides basic should require users to be
|
# all other auth types besides basic should require users to be
|
||||||
# verified
|
# verified
|
||||||
return AUTH_TYPE != AuthType.BASIC or REQUIRE_EMAIL_VERIFICATION
|
return not DISABLE_VERIFICATION and (
|
||||||
|
AUTH_TYPE != AuthType.BASIC or REQUIRE_EMAIL_VERIFICATION
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def verify_email_is_invited(email: str) -> None:
|
def verify_email_is_invited(email: str) -> None:
|
||||||
|
@ -8,18 +8,11 @@ from celery.signals import celeryd_init
|
|||||||
from celery.signals import worker_init
|
from celery.signals import worker_init
|
||||||
from celery.signals import worker_ready
|
from celery.signals import worker_ready
|
||||||
from celery.signals import worker_shutdown
|
from celery.signals import worker_shutdown
|
||||||
from sqlalchemy.orm import Session
|
|
||||||
|
|
||||||
import danswer.background.celery.apps.app_base as app_base
|
import danswer.background.celery.apps.app_base as app_base
|
||||||
from danswer.configs.constants import POSTGRES_CELERY_WORKER_INDEXING_APP_NAME
|
from danswer.configs.constants import POSTGRES_CELERY_WORKER_INDEXING_APP_NAME
|
||||||
from danswer.db.engine import SqlEngine
|
from danswer.db.engine import SqlEngine
|
||||||
from danswer.db.search_settings import get_current_search_settings
|
|
||||||
from danswer.db.swap_index import check_index_swap
|
|
||||||
from danswer.natural_language_processing.search_nlp_models import EmbeddingModel
|
|
||||||
from danswer.natural_language_processing.search_nlp_models import warm_up_bi_encoder
|
|
||||||
from danswer.utils.logger import setup_logger
|
from danswer.utils.logger import setup_logger
|
||||||
from shared_configs.configs import INDEXING_MODEL_SERVER_HOST
|
|
||||||
from shared_configs.configs import MODEL_SERVER_PORT
|
|
||||||
|
|
||||||
|
|
||||||
logger = setup_logger()
|
logger = setup_logger()
|
||||||
@ -67,27 +60,6 @@ def on_worker_init(sender: Any, **kwargs: Any) -> None:
|
|||||||
SqlEngine.set_app_name(POSTGRES_CELERY_WORKER_INDEXING_APP_NAME)
|
SqlEngine.set_app_name(POSTGRES_CELERY_WORKER_INDEXING_APP_NAME)
|
||||||
SqlEngine.init_engine(pool_size=8, max_overflow=0)
|
SqlEngine.init_engine(pool_size=8, max_overflow=0)
|
||||||
|
|
||||||
# TODO: why is this necessary for the indexer to do?
|
|
||||||
engine = SqlEngine.get_engine()
|
|
||||||
with Session(engine) as db_session:
|
|
||||||
check_index_swap(db_session=db_session)
|
|
||||||
search_settings = get_current_search_settings(db_session)
|
|
||||||
|
|
||||||
# So that the first time users aren't surprised by really slow speed of first
|
|
||||||
# batch of documents indexed
|
|
||||||
if search_settings.provider_type is None:
|
|
||||||
logger.notice("Running a first inference to warm up embedding model")
|
|
||||||
embedding_model = EmbeddingModel.from_db_model(
|
|
||||||
search_settings=search_settings,
|
|
||||||
server_host=INDEXING_MODEL_SERVER_HOST,
|
|
||||||
server_port=MODEL_SERVER_PORT,
|
|
||||||
)
|
|
||||||
|
|
||||||
warm_up_bi_encoder(
|
|
||||||
embedding_model=embedding_model,
|
|
||||||
)
|
|
||||||
logger.notice("First inference complete.")
|
|
||||||
|
|
||||||
app_base.wait_for_redis(sender, **kwargs)
|
app_base.wait_for_redis(sender, **kwargs)
|
||||||
app_base.on_secondary_worker_init(sender, **kwargs)
|
app_base.on_secondary_worker_init(sender, **kwargs)
|
||||||
|
|
||||||
|
@ -1,494 +0,0 @@
|
|||||||
# TODO(rkuo): delete after background indexing via celery is fully vetted
|
|
||||||
# import logging
|
|
||||||
# import time
|
|
||||||
# from datetime import datetime
|
|
||||||
# import dask
|
|
||||||
# from dask.distributed import Client
|
|
||||||
# from dask.distributed import Future
|
|
||||||
# from distributed import LocalCluster
|
|
||||||
# from sqlalchemy import text
|
|
||||||
# from sqlalchemy.exc import ProgrammingError
|
|
||||||
# from sqlalchemy.orm import Session
|
|
||||||
# from danswer.background.indexing.dask_utils import ResourceLogger
|
|
||||||
# from danswer.background.indexing.job_client import SimpleJob
|
|
||||||
# from danswer.background.indexing.job_client import SimpleJobClient
|
|
||||||
# from danswer.background.indexing.run_indexing import run_indexing_entrypoint
|
|
||||||
# from danswer.configs.app_configs import CLEANUP_INDEXING_JOBS_TIMEOUT
|
|
||||||
# from danswer.configs.app_configs import DASK_JOB_CLIENT_ENABLED
|
|
||||||
# from danswer.configs.app_configs import DISABLE_INDEX_UPDATE_ON_SWAP
|
|
||||||
# from danswer.configs.app_configs import MULTI_TENANT
|
|
||||||
# from danswer.configs.app_configs import NUM_INDEXING_WORKERS
|
|
||||||
# from danswer.configs.app_configs import NUM_SECONDARY_INDEXING_WORKERS
|
|
||||||
# from danswer.configs.constants import DocumentSource
|
|
||||||
# from danswer.configs.constants import POSTGRES_INDEXER_APP_NAME
|
|
||||||
# from danswer.configs.constants import TENANT_ID_PREFIX
|
|
||||||
# from danswer.db.connector import fetch_connectors
|
|
||||||
# from danswer.db.connector_credential_pair import fetch_connector_credential_pairs
|
|
||||||
# from danswer.db.engine import get_db_current_time
|
|
||||||
# from danswer.db.engine import get_session_with_tenant
|
|
||||||
# from danswer.db.engine import get_sqlalchemy_engine
|
|
||||||
# from danswer.db.engine import SqlEngine
|
|
||||||
# from danswer.db.index_attempt import create_index_attempt
|
|
||||||
# from danswer.db.index_attempt import get_index_attempt
|
|
||||||
# from danswer.db.index_attempt import get_inprogress_index_attempts
|
|
||||||
# from danswer.db.index_attempt import get_last_attempt_for_cc_pair
|
|
||||||
# from danswer.db.index_attempt import get_not_started_index_attempts
|
|
||||||
# from danswer.db.index_attempt import mark_attempt_failed
|
|
||||||
# from danswer.db.models import ConnectorCredentialPair
|
|
||||||
# from danswer.db.models import IndexAttempt
|
|
||||||
# from danswer.db.models import IndexingStatus
|
|
||||||
# from danswer.db.models import IndexModelStatus
|
|
||||||
# from danswer.db.models import SearchSettings
|
|
||||||
# from danswer.db.search_settings import get_current_search_settings
|
|
||||||
# from danswer.db.search_settings import get_secondary_search_settings
|
|
||||||
# from danswer.db.swap_index import check_index_swap
|
|
||||||
# from danswer.document_index.vespa.index import VespaIndex
|
|
||||||
# from danswer.natural_language_processing.search_nlp_models import EmbeddingModel
|
|
||||||
# from danswer.natural_language_processing.search_nlp_models import warm_up_bi_encoder
|
|
||||||
# from danswer.utils.logger import setup_logger
|
|
||||||
# from danswer.utils.variable_functionality import global_version
|
|
||||||
# from danswer.utils.variable_functionality import set_is_ee_based_on_env_variable
|
|
||||||
# from shared_configs.configs import INDEXING_MODEL_SERVER_HOST
|
|
||||||
# from shared_configs.configs import INDEXING_MODEL_SERVER_PORT
|
|
||||||
# from shared_configs.configs import LOG_LEVEL
|
|
||||||
# logger = setup_logger()
|
|
||||||
# # If the indexing dies, it's most likely due to resource constraints,
|
|
||||||
# # restarting just delays the eventual failure, not useful to the user
|
|
||||||
# dask.config.set({"distributed.scheduler.allowed-failures": 0})
|
|
||||||
# _UNEXPECTED_STATE_FAILURE_REASON = (
|
|
||||||
# "Stopped mid run, likely due to the background process being killed"
|
|
||||||
# )
|
|
||||||
# def _should_create_new_indexing(
|
|
||||||
# cc_pair: ConnectorCredentialPair,
|
|
||||||
# last_index: IndexAttempt | None,
|
|
||||||
# search_settings_instance: SearchSettings,
|
|
||||||
# secondary_index_building: bool,
|
|
||||||
# db_session: Session,
|
|
||||||
# ) -> bool:
|
|
||||||
# connector = cc_pair.connector
|
|
||||||
# # don't kick off indexing for `NOT_APPLICABLE` sources
|
|
||||||
# if connector.source == DocumentSource.NOT_APPLICABLE:
|
|
||||||
# return False
|
|
||||||
# # User can still manually create single indexing attempts via the UI for the
|
|
||||||
# # currently in use index
|
|
||||||
# if DISABLE_INDEX_UPDATE_ON_SWAP:
|
|
||||||
# if (
|
|
||||||
# search_settings_instance.status == IndexModelStatus.PRESENT
|
|
||||||
# and secondary_index_building
|
|
||||||
# ):
|
|
||||||
# return False
|
|
||||||
# # When switching over models, always index at least once
|
|
||||||
# if search_settings_instance.status == IndexModelStatus.FUTURE:
|
|
||||||
# if last_index:
|
|
||||||
# # No new index if the last index attempt succeeded
|
|
||||||
# # Once is enough. The model will never be able to swap otherwise.
|
|
||||||
# if last_index.status == IndexingStatus.SUCCESS:
|
|
||||||
# return False
|
|
||||||
# # No new index if the last index attempt is waiting to start
|
|
||||||
# if last_index.status == IndexingStatus.NOT_STARTED:
|
|
||||||
# return False
|
|
||||||
# # No new index if the last index attempt is running
|
|
||||||
# if last_index.status == IndexingStatus.IN_PROGRESS:
|
|
||||||
# return False
|
|
||||||
# else:
|
|
||||||
# if (
|
|
||||||
# connector.id == 0 or connector.source == DocumentSource.INGESTION_API
|
|
||||||
# ): # Ingestion API
|
|
||||||
# return False
|
|
||||||
# return True
|
|
||||||
# # If the connector is paused or is the ingestion API, don't index
|
|
||||||
# # NOTE: during an embedding model switch over, the following logic
|
|
||||||
# # is bypassed by the above check for a future model
|
|
||||||
# if (
|
|
||||||
# not cc_pair.status.is_active()
|
|
||||||
# or connector.id == 0
|
|
||||||
# or connector.source == DocumentSource.INGESTION_API
|
|
||||||
# ):
|
|
||||||
# return False
|
|
||||||
# if not last_index:
|
|
||||||
# return True
|
|
||||||
# if connector.refresh_freq is None:
|
|
||||||
# return False
|
|
||||||
# # Only one scheduled/ongoing job per connector at a time
|
|
||||||
# # this prevents cases where
|
|
||||||
# # (1) the "latest" index_attempt is scheduled so we show
|
|
||||||
# # that in the UI despite another index_attempt being in-progress
|
|
||||||
# # (2) multiple scheduled index_attempts at a time
|
|
||||||
# if (
|
|
||||||
# last_index.status == IndexingStatus.NOT_STARTED
|
|
||||||
# or last_index.status == IndexingStatus.IN_PROGRESS
|
|
||||||
# ):
|
|
||||||
# return False
|
|
||||||
# current_db_time = get_db_current_time(db_session)
|
|
||||||
# time_since_index = current_db_time - last_index.time_updated
|
|
||||||
# return time_since_index.total_seconds() >= connector.refresh_freq
|
|
||||||
# def _mark_run_failed(
|
|
||||||
# db_session: Session, index_attempt: IndexAttempt, failure_reason: str
|
|
||||||
# ) -> None:
|
|
||||||
# """Marks the `index_attempt` row as failed + updates the `
|
|
||||||
# connector_credential_pair` to reflect that the run failed"""
|
|
||||||
# logger.warning(
|
|
||||||
# f"Marking in-progress attempt 'connector: {index_attempt.connector_credential_pair.connector_id}, "
|
|
||||||
# f"credential: {index_attempt.connector_credential_pair.credential_id}' as failed due to {failure_reason}"
|
|
||||||
# )
|
|
||||||
# mark_attempt_failed(
|
|
||||||
# index_attempt=index_attempt,
|
|
||||||
# db_session=db_session,
|
|
||||||
# failure_reason=failure_reason,
|
|
||||||
# )
|
|
||||||
# """Main funcs"""
|
|
||||||
# def create_indexing_jobs(
|
|
||||||
# existing_jobs: dict[int, Future | SimpleJob], tenant_id: str | None
|
|
||||||
# ) -> None:
|
|
||||||
# """Creates new indexing jobs for each connector / credential pair which is:
|
|
||||||
# 1. Enabled
|
|
||||||
# 2. `refresh_frequency` time has passed since the last indexing run for this pair
|
|
||||||
# 3. There is not already an ongoing indexing attempt for this pair
|
|
||||||
# """
|
|
||||||
# with get_session_with_tenant(tenant_id) as db_session:
|
|
||||||
# ongoing: set[tuple[int | None, int]] = set()
|
|
||||||
# for attempt_id in existing_jobs:
|
|
||||||
# attempt = get_index_attempt(
|
|
||||||
# db_session=db_session, index_attempt_id=attempt_id
|
|
||||||
# )
|
|
||||||
# if attempt is None:
|
|
||||||
# logger.error(
|
|
||||||
# f"Unable to find IndexAttempt for ID '{attempt_id}' when creating "
|
|
||||||
# "indexing jobs"
|
|
||||||
# )
|
|
||||||
# continue
|
|
||||||
# ongoing.add(
|
|
||||||
# (
|
|
||||||
# attempt.connector_credential_pair_id,
|
|
||||||
# attempt.search_settings_id,
|
|
||||||
# )
|
|
||||||
# )
|
|
||||||
# # Get the primary search settings
|
|
||||||
# primary_search_settings = get_current_search_settings(db_session)
|
|
||||||
# search_settings = [primary_search_settings]
|
|
||||||
# # Check for secondary search settings
|
|
||||||
# secondary_search_settings = get_secondary_search_settings(db_session)
|
|
||||||
# if secondary_search_settings is not None:
|
|
||||||
# # If secondary settings exist, add them to the list
|
|
||||||
# search_settings.append(secondary_search_settings)
|
|
||||||
# all_connector_credential_pairs = fetch_connector_credential_pairs(db_session)
|
|
||||||
# for cc_pair in all_connector_credential_pairs:
|
|
||||||
# for search_settings_instance in search_settings:
|
|
||||||
# # Check if there is an ongoing indexing attempt for this connector credential pair
|
|
||||||
# if (cc_pair.id, search_settings_instance.id) in ongoing:
|
|
||||||
# continue
|
|
||||||
# last_attempt = get_last_attempt_for_cc_pair(
|
|
||||||
# cc_pair.id, search_settings_instance.id, db_session
|
|
||||||
# )
|
|
||||||
# if not _should_create_new_indexing(
|
|
||||||
# cc_pair=cc_pair,
|
|
||||||
# last_index=last_attempt,
|
|
||||||
# search_settings_instance=search_settings_instance,
|
|
||||||
# secondary_index_building=len(search_settings) > 1,
|
|
||||||
# db_session=db_session,
|
|
||||||
# ):
|
|
||||||
# continue
|
|
||||||
# create_index_attempt(
|
|
||||||
# cc_pair.id, search_settings_instance.id, db_session
|
|
||||||
# )
|
|
||||||
# def cleanup_indexing_jobs(
|
|
||||||
# existing_jobs: dict[int, Future | SimpleJob],
|
|
||||||
# tenant_id: str | None,
|
|
||||||
# timeout_hours: int = CLEANUP_INDEXING_JOBS_TIMEOUT,
|
|
||||||
# ) -> dict[int, Future | SimpleJob]:
|
|
||||||
# existing_jobs_copy = existing_jobs.copy()
|
|
||||||
# # clean up completed jobs
|
|
||||||
# with get_session_with_tenant(tenant_id) as db_session:
|
|
||||||
# for attempt_id, job in existing_jobs.items():
|
|
||||||
# index_attempt = get_index_attempt(
|
|
||||||
# db_session=db_session, index_attempt_id=attempt_id
|
|
||||||
# )
|
|
||||||
# # do nothing for ongoing jobs that haven't been stopped
|
|
||||||
# if not job.done():
|
|
||||||
# if not index_attempt:
|
|
||||||
# continue
|
|
||||||
# if not index_attempt.is_finished():
|
|
||||||
# continue
|
|
||||||
# if job.status == "error":
|
|
||||||
# logger.error(job.exception())
|
|
||||||
# job.release()
|
|
||||||
# del existing_jobs_copy[attempt_id]
|
|
||||||
# if not index_attempt:
|
|
||||||
# logger.error(
|
|
||||||
# f"Unable to find IndexAttempt for ID '{attempt_id}' when cleaning "
|
|
||||||
# "up indexing jobs"
|
|
||||||
# )
|
|
||||||
# continue
|
|
||||||
# if (
|
|
||||||
# index_attempt.status == IndexingStatus.IN_PROGRESS
|
|
||||||
# or job.status == "error"
|
|
||||||
# ):
|
|
||||||
# _mark_run_failed(
|
|
||||||
# db_session=db_session,
|
|
||||||
# index_attempt=index_attempt,
|
|
||||||
# failure_reason=_UNEXPECTED_STATE_FAILURE_REASON,
|
|
||||||
# )
|
|
||||||
# # clean up in-progress jobs that were never completed
|
|
||||||
# try:
|
|
||||||
# connectors = fetch_connectors(db_session)
|
|
||||||
# for connector in connectors:
|
|
||||||
# in_progress_indexing_attempts = get_inprogress_index_attempts(
|
|
||||||
# connector.id, db_session
|
|
||||||
# )
|
|
||||||
# for index_attempt in in_progress_indexing_attempts:
|
|
||||||
# if index_attempt.id in existing_jobs:
|
|
||||||
# # If index attempt is canceled, stop the run
|
|
||||||
# if index_attempt.status == IndexingStatus.FAILED:
|
|
||||||
# existing_jobs[index_attempt.id].cancel()
|
|
||||||
# # check to see if the job has been updated in last `timeout_hours` hours, if not
|
|
||||||
# # assume it to frozen in some bad state and just mark it as failed. Note: this relies
|
|
||||||
# # on the fact that the `time_updated` field is constantly updated every
|
|
||||||
# # batch of documents indexed
|
|
||||||
# current_db_time = get_db_current_time(db_session=db_session)
|
|
||||||
# time_since_update = current_db_time - index_attempt.time_updated
|
|
||||||
# if time_since_update.total_seconds() > 60 * 60 * timeout_hours:
|
|
||||||
# existing_jobs[index_attempt.id].cancel()
|
|
||||||
# _mark_run_failed(
|
|
||||||
# db_session=db_session,
|
|
||||||
# index_attempt=index_attempt,
|
|
||||||
# failure_reason="Indexing run frozen - no updates in the last three hours. "
|
|
||||||
# "The run will be re-attempted at next scheduled indexing time.",
|
|
||||||
# )
|
|
||||||
# else:
|
|
||||||
# # If job isn't known, simply mark it as failed
|
|
||||||
# _mark_run_failed(
|
|
||||||
# db_session=db_session,
|
|
||||||
# index_attempt=index_attempt,
|
|
||||||
# failure_reason=_UNEXPECTED_STATE_FAILURE_REASON,
|
|
||||||
# )
|
|
||||||
# except ProgrammingError:
|
|
||||||
# logger.debug(f"No Connector Table exists for: {tenant_id}")
|
|
||||||
# return existing_jobs_copy
|
|
||||||
# def kickoff_indexing_jobs(
|
|
||||||
# existing_jobs: dict[int, Future | SimpleJob],
|
|
||||||
# client: Client | SimpleJobClient,
|
|
||||||
# secondary_client: Client | SimpleJobClient,
|
|
||||||
# tenant_id: str | None,
|
|
||||||
# ) -> dict[int, Future | SimpleJob]:
|
|
||||||
# existing_jobs_copy = existing_jobs.copy()
|
|
||||||
# current_session = get_session_with_tenant(tenant_id)
|
|
||||||
# # Don't include jobs waiting in the Dask queue that just haven't started running
|
|
||||||
# # Also (rarely) don't include for jobs that started but haven't updated the indexing tables yet
|
|
||||||
# with current_session as db_session:
|
|
||||||
# # get_not_started_index_attempts orders its returned results from oldest to newest
|
|
||||||
# # we must process attempts in a FIFO manner to prevent connector starvation
|
|
||||||
# new_indexing_attempts = [
|
|
||||||
# (attempt, attempt.search_settings)
|
|
||||||
# for attempt in get_not_started_index_attempts(db_session)
|
|
||||||
# if attempt.id not in existing_jobs
|
|
||||||
# ]
|
|
||||||
# logger.debug(f"Found {len(new_indexing_attempts)} new indexing task(s).")
|
|
||||||
# if not new_indexing_attempts:
|
|
||||||
# return existing_jobs
|
|
||||||
# indexing_attempt_count = 0
|
|
||||||
# primary_client_full = False
|
|
||||||
# secondary_client_full = False
|
|
||||||
# for attempt, search_settings in new_indexing_attempts:
|
|
||||||
# if primary_client_full and secondary_client_full:
|
|
||||||
# break
|
|
||||||
# use_secondary_index = (
|
|
||||||
# search_settings.status == IndexModelStatus.FUTURE
|
|
||||||
# if search_settings is not None
|
|
||||||
# else False
|
|
||||||
# )
|
|
||||||
# if attempt.connector_credential_pair.connector is None:
|
|
||||||
# logger.warning(
|
|
||||||
# f"Skipping index attempt as Connector has been deleted: {attempt}"
|
|
||||||
# )
|
|
||||||
# with current_session as db_session:
|
|
||||||
# mark_attempt_failed(
|
|
||||||
# attempt, db_session, failure_reason="Connector is null"
|
|
||||||
# )
|
|
||||||
# continue
|
|
||||||
# if attempt.connector_credential_pair.credential is None:
|
|
||||||
# logger.warning(
|
|
||||||
# f"Skipping index attempt as Credential has been deleted: {attempt}"
|
|
||||||
# )
|
|
||||||
# with current_session as db_session:
|
|
||||||
# mark_attempt_failed(
|
|
||||||
# attempt, db_session, failure_reason="Credential is null"
|
|
||||||
# )
|
|
||||||
# continue
|
|
||||||
# if not use_secondary_index:
|
|
||||||
# if not primary_client_full:
|
|
||||||
# run = client.submit(
|
|
||||||
# run_indexing_entrypoint,
|
|
||||||
# attempt.id,
|
|
||||||
# tenant_id,
|
|
||||||
# attempt.connector_credential_pair_id,
|
|
||||||
# global_version.is_ee_version(),
|
|
||||||
# pure=False,
|
|
||||||
# )
|
|
||||||
# if not run:
|
|
||||||
# primary_client_full = True
|
|
||||||
# else:
|
|
||||||
# if not secondary_client_full:
|
|
||||||
# run = secondary_client.submit(
|
|
||||||
# run_indexing_entrypoint,
|
|
||||||
# attempt.id,
|
|
||||||
# tenant_id,
|
|
||||||
# attempt.connector_credential_pair_id,
|
|
||||||
# global_version.is_ee_version(),
|
|
||||||
# pure=False,
|
|
||||||
# )
|
|
||||||
# if not run:
|
|
||||||
# secondary_client_full = True
|
|
||||||
# if run:
|
|
||||||
# if indexing_attempt_count == 0:
|
|
||||||
# logger.info(
|
|
||||||
# f"Indexing dispatch starts: pending={len(new_indexing_attempts)}"
|
|
||||||
# )
|
|
||||||
# indexing_attempt_count += 1
|
|
||||||
# secondary_str = " (secondary index)" if use_secondary_index else ""
|
|
||||||
# logger.info(
|
|
||||||
# f"Indexing dispatched{secondary_str}: "
|
|
||||||
# f"attempt_id={attempt.id} "
|
|
||||||
# f"connector='{attempt.connector_credential_pair.connector.name}' "
|
|
||||||
# f"config='{attempt.connector_credential_pair.connector.connector_specific_config}' "
|
|
||||||
# f"credentials='{attempt.connector_credential_pair.credential_id}'"
|
|
||||||
# )
|
|
||||||
# existing_jobs_copy[attempt.id] = run
|
|
||||||
# if indexing_attempt_count > 0:
|
|
||||||
# logger.info(
|
|
||||||
# f"Indexing dispatch results: "
|
|
||||||
# f"initial_pending={len(new_indexing_attempts)} "
|
|
||||||
# f"started={indexing_attempt_count} "
|
|
||||||
# f"remaining={len(new_indexing_attempts) - indexing_attempt_count}"
|
|
||||||
# )
|
|
||||||
# return existing_jobs_copy
|
|
||||||
# def get_all_tenant_ids() -> list[str] | list[None]:
|
|
||||||
# if not MULTI_TENANT:
|
|
||||||
# return [None]
|
|
||||||
# with get_session_with_tenant(tenant_id="public") as session:
|
|
||||||
# result = session.execute(
|
|
||||||
# text(
|
|
||||||
# """
|
|
||||||
# SELECT schema_name
|
|
||||||
# FROM information_schema.schemata
|
|
||||||
# WHERE schema_name NOT IN ('pg_catalog', 'information_schema', 'public')"""
|
|
||||||
# )
|
|
||||||
# )
|
|
||||||
# tenant_ids = [row[0] for row in result]
|
|
||||||
# valid_tenants = [
|
|
||||||
# tenant
|
|
||||||
# for tenant in tenant_ids
|
|
||||||
# if tenant is None or tenant.startswith(TENANT_ID_PREFIX)
|
|
||||||
# ]
|
|
||||||
# return valid_tenants
|
|
||||||
# def update_loop(
|
|
||||||
# delay: int = 10,
|
|
||||||
# num_workers: int = NUM_INDEXING_WORKERS,
|
|
||||||
# num_secondary_workers: int = NUM_SECONDARY_INDEXING_WORKERS,
|
|
||||||
# ) -> None:
|
|
||||||
# if not MULTI_TENANT:
|
|
||||||
# # We can use this function as we are certain only the public schema exists
|
|
||||||
# # (explicitly for the non-`MULTI_TENANT` case)
|
|
||||||
# engine = get_sqlalchemy_engine()
|
|
||||||
# with Session(engine) as db_session:
|
|
||||||
# check_index_swap(db_session=db_session)
|
|
||||||
# search_settings = get_current_search_settings(db_session)
|
|
||||||
# # So that the first time users aren't surprised by really slow speed of first
|
|
||||||
# # batch of documents indexed
|
|
||||||
# if search_settings.provider_type is None:
|
|
||||||
# logger.notice("Running a first inference to warm up embedding model")
|
|
||||||
# embedding_model = EmbeddingModel.from_db_model(
|
|
||||||
# search_settings=search_settings,
|
|
||||||
# server_host=INDEXING_MODEL_SERVER_HOST,
|
|
||||||
# server_port=INDEXING_MODEL_SERVER_PORT,
|
|
||||||
# )
|
|
||||||
# warm_up_bi_encoder(
|
|
||||||
# embedding_model=embedding_model,
|
|
||||||
# )
|
|
||||||
# logger.notice("First inference complete.")
|
|
||||||
# client_primary: Client | SimpleJobClient
|
|
||||||
# client_secondary: Client | SimpleJobClient
|
|
||||||
# if DASK_JOB_CLIENT_ENABLED:
|
|
||||||
# cluster_primary = LocalCluster(
|
|
||||||
# n_workers=num_workers,
|
|
||||||
# threads_per_worker=1,
|
|
||||||
# silence_logs=logging.ERROR,
|
|
||||||
# )
|
|
||||||
# cluster_secondary = LocalCluster(
|
|
||||||
# n_workers=num_secondary_workers,
|
|
||||||
# threads_per_worker=1,
|
|
||||||
# silence_logs=logging.ERROR,
|
|
||||||
# )
|
|
||||||
# client_primary = Client(cluster_primary)
|
|
||||||
# client_secondary = Client(cluster_secondary)
|
|
||||||
# if LOG_LEVEL.lower() == "debug":
|
|
||||||
# client_primary.register_worker_plugin(ResourceLogger())
|
|
||||||
# else:
|
|
||||||
# client_primary = SimpleJobClient(n_workers=num_workers)
|
|
||||||
# client_secondary = SimpleJobClient(n_workers=num_secondary_workers)
|
|
||||||
# existing_jobs: dict[str | None, dict[int, Future | SimpleJob]] = {}
|
|
||||||
# logger.notice("Startup complete. Waiting for indexing jobs...")
|
|
||||||
# while True:
|
|
||||||
# start = time.time()
|
|
||||||
# start_time_utc = datetime.utcfromtimestamp(start).strftime("%Y-%m-%d %H:%M:%S")
|
|
||||||
# logger.debug(f"Running update, current UTC time: {start_time_utc}")
|
|
||||||
# if existing_jobs:
|
|
||||||
# logger.debug(
|
|
||||||
# "Found existing indexing jobs: "
|
|
||||||
# f"{[(tenant_id, list(jobs.keys())) for tenant_id, jobs in existing_jobs.items()]}"
|
|
||||||
# )
|
|
||||||
# try:
|
|
||||||
# tenants = get_all_tenant_ids()
|
|
||||||
# for tenant_id in tenants:
|
|
||||||
# try:
|
|
||||||
# logger.debug(
|
|
||||||
# f"Processing {'index attempts' if tenant_id is None else f'tenant {tenant_id}'}"
|
|
||||||
# )
|
|
||||||
# with get_session_with_tenant(tenant_id) as db_session:
|
|
||||||
# index_to_expire = check_index_swap(db_session=db_session)
|
|
||||||
# if index_to_expire and tenant_id and MULTI_TENANT:
|
|
||||||
# VespaIndex.delete_entries_by_tenant_id(
|
|
||||||
# tenant_id=tenant_id,
|
|
||||||
# index_name=index_to_expire.index_name,
|
|
||||||
# )
|
|
||||||
# if not MULTI_TENANT:
|
|
||||||
# search_settings = get_current_search_settings(db_session)
|
|
||||||
# if search_settings.provider_type is None:
|
|
||||||
# logger.notice(
|
|
||||||
# "Running a first inference to warm up embedding model"
|
|
||||||
# )
|
|
||||||
# embedding_model = EmbeddingModel.from_db_model(
|
|
||||||
# search_settings=search_settings,
|
|
||||||
# server_host=INDEXING_MODEL_SERVER_HOST,
|
|
||||||
# server_port=INDEXING_MODEL_SERVER_PORT,
|
|
||||||
# )
|
|
||||||
# warm_up_bi_encoder(embedding_model=embedding_model)
|
|
||||||
# logger.notice("First inference complete.")
|
|
||||||
# tenant_jobs = existing_jobs.get(tenant_id, {})
|
|
||||||
# tenant_jobs = cleanup_indexing_jobs(
|
|
||||||
# existing_jobs=tenant_jobs, tenant_id=tenant_id
|
|
||||||
# )
|
|
||||||
# create_indexing_jobs(existing_jobs=tenant_jobs, tenant_id=tenant_id)
|
|
||||||
# tenant_jobs = kickoff_indexing_jobs(
|
|
||||||
# existing_jobs=tenant_jobs,
|
|
||||||
# client=client_primary,
|
|
||||||
# secondary_client=client_secondary,
|
|
||||||
# tenant_id=tenant_id,
|
|
||||||
# )
|
|
||||||
# existing_jobs[tenant_id] = tenant_jobs
|
|
||||||
# except Exception as e:
|
|
||||||
# logger.exception(
|
|
||||||
# f"Failed to process tenant {tenant_id or 'default'}: {e}"
|
|
||||||
# )
|
|
||||||
# except Exception as e:
|
|
||||||
# logger.exception(f"Failed to run update due to {e}")
|
|
||||||
# sleep_time = delay - (time.time() - start)
|
|
||||||
# if sleep_time > 0:
|
|
||||||
# time.sleep(sleep_time)
|
|
||||||
# def update__main() -> None:
|
|
||||||
# set_is_ee_based_on_env_variable()
|
|
||||||
# # initialize the Postgres connection pool
|
|
||||||
# SqlEngine.set_app_name(POSTGRES_INDEXER_APP_NAME)
|
|
||||||
# logger.notice("Starting indexing service")
|
|
||||||
# update_loop()
|
|
||||||
# if __name__ == "__main__":
|
|
||||||
# update__main()
|
|
@ -43,6 +43,9 @@ WEB_DOMAIN = os.environ.get("WEB_DOMAIN") or "http://localhost:3000"
|
|||||||
AUTH_TYPE = AuthType((os.environ.get("AUTH_TYPE") or AuthType.DISABLED.value).lower())
|
AUTH_TYPE = AuthType((os.environ.get("AUTH_TYPE") or AuthType.DISABLED.value).lower())
|
||||||
DISABLE_AUTH = AUTH_TYPE == AuthType.DISABLED
|
DISABLE_AUTH = AUTH_TYPE == AuthType.DISABLED
|
||||||
|
|
||||||
|
# Necessary for cloud integration tests
|
||||||
|
DISABLE_VERIFICATION = os.environ.get("DISABLE_VERIFICATION", "").lower() == "true"
|
||||||
|
|
||||||
# Encryption key secret is used to encrypt connector credentials, api keys, and other sensitive
|
# Encryption key secret is used to encrypt connector credentials, api keys, and other sensitive
|
||||||
# information. This provides an extra layer of security on top of Postgres access controls
|
# information. This provides an extra layer of security on top of Postgres access controls
|
||||||
# and is available in Danswer EE
|
# and is available in Danswer EE
|
||||||
|
@ -42,7 +42,6 @@ def check_index_swap(db_session: Session) -> SearchSettings | None:
|
|||||||
logger.error("More unique indexings than cc pairs, should not occur")
|
logger.error("More unique indexings than cc pairs, should not occur")
|
||||||
|
|
||||||
if cc_pair_count == 0 or cc_pair_count == unique_cc_indexings:
|
if cc_pair_count == 0 or cc_pair_count == unique_cc_indexings:
|
||||||
# Swap indices
|
|
||||||
now_old_search_settings = get_current_search_settings(db_session)
|
now_old_search_settings = get_current_search_settings(db_session)
|
||||||
update_search_settings_status(
|
update_search_settings_status(
|
||||||
search_settings=now_old_search_settings,
|
search_settings=now_old_search_settings,
|
||||||
@ -69,4 +68,6 @@ def check_index_swap(db_session: Session) -> SearchSettings | None:
|
|||||||
|
|
||||||
if MULTI_TENANT:
|
if MULTI_TENANT:
|
||||||
return now_old_search_settings
|
return now_old_search_settings
|
||||||
|
else:
|
||||||
|
logger.warning("No need to swap indices")
|
||||||
return None
|
return None
|
||||||
|
@ -9,6 +9,7 @@ from danswer.connectors.models import IndexAttemptMetadata
|
|||||||
from danswer.db.connector_credential_pair import get_connector_credential_pair_from_id
|
from danswer.db.connector_credential_pair import get_connector_credential_pair_from_id
|
||||||
from danswer.db.document import get_documents_by_cc_pair
|
from danswer.db.document import get_documents_by_cc_pair
|
||||||
from danswer.db.document import get_ingestion_documents
|
from danswer.db.document import get_ingestion_documents
|
||||||
|
from danswer.db.engine import get_current_tenant_id
|
||||||
from danswer.db.engine import get_session
|
from danswer.db.engine import get_session
|
||||||
from danswer.db.models import User
|
from danswer.db.models import User
|
||||||
from danswer.db.search_settings import get_current_search_settings
|
from danswer.db.search_settings import get_current_search_settings
|
||||||
@ -67,6 +68,7 @@ def upsert_ingestion_doc(
|
|||||||
doc_info: IngestionDocument,
|
doc_info: IngestionDocument,
|
||||||
_: User | None = Depends(api_key_dep),
|
_: User | None = Depends(api_key_dep),
|
||||||
db_session: Session = Depends(get_session),
|
db_session: Session = Depends(get_session),
|
||||||
|
tenant_id: str = Depends(get_current_tenant_id),
|
||||||
) -> IngestionResult:
|
) -> IngestionResult:
|
||||||
doc_info.document.from_ingestion_api = True
|
doc_info.document.from_ingestion_api = True
|
||||||
|
|
||||||
@ -101,6 +103,7 @@ def upsert_ingestion_doc(
|
|||||||
document_index=curr_doc_index,
|
document_index=curr_doc_index,
|
||||||
ignore_time_skip=True,
|
ignore_time_skip=True,
|
||||||
db_session=db_session,
|
db_session=db_session,
|
||||||
|
tenant_id=tenant_id,
|
||||||
)
|
)
|
||||||
|
|
||||||
new_doc, __chunk_count = indexing_pipeline(
|
new_doc, __chunk_count = indexing_pipeline(
|
||||||
@ -134,6 +137,7 @@ def upsert_ingestion_doc(
|
|||||||
document_index=sec_doc_index,
|
document_index=sec_doc_index,
|
||||||
ignore_time_skip=True,
|
ignore_time_skip=True,
|
||||||
db_session=db_session,
|
db_session=db_session,
|
||||||
|
tenant_id=tenant_id,
|
||||||
)
|
)
|
||||||
|
|
||||||
sec_ind_pipeline(
|
sec_ind_pipeline(
|
||||||
|
@ -102,8 +102,6 @@ def get_settings_notifications(
|
|||||||
|
|
||||||
try:
|
try:
|
||||||
# Need a transaction in order to prevent under-counting current notifications
|
# Need a transaction in order to prevent under-counting current notifications
|
||||||
db_session.begin()
|
|
||||||
|
|
||||||
reindex_notifs = get_notifications(
|
reindex_notifs = get_notifications(
|
||||||
user=user, notif_type=NotificationType.REINDEX, db_session=db_session
|
user=user, notif_type=NotificationType.REINDEX, db_session=db_session
|
||||||
)
|
)
|
||||||
|
@ -22,7 +22,6 @@ def add_tenant_id_middleware(app: FastAPI, logger: logging.LoggerAdapter) -> Non
|
|||||||
) -> Response:
|
) -> Response:
|
||||||
try:
|
try:
|
||||||
logger.info(f"Request route: {request.url.path}")
|
logger.info(f"Request route: {request.url.path}")
|
||||||
|
|
||||||
if not MULTI_TENANT:
|
if not MULTI_TENANT:
|
||||||
tenant_id = POSTGRES_DEFAULT_SCHEMA
|
tenant_id = POSTGRES_DEFAULT_SCHEMA
|
||||||
else:
|
else:
|
||||||
|
@ -83,4 +83,5 @@ COPY ./tests/integration /app/tests/integration
|
|||||||
|
|
||||||
ENV PYTHONPATH=/app
|
ENV PYTHONPATH=/app
|
||||||
|
|
||||||
CMD ["pytest", "-s", "/app/tests/integration"]
|
ENTRYPOINT ["pytest", "-s"]
|
||||||
|
CMD ["/app/tests/integration", "--ignore=/app/tests/integration/multitenant_tests"]
|
@ -23,7 +23,7 @@ from tests.integration.common_utils.test_models import StreamedResponse
|
|||||||
class ChatSessionManager:
|
class ChatSessionManager:
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def create(
|
def create(
|
||||||
persona_id: int = -1,
|
persona_id: int = 0,
|
||||||
description: str = "Test chat session",
|
description: str = "Test chat session",
|
||||||
user_performing_action: DATestUser | None = None,
|
user_performing_action: DATestUser | None = None,
|
||||||
) -> DATestChatSession:
|
) -> DATestChatSession:
|
||||||
|
@ -32,6 +32,7 @@ class CredentialManager:
|
|||||||
"curator_public": curator_public,
|
"curator_public": curator_public,
|
||||||
"groups": groups or [],
|
"groups": groups or [],
|
||||||
}
|
}
|
||||||
|
|
||||||
response = requests.post(
|
response = requests.post(
|
||||||
url=f"{API_SERVER_URL}/manage/credential",
|
url=f"{API_SERVER_URL}/manage/credential",
|
||||||
json=credential_request,
|
json=credential_request,
|
||||||
|
82
backend/tests/integration/common_utils/managers/tenant.py
Normal file
82
backend/tests/integration/common_utils/managers/tenant.py
Normal file
@ -0,0 +1,82 @@
|
|||||||
|
from datetime import datetime
|
||||||
|
from datetime import timedelta
|
||||||
|
|
||||||
|
import jwt
|
||||||
|
import requests
|
||||||
|
|
||||||
|
from danswer.server.manage.models import AllUsersResponse
|
||||||
|
from danswer.server.models import FullUserSnapshot
|
||||||
|
from danswer.server.models import InvitedUserSnapshot
|
||||||
|
from tests.integration.common_utils.constants import API_SERVER_URL
|
||||||
|
from tests.integration.common_utils.constants import GENERAL_HEADERS
|
||||||
|
from tests.integration.common_utils.test_models import DATestUser
|
||||||
|
|
||||||
|
|
||||||
|
def generate_auth_token() -> str:
|
||||||
|
payload = {
|
||||||
|
"iss": "control_plane",
|
||||||
|
"exp": datetime.utcnow() + timedelta(minutes=5),
|
||||||
|
"iat": datetime.utcnow(),
|
||||||
|
"scope": "tenant:create",
|
||||||
|
}
|
||||||
|
token = jwt.encode(payload, "", algorithm="HS256")
|
||||||
|
return token
|
||||||
|
|
||||||
|
|
||||||
|
class TenantManager:
|
||||||
|
@staticmethod
|
||||||
|
def create(
|
||||||
|
tenant_id: str | None = None,
|
||||||
|
initial_admin_email: str | None = None,
|
||||||
|
) -> dict[str, str]:
|
||||||
|
body = {
|
||||||
|
"tenant_id": tenant_id,
|
||||||
|
"initial_admin_email": initial_admin_email,
|
||||||
|
}
|
||||||
|
|
||||||
|
token = generate_auth_token()
|
||||||
|
headers = {
|
||||||
|
"Authorization": f"Bearer {token}",
|
||||||
|
"X-API-KEY": "",
|
||||||
|
"Content-Type": "application/json",
|
||||||
|
}
|
||||||
|
|
||||||
|
response = requests.post(
|
||||||
|
url=f"{API_SERVER_URL}/tenants/create",
|
||||||
|
json=body,
|
||||||
|
headers=headers,
|
||||||
|
)
|
||||||
|
|
||||||
|
response.raise_for_status()
|
||||||
|
|
||||||
|
return response.json()
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def get_all_users(
|
||||||
|
user_performing_action: DATestUser | None = None,
|
||||||
|
) -> AllUsersResponse:
|
||||||
|
response = requests.get(
|
||||||
|
url=f"{API_SERVER_URL}/manage/users",
|
||||||
|
headers=user_performing_action.headers
|
||||||
|
if user_performing_action
|
||||||
|
else GENERAL_HEADERS,
|
||||||
|
)
|
||||||
|
response.raise_for_status()
|
||||||
|
|
||||||
|
data = response.json()
|
||||||
|
return AllUsersResponse(
|
||||||
|
accepted=[FullUserSnapshot(**user) for user in data["accepted"]],
|
||||||
|
invited=[InvitedUserSnapshot(**user) for user in data["invited"]],
|
||||||
|
accepted_pages=data["accepted_pages"],
|
||||||
|
invited_pages=data["invited_pages"],
|
||||||
|
)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def verify_user_in_tenant(
|
||||||
|
user: DATestUser, user_performing_action: DATestUser | None = None
|
||||||
|
) -> None:
|
||||||
|
all_users = TenantManager.get_all_users(user_performing_action)
|
||||||
|
for accepted_user in all_users.accepted:
|
||||||
|
if accepted_user.email == user.email and accepted_user.id == user.id:
|
||||||
|
return
|
||||||
|
raise ValueError(f"User {user.email} not found in tenant")
|
@ -65,15 +65,23 @@ class UserManager:
|
|||||||
data=data,
|
data=data,
|
||||||
headers=headers,
|
headers=headers,
|
||||||
)
|
)
|
||||||
response.raise_for_status()
|
|
||||||
result_cookie = next(iter(response.cookies), None)
|
|
||||||
|
|
||||||
if not result_cookie:
|
response.raise_for_status()
|
||||||
|
|
||||||
|
cookies = response.cookies.get_dict()
|
||||||
|
session_cookie = cookies.get("fastapiusersauth")
|
||||||
|
tenant_details_cookie = cookies.get("tenant_details")
|
||||||
|
|
||||||
|
if not session_cookie:
|
||||||
raise Exception("Failed to login")
|
raise Exception("Failed to login")
|
||||||
|
|
||||||
print(f"Logged in as {test_user.email}")
|
print(f"Logged in as {test_user.email}")
|
||||||
cookie = f"{result_cookie.name}={result_cookie.value}"
|
|
||||||
test_user.headers["Cookie"] = cookie
|
# Set both cookies in the headers
|
||||||
|
test_user.headers["Cookie"] = (
|
||||||
|
f"fastapiusersauth={session_cookie}; "
|
||||||
|
f"tenant_details={tenant_details_cookie}"
|
||||||
|
)
|
||||||
return test_user
|
return test_user
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
|
@ -1,5 +1,6 @@
|
|||||||
import logging
|
import logging
|
||||||
import time
|
import time
|
||||||
|
from types import SimpleNamespace
|
||||||
|
|
||||||
import psycopg2
|
import psycopg2
|
||||||
import requests
|
import requests
|
||||||
@ -11,7 +12,9 @@ from danswer.configs.app_configs import POSTGRES_PASSWORD
|
|||||||
from danswer.configs.app_configs import POSTGRES_PORT
|
from danswer.configs.app_configs import POSTGRES_PORT
|
||||||
from danswer.configs.app_configs import POSTGRES_USER
|
from danswer.configs.app_configs import POSTGRES_USER
|
||||||
from danswer.db.engine import build_connection_string
|
from danswer.db.engine import build_connection_string
|
||||||
|
from danswer.db.engine import get_all_tenant_ids
|
||||||
from danswer.db.engine import get_session_context_manager
|
from danswer.db.engine import get_session_context_manager
|
||||||
|
from danswer.db.engine import get_session_with_tenant
|
||||||
from danswer.db.engine import SYNC_DB_API
|
from danswer.db.engine import SYNC_DB_API
|
||||||
from danswer.db.search_settings import get_current_search_settings
|
from danswer.db.search_settings import get_current_search_settings
|
||||||
from danswer.db.swap_index import check_index_swap
|
from danswer.db.swap_index import check_index_swap
|
||||||
@ -26,7 +29,11 @@ logger = setup_logger()
|
|||||||
|
|
||||||
|
|
||||||
def _run_migrations(
|
def _run_migrations(
|
||||||
database_url: str, direction: str = "upgrade", revision: str = "head"
|
database_url: str,
|
||||||
|
config_name: str,
|
||||||
|
direction: str = "upgrade",
|
||||||
|
revision: str = "head",
|
||||||
|
schema: str = "public",
|
||||||
) -> None:
|
) -> None:
|
||||||
# hide info logs emitted during migration
|
# hide info logs emitted during migration
|
||||||
logging.getLogger("alembic").setLevel(logging.CRITICAL)
|
logging.getLogger("alembic").setLevel(logging.CRITICAL)
|
||||||
@ -35,6 +42,10 @@ def _run_migrations(
|
|||||||
alembic_cfg = Config("alembic.ini")
|
alembic_cfg = Config("alembic.ini")
|
||||||
alembic_cfg.set_section_option("logger_alembic", "level", "WARN")
|
alembic_cfg.set_section_option("logger_alembic", "level", "WARN")
|
||||||
alembic_cfg.attributes["configure_logger"] = False
|
alembic_cfg.attributes["configure_logger"] = False
|
||||||
|
alembic_cfg.config_ini_section = config_name
|
||||||
|
|
||||||
|
alembic_cfg.cmd_opts = SimpleNamespace() # type: ignore
|
||||||
|
alembic_cfg.cmd_opts.x = [f"schema={schema}"] # type: ignore
|
||||||
|
|
||||||
# Set the SQLAlchemy URL in the Alembic configuration
|
# Set the SQLAlchemy URL in the Alembic configuration
|
||||||
alembic_cfg.set_main_option("sqlalchemy.url", database_url)
|
alembic_cfg.set_main_option("sqlalchemy.url", database_url)
|
||||||
@ -52,7 +63,9 @@ def _run_migrations(
|
|||||||
logging.getLogger("alembic").setLevel(logging.INFO)
|
logging.getLogger("alembic").setLevel(logging.INFO)
|
||||||
|
|
||||||
|
|
||||||
def reset_postgres(database: str = "postgres") -> None:
|
def reset_postgres(
|
||||||
|
database: str = "postgres", config_name: str = "alembic", setup_danswer: bool = True
|
||||||
|
) -> None:
|
||||||
"""Reset the Postgres database."""
|
"""Reset the Postgres database."""
|
||||||
|
|
||||||
# NOTE: need to delete all rows to allow migrations to be rolled back
|
# NOTE: need to delete all rows to allow migrations to be rolled back
|
||||||
@ -111,14 +124,18 @@ def reset_postgres(database: str = "postgres") -> None:
|
|||||||
)
|
)
|
||||||
_run_migrations(
|
_run_migrations(
|
||||||
conn_str,
|
conn_str,
|
||||||
|
config_name,
|
||||||
direction="downgrade",
|
direction="downgrade",
|
||||||
revision="base",
|
revision="base",
|
||||||
)
|
)
|
||||||
_run_migrations(
|
_run_migrations(
|
||||||
conn_str,
|
conn_str,
|
||||||
|
config_name,
|
||||||
direction="upgrade",
|
direction="upgrade",
|
||||||
revision="head",
|
revision="head",
|
||||||
)
|
)
|
||||||
|
if not setup_danswer:
|
||||||
|
return
|
||||||
|
|
||||||
# do the same thing as we do on API server startup
|
# do the same thing as we do on API server startup
|
||||||
with get_session_context_manager() as db_session:
|
with get_session_context_manager() as db_session:
|
||||||
@ -127,6 +144,7 @@ def reset_postgres(database: str = "postgres") -> None:
|
|||||||
|
|
||||||
def reset_vespa() -> None:
|
def reset_vespa() -> None:
|
||||||
"""Wipe all data from the Vespa index."""
|
"""Wipe all data from the Vespa index."""
|
||||||
|
|
||||||
with get_session_context_manager() as db_session:
|
with get_session_context_manager() as db_session:
|
||||||
# swap to the correct default model
|
# swap to the correct default model
|
||||||
check_index_swap(db_session)
|
check_index_swap(db_session)
|
||||||
@ -166,10 +184,98 @@ def reset_vespa() -> None:
|
|||||||
time.sleep(5)
|
time.sleep(5)
|
||||||
|
|
||||||
|
|
||||||
|
def reset_postgres_multitenant() -> None:
|
||||||
|
"""Reset the Postgres database for all tenants in a multitenant setup."""
|
||||||
|
|
||||||
|
conn = psycopg2.connect(
|
||||||
|
dbname="postgres",
|
||||||
|
user=POSTGRES_USER,
|
||||||
|
password=POSTGRES_PASSWORD,
|
||||||
|
host=POSTGRES_HOST,
|
||||||
|
port=POSTGRES_PORT,
|
||||||
|
)
|
||||||
|
conn.autocommit = True
|
||||||
|
cur = conn.cursor()
|
||||||
|
|
||||||
|
# Get all tenant schemas
|
||||||
|
cur.execute(
|
||||||
|
"""
|
||||||
|
SELECT schema_name
|
||||||
|
FROM information_schema.schemata
|
||||||
|
WHERE schema_name LIKE 'tenant_%'
|
||||||
|
"""
|
||||||
|
)
|
||||||
|
tenant_schemas = cur.fetchall()
|
||||||
|
|
||||||
|
# Drop all tenant schemas
|
||||||
|
for schema in tenant_schemas:
|
||||||
|
schema_name = schema[0]
|
||||||
|
cur.execute(f'DROP SCHEMA "{schema_name}" CASCADE')
|
||||||
|
|
||||||
|
cur.close()
|
||||||
|
conn.close()
|
||||||
|
|
||||||
|
reset_postgres(config_name="schema_private", setup_danswer=False)
|
||||||
|
|
||||||
|
|
||||||
|
def reset_vespa_multitenant() -> None:
|
||||||
|
"""Wipe all data from the Vespa index for all tenants."""
|
||||||
|
|
||||||
|
for tenant_id in get_all_tenant_ids():
|
||||||
|
with get_session_with_tenant(tenant_id=tenant_id) as db_session:
|
||||||
|
# swap to the correct default model for each tenant
|
||||||
|
check_index_swap(db_session)
|
||||||
|
|
||||||
|
search_settings = get_current_search_settings(db_session)
|
||||||
|
index_name = search_settings.index_name
|
||||||
|
|
||||||
|
success = setup_vespa(
|
||||||
|
document_index=VespaIndex(index_name=index_name, secondary_index_name=None),
|
||||||
|
index_setting=IndexingSetting.from_db_model(search_settings),
|
||||||
|
secondary_index_setting=None,
|
||||||
|
)
|
||||||
|
|
||||||
|
if not success:
|
||||||
|
raise RuntimeError(
|
||||||
|
f"Could not connect to Vespa for tenant {tenant_id} within the specified timeout."
|
||||||
|
)
|
||||||
|
|
||||||
|
for _ in range(5):
|
||||||
|
try:
|
||||||
|
continuation = None
|
||||||
|
should_continue = True
|
||||||
|
while should_continue:
|
||||||
|
params = {"selection": "true", "cluster": "danswer_index"}
|
||||||
|
if continuation:
|
||||||
|
params = {**params, "continuation": continuation}
|
||||||
|
response = requests.delete(
|
||||||
|
DOCUMENT_ID_ENDPOINT.format(index_name=index_name),
|
||||||
|
params=params,
|
||||||
|
)
|
||||||
|
response.raise_for_status()
|
||||||
|
|
||||||
|
response_json = response.json()
|
||||||
|
|
||||||
|
continuation = response_json.get("continuation")
|
||||||
|
should_continue = bool(continuation)
|
||||||
|
|
||||||
|
break
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Error deleting documents for tenant {tenant_id}: {e}")
|
||||||
|
time.sleep(5)
|
||||||
|
|
||||||
|
|
||||||
def reset_all() -> None:
|
def reset_all() -> None:
|
||||||
"""Reset both Postgres and Vespa."""
|
|
||||||
logger.info("Resetting Postgres...")
|
logger.info("Resetting Postgres...")
|
||||||
reset_postgres()
|
reset_postgres()
|
||||||
logger.info("Resetting Vespa...")
|
logger.info("Resetting Vespa...")
|
||||||
reset_vespa()
|
reset_vespa()
|
||||||
|
|
||||||
|
|
||||||
|
def reset_all_multitenant() -> None:
|
||||||
|
"""Reset both Postgres and Vespa for all tenants."""
|
||||||
|
logger.info("Resetting Postgres for all tenants...")
|
||||||
|
reset_postgres_multitenant()
|
||||||
|
logger.info("Resetting Vespa for all tenants...")
|
||||||
|
reset_vespa_multitenant()
|
||||||
logger.info("Finished resetting all.")
|
logger.info("Finished resetting all.")
|
||||||
|
@ -8,6 +8,7 @@ from danswer.db.engine import get_session_context_manager
|
|||||||
from danswer.db.search_settings import get_current_search_settings
|
from danswer.db.search_settings import get_current_search_settings
|
||||||
from tests.integration.common_utils.managers.user import UserManager
|
from tests.integration.common_utils.managers.user import UserManager
|
||||||
from tests.integration.common_utils.reset import reset_all
|
from tests.integration.common_utils.reset import reset_all
|
||||||
|
from tests.integration.common_utils.reset import reset_all_multitenant
|
||||||
from tests.integration.common_utils.test_models import DATestUser
|
from tests.integration.common_utils.test_models import DATestUser
|
||||||
from tests.integration.common_utils.vespa import vespa_fixture
|
from tests.integration.common_utils.vespa import vespa_fixture
|
||||||
|
|
||||||
@ -54,3 +55,8 @@ def new_admin_user(reset: None) -> DATestUser | None:
|
|||||||
return UserManager.create(name="admin_user")
|
return UserManager.create(name="admin_user")
|
||||||
except Exception:
|
except Exception:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def reset_multitenant() -> None:
|
||||||
|
reset_all_multitenant()
|
||||||
|
0
backend/tests/integration/multitenant_tests/cc_Pair
Normal file
0
backend/tests/integration/multitenant_tests/cc_Pair
Normal file
@ -0,0 +1,150 @@
|
|||||||
|
from danswer.db.models import UserRole
|
||||||
|
from tests.integration.common_utils.managers.api_key import APIKeyManager
|
||||||
|
from tests.integration.common_utils.managers.cc_pair import CCPairManager
|
||||||
|
from tests.integration.common_utils.managers.chat import ChatSessionManager
|
||||||
|
from tests.integration.common_utils.managers.document import DocumentManager
|
||||||
|
from tests.integration.common_utils.managers.llm_provider import LLMProviderManager
|
||||||
|
from tests.integration.common_utils.managers.tenant import TenantManager
|
||||||
|
from tests.integration.common_utils.managers.user import UserManager
|
||||||
|
from tests.integration.common_utils.test_models import DATestAPIKey
|
||||||
|
from tests.integration.common_utils.test_models import DATestCCPair
|
||||||
|
from tests.integration.common_utils.test_models import DATestChatSession
|
||||||
|
from tests.integration.common_utils.test_models import DATestUser
|
||||||
|
|
||||||
|
|
||||||
|
def test_multi_tenant_access_control(reset_multitenant: None) -> None:
|
||||||
|
# Create Tenant 1 and its Admin User
|
||||||
|
TenantManager.create("tenant_dev1", "test1@test.com")
|
||||||
|
test_user1: DATestUser = UserManager.create(name="test1", email="test1@test.com")
|
||||||
|
assert UserManager.verify_role(test_user1, UserRole.ADMIN)
|
||||||
|
|
||||||
|
# Create Tenant 2 and its Admin User
|
||||||
|
TenantManager.create("tenant_dev2", "test2@test.com")
|
||||||
|
test_user2: DATestUser = UserManager.create(name="test2", email="test2@test.com")
|
||||||
|
assert UserManager.verify_role(test_user2, UserRole.ADMIN)
|
||||||
|
|
||||||
|
# Create connectors for Tenant 1
|
||||||
|
cc_pair_1: DATestCCPair = CCPairManager.create_from_scratch(
|
||||||
|
user_performing_action=test_user1,
|
||||||
|
)
|
||||||
|
api_key_1: DATestAPIKey = APIKeyManager.create(
|
||||||
|
user_performing_action=test_user1,
|
||||||
|
)
|
||||||
|
api_key_1.headers.update(test_user1.headers)
|
||||||
|
LLMProviderManager.create(user_performing_action=test_user1)
|
||||||
|
|
||||||
|
# Seed documents for Tenant 1
|
||||||
|
cc_pair_1.documents = []
|
||||||
|
doc1_tenant1 = DocumentManager.seed_doc_with_content(
|
||||||
|
cc_pair=cc_pair_1,
|
||||||
|
content="Tenant 1 Document Content",
|
||||||
|
api_key=api_key_1,
|
||||||
|
)
|
||||||
|
doc2_tenant1 = DocumentManager.seed_doc_with_content(
|
||||||
|
cc_pair=cc_pair_1,
|
||||||
|
content="Tenant 1 Document Content",
|
||||||
|
api_key=api_key_1,
|
||||||
|
)
|
||||||
|
cc_pair_1.documents.extend([doc1_tenant1, doc2_tenant1])
|
||||||
|
|
||||||
|
# Create connectors for Tenant 2
|
||||||
|
cc_pair_2: DATestCCPair = CCPairManager.create_from_scratch(
|
||||||
|
user_performing_action=test_user2,
|
||||||
|
)
|
||||||
|
api_key_2: DATestAPIKey = APIKeyManager.create(
|
||||||
|
user_performing_action=test_user2,
|
||||||
|
)
|
||||||
|
api_key_2.headers.update(test_user2.headers)
|
||||||
|
LLMProviderManager.create(user_performing_action=test_user2)
|
||||||
|
|
||||||
|
# Seed documents for Tenant 2
|
||||||
|
cc_pair_2.documents = []
|
||||||
|
doc1_tenant2 = DocumentManager.seed_doc_with_content(
|
||||||
|
cc_pair=cc_pair_2,
|
||||||
|
content="Tenant 2 Document Content",
|
||||||
|
api_key=api_key_2,
|
||||||
|
)
|
||||||
|
doc2_tenant2 = DocumentManager.seed_doc_with_content(
|
||||||
|
cc_pair=cc_pair_2,
|
||||||
|
content="Tenant 2 Document Content",
|
||||||
|
api_key=api_key_2,
|
||||||
|
)
|
||||||
|
cc_pair_2.documents.extend([doc1_tenant2, doc2_tenant2])
|
||||||
|
|
||||||
|
tenant1_doc_ids = {doc1_tenant1.id, doc2_tenant1.id}
|
||||||
|
tenant2_doc_ids = {doc1_tenant2.id, doc2_tenant2.id}
|
||||||
|
|
||||||
|
# Create chat sessions for each user
|
||||||
|
chat_session1: DATestChatSession = ChatSessionManager.create(
|
||||||
|
user_performing_action=test_user1
|
||||||
|
)
|
||||||
|
chat_session2: DATestChatSession = ChatSessionManager.create(
|
||||||
|
user_performing_action=test_user2
|
||||||
|
)
|
||||||
|
|
||||||
|
# User 1 sends a message and gets a response
|
||||||
|
response1 = ChatSessionManager.send_message(
|
||||||
|
chat_session_id=chat_session1.id,
|
||||||
|
message="What is in Tenant 1's documents?",
|
||||||
|
user_performing_action=test_user1,
|
||||||
|
)
|
||||||
|
# Assert that the search tool was used
|
||||||
|
assert response1.tool_name == "run_search"
|
||||||
|
|
||||||
|
response_doc_ids = {doc["document_id"] for doc in response1.tool_result or []}
|
||||||
|
assert tenant1_doc_ids.issubset(
|
||||||
|
response_doc_ids
|
||||||
|
), "Not all Tenant 1 document IDs are in the response"
|
||||||
|
assert not response_doc_ids.intersection(
|
||||||
|
tenant2_doc_ids
|
||||||
|
), "Tenant 2 document IDs should not be in the response"
|
||||||
|
|
||||||
|
# Assert that the contents are correct
|
||||||
|
for doc in response1.tool_result or []:
|
||||||
|
assert doc["content"] == "Tenant 1 Document Content"
|
||||||
|
|
||||||
|
# User 2 sends a message and gets a response
|
||||||
|
response2 = ChatSessionManager.send_message(
|
||||||
|
chat_session_id=chat_session2.id,
|
||||||
|
message="What is in Tenant 2's documents?",
|
||||||
|
user_performing_action=test_user2,
|
||||||
|
)
|
||||||
|
# Assert that the search tool was used
|
||||||
|
assert response2.tool_name == "run_search"
|
||||||
|
# Assert that the tool_result contains Tenant 2's documents
|
||||||
|
response_doc_ids = {doc["document_id"] for doc in response2.tool_result or []}
|
||||||
|
assert tenant2_doc_ids.issubset(
|
||||||
|
response_doc_ids
|
||||||
|
), "Not all Tenant 2 document IDs are in the response"
|
||||||
|
assert not response_doc_ids.intersection(
|
||||||
|
tenant1_doc_ids
|
||||||
|
), "Tenant 1 document IDs should not be in the response"
|
||||||
|
# Assert that the contents are correct
|
||||||
|
for doc in response2.tool_result or []:
|
||||||
|
assert doc["content"] == "Tenant 2 Document Content"
|
||||||
|
|
||||||
|
# User 1 tries to access Tenant 2's documents
|
||||||
|
response_cross = ChatSessionManager.send_message(
|
||||||
|
chat_session_id=chat_session1.id,
|
||||||
|
message="What is in Tenant 2's documents?",
|
||||||
|
user_performing_action=test_user1,
|
||||||
|
)
|
||||||
|
# Assert that the search tool was used
|
||||||
|
assert response_cross.tool_name == "run_search"
|
||||||
|
# Assert that the tool_result is empty or does not contain Tenant 2's documents
|
||||||
|
response_doc_ids = {doc["document_id"] for doc in response_cross.tool_result or []}
|
||||||
|
# Ensure none of Tenant 2's document IDs are in the response
|
||||||
|
assert not response_doc_ids.intersection(tenant2_doc_ids)
|
||||||
|
|
||||||
|
# User 2 tries to access Tenant 1's documents
|
||||||
|
response_cross2 = ChatSessionManager.send_message(
|
||||||
|
chat_session_id=chat_session2.id,
|
||||||
|
message="What is in Tenant 1's documents?",
|
||||||
|
user_performing_action=test_user2,
|
||||||
|
)
|
||||||
|
# Assert that the search tool was used
|
||||||
|
assert response_cross2.tool_name == "run_search"
|
||||||
|
# Assert that the tool_result is empty or does not contain Tenant 1's documents
|
||||||
|
response_doc_ids = {doc["document_id"] for doc in response_cross2.tool_result or []}
|
||||||
|
# Ensure none of Tenant 1's document IDs are in the response
|
||||||
|
assert not response_doc_ids.intersection(tenant1_doc_ids)
|
@ -0,0 +1,41 @@
|
|||||||
|
from danswer.configs.constants import DocumentSource
|
||||||
|
from danswer.db.enums import AccessType
|
||||||
|
from danswer.db.models import UserRole
|
||||||
|
from tests.integration.common_utils.managers.cc_pair import CCPairManager
|
||||||
|
from tests.integration.common_utils.managers.connector import ConnectorManager
|
||||||
|
from tests.integration.common_utils.managers.credential import CredentialManager
|
||||||
|
from tests.integration.common_utils.managers.tenant import TenantManager
|
||||||
|
from tests.integration.common_utils.managers.user import UserManager
|
||||||
|
from tests.integration.common_utils.test_models import DATestUser
|
||||||
|
|
||||||
|
|
||||||
|
# Test flow from creating tenant to registering as a user
|
||||||
|
def test_tenant_creation(reset_multitenant: None) -> None:
|
||||||
|
TenantManager.create("tenant_dev", "test@test.com")
|
||||||
|
test_user: DATestUser = UserManager.create(name="test", email="test@test.com")
|
||||||
|
|
||||||
|
assert UserManager.verify_role(test_user, UserRole.ADMIN)
|
||||||
|
|
||||||
|
test_credential = CredentialManager.create(
|
||||||
|
name="admin_test_credential",
|
||||||
|
source=DocumentSource.FILE,
|
||||||
|
curator_public=False,
|
||||||
|
user_performing_action=test_user,
|
||||||
|
)
|
||||||
|
|
||||||
|
test_connector = ConnectorManager.create(
|
||||||
|
name="admin_test_connector",
|
||||||
|
source=DocumentSource.FILE,
|
||||||
|
is_public=False,
|
||||||
|
user_performing_action=test_user,
|
||||||
|
)
|
||||||
|
|
||||||
|
test_cc_pair = CCPairManager.create(
|
||||||
|
connector_id=test_connector.id,
|
||||||
|
credential_id=test_credential.id,
|
||||||
|
name="admin_test_cc_pair",
|
||||||
|
access_type=AccessType.PRIVATE,
|
||||||
|
user_performing_action=test_user,
|
||||||
|
)
|
||||||
|
|
||||||
|
CCPairManager.verify(cc_pair=test_cc_pair, user_performing_action=test_user)
|
@ -119,6 +119,7 @@ def test_using_reference_docs_with_simple_with_history_api_flow(reset: None) ->
|
|||||||
)
|
)
|
||||||
assert response.status_code == 200
|
assert response.status_code == 200
|
||||||
response_json = response.json()
|
response_json = response.json()
|
||||||
|
|
||||||
# get the db_doc_id of the top document to use as a search doc id for second message
|
# get the db_doc_id of the top document to use as a search doc id for second message
|
||||||
first_db_doc_id = response_json["top_documents"][0]["db_doc_id"]
|
first_db_doc_id = response_json["top_documents"][0]["db_doc_id"]
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user