Move connector / credential pair deletion to celery

This commit is contained in:
Weves
2023-09-14 11:01:04 -07:00
committed by Chris Weaver
parent 3fc7a13a31
commit c4e0face9b
13 changed files with 237 additions and 289 deletions

View File

@@ -0,0 +1,73 @@
"""Remove deletion_attempt table
Revision ID: d5645c915d0e
Revises: 5809c0787398
Create Date: 2023-09-14 15:04:14.444909
"""
from alembic import op
import sqlalchemy as sa
from sqlalchemy.dialects import postgresql
# revision identifiers, used by Alembic.
revision = "d5645c915d0e"
down_revision = "5809c0787398"
branch_labels = None
depends_on = None
def upgrade() -> None:
op.drop_table("deletion_attempt")
def downgrade() -> None:
op.create_table(
"deletion_attempt",
sa.Column("id", sa.INTEGER(), autoincrement=True, nullable=False),
sa.Column("connector_id", sa.INTEGER(), autoincrement=False, nullable=False),
sa.Column("credential_id", sa.INTEGER(), autoincrement=False, nullable=False),
sa.Column(
"status",
postgresql.ENUM(
"NOT_STARTED",
"IN_PROGRESS",
"SUCCESS",
"FAILED",
name="deletionstatus",
),
autoincrement=False,
nullable=False,
),
sa.Column(
"num_docs_deleted",
sa.INTEGER(),
autoincrement=False,
nullable=False,
),
sa.Column("error_msg", sa.VARCHAR(), autoincrement=False, nullable=True),
sa.Column(
"time_created",
postgresql.TIMESTAMP(timezone=True),
server_default=sa.text("now()"),
autoincrement=False,
nullable=False,
),
sa.Column(
"time_updated",
postgresql.TIMESTAMP(timezone=True),
server_default=sa.text("now()"),
autoincrement=False,
nullable=False,
),
sa.ForeignKeyConstraint(
["connector_id"],
["connector.id"],
name="deletion_attempt_connector_id_fkey",
),
sa.ForeignKeyConstraint(
["credential_id"],
["credential.id"],
name="deletion_attempt_credential_id_fkey",
),
sa.PrimaryKeyConstraint("id", name="deletion_attempt_pkey"),
)

View File

@@ -0,0 +1,88 @@
import json
from typing import cast
from celery import Celery
from celery.result import AsyncResult
from sqlalchemy import text
from sqlalchemy.orm import Session
from danswer.background.connector_deletion import cleanup_connector_credential_pair
from danswer.background.connector_deletion import get_cleanup_task_id
from danswer.db.engine import build_connection_string
from danswer.db.engine import get_sqlalchemy_engine
from danswer.db.engine import SYNC_DB_API
from danswer.db.models import DeletionStatus
from danswer.server.models import DeletionAttemptSnapshot
celery_broker_url = "sqla+" + build_connection_string(db_api=SYNC_DB_API)
celery_backend_url = "db+" + build_connection_string(db_api=SYNC_DB_API)
celery_app = Celery(__name__, broker=celery_broker_url, backend=celery_backend_url)
@celery_app.task(soft_time_limit=60 * 60 * 6) # 6 hour time limit
def cleanup_connector_credential_pair_task(
connector_id: int, credential_id: int
) -> int:
return cleanup_connector_credential_pair(connector_id, credential_id)
def get_deletion_status(
connector_id: int, credential_id: int
) -> DeletionAttemptSnapshot | None:
cleanup_task_id = get_cleanup_task_id(
connector_id=connector_id, credential_id=credential_id
)
deletion_task = get_celery_task(task_id=cleanup_task_id)
deletion_task_status = get_celery_task_status(task_id=cleanup_task_id)
deletion_status = None
error_msg = None
num_docs_deleted = 0
if deletion_task_status == "SUCCESS":
deletion_status = DeletionStatus.SUCCESS
num_docs_deleted = cast(int, deletion_task.get(propagate=False))
elif deletion_task_status == "FAILURE":
deletion_status = DeletionStatus.FAILED
error_msg = deletion_task.get(propagate=False)
elif deletion_task_status == "STARTED" or deletion_task_status == "PENDING":
deletion_status = DeletionStatus.IN_PROGRESS
return (
DeletionAttemptSnapshot(
connector_id=connector_id,
credential_id=credential_id,
status=deletion_status,
error_msg=str(error_msg),
num_docs_deleted=num_docs_deleted,
)
if deletion_status
else None
)
def get_celery_task(task_id: str) -> AsyncResult:
"""NOTE: even if the task doesn't exist, celery will still return something
with a `PENDING` state"""
return AsyncResult(task_id, backend=celery_app.backend)
def get_celery_task_status(task_id: str) -> str | None:
"""NOTE: is tightly coupled to the internals of kombu (which is the
translation layer to allow us to use Postgres as a broker). If we change
the broker, this will need to be updated.
This should not be called on any critical flows.
"""
task = get_celery_task(task_id)
# if not pending, then we know the task really exists
if task.status != "PENDING":
return task.status
with Session(get_sqlalchemy_engine()) as session:
rows = session.execute(text("SELECT payload FROM kombu_message WHERE visible"))
for row in rows:
payload = json.loads(row[0])
if payload["headers"]["id"] == task_id:
return "PENDING"
return None

