diff --git a/.github/pull_request_template.md b/.github/pull_request_template.md index c49908b1d5..8287f9b530 100644 --- a/.github/pull_request_template.md +++ b/.github/pull_request_template.md @@ -26,4 +26,4 @@ N/A ## Backporting (check the box to trigger backport action) Note: You have to check that the action passes, otherwise resolve the conflicts manually and tag the patches. -[ ] This PR should be backported (make sure to check that the backport attempt succeeds) +- [ ] This PR should be backported (make sure to check that the backport attempt succeeds) diff --git a/.github/workflows/pr-backport-autotrigger.yml b/.github/workflows/pr-backport-autotrigger.yml index 4ba5136a82..2d49c39402 100644 --- a/.github/workflows/pr-backport-autotrigger.yml +++ b/.github/workflows/pr-backport-autotrigger.yml @@ -10,43 +10,83 @@ jobs: runs-on: ubuntu-latest steps: - name: Checkout code - uses: actions/checkout@v2 + uses: actions/checkout@v3 + with: + fetch-depth: 0 # Fetch all history for all branches and tags + + - name: Set up Git + run: | + git config user.name "github-actions[bot]" + git config user.email "github-actions[bot]@users.noreply.github.com" - name: Check for Backport Checkbox id: checkbox-check run: | PR_BODY="${{ github.event.pull_request.body }}" if [[ "$PR_BODY" == *"[x] This PR should be backported"* ]]; then - echo "::set-output name=backport::true" + echo "backport=true" >> $GITHUB_OUTPUT else - echo "::set-output name=backport::false" + echo "backport=false" >> $GITHUB_OUTPUT fi - name: List and sort release branches id: list-branches run: | - git fetch --all - BRANCHES=$(git branch -r | grep 'origin/release/v' | sed 's|origin/release/v||' | sort -Vr) + git fetch --all --tags + BRANCHES=$(git for-each-ref --format='%(refname:short)' refs/remotes/origin/release/* | sed 's|origin/release/||' | sort -Vr) BETA=$(echo "$BRANCHES" | head -n 1) STABLE=$(echo "$BRANCHES" | head -n 2 | tail -n 1) - echo "::set-output name=beta::$BETA" - echo "::set-output name=stable::$STABLE" + echo "beta=$BETA" >> $GITHUB_OUTPUT + echo "stable=$STABLE" >> $GITHUB_OUTPUT + # Fetch latest tags for beta and stable + LATEST_BETA_TAG=$(git tag -l "v*.*.0-beta.*" | sort -Vr | head -n 1) + LATEST_STABLE_TAG=$(git tag -l "v*.*.*" | grep -v -- "-beta" | sort -Vr | head -n 1) + # Increment latest beta tag + NEW_BETA_TAG=$(echo $LATEST_BETA_TAG | awk -F '[.-]' '{print $1 "." $2 ".0-beta." ($NF+1)}') + # Increment latest stable tag + NEW_STABLE_TAG=$(echo $LATEST_STABLE_TAG | awk -F '.' '{print $1 "." $2 "." ($3+1)}') + echo "latest_beta_tag=$LATEST_BETA_TAG" >> $GITHUB_OUTPUT + echo "latest_stable_tag=$LATEST_STABLE_TAG" >> $GITHUB_OUTPUT + echo "new_beta_tag=$NEW_BETA_TAG" >> $GITHUB_OUTPUT + echo "new_stable_tag=$NEW_STABLE_TAG" >> $GITHUB_OUTPUT + + - name: Echo branch and tag information + run: | + echo "Beta branch: ${{ steps.list-branches.outputs.beta }}" + echo "Stable branch: ${{ steps.list-branches.outputs.stable }}" + echo "Latest beta tag: ${{ steps.list-branches.outputs.latest_beta_tag }}" + echo "Latest stable tag: ${{ steps.list-branches.outputs.latest_stable_tag }}" + echo "New beta tag: ${{ steps.list-branches.outputs.new_beta_tag }}" + echo "New stable tag: ${{ steps.list-branches.outputs.new_stable_tag }}" - name: Trigger Backport if: steps.checkbox-check.outputs.backport == 'true' run: | + set -e echo "Backporting to beta ${{ steps.list-branches.outputs.beta }} and stable ${{ steps.list-branches.outputs.stable }}" # Fetch all history for all branches and tags git fetch --prune --unshallow # Checkout the beta branch git checkout ${{ steps.list-branches.outputs.beta }} - # Cherry-pick the last commit from the merged PR - git cherry-pick ${{ github.event.pull_request.merge_commit_sha }} - # Push the changes to the beta branch + # Cherry-pick the merge commit from the merged PR + git cherry-pick -m 1 ${{ github.event.pull_request.merge_commit_sha }} || { + echo "Cherry-pick to beta failed due to conflicts." + exit 1 + } + # Create new beta tag + git tag ${{ steps.list-branches.outputs.new_beta_tag }} + # Push the changes and tag to the beta branch git push origin ${{ steps.list-branches.outputs.beta }} + git push origin ${{ steps.list-branches.outputs.new_beta_tag }} # Checkout the stable branch git checkout ${{ steps.list-branches.outputs.stable }} - # Cherry-pick the last commit from the merged PR - git cherry-pick ${{ github.event.pull_request.merge_commit_sha }} - # Push the changes to the stable branch + # Cherry-pick the merge commit from the merged PR + git cherry-pick -m 1 ${{ github.event.pull_request.merge_commit_sha }} || { + echo "Cherry-pick to stable failed due to conflicts." + exit 1 + } + # Create new stable tag + git tag ${{ steps.list-branches.outputs.new_stable_tag }} + # Push the changes and tag to the stable branch git push origin ${{ steps.list-branches.outputs.stable }} + git push origin ${{ steps.list-branches.outputs.new_stable_tag }} diff --git a/backend/alembic/versions/b082fec533f0_make_last_attempt_status_nullable.py b/backend/alembic/versions/b082fec533f0_make_last_attempt_status_nullable.py index a6938e365c..db7b330c3e 100644 --- a/backend/alembic/versions/b082fec533f0_make_last_attempt_status_nullable.py +++ b/backend/alembic/versions/b082fec533f0_make_last_attempt_status_nullable.py @@ -31,6 +31,12 @@ def upgrade() -> None: def downgrade() -> None: + # First, update any null values to a default value + op.execute( + "UPDATE connector_credential_pair SET last_attempt_status = 'NOT_STARTED' WHERE last_attempt_status IS NULL" + ) + + # Then, make the column non-nullable op.alter_column( "connector_credential_pair", "last_attempt_status", diff --git a/backend/danswer/background/celery/apps/primary.py b/backend/danswer/background/celery/apps/primary.py index 23e25fa92b..983b76773e 100644 --- a/backend/danswer/background/celery/apps/primary.py +++ b/backend/danswer/background/celery/apps/primary.py @@ -17,6 +17,7 @@ from danswer.background.celery.celery_redis import RedisConnectorCredentialPair from danswer.background.celery.celery_redis import RedisConnectorDeletion from danswer.background.celery.celery_redis import RedisConnectorIndexing from danswer.background.celery.celery_redis import RedisConnectorPruning +from danswer.background.celery.celery_redis import RedisConnectorStop from danswer.background.celery.celery_redis import RedisDocumentSet from danswer.background.celery.celery_redis import RedisUserGroup from danswer.background.celery.celery_utils import celery_is_worker_primary @@ -161,6 +162,9 @@ def on_worker_init(sender: Any, **kwargs: Any) -> None: for key in r.scan_iter(RedisConnectorIndexing.FENCE_PREFIX + "*"): r.delete(key) + for key in r.scan_iter(RedisConnectorStop.FENCE_PREFIX + "*"): + r.delete(key) + # @worker_process_init.connect # def on_worker_process_init(sender: Any, **kwargs: Any) -> None: diff --git a/backend/danswer/background/celery/celery_redis.py b/backend/danswer/background/celery/celery_redis.py index 1ea5e3b176..e412b0bf73 100644 --- a/backend/danswer/background/celery/celery_redis.py +++ b/backend/danswer/background/celery/celery_redis.py @@ -313,6 +313,8 @@ class RedisConnectorDeletion(RedisObjectHelper): lock: redis.lock.Lock, tenant_id: str | None, ) -> int | None: + """Returns None if the cc_pair doesn't exist. + Otherwise, returns an int with the number of generated tasks.""" last_lock_time = time.monotonic() async_results = [] @@ -540,6 +542,29 @@ class RedisConnectorIndexing(RedisObjectHelper): return False +class RedisConnectorStop(RedisObjectHelper): + """Used to signal any running tasks for a connector to stop. We should refactor + connector related redis helpers into a single class. + """ + + PREFIX = "connectorstop" + FENCE_PREFIX = PREFIX + "_fence" # a fence for the entire indexing process + TASKSET_PREFIX = PREFIX + "_taskset" # stores a list of prune tasks id's + + def __init__(self, id: int) -> None: + super().__init__(str(id)) + + def generate_tasks( + self, + celery_app: Celery, + db_session: Session, + redis_client: Redis, + lock: redis.lock.Lock | None, + tenant_id: str | None, + ) -> int | None: + return None + + def celery_get_queue_length(queue: str, r: Redis) -> int: """This is a redis specific way to get the length of a celery queue. It is priority aware and knows how to count across the multiple redis lists diff --git a/backend/danswer/background/celery/celery_utils.py b/backend/danswer/background/celery/celery_utils.py index b1e9c2113e..18038a349d 100644 --- a/backend/danswer/background/celery/celery_utils.py +++ b/backend/danswer/background/celery/celery_utils.py @@ -1,4 +1,3 @@ -from collections.abc import Callable from datetime import datetime from datetime import timezone from typing import Any @@ -6,6 +5,7 @@ from typing import Any from sqlalchemy.orm import Session from danswer.background.celery.celery_redis import RedisConnectorDeletion +from danswer.background.indexing.run_indexing import RunIndexingCallbackInterface from danswer.configs.app_configs import MAX_PRUNING_DOCUMENT_RETRIEVAL_PER_MINUTE from danswer.connectors.cross_connector_utils.rate_limit_wrapper import ( rate_limit_builder, @@ -79,7 +79,7 @@ def document_batch_to_ids( def extract_ids_from_runnable_connector( runnable_connector: BaseConnector, - progress_callback: Callable[[int], None] | None = None, + callback: RunIndexingCallbackInterface | None = None, ) -> set[str]: """ If the PruneConnector hasnt been implemented for the given connector, just pull @@ -110,8 +110,10 @@ def extract_ids_from_runnable_connector( max_calls=MAX_PRUNING_DOCUMENT_RETRIEVAL_PER_MINUTE, period=60 )(document_batch_to_ids) for doc_batch in doc_batch_generator: - if progress_callback: - progress_callback(len(doc_batch)) + if callback: + if callback.should_stop(): + raise RuntimeError("Stop signal received") + callback.progress(len(doc_batch)) all_connector_doc_ids.update(doc_batch_processing_func(doc_batch)) return all_connector_doc_ids diff --git a/backend/danswer/background/celery/tasks/connector_deletion/tasks.py b/backend/danswer/background/celery/tasks/connector_deletion/tasks.py index f6a59d03e3..59d236cde3 100644 --- a/backend/danswer/background/celery/tasks/connector_deletion/tasks.py +++ b/backend/danswer/background/celery/tasks/connector_deletion/tasks.py @@ -1,3 +1,6 @@ +from datetime import datetime +from datetime import timezone + import redis from celery import Celery from celery import shared_task @@ -8,6 +11,12 @@ from sqlalchemy.orm import Session from danswer.background.celery.apps.app_base import task_logger from danswer.background.celery.celery_redis import RedisConnectorDeletion +from danswer.background.celery.celery_redis import RedisConnectorIndexing +from danswer.background.celery.celery_redis import RedisConnectorPruning +from danswer.background.celery.celery_redis import RedisConnectorStop +from danswer.background.celery.tasks.shared.RedisConnectorDeletionFenceData import ( + RedisConnectorDeletionFenceData, +) from danswer.configs.app_configs import JOB_TIMEOUT from danswer.configs.constants import CELERY_VESPA_SYNC_BEAT_LOCK_TIMEOUT from danswer.configs.constants import DanswerRedisLocks @@ -15,9 +24,15 @@ from danswer.db.connector_credential_pair import get_connector_credential_pair_f from danswer.db.connector_credential_pair import get_connector_credential_pairs from danswer.db.engine import get_session_with_tenant from danswer.db.enums import ConnectorCredentialPairStatus +from danswer.db.search_settings import get_all_search_settings from danswer.redis.redis_pool import get_redis_client +class TaskDependencyError(RuntimeError): + """Raised to the caller to indicate dependent tasks are running that would interfere + with connector deletion.""" + + @shared_task( name="check_for_connector_deletion_task", soft_time_limit=JOB_TIMEOUT, @@ -37,17 +52,30 @@ def check_for_connector_deletion_task(self: Task, *, tenant_id: str | None) -> N if not lock_beat.acquire(blocking=False): return + # collect cc_pair_ids cc_pair_ids: list[int] = [] with get_session_with_tenant(tenant_id) as db_session: cc_pairs = get_connector_credential_pairs(db_session) for cc_pair in cc_pairs: cc_pair_ids.append(cc_pair.id) + # try running cleanup on the cc_pair_ids for cc_pair_id in cc_pair_ids: with get_session_with_tenant(tenant_id) as db_session: - try_generate_document_cc_pair_cleanup_tasks( - self.app, cc_pair_id, db_session, r, lock_beat, tenant_id - ) + rcs = RedisConnectorStop(cc_pair_id) + try: + try_generate_document_cc_pair_cleanup_tasks( + self.app, cc_pair_id, db_session, r, lock_beat, tenant_id + ) + except TaskDependencyError as e: + # this means we wanted to start deleting but dependent tasks were running + # Leave a stop signal to clear indexing and pruning tasks more quickly + task_logger.info(str(e)) + r.set(rcs.fence_key, cc_pair_id) + else: + # clear the stop signal if it exists ... no longer needed + r.delete(rcs.fence_key) + except SoftTimeLimitExceeded: task_logger.info( "Soft time limit exceeded, task is being terminated gracefully." @@ -70,6 +98,10 @@ def try_generate_document_cc_pair_cleanup_tasks( """Returns an int if syncing is needed. The int represents the number of sync tasks generated. Note that syncing can still be required even if the number of sync tasks generated is zero. Returns None if no syncing is required. + + Will raise TaskDependencyError if dependent tasks such as indexing and pruning are + still running. In our case, the caller reacts by setting a stop signal in Redis to + exit those tasks as quickly as possible. """ lock_beat.reacquire() @@ -90,28 +122,63 @@ def try_generate_document_cc_pair_cleanup_tasks( if cc_pair.status != ConnectorCredentialPairStatus.DELETING: return None - # add tasks to celery and build up the task set to monitor in redis - r.delete(rcd.taskset_key) - - # Add all documents that need to be updated into the queue - task_logger.info( - f"RedisConnectorDeletion.generate_tasks starting. cc_pair_id={cc_pair.id}" + # set a basic fence to start + fence_value = RedisConnectorDeletionFenceData( + num_tasks=None, + submitted=datetime.now(timezone.utc), ) - tasks_generated = rcd.generate_tasks(app, db_session, r, lock_beat, tenant_id) - if tasks_generated is None: + r.set(rcd.fence_key, fence_value.model_dump_json()) + + try: + # do not proceed if connector indexing or connector pruning are running + search_settings_list = get_all_search_settings(db_session) + for search_settings in search_settings_list: + rci = RedisConnectorIndexing(cc_pair_id, search_settings.id) + if r.get(rci.fence_key): + raise TaskDependencyError( + f"Connector deletion - Delayed (indexing in progress): " + f"cc_pair={cc_pair_id} " + f"search_settings={search_settings.id}" + ) + + rcp = RedisConnectorPruning(cc_pair_id) + if r.get(rcp.fence_key): + raise TaskDependencyError( + f"Connector deletion - Delayed (pruning in progress): " + f"cc_pair={cc_pair_id}" + ) + + # add tasks to celery and build up the task set to monitor in redis + r.delete(rcd.taskset_key) + + # Add all documents that need to be updated into the queue + task_logger.info( + f"RedisConnectorDeletion.generate_tasks starting. cc_pair={cc_pair_id}" + ) + tasks_generated = rcd.generate_tasks(app, db_session, r, lock_beat, tenant_id) + if tasks_generated is None: + raise ValueError("RedisConnectorDeletion.generate_tasks returned None") + except TaskDependencyError: + r.delete(rcd.fence_key) + raise + except Exception: + task_logger.exception("Unexpected exception") + r.delete(rcd.fence_key) return None + else: + # Currently we are allowing the sync to proceed with 0 tasks. + # It's possible for sets/groups to be generated initially with no entries + # and they still need to be marked as up to date. + # if tasks_generated == 0: + # return 0 - # Currently we are allowing the sync to proceed with 0 tasks. - # It's possible for sets/groups to be generated initially with no entries - # and they still need to be marked as up to date. - # if tasks_generated == 0: - # return 0 + task_logger.info( + f"RedisConnectorDeletion.generate_tasks finished. " + f"cc_pair={cc_pair_id} tasks_generated={tasks_generated}" + ) - task_logger.info( - f"RedisConnectorDeletion.generate_tasks finished. " - f"cc_pair_id={cc_pair.id} tasks_generated={tasks_generated}" - ) + # set this only after all tasks have been added + fence_value.num_tasks = tasks_generated + r.set(rcd.fence_key, fence_value.model_dump_json()) - # set this only after all tasks have been added - r.set(rcd.fence_key, tasks_generated) return tasks_generated diff --git a/backend/danswer/background/celery/tasks/indexing/tasks.py b/backend/danswer/background/celery/tasks/indexing/tasks.py index bdd55f77f3..980266ec87 100644 --- a/backend/danswer/background/celery/tasks/indexing/tasks.py +++ b/backend/danswer/background/celery/tasks/indexing/tasks.py @@ -5,6 +5,7 @@ from time import sleep from typing import cast from uuid import uuid4 +import redis from celery import Celery from celery import shared_task from celery import Task @@ -13,12 +14,15 @@ from redis import Redis from sqlalchemy.orm import Session from danswer.background.celery.apps.app_base import task_logger +from danswer.background.celery.celery_redis import RedisConnectorDeletion from danswer.background.celery.celery_redis import RedisConnectorIndexing +from danswer.background.celery.celery_redis import RedisConnectorStop from danswer.background.celery.tasks.shared.RedisConnectorIndexingFenceData import ( RedisConnectorIndexingFenceData, ) from danswer.background.indexing.job_client import SimpleJobClient from danswer.background.indexing.run_indexing import run_indexing_entrypoint +from danswer.background.indexing.run_indexing import RunIndexingCallbackInterface from danswer.configs.app_configs import DISABLE_INDEX_UPDATE_ON_SWAP from danswer.configs.constants import CELERY_INDEXING_LOCK_TIMEOUT from danswer.configs.constants import CELERY_VESPA_SYNC_BEAT_LOCK_TIMEOUT @@ -50,6 +54,30 @@ from danswer.utils.variable_functionality import global_version logger = setup_logger() +class RunIndexingCallback(RunIndexingCallbackInterface): + def __init__( + self, + stop_key: str, + generator_progress_key: str, + redis_lock: redis.lock.Lock, + redis_client: Redis, + ): + super().__init__() + self.redis_lock: redis.lock.Lock = redis_lock + self.stop_key: str = stop_key + self.generator_progress_key: str = generator_progress_key + self.redis_client = redis_client + + def should_stop(self) -> bool: + if self.redis_client.exists(self.stop_key): + return True + return False + + def progress(self, amount: int) -> None: + self.redis_lock.reacquire() + self.redis_client.incrby(self.generator_progress_key, amount) + + @shared_task( name="check_for_indexing", soft_time_limit=300, @@ -262,6 +290,10 @@ def try_creating_indexing_task( return None # skip indexing if the cc_pair is deleting + rcd = RedisConnectorDeletion(cc_pair.id) + if r.exists(rcd.fence_key): + return None + db_session.refresh(cc_pair) if cc_pair.status == ConnectorCredentialPairStatus.DELETING: return None @@ -308,13 +340,8 @@ def try_creating_indexing_task( raise RuntimeError("send_task for connector_indexing_proxy_task failed.") # now fill out the fence with the rest of the data - fence_value = RedisConnectorIndexingFenceData( - index_attempt_id=index_attempt_id, - started=None, - submitted=datetime.now(timezone.utc), - celery_task_id=result.id, - ) - + fence_value.index_attempt_id = index_attempt_id + fence_value.celery_task_id = result.id r.set(rci.fence_key, fence_value.model_dump_json()) except Exception: r.delete(rci.fence_key) @@ -400,6 +427,22 @@ def connector_indexing_task( r = get_redis_client(tenant_id=tenant_id) + rcd = RedisConnectorDeletion(cc_pair_id) + if r.exists(rcd.fence_key): + raise RuntimeError( + f"Indexing will not start because connector deletion is in progress: " + f"cc_pair={cc_pair_id} " + f"fence={rcd.fence_key}" + ) + + rcs = RedisConnectorStop(cc_pair_id) + if r.exists(rcs.fence_key): + raise RuntimeError( + f"Indexing will not start because a connector stop signal was detected: " + f"cc_pair={cc_pair_id} " + f"fence={rcs.fence_key}" + ) + rci = RedisConnectorIndexing(cc_pair_id, search_settings_id) while True: @@ -409,7 +452,7 @@ def connector_indexing_task( task_logger.info( f"connector_indexing_task: fence_value not found: fence={rci.fence_key}" ) - raise + raise RuntimeError(f"Fence not found: fence={rci.fence_key}") try: fence_json = fence_value.decode("utf-8") @@ -443,17 +486,20 @@ def connector_indexing_task( if not acquired: task_logger.warning( f"Indexing task already running, exiting...: " - f"cc_pair_id={cc_pair_id} search_settings_id={search_settings_id}" + f"cc_pair={cc_pair_id} search_settings={search_settings_id}" ) # r.set(rci.generator_complete_key, HTTPStatus.CONFLICT.value) return None + fence_data.started = datetime.now(timezone.utc) + r.set(rci.fence_key, fence_data.model_dump_json()) + try: with get_session_with_tenant(tenant_id) as db_session: attempt = get_index_attempt(db_session, index_attempt_id) if not attempt: raise ValueError( - f"Index attempt not found: index_attempt_id={index_attempt_id}" + f"Index attempt not found: index_attempt={index_attempt_id}" ) cc_pair = get_connector_credential_pair_from_id( @@ -462,31 +508,31 @@ def connector_indexing_task( ) if not cc_pair: - raise ValueError(f"cc_pair not found: cc_pair_id={cc_pair_id}") + raise ValueError(f"cc_pair not found: cc_pair={cc_pair_id}") if not cc_pair.connector: raise ValueError( - f"Connector not found: connector_id={cc_pair.connector_id}" + f"Connector not found: cc_pair={cc_pair_id} connector={cc_pair.connector_id}" ) if not cc_pair.credential: raise ValueError( - f"Credential not found: credential_id={cc_pair.credential_id}" + f"Credential not found: cc_pair={cc_pair_id} credential={cc_pair.credential_id}" ) rci = RedisConnectorIndexing(cc_pair_id, search_settings_id) - # Define the callback function - def redis_increment_callback(amount: int) -> None: - lock.reacquire() - r.incrby(rci.generator_progress_key, amount) + # define a callback class + callback = RunIndexingCallback( + rcs.fence_key, rci.generator_progress_key, lock, r + ) run_indexing_entrypoint( index_attempt_id, tenant_id, cc_pair_id, is_ee, - progress_callback=redis_increment_callback, + callback=callback, ) # get back the total number of indexed docs and return it @@ -499,9 +545,10 @@ def connector_indexing_task( r.set(rci.generator_complete_key, HTTPStatus.OK.value) except Exception as e: - task_logger.exception(f"Failed to run indexing for cc_pair_id={cc_pair_id}.") + task_logger.exception(f"Indexing failed: cc_pair={cc_pair_id}") if attempt: - mark_attempt_failed(attempt, db_session, failure_reason=str(e)) + with get_session_with_tenant(tenant_id) as db_session: + mark_attempt_failed(attempt, db_session, failure_reason=str(e)) r.delete(rci.generator_lock_key) r.delete(rci.generator_progress_key) diff --git a/backend/danswer/background/celery/tasks/pruning/tasks.py b/backend/danswer/background/celery/tasks/pruning/tasks.py index 9f290d6f23..2e68986e83 100644 --- a/backend/danswer/background/celery/tasks/pruning/tasks.py +++ b/backend/danswer/background/celery/tasks/pruning/tasks.py @@ -11,8 +11,11 @@ from redis import Redis from sqlalchemy.orm import Session from danswer.background.celery.apps.app_base import task_logger +from danswer.background.celery.celery_redis import RedisConnectorDeletion from danswer.background.celery.celery_redis import RedisConnectorPruning +from danswer.background.celery.celery_redis import RedisConnectorStop from danswer.background.celery.celery_utils import extract_ids_from_runnable_connector +from danswer.background.celery.tasks.indexing.tasks import RunIndexingCallback from danswer.configs.app_configs import ALLOW_SIMULTANEOUS_PRUNING from danswer.configs.app_configs import JOB_TIMEOUT from danswer.configs.constants import CELERY_PRUNING_LOCK_TIMEOUT @@ -168,6 +171,10 @@ def try_creating_prune_generator_task( return None # skip pruning if the cc_pair is deleting + rcd = RedisConnectorDeletion(cc_pair.id) + if r.exists(rcd.fence_key): + return None + db_session.refresh(cc_pair) if cc_pair.status == ConnectorCredentialPairStatus.DELETING: return None @@ -234,7 +241,7 @@ def connector_pruning_generator_task( acquired = lock.acquire(blocking=False) if not acquired: task_logger.warning( - f"Pruning task already running, exiting...: cc_pair_id={cc_pair_id}" + f"Pruning task already running, exiting...: cc_pair={cc_pair_id}" ) return None @@ -252,11 +259,6 @@ def connector_pruning_generator_task( ) return - # Define the callback function - def redis_increment_callback(amount: int) -> None: - lock.reacquire() - r.incrby(rcp.generator_progress_key, amount) - runnable_connector = instantiate_connector( db_session, cc_pair.connector.source, @@ -265,9 +267,14 @@ def connector_pruning_generator_task( cc_pair.credential, ) + rcs = RedisConnectorStop(cc_pair_id) + + callback = RunIndexingCallback( + rcs.fence_key, rcp.generator_progress_key, lock, r + ) # a list of docs in the source all_connector_doc_ids: set[str] = extract_ids_from_runnable_connector( - runnable_connector, redis_increment_callback + runnable_connector, callback ) # a list of docs in our local index @@ -285,7 +292,7 @@ def connector_pruning_generator_task( task_logger.info( f"Pruning set collected: " - f"cc_pair_id={cc_pair.id} " + f"cc_pair={cc_pair_id} " f"docs_to_remove={len(doc_ids_to_remove)} " f"doc_source={cc_pair.connector.source}" ) @@ -293,7 +300,7 @@ def connector_pruning_generator_task( rcp.documents_to_prune = set(doc_ids_to_remove) task_logger.info( - f"RedisConnectorPruning.generate_tasks starting. cc_pair_id={cc_pair.id}" + f"RedisConnectorPruning.generate_tasks starting. cc_pair_id={cc_pair_id}" ) tasks_generated = rcp.generate_tasks( self.app, db_session, r, None, tenant_id @@ -303,12 +310,14 @@ def connector_pruning_generator_task( task_logger.info( f"RedisConnectorPruning.generate_tasks finished. " - f"cc_pair_id={cc_pair.id} tasks_generated={tasks_generated}" + f"cc_pair={cc_pair_id} tasks_generated={tasks_generated}" ) r.set(rcp.generator_complete_key, tasks_generated) except Exception as e: - task_logger.exception(f"Failed to run pruning for connector id {connector_id}.") + task_logger.exception( + f"Failed to run pruning: cc_pair={cc_pair_id} connector={connector_id}" + ) r.delete(rcp.generator_progress_key) r.delete(rcp.taskset_key) diff --git a/backend/danswer/background/celery/tasks/shared/RedisConnectorDeletionFenceData.py b/backend/danswer/background/celery/tasks/shared/RedisConnectorDeletionFenceData.py new file mode 100644 index 0000000000..1c664d14b4 --- /dev/null +++ b/backend/danswer/background/celery/tasks/shared/RedisConnectorDeletionFenceData.py @@ -0,0 +1,8 @@ +from datetime import datetime + +from pydantic import BaseModel + + +class RedisConnectorDeletionFenceData(BaseModel): + num_tasks: int | None + submitted: datetime diff --git a/backend/danswer/background/celery/tasks/shared/tasks.py b/backend/danswer/background/celery/tasks/shared/tasks.py index 7ce43454aa..52a49d467e 100644 --- a/backend/danswer/background/celery/tasks/shared/tasks.py +++ b/backend/danswer/background/celery/tasks/shared/tasks.py @@ -1,9 +1,6 @@ -from datetime import datetime - from celery import shared_task from celery import Task from celery.exceptions import SoftTimeLimitExceeded -from pydantic import BaseModel from danswer.access.access import get_access_for_document from danswer.background.celery.apps.app_base import task_logger @@ -23,13 +20,6 @@ from danswer.server.documents.models import ConnectorCredentialPairIdentifier DOCUMENT_BY_CC_PAIR_CLEANUP_MAX_RETRIES = 3 -class RedisConnectorIndexingFenceData(BaseModel): - index_attempt_id: int | None - started: datetime | None - submitted: datetime - celery_task_id: str | None - - @shared_task( name="document_by_cc_pair_cleanup_task", bind=True, diff --git a/backend/danswer/background/celery/tasks/vespa/tasks.py b/backend/danswer/background/celery/tasks/vespa/tasks.py index 812074b91e..fcc4d2aa5b 100644 --- a/backend/danswer/background/celery/tasks/vespa/tasks.py +++ b/backend/danswer/background/celery/tasks/vespa/tasks.py @@ -23,6 +23,9 @@ from danswer.background.celery.celery_redis import RedisConnectorIndexing from danswer.background.celery.celery_redis import RedisConnectorPruning from danswer.background.celery.celery_redis import RedisDocumentSet from danswer.background.celery.celery_redis import RedisUserGroup +from danswer.background.celery.tasks.shared.RedisConnectorDeletionFenceData import ( + RedisConnectorDeletionFenceData, +) from danswer.background.celery.tasks.shared.RedisConnectorIndexingFenceData import ( RedisConnectorIndexingFenceData, ) @@ -368,7 +371,7 @@ def monitor_document_set_taskset( count = cast(int, r.scard(rds.taskset_key)) task_logger.info( - f"Document set sync progress: document_set_id={document_set_id} " + f"Document set sync progress: document_set={document_set_id} " f"remaining={count} initial={initial_count}" ) if count > 0: @@ -383,12 +386,12 @@ def monitor_document_set_taskset( # if there are no connectors, then delete the document set. delete_document_set(document_set_row=document_set, db_session=db_session) task_logger.info( - f"Successfully deleted document set with ID: '{document_set_id}'!" + f"Successfully deleted document set: document_set={document_set_id}" ) else: mark_document_set_as_synced(document_set_id, db_session) task_logger.info( - f"Successfully synced document set with ID: '{document_set_id}'!" + f"Successfully synced document set: document_set={document_set_id}" ) r.delete(rds.taskset_key) @@ -408,19 +411,29 @@ def monitor_connector_deletion_taskset( rcd = RedisConnectorDeletion(cc_pair_id) - fence_value = r.get(rcd.fence_key) + # read related data and evaluate/print task progress + fence_value = cast(bytes, r.get(rcd.fence_key)) if fence_value is None: return try: - initial_count = int(cast(int, fence_value)) + fence_json = fence_value.decode("utf-8") + fence_data = RedisConnectorDeletionFenceData.model_validate_json( + cast(str, fence_json) + ) except ValueError: - task_logger.error("The value is not an integer.") + task_logger.exception( + "monitor_ccpair_indexing_taskset: fence_data not decodeable." + ) + raise + + # the fence is setting up but isn't ready yet + if fence_data.num_tasks is None: return count = cast(int, r.scard(rcd.taskset_key)) task_logger.info( - f"Connector deletion progress: cc_pair={cc_pair_id} remaining={count} initial={initial_count}" + f"Connector deletion progress: cc_pair={cc_pair_id} remaining={count} initial={fence_data.num_tasks}" ) if count > 0: return @@ -483,7 +496,7 @@ def monitor_connector_deletion_taskset( ) if not connector or not len(connector.credentials): task_logger.info( - "Found no credentials left for connector, deleting connector" + "Connector deletion - Found no credentials left for connector, deleting connector" ) db_session.delete(connector) db_session.commit() @@ -493,17 +506,17 @@ def monitor_connector_deletion_taskset( error_message = f"Error: {str(e)}\n\nStack Trace:\n{stack_trace}" add_deletion_failure_message(db_session, cc_pair_id, error_message) task_logger.exception( - f"Failed to run connector_deletion. " + f"Connector deletion exceptioned: " f"cc_pair={cc_pair_id} connector={cc_pair.connector_id} credential={cc_pair.credential_id}" ) raise e task_logger.info( - f"Successfully deleted cc_pair: " + f"Connector deletion succeeded: " f"cc_pair={cc_pair_id} " f"connector={cc_pair.connector_id} " f"credential={cc_pair.credential_id} " - f"docs_deleted={initial_count}" + f"docs_deleted={fence_data.num_tasks}" ) r.delete(rcd.taskset_key) @@ -618,6 +631,7 @@ def monitor_ccpair_indexing_taskset( return # Read result state BEFORE generator_complete_key to avoid a race condition + # never use any blocking methods on the result from inside a task! result: AsyncResult = AsyncResult(fence_data.celery_task_id) result_state = result.state diff --git a/backend/danswer/background/indexing/run_indexing.py b/backend/danswer/background/indexing/run_indexing.py index cb50739045..d95a6a70d5 100644 --- a/backend/danswer/background/indexing/run_indexing.py +++ b/backend/danswer/background/indexing/run_indexing.py @@ -1,6 +1,7 @@ import time import traceback -from collections.abc import Callable +from abc import ABC +from abc import abstractmethod from datetime import datetime from datetime import timedelta from datetime import timezone @@ -41,6 +42,19 @@ logger = setup_logger() INDEXING_TRACER_NUM_PRINT_ENTRIES = 5 +class RunIndexingCallbackInterface(ABC): + """Defines a callback interface to be passed to + to run_indexing_entrypoint.""" + + @abstractmethod + def should_stop(self) -> bool: + """Signal to stop the looping function in flight.""" + + @abstractmethod + def progress(self, amount: int) -> None: + """Send progress updates to the caller.""" + + def _get_connector_runner( db_session: Session, attempt: IndexAttempt, @@ -92,7 +106,7 @@ def _run_indexing( db_session: Session, index_attempt: IndexAttempt, tenant_id: str | None, - progress_callback: Callable[[int], None] | None = None, + callback: RunIndexingCallbackInterface | None = None, ) -> None: """ 1. Get documents which are either new or updated from specified application @@ -206,6 +220,11 @@ def _run_indexing( # index being built. We want to populate it even for paused connectors # Often paused connectors are sources that aren't updated frequently but the # contents still need to be initially pulled. + if callback: + if callback.should_stop(): + raise RuntimeError("Connector stop signal detected") + + # TODO: should we move this into the above callback instead? db_session.refresh(db_cc_pair) if ( ( @@ -263,8 +282,8 @@ def _run_indexing( # be inaccurate db_session.commit() - if progress_callback: - progress_callback(len(doc_batch)) + if callback: + callback.progress(len(doc_batch)) # This new value is updated every batch, so UI can refresh per batch update update_docs_indexed( @@ -394,7 +413,7 @@ def run_indexing_entrypoint( tenant_id: str | None, connector_credential_pair_id: int, is_ee: bool = False, - progress_callback: Callable[[int], None] | None = None, + callback: RunIndexingCallbackInterface | None = None, ) -> None: try: if is_ee: @@ -417,7 +436,7 @@ def run_indexing_entrypoint( f"credentials='{attempt.connector_credential_pair.connector_id}'" ) - _run_indexing(db_session, attempt, tenant_id, progress_callback) + _run_indexing(db_session, attempt, tenant_id, callback) logger.info( f"Indexing finished for tenant {tenant_id}: " diff --git a/backend/danswer/chat/process_message.py b/backend/danswer/chat/process_message.py index ea4e7be93d..f58a34c324 100644 --- a/backend/danswer/chat/process_message.py +++ b/backend/danswer/chat/process_message.py @@ -672,6 +672,7 @@ def stream_chat_message_objects( all_docs_useful=selected_db_search_docs is not None ), document_pruning_config=document_pruning_config, + structured_response_format=new_msg_req.structured_response_format, ), prompt_config=prompt_config, llm=( diff --git a/backend/danswer/llm/answering/answer.py b/backend/danswer/llm/answering/answer.py index 4648e0fe82..d2aeb1b14c 100644 --- a/backend/danswer/llm/answering/answer.py +++ b/backend/danswer/llm/answering/answer.py @@ -237,6 +237,7 @@ class Answer: prompt=prompt, tools=final_tool_definitions if final_tool_definitions else None, tool_choice="required" if self.force_use_tool.force_use else None, + structured_response_format=self.answer_style_config.structured_response_format, ): if isinstance(message, AIMessageChunk) and ( message.tool_call_chunks or message.tool_calls @@ -331,7 +332,10 @@ class Answer: tool_choice: ToolChoiceOptions | None = None, ) -> Iterator[str | StreamStopInfo]: for message in self.llm.stream( - prompt=prompt, tools=tools, tool_choice=tool_choice + prompt=prompt, + tools=tools, + tool_choice=tool_choice, + structured_response_format=self.answer_style_config.structured_response_format, ): if isinstance(message, AIMessageChunk): if message.content: diff --git a/backend/danswer/llm/answering/models.py b/backend/danswer/llm/answering/models.py index fb5fa9c313..87c1297fe9 100644 --- a/backend/danswer/llm/answering/models.py +++ b/backend/danswer/llm/answering/models.py @@ -116,6 +116,10 @@ class AnswerStyleConfig(BaseModel): document_pruning_config: DocumentPruningConfig = Field( default_factory=DocumentPruningConfig ) + # forces the LLM to return a structured response, see + # https://platform.openai.com/docs/guides/structured-outputs/introduction + # right now, only used by the simple chat API + structured_response_format: dict | None = None @model_validator(mode="after") def check_quotes_and_citation(self) -> "AnswerStyleConfig": diff --git a/backend/danswer/llm/chat_llm.py b/backend/danswer/llm/chat_llm.py index d50f825318..d450fff0a6 100644 --- a/backend/danswer/llm/chat_llm.py +++ b/backend/danswer/llm/chat_llm.py @@ -280,6 +280,7 @@ class DefaultMultiLLM(LLM): tools: list[dict] | None, tool_choice: ToolChoiceOptions | None, stream: bool, + structured_response_format: dict | None = None, ) -> litellm.ModelResponse | litellm.CustomStreamWrapper: if isinstance(prompt, list): prompt = [ @@ -313,6 +314,11 @@ class DefaultMultiLLM(LLM): # NOTE: we can't pass this in if tools are not specified # or else OpenAI throws an error **({"parallel_tool_calls": False} if tools else {}), + **( + {"response_format": structured_response_format} + if structured_response_format + else {} + ), **self._model_kwargs, ) except Exception as e: @@ -336,12 +342,16 @@ class DefaultMultiLLM(LLM): prompt: LanguageModelInput, tools: list[dict] | None = None, tool_choice: ToolChoiceOptions | None = None, + structured_response_format: dict | None = None, ) -> BaseMessage: if LOG_DANSWER_MODEL_INTERACTIONS: self.log_model_configs() response = cast( - litellm.ModelResponse, self._completion(prompt, tools, tool_choice, False) + litellm.ModelResponse, + self._completion( + prompt, tools, tool_choice, False, structured_response_format + ), ) choice = response.choices[0] if hasattr(choice, "message"): @@ -354,18 +364,21 @@ class DefaultMultiLLM(LLM): prompt: LanguageModelInput, tools: list[dict] | None = None, tool_choice: ToolChoiceOptions | None = None, + structured_response_format: dict | None = None, ) -> Iterator[BaseMessage]: if LOG_DANSWER_MODEL_INTERACTIONS: self.log_model_configs() if DISABLE_LITELLM_STREAMING: - yield self.invoke(prompt) + yield self.invoke(prompt, tools, tool_choice, structured_response_format) return output = None response = cast( litellm.CustomStreamWrapper, - self._completion(prompt, tools, tool_choice, True), + self._completion( + prompt, tools, tool_choice, True, structured_response_format + ), ) try: for part in response: diff --git a/backend/danswer/llm/custom_llm.py b/backend/danswer/llm/custom_llm.py index 4a5ba7857c..6b80406cf2 100644 --- a/backend/danswer/llm/custom_llm.py +++ b/backend/danswer/llm/custom_llm.py @@ -80,6 +80,7 @@ class CustomModelServer(LLM): prompt: LanguageModelInput, tools: list[dict] | None = None, tool_choice: ToolChoiceOptions | None = None, + structured_response_format: dict | None = None, ) -> BaseMessage: return self._execute(prompt) @@ -88,5 +89,6 @@ class CustomModelServer(LLM): prompt: LanguageModelInput, tools: list[dict] | None = None, tool_choice: ToolChoiceOptions | None = None, + structured_response_format: dict | None = None, ) -> Iterator[BaseMessage]: yield self._execute(prompt) diff --git a/backend/danswer/llm/interfaces.py b/backend/danswer/llm/interfaces.py index 6cb58e46c6..7deee11dfa 100644 --- a/backend/danswer/llm/interfaces.py +++ b/backend/danswer/llm/interfaces.py @@ -88,11 +88,14 @@ class LLM(abc.ABC): prompt: LanguageModelInput, tools: list[dict] | None = None, tool_choice: ToolChoiceOptions | None = None, + structured_response_format: dict | None = None, ) -> BaseMessage: self._precall(prompt) # TODO add a postcall to log model outputs independent of concrete class # implementation - return self._invoke_implementation(prompt, tools, tool_choice) + return self._invoke_implementation( + prompt, tools, tool_choice, structured_response_format + ) @abc.abstractmethod def _invoke_implementation( @@ -100,6 +103,7 @@ class LLM(abc.ABC): prompt: LanguageModelInput, tools: list[dict] | None = None, tool_choice: ToolChoiceOptions | None = None, + structured_response_format: dict | None = None, ) -> BaseMessage: raise NotImplementedError @@ -108,11 +112,14 @@ class LLM(abc.ABC): prompt: LanguageModelInput, tools: list[dict] | None = None, tool_choice: ToolChoiceOptions | None = None, + structured_response_format: dict | None = None, ) -> Iterator[BaseMessage]: self._precall(prompt) # TODO add a postcall to log model outputs independent of concrete class # implementation - return self._stream_implementation(prompt, tools, tool_choice) + return self._stream_implementation( + prompt, tools, tool_choice, structured_response_format + ) @abc.abstractmethod def _stream_implementation( @@ -120,5 +127,6 @@ class LLM(abc.ABC): prompt: LanguageModelInput, tools: list[dict] | None = None, tool_choice: ToolChoiceOptions | None = None, + structured_response_format: dict | None = None, ) -> Iterator[BaseMessage]: raise NotImplementedError diff --git a/backend/danswer/server/manage/administrative.py b/backend/danswer/server/manage/administrative.py index d16aa59c4c..1ceeb776ab 100644 --- a/backend/danswer/server/manage/administrative.py +++ b/backend/danswer/server/manage/administrative.py @@ -19,7 +19,6 @@ from danswer.db.connector_credential_pair import get_connector_credential_pair from danswer.db.connector_credential_pair import ( update_connector_credential_pair_from_id, ) -from danswer.db.deletion_attempt import check_deletion_attempt_is_allowed from danswer.db.engine import get_current_tenant_id from danswer.db.engine import get_session from danswer.db.enums import ConnectorCredentialPairStatus @@ -175,15 +174,19 @@ def create_deletion_attempt_for_connector_id( cc_pair_id=cc_pair.id, db_session=db_session, include_secondary_index=True ) + # TODO(rkuo): 2024-10-24 - check_deletion_attempt_is_allowed shouldn't be necessary + # any more due to background locking improvements. + # Remove the below permanently if everything is behaving for 30 days. + # Check if the deletion attempt should be allowed - deletion_attempt_disallowed_reason = check_deletion_attempt_is_allowed( - connector_credential_pair=cc_pair, db_session=db_session - ) - if deletion_attempt_disallowed_reason: - raise HTTPException( - status_code=400, - detail=deletion_attempt_disallowed_reason, - ) + # deletion_attempt_disallowed_reason = check_deletion_attempt_is_allowed( + # connector_credential_pair=cc_pair, db_session=db_session + # ) + # if deletion_attempt_disallowed_reason: + # raise HTTPException( + # status_code=400, + # detail=deletion_attempt_disallowed_reason, + # ) # mark as deleting update_connector_credential_pair_from_id( diff --git a/backend/danswer/server/query_and_chat/models.py b/backend/danswer/server/query_and_chat/models.py index 42f4100a24..1ca14f9283 100644 --- a/backend/danswer/server/query_and_chat/models.py +++ b/backend/danswer/server/query_and_chat/models.py @@ -108,6 +108,10 @@ class CreateChatMessageRequest(ChunkContext): # used for seeded chats to kick off the generation of an AI answer use_existing_user_message: bool = False + # forces the LLM to return a structured response, see + # https://platform.openai.com/docs/guides/structured-outputs/introduction + structured_response_format: dict | None = None + @model_validator(mode="after") def check_search_doc_ids_or_retrieval_options(self) -> "CreateChatMessageRequest": if self.search_doc_ids is None and self.retrieval_options is None: diff --git a/backend/ee/danswer/server/query_and_chat/chat_backend.py b/backend/ee/danswer/server/query_and_chat/chat_backend.py index dd637dcf08..b25ed8357d 100644 --- a/backend/ee/danswer/server/query_and_chat/chat_backend.py +++ b/backend/ee/danswer/server/query_and_chat/chat_backend.py @@ -176,6 +176,7 @@ def handle_simplified_chat_message( chunks_above=0, chunks_below=0, full_doc=chat_message_req.full_doc, + structured_response_format=chat_message_req.structured_response_format, ) packets = stream_chat_message_objects( @@ -202,7 +203,7 @@ def handle_send_message_simple_with_history( raise HTTPException(status_code=400, detail="Messages cannot be zero length") # This is a sanity check to make sure the chat history is valid - # It must start with a user message and alternate between user and assistant + # It must start with a user message and alternate beteen user and assistant expected_role = MessageType.USER for msg in req.messages: if not msg.message: @@ -296,6 +297,7 @@ def handle_send_message_simple_with_history( chunks_above=0, chunks_below=0, full_doc=req.full_doc, + structured_response_format=req.structured_response_format, ) packets = stream_chat_message_objects( diff --git a/backend/ee/danswer/server/query_and_chat/models.py b/backend/ee/danswer/server/query_and_chat/models.py index 052be683e9..4baf17ac8c 100644 --- a/backend/ee/danswer/server/query_and_chat/models.py +++ b/backend/ee/danswer/server/query_and_chat/models.py @@ -48,6 +48,9 @@ class BasicCreateChatMessageRequest(ChunkContext): query_override: str | None = None # If search_doc_ids provided, then retrieval options are unused search_doc_ids: list[int] | None = None + # only works if using an OpenAI model. See the following for more details: + # https://platform.openai.com/docs/guides/structured-outputs/introduction + structured_response_format: dict | None = None class BasicCreateChatMessageWithHistoryRequest(ChunkContext): @@ -60,6 +63,9 @@ class BasicCreateChatMessageWithHistoryRequest(ChunkContext): skip_rerank: bool | None = None # If search_doc_ids provided, then retrieval options are unused search_doc_ids: list[int] | None = None + # only works if using an OpenAI model. See the following for more details: + # https://platform.openai.com/docs/guides/structured-outputs/introduction + structured_response_format: dict | None = None class SimpleDoc(BaseModel): diff --git a/backend/scripts/add_connector_creation_script.py b/backend/scripts/add_connector_creation_script.py new file mode 100644 index 0000000000..9a1944080c --- /dev/null +++ b/backend/scripts/add_connector_creation_script.py @@ -0,0 +1,148 @@ +from typing import Any +from typing import Dict + +import requests + +API_SERVER_URL = "http://localhost:3000" # Adjust this to your Danswer server URL +HEADERS = {"Content-Type": "application/json"} +API_KEY = "danswer-api-key" # API key here, if auth is enabled + + +def create_connector( + name: str, + source: str, + input_type: str, + connector_specific_config: Dict[str, Any], + is_public: bool = True, + groups: list[int] | None = None, +) -> Dict[str, Any]: + connector_update_request = { + "name": name, + "source": source, + "input_type": input_type, + "connector_specific_config": connector_specific_config, + "is_public": is_public, + "groups": groups or [], + } + + response = requests.post( + url=f"{API_SERVER_URL}/api/manage/admin/connector", + json=connector_update_request, + headers=HEADERS, + ) + response.raise_for_status() + return response.json() + + +def create_credential( + name: str, + source: str, + credential_json: Dict[str, Any], + is_public: bool = True, + groups: list[int] | None = None, +) -> Dict[str, Any]: + credential_request = { + "name": name, + "source": source, + "credential_json": credential_json, + "admin_public": is_public, + "groups": groups or [], + } + + response = requests.post( + url=f"{API_SERVER_URL}/api/manage/credential", + json=credential_request, + headers=HEADERS, + ) + response.raise_for_status() + return response.json() + + +def create_cc_pair( + connector_id: int, + credential_id: int, + name: str, + access_type: str = "public", + groups: list[int] | None = None, +) -> Dict[str, Any]: + cc_pair_request = { + "name": name, + "access_type": access_type, + "groups": groups or [], + } + + response = requests.put( + url=f"{API_SERVER_URL}/api/manage/connector/{connector_id}/credential/{credential_id}", + json=cc_pair_request, + headers=HEADERS, + ) + response.raise_for_status() + return response.json() + + +def main() -> None: + # Create a Web connector + web_connector = create_connector( + name="Example Web Connector", + source="web", + input_type="load_state", + connector_specific_config={ + "base_url": "https://example.com", + "web_connector_type": "recursive", + }, + ) + print(f"Created Web Connector: {web_connector}") + + # Create a credential for the Web connector + web_credential = create_credential( + name="Example Web Credential", + source="web", + credential_json={}, # Web connectors typically don't need credentials + is_public=True, + ) + print(f"Created Web Credential: {web_credential}") + + # Create CC pair for Web connector + web_cc_pair = create_cc_pair( + connector_id=web_connector["id"], + credential_id=web_credential["id"], + name="Example Web CC Pair", + access_type="public", + ) + print(f"Created Web CC Pair: {web_cc_pair}") + + # Create a GitHub connector + github_connector = create_connector( + name="Example GitHub Connector", + source="github", + input_type="poll", + connector_specific_config={ + "repo_owner": "example-owner", + "repo_name": "example-repo", + "include_prs": True, + "include_issues": True, + }, + ) + print(f"Created GitHub Connector: {github_connector}") + + # Create a credential for the GitHub connector + github_credential = create_credential( + name="Example GitHub Credential", + source="github", + credential_json={"github_access_token": "your_github_access_token_here"}, + is_public=True, + ) + print(f"Created GitHub Credential: {github_credential}") + + # Create CC pair for GitHub connector + github_cc_pair = create_cc_pair( + connector_id=github_connector["id"], + credential_id=github_credential["id"], + name="Example GitHub CC Pair", + access_type="public", + ) + print(f"Created GitHub CC Pair: {github_cc_pair}") + + +if __name__ == "__main__": + main() diff --git a/backend/tests/integration/conftest.py b/backend/tests/integration/conftest.py index 77d9e0e702..f3d194e22b 100644 --- a/backend/tests/integration/conftest.py +++ b/backend/tests/integration/conftest.py @@ -6,7 +6,9 @@ from sqlalchemy.orm import Session from danswer.db.engine import get_session_context_manager from danswer.db.search_settings import get_current_search_settings +from tests.integration.common_utils.managers.user import UserManager from tests.integration.common_utils.reset import reset_all +from tests.integration.common_utils.test_models import DATestUser from tests.integration.common_utils.vespa import vespa_fixture @@ -44,3 +46,11 @@ def vespa_client(db_session: Session) -> vespa_fixture: @pytest.fixture def reset() -> None: reset_all() + + +@pytest.fixture +def new_admin_user(reset: None) -> DATestUser | None: + try: + return UserManager.create(name="admin_user") + except Exception: + return None diff --git a/backend/tests/integration/tests/dev_apis/test_simple_chat_api.py b/backend/tests/integration/tests/dev_apis/test_simple_chat_api.py index 0a4e7b40b5..c37d1a6235 100644 --- a/backend/tests/integration/tests/dev_apis/test_simple_chat_api.py +++ b/backend/tests/integration/tests/dev_apis/test_simple_chat_api.py @@ -1,7 +1,10 @@ +import json + import requests from danswer.configs.constants import MessageType from tests.integration.common_utils.constants import API_SERVER_URL +from tests.integration.common_utils.constants import GENERAL_HEADERS from tests.integration.common_utils.constants import NUM_DOCS from tests.integration.common_utils.managers.api_key import APIKeyManager from tests.integration.common_utils.managers.cc_pair import CCPairManager @@ -145,3 +148,85 @@ def test_using_reference_docs_with_simple_with_history_api_flow(reset: None) -> # This ensures the the document we think we are referencing when we send the search_doc_ids in the second # message is the document that we expect it to be assert response_json["top_documents"][0]["document_id"] == cc_pair_1.documents[2].id + + +def test_send_message_simple_with_history_strict_json( + new_admin_user: DATestUser | None, +) -> None: + # create connectors + LLMProviderManager.create(user_performing_action=new_admin_user) + + response = requests.post( + f"{API_SERVER_URL}/chat/send-message-simple-with-history", + json={ + # intentionally not relevant prompt to ensure that the + # structured response format is actually used + "messages": [ + { + "message": "What is green?", + "role": MessageType.USER.value, + } + ], + "persona_id": 0, + "prompt_id": 0, + "structured_response_format": { + "type": "json_schema", + "json_schema": { + "name": "presidents", + "schema": { + "type": "object", + "properties": { + "presidents": { + "type": "array", + "items": {"type": "string"}, + "description": "List of the first three US presidents", + } + }, + "required": ["presidents"], + "additionalProperties": False, + }, + "strict": True, + }, + }, + }, + headers=new_admin_user.headers if new_admin_user else GENERAL_HEADERS, + ) + assert response.status_code == 200 + + response_json = response.json() + + # Check that the answer is present + assert "answer" in response_json + assert response_json["answer"] is not None + + # helper + def clean_json_string(json_string: str) -> str: + return json_string.strip().removeprefix("```json").removesuffix("```").strip() + + # Attempt to parse the answer as JSON + try: + clean_answer = clean_json_string(response_json["answer"]) + parsed_answer = json.loads(clean_answer) + + # NOTE: do not check content, just the structure + assert isinstance(parsed_answer, dict) + assert "presidents" in parsed_answer + assert isinstance(parsed_answer["presidents"], list) + for president in parsed_answer["presidents"]: + assert isinstance(president, str) + except json.JSONDecodeError: + assert ( + False + ), f"The answer is not a valid JSON object - '{response_json['answer']}'" + + # Check that the answer_citationless is also valid JSON + assert "answer_citationless" in response_json + assert response_json["answer_citationless"] is not None + try: + clean_answer_citationless = clean_json_string( + response_json["answer_citationless"] + ) + parsed_answer_citationless = json.loads(clean_answer_citationless) + assert isinstance(parsed_answer_citationless, dict) + except json.JSONDecodeError: + assert False, "The answer_citationless is not a valid JSON object" diff --git a/web/src/app/admin/embeddings/modals/ProviderCreationModal.tsx b/web/src/app/admin/embeddings/modals/ProviderCreationModal.tsx index cf9d584439..8229d19967 100644 --- a/web/src/app/admin/embeddings/modals/ProviderCreationModal.tsx +++ b/web/src/app/admin/embeddings/modals/ProviderCreationModal.tsx @@ -173,7 +173,6 @@ export function ProviderCreationModal({ return ( { return ( - + { )} - + ); }; diff --git a/web/src/app/admin/prompt-library/modals/EditPromptModal.tsx b/web/src/app/admin/prompt-library/modals/EditPromptModal.tsx index bd20ce5e45..996b70b9fb 100644 --- a/web/src/app/admin/prompt-library/modals/EditPromptModal.tsx +++ b/web/src/app/admin/prompt-library/modals/EditPromptModal.tsx @@ -1,8 +1,9 @@ import React from "react"; import { Formik, Form, Field, ErrorMessage } from "formik"; import * as Yup from "yup"; -import { ModalWrapper } from "@/components/modals/ModalWrapper"; +import { Modal } from "@/components/Modal"; import { Button, Textarea, TextInput } from "@tremor/react"; + import { useInputPrompt } from "../hooks"; import { EditPromptModalProps } from "../interfaces"; @@ -25,20 +26,20 @@ const EditPromptModal = ({ if (error) return ( - +

