mirror of
https://github.com/danswer-ai/danswer.git
synced 2025-09-22 17:16:20 +02:00
Fix bugs in Notion connector (#440)
* Fix bad pagination * Make each block be a section -> we can link to individual blocks * Don't have a page include all content from child pages
This commit is contained in:
@@ -1,9 +1,8 @@
|
|||||||
import time
|
import time
|
||||||
|
from collections.abc import Generator
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from dataclasses import fields
|
from dataclasses import fields
|
||||||
from typing import Any
|
from typing import Any
|
||||||
from typing import Dict
|
|
||||||
from typing import List
|
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
import requests
|
import requests
|
||||||
@@ -17,6 +16,7 @@ from danswer.connectors.interfaces import PollConnector
|
|||||||
from danswer.connectors.interfaces import SecondsSinceUnixEpoch
|
from danswer.connectors.interfaces import SecondsSinceUnixEpoch
|
||||||
from danswer.connectors.models import Document
|
from danswer.connectors.models import Document
|
||||||
from danswer.connectors.models import Section
|
from danswer.connectors.models import Section
|
||||||
|
from danswer.utils.batching import batch_generator
|
||||||
from danswer.utils.logger import setup_logger
|
from danswer.utils.logger import setup_logger
|
||||||
|
|
||||||
logger = setup_logger()
|
logger = setup_logger()
|
||||||
@@ -30,7 +30,7 @@ class NotionPage:
|
|||||||
created_time: str
|
created_time: str
|
||||||
last_edited_time: str
|
last_edited_time: str
|
||||||
archived: bool
|
archived: bool
|
||||||
properties: Dict[str, Any]
|
properties: dict[str, Any]
|
||||||
url: str
|
url: str
|
||||||
|
|
||||||
def __init__(self, **kwargs: dict[str, Any]) -> None:
|
def __init__(self, **kwargs: dict[str, Any]) -> None:
|
||||||
@@ -44,7 +44,7 @@ class NotionPage:
|
|||||||
class NotionSearchResponse:
|
class NotionSearchResponse:
|
||||||
"""Represents the response from the Notion Search API"""
|
"""Represents the response from the Notion Search API"""
|
||||||
|
|
||||||
results: List[Dict[str, Any]]
|
results: list[dict[str, Any]]
|
||||||
next_cursor: Optional[str]
|
next_cursor: Optional[str]
|
||||||
has_more: bool = False
|
has_more: bool = False
|
||||||
|
|
||||||
@@ -73,22 +73,21 @@ class NotionConnector(LoadConnector, PollConnector):
|
|||||||
}
|
}
|
||||||
|
|
||||||
@retry(tries=3, delay=1, backoff=2)
|
@retry(tries=3, delay=1, backoff=2)
|
||||||
def _fetch_block(self, block_id: str) -> dict[str, Any]:
|
def _fetch_blocks(self, block_id: str, cursor: str | None = None) -> dict[str, Any]:
|
||||||
"""Fetch a single block via the Notion API."""
|
"""Fetch all child blocks via the Notion API."""
|
||||||
logger.debug(f"Fetching block with ID '{block_id}'")
|
logger.debug(f"Fetching children of block with ID '{block_id}'")
|
||||||
block_url = f"https://api.notion.com/v1/blocks/{block_id}/children"
|
block_url = f"https://api.notion.com/v1/blocks/{block_id}/children"
|
||||||
query_dict: Dict[str, Any] = {}
|
query_dict: dict[str, Any] = {} if not cursor else {"start_cursor": cursor}
|
||||||
res = requests.get(block_url, headers=self.headers, json=query_dict)
|
res = requests.get(block_url, headers=self.headers, json=query_dict)
|
||||||
res.raise_for_status()
|
res.raise_for_status()
|
||||||
return res.json()
|
return res.json()
|
||||||
|
|
||||||
def _read_blocks(self, block_id: str, num_tabs: int = 0) -> str:
|
def _read_blocks(self, page_block_id: str) -> list[tuple[str, str]]:
|
||||||
"""Reads blocks for a page"""
|
"""Reads blocks for a page"""
|
||||||
done = False
|
result_lines: list[tuple[str, str]] = []
|
||||||
result_lines_arr = []
|
cursor = None
|
||||||
cur_block_id = block_id
|
while True:
|
||||||
while not done:
|
data = self._fetch_blocks(page_block_id, cursor)
|
||||||
data = self._fetch_block(cur_block_id)
|
|
||||||
|
|
||||||
for result in data["results"]:
|
for result in data["results"]:
|
||||||
result_type = result["type"]
|
result_type = result["type"]
|
||||||
@@ -100,27 +99,18 @@ class NotionConnector(LoadConnector, PollConnector):
|
|||||||
# skip if doesn't have text object
|
# skip if doesn't have text object
|
||||||
if "text" in rich_text:
|
if "text" in rich_text:
|
||||||
text = rich_text["text"]["content"]
|
text = rich_text["text"]["content"]
|
||||||
prefix = "\t" * num_tabs
|
cur_result_text_arr.append(text)
|
||||||
cur_result_text_arr.append(prefix + text)
|
|
||||||
|
|
||||||
result_block_id = result["id"]
|
result_block_id = result["id"]
|
||||||
has_children = result["has_children"]
|
|
||||||
if has_children:
|
|
||||||
children_text = self._read_blocks(
|
|
||||||
result_block_id, num_tabs=num_tabs + 1
|
|
||||||
)
|
|
||||||
cur_result_text_arr.append(children_text)
|
|
||||||
|
|
||||||
cur_result_text = "\n".join(cur_result_text_arr)
|
cur_result_text = "\n".join(cur_result_text_arr)
|
||||||
result_lines_arr.append(cur_result_text)
|
if cur_result_text:
|
||||||
|
result_lines.append((cur_result_text, result_block_id))
|
||||||
|
|
||||||
if data["next_cursor"] is None:
|
if data["next_cursor"] is None:
|
||||||
done = True
|
|
||||||
break
|
break
|
||||||
else:
|
|
||||||
cur_block_id = data["next_cursor"]
|
|
||||||
|
|
||||||
result_lines = "\n".join(result_lines_arr)
|
cursor = data["next_cursor"]
|
||||||
|
|
||||||
return result_lines
|
return result_lines
|
||||||
|
|
||||||
def _read_page_title(self, page: NotionPage) -> str:
|
def _read_page_title(self, page: NotionPage) -> str:
|
||||||
@@ -134,26 +124,34 @@ class NotionConnector(LoadConnector, PollConnector):
|
|||||||
page_title = f"Untitled Page [{page.id}]"
|
page_title = f"Untitled Page [{page.id}]"
|
||||||
return page_title
|
return page_title
|
||||||
|
|
||||||
def _read_pages(self, pages: List[NotionPage]) -> List[Document]:
|
def _read_pages(
|
||||||
|
self,
|
||||||
|
pages: list[NotionPage],
|
||||||
|
) -> Generator[Document, None, None]:
|
||||||
"""Reads pages for rich text content and generates Documents"""
|
"""Reads pages for rich text content and generates Documents"""
|
||||||
docs_batch = []
|
|
||||||
for page in pages:
|
for page in pages:
|
||||||
logger.info(f"Reading page with ID '{page.id}', with url {page.url}")
|
logger.info(f"Reading page with ID '{page.id}', with url {page.url}")
|
||||||
page_text = self._read_blocks(page.id)
|
page_blocks = self._read_blocks(page.id)
|
||||||
page_title = self._read_page_title(page)
|
page_title = self._read_page_title(page)
|
||||||
docs_batch.append(
|
yield (
|
||||||
Document(
|
Document(
|
||||||
id=page.id,
|
id=page.id,
|
||||||
sections=[Section(link=page.url, text=page_text)],
|
sections=[Section(link=page.url, text=f"{page_title}\n")]
|
||||||
|
+ [
|
||||||
|
Section(
|
||||||
|
link=f"{page.url}#{block_id.replace('-', '')}",
|
||||||
|
text=block_text,
|
||||||
|
)
|
||||||
|
for block_text, block_id in page_blocks
|
||||||
|
],
|
||||||
source=DocumentSource.NOTION,
|
source=DocumentSource.NOTION,
|
||||||
semantic_identifier=page_title,
|
semantic_identifier=page_title,
|
||||||
metadata={},
|
metadata={},
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
return docs_batch
|
|
||||||
|
|
||||||
@retry(tries=3, delay=1, backoff=2)
|
@retry(tries=3, delay=1, backoff=2)
|
||||||
def _search_notion(self, query_dict: Dict[str, Any]) -> NotionSearchResponse:
|
def _search_notion(self, query_dict: dict[str, Any]) -> NotionSearchResponse:
|
||||||
"""Search for pages from a Notion database. Includes some small number of
|
"""Search for pages from a Notion database. Includes some small number of
|
||||||
retries to handle misc, flakey failures."""
|
retries to handle misc, flakey failures."""
|
||||||
logger.debug(f"Searching for pages in Notion with query_dict: {query_dict}")
|
logger.debug(f"Searching for pages in Notion with query_dict: {query_dict}")
|
||||||
@@ -167,17 +165,17 @@ class NotionConnector(LoadConnector, PollConnector):
|
|||||||
|
|
||||||
def _filter_pages_by_time(
|
def _filter_pages_by_time(
|
||||||
self,
|
self,
|
||||||
pages: List[Dict[str, Any]],
|
pages: list[dict[str, Any]],
|
||||||
start: SecondsSinceUnixEpoch,
|
start: SecondsSinceUnixEpoch,
|
||||||
end: SecondsSinceUnixEpoch,
|
end: SecondsSinceUnixEpoch,
|
||||||
filter_field: str = "last_edited_time",
|
filter_field: str = "last_edited_time",
|
||||||
) -> List[NotionPage]:
|
) -> list[NotionPage]:
|
||||||
"""A helper function to filter out pages outside of a time
|
"""A helper function to filter out pages outside of a time
|
||||||
range. This functionality doesn't yet exist in the Notion Search API,
|
range. This functionality doesn't yet exist in the Notion Search API,
|
||||||
but when it does, this approach can be deprecated.
|
but when it does, this approach can be deprecated.
|
||||||
|
|
||||||
Arguments:
|
Arguments:
|
||||||
pages (List[Dict]) - Pages to filter
|
pages (list[dict]) - Pages to filter
|
||||||
start (float) - start epoch time to filter from
|
start (float) - start epoch time to filter from
|
||||||
end (float) - end epoch time to filter to
|
end (float) - end epoch time to filter to
|
||||||
filter_field (str) - the attribute on the page to apply the filter
|
filter_field (str) - the attribute on the page to apply the filter
|
||||||
@@ -202,7 +200,7 @@ class NotionConnector(LoadConnector, PollConnector):
|
|||||||
"""Loads all page data from a Notion workspace.
|
"""Loads all page data from a Notion workspace.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
List[Document]: List of documents.
|
list[Document]: list of documents.
|
||||||
"""
|
"""
|
||||||
query_dict = {
|
query_dict = {
|
||||||
"filter": {"property": "object", "value": "page"},
|
"filter": {"property": "object", "value": "page"},
|
||||||
@@ -211,7 +209,7 @@ class NotionConnector(LoadConnector, PollConnector):
|
|||||||
while True:
|
while True:
|
||||||
db_res = self._search_notion(query_dict)
|
db_res = self._search_notion(query_dict)
|
||||||
pages = [NotionPage(**page) for page in db_res.results]
|
pages = [NotionPage(**page) for page in db_res.results]
|
||||||
yield self._read_pages(pages)
|
yield from batch_generator(self._read_pages(pages), self.batch_size)
|
||||||
if db_res.has_more:
|
if db_res.has_more:
|
||||||
query_dict["start_cursor"] = db_res.next_cursor
|
query_dict["start_cursor"] = db_res.next_cursor
|
||||||
else:
|
else:
|
||||||
@@ -237,7 +235,7 @@ class NotionConnector(LoadConnector, PollConnector):
|
|||||||
db_res.results, start, end, filter_field="last_edited_time"
|
db_res.results, start, end, filter_field="last_edited_time"
|
||||||
)
|
)
|
||||||
if len(pages) > 0:
|
if len(pages) > 0:
|
||||||
yield self._read_pages(pages)
|
yield from batch_generator(self._read_pages(pages), self.batch_size)
|
||||||
if db_res.has_more:
|
if db_res.has_more:
|
||||||
query_dict["start_cursor"] = db_res.next_cursor
|
query_dict["start_cursor"] = db_res.next_cursor
|
||||||
else:
|
else:
|
||||||
@@ -252,4 +250,6 @@ if __name__ == "__main__":
|
|||||||
{"notion_integration_token": os.environ.get("NOTION_INTEGRATION_TOKEN")}
|
{"notion_integration_token": os.environ.get("NOTION_INTEGRATION_TOKEN")}
|
||||||
)
|
)
|
||||||
document_batches = connector.load_from_state()
|
document_batches = connector.load_from_state()
|
||||||
print(next(document_batches))
|
batch = next(document_batches)
|
||||||
|
for doc in batch:
|
||||||
|
print(doc)
|
||||||
|
Reference in New Issue
Block a user