DAN-115 Document Polling (#91)

Includes updated document counting for polling
This commit is contained in:
Yuhong Sun 2023-06-15 21:07:05 -07:00 committed by GitHub
parent 97b9b56b03
commit 6fe54a4eed
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
20 changed files with 464 additions and 283 deletions

View File

@ -74,7 +74,7 @@ def upgrade() -> None:
sa.Column(
"credential_json",
postgresql.JSONB(astext_type=sa.Text()),
nullable=True,
nullable=False,
),
sa.Column(
"user_id",
@ -101,7 +101,7 @@ def upgrade() -> None:
sa.PrimaryKeyConstraint("id"),
)
op.create_table(
"connector_credential_association",
"connector_credential_pair",
sa.Column("connector_id", sa.Integer(), nullable=False),
sa.Column("credential_id", sa.Integer(), nullable=False),
sa.ForeignKeyConstraint(
@ -168,6 +168,6 @@ def downgrade() -> None:
)
op.drop_column("index_attempt", "credential_id")
op.drop_column("index_attempt", "connector_id")
op.drop_table("connector_credential_association")
op.drop_table("connector_credential_pair")
op.drop_table("credential")
op.drop_table("connector")

View File

@ -0,0 +1,51 @@
"""Polling Document Count
Revision ID: 3c5e35aa9af0
Revises: 27c6ecc08586
Create Date: 2023-06-14 23:45:51.760440
"""
import sqlalchemy as sa
from alembic import op
# revision identifiers, used by Alembic.
revision = "3c5e35aa9af0"
down_revision = "27c6ecc08586"
branch_labels = None
depends_on = None
def upgrade() -> None:
op.add_column(
"connector_credential_pair",
sa.Column(
"last_successful_index_time",
sa.DateTime(timezone=True),
nullable=True,
),
)
op.add_column(
"connector_credential_pair",
sa.Column(
"last_attempt_status",
sa.Enum(
"NOT_STARTED",
"IN_PROGRESS",
"SUCCESS",
"FAILED",
name="indexingstatus",
),
nullable=False,
),
)
op.add_column(
"connector_credential_pair",
sa.Column("total_docs_indexed", sa.Integer(), nullable=False),
)
def downgrade() -> None:
op.drop_column("connector_credential_pair", "total_docs_indexed")
op.drop_column("connector_credential_pair", "last_attempt_status")
op.drop_column("connector_credential_pair", "last_successful_index_time")

View File

@ -6,18 +6,21 @@ from danswer.connectors.interfaces import PollConnector
from danswer.connectors.models import InputType
from danswer.db.connector import disable_connector
from danswer.db.connector import fetch_connectors
from danswer.db.connector_credential_pair import update_connector_credential_pair
from danswer.db.credentials import backend_update_credential_json
from danswer.db.engine import build_engine
from danswer.db.engine import get_db_current_time
from danswer.db.index_attempt import create_index_attempt
from danswer.db.index_attempt import get_inprogress_index_attempts
from danswer.db.index_attempt import get_last_finished_attempt
from danswer.db.index_attempt import get_last_successful_attempt
from danswer.db.index_attempt import get_last_successful_attempt_start_time
from danswer.db.index_attempt import get_not_started_index_attempts
from danswer.db.index_attempt import mark_attempt_failed
from danswer.db.index_attempt import mark_attempt_in_progress
from danswer.db.index_attempt import mark_attempt_succeeded
from danswer.db.models import Connector
from danswer.db.models import IndexAttempt
from danswer.db.models import IndexingStatus
from danswer.utils.indexing_pipeline import build_indexing_pipeline
from danswer.utils.logging import setup_logger
from sqlalchemy.orm import Session
@ -33,9 +36,7 @@ def should_create_new_indexing(
if not last_index:
return True
current_db_time = get_db_current_time(db_session)
time_since_index = (
current_db_time - last_index.time_updated
) # Maybe better to do time created
time_since_index = current_db_time - last_index.time_updated
return time_since_index.total_seconds() >= connector.refresh_freq
@ -51,24 +52,41 @@ def create_indexing_jobs(db_session: Session) -> None:
# Currently single threaded so any still in-progress must have errored
for attempt in in_progress_indexing_attempts:
logger.warning(
f"Marking in-progress attempt 'connector: {attempt.connector_id}, credential: {attempt.credential_id}' as failed"
f"Marking in-progress attempt 'connector: {attempt.connector_id}, "
f"credential: {attempt.credential_id}' as failed"
)
mark_attempt_failed(attempt, db_session)
last_finished_indexing_attempt = get_last_finished_attempt(
connector.id, db_session
)
if not should_create_new_indexing(
connector, last_finished_indexing_attempt, db_session
):
continue
if attempt.connector_id and attempt.credential_id:
update_connector_credential_pair(
connector_id=attempt.connector_id,
credential_id=attempt.credential_id,
attempt_status=IndexingStatus.FAILED,
net_docs=None,
db_session=db_session,
)
for association in connector.credentials:
credential = association.credential
last_successful_attempt = get_last_successful_attempt(
connector.id, credential.id, db_session
)
if not should_create_new_indexing(
connector, last_successful_attempt, db_session
):
continue
create_index_attempt(connector.id, credential.id, db_session)
update_connector_credential_pair(
connector_id=connector.id,
credential_id=credential.id,
attempt_status=IndexingStatus.NOT_STARTED,
net_docs=None,
db_session=db_session,
)
def run_indexing_jobs(last_run_time: float, db_session: Session) -> None:
def run_indexing_jobs(db_session: Session) -> None:
indexing_pipeline = build_indexing_pipeline()
new_indexing_attempts = get_not_started_index_attempts(db_session)
@ -77,7 +95,7 @@ def run_indexing_jobs(last_run_time: float, db_session: Session) -> None:
logger.info(
f"Starting new indexing attempt for connector: '{attempt.connector.name}', "
f"with config: '{attempt.connector.connector_specific_config}', and "
f" with credentials: '{[c.credential_id for c in attempt.connector.credentials]}'"
f"with credentials: '{[c.credential_id for c in attempt.connector.credentials]}'"
)
mark_attempt_in_progress(attempt, db_session)
@ -85,6 +103,14 @@ def run_indexing_jobs(last_run_time: float, db_session: Session) -> None:
db_credential = attempt.credential
task = db_connector.input_type
update_connector_credential_pair(
connector_id=db_connector.id,
credential_id=db_credential.id,
attempt_status=IndexingStatus.IN_PROGRESS,
net_docs=None,
db_session=db_session,
)
try:
runnable_connector, new_credential_json = instantiate_connector(
db_connector.source,
@ -101,6 +127,7 @@ def run_indexing_jobs(last_run_time: float, db_session: Session) -> None:
disable_connector(db_connector.id, db_session)
continue
net_doc_change = 0
try:
if task == InputType.LOAD_STATE:
assert isinstance(runnable_connector, LoadConnector)
@ -108,6 +135,14 @@ def run_indexing_jobs(last_run_time: float, db_session: Session) -> None:
elif task == InputType.POLL:
assert isinstance(runnable_connector, PollConnector)
if attempt.connector_id is None or attempt.credential_id is None:
raise ValueError(
f"Polling attempt {attempt.id} is missing connector_id or credential_id, "
f"can't fetch time range."
)
last_run_time = get_last_successful_attempt_start_time(
attempt.connector_id, attempt.credential_id, db_session
)
doc_batch_generator = runnable_connector.poll_source(
last_run_time, time.time()
)
@ -121,28 +156,43 @@ def run_indexing_jobs(last_run_time: float, db_session: Session) -> None:
index_user_id = (
None if db_credential.public_doc else db_credential.user_id
)
indexing_pipeline(documents=doc_batch, user_id=index_user_id)
net_doc_change += indexing_pipeline(
documents=doc_batch, user_id=index_user_id
)
document_ids.extend([doc.id for doc in doc_batch])
mark_attempt_succeeded(attempt, document_ids, db_session)
update_connector_credential_pair(
connector_id=db_connector.id,
credential_id=db_credential.id,
attempt_status=IndexingStatus.SUCCESS,
net_docs=net_doc_change,
db_session=db_session,
)
logger.info(f"Indexed {len(document_ids)} documents")
except Exception as e:
logger.exception(f"Indexing job with id {attempt.id} failed due to {e}")
mark_attempt_failed(attempt, db_session, failure_reason=str(e))
update_connector_credential_pair(
connector_id=db_connector.id,
credential_id=db_credential.id,
attempt_status=IndexingStatus.FAILED,
net_docs=net_doc_change,
db_session=db_session,
)
def update_loop(delay: int = 10) -> None:
last_run_time = 0.0
engine = build_engine()
while True:
start = time.time()
logger.info(f"Running update, current time: {time.ctime(start)}")
try:
with Session(
build_engine(), future=True, expire_on_commit=False
) as db_session:
with Session(engine, future=True, expire_on_commit=False) as db_session:
create_indexing_jobs(db_session)
# TODO failed poll jobs won't recover data from failed runs, should fix
run_indexing_jobs(last_run_time, db_session)
run_indexing_jobs(db_session)
except Exception as e:
logger.exception(f"Failed to run update due to {e}")
sleep_time = delay - (time.time() - start)

View File

@ -1,3 +1,4 @@
import datetime
import io
from collections.abc import Generator
from typing import Any
@ -9,6 +10,8 @@ from danswer.connectors.google_drive.connector_auth import DB_CREDENTIALS_DICT_K
from danswer.connectors.google_drive.connector_auth import get_drive_tokens
from danswer.connectors.interfaces import GenerateDocumentsOutput
from danswer.connectors.interfaces import LoadConnector
from danswer.connectors.interfaces import PollConnector
from danswer.connectors.interfaces import SecondsSinceUnixEpoch
from danswer.connectors.models import Document
from danswer.connectors.models import Section
from danswer.utils.logging import setup_logger
@ -35,9 +38,23 @@ def get_file_batches(
service: discovery.Resource,
include_shared: bool = GOOGLE_DRIVE_INCLUDE_SHARED,
batch_size: int = INDEX_BATCH_SIZE,
time_range_start: SecondsSinceUnixEpoch | None = None,
time_range_end: SecondsSinceUnixEpoch | None = None,
) -> Generator[list[dict[str, str]], None, None]:
next_page_token = ""
while next_page_token is not None:
query = ""
if time_range_start is not None:
time_start = (
datetime.datetime.utcfromtimestamp(time_range_start).isoformat() + "Z"
)
query += f"modifiedTime >= '{time_start}' "
if time_range_end is not None:
time_stop = (
datetime.datetime.utcfromtimestamp(time_range_end).isoformat() + "Z"
)
query += f"and modifiedTime <= '{time_stop}'"
results = (
service.files()
.list(
@ -45,6 +62,7 @@ def get_file_batches(
supportsAllDrives=include_shared,
fields="nextPageToken, files(mimeType, id, name, webViewLink)",
pageToken=next_page_token,
q=query,
)
.execute()
)
@ -84,7 +102,7 @@ def extract_text(file: dict[str, str], service: discovery.Resource) -> str:
return "\n".join(page.extract_text() for page in pdf_reader.pages)
class GoogleDriveConnector(LoadConnector):
class GoogleDriveConnector(LoadConnector, PollConnector):
def __init__(
self,
batch_size: int = INDEX_BATCH_SIZE,
@ -105,13 +123,21 @@ class GoogleDriveConnector(LoadConnector):
return {DB_CREDENTIALS_DICT_KEY: new_creds_json_str}
return None
def load_from_state(self) -> GenerateDocumentsOutput:
def _fetch_docs_from_drive(
self,
start: SecondsSinceUnixEpoch | None = None,
end: SecondsSinceUnixEpoch | None = None,
) -> GenerateDocumentsOutput:
if self.creds is None:
raise PermissionError("Not logged into Google Drive")
service = discovery.build("drive", "v3", credentials=self.creds)
for files_batch in get_file_batches(
service, self.include_shared, self.batch_size
service,
self.include_shared,
self.batch_size,
time_range_start=start,
time_range_end=end,
):
doc_batch = []
for file in files_batch:
@ -129,3 +155,11 @@ class GoogleDriveConnector(LoadConnector):
)
yield doc_batch
def load_from_state(self) -> GenerateDocumentsOutput:
yield from self._fetch_docs_from_drive()
def poll_source(
self, start: SecondsSinceUnixEpoch, end: SecondsSinceUnixEpoch
) -> GenerateDocumentsOutput:
yield from self._fetch_docs_from_drive(start, end)

View File

@ -46,6 +46,7 @@ def get_drive_tokens(
try:
creds.refresh(Request())
if creds.valid:
logger.info("Refreshed Google Drive tokens.")
return creds
except Exception as e:
logger.exception(f"Failed to refresh google drive access token due to: {e}")

View File

@ -15,7 +15,8 @@ IndexFilter = dict[str, str | list[str] | None]
class DocumentIndex(Generic[T], abc.ABC):
@abc.abstractmethod
def index(self, chunks: list[T], user_id: UUID | None) -> bool:
def index(self, chunks: list[T], user_id: UUID | None) -> int:
"""Indexes document chunks into the Document Index and return the number of new documents"""
raise NotImplementedError

View File

@ -52,6 +52,7 @@ def create_qdrant_collection(
def get_qdrant_document_whitelists(
doc_chunk_id: str, collection_name: str, q_client: QdrantClient
) -> tuple[bool, list[str], list[str]]:
"""Get whether a document is found and the existing whitelists"""
results = q_client.retrieve(
collection_name=collection_name,
ids=[doc_chunk_id],
@ -69,8 +70,8 @@ def get_qdrant_document_whitelists(
def delete_qdrant_doc_chunks(
document_id: str, collection_name: str, q_client: QdrantClient
) -> None:
q_client.delete(
) -> bool:
res = q_client.delete(
collection_name=collection_name,
points_selector=models.FilterSelector(
filter=models.Filter(
@ -83,6 +84,7 @@ def delete_qdrant_doc_chunks(
)
),
)
return True
def index_qdrant_chunks(
@ -91,7 +93,7 @@ def index_qdrant_chunks(
collection: str,
client: QdrantClient | None = None,
batch_upsert: bool = True,
) -> bool:
) -> int:
# Public documents will have the PUBLIC string in ALLOWED_USERS
# If credential that kicked this off has no user associated, either Auth is off or the doc is public
user_str = PUBLIC_DOC_PAT if user_id is None else str(user_id)
@ -100,6 +102,7 @@ def index_qdrant_chunks(
point_structs: list[PointStruct] = []
# Maps document id to dict of whitelists for users/groups each containing list of users/groups as strings
doc_user_map: dict[str, dict[str, list[str]]] = {}
docs_deleted = 0
for chunk in chunks:
document = chunk.source_document
doc_user_map, delete_doc = update_doc_user_map(
@ -114,6 +117,8 @@ def index_qdrant_chunks(
)
if delete_doc:
# Processing the first chunk of the doc and the doc exists
docs_deleted += 1
delete_qdrant_doc_chunks(document.id, collection, q_client)
point_structs.extend(
@ -138,7 +143,6 @@ def index_qdrant_chunks(
]
)
index_results = None
if batch_upsert:
point_struct_batches = [
point_structs[x : x + DEFAULT_BATCH_SIZE]
@ -171,4 +175,5 @@ def index_qdrant_chunks(
logger.info(
f"Document batch of size {len(point_structs)} indexing status: {index_results.status}"
)
return index_results is not None and index_results.status == UpdateStatus.COMPLETED
return len(doc_user_map.keys()) - docs_deleted

View File

@ -77,7 +77,7 @@ class QdrantIndex(VectorIndex):
self.collection = collection
self.client = get_qdrant_client()
def index(self, chunks: list[EmbeddedIndexChunk], user_id: UUID | None) -> bool:
def index(self, chunks: list[EmbeddedIndexChunk], user_id: UUID | None) -> int:
return index_qdrant_chunks(
chunks=chunks,
user_id=user_id,

View File

@ -86,7 +86,7 @@ def get_typesense_document_whitelists(
def delete_typesense_doc_chunks(
document_id: str, collection_name: str, ts_client: typesense.Client
) -> None:
) -> bool:
search_parameters = {
"q": document_id,
"query_by": DOCUMENT_ID,
@ -98,6 +98,7 @@ def delete_typesense_doc_chunks(
ts_client.collections[collection_name].documents[hit["document"]["id"]].delete()
for hit in hits["hits"]
]
return True if hits else False
def index_typesense_chunks(
@ -106,12 +107,13 @@ def index_typesense_chunks(
collection: str,
client: typesense.Client | None = None,
batch_upsert: bool = True,
) -> bool:
) -> int:
user_str = PUBLIC_DOC_PAT if user_id is None else str(user_id)
ts_client: typesense.Client = client if client else get_typesense_client()
new_documents: list[dict[str, Any]] = []
doc_user_map: dict[str, dict[str, list[str]]] = {}
docs_deleted = 0
for chunk in chunks:
document = chunk.source_document
doc_user_map, delete_doc = update_doc_user_map(
@ -126,6 +128,8 @@ def index_typesense_chunks(
)
if delete_doc:
# Processing the first chunk of the doc and the doc exists
docs_deleted += 1
delete_typesense_doc_chunks(document.id, collection, ts_client)
new_documents.append(
@ -168,7 +172,7 @@ def index_typesense_chunks(
for document in new_documents
]
return True
return len(doc_user_map.keys()) - docs_deleted
def _build_typesense_filters(
@ -206,7 +210,7 @@ class TypesenseIndex(KeywordIndex):
self.collection = collection
self.ts_client = get_typesense_client()
def index(self, chunks: list[IndexChunk], user_id: UUID | None) -> bool:
def index(self, chunks: list[IndexChunk], user_id: UUID | None) -> int:
return index_typesense_chunks(
chunks=chunks,
user_id=user_id,

View File

@ -2,11 +2,8 @@ from typing import cast
from danswer.configs.constants import DocumentSource
from danswer.connectors.models import InputType
from danswer.db.credentials import fetch_credential_by_id
from danswer.db.models import Connector
from danswer.db.models import ConnectorCredentialAssociation
from danswer.db.models import IndexAttempt
from danswer.db.models import User
from danswer.server.models import ConnectorBase
from danswer.server.models import ObjectCreationIdResponse
from danswer.server.models import StatusResponse
@ -147,95 +144,6 @@ def get_connector_credential_ids(
return [association.credential.id for association in connector.credentials]
def add_credential_to_connector(
connector_id: int,
credential_id: int,
user: User,
db_session: Session,
) -> StatusResponse[int]:
connector = fetch_connector_by_id(connector_id, db_session)
credential = fetch_credential_by_id(credential_id, user, db_session)
if connector is None:
raise HTTPException(status_code=404, detail="Connector does not exist")
if credential is None:
raise HTTPException(
status_code=401,
detail="Credential does not exist or does not belong to user",
)
existing_association = (
db_session.query(ConnectorCredentialAssociation)
.filter(
ConnectorCredentialAssociation.connector_id == connector_id,
ConnectorCredentialAssociation.credential_id == credential_id,
)
.one_or_none()
)
if existing_association is not None:
return StatusResponse(
success=False,
message=f"Connector already has Credential {credential_id}",
data=connector_id,
)
association = ConnectorCredentialAssociation(
connector_id=connector_id, credential_id=credential_id
)
db_session.add(association)
db_session.commit()
return StatusResponse(
success=True,
message=f"New Credential {credential_id} added to Connector",
data=connector_id,
)
def remove_credential_from_connector(
connector_id: int,
credential_id: int,
user: User,
db_session: Session,
) -> StatusResponse[int]:
connector = fetch_connector_by_id(connector_id, db_session)
credential = fetch_credential_by_id(credential_id, user, db_session)
if connector is None:
raise HTTPException(status_code=404, detail="Connector does not exist")
if credential is None:
raise HTTPException(
status_code=404,
detail="Credential does not exist or does not belong to user",
)
association = (
db_session.query(ConnectorCredentialAssociation)
.filter(
ConnectorCredentialAssociation.connector_id == connector_id,
ConnectorCredentialAssociation.credential_id == credential_id,
)
.one_or_none()
)
if association is not None:
db_session.delete(association)
db_session.commit()
return StatusResponse(
success=True,
message=f"Credential {credential_id} removed from Connector",
data=connector_id,
)
return StatusResponse(
success=False,
message=f"Connector already does not have Credential {credential_id}",
data=connector_id,
)
def fetch_latest_index_attempt_by_connector(
db_session: Session,
source: DocumentSource | None = None,

View File

@ -0,0 +1,148 @@
from danswer.db.connector import fetch_connector_by_id
from danswer.db.credentials import fetch_credential_by_id
from danswer.db.models import ConnectorCredentialPair
from danswer.db.models import IndexingStatus
from danswer.db.models import User
from danswer.server.models import StatusResponse
from danswer.utils.logging import setup_logger
from fastapi import HTTPException
from sqlalchemy import func
from sqlalchemy import select
from sqlalchemy.orm import Session
logger = setup_logger()
def get_connector_credential_pairs(
db_session: Session, include_disabled: bool = True
) -> list[ConnectorCredentialPair]:
stmt = select(ConnectorCredentialPair)
if not include_disabled:
stmt = stmt.where(ConnectorCredentialPair.connector.disabled == False)
results = db_session.scalars(stmt)
return list(results.all())
def get_connector_credential_pair(
connector_id: int,
credential_id: int,
db_session: Session,
) -> ConnectorCredentialPair | None:
stmt = select(ConnectorCredentialPair)
stmt = stmt.where(ConnectorCredentialPair.connector_id == connector_id)
stmt = stmt.where(ConnectorCredentialPair.credential_id == credential_id)
result = db_session.execute(stmt)
return result.scalar_one_or_none()
def update_connector_credential_pair(
connector_id: int,
credential_id: int,
attempt_status: IndexingStatus,
net_docs: int | None,
db_session: Session,
) -> None:
cc_pair = get_connector_credential_pair(connector_id, credential_id, db_session)
if not cc_pair:
logger.warning(
f"Attempted to update pair for connector id {connector_id} "
f"and credential id {credential_id}"
)
return
cc_pair.last_attempt_status = attempt_status
if attempt_status == IndexingStatus.SUCCESS:
cc_pair.last_successful_index_time = func.now() # type:ignore
if net_docs is not None:
cc_pair.total_docs_indexed += net_docs
db_session.commit()
def add_credential_to_connector(
connector_id: int,
credential_id: int,
user: User,
db_session: Session,
) -> StatusResponse[int]:
connector = fetch_connector_by_id(connector_id, db_session)
credential = fetch_credential_by_id(credential_id, user, db_session)
if connector is None:
raise HTTPException(status_code=404, detail="Connector does not exist")
if credential is None:
raise HTTPException(
status_code=401,
detail="Credential does not exist or does not belong to user",
)
existing_association = (
db_session.query(ConnectorCredentialPair)
.filter(
ConnectorCredentialPair.connector_id == connector_id,
ConnectorCredentialPair.credential_id == credential_id,
)
.one_or_none()
)
if existing_association is not None:
return StatusResponse(
success=False,
message=f"Connector already has Credential {credential_id}",
data=connector_id,
)
association = ConnectorCredentialPair(
connector_id=connector_id,
credential_id=credential_id,
last_attempt_status=IndexingStatus.NOT_STARTED,
)
db_session.add(association)
db_session.commit()
return StatusResponse(
success=True,
message=f"New Credential {credential_id} added to Connector",
data=connector_id,
)
def remove_credential_from_connector(
connector_id: int,
credential_id: int,
user: User,
db_session: Session,
) -> StatusResponse[int]:
connector = fetch_connector_by_id(connector_id, db_session)
credential = fetch_credential_by_id(credential_id, user, db_session)
if connector is None:
raise HTTPException(status_code=404, detail="Connector does not exist")
if credential is None:
raise HTTPException(
status_code=404,
detail="Credential does not exist or does not belong to user",
)
association = (
db_session.query(ConnectorCredentialPair)
.filter(
ConnectorCredentialPair.connector_id == connector_id,
ConnectorCredentialPair.credential_id == credential_id,
)
.one_or_none()
)
if association is not None:
db_session.delete(association)
db_session.commit()
return StatusResponse(
success=True,
message=f"Credential {credential_id} removed from Connector",
data=connector_id,
)
return StatusResponse(
success=False,
message=f"Connector already does not have Credential {credential_id}",
data=connector_id,
)

View File

@ -1,6 +1,7 @@
from collections.abc import AsyncGenerator
from collections.abc import Generator
from datetime import datetime
from datetime import timezone
from danswer.configs.app_configs import POSTGRES_DB
from danswer.configs.app_configs import POSTGRES_HOST
@ -27,6 +28,15 @@ def get_db_current_time(db_session: Session) -> datetime:
return result
def translate_db_time_to_server_time(
db_time: datetime, db_session: Session
) -> datetime:
server_now = datetime.now()
db_now = get_db_current_time(db_session)
time_diff = server_now - db_now.astimezone(timezone.utc).replace(tzinfo=None)
return db_time + time_diff
def build_connection_string(
*,
db_api: str = ASYNC_DB_API,

View File

@ -1,3 +1,4 @@
from danswer.db.engine import translate_db_time_to_server_time
from danswer.db.models import IndexAttempt
from danswer.db.models import IndexingStatus
from danswer.utils.logging import setup_logger
@ -74,13 +75,31 @@ def mark_attempt_failed(
db_session.commit()
def get_last_finished_attempt(
def get_last_successful_attempt(
connector_id: int,
credential_id: int,
db_session: Session,
) -> IndexAttempt | None:
stmt = select(IndexAttempt)
stmt = stmt.where(IndexAttempt.connector_id == connector_id)
stmt = stmt.where(IndexAttempt.credential_id == credential_id)
stmt = stmt.where(IndexAttempt.status == IndexingStatus.SUCCESS)
stmt = stmt.order_by(desc(IndexAttempt.time_updated))
# Note, the below is using time_created instead of time_updated
stmt = stmt.order_by(desc(IndexAttempt.time_created))
return db_session.execute(stmt).scalars().first()
def get_last_successful_attempt_start_time(
connector_id: int,
credential_id: int,
db_session: Session,
) -> float:
"""Technically the start time is a bit later than creation but for intended use, it doesn't matter"""
last_indexing = get_last_successful_attempt(connector_id, credential_id, db_session)
if last_indexing is None:
return 0.0
last_index_start = translate_db_time_to_server_time(
last_indexing.time_created, db_session
)
return last_index_start.timestamp()

View File

@ -24,6 +24,13 @@ from sqlalchemy.orm import mapped_column
from sqlalchemy.orm import relationship
class IndexingStatus(str, PyEnum):
NOT_STARTED = "not_started"
IN_PROGRESS = "in_progress"
SUCCESS = "success"
FAILED = "failed"
class Base(DeclarativeBase):
pass
@ -48,19 +55,25 @@ class AccessToken(SQLAlchemyBaseAccessTokenTableUUID, Base):
pass
class ConnectorCredentialAssociation(Base):
class ConnectorCredentialPair(Base):
"""Connectors and Credentials can have a many-to-many relationship
I.e. A Confluence Connector may have multiple admin users who can run it with their own credentials
I.e. An admin user may use the same credential to index multiple Confluence Spaces
"""
__tablename__ = "connector_credential_association"
__tablename__ = "connector_credential_pair"
connector_id: Mapped[int] = mapped_column(
ForeignKey("connector.id"), primary_key=True
)
credential_id: Mapped[int] = mapped_column(
ForeignKey("credential.id"), primary_key=True
)
# Time finished, not used for calculating backend jobs which uses time started (created)
last_successful_index_time: Mapped[datetime.datetime | None] = mapped_column(
DateTime(timezone=True), default=None
)
last_attempt_status: Mapped[IndexingStatus] = mapped_column(Enum(IndexingStatus))
total_docs_indexed: Mapped[int] = mapped_column(Integer, default=0)
connector: Mapped["Connector"] = relationship(
"Connector", back_populates="credentials"
@ -91,8 +104,8 @@ class Connector(Base):
)
disabled: Mapped[bool] = mapped_column(Boolean, default=False)
credentials: Mapped[List["ConnectorCredentialAssociation"]] = relationship(
"ConnectorCredentialAssociation",
credentials: Mapped[List["ConnectorCredentialPair"]] = relationship(
"ConnectorCredentialPair",
back_populates="connector",
cascade="all, delete-orphan",
)
@ -115,8 +128,8 @@ class Credential(Base):
DateTime(timezone=True), server_default=func.now(), onupdate=func.now()
)
connectors: Mapped[List["ConnectorCredentialAssociation"]] = relationship(
"ConnectorCredentialAssociation",
connectors: Mapped[List["ConnectorCredentialPair"]] = relationship(
"ConnectorCredentialPair",
back_populates="credential",
cascade="all, delete-orphan",
)
@ -126,13 +139,6 @@ class Credential(Base):
user: Mapped[User] = relationship("User", back_populates="credentials")
class IndexingStatus(str, PyEnum):
NOT_STARTED = "not_started"
IN_PROGRESS = "in_progress"
SUCCESS = "success"
FAILED = "failed"
class IndexAttempt(Base):
"""
Represents an attempt to index a group of 1 or more documents from a
@ -144,10 +150,12 @@ class IndexAttempt(Base):
id: Mapped[int] = mapped_column(primary_key=True)
connector_id: Mapped[int | None] = mapped_column(
ForeignKey("connector.id"), nullable=True
ForeignKey("connector.id"),
nullable=True,
)
credential_id: Mapped[int | None] = mapped_column(
ForeignKey("credential.id"), nullable=True
ForeignKey("credential.id"),
nullable=True,
)
status: Mapped[IndexingStatus] = mapped_column(Enum(IndexingStatus))
document_ids: Mapped[list[str] | None] = mapped_column(
@ -157,10 +165,13 @@ class IndexAttempt(Base):
String(), default=None
) # only filled if status = "failed"
time_created: Mapped[datetime.datetime] = mapped_column(
DateTime(timezone=True), server_default=func.now()
DateTime(timezone=True),
server_default=func.now(),
)
time_updated: Mapped[datetime.datetime] = mapped_column(
DateTime(timezone=True), server_default=func.now(), onupdate=func.now()
DateTime(timezone=True),
server_default=func.now(),
onupdate=func.now(),
)
connector: Mapped[Connector] = relationship(

View File

@ -17,7 +17,6 @@ from danswer.connectors.google_drive.connector_auth import (
)
from danswer.connectors.google_drive.connector_auth import upsert_google_app_cred
from danswer.connectors.google_drive.connector_auth import verify_csrf
from danswer.db.connector import add_credential_to_connector
from danswer.db.connector import create_connector
from danswer.db.connector import delete_connector
from danswer.db.connector import fetch_connector_by_id
@ -25,8 +24,10 @@ from danswer.db.connector import fetch_connectors
from danswer.db.connector import fetch_latest_index_attempt_by_connector
from danswer.db.connector import fetch_latest_index_attempts_by_status
from danswer.db.connector import get_connector_credential_ids
from danswer.db.connector import remove_credential_from_connector
from danswer.db.connector import update_connector
from danswer.db.connector_credential_pair import add_credential_to_connector
from danswer.db.connector_credential_pair import get_connector_credential_pairs
from danswer.db.connector_credential_pair import remove_credential_from_connector
from danswer.db.credentials import create_credential
from danswer.db.credentials import delete_credential
from danswer.db.credentials import fetch_credential_by_id
@ -172,131 +173,28 @@ def upload_files(
return FileUploadResponse(file_paths=file_paths)
@router.get("/admin/latest-index-attempt")
def list_all_index_attempts(
_: User = Depends(current_admin_user),
db_session: Session = Depends(get_session),
) -> list[IndexAttemptSnapshot]:
index_attempts = fetch_latest_index_attempt_by_connector(db_session)
return [
IndexAttemptSnapshot(
source=index_attempt.connector.source,
input_type=index_attempt.connector.input_type,
status=index_attempt.status,
connector_specific_config=index_attempt.connector.connector_specific_config,
docs_indexed=0
if not index_attempt.document_ids
else len(index_attempt.document_ids),
time_created=index_attempt.time_created,
time_updated=index_attempt.time_updated,
)
for index_attempt in index_attempts
]
@router.get("/admin/latest-index-attempt/{source}")
def list_index_attempts(
source: DocumentSource,
_: User = Depends(current_admin_user),
db_session: Session = Depends(get_session),
) -> list[IndexAttemptSnapshot]:
index_attempts = fetch_latest_index_attempt_by_connector(db_session, source=source)
return [
IndexAttemptSnapshot(
source=index_attempt.connector.source,
input_type=index_attempt.connector.input_type,
status=index_attempt.status,
connector_specific_config=index_attempt.connector.connector_specific_config,
docs_indexed=0
if not index_attempt.document_ids
else len(index_attempt.document_ids),
time_created=index_attempt.time_created,
time_updated=index_attempt.time_updated,
)
for index_attempt in index_attempts
]
@router.get("/admin/connector/indexing-status")
def get_connector_indexing_status(
_: User = Depends(current_admin_user),
db_session: Session = Depends(get_session),
) -> list[ConnectorIndexingStatus]:
connector_id_to_connector: dict[int, Connector] = {
connector.id: connector for connector in fetch_connectors(db_session)
}
index_attempts = fetch_latest_index_attempts_by_status(db_session)
connector_credential_pair_to_index_attempts: dict[
tuple[int, int], list[IndexAttempt]
] = defaultdict(list)
for index_attempt in index_attempts:
# don't consider index attempts where the connector has been deleted
# or the credential has been deleted
if (
index_attempt.connector_id is not None
and index_attempt.credential_id is not None
):
connector_credential_pair_to_index_attempts[
(index_attempt.connector_id, index_attempt.credential_id)
].append(index_attempt)
indexing_statuses: list[ConnectorIndexingStatus] = []
for (
connector_id,
credential_id,
), index_attempts in connector_credential_pair_to_index_attempts.items():
# NOTE: index_attempts is guaranteed to be length > 0
connector = connector_id_to_connector[connector_id]
credential = [
credential_association.credential
for credential_association in connector.credentials
if credential_association.credential_id == credential_id
][0]
index_attempts_sorted = sorted(
index_attempts, key=lambda x: x.time_updated, reverse=True
)
successful_index_attempts_sorted = [
index_attempt
for index_attempt in index_attempts_sorted
if index_attempt.status == IndexingStatus.SUCCESS
]
cc_pairs = get_connector_credential_pairs(db_session)
for cc_pair in cc_pairs:
connector = cc_pair.connector
credential = cc_pair.credential
indexing_statuses.append(
ConnectorIndexingStatus(
connector=ConnectorSnapshot.from_connector_db_model(connector),
public_doc=credential.public_doc,
owner=credential.user.email if credential.user else "",
last_status=index_attempts_sorted[0].status,
last_success=successful_index_attempts_sorted[0].time_updated
if successful_index_attempts_sorted
else None,
docs_indexed=len(successful_index_attempts_sorted[0].document_ids)
if successful_index_attempts_sorted
and successful_index_attempts_sorted[0].document_ids
else 0,
),
last_status=cc_pair.last_attempt_status,
last_success=cc_pair.last_successful_index_time,
docs_indexed=cc_pair.total_docs_indexed,
)
)
# add in the connectors that haven't started indexing yet
for connector in connector_id_to_connector.values():
for credential_association in connector.credentials:
if (
connector.id,
credential_association.credential_id,
) not in connector_credential_pair_to_index_attempts:
indexing_statuses.append(
ConnectorIndexingStatus(
connector=ConnectorSnapshot.from_connector_db_model(connector),
public_doc=credential_association.credential.public_doc,
owner=credential_association.credential.user.email
if credential_association.credential.user
else "",
last_status=IndexingStatus.NOT_STARTED,
last_success=None,
docs_indexed=0,
),
)
return indexing_statuses

View File

@ -13,7 +13,6 @@ from danswer.direct_qa.question_answer import get_json_line
from danswer.search.danswer_helper import query_intent
from danswer.search.danswer_helper import recommend_search_flow
from danswer.search.keyword_search import retrieve_keyword_documents
from danswer.search.models import QueryFlow
from danswer.search.models import SearchType
from danswer.search.semantic_search import chunks_to_search_docs
from danswer.search.semantic_search import retrieve_ranked_documents

View File

@ -5,7 +5,6 @@ from uuid import UUID
from danswer.chunking.chunk import Chunker
from danswer.chunking.chunk import DefaultChunker
from danswer.chunking.models import EmbeddedIndexChunk
from danswer.connectors.models import Document
from danswer.datastores.interfaces import KeywordIndex
from danswer.datastores.interfaces import VectorIndex
@ -13,12 +12,13 @@ from danswer.datastores.qdrant.store import QdrantIndex
from danswer.datastores.typesense.store import TypesenseIndex
from danswer.search.models import Embedder
from danswer.search.semantic_search import DefaultEmbedder
from danswer.utils.logging import setup_logger
logger = setup_logger()
class IndexingPipelineProtocol(Protocol):
def __call__(
self, documents: list[Document], user_id: UUID | None
) -> list[EmbeddedIndexChunk]:
def __call__(self, documents: list[Document], user_id: UUID | None) -> int:
...
@ -30,15 +30,19 @@ def _indexing_pipeline(
keyword_index: KeywordIndex,
documents: list[Document],
user_id: UUID | None,
) -> list[EmbeddedIndexChunk]:
) -> int:
# TODO: make entire indexing pipeline async to not block the entire process
# when running on async endpoints
chunks = list(chain(*[chunker.chunk(document) for document in documents]))
# TODO keyword indexing can occur at same time as embedding
keyword_index.index(chunks, user_id)
net_doc_count_keyword = keyword_index.index(chunks, user_id)
chunks_with_embeddings = embedder.embed(chunks)
vector_index.index(chunks_with_embeddings, user_id)
return chunks_with_embeddings
net_doc_count_vector = vector_index.index(chunks_with_embeddings, user_id)
if net_doc_count_vector != net_doc_count_vector:
logger.exception(
"Document count change from keyword/vector indices don't align"
)
return max(net_doc_count_keyword, net_doc_count_vector)
def build_indexing_pipeline(

View File

@ -24,17 +24,12 @@ PyPDF2==3.0.1
pytest-playwright==0.3.2
python-multipart==0.0.6
qdrant-client==1.2.0
requests==2.28.2
requests==2.31.0
rfc3986==1.5.0
sentence-transformers==2.2.2
slack-sdk==3.20.2
SQLAlchemy[mypy]==2.0.12
tensorflow==2.12.0
transformers==4.27.3
types-beautifulsoup4==4.12.0.3
types-html5lib==1.1.11.13
types-regex==2023.3.23.1
types-requests==2.28.11.17
types-urllib3==1.26.25.11
transformers==4.30.1
typesense==0.15.1
uvicorn==0.21.1

View File

@ -1,10 +1,11 @@
mypy==1.1.1
mypy-extensions==1.0.0
black==23.3.0
reorder-python-imports==3.9.0
mypy-extensions==1.0.0
mypy==1.1.1
pre-commit==3.2.2
reorder-python-imports==3.9.0
types-beautifulsoup4==4.12.0.3
types-html5lib==1.1.11.13
types-psycopg2==2.9.21.10
types-regex==2023.3.23.1
types-requests==2.28.11.17
types-urllib3==1.26.25.11
types-regex==2023.3.23.1

View File

@ -0,0 +1,42 @@
import psycopg2
from danswer.configs.app_configs import POSTGRES_DB
from danswer.configs.app_configs import POSTGRES_HOST
from danswer.configs.app_configs import POSTGRES_PASSWORD
from danswer.configs.app_configs import POSTGRES_PORT
from danswer.configs.app_configs import POSTGRES_USER
def wipe_all_rows(database: str) -> None:
conn = psycopg2.connect(
dbname=database,
user=POSTGRES_USER,
password=POSTGRES_PASSWORD,
host=POSTGRES_HOST,
port=POSTGRES_PORT,
)
cur = conn.cursor()
cur.execute(
"""
SELECT table_name
FROM information_schema.tables
WHERE table_schema = 'public'
AND table_type = 'BASE TABLE'
"""
)
table_names = cur.fetchall()
for table_name in table_names:
if table_name[0] == "alembic_version":
continue
cur.execute(f'DELETE FROM "{table_name[0]}"')
print(f"Deleted all rows from table {table_name[0]}")
conn.commit()
cur.close()
conn.close()
if __name__ == "__main__":
wipe_all_rows(POSTGRES_DB)