DAN-93 Standardize Connectors (#70)

This commit is contained in:
Yuhong Sun
2023-05-21 13:24:25 -07:00
committed by GitHub
parent 51e05e3948
commit 7559ba6e9d
21 changed files with 212 additions and 212 deletions

View File

@@ -2,12 +2,10 @@ import time
from typing import cast from typing import cast
from danswer.configs.constants import DocumentSource from danswer.configs.constants import DocumentSource
from danswer.connectors.factory import build_connector from danswer.connectors.factory import build_load_connector
from danswer.connectors.factory import build_pull_connector
from danswer.connectors.models import InputType from danswer.connectors.models import InputType
from danswer.connectors.slack.config import get_pull_frequency from danswer.connectors.slack.config import get_pull_frequency
from danswer.connectors.slack.pull import PeriodicSlackLoader from danswer.connectors.slack.connector import SlackConnector
from danswer.connectors.web.pull import WebLoader
from danswer.db.index_attempt import fetch_index_attempts from danswer.db.index_attempt import fetch_index_attempts
from danswer.db.index_attempt import insert_index_attempt from danswer.db.index_attempt import insert_index_attempt
from danswer.db.index_attempt import update_index_attempt from danswer.db.index_attempt import update_index_attempt
@@ -20,7 +18,7 @@ from danswer.utils.logging import setup_logger
logger = setup_logger() logger = setup_logger()
LAST_PULL_KEY_TEMPLATE = "last_pull_{}" LAST_POLL_KEY_TEMPLATE = "last_poll_{}"
def _check_should_run(current_time: int, last_pull: int, pull_frequency: int) -> bool: def _check_should_run(current_time: int, last_pull: int, pull_frequency: int) -> bool:
@@ -43,9 +41,7 @@ def run_update() -> None:
except ConfigNotFoundError: except ConfigNotFoundError:
pull_frequency = 0 pull_frequency = 0
if pull_frequency: if pull_frequency:
last_slack_pull_key = LAST_PULL_KEY_TEMPLATE.format( last_slack_pull_key = LAST_POLL_KEY_TEMPLATE.format(SlackConnector.__name__)
PeriodicSlackLoader.__name__
)
try: try:
last_pull = cast(int, dynamic_config_store.load(last_slack_pull_key)) last_pull = cast(int, dynamic_config_store.load(last_slack_pull_key))
except ConfigNotFoundError: except ConfigNotFoundError:
@@ -61,7 +57,7 @@ def run_update() -> None:
insert_index_attempt( insert_index_attempt(
IndexAttempt( IndexAttempt(
source=DocumentSource.SLACK, source=DocumentSource.SLACK,
input_type=InputType.PULL, input_type=InputType.POLL,
status=IndexingStatus.NOT_STARTED, status=IndexingStatus.NOT_STARTED,
connector_specific_config={}, connector_specific_config={},
) )
@@ -75,7 +71,7 @@ def run_update() -> None:
# prevent race conditions across multiple background jobs. For now, # prevent race conditions across multiple background jobs. For now,
# this assumes we only ever run a single background job at a time # this assumes we only ever run a single background job at a time
not_started_index_attempts = fetch_index_attempts( not_started_index_attempts = fetch_index_attempts(
input_types=[InputType.PULL], statuses=[IndexingStatus.NOT_STARTED] input_types=[InputType.LOAD_STATE], statuses=[IndexingStatus.NOT_STARTED]
) )
for not_started_index_attempt in not_started_index_attempts: for not_started_index_attempt in not_started_index_attempts:
logger.info( logger.info(
@@ -94,13 +90,13 @@ def run_update() -> None:
try: try:
# TODO (chris): spawn processes to parallelize / take advantage of # TODO (chris): spawn processes to parallelize / take advantage of
# multiple cores + implement retries # multiple cores + implement retries
connector = build_pull_connector( connector = build_load_connector(
source=not_started_index_attempt.source, source=not_started_index_attempt.source,
connector_specific_config=not_started_index_attempt.connector_specific_config, connector_specific_config=not_started_index_attempt.connector_specific_config,
) )
document_ids: list[str] = [] document_ids: list[str] = []
for doc_batch in connector.load(): for doc_batch in connector.load_from_state():
indexing_pipeline(doc_batch) indexing_pipeline(doc_batch)
document_ids.extend([doc.id for doc in doc_batch]) document_ids.extend([doc.id for doc in doc_batch])
except Exception as e: except Exception as e:

View File

@@ -1,6 +1,7 @@
import inspect import inspect
from dataclasses import dataclass from dataclasses import dataclass
from typing import Any from typing import Any
from typing import cast
from danswer.connectors.models import Document from danswer.connectors.models import Document
@@ -34,10 +35,12 @@ class InferenceChunk(BaseChunk):
@classmethod @classmethod
def from_dict(cls, init_dict: dict[str, Any]) -> "InferenceChunk": def from_dict(cls, init_dict: dict[str, Any]) -> "InferenceChunk":
return cls( init_kwargs = {
**{ k: v for k, v in init_dict.items() if k in inspect.signature(cls).parameters
k: v }
for k, v in init_dict.items() if "source_links" in init_kwargs:
if k in inspect.signature(cls).parameters init_kwargs["source_links"] = {
int(k): v
for k, v in cast(dict[str, str], init_kwargs["source_links"]).items()
} }
) return cls(**init_kwargs)

View File

@@ -3,15 +3,16 @@ from collections.abc import Generator
from typing import Any from typing import Any
from danswer.configs.constants import DocumentSource from danswer.configs.constants import DocumentSource
from danswer.connectors.github.batch import BatchGithubLoader from danswer.connectors.github.connector import GithubConnector
from danswer.connectors.google_drive.batch import BatchGoogleDriveLoader from danswer.connectors.google_drive.connector import GoogleDriveConnector
from danswer.connectors.interfaces import PullLoader from danswer.connectors.interfaces import BaseConnector
from danswer.connectors.interfaces import RangePullLoader from danswer.connectors.interfaces import EventConnector
from danswer.connectors.interfaces import LoadConnector
from danswer.connectors.interfaces import PollConnector
from danswer.connectors.models import Document from danswer.connectors.models import Document
from danswer.connectors.models import InputType from danswer.connectors.models import InputType
from danswer.connectors.slack.batch import BatchSlackLoader from danswer.connectors.slack.connector import SlackConnector
from danswer.connectors.slack.pull import PeriodicSlackLoader from danswer.connectors.web.connector import WebConnector
from danswer.connectors.web.pull import WebLoader
_NUM_SECONDS_IN_DAY = 86400 _NUM_SECONDS_IN_DAY = 86400
@@ -24,45 +25,52 @@ def build_connector(
source: DocumentSource, source: DocumentSource,
input_type: InputType, input_type: InputType,
connector_specific_config: dict[str, Any], connector_specific_config: dict[str, Any],
) -> PullLoader | RangePullLoader: ) -> BaseConnector:
if source == DocumentSource.SLACK: if source == DocumentSource.SLACK:
if input_type == InputType.PULL: connector: BaseConnector = SlackConnector(**connector_specific_config)
return PeriodicSlackLoader(**connector_specific_config)
if input_type == InputType.LOAD_STATE:
return BatchSlackLoader(**connector_specific_config)
elif source == DocumentSource.GOOGLE_DRIVE: elif source == DocumentSource.GOOGLE_DRIVE:
if input_type == InputType.PULL: connector = GoogleDriveConnector(**connector_specific_config)
return BatchGoogleDriveLoader(**connector_specific_config)
elif source == DocumentSource.GITHUB: elif source == DocumentSource.GITHUB:
if input_type == InputType.PULL: connector = GithubConnector(**connector_specific_config)
return BatchGithubLoader(**connector_specific_config)
elif source == DocumentSource.WEB: elif source == DocumentSource.WEB:
if input_type == InputType.PULL: connector = WebConnector(**connector_specific_config)
return WebLoader(**connector_specific_config) else:
raise ConnectorMissingException(f"Connector not found for source={source}")
raise ConnectorMissingException( if any(
f"Connector not found for source={source}, input_type={input_type}" [
) input_type == InputType.LOAD_STATE
and not isinstance(connector, LoadConnector),
input_type == InputType.POLL and not isinstance(connector, PollConnector),
input_type == InputType.EVENT and not isinstance(connector, EventConnector),
]
):
raise ConnectorMissingException(
f"Connector for source={source} does not accept input_type={input_type}"
)
return connector
def build_pull_connector( # TODO this is some jank, rework at some point
source: DocumentSource, connector_specific_config: dict[str, Any] def _poll_to_load_connector(range_pull_connector: PollConnector) -> LoadConnector:
) -> PullLoader: class _Connector(LoadConnector):
connector = build_connector(source, InputType.PULL, connector_specific_config)
return (
_range_pull_to_pull(connector)
if isinstance(connector, RangePullLoader)
else connector
)
def _range_pull_to_pull(range_pull_connector: RangePullLoader) -> PullLoader:
class _Connector(PullLoader):
def __init__(self) -> None: def __init__(self) -> None:
self._connector = range_pull_connector self._connector = range_pull_connector
def load(self) -> Generator[list[Document], None, None]: def load_from_state(self) -> Generator[list[Document], None, None]:
# adding some buffer to make sure we get all documents # adding some buffer to make sure we get all documents
return self._connector.load(0, time.time() + _NUM_SECONDS_IN_DAY) return self._connector.poll_source(0, time.time() + _NUM_SECONDS_IN_DAY)
return _Connector() return _Connector()
# TODO this is some jank, rework at some point
def build_load_connector(
source: DocumentSource, connector_specific_config: dict[str, Any]
) -> LoadConnector:
connector = build_connector(source, InputType.LOAD_STATE, connector_specific_config)
if isinstance(connector, PollConnector):
return _poll_to_load_connector(connector)
assert isinstance(connector, LoadConnector)
return connector

View File

@@ -4,7 +4,7 @@ from collections.abc import Generator
from danswer.configs.app_configs import GITHUB_ACCESS_TOKEN from danswer.configs.app_configs import GITHUB_ACCESS_TOKEN
from danswer.configs.app_configs import INDEX_BATCH_SIZE from danswer.configs.app_configs import INDEX_BATCH_SIZE
from danswer.configs.constants import DocumentSource from danswer.configs.constants import DocumentSource
from danswer.connectors.interfaces import PullLoader from danswer.connectors.interfaces import LoadConnector
from danswer.connectors.models import Document from danswer.connectors.models import Document
from danswer.connectors.models import Section from danswer.connectors.models import Section
from danswer.utils.logging import setup_logger from danswer.utils.logging import setup_logger
@@ -29,7 +29,7 @@ def get_pr_batches(
yield batch yield batch
class BatchGithubLoader(PullLoader): class GithubConnector(LoadConnector):
def __init__( def __init__(
self, self,
repo_owner: str, repo_owner: str,
@@ -42,7 +42,7 @@ class BatchGithubLoader(PullLoader):
self.batch_size = batch_size self.batch_size = batch_size
self.state_filter = state_filter self.state_filter = state_filter
def load(self) -> Generator[list[Document], None, None]: def load_from_state(self) -> Generator[list[Document], None, None]:
repo = github_client.get_repo(f"{self.repo_owner}/{self.repo_name}") repo = github_client.get_repo(f"{self.repo_owner}/{self.repo_name}")
pull_requests = repo.get_pulls(state=self.state_filter) pull_requests = repo.get_pulls(state=self.state_filter)
for pr_batch in get_pr_batches(pull_requests, self.batch_size): for pr_batch in get_pr_batches(pull_requests, self.batch_size):

View File

@@ -5,7 +5,7 @@ from danswer.configs.app_configs import GOOGLE_DRIVE_INCLUDE_SHARED
from danswer.configs.app_configs import INDEX_BATCH_SIZE from danswer.configs.app_configs import INDEX_BATCH_SIZE
from danswer.configs.constants import DocumentSource from danswer.configs.constants import DocumentSource
from danswer.connectors.google_drive.connector_auth import get_drive_tokens from danswer.connectors.google_drive.connector_auth import get_drive_tokens
from danswer.connectors.interfaces import PullLoader from danswer.connectors.interfaces import LoadConnector
from danswer.connectors.models import Document from danswer.connectors.models import Document
from danswer.connectors.models import Section from danswer.connectors.models import Section
from danswer.utils.logging import setup_logger from danswer.utils.logging import setup_logger
@@ -81,11 +81,7 @@ def extract_text(file: dict[str, str], service: discovery.Resource) -> str:
return "\n".join(page.extract_text() for page in pdf_reader.pages) return "\n".join(page.extract_text() for page in pdf_reader.pages)
class BatchGoogleDriveLoader(PullLoader): class GoogleDriveConnector(LoadConnector):
"""
Loads everything in a Google Drive account
"""
def __init__( def __init__(
self, self,
batch_size: int = INDEX_BATCH_SIZE, batch_size: int = INDEX_BATCH_SIZE,
@@ -98,7 +94,7 @@ class BatchGoogleDriveLoader(PullLoader):
if not self.creds: if not self.creds:
raise PermissionError("Unable to access Google Drive.") raise PermissionError("Unable to access Google Drive.")
def load(self) -> Generator[list[Document], None, None]: def load_from_state(self) -> Generator[list[Document], None, None]:
service = discovery.build("drive", "v3", credentials=self.creds) service = discovery.build("drive", "v3", credentials=self.creds)
for files_batch in get_file_batches( for files_batch in get_file_batches(
service, self.include_shared, self.batch_size service, self.include_shared, self.batch_size

View File

@@ -8,22 +8,29 @@ from danswer.connectors.models import Document
SecondsSinceUnixEpoch = float SecondsSinceUnixEpoch = float
# TODO (chris): rename from Loader -> Connector class BaseConnector(abc.ABC):
class PullLoader: # Reserved for future shared uses
pass
# Large set update or reindex, generally pulling a complete state or from a savestate file
class LoadConnector(BaseConnector):
@abc.abstractmethod @abc.abstractmethod
def load(self) -> Generator[list[Document], None, None]: def load_from_state(self) -> Generator[list[Document], None, None]:
raise NotImplementedError raise NotImplementedError
class RangePullLoader: # Small set updates by time
class PollConnector(BaseConnector):
@abc.abstractmethod @abc.abstractmethod
def load( def poll_source(
self, start: SecondsSinceUnixEpoch, end: SecondsSinceUnixEpoch self, start: SecondsSinceUnixEpoch, end: SecondsSinceUnixEpoch
) -> Generator[list[Document], None, None]: ) -> Generator[list[Document], None, None]:
raise NotImplementedError raise NotImplementedError
class PushLoader: # Event driven
class EventConnector(BaseConnector):
@abc.abstractmethod @abc.abstractmethod
def load(self, event: Any) -> Generator[list[Document], None, None]: def handle_event(self, event: Any) -> Generator[list[Document], None, None]:
raise NotImplementedError raise NotImplementedError

View File

@@ -26,9 +26,9 @@ def get_raw_document_text(document: Document) -> str:
class InputType(str, Enum): class InputType(str, Enum):
PULL = "pull" # e.g. calling slack API to get all messages in the last hour LOAD_STATE = "load_state" # e.g. loading a current full state or a save state, such as from a file
LOAD_STATE = "load_state" # e.g. loading the state of a slack workspace from a file POLL = "poll" # e.g. calling an API to get all documents in the last hour
EVENT = "event" # e.g. registered an endpoint as a listener, and processing slack events EVENT = "event" # e.g. registered an endpoint as a listener, and processing connector events
class ConnectorDescriptor(BaseModel): class ConnectorDescriptor(BaseModel):

View File

@@ -1,99 +0,0 @@
import json
import os
from collections.abc import Generator
from pathlib import Path
from typing import Any
from typing import cast
from danswer.configs.app_configs import INDEX_BATCH_SIZE
from danswer.configs.constants import DocumentSource
from danswer.connectors.interfaces import PullLoader
from danswer.connectors.models import Document
from danswer.connectors.models import Section
from danswer.connectors.slack.utils import get_message_link
def _process_batch_event(
slack_event: dict[str, Any],
channel: dict[str, Any],
matching_doc: Document | None,
workspace: str | None = None,
) -> Document | None:
if (
slack_event["type"] == "message"
and slack_event.get("subtype") != "channel_join"
):
if matching_doc:
return Document(
id=matching_doc.id,
sections=matching_doc.sections
+ [
Section(
link=get_message_link(
slack_event, workspace=workspace, channel_id=channel["id"]
),
text=slack_event["text"],
)
],
source=matching_doc.source,
semantic_identifier=matching_doc.semantic_identifier,
metadata=matching_doc.metadata,
)
return Document(
id=slack_event["ts"],
sections=[
Section(
link=get_message_link(
slack_event, workspace=workspace, channel_id=channel["id"]
),
text=slack_event["text"],
)
],
source=DocumentSource.SLACK,
semantic_identifier=channel["name"],
metadata={},
)
return None
class BatchSlackLoader(PullLoader):
"""Loads from an unzipped slack workspace export"""
def __init__(
self, export_path_str: str, batch_size: int = INDEX_BATCH_SIZE
) -> None:
self.export_path_str = export_path_str
self.batch_size = batch_size
def load(self) -> Generator[list[Document], None, None]:
export_path = Path(self.export_path_str)
with open(export_path / "channels.json") as f:
channels = json.load(f)
document_batch: dict[str, Document] = {}
for channel_info in channels:
channel_dir_path = export_path / cast(str, channel_info["name"])
channel_file_paths = [
channel_dir_path / file_name
for file_name in os.listdir(channel_dir_path)
]
for path in channel_file_paths:
with open(path) as f:
events = cast(list[dict[str, Any]], json.load(f))
for slack_event in events:
doc = _process_batch_event(
slack_event=slack_event,
channel=channel_info,
matching_doc=document_batch.get(
slack_event.get("thread_ts", "")
),
)
if doc:
document_batch[doc.id] = doc
if len(document_batch) >= self.batch_size:
yield list(document_batch.values())
yield list(document_batch.values())

View File

@@ -1,13 +1,17 @@
import json
import os
import time import time
from collections.abc import Callable from collections.abc import Callable
from collections.abc import Generator from collections.abc import Generator
from pathlib import Path
from typing import Any from typing import Any
from typing import cast from typing import cast
from typing import List from typing import List
from danswer.configs.app_configs import INDEX_BATCH_SIZE from danswer.configs.app_configs import INDEX_BATCH_SIZE
from danswer.configs.constants import DocumentSource from danswer.configs.constants import DocumentSource
from danswer.connectors.interfaces import RangePullLoader from danswer.connectors.interfaces import LoadConnector
from danswer.connectors.interfaces import PollConnector
from danswer.connectors.interfaces import SecondsSinceUnixEpoch from danswer.connectors.interfaces import SecondsSinceUnixEpoch
from danswer.connectors.models import Document from danswer.connectors.models import Document
from danswer.connectors.models import Section from danswer.connectors.models import Section
@@ -200,12 +204,91 @@ def get_all_docs(
return docs return docs
class PeriodicSlackLoader(RangePullLoader): def _process_batch_event(
def __init__(self, batch_size: int = INDEX_BATCH_SIZE) -> None: slack_event: dict[str, Any],
self.client = get_client() channel: dict[str, Any],
self.batch_size = batch_size matching_doc: Document | None,
workspace: str | None = None,
) -> Document | None:
if (
slack_event["type"] == "message"
and slack_event.get("subtype") != "channel_join"
):
if matching_doc:
return Document(
id=matching_doc.id,
sections=matching_doc.sections
+ [
Section(
link=get_message_link(
slack_event, workspace=workspace, channel_id=channel["id"]
),
text=slack_event["text"],
)
],
source=matching_doc.source,
semantic_identifier=matching_doc.semantic_identifier,
metadata=matching_doc.metadata,
)
def load( return Document(
id=slack_event["ts"],
sections=[
Section(
link=get_message_link(
slack_event, workspace=workspace, channel_id=channel["id"]
),
text=slack_event["text"],
)
],
source=DocumentSource.SLACK,
semantic_identifier=channel["name"],
metadata={},
)
return None
class SlackConnector(LoadConnector, PollConnector):
def __init__(
self, export_path_str: str, batch_size: int = INDEX_BATCH_SIZE
) -> None:
self.export_path_str = export_path_str
self.batch_size = batch_size
self.client = get_client()
def load_from_state(self) -> Generator[list[Document], None, None]:
export_path = Path(self.export_path_str)
with open(export_path / "channels.json") as f:
channels = json.load(f)
document_batch: dict[str, Document] = {}
for channel_info in channels:
channel_dir_path = export_path / cast(str, channel_info["name"])
channel_file_paths = [
channel_dir_path / file_name
for file_name in os.listdir(channel_dir_path)
]
for path in channel_file_paths:
with open(path) as f:
events = cast(list[dict[str, Any]], json.load(f))
for slack_event in events:
doc = _process_batch_event(
slack_event=slack_event,
channel=channel_info,
matching_doc=document_batch.get(
slack_event.get("thread_ts", "")
),
)
if doc:
document_batch[doc.id] = doc
if len(document_batch) >= self.batch_size:
yield list(document_batch.values())
yield list(document_batch.values())
def poll_source(
self, start: SecondsSinceUnixEpoch, end: SecondsSinceUnixEpoch self, start: SecondsSinceUnixEpoch, end: SecondsSinceUnixEpoch
) -> Generator[List[Document], None, None]: ) -> Generator[List[Document], None, None]:
all_docs = get_all_docs(client=self.client, oldest=str(start), latest=str(end)) all_docs = get_all_docs(client=self.client, oldest=str(start), latest=str(end))

View File

@@ -9,7 +9,7 @@ import requests
from bs4 import BeautifulSoup from bs4 import BeautifulSoup
from danswer.configs.app_configs import INDEX_BATCH_SIZE from danswer.configs.app_configs import INDEX_BATCH_SIZE
from danswer.configs.constants import DocumentSource from danswer.configs.constants import DocumentSource
from danswer.connectors.interfaces import PullLoader from danswer.connectors.interfaces import LoadConnector
from danswer.connectors.models import Document from danswer.connectors.models import Document
from danswer.connectors.models import Section from danswer.connectors.models import Section
from danswer.utils.logging import setup_logger from danswer.utils.logging import setup_logger
@@ -51,7 +51,7 @@ def get_internal_links(
return internal_links return internal_links
class WebLoader(PullLoader): class WebConnector(LoadConnector):
def __init__( def __init__(
self, self,
base_url: str, base_url: str,
@@ -60,7 +60,7 @@ class WebLoader(PullLoader):
self.base_url = base_url self.base_url = base_url
self.batch_size = batch_size self.batch_size = batch_size
def load(self) -> Generator[list[Document], None, None]: def load_from_state(self) -> Generator[list[Document], None, None]:
"""Traverses through all pages found on the website """Traverses through all pages found on the website
and converts them into documents""" and converts them into documents"""
visited_links: set[str] = set() visited_links: set[str] = set()
@@ -88,8 +88,8 @@ class WebLoader(PullLoader):
response = requests.get(current_url) response = requests.get(current_url)
pdf_reader = PdfReader(io.BytesIO(response.content)) pdf_reader = PdfReader(io.BytesIO(response.content))
page_text = "" page_text = ""
for page in pdf_reader.pages: for pdf_page in pdf_reader.pages:
page_text += page.extract_text() page_text += pdf_page.extract_text()
doc_batch.append( doc_batch.append(
Document( Document(

View File

@@ -376,7 +376,7 @@ class OpenAIChatCompletionQA(OpenAIQAModel):
found_answer_start = False found_answer_start = False
found_answer_end = False found_answer_end = False
for event in response: for event in response:
event_dict = cast(str, event["choices"][0]["delta"]) event_dict = event["choices"][0]["delta"]
if ( if (
"content" not in event_dict "content" not in event_dict
): # could be a role message or empty termination ): # could be a role message or empty termination

View File

@@ -93,7 +93,11 @@ def retrieve_ranked_documents(
return None return None
ranked_chunks = semantic_reranking(query, top_chunks) ranked_chunks = semantic_reranking(query, top_chunks)
top_docs = [ranked_chunk.source_links["0"] for ranked_chunk in ranked_chunks] top_docs = [
ranked_chunk.source_links[0]
for ranked_chunk in ranked_chunks
if ranked_chunk.source_links is not None
]
files_log_msg = f"Top links from semantic search: {', '.join(top_docs)}" files_log_msg = f"Top links from semantic search: {', '.join(top_docs)}"
logger.info(files_log_msg) logger.info(files_log_msg)

View File

@@ -87,7 +87,7 @@ def index(
_: User = Depends(current_admin_user), _: User = Depends(current_admin_user),
) -> None: ) -> None:
# validate that the connector specified by the source / input_type combination # validate that the connector specified by the source / input_type combination
# exists AND that the connector_specific_config is valid for that connector type # exists AND that the connector_specific_config is valid for that connector type, should be load
build_connector( build_connector(
source=source, source=source,
input_type=index_attempt_request.input_type, input_type=index_attempt_request.input_type,

View File

@@ -1,8 +1,8 @@
from typing import Any from typing import Any
from danswer.connectors.slack.pull import get_channel_info from danswer.connectors.slack.connector import get_channel_info
from danswer.connectors.slack.pull import get_thread from danswer.connectors.slack.connector import get_thread
from danswer.connectors.slack.pull import thread_to_doc from danswer.connectors.slack.connector import thread_to_doc
from danswer.connectors.slack.utils import get_client from danswer.connectors.slack.utils import get_client
from danswer.utils.indexing_pipeline import build_indexing_pipeline from danswer.utils.indexing_pipeline import build_indexing_pipeline
from danswer.utils.logging import setup_logger from danswer.utils.logging import setup_logger

View File

@@ -32,7 +32,7 @@ class UserRoleResponse(BaseModel):
class SearchDoc(BaseModel): class SearchDoc(BaseModel):
semantic_identifier: str semantic_identifier: str
link: str link: str | None
blurb: str blurb: str
source_type: str source_type: str

View File

@@ -78,7 +78,7 @@ def direct_qa(
top_docs = [ top_docs = [
SearchDoc( SearchDoc(
semantic_identifier=chunk.semantic_identifier, semantic_identifier=chunk.semantic_identifier,
link=chunk.source_links.get("0") if chunk.source_links else None, link=chunk.source_links.get(0) if chunk.source_links else None,
blurb=chunk.blurb, blurb=chunk.blurb,
source_type=chunk.source_type, source_type=chunk.source_type,
) )
@@ -117,7 +117,7 @@ def stream_direct_qa(
top_docs = [ top_docs = [
SearchDoc( SearchDoc(
semantic_identifier=chunk.semantic_identifier, semantic_identifier=chunk.semantic_identifier,
link=chunk.source_links.get("0") if chunk.source_links else None, link=chunk.source_links.get(0) if chunk.source_links else None,
blurb=chunk.blurb, blurb=chunk.blurb,
source_type=chunk.source_type, source_type=chunk.source_type,
) )
@@ -130,6 +130,8 @@ def stream_direct_qa(
for response_dict in qa_model.answer_question_stream( for response_dict in qa_model.answer_question_stream(
query, ranked_chunks[:NUM_RERANKED_RESULTS] query, ranked_chunks[:NUM_RERANKED_RESULTS]
): ):
if response_dict is None:
continue
logger.debug(response_dict) logger.debug(response_dict)
yield get_json_line(response_dict) yield get_json_line(response_dict)
return return

View File

@@ -1,6 +1,7 @@
import time import time
from collections.abc import Callable from collections.abc import Callable
from typing import Any from typing import Any
from typing import cast
from typing import TypeVar from typing import TypeVar
from danswer.utils.logging import setup_logger from danswer.utils.logging import setup_logger
@@ -21,7 +22,7 @@ def log_function_time(
... ...
""" """
def timing_wrapper(func: Callable) -> Callable: def timing_wrapper(func: F) -> F:
def wrapped_func(*args: Any, **kwargs: Any) -> Any: def wrapped_func(*args: Any, **kwargs: Any) -> Any:
start_time = time.time() start_time = time.time()
result = func(*args, **kwargs) result = func(*args, **kwargs)
@@ -30,6 +31,6 @@ def log_function_time(
) )
return result return result
return wrapped_func return cast(F, wrapped_func)
return timing_wrapper return timing_wrapper

View File

@@ -5,12 +5,12 @@ from danswer.chunking.chunk import Chunker
from danswer.chunking.chunk import DefaultChunker from danswer.chunking.chunk import DefaultChunker
from danswer.configs.app_configs import INDEX_BATCH_SIZE from danswer.configs.app_configs import INDEX_BATCH_SIZE
from danswer.configs.app_configs import QDRANT_DEFAULT_COLLECTION from danswer.configs.app_configs import QDRANT_DEFAULT_COLLECTION
from danswer.connectors.github.batch import BatchGithubLoader from danswer.connectors.github.connector import GithubConnector
from danswer.connectors.google_drive.batch import BatchGoogleDriveLoader from danswer.connectors.google_drive.connector import GoogleDriveConnector
from danswer.connectors.google_drive.connector_auth import backend_get_credentials from danswer.connectors.google_drive.connector_auth import backend_get_credentials
from danswer.connectors.interfaces import PullLoader from danswer.connectors.interfaces import LoadConnector
from danswer.connectors.slack.batch import BatchSlackLoader from danswer.connectors.slack.connector import SlackConnector
from danswer.connectors.web.pull import WebLoader from danswer.connectors.web.connector import WebConnector
from danswer.datastores.interfaces import Datastore from danswer.datastores.interfaces import Datastore
from danswer.datastores.qdrant.indexing import recreate_collection from danswer.datastores.qdrant.indexing import recreate_collection
from danswer.datastores.qdrant.store import QdrantDatastore from danswer.datastores.qdrant.store import QdrantDatastore
@@ -23,14 +23,14 @@ logger = setup_logger()
def load_batch( def load_batch(
doc_loader: PullLoader, doc_loader: LoadConnector,
chunker: Chunker, chunker: Chunker,
embedder: Embedder, embedder: Embedder,
datastore: Datastore, datastore: Datastore,
) -> None: ) -> None:
num_processed = 0 num_processed = 0
total_chunks = 0 total_chunks = 0
for document_batch in doc_loader.load(): for document_batch in doc_loader.load_from_state():
if not document_batch: if not document_batch:
logger.warning("No parseable documents found in batch") logger.warning("No parseable documents found in batch")
continue continue
@@ -53,7 +53,7 @@ def load_batch(
def load_slack_batch(file_path: str, qdrant_collection: str) -> None: def load_slack_batch(file_path: str, qdrant_collection: str) -> None:
logger.info("Loading documents from Slack.") logger.info("Loading documents from Slack.")
load_batch( load_batch(
BatchSlackLoader(export_path_str=file_path, batch_size=INDEX_BATCH_SIZE), SlackConnector(export_path_str=file_path, batch_size=INDEX_BATCH_SIZE),
DefaultChunker(), DefaultChunker(),
DefaultEmbedder(), DefaultEmbedder(),
QdrantDatastore(collection=qdrant_collection), QdrantDatastore(collection=qdrant_collection),
@@ -63,7 +63,7 @@ def load_slack_batch(file_path: str, qdrant_collection: str) -> None:
def load_web_batch(url: str, qdrant_collection: str) -> None: def load_web_batch(url: str, qdrant_collection: str) -> None:
logger.info("Loading documents from web.") logger.info("Loading documents from web.")
load_batch( load_batch(
WebLoader(base_url=url, batch_size=INDEX_BATCH_SIZE), WebConnector(base_url=url, batch_size=INDEX_BATCH_SIZE),
DefaultChunker(), DefaultChunker(),
DefaultEmbedder(), DefaultEmbedder(),
QdrantDatastore(collection=qdrant_collection), QdrantDatastore(collection=qdrant_collection),
@@ -74,7 +74,7 @@ def load_google_drive_batch(qdrant_collection: str) -> None:
logger.info("Loading documents from Google Drive.") logger.info("Loading documents from Google Drive.")
backend_get_credentials() backend_get_credentials()
load_batch( load_batch(
BatchGoogleDriveLoader(batch_size=INDEX_BATCH_SIZE), GoogleDriveConnector(batch_size=INDEX_BATCH_SIZE),
DefaultChunker(), DefaultChunker(),
DefaultEmbedder(), DefaultEmbedder(),
QdrantDatastore(collection=qdrant_collection), QdrantDatastore(collection=qdrant_collection),
@@ -84,9 +84,7 @@ def load_google_drive_batch(qdrant_collection: str) -> None:
def load_github_batch(owner: str, repo: str, qdrant_collection: str) -> None: def load_github_batch(owner: str, repo: str, qdrant_collection: str) -> None:
logger.info("Loading documents from Github.") logger.info("Loading documents from Github.")
load_batch( load_batch(
BatchGithubLoader( GithubConnector(repo_owner=owner, repo_name=repo, batch_size=INDEX_BATCH_SIZE),
repo_owner=owner, repo_name=repo, batch_size=INDEX_BATCH_SIZE
),
DefaultChunker(), DefaultChunker(),
DefaultEmbedder(), DefaultEmbedder(),
QdrantDatastore(collection=qdrant_collection), QdrantDatastore(collection=qdrant_collection),

View File

@@ -15,7 +15,6 @@ const MainSection = () => {
// "/api/admin/connectors/web/index-attempt", // "/api/admin/connectors/web/index-attempt",
// fetcher // fetcher
// ); // );
const router = useRouter();
const { mutate } = useSWRConfig(); const { mutate } = useSWRConfig();
const { data, isLoading, error } = useSWR<SlackConfig>( const { data, isLoading, error } = useSWR<SlackConfig>(

View File

@@ -2,11 +2,12 @@ import React, { useState } from "react";
import { Formik, Form, FormikHelpers } from "formik"; import { Formik, Form, FormikHelpers } from "formik";
import * as Yup from "yup"; import * as Yup from "yup";
import { Popup } from "./Popup"; import { Popup } from "./Popup";
import { ValidSources } from "@/lib/types"; import { ValidInputTypes, ValidSources } from "@/lib/types";
export const submitIndexRequest = async ( export const submitIndexRequest = async (
source: ValidSources, source: ValidSources,
values: Yup.AnyObject values: Yup.AnyObject,
inputType: ValidInputTypes = "load_state"
): Promise<{ message: string; isSuccess: boolean }> => { ): Promise<{ message: string; isSuccess: boolean }> => {
let isSuccess = false; let isSuccess = false;
try { try {
@@ -17,7 +18,7 @@ export const submitIndexRequest = async (
headers: { headers: {
"Content-Type": "application/json", "Content-Type": "application/json",
}, },
body: JSON.stringify({ connector_specific_config: values }), body: JSON.stringify({ connector_specific_config: values, inputType }),
} }
); );

View File

@@ -8,3 +8,4 @@ export interface User {
} }
export type ValidSources = "web" | "github" | "slack" | "google_drive"; export type ValidSources = "web" | "github" | "slack" | "google_drive";
export type ValidInputTypes = "load_state" | "poll" | "event";