cleanup connector interface

This commit is contained in:
Weves 2023-05-11 20:08:42 -07:00
parent 0b610502e0
commit 560822a327
15 changed files with 101 additions and 172 deletions

View File

@ -1,11 +1,10 @@
import asyncio
import time
from typing import cast
from danswer.configs.constants import DocumentSource
from danswer.connectors.slack.config import get_pull_frequency
from danswer.connectors.slack.pull import SlackPullLoader
from danswer.connectors.web.batch import BatchWebLoader
from danswer.connectors.slack.pull import PeriodicSlackLoader
from danswer.connectors.web.pull import WebLoader
from danswer.db.index_attempt import fetch_index_attempts
from danswer.db.index_attempt import update_index_attempt
from danswer.db.models import IndexingStatus
@ -23,7 +22,11 @@ def _check_should_run(current_time: int, last_pull: int, pull_frequency: int) ->
return current_time - last_pull > pull_frequency * 60
async def run_update() -> None:
def run_update() -> None:
# NOTE: have to make this async due to fastapi users only supporting an async
# driver for postgres. In the future, we should figure out a way to
# make it work with sync drivers so we don't need to make all code touching
# the database async
logger.info("Running update")
# TODO (chris): implement a more generic way to run updates
# so we don't need to edit this file for future connectors
@ -37,7 +40,9 @@ async def run_update() -> None:
except ConfigNotFoundError:
pull_frequency = 0
if pull_frequency:
last_slack_pull_key = LAST_PULL_KEY_TEMPLATE.format(SlackPullLoader.__name__)
last_slack_pull_key = LAST_PULL_KEY_TEMPLATE.format(
PeriodicSlackLoader.__name__
)
try:
last_pull = cast(int, dynamic_config_store.load(last_slack_pull_key))
except ConfigNotFoundError:
@ -47,7 +52,7 @@ async def run_update() -> None:
current_time, last_pull, pull_frequency
):
logger.info(f"Running slack pull from {last_pull or 0} to {current_time}")
for doc_batch in SlackPullLoader().load(last_pull or 0, current_time):
for doc_batch in PeriodicSlackLoader().load(last_pull or 0, current_time):
indexing_pipeline(doc_batch)
dynamic_config_store.store(last_slack_pull_key, current_time)
@ -56,7 +61,7 @@ async def run_update() -> None:
# prevent race conditions across multiple background jobs. For now,
# this assumes we only ever run a single background job at a time
# TODO (chris): make this generic for all pull connectors (not just web)
not_started_index_attempts = await fetch_index_attempts(
not_started_index_attempts = fetch_index_attempts(
sources=[DocumentSource.WEB], statuses=[IndexingStatus.NOT_STARTED]
)
for not_started_index_attempt in not_started_index_attempts:
@ -67,7 +72,7 @@ async def run_update() -> None:
f"{not_started_index_attempt.input_type}, and connector_specific_config: "
f"{not_started_index_attempt.connector_specific_config}"
)
await update_index_attempt(
update_index_attempt(
index_attempt_id=not_started_index_attempt.id,
new_status=IndexingStatus.IN_PROGRESS,
)
@ -75,10 +80,10 @@ async def run_update() -> None:
error_msg = None
base_url = not_started_index_attempt.connector_specific_config["url"]
try:
# TODO (chris): make all connectors async + spawn processes to
# parallelize / take advantage of multiple cores + implement retries
# TODO (chris): spawn processes to parallelize / take advantage of
# multiple cores + implement retries
document_ids: list[str] = []
async for doc_batch in BatchWebLoader(base_url=base_url).async_load():
for doc_batch in WebLoader(base_url=base_url).load():
chunks = indexing_pipeline(doc_batch)
document_ids.extend([chunk.source_document.id for chunk in chunks])
except Exception as e:
@ -87,7 +92,7 @@ async def run_update() -> None:
)
error_msg = str(e)
await update_index_attempt(
update_index_attempt(
index_attempt_id=not_started_index_attempt.id,
new_status=IndexingStatus.FAILED if error_msg else IndexingStatus.SUCCESS,
document_ids=document_ids if not error_msg else None,
@ -95,11 +100,11 @@ async def run_update() -> None:
)
async def update_loop(delay: int = 60):
def update_loop(delay: int = 60):
while True:
start = time.time()
try:
await run_update()
run_update()
except Exception:
logger.exception("Failed to run update")
sleep_time = delay - (time.time() - start)
@ -108,4 +113,4 @@ async def update_loop(delay: int = 60):
if __name__ == "__main__":
asyncio.run(update_loop())
update_loop()

View File

@ -4,9 +4,9 @@ from collections.abc import Generator
from danswer.configs.app_configs import GITHUB_ACCESS_TOKEN
from danswer.configs.app_configs import INDEX_BATCH_SIZE
from danswer.configs.constants import DocumentSource
from danswer.connectors.interfaces import PullLoader
from danswer.connectors.models import Document
from danswer.connectors.models import Section
from danswer.connectors.type_aliases import BatchLoader
from danswer.utils.logging import setup_logger
from github import Github
@ -24,7 +24,7 @@ def get_pr_batches(pull_requests, batch_size):
yield batch
class BatchGithubLoader(BatchLoader):
class BatchGithubLoader(PullLoader):
def __init__(
self,
repo_owner: str,

View File

@ -7,9 +7,9 @@ from danswer.configs.app_configs import GOOGLE_DRIVE_INCLUDE_SHARED
from danswer.configs.app_configs import GOOGLE_DRIVE_TOKENS_JSON
from danswer.configs.app_configs import INDEX_BATCH_SIZE
from danswer.configs.constants import DocumentSource
from danswer.connectors.interfaces import PullLoader
from danswer.connectors.models import Document
from danswer.connectors.models import Section
from danswer.connectors.type_aliases import BatchLoader
from danswer.utils.logging import setup_logger
from google.auth.transport.requests import Request # type: ignore
from google.oauth2.credentials import Credentials # type: ignore
@ -103,7 +103,11 @@ def extract_text(file: dict[str, str], service: discovery.Resource) -> str:
return "\n".join(page.extract_text() for page in pdf_reader.pages)
class BatchGoogleDriveLoader(BatchLoader):
class BatchGoogleDriveLoader(PullLoader):
"""
Loads everything in a Google Drive account
"""
def __init__(
self,
batch_size: int = INDEX_BATCH_SIZE,

View File

@ -0,0 +1,29 @@
import abc
from collections.abc import Generator
from typing import Any
from typing import List
from danswer.connectors.models import Document
SecondsSinceUnixEpoch = float
class PullLoader:
@abc.abstractmethod
def load(self) -> Generator[List[Document], None, None]:
raise NotImplementedError
class RangePullLoader:
@abc.abstractmethod
def load(
self, start: SecondsSinceUnixEpoch, end: SecondsSinceUnixEpoch
) -> Generator[List[Document], None, None]:
raise NotImplementedError
class PushLoader:
@abc.abstractmethod
def load(self, event: Any) -> Generator[List[Document], None, None]:
raise NotImplementedError

View File

@ -7,10 +7,10 @@ from typing import cast
from danswer.configs.app_configs import INDEX_BATCH_SIZE
from danswer.configs.constants import DocumentSource
from danswer.connectors.interfaces import PullLoader
from danswer.connectors.models import Document
from danswer.connectors.models import Section
from danswer.connectors.slack.utils import get_message_link
from danswer.connectors.type_aliases import BatchLoader
def _process_batch_event(
@ -58,7 +58,9 @@ def _process_batch_event(
return None
class BatchSlackLoader(BatchLoader):
class BatchSlackLoader(PullLoader):
"""Loads from an unzipped slack workspace export"""
def __init__(
self, export_path_str: str, batch_size: int = INDEX_BATCH_SIZE
) -> None:

View File

@ -7,12 +7,12 @@ from typing import List
from danswer.configs.app_configs import INDEX_BATCH_SIZE
from danswer.configs.constants import DocumentSource
from danswer.connectors.interfaces import RangePullLoader
from danswer.connectors.interfaces import SecondsSinceUnixEpoch
from danswer.connectors.models import Document
from danswer.connectors.models import Section
from danswer.connectors.slack.utils import get_client
from danswer.connectors.slack.utils import get_message_link
from danswer.connectors.type_aliases import PullLoader
from danswer.connectors.type_aliases import SecondsSinceUnixEpoch
from danswer.utils.logging import setup_logger
from slack_sdk import WebClient
from slack_sdk.errors import SlackApiError
@ -200,7 +200,7 @@ def get_all_docs(
return docs
class SlackPullLoader(PullLoader):
class PeriodicSlackLoader(RangePullLoader):
def __init__(self, batch_size: int = INDEX_BATCH_SIZE) -> None:
self.client = get_client()
self.batch_size = batch_size
@ -208,7 +208,6 @@ class SlackPullLoader(PullLoader):
def load(
self, start: SecondsSinceUnixEpoch, end: SecondsSinceUnixEpoch
) -> Generator[List[Document], None, None]:
# TODO: make this respect start and end
all_docs = get_all_docs(client=self.client, oldest=str(start), latest=str(end))
for i in range(0, len(all_docs), self.batch_size):
yield all_docs[i : i + self.batch_size]

View File

@ -1,43 +0,0 @@
import abc
from collections.abc import Callable
from collections.abc import Generator
from datetime import datetime
from typing import Any
from typing import List
from typing import Optional
from danswer.connectors.models import Document
ConnectorConfig = dict[str, Any]
# takes in the raw representation of a document from a source and returns a
# Document object
ProcessDocumentFunc = Callable[..., Document]
BuildListenerFunc = Callable[[ConnectorConfig], ProcessDocumentFunc]
# TODO (chris) refactor definition of a connector to match `InputType`
# + make them all async-based
class BatchLoader:
@abc.abstractmethod
def load(self) -> Generator[List[Document], None, None]:
raise NotImplementedError
SecondsSinceUnixEpoch = int
class PullLoader:
@abc.abstractmethod
def load(
self, start: SecondsSinceUnixEpoch, end: SecondsSinceUnixEpoch
) -> Generator[List[Document], None, None]:
raise NotImplementedError
# Fetches raw representations from a specific source for the specified time
# range. Is used when the source does not support subscriptions to some sort
# of event stream
# TODO: use Protocol instead of Callable
TimeRangeBasedLoad = Callable[[datetime, datetime], list[Any]]

View File

@ -1,4 +1,3 @@
from collections.abc import AsyncGenerator
from collections.abc import Generator
from typing import Any
from typing import cast
@ -8,11 +7,10 @@ from urllib.parse import urlparse
from bs4 import BeautifulSoup
from danswer.configs.app_configs import INDEX_BATCH_SIZE
from danswer.configs.constants import DocumentSource
from danswer.connectors.interfaces import PullLoader
from danswer.connectors.models import Document
from danswer.connectors.models import Section
from danswer.connectors.type_aliases import BatchLoader
from danswer.utils.logging import setup_logger
from playwright.async_api import async_playwright
from playwright.sync_api import sync_playwright
logger = setup_logger()
@ -47,7 +45,7 @@ def get_internal_links(
return internal_links
class BatchWebLoader(BatchLoader):
class WebLoader(PullLoader):
def __init__(
self,
base_url: str,
@ -56,75 +54,6 @@ class BatchWebLoader(BatchLoader):
self.base_url = base_url
self.batch_size = batch_size
async def async_load(self) -> AsyncGenerator[list[Document], None]:
"""NOTE: TEMPORARY UNTIL ALL COMPONENTS ARE CONVERTED TO ASYNC
At that point, this will take over from the regular `load` func.
"""
visited_links: set[str] = set()
to_visit: list[str] = [self.base_url]
doc_batch: list[Document] = []
async with async_playwright() as playwright:
browser = await playwright.chromium.launch(headless=True)
context = await browser.new_context()
while to_visit:
current_url = to_visit.pop()
if current_url in visited_links:
continue
visited_links.add(current_url)
try:
page = await context.new_page()
await page.goto(current_url)
content = await page.content()
soup = BeautifulSoup(content, "html.parser")
title_tag = soup.find("title")
title = None
if title_tag and title_tag.text:
title = title_tag.text
# Heuristics based cleaning
for undesired_tag in ["nav", "header", "footer", "meta"]:
[tag.extract() for tag in soup.find_all(undesired_tag)]
for undesired_div in ["sidebar", "header", "footer"]:
[
tag.extract()
for tag in soup.find_all("div", {"class": undesired_div})
]
page_text = soup.get_text(TAG_SEPARATOR)
doc_batch.append(
Document(
id=current_url,
sections=[Section(link=current_url, text=page_text)],
source=DocumentSource.WEB,
semantic_identifier=title,
metadata={},
)
)
internal_links = get_internal_links(
self.base_url, current_url, soup
)
for link in internal_links:
if link not in visited_links:
to_visit.append(link)
await page.close()
except Exception as e:
logger.error(f"Failed to fetch '{current_url}': {e}")
continue
if len(doc_batch) >= self.batch_size:
yield doc_batch
doc_batch = []
if doc_batch:
yield doc_batch
def load(self) -> Generator[list[Document], None, None]:
"""Traverses through all pages found on the website
and converts them into documents"""

View File

@ -1,9 +1,13 @@
import os
from collections.abc import AsyncGenerator
from collections.abc import Generator
from sqlalchemy.engine import create_engine
from sqlalchemy.engine import Engine
from sqlalchemy.ext.asyncio import AsyncEngine
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.ext.asyncio import create_async_engine
from sqlalchemy.orm import Session
ASYNC_DB_API = "asyncpg"
@ -28,6 +32,11 @@ def build_connection_string(
return f"postgresql+{db_api}://{user}:{password}@{host}:{port}/{db}"
def build_engine() -> Engine:
connection_string = build_connection_string()
return create_engine(connection_string)
def build_async_engine() -> AsyncEngine:
connection_string = build_connection_string()
return create_async_engine(connection_string)

View File

@ -1,39 +1,37 @@
from danswer.configs.constants import DocumentSource
from danswer.db.engine import build_async_engine
from danswer.db.engine import build_engine
from danswer.db.models import IndexAttempt
from danswer.db.models import IndexingStatus
from danswer.utils.logging import setup_logger
from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.orm import Session
logger = setup_logger()
async def insert_index_attempt(index_attempt: IndexAttempt) -> None:
def insert_index_attempt(index_attempt: IndexAttempt) -> None:
logger.info(f"Inserting {index_attempt}")
async with AsyncSession(build_async_engine()) as asession:
asession.add(index_attempt)
await asession.commit()
with Session(build_engine()) as session:
session.add(index_attempt)
session.commit()
async def fetch_index_attempts(
def fetch_index_attempts(
*,
sources: list[DocumentSource] | None = None,
statuses: list[IndexingStatus] | None = None,
) -> list[IndexAttempt]:
async with AsyncSession(
build_async_engine(), future=True, expire_on_commit=False
) as asession:
with Session(build_engine(), future=True, expire_on_commit=False) as session:
stmt = select(IndexAttempt)
if sources:
stmt = stmt.where(IndexAttempt.source.in_(sources))
if statuses:
stmt = stmt.where(IndexAttempt.status.in_(statuses))
results = await asession.scalars(stmt)
results = session.scalars(stmt)
return list(results.all())
async def update_index_attempt(
def update_index_attempt(
*,
index_attempt_id: int,
new_status: IndexingStatus,
@ -41,15 +39,13 @@ async def update_index_attempt(
error_msg: str | None = None,
) -> bool:
"""Returns `True` if successfully updated, `False` if cannot find matching ID"""
async with AsyncSession(
build_async_engine(), future=True, expire_on_commit=False
) as asession:
with Session(build_engine(), future=True, expire_on_commit=False) as session:
stmt = select(IndexAttempt).where(IndexAttempt.id == index_attempt_id)
result = await asession.scalar(stmt)
result = session.scalar(stmt)
if result:
result.status = new_status
result.document_ids = document_ids
result.error_msg = error_msg
await asession.commit()
session.commit()
return True
return False

View File

@ -39,14 +39,14 @@ class WebIndexAttemptRequest(BaseModel):
@router.post("/connectors/web/index-attempt", status_code=201)
async def index_website(web_index_attempt_request: WebIndexAttemptRequest):
def index_website(web_index_attempt_request: WebIndexAttemptRequest) -> None:
index_request = IndexAttempt(
source=DocumentSource.WEB,
input_type=InputType.PULL,
connector_specific_config={"url": web_index_attempt_request.url},
status=IndexingStatus.NOT_STARTED,
)
await insert_index_attempt(index_request)
insert_index_attempt(index_request)
class IndexAttemptSnapshot(BaseModel):
@ -62,8 +62,8 @@ class ListWebsiteIndexAttemptsResponse(BaseModel):
@router.get("/connectors/web/index-attempt")
async def list_website_index_attempts() -> ListWebsiteIndexAttemptsResponse:
index_attempts = await fetch_index_attempts(sources=[DocumentSource.WEB])
def list_website_index_attempts() -> ListWebsiteIndexAttemptsResponse:
index_attempts = fetch_index_attempts(sources=[DocumentSource.WEB])
return ListWebsiteIndexAttemptsResponse(
index_attempts=[
IndexAttemptSnapshot(

View File

@ -1,7 +1,7 @@
import logging
def setup_logger(name=__name__, log_level=logging.INFO):
def setup_logger(name: str = __name__, log_level: int = logging.INFO):
logger = logging.getLogger(name)
# If the logger already has handlers, assume it was already configured and return it.

View File

@ -1,4 +0,0 @@
[mypy]
mypy_path = .
explicit-package-bases = True
no-site-packages = True

View File

@ -7,9 +7,9 @@ from danswer.configs.app_configs import INDEX_BATCH_SIZE
from danswer.configs.app_configs import QDRANT_DEFAULT_COLLECTION
from danswer.connectors.github.batch import BatchGithubLoader
from danswer.connectors.google_drive.batch import BatchGoogleDriveLoader
from danswer.connectors.interfaces import PullLoader
from danswer.connectors.slack.batch import BatchSlackLoader
from danswer.connectors.type_aliases import BatchLoader
from danswer.connectors.web.batch import BatchWebLoader
from danswer.connectors.web.pull import WebLoader
from danswer.datastores.interfaces import Datastore
from danswer.datastores.qdrant.indexing import recreate_collection
from danswer.datastores.qdrant.store import QdrantDatastore
@ -22,7 +22,7 @@ logger = setup_logger()
def load_batch(
doc_loader: BatchLoader,
doc_loader: PullLoader,
chunker: Chunker,
embedder: Embedder,
datastore: Datastore,
@ -62,7 +62,7 @@ def load_slack_batch(file_path: str, qdrant_collection: str):
def load_web_batch(url: str, qdrant_collection: str):
logger.info("Loading documents from web.")
load_batch(
BatchWebLoader(base_url=url, batch_size=INDEX_BATCH_SIZE),
WebLoader(base_url=url, batch_size=INDEX_BATCH_SIZE),
DefaultChunker(),
DefaultEmbedder(),
QdrantDatastore(collection=qdrant_collection),

View File

@ -1,2 +1,5 @@
[mypy]
plugins = sqlalchemy.ext.mypy.plugin
plugins = sqlalchemy.ext.mypy.plugin
mypy_path = .
explicit_package_bases = True
disallow_untyped_defs = True