mirror of
https://github.com/danswer-ai/danswer.git
synced 2025-04-08 20:08:36 +02:00
Address rate limiting for Notion
This commit is contained in:
parent
5b93e786ad
commit
3fb68af405
@ -5,6 +5,8 @@ from typing import Any
|
||||
from typing import cast
|
||||
from typing import TypeVar
|
||||
|
||||
import requests
|
||||
|
||||
from danswer.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
@ -84,3 +86,45 @@ class _RateLimitDecorator:
|
||||
|
||||
|
||||
rate_limit_builder = _RateLimitDecorator
|
||||
|
||||
|
||||
"""If you want to allow the external service to tell you when you've hit the rate limit,
|
||||
use the following instead"""
|
||||
|
||||
R = TypeVar("R", bound=Callable[..., requests.Response])
|
||||
|
||||
|
||||
def wrap_request_to_handle_ratelimiting(
|
||||
request_fn: R, default_wait_time_sec: int = 30, max_waits: int = 30
|
||||
) -> R:
|
||||
def wrapped_request(*args: list, **kwargs: dict[str, Any]) -> requests.Response:
|
||||
for _ in range(max_waits):
|
||||
response = request_fn(*args, **kwargs)
|
||||
if response.status_code == 429:
|
||||
try:
|
||||
wait_time = int(
|
||||
response.headers.get("Retry-After", default_wait_time_sec)
|
||||
)
|
||||
except ValueError:
|
||||
wait_time = default_wait_time_sec
|
||||
|
||||
time.sleep(wait_time)
|
||||
continue
|
||||
|
||||
return response
|
||||
|
||||
raise RateLimitTriedTooManyTimesError(f"Exceeded '{max_waits}' retries")
|
||||
|
||||
return cast(R, wrapped_request)
|
||||
|
||||
|
||||
_rate_limited_get = wrap_request_to_handle_ratelimiting(requests.get)
|
||||
_rate_limited_post = wrap_request_to_handle_ratelimiting(requests.post)
|
||||
|
||||
|
||||
class _RateLimitedRequest:
|
||||
get = _rate_limited_get
|
||||
post = _rate_limited_post
|
||||
|
||||
|
||||
rl_requests = _RateLimitedRequest
|
||||
|
@ -7,12 +7,14 @@ from datetime import timezone
|
||||
from typing import Any
|
||||
from typing import Optional
|
||||
|
||||
import requests
|
||||
from retry import retry
|
||||
|
||||
from danswer.configs.app_configs import INDEX_BATCH_SIZE
|
||||
from danswer.configs.app_configs import NOTION_CONNECTOR_ENABLE_RECURSIVE_PAGE_LOOKUP
|
||||
from danswer.configs.constants import DocumentSource
|
||||
from danswer.connectors.cross_connector_utils.rate_limit_wrapper import (
|
||||
rl_requests,
|
||||
)
|
||||
from danswer.connectors.interfaces import GenerateDocumentsOutput
|
||||
from danswer.connectors.interfaces import LoadConnector
|
||||
from danswer.connectors.interfaces import PollConnector
|
||||
@ -100,7 +102,7 @@ class NotionConnector(LoadConnector, PollConnector):
|
||||
logger.debug(f"Fetching children of block with ID '{block_id}'")
|
||||
block_url = f"https://api.notion.com/v1/blocks/{block_id}/children"
|
||||
query_params = None if not cursor else {"start_cursor": cursor}
|
||||
res = requests.get(
|
||||
res = rl_requests.get(
|
||||
block_url,
|
||||
headers=self.headers,
|
||||
params=query_params,
|
||||
@ -127,7 +129,7 @@ class NotionConnector(LoadConnector, PollConnector):
|
||||
"""Fetch a page from it's ID via the Notion API."""
|
||||
logger.debug(f"Fetching page for ID '{page_id}'")
|
||||
block_url = f"https://api.notion.com/v1/pages/{page_id}"
|
||||
res = requests.get(
|
||||
res = rl_requests.get(
|
||||
block_url,
|
||||
headers=self.headers,
|
||||
timeout=_NOTION_CALL_TIMEOUT,
|
||||
@ -147,7 +149,7 @@ class NotionConnector(LoadConnector, PollConnector):
|
||||
logger.debug(f"Fetching database for ID '{database_id}'")
|
||||
block_url = f"https://api.notion.com/v1/databases/{database_id}/query"
|
||||
body = None if not cursor else {"start_cursor": cursor}
|
||||
res = requests.post(
|
||||
res = rl_requests.post(
|
||||
block_url,
|
||||
headers=self.headers,
|
||||
json=body,
|
||||
@ -327,7 +329,7 @@ class NotionConnector(LoadConnector, PollConnector):
|
||||
"""Search for pages from a Notion database. Includes some small number of
|
||||
retries to handle misc, flakey failures."""
|
||||
logger.debug(f"Searching for pages in Notion with query_dict: {query_dict}")
|
||||
res = requests.post(
|
||||
res = rl_requests.post(
|
||||
"https://api.notion.com/v1/search",
|
||||
headers=self.headers,
|
||||
json=query_dict,
|
||||
@ -435,6 +437,8 @@ class NotionConnector(LoadConnector, PollConnector):
|
||||
yield from batch_generator(self._read_pages(pages), self.batch_size)
|
||||
if db_res.has_more:
|
||||
query_dict["start_cursor"] = db_res.next_cursor
|
||||
else:
|
||||
break
|
||||
else:
|
||||
break
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user