Failed to load prompt data

-
+ ); if (!promptData) return ( - +

Loading...

-
+ ); return ( - + )} - + ); }; diff --git a/web/src/app/chat/modal/FeedbackModal.tsx b/web/src/app/chat/modal/FeedbackModal.tsx index 64feffefc7..39c3253b76 100644 --- a/web/src/app/chat/modal/FeedbackModal.tsx +++ b/web/src/app/chat/modal/FeedbackModal.tsx @@ -2,7 +2,7 @@ import { useState } from "react"; import { FeedbackType } from "../types"; -import { ModalWrapper } from "@/components/modals/ModalWrapper"; +import { Modal } from "@/components/Modal"; import { FilledLikeIcon } from "@/components/icons/icons"; const predefinedPositiveFeedbackOptions = @@ -49,7 +49,7 @@ export const FeedbackModal = ({ : predefinedNegativeFeedbackOptions; return ( - + <>

@@ -112,6 +112,6 @@ export const FeedbackModal = ({
- + ); }; diff --git a/web/src/app/chat/modal/MakePublicAssistantModal.tsx b/web/src/app/chat/modal/MakePublicAssistantModal.tsx index a234050a52..757cf060e8 100644 --- a/web/src/app/chat/modal/MakePublicAssistantModal.tsx +++ b/web/src/app/chat/modal/MakePublicAssistantModal.tsx @@ -1,4 +1,4 @@ -import { ModalWrapper } from "@/components/modals/ModalWrapper"; +import { Modal } from "@/components/Modal"; import { Button, Divider, Text } from "@tremor/react"; export function MakePublicAssistantModal({ @@ -11,7 +11,7 @@ export function MakePublicAssistantModal({ onClose: () => void; }) { return ( - +

{isPublic ? "Public Assistant" : "Make Assistant Public"} @@ -67,6 +67,6 @@ export function MakePublicAssistantModal({

)} -
+ ); } diff --git a/web/src/app/chat/modal/SetDefaultModelModal.tsx b/web/src/app/chat/modal/SetDefaultModelModal.tsx index 5a47d9e66f..61190120dc 100644 --- a/web/src/app/chat/modal/SetDefaultModelModal.tsx +++ b/web/src/app/chat/modal/SetDefaultModelModal.tsx @@ -1,5 +1,5 @@ import { Dispatch, SetStateAction, useEffect, useRef } from "react"; -import { ModalWrapper } from "@/components/modals/ModalWrapper"; +import { Modal } from "@/components/Modal"; import { Text } from "@tremor/react"; import { getDisplayNameForModel, LlmOverride } from "@/lib/hooks"; import { LLMProviderDescriptor } from "@/app/admin/configuration/llm/interfaces"; @@ -123,10 +123,7 @@ export function SetDefaultModelModal({ ); return ( - + <>

@@ -203,6 +200,6 @@ export function SetDefaultModelModal({

-
+ ); } diff --git a/web/src/app/chat/modal/ShareChatSessionModal.tsx b/web/src/app/chat/modal/ShareChatSessionModal.tsx index 16a9147b52..1b797e77ab 100644 --- a/web/src/app/chat/modal/ShareChatSessionModal.tsx +++ b/web/src/app/chat/modal/ShareChatSessionModal.tsx @@ -1,5 +1,5 @@ import { useState } from "react"; -import { ModalWrapper } from "@/components/modals/ModalWrapper"; +import { Modal } from "@/components/Modal"; import { Button, Callout, Divider, Text } from "@tremor/react"; import { Spinner } from "@/components/Spinner"; import { ChatSessionSharedStatus } from "../interfaces"; @@ -57,7 +57,7 @@ export function ShareChatSessionModal({ ); return ( - + <>

@@ -154,6 +154,6 @@ export function ShareChatSessionModal({ )}

-
+ ); } diff --git a/web/src/components/Modal.tsx b/web/src/components/Modal.tsx index 169e85025d..0f354a264f 100644 --- a/web/src/components/Modal.tsx +++ b/web/src/components/Modal.tsx @@ -54,9 +54,9 @@ export function Modal({ e.stopPropagation(); } }} - className={`bg-background text-emphasis rounded shadow-2xl + className={`bg-background text-emphasis rounded shadow-2xl transform transition-all duration-300 ease-in-out - ${width ?? "w-11/12 max-w-5xl"} + ${width ?? "w-11/12 max-w-4xl"} ${noPadding ? "" : "p-10"} ${className || ""}`} > @@ -88,7 +88,7 @@ export function Modal({ {!hideDividerForTitle && } )} - {children} +
{children}
diff --git a/web/src/components/modals/DeleteEntityModal.tsx b/web/src/components/modals/DeleteEntityModal.tsx index 5ef76f9c85..85cda2fd4d 100644 --- a/web/src/components/modals/DeleteEntityModal.tsx +++ b/web/src/components/modals/DeleteEntityModal.tsx @@ -1,6 +1,6 @@ import { FiTrash, FiX } from "react-icons/fi"; -import { ModalWrapper } from "@/components/modals/ModalWrapper"; import { BasicClickable } from "@/components/BasicClickable"; +import { Modal } from "../Modal"; export const DeleteEntityModal = ({ onClose, @@ -16,7 +16,7 @@ export const DeleteEntityModal = ({ additionalDetails?: string; }) => { return ( - + <>

Delete {entityType}?

@@ -37,6 +37,6 @@ export const DeleteEntityModal = ({
-
+ ); }; diff --git a/web/src/components/modals/GenericConfirmModal.tsx b/web/src/components/modals/GenericConfirmModal.tsx index fe6c2b020a..893ae2f6b9 100644 --- a/web/src/components/modals/GenericConfirmModal.tsx +++ b/web/src/components/modals/GenericConfirmModal.tsx @@ -1,5 +1,5 @@ import { FiCheck } from "react-icons/fi"; -import { ModalWrapper } from "./ModalWrapper"; +import { Modal } from "@/components/Modal"; import { BasicClickable } from "@/components/BasicClickable"; export const GenericConfirmModal = ({ @@ -16,7 +16,7 @@ export const GenericConfirmModal = ({ onConfirm: () => void; }) => { return ( - +

@@ -37,6 +37,6 @@ export const GenericConfirmModal = ({

-
+ ); }; diff --git a/web/src/components/modals/ModalWrapper.tsx b/web/src/components/modals/ModalWrapper.tsx deleted file mode 100644 index f69ff0e2b6..0000000000 --- a/web/src/components/modals/ModalWrapper.tsx +++ /dev/null @@ -1,63 +0,0 @@ -"use client"; -import { XIcon } from "@/components/icons/icons"; -import { isEventWithinRef } from "@/lib/contains"; -import { useRef } from "react"; - -export const ModalWrapper = ({ - children, - bgClassName, - modalClassName, - onClose, -}: { - children: JSX.Element; - bgClassName?: string; - modalClassName?: string; - onClose?: () => void; -}) => { - const modalRef = useRef(null); - - const handleMouseDown = (e: React.MouseEvent) => { - if ( - onClose && - modalRef.current && - !modalRef.current.contains(e.target as Node) && - !isEventWithinRef(e.nativeEvent, modalRef) - ) { - onClose(); - } - }; - return ( -
-
{ - if (onClose) { - e.stopPropagation(); - } - }} - className={`bg-background text-emphasis p-10 rounded shadow-2xl - w-11/12 max-w-3xl transform transition-all duration-300 ease-in-out - relative ${modalClassName || ""}`} - > - {onClose && ( -
- -
- )} - -
{children}
-
-
- ); -}; diff --git a/web/src/components/modals/NoAssistantModal.tsx b/web/src/components/modals/NoAssistantModal.tsx index 0eed887662..94d20aadba 100644 --- a/web/src/components/modals/NoAssistantModal.tsx +++ b/web/src/components/modals/NoAssistantModal.tsx @@ -1,8 +1,8 @@ -import { ModalWrapper } from "@/components/modals/ModalWrapper"; +import { Modal } from "@/components/Modal"; export const NoAssistantModal = ({ isAdmin }: { isAdmin: boolean }) => { return ( - + <>

No Assistant Available @@ -32,6 +32,6 @@ export const NoAssistantModal = ({ isAdmin }: { isAdmin: boolean }) => {

)} - + ); };