Fix bugs in Notion connector (#440)

* Fix bad pagination
* Make each block be a section -> we can link to individual blocks 
* Don't have a page include all content from child pages
This commit is contained in:
Chris Weaver 2023-09-13 13:30:41 -07:00 committed by GitHub
parent ffa24e2f09
commit 4e359bc731
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -1,9 +1,8 @@
import time
from collections.abc import Generator
from dataclasses import dataclass
from dataclasses import fields
from typing import Any
from typing import Dict
from typing import List
from typing import Optional
import requests
@ -17,6 +16,7 @@ from danswer.connectors.interfaces import PollConnector
from danswer.connectors.interfaces import SecondsSinceUnixEpoch
from danswer.connectors.models import Document
from danswer.connectors.models import Section
from danswer.utils.batching import batch_generator
from danswer.utils.logger import setup_logger
logger = setup_logger()
@ -30,7 +30,7 @@ class NotionPage:
created_time: str
last_edited_time: str
archived: bool
properties: Dict[str, Any]
properties: dict[str, Any]
url: str
def __init__(self, **kwargs: dict[str, Any]) -> None:
@ -44,7 +44,7 @@ class NotionPage:
class NotionSearchResponse:
"""Represents the response from the Notion Search API"""
results: List[Dict[str, Any]]
results: list[dict[str, Any]]
next_cursor: Optional[str]
has_more: bool = False
@ -73,22 +73,21 @@ class NotionConnector(LoadConnector, PollConnector):
}
@retry(tries=3, delay=1, backoff=2)
def _fetch_block(self, block_id: str) -> dict[str, Any]:
"""Fetch a single block via the Notion API."""
logger.debug(f"Fetching block with ID '{block_id}'")
def _fetch_blocks(self, block_id: str, cursor: str | None = None) -> dict[str, Any]:
"""Fetch all child blocks via the Notion API."""
logger.debug(f"Fetching children of block with ID '{block_id}'")
block_url = f"https://api.notion.com/v1/blocks/{block_id}/children"
query_dict: Dict[str, Any] = {}
query_dict: dict[str, Any] = {} if not cursor else {"start_cursor": cursor}
res = requests.get(block_url, headers=self.headers, json=query_dict)
res.raise_for_status()
return res.json()
def _read_blocks(self, block_id: str, num_tabs: int = 0) -> str:
def _read_blocks(self, page_block_id: str) -> list[tuple[str, str]]:
"""Reads blocks for a page"""
done = False
result_lines_arr = []
cur_block_id = block_id
while not done:
data = self._fetch_block(cur_block_id)
result_lines: list[tuple[str, str]] = []
cursor = None
while True:
data = self._fetch_blocks(page_block_id, cursor)
for result in data["results"]:
result_type = result["type"]
@ -100,27 +99,18 @@ class NotionConnector(LoadConnector, PollConnector):
# skip if doesn't have text object
if "text" in rich_text:
text = rich_text["text"]["content"]
prefix = "\t" * num_tabs
cur_result_text_arr.append(prefix + text)
cur_result_text_arr.append(text)
result_block_id = result["id"]
has_children = result["has_children"]
if has_children:
children_text = self._read_blocks(
result_block_id, num_tabs=num_tabs + 1
)
cur_result_text_arr.append(children_text)
cur_result_text = "\n".join(cur_result_text_arr)
result_lines_arr.append(cur_result_text)
if cur_result_text:
result_lines.append((cur_result_text, result_block_id))
if data["next_cursor"] is None:
done = True
break
else:
cur_block_id = data["next_cursor"]
result_lines = "\n".join(result_lines_arr)
cursor = data["next_cursor"]
return result_lines
def _read_page_title(self, page: NotionPage) -> str:
@ -134,26 +124,34 @@ class NotionConnector(LoadConnector, PollConnector):
page_title = f"Untitled Page [{page.id}]"
return page_title
def _read_pages(self, pages: List[NotionPage]) -> List[Document]:
def _read_pages(
self,
pages: list[NotionPage],
) -> Generator[Document, None, None]:
"""Reads pages for rich text content and generates Documents"""
docs_batch = []
for page in pages:
logger.info(f"Reading page with ID '{page.id}', with url {page.url}")
page_text = self._read_blocks(page.id)
page_blocks = self._read_blocks(page.id)
page_title = self._read_page_title(page)
docs_batch.append(
yield (
Document(
id=page.id,
sections=[Section(link=page.url, text=page_text)],
sections=[Section(link=page.url, text=f"{page_title}\n")]
+ [
Section(
link=f"{page.url}#{block_id.replace('-', '')}",
text=block_text,
)
for block_text, block_id in page_blocks
],
source=DocumentSource.NOTION,
semantic_identifier=page_title,
metadata={},
)
)
return docs_batch
@retry(tries=3, delay=1, backoff=2)
def _search_notion(self, query_dict: Dict[str, Any]) -> NotionSearchResponse:
def _search_notion(self, query_dict: dict[str, Any]) -> NotionSearchResponse:
"""Search for pages from a Notion database. Includes some small number of
retries to handle misc, flakey failures."""
logger.debug(f"Searching for pages in Notion with query_dict: {query_dict}")
@ -167,17 +165,17 @@ class NotionConnector(LoadConnector, PollConnector):
def _filter_pages_by_time(
self,
pages: List[Dict[str, Any]],
pages: list[dict[str, Any]],
start: SecondsSinceUnixEpoch,
end: SecondsSinceUnixEpoch,
filter_field: str = "last_edited_time",
) -> List[NotionPage]:
) -> list[NotionPage]:
"""A helper function to filter out pages outside of a time
range. This functionality doesn't yet exist in the Notion Search API,
but when it does, this approach can be deprecated.
Arguments:
pages (List[Dict]) - Pages to filter
pages (list[dict]) - Pages to filter
start (float) - start epoch time to filter from
end (float) - end epoch time to filter to
filter_field (str) - the attribute on the page to apply the filter
@ -202,7 +200,7 @@ class NotionConnector(LoadConnector, PollConnector):
"""Loads all page data from a Notion workspace.
Returns:
List[Document]: List of documents.
list[Document]: list of documents.
"""
query_dict = {
"filter": {"property": "object", "value": "page"},
@ -211,7 +209,7 @@ class NotionConnector(LoadConnector, PollConnector):
while True:
db_res = self._search_notion(query_dict)
pages = [NotionPage(**page) for page in db_res.results]
yield self._read_pages(pages)
yield from batch_generator(self._read_pages(pages), self.batch_size)
if db_res.has_more:
query_dict["start_cursor"] = db_res.next_cursor
else:
@ -237,7 +235,7 @@ class NotionConnector(LoadConnector, PollConnector):
db_res.results, start, end, filter_field="last_edited_time"
)
if len(pages) > 0:
yield self._read_pages(pages)
yield from batch_generator(self._read_pages(pages), self.batch_size)
if db_res.has_more:
query_dict["start_cursor"] = db_res.next_cursor
else:
@ -252,4 +250,6 @@ if __name__ == "__main__":
{"notion_integration_token": os.environ.get("NOTION_INTEGRATION_TOKEN")}
)
document_batches = connector.load_from_state()
print(next(document_batches))
batch = next(document_batches)
for doc in batch:
print(doc)