Add cursor to cql confluence (#2775)

* add cursor to cql confluence

* k

* k

* fixed space indexing issue

* fixed .get

---------

Co-authored-by: hagen-danswer <hagen@danswer.ai>
This commit is contained in:
pablodanswer 2024-10-13 19:09:17 -07:00 committed by GitHub
parent ded42e2036
commit a9bcc89a2c
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 183 additions and 296 deletions

View File

@ -1,13 +1,13 @@
import io import io
import os import os
import re
from collections.abc import Callable from collections.abc import Callable
from collections.abc import Collection from collections.abc import Collection
from datetime import datetime from datetime import datetime
from datetime import timezone from datetime import timezone
from functools import lru_cache from functools import lru_cache
from typing import Any from typing import Any
from typing import cast from urllib.parse import parse_qs
from urllib.parse import urlparse
import bs4 import bs4
from atlassian import Confluence # type:ignore from atlassian import Confluence # type:ignore
@ -33,7 +33,6 @@ from danswer.connectors.confluence.rate_limit_handler import (
from danswer.connectors.interfaces import GenerateDocumentsOutput from danswer.connectors.interfaces import GenerateDocumentsOutput
from danswer.connectors.interfaces import LoadConnector from danswer.connectors.interfaces import LoadConnector
from danswer.connectors.interfaces import PollConnector from danswer.connectors.interfaces import PollConnector
from danswer.connectors.interfaces import SecondsSinceUnixEpoch
from danswer.connectors.models import BasicExpertInfo from danswer.connectors.models import BasicExpertInfo
from danswer.connectors.models import ConnectorMissingCredentialError from danswer.connectors.models import ConnectorMissingCredentialError
from danswer.connectors.models import Document from danswer.connectors.models import Document
@ -70,86 +69,25 @@ class DanswerConfluence(Confluence):
self, self,
cql: str, cql: str,
expand: str | None = None, expand: str | None = None,
start: int = 0, cursor: str | None = None,
limit: int = 500, limit: int = 500,
include_archived_spaces: bool = False, include_archived_spaces: bool = False,
) -> list[dict[str, Any]]: ) -> dict[str, Any]:
# Performs the query expansion and start/limit url additions
url_suffix = f"rest/api/content/search?cql={cql}" url_suffix = f"rest/api/content/search?cql={cql}"
if expand: if expand:
url_suffix += f"&expand={expand}" url_suffix += f"&expand={expand}"
url_suffix += f"&start={start}&limit={limit}" if cursor:
url_suffix += f"&cursor={cursor}"
url_suffix += f"&limit={limit}"
if include_archived_spaces: if include_archived_spaces:
url_suffix += "&includeArchivedSpaces=true" url_suffix += "&includeArchivedSpaces=true"
try: try:
response = self.get(url_suffix) response = self.get(url_suffix)
return response.get("results", []) return response
except Exception as e: except Exception as e:
raise e raise e
def _replace_cql_time_filter(
cql_query: str, start_time: datetime, end_time: datetime
) -> str:
"""
This function replaces the lastmodified filter in the CQL query with the start and end times.
This selects the more restrictive time range.
"""
# Extract existing lastmodified >= and <= filters
existing_start_match = re.search(
r'lastmodified\s*>=\s*["\']?(\d{4}-\d{2}-\d{2}(?:\s+\d{2}:\d{2})?)["\']?',
cql_query,
flags=re.IGNORECASE,
)
existing_end_match = re.search(
r'lastmodified\s*<=\s*["\']?(\d{4}-\d{2}-\d{2}(?:\s+\d{2}:\d{2})?)["\']?',
cql_query,
flags=re.IGNORECASE,
)
# Remove all existing lastmodified and updated filters
cql_query = re.sub(
r'\s*AND\s+(lastmodified|updated)\s*[<>=]+\s*["\']?[\d-]+(?:\s+[\d:]+)?["\']?',
"",
cql_query,
flags=re.IGNORECASE,
)
# Determine the start time to use
if existing_start_match:
existing_start_str = existing_start_match.group(1)
existing_start = datetime.strptime(
existing_start_str,
"%Y-%m-%d %H:%M" if " " in existing_start_str else "%Y-%m-%d",
)
existing_start = existing_start.replace(
tzinfo=timezone.utc
) # Make offset-aware
start_time_to_use = max(start_time.astimezone(timezone.utc), existing_start)
else:
start_time_to_use = start_time.astimezone(timezone.utc)
# Determine the end time to use
if existing_end_match:
existing_end_str = existing_end_match.group(1)
existing_end = datetime.strptime(
existing_end_str,
"%Y-%m-%d %H:%M" if " " in existing_end_str else "%Y-%m-%d",
)
existing_end = existing_end.replace(tzinfo=timezone.utc) # Make offset-aware
end_time_to_use = min(end_time.astimezone(timezone.utc), existing_end)
else:
end_time_to_use = end_time.astimezone(timezone.utc)
# Add new time filters
cql_query += (
f" and lastmodified >= '{start_time_to_use.strftime('%Y-%m-%d %H:%M')}'"
)
cql_query += f" and lastmodified <= '{end_time_to_use.strftime('%Y-%m-%d %H:%M')}'"
return cql_query.strip()
@lru_cache() @lru_cache()
def _get_user(user_id: str, confluence_client: DanswerConfluence) -> str: def _get_user(user_id: str, confluence_client: DanswerConfluence) -> str:
"""Get Confluence Display Name based on the account-id or userkey value """Get Confluence Display Name based on the account-id or userkey value
@ -253,126 +191,86 @@ class RecursiveIndexer:
def __init__( def __init__(
self, self,
batch_size: int, batch_size: int,
confluence_client: DanswerConfluence, confluence_client: Confluence,
index_recursively: bool, index_recursively: bool,
origin_page_id: str, origin_page_id: str,
) -> None: ) -> None:
self.batch_size = 1 self.batch_size = batch_size
# batch_size
self.confluence_client = confluence_client self.confluence_client = confluence_client
self.index_recursively = index_recursively self.index_recursively = index_recursively
self.origin_page_id = origin_page_id self.origin_page_id = origin_page_id
self.pages = self.recurse_children_pages(0, self.origin_page_id) self.pages = self.recurse_children_pages(self.origin_page_id)
def get_origin_page(self) -> list[dict[str, Any]]: def get_origin_page(self) -> list[dict[str, Any]]:
return [self._fetch_origin_page()] return [self._fetch_origin_page()]
def get_pages(self, ind: int, size: int) -> list[dict]: def get_pages(self) -> list[dict[str, Any]]:
if ind * size > len(self.pages): return self.pages
return []
return self.pages[ind * size : (ind + 1) * size]
def _fetch_origin_page( def _fetch_origin_page(self) -> dict[str, Any]:
self,
) -> dict[str, Any]:
get_page_by_id = make_confluence_call_handle_rate_limit( get_page_by_id = make_confluence_call_handle_rate_limit(
self.confluence_client.get_page_by_id self.confluence_client.get_page_by_id
) )
try: try:
origin_page = get_page_by_id( origin_page = get_page_by_id(
self.origin_page_id, expand="body.storage.value,version" self.origin_page_id, expand="body.storage.value,version,space"
) )
return origin_page return origin_page
except Exception as e: except Exception as e:
logger.warning( logger.warning(
f"Appending orgin page with id {self.origin_page_id} failed: {e}" f"Appending origin page with id {self.origin_page_id} failed: {e}"
) )
return {} return {}
def recurse_children_pages( def recurse_children_pages(
self, self,
start_ind: int,
page_id: str, page_id: str,
) -> list[dict[str, Any]]: ) -> list[dict[str, Any]]:
pages: list[dict[str, Any]] = [] pages: list[dict[str, Any]] = []
current_level_pages: list[dict[str, Any]] = [] queue: list[str] = [page_id]
next_level_pages: list[dict[str, Any]] = [] visited_pages: set[str] = set()
# Initial fetch of first level children
index = start_ind
while batch := self._fetch_single_depth_child_pages(
index, self.batch_size, page_id
):
current_level_pages.extend(batch)
index += len(batch)
pages.extend(current_level_pages)
# Recursively index children and children's children, etc.
while current_level_pages:
for child in current_level_pages:
child_index = 0
while child_batch := self._fetch_single_depth_child_pages(
child_index, self.batch_size, child["id"]
):
next_level_pages.extend(child_batch)
child_index += len(child_batch)
pages.extend(next_level_pages)
current_level_pages = next_level_pages
next_level_pages = []
try:
origin_page = self._fetch_origin_page()
pages.append(origin_page)
except Exception as e:
logger.warning(f"Appending origin page with id {page_id} failed: {e}")
return pages
def _fetch_single_depth_child_pages(
self, start_ind: int, batch_size: int, page_id: str
) -> list[dict[str, Any]]:
child_pages: list[dict[str, Any]] = []
get_page_child_by_type = make_confluence_call_handle_rate_limit( get_page_child_by_type = make_confluence_call_handle_rate_limit(
self.confluence_client.get_page_child_by_type self.confluence_client.get_page_child_by_type
) )
try: while queue:
child_page = get_page_child_by_type( current_page_id = queue.pop(0)
page_id, if current_page_id in visited_pages:
type="page", continue
start=start_ind, visited_pages.add(current_page_id)
limit=batch_size,
expand="body.storage.value,version",
)
child_pages.extend(child_page) try:
return child_pages # Fetch the page itself
page = self.confluence_client.get_page_by_id(
current_page_id, expand="body.storage.value,version,space"
)
pages.append(page)
except Exception as e:
logger.warning(f"Failed to fetch page {current_page_id}: {e}")
continue
except Exception: if not self.index_recursively:
logger.warning( continue
f"Batch failed with page {page_id} at offset {start_ind} "
f"with size {batch_size}, processing pages individually..."
)
for i in range(batch_size): # Fetch child pages
ind = start_ind + i start = 0
try: while True:
child_page = get_page_child_by_type( child_pages_response = get_page_child_by_type(
page_id, current_page_id,
type="page", type="page",
start=ind, start=start,
limit=1, limit=self.batch_size,
expand="body.storage.value,version", expand="",
) )
child_pages.extend(child_page) if not child_pages_response:
except Exception as e: break
logger.warning(f"Page {page_id} at offset {ind} failed: {e}") for child_page in child_pages_response:
raise e child_page_id = child_page["id"]
queue.append(child_page_id)
start += len(child_pages_response)
return child_pages return pages
class ConfluenceConnector(LoadConnector, PollConnector): class ConfluenceConnector(LoadConnector, PollConnector):
@ -399,7 +297,6 @@ class ConfluenceConnector(LoadConnector, PollConnector):
# Remove trailing slash from wiki_base if present # Remove trailing slash from wiki_base if present
self.wiki_base = wiki_base.rstrip("/") self.wiki_base = wiki_base.rstrip("/")
self.space = space
self.page_id = "" if cql_query else page_id self.page_id = "" if cql_query else page_id
self.space_level_scan = bool(not self.page_id) self.space_level_scan = bool(not self.page_id)
@ -409,16 +306,16 @@ class ConfluenceConnector(LoadConnector, PollConnector):
# if a cql_query is provided, we will use it to fetch the pages # if a cql_query is provided, we will use it to fetch the pages
# if no cql_query is provided, we will use the space to fetch the pages # if no cql_query is provided, we will use the space to fetch the pages
# if no space is provided, we will default to fetching all pages, regardless of space # if no space is provided and no cql_query, we will default to fetching all pages, regardless of space
if cql_query: if cql_query:
self.cql_query = cql_query self.cql_query = cql_query
elif self.space: elif space:
self.cql_query = f"type=page and space={self.space}" self.cql_query = f"type=page and space='{space}'"
else: else:
self.cql_query = "type=page" self.cql_query = "type=page"
logger.info( logger.info(
f"wiki_base: {self.wiki_base}, space: {self.space}, page_id: {self.page_id}," f"wiki_base: {self.wiki_base}, space: {space}, page_id: {self.page_id},"
+ f" space_level_scan: {self.space_level_scan}, index_recursively: {self.index_recursively}," + f" space_level_scan: {self.space_level_scan}, index_recursively: {self.index_recursively},"
+ f" cql_query: {self.cql_query}" + f" cql_query: {self.cql_query}"
) )
@ -428,7 +325,6 @@ class ConfluenceConnector(LoadConnector, PollConnector):
access_token = credentials["confluence_access_token"] access_token = credentials["confluence_access_token"]
self.confluence_client = DanswerConfluence( self.confluence_client = DanswerConfluence(
url=self.wiki_base, url=self.wiki_base,
# passing in username causes issues for Confluence data center
username=username if self.is_cloud else None, username=username if self.is_cloud else None,
password=access_token if self.is_cloud else None, password=access_token if self.is_cloud else None,
token=access_token if not self.is_cloud else None, token=access_token if not self.is_cloud else None,
@ -437,12 +333,16 @@ class ConfluenceConnector(LoadConnector, PollConnector):
def _fetch_pages( def _fetch_pages(
self, self,
start_ind: int, cursor: str | None,
) -> list[dict[str, Any]]: ) -> tuple[list[dict[str, Any]], str | None]:
def _fetch_space(start_ind: int, batch_size: int) -> list[dict[str, Any]]: if self.confluence_client is None:
if self.confluence_client is None: raise Exception("Confluence client is not initialized")
raise ConnectorMissingCredentialError("Confluence")
def _fetch_space(
cursor: str | None, batch_size: int
) -> tuple[list[dict[str, Any]], str | None]:
if not self.confluence_client:
raise Exception("Confluence client is not initialized")
get_all_pages = make_confluence_call_handle_rate_limit( get_all_pages = make_confluence_call_handle_rate_limit(
self.confluence_client.danswer_cql self.confluence_client.danswer_cql
) )
@ -454,53 +354,84 @@ class ConfluenceConnector(LoadConnector, PollConnector):
) )
try: try:
return get_all_pages( response = get_all_pages(
cql=self.cql_query, cql=self.cql_query,
start=start_ind, cursor=cursor,
limit=batch_size, limit=batch_size,
expand="body.storage.value,version", expand="body.storage.value,version,space",
include_archived_spaces=include_archived_spaces, include_archived_spaces=include_archived_spaces,
) )
pages = response.get("results", [])
next_cursor = None
if "_links" in response and "next" in response["_links"]:
next_link = response["_links"]["next"]
parsed_url = urlparse(next_link)
query_params = parse_qs(parsed_url.query)
cursor_list = query_params.get("cursor", [])
if cursor_list:
next_cursor = cursor_list[0]
return pages, next_cursor
except Exception: except Exception:
logger.warning( logger.warning(
f"Batch failed with cql {self.cql_query} at offset {start_ind} " f"Batch failed with cql {self.cql_query} with cursor {cursor} "
f"with size {batch_size}, processing pages individually..." f"and size {batch_size}, processing pages individually..."
) )
view_pages: list[dict[str, Any]] = [] view_pages: list[dict[str, Any]] = []
for i in range(self.batch_size): for _ in range(self.batch_size):
try: try:
# Could be that one of the pages here failed due to this bug: response = get_all_pages(
# https://jira.atlassian.com/browse/CONFCLOUD-76433 cql=self.cql_query,
view_pages.extend( cursor=cursor,
get_all_pages( limit=1,
cql=self.cql_query, expand="body.view.value,version,space",
start=start_ind + i, include_archived_spaces=include_archived_spaces,
limit=1,
expand="body.storage.value,version",
include_archived_spaces=include_archived_spaces,
)
) )
pages = response.get("results", [])
view_pages.extend(pages)
if "_links" in response and "next" in response["_links"]:
next_link = response["_links"]["next"]
parsed_url = urlparse(next_link)
query_params = parse_qs(parsed_url.query)
cursor_list = query_params.get("cursor", [])
if cursor_list:
cursor = cursor_list[0]
else:
cursor = None
else:
cursor = None
break
except HTTPError as e: except HTTPError as e:
logger.warning( logger.warning(
f"Page failed with cql {self.cql_query} at offset {start_ind + i}, " f"Page failed with cql {self.cql_query} with cursor {cursor}, "
f"trying alternative expand option: {e}" f"trying alternative expand option: {e}"
) )
# Use view instead, which captures most info but is less complete response = get_all_pages(
view_pages.extend( cql=self.cql_query,
get_all_pages( cursor=cursor,
cql=self.cql_query, limit=1,
start=start_ind + i, expand="body.view.value,version,space",
limit=1,
expand="body.view.value,version",
)
) )
pages = response.get("results", [])
view_pages.extend(pages)
if "_links" in response and "next" in response["_links"]:
next_link = response["_links"]["next"]
parsed_url = urlparse(next_link)
query_params = parse_qs(parsed_url.query)
cursor_list = query_params.get("cursor", [])
if cursor_list:
cursor = cursor_list[0]
else:
cursor = None
else:
cursor = None
break
return view_pages return view_pages, cursor
def _fetch_page(start_ind: int, batch_size: int) -> list[dict[str, Any]]: def _fetch_page() -> tuple[list[dict[str, Any]], str | None]:
if self.confluence_client is None: if self.confluence_client is None:
raise ConnectorMissingCredentialError("Confluence") raise Exception("Confluence client is not initialized")
if self.recursive_indexer is None: if self.recursive_indexer is None:
self.recursive_indexer = RecursiveIndexer( self.recursive_indexer = RecursiveIndexer(
@ -510,59 +441,37 @@ class ConfluenceConnector(LoadConnector, PollConnector):
index_recursively=self.index_recursively, index_recursively=self.index_recursively,
) )
if self.index_recursively: pages = self.recursive_indexer.get_pages()
return self.recursive_indexer.get_pages(start_ind, batch_size) return pages, None # Since we fetched all pages, no cursor
else:
return self.recursive_indexer.get_origin_page()
pages: list[dict[str, Any]] = []
try: try:
pages = ( pages, next_cursor = (
_fetch_space(start_ind, self.batch_size) _fetch_space(cursor, self.batch_size)
if self.space_level_scan if self.space_level_scan
else _fetch_page(start_ind, self.batch_size) else _fetch_page()
) )
return pages return pages, next_cursor
except Exception as e: except Exception as e:
if not self.continue_on_failure: if not self.continue_on_failure:
raise e raise e
# error checking phase, only reachable if `self.continue_on_failure=True` logger.exception("Ran into exception when fetching pages from Confluence")
for _ in range(self.batch_size): return [], None
try:
pages = (
_fetch_space(start_ind, self.batch_size)
if self.space_level_scan
else _fetch_page(start_ind, self.batch_size)
)
return pages
except Exception: def _fetch_comments(self, confluence_client: Confluence, page_id: str) -> str:
logger.exception(
"Ran into exception when fetching pages from Confluence"
)
return pages
def _fetch_comments(
self, confluence_client: DanswerConfluence, page_id: str
) -> str:
get_page_child_by_type = make_confluence_call_handle_rate_limit( get_page_child_by_type = make_confluence_call_handle_rate_limit(
confluence_client.get_page_child_by_type confluence_client.get_page_child_by_type
) )
try: try:
comment_pages = cast( comment_pages = list(
Collection[dict[str, Any]],
get_page_child_by_type( get_page_child_by_type(
page_id, page_id,
type="comment", type="comment",
start=None, start=None,
limit=None, limit=None,
expand="body.storage.value", expand="body.storage.value",
), )
) )
return _comment_dfs("", comment_pages, confluence_client) return _comment_dfs("", comment_pages, confluence_client)
except Exception as e: except Exception as e:
@ -574,9 +483,7 @@ class ConfluenceConnector(LoadConnector, PollConnector):
) )
return "" return ""
def _fetch_labels( def _fetch_labels(self, confluence_client: Confluence, page_id: str) -> list[str]:
self, confluence_client: DanswerConfluence, page_id: str
) -> list[str]:
get_page_labels = make_confluence_call_handle_rate_limit( get_page_labels = make_confluence_call_handle_rate_limit(
confluence_client.get_page_labels confluence_client.get_page_labels
) )
@ -647,22 +554,22 @@ class ConfluenceConnector(LoadConnector, PollConnector):
return extracted_text return extracted_text
def _fetch_attachments( def _fetch_attachments(
self, confluence_client: Confluence, page_id: str, files_in_used: list[str] self, confluence_client: Confluence, page_id: str, files_in_use: list[str]
) -> tuple[str, list[dict[str, Any]]]: ) -> tuple[str, list[dict[str, Any]]]:
unused_attachments: list = [] unused_attachments: list[dict[str, Any]] = []
files_attachment_content: list[str] = []
get_attachments_from_content = make_confluence_call_handle_rate_limit( get_attachments_from_content = make_confluence_call_handle_rate_limit(
confluence_client.get_attachments_from_content confluence_client.get_attachments_from_content
) )
files_attachment_content: list = []
try: try:
expand = "history.lastUpdated,metadata.labels" expand = "history.lastUpdated,metadata.labels"
attachments_container = get_attachments_from_content( attachments_container = get_attachments_from_content(
page_id, start=0, limit=500, expand=expand page_id, start=None, limit=None, expand=expand
) )
for attachment in attachments_container["results"]: for attachment in attachments_container.get("results", []):
if attachment["title"] not in files_in_used: if attachment["title"] not in files_in_use:
unused_attachments.append(attachment) unused_attachments.append(attachment)
continue continue
@ -680,7 +587,6 @@ class ConfluenceConnector(LoadConnector, PollConnector):
f"User does not have access to attachments on page '{page_id}'" f"User does not have access to attachments on page '{page_id}'"
) )
return "", [] return "", []
if not self.continue_on_failure: if not self.continue_on_failure:
raise e raise e
logger.exception( logger.exception(
@ -690,24 +596,26 @@ class ConfluenceConnector(LoadConnector, PollConnector):
return "\n".join(files_attachment_content), unused_attachments return "\n".join(files_attachment_content), unused_attachments
def _get_doc_batch( def _get_doc_batch(
self, start_ind: int self, cursor: str | None
) -> tuple[list[Document], list[dict[str, Any]], int]: ) -> tuple[list[Any], str | None, list[dict[str, Any]]]:
if self.confluence_client is None: if self.confluence_client is None:
raise ConnectorMissingCredentialError("Confluence") raise Exception("Confluence client is not initialized")
doc_batch: list[Document] = [] doc_batch: list[Any] = []
unused_attachments: list[dict[str, Any]] = [] unused_attachments: list[dict[str, Any]] = []
batch = self._fetch_pages(start_ind) batch, next_cursor = self._fetch_pages(cursor)
for page in batch: for page in batch:
last_modified = _datetime_from_string(page["version"]["when"]) last_modified = _datetime_from_string(page["version"]["when"])
author = cast(str | None, page["version"].get("by", {}).get("email")) author = page["version"].get("by", {}).get("email")
page_id = page["id"] page_id = page["id"]
if self.labels_to_skip or not CONFLUENCE_CONNECTOR_SKIP_LABEL_INDEXING: if self.labels_to_skip or not CONFLUENCE_CONNECTOR_SKIP_LABEL_INDEXING:
page_labels = self._fetch_labels(self.confluence_client, page_id) page_labels = self._fetch_labels(self.confluence_client, page_id)
else:
page_labels = []
# check disallowed labels # check disallowed labels
if self.labels_to_skip: if self.labels_to_skip:
@ -717,7 +625,6 @@ class ConfluenceConnector(LoadConnector, PollConnector):
f"Page with ID '{page_id}' has a label which has been " f"Page with ID '{page_id}' has a label which has been "
f"designated as disallowed: {label_intersection}. Skipping." f"designated as disallowed: {label_intersection}. Skipping."
) )
continue continue
page_html = ( page_html = (
@ -732,16 +639,18 @@ class ConfluenceConnector(LoadConnector, PollConnector):
continue continue
page_text = parse_html_page(page_html, self.confluence_client) page_text = parse_html_page(page_html, self.confluence_client)
files_in_used = get_used_attachments(page_html) files_in_use = get_used_attachments(page_html)
attachment_text, unused_page_attachments = self._fetch_attachments( attachment_text, unused_page_attachments = self._fetch_attachments(
self.confluence_client, page_id, files_in_used self.confluence_client, page_id, files_in_use
) )
unused_attachments.extend(unused_page_attachments) unused_attachments.extend(unused_page_attachments)
page_text += "\n" + attachment_text if attachment_text else "" page_text += "\n" + attachment_text if attachment_text else ""
comments_text = self._fetch_comments(self.confluence_client, page_id) comments_text = self._fetch_comments(self.confluence_client, page_id)
page_text += comments_text page_text += comments_text
doc_metadata: dict[str, str | list[str]] = {"Wiki Space Name": self.space} doc_metadata: dict[str, str | list[str]] = {
"Wiki Space Name": page["space"]["name"]
}
if not CONFLUENCE_CONNECTOR_SKIP_LABEL_INDEXING and page_labels: if not CONFLUENCE_CONNECTOR_SKIP_LABEL_INDEXING and page_labels:
doc_metadata["labels"] = page_labels doc_metadata["labels"] = page_labels
@ -760,8 +669,8 @@ class ConfluenceConnector(LoadConnector, PollConnector):
) )
return ( return (
doc_batch, doc_batch,
next_cursor,
unused_attachments, unused_attachments,
len(batch),
) )
def _get_attachment_batch( def _get_attachment_batch(
@ -769,8 +678,8 @@ class ConfluenceConnector(LoadConnector, PollConnector):
start_ind: int, start_ind: int,
attachments: list[dict[str, Any]], attachments: list[dict[str, Any]],
time_filter: Callable[[datetime], bool] | None = None, time_filter: Callable[[datetime], bool] | None = None,
) -> tuple[list[Document], int]: ) -> tuple[list[Any], int]:
doc_batch: list[Document] = [] doc_batch: list[Any] = []
if self.confluence_client is None: if self.confluence_client is None:
raise ConnectorMissingCredentialError("Confluence") raise ConnectorMissingCredentialError("Confluence")
@ -798,7 +707,7 @@ class ConfluenceConnector(LoadConnector, PollConnector):
creator_email = attachment["history"]["createdBy"].get("email") creator_email = attachment["history"]["createdBy"].get("email")
comment = attachment["metadata"].get("comment", "") comment = attachment["metadata"].get("comment", "")
doc_metadata: dict[str, str | list[str]] = {"comment": comment} doc_metadata: dict[str, Any] = {"comment": comment}
attachment_labels: list[str] = [] attachment_labels: list[str] = []
if not CONFLUENCE_CONNECTOR_SKIP_LABEL_INDEXING: if not CONFLUENCE_CONNECTOR_SKIP_LABEL_INDEXING:
@ -825,64 +734,36 @@ class ConfluenceConnector(LoadConnector, PollConnector):
return doc_batch, end_ind - start_ind return doc_batch, end_ind - start_ind
def load_from_state(self) -> GenerateDocumentsOutput: def _handle_batch_retrieval(
unused_attachments: list[dict[str, Any]] = [] self,
start: float | None = None,
start_ind = 0 end: float | None = None,
while True:
doc_batch, unused_attachments, num_pages = self._get_doc_batch(start_ind)
unused_attachments.extend(unused_attachments)
start_ind += num_pages
if doc_batch:
yield doc_batch
if num_pages < self.batch_size:
break
start_ind = 0
while True:
attachment_batch, num_attachments = self._get_attachment_batch(
start_ind, unused_attachments
)
start_ind += num_attachments
if attachment_batch:
yield attachment_batch
if num_attachments < self.batch_size:
break
def poll_source(
self, start: SecondsSinceUnixEpoch, end: SecondsSinceUnixEpoch
) -> GenerateDocumentsOutput: ) -> GenerateDocumentsOutput:
start_time = datetime.fromtimestamp(start, tz=timezone.utc) if start else None
end_time = datetime.fromtimestamp(end, tz=timezone.utc) if end else None
unused_attachments: list[dict[str, Any]] = [] unused_attachments: list[dict[str, Any]] = []
cursor = None
if self.confluence_client is None:
raise ConnectorMissingCredentialError("Confluence")
start_time = datetime.fromtimestamp(start, tz=timezone.utc)
end_time = datetime.fromtimestamp(end, tz=timezone.utc)
self.cql_query = _replace_cql_time_filter(self.cql_query, start_time, end_time)
start_ind = 0
while True: while True:
doc_batch, unused_attachments, num_pages = self._get_doc_batch(start_ind) doc_batch, cursor, new_unused_attachments = self._get_doc_batch(cursor)
unused_attachments.extend(unused_attachments) unused_attachments.extend(new_unused_attachments)
start_ind += num_pages
if doc_batch: if doc_batch:
yield doc_batch yield doc_batch
if num_pages < self.batch_size: if not cursor:
break break
# Process attachments if any
start_ind = 0 start_ind = 0
while True: while True:
attachment_batch, num_attachments = self._get_attachment_batch( attachment_batch, num_attachments = self._get_attachment_batch(
start_ind, start_ind=start_ind,
unused_attachments, attachments=unused_attachments,
time_filter=lambda t: start_time <= t <= end_time, time_filter=(lambda t: start_time <= t <= end_time)
if start_time and end_time
else None,
) )
start_ind += num_attachments start_ind += num_attachments
if attachment_batch: if attachment_batch:
yield attachment_batch yield attachment_batch
@ -890,6 +771,12 @@ class ConfluenceConnector(LoadConnector, PollConnector):
if num_attachments < self.batch_size: if num_attachments < self.batch_size:
break break
def load_from_state(self) -> GenerateDocumentsOutput:
return self._handle_batch_retrieval()
def poll_source(self, start: float, end: float) -> GenerateDocumentsOutput:
return self._handle_batch_retrieval(start=start, end=end)
if __name__ == "__main__": if __name__ == "__main__":
connector = ConfluenceConnector( connector = ConfluenceConnector(

View File

@ -306,7 +306,7 @@ export const connectorConfigs: Record<
name: "cql_query", name: "cql_query",
optional: true, optional: true,
description: description:
"IMPORTANT: This will overwrite all other selected connector settings (besides Wiki Base URL). We currently only support CQL queries that return objects of type 'page'. This means all CQL queries must contain 'type=page' as the only type filter. We will still get all attachments and comments for the pages returned by the CQL query. Any 'lastmodified' filters will be overwritten. See https://developer.atlassian.com/server/confluence/advanced-searching-using-cql/ for more details.", "IMPORTANT: This will overwrite all other selected connector settings (besides Wiki Base URL). We currently only support CQL queries that return objects of type 'page'. This means all CQL queries must contain 'type=page' as the only type filter. It is also important that no filters for 'lastModified' are used as it will cause issues with our connector polling logic. We will still get all attachments and comments for the pages returned by the CQL query. Any 'lastmodified' filters will be overwritten. See https://developer.atlassian.com/server/confluence/advanced-searching-using-cql/ for more details.",
}, },
], ],
}, },