View File

@@ -10,9 +10,7 @@ are multiple connector / credential pairs that have indexed it
connector / credential pair from the access list
(6) delete all relevant entries from postgres
"""
import time
from collections import defaultdict
from datetime import datetime
from sqlalchemy.orm import Session
@@ -24,8 +22,6 @@ from danswer.db.connector import fetch_connector_by_id
from danswer.db.connector_credential_pair import delete_connector_credential_pair
from danswer.db.connector_credential_pair import get_connector_credential_pair
from danswer.db.deletion_attempt import check_deletion_attempt_is_allowed
from danswer.db.deletion_attempt import delete_deletion_attempts
from danswer.db.deletion_attempt import get_deletion_attempts
from danswer.db.document import (
delete_document_by_connector_credential_pair_for_connector_credential_pair,
)
@@ -38,9 +34,8 @@ from danswer.db.document import (
)
from danswer.db.engine import get_sqlalchemy_engine
from danswer.db.index_attempt import delete_index_attempts
from danswer.db.models import ConnectorCredentialPair
from danswer.db.models import Credential
from danswer.db.models import DeletionAttempt
from danswer.db.models import DeletionStatus
from danswer.utils.logger import setup_logger
logger = setup_logger()
@@ -49,10 +44,23 @@ logger = setup_logger()
def _delete_connector_credential_pair(
db_session: Session,
document_index: DocumentIndex,
deletion_attempt: DeletionAttempt,
connector_id: int,
credential_id: int,
) -> int:
connector_id = deletion_attempt.connector_id
credential_id = deletion_attempt.credential_id
# validate that the connector / credential pair is deletable
cc_pair = get_connector_credential_pair(
db_session=db_session,
connector_id=connector_id,
credential_id=credential_id,
)
if not cc_pair or not check_deletion_attempt_is_allowed(
connector_credential_pair=cc_pair
):
raise ValueError(
"Cannot run deletion attempt - connector_credential_pair is not deletable. "
"This is likely because there is an ongoing / planned indexing attempt OR the "
"connector is not disabled."
)
def _delete_singly_indexed_docs() -> int:
# if a document store entry is only indexed by this connector_credential_pair, delete it
@@ -78,7 +86,9 @@ def _delete_connector_credential_pair(
num_docs_deleted = _delete_singly_indexed_docs()
logger.info(f"Deleted {num_docs_deleted} documents from document stores")
def _update_multi_indexed_docs() -> None:
def _update_multi_indexed_docs(
connector_credential_pair: ConnectorCredentialPair,
) -> None:
# if a document is indexed by multiple connector_credential_pairs, we should
# update its access rather than outright delete it
document_by_connector_credential_pairs_to_update = (
@@ -99,7 +109,7 @@ def _delete_connector_credential_pair(
# find out which documents need to be updated and what their new allowed_users
# should be. This is a bit slow as it requires looping through all the documents
to_be_deleted_user = _get_user(deletion_attempt.credential)
to_be_deleted_user = _get_user(connector_credential_pair.credential)
document_ids_not_needing_update: set[str] = set()
document_id_to_allowed_users: dict[str, list[str]] = defaultdict(list)
for (
@@ -145,7 +155,7 @@ def _delete_connector_credential_pair(
credential_id=credential_id,
)
_update_multi_indexed_docs()
_update_multi_indexed_docs(cc_pair)
def _cleanup() -> None:
# cleanup everything else up
@@ -162,11 +172,6 @@ def _delete_connector_credential_pair(
connector_id=connector_id,
credential_id=credential_id,
)
delete_deletion_attempts(
db_session=db_session,
connector_id=connector_id,
credential_id=credential_id,
)
delete_connector_credential_pair(
db_session=db_session,
connector_id=connector_id,
@@ -191,88 +196,20 @@ def _delete_connector_credential_pair(
return num_docs_deleted
def _run_deletion(db_session: Session) -> None:
# NOTE: makes the assumption that there is only one deletion job running at a time
deletion_attempts = get_deletion_attempts(
db_session, statuses=[DeletionStatus.NOT_STARTED], limit=1
)
if not deletion_attempts:
logger.info("No deletion attempts to run")
return
deletion_attempt = deletion_attempts[0]
# validate that the connector / credential pair is deletable
cc_pair = get_connector_credential_pair(
db_session=db_session,
connector_id=deletion_attempt.connector_id,
credential_id=deletion_attempt.credential_id,
)
if not cc_pair or not check_deletion_attempt_is_allowed(
connector_credential_pair=cc_pair
):
error_msg = (
"Cannot run deletion attempt - connector_credential_pair is not deletable. "
"This is likely because there is an ongoing / planned indexing attempt OR the "
"connector is not disabled."
)
logger.error(error_msg)
deletion_attempt.status = DeletionStatus.FAILED
deletion_attempt.error_msg = error_msg
db_session.commit()
return
# kick off the actual deletion process
deletion_attempt.status = DeletionStatus.IN_PROGRESS
db_session.commit()
try:
num_docs_deleted = _delete_connector_credential_pair(
db_session=db_session,
document_index=get_default_document_index(),
deletion_attempt=deletion_attempt,
)
except Exception as e:
logger.exception(f"Failed to delete connector_credential_pair due to {e}")
deletion_attempt.status = DeletionStatus.FAILED
deletion_attempt.error_msg = str(e)
db_session.commit()
return
deletion_attempt.status = DeletionStatus.SUCCESS
deletion_attempt.num_docs_deleted = num_docs_deleted
db_session.commit()
def _cleanup_deletion_jobs(db_session: Session) -> None:
"""Cleanup any deletion jobs that were in progress but failed to complete
NOTE: makes the assumption that there is only one deletion job running at a time.
If multiple deletion jobs can be run at once, then this behavior no longer makes
sense."""
deletion_attempts = get_deletion_attempts(
db_session,
statuses=[DeletionStatus.IN_PROGRESS],
)
for deletion_attempt in deletion_attempts:
deletion_attempt.status = DeletionStatus.FAILED
db_session.commit()
def _update_loop(delay: int = 10) -> None:
def cleanup_connector_credential_pair(connector_id: int, credential_id: int) -> int:
engine = get_sqlalchemy_engine()
while True:
start = time.time()
start_time_utc = datetime.utcfromtimestamp(start).strftime("%Y-%m-%d %H:%M:%S")
logger.info(f"Running connector_deletion, current UTC time: {start_time_utc}")
with Session(engine) as db_session:
try:
with Session(engine) as db_session:
_run_deletion(db_session)
_cleanup_deletion_jobs(db_session)
return _delete_connector_credential_pair(
db_session=db_session,
document_index=get_default_document_index(),
connector_id=connector_id,
credential_id=credential_id,
)
except Exception as e:
logger.exception(f"Failed to run connector_deletion due to {e}")
sleep_time = delay - (time.time() - start)
if sleep_time > 0:
time.sleep(sleep_time)
raise e
if __name__ == "__main__":
_update_loop()
def get_cleanup_task_id(connector_id: int, credential_id: int) -> str:
return f"cleanup_connector_credential_pair_{connector_id}_{credential_id}"

View File

@@ -1,12 +1,4 @@
from sqlalchemy import and_
from sqlalchemy import delete
from sqlalchemy import desc
from sqlalchemy import select
from sqlalchemy.orm import Session
from danswer.db.models import ConnectorCredentialPair
from danswer.db.models import DeletionAttempt
from danswer.db.models import DeletionStatus
from danswer.db.models import IndexingStatus
@@ -26,60 +18,3 @@ def check_deletion_attempt_is_allowed(
!= IndexingStatus.NOT_STARTED
)
)
def create_deletion_attempt(
connector_id: int,
credential_id: int,
db_session: Session,
) -> int:
new_attempt = DeletionAttempt(
connector_id=connector_id,
credential_id=credential_id,
status=DeletionStatus.NOT_STARTED,
)
db_session.add(new_attempt)
db_session.commit()
return new_attempt.id
def get_not_started_index_attempts(db_session: Session) -> list[DeletionAttempt]:
stmt = select(DeletionAttempt).where(
DeletionAttempt.status == DeletionStatus.NOT_STARTED
)
not_started_deletion_attempts = db_session.scalars(stmt)
return list(not_started_deletion_attempts.all())
def get_deletion_attempts(
db_session: Session,
connector_ids: list[int] | None = None,
statuses: list[DeletionStatus] | None = None,
ordered_by_time_updated: bool = False,
limit: int | None = None,
) -> list[DeletionAttempt]:
stmt = select(DeletionAttempt)
if connector_ids:
stmt = stmt.where(DeletionAttempt.connector_id.in_(connector_ids))
if statuses:
stmt = stmt.where(DeletionAttempt.status.in_(statuses))
if ordered_by_time_updated:
stmt = stmt.order_by(desc(DeletionAttempt.time_updated))
if limit:
stmt = stmt.limit(limit)
deletion_attempts = db_session.scalars(stmt)
return list(deletion_attempts.all())
def delete_deletion_attempts(
db_session: Session, connector_id: int, credential_id: int
) -> None:
stmt = delete(DeletionAttempt).where(
and_(
DeletionAttempt.connector_id == connector_id,
DeletionAttempt.credential_id == credential_id,
)
)
db_session.execute(stmt)

View File

@@ -2,7 +2,6 @@ import datetime
from enum import Enum as PyEnum
from typing import Any
from typing import List
from typing import Optional
from uuid import UUID
from fastapi_users.db import SQLAlchemyBaseOAuthAccountTableUUID
@@ -141,9 +140,6 @@ class Connector(Base):
index_attempts: Mapped[List["IndexAttempt"]] = relationship(
"IndexAttempt", back_populates="connector"
)
deletion_attempt: Mapped[Optional["DeletionAttempt"]] = relationship(
"DeletionAttempt", back_populates="connector"
)
class Credential(Base):
@@ -171,9 +167,6 @@ class Credential(Base):
index_attempts: Mapped[List["IndexAttempt"]] = relationship(
"IndexAttempt", back_populates="credential"
)
deletion_attempt: Mapped[Optional["DeletionAttempt"]] = relationship(
"DeletionAttempt", back_populates="credential"
)
user: Mapped[User | None] = relationship("User", back_populates="credentials")
@@ -242,43 +235,6 @@ class IndexAttempt(Base):
)
class DeletionAttempt(Base):
"""Represents an attempt to delete all documents indexed by a specific
connector / credential pair.
"""
__tablename__ = "deletion_attempt"
id: Mapped[int] = mapped_column(primary_key=True)
connector_id: Mapped[int] = mapped_column(
ForeignKey("connector.id"),
)
credential_id: Mapped[int] = mapped_column(
ForeignKey("credential.id"),
)
status: Mapped[DeletionStatus] = mapped_column(Enum(DeletionStatus))
num_docs_deleted: Mapped[int] = mapped_column(Integer, default=0)
error_msg: Mapped[str | None] = mapped_column(
Text, default=None
) # only filled if status = "failed"
time_created: Mapped[datetime.datetime] = mapped_column(
DateTime(timezone=True),
server_default=func.now(),
)
time_updated: Mapped[datetime.datetime] = mapped_column(
DateTime(timezone=True),
server_default=func.now(),
onupdate=func.now(),
)
connector: Mapped[Connector] = relationship(
"Connector", back_populates="deletion_attempt"
)
credential: Mapped[Credential] = relationship(
"Credential", back_populates="deletion_attempt"
)
class DocumentByConnectorCredentialPair(Base):
"""Represents an indexing of a document by a specific connector / credential
pair"""

View File

@@ -12,6 +12,11 @@ from sqlalchemy.orm import Session
from danswer.auth.users import current_admin_user
from danswer.auth.users import current_user
from danswer.background.celery import cleanup_connector_credential_pair_task
from danswer.background.celery import get_deletion_status
from danswer.background.connector_deletion import (
get_cleanup_task_id,
)
from danswer.configs.app_configs import DISABLE_GENERATIVE_AI
from danswer.configs.app_configs import GENERATIVE_MODEL_ACCESS_CHECK_FREQ
from danswer.configs.constants import GEN_AI_API_KEY_STORAGE_KEY
@@ -46,14 +51,11 @@ from danswer.db.credentials import create_credential
from danswer.db.credentials import delete_google_drive_service_account_credentials
from danswer.db.credentials import fetch_credential_by_id
from danswer.db.deletion_attempt import check_deletion_attempt_is_allowed
from danswer.db.deletion_attempt import create_deletion_attempt
from danswer.db.deletion_attempt import get_deletion_attempts
from danswer.db.engine import get_session
from danswer.db.feedback import fetch_docs_ranked_by_boost
from danswer.db.feedback import update_document_boost
from danswer.db.index_attempt import create_index_attempt
from danswer.db.index_attempt import get_latest_index_attempts
from danswer.db.models import DeletionAttempt
from danswer.db.models import User
from danswer.direct_qa.llm_utils import check_model_api_key_is_valid
from danswer.direct_qa.llm_utils import get_default_qa_model
@@ -70,7 +72,6 @@ from danswer.server.models import ConnectorCredentialPairIdentifier
from danswer.server.models import ConnectorIndexingStatus
from danswer.server.models import ConnectorSnapshot
from danswer.server.models import CredentialSnapshot
from danswer.server.models import DeletionAttemptSnapshot
from danswer.server.models import FileUploadResponse
from danswer.server.models import GDriveCallback
from danswer.server.models import GoogleAppCredentials
@@ -307,25 +308,12 @@ def get_connector_indexing_status(
for index_attempt in latest_index_attempts
}
deletion_attempts_by_connector: dict[int, list[DeletionAttempt]] = {
cc_pair.connector.id: [] for cc_pair in cc_pairs
}
for deletion_attempt in get_deletion_attempts(
db_session=db_session,
connector_ids=[cc_pair.connector.id for cc_pair in cc_pairs],
ordered_by_time_updated=True,
):
deletion_attempts_by_connector[deletion_attempt.connector_id].append(
deletion_attempt
)
for cc_pair in cc_pairs:
connector = cc_pair.connector
credential = cc_pair.credential
latest_index_attempt = cc_pair_to_latest_index_attempt.get(
(connector.id, credential.id)
)
deletion_attempts = deletion_attempts_by_connector.get(connector.id, [])
indexing_statuses.append(
ConnectorIndexingStatus(
connector=ConnectorSnapshot.from_connector_db_model(connector),
@@ -343,12 +331,9 @@ def get_connector_indexing_status(
)
if latest_index_attempt
else None,
deletion_attempts=[
DeletionAttemptSnapshot.from_deletion_attempt_db_model(
deletion_attempt
)
for deletion_attempt in deletion_attempts
],
deletion_attempt=get_deletion_status(
connector_id=connector.id, credential_id=credential.id
),
is_deletable=check_deletion_attempt_is_allowed(
connector_credential_pair=cc_pair
),
@@ -564,31 +549,13 @@ def create_deletion_attempt_for_connector_id(
"no ongoing / planned indexing attempts.",
)
create_deletion_attempt(
connector_id=connector_id,
credential_id=credential_id,
db_session=db_session,
task_id = get_cleanup_task_id(
connector_id=connector_id, credential_id=credential_id
)
@router.get("/admin/deletion-attempt/{connector_id}")
def get_deletion_attempts_for_connector_id(
connector_id: int,
_: User = Depends(current_admin_user),
db_session: Session = Depends(get_session),
) -> list[DeletionAttemptSnapshot]:
deletion_attempts = get_deletion_attempts(
db_session=db_session, connector_ids=[connector_id]
cleanup_connector_credential_pair_task.apply_async(
kwargs=dict(connector_id=connector_id, credential_id=credential_id),
task_id=task_id,
)
return [
DeletionAttemptSnapshot(
connector_id=connector_id,
status=deletion_attempt.status,
error_msg=deletion_attempt.error_msg,
num_docs_deleted=deletion_attempt.num_docs_deleted,
)
for deletion_attempt in deletion_attempts
]
"""Endpoints for basic users"""

View File

@@ -18,7 +18,6 @@ from danswer.connectors.models import InputType
from danswer.datastores.interfaces import IndexFilter
from danswer.db.models import Connector
from danswer.db.models import Credential
from danswer.db.models import DeletionAttempt
from danswer.db.models import DeletionStatus
from danswer.db.models import IndexAttempt
from danswer.db.models import IndexingStatus
@@ -256,21 +255,11 @@ class IndexAttemptSnapshot(BaseModel):
class DeletionAttemptSnapshot(BaseModel):
connector_id: int
credential_id: int
status: DeletionStatus
error_msg: str | None
num_docs_deleted: int
@classmethod
def from_deletion_attempt_db_model(
cls, deletion_attempt: DeletionAttempt
) -> "DeletionAttemptSnapshot":
return DeletionAttemptSnapshot(
connector_id=deletion_attempt.connector_id,
status=deletion_attempt.status,
error_msg=deletion_attempt.error_msg,
num_docs_deleted=deletion_attempt.num_docs_deleted,
)
class ConnectorBase(BaseModel):
name: str
@@ -347,7 +336,7 @@ class ConnectorIndexingStatus(BaseModel):
docs_indexed: int
error_msg: str | None
latest_index_attempt: IndexAttemptSnapshot | None
deletion_attempts: list[DeletionAttemptSnapshot]
deletion_attempt: DeletionAttemptSnapshot | None
is_deletable: bool

View File

@@ -2,6 +2,7 @@ alembic==1.10.4
asyncpg==0.27.0
atlassian-python-api==3.37.0
beautifulsoup4==4.12.0
celery==5.3.4
dask==2023.8.1
distributed==2023.8.1
python-dateutil==2.8.2

View File

@@ -1,4 +1,5 @@
black==23.3.0
celery-types==0.19.0
mypy-extensions==1.0.0
mypy==1.1.1
pre-commit==3.2.2

View File

@@ -10,9 +10,9 @@ stdout_logfile_maxbytes=52428800
redirect_stderr=true
autorestart=true
[program:connector_deletion]
command=python danswer/background/connector_deletion.py
stdout_logfile=/var/log/connector_deletion.log
[program:celery]
command=celery -A danswer.background.celery worker --loglevel=INFO
stdout_logfile=/var/log/celery.log
stdout_logfile_maxbytes=52428800
redirect_stderr=true
autorestart=true
@@ -39,7 +39,7 @@ startsecs=60
# pushes all logs from the above programs to stdout
[program:log-redirect-handler]
command=tail -qF /var/log/update.log /var/log/connector_deletion.log /var/log/file_deletion.log /var/log/slack_bot_listener.log
command=tail -qF /var/log/update.log /var/log/celery.log /var/log/file_deletion.log /var/log/slack_bot_listener.log
stdout_logfile=/dev/stdout
stdout_logfile_maxbytes=0
redirect_stderr=true