mirror of
https://github.com/danswer-ai/danswer.git
synced 2025-06-11 00:20:55 +02:00
cleanup connector interface
This commit is contained in:
parent
0b610502e0
commit
560822a327
@ -1,11 +1,10 @@
|
|||||||
import asyncio
|
|
||||||
import time
|
import time
|
||||||
from typing import cast
|
from typing import cast
|
||||||
|
|
||||||
from danswer.configs.constants import DocumentSource
|
from danswer.configs.constants import DocumentSource
|
||||||
from danswer.connectors.slack.config import get_pull_frequency
|
from danswer.connectors.slack.config import get_pull_frequency
|
||||||
from danswer.connectors.slack.pull import SlackPullLoader
|
from danswer.connectors.slack.pull import PeriodicSlackLoader
|
||||||
from danswer.connectors.web.batch import BatchWebLoader
|
from danswer.connectors.web.pull import WebLoader
|
||||||
from danswer.db.index_attempt import fetch_index_attempts
|
from danswer.db.index_attempt import fetch_index_attempts
|
||||||
from danswer.db.index_attempt import update_index_attempt
|
from danswer.db.index_attempt import update_index_attempt
|
||||||
from danswer.db.models import IndexingStatus
|
from danswer.db.models import IndexingStatus
|
||||||
@ -23,7 +22,11 @@ def _check_should_run(current_time: int, last_pull: int, pull_frequency: int) ->
|
|||||||
return current_time - last_pull > pull_frequency * 60
|
return current_time - last_pull > pull_frequency * 60
|
||||||
|
|
||||||
|
|
||||||
async def run_update() -> None:
|
def run_update() -> None:
|
||||||
|
# NOTE: have to make this async due to fastapi users only supporting an async
|
||||||
|
# driver for postgres. In the future, we should figure out a way to
|
||||||
|
# make it work with sync drivers so we don't need to make all code touching
|
||||||
|
# the database async
|
||||||
logger.info("Running update")
|
logger.info("Running update")
|
||||||
# TODO (chris): implement a more generic way to run updates
|
# TODO (chris): implement a more generic way to run updates
|
||||||
# so we don't need to edit this file for future connectors
|
# so we don't need to edit this file for future connectors
|
||||||
@ -37,7 +40,9 @@ async def run_update() -> None:
|
|||||||
except ConfigNotFoundError:
|
except ConfigNotFoundError:
|
||||||
pull_frequency = 0
|
pull_frequency = 0
|
||||||
if pull_frequency:
|
if pull_frequency:
|
||||||
last_slack_pull_key = LAST_PULL_KEY_TEMPLATE.format(SlackPullLoader.__name__)
|
last_slack_pull_key = LAST_PULL_KEY_TEMPLATE.format(
|
||||||
|
PeriodicSlackLoader.__name__
|
||||||
|
)
|
||||||
try:
|
try:
|
||||||
last_pull = cast(int, dynamic_config_store.load(last_slack_pull_key))
|
last_pull = cast(int, dynamic_config_store.load(last_slack_pull_key))
|
||||||
except ConfigNotFoundError:
|
except ConfigNotFoundError:
|
||||||
@ -47,7 +52,7 @@ async def run_update() -> None:
|
|||||||
current_time, last_pull, pull_frequency
|
current_time, last_pull, pull_frequency
|
||||||
):
|
):
|
||||||
logger.info(f"Running slack pull from {last_pull or 0} to {current_time}")
|
logger.info(f"Running slack pull from {last_pull or 0} to {current_time}")
|
||||||
for doc_batch in SlackPullLoader().load(last_pull or 0, current_time):
|
for doc_batch in PeriodicSlackLoader().load(last_pull or 0, current_time):
|
||||||
indexing_pipeline(doc_batch)
|
indexing_pipeline(doc_batch)
|
||||||
dynamic_config_store.store(last_slack_pull_key, current_time)
|
dynamic_config_store.store(last_slack_pull_key, current_time)
|
||||||
|
|
||||||
@ -56,7 +61,7 @@ async def run_update() -> None:
|
|||||||
# prevent race conditions across multiple background jobs. For now,
|
# prevent race conditions across multiple background jobs. For now,
|
||||||
# this assumes we only ever run a single background job at a time
|
# this assumes we only ever run a single background job at a time
|
||||||
# TODO (chris): make this generic for all pull connectors (not just web)
|
# TODO (chris): make this generic for all pull connectors (not just web)
|
||||||
not_started_index_attempts = await fetch_index_attempts(
|
not_started_index_attempts = fetch_index_attempts(
|
||||||
sources=[DocumentSource.WEB], statuses=[IndexingStatus.NOT_STARTED]
|
sources=[DocumentSource.WEB], statuses=[IndexingStatus.NOT_STARTED]
|
||||||
)
|
)
|
||||||
for not_started_index_attempt in not_started_index_attempts:
|
for not_started_index_attempt in not_started_index_attempts:
|
||||||
@ -67,7 +72,7 @@ async def run_update() -> None:
|
|||||||
f"{not_started_index_attempt.input_type}, and connector_specific_config: "
|
f"{not_started_index_attempt.input_type}, and connector_specific_config: "
|
||||||
f"{not_started_index_attempt.connector_specific_config}"
|
f"{not_started_index_attempt.connector_specific_config}"
|
||||||
)
|
)
|
||||||
await update_index_attempt(
|
update_index_attempt(
|
||||||
index_attempt_id=not_started_index_attempt.id,
|
index_attempt_id=not_started_index_attempt.id,
|
||||||
new_status=IndexingStatus.IN_PROGRESS,
|
new_status=IndexingStatus.IN_PROGRESS,
|
||||||
)
|
)
|
||||||
@ -75,10 +80,10 @@ async def run_update() -> None:
|
|||||||
error_msg = None
|
error_msg = None
|
||||||
base_url = not_started_index_attempt.connector_specific_config["url"]
|
base_url = not_started_index_attempt.connector_specific_config["url"]
|
||||||
try:
|
try:
|
||||||
# TODO (chris): make all connectors async + spawn processes to
|
# TODO (chris): spawn processes to parallelize / take advantage of
|
||||||
# parallelize / take advantage of multiple cores + implement retries
|
# multiple cores + implement retries
|
||||||
document_ids: list[str] = []
|
document_ids: list[str] = []
|
||||||
async for doc_batch in BatchWebLoader(base_url=base_url).async_load():
|
for doc_batch in WebLoader(base_url=base_url).load():
|
||||||
chunks = indexing_pipeline(doc_batch)
|
chunks = indexing_pipeline(doc_batch)
|
||||||
document_ids.extend([chunk.source_document.id for chunk in chunks])
|
document_ids.extend([chunk.source_document.id for chunk in chunks])
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
@ -87,7 +92,7 @@ async def run_update() -> None:
|
|||||||
)
|
)
|
||||||
error_msg = str(e)
|
error_msg = str(e)
|
||||||
|
|
||||||
await update_index_attempt(
|
update_index_attempt(
|
||||||
index_attempt_id=not_started_index_attempt.id,
|
index_attempt_id=not_started_index_attempt.id,
|
||||||
new_status=IndexingStatus.FAILED if error_msg else IndexingStatus.SUCCESS,
|
new_status=IndexingStatus.FAILED if error_msg else IndexingStatus.SUCCESS,
|
||||||
document_ids=document_ids if not error_msg else None,
|
document_ids=document_ids if not error_msg else None,
|
||||||
@ -95,11 +100,11 @@ async def run_update() -> None:
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
async def update_loop(delay: int = 60):
|
def update_loop(delay: int = 60):
|
||||||
while True:
|
while True:
|
||||||
start = time.time()
|
start = time.time()
|
||||||
try:
|
try:
|
||||||
await run_update()
|
run_update()
|
||||||
except Exception:
|
except Exception:
|
||||||
logger.exception("Failed to run update")
|
logger.exception("Failed to run update")
|
||||||
sleep_time = delay - (time.time() - start)
|
sleep_time = delay - (time.time() - start)
|
||||||
@ -108,4 +113,4 @@ async def update_loop(delay: int = 60):
|
|||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
asyncio.run(update_loop())
|
update_loop()
|
||||||
|
@ -4,9 +4,9 @@ from collections.abc import Generator
|
|||||||
from danswer.configs.app_configs import GITHUB_ACCESS_TOKEN
|
from danswer.configs.app_configs import GITHUB_ACCESS_TOKEN
|
||||||
from danswer.configs.app_configs import INDEX_BATCH_SIZE
|
from danswer.configs.app_configs import INDEX_BATCH_SIZE
|
||||||
from danswer.configs.constants import DocumentSource
|
from danswer.configs.constants import DocumentSource
|
||||||
|
from danswer.connectors.interfaces import PullLoader
|
||||||
from danswer.connectors.models import Document
|
from danswer.connectors.models import Document
|
||||||
from danswer.connectors.models import Section
|
from danswer.connectors.models import Section
|
||||||
from danswer.connectors.type_aliases import BatchLoader
|
|
||||||
from danswer.utils.logging import setup_logger
|
from danswer.utils.logging import setup_logger
|
||||||
from github import Github
|
from github import Github
|
||||||
|
|
||||||
@ -24,7 +24,7 @@ def get_pr_batches(pull_requests, batch_size):
|
|||||||
yield batch
|
yield batch
|
||||||
|
|
||||||
|
|
||||||
class BatchGithubLoader(BatchLoader):
|
class BatchGithubLoader(PullLoader):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
repo_owner: str,
|
repo_owner: str,
|
||||||
|
@ -7,9 +7,9 @@ from danswer.configs.app_configs import GOOGLE_DRIVE_INCLUDE_SHARED
|
|||||||
from danswer.configs.app_configs import GOOGLE_DRIVE_TOKENS_JSON
|
from danswer.configs.app_configs import GOOGLE_DRIVE_TOKENS_JSON
|
||||||
from danswer.configs.app_configs import INDEX_BATCH_SIZE
|
from danswer.configs.app_configs import INDEX_BATCH_SIZE
|
||||||
from danswer.configs.constants import DocumentSource
|
from danswer.configs.constants import DocumentSource
|
||||||
|
from danswer.connectors.interfaces import PullLoader
|
||||||
from danswer.connectors.models import Document
|
from danswer.connectors.models import Document
|
||||||
from danswer.connectors.models import Section
|
from danswer.connectors.models import Section
|
||||||
from danswer.connectors.type_aliases import BatchLoader
|
|
||||||
from danswer.utils.logging import setup_logger
|
from danswer.utils.logging import setup_logger
|
||||||
from google.auth.transport.requests import Request # type: ignore
|
from google.auth.transport.requests import Request # type: ignore
|
||||||
from google.oauth2.credentials import Credentials # type: ignore
|
from google.oauth2.credentials import Credentials # type: ignore
|
||||||
@ -103,7 +103,11 @@ def extract_text(file: dict[str, str], service: discovery.Resource) -> str:
|
|||||||
return "\n".join(page.extract_text() for page in pdf_reader.pages)
|
return "\n".join(page.extract_text() for page in pdf_reader.pages)
|
||||||
|
|
||||||
|
|
||||||
class BatchGoogleDriveLoader(BatchLoader):
|
class BatchGoogleDriveLoader(PullLoader):
|
||||||
|
"""
|
||||||
|
Loads everything in a Google Drive account
|
||||||
|
"""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
batch_size: int = INDEX_BATCH_SIZE,
|
batch_size: int = INDEX_BATCH_SIZE,
|
||||||
|
29
backend/danswer/connectors/interfaces.py
Normal file
29
backend/danswer/connectors/interfaces.py
Normal file
@ -0,0 +1,29 @@
|
|||||||
|
import abc
|
||||||
|
from collections.abc import Generator
|
||||||
|
from typing import Any
|
||||||
|
from typing import List
|
||||||
|
|
||||||
|
from danswer.connectors.models import Document
|
||||||
|
|
||||||
|
|
||||||
|
SecondsSinceUnixEpoch = float
|
||||||
|
|
||||||
|
|
||||||
|
class PullLoader:
|
||||||
|
@abc.abstractmethod
|
||||||
|
def load(self) -> Generator[List[Document], None, None]:
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
|
||||||
|
class RangePullLoader:
|
||||||
|
@abc.abstractmethod
|
||||||
|
def load(
|
||||||
|
self, start: SecondsSinceUnixEpoch, end: SecondsSinceUnixEpoch
|
||||||
|
) -> Generator[List[Document], None, None]:
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
|
||||||
|
class PushLoader:
|
||||||
|
@abc.abstractmethod
|
||||||
|
def load(self, event: Any) -> Generator[List[Document], None, None]:
|
||||||
|
raise NotImplementedError
|
@ -7,10 +7,10 @@ from typing import cast
|
|||||||
|
|
||||||
from danswer.configs.app_configs import INDEX_BATCH_SIZE
|
from danswer.configs.app_configs import INDEX_BATCH_SIZE
|
||||||
from danswer.configs.constants import DocumentSource
|
from danswer.configs.constants import DocumentSource
|
||||||
|
from danswer.connectors.interfaces import PullLoader
|
||||||
from danswer.connectors.models import Document
|
from danswer.connectors.models import Document
|
||||||
from danswer.connectors.models import Section
|
from danswer.connectors.models import Section
|
||||||
from danswer.connectors.slack.utils import get_message_link
|
from danswer.connectors.slack.utils import get_message_link
|
||||||
from danswer.connectors.type_aliases import BatchLoader
|
|
||||||
|
|
||||||
|
|
||||||
def _process_batch_event(
|
def _process_batch_event(
|
||||||
@ -58,7 +58,9 @@ def _process_batch_event(
|
|||||||
return None
|
return None
|
||||||
|
|
||||||
|
|
||||||
class BatchSlackLoader(BatchLoader):
|
class BatchSlackLoader(PullLoader):
|
||||||
|
"""Loads from an unzipped slack workspace export"""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self, export_path_str: str, batch_size: int = INDEX_BATCH_SIZE
|
self, export_path_str: str, batch_size: int = INDEX_BATCH_SIZE
|
||||||
) -> None:
|
) -> None:
|
||||||
|
@ -7,12 +7,12 @@ from typing import List
|
|||||||
|
|
||||||
from danswer.configs.app_configs import INDEX_BATCH_SIZE
|
from danswer.configs.app_configs import INDEX_BATCH_SIZE
|
||||||
from danswer.configs.constants import DocumentSource
|
from danswer.configs.constants import DocumentSource
|
||||||
|
from danswer.connectors.interfaces import RangePullLoader
|
||||||
|
from danswer.connectors.interfaces import SecondsSinceUnixEpoch
|
||||||
from danswer.connectors.models import Document
|
from danswer.connectors.models import Document
|
||||||
from danswer.connectors.models import Section
|
from danswer.connectors.models import Section
|
||||||
from danswer.connectors.slack.utils import get_client
|
from danswer.connectors.slack.utils import get_client
|
||||||
from danswer.connectors.slack.utils import get_message_link
|
from danswer.connectors.slack.utils import get_message_link
|
||||||
from danswer.connectors.type_aliases import PullLoader
|
|
||||||
from danswer.connectors.type_aliases import SecondsSinceUnixEpoch
|
|
||||||
from danswer.utils.logging import setup_logger
|
from danswer.utils.logging import setup_logger
|
||||||
from slack_sdk import WebClient
|
from slack_sdk import WebClient
|
||||||
from slack_sdk.errors import SlackApiError
|
from slack_sdk.errors import SlackApiError
|
||||||
@ -200,7 +200,7 @@ def get_all_docs(
|
|||||||
return docs
|
return docs
|
||||||
|
|
||||||
|
|
||||||
class SlackPullLoader(PullLoader):
|
class PeriodicSlackLoader(RangePullLoader):
|
||||||
def __init__(self, batch_size: int = INDEX_BATCH_SIZE) -> None:
|
def __init__(self, batch_size: int = INDEX_BATCH_SIZE) -> None:
|
||||||
self.client = get_client()
|
self.client = get_client()
|
||||||
self.batch_size = batch_size
|
self.batch_size = batch_size
|
||||||
@ -208,7 +208,6 @@ class SlackPullLoader(PullLoader):
|
|||||||
def load(
|
def load(
|
||||||
self, start: SecondsSinceUnixEpoch, end: SecondsSinceUnixEpoch
|
self, start: SecondsSinceUnixEpoch, end: SecondsSinceUnixEpoch
|
||||||
) -> Generator[List[Document], None, None]:
|
) -> Generator[List[Document], None, None]:
|
||||||
# TODO: make this respect start and end
|
|
||||||
all_docs = get_all_docs(client=self.client, oldest=str(start), latest=str(end))
|
all_docs = get_all_docs(client=self.client, oldest=str(start), latest=str(end))
|
||||||
for i in range(0, len(all_docs), self.batch_size):
|
for i in range(0, len(all_docs), self.batch_size):
|
||||||
yield all_docs[i : i + self.batch_size]
|
yield all_docs[i : i + self.batch_size]
|
||||||
|
@ -1,43 +0,0 @@
|
|||||||
import abc
|
|
||||||
from collections.abc import Callable
|
|
||||||
from collections.abc import Generator
|
|
||||||
from datetime import datetime
|
|
||||||
from typing import Any
|
|
||||||
from typing import List
|
|
||||||
from typing import Optional
|
|
||||||
|
|
||||||
from danswer.connectors.models import Document
|
|
||||||
|
|
||||||
|
|
||||||
ConnectorConfig = dict[str, Any]
|
|
||||||
|
|
||||||
# takes in the raw representation of a document from a source and returns a
|
|
||||||
# Document object
|
|
||||||
ProcessDocumentFunc = Callable[..., Document]
|
|
||||||
BuildListenerFunc = Callable[[ConnectorConfig], ProcessDocumentFunc]
|
|
||||||
|
|
||||||
|
|
||||||
# TODO (chris) refactor definition of a connector to match `InputType`
|
|
||||||
# + make them all async-based
|
|
||||||
class BatchLoader:
|
|
||||||
@abc.abstractmethod
|
|
||||||
def load(self) -> Generator[List[Document], None, None]:
|
|
||||||
raise NotImplementedError
|
|
||||||
|
|
||||||
|
|
||||||
SecondsSinceUnixEpoch = int
|
|
||||||
|
|
||||||
|
|
||||||
class PullLoader:
|
|
||||||
@abc.abstractmethod
|
|
||||||
def load(
|
|
||||||
self, start: SecondsSinceUnixEpoch, end: SecondsSinceUnixEpoch
|
|
||||||
) -> Generator[List[Document], None, None]:
|
|
||||||
raise NotImplementedError
|
|
||||||
|
|
||||||
|
|
||||||
# Fetches raw representations from a specific source for the specified time
|
|
||||||
# range. Is used when the source does not support subscriptions to some sort
|
|
||||||
# of event stream
|
|
||||||
# TODO: use Protocol instead of Callable
|
|
||||||
TimeRangeBasedLoad = Callable[[datetime, datetime], list[Any]]
|
|
@ -1,4 +1,3 @@
|
|||||||
from collections.abc import AsyncGenerator
|
|
||||||
from collections.abc import Generator
|
from collections.abc import Generator
|
||||||
from typing import Any
|
from typing import Any
|
||||||
from typing import cast
|
from typing import cast
|
||||||
@ -8,11 +7,10 @@ from urllib.parse import urlparse
|
|||||||
from bs4 import BeautifulSoup
|
from bs4 import BeautifulSoup
|
||||||
from danswer.configs.app_configs import INDEX_BATCH_SIZE
|
from danswer.configs.app_configs import INDEX_BATCH_SIZE
|
||||||
from danswer.configs.constants import DocumentSource
|
from danswer.configs.constants import DocumentSource
|
||||||
|
from danswer.connectors.interfaces import PullLoader
|
||||||
from danswer.connectors.models import Document
|
from danswer.connectors.models import Document
|
||||||
from danswer.connectors.models import Section
|
from danswer.connectors.models import Section
|
||||||
from danswer.connectors.type_aliases import BatchLoader
|
|
||||||
from danswer.utils.logging import setup_logger
|
from danswer.utils.logging import setup_logger
|
||||||
from playwright.async_api import async_playwright
|
|
||||||
from playwright.sync_api import sync_playwright
|
from playwright.sync_api import sync_playwright
|
||||||
|
|
||||||
logger = setup_logger()
|
logger = setup_logger()
|
||||||
@ -47,7 +45,7 @@ def get_internal_links(
|
|||||||
return internal_links
|
return internal_links
|
||||||
|
|
||||||
|
|
||||||
class BatchWebLoader(BatchLoader):
|
class WebLoader(PullLoader):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
base_url: str,
|
base_url: str,
|
||||||
@ -56,75 +54,6 @@ class BatchWebLoader(BatchLoader):
|
|||||||
self.base_url = base_url
|
self.base_url = base_url
|
||||||
self.batch_size = batch_size
|
self.batch_size = batch_size
|
||||||
|
|
||||||
async def async_load(self) -> AsyncGenerator[list[Document], None]:
|
|
||||||
"""NOTE: TEMPORARY UNTIL ALL COMPONENTS ARE CONVERTED TO ASYNC
|
|
||||||
At that point, this will take over from the regular `load` func.
|
|
||||||
"""
|
|
||||||
visited_links: set[str] = set()
|
|
||||||
to_visit: list[str] = [self.base_url]
|
|
||||||
doc_batch: list[Document] = []
|
|
||||||
|
|
||||||
async with async_playwright() as playwright:
|
|
||||||
browser = await playwright.chromium.launch(headless=True)
|
|
||||||
context = await browser.new_context()
|
|
||||||
|
|
||||||
while to_visit:
|
|
||||||
current_url = to_visit.pop()
|
|
||||||
if current_url in visited_links:
|
|
||||||
continue
|
|
||||||
visited_links.add(current_url)
|
|
||||||
|
|
||||||
try:
|
|
||||||
page = await context.new_page()
|
|
||||||
await page.goto(current_url)
|
|
||||||
content = await page.content()
|
|
||||||
soup = BeautifulSoup(content, "html.parser")
|
|
||||||
|
|
||||||
title_tag = soup.find("title")
|
|
||||||
title = None
|
|
||||||
if title_tag and title_tag.text:
|
|
||||||
title = title_tag.text
|
|
||||||
|
|
||||||
# Heuristics based cleaning
|
|
||||||
for undesired_tag in ["nav", "header", "footer", "meta"]:
|
|
||||||
[tag.extract() for tag in soup.find_all(undesired_tag)]
|
|
||||||
for undesired_div in ["sidebar", "header", "footer"]:
|
|
||||||
[
|
|
||||||
tag.extract()
|
|
||||||
for tag in soup.find_all("div", {"class": undesired_div})
|
|
||||||
]
|
|
||||||
|
|
||||||
page_text = soup.get_text(TAG_SEPARATOR)
|
|
||||||
|
|
||||||
doc_batch.append(
|
|
||||||
Document(
|
|
||||||
id=current_url,
|
|
||||||
sections=[Section(link=current_url, text=page_text)],
|
|
||||||
source=DocumentSource.WEB,
|
|
||||||
semantic_identifier=title,
|
|
||||||
metadata={},
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
internal_links = get_internal_links(
|
|
||||||
self.base_url, current_url, soup
|
|
||||||
)
|
|
||||||
for link in internal_links:
|
|
||||||
if link not in visited_links:
|
|
||||||
to_visit.append(link)
|
|
||||||
|
|
||||||
await page.close()
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"Failed to fetch '{current_url}': {e}")
|
|
||||||
continue
|
|
||||||
|
|
||||||
if len(doc_batch) >= self.batch_size:
|
|
||||||
yield doc_batch
|
|
||||||
doc_batch = []
|
|
||||||
|
|
||||||
if doc_batch:
|
|
||||||
yield doc_batch
|
|
||||||
|
|
||||||
def load(self) -> Generator[list[Document], None, None]:
|
def load(self) -> Generator[list[Document], None, None]:
|
||||||
"""Traverses through all pages found on the website
|
"""Traverses through all pages found on the website
|
||||||
and converts them into documents"""
|
and converts them into documents"""
|
@ -1,9 +1,13 @@
|
|||||||
import os
|
import os
|
||||||
from collections.abc import AsyncGenerator
|
from collections.abc import AsyncGenerator
|
||||||
|
from collections.abc import Generator
|
||||||
|
|
||||||
|
from sqlalchemy.engine import create_engine
|
||||||
|
from sqlalchemy.engine import Engine
|
||||||
from sqlalchemy.ext.asyncio import AsyncEngine
|
from sqlalchemy.ext.asyncio import AsyncEngine
|
||||||
from sqlalchemy.ext.asyncio import AsyncSession
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
from sqlalchemy.ext.asyncio import create_async_engine
|
from sqlalchemy.ext.asyncio import create_async_engine
|
||||||
|
from sqlalchemy.orm import Session
|
||||||
|
|
||||||
|
|
||||||
ASYNC_DB_API = "asyncpg"
|
ASYNC_DB_API = "asyncpg"
|
||||||
@ -28,6 +32,11 @@ def build_connection_string(
|
|||||||
return f"postgresql+{db_api}://{user}:{password}@{host}:{port}/{db}"
|
return f"postgresql+{db_api}://{user}:{password}@{host}:{port}/{db}"
|
||||||
|
|
||||||
|
|
||||||
|
def build_engine() -> Engine:
|
||||||
|
connection_string = build_connection_string()
|
||||||
|
return create_engine(connection_string)
|
||||||
|
|
||||||
|
|
||||||
def build_async_engine() -> AsyncEngine:
|
def build_async_engine() -> AsyncEngine:
|
||||||
connection_string = build_connection_string()
|
connection_string = build_connection_string()
|
||||||
return create_async_engine(connection_string)
|
return create_async_engine(connection_string)
|
||||||
|
@ -1,39 +1,37 @@
|
|||||||
from danswer.configs.constants import DocumentSource
|
from danswer.configs.constants import DocumentSource
|
||||||
from danswer.db.engine import build_async_engine
|
from danswer.db.engine import build_engine
|
||||||
from danswer.db.models import IndexAttempt
|
from danswer.db.models import IndexAttempt
|
||||||
from danswer.db.models import IndexingStatus
|
from danswer.db.models import IndexingStatus
|
||||||
from danswer.utils.logging import setup_logger
|
from danswer.utils.logging import setup_logger
|
||||||
from sqlalchemy import select
|
from sqlalchemy import select
|
||||||
from sqlalchemy.ext.asyncio import AsyncSession
|
from sqlalchemy.orm import Session
|
||||||
|
|
||||||
logger = setup_logger()
|
logger = setup_logger()
|
||||||
|
|
||||||
|
|
||||||
async def insert_index_attempt(index_attempt: IndexAttempt) -> None:
|
def insert_index_attempt(index_attempt: IndexAttempt) -> None:
|
||||||
logger.info(f"Inserting {index_attempt}")
|
logger.info(f"Inserting {index_attempt}")
|
||||||
async with AsyncSession(build_async_engine()) as asession:
|
with Session(build_engine()) as session:
|
||||||
asession.add(index_attempt)
|
session.add(index_attempt)
|
||||||
await asession.commit()
|
session.commit()
|
||||||
|
|
||||||
|
|
||||||
async def fetch_index_attempts(
|
def fetch_index_attempts(
|
||||||
*,
|
*,
|
||||||
sources: list[DocumentSource] | None = None,
|
sources: list[DocumentSource] | None = None,
|
||||||
statuses: list[IndexingStatus] | None = None,
|
statuses: list[IndexingStatus] | None = None,
|
||||||
) -> list[IndexAttempt]:
|
) -> list[IndexAttempt]:
|
||||||
async with AsyncSession(
|
with Session(build_engine(), future=True, expire_on_commit=False) as session:
|
||||||
build_async_engine(), future=True, expire_on_commit=False
|
|
||||||
) as asession:
|
|
||||||
stmt = select(IndexAttempt)
|
stmt = select(IndexAttempt)
|
||||||
if sources:
|
if sources:
|
||||||
stmt = stmt.where(IndexAttempt.source.in_(sources))
|
stmt = stmt.where(IndexAttempt.source.in_(sources))
|
||||||
if statuses:
|
if statuses:
|
||||||
stmt = stmt.where(IndexAttempt.status.in_(statuses))
|
stmt = stmt.where(IndexAttempt.status.in_(statuses))
|
||||||
results = await asession.scalars(stmt)
|
results = session.scalars(stmt)
|
||||||
return list(results.all())
|
return list(results.all())
|
||||||
|
|
||||||
|
|
||||||
async def update_index_attempt(
|
def update_index_attempt(
|
||||||
*,
|
*,
|
||||||
index_attempt_id: int,
|
index_attempt_id: int,
|
||||||
new_status: IndexingStatus,
|
new_status: IndexingStatus,
|
||||||
@ -41,15 +39,13 @@ async def update_index_attempt(
|
|||||||
error_msg: str | None = None,
|
error_msg: str | None = None,
|
||||||
) -> bool:
|
) -> bool:
|
||||||
"""Returns `True` if successfully updated, `False` if cannot find matching ID"""
|
"""Returns `True` if successfully updated, `False` if cannot find matching ID"""
|
||||||
async with AsyncSession(
|
with Session(build_engine(), future=True, expire_on_commit=False) as session:
|
||||||
build_async_engine(), future=True, expire_on_commit=False
|
|
||||||
) as asession:
|
|
||||||
stmt = select(IndexAttempt).where(IndexAttempt.id == index_attempt_id)
|
stmt = select(IndexAttempt).where(IndexAttempt.id == index_attempt_id)
|
||||||
result = await asession.scalar(stmt)
|
result = session.scalar(stmt)
|
||||||
if result:
|
if result:
|
||||||
result.status = new_status
|
result.status = new_status
|
||||||
result.document_ids = document_ids
|
result.document_ids = document_ids
|
||||||
result.error_msg = error_msg
|
result.error_msg = error_msg
|
||||||
await asession.commit()
|
session.commit()
|
||||||
return True
|
return True
|
||||||
return False
|
return False
|
||||||
|
@ -39,14 +39,14 @@ class WebIndexAttemptRequest(BaseModel):
|
|||||||
|
|
||||||
|
|
||||||
@router.post("/connectors/web/index-attempt", status_code=201)
|
@router.post("/connectors/web/index-attempt", status_code=201)
|
||||||
async def index_website(web_index_attempt_request: WebIndexAttemptRequest):
|
def index_website(web_index_attempt_request: WebIndexAttemptRequest) -> None:
|
||||||
index_request = IndexAttempt(
|
index_request = IndexAttempt(
|
||||||
source=DocumentSource.WEB,
|
source=DocumentSource.WEB,
|
||||||
input_type=InputType.PULL,
|
input_type=InputType.PULL,
|
||||||
connector_specific_config={"url": web_index_attempt_request.url},
|
connector_specific_config={"url": web_index_attempt_request.url},
|
||||||
status=IndexingStatus.NOT_STARTED,
|
status=IndexingStatus.NOT_STARTED,
|
||||||
)
|
)
|
||||||
await insert_index_attempt(index_request)
|
insert_index_attempt(index_request)
|
||||||
|
|
||||||
|
|
||||||
class IndexAttemptSnapshot(BaseModel):
|
class IndexAttemptSnapshot(BaseModel):
|
||||||
@ -62,8 +62,8 @@ class ListWebsiteIndexAttemptsResponse(BaseModel):
|
|||||||
|
|
||||||
|
|
||||||
@router.get("/connectors/web/index-attempt")
|
@router.get("/connectors/web/index-attempt")
|
||||||
async def list_website_index_attempts() -> ListWebsiteIndexAttemptsResponse:
|
def list_website_index_attempts() -> ListWebsiteIndexAttemptsResponse:
|
||||||
index_attempts = await fetch_index_attempts(sources=[DocumentSource.WEB])
|
index_attempts = fetch_index_attempts(sources=[DocumentSource.WEB])
|
||||||
return ListWebsiteIndexAttemptsResponse(
|
return ListWebsiteIndexAttemptsResponse(
|
||||||
index_attempts=[
|
index_attempts=[
|
||||||
IndexAttemptSnapshot(
|
IndexAttemptSnapshot(
|
||||||
|
@ -1,7 +1,7 @@
|
|||||||
import logging
|
import logging
|
||||||
|
|
||||||
|
|
||||||
def setup_logger(name=__name__, log_level=logging.INFO):
|
def setup_logger(name: str = __name__, log_level: int = logging.INFO):
|
||||||
logger = logging.getLogger(name)
|
logger = logging.getLogger(name)
|
||||||
|
|
||||||
# If the logger already has handlers, assume it was already configured and return it.
|
# If the logger already has handlers, assume it was already configured and return it.
|
||||||
|
@ -1,4 +0,0 @@
|
|||||||
[mypy]
|
|
||||||
mypy_path = .
|
|
||||||
explicit-package-bases = True
|
|
||||||
no-site-packages = True
|
|
@ -7,9 +7,9 @@ from danswer.configs.app_configs import INDEX_BATCH_SIZE
|
|||||||
from danswer.configs.app_configs import QDRANT_DEFAULT_COLLECTION
|
from danswer.configs.app_configs import QDRANT_DEFAULT_COLLECTION
|
||||||
from danswer.connectors.github.batch import BatchGithubLoader
|
from danswer.connectors.github.batch import BatchGithubLoader
|
||||||
from danswer.connectors.google_drive.batch import BatchGoogleDriveLoader
|
from danswer.connectors.google_drive.batch import BatchGoogleDriveLoader
|
||||||
|
from danswer.connectors.interfaces import PullLoader
|
||||||
from danswer.connectors.slack.batch import BatchSlackLoader
|
from danswer.connectors.slack.batch import BatchSlackLoader
|
||||||
from danswer.connectors.type_aliases import BatchLoader
|
from danswer.connectors.web.pull import WebLoader
|
||||||
from danswer.connectors.web.batch import BatchWebLoader
|
|
||||||
from danswer.datastores.interfaces import Datastore
|
from danswer.datastores.interfaces import Datastore
|
||||||
from danswer.datastores.qdrant.indexing import recreate_collection
|
from danswer.datastores.qdrant.indexing import recreate_collection
|
||||||
from danswer.datastores.qdrant.store import QdrantDatastore
|
from danswer.datastores.qdrant.store import QdrantDatastore
|
||||||
@ -22,7 +22,7 @@ logger = setup_logger()
|
|||||||
|
|
||||||
|
|
||||||
def load_batch(
|
def load_batch(
|
||||||
doc_loader: BatchLoader,
|
doc_loader: PullLoader,
|
||||||
chunker: Chunker,
|
chunker: Chunker,
|
||||||
embedder: Embedder,
|
embedder: Embedder,
|
||||||
datastore: Datastore,
|
datastore: Datastore,
|
||||||
@ -62,7 +62,7 @@ def load_slack_batch(file_path: str, qdrant_collection: str):
|
|||||||
def load_web_batch(url: str, qdrant_collection: str):
|
def load_web_batch(url: str, qdrant_collection: str):
|
||||||
logger.info("Loading documents from web.")
|
logger.info("Loading documents from web.")
|
||||||
load_batch(
|
load_batch(
|
||||||
BatchWebLoader(base_url=url, batch_size=INDEX_BATCH_SIZE),
|
WebLoader(base_url=url, batch_size=INDEX_BATCH_SIZE),
|
||||||
DefaultChunker(),
|
DefaultChunker(),
|
||||||
DefaultEmbedder(),
|
DefaultEmbedder(),
|
||||||
QdrantDatastore(collection=qdrant_collection),
|
QdrantDatastore(collection=qdrant_collection),
|
||||||
|
@ -1,2 +1,5 @@
|
|||||||
[mypy]
|
[mypy]
|
||||||
plugins = sqlalchemy.ext.mypy.plugin
|
plugins = sqlalchemy.ext.mypy.plugin
|
||||||
|
mypy_path = .
|
||||||
|
explicit_package_bases = True
|
||||||
|
disallow_untyped_defs = True
|
||||||
|
Loading…
x
Reference in New Issue
Block a user