Merge branch 'main' of https://github.com/danswer-ai/danswer into feature/reset_indexes

This commit is contained in:
Richard Kuo (Danswer) 2024-10-25 12:00:25 -07:00
commit 0ed77aa8a7
38 changed files with 670 additions and 215 deletions

View File

@ -26,4 +26,4 @@ N/A
## Backporting (check the box to trigger backport action)
Note: You have to check that the action passes, otherwise resolve the conflicts manually and tag the patches.
[ ] This PR should be backported (make sure to check that the backport attempt succeeds)
- [ ] This PR should be backported (make sure to check that the backport attempt succeeds)

View File

@ -10,43 +10,83 @@ jobs:
runs-on: ubuntu-latest
steps:
- name: Checkout code
uses: actions/checkout@v2
uses: actions/checkout@v3
with:
fetch-depth: 0 # Fetch all history for all branches and tags
- name: Set up Git
run: |
git config user.name "github-actions[bot]"
git config user.email "github-actions[bot]@users.noreply.github.com"
- name: Check for Backport Checkbox
id: checkbox-check
run: |
PR_BODY="${{ github.event.pull_request.body }}"
if [[ "$PR_BODY" == *"[x] This PR should be backported"* ]]; then
echo "::set-output name=backport::true"
echo "backport=true" >> $GITHUB_OUTPUT
else
echo "::set-output name=backport::false"
echo "backport=false" >> $GITHUB_OUTPUT
fi
- name: List and sort release branches
id: list-branches
run: |
git fetch --all
BRANCHES=$(git branch -r | grep 'origin/release/v' | sed 's|origin/release/v||' | sort -Vr)
git fetch --all --tags
BRANCHES=$(git for-each-ref --format='%(refname:short)' refs/remotes/origin/release/* | sed 's|origin/release/||' | sort -Vr)
BETA=$(echo "$BRANCHES" | head -n 1)
STABLE=$(echo "$BRANCHES" | head -n 2 | tail -n 1)
echo "::set-output name=beta::$BETA"
echo "::set-output name=stable::$STABLE"
echo "beta=$BETA" >> $GITHUB_OUTPUT
echo "stable=$STABLE" >> $GITHUB_OUTPUT
# Fetch latest tags for beta and stable
LATEST_BETA_TAG=$(git tag -l "v*.*.0-beta.*" | sort -Vr | head -n 1)
LATEST_STABLE_TAG=$(git tag -l "v*.*.*" | grep -v -- "-beta" | sort -Vr | head -n 1)
# Increment latest beta tag
NEW_BETA_TAG=$(echo $LATEST_BETA_TAG | awk -F '[.-]' '{print $1 "." $2 ".0-beta." ($NF+1)}')
# Increment latest stable tag
NEW_STABLE_TAG=$(echo $LATEST_STABLE_TAG | awk -F '.' '{print $1 "." $2 "." ($3+1)}')
echo "latest_beta_tag=$LATEST_BETA_TAG" >> $GITHUB_OUTPUT
echo "latest_stable_tag=$LATEST_STABLE_TAG" >> $GITHUB_OUTPUT
echo "new_beta_tag=$NEW_BETA_TAG" >> $GITHUB_OUTPUT
echo "new_stable_tag=$NEW_STABLE_TAG" >> $GITHUB_OUTPUT
- name: Echo branch and tag information
run: |
echo "Beta branch: ${{ steps.list-branches.outputs.beta }}"
echo "Stable branch: ${{ steps.list-branches.outputs.stable }}"
echo "Latest beta tag: ${{ steps.list-branches.outputs.latest_beta_tag }}"
echo "Latest stable tag: ${{ steps.list-branches.outputs.latest_stable_tag }}"
echo "New beta tag: ${{ steps.list-branches.outputs.new_beta_tag }}"
echo "New stable tag: ${{ steps.list-branches.outputs.new_stable_tag }}"
- name: Trigger Backport
if: steps.checkbox-check.outputs.backport == 'true'
run: |
set -e
echo "Backporting to beta ${{ steps.list-branches.outputs.beta }} and stable ${{ steps.list-branches.outputs.stable }}"
# Fetch all history for all branches and tags
git fetch --prune --unshallow
# Checkout the beta branch
git checkout ${{ steps.list-branches.outputs.beta }}
# Cherry-pick the last commit from the merged PR
git cherry-pick ${{ github.event.pull_request.merge_commit_sha }}
# Push the changes to the beta branch
# Cherry-pick the merge commit from the merged PR
git cherry-pick -m 1 ${{ github.event.pull_request.merge_commit_sha }} || {
echo "Cherry-pick to beta failed due to conflicts."
exit 1
}
# Create new beta tag
git tag ${{ steps.list-branches.outputs.new_beta_tag }}
# Push the changes and tag to the beta branch
git push origin ${{ steps.list-branches.outputs.beta }}
git push origin ${{ steps.list-branches.outputs.new_beta_tag }}
# Checkout the stable branch
git checkout ${{ steps.list-branches.outputs.stable }}
# Cherry-pick the last commit from the merged PR
git cherry-pick ${{ github.event.pull_request.merge_commit_sha }}
# Push the changes to the stable branch
# Cherry-pick the merge commit from the merged PR
git cherry-pick -m 1 ${{ github.event.pull_request.merge_commit_sha }} || {
echo "Cherry-pick to stable failed due to conflicts."
exit 1
}
# Create new stable tag
git tag ${{ steps.list-branches.outputs.new_stable_tag }}
# Push the changes and tag to the stable branch
git push origin ${{ steps.list-branches.outputs.stable }}
git push origin ${{ steps.list-branches.outputs.new_stable_tag }}

View File

@ -31,6 +31,12 @@ def upgrade() -> None:
def downgrade() -> None:
# First, update any null values to a default value
op.execute(
"UPDATE connector_credential_pair SET last_attempt_status = 'NOT_STARTED' WHERE last_attempt_status IS NULL"
)
# Then, make the column non-nullable
op.alter_column(
"connector_credential_pair",
"last_attempt_status",

View File

@ -17,6 +17,7 @@ from danswer.background.celery.celery_redis import RedisConnectorCredentialPair
from danswer.background.celery.celery_redis import RedisConnectorDeletion
from danswer.background.celery.celery_redis import RedisConnectorIndexing
from danswer.background.celery.celery_redis import RedisConnectorPruning
from danswer.background.celery.celery_redis import RedisConnectorStop
from danswer.background.celery.celery_redis import RedisDocumentSet
from danswer.background.celery.celery_redis import RedisUserGroup
from danswer.background.celery.celery_utils import celery_is_worker_primary
@ -161,6 +162,9 @@ def on_worker_init(sender: Any, **kwargs: Any) -> None:
for key in r.scan_iter(RedisConnectorIndexing.FENCE_PREFIX + "*"):
r.delete(key)
for key in r.scan_iter(RedisConnectorStop.FENCE_PREFIX + "*"):
r.delete(key)
# @worker_process_init.connect
# def on_worker_process_init(sender: Any, **kwargs: Any) -> None:

View File

@ -313,6 +313,8 @@ class RedisConnectorDeletion(RedisObjectHelper):
lock: redis.lock.Lock,
tenant_id: str | None,
) -> int | None:
"""Returns None if the cc_pair doesn't exist.
Otherwise, returns an int with the number of generated tasks."""
last_lock_time = time.monotonic()
async_results = []
@ -540,6 +542,29 @@ class RedisConnectorIndexing(RedisObjectHelper):
return False
class RedisConnectorStop(RedisObjectHelper):
"""Used to signal any running tasks for a connector to stop. We should refactor
connector related redis helpers into a single class.
"""
PREFIX = "connectorstop"
FENCE_PREFIX = PREFIX + "_fence" # a fence for the entire indexing process
TASKSET_PREFIX = PREFIX + "_taskset" # stores a list of prune tasks id's
def __init__(self, id: int) -> None:
super().__init__(str(id))
def generate_tasks(
self,
celery_app: Celery,
db_session: Session,
redis_client: Redis,
lock: redis.lock.Lock | None,
tenant_id: str | None,
) -> int | None:
return None
def celery_get_queue_length(queue: str, r: Redis) -> int:
"""This is a redis specific way to get the length of a celery queue.
It is priority aware and knows how to count across the multiple redis lists

View File

@ -1,4 +1,3 @@
from collections.abc import Callable
from datetime import datetime
from datetime import timezone
from typing import Any
@ -6,6 +5,7 @@ from typing import Any
from sqlalchemy.orm import Session
from danswer.background.celery.celery_redis import RedisConnectorDeletion
from danswer.background.indexing.run_indexing import RunIndexingCallbackInterface
from danswer.configs.app_configs import MAX_PRUNING_DOCUMENT_RETRIEVAL_PER_MINUTE
from danswer.connectors.cross_connector_utils.rate_limit_wrapper import (
rate_limit_builder,
@ -79,7 +79,7 @@ def document_batch_to_ids(
def extract_ids_from_runnable_connector(
runnable_connector: BaseConnector,
progress_callback: Callable[[int], None] | None = None,
callback: RunIndexingCallbackInterface | None = None,
) -> set[str]:
"""
If the PruneConnector hasnt been implemented for the given connector, just pull
@ -110,8 +110,10 @@ def extract_ids_from_runnable_connector(
max_calls=MAX_PRUNING_DOCUMENT_RETRIEVAL_PER_MINUTE, period=60
)(document_batch_to_ids)
for doc_batch in doc_batch_generator:
if progress_callback:
progress_callback(len(doc_batch))
if callback:
if callback.should_stop():
raise RuntimeError("Stop signal received")
callback.progress(len(doc_batch))
all_connector_doc_ids.update(doc_batch_processing_func(doc_batch))
return all_connector_doc_ids

View File

@ -1,3 +1,6 @@
from datetime import datetime
from datetime import timezone
import redis
from celery import Celery
from celery import shared_task
@ -8,6 +11,12 @@ from sqlalchemy.orm import Session
from danswer.background.celery.apps.app_base import task_logger
from danswer.background.celery.celery_redis import RedisConnectorDeletion
from danswer.background.celery.celery_redis import RedisConnectorIndexing
from danswer.background.celery.celery_redis import RedisConnectorPruning
from danswer.background.celery.celery_redis import RedisConnectorStop
from danswer.background.celery.tasks.shared.RedisConnectorDeletionFenceData import (
RedisConnectorDeletionFenceData,
)
from danswer.configs.app_configs import JOB_TIMEOUT
from danswer.configs.constants import CELERY_VESPA_SYNC_BEAT_LOCK_TIMEOUT
from danswer.configs.constants import DanswerRedisLocks
@ -15,9 +24,15 @@ from danswer.db.connector_credential_pair import get_connector_credential_pair_f
from danswer.db.connector_credential_pair import get_connector_credential_pairs
from danswer.db.engine import get_session_with_tenant
from danswer.db.enums import ConnectorCredentialPairStatus
from danswer.db.search_settings import get_all_search_settings
from danswer.redis.redis_pool import get_redis_client
class TaskDependencyError(RuntimeError):
"""Raised to the caller to indicate dependent tasks are running that would interfere
with connector deletion."""
@shared_task(
name="check_for_connector_deletion_task",
soft_time_limit=JOB_TIMEOUT,
@ -37,17 +52,30 @@ def check_for_connector_deletion_task(self: Task, *, tenant_id: str | None) -> N
if not lock_beat.acquire(blocking=False):
return
# collect cc_pair_ids
cc_pair_ids: list[int] = []
with get_session_with_tenant(tenant_id) as db_session:
cc_pairs = get_connector_credential_pairs(db_session)
for cc_pair in cc_pairs:
cc_pair_ids.append(cc_pair.id)
# try running cleanup on the cc_pair_ids
for cc_pair_id in cc_pair_ids:
with get_session_with_tenant(tenant_id) as db_session:
try_generate_document_cc_pair_cleanup_tasks(
self.app, cc_pair_id, db_session, r, lock_beat, tenant_id
)
rcs = RedisConnectorStop(cc_pair_id)
try:
try_generate_document_cc_pair_cleanup_tasks(
self.app, cc_pair_id, db_session, r, lock_beat, tenant_id
)
except TaskDependencyError as e:
# this means we wanted to start deleting but dependent tasks were running
# Leave a stop signal to clear indexing and pruning tasks more quickly
task_logger.info(str(e))
r.set(rcs.fence_key, cc_pair_id)
else:
# clear the stop signal if it exists ... no longer needed
r.delete(rcs.fence_key)
except SoftTimeLimitExceeded:
task_logger.info(
"Soft time limit exceeded, task is being terminated gracefully."
@ -70,6 +98,10 @@ def try_generate_document_cc_pair_cleanup_tasks(
"""Returns an int if syncing is needed. The int represents the number of sync tasks generated.
Note that syncing can still be required even if the number of sync tasks generated is zero.
Returns None if no syncing is required.
Will raise TaskDependencyError if dependent tasks such as indexing and pruning are
still running. In our case, the caller reacts by setting a stop signal in Redis to
exit those tasks as quickly as possible.
"""
lock_beat.reacquire()
@ -90,28 +122,63 @@ def try_generate_document_cc_pair_cleanup_tasks(
if cc_pair.status != ConnectorCredentialPairStatus.DELETING:
return None
# add tasks to celery and build up the task set to monitor in redis
r.delete(rcd.taskset_key)
# Add all documents that need to be updated into the queue
task_logger.info(
f"RedisConnectorDeletion.generate_tasks starting. cc_pair_id={cc_pair.id}"
# set a basic fence to start
fence_value = RedisConnectorDeletionFenceData(
num_tasks=None,
submitted=datetime.now(timezone.utc),
)
tasks_generated = rcd.generate_tasks(app, db_session, r, lock_beat, tenant_id)
if tasks_generated is None:
r.set(rcd.fence_key, fence_value.model_dump_json())
try:
# do not proceed if connector indexing or connector pruning are running
search_settings_list = get_all_search_settings(db_session)
for search_settings in search_settings_list:
rci = RedisConnectorIndexing(cc_pair_id, search_settings.id)
if r.get(rci.fence_key):
raise TaskDependencyError(
f"Connector deletion - Delayed (indexing in progress): "
f"cc_pair={cc_pair_id} "
f"search_settings={search_settings.id}"
)
rcp = RedisConnectorPruning(cc_pair_id)
if r.get(rcp.fence_key):
raise TaskDependencyError(
f"Connector deletion - Delayed (pruning in progress): "
f"cc_pair={cc_pair_id}"
)
# add tasks to celery and build up the task set to monitor in redis
r.delete(rcd.taskset_key)
# Add all documents that need to be updated into the queue
task_logger.info(
f"RedisConnectorDeletion.generate_tasks starting. cc_pair={cc_pair_id}"
)
tasks_generated = rcd.generate_tasks(app, db_session, r, lock_beat, tenant_id)
if tasks_generated is None:
raise ValueError("RedisConnectorDeletion.generate_tasks returned None")
except TaskDependencyError:
r.delete(rcd.fence_key)
raise
except Exception:
task_logger.exception("Unexpected exception")
r.delete(rcd.fence_key)
return None
else:
# Currently we are allowing the sync to proceed with 0 tasks.
# It's possible for sets/groups to be generated initially with no entries
# and they still need to be marked as up to date.
# if tasks_generated == 0:
# return 0
# Currently we are allowing the sync to proceed with 0 tasks.
# It's possible for sets/groups to be generated initially with no entries
# and they still need to be marked as up to date.
# if tasks_generated == 0:
# return 0
task_logger.info(
f"RedisConnectorDeletion.generate_tasks finished. "
f"cc_pair={cc_pair_id} tasks_generated={tasks_generated}"
)
task_logger.info(
f"RedisConnectorDeletion.generate_tasks finished. "
f"cc_pair_id={cc_pair.id} tasks_generated={tasks_generated}"
)
# set this only after all tasks have been added
fence_value.num_tasks = tasks_generated
r.set(rcd.fence_key, fence_value.model_dump_json())
# set this only after all tasks have been added
r.set(rcd.fence_key, tasks_generated)
return tasks_generated

View File

@ -5,6 +5,7 @@ from time import sleep
from typing import cast
from uuid import uuid4
import redis
from celery import Celery
from celery import shared_task
from celery import Task
@ -13,12 +14,15 @@ from redis import Redis
from sqlalchemy.orm import Session
from danswer.background.celery.apps.app_base import task_logger
from danswer.background.celery.celery_redis import RedisConnectorDeletion
from danswer.background.celery.celery_redis import RedisConnectorIndexing
from danswer.background.celery.celery_redis import RedisConnectorStop
from danswer.background.celery.tasks.shared.RedisConnectorIndexingFenceData import (
RedisConnectorIndexingFenceData,
)
from danswer.background.indexing.job_client import SimpleJobClient
from danswer.background.indexing.run_indexing import run_indexing_entrypoint
from danswer.background.indexing.run_indexing import RunIndexingCallbackInterface
from danswer.configs.app_configs import DISABLE_INDEX_UPDATE_ON_SWAP
from danswer.configs.constants import CELERY_INDEXING_LOCK_TIMEOUT
from danswer.configs.constants import CELERY_VESPA_SYNC_BEAT_LOCK_TIMEOUT
@ -50,6 +54,30 @@ from danswer.utils.variable_functionality import global_version
logger = setup_logger()
class RunIndexingCallback(RunIndexingCallbackInterface):
def __init__(
self,
stop_key: str,
generator_progress_key: str,
redis_lock: redis.lock.Lock,
redis_client: Redis,
):
super().__init__()
self.redis_lock: redis.lock.Lock = redis_lock
self.stop_key: str = stop_key
self.generator_progress_key: str = generator_progress_key
self.redis_client = redis_client
def should_stop(self) -> bool:
if self.redis_client.exists(self.stop_key):
return True
return False
def progress(self, amount: int) -> None:
self.redis_lock.reacquire()
self.redis_client.incrby(self.generator_progress_key, amount)
@shared_task(
name="check_for_indexing",
soft_time_limit=300,
@ -262,6 +290,10 @@ def try_creating_indexing_task(
return None
# skip indexing if the cc_pair is deleting
rcd = RedisConnectorDeletion(cc_pair.id)
if r.exists(rcd.fence_key):
return None
db_session.refresh(cc_pair)
if cc_pair.status == ConnectorCredentialPairStatus.DELETING:
return None
@ -308,13 +340,8 @@ def try_creating_indexing_task(
raise RuntimeError("send_task for connector_indexing_proxy_task failed.")
# now fill out the fence with the rest of the data
fence_value = RedisConnectorIndexingFenceData(
index_attempt_id=index_attempt_id,
started=None,
submitted=datetime.now(timezone.utc),
celery_task_id=result.id,
)
fence_value.index_attempt_id = index_attempt_id
fence_value.celery_task_id = result.id
r.set(rci.fence_key, fence_value.model_dump_json())
except Exception:
r.delete(rci.fence_key)
@ -400,6 +427,22 @@ def connector_indexing_task(
r = get_redis_client(tenant_id=tenant_id)
rcd = RedisConnectorDeletion(cc_pair_id)
if r.exists(rcd.fence_key):
raise RuntimeError(
f"Indexing will not start because connector deletion is in progress: "
f"cc_pair={cc_pair_id} "
f"fence={rcd.fence_key}"
)
rcs = RedisConnectorStop(cc_pair_id)
if r.exists(rcs.fence_key):
raise RuntimeError(
f"Indexing will not start because a connector stop signal was detected: "
f"cc_pair={cc_pair_id} "
f"fence={rcs.fence_key}"
)
rci = RedisConnectorIndexing(cc_pair_id, search_settings_id)
while True:
@ -409,7 +452,7 @@ def connector_indexing_task(
task_logger.info(
f"connector_indexing_task: fence_value not found: fence={rci.fence_key}"
)
raise
raise RuntimeError(f"Fence not found: fence={rci.fence_key}")
try:
fence_json = fence_value.decode("utf-8")
@ -443,17 +486,20 @@ def connector_indexing_task(
if not acquired:
task_logger.warning(
f"Indexing task already running, exiting...: "
f"cc_pair_id={cc_pair_id} search_settings_id={search_settings_id}"
f"cc_pair={cc_pair_id} search_settings={search_settings_id}"
)
# r.set(rci.generator_complete_key, HTTPStatus.CONFLICT.value)
return None
fence_data.started = datetime.now(timezone.utc)
r.set(rci.fence_key, fence_data.model_dump_json())
try:
with get_session_with_tenant(tenant_id) as db_session:
attempt = get_index_attempt(db_session, index_attempt_id)
if not attempt:
raise ValueError(
f"Index attempt not found: index_attempt_id={index_attempt_id}"
f"Index attempt not found: index_attempt={index_attempt_id}"
)
cc_pair = get_connector_credential_pair_from_id(
@ -462,31 +508,31 @@ def connector_indexing_task(
)
if not cc_pair:
raise ValueError(f"cc_pair not found: cc_pair_id={cc_pair_id}")
raise ValueError(f"cc_pair not found: cc_pair={cc_pair_id}")
if not cc_pair.connector:
raise ValueError(
f"Connector not found: connector_id={cc_pair.connector_id}"
f"Connector not found: cc_pair={cc_pair_id} connector={cc_pair.connector_id}"
)
if not cc_pair.credential:
raise ValueError(
f"Credential not found: credential_id={cc_pair.credential_id}"
f"Credential not found: cc_pair={cc_pair_id} credential={cc_pair.credential_id}"
)
rci = RedisConnectorIndexing(cc_pair_id, search_settings_id)
# Define the callback function
def redis_increment_callback(amount: int) -> None:
lock.reacquire()
r.incrby(rci.generator_progress_key, amount)
# define a callback class
callback = RunIndexingCallback(
rcs.fence_key, rci.generator_progress_key, lock, r
)
run_indexing_entrypoint(
index_attempt_id,
tenant_id,
cc_pair_id,
is_ee,
progress_callback=redis_increment_callback,
callback=callback,
)
# get back the total number of indexed docs and return it
@ -499,9 +545,10 @@ def connector_indexing_task(
r.set(rci.generator_complete_key, HTTPStatus.OK.value)
except Exception as e:
task_logger.exception(f"Failed to run indexing for cc_pair_id={cc_pair_id}.")
task_logger.exception(f"Indexing failed: cc_pair={cc_pair_id}")
if attempt:
mark_attempt_failed(attempt, db_session, failure_reason=str(e))
with get_session_with_tenant(tenant_id) as db_session:
mark_attempt_failed(attempt, db_session, failure_reason=str(e))
r.delete(rci.generator_lock_key)
r.delete(rci.generator_progress_key)

View File

@ -11,8 +11,11 @@ from redis import Redis
from sqlalchemy.orm import Session
from danswer.background.celery.apps.app_base import task_logger
from danswer.background.celery.celery_redis import RedisConnectorDeletion
from danswer.background.celery.celery_redis import RedisConnectorPruning
from danswer.background.celery.celery_redis import RedisConnectorStop
from danswer.background.celery.celery_utils import extract_ids_from_runnable_connector
from danswer.background.celery.tasks.indexing.tasks import RunIndexingCallback
from danswer.configs.app_configs import ALLOW_SIMULTANEOUS_PRUNING
from danswer.configs.app_configs import JOB_TIMEOUT
from danswer.configs.constants import CELERY_PRUNING_LOCK_TIMEOUT
@ -168,6 +171,10 @@ def try_creating_prune_generator_task(
return None
# skip pruning if the cc_pair is deleting
rcd = RedisConnectorDeletion(cc_pair.id)
if r.exists(rcd.fence_key):
return None
db_session.refresh(cc_pair)
if cc_pair.status == ConnectorCredentialPairStatus.DELETING:
return None
@ -234,7 +241,7 @@ def connector_pruning_generator_task(
acquired = lock.acquire(blocking=False)
if not acquired:
task_logger.warning(
f"Pruning task already running, exiting...: cc_pair_id={cc_pair_id}"
f"Pruning task already running, exiting...: cc_pair={cc_pair_id}"
)
return None
@ -252,11 +259,6 @@ def connector_pruning_generator_task(
)
return
# Define the callback function
def redis_increment_callback(amount: int) -> None:
lock.reacquire()
r.incrby(rcp.generator_progress_key, amount)
runnable_connector = instantiate_connector(
db_session,
cc_pair.connector.source,
@ -265,9 +267,14 @@ def connector_pruning_generator_task(
cc_pair.credential,
)
rcs = RedisConnectorStop(cc_pair_id)
callback = RunIndexingCallback(
rcs.fence_key, rcp.generator_progress_key, lock, r
)
# a list of docs in the source
all_connector_doc_ids: set[str] = extract_ids_from_runnable_connector(
runnable_connector, redis_increment_callback
runnable_connector, callback
)
# a list of docs in our local index
@ -285,7 +292,7 @@ def connector_pruning_generator_task(
task_logger.info(
f"Pruning set collected: "
f"cc_pair_id={cc_pair.id} "
f"cc_pair={cc_pair_id} "
f"docs_to_remove={len(doc_ids_to_remove)} "
f"doc_source={cc_pair.connector.source}"
)
@ -293,7 +300,7 @@ def connector_pruning_generator_task(
rcp.documents_to_prune = set(doc_ids_to_remove)
task_logger.info(
f"RedisConnectorPruning.generate_tasks starting. cc_pair_id={cc_pair.id}"
f"RedisConnectorPruning.generate_tasks starting. cc_pair_id={cc_pair_id}"
)
tasks_generated = rcp.generate_tasks(
self.app, db_session, r, None, tenant_id
@ -303,12 +310,14 @@ def connector_pruning_generator_task(
task_logger.info(
f"RedisConnectorPruning.generate_tasks finished. "
f"cc_pair_id={cc_pair.id} tasks_generated={tasks_generated}"
f"cc_pair={cc_pair_id} tasks_generated={tasks_generated}"
)
r.set(rcp.generator_complete_key, tasks_generated)
except Exception as e:
task_logger.exception(f"Failed to run pruning for connector id {connector_id}.")
task_logger.exception(
f"Failed to run pruning: cc_pair={cc_pair_id} connector={connector_id}"
)
r.delete(rcp.generator_progress_key)
r.delete(rcp.taskset_key)

View File

@ -0,0 +1,8 @@
from datetime import datetime
from pydantic import BaseModel
class RedisConnectorDeletionFenceData(BaseModel):
num_tasks: int | None
submitted: datetime

View File

@ -1,9 +1,6 @@
from datetime import datetime
from celery import shared_task
from celery import Task
from celery.exceptions import SoftTimeLimitExceeded
from pydantic import BaseModel
from danswer.access.access import get_access_for_document
from danswer.background.celery.apps.app_base import task_logger
@ -23,13 +20,6 @@ from danswer.server.documents.models import ConnectorCredentialPairIdentifier
DOCUMENT_BY_CC_PAIR_CLEANUP_MAX_RETRIES = 3
class RedisConnectorIndexingFenceData(BaseModel):
index_attempt_id: int | None
started: datetime | None
submitted: datetime
celery_task_id: str | None
@shared_task(
name="document_by_cc_pair_cleanup_task",
bind=True,

View File

@ -23,6 +23,9 @@ from danswer.background.celery.celery_redis import RedisConnectorIndexing
from danswer.background.celery.celery_redis import RedisConnectorPruning
from danswer.background.celery.celery_redis import RedisDocumentSet
from danswer.background.celery.celery_redis import RedisUserGroup
from danswer.background.celery.tasks.shared.RedisConnectorDeletionFenceData import (
RedisConnectorDeletionFenceData,
)
from danswer.background.celery.tasks.shared.RedisConnectorIndexingFenceData import (
RedisConnectorIndexingFenceData,
)
@ -368,7 +371,7 @@ def monitor_document_set_taskset(
count = cast(int, r.scard(rds.taskset_key))
task_logger.info(
f"Document set sync progress: document_set_id={document_set_id} "
f"Document set sync progress: document_set={document_set_id} "
f"remaining={count} initial={initial_count}"
)
if count > 0:
@ -383,12 +386,12 @@ def monitor_document_set_taskset(
# if there are no connectors, then delete the document set.
delete_document_set(document_set_row=document_set, db_session=db_session)
task_logger.info(
f"Successfully deleted document set with ID: '{document_set_id}'!"
f"Successfully deleted document set: document_set={document_set_id}"
)
else:
mark_document_set_as_synced(document_set_id, db_session)
task_logger.info(
f"Successfully synced document set with ID: '{document_set_id}'!"
f"Successfully synced document set: document_set={document_set_id}"
)
r.delete(rds.taskset_key)
@ -408,19 +411,29 @@ def monitor_connector_deletion_taskset(
rcd = RedisConnectorDeletion(cc_pair_id)
fence_value = r.get(rcd.fence_key)
# read related data and evaluate/print task progress
fence_value = cast(bytes, r.get(rcd.fence_key))
if fence_value is None:
return
try:
initial_count = int(cast(int, fence_value))
fence_json = fence_value.decode("utf-8")
fence_data = RedisConnectorDeletionFenceData.model_validate_json(
cast(str, fence_json)
)
except ValueError:
task_logger.error("The value is not an integer.")
task_logger.exception(
"monitor_ccpair_indexing_taskset: fence_data not decodeable."
)
raise
# the fence is setting up but isn't ready yet
if fence_data.num_tasks is None:
return
count = cast(int, r.scard(rcd.taskset_key))
task_logger.info(
f"Connector deletion progress: cc_pair={cc_pair_id} remaining={count} initial={initial_count}"
f"Connector deletion progress: cc_pair={cc_pair_id} remaining={count} initial={fence_data.num_tasks}"
)
if count > 0:
return
@ -483,7 +496,7 @@ def monitor_connector_deletion_taskset(
)
if not connector or not len(connector.credentials):
task_logger.info(
"Found no credentials left for connector, deleting connector"
"Connector deletion - Found no credentials left for connector, deleting connector"
)
db_session.delete(connector)
db_session.commit()
@ -493,17 +506,17 @@ def monitor_connector_deletion_taskset(
error_message = f"Error: {str(e)}\n\nStack Trace:\n{stack_trace}"
add_deletion_failure_message(db_session, cc_pair_id, error_message)
task_logger.exception(
f"Failed to run connector_deletion. "
f"Connector deletion exceptioned: "
f"cc_pair={cc_pair_id} connector={cc_pair.connector_id} credential={cc_pair.credential_id}"
)
raise e
task_logger.info(
f"Successfully deleted cc_pair: "
f"Connector deletion succeeded: "
f"cc_pair={cc_pair_id} "
f"connector={cc_pair.connector_id} "
f"credential={cc_pair.credential_id} "
f"docs_deleted={initial_count}"
f"docs_deleted={fence_data.num_tasks}"
)
r.delete(rcd.taskset_key)
@ -618,6 +631,7 @@ def monitor_ccpair_indexing_taskset(
return
# Read result state BEFORE generator_complete_key to avoid a race condition
# never use any blocking methods on the result from inside a task!
result: AsyncResult = AsyncResult(fence_data.celery_task_id)
result_state = result.state

View File

@ -1,6 +1,7 @@
import time
import traceback
from collections.abc import Callable
from abc import ABC
from abc import abstractmethod
from datetime import datetime
from datetime import timedelta
from datetime import timezone
@ -41,6 +42,19 @@ logger = setup_logger()
INDEXING_TRACER_NUM_PRINT_ENTRIES = 5
class RunIndexingCallbackInterface(ABC):
"""Defines a callback interface to be passed to
to run_indexing_entrypoint."""
@abstractmethod
def should_stop(self) -> bool:
"""Signal to stop the looping function in flight."""
@abstractmethod
def progress(self, amount: int) -> None:
"""Send progress updates to the caller."""
def _get_connector_runner(
db_session: Session,
attempt: IndexAttempt,
@ -92,7 +106,7 @@ def _run_indexing(
db_session: Session,
index_attempt: IndexAttempt,
tenant_id: str | None,
progress_callback: Callable[[int], None] | None = None,
callback: RunIndexingCallbackInterface | None = None,
) -> None:
"""
1. Get documents which are either new or updated from specified application
@ -206,6 +220,11 @@ def _run_indexing(
# index being built. We want to populate it even for paused connectors
# Often paused connectors are sources that aren't updated frequently but the
# contents still need to be initially pulled.
if callback:
if callback.should_stop():
raise RuntimeError("Connector stop signal detected")
# TODO: should we move this into the above callback instead?
db_session.refresh(db_cc_pair)
if (
(
@ -263,8 +282,8 @@ def _run_indexing(
# be inaccurate
db_session.commit()
if progress_callback:
progress_callback(len(doc_batch))
if callback:
callback.progress(len(doc_batch))
# This new value is updated every batch, so UI can refresh per batch update
update_docs_indexed(
@ -394,7 +413,7 @@ def run_indexing_entrypoint(
tenant_id: str | None,
connector_credential_pair_id: int,
is_ee: bool = False,
progress_callback: Callable[[int], None] | None = None,
callback: RunIndexingCallbackInterface | None = None,
) -> None:
try:
if is_ee:
@ -417,7 +436,7 @@ def run_indexing_entrypoint(
f"credentials='{attempt.connector_credential_pair.connector_id}'"
)
_run_indexing(db_session, attempt, tenant_id, progress_callback)
_run_indexing(db_session, attempt, tenant_id, callback)
logger.info(
f"Indexing finished for tenant {tenant_id}: "

View File

@ -672,6 +672,7 @@ def stream_chat_message_objects(
all_docs_useful=selected_db_search_docs is not None
),
document_pruning_config=document_pruning_config,
structured_response_format=new_msg_req.structured_response_format,
),
prompt_config=prompt_config,
llm=(

View File

@ -237,6 +237,7 @@ class Answer:
prompt=prompt,
tools=final_tool_definitions if final_tool_definitions else None,
tool_choice="required" if self.force_use_tool.force_use else None,
structured_response_format=self.answer_style_config.structured_response_format,
):
if isinstance(message, AIMessageChunk) and (
message.tool_call_chunks or message.tool_calls
@ -331,7 +332,10 @@ class Answer:
tool_choice: ToolChoiceOptions | None = None,
) -> Iterator[str | StreamStopInfo]:
for message in self.llm.stream(
prompt=prompt, tools=tools, tool_choice=tool_choice
prompt=prompt,
tools=tools,
tool_choice=tool_choice,
structured_response_format=self.answer_style_config.structured_response_format,
):
if isinstance(message, AIMessageChunk):
if message.content:

View File

@ -116,6 +116,10 @@ class AnswerStyleConfig(BaseModel):
document_pruning_config: DocumentPruningConfig = Field(
default_factory=DocumentPruningConfig
)
# forces the LLM to return a structured response, see
# https://platform.openai.com/docs/guides/structured-outputs/introduction
# right now, only used by the simple chat API
structured_response_format: dict | None = None
@model_validator(mode="after")
def check_quotes_and_citation(self) -> "AnswerStyleConfig":

View File

@ -280,6 +280,7 @@ class DefaultMultiLLM(LLM):
tools: list[dict] | None,
tool_choice: ToolChoiceOptions | None,
stream: bool,
structured_response_format: dict | None = None,
) -> litellm.ModelResponse | litellm.CustomStreamWrapper:
if isinstance(prompt, list):
prompt = [
@ -313,6 +314,11 @@ class DefaultMultiLLM(LLM):
# NOTE: we can't pass this in if tools are not specified
# or else OpenAI throws an error
**({"parallel_tool_calls": False} if tools else {}),
**(
{"response_format": structured_response_format}
if structured_response_format
else {}
),
**self._model_kwargs,
)
except Exception as e:
@ -336,12 +342,16 @@ class DefaultMultiLLM(LLM):
prompt: LanguageModelInput,
tools: list[dict] | None = None,
tool_choice: ToolChoiceOptions | None = None,
structured_response_format: dict | None = None,
) -> BaseMessage:
if LOG_DANSWER_MODEL_INTERACTIONS:
self.log_model_configs()
response = cast(
litellm.ModelResponse, self._completion(prompt, tools, tool_choice, False)
litellm.ModelResponse,
self._completion(
prompt, tools, tool_choice, False, structured_response_format
),
)
choice = response.choices[0]
if hasattr(choice, "message"):
@ -354,18 +364,21 @@ class DefaultMultiLLM(LLM):
prompt: LanguageModelInput,
tools: list[dict] | None = None,
tool_choice: ToolChoiceOptions | None = None,
structured_response_format: dict | None = None,
) -> Iterator[BaseMessage]:
if LOG_DANSWER_MODEL_INTERACTIONS:
self.log_model_configs()
if DISABLE_LITELLM_STREAMING:
yield self.invoke(prompt)
yield self.invoke(prompt, tools, tool_choice, structured_response_format)
return
output = None
response = cast(
litellm.CustomStreamWrapper,
self._completion(prompt, tools, tool_choice, True),
self._completion(
prompt, tools, tool_choice, True, structured_response_format
),
)
try:
for part in response:

View File

@ -80,6 +80,7 @@ class CustomModelServer(LLM):
prompt: LanguageModelInput,
tools: list[dict] | None = None,
tool_choice: ToolChoiceOptions | None = None,
structured_response_format: dict | None = None,
) -> BaseMessage:
return self._execute(prompt)
@ -88,5 +89,6 @@ class CustomModelServer(LLM):
prompt: LanguageModelInput,
tools: list[dict] | None = None,
tool_choice: ToolChoiceOptions | None = None,
structured_response_format: dict | None = None,
) -> Iterator[BaseMessage]:
yield self._execute(prompt)

View File

@ -88,11 +88,14 @@ class LLM(abc.ABC):
prompt: LanguageModelInput,
tools: list[dict] | None = None,
tool_choice: ToolChoiceOptions | None = None,
structured_response_format: dict | None = None,
) -> BaseMessage:
self._precall(prompt)
# TODO add a postcall to log model outputs independent of concrete class
# implementation
return self._invoke_implementation(prompt, tools, tool_choice)
return self._invoke_implementation(
prompt, tools, tool_choice, structured_response_format
)
@abc.abstractmethod
def _invoke_implementation(
@ -100,6 +103,7 @@ class LLM(abc.ABC):
prompt: LanguageModelInput,
tools: list[dict] | None = None,
tool_choice: ToolChoiceOptions | None = None,
structured_response_format: dict | None = None,
) -> BaseMessage:
raise NotImplementedError
@ -108,11 +112,14 @@ class LLM(abc.ABC):
prompt: LanguageModelInput,
tools: list[dict] | None = None,
tool_choice: ToolChoiceOptions | None = None,
structured_response_format: dict | None = None,
) -> Iterator[BaseMessage]:
self._precall(prompt)
# TODO add a postcall to log model outputs independent of concrete class
# implementation
return self._stream_implementation(prompt, tools, tool_choice)
return self._stream_implementation(
prompt, tools, tool_choice, structured_response_format
)
@abc.abstractmethod
def _stream_implementation(
@ -120,5 +127,6 @@ class LLM(abc.ABC):
prompt: LanguageModelInput,
tools: list[dict] | None = None,
tool_choice: ToolChoiceOptions | None = None,
structured_response_format: dict | None = None,
) -> Iterator[BaseMessage]:
raise NotImplementedError

View File

@ -19,7 +19,6 @@ from danswer.db.connector_credential_pair import get_connector_credential_pair
from danswer.db.connector_credential_pair import (
update_connector_credential_pair_from_id,
)
from danswer.db.deletion_attempt import check_deletion_attempt_is_allowed
from danswer.db.engine import get_current_tenant_id
from danswer.db.engine import get_session
from danswer.db.enums import ConnectorCredentialPairStatus
@ -175,15 +174,19 @@ def create_deletion_attempt_for_connector_id(
cc_pair_id=cc_pair.id, db_session=db_session, include_secondary_index=True
)
# TODO(rkuo): 2024-10-24 - check_deletion_attempt_is_allowed shouldn't be necessary
# any more due to background locking improvements.
# Remove the below permanently if everything is behaving for 30 days.
# Check if the deletion attempt should be allowed
deletion_attempt_disallowed_reason = check_deletion_attempt_is_allowed(
connector_credential_pair=cc_pair, db_session=db_session
)
if deletion_attempt_disallowed_reason:
raise HTTPException(
status_code=400,
detail=deletion_attempt_disallowed_reason,
)
# deletion_attempt_disallowed_reason = check_deletion_attempt_is_allowed(
# connector_credential_pair=cc_pair, db_session=db_session
# )
# if deletion_attempt_disallowed_reason:
# raise HTTPException(
# status_code=400,
# detail=deletion_attempt_disallowed_reason,
# )
# mark as deleting
update_connector_credential_pair_from_id(

View File

@ -108,6 +108,10 @@ class CreateChatMessageRequest(ChunkContext):
# used for seeded chats to kick off the generation of an AI answer
use_existing_user_message: bool = False
# forces the LLM to return a structured response, see
# https://platform.openai.com/docs/guides/structured-outputs/introduction
structured_response_format: dict | None = None
@model_validator(mode="after")
def check_search_doc_ids_or_retrieval_options(self) -> "CreateChatMessageRequest":
if self.search_doc_ids is None and self.retrieval_options is None:

View File

@ -176,6 +176,7 @@ def handle_simplified_chat_message(
chunks_above=0,
chunks_below=0,
full_doc=chat_message_req.full_doc,
structured_response_format=chat_message_req.structured_response_format,
)
packets = stream_chat_message_objects(
@ -202,7 +203,7 @@ def handle_send_message_simple_with_history(
raise HTTPException(status_code=400, detail="Messages cannot be zero length")
# This is a sanity check to make sure the chat history is valid
# It must start with a user message and alternate between user and assistant
# It must start with a user message and alternate beteen user and assistant
expected_role = MessageType.USER
for msg in req.messages:
if not msg.message:
@ -296,6 +297,7 @@ def handle_send_message_simple_with_history(
chunks_above=0,
chunks_below=0,
full_doc=req.full_doc,
structured_response_format=req.structured_response_format,
)
packets = stream_chat_message_objects(

View File

@ -48,6 +48,9 @@ class BasicCreateChatMessageRequest(ChunkContext):
query_override: str | None = None
# If search_doc_ids provided, then retrieval options are unused
search_doc_ids: list[int] | None = None
# only works if using an OpenAI model. See the following for more details:
# https://platform.openai.com/docs/guides/structured-outputs/introduction
structured_response_format: dict | None = None
class BasicCreateChatMessageWithHistoryRequest(ChunkContext):
@ -60,6 +63,9 @@ class BasicCreateChatMessageWithHistoryRequest(ChunkContext):
skip_rerank: bool | None = None
# If search_doc_ids provided, then retrieval options are unused
search_doc_ids: list[int] | None = None
# only works if using an OpenAI model. See the following for more details:
# https://platform.openai.com/docs/guides/structured-outputs/introduction
structured_response_format: dict | None = None
class SimpleDoc(BaseModel):

View File

@ -0,0 +1,148 @@
from typing import Any
from typing import Dict
import requests
API_SERVER_URL = "http://localhost:3000" # Adjust this to your Danswer server URL
HEADERS = {"Content-Type": "application/json"}
API_KEY = "danswer-api-key" # API key here, if auth is enabled
def create_connector(
name: str,
source: str,
input_type: str,
connector_specific_config: Dict[str, Any],
is_public: bool = True,
groups: list[int] | None = None,
) -> Dict[str, Any]:
connector_update_request = {
"name": name,
"source": source,
"input_type": input_type,
"connector_specific_config": connector_specific_config,
"is_public": is_public,
"groups": groups or [],
}
response = requests.post(
url=f"{API_SERVER_URL}/api/manage/admin/connector",
json=connector_update_request,
headers=HEADERS,
)
response.raise_for_status()
return response.json()
def create_credential(
name: str,
source: str,
credential_json: Dict[str, Any],
is_public: bool = True,
groups: list[int] | None = None,
) -> Dict[str, Any]:
credential_request = {
"name": name,
"source": source,
"credential_json": credential_json,
"admin_public": is_public,
"groups": groups or [],
}
response = requests.post(
url=f"{API_SERVER_URL}/api/manage/credential",
json=credential_request,
headers=HEADERS,
)
response.raise_for_status()
return response.json()
def create_cc_pair(
connector_id: int,
credential_id: int,
name: str,
access_type: str = "public",
groups: list[int] | None = None,
) -> Dict[str, Any]:
cc_pair_request = {
"name": name,
"access_type": access_type,
"groups": groups or [],
}
response = requests.put(
url=f"{API_SERVER_URL}/api/manage/connector/{connector_id}/credential/{credential_id}",
json=cc_pair_request,
headers=HEADERS,
)
response.raise_for_status()
return response.json()
def main() -> None:
# Create a Web connector
web_connector = create_connector(
name="Example Web Connector",
source="web",
input_type="load_state",
connector_specific_config={
"base_url": "https://example.com",
"web_connector_type": "recursive",
},
)
print(f"Created Web Connector: {web_connector}")
# Create a credential for the Web connector
web_credential = create_credential(
name="Example Web Credential",
source="web",
credential_json={}, # Web connectors typically don't need credentials
is_public=True,
)
print(f"Created Web Credential: {web_credential}")
# Create CC pair for Web connector
web_cc_pair = create_cc_pair(
connector_id=web_connector["id"],
credential_id=web_credential["id"],
name="Example Web CC Pair",
access_type="public",
)
print(f"Created Web CC Pair: {web_cc_pair}")
# Create a GitHub connector
github_connector = create_connector(
name="Example GitHub Connector",
source="github",
input_type="poll",
connector_specific_config={
"repo_owner": "example-owner",
"repo_name": "example-repo",
"include_prs": True,
"include_issues": True,
},
)
print(f"Created GitHub Connector: {github_connector}")
# Create a credential for the GitHub connector
github_credential = create_credential(
name="Example GitHub Credential",
source="github",
credential_json={"github_access_token": "your_github_access_token_here"},
is_public=True,
)
print(f"Created GitHub Credential: {github_credential}")
# Create CC pair for GitHub connector
github_cc_pair = create_cc_pair(
connector_id=github_connector["id"],
credential_id=github_credential["id"],
name="Example GitHub CC Pair",
access_type="public",
)
print(f"Created GitHub CC Pair: {github_cc_pair}")
if __name__ == "__main__":
main()

View File

@ -6,7 +6,9 @@ from sqlalchemy.orm import Session
from danswer.db.engine import get_session_context_manager
from danswer.db.search_settings import get_current_search_settings
from tests.integration.common_utils.managers.user import UserManager
from tests.integration.common_utils.reset import reset_all
from tests.integration.common_utils.test_models import DATestUser
from tests.integration.common_utils.vespa import vespa_fixture
@ -44,3 +46,11 @@ def vespa_client(db_session: Session) -> vespa_fixture:
@pytest.fixture
def reset() -> None:
reset_all()
@pytest.fixture
def new_admin_user(reset: None) -> DATestUser | None:
try:
return UserManager.create(name="admin_user")
except Exception:
return None

View File

@ -1,7 +1,10 @@
import json
import requests
from danswer.configs.constants import MessageType
from tests.integration.common_utils.constants import API_SERVER_URL
from tests.integration.common_utils.constants import GENERAL_HEADERS
from tests.integration.common_utils.constants import NUM_DOCS
from tests.integration.common_utils.managers.api_key import APIKeyManager
from tests.integration.common_utils.managers.cc_pair import CCPairManager
@ -145,3 +148,85 @@ def test_using_reference_docs_with_simple_with_history_api_flow(reset: None) ->
# This ensures the the document we think we are referencing when we send the search_doc_ids in the second
# message is the document that we expect it to be
assert response_json["top_documents"][0]["document_id"] == cc_pair_1.documents[2].id
def test_send_message_simple_with_history_strict_json(
new_admin_user: DATestUser | None,
) -> None:
# create connectors
LLMProviderManager.create(user_performing_action=new_admin_user)
response = requests.post(
f"{API_SERVER_URL}/chat/send-message-simple-with-history",
json={
# intentionally not relevant prompt to ensure that the
# structured response format is actually used
"messages": [
{
"message": "What is green?",
"role": MessageType.USER.value,
}
],
"persona_id": 0,
"prompt_id": 0,
"structured_response_format": {
"type": "json_schema",
"json_schema": {
"name": "presidents",
"schema": {
"type": "object",
"properties": {
"presidents": {
"type": "array",
"items": {"type": "string"},
"description": "List of the first three US presidents",
}
},
"required": ["presidents"],
"additionalProperties": False,
},
"strict": True,
},
},
},
headers=new_admin_user.headers if new_admin_user else GENERAL_HEADERS,
)
assert response.status_code == 200
response_json = response.json()
# Check that the answer is present
assert "answer" in response_json
assert response_json["answer"] is not None
# helper
def clean_json_string(json_string: str) -> str:
return json_string.strip().removeprefix("```json").removesuffix("```").strip()
# Attempt to parse the answer as JSON
try:
clean_answer = clean_json_string(response_json["answer"])
parsed_answer = json.loads(clean_answer)
# NOTE: do not check content, just the structure
assert isinstance(parsed_answer, dict)
assert "presidents" in parsed_answer
assert isinstance(parsed_answer["presidents"], list)
for president in parsed_answer["presidents"]:
assert isinstance(president, str)
except json.JSONDecodeError:
assert (
False
), f"The answer is not a valid JSON object - '{response_json['answer']}'"
# Check that the answer_citationless is also valid JSON
assert "answer_citationless" in response_json
assert response_json["answer_citationless"] is not None
try:
clean_answer_citationless = clean_json_string(
response_json["answer_citationless"]
)
parsed_answer_citationless = json.loads(clean_answer_citationless)
assert isinstance(parsed_answer_citationless, dict)
except json.JSONDecodeError:
assert False, "The answer_citationless is not a valid JSON object"

View File

@ -173,7 +173,6 @@ export function ProviderCreationModal({
return (
<Modal
width="max-w-3xl"
title={`Configure ${selectedProvider.provider_type}`}
onOutsideClick={onCancel}
icon={selectedProvider.icon}

View File

@ -1,12 +1,12 @@
import React from "react";
import { Formik, Form } from "formik";
import * as Yup from "yup";
import { ModalWrapper } from "@/components/modals/ModalWrapper";
import { Button } from "@tremor/react";
import { BookstackIcon } from "@/components/icons/icons";
import { AddPromptModalProps } from "../interfaces";
import { TextFormField } from "@/components/admin/connectors/Field";
import { Modal } from "@/components/Modal";
const AddPromptSchema = Yup.object().shape({
title: Yup.string().required("Title is required"),
@ -15,7 +15,7 @@ const AddPromptSchema = Yup.object().shape({
const AddPromptModal = ({ onClose, onSubmit }: AddPromptModalProps) => {
return (
<ModalWrapper onClose={onClose} modalClassName="max-w-xl">
<Modal onOutsideClick={onClose} width="max-w-xl">
<Formik
initialValues={{
title: "",
@ -57,7 +57,7 @@ const AddPromptModal = ({ onClose, onSubmit }: AddPromptModalProps) => {
</Form>
)}
</Formik>
</ModalWrapper>
</Modal>
);
};

View File

@ -1,8 +1,9 @@
import React from "react";
import { Formik, Form, Field, ErrorMessage } from "formik";
import * as Yup from "yup";
import { ModalWrapper } from "@/components/modals/ModalWrapper";
import { Modal } from "@/components/Modal";
import { Button, Textarea, TextInput } from "@tremor/react";
import { useInputPrompt } from "../hooks";
import { EditPromptModalProps } from "../interfaces";
@ -25,20 +26,20 @@ const EditPromptModal = ({
if (error)
return (
<ModalWrapper onClose={onClose} modalClassName="max-w-xl">
<Modal onOutsideClick={onClose} width="max-w-xl">
<p>Failed to load prompt data</p>
</ModalWrapper>
</Modal>
);
if (!promptData)
return (
<ModalWrapper onClose={onClose} modalClassName="max-w-xl">
<Modal onOutsideClick={onClose} width="max-w-xl">
<p>Loading...</p>
</ModalWrapper>
</Modal>
);
return (
<ModalWrapper onClose={onClose} modalClassName="max-w-xl">
<Modal onOutsideClick={onClose} width="max-w-xl">
<Formik
initialValues={{
prompt: promptData.prompt,
@ -131,7 +132,7 @@ const EditPromptModal = ({
</Form>
)}
</Formik>
</ModalWrapper>
</Modal>
);
};

View File

@ -2,7 +2,7 @@
import { useState } from "react";
import { FeedbackType } from "../types";
import { ModalWrapper } from "@/components/modals/ModalWrapper";
import { Modal } from "@/components/Modal";
import { FilledLikeIcon } from "@/components/icons/icons";
const predefinedPositiveFeedbackOptions =
@ -49,7 +49,7 @@ export const FeedbackModal = ({
: predefinedNegativeFeedbackOptions;
return (
<ModalWrapper onClose={onClose} modalClassName="max-w-3xl">
<Modal onOutsideClick={onClose} width="max-w-3xl">
<>
<h2 className="text-2xl text-emphasis font-bold mb-4 flex">
<div className="mr-1 my-auto">
@ -112,6 +112,6 @@ export const FeedbackModal = ({
</button>
</div>
</>
</ModalWrapper>
</Modal>
);
};

View File

@ -1,4 +1,4 @@
import { ModalWrapper } from "@/components/modals/ModalWrapper";
import { Modal } from "@/components/Modal";
import { Button, Divider, Text } from "@tremor/react";
export function MakePublicAssistantModal({
@ -11,7 +11,7 @@ export function MakePublicAssistantModal({
onClose: () => void;
}) {
return (
<ModalWrapper onClose={onClose} modalClassName="max-w-3xl">
<Modal onOutsideClick={onClose} width="max-w-3xl">
<div className="space-y-6">
<h2 className="text-2xl font-bold text-emphasis">
{isPublic ? "Public Assistant" : "Make Assistant Public"}
@ -67,6 +67,6 @@ export function MakePublicAssistantModal({
</div>
)}
</div>
</ModalWrapper>
</Modal>
);
}

View File

@ -1,5 +1,5 @@
import { Dispatch, SetStateAction, useEffect, useRef } from "react";
import { ModalWrapper } from "@/components/modals/ModalWrapper";
import { Modal } from "@/components/Modal";
import { Text } from "@tremor/react";
import { getDisplayNameForModel, LlmOverride } from "@/lib/hooks";
import { LLMProviderDescriptor } from "@/app/admin/configuration/llm/interfaces";
@ -123,10 +123,7 @@ export function SetDefaultModelModal({
);
return (
<ModalWrapper
onClose={onClose}
modalClassName="rounded-lg bg-white max-w-xl"
>
<Modal onOutsideClick={onClose} width="rounded-lg bg-white max-w-xl">
<>
<div className="flex mb-4">
<h2 className="text-2xl text-emphasis font-bold flex my-auto">
@ -203,6 +200,6 @@ export function SetDefaultModelModal({
</div>
</div>
</>
</ModalWrapper>
</Modal>
);
}

View File

@ -1,5 +1,5 @@
import { useState } from "react";
import { ModalWrapper } from "@/components/modals/ModalWrapper";
import { Modal } from "@/components/Modal";
import { Button, Callout, Divider, Text } from "@tremor/react";
import { Spinner } from "@/components/Spinner";
import { ChatSessionSharedStatus } from "../interfaces";
@ -57,7 +57,7 @@ export function ShareChatSessionModal({
);
return (
<ModalWrapper onClose={onClose} modalClassName="max-w-3xl">
<Modal onOutsideClick={onClose} width="max-w-3xl">
<>
<div className="flex mb-4">
<h2 className="text-2xl text-emphasis font-bold flex my-auto">
@ -154,6 +154,6 @@ export function ShareChatSessionModal({
)}
</div>
</>
</ModalWrapper>
</Modal>
);
}

View File

@ -54,9 +54,9 @@ export function Modal({
e.stopPropagation();
}
}}
className={`bg-background text-emphasis rounded shadow-2xl
className={`bg-background text-emphasis rounded shadow-2xl
transform transition-all duration-300 ease-in-out
${width ?? "w-11/12 max-w-5xl"}
${width ?? "w-11/12 max-w-4xl"}
${noPadding ? "" : "p-10"}
${className || ""}`}
>
@ -88,7 +88,7 @@ export function Modal({
{!hideDividerForTitle && <Divider />}
</>
)}
{children}
<div className="max-h-[60vh] overflow-y-scroll">{children}</div>
</div>
</div>
</div>

View File

@ -1,6 +1,6 @@
import { FiTrash, FiX } from "react-icons/fi";
import { ModalWrapper } from "@/components/modals/ModalWrapper";
import { BasicClickable } from "@/components/BasicClickable";
import { Modal } from "../Modal";
export const DeleteEntityModal = ({
onClose,
@ -16,7 +16,7 @@ export const DeleteEntityModal = ({
additionalDetails?: string;
}) => {
return (
<ModalWrapper onClose={onClose}>
<Modal onOutsideClick={onClose}>
<>
<div className="flex mb-4">
<h2 className="my-auto text-2xl font-bold">Delete {entityType}?</h2>
@ -37,6 +37,6 @@ export const DeleteEntityModal = ({
</div>
</div>
</>
</ModalWrapper>
</Modal>
);
};

View File

@ -1,5 +1,5 @@
import { FiCheck } from "react-icons/fi";
import { ModalWrapper } from "./ModalWrapper";
import { Modal } from "@/components/Modal";
import { BasicClickable } from "@/components/BasicClickable";
export const GenericConfirmModal = ({
@ -16,7 +16,7 @@ export const GenericConfirmModal = ({
onConfirm: () => void;
}) => {
return (
<ModalWrapper onClose={onClose}>
<Modal onOutsideClick={onClose}>
<div className="max-w-full">
<div className="flex mb-4">
<h2 className="my-auto text-2xl font-bold whitespace-normal overflow-wrap-normal">
@ -37,6 +37,6 @@ export const GenericConfirmModal = ({
</div>
</div>
</div>
</ModalWrapper>
</Modal>
);
};

View File

@ -1,63 +0,0 @@
"use client";
import { XIcon } from "@/components/icons/icons";
import { isEventWithinRef } from "@/lib/contains";
import { useRef } from "react";
export const ModalWrapper = ({
children,
bgClassName,
modalClassName,
onClose,
}: {
children: JSX.Element;
bgClassName?: string;
modalClassName?: string;
onClose?: () => void;
}) => {
const modalRef = useRef<HTMLDivElement>(null);
const handleMouseDown = (e: React.MouseEvent<HTMLDivElement>) => {
if (
onClose &&
modalRef.current &&
!modalRef.current.contains(e.target as Node) &&
!isEventWithinRef(e.nativeEvent, modalRef)
) {
onClose();
}
};
return (
<div
onMouseDown={handleMouseDown}
className={`fixed inset-0 bg-black bg-opacity-25 backdrop-blur-sm
flex items-center justify-center z-50 transition-opacity duration-300 ease-in-out
${bgClassName || ""}`}
>
<div
ref={modalRef}
onClick={(e) => {
if (onClose) {
e.stopPropagation();
}
}}
className={`bg-background text-emphasis p-10 rounded shadow-2xl
w-11/12 max-w-3xl transform transition-all duration-300 ease-in-out
relative ${modalClassName || ""}`}
>
{onClose && (
<div className="absolute top-2 right-2">
<button
onClick={onClose}
className="cursor-pointer text-text-500 hover:text-text-700 transition-colors duration-200 p-2"
aria-label="Close modal"
>
<XIcon className="w-5 h-5" />
</button>
</div>
)}
<div className="flex w-full flex-col justify-stretch">{children}</div>
</div>
</div>
);
};

View File

@ -1,8 +1,8 @@
import { ModalWrapper } from "@/components/modals/ModalWrapper";
import { Modal } from "@/components/Modal";
export const NoAssistantModal = ({ isAdmin }: { isAdmin: boolean }) => {
return (
<ModalWrapper modalClassName="bg-white max-w-2xl rounded-lg shadow-xl text-center">
<Modal width="bg-white max-w-2xl rounded-lg shadow-xl text-center">
<>
<h2 className="text-3xl font-bold text-gray-800 mb-4">
No Assistant Available
@ -32,6 +32,6 @@ export const NoAssistantModal = ({ isAdmin }: { isAdmin: boolean }) => {
</p>
)}
</>
</ModalWrapper>
</Modal>
);
};