Fix bugs in Notion connector (#440)

* Fix bad pagination
* Make each block be a section -> we can link to individual blocks 
* Don't have a page include all content from child pages
This commit is contained in:
Chris Weaver
2023-09-13 13:30:41 -07:00
committed by GitHub
parent ffa24e2f09
commit 4e359bc731

View File

@@ -1,9 +1,8 @@
import time import time
from collections.abc import Generator
from dataclasses import dataclass from dataclasses import dataclass
from dataclasses import fields from dataclasses import fields
from typing import Any from typing import Any
from typing import Dict
from typing import List
from typing import Optional from typing import Optional
import requests import requests
@@ -17,6 +16,7 @@ from danswer.connectors.interfaces import PollConnector
from danswer.connectors.interfaces import SecondsSinceUnixEpoch from danswer.connectors.interfaces import SecondsSinceUnixEpoch
from danswer.connectors.models import Document from danswer.connectors.models import Document
from danswer.connectors.models import Section from danswer.connectors.models import Section
from danswer.utils.batching import batch_generator
from danswer.utils.logger import setup_logger from danswer.utils.logger import setup_logger
logger = setup_logger() logger = setup_logger()
@@ -30,7 +30,7 @@ class NotionPage:
created_time: str created_time: str
last_edited_time: str last_edited_time: str
archived: bool archived: bool
properties: Dict[str, Any] properties: dict[str, Any]
url: str url: str
def __init__(self, **kwargs: dict[str, Any]) -> None: def __init__(self, **kwargs: dict[str, Any]) -> None:
@@ -44,7 +44,7 @@ class NotionPage:
class NotionSearchResponse: class NotionSearchResponse:
"""Represents the response from the Notion Search API""" """Represents the response from the Notion Search API"""
results: List[Dict[str, Any]] results: list[dict[str, Any]]
next_cursor: Optional[str] next_cursor: Optional[str]
has_more: bool = False has_more: bool = False
@@ -73,22 +73,21 @@ class NotionConnector(LoadConnector, PollConnector):
} }
@retry(tries=3, delay=1, backoff=2) @retry(tries=3, delay=1, backoff=2)
def _fetch_block(self, block_id: str) -> dict[str, Any]: def _fetch_blocks(self, block_id: str, cursor: str | None = None) -> dict[str, Any]:
"""Fetch a single block via the Notion API.""" """Fetch all child blocks via the Notion API."""
logger.debug(f"Fetching block with ID '{block_id}'") logger.debug(f"Fetching children of block with ID '{block_id}'")
block_url = f"https://api.notion.com/v1/blocks/{block_id}/children" block_url = f"https://api.notion.com/v1/blocks/{block_id}/children"
query_dict: Dict[str, Any] = {} query_dict: dict[str, Any] = {} if not cursor else {"start_cursor": cursor}
res = requests.get(block_url, headers=self.headers, json=query_dict) res = requests.get(block_url, headers=self.headers, json=query_dict)
res.raise_for_status() res.raise_for_status()
return res.json() return res.json()
def _read_blocks(self, block_id: str, num_tabs: int = 0) -> str: def _read_blocks(self, page_block_id: str) -> list[tuple[str, str]]:
"""Reads blocks for a page""" """Reads blocks for a page"""
done = False result_lines: list[tuple[str, str]] = []
result_lines_arr = [] cursor = None
cur_block_id = block_id while True:
while not done: data = self._fetch_blocks(page_block_id, cursor)
data = self._fetch_block(cur_block_id)
for result in data["results"]: for result in data["results"]:
result_type = result["type"] result_type = result["type"]
@@ -100,27 +99,18 @@ class NotionConnector(LoadConnector, PollConnector):
# skip if doesn't have text object # skip if doesn't have text object
if "text" in rich_text: if "text" in rich_text:
text = rich_text["text"]["content"] text = rich_text["text"]["content"]
prefix = "\t" * num_tabs cur_result_text_arr.append(text)
cur_result_text_arr.append(prefix + text)
result_block_id = result["id"] result_block_id = result["id"]
has_children = result["has_children"]
if has_children:
children_text = self._read_blocks(
result_block_id, num_tabs=num_tabs + 1
)
cur_result_text_arr.append(children_text)
cur_result_text = "\n".join(cur_result_text_arr) cur_result_text = "\n".join(cur_result_text_arr)
result_lines_arr.append(cur_result_text) if cur_result_text:
result_lines.append((cur_result_text, result_block_id))
if data["next_cursor"] is None: if data["next_cursor"] is None:
done = True
break break
else:
cur_block_id = data["next_cursor"]
result_lines = "\n".join(result_lines_arr) cursor = data["next_cursor"]
return result_lines return result_lines
def _read_page_title(self, page: NotionPage) -> str: def _read_page_title(self, page: NotionPage) -> str:
@@ -134,26 +124,34 @@ class NotionConnector(LoadConnector, PollConnector):
page_title = f"Untitled Page [{page.id}]" page_title = f"Untitled Page [{page.id}]"
return page_title return page_title
def _read_pages(self, pages: List[NotionPage]) -> List[Document]: def _read_pages(
self,
pages: list[NotionPage],
) -> Generator[Document, None, None]:
"""Reads pages for rich text content and generates Documents""" """Reads pages for rich text content and generates Documents"""
docs_batch = []
for page in pages: for page in pages:
logger.info(f"Reading page with ID '{page.id}', with url {page.url}") logger.info(f"Reading page with ID '{page.id}', with url {page.url}")
page_text = self._read_blocks(page.id) page_blocks = self._read_blocks(page.id)
page_title = self._read_page_title(page) page_title = self._read_page_title(page)
docs_batch.append( yield (
Document( Document(
id=page.id, id=page.id,
sections=[Section(link=page.url, text=page_text)], sections=[Section(link=page.url, text=f"{page_title}\n")]
+ [
Section(
link=f"{page.url}#{block_id.replace('-', '')}",
text=block_text,
)
for block_text, block_id in page_blocks
],
source=DocumentSource.NOTION, source=DocumentSource.NOTION,
semantic_identifier=page_title, semantic_identifier=page_title,
metadata={}, metadata={},
) )
) )
return docs_batch
@retry(tries=3, delay=1, backoff=2) @retry(tries=3, delay=1, backoff=2)
def _search_notion(self, query_dict: Dict[str, Any]) -> NotionSearchResponse: def _search_notion(self, query_dict: dict[str, Any]) -> NotionSearchResponse:
"""Search for pages from a Notion database. Includes some small number of """Search for pages from a Notion database. Includes some small number of
retries to handle misc, flakey failures.""" retries to handle misc, flakey failures."""
logger.debug(f"Searching for pages in Notion with query_dict: {query_dict}") logger.debug(f"Searching for pages in Notion with query_dict: {query_dict}")
@@ -167,17 +165,17 @@ class NotionConnector(LoadConnector, PollConnector):
def _filter_pages_by_time( def _filter_pages_by_time(
self, self,
pages: List[Dict[str, Any]], pages: list[dict[str, Any]],
start: SecondsSinceUnixEpoch, start: SecondsSinceUnixEpoch,
end: SecondsSinceUnixEpoch, end: SecondsSinceUnixEpoch,
filter_field: str = "last_edited_time", filter_field: str = "last_edited_time",
) -> List[NotionPage]: ) -> list[NotionPage]:
"""A helper function to filter out pages outside of a time """A helper function to filter out pages outside of a time
range. This functionality doesn't yet exist in the Notion Search API, range. This functionality doesn't yet exist in the Notion Search API,
but when it does, this approach can be deprecated. but when it does, this approach can be deprecated.
Arguments: Arguments:
pages (List[Dict]) - Pages to filter pages (list[dict]) - Pages to filter
start (float) - start epoch time to filter from start (float) - start epoch time to filter from
end (float) - end epoch time to filter to end (float) - end epoch time to filter to
filter_field (str) - the attribute on the page to apply the filter filter_field (str) - the attribute on the page to apply the filter
@@ -202,7 +200,7 @@ class NotionConnector(LoadConnector, PollConnector):
"""Loads all page data from a Notion workspace. """Loads all page data from a Notion workspace.
Returns: Returns:
List[Document]: List of documents. list[Document]: list of documents.
""" """
query_dict = { query_dict = {
"filter": {"property": "object", "value": "page"}, "filter": {"property": "object", "value": "page"},
@@ -211,7 +209,7 @@ class NotionConnector(LoadConnector, PollConnector):
while True: while True:
db_res = self._search_notion(query_dict) db_res = self._search_notion(query_dict)
pages = [NotionPage(**page) for page in db_res.results] pages = [NotionPage(**page) for page in db_res.results]
yield self._read_pages(pages) yield from batch_generator(self._read_pages(pages), self.batch_size)
if db_res.has_more: if db_res.has_more:
query_dict["start_cursor"] = db_res.next_cursor query_dict["start_cursor"] = db_res.next_cursor
else: else:
@@ -237,7 +235,7 @@ class NotionConnector(LoadConnector, PollConnector):
db_res.results, start, end, filter_field="last_edited_time" db_res.results, start, end, filter_field="last_edited_time"
) )
if len(pages) > 0: if len(pages) > 0:
yield self._read_pages(pages) yield from batch_generator(self._read_pages(pages), self.batch_size)
if db_res.has_more: if db_res.has_more:
query_dict["start_cursor"] = db_res.next_cursor query_dict["start_cursor"] = db_res.next_cursor
else: else:
@@ -252,4 +250,6 @@ if __name__ == "__main__":
{"notion_integration_token": os.environ.get("NOTION_INTEGRATION_TOKEN")} {"notion_integration_token": os.environ.get("NOTION_INTEGRATION_TOKEN")}
) )
document_batches = connector.load_from_state() document_batches = connector.load_from_state()
print(next(document_batches)) batch = next(document_batches)
for doc in batch:
print(doc)