From 0c58c8d6cb8e00f504505953ad880714348611d6 Mon Sep 17 00:00:00 2001
From: Chris Weaver <25087905+Weves@users.noreply.github.com>
Date: Tue, 26 Sep 2023 12:53:19 -0700
Subject: [PATCH] Adding Document Sets (#477)
Adds:
- name for connector credential pairs + frontend changes to start populating this field
- document set table migration
- during indexing, document sets are now checked and inserted into Vespa
- background job to check if document sets need to be synced
- document set management APIs
- document set management dashboard in the UI
---
.../57b53544726e_add_document_set_tables.py | 59 ++++
...24ae9_add_id_to_connectorcredentialpair.py | 60 ++++
backend/danswer/background/celery.py | 88 ------
backend/danswer/background/celery/celery.py | 29 ++
.../danswer/background/celery/celery_utils.py | 37 +++
.../background/celery/deletion_utils.py | 41 +++
.../background/document_set_sync_script.py | 53 ++++
backend/danswer/background/utils.py | 7 +-
.../danswer/datastores/indexing_pipeline.py | 11 +-
.../danswer/db/connector_credential_pair.py | 2 +
backend/danswer/db/document_set.py | 282 ++++++++++++++++++
backend/danswer/db/models.py | 56 ++++
backend/danswer/document_set/document_set.py | 65 ++++
backend/danswer/main.py | 2 +
backend/danswer/server/document_set.py | 101 +++++++
backend/danswer/server/manage.py | 27 +-
backend/danswer/server/models.py | 33 ++
backend/supervisord.conf | 9 +-
.../app/admin/connectors/bookstack/page.tsx | 19 +-
.../app/admin/connectors/confluence/page.tsx | 32 +-
web/src/app/admin/connectors/file/page.tsx | 98 +++---
web/src/app/admin/connectors/github/page.tsx | 16 +-
.../admin/connectors/google-drive/page.tsx | 7 +-
web/src/app/admin/connectors/guru/page.tsx | 13 +-
web/src/app/admin/connectors/jira/page.tsx | 22 +-
web/src/app/admin/connectors/linear/page.tsx | 14 +-
web/src/app/admin/connectors/notion/page.tsx | 16 +-
.../admin/connectors/productboard/page.tsx | 20 +-
web/src/app/admin/connectors/slab/page.tsx | 16 +-
web/src/app/admin/connectors/slack/page.tsx | 105 +++----
web/src/app/admin/connectors/web/page.tsx | 9 +-
web/src/app/admin/connectors/zulip/page.tsx | 10 +-
.../sets/DocumentSetCreationForm.tsx | 179 +++++++++++
web/src/app/admin/documents/sets/hooks.tsx | 13 +
web/src/app/admin/documents/sets/lib.ts | 56 ++++
web/src/app/admin/documents/sets/page.tsx | 268 +++++++++++++++++
web/src/app/admin/indexing/status/page.tsx | 111 +------
web/src/app/admin/layout.tsx | 10 +
web/src/components/Loading.tsx | 20 ++
.../admin/connectors/ConnectorForm.tsx | 105 +++++--
.../admin/connectors/ConnectorTitle.tsx | 109 +++++++
web/src/components/admin/connectors/Field.tsx | 24 +-
.../table/SingleUseConnectorsTable.tsx | 5 +
web/src/components/icons/icons.tsx | 8 +
web/src/lib/credential.ts | 9 +-
web/src/lib/hooks.ts | 24 +-
web/src/lib/types.ts | 18 ++
47 files changed, 1887 insertions(+), 431 deletions(-)
create mode 100644 backend/alembic/versions/57b53544726e_add_document_set_tables.py
create mode 100644 backend/alembic/versions/800f48024ae9_add_id_to_connectorcredentialpair.py
delete mode 100644 backend/danswer/background/celery.py
create mode 100644 backend/danswer/background/celery/celery.py
create mode 100644 backend/danswer/background/celery/celery_utils.py
create mode 100644 backend/danswer/background/celery/deletion_utils.py
create mode 100644 backend/danswer/background/document_set_sync_script.py
create mode 100644 backend/danswer/db/document_set.py
create mode 100644 backend/danswer/document_set/document_set.py
create mode 100644 backend/danswer/server/document_set.py
create mode 100644 web/src/app/admin/documents/sets/DocumentSetCreationForm.tsx
create mode 100644 web/src/app/admin/documents/sets/hooks.tsx
create mode 100644 web/src/app/admin/documents/sets/lib.ts
create mode 100644 web/src/app/admin/documents/sets/page.tsx
create mode 100644 web/src/components/admin/connectors/ConnectorTitle.tsx
diff --git a/backend/alembic/versions/57b53544726e_add_document_set_tables.py b/backend/alembic/versions/57b53544726e_add_document_set_tables.py
new file mode 100644
index 000000000..719f43f23
--- /dev/null
+++ b/backend/alembic/versions/57b53544726e_add_document_set_tables.py
@@ -0,0 +1,59 @@
+"""Add document set tables
+
+Revision ID: 57b53544726e
+Revises: 800f48024ae9
+Create Date: 2023-09-20 16:59:39.097177
+
+"""
+from alembic import op
+import fastapi_users_db_sqlalchemy
+import sqlalchemy as sa
+
+# revision identifiers, used by Alembic.
+revision = "57b53544726e"
+down_revision = "800f48024ae9"
+branch_labels = None
+depends_on = None
+
+
+def upgrade() -> None:
+ op.create_table(
+ "document_set",
+ sa.Column("id", sa.Integer(), nullable=False),
+ sa.Column("name", sa.String(), nullable=False),
+ sa.Column("description", sa.String(), nullable=False),
+ sa.Column(
+ "user_id",
+ fastapi_users_db_sqlalchemy.generics.GUID(),
+ nullable=True,
+ ),
+ sa.Column("is_up_to_date", sa.Boolean(), nullable=False),
+ sa.ForeignKeyConstraint(
+ ["user_id"],
+ ["user.id"],
+ ),
+ sa.PrimaryKeyConstraint("id"),
+ sa.UniqueConstraint("name"),
+ )
+ op.create_table(
+ "document_set__connector_credential_pair",
+ sa.Column("document_set_id", sa.Integer(), nullable=False),
+ sa.Column("connector_credential_pair_id", sa.Integer(), nullable=False),
+ sa.Column("is_current", sa.Boolean(), nullable=False),
+ sa.ForeignKeyConstraint(
+ ["connector_credential_pair_id"],
+ ["connector_credential_pair.id"],
+ ),
+ sa.ForeignKeyConstraint(
+ ["document_set_id"],
+ ["document_set.id"],
+ ),
+ sa.PrimaryKeyConstraint(
+ "document_set_id", "connector_credential_pair_id", "is_current"
+ ),
+ )
+
+
+def downgrade() -> None:
+ op.drop_table("document_set__connector_credential_pair")
+ op.drop_table("document_set")
diff --git a/backend/alembic/versions/800f48024ae9_add_id_to_connectorcredentialpair.py b/backend/alembic/versions/800f48024ae9_add_id_to_connectorcredentialpair.py
new file mode 100644
index 000000000..3074a8af0
--- /dev/null
+++ b/backend/alembic/versions/800f48024ae9_add_id_to_connectorcredentialpair.py
@@ -0,0 +1,60 @@
+"""Add ID to ConnectorCredentialPair
+
+Revision ID: 800f48024ae9
+Revises: 767f1c2a00eb
+Create Date: 2023-09-19 16:13:42.299715
+
+"""
+from alembic import op
+import sqlalchemy as sa
+from sqlalchemy.schema import Sequence, CreateSequence
+
+# revision identifiers, used by Alembic.
+revision = "800f48024ae9"
+down_revision = "767f1c2a00eb"
+branch_labels = None
+depends_on = None
+
+
+def upgrade() -> None:
+ sequence = Sequence("connector_credential_pair_id_seq")
+ op.execute(CreateSequence(sequence)) # type: ignore
+ op.add_column(
+ "connector_credential_pair",
+ sa.Column(
+ "id", sa.Integer(), nullable=True, server_default=sequence.next_value()
+ ),
+ )
+ op.add_column(
+ "connector_credential_pair",
+ sa.Column("name", sa.String(), nullable=True),
+ )
+
+ # fill in IDs for existing rows
+ op.execute(
+ "UPDATE connector_credential_pair SET id = nextval('connector_credential_pair_id_seq') WHERE id IS NULL"
+ )
+ op.alter_column("connector_credential_pair", "id", nullable=False)
+
+ op.create_unique_constraint(
+ "connector_credential_pair__name__key", "connector_credential_pair", ["name"]
+ )
+ op.create_unique_constraint(
+ "connector_credential_pair__id__key", "connector_credential_pair", ["id"]
+ )
+
+
+def downgrade() -> None:
+ op.drop_constraint(
+ "connector_credential_pair__name__key",
+ "connector_credential_pair",
+ type_="unique",
+ )
+ op.drop_constraint(
+ "connector_credential_pair__id__key",
+ "connector_credential_pair",
+ type_="unique",
+ )
+ op.drop_column("connector_credential_pair", "name")
+ op.drop_column("connector_credential_pair", "id")
+ op.execute("DROP SEQUENCE connector_credential_pair_id_seq")
diff --git a/backend/danswer/background/celery.py b/backend/danswer/background/celery.py
deleted file mode 100644
index 2325adcb9..000000000
--- a/backend/danswer/background/celery.py
+++ /dev/null
@@ -1,88 +0,0 @@
-import json
-from typing import cast
-
-from celery import Celery
-from celery.result import AsyncResult
-from sqlalchemy import text
-from sqlalchemy.orm import Session
-
-from danswer.background.connector_deletion import cleanup_connector_credential_pair
-from danswer.background.connector_deletion import get_cleanup_task_id
-from danswer.db.engine import build_connection_string
-from danswer.db.engine import get_sqlalchemy_engine
-from danswer.db.engine import SYNC_DB_API
-from danswer.db.models import DeletionStatus
-from danswer.server.models import DeletionAttemptSnapshot
-
-celery_broker_url = "sqla+" + build_connection_string(db_api=SYNC_DB_API)
-celery_backend_url = "db+" + build_connection_string(db_api=SYNC_DB_API)
-celery_app = Celery(__name__, broker=celery_broker_url, backend=celery_backend_url)
-
-
-@celery_app.task(soft_time_limit=60 * 60 * 6) # 6 hour time limit
-def cleanup_connector_credential_pair_task(
- connector_id: int, credential_id: int
-) -> int:
- return cleanup_connector_credential_pair(connector_id, credential_id)
-
-
-def get_deletion_status(
- connector_id: int, credential_id: int
-) -> DeletionAttemptSnapshot | None:
- cleanup_task_id = get_cleanup_task_id(
- connector_id=connector_id, credential_id=credential_id
- )
- deletion_task = get_celery_task(task_id=cleanup_task_id)
- deletion_task_status = get_celery_task_status(task_id=cleanup_task_id)
-
- deletion_status = None
- error_msg = None
- num_docs_deleted = 0
- if deletion_task_status == "SUCCESS":
- deletion_status = DeletionStatus.SUCCESS
- num_docs_deleted = cast(int, deletion_task.get(propagate=False))
- elif deletion_task_status == "FAILURE":
- deletion_status = DeletionStatus.FAILED
- error_msg = deletion_task.get(propagate=False)
- elif deletion_task_status == "STARTED" or deletion_task_status == "PENDING":
- deletion_status = DeletionStatus.IN_PROGRESS
-
- return (
- DeletionAttemptSnapshot(
- connector_id=connector_id,
- credential_id=credential_id,
- status=deletion_status,
- error_msg=str(error_msg),
- num_docs_deleted=num_docs_deleted,
- )
- if deletion_status
- else None
- )
-
-
-def get_celery_task(task_id: str) -> AsyncResult:
- """NOTE: even if the task doesn't exist, celery will still return something
- with a `PENDING` state"""
- return AsyncResult(task_id, backend=celery_app.backend)
-
-
-def get_celery_task_status(task_id: str) -> str | None:
- """NOTE: is tightly coupled to the internals of kombu (which is the
- translation layer to allow us to use Postgres as a broker). If we change
- the broker, this will need to be updated.
-
- This should not be called on any critical flows.
- """
- task = get_celery_task(task_id)
- # if not pending, then we know the task really exists
- if task.status != "PENDING":
- return task.status
-
- with Session(get_sqlalchemy_engine()) as session:
- rows = session.execute(text("SELECT payload FROM kombu_message WHERE visible"))
- for row in rows:
- payload = json.loads(row[0])
- if payload["headers"]["id"] == task_id:
- return "PENDING"
-
- return None
diff --git a/backend/danswer/background/celery/celery.py b/backend/danswer/background/celery/celery.py
new file mode 100644
index 000000000..47de830b1
--- /dev/null
+++ b/backend/danswer/background/celery/celery.py
@@ -0,0 +1,29 @@
+from celery import Celery
+
+from danswer.background.connector_deletion import cleanup_connector_credential_pair
+from danswer.db.engine import build_connection_string
+from danswer.db.engine import SYNC_DB_API
+from danswer.document_set.document_set import sync_document_set
+from danswer.utils.logger import setup_logger
+
+logger = setup_logger()
+
+celery_broker_url = "sqla+" + build_connection_string(db_api=SYNC_DB_API)
+celery_backend_url = "db+" + build_connection_string(db_api=SYNC_DB_API)
+celery_app = Celery(__name__, broker=celery_broker_url, backend=celery_backend_url)
+
+
+@celery_app.task(soft_time_limit=60 * 60 * 6) # 6 hour time limit
+def cleanup_connector_credential_pair_task(
+ connector_id: int, credential_id: int
+) -> int:
+ return cleanup_connector_credential_pair(connector_id, credential_id)
+
+
+@celery_app.task(soft_time_limit=60 * 60 * 6) # 6 hour time limit
+def sync_document_set_task(document_set_id: int) -> None:
+ try:
+ return sync_document_set(document_set_id=document_set_id)
+ except Exception:
+ logger.exception("Failed to sync document set %s", document_set_id)
+ raise
diff --git a/backend/danswer/background/celery/celery_utils.py b/backend/danswer/background/celery/celery_utils.py
new file mode 100644
index 000000000..c58624065
--- /dev/null
+++ b/backend/danswer/background/celery/celery_utils.py
@@ -0,0 +1,37 @@
+import json
+
+from celery.result import AsyncResult
+from sqlalchemy import text
+from sqlalchemy.orm import Session
+
+from danswer.background.celery.celery import celery_app
+from danswer.db.engine import get_sqlalchemy_engine
+
+
+def get_celery_task(task_id: str) -> AsyncResult:
+ """NOTE: even if the task doesn't exist, celery will still return something
+ with a `PENDING` state"""
+ return AsyncResult(task_id, backend=celery_app.backend)
+
+
+def get_celery_task_status(task_id: str) -> str | None:
+ """NOTE: is tightly coupled to the internals of kombu (which is the
+ translation layer to allow us to use Postgres as a broker). If we change
+ the broker, this will need to be updated.
+
+ This should not be called on any critical flows.
+ """
+ # first check for any pending tasks
+ with Session(get_sqlalchemy_engine()) as session:
+ rows = session.execute(text("SELECT payload FROM kombu_message WHERE visible"))
+ for row in rows:
+ payload = json.loads(row[0])
+ if payload["headers"]["id"] == task_id:
+ return "PENDING"
+
+ task = get_celery_task(task_id)
+ # if not pending, then we know the task really exists
+ if task.status != "PENDING":
+ return task.status
+
+ return None
diff --git a/backend/danswer/background/celery/deletion_utils.py b/backend/danswer/background/celery/deletion_utils.py
new file mode 100644
index 000000000..c6d056022
--- /dev/null
+++ b/backend/danswer/background/celery/deletion_utils.py
@@ -0,0 +1,41 @@
+from typing import cast
+
+from danswer.background.celery.celery_utils import get_celery_task
+from danswer.background.celery.celery_utils import get_celery_task_status
+from danswer.background.connector_deletion import get_cleanup_task_id
+from danswer.db.models import DeletionStatus
+from danswer.server.models import DeletionAttemptSnapshot
+
+
+def get_deletion_status(
+ connector_id: int, credential_id: int
+) -> DeletionAttemptSnapshot | None:
+ cleanup_task_id = get_cleanup_task_id(
+ connector_id=connector_id, credential_id=credential_id
+ )
+ deletion_task = get_celery_task(task_id=cleanup_task_id)
+ deletion_task_status = get_celery_task_status(task_id=cleanup_task_id)
+
+ deletion_status = None
+ error_msg = None
+ num_docs_deleted = 0
+ if deletion_task_status == "SUCCESS":
+ deletion_status = DeletionStatus.SUCCESS
+ num_docs_deleted = cast(int, deletion_task.get(propagate=False))
+ elif deletion_task_status == "FAILURE":
+ deletion_status = DeletionStatus.FAILED
+ error_msg = deletion_task.get(propagate=False)
+ elif deletion_task_status == "STARTED" or deletion_task_status == "PENDING":
+ deletion_status = DeletionStatus.IN_PROGRESS
+
+ return (
+ DeletionAttemptSnapshot(
+ connector_id=connector_id,
+ credential_id=credential_id,
+ status=deletion_status,
+ error_msg=str(error_msg),
+ num_docs_deleted=num_docs_deleted,
+ )
+ if deletion_status
+ else None
+ )
diff --git a/backend/danswer/background/document_set_sync_script.py b/backend/danswer/background/document_set_sync_script.py
new file mode 100644
index 000000000..37ce3be05
--- /dev/null
+++ b/backend/danswer/background/document_set_sync_script.py
@@ -0,0 +1,53 @@
+from celery.result import AsyncResult
+from sqlalchemy.orm import Session
+
+from danswer.background.celery.celery import sync_document_set_task
+from danswer.background.utils import interval_run_job
+from danswer.db.document_set import (
+ fetch_document_sets,
+)
+from danswer.db.engine import get_sqlalchemy_engine
+from danswer.utils.logger import setup_logger
+
+logger = setup_logger()
+
+
+_ExistingTaskCache: dict[int, AsyncResult] = {}
+
+
+def _document_sync_loop() -> None:
+ # cleanup tasks
+ existing_tasks = list(_ExistingTaskCache.items())
+ for document_set_id, task in existing_tasks:
+ if task.ready():
+ logger.info(
+ f"Document set '{document_set_id}' is complete with status "
+ f"{task.status}. Cleaning up."
+ )
+ del _ExistingTaskCache[document_set_id]
+
+ # kick off new tasks
+ with Session(get_sqlalchemy_engine()) as db_session:
+ # check if any document sets are not synced
+ document_set_info = fetch_document_sets(db_session=db_session)
+ for document_set, _ in document_set_info:
+ if not document_set.is_up_to_date:
+ if document_set.id in _ExistingTaskCache:
+ logger.info(
+ f"Document set '{document_set.id}' is already syncing. Skipping."
+ )
+ continue
+
+ logger.info(
+ f"Document set {document_set.id} is not synced. Syncing now!"
+ )
+ task = sync_document_set_task.apply_async(
+ kwargs=dict(document_set_id=document_set.id),
+ )
+ _ExistingTaskCache[document_set.id] = task
+
+
+if __name__ == "__main__":
+ interval_run_job(
+ job=_document_sync_loop, delay=5, emit_job_start_log=False
+ ) # run every 5 seconds
diff --git a/backend/danswer/background/utils.py b/backend/danswer/background/utils.py
index 8103d82ce..b822e9571 100644
--- a/backend/danswer/background/utils.py
+++ b/backend/danswer/background/utils.py
@@ -8,10 +8,13 @@ from danswer.utils.logger import setup_logger
logger = setup_logger()
-def interval_run_job(job: Callable[[], Any], delay: int | float) -> None:
+def interval_run_job(
+ job: Callable[[], Any], delay: int | float, emit_job_start_log: bool = True
+) -> None:
while True:
start = time.time()
- logger.info(f"Running '{job.__name__}', current time: {time.ctime(start)}")
+ if emit_job_start_log:
+ logger.info(f"Running '{job.__name__}', current time: {time.ctime(start)}")
try:
job()
except Exception as e:
diff --git a/backend/danswer/datastores/indexing_pipeline.py b/backend/danswer/datastores/indexing_pipeline.py
index c6b0cfbdd..ea395f368 100644
--- a/backend/danswer/datastores/indexing_pipeline.py
+++ b/backend/danswer/datastores/indexing_pipeline.py
@@ -16,6 +16,7 @@ from danswer.datastores.interfaces import DocumentIndex
from danswer.datastores.interfaces import DocumentMetadata
from danswer.db.document import prepare_to_modify_documents
from danswer.db.document import upsert_documents_complete
+from danswer.db.document_set import fetch_document_sets_for_documents
from danswer.db.engine import get_sqlalchemy_engine
from danswer.search.models import Embedder
from danswer.search.semantic_search import DefaultEmbedder
@@ -99,11 +100,19 @@ def _indexing_pipeline(
document_id_to_access_info = get_access_for_documents(
document_ids=document_ids, db_session=db_session
)
+ document_id_to_document_set = {
+ document_id: document_sets
+ for document_id, document_sets in fetch_document_sets_for_documents(
+ document_ids=document_ids, db_session=db_session
+ )
+ }
access_aware_chunks = [
DocMetadataAwareIndexChunk.from_index_chunk(
index_chunk=chunk,
access=document_id_to_access_info[chunk.source_document.id],
- document_sets=set(),
+ document_sets=set(
+ document_id_to_document_set.get(chunk.source_document.id, [])
+ ),
)
for chunk in chunks_with_embeddings
]
diff --git a/backend/danswer/db/connector_credential_pair.py b/backend/danswer/db/connector_credential_pair.py
index ac6d5dcc0..2b1767f1e 100644
--- a/backend/danswer/db/connector_credential_pair.py
+++ b/backend/danswer/db/connector_credential_pair.py
@@ -110,6 +110,7 @@ def mark_all_in_progress_cc_pairs_failed(
def add_credential_to_connector(
connector_id: int,
credential_id: int,
+ cc_pair_name: str | None,
user: User,
db_session: Session,
) -> StatusResponse[int]:
@@ -143,6 +144,7 @@ def add_credential_to_connector(
association = ConnectorCredentialPair(
connector_id=connector_id,
credential_id=credential_id,
+ name=cc_pair_name,
)
db_session.add(association)
db_session.commit()
diff --git a/backend/danswer/db/document_set.py b/backend/danswer/db/document_set.py
new file mode 100644
index 000000000..aacc2243d
--- /dev/null
+++ b/backend/danswer/db/document_set.py
@@ -0,0 +1,282 @@
+from collections.abc import Sequence
+from typing import cast
+from uuid import UUID
+
+from sqlalchemy import and_
+from sqlalchemy import delete
+from sqlalchemy import func
+from sqlalchemy import select
+from sqlalchemy.orm import Session
+
+from danswer.db.models import ConnectorCredentialPair
+from danswer.db.models import Document
+from danswer.db.models import DocumentByConnectorCredentialPair
+from danswer.db.models import DocumentSet as DocumentSetDBModel
+from danswer.db.models import DocumentSet__ConnectorCredentialPair
+from danswer.server.models import DocumentSetCreationRequest
+from danswer.server.models import DocumentSetUpdateRequest
+
+
+def _delete_document_set_cc_pairs(
+ db_session: Session, document_set_id: int, is_current: bool | None = None
+) -> None:
+ """NOTE: does not commit transaction, this must be done by the caller"""
+ stmt = delete(DocumentSet__ConnectorCredentialPair).where(
+ DocumentSet__ConnectorCredentialPair.document_set_id == document_set_id
+ )
+ if is_current is not None:
+ stmt = stmt.where(DocumentSet__ConnectorCredentialPair.is_current == is_current)
+ db_session.execute(stmt)
+
+
+def _mark_document_set_cc_pairs_as_outdated(
+ db_session: Session, document_set_id: int
+) -> None:
+ """NOTE: does not commit transaction, this must be done by the caller"""
+ stmt = select(DocumentSet__ConnectorCredentialPair).where(
+ DocumentSet__ConnectorCredentialPair.document_set_id == document_set_id
+ )
+ for row in db_session.scalars(stmt):
+ row.is_current = False
+
+
+def get_document_set_by_id(
+ db_session: Session, document_set_id: int
+) -> DocumentSetDBModel | None:
+ return db_session.scalar(
+ select(DocumentSetDBModel).where(DocumentSetDBModel.id == document_set_id)
+ )
+
+
+def insert_document_set(
+ document_set_creation_request: DocumentSetCreationRequest,
+ user_id: UUID | None,
+ db_session: Session,
+) -> tuple[DocumentSetDBModel, list[DocumentSet__ConnectorCredentialPair]]:
+ if not document_set_creation_request.cc_pair_ids:
+ raise ValueError("Cannot create a document set with no CC pairs")
+
+ # start a transaction
+ db_session.begin()
+
+ try:
+ new_document_set_row = DocumentSetDBModel(
+ name=document_set_creation_request.name,
+ description=document_set_creation_request.description,
+ user_id=user_id,
+ )
+ db_session.add(new_document_set_row)
+ db_session.flush() # ensure the new document set gets assigned an ID
+
+ ds_cc_pairs = [
+ DocumentSet__ConnectorCredentialPair(
+ document_set_id=new_document_set_row.id,
+ connector_credential_pair_id=cc_pair_id,
+ is_current=True,
+ )
+ for cc_pair_id in document_set_creation_request.cc_pair_ids
+ ]
+ db_session.add_all(ds_cc_pairs)
+ db_session.commit()
+ except:
+ db_session.rollback()
+ raise
+
+ return new_document_set_row, ds_cc_pairs
+
+
+def update_document_set(
+ document_set_update_request: DocumentSetUpdateRequest, db_session: Session
+) -> tuple[DocumentSetDBModel, list[DocumentSet__ConnectorCredentialPair]]:
+ if not document_set_update_request.cc_pair_ids:
+ raise ValueError("Cannot create a document set with no CC pairs")
+
+ # start a transaction
+ db_session.begin()
+
+ try:
+ # update the description
+ document_set_row = get_document_set_by_id(
+ db_session=db_session, document_set_id=document_set_update_request.id
+ )
+ if document_set_row is None:
+ raise ValueError(
+ f"No document set with ID {document_set_update_request.id}"
+ )
+ document_set_row.description = document_set_update_request.description
+ document_set_row.is_up_to_date = False
+
+ # update the attached CC pairs
+ # first, mark all existing CC pairs as not current
+ _mark_document_set_cc_pairs_as_outdated(
+ db_session=db_session, document_set_id=document_set_row.id
+ )
+ # add in rows for the new CC pairs
+ ds_cc_pairs = [
+ DocumentSet__ConnectorCredentialPair(
+ document_set_id=document_set_update_request.id,
+ connector_credential_pair_id=cc_pair_id,
+ is_current=True,
+ )
+ for cc_pair_id in document_set_update_request.cc_pair_ids
+ ]
+ db_session.add_all(ds_cc_pairs)
+ db_session.commit()
+ except:
+ db_session.rollback()
+ raise
+
+ return document_set_row, ds_cc_pairs
+
+
+def mark_document_set_as_synced(document_set_id: int, db_session: Session) -> None:
+ stmt = select(DocumentSetDBModel).where(DocumentSetDBModel.id == document_set_id)
+ document_set = db_session.scalar(stmt)
+ if document_set is None:
+ raise ValueError(f"No document set with ID: {document_set_id}")
+
+ # mark as up to date
+ document_set.is_up_to_date = True
+ # delete outdated relationship table rows
+ _delete_document_set_cc_pairs(
+ db_session=db_session, document_set_id=document_set_id, is_current=False
+ )
+ db_session.commit()
+
+
+def delete_document_set(document_set_id: int, db_session: Session) -> None:
+ # start a transaction
+ db_session.begin()
+
+ try:
+ document_set_row = get_document_set_by_id(
+ db_session=db_session, document_set_id=document_set_id
+ )
+ if document_set_row is None:
+ raise ValueError(f"No document set with ID: '{document_set_id}'")
+
+ # delete all relationships to CC pairs
+ _delete_document_set_cc_pairs(
+ db_session=db_session, document_set_id=document_set_id
+ )
+ # delete the actual document set row
+ db_session.delete(document_set_row)
+ db_session.commit()
+ except:
+ db_session.rollback()
+ raise
+
+
+def fetch_document_sets(
+ db_session: Session,
+) -> list[tuple[DocumentSetDBModel, list[ConnectorCredentialPair]]]:
+ """Return is a list where each element contains a tuple of:
+ 1. The document set itself
+ 2. All CC pairs associated with the document set"""
+ results = cast(
+ list[tuple[DocumentSetDBModel, ConnectorCredentialPair]],
+ db_session.execute(
+ select(DocumentSetDBModel, ConnectorCredentialPair)
+ .join(
+ DocumentSet__ConnectorCredentialPair,
+ DocumentSetDBModel.id
+ == DocumentSet__ConnectorCredentialPair.document_set_id,
+ )
+ .join(
+ ConnectorCredentialPair,
+ ConnectorCredentialPair.id
+ == DocumentSet__ConnectorCredentialPair.connector_credential_pair_id,
+ )
+ .where(
+ DocumentSet__ConnectorCredentialPair.is_current == True # noqa: E712
+ )
+ ).all(),
+ )
+
+ aggregated_results: dict[
+ int, tuple[DocumentSetDBModel, list[ConnectorCredentialPair]]
+ ] = {}
+ for document_set, cc_pair in results:
+ if document_set.id not in aggregated_results:
+ aggregated_results[document_set.id] = (document_set, [cc_pair])
+ else:
+ aggregated_results[document_set.id][1].append(cc_pair)
+
+ return [
+ (document_set, cc_pairs)
+ for document_set, cc_pairs in aggregated_results.values()
+ ]
+
+
+def fetch_documents_for_document_set(
+ document_set_id: int, db_session: Session, current_only: bool = True
+) -> Sequence[Document]:
+ stmt = (
+ select(Document)
+ .join(
+ DocumentByConnectorCredentialPair,
+ DocumentByConnectorCredentialPair.id == Document.id,
+ )
+ .join(
+ ConnectorCredentialPair,
+ and_(
+ ConnectorCredentialPair.connector_id
+ == DocumentByConnectorCredentialPair.connector_id,
+ ConnectorCredentialPair.credential_id
+ == DocumentByConnectorCredentialPair.credential_id,
+ ),
+ )
+ .join(
+ DocumentSet__ConnectorCredentialPair,
+ DocumentSet__ConnectorCredentialPair.connector_credential_pair_id
+ == ConnectorCredentialPair.id,
+ )
+ .join(
+ DocumentSetDBModel,
+ DocumentSetDBModel.id
+ == DocumentSet__ConnectorCredentialPair.document_set_id,
+ )
+ .where(DocumentSetDBModel.id == document_set_id)
+ )
+ if current_only:
+ stmt = stmt.where(
+ DocumentSet__ConnectorCredentialPair.is_current == True # noqa: E712
+ )
+ stmt = stmt.distinct()
+
+ return db_session.scalars(stmt).all()
+
+
+def fetch_document_sets_for_documents(
+ document_ids: list[str], db_session: Session
+) -> Sequence[tuple[str, list[str]]]:
+ stmt = (
+ select(Document.id, func.array_agg(DocumentSetDBModel.name))
+ .join(
+ DocumentSet__ConnectorCredentialPair,
+ DocumentSetDBModel.id
+ == DocumentSet__ConnectorCredentialPair.document_set_id,
+ )
+ .join(
+ ConnectorCredentialPair,
+ ConnectorCredentialPair.id
+ == DocumentSet__ConnectorCredentialPair.connector_credential_pair_id,
+ )
+ .join(
+ DocumentByConnectorCredentialPair,
+ and_(
+ DocumentByConnectorCredentialPair.connector_id
+ == ConnectorCredentialPair.connector_id,
+ DocumentByConnectorCredentialPair.credential_id
+ == ConnectorCredentialPair.credential_id,
+ ),
+ )
+ .join(
+ Document,
+ Document.id == DocumentByConnectorCredentialPair.id,
+ )
+ .where(Document.id.in_(document_ids))
+ .where(DocumentSet__ConnectorCredentialPair.is_current == True) # noqa: E712
+ .group_by(Document.id)
+ )
+ return db_session.execute(stmt).all() # type: ignore
diff --git a/backend/danswer/db/models.py b/backend/danswer/db/models.py
index 729ddf70b..0130fa88f 100644
--- a/backend/danswer/db/models.py
+++ b/backend/danswer/db/models.py
@@ -14,6 +14,7 @@ from sqlalchemy import ForeignKey
from sqlalchemy import func
from sqlalchemy import Index
from sqlalchemy import Integer
+from sqlalchemy import Sequence
from sqlalchemy import String
from sqlalchemy import Text
from sqlalchemy.dialects import postgresql
@@ -85,6 +86,17 @@ class ConnectorCredentialPair(Base):
"""
__tablename__ = "connector_credential_pair"
+ # NOTE: this `id` column has to use `Sequence` instead of `autoincrement=True`
+ # due to some SQLAlchemy quirks + this not being a primary key column
+ id: Mapped[int] = mapped_column(
+ Integer,
+ Sequence("connector_credential_pair_id_seq"),
+ unique=True,
+ nullable=False,
+ )
+ name: Mapped[str] = mapped_column(
+ String, unique=True, nullable=True
+ ) # nullable for backwards compatability
connector_id: Mapped[int] = mapped_column(
ForeignKey("connector.id"), primary_key=True
)
@@ -242,6 +254,7 @@ class DocumentByConnectorCredentialPair(Base):
__tablename__ = "document_by_connector_credential_pair"
id: Mapped[str] = mapped_column(ForeignKey("document.id"), primary_key=True)
+ # TODO: transition this to use the ConnectorCredentialPair id directly
connector_id: Mapped[int] = mapped_column(
ForeignKey("connector.id"), primary_key=True
)
@@ -326,6 +339,49 @@ class Document(Base):
)
+class DocumentSet(Base):
+ __tablename__ = "document_set"
+
+ id: Mapped[int] = mapped_column(Integer, primary_key=True)
+ name: Mapped[str] = mapped_column(String, unique=True)
+ description: Mapped[str] = mapped_column(String)
+ user_id: Mapped[UUID | None] = mapped_column(ForeignKey("user.id"), nullable=True)
+ # whether or not changes to the document set have been propogated
+ is_up_to_date: Mapped[bool] = mapped_column(Boolean, nullable=False, default=False)
+
+ connector_credential_pair_relationships: Mapped[
+ list["DocumentSet__ConnectorCredentialPair"]
+ ] = relationship(
+ "DocumentSet__ConnectorCredentialPair", back_populates="document_set"
+ )
+
+
+class DocumentSet__ConnectorCredentialPair(Base):
+ __tablename__ = "document_set__connector_credential_pair"
+
+ document_set_id: Mapped[int] = mapped_column(
+ ForeignKey("document_set.id"), primary_key=True
+ )
+ connector_credential_pair_id: Mapped[int] = mapped_column(
+ ForeignKey("connector_credential_pair.id"), primary_key=True
+ )
+ # if `True`, then is part of the current state of the document set
+ # if `False`, then is a part of the prior state of the document set
+ # rows with `is_current=False` should be deleted when the document
+ # set is updated and should not exist for a given document set if
+ # `DocumentSet.is_up_to_date == True`
+ is_current: Mapped[bool] = mapped_column(
+ Boolean,
+ nullable=False,
+ default=True,
+ primary_key=True,
+ )
+
+ document_set: Mapped[DocumentSet] = relationship(
+ "DocumentSet", back_populates="connector_credential_pair_relationships"
+ )
+
+
class ChatSession(Base):
__tablename__ = "chat_session"
diff --git a/backend/danswer/document_set/document_set.py b/backend/danswer/document_set/document_set.py
new file mode 100644
index 000000000..0b9ca99b5
--- /dev/null
+++ b/backend/danswer/document_set/document_set.py
@@ -0,0 +1,65 @@
+from sqlalchemy.orm import Session
+
+from danswer.datastores.document_index import get_default_document_index
+from danswer.datastores.interfaces import DocumentIndex
+from danswer.datastores.interfaces import UpdateRequest
+from danswer.db.document import prepare_to_modify_documents
+from danswer.db.document_set import fetch_document_sets_for_documents
+from danswer.db.document_set import fetch_documents_for_document_set
+from danswer.db.document_set import mark_document_set_as_synced
+from danswer.db.engine import get_sqlalchemy_engine
+from danswer.utils.batching import batch_generator
+from danswer.utils.logger import setup_logger
+
+logger = setup_logger()
+
+_SYNC_BATCH_SIZE = 1000
+
+
+def _sync_document_batch(
+ document_ids: list[str], document_index: DocumentIndex
+) -> None:
+ logger.debug(f"Syncing document sets for: {document_ids}")
+ # begin a transaction, release lock at the end
+ with Session(get_sqlalchemy_engine()) as db_session:
+ # acquires a lock on the documents so that no other process can modify them
+ prepare_to_modify_documents(db_session=db_session, document_ids=document_ids)
+
+ # get current state of document sets for these documents
+ document_set_map = {
+ document_id: document_sets
+ for document_id, document_sets in fetch_document_sets_for_documents(
+ document_ids=document_ids, db_session=db_session
+ )
+ }
+
+ # update Vespa
+ document_index.update(
+ update_requests=[
+ UpdateRequest(
+ document_ids=[document_id],
+ document_sets=set(document_set_map.get(document_id, [])),
+ )
+ for document_id in document_ids
+ ]
+ )
+
+
+def sync_document_set(document_set_id: int) -> None:
+ document_index = get_default_document_index()
+ with Session(get_sqlalchemy_engine()) as db_session:
+ documents_to_update = fetch_documents_for_document_set(
+ document_set_id=document_set_id,
+ db_session=db_session,
+ current_only=False,
+ )
+ for document_batch in batch_generator(documents_to_update, _SYNC_BATCH_SIZE):
+ _sync_document_batch(
+ document_ids=[document.id for document in document_batch],
+ document_index=document_index,
+ )
+
+ mark_document_set_as_synced(
+ document_set_id=document_set_id, db_session=db_session
+ )
+ logger.info(f"Document set sync for '{document_set_id}' complete!")
diff --git a/backend/danswer/main.py b/backend/danswer/main.py
index 6c3063f4f..70b9e1584 100644
--- a/backend/danswer/main.py
+++ b/backend/danswer/main.py
@@ -37,6 +37,7 @@ from danswer.db.credentials import create_initial_public_credential
from danswer.direct_qa.llm_utils import get_default_qa_model
from danswer.server.chat_backend import router as chat_router
from danswer.server.credential import router as credential_router
+from danswer.server.document_set import router as document_set_router
from danswer.server.event_loading import router as event_processing_router
from danswer.server.health import router as health_router
from danswer.server.manage import router as admin_router
@@ -77,6 +78,7 @@ def get_application() -> FastAPI:
application.include_router(admin_router)
application.include_router(user_router)
application.include_router(credential_router)
+ application.include_router(document_set_router)
application.include_router(health_router)
application.include_router(
diff --git a/backend/danswer/server/document_set.py b/backend/danswer/server/document_set.py
new file mode 100644
index 000000000..f0a2f9886
--- /dev/null
+++ b/backend/danswer/server/document_set.py
@@ -0,0 +1,101 @@
+from fastapi import APIRouter
+from fastapi import Depends
+from fastapi import HTTPException
+from sqlalchemy.orm import Session
+
+from danswer.auth.users import current_admin_user
+from danswer.auth.users import current_user
+from danswer.db.document_set import delete_document_set as delete_document_set_from_db
+from danswer.db.document_set import fetch_document_sets
+from danswer.db.document_set import insert_document_set
+from danswer.db.document_set import update_document_set
+from danswer.db.engine import get_session
+from danswer.db.models import User
+from danswer.server.models import ConnectorCredentialPairDescriptor
+from danswer.server.models import ConnectorSnapshot
+from danswer.server.models import CredentialSnapshot
+from danswer.server.models import DocumentSet
+from danswer.server.models import DocumentSetCreationRequest
+from danswer.server.models import DocumentSetUpdateRequest
+
+
+router = APIRouter(prefix="/manage")
+
+
+@router.post("/admin/document-set")
+def create_document_set(
+ document_set_creation_request: DocumentSetCreationRequest,
+ user: User = Depends(current_admin_user),
+ db_session: Session = Depends(get_session),
+) -> int:
+ try:
+ document_set_db_model, _ = insert_document_set(
+ document_set_creation_request=document_set_creation_request,
+ user_id=user.id if user else None,
+ db_session=db_session,
+ )
+ except Exception as e:
+ raise HTTPException(status_code=400, detail=str(e))
+ return document_set_db_model.id
+
+
+@router.patch("/admin/document-set")
+def patch_document_set(
+ document_set_update_request: DocumentSetUpdateRequest,
+ _: User = Depends(current_admin_user),
+ db_session: Session = Depends(get_session),
+) -> None:
+ try:
+ update_document_set(
+ document_set_update_request=document_set_update_request,
+ db_session=db_session,
+ )
+ except Exception as e:
+ raise HTTPException(status_code=400, detail=str(e))
+
+
+@router.delete("/admin/document-set/{document_set_id}")
+def delete_document_set(
+ document_set_id: int,
+ _: User = Depends(current_admin_user),
+ db_session: Session = Depends(get_session),
+) -> None:
+ try:
+ delete_document_set_from_db(
+ document_set_id=document_set_id, db_session=db_session
+ )
+ except Exception as e:
+ raise HTTPException(status_code=400, detail=str(e))
+
+
+"""Endpoints for non-admins"""
+
+
+@router.get("/document-set")
+def list_document_sets(
+ _: User = Depends(current_user),
+ db_session: Session = Depends(get_session),
+) -> list[DocumentSet]:
+ document_set_info = fetch_document_sets(db_session=db_session)
+ return [
+ DocumentSet(
+ id=document_set_db_model.id,
+ name=document_set_db_model.name,
+ description=document_set_db_model.description,
+ cc_pair_descriptors=[
+ ConnectorCredentialPairDescriptor(
+ id=cc_pair.id,
+ name=cc_pair.name,
+ connector=ConnectorSnapshot.from_connector_db_model(
+ cc_pair.connector
+ ),
+ credential=CredentialSnapshot.from_credential_db_model(
+ cc_pair.credential
+ ),
+ )
+ for cc_pair in cc_pairs
+ ],
+ is_up_to_date=document_set_db_model.is_up_to_date,
+ )
+ for document_set_db_model, cc_pairs in document_set_info
+ ]
diff --git a/backend/danswer/server/manage.py b/backend/danswer/server/manage.py
index 3803e3296..82dc35b9d 100644
--- a/backend/danswer/server/manage.py
+++ b/backend/danswer/server/manage.py
@@ -8,12 +8,13 @@ from fastapi import HTTPException
from fastapi import Request
from fastapi import Response
from fastapi import UploadFile
+from sqlalchemy.exc import IntegrityError
from sqlalchemy.orm import Session
from danswer.auth.users import current_admin_user
from danswer.auth.users import current_user
-from danswer.background.celery import cleanup_connector_credential_pair_task
-from danswer.background.celery import get_deletion_status
+from danswer.background.celery.celery import cleanup_connector_credential_pair_task
+from danswer.background.celery.deletion_utils import get_deletion_status
from danswer.background.connector_deletion import (
get_cleanup_task_id,
)
@@ -69,6 +70,7 @@ from danswer.server.models import BoostDoc
from danswer.server.models import BoostUpdateRequest
from danswer.server.models import ConnectorBase
from danswer.server.models import ConnectorCredentialPairIdentifier
+from danswer.server.models import ConnectorCredentialPairMetadata
from danswer.server.models import ConnectorIndexingStatus
from danswer.server.models import ConnectorSnapshot
from danswer.server.models import CredentialSnapshot
@@ -316,6 +318,8 @@ def get_connector_indexing_status(
)
indexing_statuses.append(
ConnectorIndexingStatus(
+ cc_pair_id=cc_pair.id,
+ name=cc_pair.name,
connector=ConnectorSnapshot.from_connector_db_model(connector),
credential=CredentialSnapshot.from_credential_db_model(credential),
public_doc=credential.public_doc,
@@ -390,8 +394,11 @@ def delete_connector_by_id(
_: User = Depends(current_admin_user),
db_session: Session = Depends(get_session),
) -> StatusResponse[int]:
- with db_session.begin():
- return delete_connector(db_session=db_session, connector_id=connector_id)
+ try:
+ with db_session.begin():
+ return delete_connector(db_session=db_session, connector_id=connector_id)
+ except AssertionError:
+ raise HTTPException(status_code=400, detail="Connector is not deletable")
@router.post("/admin/connector/run-once")
@@ -650,10 +657,20 @@ def get_connector_by_id(
def associate_credential_to_connector(
connector_id: int,
credential_id: int,
+ metadata: ConnectorCredentialPairMetadata,
user: User = Depends(current_user),
db_session: Session = Depends(get_session),
) -> StatusResponse[int]:
- return add_credential_to_connector(connector_id, credential_id, user, db_session)
+ try:
+ return add_credential_to_connector(
+ connector_id=connector_id,
+ credential_id=credential_id,
+ cc_pair_name=metadata.name,
+ user=user,
+ db_session=db_session,
+ )
+ except IntegrityError:
+ raise HTTPException(status_code=400, detail="Name must be unique")
@router.delete("/connector/{connector_id}/credential/{credential_id}")
diff --git a/backend/danswer/server/models.py b/backend/danswer/server/models.py
index 6c7ecbe78..24a9474bd 100644
--- a/backend/danswer/server/models.py
+++ b/backend/danswer/server/models.py
@@ -333,6 +333,8 @@ class CredentialSnapshot(CredentialBase):
class ConnectorIndexingStatus(BaseModel):
"""Represents the latest indexing status of a connector"""
+ cc_pair_id: int
+ name: str | None
connector: ConnectorSnapshot
credential: CredentialSnapshot
owner: str
@@ -351,5 +353,36 @@ class ConnectorCredentialPairIdentifier(BaseModel):
credential_id: int
+class ConnectorCredentialPairMetadata(BaseModel):
+ name: str | None
+
+
+class ConnectorCredentialPairDescriptor(BaseModel):
+ id: int
+ name: str | None
+ connector: ConnectorSnapshot
+ credential: CredentialSnapshot
+
+
class ApiKey(BaseModel):
api_key: str
+
+
+class DocumentSetCreationRequest(BaseModel):
+ name: str
+ description: str
+ cc_pair_ids: list[int]
+
+
+class DocumentSetUpdateRequest(BaseModel):
+ id: int
+ description: str
+ cc_pair_ids: list[int]
+
+
+class DocumentSet(BaseModel):
+ id: int
+ name: str
+ description: str
+ cc_pair_descriptors: list[ConnectorCredentialPairDescriptor]
+ is_up_to_date: bool
diff --git a/backend/supervisord.conf b/backend/supervisord.conf
index 47e2c17d8..b28fa0024 100644
--- a/backend/supervisord.conf
+++ b/backend/supervisord.conf
@@ -24,6 +24,13 @@ stdout_logfile_maxbytes=52428800
redirect_stderr=true
autorestart=true
+[program:document_set_sync]
+command=python danswer/background/document_set_sync_script.py
+stdout_logfile=/var/log/document_set_sync.log
+stdout_logfile_maxbytes=52428800
+redirect_stderr=true
+autorestart=true
+
# Listens for slack messages and responds with answers
# for all channels that the DanswerBot has been added to.
# If not setup, this will just fail 5 times and then stop.
@@ -39,7 +46,7 @@ startsecs=60
# pushes all logs from the above programs to stdout
[program:log-redirect-handler]
-command=tail -qF /var/log/update.log /var/log/celery.log /var/log/file_deletion.log /var/log/slack_bot_listener.log
+command=tail -qF /var/log/update.log /var/log/celery.log /var/log/file_deletion.log /var/log/slack_bot_listener.log /var/log/document_set_sync.log
stdout_logfile=/dev/stdout
stdout_logfile_maxbytes=0
redirect_stderr=true
diff --git a/web/src/app/admin/connectors/bookstack/page.tsx b/web/src/app/admin/connectors/bookstack/page.tsx
index d32985db7..86a95c1ea 100644
--- a/web/src/app/admin/connectors/bookstack/page.tsx
+++ b/web/src/app/admin/connectors/bookstack/page.tsx
@@ -9,6 +9,7 @@ import {
BookstackCredentialJson,
BookstackConfig,
ConnectorIndexingStatus,
+ Credential,
} from "@/lib/types";
import useSWR, { useSWRConfig } from "swr";
import { fetcher } from "@/lib/fetcher";
@@ -60,9 +61,10 @@ const Main = () => {
(connectorIndexingStatus) =>
connectorIndexingStatus.connector.source === "bookstack"
);
- const bookstackCredential = credentialsData.filter(
- (credential) => credential.credential_json?.bookstack_api_token_id
- )[0];
+ const bookstackCredential: Credential
Specify files below, click the Upload button, and the contents of these files will be searchable via Danswer! Currently only .txt{" "} @@ -83,18 +76,19 @@ const Main = () => { documentation.
+ Please provide your slack bot token in Step 1 first! Once done with + that, you can then specify which Slack channels you want to make + searchable. +
+ )} > ); }; diff --git a/web/src/app/admin/connectors/web/page.tsx b/web/src/app/admin/connectors/web/page.tsx index 7d0f40177..0886c2a12 100644 --- a/web/src/app/admin/connectors/web/page.tsx +++ b/web/src/app/admin/connectors/web/page.tsx @@ -49,6 +49,8 @@ export default function Web() {