From 68b23b6339c3ee322986c00c841b4be02f52024d Mon Sep 17 00:00:00 2001 From: Weves Date: Thu, 2 Nov 2023 23:13:00 -0700 Subject: [PATCH] Enable database reading in recursive notion crawl --- .../danswer/connectors/notion/connector.py | 49 +++++++++++++++++++ 1 file changed, 49 insertions(+) diff --git a/backend/danswer/connectors/notion/connector.py b/backend/danswer/connectors/notion/connector.py index 2c84df47f..5ad28b75e 100644 --- a/backend/danswer/connectors/notion/connector.py +++ b/backend/danswer/connectors/notion/connector.py @@ -117,6 +117,50 @@ class NotionConnector(LoadConnector, PollConnector): raise e return NotionPage(**res.json()) + @retry(tries=3, delay=1, backoff=2) + def _fetch_database( + self, database_id: str, cursor: str | None = None + ) -> dict[str, Any]: + """Fetch a database from it's ID via the Notion API.""" + logger.debug(f"Fetching database for ID '{database_id}'") + block_url = f"https://api.notion.com/v1/databases/{database_id}/query" + body = None if not cursor else {"start_cursor": cursor} + res = requests.post(block_url, headers=self.headers, json=body) + try: + res.raise_for_status() + except Exception as e: + logger.exception(f"Error fetching database - {res.json()}") + raise e + return res.json() + + def _read_pages_from_database(self, database_id: str) -> list[str]: + """Returns a list of all page IDs in the database""" + result_pages: list[str] = [] + cursor = None + while True: + data = self._fetch_database(database_id, cursor) + + for result in data["results"]: + obj_id = result["id"] + obj_type = result["object"] + if obj_type == "page": + logger.debug( + f"Found page with ID '{obj_id}' in database '{database_id}'" + ) + result_pages.append(result["id"]) + elif obj_type == "database": + logger.debug( + f"Found database with ID '{obj_id}' in database '{database_id}'" + ) + result_pages.extend(self._read_pages_from_database(obj_id)) + + if data["next_cursor"] is None: + break + + cursor = data["next_cursor"] + + return result_pages + def _read_blocks( self, page_block_id: str ) -> tuple[list[tuple[str, str]], list[str]]: @@ -145,12 +189,17 @@ class NotionConnector(LoadConnector, PollConnector): if result_type == "child_page": child_pages.append(result_block_id) else: + logger.debug(f"Entering sub-block: {result_block_id}") subblock_result_lines, subblock_child_pages = self._read_blocks( result_block_id ) + logger.debug(f"Finished sub-block: {result_block_id}") result_lines.extend(subblock_result_lines) child_pages.extend(subblock_child_pages) + if result_type == "child_database" and self.recursive_index_enabled: + child_pages.extend(self._read_pages_from_database(result_block_id)) + cur_result_text = "\n".join(cur_result_text_arr) if cur_result_text: result_lines.append((cur_result_text, result_block_id))