mirror of
https://github.com/danswer-ai/danswer.git
synced 2025-03-26 17:51:54 +01:00
Add connector document pruning task (#1652)
This commit is contained in:
parent
58b5e25c97
commit
54c2547d89
@ -0,0 +1,22 @@
|
||||
"""added-prune-frequency
|
||||
|
||||
Revision ID: e209dc5a8156
|
||||
Revises: 48d14957fe80
|
||||
Create Date: 2024-06-16 16:02:35.273231
|
||||
|
||||
"""
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
revision = "e209dc5a8156"
|
||||
down_revision = "48d14957fe80"
|
||||
branch_labels = None # type: ignore
|
||||
depends_on = None # type: ignore
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
op.add_column("connector", sa.Column("prune_freq", sa.Integer(), nullable=True))
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.drop_column("connector", "prune_freq")
|
@ -4,13 +4,22 @@ from typing import cast
|
||||
from celery import Celery # type: ignore
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from danswer.background.celery.celery_utils import extract_ids_from_runnable_connector
|
||||
from danswer.background.celery.celery_utils import should_prune_cc_pair
|
||||
from danswer.background.celery.celery_utils import should_sync_doc_set
|
||||
from danswer.background.connector_deletion import delete_connector_credential_pair
|
||||
from danswer.background.connector_deletion import delete_connector_credential_pair_batch
|
||||
from danswer.background.task_utils import build_celery_task_wrapper
|
||||
from danswer.background.task_utils import name_cc_cleanup_task
|
||||
from danswer.background.task_utils import name_cc_prune_task
|
||||
from danswer.background.task_utils import name_document_set_sync_task
|
||||
from danswer.configs.app_configs import JOB_TIMEOUT
|
||||
from danswer.connectors.factory import instantiate_connector
|
||||
from danswer.connectors.models import InputType
|
||||
from danswer.db.connector_credential_pair import get_connector_credential_pair
|
||||
from danswer.db.connector_credential_pair import get_connector_credential_pairs
|
||||
from danswer.db.deletion_attempt import check_deletion_attempt_is_allowed
|
||||
from danswer.db.document import get_documents_for_connector_credential_pair
|
||||
from danswer.db.document import prepare_to_modify_documents
|
||||
from danswer.db.document_set import delete_document_set
|
||||
from danswer.db.document_set import fetch_document_sets
|
||||
@ -22,8 +31,6 @@ from danswer.db.engine import build_connection_string
|
||||
from danswer.db.engine import get_sqlalchemy_engine
|
||||
from danswer.db.engine import SYNC_DB_API
|
||||
from danswer.db.models import DocumentSet
|
||||
from danswer.db.tasks import check_live_task_not_timed_out
|
||||
from danswer.db.tasks import get_latest_task
|
||||
from danswer.document_index.document_index_utils import get_both_index_names
|
||||
from danswer.document_index.factory import get_default_document_index
|
||||
from danswer.document_index.interfaces import UpdateRequest
|
||||
@ -90,6 +97,74 @@ def cleanup_connector_credential_pair_task(
|
||||
raise e
|
||||
|
||||
|
||||
@build_celery_task_wrapper(name_cc_prune_task)
|
||||
@celery_app.task(soft_time_limit=JOB_TIMEOUT)
|
||||
def prune_documents_task(connector_id: int, credential_id: int) -> None:
|
||||
"""connector pruning task. For a cc pair, this task pulls all docuement IDs from the source
|
||||
and compares those IDs to locally stored documents and deletes all locally stored IDs missing
|
||||
from the most recently pulled document ID list"""
|
||||
with Session(get_sqlalchemy_engine()) as db_session:
|
||||
try:
|
||||
cc_pair = get_connector_credential_pair(
|
||||
db_session=db_session,
|
||||
connector_id=connector_id,
|
||||
credential_id=credential_id,
|
||||
)
|
||||
|
||||
if not cc_pair:
|
||||
logger.warning(f"ccpair not found for {connector_id} {credential_id}")
|
||||
return
|
||||
|
||||
runnable_connector = instantiate_connector(
|
||||
cc_pair.connector.source,
|
||||
InputType.PRUNE,
|
||||
cc_pair.connector.connector_specific_config,
|
||||
cc_pair.credential,
|
||||
db_session,
|
||||
)
|
||||
|
||||
all_connector_doc_ids: set[str] = extract_ids_from_runnable_connector(
|
||||
runnable_connector
|
||||
)
|
||||
|
||||
all_indexed_document_ids = {
|
||||
doc.id
|
||||
for doc in get_documents_for_connector_credential_pair(
|
||||
db_session=db_session,
|
||||
connector_id=connector_id,
|
||||
credential_id=credential_id,
|
||||
)
|
||||
}
|
||||
|
||||
doc_ids_to_remove = list(all_indexed_document_ids - all_connector_doc_ids)
|
||||
|
||||
curr_ind_name, sec_ind_name = get_both_index_names(db_session)
|
||||
document_index = get_default_document_index(
|
||||
primary_index_name=curr_ind_name, secondary_index_name=sec_ind_name
|
||||
)
|
||||
|
||||
if len(doc_ids_to_remove) == 0:
|
||||
logger.info(
|
||||
f"No docs to prune from {cc_pair.connector.source} connector"
|
||||
)
|
||||
return
|
||||
|
||||
logger.info(
|
||||
f"pruning {len(doc_ids_to_remove)} doc(s) from {cc_pair.connector.source} connector"
|
||||
)
|
||||
delete_connector_credential_pair_batch(
|
||||
document_ids=doc_ids_to_remove,
|
||||
connector_id=connector_id,
|
||||
credential_id=credential_id,
|
||||
document_index=document_index,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.exception(
|
||||
f"Failed to run pruning for connector id {connector_id} due to {e}"
|
||||
)
|
||||
raise e
|
||||
|
||||
|
||||
@build_celery_task_wrapper(name_document_set_sync_task)
|
||||
@celery_app.task(soft_time_limit=JOB_TIMEOUT)
|
||||
def sync_document_set_task(document_set_id: int) -> None:
|
||||
@ -177,32 +252,48 @@ def sync_document_set_task(document_set_id: int) -> None:
|
||||
soft_time_limit=JOB_TIMEOUT,
|
||||
)
|
||||
def check_for_document_sets_sync_task() -> None:
|
||||
"""Runs periodically to check if any document sets are out of sync
|
||||
Creates a task to sync the set if needed"""
|
||||
"""Runs periodically to check if any sync tasks should be run and adds them
|
||||
to the queue"""
|
||||
with Session(get_sqlalchemy_engine()) as db_session:
|
||||
# check if any document sets are not synced
|
||||
document_set_info = fetch_document_sets(
|
||||
user_id=None, db_session=db_session, include_outdated=True
|
||||
)
|
||||
for document_set, _ in document_set_info:
|
||||
if not document_set.is_up_to_date:
|
||||
task_name = name_document_set_sync_task(document_set.id)
|
||||
latest_sync = get_latest_task(task_name, db_session)
|
||||
|
||||
if latest_sync and check_live_task_not_timed_out(
|
||||
latest_sync, db_session
|
||||
):
|
||||
logger.info(
|
||||
f"Document set '{document_set.id}' is already syncing. Skipping."
|
||||
)
|
||||
continue
|
||||
|
||||
logger.info(f"Document set {document_set.id} syncing now!")
|
||||
if should_sync_doc_set(document_set, db_session):
|
||||
logger.info(f"Syncing the {document_set.name} document set")
|
||||
sync_document_set_task.apply_async(
|
||||
kwargs=dict(document_set_id=document_set.id),
|
||||
)
|
||||
|
||||
|
||||
@celery_app.task(
|
||||
name="check_for_prune_task",
|
||||
soft_time_limit=JOB_TIMEOUT,
|
||||
)
|
||||
def check_for_prune_task() -> None:
|
||||
"""Runs periodically to check if any prune tasks should be run and adds them
|
||||
to the queue"""
|
||||
|
||||
with Session(get_sqlalchemy_engine()) as db_session:
|
||||
all_cc_pairs = get_connector_credential_pairs(db_session)
|
||||
|
||||
for cc_pair in all_cc_pairs:
|
||||
if should_prune_cc_pair(
|
||||
connector=cc_pair.connector,
|
||||
credential=cc_pair.credential,
|
||||
db_session=db_session,
|
||||
):
|
||||
logger.info(f"Pruning the {cc_pair.connector.name} connector")
|
||||
|
||||
prune_documents_task.apply_async(
|
||||
kwargs=dict(
|
||||
connector_id=cc_pair.connector.id,
|
||||
credential_id=cc_pair.credential.id,
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
#####
|
||||
# Celery Beat (Periodic Tasks) Settings
|
||||
#####
|
||||
@ -212,3 +303,11 @@ celery_app.conf.beat_schedule = {
|
||||
"schedule": timedelta(seconds=5),
|
||||
},
|
||||
}
|
||||
celery_app.conf.beat_schedule.update(
|
||||
{
|
||||
"check-for-prune": {
|
||||
"task": "check_for_prune_task",
|
||||
"schedule": timedelta(seconds=5),
|
||||
},
|
||||
}
|
||||
)
|
||||
|
@ -1,8 +1,25 @@
|
||||
from datetime import datetime
|
||||
from datetime import timezone
|
||||
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from danswer.background.task_utils import name_cc_cleanup_task
|
||||
from danswer.background.task_utils import name_cc_prune_task
|
||||
from danswer.background.task_utils import name_document_set_sync_task
|
||||
from danswer.connectors.interfaces import BaseConnector
|
||||
from danswer.connectors.interfaces import IdConnector
|
||||
from danswer.connectors.interfaces import LoadConnector
|
||||
from danswer.connectors.interfaces import PollConnector
|
||||
from danswer.db.engine import get_db_current_time
|
||||
from danswer.db.models import Connector
|
||||
from danswer.db.models import Credential
|
||||
from danswer.db.models import DocumentSet
|
||||
from danswer.db.tasks import check_live_task_not_timed_out
|
||||
from danswer.db.tasks import get_latest_task
|
||||
from danswer.server.documents.models import DeletionAttemptSnapshot
|
||||
from danswer.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
def get_deletion_status(
|
||||
@ -21,3 +38,71 @@ def get_deletion_status(
|
||||
credential_id=credential_id,
|
||||
status=task_state.status,
|
||||
)
|
||||
|
||||
|
||||
def should_sync_doc_set(document_set: DocumentSet, db_session: Session) -> bool:
|
||||
if document_set.is_up_to_date:
|
||||
return False
|
||||
|
||||
task_name = name_document_set_sync_task(document_set.id)
|
||||
latest_sync = get_latest_task(task_name, db_session)
|
||||
|
||||
if latest_sync and check_live_task_not_timed_out(latest_sync, db_session):
|
||||
logger.info(f"Document set '{document_set.id}' is already syncing. Skipping.")
|
||||
return False
|
||||
|
||||
logger.info(f"Document set {document_set.id} syncing now!")
|
||||
return True
|
||||
|
||||
|
||||
def should_prune_cc_pair(
|
||||
connector: Connector, credential: Credential, db_session: Session
|
||||
) -> bool:
|
||||
if not connector.prune_freq:
|
||||
return False
|
||||
|
||||
pruning_task_name = name_cc_prune_task(
|
||||
connector_id=connector.id, credential_id=credential.id
|
||||
)
|
||||
last_pruning_task = get_latest_task(pruning_task_name, db_session)
|
||||
current_db_time = get_db_current_time(db_session)
|
||||
|
||||
if not last_pruning_task:
|
||||
time_since_initialization = current_db_time - connector.time_created
|
||||
if time_since_initialization.total_seconds() >= connector.prune_freq:
|
||||
return True
|
||||
return False
|
||||
|
||||
if check_live_task_not_timed_out(last_pruning_task, db_session):
|
||||
logger.info(f"Connector '{connector.name}' is already pruning. Skipping.")
|
||||
return False
|
||||
|
||||
if not last_pruning_task.start_time:
|
||||
return False
|
||||
|
||||
time_since_last_pruning = current_db_time - last_pruning_task.start_time
|
||||
return time_since_last_pruning.total_seconds() >= connector.prune_freq
|
||||
|
||||
|
||||
def extract_ids_from_runnable_connector(runnable_connector: BaseConnector) -> set[str]:
|
||||
"""
|
||||
If the PruneConnector hasnt been implemented for the given connector, just pull
|
||||
all docs using the load_from_state and grab out the IDs
|
||||
"""
|
||||
all_connector_doc_ids: set[str] = set()
|
||||
if isinstance(runnable_connector, IdConnector):
|
||||
all_connector_doc_ids = runnable_connector.retrieve_all_source_ids()
|
||||
elif isinstance(runnable_connector, LoadConnector):
|
||||
doc_batch_generator = runnable_connector.load_from_state()
|
||||
for doc_batch in doc_batch_generator:
|
||||
all_connector_doc_ids.update(doc.id for doc in doc_batch)
|
||||
elif isinstance(runnable_connector, PollConnector):
|
||||
start = datetime(1970, 1, 1, tzinfo=timezone.utc).timestamp()
|
||||
end = datetime.now(timezone.utc).timestamp()
|
||||
doc_batch_generator = runnable_connector.poll_source(start=start, end=end)
|
||||
for doc_batch in doc_batch_generator:
|
||||
all_connector_doc_ids.update(doc.id for doc in doc_batch)
|
||||
else:
|
||||
raise RuntimeError("Pruning job could not find a valid runnable_connector.")
|
||||
|
||||
return all_connector_doc_ids
|
||||
|
@ -41,7 +41,7 @@ logger = setup_logger()
|
||||
_DELETION_BATCH_SIZE = 1000
|
||||
|
||||
|
||||
def _delete_connector_credential_pair_batch(
|
||||
def delete_connector_credential_pair_batch(
|
||||
document_ids: list[str],
|
||||
connector_id: int,
|
||||
credential_id: int,
|
||||
@ -169,7 +169,7 @@ def delete_connector_credential_pair(
|
||||
if not documents:
|
||||
break
|
||||
|
||||
_delete_connector_credential_pair_batch(
|
||||
delete_connector_credential_pair_batch(
|
||||
document_ids=[document.id for document in documents],
|
||||
connector_id=connector_id,
|
||||
credential_id=credential_id,
|
||||
|
@ -6,11 +6,7 @@ from datetime import timezone
|
||||
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from danswer.background.connector_deletion import (
|
||||
_delete_connector_credential_pair_batch,
|
||||
)
|
||||
from danswer.background.indexing.checkpointing import get_time_windows_for_index_attempt
|
||||
from danswer.configs.app_configs import DISABLE_DOCUMENT_CLEANUP
|
||||
from danswer.configs.app_configs import POLL_CONNECTOR_OFFSET
|
||||
from danswer.connectors.factory import instantiate_connector
|
||||
from danswer.connectors.interfaces import GenerateDocumentsOutput
|
||||
@ -21,8 +17,6 @@ from danswer.connectors.models import InputType
|
||||
from danswer.db.connector import disable_connector
|
||||
from danswer.db.connector_credential_pair import get_last_successful_attempt_time
|
||||
from danswer.db.connector_credential_pair import update_connector_credential_pair
|
||||
from danswer.db.credentials import backend_update_credential_json
|
||||
from danswer.db.document import get_documents_for_connector_credential_pair
|
||||
from danswer.db.engine import get_sqlalchemy_engine
|
||||
from danswer.db.index_attempt import get_index_attempt
|
||||
from danswer.db.index_attempt import mark_attempt_failed
|
||||
@ -46,7 +40,7 @@ def _get_document_generator(
|
||||
attempt: IndexAttempt,
|
||||
start_time: datetime,
|
||||
end_time: datetime,
|
||||
) -> tuple[GenerateDocumentsOutput, bool]:
|
||||
) -> GenerateDocumentsOutput:
|
||||
"""
|
||||
NOTE: `start_time` and `end_time` are only used for poll connectors
|
||||
|
||||
@ -57,16 +51,13 @@ def _get_document_generator(
|
||||
task = attempt.connector.input_type
|
||||
|
||||
try:
|
||||
runnable_connector, new_credential_json = instantiate_connector(
|
||||
runnable_connector = instantiate_connector(
|
||||
attempt.connector.source,
|
||||
task,
|
||||
attempt.connector.connector_specific_config,
|
||||
attempt.credential.credential_json,
|
||||
attempt.credential,
|
||||
db_session,
|
||||
)
|
||||
if new_credential_json is not None:
|
||||
backend_update_credential_json(
|
||||
attempt.credential, new_credential_json, db_session
|
||||
)
|
||||
except Exception as e:
|
||||
logger.exception(f"Unable to instantiate connector due to {e}")
|
||||
disable_connector(attempt.connector.id, db_session)
|
||||
@ -75,7 +66,7 @@ def _get_document_generator(
|
||||
if task == InputType.LOAD_STATE:
|
||||
assert isinstance(runnable_connector, LoadConnector)
|
||||
doc_batch_generator = runnable_connector.load_from_state()
|
||||
is_listing_complete = True
|
||||
|
||||
elif task == InputType.POLL:
|
||||
assert isinstance(runnable_connector, PollConnector)
|
||||
if attempt.connector_id is None or attempt.credential_id is None:
|
||||
@ -88,13 +79,12 @@ def _get_document_generator(
|
||||
doc_batch_generator = runnable_connector.poll_source(
|
||||
start=start_time.timestamp(), end=end_time.timestamp()
|
||||
)
|
||||
is_listing_complete = False
|
||||
|
||||
else:
|
||||
# Event types cannot be handled by a background type
|
||||
raise RuntimeError(f"Invalid task type: {task}")
|
||||
|
||||
return doc_batch_generator, is_listing_complete
|
||||
return doc_batch_generator
|
||||
|
||||
|
||||
def _run_indexing(
|
||||
@ -166,7 +156,7 @@ def _run_indexing(
|
||||
datetime(1970, 1, 1, tzinfo=timezone.utc),
|
||||
)
|
||||
|
||||
doc_batch_generator, is_listing_complete = _get_document_generator(
|
||||
doc_batch_generator = _get_document_generator(
|
||||
db_session=db_session,
|
||||
attempt=index_attempt,
|
||||
start_time=window_start,
|
||||
@ -224,39 +214,6 @@ def _run_indexing(
|
||||
docs_removed_from_index=0,
|
||||
)
|
||||
|
||||
if is_listing_complete and not DISABLE_DOCUMENT_CLEANUP:
|
||||
# clean up all documents from the index that have not been returned from the connector
|
||||
all_indexed_document_ids = {
|
||||
d.id
|
||||
for d in get_documents_for_connector_credential_pair(
|
||||
db_session=db_session,
|
||||
connector_id=db_connector.id,
|
||||
credential_id=db_credential.id,
|
||||
)
|
||||
}
|
||||
doc_ids_to_remove = list(
|
||||
all_indexed_document_ids - all_connector_doc_ids
|
||||
)
|
||||
logger.debug(
|
||||
f"Cleaning up {len(doc_ids_to_remove)} documents that are not contained in the newest connector state"
|
||||
)
|
||||
|
||||
# delete docs from cc-pair and receive the number of completely deleted docs in return
|
||||
_delete_connector_credential_pair_batch(
|
||||
document_ids=doc_ids_to_remove,
|
||||
connector_id=db_connector.id,
|
||||
credential_id=db_credential.id,
|
||||
document_index=document_index,
|
||||
)
|
||||
|
||||
update_docs_indexed(
|
||||
db_session=db_session,
|
||||
index_attempt=index_attempt,
|
||||
total_docs_indexed=document_count,
|
||||
new_docs_indexed=net_doc_change,
|
||||
docs_removed_from_index=len(doc_ids_to_remove),
|
||||
)
|
||||
|
||||
run_end_dt = window_end
|
||||
if is_primary:
|
||||
update_connector_credential_pair(
|
||||
|
@ -22,6 +22,10 @@ def name_document_set_sync_task(document_set_id: int) -> str:
|
||||
return f"sync_doc_set_{document_set_id}"
|
||||
|
||||
|
||||
def name_cc_prune_task(connector_id: int, credential_id: int) -> str:
|
||||
return f"prune_connector_credential_pair_{connector_id}_{credential_id}"
|
||||
|
||||
|
||||
T = TypeVar("T", bound=Callable)
|
||||
|
||||
|
||||
|
@ -199,6 +199,8 @@ EXPERIMENTAL_CHECKPOINTING_ENABLED = (
|
||||
os.environ.get("EXPERIMENTAL_CHECKPOINTING_ENABLED", "").lower() == "true"
|
||||
)
|
||||
|
||||
DEFAULT_PRUNING_FREQ = 60 * 60 * 24 # Once a day
|
||||
|
||||
|
||||
#####
|
||||
# Indexing Configs
|
||||
@ -228,10 +230,6 @@ ENABLE_MINI_CHUNK = os.environ.get("ENABLE_MINI_CHUNK", "").lower() == "true"
|
||||
MINI_CHUNK_SIZE = 150
|
||||
# Timeout to wait for job's last update before killing it, in hours
|
||||
CLEANUP_INDEXING_JOBS_TIMEOUT = int(os.environ.get("CLEANUP_INDEXING_JOBS_TIMEOUT", 3))
|
||||
# If set to true, then will not clean up documents that "no longer exist" when running Load connectors
|
||||
DISABLE_DOCUMENT_CLEANUP = (
|
||||
os.environ.get("DISABLE_DOCUMENT_CLEANUP", "").lower() == "true"
|
||||
)
|
||||
|
||||
|
||||
#####
|
||||
|
@ -1,6 +1,8 @@
|
||||
from typing import Any
|
||||
from typing import Type
|
||||
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from danswer.configs.constants import DocumentSource
|
||||
from danswer.connectors.axero.connector import AxeroConnector
|
||||
from danswer.connectors.bookstack.connector import BookstackConnector
|
||||
@ -40,6 +42,8 @@ from danswer.connectors.web.connector import WebConnector
|
||||
from danswer.connectors.wikipedia.connector import WikipediaConnector
|
||||
from danswer.connectors.zendesk.connector import ZendeskConnector
|
||||
from danswer.connectors.zulip.connector import ZulipConnector
|
||||
from danswer.db.credentials import backend_update_credential_json
|
||||
from danswer.db.models import Credential
|
||||
|
||||
|
||||
class ConnectorMissingException(Exception):
|
||||
@ -119,10 +123,14 @@ def instantiate_connector(
|
||||
source: DocumentSource,
|
||||
input_type: InputType,
|
||||
connector_specific_config: dict[str, Any],
|
||||
credentials: dict[str, Any],
|
||||
) -> tuple[BaseConnector, dict[str, Any] | None]:
|
||||
credential: Credential,
|
||||
db_session: Session,
|
||||
) -> BaseConnector:
|
||||
connector_class = identify_connector_class(source, input_type)
|
||||
connector = connector_class(**connector_specific_config)
|
||||
new_credentials = connector.load_credentials(credentials)
|
||||
new_credentials = connector.load_credentials(credential.credential_json)
|
||||
|
||||
return connector, new_credentials
|
||||
if new_credentials is not None:
|
||||
backend_update_credential_json(credential, new_credentials, db_session)
|
||||
|
||||
return connector
|
||||
|
@ -50,6 +50,12 @@ class PollConnector(BaseConnector):
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class IdConnector(BaseConnector):
|
||||
@abc.abstractmethod
|
||||
def retrieve_all_source_ids(self) -> set[str]:
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
# Event driven
|
||||
class EventConnector(BaseConnector):
|
||||
@abc.abstractmethod
|
||||
|
@ -13,6 +13,7 @@ class InputType(str, Enum):
|
||||
LOAD_STATE = "load_state" # e.g. loading a current full state or a save state, such as from a file
|
||||
POLL = "poll" # e.g. calling an API to get all documents in the last hour
|
||||
EVENT = "event" # e.g. registered an endpoint as a listener, and processing connector events
|
||||
PRUNE = "prune"
|
||||
|
||||
|
||||
class ConnectorMissingCredentialError(PermissionError):
|
||||
|
@ -11,6 +11,7 @@ from danswer.configs.app_configs import INDEX_BATCH_SIZE
|
||||
from danswer.configs.constants import DocumentSource
|
||||
from danswer.connectors.cross_connector_utils.miscellaneous_utils import time_str_to_utc
|
||||
from danswer.connectors.interfaces import GenerateDocumentsOutput
|
||||
from danswer.connectors.interfaces import IdConnector
|
||||
from danswer.connectors.interfaces import LoadConnector
|
||||
from danswer.connectors.interfaces import PollConnector
|
||||
from danswer.connectors.interfaces import SecondsSinceUnixEpoch
|
||||
@ -23,11 +24,12 @@ from danswer.utils.logger import setup_logger
|
||||
|
||||
DEFAULT_PARENT_OBJECT_TYPES = ["Account"]
|
||||
MAX_QUERY_LENGTH = 10000 # max query length is 20,000 characters
|
||||
ID_PREFIX = "SALESFORCE_"
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
class SalesforceConnector(LoadConnector, PollConnector):
|
||||
class SalesforceConnector(LoadConnector, PollConnector, IdConnector):
|
||||
def __init__(
|
||||
self,
|
||||
batch_size: int = INDEX_BATCH_SIZE,
|
||||
@ -77,7 +79,7 @@ class SalesforceConnector(LoadConnector, PollConnector):
|
||||
if self.sf_client is None:
|
||||
raise ConnectorMissingCredentialError("Salesforce")
|
||||
|
||||
extracted_id = f"SALESFORCE_{object_dict['Id']}"
|
||||
extracted_id = f"{ID_PREFIX}{object_dict['Id']}"
|
||||
extracted_link = f"https://{self.sf_client.sf_instance}/{extracted_id}"
|
||||
extracted_doc_updated_at = time_str_to_utc(object_dict["LastModifiedDate"])
|
||||
extracted_object_text = extract_dict_text(object_dict)
|
||||
@ -229,8 +231,6 @@ class SalesforceConnector(LoadConnector, PollConnector):
|
||||
yield doc_batch
|
||||
|
||||
def load_from_state(self) -> GenerateDocumentsOutput:
|
||||
if self.sf_client is None:
|
||||
raise ConnectorMissingCredentialError("Salesforce")
|
||||
return self._fetch_from_salesforce()
|
||||
|
||||
def poll_source(
|
||||
@ -242,6 +242,20 @@ class SalesforceConnector(LoadConnector, PollConnector):
|
||||
end_datetime = datetime.utcfromtimestamp(end)
|
||||
return self._fetch_from_salesforce(start=start_datetime, end=end_datetime)
|
||||
|
||||
def retrieve_all_source_ids(self) -> set[str]:
|
||||
if self.sf_client is None:
|
||||
raise ConnectorMissingCredentialError("Salesforce")
|
||||
all_retrieved_ids: set[str] = set()
|
||||
for parent_object_type in self.parent_object_list:
|
||||
query = f"SELECT Id FROM {parent_object_type}"
|
||||
query_result = self.sf_client.query_all(query)
|
||||
all_retrieved_ids.update(
|
||||
f"{ID_PREFIX}{instance_dict.get('Id', '')}"
|
||||
for instance_dict in query_result["records"]
|
||||
)
|
||||
|
||||
return all_retrieved_ids
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
connector = SalesforceConnector(
|
||||
|
@ -7,6 +7,7 @@ from sqlalchemy import select
|
||||
from sqlalchemy.orm import aliased
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from danswer.configs.app_configs import DEFAULT_PRUNING_FREQ
|
||||
from danswer.configs.constants import DocumentSource
|
||||
from danswer.connectors.models import InputType
|
||||
from danswer.db.models import Connector
|
||||
@ -84,6 +85,9 @@ def create_connector(
|
||||
input_type=connector_data.input_type,
|
||||
connector_specific_config=connector_data.connector_specific_config,
|
||||
refresh_freq=connector_data.refresh_freq,
|
||||
prune_freq=connector_data.prune_freq
|
||||
if connector_data.prune_freq is not None
|
||||
else DEFAULT_PRUNING_FREQ,
|
||||
disabled=connector_data.disabled,
|
||||
)
|
||||
db_session.add(connector)
|
||||
@ -113,6 +117,11 @@ def update_connector(
|
||||
connector.input_type = connector_data.input_type
|
||||
connector.connector_specific_config = connector_data.connector_specific_config
|
||||
connector.refresh_freq = connector_data.refresh_freq
|
||||
connector.prune_freq = (
|
||||
connector_data.prune_freq
|
||||
if connector_data.prune_freq is not None
|
||||
else DEFAULT_PRUNING_FREQ
|
||||
)
|
||||
connector.disabled = connector_data.disabled
|
||||
|
||||
db_session.commit()
|
||||
@ -259,6 +268,7 @@ def create_initial_default_connector(db_session: Session) -> None:
|
||||
input_type=InputType.LOAD_STATE,
|
||||
connector_specific_config={},
|
||||
refresh_freq=None,
|
||||
prune_freq=None,
|
||||
)
|
||||
db_session.add(connector)
|
||||
db_session.commit()
|
||||
|
@ -383,6 +383,7 @@ class Connector(Base):
|
||||
postgresql.JSONB()
|
||||
)
|
||||
refresh_freq: Mapped[int | None] = mapped_column(Integer, nullable=True)
|
||||
prune_freq: Mapped[int | None] = mapped_column(Integer, nullable=True)
|
||||
time_created: Mapped[datetime.datetime] = mapped_column(
|
||||
DateTime(timezone=True), server_default=func.now()
|
||||
)
|
||||
|
@ -511,6 +511,7 @@ def update_connector_from_model(
|
||||
input_type=updated_connector.input_type,
|
||||
connector_specific_config=updated_connector.connector_specific_config,
|
||||
refresh_freq=updated_connector.refresh_freq,
|
||||
prune_freq=updated_connector.prune_freq,
|
||||
credential_ids=[
|
||||
association.credential.id for association in updated_connector.credentials
|
||||
],
|
||||
@ -726,6 +727,7 @@ def get_connector_by_id(
|
||||
input_type=connector.input_type,
|
||||
connector_specific_config=connector.connector_specific_config,
|
||||
refresh_freq=connector.refresh_freq,
|
||||
prune_freq=connector.prune_freq,
|
||||
credential_ids=[
|
||||
association.credential.id for association in connector.credentials
|
||||
],
|
||||
|
@ -68,6 +68,7 @@ class ConnectorBase(BaseModel):
|
||||
input_type: InputType
|
||||
connector_specific_config: dict[str, Any]
|
||||
refresh_freq: int | None # In seconds, None for one time index with no refresh
|
||||
prune_freq: int | None
|
||||
disabled: bool
|
||||
|
||||
|
||||
@ -86,6 +87,7 @@ class ConnectorSnapshot(ConnectorBase):
|
||||
input_type=connector.input_type,
|
||||
connector_specific_config=connector.connector_specific_config,
|
||||
refresh_freq=connector.refresh_freq,
|
||||
prune_freq=connector.prune_freq,
|
||||
credential_ids=[
|
||||
association.credential.id for association in connector.credentials
|
||||
],
|
||||
|
@ -160,7 +160,6 @@ services:
|
||||
- GONG_CONNECTOR_START_TIME=${GONG_CONNECTOR_START_TIME:-}
|
||||
- NOTION_CONNECTOR_ENABLE_RECURSIVE_PAGE_LOOKUP=${NOTION_CONNECTOR_ENABLE_RECURSIVE_PAGE_LOOKUP:-}
|
||||
- GITHUB_CONNECTOR_BASE_URL=${GITHUB_CONNECTOR_BASE_URL:-}
|
||||
- DISABLE_DOCUMENT_CLEANUP=${DISABLE_DOCUMENT_CLEANUP:-}
|
||||
# Danswer SlackBot Configs
|
||||
- DANSWER_BOT_SLACK_APP_TOKEN=${DANSWER_BOT_SLACK_APP_TOKEN:-}
|
||||
- DANSWER_BOT_SLACK_BOT_TOKEN=${DANSWER_BOT_SLACK_BOT_TOKEN:-}
|
||||
|
@ -159,7 +159,6 @@ services:
|
||||
- GONG_CONNECTOR_START_TIME=${GONG_CONNECTOR_START_TIME:-}
|
||||
- NOTION_CONNECTOR_ENABLE_RECURSIVE_PAGE_LOOKUP=${NOTION_CONNECTOR_ENABLE_RECURSIVE_PAGE_LOOKUP:-}
|
||||
- GITHUB_CONNECTOR_BASE_URL=${GITHUB_CONNECTOR_BASE_URL:-}
|
||||
- DISABLE_DOCUMENT_CLEANUP=${DISABLE_DOCUMENT_CLEANUP:-}
|
||||
# Danswer SlackBot Configs
|
||||
- DANSWER_BOT_SLACK_APP_TOKEN=${DANSWER_BOT_SLACK_APP_TOKEN:-}
|
||||
- DANSWER_BOT_SLACK_BOT_TOKEN=${DANSWER_BOT_SLACK_BOT_TOKEN:-}
|
||||
|
@ -122,6 +122,7 @@ const Main = () => {
|
||||
file_locations: filePaths,
|
||||
},
|
||||
refresh_freq: null,
|
||||
prune_freq: 0,
|
||||
disabled: false,
|
||||
});
|
||||
if (connectorErrorMsg || !connector) {
|
||||
|
@ -114,6 +114,7 @@ export default function GoogleSites() {
|
||||
zip_path: filePaths[0],
|
||||
},
|
||||
refresh_freq: null,
|
||||
prune_freq: 0,
|
||||
disabled: false,
|
||||
});
|
||||
if (connectorErrorMsg || !connector) {
|
||||
|
@ -118,6 +118,7 @@ export default function Web() {
|
||||
web_connector_type: undefined,
|
||||
}}
|
||||
refreshFreq={60 * 60 * 24} // 1 day
|
||||
pruneFreq={0} // Don't prune
|
||||
/>
|
||||
</Card>
|
||||
|
||||
|
@ -70,6 +70,7 @@ interface BaseProps<T extends Yup.AnyObject> {
|
||||
responseJson: Connector<T> | undefined
|
||||
) => void;
|
||||
refreshFreq?: number;
|
||||
pruneFreq?: number;
|
||||
// If specified, then we will create an empty credential and associate
|
||||
// the connector with it. If credentialId is specified, then this will be ignored
|
||||
shouldCreateEmptyCredentialForConnector?: boolean;
|
||||
@ -91,6 +92,7 @@ export function ConnectorForm<T extends Yup.AnyObject>({
|
||||
validationSchema,
|
||||
initialValues,
|
||||
refreshFreq,
|
||||
pruneFreq,
|
||||
onSubmit,
|
||||
shouldCreateEmptyCredentialForConnector,
|
||||
}: ConnectorFormProps<T>): JSX.Element {
|
||||
@ -144,6 +146,7 @@ export function ConnectorForm<T extends Yup.AnyObject>({
|
||||
input_type: inputType,
|
||||
connector_specific_config: connectorConfig,
|
||||
refresh_freq: refreshFreq || 0,
|
||||
prune_freq: pruneFreq ?? null,
|
||||
disabled: false,
|
||||
});
|
||||
|
||||
@ -281,6 +284,7 @@ export function UpdateConnectorForm<T extends Yup.AnyObject>({
|
||||
input_type: existingConnector.input_type,
|
||||
connector_specific_config: values,
|
||||
refresh_freq: existingConnector.refresh_freq,
|
||||
prune_freq: existingConnector.prune_freq,
|
||||
disabled: false,
|
||||
},
|
||||
existingConnector.id
|
||||
|
@ -85,6 +85,7 @@ export interface ConnectorBase<T> {
|
||||
source: ValidSources;
|
||||
connector_specific_config: T;
|
||||
refresh_freq: number | null;
|
||||
prune_freq: number | null;
|
||||
disabled: boolean;
|
||||
}
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user