mirror of
https://github.com/danswer-ai/danswer.git
synced 2025-05-11 20:30:07 +02:00
DAN-115 Document Polling (#91)
Includes updated document counting for polling
This commit is contained in:
parent
97b9b56b03
commit
6fe54a4eed
@ -74,7 +74,7 @@ def upgrade() -> None:
|
||||
sa.Column(
|
||||
"credential_json",
|
||||
postgresql.JSONB(astext_type=sa.Text()),
|
||||
nullable=True,
|
||||
nullable=False,
|
||||
),
|
||||
sa.Column(
|
||||
"user_id",
|
||||
@ -101,7 +101,7 @@ def upgrade() -> None:
|
||||
sa.PrimaryKeyConstraint("id"),
|
||||
)
|
||||
op.create_table(
|
||||
"connector_credential_association",
|
||||
"connector_credential_pair",
|
||||
sa.Column("connector_id", sa.Integer(), nullable=False),
|
||||
sa.Column("credential_id", sa.Integer(), nullable=False),
|
||||
sa.ForeignKeyConstraint(
|
||||
@ -168,6 +168,6 @@ def downgrade() -> None:
|
||||
)
|
||||
op.drop_column("index_attempt", "credential_id")
|
||||
op.drop_column("index_attempt", "connector_id")
|
||||
op.drop_table("connector_credential_association")
|
||||
op.drop_table("connector_credential_pair")
|
||||
op.drop_table("credential")
|
||||
op.drop_table("connector")
|
||||
|
@ -0,0 +1,51 @@
|
||||
"""Polling Document Count
|
||||
|
||||
Revision ID: 3c5e35aa9af0
|
||||
Revises: 27c6ecc08586
|
||||
Create Date: 2023-06-14 23:45:51.760440
|
||||
|
||||
"""
|
||||
import sqlalchemy as sa
|
||||
from alembic import op
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "3c5e35aa9af0"
|
||||
down_revision = "27c6ecc08586"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
op.add_column(
|
||||
"connector_credential_pair",
|
||||
sa.Column(
|
||||
"last_successful_index_time",
|
||||
sa.DateTime(timezone=True),
|
||||
nullable=True,
|
||||
),
|
||||
)
|
||||
op.add_column(
|
||||
"connector_credential_pair",
|
||||
sa.Column(
|
||||
"last_attempt_status",
|
||||
sa.Enum(
|
||||
"NOT_STARTED",
|
||||
"IN_PROGRESS",
|
||||
"SUCCESS",
|
||||
"FAILED",
|
||||
name="indexingstatus",
|
||||
),
|
||||
nullable=False,
|
||||
),
|
||||
)
|
||||
op.add_column(
|
||||
"connector_credential_pair",
|
||||
sa.Column("total_docs_indexed", sa.Integer(), nullable=False),
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.drop_column("connector_credential_pair", "total_docs_indexed")
|
||||
op.drop_column("connector_credential_pair", "last_attempt_status")
|
||||
op.drop_column("connector_credential_pair", "last_successful_index_time")
|
@ -6,18 +6,21 @@ from danswer.connectors.interfaces import PollConnector
|
||||
from danswer.connectors.models import InputType
|
||||
from danswer.db.connector import disable_connector
|
||||
from danswer.db.connector import fetch_connectors
|
||||
from danswer.db.connector_credential_pair import update_connector_credential_pair
|
||||
from danswer.db.credentials import backend_update_credential_json
|
||||
from danswer.db.engine import build_engine
|
||||
from danswer.db.engine import get_db_current_time
|
||||
from danswer.db.index_attempt import create_index_attempt
|
||||
from danswer.db.index_attempt import get_inprogress_index_attempts
|
||||
from danswer.db.index_attempt import get_last_finished_attempt
|
||||
from danswer.db.index_attempt import get_last_successful_attempt
|
||||
from danswer.db.index_attempt import get_last_successful_attempt_start_time
|
||||
from danswer.db.index_attempt import get_not_started_index_attempts
|
||||
from danswer.db.index_attempt import mark_attempt_failed
|
||||
from danswer.db.index_attempt import mark_attempt_in_progress
|
||||
from danswer.db.index_attempt import mark_attempt_succeeded
|
||||
from danswer.db.models import Connector
|
||||
from danswer.db.models import IndexAttempt
|
||||
from danswer.db.models import IndexingStatus
|
||||
from danswer.utils.indexing_pipeline import build_indexing_pipeline
|
||||
from danswer.utils.logging import setup_logger
|
||||
from sqlalchemy.orm import Session
|
||||
@ -33,9 +36,7 @@ def should_create_new_indexing(
|
||||
if not last_index:
|
||||
return True
|
||||
current_db_time = get_db_current_time(db_session)
|
||||
time_since_index = (
|
||||
current_db_time - last_index.time_updated
|
||||
) # Maybe better to do time created
|
||||
time_since_index = current_db_time - last_index.time_updated
|
||||
return time_since_index.total_seconds() >= connector.refresh_freq
|
||||
|
||||
|
||||
@ -51,24 +52,41 @@ def create_indexing_jobs(db_session: Session) -> None:
|
||||
# Currently single threaded so any still in-progress must have errored
|
||||
for attempt in in_progress_indexing_attempts:
|
||||
logger.warning(
|
||||
f"Marking in-progress attempt 'connector: {attempt.connector_id}, credential: {attempt.credential_id}' as failed"
|
||||
f"Marking in-progress attempt 'connector: {attempt.connector_id}, "
|
||||
f"credential: {attempt.credential_id}' as failed"
|
||||
)
|
||||
mark_attempt_failed(attempt, db_session)
|
||||
|
||||
last_finished_indexing_attempt = get_last_finished_attempt(
|
||||
connector.id, db_session
|
||||
)
|
||||
if not should_create_new_indexing(
|
||||
connector, last_finished_indexing_attempt, db_session
|
||||
):
|
||||
continue
|
||||
if attempt.connector_id and attempt.credential_id:
|
||||
update_connector_credential_pair(
|
||||
connector_id=attempt.connector_id,
|
||||
credential_id=attempt.credential_id,
|
||||
attempt_status=IndexingStatus.FAILED,
|
||||
net_docs=None,
|
||||
db_session=db_session,
|
||||
)
|
||||
|
||||
for association in connector.credentials:
|
||||
credential = association.credential
|
||||
|
||||
last_successful_attempt = get_last_successful_attempt(
|
||||
connector.id, credential.id, db_session
|
||||
)
|
||||
if not should_create_new_indexing(
|
||||
connector, last_successful_attempt, db_session
|
||||
):
|
||||
continue
|
||||
create_index_attempt(connector.id, credential.id, db_session)
|
||||
|
||||
update_connector_credential_pair(
|
||||
connector_id=connector.id,
|
||||
credential_id=credential.id,
|
||||
attempt_status=IndexingStatus.NOT_STARTED,
|
||||
net_docs=None,
|
||||
db_session=db_session,
|
||||
)
|
||||
|
||||
def run_indexing_jobs(last_run_time: float, db_session: Session) -> None:
|
||||
|
||||
def run_indexing_jobs(db_session: Session) -> None:
|
||||
indexing_pipeline = build_indexing_pipeline()
|
||||
|
||||
new_indexing_attempts = get_not_started_index_attempts(db_session)
|
||||
@ -77,7 +95,7 @@ def run_indexing_jobs(last_run_time: float, db_session: Session) -> None:
|
||||
logger.info(
|
||||
f"Starting new indexing attempt for connector: '{attempt.connector.name}', "
|
||||
f"with config: '{attempt.connector.connector_specific_config}', and "
|
||||
f" with credentials: '{[c.credential_id for c in attempt.connector.credentials]}'"
|
||||
f"with credentials: '{[c.credential_id for c in attempt.connector.credentials]}'"
|
||||
)
|
||||
mark_attempt_in_progress(attempt, db_session)
|
||||
|
||||
@ -85,6 +103,14 @@ def run_indexing_jobs(last_run_time: float, db_session: Session) -> None:
|
||||
db_credential = attempt.credential
|
||||
task = db_connector.input_type
|
||||
|
||||
update_connector_credential_pair(
|
||||
connector_id=db_connector.id,
|
||||
credential_id=db_credential.id,
|
||||
attempt_status=IndexingStatus.IN_PROGRESS,
|
||||
net_docs=None,
|
||||
db_session=db_session,
|
||||
)
|
||||
|
||||
try:
|
||||
runnable_connector, new_credential_json = instantiate_connector(
|
||||
db_connector.source,
|
||||
@ -101,6 +127,7 @@ def run_indexing_jobs(last_run_time: float, db_session: Session) -> None:
|
||||
disable_connector(db_connector.id, db_session)
|
||||
continue
|
||||
|
||||
net_doc_change = 0
|
||||
try:
|
||||
if task == InputType.LOAD_STATE:
|
||||
assert isinstance(runnable_connector, LoadConnector)
|
||||
@ -108,6 +135,14 @@ def run_indexing_jobs(last_run_time: float, db_session: Session) -> None:
|
||||
|
||||
elif task == InputType.POLL:
|
||||
assert isinstance(runnable_connector, PollConnector)
|
||||
if attempt.connector_id is None or attempt.credential_id is None:
|
||||
raise ValueError(
|
||||
f"Polling attempt {attempt.id} is missing connector_id or credential_id, "
|
||||
f"can't fetch time range."
|
||||
)
|
||||
last_run_time = get_last_successful_attempt_start_time(
|
||||
attempt.connector_id, attempt.credential_id, db_session
|
||||
)
|
||||
doc_batch_generator = runnable_connector.poll_source(
|
||||
last_run_time, time.time()
|
||||
)
|
||||
@ -121,28 +156,43 @@ def run_indexing_jobs(last_run_time: float, db_session: Session) -> None:
|
||||
index_user_id = (
|
||||
None if db_credential.public_doc else db_credential.user_id
|
||||
)
|
||||
indexing_pipeline(documents=doc_batch, user_id=index_user_id)
|
||||
net_doc_change += indexing_pipeline(
|
||||
documents=doc_batch, user_id=index_user_id
|
||||
)
|
||||
document_ids.extend([doc.id for doc in doc_batch])
|
||||
|
||||
mark_attempt_succeeded(attempt, document_ids, db_session)
|
||||
update_connector_credential_pair(
|
||||
connector_id=db_connector.id,
|
||||
credential_id=db_credential.id,
|
||||
attempt_status=IndexingStatus.SUCCESS,
|
||||
net_docs=net_doc_change,
|
||||
db_session=db_session,
|
||||
)
|
||||
|
||||
logger.info(f"Indexed {len(document_ids)} documents")
|
||||
|
||||
except Exception as e:
|
||||
logger.exception(f"Indexing job with id {attempt.id} failed due to {e}")
|
||||
mark_attempt_failed(attempt, db_session, failure_reason=str(e))
|
||||
update_connector_credential_pair(
|
||||
connector_id=db_connector.id,
|
||||
credential_id=db_credential.id,
|
||||
attempt_status=IndexingStatus.FAILED,
|
||||
net_docs=net_doc_change,
|
||||
db_session=db_session,
|
||||
)
|
||||
|
||||
|
||||
def update_loop(delay: int = 10) -> None:
|
||||
last_run_time = 0.0
|
||||
engine = build_engine()
|
||||
while True:
|
||||
start = time.time()
|
||||
logger.info(f"Running update, current time: {time.ctime(start)}")
|
||||
try:
|
||||
with Session(
|
||||
build_engine(), future=True, expire_on_commit=False
|
||||
) as db_session:
|
||||
with Session(engine, future=True, expire_on_commit=False) as db_session:
|
||||
create_indexing_jobs(db_session)
|
||||
# TODO failed poll jobs won't recover data from failed runs, should fix
|
||||
run_indexing_jobs(last_run_time, db_session)
|
||||
run_indexing_jobs(db_session)
|
||||
except Exception as e:
|
||||
logger.exception(f"Failed to run update due to {e}")
|
||||
sleep_time = delay - (time.time() - start)
|
||||
|
@ -1,3 +1,4 @@
|
||||
import datetime
|
||||
import io
|
||||
from collections.abc import Generator
|
||||
from typing import Any
|
||||
@ -9,6 +10,8 @@ from danswer.connectors.google_drive.connector_auth import DB_CREDENTIALS_DICT_K
|
||||
from danswer.connectors.google_drive.connector_auth import get_drive_tokens
|
||||
from danswer.connectors.interfaces import GenerateDocumentsOutput
|
||||
from danswer.connectors.interfaces import LoadConnector
|
||||
from danswer.connectors.interfaces import PollConnector
|
||||
from danswer.connectors.interfaces import SecondsSinceUnixEpoch
|
||||
from danswer.connectors.models import Document
|
||||
from danswer.connectors.models import Section
|
||||
from danswer.utils.logging import setup_logger
|
||||
@ -35,9 +38,23 @@ def get_file_batches(
|
||||
service: discovery.Resource,
|
||||
include_shared: bool = GOOGLE_DRIVE_INCLUDE_SHARED,
|
||||
batch_size: int = INDEX_BATCH_SIZE,
|
||||
time_range_start: SecondsSinceUnixEpoch | None = None,
|
||||
time_range_end: SecondsSinceUnixEpoch | None = None,
|
||||
) -> Generator[list[dict[str, str]], None, None]:
|
||||
next_page_token = ""
|
||||
while next_page_token is not None:
|
||||
query = ""
|
||||
if time_range_start is not None:
|
||||
time_start = (
|
||||
datetime.datetime.utcfromtimestamp(time_range_start).isoformat() + "Z"
|
||||
)
|
||||
query += f"modifiedTime >= '{time_start}' "
|
||||
if time_range_end is not None:
|
||||
time_stop = (
|
||||
datetime.datetime.utcfromtimestamp(time_range_end).isoformat() + "Z"
|
||||
)
|
||||
query += f"and modifiedTime <= '{time_stop}'"
|
||||
|
||||
results = (
|
||||
service.files()
|
||||
.list(
|
||||
@ -45,6 +62,7 @@ def get_file_batches(
|
||||
supportsAllDrives=include_shared,
|
||||
fields="nextPageToken, files(mimeType, id, name, webViewLink)",
|
||||
pageToken=next_page_token,
|
||||
q=query,
|
||||
)
|
||||
.execute()
|
||||
)
|
||||
@ -84,7 +102,7 @@ def extract_text(file: dict[str, str], service: discovery.Resource) -> str:
|
||||
return "\n".join(page.extract_text() for page in pdf_reader.pages)
|
||||
|
||||
|
||||
class GoogleDriveConnector(LoadConnector):
|
||||
class GoogleDriveConnector(LoadConnector, PollConnector):
|
||||
def __init__(
|
||||
self,
|
||||
batch_size: int = INDEX_BATCH_SIZE,
|
||||
@ -105,13 +123,21 @@ class GoogleDriveConnector(LoadConnector):
|
||||
return {DB_CREDENTIALS_DICT_KEY: new_creds_json_str}
|
||||
return None
|
||||
|
||||
def load_from_state(self) -> GenerateDocumentsOutput:
|
||||
def _fetch_docs_from_drive(
|
||||
self,
|
||||
start: SecondsSinceUnixEpoch | None = None,
|
||||
end: SecondsSinceUnixEpoch | None = None,
|
||||
) -> GenerateDocumentsOutput:
|
||||
if self.creds is None:
|
||||
raise PermissionError("Not logged into Google Drive")
|
||||
|
||||
service = discovery.build("drive", "v3", credentials=self.creds)
|
||||
for files_batch in get_file_batches(
|
||||
service, self.include_shared, self.batch_size
|
||||
service,
|
||||
self.include_shared,
|
||||
self.batch_size,
|
||||
time_range_start=start,
|
||||
time_range_end=end,
|
||||
):
|
||||
doc_batch = []
|
||||
for file in files_batch:
|
||||
@ -129,3 +155,11 @@ class GoogleDriveConnector(LoadConnector):
|
||||
)
|
||||
|
||||
yield doc_batch
|
||||
|
||||
def load_from_state(self) -> GenerateDocumentsOutput:
|
||||
yield from self._fetch_docs_from_drive()
|
||||
|
||||
def poll_source(
|
||||
self, start: SecondsSinceUnixEpoch, end: SecondsSinceUnixEpoch
|
||||
) -> GenerateDocumentsOutput:
|
||||
yield from self._fetch_docs_from_drive(start, end)
|
||||
|
@ -46,6 +46,7 @@ def get_drive_tokens(
|
||||
try:
|
||||
creds.refresh(Request())
|
||||
if creds.valid:
|
||||
logger.info("Refreshed Google Drive tokens.")
|
||||
return creds
|
||||
except Exception as e:
|
||||
logger.exception(f"Failed to refresh google drive access token due to: {e}")
|
||||
|
@ -15,7 +15,8 @@ IndexFilter = dict[str, str | list[str] | None]
|
||||
|
||||
class DocumentIndex(Generic[T], abc.ABC):
|
||||
@abc.abstractmethod
|
||||
def index(self, chunks: list[T], user_id: UUID | None) -> bool:
|
||||
def index(self, chunks: list[T], user_id: UUID | None) -> int:
|
||||
"""Indexes document chunks into the Document Index and return the number of new documents"""
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
|
@ -52,6 +52,7 @@ def create_qdrant_collection(
|
||||
def get_qdrant_document_whitelists(
|
||||
doc_chunk_id: str, collection_name: str, q_client: QdrantClient
|
||||
) -> tuple[bool, list[str], list[str]]:
|
||||
"""Get whether a document is found and the existing whitelists"""
|
||||
results = q_client.retrieve(
|
||||
collection_name=collection_name,
|
||||
ids=[doc_chunk_id],
|
||||
@ -69,8 +70,8 @@ def get_qdrant_document_whitelists(
|
||||
|
||||
def delete_qdrant_doc_chunks(
|
||||
document_id: str, collection_name: str, q_client: QdrantClient
|
||||
) -> None:
|
||||
q_client.delete(
|
||||
) -> bool:
|
||||
res = q_client.delete(
|
||||
collection_name=collection_name,
|
||||
points_selector=models.FilterSelector(
|
||||
filter=models.Filter(
|
||||
@ -83,6 +84,7 @@ def delete_qdrant_doc_chunks(
|
||||
)
|
||||
),
|
||||
)
|
||||
return True
|
||||
|
||||
|
||||
def index_qdrant_chunks(
|
||||
@ -91,7 +93,7 @@ def index_qdrant_chunks(
|
||||
collection: str,
|
||||
client: QdrantClient | None = None,
|
||||
batch_upsert: bool = True,
|
||||
) -> bool:
|
||||
) -> int:
|
||||
# Public documents will have the PUBLIC string in ALLOWED_USERS
|
||||
# If credential that kicked this off has no user associated, either Auth is off or the doc is public
|
||||
user_str = PUBLIC_DOC_PAT if user_id is None else str(user_id)
|
||||
@ -100,6 +102,7 @@ def index_qdrant_chunks(
|
||||
point_structs: list[PointStruct] = []
|
||||
# Maps document id to dict of whitelists for users/groups each containing list of users/groups as strings
|
||||
doc_user_map: dict[str, dict[str, list[str]]] = {}
|
||||
docs_deleted = 0
|
||||
for chunk in chunks:
|
||||
document = chunk.source_document
|
||||
doc_user_map, delete_doc = update_doc_user_map(
|
||||
@ -114,6 +117,8 @@ def index_qdrant_chunks(
|
||||
)
|
||||
|
||||
if delete_doc:
|
||||
# Processing the first chunk of the doc and the doc exists
|
||||
docs_deleted += 1
|
||||
delete_qdrant_doc_chunks(document.id, collection, q_client)
|
||||
|
||||
point_structs.extend(
|
||||
@ -138,7 +143,6 @@ def index_qdrant_chunks(
|
||||
]
|
||||
)
|
||||
|
||||
index_results = None
|
||||
if batch_upsert:
|
||||
point_struct_batches = [
|
||||
point_structs[x : x + DEFAULT_BATCH_SIZE]
|
||||
@ -171,4 +175,5 @@ def index_qdrant_chunks(
|
||||
logger.info(
|
||||
f"Document batch of size {len(point_structs)} indexing status: {index_results.status}"
|
||||
)
|
||||
return index_results is not None and index_results.status == UpdateStatus.COMPLETED
|
||||
|
||||
return len(doc_user_map.keys()) - docs_deleted
|
||||
|
@ -77,7 +77,7 @@ class QdrantIndex(VectorIndex):
|
||||
self.collection = collection
|
||||
self.client = get_qdrant_client()
|
||||
|
||||
def index(self, chunks: list[EmbeddedIndexChunk], user_id: UUID | None) -> bool:
|
||||
def index(self, chunks: list[EmbeddedIndexChunk], user_id: UUID | None) -> int:
|
||||
return index_qdrant_chunks(
|
||||
chunks=chunks,
|
||||
user_id=user_id,
|
||||
|
@ -86,7 +86,7 @@ def get_typesense_document_whitelists(
|
||||
|
||||
def delete_typesense_doc_chunks(
|
||||
document_id: str, collection_name: str, ts_client: typesense.Client
|
||||
) -> None:
|
||||
) -> bool:
|
||||
search_parameters = {
|
||||
"q": document_id,
|
||||
"query_by": DOCUMENT_ID,
|
||||
@ -98,6 +98,7 @@ def delete_typesense_doc_chunks(
|
||||
ts_client.collections[collection_name].documents[hit["document"]["id"]].delete()
|
||||
for hit in hits["hits"]
|
||||
]
|
||||
return True if hits else False
|
||||
|
||||
|
||||
def index_typesense_chunks(
|
||||
@ -106,12 +107,13 @@ def index_typesense_chunks(
|
||||
collection: str,
|
||||
client: typesense.Client | None = None,
|
||||
batch_upsert: bool = True,
|
||||
) -> bool:
|
||||
) -> int:
|
||||
user_str = PUBLIC_DOC_PAT if user_id is None else str(user_id)
|
||||
ts_client: typesense.Client = client if client else get_typesense_client()
|
||||
|
||||
new_documents: list[dict[str, Any]] = []
|
||||
doc_user_map: dict[str, dict[str, list[str]]] = {}
|
||||
docs_deleted = 0
|
||||
for chunk in chunks:
|
||||
document = chunk.source_document
|
||||
doc_user_map, delete_doc = update_doc_user_map(
|
||||
@ -126,6 +128,8 @@ def index_typesense_chunks(
|
||||
)
|
||||
|
||||
if delete_doc:
|
||||
# Processing the first chunk of the doc and the doc exists
|
||||
docs_deleted += 1
|
||||
delete_typesense_doc_chunks(document.id, collection, ts_client)
|
||||
|
||||
new_documents.append(
|
||||
@ -168,7 +172,7 @@ def index_typesense_chunks(
|
||||
for document in new_documents
|
||||
]
|
||||
|
||||
return True
|
||||
return len(doc_user_map.keys()) - docs_deleted
|
||||
|
||||
|
||||
def _build_typesense_filters(
|
||||
@ -206,7 +210,7 @@ class TypesenseIndex(KeywordIndex):
|
||||
self.collection = collection
|
||||
self.ts_client = get_typesense_client()
|
||||
|
||||
def index(self, chunks: list[IndexChunk], user_id: UUID | None) -> bool:
|
||||
def index(self, chunks: list[IndexChunk], user_id: UUID | None) -> int:
|
||||
return index_typesense_chunks(
|
||||
chunks=chunks,
|
||||
user_id=user_id,
|
||||
|
@ -2,11 +2,8 @@ from typing import cast
|
||||
|
||||
from danswer.configs.constants import DocumentSource
|
||||
from danswer.connectors.models import InputType
|
||||
from danswer.db.credentials import fetch_credential_by_id
|
||||
from danswer.db.models import Connector
|
||||
from danswer.db.models import ConnectorCredentialAssociation
|
||||
from danswer.db.models import IndexAttempt
|
||||
from danswer.db.models import User
|
||||
from danswer.server.models import ConnectorBase
|
||||
from danswer.server.models import ObjectCreationIdResponse
|
||||
from danswer.server.models import StatusResponse
|
||||
@ -147,95 +144,6 @@ def get_connector_credential_ids(
|
||||
return [association.credential.id for association in connector.credentials]
|
||||
|
||||
|
||||
def add_credential_to_connector(
|
||||
connector_id: int,
|
||||
credential_id: int,
|
||||
user: User,
|
||||
db_session: Session,
|
||||
) -> StatusResponse[int]:
|
||||
connector = fetch_connector_by_id(connector_id, db_session)
|
||||
credential = fetch_credential_by_id(credential_id, user, db_session)
|
||||
|
||||
if connector is None:
|
||||
raise HTTPException(status_code=404, detail="Connector does not exist")
|
||||
|
||||
if credential is None:
|
||||
raise HTTPException(
|
||||
status_code=401,
|
||||
detail="Credential does not exist or does not belong to user",
|
||||
)
|
||||
|
||||
existing_association = (
|
||||
db_session.query(ConnectorCredentialAssociation)
|
||||
.filter(
|
||||
ConnectorCredentialAssociation.connector_id == connector_id,
|
||||
ConnectorCredentialAssociation.credential_id == credential_id,
|
||||
)
|
||||
.one_or_none()
|
||||
)
|
||||
if existing_association is not None:
|
||||
return StatusResponse(
|
||||
success=False,
|
||||
message=f"Connector already has Credential {credential_id}",
|
||||
data=connector_id,
|
||||
)
|
||||
|
||||
association = ConnectorCredentialAssociation(
|
||||
connector_id=connector_id, credential_id=credential_id
|
||||
)
|
||||
db_session.add(association)
|
||||
db_session.commit()
|
||||
|
||||
return StatusResponse(
|
||||
success=True,
|
||||
message=f"New Credential {credential_id} added to Connector",
|
||||
data=connector_id,
|
||||
)
|
||||
|
||||
|
||||
def remove_credential_from_connector(
|
||||
connector_id: int,
|
||||
credential_id: int,
|
||||
user: User,
|
||||
db_session: Session,
|
||||
) -> StatusResponse[int]:
|
||||
connector = fetch_connector_by_id(connector_id, db_session)
|
||||
credential = fetch_credential_by_id(credential_id, user, db_session)
|
||||
|
||||
if connector is None:
|
||||
raise HTTPException(status_code=404, detail="Connector does not exist")
|
||||
|
||||
if credential is None:
|
||||
raise HTTPException(
|
||||
status_code=404,
|
||||
detail="Credential does not exist or does not belong to user",
|
||||
)
|
||||
|
||||
association = (
|
||||
db_session.query(ConnectorCredentialAssociation)
|
||||
.filter(
|
||||
ConnectorCredentialAssociation.connector_id == connector_id,
|
||||
ConnectorCredentialAssociation.credential_id == credential_id,
|
||||
)
|
||||
.one_or_none()
|
||||
)
|
||||
|
||||
if association is not None:
|
||||
db_session.delete(association)
|
||||
db_session.commit()
|
||||
return StatusResponse(
|
||||
success=True,
|
||||
message=f"Credential {credential_id} removed from Connector",
|
||||
data=connector_id,
|
||||
)
|
||||
|
||||
return StatusResponse(
|
||||
success=False,
|
||||
message=f"Connector already does not have Credential {credential_id}",
|
||||
data=connector_id,
|
||||
)
|
||||
|
||||
|
||||
def fetch_latest_index_attempt_by_connector(
|
||||
db_session: Session,
|
||||
source: DocumentSource | None = None,
|
||||
|
148
backend/danswer/db/connector_credential_pair.py
Normal file
148
backend/danswer/db/connector_credential_pair.py
Normal file
@ -0,0 +1,148 @@
|
||||
from danswer.db.connector import fetch_connector_by_id
|
||||
from danswer.db.credentials import fetch_credential_by_id
|
||||
from danswer.db.models import ConnectorCredentialPair
|
||||
from danswer.db.models import IndexingStatus
|
||||
from danswer.db.models import User
|
||||
from danswer.server.models import StatusResponse
|
||||
from danswer.utils.logging import setup_logger
|
||||
from fastapi import HTTPException
|
||||
from sqlalchemy import func
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
def get_connector_credential_pairs(
|
||||
db_session: Session, include_disabled: bool = True
|
||||
) -> list[ConnectorCredentialPair]:
|
||||
stmt = select(ConnectorCredentialPair)
|
||||
if not include_disabled:
|
||||
stmt = stmt.where(ConnectorCredentialPair.connector.disabled == False)
|
||||
results = db_session.scalars(stmt)
|
||||
return list(results.all())
|
||||
|
||||
|
||||
def get_connector_credential_pair(
|
||||
connector_id: int,
|
||||
credential_id: int,
|
||||
db_session: Session,
|
||||
) -> ConnectorCredentialPair | None:
|
||||
stmt = select(ConnectorCredentialPair)
|
||||
stmt = stmt.where(ConnectorCredentialPair.connector_id == connector_id)
|
||||
stmt = stmt.where(ConnectorCredentialPair.credential_id == credential_id)
|
||||
result = db_session.execute(stmt)
|
||||
return result.scalar_one_or_none()
|
||||
|
||||
|
||||
def update_connector_credential_pair(
|
||||
connector_id: int,
|
||||
credential_id: int,
|
||||
attempt_status: IndexingStatus,
|
||||
net_docs: int | None,
|
||||
db_session: Session,
|
||||
) -> None:
|
||||
cc_pair = get_connector_credential_pair(connector_id, credential_id, db_session)
|
||||
if not cc_pair:
|
||||
logger.warning(
|
||||
f"Attempted to update pair for connector id {connector_id} "
|
||||
f"and credential id {credential_id}"
|
||||
)
|
||||
return
|
||||
cc_pair.last_attempt_status = attempt_status
|
||||
if attempt_status == IndexingStatus.SUCCESS:
|
||||
cc_pair.last_successful_index_time = func.now() # type:ignore
|
||||
if net_docs is not None:
|
||||
cc_pair.total_docs_indexed += net_docs
|
||||
db_session.commit()
|
||||
|
||||
|
||||
def add_credential_to_connector(
|
||||
connector_id: int,
|
||||
credential_id: int,
|
||||
user: User,
|
||||
db_session: Session,
|
||||
) -> StatusResponse[int]:
|
||||
connector = fetch_connector_by_id(connector_id, db_session)
|
||||
credential = fetch_credential_by_id(credential_id, user, db_session)
|
||||
|
||||
if connector is None:
|
||||
raise HTTPException(status_code=404, detail="Connector does not exist")
|
||||
|
||||
if credential is None:
|
||||
raise HTTPException(
|
||||
status_code=401,
|
||||
detail="Credential does not exist or does not belong to user",
|
||||
)
|
||||
|
||||
existing_association = (
|
||||
db_session.query(ConnectorCredentialPair)
|
||||
.filter(
|
||||
ConnectorCredentialPair.connector_id == connector_id,
|
||||
ConnectorCredentialPair.credential_id == credential_id,
|
||||
)
|
||||
.one_or_none()
|
||||
)
|
||||
if existing_association is not None:
|
||||
return StatusResponse(
|
||||
success=False,
|
||||
message=f"Connector already has Credential {credential_id}",
|
||||
data=connector_id,
|
||||
)
|
||||
|
||||
association = ConnectorCredentialPair(
|
||||
connector_id=connector_id,
|
||||
credential_id=credential_id,
|
||||
last_attempt_status=IndexingStatus.NOT_STARTED,
|
||||
)
|
||||
db_session.add(association)
|
||||
db_session.commit()
|
||||
|
||||
return StatusResponse(
|
||||
success=True,
|
||||
message=f"New Credential {credential_id} added to Connector",
|
||||
data=connector_id,
|
||||
)
|
||||
|
||||
|
||||
def remove_credential_from_connector(
|
||||
connector_id: int,
|
||||
credential_id: int,
|
||||
user: User,
|
||||
db_session: Session,
|
||||
) -> StatusResponse[int]:
|
||||
connector = fetch_connector_by_id(connector_id, db_session)
|
||||
credential = fetch_credential_by_id(credential_id, user, db_session)
|
||||
|
||||
if connector is None:
|
||||
raise HTTPException(status_code=404, detail="Connector does not exist")
|
||||
|
||||
if credential is None:
|
||||
raise HTTPException(
|
||||
status_code=404,
|
||||
detail="Credential does not exist or does not belong to user",
|
||||
)
|
||||
|
||||
association = (
|
||||
db_session.query(ConnectorCredentialPair)
|
||||
.filter(
|
||||
ConnectorCredentialPair.connector_id == connector_id,
|
||||
ConnectorCredentialPair.credential_id == credential_id,
|
||||
)
|
||||
.one_or_none()
|
||||
)
|
||||
|
||||
if association is not None:
|
||||
db_session.delete(association)
|
||||
db_session.commit()
|
||||
return StatusResponse(
|
||||
success=True,
|
||||
message=f"Credential {credential_id} removed from Connector",
|
||||
data=connector_id,
|
||||
)
|
||||
|
||||
return StatusResponse(
|
||||
success=False,
|
||||
message=f"Connector already does not have Credential {credential_id}",
|
||||
data=connector_id,
|
||||
)
|
@ -1,6 +1,7 @@
|
||||
from collections.abc import AsyncGenerator
|
||||
from collections.abc import Generator
|
||||
from datetime import datetime
|
||||
from datetime import timezone
|
||||
|
||||
from danswer.configs.app_configs import POSTGRES_DB
|
||||
from danswer.configs.app_configs import POSTGRES_HOST
|
||||
@ -27,6 +28,15 @@ def get_db_current_time(db_session: Session) -> datetime:
|
||||
return result
|
||||
|
||||
|
||||
def translate_db_time_to_server_time(
|
||||
db_time: datetime, db_session: Session
|
||||
) -> datetime:
|
||||
server_now = datetime.now()
|
||||
db_now = get_db_current_time(db_session)
|
||||
time_diff = server_now - db_now.astimezone(timezone.utc).replace(tzinfo=None)
|
||||
return db_time + time_diff
|
||||
|
||||
|
||||
def build_connection_string(
|
||||
*,
|
||||
db_api: str = ASYNC_DB_API,
|
||||
|
@ -1,3 +1,4 @@
|
||||
from danswer.db.engine import translate_db_time_to_server_time
|
||||
from danswer.db.models import IndexAttempt
|
||||
from danswer.db.models import IndexingStatus
|
||||
from danswer.utils.logging import setup_logger
|
||||
@ -74,13 +75,31 @@ def mark_attempt_failed(
|
||||
db_session.commit()
|
||||
|
||||
|
||||
def get_last_finished_attempt(
|
||||
def get_last_successful_attempt(
|
||||
connector_id: int,
|
||||
credential_id: int,
|
||||
db_session: Session,
|
||||
) -> IndexAttempt | None:
|
||||
stmt = select(IndexAttempt)
|
||||
stmt = stmt.where(IndexAttempt.connector_id == connector_id)
|
||||
stmt = stmt.where(IndexAttempt.credential_id == credential_id)
|
||||
stmt = stmt.where(IndexAttempt.status == IndexingStatus.SUCCESS)
|
||||
stmt = stmt.order_by(desc(IndexAttempt.time_updated))
|
||||
# Note, the below is using time_created instead of time_updated
|
||||
stmt = stmt.order_by(desc(IndexAttempt.time_created))
|
||||
|
||||
return db_session.execute(stmt).scalars().first()
|
||||
|
||||
|
||||
def get_last_successful_attempt_start_time(
|
||||
connector_id: int,
|
||||
credential_id: int,
|
||||
db_session: Session,
|
||||
) -> float:
|
||||
"""Technically the start time is a bit later than creation but for intended use, it doesn't matter"""
|
||||
last_indexing = get_last_successful_attempt(connector_id, credential_id, db_session)
|
||||
if last_indexing is None:
|
||||
return 0.0
|
||||
last_index_start = translate_db_time_to_server_time(
|
||||
last_indexing.time_created, db_session
|
||||
)
|
||||
return last_index_start.timestamp()
|
||||
|
@ -24,6 +24,13 @@ from sqlalchemy.orm import mapped_column
|
||||
from sqlalchemy.orm import relationship
|
||||
|
||||
|
||||
class IndexingStatus(str, PyEnum):
|
||||
NOT_STARTED = "not_started"
|
||||
IN_PROGRESS = "in_progress"
|
||||
SUCCESS = "success"
|
||||
FAILED = "failed"
|
||||
|
||||
|
||||
class Base(DeclarativeBase):
|
||||
pass
|
||||
|
||||
@ -48,19 +55,25 @@ class AccessToken(SQLAlchemyBaseAccessTokenTableUUID, Base):
|
||||
pass
|
||||
|
||||
|
||||
class ConnectorCredentialAssociation(Base):
|
||||
class ConnectorCredentialPair(Base):
|
||||
"""Connectors and Credentials can have a many-to-many relationship
|
||||
I.e. A Confluence Connector may have multiple admin users who can run it with their own credentials
|
||||
I.e. An admin user may use the same credential to index multiple Confluence Spaces
|
||||
"""
|
||||
|
||||
__tablename__ = "connector_credential_association"
|
||||
__tablename__ = "connector_credential_pair"
|
||||
connector_id: Mapped[int] = mapped_column(
|
||||
ForeignKey("connector.id"), primary_key=True
|
||||
)
|
||||
credential_id: Mapped[int] = mapped_column(
|
||||
ForeignKey("credential.id"), primary_key=True
|
||||
)
|
||||
# Time finished, not used for calculating backend jobs which uses time started (created)
|
||||
last_successful_index_time: Mapped[datetime.datetime | None] = mapped_column(
|
||||
DateTime(timezone=True), default=None
|
||||
)
|
||||
last_attempt_status: Mapped[IndexingStatus] = mapped_column(Enum(IndexingStatus))
|
||||
total_docs_indexed: Mapped[int] = mapped_column(Integer, default=0)
|
||||
|
||||
connector: Mapped["Connector"] = relationship(
|
||||
"Connector", back_populates="credentials"
|
||||
@ -91,8 +104,8 @@ class Connector(Base):
|
||||
)
|
||||
disabled: Mapped[bool] = mapped_column(Boolean, default=False)
|
||||
|
||||
credentials: Mapped[List["ConnectorCredentialAssociation"]] = relationship(
|
||||
"ConnectorCredentialAssociation",
|
||||
credentials: Mapped[List["ConnectorCredentialPair"]] = relationship(
|
||||
"ConnectorCredentialPair",
|
||||
back_populates="connector",
|
||||
cascade="all, delete-orphan",
|
||||
)
|
||||
@ -115,8 +128,8 @@ class Credential(Base):
|
||||
DateTime(timezone=True), server_default=func.now(), onupdate=func.now()
|
||||
)
|
||||
|
||||
connectors: Mapped[List["ConnectorCredentialAssociation"]] = relationship(
|
||||
"ConnectorCredentialAssociation",
|
||||
connectors: Mapped[List["ConnectorCredentialPair"]] = relationship(
|
||||
"ConnectorCredentialPair",
|
||||
back_populates="credential",
|
||||
cascade="all, delete-orphan",
|
||||
)
|
||||
@ -126,13 +139,6 @@ class Credential(Base):
|
||||
user: Mapped[User] = relationship("User", back_populates="credentials")
|
||||
|
||||
|
||||
class IndexingStatus(str, PyEnum):
|
||||
NOT_STARTED = "not_started"
|
||||
IN_PROGRESS = "in_progress"
|
||||
SUCCESS = "success"
|
||||
FAILED = "failed"
|
||||
|
||||
|
||||
class IndexAttempt(Base):
|
||||
"""
|
||||
Represents an attempt to index a group of 1 or more documents from a
|
||||
@ -144,10 +150,12 @@ class IndexAttempt(Base):
|
||||
|
||||
id: Mapped[int] = mapped_column(primary_key=True)
|
||||
connector_id: Mapped[int | None] = mapped_column(
|
||||
ForeignKey("connector.id"), nullable=True
|
||||
ForeignKey("connector.id"),
|
||||
nullable=True,
|
||||
)
|
||||
credential_id: Mapped[int | None] = mapped_column(
|
||||
ForeignKey("credential.id"), nullable=True
|
||||
ForeignKey("credential.id"),
|
||||
nullable=True,
|
||||
)
|
||||
status: Mapped[IndexingStatus] = mapped_column(Enum(IndexingStatus))
|
||||
document_ids: Mapped[list[str] | None] = mapped_column(
|
||||
@ -157,10 +165,13 @@ class IndexAttempt(Base):
|
||||
String(), default=None
|
||||
) # only filled if status = "failed"
|
||||
time_created: Mapped[datetime.datetime] = mapped_column(
|
||||
DateTime(timezone=True), server_default=func.now()
|
||||
DateTime(timezone=True),
|
||||
server_default=func.now(),
|
||||
)
|
||||
time_updated: Mapped[datetime.datetime] = mapped_column(
|
||||
DateTime(timezone=True), server_default=func.now(), onupdate=func.now()
|
||||
DateTime(timezone=True),
|
||||
server_default=func.now(),
|
||||
onupdate=func.now(),
|
||||
)
|
||||
|
||||
connector: Mapped[Connector] = relationship(
|
||||
|
@ -17,7 +17,6 @@ from danswer.connectors.google_drive.connector_auth import (
|
||||
)
|
||||
from danswer.connectors.google_drive.connector_auth import upsert_google_app_cred
|
||||
from danswer.connectors.google_drive.connector_auth import verify_csrf
|
||||
from danswer.db.connector import add_credential_to_connector
|
||||
from danswer.db.connector import create_connector
|
||||
from danswer.db.connector import delete_connector
|
||||
from danswer.db.connector import fetch_connector_by_id
|
||||
@ -25,8 +24,10 @@ from danswer.db.connector import fetch_connectors
|
||||
from danswer.db.connector import fetch_latest_index_attempt_by_connector
|
||||
from danswer.db.connector import fetch_latest_index_attempts_by_status
|
||||
from danswer.db.connector import get_connector_credential_ids
|
||||
from danswer.db.connector import remove_credential_from_connector
|
||||
from danswer.db.connector import update_connector
|
||||
from danswer.db.connector_credential_pair import add_credential_to_connector
|
||||
from danswer.db.connector_credential_pair import get_connector_credential_pairs
|
||||
from danswer.db.connector_credential_pair import remove_credential_from_connector
|
||||
from danswer.db.credentials import create_credential
|
||||
from danswer.db.credentials import delete_credential
|
||||
from danswer.db.credentials import fetch_credential_by_id
|
||||
@ -172,131 +173,28 @@ def upload_files(
|
||||
return FileUploadResponse(file_paths=file_paths)
|
||||
|
||||
|
||||
@router.get("/admin/latest-index-attempt")
|
||||
def list_all_index_attempts(
|
||||
_: User = Depends(current_admin_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> list[IndexAttemptSnapshot]:
|
||||
index_attempts = fetch_latest_index_attempt_by_connector(db_session)
|
||||
return [
|
||||
IndexAttemptSnapshot(
|
||||
source=index_attempt.connector.source,
|
||||
input_type=index_attempt.connector.input_type,
|
||||
status=index_attempt.status,
|
||||
connector_specific_config=index_attempt.connector.connector_specific_config,
|
||||
docs_indexed=0
|
||||
if not index_attempt.document_ids
|
||||
else len(index_attempt.document_ids),
|
||||
time_created=index_attempt.time_created,
|
||||
time_updated=index_attempt.time_updated,
|
||||
)
|
||||
for index_attempt in index_attempts
|
||||
]
|
||||
|
||||
|
||||
@router.get("/admin/latest-index-attempt/{source}")
|
||||
def list_index_attempts(
|
||||
source: DocumentSource,
|
||||
_: User = Depends(current_admin_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> list[IndexAttemptSnapshot]:
|
||||
index_attempts = fetch_latest_index_attempt_by_connector(db_session, source=source)
|
||||
return [
|
||||
IndexAttemptSnapshot(
|
||||
source=index_attempt.connector.source,
|
||||
input_type=index_attempt.connector.input_type,
|
||||
status=index_attempt.status,
|
||||
connector_specific_config=index_attempt.connector.connector_specific_config,
|
||||
docs_indexed=0
|
||||
if not index_attempt.document_ids
|
||||
else len(index_attempt.document_ids),
|
||||
time_created=index_attempt.time_created,
|
||||
time_updated=index_attempt.time_updated,
|
||||
)
|
||||
for index_attempt in index_attempts
|
||||
]
|
||||
|
||||
|
||||
@router.get("/admin/connector/indexing-status")
|
||||
def get_connector_indexing_status(
|
||||
_: User = Depends(current_admin_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> list[ConnectorIndexingStatus]:
|
||||
connector_id_to_connector: dict[int, Connector] = {
|
||||
connector.id: connector for connector in fetch_connectors(db_session)
|
||||
}
|
||||
index_attempts = fetch_latest_index_attempts_by_status(db_session)
|
||||
connector_credential_pair_to_index_attempts: dict[
|
||||
tuple[int, int], list[IndexAttempt]
|
||||
] = defaultdict(list)
|
||||
for index_attempt in index_attempts:
|
||||
# don't consider index attempts where the connector has been deleted
|
||||
# or the credential has been deleted
|
||||
if (
|
||||
index_attempt.connector_id is not None
|
||||
and index_attempt.credential_id is not None
|
||||
):
|
||||
connector_credential_pair_to_index_attempts[
|
||||
(index_attempt.connector_id, index_attempt.credential_id)
|
||||
].append(index_attempt)
|
||||
|
||||
indexing_statuses: list[ConnectorIndexingStatus] = []
|
||||
for (
|
||||
connector_id,
|
||||
credential_id,
|
||||
), index_attempts in connector_credential_pair_to_index_attempts.items():
|
||||
# NOTE: index_attempts is guaranteed to be length > 0
|
||||
connector = connector_id_to_connector[connector_id]
|
||||
credential = [
|
||||
credential_association.credential
|
||||
for credential_association in connector.credentials
|
||||
if credential_association.credential_id == credential_id
|
||||
][0]
|
||||
|
||||
index_attempts_sorted = sorted(
|
||||
index_attempts, key=lambda x: x.time_updated, reverse=True
|
||||
)
|
||||
successful_index_attempts_sorted = [
|
||||
index_attempt
|
||||
for index_attempt in index_attempts_sorted
|
||||
if index_attempt.status == IndexingStatus.SUCCESS
|
||||
]
|
||||
cc_pairs = get_connector_credential_pairs(db_session)
|
||||
for cc_pair in cc_pairs:
|
||||
connector = cc_pair.connector
|
||||
credential = cc_pair.credential
|
||||
indexing_statuses.append(
|
||||
ConnectorIndexingStatus(
|
||||
connector=ConnectorSnapshot.from_connector_db_model(connector),
|
||||
public_doc=credential.public_doc,
|
||||
owner=credential.user.email if credential.user else "",
|
||||
last_status=index_attempts_sorted[0].status,
|
||||
last_success=successful_index_attempts_sorted[0].time_updated
|
||||
if successful_index_attempts_sorted
|
||||
else None,
|
||||
docs_indexed=len(successful_index_attempts_sorted[0].document_ids)
|
||||
if successful_index_attempts_sorted
|
||||
and successful_index_attempts_sorted[0].document_ids
|
||||
else 0,
|
||||
),
|
||||
last_status=cc_pair.last_attempt_status,
|
||||
last_success=cc_pair.last_successful_index_time,
|
||||
docs_indexed=cc_pair.total_docs_indexed,
|
||||
)
|
||||
)
|
||||
|
||||
# add in the connectors that haven't started indexing yet
|
||||
for connector in connector_id_to_connector.values():
|
||||
for credential_association in connector.credentials:
|
||||
if (
|
||||
connector.id,
|
||||
credential_association.credential_id,
|
||||
) not in connector_credential_pair_to_index_attempts:
|
||||
indexing_statuses.append(
|
||||
ConnectorIndexingStatus(
|
||||
connector=ConnectorSnapshot.from_connector_db_model(connector),
|
||||
public_doc=credential_association.credential.public_doc,
|
||||
owner=credential_association.credential.user.email
|
||||
if credential_association.credential.user
|
||||
else "",
|
||||
last_status=IndexingStatus.NOT_STARTED,
|
||||
last_success=None,
|
||||
docs_indexed=0,
|
||||
),
|
||||
)
|
||||
|
||||
return indexing_statuses
|
||||
|
||||
|
||||
|
@ -13,7 +13,6 @@ from danswer.direct_qa.question_answer import get_json_line
|
||||
from danswer.search.danswer_helper import query_intent
|
||||
from danswer.search.danswer_helper import recommend_search_flow
|
||||
from danswer.search.keyword_search import retrieve_keyword_documents
|
||||
from danswer.search.models import QueryFlow
|
||||
from danswer.search.models import SearchType
|
||||
from danswer.search.semantic_search import chunks_to_search_docs
|
||||
from danswer.search.semantic_search import retrieve_ranked_documents
|
||||
|
@ -5,7 +5,6 @@ from uuid import UUID
|
||||
|
||||
from danswer.chunking.chunk import Chunker
|
||||
from danswer.chunking.chunk import DefaultChunker
|
||||
from danswer.chunking.models import EmbeddedIndexChunk
|
||||
from danswer.connectors.models import Document
|
||||
from danswer.datastores.interfaces import KeywordIndex
|
||||
from danswer.datastores.interfaces import VectorIndex
|
||||
@ -13,12 +12,13 @@ from danswer.datastores.qdrant.store import QdrantIndex
|
||||
from danswer.datastores.typesense.store import TypesenseIndex
|
||||
from danswer.search.models import Embedder
|
||||
from danswer.search.semantic_search import DefaultEmbedder
|
||||
from danswer.utils.logging import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
class IndexingPipelineProtocol(Protocol):
|
||||
def __call__(
|
||||
self, documents: list[Document], user_id: UUID | None
|
||||
) -> list[EmbeddedIndexChunk]:
|
||||
def __call__(self, documents: list[Document], user_id: UUID | None) -> int:
|
||||
...
|
||||
|
||||
|
||||
@ -30,15 +30,19 @@ def _indexing_pipeline(
|
||||
keyword_index: KeywordIndex,
|
||||
documents: list[Document],
|
||||
user_id: UUID | None,
|
||||
) -> list[EmbeddedIndexChunk]:
|
||||
) -> int:
|
||||
# TODO: make entire indexing pipeline async to not block the entire process
|
||||
# when running on async endpoints
|
||||
chunks = list(chain(*[chunker.chunk(document) for document in documents]))
|
||||
# TODO keyword indexing can occur at same time as embedding
|
||||
keyword_index.index(chunks, user_id)
|
||||
net_doc_count_keyword = keyword_index.index(chunks, user_id)
|
||||
chunks_with_embeddings = embedder.embed(chunks)
|
||||
vector_index.index(chunks_with_embeddings, user_id)
|
||||
return chunks_with_embeddings
|
||||
net_doc_count_vector = vector_index.index(chunks_with_embeddings, user_id)
|
||||
if net_doc_count_vector != net_doc_count_vector:
|
||||
logger.exception(
|
||||
"Document count change from keyword/vector indices don't align"
|
||||
)
|
||||
return max(net_doc_count_keyword, net_doc_count_vector)
|
||||
|
||||
|
||||
def build_indexing_pipeline(
|
||||
|
@ -24,17 +24,12 @@ PyPDF2==3.0.1
|
||||
pytest-playwright==0.3.2
|
||||
python-multipart==0.0.6
|
||||
qdrant-client==1.2.0
|
||||
requests==2.28.2
|
||||
requests==2.31.0
|
||||
rfc3986==1.5.0
|
||||
sentence-transformers==2.2.2
|
||||
slack-sdk==3.20.2
|
||||
SQLAlchemy[mypy]==2.0.12
|
||||
tensorflow==2.12.0
|
||||
transformers==4.27.3
|
||||
types-beautifulsoup4==4.12.0.3
|
||||
types-html5lib==1.1.11.13
|
||||
types-regex==2023.3.23.1
|
||||
types-requests==2.28.11.17
|
||||
types-urllib3==1.26.25.11
|
||||
transformers==4.30.1
|
||||
typesense==0.15.1
|
||||
uvicorn==0.21.1
|
||||
|
@ -1,10 +1,11 @@
|
||||
mypy==1.1.1
|
||||
mypy-extensions==1.0.0
|
||||
black==23.3.0
|
||||
reorder-python-imports==3.9.0
|
||||
mypy-extensions==1.0.0
|
||||
mypy==1.1.1
|
||||
pre-commit==3.2.2
|
||||
reorder-python-imports==3.9.0
|
||||
types-beautifulsoup4==4.12.0.3
|
||||
types-html5lib==1.1.11.13
|
||||
types-psycopg2==2.9.21.10
|
||||
types-regex==2023.3.23.1
|
||||
types-requests==2.28.11.17
|
||||
types-urllib3==1.26.25.11
|
||||
types-regex==2023.3.23.1
|
||||
|
42
backend/scripts/reset_postgres.py
Normal file
42
backend/scripts/reset_postgres.py
Normal file
@ -0,0 +1,42 @@
|
||||
import psycopg2
|
||||
from danswer.configs.app_configs import POSTGRES_DB
|
||||
from danswer.configs.app_configs import POSTGRES_HOST
|
||||
from danswer.configs.app_configs import POSTGRES_PASSWORD
|
||||
from danswer.configs.app_configs import POSTGRES_PORT
|
||||
from danswer.configs.app_configs import POSTGRES_USER
|
||||
|
||||
|
||||
def wipe_all_rows(database: str) -> None:
|
||||
conn = psycopg2.connect(
|
||||
dbname=database,
|
||||
user=POSTGRES_USER,
|
||||
password=POSTGRES_PASSWORD,
|
||||
host=POSTGRES_HOST,
|
||||
port=POSTGRES_PORT,
|
||||
)
|
||||
|
||||
cur = conn.cursor()
|
||||
cur.execute(
|
||||
"""
|
||||
SELECT table_name
|
||||
FROM information_schema.tables
|
||||
WHERE table_schema = 'public'
|
||||
AND table_type = 'BASE TABLE'
|
||||
"""
|
||||
)
|
||||
|
||||
table_names = cur.fetchall()
|
||||
|
||||
for table_name in table_names:
|
||||
if table_name[0] == "alembic_version":
|
||||
continue
|
||||
cur.execute(f'DELETE FROM "{table_name[0]}"')
|
||||
print(f"Deleted all rows from table {table_name[0]}")
|
||||
conn.commit()
|
||||
|
||||
cur.close()
|
||||
conn.close()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
wipe_all_rows(POSTGRES_DB)
|
Loading…
x
Reference in New Issue
Block a user