mirror of
https://github.com/danswer-ai/danswer.git
synced 2025-09-19 20:24:32 +02:00
DAN-93 Standardize Connectors (#70)
This commit is contained in:
@@ -2,12 +2,10 @@ import time
|
|||||||
from typing import cast
|
from typing import cast
|
||||||
|
|
||||||
from danswer.configs.constants import DocumentSource
|
from danswer.configs.constants import DocumentSource
|
||||||
from danswer.connectors.factory import build_connector
|
from danswer.connectors.factory import build_load_connector
|
||||||
from danswer.connectors.factory import build_pull_connector
|
|
||||||
from danswer.connectors.models import InputType
|
from danswer.connectors.models import InputType
|
||||||
from danswer.connectors.slack.config import get_pull_frequency
|
from danswer.connectors.slack.config import get_pull_frequency
|
||||||
from danswer.connectors.slack.pull import PeriodicSlackLoader
|
from danswer.connectors.slack.connector import SlackConnector
|
||||||
from danswer.connectors.web.pull import WebLoader
|
|
||||||
from danswer.db.index_attempt import fetch_index_attempts
|
from danswer.db.index_attempt import fetch_index_attempts
|
||||||
from danswer.db.index_attempt import insert_index_attempt
|
from danswer.db.index_attempt import insert_index_attempt
|
||||||
from danswer.db.index_attempt import update_index_attempt
|
from danswer.db.index_attempt import update_index_attempt
|
||||||
@@ -20,7 +18,7 @@ from danswer.utils.logging import setup_logger
|
|||||||
|
|
||||||
logger = setup_logger()
|
logger = setup_logger()
|
||||||
|
|
||||||
LAST_PULL_KEY_TEMPLATE = "last_pull_{}"
|
LAST_POLL_KEY_TEMPLATE = "last_poll_{}"
|
||||||
|
|
||||||
|
|
||||||
def _check_should_run(current_time: int, last_pull: int, pull_frequency: int) -> bool:
|
def _check_should_run(current_time: int, last_pull: int, pull_frequency: int) -> bool:
|
||||||
@@ -43,9 +41,7 @@ def run_update() -> None:
|
|||||||
except ConfigNotFoundError:
|
except ConfigNotFoundError:
|
||||||
pull_frequency = 0
|
pull_frequency = 0
|
||||||
if pull_frequency:
|
if pull_frequency:
|
||||||
last_slack_pull_key = LAST_PULL_KEY_TEMPLATE.format(
|
last_slack_pull_key = LAST_POLL_KEY_TEMPLATE.format(SlackConnector.__name__)
|
||||||
PeriodicSlackLoader.__name__
|
|
||||||
)
|
|
||||||
try:
|
try:
|
||||||
last_pull = cast(int, dynamic_config_store.load(last_slack_pull_key))
|
last_pull = cast(int, dynamic_config_store.load(last_slack_pull_key))
|
||||||
except ConfigNotFoundError:
|
except ConfigNotFoundError:
|
||||||
@@ -61,7 +57,7 @@ def run_update() -> None:
|
|||||||
insert_index_attempt(
|
insert_index_attempt(
|
||||||
IndexAttempt(
|
IndexAttempt(
|
||||||
source=DocumentSource.SLACK,
|
source=DocumentSource.SLACK,
|
||||||
input_type=InputType.PULL,
|
input_type=InputType.POLL,
|
||||||
status=IndexingStatus.NOT_STARTED,
|
status=IndexingStatus.NOT_STARTED,
|
||||||
connector_specific_config={},
|
connector_specific_config={},
|
||||||
)
|
)
|
||||||
@@ -75,7 +71,7 @@ def run_update() -> None:
|
|||||||
# prevent race conditions across multiple background jobs. For now,
|
# prevent race conditions across multiple background jobs. For now,
|
||||||
# this assumes we only ever run a single background job at a time
|
# this assumes we only ever run a single background job at a time
|
||||||
not_started_index_attempts = fetch_index_attempts(
|
not_started_index_attempts = fetch_index_attempts(
|
||||||
input_types=[InputType.PULL], statuses=[IndexingStatus.NOT_STARTED]
|
input_types=[InputType.LOAD_STATE], statuses=[IndexingStatus.NOT_STARTED]
|
||||||
)
|
)
|
||||||
for not_started_index_attempt in not_started_index_attempts:
|
for not_started_index_attempt in not_started_index_attempts:
|
||||||
logger.info(
|
logger.info(
|
||||||
@@ -94,13 +90,13 @@ def run_update() -> None:
|
|||||||
try:
|
try:
|
||||||
# TODO (chris): spawn processes to parallelize / take advantage of
|
# TODO (chris): spawn processes to parallelize / take advantage of
|
||||||
# multiple cores + implement retries
|
# multiple cores + implement retries
|
||||||
connector = build_pull_connector(
|
connector = build_load_connector(
|
||||||
source=not_started_index_attempt.source,
|
source=not_started_index_attempt.source,
|
||||||
connector_specific_config=not_started_index_attempt.connector_specific_config,
|
connector_specific_config=not_started_index_attempt.connector_specific_config,
|
||||||
)
|
)
|
||||||
|
|
||||||
document_ids: list[str] = []
|
document_ids: list[str] = []
|
||||||
for doc_batch in connector.load():
|
for doc_batch in connector.load_from_state():
|
||||||
indexing_pipeline(doc_batch)
|
indexing_pipeline(doc_batch)
|
||||||
document_ids.extend([doc.id for doc in doc_batch])
|
document_ids.extend([doc.id for doc in doc_batch])
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
@@ -1,6 +1,7 @@
|
|||||||
import inspect
|
import inspect
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
from typing import cast
|
||||||
|
|
||||||
from danswer.connectors.models import Document
|
from danswer.connectors.models import Document
|
||||||
|
|
||||||
@@ -34,10 +35,12 @@ class InferenceChunk(BaseChunk):
|
|||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_dict(cls, init_dict: dict[str, Any]) -> "InferenceChunk":
|
def from_dict(cls, init_dict: dict[str, Any]) -> "InferenceChunk":
|
||||||
return cls(
|
init_kwargs = {
|
||||||
**{
|
k: v for k, v in init_dict.items() if k in inspect.signature(cls).parameters
|
||||||
k: v
|
}
|
||||||
for k, v in init_dict.items()
|
if "source_links" in init_kwargs:
|
||||||
if k in inspect.signature(cls).parameters
|
init_kwargs["source_links"] = {
|
||||||
|
int(k): v
|
||||||
|
for k, v in cast(dict[str, str], init_kwargs["source_links"]).items()
|
||||||
}
|
}
|
||||||
)
|
return cls(**init_kwargs)
|
||||||
|
@@ -3,15 +3,16 @@ from collections.abc import Generator
|
|||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
from danswer.configs.constants import DocumentSource
|
from danswer.configs.constants import DocumentSource
|
||||||
from danswer.connectors.github.batch import BatchGithubLoader
|
from danswer.connectors.github.connector import GithubConnector
|
||||||
from danswer.connectors.google_drive.batch import BatchGoogleDriveLoader
|
from danswer.connectors.google_drive.connector import GoogleDriveConnector
|
||||||
from danswer.connectors.interfaces import PullLoader
|
from danswer.connectors.interfaces import BaseConnector
|
||||||
from danswer.connectors.interfaces import RangePullLoader
|
from danswer.connectors.interfaces import EventConnector
|
||||||
|
from danswer.connectors.interfaces import LoadConnector
|
||||||
|
from danswer.connectors.interfaces import PollConnector
|
||||||
from danswer.connectors.models import Document
|
from danswer.connectors.models import Document
|
||||||
from danswer.connectors.models import InputType
|
from danswer.connectors.models import InputType
|
||||||
from danswer.connectors.slack.batch import BatchSlackLoader
|
from danswer.connectors.slack.connector import SlackConnector
|
||||||
from danswer.connectors.slack.pull import PeriodicSlackLoader
|
from danswer.connectors.web.connector import WebConnector
|
||||||
from danswer.connectors.web.pull import WebLoader
|
|
||||||
|
|
||||||
_NUM_SECONDS_IN_DAY = 86400
|
_NUM_SECONDS_IN_DAY = 86400
|
||||||
|
|
||||||
@@ -24,45 +25,52 @@ def build_connector(
|
|||||||
source: DocumentSource,
|
source: DocumentSource,
|
||||||
input_type: InputType,
|
input_type: InputType,
|
||||||
connector_specific_config: dict[str, Any],
|
connector_specific_config: dict[str, Any],
|
||||||
) -> PullLoader | RangePullLoader:
|
) -> BaseConnector:
|
||||||
if source == DocumentSource.SLACK:
|
if source == DocumentSource.SLACK:
|
||||||
if input_type == InputType.PULL:
|
connector: BaseConnector = SlackConnector(**connector_specific_config)
|
||||||
return PeriodicSlackLoader(**connector_specific_config)
|
|
||||||
if input_type == InputType.LOAD_STATE:
|
|
||||||
return BatchSlackLoader(**connector_specific_config)
|
|
||||||
elif source == DocumentSource.GOOGLE_DRIVE:
|
elif source == DocumentSource.GOOGLE_DRIVE:
|
||||||
if input_type == InputType.PULL:
|
connector = GoogleDriveConnector(**connector_specific_config)
|
||||||
return BatchGoogleDriveLoader(**connector_specific_config)
|
|
||||||
elif source == DocumentSource.GITHUB:
|
elif source == DocumentSource.GITHUB:
|
||||||
if input_type == InputType.PULL:
|
connector = GithubConnector(**connector_specific_config)
|
||||||
return BatchGithubLoader(**connector_specific_config)
|
|
||||||
elif source == DocumentSource.WEB:
|
elif source == DocumentSource.WEB:
|
||||||
if input_type == InputType.PULL:
|
connector = WebConnector(**connector_specific_config)
|
||||||
return WebLoader(**connector_specific_config)
|
else:
|
||||||
|
raise ConnectorMissingException(f"Connector not found for source={source}")
|
||||||
|
|
||||||
raise ConnectorMissingException(
|
if any(
|
||||||
f"Connector not found for source={source}, input_type={input_type}"
|
[
|
||||||
)
|
input_type == InputType.LOAD_STATE
|
||||||
|
and not isinstance(connector, LoadConnector),
|
||||||
|
input_type == InputType.POLL and not isinstance(connector, PollConnector),
|
||||||
|
input_type == InputType.EVENT and not isinstance(connector, EventConnector),
|
||||||
|
]
|
||||||
|
):
|
||||||
|
raise ConnectorMissingException(
|
||||||
|
f"Connector for source={source} does not accept input_type={input_type}"
|
||||||
|
)
|
||||||
|
|
||||||
|
return connector
|
||||||
|
|
||||||
|
|
||||||
def build_pull_connector(
|
# TODO this is some jank, rework at some point
|
||||||
source: DocumentSource, connector_specific_config: dict[str, Any]
|
def _poll_to_load_connector(range_pull_connector: PollConnector) -> LoadConnector:
|
||||||
) -> PullLoader:
|
class _Connector(LoadConnector):
|
||||||
connector = build_connector(source, InputType.PULL, connector_specific_config)
|
|
||||||
return (
|
|
||||||
_range_pull_to_pull(connector)
|
|
||||||
if isinstance(connector, RangePullLoader)
|
|
||||||
else connector
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def _range_pull_to_pull(range_pull_connector: RangePullLoader) -> PullLoader:
|
|
||||||
class _Connector(PullLoader):
|
|
||||||
def __init__(self) -> None:
|
def __init__(self) -> None:
|
||||||
self._connector = range_pull_connector
|
self._connector = range_pull_connector
|
||||||
|
|
||||||
def load(self) -> Generator[list[Document], None, None]:
|
def load_from_state(self) -> Generator[list[Document], None, None]:
|
||||||
# adding some buffer to make sure we get all documents
|
# adding some buffer to make sure we get all documents
|
||||||
return self._connector.load(0, time.time() + _NUM_SECONDS_IN_DAY)
|
return self._connector.poll_source(0, time.time() + _NUM_SECONDS_IN_DAY)
|
||||||
|
|
||||||
return _Connector()
|
return _Connector()
|
||||||
|
|
||||||
|
|
||||||
|
# TODO this is some jank, rework at some point
|
||||||
|
def build_load_connector(
|
||||||
|
source: DocumentSource, connector_specific_config: dict[str, Any]
|
||||||
|
) -> LoadConnector:
|
||||||
|
connector = build_connector(source, InputType.LOAD_STATE, connector_specific_config)
|
||||||
|
if isinstance(connector, PollConnector):
|
||||||
|
return _poll_to_load_connector(connector)
|
||||||
|
assert isinstance(connector, LoadConnector)
|
||||||
|
return connector
|
||||||
|
@@ -4,7 +4,7 @@ from collections.abc import Generator
|
|||||||
from danswer.configs.app_configs import GITHUB_ACCESS_TOKEN
|
from danswer.configs.app_configs import GITHUB_ACCESS_TOKEN
|
||||||
from danswer.configs.app_configs import INDEX_BATCH_SIZE
|
from danswer.configs.app_configs import INDEX_BATCH_SIZE
|
||||||
from danswer.configs.constants import DocumentSource
|
from danswer.configs.constants import DocumentSource
|
||||||
from danswer.connectors.interfaces import PullLoader
|
from danswer.connectors.interfaces import LoadConnector
|
||||||
from danswer.connectors.models import Document
|
from danswer.connectors.models import Document
|
||||||
from danswer.connectors.models import Section
|
from danswer.connectors.models import Section
|
||||||
from danswer.utils.logging import setup_logger
|
from danswer.utils.logging import setup_logger
|
||||||
@@ -29,7 +29,7 @@ def get_pr_batches(
|
|||||||
yield batch
|
yield batch
|
||||||
|
|
||||||
|
|
||||||
class BatchGithubLoader(PullLoader):
|
class GithubConnector(LoadConnector):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
repo_owner: str,
|
repo_owner: str,
|
||||||
@@ -42,7 +42,7 @@ class BatchGithubLoader(PullLoader):
|
|||||||
self.batch_size = batch_size
|
self.batch_size = batch_size
|
||||||
self.state_filter = state_filter
|
self.state_filter = state_filter
|
||||||
|
|
||||||
def load(self) -> Generator[list[Document], None, None]:
|
def load_from_state(self) -> Generator[list[Document], None, None]:
|
||||||
repo = github_client.get_repo(f"{self.repo_owner}/{self.repo_name}")
|
repo = github_client.get_repo(f"{self.repo_owner}/{self.repo_name}")
|
||||||
pull_requests = repo.get_pulls(state=self.state_filter)
|
pull_requests = repo.get_pulls(state=self.state_filter)
|
||||||
for pr_batch in get_pr_batches(pull_requests, self.batch_size):
|
for pr_batch in get_pr_batches(pull_requests, self.batch_size):
|
@@ -5,7 +5,7 @@ from danswer.configs.app_configs import GOOGLE_DRIVE_INCLUDE_SHARED
|
|||||||
from danswer.configs.app_configs import INDEX_BATCH_SIZE
|
from danswer.configs.app_configs import INDEX_BATCH_SIZE
|
||||||
from danswer.configs.constants import DocumentSource
|
from danswer.configs.constants import DocumentSource
|
||||||
from danswer.connectors.google_drive.connector_auth import get_drive_tokens
|
from danswer.connectors.google_drive.connector_auth import get_drive_tokens
|
||||||
from danswer.connectors.interfaces import PullLoader
|
from danswer.connectors.interfaces import LoadConnector
|
||||||
from danswer.connectors.models import Document
|
from danswer.connectors.models import Document
|
||||||
from danswer.connectors.models import Section
|
from danswer.connectors.models import Section
|
||||||
from danswer.utils.logging import setup_logger
|
from danswer.utils.logging import setup_logger
|
||||||
@@ -81,11 +81,7 @@ def extract_text(file: dict[str, str], service: discovery.Resource) -> str:
|
|||||||
return "\n".join(page.extract_text() for page in pdf_reader.pages)
|
return "\n".join(page.extract_text() for page in pdf_reader.pages)
|
||||||
|
|
||||||
|
|
||||||
class BatchGoogleDriveLoader(PullLoader):
|
class GoogleDriveConnector(LoadConnector):
|
||||||
"""
|
|
||||||
Loads everything in a Google Drive account
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
batch_size: int = INDEX_BATCH_SIZE,
|
batch_size: int = INDEX_BATCH_SIZE,
|
||||||
@@ -98,7 +94,7 @@ class BatchGoogleDriveLoader(PullLoader):
|
|||||||
if not self.creds:
|
if not self.creds:
|
||||||
raise PermissionError("Unable to access Google Drive.")
|
raise PermissionError("Unable to access Google Drive.")
|
||||||
|
|
||||||
def load(self) -> Generator[list[Document], None, None]:
|
def load_from_state(self) -> Generator[list[Document], None, None]:
|
||||||
service = discovery.build("drive", "v3", credentials=self.creds)
|
service = discovery.build("drive", "v3", credentials=self.creds)
|
||||||
for files_batch in get_file_batches(
|
for files_batch in get_file_batches(
|
||||||
service, self.include_shared, self.batch_size
|
service, self.include_shared, self.batch_size
|
@@ -8,22 +8,29 @@ from danswer.connectors.models import Document
|
|||||||
SecondsSinceUnixEpoch = float
|
SecondsSinceUnixEpoch = float
|
||||||
|
|
||||||
|
|
||||||
# TODO (chris): rename from Loader -> Connector
|
class BaseConnector(abc.ABC):
|
||||||
class PullLoader:
|
# Reserved for future shared uses
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
# Large set update or reindex, generally pulling a complete state or from a savestate file
|
||||||
|
class LoadConnector(BaseConnector):
|
||||||
@abc.abstractmethod
|
@abc.abstractmethod
|
||||||
def load(self) -> Generator[list[Document], None, None]:
|
def load_from_state(self) -> Generator[list[Document], None, None]:
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
|
|
||||||
class RangePullLoader:
|
# Small set updates by time
|
||||||
|
class PollConnector(BaseConnector):
|
||||||
@abc.abstractmethod
|
@abc.abstractmethod
|
||||||
def load(
|
def poll_source(
|
||||||
self, start: SecondsSinceUnixEpoch, end: SecondsSinceUnixEpoch
|
self, start: SecondsSinceUnixEpoch, end: SecondsSinceUnixEpoch
|
||||||
) -> Generator[list[Document], None, None]:
|
) -> Generator[list[Document], None, None]:
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
|
|
||||||
class PushLoader:
|
# Event driven
|
||||||
|
class EventConnector(BaseConnector):
|
||||||
@abc.abstractmethod
|
@abc.abstractmethod
|
||||||
def load(self, event: Any) -> Generator[list[Document], None, None]:
|
def handle_event(self, event: Any) -> Generator[list[Document], None, None]:
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
@@ -26,9 +26,9 @@ def get_raw_document_text(document: Document) -> str:
|
|||||||
|
|
||||||
|
|
||||||
class InputType(str, Enum):
|
class InputType(str, Enum):
|
||||||
PULL = "pull" # e.g. calling slack API to get all messages in the last hour
|
LOAD_STATE = "load_state" # e.g. loading a current full state or a save state, such as from a file
|
||||||
LOAD_STATE = "load_state" # e.g. loading the state of a slack workspace from a file
|
POLL = "poll" # e.g. calling an API to get all documents in the last hour
|
||||||
EVENT = "event" # e.g. registered an endpoint as a listener, and processing slack events
|
EVENT = "event" # e.g. registered an endpoint as a listener, and processing connector events
|
||||||
|
|
||||||
|
|
||||||
class ConnectorDescriptor(BaseModel):
|
class ConnectorDescriptor(BaseModel):
|
||||||
|
@@ -1,99 +0,0 @@
|
|||||||
import json
|
|
||||||
import os
|
|
||||||
from collections.abc import Generator
|
|
||||||
from pathlib import Path
|
|
||||||
from typing import Any
|
|
||||||
from typing import cast
|
|
||||||
|
|
||||||
from danswer.configs.app_configs import INDEX_BATCH_SIZE
|
|
||||||
from danswer.configs.constants import DocumentSource
|
|
||||||
from danswer.connectors.interfaces import PullLoader
|
|
||||||
from danswer.connectors.models import Document
|
|
||||||
from danswer.connectors.models import Section
|
|
||||||
from danswer.connectors.slack.utils import get_message_link
|
|
||||||
|
|
||||||
|
|
||||||
def _process_batch_event(
|
|
||||||
slack_event: dict[str, Any],
|
|
||||||
channel: dict[str, Any],
|
|
||||||
matching_doc: Document | None,
|
|
||||||
workspace: str | None = None,
|
|
||||||
) -> Document | None:
|
|
||||||
if (
|
|
||||||
slack_event["type"] == "message"
|
|
||||||
and slack_event.get("subtype") != "channel_join"
|
|
||||||
):
|
|
||||||
if matching_doc:
|
|
||||||
return Document(
|
|
||||||
id=matching_doc.id,
|
|
||||||
sections=matching_doc.sections
|
|
||||||
+ [
|
|
||||||
Section(
|
|
||||||
link=get_message_link(
|
|
||||||
slack_event, workspace=workspace, channel_id=channel["id"]
|
|
||||||
),
|
|
||||||
text=slack_event["text"],
|
|
||||||
)
|
|
||||||
],
|
|
||||||
source=matching_doc.source,
|
|
||||||
semantic_identifier=matching_doc.semantic_identifier,
|
|
||||||
metadata=matching_doc.metadata,
|
|
||||||
)
|
|
||||||
|
|
||||||
return Document(
|
|
||||||
id=slack_event["ts"],
|
|
||||||
sections=[
|
|
||||||
Section(
|
|
||||||
link=get_message_link(
|
|
||||||
slack_event, workspace=workspace, channel_id=channel["id"]
|
|
||||||
),
|
|
||||||
text=slack_event["text"],
|
|
||||||
)
|
|
||||||
],
|
|
||||||
source=DocumentSource.SLACK,
|
|
||||||
semantic_identifier=channel["name"],
|
|
||||||
metadata={},
|
|
||||||
)
|
|
||||||
|
|
||||||
return None
|
|
||||||
|
|
||||||
|
|
||||||
class BatchSlackLoader(PullLoader):
|
|
||||||
"""Loads from an unzipped slack workspace export"""
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self, export_path_str: str, batch_size: int = INDEX_BATCH_SIZE
|
|
||||||
) -> None:
|
|
||||||
self.export_path_str = export_path_str
|
|
||||||
self.batch_size = batch_size
|
|
||||||
|
|
||||||
def load(self) -> Generator[list[Document], None, None]:
|
|
||||||
export_path = Path(self.export_path_str)
|
|
||||||
|
|
||||||
with open(export_path / "channels.json") as f:
|
|
||||||
channels = json.load(f)
|
|
||||||
|
|
||||||
document_batch: dict[str, Document] = {}
|
|
||||||
for channel_info in channels:
|
|
||||||
channel_dir_path = export_path / cast(str, channel_info["name"])
|
|
||||||
channel_file_paths = [
|
|
||||||
channel_dir_path / file_name
|
|
||||||
for file_name in os.listdir(channel_dir_path)
|
|
||||||
]
|
|
||||||
for path in channel_file_paths:
|
|
||||||
with open(path) as f:
|
|
||||||
events = cast(list[dict[str, Any]], json.load(f))
|
|
||||||
for slack_event in events:
|
|
||||||
doc = _process_batch_event(
|
|
||||||
slack_event=slack_event,
|
|
||||||
channel=channel_info,
|
|
||||||
matching_doc=document_batch.get(
|
|
||||||
slack_event.get("thread_ts", "")
|
|
||||||
),
|
|
||||||
)
|
|
||||||
if doc:
|
|
||||||
document_batch[doc.id] = doc
|
|
||||||
if len(document_batch) >= self.batch_size:
|
|
||||||
yield list(document_batch.values())
|
|
||||||
|
|
||||||
yield list(document_batch.values())
|
|
@@ -1,13 +1,17 @@
|
|||||||
|
import json
|
||||||
|
import os
|
||||||
import time
|
import time
|
||||||
from collections.abc import Callable
|
from collections.abc import Callable
|
||||||
from collections.abc import Generator
|
from collections.abc import Generator
|
||||||
|
from pathlib import Path
|
||||||
from typing import Any
|
from typing import Any
|
||||||
from typing import cast
|
from typing import cast
|
||||||
from typing import List
|
from typing import List
|
||||||
|
|
||||||
from danswer.configs.app_configs import INDEX_BATCH_SIZE
|
from danswer.configs.app_configs import INDEX_BATCH_SIZE
|
||||||
from danswer.configs.constants import DocumentSource
|
from danswer.configs.constants import DocumentSource
|
||||||
from danswer.connectors.interfaces import RangePullLoader
|
from danswer.connectors.interfaces import LoadConnector
|
||||||
|
from danswer.connectors.interfaces import PollConnector
|
||||||
from danswer.connectors.interfaces import SecondsSinceUnixEpoch
|
from danswer.connectors.interfaces import SecondsSinceUnixEpoch
|
||||||
from danswer.connectors.models import Document
|
from danswer.connectors.models import Document
|
||||||
from danswer.connectors.models import Section
|
from danswer.connectors.models import Section
|
||||||
@@ -200,12 +204,91 @@ def get_all_docs(
|
|||||||
return docs
|
return docs
|
||||||
|
|
||||||
|
|
||||||
class PeriodicSlackLoader(RangePullLoader):
|
def _process_batch_event(
|
||||||
def __init__(self, batch_size: int = INDEX_BATCH_SIZE) -> None:
|
slack_event: dict[str, Any],
|
||||||
self.client = get_client()
|
channel: dict[str, Any],
|
||||||
self.batch_size = batch_size
|
matching_doc: Document | None,
|
||||||
|
workspace: str | None = None,
|
||||||
|
) -> Document | None:
|
||||||
|
if (
|
||||||
|
slack_event["type"] == "message"
|
||||||
|
and slack_event.get("subtype") != "channel_join"
|
||||||
|
):
|
||||||
|
if matching_doc:
|
||||||
|
return Document(
|
||||||
|
id=matching_doc.id,
|
||||||
|
sections=matching_doc.sections
|
||||||
|
+ [
|
||||||
|
Section(
|
||||||
|
link=get_message_link(
|
||||||
|
slack_event, workspace=workspace, channel_id=channel["id"]
|
||||||
|
),
|
||||||
|
text=slack_event["text"],
|
||||||
|
)
|
||||||
|
],
|
||||||
|
source=matching_doc.source,
|
||||||
|
semantic_identifier=matching_doc.semantic_identifier,
|
||||||
|
metadata=matching_doc.metadata,
|
||||||
|
)
|
||||||
|
|
||||||
def load(
|
return Document(
|
||||||
|
id=slack_event["ts"],
|
||||||
|
sections=[
|
||||||
|
Section(
|
||||||
|
link=get_message_link(
|
||||||
|
slack_event, workspace=workspace, channel_id=channel["id"]
|
||||||
|
),
|
||||||
|
text=slack_event["text"],
|
||||||
|
)
|
||||||
|
],
|
||||||
|
source=DocumentSource.SLACK,
|
||||||
|
semantic_identifier=channel["name"],
|
||||||
|
metadata={},
|
||||||
|
)
|
||||||
|
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
class SlackConnector(LoadConnector, PollConnector):
|
||||||
|
def __init__(
|
||||||
|
self, export_path_str: str, batch_size: int = INDEX_BATCH_SIZE
|
||||||
|
) -> None:
|
||||||
|
self.export_path_str = export_path_str
|
||||||
|
self.batch_size = batch_size
|
||||||
|
self.client = get_client()
|
||||||
|
|
||||||
|
def load_from_state(self) -> Generator[list[Document], None, None]:
|
||||||
|
export_path = Path(self.export_path_str)
|
||||||
|
|
||||||
|
with open(export_path / "channels.json") as f:
|
||||||
|
channels = json.load(f)
|
||||||
|
|
||||||
|
document_batch: dict[str, Document] = {}
|
||||||
|
for channel_info in channels:
|
||||||
|
channel_dir_path = export_path / cast(str, channel_info["name"])
|
||||||
|
channel_file_paths = [
|
||||||
|
channel_dir_path / file_name
|
||||||
|
for file_name in os.listdir(channel_dir_path)
|
||||||
|
]
|
||||||
|
for path in channel_file_paths:
|
||||||
|
with open(path) as f:
|
||||||
|
events = cast(list[dict[str, Any]], json.load(f))
|
||||||
|
for slack_event in events:
|
||||||
|
doc = _process_batch_event(
|
||||||
|
slack_event=slack_event,
|
||||||
|
channel=channel_info,
|
||||||
|
matching_doc=document_batch.get(
|
||||||
|
slack_event.get("thread_ts", "")
|
||||||
|
),
|
||||||
|
)
|
||||||
|
if doc:
|
||||||
|
document_batch[doc.id] = doc
|
||||||
|
if len(document_batch) >= self.batch_size:
|
||||||
|
yield list(document_batch.values())
|
||||||
|
|
||||||
|
yield list(document_batch.values())
|
||||||
|
|
||||||
|
def poll_source(
|
||||||
self, start: SecondsSinceUnixEpoch, end: SecondsSinceUnixEpoch
|
self, start: SecondsSinceUnixEpoch, end: SecondsSinceUnixEpoch
|
||||||
) -> Generator[List[Document], None, None]:
|
) -> Generator[List[Document], None, None]:
|
||||||
all_docs = get_all_docs(client=self.client, oldest=str(start), latest=str(end))
|
all_docs = get_all_docs(client=self.client, oldest=str(start), latest=str(end))
|
@@ -9,7 +9,7 @@ import requests
|
|||||||
from bs4 import BeautifulSoup
|
from bs4 import BeautifulSoup
|
||||||
from danswer.configs.app_configs import INDEX_BATCH_SIZE
|
from danswer.configs.app_configs import INDEX_BATCH_SIZE
|
||||||
from danswer.configs.constants import DocumentSource
|
from danswer.configs.constants import DocumentSource
|
||||||
from danswer.connectors.interfaces import PullLoader
|
from danswer.connectors.interfaces import LoadConnector
|
||||||
from danswer.connectors.models import Document
|
from danswer.connectors.models import Document
|
||||||
from danswer.connectors.models import Section
|
from danswer.connectors.models import Section
|
||||||
from danswer.utils.logging import setup_logger
|
from danswer.utils.logging import setup_logger
|
||||||
@@ -51,7 +51,7 @@ def get_internal_links(
|
|||||||
return internal_links
|
return internal_links
|
||||||
|
|
||||||
|
|
||||||
class WebLoader(PullLoader):
|
class WebConnector(LoadConnector):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
base_url: str,
|
base_url: str,
|
||||||
@@ -60,7 +60,7 @@ class WebLoader(PullLoader):
|
|||||||
self.base_url = base_url
|
self.base_url = base_url
|
||||||
self.batch_size = batch_size
|
self.batch_size = batch_size
|
||||||
|
|
||||||
def load(self) -> Generator[list[Document], None, None]:
|
def load_from_state(self) -> Generator[list[Document], None, None]:
|
||||||
"""Traverses through all pages found on the website
|
"""Traverses through all pages found on the website
|
||||||
and converts them into documents"""
|
and converts them into documents"""
|
||||||
visited_links: set[str] = set()
|
visited_links: set[str] = set()
|
||||||
@@ -88,8 +88,8 @@ class WebLoader(PullLoader):
|
|||||||
response = requests.get(current_url)
|
response = requests.get(current_url)
|
||||||
pdf_reader = PdfReader(io.BytesIO(response.content))
|
pdf_reader = PdfReader(io.BytesIO(response.content))
|
||||||
page_text = ""
|
page_text = ""
|
||||||
for page in pdf_reader.pages:
|
for pdf_page in pdf_reader.pages:
|
||||||
page_text += page.extract_text()
|
page_text += pdf_page.extract_text()
|
||||||
|
|
||||||
doc_batch.append(
|
doc_batch.append(
|
||||||
Document(
|
Document(
|
@@ -376,7 +376,7 @@ class OpenAIChatCompletionQA(OpenAIQAModel):
|
|||||||
found_answer_start = False
|
found_answer_start = False
|
||||||
found_answer_end = False
|
found_answer_end = False
|
||||||
for event in response:
|
for event in response:
|
||||||
event_dict = cast(str, event["choices"][0]["delta"])
|
event_dict = event["choices"][0]["delta"]
|
||||||
if (
|
if (
|
||||||
"content" not in event_dict
|
"content" not in event_dict
|
||||||
): # could be a role message or empty termination
|
): # could be a role message or empty termination
|
||||||
|
@@ -93,7 +93,11 @@ def retrieve_ranked_documents(
|
|||||||
return None
|
return None
|
||||||
ranked_chunks = semantic_reranking(query, top_chunks)
|
ranked_chunks = semantic_reranking(query, top_chunks)
|
||||||
|
|
||||||
top_docs = [ranked_chunk.source_links["0"] for ranked_chunk in ranked_chunks]
|
top_docs = [
|
||||||
|
ranked_chunk.source_links[0]
|
||||||
|
for ranked_chunk in ranked_chunks
|
||||||
|
if ranked_chunk.source_links is not None
|
||||||
|
]
|
||||||
files_log_msg = f"Top links from semantic search: {', '.join(top_docs)}"
|
files_log_msg = f"Top links from semantic search: {', '.join(top_docs)}"
|
||||||
logger.info(files_log_msg)
|
logger.info(files_log_msg)
|
||||||
|
|
||||||
|
@@ -87,7 +87,7 @@ def index(
|
|||||||
_: User = Depends(current_admin_user),
|
_: User = Depends(current_admin_user),
|
||||||
) -> None:
|
) -> None:
|
||||||
# validate that the connector specified by the source / input_type combination
|
# validate that the connector specified by the source / input_type combination
|
||||||
# exists AND that the connector_specific_config is valid for that connector type
|
# exists AND that the connector_specific_config is valid for that connector type, should be load
|
||||||
build_connector(
|
build_connector(
|
||||||
source=source,
|
source=source,
|
||||||
input_type=index_attempt_request.input_type,
|
input_type=index_attempt_request.input_type,
|
||||||
|
@@ -1,8 +1,8 @@
|
|||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
from danswer.connectors.slack.pull import get_channel_info
|
from danswer.connectors.slack.connector import get_channel_info
|
||||||
from danswer.connectors.slack.pull import get_thread
|
from danswer.connectors.slack.connector import get_thread
|
||||||
from danswer.connectors.slack.pull import thread_to_doc
|
from danswer.connectors.slack.connector import thread_to_doc
|
||||||
from danswer.connectors.slack.utils import get_client
|
from danswer.connectors.slack.utils import get_client
|
||||||
from danswer.utils.indexing_pipeline import build_indexing_pipeline
|
from danswer.utils.indexing_pipeline import build_indexing_pipeline
|
||||||
from danswer.utils.logging import setup_logger
|
from danswer.utils.logging import setup_logger
|
||||||
|
@@ -32,7 +32,7 @@ class UserRoleResponse(BaseModel):
|
|||||||
|
|
||||||
class SearchDoc(BaseModel):
|
class SearchDoc(BaseModel):
|
||||||
semantic_identifier: str
|
semantic_identifier: str
|
||||||
link: str
|
link: str | None
|
||||||
blurb: str
|
blurb: str
|
||||||
source_type: str
|
source_type: str
|
||||||
|
|
||||||
|
@@ -78,7 +78,7 @@ def direct_qa(
|
|||||||
top_docs = [
|
top_docs = [
|
||||||
SearchDoc(
|
SearchDoc(
|
||||||
semantic_identifier=chunk.semantic_identifier,
|
semantic_identifier=chunk.semantic_identifier,
|
||||||
link=chunk.source_links.get("0") if chunk.source_links else None,
|
link=chunk.source_links.get(0) if chunk.source_links else None,
|
||||||
blurb=chunk.blurb,
|
blurb=chunk.blurb,
|
||||||
source_type=chunk.source_type,
|
source_type=chunk.source_type,
|
||||||
)
|
)
|
||||||
@@ -117,7 +117,7 @@ def stream_direct_qa(
|
|||||||
top_docs = [
|
top_docs = [
|
||||||
SearchDoc(
|
SearchDoc(
|
||||||
semantic_identifier=chunk.semantic_identifier,
|
semantic_identifier=chunk.semantic_identifier,
|
||||||
link=chunk.source_links.get("0") if chunk.source_links else None,
|
link=chunk.source_links.get(0) if chunk.source_links else None,
|
||||||
blurb=chunk.blurb,
|
blurb=chunk.blurb,
|
||||||
source_type=chunk.source_type,
|
source_type=chunk.source_type,
|
||||||
)
|
)
|
||||||
@@ -130,6 +130,8 @@ def stream_direct_qa(
|
|||||||
for response_dict in qa_model.answer_question_stream(
|
for response_dict in qa_model.answer_question_stream(
|
||||||
query, ranked_chunks[:NUM_RERANKED_RESULTS]
|
query, ranked_chunks[:NUM_RERANKED_RESULTS]
|
||||||
):
|
):
|
||||||
|
if response_dict is None:
|
||||||
|
continue
|
||||||
logger.debug(response_dict)
|
logger.debug(response_dict)
|
||||||
yield get_json_line(response_dict)
|
yield get_json_line(response_dict)
|
||||||
return
|
return
|
||||||
|
@@ -1,6 +1,7 @@
|
|||||||
import time
|
import time
|
||||||
from collections.abc import Callable
|
from collections.abc import Callable
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
from typing import cast
|
||||||
from typing import TypeVar
|
from typing import TypeVar
|
||||||
|
|
||||||
from danswer.utils.logging import setup_logger
|
from danswer.utils.logging import setup_logger
|
||||||
@@ -21,7 +22,7 @@ def log_function_time(
|
|||||||
...
|
...
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def timing_wrapper(func: Callable) -> Callable:
|
def timing_wrapper(func: F) -> F:
|
||||||
def wrapped_func(*args: Any, **kwargs: Any) -> Any:
|
def wrapped_func(*args: Any, **kwargs: Any) -> Any:
|
||||||
start_time = time.time()
|
start_time = time.time()
|
||||||
result = func(*args, **kwargs)
|
result = func(*args, **kwargs)
|
||||||
@@ -30,6 +31,6 @@ def log_function_time(
|
|||||||
)
|
)
|
||||||
return result
|
return result
|
||||||
|
|
||||||
return wrapped_func
|
return cast(F, wrapped_func)
|
||||||
|
|
||||||
return timing_wrapper
|
return timing_wrapper
|
||||||
|
@@ -5,12 +5,12 @@ from danswer.chunking.chunk import Chunker
|
|||||||
from danswer.chunking.chunk import DefaultChunker
|
from danswer.chunking.chunk import DefaultChunker
|
||||||
from danswer.configs.app_configs import INDEX_BATCH_SIZE
|
from danswer.configs.app_configs import INDEX_BATCH_SIZE
|
||||||
from danswer.configs.app_configs import QDRANT_DEFAULT_COLLECTION
|
from danswer.configs.app_configs import QDRANT_DEFAULT_COLLECTION
|
||||||
from danswer.connectors.github.batch import BatchGithubLoader
|
from danswer.connectors.github.connector import GithubConnector
|
||||||
from danswer.connectors.google_drive.batch import BatchGoogleDriveLoader
|
from danswer.connectors.google_drive.connector import GoogleDriveConnector
|
||||||
from danswer.connectors.google_drive.connector_auth import backend_get_credentials
|
from danswer.connectors.google_drive.connector_auth import backend_get_credentials
|
||||||
from danswer.connectors.interfaces import PullLoader
|
from danswer.connectors.interfaces import LoadConnector
|
||||||
from danswer.connectors.slack.batch import BatchSlackLoader
|
from danswer.connectors.slack.connector import SlackConnector
|
||||||
from danswer.connectors.web.pull import WebLoader
|
from danswer.connectors.web.connector import WebConnector
|
||||||
from danswer.datastores.interfaces import Datastore
|
from danswer.datastores.interfaces import Datastore
|
||||||
from danswer.datastores.qdrant.indexing import recreate_collection
|
from danswer.datastores.qdrant.indexing import recreate_collection
|
||||||
from danswer.datastores.qdrant.store import QdrantDatastore
|
from danswer.datastores.qdrant.store import QdrantDatastore
|
||||||
@@ -23,14 +23,14 @@ logger = setup_logger()
|
|||||||
|
|
||||||
|
|
||||||
def load_batch(
|
def load_batch(
|
||||||
doc_loader: PullLoader,
|
doc_loader: LoadConnector,
|
||||||
chunker: Chunker,
|
chunker: Chunker,
|
||||||
embedder: Embedder,
|
embedder: Embedder,
|
||||||
datastore: Datastore,
|
datastore: Datastore,
|
||||||
) -> None:
|
) -> None:
|
||||||
num_processed = 0
|
num_processed = 0
|
||||||
total_chunks = 0
|
total_chunks = 0
|
||||||
for document_batch in doc_loader.load():
|
for document_batch in doc_loader.load_from_state():
|
||||||
if not document_batch:
|
if not document_batch:
|
||||||
logger.warning("No parseable documents found in batch")
|
logger.warning("No parseable documents found in batch")
|
||||||
continue
|
continue
|
||||||
@@ -53,7 +53,7 @@ def load_batch(
|
|||||||
def load_slack_batch(file_path: str, qdrant_collection: str) -> None:
|
def load_slack_batch(file_path: str, qdrant_collection: str) -> None:
|
||||||
logger.info("Loading documents from Slack.")
|
logger.info("Loading documents from Slack.")
|
||||||
load_batch(
|
load_batch(
|
||||||
BatchSlackLoader(export_path_str=file_path, batch_size=INDEX_BATCH_SIZE),
|
SlackConnector(export_path_str=file_path, batch_size=INDEX_BATCH_SIZE),
|
||||||
DefaultChunker(),
|
DefaultChunker(),
|
||||||
DefaultEmbedder(),
|
DefaultEmbedder(),
|
||||||
QdrantDatastore(collection=qdrant_collection),
|
QdrantDatastore(collection=qdrant_collection),
|
||||||
@@ -63,7 +63,7 @@ def load_slack_batch(file_path: str, qdrant_collection: str) -> None:
|
|||||||
def load_web_batch(url: str, qdrant_collection: str) -> None:
|
def load_web_batch(url: str, qdrant_collection: str) -> None:
|
||||||
logger.info("Loading documents from web.")
|
logger.info("Loading documents from web.")
|
||||||
load_batch(
|
load_batch(
|
||||||
WebLoader(base_url=url, batch_size=INDEX_BATCH_SIZE),
|
WebConnector(base_url=url, batch_size=INDEX_BATCH_SIZE),
|
||||||
DefaultChunker(),
|
DefaultChunker(),
|
||||||
DefaultEmbedder(),
|
DefaultEmbedder(),
|
||||||
QdrantDatastore(collection=qdrant_collection),
|
QdrantDatastore(collection=qdrant_collection),
|
||||||
@@ -74,7 +74,7 @@ def load_google_drive_batch(qdrant_collection: str) -> None:
|
|||||||
logger.info("Loading documents from Google Drive.")
|
logger.info("Loading documents from Google Drive.")
|
||||||
backend_get_credentials()
|
backend_get_credentials()
|
||||||
load_batch(
|
load_batch(
|
||||||
BatchGoogleDriveLoader(batch_size=INDEX_BATCH_SIZE),
|
GoogleDriveConnector(batch_size=INDEX_BATCH_SIZE),
|
||||||
DefaultChunker(),
|
DefaultChunker(),
|
||||||
DefaultEmbedder(),
|
DefaultEmbedder(),
|
||||||
QdrantDatastore(collection=qdrant_collection),
|
QdrantDatastore(collection=qdrant_collection),
|
||||||
@@ -84,9 +84,7 @@ def load_google_drive_batch(qdrant_collection: str) -> None:
|
|||||||
def load_github_batch(owner: str, repo: str, qdrant_collection: str) -> None:
|
def load_github_batch(owner: str, repo: str, qdrant_collection: str) -> None:
|
||||||
logger.info("Loading documents from Github.")
|
logger.info("Loading documents from Github.")
|
||||||
load_batch(
|
load_batch(
|
||||||
BatchGithubLoader(
|
GithubConnector(repo_owner=owner, repo_name=repo, batch_size=INDEX_BATCH_SIZE),
|
||||||
repo_owner=owner, repo_name=repo, batch_size=INDEX_BATCH_SIZE
|
|
||||||
),
|
|
||||||
DefaultChunker(),
|
DefaultChunker(),
|
||||||
DefaultEmbedder(),
|
DefaultEmbedder(),
|
||||||
QdrantDatastore(collection=qdrant_collection),
|
QdrantDatastore(collection=qdrant_collection),
|
||||||
|
@@ -15,7 +15,6 @@ const MainSection = () => {
|
|||||||
// "/api/admin/connectors/web/index-attempt",
|
// "/api/admin/connectors/web/index-attempt",
|
||||||
// fetcher
|
// fetcher
|
||||||
// );
|
// );
|
||||||
const router = useRouter();
|
|
||||||
|
|
||||||
const { mutate } = useSWRConfig();
|
const { mutate } = useSWRConfig();
|
||||||
const { data, isLoading, error } = useSWR<SlackConfig>(
|
const { data, isLoading, error } = useSWR<SlackConfig>(
|
||||||
|
@@ -2,11 +2,12 @@ import React, { useState } from "react";
|
|||||||
import { Formik, Form, FormikHelpers } from "formik";
|
import { Formik, Form, FormikHelpers } from "formik";
|
||||||
import * as Yup from "yup";
|
import * as Yup from "yup";
|
||||||
import { Popup } from "./Popup";
|
import { Popup } from "./Popup";
|
||||||
import { ValidSources } from "@/lib/types";
|
import { ValidInputTypes, ValidSources } from "@/lib/types";
|
||||||
|
|
||||||
export const submitIndexRequest = async (
|
export const submitIndexRequest = async (
|
||||||
source: ValidSources,
|
source: ValidSources,
|
||||||
values: Yup.AnyObject
|
values: Yup.AnyObject,
|
||||||
|
inputType: ValidInputTypes = "load_state"
|
||||||
): Promise<{ message: string; isSuccess: boolean }> => {
|
): Promise<{ message: string; isSuccess: boolean }> => {
|
||||||
let isSuccess = false;
|
let isSuccess = false;
|
||||||
try {
|
try {
|
||||||
@@ -17,7 +18,7 @@ export const submitIndexRequest = async (
|
|||||||
headers: {
|
headers: {
|
||||||
"Content-Type": "application/json",
|
"Content-Type": "application/json",
|
||||||
},
|
},
|
||||||
body: JSON.stringify({ connector_specific_config: values }),
|
body: JSON.stringify({ connector_specific_config: values, inputType }),
|
||||||
}
|
}
|
||||||
);
|
);
|
||||||
|
|
||||||
|
@@ -8,3 +8,4 @@ export interface User {
|
|||||||
}
|
}
|
||||||
|
|
||||||
export type ValidSources = "web" | "github" | "slack" | "google_drive";
|
export type ValidSources = "web" | "github" | "slack" | "google_drive";
|
||||||
|
export type ValidInputTypes = "load_state" | "poll" | "event";
|
||||||
|
Reference in New Issue
Block a user