diff --git a/.github/pull_request_template.md b/.github/pull_request_template.md
index c49908b1d5..8287f9b530 100644
--- a/.github/pull_request_template.md
+++ b/.github/pull_request_template.md
@@ -26,4 +26,4 @@ N/A
## Backporting (check the box to trigger backport action)
Note: You have to check that the action passes, otherwise resolve the conflicts manually and tag the patches.
-[ ] This PR should be backported (make sure to check that the backport attempt succeeds)
+- [ ] This PR should be backported (make sure to check that the backport attempt succeeds)
diff --git a/.github/workflows/pr-backport-autotrigger.yml b/.github/workflows/pr-backport-autotrigger.yml
index 4ba5136a82..2d49c39402 100644
--- a/.github/workflows/pr-backport-autotrigger.yml
+++ b/.github/workflows/pr-backport-autotrigger.yml
@@ -10,43 +10,83 @@ jobs:
runs-on: ubuntu-latest
steps:
- name: Checkout code
- uses: actions/checkout@v2
+ uses: actions/checkout@v3
+ with:
+ fetch-depth: 0 # Fetch all history for all branches and tags
+
+ - name: Set up Git
+ run: |
+ git config user.name "github-actions[bot]"
+ git config user.email "github-actions[bot]@users.noreply.github.com"
- name: Check for Backport Checkbox
id: checkbox-check
run: |
PR_BODY="${{ github.event.pull_request.body }}"
if [[ "$PR_BODY" == *"[x] This PR should be backported"* ]]; then
- echo "::set-output name=backport::true"
+ echo "backport=true" >> $GITHUB_OUTPUT
else
- echo "::set-output name=backport::false"
+ echo "backport=false" >> $GITHUB_OUTPUT
fi
- name: List and sort release branches
id: list-branches
run: |
- git fetch --all
- BRANCHES=$(git branch -r | grep 'origin/release/v' | sed 's|origin/release/v||' | sort -Vr)
+ git fetch --all --tags
+ BRANCHES=$(git for-each-ref --format='%(refname:short)' refs/remotes/origin/release/* | sed 's|origin/release/||' | sort -Vr)
BETA=$(echo "$BRANCHES" | head -n 1)
STABLE=$(echo "$BRANCHES" | head -n 2 | tail -n 1)
- echo "::set-output name=beta::$BETA"
- echo "::set-output name=stable::$STABLE"
+ echo "beta=$BETA" >> $GITHUB_OUTPUT
+ echo "stable=$STABLE" >> $GITHUB_OUTPUT
+ # Fetch latest tags for beta and stable
+ LATEST_BETA_TAG=$(git tag -l "v*.*.0-beta.*" | sort -Vr | head -n 1)
+ LATEST_STABLE_TAG=$(git tag -l "v*.*.*" | grep -v -- "-beta" | sort -Vr | head -n 1)
+ # Increment latest beta tag
+ NEW_BETA_TAG=$(echo $LATEST_BETA_TAG | awk -F '[.-]' '{print $1 "." $2 ".0-beta." ($NF+1)}')
+ # Increment latest stable tag
+ NEW_STABLE_TAG=$(echo $LATEST_STABLE_TAG | awk -F '.' '{print $1 "." $2 "." ($3+1)}')
+ echo "latest_beta_tag=$LATEST_BETA_TAG" >> $GITHUB_OUTPUT
+ echo "latest_stable_tag=$LATEST_STABLE_TAG" >> $GITHUB_OUTPUT
+ echo "new_beta_tag=$NEW_BETA_TAG" >> $GITHUB_OUTPUT
+ echo "new_stable_tag=$NEW_STABLE_TAG" >> $GITHUB_OUTPUT
+
+ - name: Echo branch and tag information
+ run: |
+ echo "Beta branch: ${{ steps.list-branches.outputs.beta }}"
+ echo "Stable branch: ${{ steps.list-branches.outputs.stable }}"
+ echo "Latest beta tag: ${{ steps.list-branches.outputs.latest_beta_tag }}"
+ echo "Latest stable tag: ${{ steps.list-branches.outputs.latest_stable_tag }}"
+ echo "New beta tag: ${{ steps.list-branches.outputs.new_beta_tag }}"
+ echo "New stable tag: ${{ steps.list-branches.outputs.new_stable_tag }}"
- name: Trigger Backport
if: steps.checkbox-check.outputs.backport == 'true'
run: |
+ set -e
echo "Backporting to beta ${{ steps.list-branches.outputs.beta }} and stable ${{ steps.list-branches.outputs.stable }}"
# Fetch all history for all branches and tags
git fetch --prune --unshallow
# Checkout the beta branch
git checkout ${{ steps.list-branches.outputs.beta }}
- # Cherry-pick the last commit from the merged PR
- git cherry-pick ${{ github.event.pull_request.merge_commit_sha }}
- # Push the changes to the beta branch
+ # Cherry-pick the merge commit from the merged PR
+ git cherry-pick -m 1 ${{ github.event.pull_request.merge_commit_sha }} || {
+ echo "Cherry-pick to beta failed due to conflicts."
+ exit 1
+ }
+ # Create new beta tag
+ git tag ${{ steps.list-branches.outputs.new_beta_tag }}
+ # Push the changes and tag to the beta branch
git push origin ${{ steps.list-branches.outputs.beta }}
+ git push origin ${{ steps.list-branches.outputs.new_beta_tag }}
# Checkout the stable branch
git checkout ${{ steps.list-branches.outputs.stable }}
- # Cherry-pick the last commit from the merged PR
- git cherry-pick ${{ github.event.pull_request.merge_commit_sha }}
- # Push the changes to the stable branch
+ # Cherry-pick the merge commit from the merged PR
+ git cherry-pick -m 1 ${{ github.event.pull_request.merge_commit_sha }} || {
+ echo "Cherry-pick to stable failed due to conflicts."
+ exit 1
+ }
+ # Create new stable tag
+ git tag ${{ steps.list-branches.outputs.new_stable_tag }}
+ # Push the changes and tag to the stable branch
git push origin ${{ steps.list-branches.outputs.stable }}
+ git push origin ${{ steps.list-branches.outputs.new_stable_tag }}
diff --git a/backend/alembic/versions/b082fec533f0_make_last_attempt_status_nullable.py b/backend/alembic/versions/b082fec533f0_make_last_attempt_status_nullable.py
index a6938e365c..db7b330c3e 100644
--- a/backend/alembic/versions/b082fec533f0_make_last_attempt_status_nullable.py
+++ b/backend/alembic/versions/b082fec533f0_make_last_attempt_status_nullable.py
@@ -31,6 +31,12 @@ def upgrade() -> None:
def downgrade() -> None:
+ # First, update any null values to a default value
+ op.execute(
+ "UPDATE connector_credential_pair SET last_attempt_status = 'NOT_STARTED' WHERE last_attempt_status IS NULL"
+ )
+
+ # Then, make the column non-nullable
op.alter_column(
"connector_credential_pair",
"last_attempt_status",
diff --git a/backend/danswer/background/celery/apps/primary.py b/backend/danswer/background/celery/apps/primary.py
index 23e25fa92b..983b76773e 100644
--- a/backend/danswer/background/celery/apps/primary.py
+++ b/backend/danswer/background/celery/apps/primary.py
@@ -17,6 +17,7 @@ from danswer.background.celery.celery_redis import RedisConnectorCredentialPair
from danswer.background.celery.celery_redis import RedisConnectorDeletion
from danswer.background.celery.celery_redis import RedisConnectorIndexing
from danswer.background.celery.celery_redis import RedisConnectorPruning
+from danswer.background.celery.celery_redis import RedisConnectorStop
from danswer.background.celery.celery_redis import RedisDocumentSet
from danswer.background.celery.celery_redis import RedisUserGroup
from danswer.background.celery.celery_utils import celery_is_worker_primary
@@ -161,6 +162,9 @@ def on_worker_init(sender: Any, **kwargs: Any) -> None:
for key in r.scan_iter(RedisConnectorIndexing.FENCE_PREFIX + "*"):
r.delete(key)
+ for key in r.scan_iter(RedisConnectorStop.FENCE_PREFIX + "*"):
+ r.delete(key)
+
# @worker_process_init.connect
# def on_worker_process_init(sender: Any, **kwargs: Any) -> None:
diff --git a/backend/danswer/background/celery/celery_redis.py b/backend/danswer/background/celery/celery_redis.py
index 1ea5e3b176..e412b0bf73 100644
--- a/backend/danswer/background/celery/celery_redis.py
+++ b/backend/danswer/background/celery/celery_redis.py
@@ -313,6 +313,8 @@ class RedisConnectorDeletion(RedisObjectHelper):
lock: redis.lock.Lock,
tenant_id: str | None,
) -> int | None:
+ """Returns None if the cc_pair doesn't exist.
+ Otherwise, returns an int with the number of generated tasks."""
last_lock_time = time.monotonic()
async_results = []
@@ -540,6 +542,29 @@ class RedisConnectorIndexing(RedisObjectHelper):
return False
+class RedisConnectorStop(RedisObjectHelper):
+ """Used to signal any running tasks for a connector to stop. We should refactor
+ connector related redis helpers into a single class.
+ """
+
+ PREFIX = "connectorstop"
+ FENCE_PREFIX = PREFIX + "_fence" # a fence for the entire indexing process
+ TASKSET_PREFIX = PREFIX + "_taskset" # stores a list of prune tasks id's
+
+ def __init__(self, id: int) -> None:
+ super().__init__(str(id))
+
+ def generate_tasks(
+ self,
+ celery_app: Celery,
+ db_session: Session,
+ redis_client: Redis,
+ lock: redis.lock.Lock | None,
+ tenant_id: str | None,
+ ) -> int | None:
+ return None
+
+
def celery_get_queue_length(queue: str, r: Redis) -> int:
"""This is a redis specific way to get the length of a celery queue.
It is priority aware and knows how to count across the multiple redis lists
diff --git a/backend/danswer/background/celery/celery_utils.py b/backend/danswer/background/celery/celery_utils.py
index b1e9c2113e..18038a349d 100644
--- a/backend/danswer/background/celery/celery_utils.py
+++ b/backend/danswer/background/celery/celery_utils.py
@@ -1,4 +1,3 @@
-from collections.abc import Callable
from datetime import datetime
from datetime import timezone
from typing import Any
@@ -6,6 +5,7 @@ from typing import Any
from sqlalchemy.orm import Session
from danswer.background.celery.celery_redis import RedisConnectorDeletion
+from danswer.background.indexing.run_indexing import RunIndexingCallbackInterface
from danswer.configs.app_configs import MAX_PRUNING_DOCUMENT_RETRIEVAL_PER_MINUTE
from danswer.connectors.cross_connector_utils.rate_limit_wrapper import (
rate_limit_builder,
@@ -79,7 +79,7 @@ def document_batch_to_ids(
def extract_ids_from_runnable_connector(
runnable_connector: BaseConnector,
- progress_callback: Callable[[int], None] | None = None,
+ callback: RunIndexingCallbackInterface | None = None,
) -> set[str]:
"""
If the PruneConnector hasnt been implemented for the given connector, just pull
@@ -110,8 +110,10 @@ def extract_ids_from_runnable_connector(
max_calls=MAX_PRUNING_DOCUMENT_RETRIEVAL_PER_MINUTE, period=60
)(document_batch_to_ids)
for doc_batch in doc_batch_generator:
- if progress_callback:
- progress_callback(len(doc_batch))
+ if callback:
+ if callback.should_stop():
+ raise RuntimeError("Stop signal received")
+ callback.progress(len(doc_batch))
all_connector_doc_ids.update(doc_batch_processing_func(doc_batch))
return all_connector_doc_ids
diff --git a/backend/danswer/background/celery/tasks/connector_deletion/tasks.py b/backend/danswer/background/celery/tasks/connector_deletion/tasks.py
index f6a59d03e3..59d236cde3 100644
--- a/backend/danswer/background/celery/tasks/connector_deletion/tasks.py
+++ b/backend/danswer/background/celery/tasks/connector_deletion/tasks.py
@@ -1,3 +1,6 @@
+from datetime import datetime
+from datetime import timezone
+
import redis
from celery import Celery
from celery import shared_task
@@ -8,6 +11,12 @@ from sqlalchemy.orm import Session
from danswer.background.celery.apps.app_base import task_logger
from danswer.background.celery.celery_redis import RedisConnectorDeletion
+from danswer.background.celery.celery_redis import RedisConnectorIndexing
+from danswer.background.celery.celery_redis import RedisConnectorPruning
+from danswer.background.celery.celery_redis import RedisConnectorStop
+from danswer.background.celery.tasks.shared.RedisConnectorDeletionFenceData import (
+ RedisConnectorDeletionFenceData,
+)
from danswer.configs.app_configs import JOB_TIMEOUT
from danswer.configs.constants import CELERY_VESPA_SYNC_BEAT_LOCK_TIMEOUT
from danswer.configs.constants import DanswerRedisLocks
@@ -15,9 +24,15 @@ from danswer.db.connector_credential_pair import get_connector_credential_pair_f
from danswer.db.connector_credential_pair import get_connector_credential_pairs
from danswer.db.engine import get_session_with_tenant
from danswer.db.enums import ConnectorCredentialPairStatus
+from danswer.db.search_settings import get_all_search_settings
from danswer.redis.redis_pool import get_redis_client
+class TaskDependencyError(RuntimeError):
+ """Raised to the caller to indicate dependent tasks are running that would interfere
+ with connector deletion."""
+
+
@shared_task(
name="check_for_connector_deletion_task",
soft_time_limit=JOB_TIMEOUT,
@@ -37,17 +52,30 @@ def check_for_connector_deletion_task(self: Task, *, tenant_id: str | None) -> N
if not lock_beat.acquire(blocking=False):
return
+ # collect cc_pair_ids
cc_pair_ids: list[int] = []
with get_session_with_tenant(tenant_id) as db_session:
cc_pairs = get_connector_credential_pairs(db_session)
for cc_pair in cc_pairs:
cc_pair_ids.append(cc_pair.id)
+ # try running cleanup on the cc_pair_ids
for cc_pair_id in cc_pair_ids:
with get_session_with_tenant(tenant_id) as db_session:
- try_generate_document_cc_pair_cleanup_tasks(
- self.app, cc_pair_id, db_session, r, lock_beat, tenant_id
- )
+ rcs = RedisConnectorStop(cc_pair_id)
+ try:
+ try_generate_document_cc_pair_cleanup_tasks(
+ self.app, cc_pair_id, db_session, r, lock_beat, tenant_id
+ )
+ except TaskDependencyError as e:
+ # this means we wanted to start deleting but dependent tasks were running
+ # Leave a stop signal to clear indexing and pruning tasks more quickly
+ task_logger.info(str(e))
+ r.set(rcs.fence_key, cc_pair_id)
+ else:
+ # clear the stop signal if it exists ... no longer needed
+ r.delete(rcs.fence_key)
+
except SoftTimeLimitExceeded:
task_logger.info(
"Soft time limit exceeded, task is being terminated gracefully."
@@ -70,6 +98,10 @@ def try_generate_document_cc_pair_cleanup_tasks(
"""Returns an int if syncing is needed. The int represents the number of sync tasks generated.
Note that syncing can still be required even if the number of sync tasks generated is zero.
Returns None if no syncing is required.
+
+ Will raise TaskDependencyError if dependent tasks such as indexing and pruning are
+ still running. In our case, the caller reacts by setting a stop signal in Redis to
+ exit those tasks as quickly as possible.
"""
lock_beat.reacquire()
@@ -90,28 +122,63 @@ def try_generate_document_cc_pair_cleanup_tasks(
if cc_pair.status != ConnectorCredentialPairStatus.DELETING:
return None
- # add tasks to celery and build up the task set to monitor in redis
- r.delete(rcd.taskset_key)
-
- # Add all documents that need to be updated into the queue
- task_logger.info(
- f"RedisConnectorDeletion.generate_tasks starting. cc_pair_id={cc_pair.id}"
+ # set a basic fence to start
+ fence_value = RedisConnectorDeletionFenceData(
+ num_tasks=None,
+ submitted=datetime.now(timezone.utc),
)
- tasks_generated = rcd.generate_tasks(app, db_session, r, lock_beat, tenant_id)
- if tasks_generated is None:
+ r.set(rcd.fence_key, fence_value.model_dump_json())
+
+ try:
+ # do not proceed if connector indexing or connector pruning are running
+ search_settings_list = get_all_search_settings(db_session)
+ for search_settings in search_settings_list:
+ rci = RedisConnectorIndexing(cc_pair_id, search_settings.id)
+ if r.get(rci.fence_key):
+ raise TaskDependencyError(
+ f"Connector deletion - Delayed (indexing in progress): "
+ f"cc_pair={cc_pair_id} "
+ f"search_settings={search_settings.id}"
+ )
+
+ rcp = RedisConnectorPruning(cc_pair_id)
+ if r.get(rcp.fence_key):
+ raise TaskDependencyError(
+ f"Connector deletion - Delayed (pruning in progress): "
+ f"cc_pair={cc_pair_id}"
+ )
+
+ # add tasks to celery and build up the task set to monitor in redis
+ r.delete(rcd.taskset_key)
+
+ # Add all documents that need to be updated into the queue
+ task_logger.info(
+ f"RedisConnectorDeletion.generate_tasks starting. cc_pair={cc_pair_id}"
+ )
+ tasks_generated = rcd.generate_tasks(app, db_session, r, lock_beat, tenant_id)
+ if tasks_generated is None:
+ raise ValueError("RedisConnectorDeletion.generate_tasks returned None")
+ except TaskDependencyError:
+ r.delete(rcd.fence_key)
+ raise
+ except Exception:
+ task_logger.exception("Unexpected exception")
+ r.delete(rcd.fence_key)
return None
+ else:
+ # Currently we are allowing the sync to proceed with 0 tasks.
+ # It's possible for sets/groups to be generated initially with no entries
+ # and they still need to be marked as up to date.
+ # if tasks_generated == 0:
+ # return 0
- # Currently we are allowing the sync to proceed with 0 tasks.
- # It's possible for sets/groups to be generated initially with no entries
- # and they still need to be marked as up to date.
- # if tasks_generated == 0:
- # return 0
+ task_logger.info(
+ f"RedisConnectorDeletion.generate_tasks finished. "
+ f"cc_pair={cc_pair_id} tasks_generated={tasks_generated}"
+ )
- task_logger.info(
- f"RedisConnectorDeletion.generate_tasks finished. "
- f"cc_pair_id={cc_pair.id} tasks_generated={tasks_generated}"
- )
+ # set this only after all tasks have been added
+ fence_value.num_tasks = tasks_generated
+ r.set(rcd.fence_key, fence_value.model_dump_json())
- # set this only after all tasks have been added
- r.set(rcd.fence_key, tasks_generated)
return tasks_generated
diff --git a/backend/danswer/background/celery/tasks/indexing/tasks.py b/backend/danswer/background/celery/tasks/indexing/tasks.py
index bdd55f77f3..980266ec87 100644
--- a/backend/danswer/background/celery/tasks/indexing/tasks.py
+++ b/backend/danswer/background/celery/tasks/indexing/tasks.py
@@ -5,6 +5,7 @@ from time import sleep
from typing import cast
from uuid import uuid4
+import redis
from celery import Celery
from celery import shared_task
from celery import Task
@@ -13,12 +14,15 @@ from redis import Redis
from sqlalchemy.orm import Session
from danswer.background.celery.apps.app_base import task_logger
+from danswer.background.celery.celery_redis import RedisConnectorDeletion
from danswer.background.celery.celery_redis import RedisConnectorIndexing
+from danswer.background.celery.celery_redis import RedisConnectorStop
from danswer.background.celery.tasks.shared.RedisConnectorIndexingFenceData import (
RedisConnectorIndexingFenceData,
)
from danswer.background.indexing.job_client import SimpleJobClient
from danswer.background.indexing.run_indexing import run_indexing_entrypoint
+from danswer.background.indexing.run_indexing import RunIndexingCallbackInterface
from danswer.configs.app_configs import DISABLE_INDEX_UPDATE_ON_SWAP
from danswer.configs.constants import CELERY_INDEXING_LOCK_TIMEOUT
from danswer.configs.constants import CELERY_VESPA_SYNC_BEAT_LOCK_TIMEOUT
@@ -50,6 +54,30 @@ from danswer.utils.variable_functionality import global_version
logger = setup_logger()
+class RunIndexingCallback(RunIndexingCallbackInterface):
+ def __init__(
+ self,
+ stop_key: str,
+ generator_progress_key: str,
+ redis_lock: redis.lock.Lock,
+ redis_client: Redis,
+ ):
+ super().__init__()
+ self.redis_lock: redis.lock.Lock = redis_lock
+ self.stop_key: str = stop_key
+ self.generator_progress_key: str = generator_progress_key
+ self.redis_client = redis_client
+
+ def should_stop(self) -> bool:
+ if self.redis_client.exists(self.stop_key):
+ return True
+ return False
+
+ def progress(self, amount: int) -> None:
+ self.redis_lock.reacquire()
+ self.redis_client.incrby(self.generator_progress_key, amount)
+
+
@shared_task(
name="check_for_indexing",
soft_time_limit=300,
@@ -262,6 +290,10 @@ def try_creating_indexing_task(
return None
# skip indexing if the cc_pair is deleting
+ rcd = RedisConnectorDeletion(cc_pair.id)
+ if r.exists(rcd.fence_key):
+ return None
+
db_session.refresh(cc_pair)
if cc_pair.status == ConnectorCredentialPairStatus.DELETING:
return None
@@ -308,13 +340,8 @@ def try_creating_indexing_task(
raise RuntimeError("send_task for connector_indexing_proxy_task failed.")
# now fill out the fence with the rest of the data
- fence_value = RedisConnectorIndexingFenceData(
- index_attempt_id=index_attempt_id,
- started=None,
- submitted=datetime.now(timezone.utc),
- celery_task_id=result.id,
- )
-
+ fence_value.index_attempt_id = index_attempt_id
+ fence_value.celery_task_id = result.id
r.set(rci.fence_key, fence_value.model_dump_json())
except Exception:
r.delete(rci.fence_key)
@@ -400,6 +427,22 @@ def connector_indexing_task(
r = get_redis_client(tenant_id=tenant_id)
+ rcd = RedisConnectorDeletion(cc_pair_id)
+ if r.exists(rcd.fence_key):
+ raise RuntimeError(
+ f"Indexing will not start because connector deletion is in progress: "
+ f"cc_pair={cc_pair_id} "
+ f"fence={rcd.fence_key}"
+ )
+
+ rcs = RedisConnectorStop(cc_pair_id)
+ if r.exists(rcs.fence_key):
+ raise RuntimeError(
+ f"Indexing will not start because a connector stop signal was detected: "
+ f"cc_pair={cc_pair_id} "
+ f"fence={rcs.fence_key}"
+ )
+
rci = RedisConnectorIndexing(cc_pair_id, search_settings_id)
while True:
@@ -409,7 +452,7 @@ def connector_indexing_task(
task_logger.info(
f"connector_indexing_task: fence_value not found: fence={rci.fence_key}"
)
- raise
+ raise RuntimeError(f"Fence not found: fence={rci.fence_key}")
try:
fence_json = fence_value.decode("utf-8")
@@ -443,17 +486,20 @@ def connector_indexing_task(
if not acquired:
task_logger.warning(
f"Indexing task already running, exiting...: "
- f"cc_pair_id={cc_pair_id} search_settings_id={search_settings_id}"
+ f"cc_pair={cc_pair_id} search_settings={search_settings_id}"
)
# r.set(rci.generator_complete_key, HTTPStatus.CONFLICT.value)
return None
+ fence_data.started = datetime.now(timezone.utc)
+ r.set(rci.fence_key, fence_data.model_dump_json())
+
try:
with get_session_with_tenant(tenant_id) as db_session:
attempt = get_index_attempt(db_session, index_attempt_id)
if not attempt:
raise ValueError(
- f"Index attempt not found: index_attempt_id={index_attempt_id}"
+ f"Index attempt not found: index_attempt={index_attempt_id}"
)
cc_pair = get_connector_credential_pair_from_id(
@@ -462,31 +508,31 @@ def connector_indexing_task(
)
if not cc_pair:
- raise ValueError(f"cc_pair not found: cc_pair_id={cc_pair_id}")
+ raise ValueError(f"cc_pair not found: cc_pair={cc_pair_id}")
if not cc_pair.connector:
raise ValueError(
- f"Connector not found: connector_id={cc_pair.connector_id}"
+ f"Connector not found: cc_pair={cc_pair_id} connector={cc_pair.connector_id}"
)
if not cc_pair.credential:
raise ValueError(
- f"Credential not found: credential_id={cc_pair.credential_id}"
+ f"Credential not found: cc_pair={cc_pair_id} credential={cc_pair.credential_id}"
)
rci = RedisConnectorIndexing(cc_pair_id, search_settings_id)
- # Define the callback function
- def redis_increment_callback(amount: int) -> None:
- lock.reacquire()
- r.incrby(rci.generator_progress_key, amount)
+ # define a callback class
+ callback = RunIndexingCallback(
+ rcs.fence_key, rci.generator_progress_key, lock, r
+ )
run_indexing_entrypoint(
index_attempt_id,
tenant_id,
cc_pair_id,
is_ee,
- progress_callback=redis_increment_callback,
+ callback=callback,
)
# get back the total number of indexed docs and return it
@@ -499,9 +545,10 @@ def connector_indexing_task(
r.set(rci.generator_complete_key, HTTPStatus.OK.value)
except Exception as e:
- task_logger.exception(f"Failed to run indexing for cc_pair_id={cc_pair_id}.")
+ task_logger.exception(f"Indexing failed: cc_pair={cc_pair_id}")
if attempt:
- mark_attempt_failed(attempt, db_session, failure_reason=str(e))
+ with get_session_with_tenant(tenant_id) as db_session:
+ mark_attempt_failed(attempt, db_session, failure_reason=str(e))
r.delete(rci.generator_lock_key)
r.delete(rci.generator_progress_key)
diff --git a/backend/danswer/background/celery/tasks/pruning/tasks.py b/backend/danswer/background/celery/tasks/pruning/tasks.py
index 9f290d6f23..2e68986e83 100644
--- a/backend/danswer/background/celery/tasks/pruning/tasks.py
+++ b/backend/danswer/background/celery/tasks/pruning/tasks.py
@@ -11,8 +11,11 @@ from redis import Redis
from sqlalchemy.orm import Session
from danswer.background.celery.apps.app_base import task_logger
+from danswer.background.celery.celery_redis import RedisConnectorDeletion
from danswer.background.celery.celery_redis import RedisConnectorPruning
+from danswer.background.celery.celery_redis import RedisConnectorStop
from danswer.background.celery.celery_utils import extract_ids_from_runnable_connector
+from danswer.background.celery.tasks.indexing.tasks import RunIndexingCallback
from danswer.configs.app_configs import ALLOW_SIMULTANEOUS_PRUNING
from danswer.configs.app_configs import JOB_TIMEOUT
from danswer.configs.constants import CELERY_PRUNING_LOCK_TIMEOUT
@@ -168,6 +171,10 @@ def try_creating_prune_generator_task(
return None
# skip pruning if the cc_pair is deleting
+ rcd = RedisConnectorDeletion(cc_pair.id)
+ if r.exists(rcd.fence_key):
+ return None
+
db_session.refresh(cc_pair)
if cc_pair.status == ConnectorCredentialPairStatus.DELETING:
return None
@@ -234,7 +241,7 @@ def connector_pruning_generator_task(
acquired = lock.acquire(blocking=False)
if not acquired:
task_logger.warning(
- f"Pruning task already running, exiting...: cc_pair_id={cc_pair_id}"
+ f"Pruning task already running, exiting...: cc_pair={cc_pair_id}"
)
return None
@@ -252,11 +259,6 @@ def connector_pruning_generator_task(
)
return
- # Define the callback function
- def redis_increment_callback(amount: int) -> None:
- lock.reacquire()
- r.incrby(rcp.generator_progress_key, amount)
-
runnable_connector = instantiate_connector(
db_session,
cc_pair.connector.source,
@@ -265,9 +267,14 @@ def connector_pruning_generator_task(
cc_pair.credential,
)
+ rcs = RedisConnectorStop(cc_pair_id)
+
+ callback = RunIndexingCallback(
+ rcs.fence_key, rcp.generator_progress_key, lock, r
+ )
# a list of docs in the source
all_connector_doc_ids: set[str] = extract_ids_from_runnable_connector(
- runnable_connector, redis_increment_callback
+ runnable_connector, callback
)
# a list of docs in our local index
@@ -285,7 +292,7 @@ def connector_pruning_generator_task(
task_logger.info(
f"Pruning set collected: "
- f"cc_pair_id={cc_pair.id} "
+ f"cc_pair={cc_pair_id} "
f"docs_to_remove={len(doc_ids_to_remove)} "
f"doc_source={cc_pair.connector.source}"
)
@@ -293,7 +300,7 @@ def connector_pruning_generator_task(
rcp.documents_to_prune = set(doc_ids_to_remove)
task_logger.info(
- f"RedisConnectorPruning.generate_tasks starting. cc_pair_id={cc_pair.id}"
+ f"RedisConnectorPruning.generate_tasks starting. cc_pair_id={cc_pair_id}"
)
tasks_generated = rcp.generate_tasks(
self.app, db_session, r, None, tenant_id
@@ -303,12 +310,14 @@ def connector_pruning_generator_task(
task_logger.info(
f"RedisConnectorPruning.generate_tasks finished. "
- f"cc_pair_id={cc_pair.id} tasks_generated={tasks_generated}"
+ f"cc_pair={cc_pair_id} tasks_generated={tasks_generated}"
)
r.set(rcp.generator_complete_key, tasks_generated)
except Exception as e:
- task_logger.exception(f"Failed to run pruning for connector id {connector_id}.")
+ task_logger.exception(
+ f"Failed to run pruning: cc_pair={cc_pair_id} connector={connector_id}"
+ )
r.delete(rcp.generator_progress_key)
r.delete(rcp.taskset_key)
diff --git a/backend/danswer/background/celery/tasks/shared/RedisConnectorDeletionFenceData.py b/backend/danswer/background/celery/tasks/shared/RedisConnectorDeletionFenceData.py
new file mode 100644
index 0000000000..1c664d14b4
--- /dev/null
+++ b/backend/danswer/background/celery/tasks/shared/RedisConnectorDeletionFenceData.py
@@ -0,0 +1,8 @@
+from datetime import datetime
+
+from pydantic import BaseModel
+
+
+class RedisConnectorDeletionFenceData(BaseModel):
+ num_tasks: int | None
+ submitted: datetime
diff --git a/backend/danswer/background/celery/tasks/shared/tasks.py b/backend/danswer/background/celery/tasks/shared/tasks.py
index 7ce43454aa..52a49d467e 100644
--- a/backend/danswer/background/celery/tasks/shared/tasks.py
+++ b/backend/danswer/background/celery/tasks/shared/tasks.py
@@ -1,9 +1,6 @@
-from datetime import datetime
-
from celery import shared_task
from celery import Task
from celery.exceptions import SoftTimeLimitExceeded
-from pydantic import BaseModel
from danswer.access.access import get_access_for_document
from danswer.background.celery.apps.app_base import task_logger
@@ -23,13 +20,6 @@ from danswer.server.documents.models import ConnectorCredentialPairIdentifier
DOCUMENT_BY_CC_PAIR_CLEANUP_MAX_RETRIES = 3
-class RedisConnectorIndexingFenceData(BaseModel):
- index_attempt_id: int | None
- started: datetime | None
- submitted: datetime
- celery_task_id: str | None
-
-
@shared_task(
name="document_by_cc_pair_cleanup_task",
bind=True,
diff --git a/backend/danswer/background/celery/tasks/vespa/tasks.py b/backend/danswer/background/celery/tasks/vespa/tasks.py
index 812074b91e..fcc4d2aa5b 100644
--- a/backend/danswer/background/celery/tasks/vespa/tasks.py
+++ b/backend/danswer/background/celery/tasks/vespa/tasks.py
@@ -23,6 +23,9 @@ from danswer.background.celery.celery_redis import RedisConnectorIndexing
from danswer.background.celery.celery_redis import RedisConnectorPruning
from danswer.background.celery.celery_redis import RedisDocumentSet
from danswer.background.celery.celery_redis import RedisUserGroup
+from danswer.background.celery.tasks.shared.RedisConnectorDeletionFenceData import (
+ RedisConnectorDeletionFenceData,
+)
from danswer.background.celery.tasks.shared.RedisConnectorIndexingFenceData import (
RedisConnectorIndexingFenceData,
)
@@ -368,7 +371,7 @@ def monitor_document_set_taskset(
count = cast(int, r.scard(rds.taskset_key))
task_logger.info(
- f"Document set sync progress: document_set_id={document_set_id} "
+ f"Document set sync progress: document_set={document_set_id} "
f"remaining={count} initial={initial_count}"
)
if count > 0:
@@ -383,12 +386,12 @@ def monitor_document_set_taskset(
# if there are no connectors, then delete the document set.
delete_document_set(document_set_row=document_set, db_session=db_session)
task_logger.info(
- f"Successfully deleted document set with ID: '{document_set_id}'!"
+ f"Successfully deleted document set: document_set={document_set_id}"
)
else:
mark_document_set_as_synced(document_set_id, db_session)
task_logger.info(
- f"Successfully synced document set with ID: '{document_set_id}'!"
+ f"Successfully synced document set: document_set={document_set_id}"
)
r.delete(rds.taskset_key)
@@ -408,19 +411,29 @@ def monitor_connector_deletion_taskset(
rcd = RedisConnectorDeletion(cc_pair_id)
- fence_value = r.get(rcd.fence_key)
+ # read related data and evaluate/print task progress
+ fence_value = cast(bytes, r.get(rcd.fence_key))
if fence_value is None:
return
try:
- initial_count = int(cast(int, fence_value))
+ fence_json = fence_value.decode("utf-8")
+ fence_data = RedisConnectorDeletionFenceData.model_validate_json(
+ cast(str, fence_json)
+ )
except ValueError:
- task_logger.error("The value is not an integer.")
+ task_logger.exception(
+ "monitor_ccpair_indexing_taskset: fence_data not decodeable."
+ )
+ raise
+
+ # the fence is setting up but isn't ready yet
+ if fence_data.num_tasks is None:
return
count = cast(int, r.scard(rcd.taskset_key))
task_logger.info(
- f"Connector deletion progress: cc_pair={cc_pair_id} remaining={count} initial={initial_count}"
+ f"Connector deletion progress: cc_pair={cc_pair_id} remaining={count} initial={fence_data.num_tasks}"
)
if count > 0:
return
@@ -483,7 +496,7 @@ def monitor_connector_deletion_taskset(
)
if not connector or not len(connector.credentials):
task_logger.info(
- "Found no credentials left for connector, deleting connector"
+ "Connector deletion - Found no credentials left for connector, deleting connector"
)
db_session.delete(connector)
db_session.commit()
@@ -493,17 +506,17 @@ def monitor_connector_deletion_taskset(
error_message = f"Error: {str(e)}\n\nStack Trace:\n{stack_trace}"
add_deletion_failure_message(db_session, cc_pair_id, error_message)
task_logger.exception(
- f"Failed to run connector_deletion. "
+ f"Connector deletion exceptioned: "
f"cc_pair={cc_pair_id} connector={cc_pair.connector_id} credential={cc_pair.credential_id}"
)
raise e
task_logger.info(
- f"Successfully deleted cc_pair: "
+ f"Connector deletion succeeded: "
f"cc_pair={cc_pair_id} "
f"connector={cc_pair.connector_id} "
f"credential={cc_pair.credential_id} "
- f"docs_deleted={initial_count}"
+ f"docs_deleted={fence_data.num_tasks}"
)
r.delete(rcd.taskset_key)
@@ -618,6 +631,7 @@ def monitor_ccpair_indexing_taskset(
return
# Read result state BEFORE generator_complete_key to avoid a race condition
+ # never use any blocking methods on the result from inside a task!
result: AsyncResult = AsyncResult(fence_data.celery_task_id)
result_state = result.state
diff --git a/backend/danswer/background/indexing/run_indexing.py b/backend/danswer/background/indexing/run_indexing.py
index cb50739045..d95a6a70d5 100644
--- a/backend/danswer/background/indexing/run_indexing.py
+++ b/backend/danswer/background/indexing/run_indexing.py
@@ -1,6 +1,7 @@
import time
import traceback
-from collections.abc import Callable
+from abc import ABC
+from abc import abstractmethod
from datetime import datetime
from datetime import timedelta
from datetime import timezone
@@ -41,6 +42,19 @@ logger = setup_logger()
INDEXING_TRACER_NUM_PRINT_ENTRIES = 5
+class RunIndexingCallbackInterface(ABC):
+ """Defines a callback interface to be passed to
+ to run_indexing_entrypoint."""
+
+ @abstractmethod
+ def should_stop(self) -> bool:
+ """Signal to stop the looping function in flight."""
+
+ @abstractmethod
+ def progress(self, amount: int) -> None:
+ """Send progress updates to the caller."""
+
+
def _get_connector_runner(
db_session: Session,
attempt: IndexAttempt,
@@ -92,7 +106,7 @@ def _run_indexing(
db_session: Session,
index_attempt: IndexAttempt,
tenant_id: str | None,
- progress_callback: Callable[[int], None] | None = None,
+ callback: RunIndexingCallbackInterface | None = None,
) -> None:
"""
1. Get documents which are either new or updated from specified application
@@ -206,6 +220,11 @@ def _run_indexing(
# index being built. We want to populate it even for paused connectors
# Often paused connectors are sources that aren't updated frequently but the
# contents still need to be initially pulled.
+ if callback:
+ if callback.should_stop():
+ raise RuntimeError("Connector stop signal detected")
+
+ # TODO: should we move this into the above callback instead?
db_session.refresh(db_cc_pair)
if (
(
@@ -263,8 +282,8 @@ def _run_indexing(
# be inaccurate
db_session.commit()
- if progress_callback:
- progress_callback(len(doc_batch))
+ if callback:
+ callback.progress(len(doc_batch))
# This new value is updated every batch, so UI can refresh per batch update
update_docs_indexed(
@@ -394,7 +413,7 @@ def run_indexing_entrypoint(
tenant_id: str | None,
connector_credential_pair_id: int,
is_ee: bool = False,
- progress_callback: Callable[[int], None] | None = None,
+ callback: RunIndexingCallbackInterface | None = None,
) -> None:
try:
if is_ee:
@@ -417,7 +436,7 @@ def run_indexing_entrypoint(
f"credentials='{attempt.connector_credential_pair.connector_id}'"
)
- _run_indexing(db_session, attempt, tenant_id, progress_callback)
+ _run_indexing(db_session, attempt, tenant_id, callback)
logger.info(
f"Indexing finished for tenant {tenant_id}: "
diff --git a/backend/danswer/chat/process_message.py b/backend/danswer/chat/process_message.py
index ea4e7be93d..f58a34c324 100644
--- a/backend/danswer/chat/process_message.py
+++ b/backend/danswer/chat/process_message.py
@@ -672,6 +672,7 @@ def stream_chat_message_objects(
all_docs_useful=selected_db_search_docs is not None
),
document_pruning_config=document_pruning_config,
+ structured_response_format=new_msg_req.structured_response_format,
),
prompt_config=prompt_config,
llm=(
diff --git a/backend/danswer/llm/answering/answer.py b/backend/danswer/llm/answering/answer.py
index 4648e0fe82..d2aeb1b14c 100644
--- a/backend/danswer/llm/answering/answer.py
+++ b/backend/danswer/llm/answering/answer.py
@@ -237,6 +237,7 @@ class Answer:
prompt=prompt,
tools=final_tool_definitions if final_tool_definitions else None,
tool_choice="required" if self.force_use_tool.force_use else None,
+ structured_response_format=self.answer_style_config.structured_response_format,
):
if isinstance(message, AIMessageChunk) and (
message.tool_call_chunks or message.tool_calls
@@ -331,7 +332,10 @@ class Answer:
tool_choice: ToolChoiceOptions | None = None,
) -> Iterator[str | StreamStopInfo]:
for message in self.llm.stream(
- prompt=prompt, tools=tools, tool_choice=tool_choice
+ prompt=prompt,
+ tools=tools,
+ tool_choice=tool_choice,
+ structured_response_format=self.answer_style_config.structured_response_format,
):
if isinstance(message, AIMessageChunk):
if message.content:
diff --git a/backend/danswer/llm/answering/models.py b/backend/danswer/llm/answering/models.py
index fb5fa9c313..87c1297fe9 100644
--- a/backend/danswer/llm/answering/models.py
+++ b/backend/danswer/llm/answering/models.py
@@ -116,6 +116,10 @@ class AnswerStyleConfig(BaseModel):
document_pruning_config: DocumentPruningConfig = Field(
default_factory=DocumentPruningConfig
)
+ # forces the LLM to return a structured response, see
+ # https://platform.openai.com/docs/guides/structured-outputs/introduction
+ # right now, only used by the simple chat API
+ structured_response_format: dict | None = None
@model_validator(mode="after")
def check_quotes_and_citation(self) -> "AnswerStyleConfig":
diff --git a/backend/danswer/llm/chat_llm.py b/backend/danswer/llm/chat_llm.py
index d50f825318..d450fff0a6 100644
--- a/backend/danswer/llm/chat_llm.py
+++ b/backend/danswer/llm/chat_llm.py
@@ -280,6 +280,7 @@ class DefaultMultiLLM(LLM):
tools: list[dict] | None,
tool_choice: ToolChoiceOptions | None,
stream: bool,
+ structured_response_format: dict | None = None,
) -> litellm.ModelResponse | litellm.CustomStreamWrapper:
if isinstance(prompt, list):
prompt = [
@@ -313,6 +314,11 @@ class DefaultMultiLLM(LLM):
# NOTE: we can't pass this in if tools are not specified
# or else OpenAI throws an error
**({"parallel_tool_calls": False} if tools else {}),
+ **(
+ {"response_format": structured_response_format}
+ if structured_response_format
+ else {}
+ ),
**self._model_kwargs,
)
except Exception as e:
@@ -336,12 +342,16 @@ class DefaultMultiLLM(LLM):
prompt: LanguageModelInput,
tools: list[dict] | None = None,
tool_choice: ToolChoiceOptions | None = None,
+ structured_response_format: dict | None = None,
) -> BaseMessage:
if LOG_DANSWER_MODEL_INTERACTIONS:
self.log_model_configs()
response = cast(
- litellm.ModelResponse, self._completion(prompt, tools, tool_choice, False)
+ litellm.ModelResponse,
+ self._completion(
+ prompt, tools, tool_choice, False, structured_response_format
+ ),
)
choice = response.choices[0]
if hasattr(choice, "message"):
@@ -354,18 +364,21 @@ class DefaultMultiLLM(LLM):
prompt: LanguageModelInput,
tools: list[dict] | None = None,
tool_choice: ToolChoiceOptions | None = None,
+ structured_response_format: dict | None = None,
) -> Iterator[BaseMessage]:
if LOG_DANSWER_MODEL_INTERACTIONS:
self.log_model_configs()
if DISABLE_LITELLM_STREAMING:
- yield self.invoke(prompt)
+ yield self.invoke(prompt, tools, tool_choice, structured_response_format)
return
output = None
response = cast(
litellm.CustomStreamWrapper,
- self._completion(prompt, tools, tool_choice, True),
+ self._completion(
+ prompt, tools, tool_choice, True, structured_response_format
+ ),
)
try:
for part in response:
diff --git a/backend/danswer/llm/custom_llm.py b/backend/danswer/llm/custom_llm.py
index 4a5ba7857c..6b80406cf2 100644
--- a/backend/danswer/llm/custom_llm.py
+++ b/backend/danswer/llm/custom_llm.py
@@ -80,6 +80,7 @@ class CustomModelServer(LLM):
prompt: LanguageModelInput,
tools: list[dict] | None = None,
tool_choice: ToolChoiceOptions | None = None,
+ structured_response_format: dict | None = None,
) -> BaseMessage:
return self._execute(prompt)
@@ -88,5 +89,6 @@ class CustomModelServer(LLM):
prompt: LanguageModelInput,
tools: list[dict] | None = None,
tool_choice: ToolChoiceOptions | None = None,
+ structured_response_format: dict | None = None,
) -> Iterator[BaseMessage]:
yield self._execute(prompt)
diff --git a/backend/danswer/llm/interfaces.py b/backend/danswer/llm/interfaces.py
index 6cb58e46c6..7deee11dfa 100644
--- a/backend/danswer/llm/interfaces.py
+++ b/backend/danswer/llm/interfaces.py
@@ -88,11 +88,14 @@ class LLM(abc.ABC):
prompt: LanguageModelInput,
tools: list[dict] | None = None,
tool_choice: ToolChoiceOptions | None = None,
+ structured_response_format: dict | None = None,
) -> BaseMessage:
self._precall(prompt)
# TODO add a postcall to log model outputs independent of concrete class
# implementation
- return self._invoke_implementation(prompt, tools, tool_choice)
+ return self._invoke_implementation(
+ prompt, tools, tool_choice, structured_response_format
+ )
@abc.abstractmethod
def _invoke_implementation(
@@ -100,6 +103,7 @@ class LLM(abc.ABC):
prompt: LanguageModelInput,
tools: list[dict] | None = None,
tool_choice: ToolChoiceOptions | None = None,
+ structured_response_format: dict | None = None,
) -> BaseMessage:
raise NotImplementedError
@@ -108,11 +112,14 @@ class LLM(abc.ABC):
prompt: LanguageModelInput,
tools: list[dict] | None = None,
tool_choice: ToolChoiceOptions | None = None,
+ structured_response_format: dict | None = None,
) -> Iterator[BaseMessage]:
self._precall(prompt)
# TODO add a postcall to log model outputs independent of concrete class
# implementation
- return self._stream_implementation(prompt, tools, tool_choice)
+ return self._stream_implementation(
+ prompt, tools, tool_choice, structured_response_format
+ )
@abc.abstractmethod
def _stream_implementation(
@@ -120,5 +127,6 @@ class LLM(abc.ABC):
prompt: LanguageModelInput,
tools: list[dict] | None = None,
tool_choice: ToolChoiceOptions | None = None,
+ structured_response_format: dict | None = None,
) -> Iterator[BaseMessage]:
raise NotImplementedError
diff --git a/backend/danswer/server/manage/administrative.py b/backend/danswer/server/manage/administrative.py
index d16aa59c4c..1ceeb776ab 100644
--- a/backend/danswer/server/manage/administrative.py
+++ b/backend/danswer/server/manage/administrative.py
@@ -19,7 +19,6 @@ from danswer.db.connector_credential_pair import get_connector_credential_pair
from danswer.db.connector_credential_pair import (
update_connector_credential_pair_from_id,
)
-from danswer.db.deletion_attempt import check_deletion_attempt_is_allowed
from danswer.db.engine import get_current_tenant_id
from danswer.db.engine import get_session
from danswer.db.enums import ConnectorCredentialPairStatus
@@ -175,15 +174,19 @@ def create_deletion_attempt_for_connector_id(
cc_pair_id=cc_pair.id, db_session=db_session, include_secondary_index=True
)
+ # TODO(rkuo): 2024-10-24 - check_deletion_attempt_is_allowed shouldn't be necessary
+ # any more due to background locking improvements.
+ # Remove the below permanently if everything is behaving for 30 days.
+
# Check if the deletion attempt should be allowed
- deletion_attempt_disallowed_reason = check_deletion_attempt_is_allowed(
- connector_credential_pair=cc_pair, db_session=db_session
- )
- if deletion_attempt_disallowed_reason:
- raise HTTPException(
- status_code=400,
- detail=deletion_attempt_disallowed_reason,
- )
+ # deletion_attempt_disallowed_reason = check_deletion_attempt_is_allowed(
+ # connector_credential_pair=cc_pair, db_session=db_session
+ # )
+ # if deletion_attempt_disallowed_reason:
+ # raise HTTPException(
+ # status_code=400,
+ # detail=deletion_attempt_disallowed_reason,
+ # )
# mark as deleting
update_connector_credential_pair_from_id(
diff --git a/backend/danswer/server/query_and_chat/models.py b/backend/danswer/server/query_and_chat/models.py
index 42f4100a24..1ca14f9283 100644
--- a/backend/danswer/server/query_and_chat/models.py
+++ b/backend/danswer/server/query_and_chat/models.py
@@ -108,6 +108,10 @@ class CreateChatMessageRequest(ChunkContext):
# used for seeded chats to kick off the generation of an AI answer
use_existing_user_message: bool = False
+ # forces the LLM to return a structured response, see
+ # https://platform.openai.com/docs/guides/structured-outputs/introduction
+ structured_response_format: dict | None = None
+
@model_validator(mode="after")
def check_search_doc_ids_or_retrieval_options(self) -> "CreateChatMessageRequest":
if self.search_doc_ids is None and self.retrieval_options is None:
diff --git a/backend/ee/danswer/server/query_and_chat/chat_backend.py b/backend/ee/danswer/server/query_and_chat/chat_backend.py
index dd637dcf08..b25ed8357d 100644
--- a/backend/ee/danswer/server/query_and_chat/chat_backend.py
+++ b/backend/ee/danswer/server/query_and_chat/chat_backend.py
@@ -176,6 +176,7 @@ def handle_simplified_chat_message(
chunks_above=0,
chunks_below=0,
full_doc=chat_message_req.full_doc,
+ structured_response_format=chat_message_req.structured_response_format,
)
packets = stream_chat_message_objects(
@@ -202,7 +203,7 @@ def handle_send_message_simple_with_history(
raise HTTPException(status_code=400, detail="Messages cannot be zero length")
# This is a sanity check to make sure the chat history is valid
- # It must start with a user message and alternate between user and assistant
+ # It must start with a user message and alternate beteen user and assistant
expected_role = MessageType.USER
for msg in req.messages:
if not msg.message:
@@ -296,6 +297,7 @@ def handle_send_message_simple_with_history(
chunks_above=0,
chunks_below=0,
full_doc=req.full_doc,
+ structured_response_format=req.structured_response_format,
)
packets = stream_chat_message_objects(
diff --git a/backend/ee/danswer/server/query_and_chat/models.py b/backend/ee/danswer/server/query_and_chat/models.py
index 052be683e9..4baf17ac8c 100644
--- a/backend/ee/danswer/server/query_and_chat/models.py
+++ b/backend/ee/danswer/server/query_and_chat/models.py
@@ -48,6 +48,9 @@ class BasicCreateChatMessageRequest(ChunkContext):
query_override: str | None = None
# If search_doc_ids provided, then retrieval options are unused
search_doc_ids: list[int] | None = None
+ # only works if using an OpenAI model. See the following for more details:
+ # https://platform.openai.com/docs/guides/structured-outputs/introduction
+ structured_response_format: dict | None = None
class BasicCreateChatMessageWithHistoryRequest(ChunkContext):
@@ -60,6 +63,9 @@ class BasicCreateChatMessageWithHistoryRequest(ChunkContext):
skip_rerank: bool | None = None
# If search_doc_ids provided, then retrieval options are unused
search_doc_ids: list[int] | None = None
+ # only works if using an OpenAI model. See the following for more details:
+ # https://platform.openai.com/docs/guides/structured-outputs/introduction
+ structured_response_format: dict | None = None
class SimpleDoc(BaseModel):
diff --git a/backend/scripts/add_connector_creation_script.py b/backend/scripts/add_connector_creation_script.py
new file mode 100644
index 0000000000..9a1944080c
--- /dev/null
+++ b/backend/scripts/add_connector_creation_script.py
@@ -0,0 +1,148 @@
+from typing import Any
+from typing import Dict
+
+import requests
+
+API_SERVER_URL = "http://localhost:3000" # Adjust this to your Danswer server URL
+HEADERS = {"Content-Type": "application/json"}
+API_KEY = "danswer-api-key" # API key here, if auth is enabled
+
+
+def create_connector(
+ name: str,
+ source: str,
+ input_type: str,
+ connector_specific_config: Dict[str, Any],
+ is_public: bool = True,
+ groups: list[int] | None = None,
+) -> Dict[str, Any]:
+ connector_update_request = {
+ "name": name,
+ "source": source,
+ "input_type": input_type,
+ "connector_specific_config": connector_specific_config,
+ "is_public": is_public,
+ "groups": groups or [],
+ }
+
+ response = requests.post(
+ url=f"{API_SERVER_URL}/api/manage/admin/connector",
+ json=connector_update_request,
+ headers=HEADERS,
+ )
+ response.raise_for_status()
+ return response.json()
+
+
+def create_credential(
+ name: str,
+ source: str,
+ credential_json: Dict[str, Any],
+ is_public: bool = True,
+ groups: list[int] | None = None,
+) -> Dict[str, Any]:
+ credential_request = {
+ "name": name,
+ "source": source,
+ "credential_json": credential_json,
+ "admin_public": is_public,
+ "groups": groups or [],
+ }
+
+ response = requests.post(
+ url=f"{API_SERVER_URL}/api/manage/credential",
+ json=credential_request,
+ headers=HEADERS,
+ )
+ response.raise_for_status()
+ return response.json()
+
+
+def create_cc_pair(
+ connector_id: int,
+ credential_id: int,
+ name: str,
+ access_type: str = "public",
+ groups: list[int] | None = None,
+) -> Dict[str, Any]:
+ cc_pair_request = {
+ "name": name,
+ "access_type": access_type,
+ "groups": groups or [],
+ }
+
+ response = requests.put(
+ url=f"{API_SERVER_URL}/api/manage/connector/{connector_id}/credential/{credential_id}",
+ json=cc_pair_request,
+ headers=HEADERS,
+ )
+ response.raise_for_status()
+ return response.json()
+
+
+def main() -> None:
+ # Create a Web connector
+ web_connector = create_connector(
+ name="Example Web Connector",
+ source="web",
+ input_type="load_state",
+ connector_specific_config={
+ "base_url": "https://example.com",
+ "web_connector_type": "recursive",
+ },
+ )
+ print(f"Created Web Connector: {web_connector}")
+
+ # Create a credential for the Web connector
+ web_credential = create_credential(
+ name="Example Web Credential",
+ source="web",
+ credential_json={}, # Web connectors typically don't need credentials
+ is_public=True,
+ )
+ print(f"Created Web Credential: {web_credential}")
+
+ # Create CC pair for Web connector
+ web_cc_pair = create_cc_pair(
+ connector_id=web_connector["id"],
+ credential_id=web_credential["id"],
+ name="Example Web CC Pair",
+ access_type="public",
+ )
+ print(f"Created Web CC Pair: {web_cc_pair}")
+
+ # Create a GitHub connector
+ github_connector = create_connector(
+ name="Example GitHub Connector",
+ source="github",
+ input_type="poll",
+ connector_specific_config={
+ "repo_owner": "example-owner",
+ "repo_name": "example-repo",
+ "include_prs": True,
+ "include_issues": True,
+ },
+ )
+ print(f"Created GitHub Connector: {github_connector}")
+
+ # Create a credential for the GitHub connector
+ github_credential = create_credential(
+ name="Example GitHub Credential",
+ source="github",
+ credential_json={"github_access_token": "your_github_access_token_here"},
+ is_public=True,
+ )
+ print(f"Created GitHub Credential: {github_credential}")
+
+ # Create CC pair for GitHub connector
+ github_cc_pair = create_cc_pair(
+ connector_id=github_connector["id"],
+ credential_id=github_credential["id"],
+ name="Example GitHub CC Pair",
+ access_type="public",
+ )
+ print(f"Created GitHub CC Pair: {github_cc_pair}")
+
+
+if __name__ == "__main__":
+ main()
diff --git a/backend/tests/integration/conftest.py b/backend/tests/integration/conftest.py
index 77d9e0e702..f3d194e22b 100644
--- a/backend/tests/integration/conftest.py
+++ b/backend/tests/integration/conftest.py
@@ -6,7 +6,9 @@ from sqlalchemy.orm import Session
from danswer.db.engine import get_session_context_manager
from danswer.db.search_settings import get_current_search_settings
+from tests.integration.common_utils.managers.user import UserManager
from tests.integration.common_utils.reset import reset_all
+from tests.integration.common_utils.test_models import DATestUser
from tests.integration.common_utils.vespa import vespa_fixture
@@ -44,3 +46,11 @@ def vespa_client(db_session: Session) -> vespa_fixture:
@pytest.fixture
def reset() -> None:
reset_all()
+
+
+@pytest.fixture
+def new_admin_user(reset: None) -> DATestUser | None:
+ try:
+ return UserManager.create(name="admin_user")
+ except Exception:
+ return None
diff --git a/backend/tests/integration/tests/dev_apis/test_simple_chat_api.py b/backend/tests/integration/tests/dev_apis/test_simple_chat_api.py
index 0a4e7b40b5..c37d1a6235 100644
--- a/backend/tests/integration/tests/dev_apis/test_simple_chat_api.py
+++ b/backend/tests/integration/tests/dev_apis/test_simple_chat_api.py
@@ -1,7 +1,10 @@
+import json
+
import requests
from danswer.configs.constants import MessageType
from tests.integration.common_utils.constants import API_SERVER_URL
+from tests.integration.common_utils.constants import GENERAL_HEADERS
from tests.integration.common_utils.constants import NUM_DOCS
from tests.integration.common_utils.managers.api_key import APIKeyManager
from tests.integration.common_utils.managers.cc_pair import CCPairManager
@@ -145,3 +148,85 @@ def test_using_reference_docs_with_simple_with_history_api_flow(reset: None) ->
# This ensures the the document we think we are referencing when we send the search_doc_ids in the second
# message is the document that we expect it to be
assert response_json["top_documents"][0]["document_id"] == cc_pair_1.documents[2].id
+
+
+def test_send_message_simple_with_history_strict_json(
+ new_admin_user: DATestUser | None,
+) -> None:
+ # create connectors
+ LLMProviderManager.create(user_performing_action=new_admin_user)
+
+ response = requests.post(
+ f"{API_SERVER_URL}/chat/send-message-simple-with-history",
+ json={
+ # intentionally not relevant prompt to ensure that the
+ # structured response format is actually used
+ "messages": [
+ {
+ "message": "What is green?",
+ "role": MessageType.USER.value,
+ }
+ ],
+ "persona_id": 0,
+ "prompt_id": 0,
+ "structured_response_format": {
+ "type": "json_schema",
+ "json_schema": {
+ "name": "presidents",
+ "schema": {
+ "type": "object",
+ "properties": {
+ "presidents": {
+ "type": "array",
+ "items": {"type": "string"},
+ "description": "List of the first three US presidents",
+ }
+ },
+ "required": ["presidents"],
+ "additionalProperties": False,
+ },
+ "strict": True,
+ },
+ },
+ },
+ headers=new_admin_user.headers if new_admin_user else GENERAL_HEADERS,
+ )
+ assert response.status_code == 200
+
+ response_json = response.json()
+
+ # Check that the answer is present
+ assert "answer" in response_json
+ assert response_json["answer"] is not None
+
+ # helper
+ def clean_json_string(json_string: str) -> str:
+ return json_string.strip().removeprefix("```json").removesuffix("```").strip()
+
+ # Attempt to parse the answer as JSON
+ try:
+ clean_answer = clean_json_string(response_json["answer"])
+ parsed_answer = json.loads(clean_answer)
+
+ # NOTE: do not check content, just the structure
+ assert isinstance(parsed_answer, dict)
+ assert "presidents" in parsed_answer
+ assert isinstance(parsed_answer["presidents"], list)
+ for president in parsed_answer["presidents"]:
+ assert isinstance(president, str)
+ except json.JSONDecodeError:
+ assert (
+ False
+ ), f"The answer is not a valid JSON object - '{response_json['answer']}'"
+
+ # Check that the answer_citationless is also valid JSON
+ assert "answer_citationless" in response_json
+ assert response_json["answer_citationless"] is not None
+ try:
+ clean_answer_citationless = clean_json_string(
+ response_json["answer_citationless"]
+ )
+ parsed_answer_citationless = json.loads(clean_answer_citationless)
+ assert isinstance(parsed_answer_citationless, dict)
+ except json.JSONDecodeError:
+ assert False, "The answer_citationless is not a valid JSON object"
diff --git a/web/src/app/admin/embeddings/modals/ProviderCreationModal.tsx b/web/src/app/admin/embeddings/modals/ProviderCreationModal.tsx
index cf9d584439..8229d19967 100644
--- a/web/src/app/admin/embeddings/modals/ProviderCreationModal.tsx
+++ b/web/src/app/admin/embeddings/modals/ProviderCreationModal.tsx
@@ -173,7 +173,6 @@ export function ProviderCreationModal({
return (
Failed to load prompt data Loading...
{isPublic ? "Public Assistant" : "Make Assistant Public"}
@@ -67,6 +67,6 @@ export function MakePublicAssistantModal({
@@ -203,6 +200,6 @@ export function SetDefaultModelModal({
@@ -154,6 +154,6 @@ export function ShareChatSessionModal({
)}