Add recursive Notion search

This commit is contained in:
Weves 2023-09-20 14:21:28 -07:00 committed by Chris Weaver
parent db024ad7b7
commit 4912beb283
2 changed files with 61 additions and 6 deletions

View File

@ -133,6 +133,11 @@ WEB_CONNECTOR_OAUTH_CLIENT_ID = os.environ.get("WEB_CONNECTOR_OAUTH_CLIENT_ID")
WEB_CONNECTOR_OAUTH_CLIENT_SECRET = os.environ.get("WEB_CONNECTOR_OAUTH_CLIENT_SECRET")
WEB_CONNECTOR_OAUTH_TOKEN_URL = os.environ.get("WEB_CONNECTOR_OAUTH_TOKEN_URL")
NOTION_CONNECTOR_ENABLE_RECURSIVE_PAGE_LOOKUP = (
os.environ.get("NOTION_CONNECTOR_ENABLE_RECURSIVE_PAGE_LOOKUP", "").lower()
== "true"
)
#####
# Query Configs
#####

View File

@ -9,6 +9,7 @@ import requests
from retry import retry
from danswer.configs.app_configs import INDEX_BATCH_SIZE
from danswer.configs.app_configs import NOTION_CONNECTOR_ENABLE_RECURSIVE_PAGE_LOOKUP
from danswer.configs.constants import DocumentSource
from danswer.connectors.interfaces import GenerateDocumentsOutput
from danswer.connectors.interfaces import LoadConnector
@ -64,13 +65,26 @@ class NotionConnector(LoadConnector, PollConnector):
batch_size (int): Number of objects to index in a batch
"""
def __init__(self, batch_size: int = INDEX_BATCH_SIZE) -> None:
def __init__(
self,
batch_size: int = INDEX_BATCH_SIZE,
recursive_index_enabled: bool = NOTION_CONNECTOR_ENABLE_RECURSIVE_PAGE_LOOKUP,
) -> None:
"""Initialize with parameters."""
self.batch_size = batch_size
self.headers = {
"Content-Type": "application/json",
"Notion-Version": "2022-06-28",
}
self.indexed_pages: set[str] = set()
# if enabled, will recursively index child pages as they are found rather
# relying entirely on the `search` API. We have recieved reports that the
# `search` API misses many pages - in those cases, this might need to be
# turned on. It's not currently known why/when this is required.
# NOTE: this also removes all benefits polling, since we need to traverse
# all pages regardless of if they are updated. If the notion workspace is
# very large, this may not be practical.
self.recursive_index_enabled = recursive_index_enabled
@retry(tries=3, delay=1, backoff=2)
def _fetch_blocks(self, block_id: str, cursor: str | None = None) -> dict[str, Any]:
@ -86,14 +100,31 @@ class NotionConnector(LoadConnector, PollConnector):
raise e
return res.json()
def _read_blocks(self, page_block_id: str) -> list[tuple[str, str]]:
@retry(tries=3, delay=1, backoff=2)
def _fetch_page(self, page_id: str) -> NotionPage:
"""Fetch a page from it's ID via the Notion API."""
logger.debug(f"Fetching page for ID '{page_id}'")
block_url = f"https://api.notion.com/v1/pages/{page_id}"
res = requests.get(block_url, headers=self.headers)
try:
res.raise_for_status()
except Exception as e:
logger.exception(f"Error fetching page - {res.json()}")
raise e
return NotionPage(**res.json())
def _read_blocks(
self, page_block_id: str
) -> tuple[list[tuple[str, str]], list[str]]:
"""Reads blocks for a page"""
result_lines: list[tuple[str, str]] = []
child_pages: list[str] = []
cursor = None
while True:
data = self._fetch_blocks(page_block_id, cursor)
for result in data["results"]:
result_block_id = result["id"]
result_type = result["type"]
result_obj = result[result_type]
@ -105,7 +136,9 @@ class NotionConnector(LoadConnector, PollConnector):
text = rich_text["text"]["content"]
cur_result_text_arr.append(text)
result_block_id = result["id"]
if result["has_children"] and result_type == "child_page":
child_pages.append(result_block_id)
cur_result_text = "\n".join(cur_result_text_arr)
if cur_result_text:
result_lines.append((cur_result_text, result_block_id))
@ -115,7 +148,7 @@ class NotionConnector(LoadConnector, PollConnector):
cursor = data["next_cursor"]
return result_lines
return result_lines, child_pages
def _read_page_title(self, page: NotionPage) -> str:
"""Extracts the title from a Notion page"""
@ -133,9 +166,15 @@ class NotionConnector(LoadConnector, PollConnector):
pages: list[NotionPage],
) -> Generator[Document, None, None]:
"""Reads pages for rich text content and generates Documents"""
all_child_page_ids: list[str] = []
for page in pages:
if page.id in self.indexed_pages:
logger.debug(f"Already indexed page with ID '{page.id}'. Skipping.")
continue
logger.info(f"Reading page with ID '{page.id}', with url {page.url}")
page_blocks = self._read_blocks(page.id)
page_blocks, child_page_ids = self._read_blocks(page.id)
all_child_page_ids.extend(child_page_ids)
page_title = self._read_page_title(page)
yield (
Document(
@ -153,6 +192,17 @@ class NotionConnector(LoadConnector, PollConnector):
metadata={},
)
)
self.indexed_pages.add(page.id)
if self.recursive_index_enabled and all_child_page_ids:
# NOTE: checking if page_id is in self.indexed_pages to prevent extra
# calls to `_fetch_page` for pages we've already indexed
all_child_pages = [
self._fetch_page(page_id)
for page_id in all_child_page_ids
if page_id not in self.indexed_pages
]
yield from self._read_pages(all_child_pages)
@retry(tries=3, delay=1, backoff=2)
def _search_notion(self, query_dict: dict[str, Any]) -> NotionSearchResponse:
@ -184,7 +234,7 @@ class NotionConnector(LoadConnector, PollConnector):
end (float) - end epoch time to filter to
filter_field (str) - the attribute on the page to apply the filter
"""
filtered_pages = []
filtered_pages: list[NotionPage] = []
for page in pages:
compare_time = time.mktime(
time.strptime(page[filter_field], "%Y-%m-%dT%H:%M:%S.000Z")