mirror of
https://github.com/danswer-ai/danswer.git
synced 2025-04-10 12:59:59 +02:00
DAN-93 Standardize Connectors (#70)
This commit is contained in:
parent
51e05e3948
commit
7559ba6e9d
@ -2,12 +2,10 @@ import time
|
||||
from typing import cast
|
||||
|
||||
from danswer.configs.constants import DocumentSource
|
||||
from danswer.connectors.factory import build_connector
|
||||
from danswer.connectors.factory import build_pull_connector
|
||||
from danswer.connectors.factory import build_load_connector
|
||||
from danswer.connectors.models import InputType
|
||||
from danswer.connectors.slack.config import get_pull_frequency
|
||||
from danswer.connectors.slack.pull import PeriodicSlackLoader
|
||||
from danswer.connectors.web.pull import WebLoader
|
||||
from danswer.connectors.slack.connector import SlackConnector
|
||||
from danswer.db.index_attempt import fetch_index_attempts
|
||||
from danswer.db.index_attempt import insert_index_attempt
|
||||
from danswer.db.index_attempt import update_index_attempt
|
||||
@ -20,7 +18,7 @@ from danswer.utils.logging import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
LAST_PULL_KEY_TEMPLATE = "last_pull_{}"
|
||||
LAST_POLL_KEY_TEMPLATE = "last_poll_{}"
|
||||
|
||||
|
||||
def _check_should_run(current_time: int, last_pull: int, pull_frequency: int) -> bool:
|
||||
@ -43,9 +41,7 @@ def run_update() -> None:
|
||||
except ConfigNotFoundError:
|
||||
pull_frequency = 0
|
||||
if pull_frequency:
|
||||
last_slack_pull_key = LAST_PULL_KEY_TEMPLATE.format(
|
||||
PeriodicSlackLoader.__name__
|
||||
)
|
||||
last_slack_pull_key = LAST_POLL_KEY_TEMPLATE.format(SlackConnector.__name__)
|
||||
try:
|
||||
last_pull = cast(int, dynamic_config_store.load(last_slack_pull_key))
|
||||
except ConfigNotFoundError:
|
||||
@ -61,7 +57,7 @@ def run_update() -> None:
|
||||
insert_index_attempt(
|
||||
IndexAttempt(
|
||||
source=DocumentSource.SLACK,
|
||||
input_type=InputType.PULL,
|
||||
input_type=InputType.POLL,
|
||||
status=IndexingStatus.NOT_STARTED,
|
||||
connector_specific_config={},
|
||||
)
|
||||
@ -75,7 +71,7 @@ def run_update() -> None:
|
||||
# prevent race conditions across multiple background jobs. For now,
|
||||
# this assumes we only ever run a single background job at a time
|
||||
not_started_index_attempts = fetch_index_attempts(
|
||||
input_types=[InputType.PULL], statuses=[IndexingStatus.NOT_STARTED]
|
||||
input_types=[InputType.LOAD_STATE], statuses=[IndexingStatus.NOT_STARTED]
|
||||
)
|
||||
for not_started_index_attempt in not_started_index_attempts:
|
||||
logger.info(
|
||||
@ -94,13 +90,13 @@ def run_update() -> None:
|
||||
try:
|
||||
# TODO (chris): spawn processes to parallelize / take advantage of
|
||||
# multiple cores + implement retries
|
||||
connector = build_pull_connector(
|
||||
connector = build_load_connector(
|
||||
source=not_started_index_attempt.source,
|
||||
connector_specific_config=not_started_index_attempt.connector_specific_config,
|
||||
)
|
||||
|
||||
document_ids: list[str] = []
|
||||
for doc_batch in connector.load():
|
||||
for doc_batch in connector.load_from_state():
|
||||
indexing_pipeline(doc_batch)
|
||||
document_ids.extend([doc.id for doc in doc_batch])
|
||||
except Exception as e:
|
||||
|
@ -1,6 +1,7 @@
|
||||
import inspect
|
||||
from dataclasses import dataclass
|
||||
from typing import Any
|
||||
from typing import cast
|
||||
|
||||
from danswer.connectors.models import Document
|
||||
|
||||
@ -34,10 +35,12 @@ class InferenceChunk(BaseChunk):
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, init_dict: dict[str, Any]) -> "InferenceChunk":
|
||||
return cls(
|
||||
**{
|
||||
k: v
|
||||
for k, v in init_dict.items()
|
||||
if k in inspect.signature(cls).parameters
|
||||
init_kwargs = {
|
||||
k: v for k, v in init_dict.items() if k in inspect.signature(cls).parameters
|
||||
}
|
||||
if "source_links" in init_kwargs:
|
||||
init_kwargs["source_links"] = {
|
||||
int(k): v
|
||||
for k, v in cast(dict[str, str], init_kwargs["source_links"]).items()
|
||||
}
|
||||
)
|
||||
return cls(**init_kwargs)
|
||||
|
@ -3,15 +3,16 @@ from collections.abc import Generator
|
||||
from typing import Any
|
||||
|
||||
from danswer.configs.constants import DocumentSource
|
||||
from danswer.connectors.github.batch import BatchGithubLoader
|
||||
from danswer.connectors.google_drive.batch import BatchGoogleDriveLoader
|
||||
from danswer.connectors.interfaces import PullLoader
|
||||
from danswer.connectors.interfaces import RangePullLoader
|
||||
from danswer.connectors.github.connector import GithubConnector
|
||||
from danswer.connectors.google_drive.connector import GoogleDriveConnector
|
||||
from danswer.connectors.interfaces import BaseConnector
|
||||
from danswer.connectors.interfaces import EventConnector
|
||||
from danswer.connectors.interfaces import LoadConnector
|
||||
from danswer.connectors.interfaces import PollConnector
|
||||
from danswer.connectors.models import Document
|
||||
from danswer.connectors.models import InputType
|
||||
from danswer.connectors.slack.batch import BatchSlackLoader
|
||||
from danswer.connectors.slack.pull import PeriodicSlackLoader
|
||||
from danswer.connectors.web.pull import WebLoader
|
||||
from danswer.connectors.slack.connector import SlackConnector
|
||||
from danswer.connectors.web.connector import WebConnector
|
||||
|
||||
_NUM_SECONDS_IN_DAY = 86400
|
||||
|
||||
@ -24,45 +25,52 @@ def build_connector(
|
||||
source: DocumentSource,
|
||||
input_type: InputType,
|
||||
connector_specific_config: dict[str, Any],
|
||||
) -> PullLoader | RangePullLoader:
|
||||
) -> BaseConnector:
|
||||
if source == DocumentSource.SLACK:
|
||||
if input_type == InputType.PULL:
|
||||
return PeriodicSlackLoader(**connector_specific_config)
|
||||
if input_type == InputType.LOAD_STATE:
|
||||
return BatchSlackLoader(**connector_specific_config)
|
||||
connector: BaseConnector = SlackConnector(**connector_specific_config)
|
||||
elif source == DocumentSource.GOOGLE_DRIVE:
|
||||
if input_type == InputType.PULL:
|
||||
return BatchGoogleDriveLoader(**connector_specific_config)
|
||||
connector = GoogleDriveConnector(**connector_specific_config)
|
||||
elif source == DocumentSource.GITHUB:
|
||||
if input_type == InputType.PULL:
|
||||
return BatchGithubLoader(**connector_specific_config)
|
||||
connector = GithubConnector(**connector_specific_config)
|
||||
elif source == DocumentSource.WEB:
|
||||
if input_type == InputType.PULL:
|
||||
return WebLoader(**connector_specific_config)
|
||||
connector = WebConnector(**connector_specific_config)
|
||||
else:
|
||||
raise ConnectorMissingException(f"Connector not found for source={source}")
|
||||
|
||||
raise ConnectorMissingException(
|
||||
f"Connector not found for source={source}, input_type={input_type}"
|
||||
)
|
||||
if any(
|
||||
[
|
||||
input_type == InputType.LOAD_STATE
|
||||
and not isinstance(connector, LoadConnector),
|
||||
input_type == InputType.POLL and not isinstance(connector, PollConnector),
|
||||
input_type == InputType.EVENT and not isinstance(connector, EventConnector),
|
||||
]
|
||||
):
|
||||
raise ConnectorMissingException(
|
||||
f"Connector for source={source} does not accept input_type={input_type}"
|
||||
)
|
||||
|
||||
return connector
|
||||
|
||||
|
||||
def build_pull_connector(
|
||||
source: DocumentSource, connector_specific_config: dict[str, Any]
|
||||
) -> PullLoader:
|
||||
connector = build_connector(source, InputType.PULL, connector_specific_config)
|
||||
return (
|
||||
_range_pull_to_pull(connector)
|
||||
if isinstance(connector, RangePullLoader)
|
||||
else connector
|
||||
)
|
||||
|
||||
|
||||
def _range_pull_to_pull(range_pull_connector: RangePullLoader) -> PullLoader:
|
||||
class _Connector(PullLoader):
|
||||
# TODO this is some jank, rework at some point
|
||||
def _poll_to_load_connector(range_pull_connector: PollConnector) -> LoadConnector:
|
||||
class _Connector(LoadConnector):
|
||||
def __init__(self) -> None:
|
||||
self._connector = range_pull_connector
|
||||
|
||||
def load(self) -> Generator[list[Document], None, None]:
|
||||
def load_from_state(self) -> Generator[list[Document], None, None]:
|
||||
# adding some buffer to make sure we get all documents
|
||||
return self._connector.load(0, time.time() + _NUM_SECONDS_IN_DAY)
|
||||
return self._connector.poll_source(0, time.time() + _NUM_SECONDS_IN_DAY)
|
||||
|
||||
return _Connector()
|
||||
|
||||
|
||||
# TODO this is some jank, rework at some point
|
||||
def build_load_connector(
|
||||
source: DocumentSource, connector_specific_config: dict[str, Any]
|
||||
) -> LoadConnector:
|
||||
connector = build_connector(source, InputType.LOAD_STATE, connector_specific_config)
|
||||
if isinstance(connector, PollConnector):
|
||||
return _poll_to_load_connector(connector)
|
||||
assert isinstance(connector, LoadConnector)
|
||||
return connector
|
||||
|
@ -4,7 +4,7 @@ from collections.abc import Generator
|
||||
from danswer.configs.app_configs import GITHUB_ACCESS_TOKEN
|
||||
from danswer.configs.app_configs import INDEX_BATCH_SIZE
|
||||
from danswer.configs.constants import DocumentSource
|
||||
from danswer.connectors.interfaces import PullLoader
|
||||
from danswer.connectors.interfaces import LoadConnector
|
||||
from danswer.connectors.models import Document
|
||||
from danswer.connectors.models import Section
|
||||
from danswer.utils.logging import setup_logger
|
||||
@ -29,7 +29,7 @@ def get_pr_batches(
|
||||
yield batch
|
||||
|
||||
|
||||
class BatchGithubLoader(PullLoader):
|
||||
class GithubConnector(LoadConnector):
|
||||
def __init__(
|
||||
self,
|
||||
repo_owner: str,
|
||||
@ -42,7 +42,7 @@ class BatchGithubLoader(PullLoader):
|
||||
self.batch_size = batch_size
|
||||
self.state_filter = state_filter
|
||||
|
||||
def load(self) -> Generator[list[Document], None, None]:
|
||||
def load_from_state(self) -> Generator[list[Document], None, None]:
|
||||
repo = github_client.get_repo(f"{self.repo_owner}/{self.repo_name}")
|
||||
pull_requests = repo.get_pulls(state=self.state_filter)
|
||||
for pr_batch in get_pr_batches(pull_requests, self.batch_size):
|
@ -5,7 +5,7 @@ from danswer.configs.app_configs import GOOGLE_DRIVE_INCLUDE_SHARED
|
||||
from danswer.configs.app_configs import INDEX_BATCH_SIZE
|
||||
from danswer.configs.constants import DocumentSource
|
||||
from danswer.connectors.google_drive.connector_auth import get_drive_tokens
|
||||
from danswer.connectors.interfaces import PullLoader
|
||||
from danswer.connectors.interfaces import LoadConnector
|
||||
from danswer.connectors.models import Document
|
||||
from danswer.connectors.models import Section
|
||||
from danswer.utils.logging import setup_logger
|
||||
@ -81,11 +81,7 @@ def extract_text(file: dict[str, str], service: discovery.Resource) -> str:
|
||||
return "\n".join(page.extract_text() for page in pdf_reader.pages)
|
||||
|
||||
|
||||
class BatchGoogleDriveLoader(PullLoader):
|
||||
"""
|
||||
Loads everything in a Google Drive account
|
||||
"""
|
||||
|
||||
class GoogleDriveConnector(LoadConnector):
|
||||
def __init__(
|
||||
self,
|
||||
batch_size: int = INDEX_BATCH_SIZE,
|
||||
@ -98,7 +94,7 @@ class BatchGoogleDriveLoader(PullLoader):
|
||||
if not self.creds:
|
||||
raise PermissionError("Unable to access Google Drive.")
|
||||
|
||||
def load(self) -> Generator[list[Document], None, None]:
|
||||
def load_from_state(self) -> Generator[list[Document], None, None]:
|
||||
service = discovery.build("drive", "v3", credentials=self.creds)
|
||||
for files_batch in get_file_batches(
|
||||
service, self.include_shared, self.batch_size
|
@ -8,22 +8,29 @@ from danswer.connectors.models import Document
|
||||
SecondsSinceUnixEpoch = float
|
||||
|
||||
|
||||
# TODO (chris): rename from Loader -> Connector
|
||||
class PullLoader:
|
||||
class BaseConnector(abc.ABC):
|
||||
# Reserved for future shared uses
|
||||
pass
|
||||
|
||||
|
||||
# Large set update or reindex, generally pulling a complete state or from a savestate file
|
||||
class LoadConnector(BaseConnector):
|
||||
@abc.abstractmethod
|
||||
def load(self) -> Generator[list[Document], None, None]:
|
||||
def load_from_state(self) -> Generator[list[Document], None, None]:
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class RangePullLoader:
|
||||
# Small set updates by time
|
||||
class PollConnector(BaseConnector):
|
||||
@abc.abstractmethod
|
||||
def load(
|
||||
def poll_source(
|
||||
self, start: SecondsSinceUnixEpoch, end: SecondsSinceUnixEpoch
|
||||
) -> Generator[list[Document], None, None]:
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class PushLoader:
|
||||
# Event driven
|
||||
class EventConnector(BaseConnector):
|
||||
@abc.abstractmethod
|
||||
def load(self, event: Any) -> Generator[list[Document], None, None]:
|
||||
def handle_event(self, event: Any) -> Generator[list[Document], None, None]:
|
||||
raise NotImplementedError
|
||||
|
@ -26,9 +26,9 @@ def get_raw_document_text(document: Document) -> str:
|
||||
|
||||
|
||||
class InputType(str, Enum):
|
||||
PULL = "pull" # e.g. calling slack API to get all messages in the last hour
|
||||
LOAD_STATE = "load_state" # e.g. loading the state of a slack workspace from a file
|
||||
EVENT = "event" # e.g. registered an endpoint as a listener, and processing slack events
|
||||
LOAD_STATE = "load_state" # e.g. loading a current full state or a save state, such as from a file
|
||||
POLL = "poll" # e.g. calling an API to get all documents in the last hour
|
||||
EVENT = "event" # e.g. registered an endpoint as a listener, and processing connector events
|
||||
|
||||
|
||||
class ConnectorDescriptor(BaseModel):
|
||||
|
@ -1,99 +0,0 @@
|
||||
import json
|
||||
import os
|
||||
from collections.abc import Generator
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
from typing import cast
|
||||
|
||||
from danswer.configs.app_configs import INDEX_BATCH_SIZE
|
||||
from danswer.configs.constants import DocumentSource
|
||||
from danswer.connectors.interfaces import PullLoader
|
||||
from danswer.connectors.models import Document
|
||||
from danswer.connectors.models import Section
|
||||
from danswer.connectors.slack.utils import get_message_link
|
||||
|
||||
|
||||
def _process_batch_event(
|
||||
slack_event: dict[str, Any],
|
||||
channel: dict[str, Any],
|
||||
matching_doc: Document | None,
|
||||
workspace: str | None = None,
|
||||
) -> Document | None:
|
||||
if (
|
||||
slack_event["type"] == "message"
|
||||
and slack_event.get("subtype") != "channel_join"
|
||||
):
|
||||
if matching_doc:
|
||||
return Document(
|
||||
id=matching_doc.id,
|
||||
sections=matching_doc.sections
|
||||
+ [
|
||||
Section(
|
||||
link=get_message_link(
|
||||
slack_event, workspace=workspace, channel_id=channel["id"]
|
||||
),
|
||||
text=slack_event["text"],
|
||||
)
|
||||
],
|
||||
source=matching_doc.source,
|
||||
semantic_identifier=matching_doc.semantic_identifier,
|
||||
metadata=matching_doc.metadata,
|
||||
)
|
||||
|
||||
return Document(
|
||||
id=slack_event["ts"],
|
||||
sections=[
|
||||
Section(
|
||||
link=get_message_link(
|
||||
slack_event, workspace=workspace, channel_id=channel["id"]
|
||||
),
|
||||
text=slack_event["text"],
|
||||
)
|
||||
],
|
||||
source=DocumentSource.SLACK,
|
||||
semantic_identifier=channel["name"],
|
||||
metadata={},
|
||||
)
|
||||
|
||||
return None
|
||||
|
||||
|
||||
class BatchSlackLoader(PullLoader):
|
||||
"""Loads from an unzipped slack workspace export"""
|
||||
|
||||
def __init__(
|
||||
self, export_path_str: str, batch_size: int = INDEX_BATCH_SIZE
|
||||
) -> None:
|
||||
self.export_path_str = export_path_str
|
||||
self.batch_size = batch_size
|
||||
|
||||
def load(self) -> Generator[list[Document], None, None]:
|
||||
export_path = Path(self.export_path_str)
|
||||
|
||||
with open(export_path / "channels.json") as f:
|
||||
channels = json.load(f)
|
||||
|
||||
document_batch: dict[str, Document] = {}
|
||||
for channel_info in channels:
|
||||
channel_dir_path = export_path / cast(str, channel_info["name"])
|
||||
channel_file_paths = [
|
||||
channel_dir_path / file_name
|
||||
for file_name in os.listdir(channel_dir_path)
|
||||
]
|
||||
for path in channel_file_paths:
|
||||
with open(path) as f:
|
||||
events = cast(list[dict[str, Any]], json.load(f))
|
||||
for slack_event in events:
|
||||
doc = _process_batch_event(
|
||||
slack_event=slack_event,
|
||||
channel=channel_info,
|
||||
matching_doc=document_batch.get(
|
||||
slack_event.get("thread_ts", "")
|
||||
),
|
||||
)
|
||||
if doc:
|
||||
document_batch[doc.id] = doc
|
||||
if len(document_batch) >= self.batch_size:
|
||||
yield list(document_batch.values())
|
||||
|
||||
yield list(document_batch.values())
|
@ -1,13 +1,17 @@
|
||||
import json
|
||||
import os
|
||||
import time
|
||||
from collections.abc import Callable
|
||||
from collections.abc import Generator
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
from typing import cast
|
||||
from typing import List
|
||||
|
||||
from danswer.configs.app_configs import INDEX_BATCH_SIZE
|
||||
from danswer.configs.constants import DocumentSource
|
||||
from danswer.connectors.interfaces import RangePullLoader
|
||||
from danswer.connectors.interfaces import LoadConnector
|
||||
from danswer.connectors.interfaces import PollConnector
|
||||
from danswer.connectors.interfaces import SecondsSinceUnixEpoch
|
||||
from danswer.connectors.models import Document
|
||||
from danswer.connectors.models import Section
|
||||
@ -200,12 +204,91 @@ def get_all_docs(
|
||||
return docs
|
||||
|
||||
|
||||
class PeriodicSlackLoader(RangePullLoader):
|
||||
def __init__(self, batch_size: int = INDEX_BATCH_SIZE) -> None:
|
||||
self.client = get_client()
|
||||
self.batch_size = batch_size
|
||||
def _process_batch_event(
|
||||
slack_event: dict[str, Any],
|
||||
channel: dict[str, Any],
|
||||
matching_doc: Document | None,
|
||||
workspace: str | None = None,
|
||||
) -> Document | None:
|
||||
if (
|
||||
slack_event["type"] == "message"
|
||||
and slack_event.get("subtype") != "channel_join"
|
||||
):
|
||||
if matching_doc:
|
||||
return Document(
|
||||
id=matching_doc.id,
|
||||
sections=matching_doc.sections
|
||||
+ [
|
||||
Section(
|
||||
link=get_message_link(
|
||||
slack_event, workspace=workspace, channel_id=channel["id"]
|
||||
),
|
||||
text=slack_event["text"],
|
||||
)
|
||||
],
|
||||
source=matching_doc.source,
|
||||
semantic_identifier=matching_doc.semantic_identifier,
|
||||
metadata=matching_doc.metadata,
|
||||
)
|
||||
|
||||
def load(
|
||||
return Document(
|
||||
id=slack_event["ts"],
|
||||
sections=[
|
||||
Section(
|
||||
link=get_message_link(
|
||||
slack_event, workspace=workspace, channel_id=channel["id"]
|
||||
),
|
||||
text=slack_event["text"],
|
||||
)
|
||||
],
|
||||
source=DocumentSource.SLACK,
|
||||
semantic_identifier=channel["name"],
|
||||
metadata={},
|
||||
)
|
||||
|
||||
return None
|
||||
|
||||
|
||||
class SlackConnector(LoadConnector, PollConnector):
|
||||
def __init__(
|
||||
self, export_path_str: str, batch_size: int = INDEX_BATCH_SIZE
|
||||
) -> None:
|
||||
self.export_path_str = export_path_str
|
||||
self.batch_size = batch_size
|
||||
self.client = get_client()
|
||||
|
||||
def load_from_state(self) -> Generator[list[Document], None, None]:
|
||||
export_path = Path(self.export_path_str)
|
||||
|
||||
with open(export_path / "channels.json") as f:
|
||||
channels = json.load(f)
|
||||
|
||||
document_batch: dict[str, Document] = {}
|
||||
for channel_info in channels:
|
||||
channel_dir_path = export_path / cast(str, channel_info["name"])
|
||||
channel_file_paths = [
|
||||
channel_dir_path / file_name
|
||||
for file_name in os.listdir(channel_dir_path)
|
||||
]
|
||||
for path in channel_file_paths:
|
||||
with open(path) as f:
|
||||
events = cast(list[dict[str, Any]], json.load(f))
|
||||
for slack_event in events:
|
||||
doc = _process_batch_event(
|
||||
slack_event=slack_event,
|
||||
channel=channel_info,
|
||||
matching_doc=document_batch.get(
|
||||
slack_event.get("thread_ts", "")
|
||||
),
|
||||
)
|
||||
if doc:
|
||||
document_batch[doc.id] = doc
|
||||
if len(document_batch) >= self.batch_size:
|
||||
yield list(document_batch.values())
|
||||
|
||||
yield list(document_batch.values())
|
||||
|
||||
def poll_source(
|
||||
self, start: SecondsSinceUnixEpoch, end: SecondsSinceUnixEpoch
|
||||
) -> Generator[List[Document], None, None]:
|
||||
all_docs = get_all_docs(client=self.client, oldest=str(start), latest=str(end))
|
@ -9,7 +9,7 @@ import requests
|
||||
from bs4 import BeautifulSoup
|
||||
from danswer.configs.app_configs import INDEX_BATCH_SIZE
|
||||
from danswer.configs.constants import DocumentSource
|
||||
from danswer.connectors.interfaces import PullLoader
|
||||
from danswer.connectors.interfaces import LoadConnector
|
||||
from danswer.connectors.models import Document
|
||||
from danswer.connectors.models import Section
|
||||
from danswer.utils.logging import setup_logger
|
||||
@ -51,7 +51,7 @@ def get_internal_links(
|
||||
return internal_links
|
||||
|
||||
|
||||
class WebLoader(PullLoader):
|
||||
class WebConnector(LoadConnector):
|
||||
def __init__(
|
||||
self,
|
||||
base_url: str,
|
||||
@ -60,7 +60,7 @@ class WebLoader(PullLoader):
|
||||
self.base_url = base_url
|
||||
self.batch_size = batch_size
|
||||
|
||||
def load(self) -> Generator[list[Document], None, None]:
|
||||
def load_from_state(self) -> Generator[list[Document], None, None]:
|
||||
"""Traverses through all pages found on the website
|
||||
and converts them into documents"""
|
||||
visited_links: set[str] = set()
|
||||
@ -88,8 +88,8 @@ class WebLoader(PullLoader):
|
||||
response = requests.get(current_url)
|
||||
pdf_reader = PdfReader(io.BytesIO(response.content))
|
||||
page_text = ""
|
||||
for page in pdf_reader.pages:
|
||||
page_text += page.extract_text()
|
||||
for pdf_page in pdf_reader.pages:
|
||||
page_text += pdf_page.extract_text()
|
||||
|
||||
doc_batch.append(
|
||||
Document(
|
@ -376,7 +376,7 @@ class OpenAIChatCompletionQA(OpenAIQAModel):
|
||||
found_answer_start = False
|
||||
found_answer_end = False
|
||||
for event in response:
|
||||
event_dict = cast(str, event["choices"][0]["delta"])
|
||||
event_dict = event["choices"][0]["delta"]
|
||||
if (
|
||||
"content" not in event_dict
|
||||
): # could be a role message or empty termination
|
||||
|
@ -93,7 +93,11 @@ def retrieve_ranked_documents(
|
||||
return None
|
||||
ranked_chunks = semantic_reranking(query, top_chunks)
|
||||
|
||||
top_docs = [ranked_chunk.source_links["0"] for ranked_chunk in ranked_chunks]
|
||||
top_docs = [
|
||||
ranked_chunk.source_links[0]
|
||||
for ranked_chunk in ranked_chunks
|
||||
if ranked_chunk.source_links is not None
|
||||
]
|
||||
files_log_msg = f"Top links from semantic search: {', '.join(top_docs)}"
|
||||
logger.info(files_log_msg)
|
||||
|
||||
|
@ -87,7 +87,7 @@ def index(
|
||||
_: User = Depends(current_admin_user),
|
||||
) -> None:
|
||||
# validate that the connector specified by the source / input_type combination
|
||||
# exists AND that the connector_specific_config is valid for that connector type
|
||||
# exists AND that the connector_specific_config is valid for that connector type, should be load
|
||||
build_connector(
|
||||
source=source,
|
||||
input_type=index_attempt_request.input_type,
|
||||
|
@ -1,8 +1,8 @@
|
||||
from typing import Any
|
||||
|
||||
from danswer.connectors.slack.pull import get_channel_info
|
||||
from danswer.connectors.slack.pull import get_thread
|
||||
from danswer.connectors.slack.pull import thread_to_doc
|
||||
from danswer.connectors.slack.connector import get_channel_info
|
||||
from danswer.connectors.slack.connector import get_thread
|
||||
from danswer.connectors.slack.connector import thread_to_doc
|
||||
from danswer.connectors.slack.utils import get_client
|
||||
from danswer.utils.indexing_pipeline import build_indexing_pipeline
|
||||
from danswer.utils.logging import setup_logger
|
||||
|
@ -32,7 +32,7 @@ class UserRoleResponse(BaseModel):
|
||||
|
||||
class SearchDoc(BaseModel):
|
||||
semantic_identifier: str
|
||||
link: str
|
||||
link: str | None
|
||||
blurb: str
|
||||
source_type: str
|
||||
|
||||
|
@ -78,7 +78,7 @@ def direct_qa(
|
||||
top_docs = [
|
||||
SearchDoc(
|
||||
semantic_identifier=chunk.semantic_identifier,
|
||||
link=chunk.source_links.get("0") if chunk.source_links else None,
|
||||
link=chunk.source_links.get(0) if chunk.source_links else None,
|
||||
blurb=chunk.blurb,
|
||||
source_type=chunk.source_type,
|
||||
)
|
||||
@ -117,7 +117,7 @@ def stream_direct_qa(
|
||||
top_docs = [
|
||||
SearchDoc(
|
||||
semantic_identifier=chunk.semantic_identifier,
|
||||
link=chunk.source_links.get("0") if chunk.source_links else None,
|
||||
link=chunk.source_links.get(0) if chunk.source_links else None,
|
||||
blurb=chunk.blurb,
|
||||
source_type=chunk.source_type,
|
||||
)
|
||||
@ -130,6 +130,8 @@ def stream_direct_qa(
|
||||
for response_dict in qa_model.answer_question_stream(
|
||||
query, ranked_chunks[:NUM_RERANKED_RESULTS]
|
||||
):
|
||||
if response_dict is None:
|
||||
continue
|
||||
logger.debug(response_dict)
|
||||
yield get_json_line(response_dict)
|
||||
return
|
||||
|
@ -1,6 +1,7 @@
|
||||
import time
|
||||
from collections.abc import Callable
|
||||
from typing import Any
|
||||
from typing import cast
|
||||
from typing import TypeVar
|
||||
|
||||
from danswer.utils.logging import setup_logger
|
||||
@ -21,7 +22,7 @@ def log_function_time(
|
||||
...
|
||||
"""
|
||||
|
||||
def timing_wrapper(func: Callable) -> Callable:
|
||||
def timing_wrapper(func: F) -> F:
|
||||
def wrapped_func(*args: Any, **kwargs: Any) -> Any:
|
||||
start_time = time.time()
|
||||
result = func(*args, **kwargs)
|
||||
@ -30,6 +31,6 @@ def log_function_time(
|
||||
)
|
||||
return result
|
||||
|
||||
return wrapped_func
|
||||
return cast(F, wrapped_func)
|
||||
|
||||
return timing_wrapper
|
||||
|
@ -5,12 +5,12 @@ from danswer.chunking.chunk import Chunker
|
||||
from danswer.chunking.chunk import DefaultChunker
|
||||
from danswer.configs.app_configs import INDEX_BATCH_SIZE
|
||||
from danswer.configs.app_configs import QDRANT_DEFAULT_COLLECTION
|
||||
from danswer.connectors.github.batch import BatchGithubLoader
|
||||
from danswer.connectors.google_drive.batch import BatchGoogleDriveLoader
|
||||
from danswer.connectors.github.connector import GithubConnector
|
||||
from danswer.connectors.google_drive.connector import GoogleDriveConnector
|
||||
from danswer.connectors.google_drive.connector_auth import backend_get_credentials
|
||||
from danswer.connectors.interfaces import PullLoader
|
||||
from danswer.connectors.slack.batch import BatchSlackLoader
|
||||
from danswer.connectors.web.pull import WebLoader
|
||||
from danswer.connectors.interfaces import LoadConnector
|
||||
from danswer.connectors.slack.connector import SlackConnector
|
||||
from danswer.connectors.web.connector import WebConnector
|
||||
from danswer.datastores.interfaces import Datastore
|
||||
from danswer.datastores.qdrant.indexing import recreate_collection
|
||||
from danswer.datastores.qdrant.store import QdrantDatastore
|
||||
@ -23,14 +23,14 @@ logger = setup_logger()
|
||||
|
||||
|
||||
def load_batch(
|
||||
doc_loader: PullLoader,
|
||||
doc_loader: LoadConnector,
|
||||
chunker: Chunker,
|
||||
embedder: Embedder,
|
||||
datastore: Datastore,
|
||||
) -> None:
|
||||
num_processed = 0
|
||||
total_chunks = 0
|
||||
for document_batch in doc_loader.load():
|
||||
for document_batch in doc_loader.load_from_state():
|
||||
if not document_batch:
|
||||
logger.warning("No parseable documents found in batch")
|
||||
continue
|
||||
@ -53,7 +53,7 @@ def load_batch(
|
||||
def load_slack_batch(file_path: str, qdrant_collection: str) -> None:
|
||||
logger.info("Loading documents from Slack.")
|
||||
load_batch(
|
||||
BatchSlackLoader(export_path_str=file_path, batch_size=INDEX_BATCH_SIZE),
|
||||
SlackConnector(export_path_str=file_path, batch_size=INDEX_BATCH_SIZE),
|
||||
DefaultChunker(),
|
||||
DefaultEmbedder(),
|
||||
QdrantDatastore(collection=qdrant_collection),
|
||||
@ -63,7 +63,7 @@ def load_slack_batch(file_path: str, qdrant_collection: str) -> None:
|
||||
def load_web_batch(url: str, qdrant_collection: str) -> None:
|
||||
logger.info("Loading documents from web.")
|
||||
load_batch(
|
||||
WebLoader(base_url=url, batch_size=INDEX_BATCH_SIZE),
|
||||
WebConnector(base_url=url, batch_size=INDEX_BATCH_SIZE),
|
||||
DefaultChunker(),
|
||||
DefaultEmbedder(),
|
||||
QdrantDatastore(collection=qdrant_collection),
|
||||
@ -74,7 +74,7 @@ def load_google_drive_batch(qdrant_collection: str) -> None:
|
||||
logger.info("Loading documents from Google Drive.")
|
||||
backend_get_credentials()
|
||||
load_batch(
|
||||
BatchGoogleDriveLoader(batch_size=INDEX_BATCH_SIZE),
|
||||
GoogleDriveConnector(batch_size=INDEX_BATCH_SIZE),
|
||||
DefaultChunker(),
|
||||
DefaultEmbedder(),
|
||||
QdrantDatastore(collection=qdrant_collection),
|
||||
@ -84,9 +84,7 @@ def load_google_drive_batch(qdrant_collection: str) -> None:
|
||||
def load_github_batch(owner: str, repo: str, qdrant_collection: str) -> None:
|
||||
logger.info("Loading documents from Github.")
|
||||
load_batch(
|
||||
BatchGithubLoader(
|
||||
repo_owner=owner, repo_name=repo, batch_size=INDEX_BATCH_SIZE
|
||||
),
|
||||
GithubConnector(repo_owner=owner, repo_name=repo, batch_size=INDEX_BATCH_SIZE),
|
||||
DefaultChunker(),
|
||||
DefaultEmbedder(),
|
||||
QdrantDatastore(collection=qdrant_collection),
|
||||
|
@ -15,7 +15,6 @@ const MainSection = () => {
|
||||
// "/api/admin/connectors/web/index-attempt",
|
||||
// fetcher
|
||||
// );
|
||||
const router = useRouter();
|
||||
|
||||
const { mutate } = useSWRConfig();
|
||||
const { data, isLoading, error } = useSWR<SlackConfig>(
|
||||
|
@ -2,11 +2,12 @@ import React, { useState } from "react";
|
||||
import { Formik, Form, FormikHelpers } from "formik";
|
||||
import * as Yup from "yup";
|
||||
import { Popup } from "./Popup";
|
||||
import { ValidSources } from "@/lib/types";
|
||||
import { ValidInputTypes, ValidSources } from "@/lib/types";
|
||||
|
||||
export const submitIndexRequest = async (
|
||||
source: ValidSources,
|
||||
values: Yup.AnyObject
|
||||
values: Yup.AnyObject,
|
||||
inputType: ValidInputTypes = "load_state"
|
||||
): Promise<{ message: string; isSuccess: boolean }> => {
|
||||
let isSuccess = false;
|
||||
try {
|
||||
@ -17,7 +18,7 @@ export const submitIndexRequest = async (
|
||||
headers: {
|
||||
"Content-Type": "application/json",
|
||||
},
|
||||
body: JSON.stringify({ connector_specific_config: values }),
|
||||
body: JSON.stringify({ connector_specific_config: values, inputType }),
|
||||
}
|
||||
);
|
||||
|
||||
|
@ -8,3 +8,4 @@ export interface User {
|
||||
}
|
||||
|
||||
export type ValidSources = "web" | "github" | "slack" | "google_drive";
|
||||
export type ValidInputTypes = "load_state" | "poll" | "event";
|
||||
|
Loading…
x
Reference in New Issue
Block a user