mirror of
https://github.com/danswer-ai/danswer.git
synced 2025-06-22 14:00:57 +02:00
Merge branch 'main' of https://github.com/danswer-ai/danswer into feature/reset_indexes
This commit is contained in:
commit
0ed77aa8a7
2
.github/pull_request_template.md
vendored
2
.github/pull_request_template.md
vendored
@ -26,4 +26,4 @@ N/A
|
||||
|
||||
## Backporting (check the box to trigger backport action)
|
||||
Note: You have to check that the action passes, otherwise resolve the conflicts manually and tag the patches.
|
||||
[ ] This PR should be backported (make sure to check that the backport attempt succeeds)
|
||||
- [ ] This PR should be backported (make sure to check that the backport attempt succeeds)
|
||||
|
66
.github/workflows/pr-backport-autotrigger.yml
vendored
66
.github/workflows/pr-backport-autotrigger.yml
vendored
@ -10,43 +10,83 @@ jobs:
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v2
|
||||
uses: actions/checkout@v3
|
||||
with:
|
||||
fetch-depth: 0 # Fetch all history for all branches and tags
|
||||
|
||||
- name: Set up Git
|
||||
run: |
|
||||
git config user.name "github-actions[bot]"
|
||||
git config user.email "github-actions[bot]@users.noreply.github.com"
|
||||
|
||||
- name: Check for Backport Checkbox
|
||||
id: checkbox-check
|
||||
run: |
|
||||
PR_BODY="${{ github.event.pull_request.body }}"
|
||||
if [[ "$PR_BODY" == *"[x] This PR should be backported"* ]]; then
|
||||
echo "::set-output name=backport::true"
|
||||
echo "backport=true" >> $GITHUB_OUTPUT
|
||||
else
|
||||
echo "::set-output name=backport::false"
|
||||
echo "backport=false" >> $GITHUB_OUTPUT
|
||||
fi
|
||||
|
||||
- name: List and sort release branches
|
||||
id: list-branches
|
||||
run: |
|
||||
git fetch --all
|
||||
BRANCHES=$(git branch -r | grep 'origin/release/v' | sed 's|origin/release/v||' | sort -Vr)
|
||||
git fetch --all --tags
|
||||
BRANCHES=$(git for-each-ref --format='%(refname:short)' refs/remotes/origin/release/* | sed 's|origin/release/||' | sort -Vr)
|
||||
BETA=$(echo "$BRANCHES" | head -n 1)
|
||||
STABLE=$(echo "$BRANCHES" | head -n 2 | tail -n 1)
|
||||
echo "::set-output name=beta::$BETA"
|
||||
echo "::set-output name=stable::$STABLE"
|
||||
echo "beta=$BETA" >> $GITHUB_OUTPUT
|
||||
echo "stable=$STABLE" >> $GITHUB_OUTPUT
|
||||
# Fetch latest tags for beta and stable
|
||||
LATEST_BETA_TAG=$(git tag -l "v*.*.0-beta.*" | sort -Vr | head -n 1)
|
||||
LATEST_STABLE_TAG=$(git tag -l "v*.*.*" | grep -v -- "-beta" | sort -Vr | head -n 1)
|
||||
# Increment latest beta tag
|
||||
NEW_BETA_TAG=$(echo $LATEST_BETA_TAG | awk -F '[.-]' '{print $1 "." $2 ".0-beta." ($NF+1)}')
|
||||
# Increment latest stable tag
|
||||
NEW_STABLE_TAG=$(echo $LATEST_STABLE_TAG | awk -F '.' '{print $1 "." $2 "." ($3+1)}')
|
||||
echo "latest_beta_tag=$LATEST_BETA_TAG" >> $GITHUB_OUTPUT
|
||||
echo "latest_stable_tag=$LATEST_STABLE_TAG" >> $GITHUB_OUTPUT
|
||||
echo "new_beta_tag=$NEW_BETA_TAG" >> $GITHUB_OUTPUT
|
||||
echo "new_stable_tag=$NEW_STABLE_TAG" >> $GITHUB_OUTPUT
|
||||
|
||||
- name: Echo branch and tag information
|
||||
run: |
|
||||
echo "Beta branch: ${{ steps.list-branches.outputs.beta }}"
|
||||
echo "Stable branch: ${{ steps.list-branches.outputs.stable }}"
|
||||
echo "Latest beta tag: ${{ steps.list-branches.outputs.latest_beta_tag }}"
|
||||
echo "Latest stable tag: ${{ steps.list-branches.outputs.latest_stable_tag }}"
|
||||
echo "New beta tag: ${{ steps.list-branches.outputs.new_beta_tag }}"
|
||||
echo "New stable tag: ${{ steps.list-branches.outputs.new_stable_tag }}"
|
||||
|
||||
- name: Trigger Backport
|
||||
if: steps.checkbox-check.outputs.backport == 'true'
|
||||
run: |
|
||||
set -e
|
||||
echo "Backporting to beta ${{ steps.list-branches.outputs.beta }} and stable ${{ steps.list-branches.outputs.stable }}"
|
||||
# Fetch all history for all branches and tags
|
||||
git fetch --prune --unshallow
|
||||
# Checkout the beta branch
|
||||
git checkout ${{ steps.list-branches.outputs.beta }}
|
||||
# Cherry-pick the last commit from the merged PR
|
||||
git cherry-pick ${{ github.event.pull_request.merge_commit_sha }}
|
||||
# Push the changes to the beta branch
|
||||
# Cherry-pick the merge commit from the merged PR
|
||||
git cherry-pick -m 1 ${{ github.event.pull_request.merge_commit_sha }} || {
|
||||
echo "Cherry-pick to beta failed due to conflicts."
|
||||
exit 1
|
||||
}
|
||||
# Create new beta tag
|
||||
git tag ${{ steps.list-branches.outputs.new_beta_tag }}
|
||||
# Push the changes and tag to the beta branch
|
||||
git push origin ${{ steps.list-branches.outputs.beta }}
|
||||
git push origin ${{ steps.list-branches.outputs.new_beta_tag }}
|
||||
# Checkout the stable branch
|
||||
git checkout ${{ steps.list-branches.outputs.stable }}
|
||||
# Cherry-pick the last commit from the merged PR
|
||||
git cherry-pick ${{ github.event.pull_request.merge_commit_sha }}
|
||||
# Push the changes to the stable branch
|
||||
# Cherry-pick the merge commit from the merged PR
|
||||
git cherry-pick -m 1 ${{ github.event.pull_request.merge_commit_sha }} || {
|
||||
echo "Cherry-pick to stable failed due to conflicts."
|
||||
exit 1
|
||||
}
|
||||
# Create new stable tag
|
||||
git tag ${{ steps.list-branches.outputs.new_stable_tag }}
|
||||
# Push the changes and tag to the stable branch
|
||||
git push origin ${{ steps.list-branches.outputs.stable }}
|
||||
git push origin ${{ steps.list-branches.outputs.new_stable_tag }}
|
||||
|
@ -31,6 +31,12 @@ def upgrade() -> None:
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
# First, update any null values to a default value
|
||||
op.execute(
|
||||
"UPDATE connector_credential_pair SET last_attempt_status = 'NOT_STARTED' WHERE last_attempt_status IS NULL"
|
||||
)
|
||||
|
||||
# Then, make the column non-nullable
|
||||
op.alter_column(
|
||||
"connector_credential_pair",
|
||||
"last_attempt_status",
|
||||
|
@ -17,6 +17,7 @@ from danswer.background.celery.celery_redis import RedisConnectorCredentialPair
|
||||
from danswer.background.celery.celery_redis import RedisConnectorDeletion
|
||||
from danswer.background.celery.celery_redis import RedisConnectorIndexing
|
||||
from danswer.background.celery.celery_redis import RedisConnectorPruning
|
||||
from danswer.background.celery.celery_redis import RedisConnectorStop
|
||||
from danswer.background.celery.celery_redis import RedisDocumentSet
|
||||
from danswer.background.celery.celery_redis import RedisUserGroup
|
||||
from danswer.background.celery.celery_utils import celery_is_worker_primary
|
||||
@ -161,6 +162,9 @@ def on_worker_init(sender: Any, **kwargs: Any) -> None:
|
||||
for key in r.scan_iter(RedisConnectorIndexing.FENCE_PREFIX + "*"):
|
||||
r.delete(key)
|
||||
|
||||
for key in r.scan_iter(RedisConnectorStop.FENCE_PREFIX + "*"):
|
||||
r.delete(key)
|
||||
|
||||
|
||||
# @worker_process_init.connect
|
||||
# def on_worker_process_init(sender: Any, **kwargs: Any) -> None:
|
||||
|
@ -313,6 +313,8 @@ class RedisConnectorDeletion(RedisObjectHelper):
|
||||
lock: redis.lock.Lock,
|
||||
tenant_id: str | None,
|
||||
) -> int | None:
|
||||
"""Returns None if the cc_pair doesn't exist.
|
||||
Otherwise, returns an int with the number of generated tasks."""
|
||||
last_lock_time = time.monotonic()
|
||||
|
||||
async_results = []
|
||||
@ -540,6 +542,29 @@ class RedisConnectorIndexing(RedisObjectHelper):
|
||||
return False
|
||||
|
||||
|
||||
class RedisConnectorStop(RedisObjectHelper):
|
||||
"""Used to signal any running tasks for a connector to stop. We should refactor
|
||||
connector related redis helpers into a single class.
|
||||
"""
|
||||
|
||||
PREFIX = "connectorstop"
|
||||
FENCE_PREFIX = PREFIX + "_fence" # a fence for the entire indexing process
|
||||
TASKSET_PREFIX = PREFIX + "_taskset" # stores a list of prune tasks id's
|
||||
|
||||
def __init__(self, id: int) -> None:
|
||||
super().__init__(str(id))
|
||||
|
||||
def generate_tasks(
|
||||
self,
|
||||
celery_app: Celery,
|
||||
db_session: Session,
|
||||
redis_client: Redis,
|
||||
lock: redis.lock.Lock | None,
|
||||
tenant_id: str | None,
|
||||
) -> int | None:
|
||||
return None
|
||||
|
||||
|
||||
def celery_get_queue_length(queue: str, r: Redis) -> int:
|
||||
"""This is a redis specific way to get the length of a celery queue.
|
||||
It is priority aware and knows how to count across the multiple redis lists
|
||||
|
@ -1,4 +1,3 @@
|
||||
from collections.abc import Callable
|
||||
from datetime import datetime
|
||||
from datetime import timezone
|
||||
from typing import Any
|
||||
@ -6,6 +5,7 @@ from typing import Any
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from danswer.background.celery.celery_redis import RedisConnectorDeletion
|
||||
from danswer.background.indexing.run_indexing import RunIndexingCallbackInterface
|
||||
from danswer.configs.app_configs import MAX_PRUNING_DOCUMENT_RETRIEVAL_PER_MINUTE
|
||||
from danswer.connectors.cross_connector_utils.rate_limit_wrapper import (
|
||||
rate_limit_builder,
|
||||
@ -79,7 +79,7 @@ def document_batch_to_ids(
|
||||
|
||||
def extract_ids_from_runnable_connector(
|
||||
runnable_connector: BaseConnector,
|
||||
progress_callback: Callable[[int], None] | None = None,
|
||||
callback: RunIndexingCallbackInterface | None = None,
|
||||
) -> set[str]:
|
||||
"""
|
||||
If the PruneConnector hasnt been implemented for the given connector, just pull
|
||||
@ -110,8 +110,10 @@ def extract_ids_from_runnable_connector(
|
||||
max_calls=MAX_PRUNING_DOCUMENT_RETRIEVAL_PER_MINUTE, period=60
|
||||
)(document_batch_to_ids)
|
||||
for doc_batch in doc_batch_generator:
|
||||
if progress_callback:
|
||||
progress_callback(len(doc_batch))
|
||||
if callback:
|
||||
if callback.should_stop():
|
||||
raise RuntimeError("Stop signal received")
|
||||
callback.progress(len(doc_batch))
|
||||
all_connector_doc_ids.update(doc_batch_processing_func(doc_batch))
|
||||
|
||||
return all_connector_doc_ids
|
||||
|
@ -1,3 +1,6 @@
|
||||
from datetime import datetime
|
||||
from datetime import timezone
|
||||
|
||||
import redis
|
||||
from celery import Celery
|
||||
from celery import shared_task
|
||||
@ -8,6 +11,12 @@ from sqlalchemy.orm import Session
|
||||
|
||||
from danswer.background.celery.apps.app_base import task_logger
|
||||
from danswer.background.celery.celery_redis import RedisConnectorDeletion
|
||||
from danswer.background.celery.celery_redis import RedisConnectorIndexing
|
||||
from danswer.background.celery.celery_redis import RedisConnectorPruning
|
||||
from danswer.background.celery.celery_redis import RedisConnectorStop
|
||||
from danswer.background.celery.tasks.shared.RedisConnectorDeletionFenceData import (
|
||||
RedisConnectorDeletionFenceData,
|
||||
)
|
||||
from danswer.configs.app_configs import JOB_TIMEOUT
|
||||
from danswer.configs.constants import CELERY_VESPA_SYNC_BEAT_LOCK_TIMEOUT
|
||||
from danswer.configs.constants import DanswerRedisLocks
|
||||
@ -15,9 +24,15 @@ from danswer.db.connector_credential_pair import get_connector_credential_pair_f
|
||||
from danswer.db.connector_credential_pair import get_connector_credential_pairs
|
||||
from danswer.db.engine import get_session_with_tenant
|
||||
from danswer.db.enums import ConnectorCredentialPairStatus
|
||||
from danswer.db.search_settings import get_all_search_settings
|
||||
from danswer.redis.redis_pool import get_redis_client
|
||||
|
||||
|
||||
class TaskDependencyError(RuntimeError):
|
||||
"""Raised to the caller to indicate dependent tasks are running that would interfere
|
||||
with connector deletion."""
|
||||
|
||||
|
||||
@shared_task(
|
||||
name="check_for_connector_deletion_task",
|
||||
soft_time_limit=JOB_TIMEOUT,
|
||||
@ -37,17 +52,30 @@ def check_for_connector_deletion_task(self: Task, *, tenant_id: str | None) -> N
|
||||
if not lock_beat.acquire(blocking=False):
|
||||
return
|
||||
|
||||
# collect cc_pair_ids
|
||||
cc_pair_ids: list[int] = []
|
||||
with get_session_with_tenant(tenant_id) as db_session:
|
||||
cc_pairs = get_connector_credential_pairs(db_session)
|
||||
for cc_pair in cc_pairs:
|
||||
cc_pair_ids.append(cc_pair.id)
|
||||
|
||||
# try running cleanup on the cc_pair_ids
|
||||
for cc_pair_id in cc_pair_ids:
|
||||
with get_session_with_tenant(tenant_id) as db_session:
|
||||
try_generate_document_cc_pair_cleanup_tasks(
|
||||
self.app, cc_pair_id, db_session, r, lock_beat, tenant_id
|
||||
)
|
||||
rcs = RedisConnectorStop(cc_pair_id)
|
||||
try:
|
||||
try_generate_document_cc_pair_cleanup_tasks(
|
||||
self.app, cc_pair_id, db_session, r, lock_beat, tenant_id
|
||||
)
|
||||
except TaskDependencyError as e:
|
||||
# this means we wanted to start deleting but dependent tasks were running
|
||||
# Leave a stop signal to clear indexing and pruning tasks more quickly
|
||||
task_logger.info(str(e))
|
||||
r.set(rcs.fence_key, cc_pair_id)
|
||||
else:
|
||||
# clear the stop signal if it exists ... no longer needed
|
||||
r.delete(rcs.fence_key)
|
||||
|
||||
except SoftTimeLimitExceeded:
|
||||
task_logger.info(
|
||||
"Soft time limit exceeded, task is being terminated gracefully."
|
||||
@ -70,6 +98,10 @@ def try_generate_document_cc_pair_cleanup_tasks(
|
||||
"""Returns an int if syncing is needed. The int represents the number of sync tasks generated.
|
||||
Note that syncing can still be required even if the number of sync tasks generated is zero.
|
||||
Returns None if no syncing is required.
|
||||
|
||||
Will raise TaskDependencyError if dependent tasks such as indexing and pruning are
|
||||
still running. In our case, the caller reacts by setting a stop signal in Redis to
|
||||
exit those tasks as quickly as possible.
|
||||
"""
|
||||
|
||||
lock_beat.reacquire()
|
||||
@ -90,28 +122,63 @@ def try_generate_document_cc_pair_cleanup_tasks(
|
||||
if cc_pair.status != ConnectorCredentialPairStatus.DELETING:
|
||||
return None
|
||||
|
||||
# add tasks to celery and build up the task set to monitor in redis
|
||||
r.delete(rcd.taskset_key)
|
||||
|
||||
# Add all documents that need to be updated into the queue
|
||||
task_logger.info(
|
||||
f"RedisConnectorDeletion.generate_tasks starting. cc_pair_id={cc_pair.id}"
|
||||
# set a basic fence to start
|
||||
fence_value = RedisConnectorDeletionFenceData(
|
||||
num_tasks=None,
|
||||
submitted=datetime.now(timezone.utc),
|
||||
)
|
||||
tasks_generated = rcd.generate_tasks(app, db_session, r, lock_beat, tenant_id)
|
||||
if tasks_generated is None:
|
||||
r.set(rcd.fence_key, fence_value.model_dump_json())
|
||||
|
||||
try:
|
||||
# do not proceed if connector indexing or connector pruning are running
|
||||
search_settings_list = get_all_search_settings(db_session)
|
||||
for search_settings in search_settings_list:
|
||||
rci = RedisConnectorIndexing(cc_pair_id, search_settings.id)
|
||||
if r.get(rci.fence_key):
|
||||
raise TaskDependencyError(
|
||||
f"Connector deletion - Delayed (indexing in progress): "
|
||||
f"cc_pair={cc_pair_id} "
|
||||
f"search_settings={search_settings.id}"
|
||||
)
|
||||
|
||||
rcp = RedisConnectorPruning(cc_pair_id)
|
||||
if r.get(rcp.fence_key):
|
||||
raise TaskDependencyError(
|
||||
f"Connector deletion - Delayed (pruning in progress): "
|
||||
f"cc_pair={cc_pair_id}"
|
||||
)
|
||||
|
||||
# add tasks to celery and build up the task set to monitor in redis
|
||||
r.delete(rcd.taskset_key)
|
||||
|
||||
# Add all documents that need to be updated into the queue
|
||||
task_logger.info(
|
||||
f"RedisConnectorDeletion.generate_tasks starting. cc_pair={cc_pair_id}"
|
||||
)
|
||||
tasks_generated = rcd.generate_tasks(app, db_session, r, lock_beat, tenant_id)
|
||||
if tasks_generated is None:
|
||||
raise ValueError("RedisConnectorDeletion.generate_tasks returned None")
|
||||
except TaskDependencyError:
|
||||
r.delete(rcd.fence_key)
|
||||
raise
|
||||
except Exception:
|
||||
task_logger.exception("Unexpected exception")
|
||||
r.delete(rcd.fence_key)
|
||||
return None
|
||||
else:
|
||||
# Currently we are allowing the sync to proceed with 0 tasks.
|
||||
# It's possible for sets/groups to be generated initially with no entries
|
||||
# and they still need to be marked as up to date.
|
||||
# if tasks_generated == 0:
|
||||
# return 0
|
||||
|
||||
# Currently we are allowing the sync to proceed with 0 tasks.
|
||||
# It's possible for sets/groups to be generated initially with no entries
|
||||
# and they still need to be marked as up to date.
|
||||
# if tasks_generated == 0:
|
||||
# return 0
|
||||
task_logger.info(
|
||||
f"RedisConnectorDeletion.generate_tasks finished. "
|
||||
f"cc_pair={cc_pair_id} tasks_generated={tasks_generated}"
|
||||
)
|
||||
|
||||
task_logger.info(
|
||||
f"RedisConnectorDeletion.generate_tasks finished. "
|
||||
f"cc_pair_id={cc_pair.id} tasks_generated={tasks_generated}"
|
||||
)
|
||||
# set this only after all tasks have been added
|
||||
fence_value.num_tasks = tasks_generated
|
||||
r.set(rcd.fence_key, fence_value.model_dump_json())
|
||||
|
||||
# set this only after all tasks have been added
|
||||
r.set(rcd.fence_key, tasks_generated)
|
||||
return tasks_generated
|
||||
|
@ -5,6 +5,7 @@ from time import sleep
|
||||
from typing import cast
|
||||
from uuid import uuid4
|
||||
|
||||
import redis
|
||||
from celery import Celery
|
||||
from celery import shared_task
|
||||
from celery import Task
|
||||
@ -13,12 +14,15 @@ from redis import Redis
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from danswer.background.celery.apps.app_base import task_logger
|
||||
from danswer.background.celery.celery_redis import RedisConnectorDeletion
|
||||
from danswer.background.celery.celery_redis import RedisConnectorIndexing
|
||||
from danswer.background.celery.celery_redis import RedisConnectorStop
|
||||
from danswer.background.celery.tasks.shared.RedisConnectorIndexingFenceData import (
|
||||
RedisConnectorIndexingFenceData,
|
||||
)
|
||||
from danswer.background.indexing.job_client import SimpleJobClient
|
||||
from danswer.background.indexing.run_indexing import run_indexing_entrypoint
|
||||
from danswer.background.indexing.run_indexing import RunIndexingCallbackInterface
|
||||
from danswer.configs.app_configs import DISABLE_INDEX_UPDATE_ON_SWAP
|
||||
from danswer.configs.constants import CELERY_INDEXING_LOCK_TIMEOUT
|
||||
from danswer.configs.constants import CELERY_VESPA_SYNC_BEAT_LOCK_TIMEOUT
|
||||
@ -50,6 +54,30 @@ from danswer.utils.variable_functionality import global_version
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
class RunIndexingCallback(RunIndexingCallbackInterface):
|
||||
def __init__(
|
||||
self,
|
||||
stop_key: str,
|
||||
generator_progress_key: str,
|
||||
redis_lock: redis.lock.Lock,
|
||||
redis_client: Redis,
|
||||
):
|
||||
super().__init__()
|
||||
self.redis_lock: redis.lock.Lock = redis_lock
|
||||
self.stop_key: str = stop_key
|
||||
self.generator_progress_key: str = generator_progress_key
|
||||
self.redis_client = redis_client
|
||||
|
||||
def should_stop(self) -> bool:
|
||||
if self.redis_client.exists(self.stop_key):
|
||||
return True
|
||||
return False
|
||||
|
||||
def progress(self, amount: int) -> None:
|
||||
self.redis_lock.reacquire()
|
||||
self.redis_client.incrby(self.generator_progress_key, amount)
|
||||
|
||||
|
||||
@shared_task(
|
||||
name="check_for_indexing",
|
||||
soft_time_limit=300,
|
||||
@ -262,6 +290,10 @@ def try_creating_indexing_task(
|
||||
return None
|
||||
|
||||
# skip indexing if the cc_pair is deleting
|
||||
rcd = RedisConnectorDeletion(cc_pair.id)
|
||||
if r.exists(rcd.fence_key):
|
||||
return None
|
||||
|
||||
db_session.refresh(cc_pair)
|
||||
if cc_pair.status == ConnectorCredentialPairStatus.DELETING:
|
||||
return None
|
||||
@ -308,13 +340,8 @@ def try_creating_indexing_task(
|
||||
raise RuntimeError("send_task for connector_indexing_proxy_task failed.")
|
||||
|
||||
# now fill out the fence with the rest of the data
|
||||
fence_value = RedisConnectorIndexingFenceData(
|
||||
index_attempt_id=index_attempt_id,
|
||||
started=None,
|
||||
submitted=datetime.now(timezone.utc),
|
||||
celery_task_id=result.id,
|
||||
)
|
||||
|
||||
fence_value.index_attempt_id = index_attempt_id
|
||||
fence_value.celery_task_id = result.id
|
||||
r.set(rci.fence_key, fence_value.model_dump_json())
|
||||
except Exception:
|
||||
r.delete(rci.fence_key)
|
||||
@ -400,6 +427,22 @@ def connector_indexing_task(
|
||||
|
||||
r = get_redis_client(tenant_id=tenant_id)
|
||||
|
||||
rcd = RedisConnectorDeletion(cc_pair_id)
|
||||
if r.exists(rcd.fence_key):
|
||||
raise RuntimeError(
|
||||
f"Indexing will not start because connector deletion is in progress: "
|
||||
f"cc_pair={cc_pair_id} "
|
||||
f"fence={rcd.fence_key}"
|
||||
)
|
||||
|
||||
rcs = RedisConnectorStop(cc_pair_id)
|
||||
if r.exists(rcs.fence_key):
|
||||
raise RuntimeError(
|
||||
f"Indexing will not start because a connector stop signal was detected: "
|
||||
f"cc_pair={cc_pair_id} "
|
||||
f"fence={rcs.fence_key}"
|
||||
)
|
||||
|
||||
rci = RedisConnectorIndexing(cc_pair_id, search_settings_id)
|
||||
|
||||
while True:
|
||||
@ -409,7 +452,7 @@ def connector_indexing_task(
|
||||
task_logger.info(
|
||||
f"connector_indexing_task: fence_value not found: fence={rci.fence_key}"
|
||||
)
|
||||
raise
|
||||
raise RuntimeError(f"Fence not found: fence={rci.fence_key}")
|
||||
|
||||
try:
|
||||
fence_json = fence_value.decode("utf-8")
|
||||
@ -443,17 +486,20 @@ def connector_indexing_task(
|
||||
if not acquired:
|
||||
task_logger.warning(
|
||||
f"Indexing task already running, exiting...: "
|
||||
f"cc_pair_id={cc_pair_id} search_settings_id={search_settings_id}"
|
||||
f"cc_pair={cc_pair_id} search_settings={search_settings_id}"
|
||||
)
|
||||
# r.set(rci.generator_complete_key, HTTPStatus.CONFLICT.value)
|
||||
return None
|
||||
|
||||
fence_data.started = datetime.now(timezone.utc)
|
||||
r.set(rci.fence_key, fence_data.model_dump_json())
|
||||
|
||||
try:
|
||||
with get_session_with_tenant(tenant_id) as db_session:
|
||||
attempt = get_index_attempt(db_session, index_attempt_id)
|
||||
if not attempt:
|
||||
raise ValueError(
|
||||
f"Index attempt not found: index_attempt_id={index_attempt_id}"
|
||||
f"Index attempt not found: index_attempt={index_attempt_id}"
|
||||
)
|
||||
|
||||
cc_pair = get_connector_credential_pair_from_id(
|
||||
@ -462,31 +508,31 @@ def connector_indexing_task(
|
||||
)
|
||||
|
||||
if not cc_pair:
|
||||
raise ValueError(f"cc_pair not found: cc_pair_id={cc_pair_id}")
|
||||
raise ValueError(f"cc_pair not found: cc_pair={cc_pair_id}")
|
||||
|
||||
if not cc_pair.connector:
|
||||
raise ValueError(
|
||||
f"Connector not found: connector_id={cc_pair.connector_id}"
|
||||
f"Connector not found: cc_pair={cc_pair_id} connector={cc_pair.connector_id}"
|
||||
)
|
||||
|
||||
if not cc_pair.credential:
|
||||
raise ValueError(
|
||||
f"Credential not found: credential_id={cc_pair.credential_id}"
|
||||
f"Credential not found: cc_pair={cc_pair_id} credential={cc_pair.credential_id}"
|
||||
)
|
||||
|
||||
rci = RedisConnectorIndexing(cc_pair_id, search_settings_id)
|
||||
|
||||
# Define the callback function
|
||||
def redis_increment_callback(amount: int) -> None:
|
||||
lock.reacquire()
|
||||
r.incrby(rci.generator_progress_key, amount)
|
||||
# define a callback class
|
||||
callback = RunIndexingCallback(
|
||||
rcs.fence_key, rci.generator_progress_key, lock, r
|
||||
)
|
||||
|
||||
run_indexing_entrypoint(
|
||||
index_attempt_id,
|
||||
tenant_id,
|
||||
cc_pair_id,
|
||||
is_ee,
|
||||
progress_callback=redis_increment_callback,
|
||||
callback=callback,
|
||||
)
|
||||
|
||||
# get back the total number of indexed docs and return it
|
||||
@ -499,9 +545,10 @@ def connector_indexing_task(
|
||||
|
||||
r.set(rci.generator_complete_key, HTTPStatus.OK.value)
|
||||
except Exception as e:
|
||||
task_logger.exception(f"Failed to run indexing for cc_pair_id={cc_pair_id}.")
|
||||
task_logger.exception(f"Indexing failed: cc_pair={cc_pair_id}")
|
||||
if attempt:
|
||||
mark_attempt_failed(attempt, db_session, failure_reason=str(e))
|
||||
with get_session_with_tenant(tenant_id) as db_session:
|
||||
mark_attempt_failed(attempt, db_session, failure_reason=str(e))
|
||||
|
||||
r.delete(rci.generator_lock_key)
|
||||
r.delete(rci.generator_progress_key)
|
||||
|
@ -11,8 +11,11 @@ from redis import Redis
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from danswer.background.celery.apps.app_base import task_logger
|
||||
from danswer.background.celery.celery_redis import RedisConnectorDeletion
|
||||
from danswer.background.celery.celery_redis import RedisConnectorPruning
|
||||
from danswer.background.celery.celery_redis import RedisConnectorStop
|
||||
from danswer.background.celery.celery_utils import extract_ids_from_runnable_connector
|
||||
from danswer.background.celery.tasks.indexing.tasks import RunIndexingCallback
|
||||
from danswer.configs.app_configs import ALLOW_SIMULTANEOUS_PRUNING
|
||||
from danswer.configs.app_configs import JOB_TIMEOUT
|
||||
from danswer.configs.constants import CELERY_PRUNING_LOCK_TIMEOUT
|
||||
@ -168,6 +171,10 @@ def try_creating_prune_generator_task(
|
||||
return None
|
||||
|
||||
# skip pruning if the cc_pair is deleting
|
||||
rcd = RedisConnectorDeletion(cc_pair.id)
|
||||
if r.exists(rcd.fence_key):
|
||||
return None
|
||||
|
||||
db_session.refresh(cc_pair)
|
||||
if cc_pair.status == ConnectorCredentialPairStatus.DELETING:
|
||||
return None
|
||||
@ -234,7 +241,7 @@ def connector_pruning_generator_task(
|
||||
acquired = lock.acquire(blocking=False)
|
||||
if not acquired:
|
||||
task_logger.warning(
|
||||
f"Pruning task already running, exiting...: cc_pair_id={cc_pair_id}"
|
||||
f"Pruning task already running, exiting...: cc_pair={cc_pair_id}"
|
||||
)
|
||||
return None
|
||||
|
||||
@ -252,11 +259,6 @@ def connector_pruning_generator_task(
|
||||
)
|
||||
return
|
||||
|
||||
# Define the callback function
|
||||
def redis_increment_callback(amount: int) -> None:
|
||||
lock.reacquire()
|
||||
r.incrby(rcp.generator_progress_key, amount)
|
||||
|
||||
runnable_connector = instantiate_connector(
|
||||
db_session,
|
||||
cc_pair.connector.source,
|
||||
@ -265,9 +267,14 @@ def connector_pruning_generator_task(
|
||||
cc_pair.credential,
|
||||
)
|
||||
|
||||
rcs = RedisConnectorStop(cc_pair_id)
|
||||
|
||||
callback = RunIndexingCallback(
|
||||
rcs.fence_key, rcp.generator_progress_key, lock, r
|
||||
)
|
||||
# a list of docs in the source
|
||||
all_connector_doc_ids: set[str] = extract_ids_from_runnable_connector(
|
||||
runnable_connector, redis_increment_callback
|
||||
runnable_connector, callback
|
||||
)
|
||||
|
||||
# a list of docs in our local index
|
||||
@ -285,7 +292,7 @@ def connector_pruning_generator_task(
|
||||
|
||||
task_logger.info(
|
||||
f"Pruning set collected: "
|
||||
f"cc_pair_id={cc_pair.id} "
|
||||
f"cc_pair={cc_pair_id} "
|
||||
f"docs_to_remove={len(doc_ids_to_remove)} "
|
||||
f"doc_source={cc_pair.connector.source}"
|
||||
)
|
||||
@ -293,7 +300,7 @@ def connector_pruning_generator_task(
|
||||
rcp.documents_to_prune = set(doc_ids_to_remove)
|
||||
|
||||
task_logger.info(
|
||||
f"RedisConnectorPruning.generate_tasks starting. cc_pair_id={cc_pair.id}"
|
||||
f"RedisConnectorPruning.generate_tasks starting. cc_pair_id={cc_pair_id}"
|
||||
)
|
||||
tasks_generated = rcp.generate_tasks(
|
||||
self.app, db_session, r, None, tenant_id
|
||||
@ -303,12 +310,14 @@ def connector_pruning_generator_task(
|
||||
|
||||
task_logger.info(
|
||||
f"RedisConnectorPruning.generate_tasks finished. "
|
||||
f"cc_pair_id={cc_pair.id} tasks_generated={tasks_generated}"
|
||||
f"cc_pair={cc_pair_id} tasks_generated={tasks_generated}"
|
||||
)
|
||||
|
||||
r.set(rcp.generator_complete_key, tasks_generated)
|
||||
except Exception as e:
|
||||
task_logger.exception(f"Failed to run pruning for connector id {connector_id}.")
|
||||
task_logger.exception(
|
||||
f"Failed to run pruning: cc_pair={cc_pair_id} connector={connector_id}"
|
||||
)
|
||||
|
||||
r.delete(rcp.generator_progress_key)
|
||||
r.delete(rcp.taskset_key)
|
||||
|
@ -0,0 +1,8 @@
|
||||
from datetime import datetime
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
class RedisConnectorDeletionFenceData(BaseModel):
|
||||
num_tasks: int | None
|
||||
submitted: datetime
|
@ -1,9 +1,6 @@
|
||||
from datetime import datetime
|
||||
|
||||
from celery import shared_task
|
||||
from celery import Task
|
||||
from celery.exceptions import SoftTimeLimitExceeded
|
||||
from pydantic import BaseModel
|
||||
|
||||
from danswer.access.access import get_access_for_document
|
||||
from danswer.background.celery.apps.app_base import task_logger
|
||||
@ -23,13 +20,6 @@ from danswer.server.documents.models import ConnectorCredentialPairIdentifier
|
||||
DOCUMENT_BY_CC_PAIR_CLEANUP_MAX_RETRIES = 3
|
||||
|
||||
|
||||
class RedisConnectorIndexingFenceData(BaseModel):
|
||||
index_attempt_id: int | None
|
||||
started: datetime | None
|
||||
submitted: datetime
|
||||
celery_task_id: str | None
|
||||
|
||||
|
||||
@shared_task(
|
||||
name="document_by_cc_pair_cleanup_task",
|
||||
bind=True,
|
||||
|
@ -23,6 +23,9 @@ from danswer.background.celery.celery_redis import RedisConnectorIndexing
|
||||
from danswer.background.celery.celery_redis import RedisConnectorPruning
|
||||
from danswer.background.celery.celery_redis import RedisDocumentSet
|
||||
from danswer.background.celery.celery_redis import RedisUserGroup
|
||||
from danswer.background.celery.tasks.shared.RedisConnectorDeletionFenceData import (
|
||||
RedisConnectorDeletionFenceData,
|
||||
)
|
||||
from danswer.background.celery.tasks.shared.RedisConnectorIndexingFenceData import (
|
||||
RedisConnectorIndexingFenceData,
|
||||
)
|
||||
@ -368,7 +371,7 @@ def monitor_document_set_taskset(
|
||||
|
||||
count = cast(int, r.scard(rds.taskset_key))
|
||||
task_logger.info(
|
||||
f"Document set sync progress: document_set_id={document_set_id} "
|
||||
f"Document set sync progress: document_set={document_set_id} "
|
||||
f"remaining={count} initial={initial_count}"
|
||||
)
|
||||
if count > 0:
|
||||
@ -383,12 +386,12 @@ def monitor_document_set_taskset(
|
||||
# if there are no connectors, then delete the document set.
|
||||
delete_document_set(document_set_row=document_set, db_session=db_session)
|
||||
task_logger.info(
|
||||
f"Successfully deleted document set with ID: '{document_set_id}'!"
|
||||
f"Successfully deleted document set: document_set={document_set_id}"
|
||||
)
|
||||
else:
|
||||
mark_document_set_as_synced(document_set_id, db_session)
|
||||
task_logger.info(
|
||||
f"Successfully synced document set with ID: '{document_set_id}'!"
|
||||
f"Successfully synced document set: document_set={document_set_id}"
|
||||
)
|
||||
|
||||
r.delete(rds.taskset_key)
|
||||
@ -408,19 +411,29 @@ def monitor_connector_deletion_taskset(
|
||||
|
||||
rcd = RedisConnectorDeletion(cc_pair_id)
|
||||
|
||||
fence_value = r.get(rcd.fence_key)
|
||||
# read related data and evaluate/print task progress
|
||||
fence_value = cast(bytes, r.get(rcd.fence_key))
|
||||
if fence_value is None:
|
||||
return
|
||||
|
||||
try:
|
||||
initial_count = int(cast(int, fence_value))
|
||||
fence_json = fence_value.decode("utf-8")
|
||||
fence_data = RedisConnectorDeletionFenceData.model_validate_json(
|
||||
cast(str, fence_json)
|
||||
)
|
||||
except ValueError:
|
||||
task_logger.error("The value is not an integer.")
|
||||
task_logger.exception(
|
||||
"monitor_ccpair_indexing_taskset: fence_data not decodeable."
|
||||
)
|
||||
raise
|
||||
|
||||
# the fence is setting up but isn't ready yet
|
||||
if fence_data.num_tasks is None:
|
||||
return
|
||||
|
||||
count = cast(int, r.scard(rcd.taskset_key))
|
||||
task_logger.info(
|
||||
f"Connector deletion progress: cc_pair={cc_pair_id} remaining={count} initial={initial_count}"
|
||||
f"Connector deletion progress: cc_pair={cc_pair_id} remaining={count} initial={fence_data.num_tasks}"
|
||||
)
|
||||
if count > 0:
|
||||
return
|
||||
@ -483,7 +496,7 @@ def monitor_connector_deletion_taskset(
|
||||
)
|
||||
if not connector or not len(connector.credentials):
|
||||
task_logger.info(
|
||||
"Found no credentials left for connector, deleting connector"
|
||||
"Connector deletion - Found no credentials left for connector, deleting connector"
|
||||
)
|
||||
db_session.delete(connector)
|
||||
db_session.commit()
|
||||
@ -493,17 +506,17 @@ def monitor_connector_deletion_taskset(
|
||||
error_message = f"Error: {str(e)}\n\nStack Trace:\n{stack_trace}"
|
||||
add_deletion_failure_message(db_session, cc_pair_id, error_message)
|
||||
task_logger.exception(
|
||||
f"Failed to run connector_deletion. "
|
||||
f"Connector deletion exceptioned: "
|
||||
f"cc_pair={cc_pair_id} connector={cc_pair.connector_id} credential={cc_pair.credential_id}"
|
||||
)
|
||||
raise e
|
||||
|
||||
task_logger.info(
|
||||
f"Successfully deleted cc_pair: "
|
||||
f"Connector deletion succeeded: "
|
||||
f"cc_pair={cc_pair_id} "
|
||||
f"connector={cc_pair.connector_id} "
|
||||
f"credential={cc_pair.credential_id} "
|
||||
f"docs_deleted={initial_count}"
|
||||
f"docs_deleted={fence_data.num_tasks}"
|
||||
)
|
||||
|
||||
r.delete(rcd.taskset_key)
|
||||
@ -618,6 +631,7 @@ def monitor_ccpair_indexing_taskset(
|
||||
return
|
||||
|
||||
# Read result state BEFORE generator_complete_key to avoid a race condition
|
||||
# never use any blocking methods on the result from inside a task!
|
||||
result: AsyncResult = AsyncResult(fence_data.celery_task_id)
|
||||
result_state = result.state
|
||||
|
||||
|
@ -1,6 +1,7 @@
|
||||
import time
|
||||
import traceback
|
||||
from collections.abc import Callable
|
||||
from abc import ABC
|
||||
from abc import abstractmethod
|
||||
from datetime import datetime
|
||||
from datetime import timedelta
|
||||
from datetime import timezone
|
||||
@ -41,6 +42,19 @@ logger = setup_logger()
|
||||
INDEXING_TRACER_NUM_PRINT_ENTRIES = 5
|
||||
|
||||
|
||||
class RunIndexingCallbackInterface(ABC):
|
||||
"""Defines a callback interface to be passed to
|
||||
to run_indexing_entrypoint."""
|
||||
|
||||
@abstractmethod
|
||||
def should_stop(self) -> bool:
|
||||
"""Signal to stop the looping function in flight."""
|
||||
|
||||
@abstractmethod
|
||||
def progress(self, amount: int) -> None:
|
||||
"""Send progress updates to the caller."""
|
||||
|
||||
|
||||
def _get_connector_runner(
|
||||
db_session: Session,
|
||||
attempt: IndexAttempt,
|
||||
@ -92,7 +106,7 @@ def _run_indexing(
|
||||
db_session: Session,
|
||||
index_attempt: IndexAttempt,
|
||||
tenant_id: str | None,
|
||||
progress_callback: Callable[[int], None] | None = None,
|
||||
callback: RunIndexingCallbackInterface | None = None,
|
||||
) -> None:
|
||||
"""
|
||||
1. Get documents which are either new or updated from specified application
|
||||
@ -206,6 +220,11 @@ def _run_indexing(
|
||||
# index being built. We want to populate it even for paused connectors
|
||||
# Often paused connectors are sources that aren't updated frequently but the
|
||||
# contents still need to be initially pulled.
|
||||
if callback:
|
||||
if callback.should_stop():
|
||||
raise RuntimeError("Connector stop signal detected")
|
||||
|
||||
# TODO: should we move this into the above callback instead?
|
||||
db_session.refresh(db_cc_pair)
|
||||
if (
|
||||
(
|
||||
@ -263,8 +282,8 @@ def _run_indexing(
|
||||
# be inaccurate
|
||||
db_session.commit()
|
||||
|
||||
if progress_callback:
|
||||
progress_callback(len(doc_batch))
|
||||
if callback:
|
||||
callback.progress(len(doc_batch))
|
||||
|
||||
# This new value is updated every batch, so UI can refresh per batch update
|
||||
update_docs_indexed(
|
||||
@ -394,7 +413,7 @@ def run_indexing_entrypoint(
|
||||
tenant_id: str | None,
|
||||
connector_credential_pair_id: int,
|
||||
is_ee: bool = False,
|
||||
progress_callback: Callable[[int], None] | None = None,
|
||||
callback: RunIndexingCallbackInterface | None = None,
|
||||
) -> None:
|
||||
try:
|
||||
if is_ee:
|
||||
@ -417,7 +436,7 @@ def run_indexing_entrypoint(
|
||||
f"credentials='{attempt.connector_credential_pair.connector_id}'"
|
||||
)
|
||||
|
||||
_run_indexing(db_session, attempt, tenant_id, progress_callback)
|
||||
_run_indexing(db_session, attempt, tenant_id, callback)
|
||||
|
||||
logger.info(
|
||||
f"Indexing finished for tenant {tenant_id}: "
|
||||
|
@ -672,6 +672,7 @@ def stream_chat_message_objects(
|
||||
all_docs_useful=selected_db_search_docs is not None
|
||||
),
|
||||
document_pruning_config=document_pruning_config,
|
||||
structured_response_format=new_msg_req.structured_response_format,
|
||||
),
|
||||
prompt_config=prompt_config,
|
||||
llm=(
|
||||
|
@ -237,6 +237,7 @@ class Answer:
|
||||
prompt=prompt,
|
||||
tools=final_tool_definitions if final_tool_definitions else None,
|
||||
tool_choice="required" if self.force_use_tool.force_use else None,
|
||||
structured_response_format=self.answer_style_config.structured_response_format,
|
||||
):
|
||||
if isinstance(message, AIMessageChunk) and (
|
||||
message.tool_call_chunks or message.tool_calls
|
||||
@ -331,7 +332,10 @@ class Answer:
|
||||
tool_choice: ToolChoiceOptions | None = None,
|
||||
) -> Iterator[str | StreamStopInfo]:
|
||||
for message in self.llm.stream(
|
||||
prompt=prompt, tools=tools, tool_choice=tool_choice
|
||||
prompt=prompt,
|
||||
tools=tools,
|
||||
tool_choice=tool_choice,
|
||||
structured_response_format=self.answer_style_config.structured_response_format,
|
||||
):
|
||||
if isinstance(message, AIMessageChunk):
|
||||
if message.content:
|
||||
|
@ -116,6 +116,10 @@ class AnswerStyleConfig(BaseModel):
|
||||
document_pruning_config: DocumentPruningConfig = Field(
|
||||
default_factory=DocumentPruningConfig
|
||||
)
|
||||
# forces the LLM to return a structured response, see
|
||||
# https://platform.openai.com/docs/guides/structured-outputs/introduction
|
||||
# right now, only used by the simple chat API
|
||||
structured_response_format: dict | None = None
|
||||
|
||||
@model_validator(mode="after")
|
||||
def check_quotes_and_citation(self) -> "AnswerStyleConfig":
|
||||
|
@ -280,6 +280,7 @@ class DefaultMultiLLM(LLM):
|
||||
tools: list[dict] | None,
|
||||
tool_choice: ToolChoiceOptions | None,
|
||||
stream: bool,
|
||||
structured_response_format: dict | None = None,
|
||||
) -> litellm.ModelResponse | litellm.CustomStreamWrapper:
|
||||
if isinstance(prompt, list):
|
||||
prompt = [
|
||||
@ -313,6 +314,11 @@ class DefaultMultiLLM(LLM):
|
||||
# NOTE: we can't pass this in if tools are not specified
|
||||
# or else OpenAI throws an error
|
||||
**({"parallel_tool_calls": False} if tools else {}),
|
||||
**(
|
||||
{"response_format": structured_response_format}
|
||||
if structured_response_format
|
||||
else {}
|
||||
),
|
||||
**self._model_kwargs,
|
||||
)
|
||||
except Exception as e:
|
||||
@ -336,12 +342,16 @@ class DefaultMultiLLM(LLM):
|
||||
prompt: LanguageModelInput,
|
||||
tools: list[dict] | None = None,
|
||||
tool_choice: ToolChoiceOptions | None = None,
|
||||
structured_response_format: dict | None = None,
|
||||
) -> BaseMessage:
|
||||
if LOG_DANSWER_MODEL_INTERACTIONS:
|
||||
self.log_model_configs()
|
||||
|
||||
response = cast(
|
||||
litellm.ModelResponse, self._completion(prompt, tools, tool_choice, False)
|
||||
litellm.ModelResponse,
|
||||
self._completion(
|
||||
prompt, tools, tool_choice, False, structured_response_format
|
||||
),
|
||||
)
|
||||
choice = response.choices[0]
|
||||
if hasattr(choice, "message"):
|
||||
@ -354,18 +364,21 @@ class DefaultMultiLLM(LLM):
|
||||
prompt: LanguageModelInput,
|
||||
tools: list[dict] | None = None,
|
||||
tool_choice: ToolChoiceOptions | None = None,
|
||||
structured_response_format: dict | None = None,
|
||||
) -> Iterator[BaseMessage]:
|
||||
if LOG_DANSWER_MODEL_INTERACTIONS:
|
||||
self.log_model_configs()
|
||||
|
||||
if DISABLE_LITELLM_STREAMING:
|
||||
yield self.invoke(prompt)
|
||||
yield self.invoke(prompt, tools, tool_choice, structured_response_format)
|
||||
return
|
||||
|
||||
output = None
|
||||
response = cast(
|
||||
litellm.CustomStreamWrapper,
|
||||
self._completion(prompt, tools, tool_choice, True),
|
||||
self._completion(
|
||||
prompt, tools, tool_choice, True, structured_response_format
|
||||
),
|
||||
)
|
||||
try:
|
||||
for part in response:
|
||||
|
@ -80,6 +80,7 @@ class CustomModelServer(LLM):
|
||||
prompt: LanguageModelInput,
|
||||
tools: list[dict] | None = None,
|
||||
tool_choice: ToolChoiceOptions | None = None,
|
||||
structured_response_format: dict | None = None,
|
||||
) -> BaseMessage:
|
||||
return self._execute(prompt)
|
||||
|
||||
@ -88,5 +89,6 @@ class CustomModelServer(LLM):
|
||||
prompt: LanguageModelInput,
|
||||
tools: list[dict] | None = None,
|
||||
tool_choice: ToolChoiceOptions | None = None,
|
||||
structured_response_format: dict | None = None,
|
||||
) -> Iterator[BaseMessage]:
|
||||
yield self._execute(prompt)
|
||||
|
@ -88,11 +88,14 @@ class LLM(abc.ABC):
|
||||
prompt: LanguageModelInput,
|
||||
tools: list[dict] | None = None,
|
||||
tool_choice: ToolChoiceOptions | None = None,
|
||||
structured_response_format: dict | None = None,
|
||||
) -> BaseMessage:
|
||||
self._precall(prompt)
|
||||
# TODO add a postcall to log model outputs independent of concrete class
|
||||
# implementation
|
||||
return self._invoke_implementation(prompt, tools, tool_choice)
|
||||
return self._invoke_implementation(
|
||||
prompt, tools, tool_choice, structured_response_format
|
||||
)
|
||||
|
||||
@abc.abstractmethod
|
||||
def _invoke_implementation(
|
||||
@ -100,6 +103,7 @@ class LLM(abc.ABC):
|
||||
prompt: LanguageModelInput,
|
||||
tools: list[dict] | None = None,
|
||||
tool_choice: ToolChoiceOptions | None = None,
|
||||
structured_response_format: dict | None = None,
|
||||
) -> BaseMessage:
|
||||
raise NotImplementedError
|
||||
|
||||
@ -108,11 +112,14 @@ class LLM(abc.ABC):
|
||||
prompt: LanguageModelInput,
|
||||
tools: list[dict] | None = None,
|
||||
tool_choice: ToolChoiceOptions | None = None,
|
||||
structured_response_format: dict | None = None,
|
||||
) -> Iterator[BaseMessage]:
|
||||
self._precall(prompt)
|
||||
# TODO add a postcall to log model outputs independent of concrete class
|
||||
# implementation
|
||||
return self._stream_implementation(prompt, tools, tool_choice)
|
||||
return self._stream_implementation(
|
||||
prompt, tools, tool_choice, structured_response_format
|
||||
)
|
||||
|
||||
@abc.abstractmethod
|
||||
def _stream_implementation(
|
||||
@ -120,5 +127,6 @@ class LLM(abc.ABC):
|
||||
prompt: LanguageModelInput,
|
||||
tools: list[dict] | None = None,
|
||||
tool_choice: ToolChoiceOptions | None = None,
|
||||
structured_response_format: dict | None = None,
|
||||
) -> Iterator[BaseMessage]:
|
||||
raise NotImplementedError
|
||||
|
@ -19,7 +19,6 @@ from danswer.db.connector_credential_pair import get_connector_credential_pair
|
||||
from danswer.db.connector_credential_pair import (
|
||||
update_connector_credential_pair_from_id,
|
||||
)
|
||||
from danswer.db.deletion_attempt import check_deletion_attempt_is_allowed
|
||||
from danswer.db.engine import get_current_tenant_id
|
||||
from danswer.db.engine import get_session
|
||||
from danswer.db.enums import ConnectorCredentialPairStatus
|
||||
@ -175,15 +174,19 @@ def create_deletion_attempt_for_connector_id(
|
||||
cc_pair_id=cc_pair.id, db_session=db_session, include_secondary_index=True
|
||||
)
|
||||
|
||||
# TODO(rkuo): 2024-10-24 - check_deletion_attempt_is_allowed shouldn't be necessary
|
||||
# any more due to background locking improvements.
|
||||
# Remove the below permanently if everything is behaving for 30 days.
|
||||
|
||||
# Check if the deletion attempt should be allowed
|
||||
deletion_attempt_disallowed_reason = check_deletion_attempt_is_allowed(
|
||||
connector_credential_pair=cc_pair, db_session=db_session
|
||||
)
|
||||
if deletion_attempt_disallowed_reason:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=deletion_attempt_disallowed_reason,
|
||||
)
|
||||
# deletion_attempt_disallowed_reason = check_deletion_attempt_is_allowed(
|
||||
# connector_credential_pair=cc_pair, db_session=db_session
|
||||
# )
|
||||
# if deletion_attempt_disallowed_reason:
|
||||
# raise HTTPException(
|
||||
# status_code=400,
|
||||
# detail=deletion_attempt_disallowed_reason,
|
||||
# )
|
||||
|
||||
# mark as deleting
|
||||
update_connector_credential_pair_from_id(
|
||||
|
@ -108,6 +108,10 @@ class CreateChatMessageRequest(ChunkContext):
|
||||
# used for seeded chats to kick off the generation of an AI answer
|
||||
use_existing_user_message: bool = False
|
||||
|
||||
# forces the LLM to return a structured response, see
|
||||
# https://platform.openai.com/docs/guides/structured-outputs/introduction
|
||||
structured_response_format: dict | None = None
|
||||
|
||||
@model_validator(mode="after")
|
||||
def check_search_doc_ids_or_retrieval_options(self) -> "CreateChatMessageRequest":
|
||||
if self.search_doc_ids is None and self.retrieval_options is None:
|
||||
|
@ -176,6 +176,7 @@ def handle_simplified_chat_message(
|
||||
chunks_above=0,
|
||||
chunks_below=0,
|
||||
full_doc=chat_message_req.full_doc,
|
||||
structured_response_format=chat_message_req.structured_response_format,
|
||||
)
|
||||
|
||||
packets = stream_chat_message_objects(
|
||||
@ -202,7 +203,7 @@ def handle_send_message_simple_with_history(
|
||||
raise HTTPException(status_code=400, detail="Messages cannot be zero length")
|
||||
|
||||
# This is a sanity check to make sure the chat history is valid
|
||||
# It must start with a user message and alternate between user and assistant
|
||||
# It must start with a user message and alternate beteen user and assistant
|
||||
expected_role = MessageType.USER
|
||||
for msg in req.messages:
|
||||
if not msg.message:
|
||||
@ -296,6 +297,7 @@ def handle_send_message_simple_with_history(
|
||||
chunks_above=0,
|
||||
chunks_below=0,
|
||||
full_doc=req.full_doc,
|
||||
structured_response_format=req.structured_response_format,
|
||||
)
|
||||
|
||||
packets = stream_chat_message_objects(
|
||||
|
@ -48,6 +48,9 @@ class BasicCreateChatMessageRequest(ChunkContext):
|
||||
query_override: str | None = None
|
||||
# If search_doc_ids provided, then retrieval options are unused
|
||||
search_doc_ids: list[int] | None = None
|
||||
# only works if using an OpenAI model. See the following for more details:
|
||||
# https://platform.openai.com/docs/guides/structured-outputs/introduction
|
||||
structured_response_format: dict | None = None
|
||||
|
||||
|
||||
class BasicCreateChatMessageWithHistoryRequest(ChunkContext):
|
||||
@ -60,6 +63,9 @@ class BasicCreateChatMessageWithHistoryRequest(ChunkContext):
|
||||
skip_rerank: bool | None = None
|
||||
# If search_doc_ids provided, then retrieval options are unused
|
||||
search_doc_ids: list[int] | None = None
|
||||
# only works if using an OpenAI model. See the following for more details:
|
||||
# https://platform.openai.com/docs/guides/structured-outputs/introduction
|
||||
structured_response_format: dict | None = None
|
||||
|
||||
|
||||
class SimpleDoc(BaseModel):
|
||||
|
148
backend/scripts/add_connector_creation_script.py
Normal file
148
backend/scripts/add_connector_creation_script.py
Normal file
@ -0,0 +1,148 @@
|
||||
from typing import Any
|
||||
from typing import Dict
|
||||
|
||||
import requests
|
||||
|
||||
API_SERVER_URL = "http://localhost:3000" # Adjust this to your Danswer server URL
|
||||
HEADERS = {"Content-Type": "application/json"}
|
||||
API_KEY = "danswer-api-key" # API key here, if auth is enabled
|
||||
|
||||
|
||||
def create_connector(
|
||||
name: str,
|
||||
source: str,
|
||||
input_type: str,
|
||||
connector_specific_config: Dict[str, Any],
|
||||
is_public: bool = True,
|
||||
groups: list[int] | None = None,
|
||||
) -> Dict[str, Any]:
|
||||
connector_update_request = {
|
||||
"name": name,
|
||||
"source": source,
|
||||
"input_type": input_type,
|
||||
"connector_specific_config": connector_specific_config,
|
||||
"is_public": is_public,
|
||||
"groups": groups or [],
|
||||
}
|
||||
|
||||
response = requests.post(
|
||||
url=f"{API_SERVER_URL}/api/manage/admin/connector",
|
||||
json=connector_update_request,
|
||||
headers=HEADERS,
|
||||
)
|
||||
response.raise_for_status()
|
||||
return response.json()
|
||||
|
||||
|
||||
def create_credential(
|
||||
name: str,
|
||||
source: str,
|
||||
credential_json: Dict[str, Any],
|
||||
is_public: bool = True,
|
||||
groups: list[int] | None = None,
|
||||
) -> Dict[str, Any]:
|
||||
credential_request = {
|
||||
"name": name,
|
||||
"source": source,
|
||||
"credential_json": credential_json,
|
||||
"admin_public": is_public,
|
||||
"groups": groups or [],
|
||||
}
|
||||
|
||||
response = requests.post(
|
||||
url=f"{API_SERVER_URL}/api/manage/credential",
|
||||
json=credential_request,
|
||||
headers=HEADERS,
|
||||
)
|
||||
response.raise_for_status()
|
||||
return response.json()
|
||||
|
||||
|
||||
def create_cc_pair(
|
||||
connector_id: int,
|
||||
credential_id: int,
|
||||
name: str,
|
||||
access_type: str = "public",
|
||||
groups: list[int] | None = None,
|
||||
) -> Dict[str, Any]:
|
||||
cc_pair_request = {
|
||||
"name": name,
|
||||
"access_type": access_type,
|
||||
"groups": groups or [],
|
||||
}
|
||||
|
||||
response = requests.put(
|
||||
url=f"{API_SERVER_URL}/api/manage/connector/{connector_id}/credential/{credential_id}",
|
||||
json=cc_pair_request,
|
||||
headers=HEADERS,
|
||||
)
|
||||
response.raise_for_status()
|
||||
return response.json()
|
||||
|
||||
|
||||
def main() -> None:
|
||||
# Create a Web connector
|
||||
web_connector = create_connector(
|
||||
name="Example Web Connector",
|
||||
source="web",
|
||||
input_type="load_state",
|
||||
connector_specific_config={
|
||||
"base_url": "https://example.com",
|
||||
"web_connector_type": "recursive",
|
||||
},
|
||||
)
|
||||
print(f"Created Web Connector: {web_connector}")
|
||||
|
||||
# Create a credential for the Web connector
|
||||
web_credential = create_credential(
|
||||
name="Example Web Credential",
|
||||
source="web",
|
||||
credential_json={}, # Web connectors typically don't need credentials
|
||||
is_public=True,
|
||||
)
|
||||
print(f"Created Web Credential: {web_credential}")
|
||||
|
||||
# Create CC pair for Web connector
|
||||
web_cc_pair = create_cc_pair(
|
||||
connector_id=web_connector["id"],
|
||||
credential_id=web_credential["id"],
|
||||
name="Example Web CC Pair",
|
||||
access_type="public",
|
||||
)
|
||||
print(f"Created Web CC Pair: {web_cc_pair}")
|
||||
|
||||
# Create a GitHub connector
|
||||
github_connector = create_connector(
|
||||
name="Example GitHub Connector",
|
||||
source="github",
|
||||
input_type="poll",
|
||||
connector_specific_config={
|
||||
"repo_owner": "example-owner",
|
||||
"repo_name": "example-repo",
|
||||
"include_prs": True,
|
||||
"include_issues": True,
|
||||
},
|
||||
)
|
||||
print(f"Created GitHub Connector: {github_connector}")
|
||||
|
||||
# Create a credential for the GitHub connector
|
||||
github_credential = create_credential(
|
||||
name="Example GitHub Credential",
|
||||
source="github",
|
||||
credential_json={"github_access_token": "your_github_access_token_here"},
|
||||
is_public=True,
|
||||
)
|
||||
print(f"Created GitHub Credential: {github_credential}")
|
||||
|
||||
# Create CC pair for GitHub connector
|
||||
github_cc_pair = create_cc_pair(
|
||||
connector_id=github_connector["id"],
|
||||
credential_id=github_credential["id"],
|
||||
name="Example GitHub CC Pair",
|
||||
access_type="public",
|
||||
)
|
||||
print(f"Created GitHub CC Pair: {github_cc_pair}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
@ -6,7 +6,9 @@ from sqlalchemy.orm import Session
|
||||
|
||||
from danswer.db.engine import get_session_context_manager
|
||||
from danswer.db.search_settings import get_current_search_settings
|
||||
from tests.integration.common_utils.managers.user import UserManager
|
||||
from tests.integration.common_utils.reset import reset_all
|
||||
from tests.integration.common_utils.test_models import DATestUser
|
||||
from tests.integration.common_utils.vespa import vespa_fixture
|
||||
|
||||
|
||||
@ -44,3 +46,11 @@ def vespa_client(db_session: Session) -> vespa_fixture:
|
||||
@pytest.fixture
|
||||
def reset() -> None:
|
||||
reset_all()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def new_admin_user(reset: None) -> DATestUser | None:
|
||||
try:
|
||||
return UserManager.create(name="admin_user")
|
||||
except Exception:
|
||||
return None
|
||||
|
@ -1,7 +1,10 @@
|
||||
import json
|
||||
|
||||
import requests
|
||||
|
||||
from danswer.configs.constants import MessageType
|
||||
from tests.integration.common_utils.constants import API_SERVER_URL
|
||||
from tests.integration.common_utils.constants import GENERAL_HEADERS
|
||||
from tests.integration.common_utils.constants import NUM_DOCS
|
||||
from tests.integration.common_utils.managers.api_key import APIKeyManager
|
||||
from tests.integration.common_utils.managers.cc_pair import CCPairManager
|
||||
@ -145,3 +148,85 @@ def test_using_reference_docs_with_simple_with_history_api_flow(reset: None) ->
|
||||
# This ensures the the document we think we are referencing when we send the search_doc_ids in the second
|
||||
# message is the document that we expect it to be
|
||||
assert response_json["top_documents"][0]["document_id"] == cc_pair_1.documents[2].id
|
||||
|
||||
|
||||
def test_send_message_simple_with_history_strict_json(
|
||||
new_admin_user: DATestUser | None,
|
||||
) -> None:
|
||||
# create connectors
|
||||
LLMProviderManager.create(user_performing_action=new_admin_user)
|
||||
|
||||
response = requests.post(
|
||||
f"{API_SERVER_URL}/chat/send-message-simple-with-history",
|
||||
json={
|
||||
# intentionally not relevant prompt to ensure that the
|
||||
# structured response format is actually used
|
||||
"messages": [
|
||||
{
|
||||
"message": "What is green?",
|
||||
"role": MessageType.USER.value,
|
||||
}
|
||||
],
|
||||
"persona_id": 0,
|
||||
"prompt_id": 0,
|
||||
"structured_response_format": {
|
||||
"type": "json_schema",
|
||||
"json_schema": {
|
||||
"name": "presidents",
|
||||
"schema": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"presidents": {
|
||||
"type": "array",
|
||||
"items": {"type": "string"},
|
||||
"description": "List of the first three US presidents",
|
||||
}
|
||||
},
|
||||
"required": ["presidents"],
|
||||
"additionalProperties": False,
|
||||
},
|
||||
"strict": True,
|
||||
},
|
||||
},
|
||||
},
|
||||
headers=new_admin_user.headers if new_admin_user else GENERAL_HEADERS,
|
||||
)
|
||||
assert response.status_code == 200
|
||||
|
||||
response_json = response.json()
|
||||
|
||||
# Check that the answer is present
|
||||
assert "answer" in response_json
|
||||
assert response_json["answer"] is not None
|
||||
|
||||
# helper
|
||||
def clean_json_string(json_string: str) -> str:
|
||||
return json_string.strip().removeprefix("```json").removesuffix("```").strip()
|
||||
|
||||
# Attempt to parse the answer as JSON
|
||||
try:
|
||||
clean_answer = clean_json_string(response_json["answer"])
|
||||
parsed_answer = json.loads(clean_answer)
|
||||
|
||||
# NOTE: do not check content, just the structure
|
||||
assert isinstance(parsed_answer, dict)
|
||||
assert "presidents" in parsed_answer
|
||||
assert isinstance(parsed_answer["presidents"], list)
|
||||
for president in parsed_answer["presidents"]:
|
||||
assert isinstance(president, str)
|
||||
except json.JSONDecodeError:
|
||||
assert (
|
||||
False
|
||||
), f"The answer is not a valid JSON object - '{response_json['answer']}'"
|
||||
|
||||
# Check that the answer_citationless is also valid JSON
|
||||
assert "answer_citationless" in response_json
|
||||
assert response_json["answer_citationless"] is not None
|
||||
try:
|
||||
clean_answer_citationless = clean_json_string(
|
||||
response_json["answer_citationless"]
|
||||
)
|
||||
parsed_answer_citationless = json.loads(clean_answer_citationless)
|
||||
assert isinstance(parsed_answer_citationless, dict)
|
||||
except json.JSONDecodeError:
|
||||
assert False, "The answer_citationless is not a valid JSON object"
|
||||
|
@ -173,7 +173,6 @@ export function ProviderCreationModal({
|
||||
|
||||
return (
|
||||
<Modal
|
||||
width="max-w-3xl"
|
||||
title={`Configure ${selectedProvider.provider_type}`}
|
||||
onOutsideClick={onCancel}
|
||||
icon={selectedProvider.icon}
|
||||
|
@ -1,12 +1,12 @@
|
||||
import React from "react";
|
||||
import { Formik, Form } from "formik";
|
||||
import * as Yup from "yup";
|
||||
import { ModalWrapper } from "@/components/modals/ModalWrapper";
|
||||
import { Button } from "@tremor/react";
|
||||
|
||||
import { BookstackIcon } from "@/components/icons/icons";
|
||||
import { AddPromptModalProps } from "../interfaces";
|
||||
import { TextFormField } from "@/components/admin/connectors/Field";
|
||||
import { Modal } from "@/components/Modal";
|
||||
|
||||
const AddPromptSchema = Yup.object().shape({
|
||||
title: Yup.string().required("Title is required"),
|
||||
@ -15,7 +15,7 @@ const AddPromptSchema = Yup.object().shape({
|
||||
|
||||
const AddPromptModal = ({ onClose, onSubmit }: AddPromptModalProps) => {
|
||||
return (
|
||||
<ModalWrapper onClose={onClose} modalClassName="max-w-xl">
|
||||
<Modal onOutsideClick={onClose} width="max-w-xl">
|
||||
<Formik
|
||||
initialValues={{
|
||||
title: "",
|
||||
@ -57,7 +57,7 @@ const AddPromptModal = ({ onClose, onSubmit }: AddPromptModalProps) => {
|
||||
</Form>
|
||||
)}
|
||||
</Formik>
|
||||
</ModalWrapper>
|
||||
</Modal>
|
||||
);
|
||||
};
|
||||
|
||||
|
@ -1,8 +1,9 @@
|
||||
import React from "react";
|
||||
import { Formik, Form, Field, ErrorMessage } from "formik";
|
||||
import * as Yup from "yup";
|
||||
import { ModalWrapper } from "@/components/modals/ModalWrapper";
|
||||
import { Modal } from "@/components/Modal";
|
||||
import { Button, Textarea, TextInput } from "@tremor/react";
|
||||
|
||||
import { useInputPrompt } from "../hooks";
|
||||
import { EditPromptModalProps } from "../interfaces";
|
||||
|
||||
@ -25,20 +26,20 @@ const EditPromptModal = ({
|
||||
|
||||
if (error)
|
||||
return (
|
||||
<ModalWrapper onClose={onClose} modalClassName="max-w-xl">
|
||||
<Modal onOutsideClick={onClose} width="max-w-xl">
|
||||
<p>Failed to load prompt data</p>
|
||||
</ModalWrapper>
|
||||
</Modal>
|
||||
);
|
||||
|
||||
if (!promptData)
|
||||
return (
|
||||
<ModalWrapper onClose={onClose} modalClassName="max-w-xl">
|
||||
<Modal onOutsideClick={onClose} width="max-w-xl">
|
||||
<p>Loading...</p>
|
||||
</ModalWrapper>
|
||||
</Modal>
|
||||
);
|
||||
|
||||
return (
|
||||
<ModalWrapper onClose={onClose} modalClassName="max-w-xl">
|
||||
<Modal onOutsideClick={onClose} width="max-w-xl">
|
||||
<Formik
|
||||
initialValues={{
|
||||
prompt: promptData.prompt,
|
||||
@ -131,7 +132,7 @@ const EditPromptModal = ({
|
||||
</Form>
|
||||
)}
|
||||
</Formik>
|
||||
</ModalWrapper>
|
||||
</Modal>
|
||||
);
|
||||
};
|
||||
|
||||
|
@ -2,7 +2,7 @@
|
||||
|
||||
import { useState } from "react";
|
||||
import { FeedbackType } from "../types";
|
||||
import { ModalWrapper } from "@/components/modals/ModalWrapper";
|
||||
import { Modal } from "@/components/Modal";
|
||||
import { FilledLikeIcon } from "@/components/icons/icons";
|
||||
|
||||
const predefinedPositiveFeedbackOptions =
|
||||
@ -49,7 +49,7 @@ export const FeedbackModal = ({
|
||||
: predefinedNegativeFeedbackOptions;
|
||||
|
||||
return (
|
||||
<ModalWrapper onClose={onClose} modalClassName="max-w-3xl">
|
||||
<Modal onOutsideClick={onClose} width="max-w-3xl">
|
||||
<>
|
||||
<h2 className="text-2xl text-emphasis font-bold mb-4 flex">
|
||||
<div className="mr-1 my-auto">
|
||||
@ -112,6 +112,6 @@ export const FeedbackModal = ({
|
||||
</button>
|
||||
</div>
|
||||
</>
|
||||
</ModalWrapper>
|
||||
</Modal>
|
||||
);
|
||||
};
|
||||
|
@ -1,4 +1,4 @@
|
||||
import { ModalWrapper } from "@/components/modals/ModalWrapper";
|
||||
import { Modal } from "@/components/Modal";
|
||||
import { Button, Divider, Text } from "@tremor/react";
|
||||
|
||||
export function MakePublicAssistantModal({
|
||||
@ -11,7 +11,7 @@ export function MakePublicAssistantModal({
|
||||
onClose: () => void;
|
||||
}) {
|
||||
return (
|
||||
<ModalWrapper onClose={onClose} modalClassName="max-w-3xl">
|
||||
<Modal onOutsideClick={onClose} width="max-w-3xl">
|
||||
<div className="space-y-6">
|
||||
<h2 className="text-2xl font-bold text-emphasis">
|
||||
{isPublic ? "Public Assistant" : "Make Assistant Public"}
|
||||
@ -67,6 +67,6 @@ export function MakePublicAssistantModal({
|
||||
</div>
|
||||
)}
|
||||
</div>
|
||||
</ModalWrapper>
|
||||
</Modal>
|
||||
);
|
||||
}
|
||||
|
@ -1,5 +1,5 @@
|
||||
import { Dispatch, SetStateAction, useEffect, useRef } from "react";
|
||||
import { ModalWrapper } from "@/components/modals/ModalWrapper";
|
||||
import { Modal } from "@/components/Modal";
|
||||
import { Text } from "@tremor/react";
|
||||
import { getDisplayNameForModel, LlmOverride } from "@/lib/hooks";
|
||||
import { LLMProviderDescriptor } from "@/app/admin/configuration/llm/interfaces";
|
||||
@ -123,10 +123,7 @@ export function SetDefaultModelModal({
|
||||
);
|
||||
|
||||
return (
|
||||
<ModalWrapper
|
||||
onClose={onClose}
|
||||
modalClassName="rounded-lg bg-white max-w-xl"
|
||||
>
|
||||
<Modal onOutsideClick={onClose} width="rounded-lg bg-white max-w-xl">
|
||||
<>
|
||||
<div className="flex mb-4">
|
||||
<h2 className="text-2xl text-emphasis font-bold flex my-auto">
|
||||
@ -203,6 +200,6 @@ export function SetDefaultModelModal({
|
||||
</div>
|
||||
</div>
|
||||
</>
|
||||
</ModalWrapper>
|
||||
</Modal>
|
||||
);
|
||||
}
|
||||
|
@ -1,5 +1,5 @@
|
||||
import { useState } from "react";
|
||||
import { ModalWrapper } from "@/components/modals/ModalWrapper";
|
||||
import { Modal } from "@/components/Modal";
|
||||
import { Button, Callout, Divider, Text } from "@tremor/react";
|
||||
import { Spinner } from "@/components/Spinner";
|
||||
import { ChatSessionSharedStatus } from "../interfaces";
|
||||
@ -57,7 +57,7 @@ export function ShareChatSessionModal({
|
||||
);
|
||||
|
||||
return (
|
||||
<ModalWrapper onClose={onClose} modalClassName="max-w-3xl">
|
||||
<Modal onOutsideClick={onClose} width="max-w-3xl">
|
||||
<>
|
||||
<div className="flex mb-4">
|
||||
<h2 className="text-2xl text-emphasis font-bold flex my-auto">
|
||||
@ -154,6 +154,6 @@ export function ShareChatSessionModal({
|
||||
)}
|
||||
</div>
|
||||
</>
|
||||
</ModalWrapper>
|
||||
</Modal>
|
||||
);
|
||||
}
|
||||
|
@ -54,9 +54,9 @@ export function Modal({
|
||||
e.stopPropagation();
|
||||
}
|
||||
}}
|
||||
className={`bg-background text-emphasis rounded shadow-2xl
|
||||
className={`bg-background text-emphasis rounded shadow-2xl
|
||||
transform transition-all duration-300 ease-in-out
|
||||
${width ?? "w-11/12 max-w-5xl"}
|
||||
${width ?? "w-11/12 max-w-4xl"}
|
||||
${noPadding ? "" : "p-10"}
|
||||
${className || ""}`}
|
||||
>
|
||||
@ -88,7 +88,7 @@ export function Modal({
|
||||
{!hideDividerForTitle && <Divider />}
|
||||
</>
|
||||
)}
|
||||
{children}
|
||||
<div className="max-h-[60vh] overflow-y-scroll">{children}</div>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
|
@ -1,6 +1,6 @@
|
||||
import { FiTrash, FiX } from "react-icons/fi";
|
||||
import { ModalWrapper } from "@/components/modals/ModalWrapper";
|
||||
import { BasicClickable } from "@/components/BasicClickable";
|
||||
import { Modal } from "../Modal";
|
||||
|
||||
export const DeleteEntityModal = ({
|
||||
onClose,
|
||||
@ -16,7 +16,7 @@ export const DeleteEntityModal = ({
|
||||
additionalDetails?: string;
|
||||
}) => {
|
||||
return (
|
||||
<ModalWrapper onClose={onClose}>
|
||||
<Modal onOutsideClick={onClose}>
|
||||
<>
|
||||
<div className="flex mb-4">
|
||||
<h2 className="my-auto text-2xl font-bold">Delete {entityType}?</h2>
|
||||
@ -37,6 +37,6 @@ export const DeleteEntityModal = ({
|
||||
</div>
|
||||
</div>
|
||||
</>
|
||||
</ModalWrapper>
|
||||
</Modal>
|
||||
);
|
||||
};
|
||||
|
@ -1,5 +1,5 @@
|
||||
import { FiCheck } from "react-icons/fi";
|
||||
import { ModalWrapper } from "./ModalWrapper";
|
||||
import { Modal } from "@/components/Modal";
|
||||
import { BasicClickable } from "@/components/BasicClickable";
|
||||
|
||||
export const GenericConfirmModal = ({
|
||||
@ -16,7 +16,7 @@ export const GenericConfirmModal = ({
|
||||
onConfirm: () => void;
|
||||
}) => {
|
||||
return (
|
||||
<ModalWrapper onClose={onClose}>
|
||||
<Modal onOutsideClick={onClose}>
|
||||
<div className="max-w-full">
|
||||
<div className="flex mb-4">
|
||||
<h2 className="my-auto text-2xl font-bold whitespace-normal overflow-wrap-normal">
|
||||
@ -37,6 +37,6 @@ export const GenericConfirmModal = ({
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
</ModalWrapper>
|
||||
</Modal>
|
||||
);
|
||||
};
|
||||
|
@ -1,63 +0,0 @@
|
||||
"use client";
|
||||
import { XIcon } from "@/components/icons/icons";
|
||||
import { isEventWithinRef } from "@/lib/contains";
|
||||
import { useRef } from "react";
|
||||
|
||||
export const ModalWrapper = ({
|
||||
children,
|
||||
bgClassName,
|
||||
modalClassName,
|
||||
onClose,
|
||||
}: {
|
||||
children: JSX.Element;
|
||||
bgClassName?: string;
|
||||
modalClassName?: string;
|
||||
onClose?: () => void;
|
||||
}) => {
|
||||
const modalRef = useRef<HTMLDivElement>(null);
|
||||
|
||||
const handleMouseDown = (e: React.MouseEvent<HTMLDivElement>) => {
|
||||
if (
|
||||
onClose &&
|
||||
modalRef.current &&
|
||||
!modalRef.current.contains(e.target as Node) &&
|
||||
!isEventWithinRef(e.nativeEvent, modalRef)
|
||||
) {
|
||||
onClose();
|
||||
}
|
||||
};
|
||||
return (
|
||||
<div
|
||||
onMouseDown={handleMouseDown}
|
||||
className={`fixed inset-0 bg-black bg-opacity-25 backdrop-blur-sm
|
||||
flex items-center justify-center z-50 transition-opacity duration-300 ease-in-out
|
||||
${bgClassName || ""}`}
|
||||
>
|
||||
<div
|
||||
ref={modalRef}
|
||||
onClick={(e) => {
|
||||
if (onClose) {
|
||||
e.stopPropagation();
|
||||
}
|
||||
}}
|
||||
className={`bg-background text-emphasis p-10 rounded shadow-2xl
|
||||
w-11/12 max-w-3xl transform transition-all duration-300 ease-in-out
|
||||
relative ${modalClassName || ""}`}
|
||||
>
|
||||
{onClose && (
|
||||
<div className="absolute top-2 right-2">
|
||||
<button
|
||||
onClick={onClose}
|
||||
className="cursor-pointer text-text-500 hover:text-text-700 transition-colors duration-200 p-2"
|
||||
aria-label="Close modal"
|
||||
>
|
||||
<XIcon className="w-5 h-5" />
|
||||
</button>
|
||||
</div>
|
||||
)}
|
||||
|
||||
<div className="flex w-full flex-col justify-stretch">{children}</div>
|
||||
</div>
|
||||
</div>
|
||||
);
|
||||
};
|
@ -1,8 +1,8 @@
|
||||
import { ModalWrapper } from "@/components/modals/ModalWrapper";
|
||||
import { Modal } from "@/components/Modal";
|
||||
|
||||
export const NoAssistantModal = ({ isAdmin }: { isAdmin: boolean }) => {
|
||||
return (
|
||||
<ModalWrapper modalClassName="bg-white max-w-2xl rounded-lg shadow-xl text-center">
|
||||
<Modal width="bg-white max-w-2xl rounded-lg shadow-xl text-center">
|
||||
<>
|
||||
<h2 className="text-3xl font-bold text-gray-800 mb-4">
|
||||
No Assistant Available
|
||||
@ -32,6 +32,6 @@ export const NoAssistantModal = ({ isAdmin }: { isAdmin: boolean }) => {
|
||||
</p>
|
||||
)}
|
||||
</>
|
||||
</ModalWrapper>
|
||||
</Modal>
|
||||
);
|
||||
};
|
||||
|
Loading…
x
Reference in New Issue
Block a user