diff --git a/backend/danswer/background/update.py b/backend/danswer/background/update.py index aee35f661..6ba1f1001 100755 --- a/backend/danswer/background/update.py +++ b/backend/danswer/background/update.py @@ -2,12 +2,10 @@ import time from typing import cast from danswer.configs.constants import DocumentSource -from danswer.connectors.factory import build_connector -from danswer.connectors.factory import build_pull_connector +from danswer.connectors.factory import build_load_connector from danswer.connectors.models import InputType from danswer.connectors.slack.config import get_pull_frequency -from danswer.connectors.slack.pull import PeriodicSlackLoader -from danswer.connectors.web.pull import WebLoader +from danswer.connectors.slack.connector import SlackConnector from danswer.db.index_attempt import fetch_index_attempts from danswer.db.index_attempt import insert_index_attempt from danswer.db.index_attempt import update_index_attempt @@ -20,7 +18,7 @@ from danswer.utils.logging import setup_logger logger = setup_logger() -LAST_PULL_KEY_TEMPLATE = "last_pull_{}" +LAST_POLL_KEY_TEMPLATE = "last_poll_{}" def _check_should_run(current_time: int, last_pull: int, pull_frequency: int) -> bool: @@ -43,9 +41,7 @@ def run_update() -> None: except ConfigNotFoundError: pull_frequency = 0 if pull_frequency: - last_slack_pull_key = LAST_PULL_KEY_TEMPLATE.format( - PeriodicSlackLoader.__name__ - ) + last_slack_pull_key = LAST_POLL_KEY_TEMPLATE.format(SlackConnector.__name__) try: last_pull = cast(int, dynamic_config_store.load(last_slack_pull_key)) except ConfigNotFoundError: @@ -61,7 +57,7 @@ def run_update() -> None: insert_index_attempt( IndexAttempt( source=DocumentSource.SLACK, - input_type=InputType.PULL, + input_type=InputType.POLL, status=IndexingStatus.NOT_STARTED, connector_specific_config={}, ) @@ -75,7 +71,7 @@ def run_update() -> None: # prevent race conditions across multiple background jobs. For now, # this assumes we only ever run a single background job at a time not_started_index_attempts = fetch_index_attempts( - input_types=[InputType.PULL], statuses=[IndexingStatus.NOT_STARTED] + input_types=[InputType.LOAD_STATE], statuses=[IndexingStatus.NOT_STARTED] ) for not_started_index_attempt in not_started_index_attempts: logger.info( @@ -94,13 +90,13 @@ def run_update() -> None: try: # TODO (chris): spawn processes to parallelize / take advantage of # multiple cores + implement retries - connector = build_pull_connector( + connector = build_load_connector( source=not_started_index_attempt.source, connector_specific_config=not_started_index_attempt.connector_specific_config, ) document_ids: list[str] = [] - for doc_batch in connector.load(): + for doc_batch in connector.load_from_state(): indexing_pipeline(doc_batch) document_ids.extend([doc.id for doc in doc_batch]) except Exception as e: diff --git a/backend/danswer/chunking/models.py b/backend/danswer/chunking/models.py index 061f7b7d6..b2b6591e3 100644 --- a/backend/danswer/chunking/models.py +++ b/backend/danswer/chunking/models.py @@ -1,6 +1,7 @@ import inspect from dataclasses import dataclass from typing import Any +from typing import cast from danswer.connectors.models import Document @@ -34,10 +35,12 @@ class InferenceChunk(BaseChunk): @classmethod def from_dict(cls, init_dict: dict[str, Any]) -> "InferenceChunk": - return cls( - **{ - k: v - for k, v in init_dict.items() - if k in inspect.signature(cls).parameters + init_kwargs = { + k: v for k, v in init_dict.items() if k in inspect.signature(cls).parameters + } + if "source_links" in init_kwargs: + init_kwargs["source_links"] = { + int(k): v + for k, v in cast(dict[str, str], init_kwargs["source_links"]).items() } - ) + return cls(**init_kwargs) diff --git a/backend/danswer/connectors/factory.py b/backend/danswer/connectors/factory.py index 2f6ebcb9a..d4a5b37b8 100644 --- a/backend/danswer/connectors/factory.py +++ b/backend/danswer/connectors/factory.py @@ -3,15 +3,16 @@ from collections.abc import Generator from typing import Any from danswer.configs.constants import DocumentSource -from danswer.connectors.github.batch import BatchGithubLoader -from danswer.connectors.google_drive.batch import BatchGoogleDriveLoader -from danswer.connectors.interfaces import PullLoader -from danswer.connectors.interfaces import RangePullLoader +from danswer.connectors.github.connector import GithubConnector +from danswer.connectors.google_drive.connector import GoogleDriveConnector +from danswer.connectors.interfaces import BaseConnector +from danswer.connectors.interfaces import EventConnector +from danswer.connectors.interfaces import LoadConnector +from danswer.connectors.interfaces import PollConnector from danswer.connectors.models import Document from danswer.connectors.models import InputType -from danswer.connectors.slack.batch import BatchSlackLoader -from danswer.connectors.slack.pull import PeriodicSlackLoader -from danswer.connectors.web.pull import WebLoader +from danswer.connectors.slack.connector import SlackConnector +from danswer.connectors.web.connector import WebConnector _NUM_SECONDS_IN_DAY = 86400 @@ -24,45 +25,52 @@ def build_connector( source: DocumentSource, input_type: InputType, connector_specific_config: dict[str, Any], -) -> PullLoader | RangePullLoader: +) -> BaseConnector: if source == DocumentSource.SLACK: - if input_type == InputType.PULL: - return PeriodicSlackLoader(**connector_specific_config) - if input_type == InputType.LOAD_STATE: - return BatchSlackLoader(**connector_specific_config) + connector: BaseConnector = SlackConnector(**connector_specific_config) elif source == DocumentSource.GOOGLE_DRIVE: - if input_type == InputType.PULL: - return BatchGoogleDriveLoader(**connector_specific_config) + connector = GoogleDriveConnector(**connector_specific_config) elif source == DocumentSource.GITHUB: - if input_type == InputType.PULL: - return BatchGithubLoader(**connector_specific_config) + connector = GithubConnector(**connector_specific_config) elif source == DocumentSource.WEB: - if input_type == InputType.PULL: - return WebLoader(**connector_specific_config) + connector = WebConnector(**connector_specific_config) + else: + raise ConnectorMissingException(f"Connector not found for source={source}") - raise ConnectorMissingException( - f"Connector not found for source={source}, input_type={input_type}" - ) + if any( + [ + input_type == InputType.LOAD_STATE + and not isinstance(connector, LoadConnector), + input_type == InputType.POLL and not isinstance(connector, PollConnector), + input_type == InputType.EVENT and not isinstance(connector, EventConnector), + ] + ): + raise ConnectorMissingException( + f"Connector for source={source} does not accept input_type={input_type}" + ) + + return connector -def build_pull_connector( - source: DocumentSource, connector_specific_config: dict[str, Any] -) -> PullLoader: - connector = build_connector(source, InputType.PULL, connector_specific_config) - return ( - _range_pull_to_pull(connector) - if isinstance(connector, RangePullLoader) - else connector - ) - - -def _range_pull_to_pull(range_pull_connector: RangePullLoader) -> PullLoader: - class _Connector(PullLoader): +# TODO this is some jank, rework at some point +def _poll_to_load_connector(range_pull_connector: PollConnector) -> LoadConnector: + class _Connector(LoadConnector): def __init__(self) -> None: self._connector = range_pull_connector - def load(self) -> Generator[list[Document], None, None]: + def load_from_state(self) -> Generator[list[Document], None, None]: # adding some buffer to make sure we get all documents - return self._connector.load(0, time.time() + _NUM_SECONDS_IN_DAY) + return self._connector.poll_source(0, time.time() + _NUM_SECONDS_IN_DAY) return _Connector() + + +# TODO this is some jank, rework at some point +def build_load_connector( + source: DocumentSource, connector_specific_config: dict[str, Any] +) -> LoadConnector: + connector = build_connector(source, InputType.LOAD_STATE, connector_specific_config) + if isinstance(connector, PollConnector): + return _poll_to_load_connector(connector) + assert isinstance(connector, LoadConnector) + return connector diff --git a/backend/danswer/connectors/github/batch.py b/backend/danswer/connectors/github/connector.py similarity index 93% rename from backend/danswer/connectors/github/batch.py rename to backend/danswer/connectors/github/connector.py index 6d34ffdb9..ff2b81976 100644 --- a/backend/danswer/connectors/github/batch.py +++ b/backend/danswer/connectors/github/connector.py @@ -4,7 +4,7 @@ from collections.abc import Generator from danswer.configs.app_configs import GITHUB_ACCESS_TOKEN from danswer.configs.app_configs import INDEX_BATCH_SIZE from danswer.configs.constants import DocumentSource -from danswer.connectors.interfaces import PullLoader +from danswer.connectors.interfaces import LoadConnector from danswer.connectors.models import Document from danswer.connectors.models import Section from danswer.utils.logging import setup_logger @@ -29,7 +29,7 @@ def get_pr_batches( yield batch -class BatchGithubLoader(PullLoader): +class GithubConnector(LoadConnector): def __init__( self, repo_owner: str, @@ -42,7 +42,7 @@ class BatchGithubLoader(PullLoader): self.batch_size = batch_size self.state_filter = state_filter - def load(self) -> Generator[list[Document], None, None]: + def load_from_state(self) -> Generator[list[Document], None, None]: repo = github_client.get_repo(f"{self.repo_owner}/{self.repo_name}") pull_requests = repo.get_pulls(state=self.state_filter) for pr_batch in get_pr_batches(pull_requests, self.batch_size): diff --git a/backend/danswer/connectors/google_drive/batch.py b/backend/danswer/connectors/google_drive/connector.py similarity index 94% rename from backend/danswer/connectors/google_drive/batch.py rename to backend/danswer/connectors/google_drive/connector.py index e9875f9b3..169e366a3 100644 --- a/backend/danswer/connectors/google_drive/batch.py +++ b/backend/danswer/connectors/google_drive/connector.py @@ -5,7 +5,7 @@ from danswer.configs.app_configs import GOOGLE_DRIVE_INCLUDE_SHARED from danswer.configs.app_configs import INDEX_BATCH_SIZE from danswer.configs.constants import DocumentSource from danswer.connectors.google_drive.connector_auth import get_drive_tokens -from danswer.connectors.interfaces import PullLoader +from danswer.connectors.interfaces import LoadConnector from danswer.connectors.models import Document from danswer.connectors.models import Section from danswer.utils.logging import setup_logger @@ -81,11 +81,7 @@ def extract_text(file: dict[str, str], service: discovery.Resource) -> str: return "\n".join(page.extract_text() for page in pdf_reader.pages) -class BatchGoogleDriveLoader(PullLoader): - """ - Loads everything in a Google Drive account - """ - +class GoogleDriveConnector(LoadConnector): def __init__( self, batch_size: int = INDEX_BATCH_SIZE, @@ -98,7 +94,7 @@ class BatchGoogleDriveLoader(PullLoader): if not self.creds: raise PermissionError("Unable to access Google Drive.") - def load(self) -> Generator[list[Document], None, None]: + def load_from_state(self) -> Generator[list[Document], None, None]: service = discovery.build("drive", "v3", credentials=self.creds) for files_batch in get_file_batches( service, self.include_shared, self.batch_size diff --git a/backend/danswer/connectors/interfaces.py b/backend/danswer/connectors/interfaces.py index da01d99a0..98c199898 100644 --- a/backend/danswer/connectors/interfaces.py +++ b/backend/danswer/connectors/interfaces.py @@ -8,22 +8,29 @@ from danswer.connectors.models import Document SecondsSinceUnixEpoch = float -# TODO (chris): rename from Loader -> Connector -class PullLoader: +class BaseConnector(abc.ABC): + # Reserved for future shared uses + pass + + +# Large set update or reindex, generally pulling a complete state or from a savestate file +class LoadConnector(BaseConnector): @abc.abstractmethod - def load(self) -> Generator[list[Document], None, None]: + def load_from_state(self) -> Generator[list[Document], None, None]: raise NotImplementedError -class RangePullLoader: +# Small set updates by time +class PollConnector(BaseConnector): @abc.abstractmethod - def load( + def poll_source( self, start: SecondsSinceUnixEpoch, end: SecondsSinceUnixEpoch ) -> Generator[list[Document], None, None]: raise NotImplementedError -class PushLoader: +# Event driven +class EventConnector(BaseConnector): @abc.abstractmethod - def load(self, event: Any) -> Generator[list[Document], None, None]: + def handle_event(self, event: Any) -> Generator[list[Document], None, None]: raise NotImplementedError diff --git a/backend/danswer/connectors/models.py b/backend/danswer/connectors/models.py index 930418891..f4dcf0053 100644 --- a/backend/danswer/connectors/models.py +++ b/backend/danswer/connectors/models.py @@ -26,9 +26,9 @@ def get_raw_document_text(document: Document) -> str: class InputType(str, Enum): - PULL = "pull" # e.g. calling slack API to get all messages in the last hour - LOAD_STATE = "load_state" # e.g. loading the state of a slack workspace from a file - EVENT = "event" # e.g. registered an endpoint as a listener, and processing slack events + LOAD_STATE = "load_state" # e.g. loading a current full state or a save state, such as from a file + POLL = "poll" # e.g. calling an API to get all documents in the last hour + EVENT = "event" # e.g. registered an endpoint as a listener, and processing connector events class ConnectorDescriptor(BaseModel): diff --git a/backend/danswer/connectors/slack/batch.py b/backend/danswer/connectors/slack/batch.py deleted file mode 100644 index 90dff7807..000000000 --- a/backend/danswer/connectors/slack/batch.py +++ /dev/null @@ -1,99 +0,0 @@ -import json -import os -from collections.abc import Generator -from pathlib import Path -from typing import Any -from typing import cast - -from danswer.configs.app_configs import INDEX_BATCH_SIZE -from danswer.configs.constants import DocumentSource -from danswer.connectors.interfaces import PullLoader -from danswer.connectors.models import Document -from danswer.connectors.models import Section -from danswer.connectors.slack.utils import get_message_link - - -def _process_batch_event( - slack_event: dict[str, Any], - channel: dict[str, Any], - matching_doc: Document | None, - workspace: str | None = None, -) -> Document | None: - if ( - slack_event["type"] == "message" - and slack_event.get("subtype") != "channel_join" - ): - if matching_doc: - return Document( - id=matching_doc.id, - sections=matching_doc.sections - + [ - Section( - link=get_message_link( - slack_event, workspace=workspace, channel_id=channel["id"] - ), - text=slack_event["text"], - ) - ], - source=matching_doc.source, - semantic_identifier=matching_doc.semantic_identifier, - metadata=matching_doc.metadata, - ) - - return Document( - id=slack_event["ts"], - sections=[ - Section( - link=get_message_link( - slack_event, workspace=workspace, channel_id=channel["id"] - ), - text=slack_event["text"], - ) - ], - source=DocumentSource.SLACK, - semantic_identifier=channel["name"], - metadata={}, - ) - - return None - - -class BatchSlackLoader(PullLoader): - """Loads from an unzipped slack workspace export""" - - def __init__( - self, export_path_str: str, batch_size: int = INDEX_BATCH_SIZE - ) -> None: - self.export_path_str = export_path_str - self.batch_size = batch_size - - def load(self) -> Generator[list[Document], None, None]: - export_path = Path(self.export_path_str) - - with open(export_path / "channels.json") as f: - channels = json.load(f) - - document_batch: dict[str, Document] = {} - for channel_info in channels: - channel_dir_path = export_path / cast(str, channel_info["name"]) - channel_file_paths = [ - channel_dir_path / file_name - for file_name in os.listdir(channel_dir_path) - ] - for path in channel_file_paths: - with open(path) as f: - events = cast(list[dict[str, Any]], json.load(f)) - for slack_event in events: - doc = _process_batch_event( - slack_event=slack_event, - channel=channel_info, - matching_doc=document_batch.get( - slack_event.get("thread_ts", "") - ), - ) - if doc: - document_batch[doc.id] = doc - if len(document_batch) >= self.batch_size: - yield list(document_batch.values()) - - yield list(document_batch.values()) diff --git a/backend/danswer/connectors/slack/pull.py b/backend/danswer/connectors/slack/connector.py similarity index 70% rename from backend/danswer/connectors/slack/pull.py rename to backend/danswer/connectors/slack/connector.py index 8186bfe12..737d470cb 100644 --- a/backend/danswer/connectors/slack/pull.py +++ b/backend/danswer/connectors/slack/connector.py @@ -1,13 +1,17 @@ +import json +import os import time from collections.abc import Callable from collections.abc import Generator +from pathlib import Path from typing import Any from typing import cast from typing import List from danswer.configs.app_configs import INDEX_BATCH_SIZE from danswer.configs.constants import DocumentSource -from danswer.connectors.interfaces import RangePullLoader +from danswer.connectors.interfaces import LoadConnector +from danswer.connectors.interfaces import PollConnector from danswer.connectors.interfaces import SecondsSinceUnixEpoch from danswer.connectors.models import Document from danswer.connectors.models import Section @@ -200,12 +204,91 @@ def get_all_docs( return docs -class PeriodicSlackLoader(RangePullLoader): - def __init__(self, batch_size: int = INDEX_BATCH_SIZE) -> None: - self.client = get_client() - self.batch_size = batch_size +def _process_batch_event( + slack_event: dict[str, Any], + channel: dict[str, Any], + matching_doc: Document | None, + workspace: str | None = None, +) -> Document | None: + if ( + slack_event["type"] == "message" + and slack_event.get("subtype") != "channel_join" + ): + if matching_doc: + return Document( + id=matching_doc.id, + sections=matching_doc.sections + + [ + Section( + link=get_message_link( + slack_event, workspace=workspace, channel_id=channel["id"] + ), + text=slack_event["text"], + ) + ], + source=matching_doc.source, + semantic_identifier=matching_doc.semantic_identifier, + metadata=matching_doc.metadata, + ) - def load( + return Document( + id=slack_event["ts"], + sections=[ + Section( + link=get_message_link( + slack_event, workspace=workspace, channel_id=channel["id"] + ), + text=slack_event["text"], + ) + ], + source=DocumentSource.SLACK, + semantic_identifier=channel["name"], + metadata={}, + ) + + return None + + +class SlackConnector(LoadConnector, PollConnector): + def __init__( + self, export_path_str: str, batch_size: int = INDEX_BATCH_SIZE + ) -> None: + self.export_path_str = export_path_str + self.batch_size = batch_size + self.client = get_client() + + def load_from_state(self) -> Generator[list[Document], None, None]: + export_path = Path(self.export_path_str) + + with open(export_path / "channels.json") as f: + channels = json.load(f) + + document_batch: dict[str, Document] = {} + for channel_info in channels: + channel_dir_path = export_path / cast(str, channel_info["name"]) + channel_file_paths = [ + channel_dir_path / file_name + for file_name in os.listdir(channel_dir_path) + ] + for path in channel_file_paths: + with open(path) as f: + events = cast(list[dict[str, Any]], json.load(f)) + for slack_event in events: + doc = _process_batch_event( + slack_event=slack_event, + channel=channel_info, + matching_doc=document_batch.get( + slack_event.get("thread_ts", "") + ), + ) + if doc: + document_batch[doc.id] = doc + if len(document_batch) >= self.batch_size: + yield list(document_batch.values()) + + yield list(document_batch.values()) + + def poll_source( self, start: SecondsSinceUnixEpoch, end: SecondsSinceUnixEpoch ) -> Generator[List[Document], None, None]: all_docs = get_all_docs(client=self.client, oldest=str(start), latest=str(end)) diff --git a/backend/danswer/connectors/web/pull.py b/backend/danswer/connectors/web/connector.py similarity index 95% rename from backend/danswer/connectors/web/pull.py rename to backend/danswer/connectors/web/connector.py index aa22b4395..df7d98fe9 100644 --- a/backend/danswer/connectors/web/pull.py +++ b/backend/danswer/connectors/web/connector.py @@ -9,7 +9,7 @@ import requests from bs4 import BeautifulSoup from danswer.configs.app_configs import INDEX_BATCH_SIZE from danswer.configs.constants import DocumentSource -from danswer.connectors.interfaces import PullLoader +from danswer.connectors.interfaces import LoadConnector from danswer.connectors.models import Document from danswer.connectors.models import Section from danswer.utils.logging import setup_logger @@ -51,7 +51,7 @@ def get_internal_links( return internal_links -class WebLoader(PullLoader): +class WebConnector(LoadConnector): def __init__( self, base_url: str, @@ -60,7 +60,7 @@ class WebLoader(PullLoader): self.base_url = base_url self.batch_size = batch_size - def load(self) -> Generator[list[Document], None, None]: + def load_from_state(self) -> Generator[list[Document], None, None]: """Traverses through all pages found on the website and converts them into documents""" visited_links: set[str] = set() @@ -88,8 +88,8 @@ class WebLoader(PullLoader): response = requests.get(current_url) pdf_reader = PdfReader(io.BytesIO(response.content)) page_text = "" - for page in pdf_reader.pages: - page_text += page.extract_text() + for pdf_page in pdf_reader.pages: + page_text += pdf_page.extract_text() doc_batch.append( Document( diff --git a/backend/danswer/direct_qa/question_answer.py b/backend/danswer/direct_qa/question_answer.py index 5574cf947..cfd8457e3 100644 --- a/backend/danswer/direct_qa/question_answer.py +++ b/backend/danswer/direct_qa/question_answer.py @@ -376,7 +376,7 @@ class OpenAIChatCompletionQA(OpenAIQAModel): found_answer_start = False found_answer_end = False for event in response: - event_dict = cast(str, event["choices"][0]["delta"]) + event_dict = event["choices"][0]["delta"] if ( "content" not in event_dict ): # could be a role message or empty termination diff --git a/backend/danswer/semantic_search/semantic_search.py b/backend/danswer/semantic_search/semantic_search.py index ea7f8c5ce..8a871e4d5 100644 --- a/backend/danswer/semantic_search/semantic_search.py +++ b/backend/danswer/semantic_search/semantic_search.py @@ -93,7 +93,11 @@ def retrieve_ranked_documents( return None ranked_chunks = semantic_reranking(query, top_chunks) - top_docs = [ranked_chunk.source_links["0"] for ranked_chunk in ranked_chunks] + top_docs = [ + ranked_chunk.source_links[0] + for ranked_chunk in ranked_chunks + if ranked_chunk.source_links is not None + ] files_log_msg = f"Top links from semantic search: {', '.join(top_docs)}" logger.info(files_log_msg) diff --git a/backend/danswer/server/admin.py b/backend/danswer/server/admin.py index 9fe57f65f..f8b0e2154 100644 --- a/backend/danswer/server/admin.py +++ b/backend/danswer/server/admin.py @@ -87,7 +87,7 @@ def index( _: User = Depends(current_admin_user), ) -> None: # validate that the connector specified by the source / input_type combination - # exists AND that the connector_specific_config is valid for that connector type + # exists AND that the connector_specific_config is valid for that connector type, should be load build_connector( source=source, input_type=index_attempt_request.input_type, diff --git a/backend/danswer/server/event_loading.py b/backend/danswer/server/event_loading.py index 4af6e23d4..ecc8477ea 100644 --- a/backend/danswer/server/event_loading.py +++ b/backend/danswer/server/event_loading.py @@ -1,8 +1,8 @@ from typing import Any -from danswer.connectors.slack.pull import get_channel_info -from danswer.connectors.slack.pull import get_thread -from danswer.connectors.slack.pull import thread_to_doc +from danswer.connectors.slack.connector import get_channel_info +from danswer.connectors.slack.connector import get_thread +from danswer.connectors.slack.connector import thread_to_doc from danswer.connectors.slack.utils import get_client from danswer.utils.indexing_pipeline import build_indexing_pipeline from danswer.utils.logging import setup_logger diff --git a/backend/danswer/server/models.py b/backend/danswer/server/models.py index 92ce9df3e..14abf4f19 100644 --- a/backend/danswer/server/models.py +++ b/backend/danswer/server/models.py @@ -32,7 +32,7 @@ class UserRoleResponse(BaseModel): class SearchDoc(BaseModel): semantic_identifier: str - link: str + link: str | None blurb: str source_type: str diff --git a/backend/danswer/server/search_backend.py b/backend/danswer/server/search_backend.py index c4aaf3a7c..da89a8d3c 100644 --- a/backend/danswer/server/search_backend.py +++ b/backend/danswer/server/search_backend.py @@ -78,7 +78,7 @@ def direct_qa( top_docs = [ SearchDoc( semantic_identifier=chunk.semantic_identifier, - link=chunk.source_links.get("0") if chunk.source_links else None, + link=chunk.source_links.get(0) if chunk.source_links else None, blurb=chunk.blurb, source_type=chunk.source_type, ) @@ -117,7 +117,7 @@ def stream_direct_qa( top_docs = [ SearchDoc( semantic_identifier=chunk.semantic_identifier, - link=chunk.source_links.get("0") if chunk.source_links else None, + link=chunk.source_links.get(0) if chunk.source_links else None, blurb=chunk.blurb, source_type=chunk.source_type, ) @@ -130,6 +130,8 @@ def stream_direct_qa( for response_dict in qa_model.answer_question_stream( query, ranked_chunks[:NUM_RERANKED_RESULTS] ): + if response_dict is None: + continue logger.debug(response_dict) yield get_json_line(response_dict) return diff --git a/backend/danswer/utils/timing.py b/backend/danswer/utils/timing.py index 7cc4cc467..0c3ae52e6 100644 --- a/backend/danswer/utils/timing.py +++ b/backend/danswer/utils/timing.py @@ -1,6 +1,7 @@ import time from collections.abc import Callable from typing import Any +from typing import cast from typing import TypeVar from danswer.utils.logging import setup_logger @@ -21,7 +22,7 @@ def log_function_time( ... """ - def timing_wrapper(func: Callable) -> Callable: + def timing_wrapper(func: F) -> F: def wrapped_func(*args: Any, **kwargs: Any) -> Any: start_time = time.time() result = func(*args, **kwargs) @@ -30,6 +31,6 @@ def log_function_time( ) return result - return wrapped_func + return cast(F, wrapped_func) return timing_wrapper diff --git a/backend/scripts/ingestion.py b/backend/scripts/ingestion.py index 6de7ce830..3a0d06e09 100644 --- a/backend/scripts/ingestion.py +++ b/backend/scripts/ingestion.py @@ -5,12 +5,12 @@ from danswer.chunking.chunk import Chunker from danswer.chunking.chunk import DefaultChunker from danswer.configs.app_configs import INDEX_BATCH_SIZE from danswer.configs.app_configs import QDRANT_DEFAULT_COLLECTION -from danswer.connectors.github.batch import BatchGithubLoader -from danswer.connectors.google_drive.batch import BatchGoogleDriveLoader +from danswer.connectors.github.connector import GithubConnector +from danswer.connectors.google_drive.connector import GoogleDriveConnector from danswer.connectors.google_drive.connector_auth import backend_get_credentials -from danswer.connectors.interfaces import PullLoader -from danswer.connectors.slack.batch import BatchSlackLoader -from danswer.connectors.web.pull import WebLoader +from danswer.connectors.interfaces import LoadConnector +from danswer.connectors.slack.connector import SlackConnector +from danswer.connectors.web.connector import WebConnector from danswer.datastores.interfaces import Datastore from danswer.datastores.qdrant.indexing import recreate_collection from danswer.datastores.qdrant.store import QdrantDatastore @@ -23,14 +23,14 @@ logger = setup_logger() def load_batch( - doc_loader: PullLoader, + doc_loader: LoadConnector, chunker: Chunker, embedder: Embedder, datastore: Datastore, ) -> None: num_processed = 0 total_chunks = 0 - for document_batch in doc_loader.load(): + for document_batch in doc_loader.load_from_state(): if not document_batch: logger.warning("No parseable documents found in batch") continue @@ -53,7 +53,7 @@ def load_batch( def load_slack_batch(file_path: str, qdrant_collection: str) -> None: logger.info("Loading documents from Slack.") load_batch( - BatchSlackLoader(export_path_str=file_path, batch_size=INDEX_BATCH_SIZE), + SlackConnector(export_path_str=file_path, batch_size=INDEX_BATCH_SIZE), DefaultChunker(), DefaultEmbedder(), QdrantDatastore(collection=qdrant_collection), @@ -63,7 +63,7 @@ def load_slack_batch(file_path: str, qdrant_collection: str) -> None: def load_web_batch(url: str, qdrant_collection: str) -> None: logger.info("Loading documents from web.") load_batch( - WebLoader(base_url=url, batch_size=INDEX_BATCH_SIZE), + WebConnector(base_url=url, batch_size=INDEX_BATCH_SIZE), DefaultChunker(), DefaultEmbedder(), QdrantDatastore(collection=qdrant_collection), @@ -74,7 +74,7 @@ def load_google_drive_batch(qdrant_collection: str) -> None: logger.info("Loading documents from Google Drive.") backend_get_credentials() load_batch( - BatchGoogleDriveLoader(batch_size=INDEX_BATCH_SIZE), + GoogleDriveConnector(batch_size=INDEX_BATCH_SIZE), DefaultChunker(), DefaultEmbedder(), QdrantDatastore(collection=qdrant_collection), @@ -84,9 +84,7 @@ def load_google_drive_batch(qdrant_collection: str) -> None: def load_github_batch(owner: str, repo: str, qdrant_collection: str) -> None: logger.info("Loading documents from Github.") load_batch( - BatchGithubLoader( - repo_owner=owner, repo_name=repo, batch_size=INDEX_BATCH_SIZE - ), + GithubConnector(repo_owner=owner, repo_name=repo, batch_size=INDEX_BATCH_SIZE), DefaultChunker(), DefaultEmbedder(), QdrantDatastore(collection=qdrant_collection), diff --git a/web/src/app/admin/connectors/slack/page.tsx b/web/src/app/admin/connectors/slack/page.tsx index 2040cd4a3..55511028d 100644 --- a/web/src/app/admin/connectors/slack/page.tsx +++ b/web/src/app/admin/connectors/slack/page.tsx @@ -15,7 +15,6 @@ const MainSection = () => { // "/api/admin/connectors/web/index-attempt", // fetcher // ); - const router = useRouter(); const { mutate } = useSWRConfig(); const { data, isLoading, error } = useSWR( diff --git a/web/src/components/admin/connectors/Form.tsx b/web/src/components/admin/connectors/Form.tsx index ba0a679a1..f2bd870b6 100644 --- a/web/src/components/admin/connectors/Form.tsx +++ b/web/src/components/admin/connectors/Form.tsx @@ -2,11 +2,12 @@ import React, { useState } from "react"; import { Formik, Form, FormikHelpers } from "formik"; import * as Yup from "yup"; import { Popup } from "./Popup"; -import { ValidSources } from "@/lib/types"; +import { ValidInputTypes, ValidSources } from "@/lib/types"; export const submitIndexRequest = async ( source: ValidSources, - values: Yup.AnyObject + values: Yup.AnyObject, + inputType: ValidInputTypes = "load_state" ): Promise<{ message: string; isSuccess: boolean }> => { let isSuccess = false; try { @@ -17,7 +18,7 @@ export const submitIndexRequest = async ( headers: { "Content-Type": "application/json", }, - body: JSON.stringify({ connector_specific_config: values }), + body: JSON.stringify({ connector_specific_config: values, inputType }), } ); diff --git a/web/src/lib/types.ts b/web/src/lib/types.ts index 2e47f696c..cfc0e9317 100644 --- a/web/src/lib/types.ts +++ b/web/src/lib/types.ts @@ -8,3 +8,4 @@ export interface User { } export type ValidSources = "web" | "github" | "slack" | "google_drive"; +export type ValidInputTypes = "load_state" | "poll" | "event";