mirror of
https://github.com/danswer-ai/danswer.git
synced 2025-04-08 20:08:36 +02:00
Add recursive Notion search
This commit is contained in:
parent
db024ad7b7
commit
4912beb283
@ -133,6 +133,11 @@ WEB_CONNECTOR_OAUTH_CLIENT_ID = os.environ.get("WEB_CONNECTOR_OAUTH_CLIENT_ID")
|
||||
WEB_CONNECTOR_OAUTH_CLIENT_SECRET = os.environ.get("WEB_CONNECTOR_OAUTH_CLIENT_SECRET")
|
||||
WEB_CONNECTOR_OAUTH_TOKEN_URL = os.environ.get("WEB_CONNECTOR_OAUTH_TOKEN_URL")
|
||||
|
||||
NOTION_CONNECTOR_ENABLE_RECURSIVE_PAGE_LOOKUP = (
|
||||
os.environ.get("NOTION_CONNECTOR_ENABLE_RECURSIVE_PAGE_LOOKUP", "").lower()
|
||||
== "true"
|
||||
)
|
||||
|
||||
#####
|
||||
# Query Configs
|
||||
#####
|
||||
|
@ -9,6 +9,7 @@ import requests
|
||||
from retry import retry
|
||||
|
||||
from danswer.configs.app_configs import INDEX_BATCH_SIZE
|
||||
from danswer.configs.app_configs import NOTION_CONNECTOR_ENABLE_RECURSIVE_PAGE_LOOKUP
|
||||
from danswer.configs.constants import DocumentSource
|
||||
from danswer.connectors.interfaces import GenerateDocumentsOutput
|
||||
from danswer.connectors.interfaces import LoadConnector
|
||||
@ -64,13 +65,26 @@ class NotionConnector(LoadConnector, PollConnector):
|
||||
batch_size (int): Number of objects to index in a batch
|
||||
"""
|
||||
|
||||
def __init__(self, batch_size: int = INDEX_BATCH_SIZE) -> None:
|
||||
def __init__(
|
||||
self,
|
||||
batch_size: int = INDEX_BATCH_SIZE,
|
||||
recursive_index_enabled: bool = NOTION_CONNECTOR_ENABLE_RECURSIVE_PAGE_LOOKUP,
|
||||
) -> None:
|
||||
"""Initialize with parameters."""
|
||||
self.batch_size = batch_size
|
||||
self.headers = {
|
||||
"Content-Type": "application/json",
|
||||
"Notion-Version": "2022-06-28",
|
||||
}
|
||||
self.indexed_pages: set[str] = set()
|
||||
# if enabled, will recursively index child pages as they are found rather
|
||||
# relying entirely on the `search` API. We have recieved reports that the
|
||||
# `search` API misses many pages - in those cases, this might need to be
|
||||
# turned on. It's not currently known why/when this is required.
|
||||
# NOTE: this also removes all benefits polling, since we need to traverse
|
||||
# all pages regardless of if they are updated. If the notion workspace is
|
||||
# very large, this may not be practical.
|
||||
self.recursive_index_enabled = recursive_index_enabled
|
||||
|
||||
@retry(tries=3, delay=1, backoff=2)
|
||||
def _fetch_blocks(self, block_id: str, cursor: str | None = None) -> dict[str, Any]:
|
||||
@ -86,14 +100,31 @@ class NotionConnector(LoadConnector, PollConnector):
|
||||
raise e
|
||||
return res.json()
|
||||
|
||||
def _read_blocks(self, page_block_id: str) -> list[tuple[str, str]]:
|
||||
@retry(tries=3, delay=1, backoff=2)
|
||||
def _fetch_page(self, page_id: str) -> NotionPage:
|
||||
"""Fetch a page from it's ID via the Notion API."""
|
||||
logger.debug(f"Fetching page for ID '{page_id}'")
|
||||
block_url = f"https://api.notion.com/v1/pages/{page_id}"
|
||||
res = requests.get(block_url, headers=self.headers)
|
||||
try:
|
||||
res.raise_for_status()
|
||||
except Exception as e:
|
||||
logger.exception(f"Error fetching page - {res.json()}")
|
||||
raise e
|
||||
return NotionPage(**res.json())
|
||||
|
||||
def _read_blocks(
|
||||
self, page_block_id: str
|
||||
) -> tuple[list[tuple[str, str]], list[str]]:
|
||||
"""Reads blocks for a page"""
|
||||
result_lines: list[tuple[str, str]] = []
|
||||
child_pages: list[str] = []
|
||||
cursor = None
|
||||
while True:
|
||||
data = self._fetch_blocks(page_block_id, cursor)
|
||||
|
||||
for result in data["results"]:
|
||||
result_block_id = result["id"]
|
||||
result_type = result["type"]
|
||||
result_obj = result[result_type]
|
||||
|
||||
@ -105,7 +136,9 @@ class NotionConnector(LoadConnector, PollConnector):
|
||||
text = rich_text["text"]["content"]
|
||||
cur_result_text_arr.append(text)
|
||||
|
||||
result_block_id = result["id"]
|
||||
if result["has_children"] and result_type == "child_page":
|
||||
child_pages.append(result_block_id)
|
||||
|
||||
cur_result_text = "\n".join(cur_result_text_arr)
|
||||
if cur_result_text:
|
||||
result_lines.append((cur_result_text, result_block_id))
|
||||
@ -115,7 +148,7 @@ class NotionConnector(LoadConnector, PollConnector):
|
||||
|
||||
cursor = data["next_cursor"]
|
||||
|
||||
return result_lines
|
||||
return result_lines, child_pages
|
||||
|
||||
def _read_page_title(self, page: NotionPage) -> str:
|
||||
"""Extracts the title from a Notion page"""
|
||||
@ -133,9 +166,15 @@ class NotionConnector(LoadConnector, PollConnector):
|
||||
pages: list[NotionPage],
|
||||
) -> Generator[Document, None, None]:
|
||||
"""Reads pages for rich text content and generates Documents"""
|
||||
all_child_page_ids: list[str] = []
|
||||
for page in pages:
|
||||
if page.id in self.indexed_pages:
|
||||
logger.debug(f"Already indexed page with ID '{page.id}'. Skipping.")
|
||||
continue
|
||||
|
||||
logger.info(f"Reading page with ID '{page.id}', with url {page.url}")
|
||||
page_blocks = self._read_blocks(page.id)
|
||||
page_blocks, child_page_ids = self._read_blocks(page.id)
|
||||
all_child_page_ids.extend(child_page_ids)
|
||||
page_title = self._read_page_title(page)
|
||||
yield (
|
||||
Document(
|
||||
@ -153,6 +192,17 @@ class NotionConnector(LoadConnector, PollConnector):
|
||||
metadata={},
|
||||
)
|
||||
)
|
||||
self.indexed_pages.add(page.id)
|
||||
|
||||
if self.recursive_index_enabled and all_child_page_ids:
|
||||
# NOTE: checking if page_id is in self.indexed_pages to prevent extra
|
||||
# calls to `_fetch_page` for pages we've already indexed
|
||||
all_child_pages = [
|
||||
self._fetch_page(page_id)
|
||||
for page_id in all_child_page_ids
|
||||
if page_id not in self.indexed_pages
|
||||
]
|
||||
yield from self._read_pages(all_child_pages)
|
||||
|
||||
@retry(tries=3, delay=1, backoff=2)
|
||||
def _search_notion(self, query_dict: dict[str, Any]) -> NotionSearchResponse:
|
||||
@ -184,7 +234,7 @@ class NotionConnector(LoadConnector, PollConnector):
|
||||
end (float) - end epoch time to filter to
|
||||
filter_field (str) - the attribute on the page to apply the filter
|
||||
"""
|
||||
filtered_pages = []
|
||||
filtered_pages: list[NotionPage] = []
|
||||
for page in pages:
|
||||
compare_time = time.mktime(
|
||||
time.strptime(page[filter_field], "%Y-%m-%dT%H:%M:%S.000Z")
|
||||
|
Loading…
x
Reference in New Issue
Block a user