Add cursor to cql confluence (#2775)

* add cursor to cql confluence

* k

* k

* fixed space indexing issue

* fixed .get

---------

Co-authored-by: hagen-danswer <hagen@danswer.ai>
This commit is contained in:
pablodanswer
2024-10-13 19:09:17 -07:00
committed by GitHub
parent ded42e2036
commit a9bcc89a2c
2 changed files with 183 additions and 296 deletions

View File

@@ -1,13 +1,13 @@
import io
import os
import re
from collections.abc import Callable
from collections.abc import Collection
from datetime import datetime
from datetime import timezone
from functools import lru_cache
from typing import Any
from typing import cast
from urllib.parse import parse_qs
from urllib.parse import urlparse
import bs4
from atlassian import Confluence # type:ignore
@@ -33,7 +33,6 @@ from danswer.connectors.confluence.rate_limit_handler import (
from danswer.connectors.interfaces import GenerateDocumentsOutput
from danswer.connectors.interfaces import LoadConnector
from danswer.connectors.interfaces import PollConnector
from danswer.connectors.interfaces import SecondsSinceUnixEpoch
from danswer.connectors.models import BasicExpertInfo
from danswer.connectors.models import ConnectorMissingCredentialError
from danswer.connectors.models import Document
@@ -70,86 +69,25 @@ class DanswerConfluence(Confluence):
self,
cql: str,
expand: str | None = None,
start: int = 0,
cursor: str | None = None,
limit: int = 500,
include_archived_spaces: bool = False,
) -> list[dict[str, Any]]:
# Performs the query expansion and start/limit url additions
) -> dict[str, Any]:
url_suffix = f"rest/api/content/search?cql={cql}"
if expand:
url_suffix += f"&expand={expand}"
url_suffix += f"&start={start}&limit={limit}"
if cursor:
url_suffix += f"&cursor={cursor}"
url_suffix += f"&limit={limit}"
if include_archived_spaces:
url_suffix += "&includeArchivedSpaces=true"
try:
response = self.get(url_suffix)
return response.get("results", [])
return response
except Exception as e:
raise e
def _replace_cql_time_filter(
cql_query: str, start_time: datetime, end_time: datetime
) -> str:
"""
This function replaces the lastmodified filter in the CQL query with the start and end times.
This selects the more restrictive time range.
"""
# Extract existing lastmodified >= and <= filters
existing_start_match = re.search(
r'lastmodified\s*>=\s*["\']?(\d{4}-\d{2}-\d{2}(?:\s+\d{2}:\d{2})?)["\']?',
cql_query,
flags=re.IGNORECASE,
)
existing_end_match = re.search(
r'lastmodified\s*<=\s*["\']?(\d{4}-\d{2}-\d{2}(?:\s+\d{2}:\d{2})?)["\']?',
cql_query,
flags=re.IGNORECASE,
)
# Remove all existing lastmodified and updated filters
cql_query = re.sub(
r'\s*AND\s+(lastmodified|updated)\s*[<>=]+\s*["\']?[\d-]+(?:\s+[\d:]+)?["\']?',
"",
cql_query,
flags=re.IGNORECASE,
)
# Determine the start time to use
if existing_start_match:
existing_start_str = existing_start_match.group(1)
existing_start = datetime.strptime(
existing_start_str,
"%Y-%m-%d %H:%M" if " " in existing_start_str else "%Y-%m-%d",
)
existing_start = existing_start.replace(
tzinfo=timezone.utc
) # Make offset-aware
start_time_to_use = max(start_time.astimezone(timezone.utc), existing_start)
else:
start_time_to_use = start_time.astimezone(timezone.utc)
# Determine the end time to use
if existing_end_match:
existing_end_str = existing_end_match.group(1)
existing_end = datetime.strptime(
existing_end_str,
"%Y-%m-%d %H:%M" if " " in existing_end_str else "%Y-%m-%d",
)
existing_end = existing_end.replace(tzinfo=timezone.utc) # Make offset-aware
end_time_to_use = min(end_time.astimezone(timezone.utc), existing_end)
else:
end_time_to_use = end_time.astimezone(timezone.utc)
# Add new time filters
cql_query += (
f" and lastmodified >= '{start_time_to_use.strftime('%Y-%m-%d %H:%M')}'"
)
cql_query += f" and lastmodified <= '{end_time_to_use.strftime('%Y-%m-%d %H:%M')}'"
return cql_query.strip()
@lru_cache()
def _get_user(user_id: str, confluence_client: DanswerConfluence) -> str:
"""Get Confluence Display Name based on the account-id or userkey value
@@ -253,126 +191,86 @@ class RecursiveIndexer:
def __init__(
self,
batch_size: int,
confluence_client: DanswerConfluence,
confluence_client: Confluence,
index_recursively: bool,
origin_page_id: str,
) -> None:
self.batch_size = 1
# batch_size
self.batch_size = batch_size
self.confluence_client = confluence_client
self.index_recursively = index_recursively
self.origin_page_id = origin_page_id
self.pages = self.recurse_children_pages(0, self.origin_page_id)
self.pages = self.recurse_children_pages(self.origin_page_id)
def get_origin_page(self) -> list[dict[str, Any]]:
return [self._fetch_origin_page()]
def get_pages(self, ind: int, size: int) -> list[dict]:
if ind * size > len(self.pages):
return []
return self.pages[ind * size : (ind + 1) * size]
def get_pages(self) -> list[dict[str, Any]]:
return self.pages
def _fetch_origin_page(
self,
) -> dict[str, Any]:
def _fetch_origin_page(self) -> dict[str, Any]:
get_page_by_id = make_confluence_call_handle_rate_limit(
self.confluence_client.get_page_by_id
)
try:
origin_page = get_page_by_id(
self.origin_page_id, expand="body.storage.value,version"
self.origin_page_id, expand="body.storage.value,version,space"
)
return origin_page
except Exception as e:
logger.warning(
f"Appending orgin page with id {self.origin_page_id} failed: {e}"
f"Appending origin page with id {self.origin_page_id} failed: {e}"
)
return {}
def recurse_children_pages(
self,
start_ind: int,
page_id: str,
) -> list[dict[str, Any]]:
pages: list[dict[str, Any]] = []
current_level_pages: list[dict[str, Any]] = []
next_level_pages: list[dict[str, Any]] = []
# Initial fetch of first level children
index = start_ind
while batch := self._fetch_single_depth_child_pages(
index, self.batch_size, page_id
):
current_level_pages.extend(batch)
index += len(batch)
pages.extend(current_level_pages)
# Recursively index children and children's children, etc.
while current_level_pages:
for child in current_level_pages:
child_index = 0
while child_batch := self._fetch_single_depth_child_pages(
child_index, self.batch_size, child["id"]
):
next_level_pages.extend(child_batch)
child_index += len(child_batch)
pages.extend(next_level_pages)
current_level_pages = next_level_pages
next_level_pages = []
try:
origin_page = self._fetch_origin_page()
pages.append(origin_page)
except Exception as e:
logger.warning(f"Appending origin page with id {page_id} failed: {e}")
return pages
def _fetch_single_depth_child_pages(
self, start_ind: int, batch_size: int, page_id: str
) -> list[dict[str, Any]]:
child_pages: list[dict[str, Any]] = []
queue: list[str] = [page_id]
visited_pages: set[str] = set()
get_page_child_by_type = make_confluence_call_handle_rate_limit(
self.confluence_client.get_page_child_by_type
)
try:
child_page = get_page_child_by_type(
page_id,
type="page",
start=start_ind,
limit=batch_size,
expand="body.storage.value,version",
)
while queue:
current_page_id = queue.pop(0)
if current_page_id in visited_pages:
continue
visited_pages.add(current_page_id)
child_pages.extend(child_page)
return child_pages
try:
# Fetch the page itself
page = self.confluence_client.get_page_by_id(
current_page_id, expand="body.storage.value,version,space"
)
pages.append(page)
except Exception as e:
logger.warning(f"Failed to fetch page {current_page_id}: {e}")
continue
except Exception:
logger.warning(
f"Batch failed with page {page_id} at offset {start_ind} "
f"with size {batch_size}, processing pages individually..."
)
if not self.index_recursively:
continue
for i in range(batch_size):
ind = start_ind + i
try:
child_page = get_page_child_by_type(
page_id,
type="page",
start=ind,
limit=1,
expand="body.storage.value,version",
)
child_pages.extend(child_page)
except Exception as e:
logger.warning(f"Page {page_id} at offset {ind} failed: {e}")
raise e
# Fetch child pages
start = 0
while True:
child_pages_response = get_page_child_by_type(
current_page_id,
type="page",
start=start,
limit=self.batch_size,
expand="",
)
if not child_pages_response:
break
for child_page in child_pages_response:
child_page_id = child_page["id"]
queue.append(child_page_id)
start += len(child_pages_response)
return child_pages
return pages
class ConfluenceConnector(LoadConnector, PollConnector):
@@ -399,7 +297,6 @@ class ConfluenceConnector(LoadConnector, PollConnector):
# Remove trailing slash from wiki_base if present
self.wiki_base = wiki_base.rstrip("/")
self.space = space
self.page_id = "" if cql_query else page_id
self.space_level_scan = bool(not self.page_id)
@@ -409,16 +306,16 @@ class ConfluenceConnector(LoadConnector, PollConnector):
# if a cql_query is provided, we will use it to fetch the pages
# if no cql_query is provided, we will use the space to fetch the pages
# if no space is provided, we will default to fetching all pages, regardless of space
# if no space is provided and no cql_query, we will default to fetching all pages, regardless of space
if cql_query:
self.cql_query = cql_query
elif self.space:
self.cql_query = f"type=page and space={self.space}"
elif space:
self.cql_query = f"type=page and space='{space}'"
else:
self.cql_query = "type=page"
logger.info(
f"wiki_base: {self.wiki_base}, space: {self.space}, page_id: {self.page_id},"
f"wiki_base: {self.wiki_base}, space: {space}, page_id: {self.page_id},"
+ f" space_level_scan: {self.space_level_scan}, index_recursively: {self.index_recursively},"
+ f" cql_query: {self.cql_query}"
)
@@ -428,7 +325,6 @@ class ConfluenceConnector(LoadConnector, PollConnector):
access_token = credentials["confluence_access_token"]
self.confluence_client = DanswerConfluence(
url=self.wiki_base,
# passing in username causes issues for Confluence data center
username=username if self.is_cloud else None,
password=access_token if self.is_cloud else None,
token=access_token if not self.is_cloud else None,
@@ -437,12 +333,16 @@ class ConfluenceConnector(LoadConnector, PollConnector):
def _fetch_pages(
self,
start_ind: int,
) -> list[dict[str, Any]]:
def _fetch_space(start_ind: int, batch_size: int) -> list[dict[str, Any]]:
if self.confluence_client is None:
raise ConnectorMissingCredentialError("Confluence")
cursor: str | None,
) -> tuple[list[dict[str, Any]], str | None]:
if self.confluence_client is None:
raise Exception("Confluence client is not initialized")
def _fetch_space(
cursor: str | None, batch_size: int
) -> tuple[list[dict[str, Any]], str | None]:
if not self.confluence_client:
raise Exception("Confluence client is not initialized")
get_all_pages = make_confluence_call_handle_rate_limit(
self.confluence_client.danswer_cql
)
@@ -454,53 +354,84 @@ class ConfluenceConnector(LoadConnector, PollConnector):
)
try:
return get_all_pages(
response = get_all_pages(
cql=self.cql_query,
start=start_ind,
cursor=cursor,
limit=batch_size,
expand="body.storage.value,version",
expand="body.storage.value,version,space",
include_archived_spaces=include_archived_spaces,
)
pages = response.get("results", [])
next_cursor = None
if "_links" in response and "next" in response["_links"]:
next_link = response["_links"]["next"]
parsed_url = urlparse(next_link)
query_params = parse_qs(parsed_url.query)
cursor_list = query_params.get("cursor", [])
if cursor_list:
next_cursor = cursor_list[0]
return pages, next_cursor
except Exception:
logger.warning(
f"Batch failed with cql {self.cql_query} at offset {start_ind} "
f"with size {batch_size}, processing pages individually..."
f"Batch failed with cql {self.cql_query} with cursor {cursor} "
f"and size {batch_size}, processing pages individually..."
)
view_pages: list[dict[str, Any]] = []
for i in range(self.batch_size):
for _ in range(self.batch_size):
try:
# Could be that one of the pages here failed due to this bug:
# https://jira.atlassian.com/browse/CONFCLOUD-76433
view_pages.extend(
get_all_pages(
cql=self.cql_query,
start=start_ind + i,
limit=1,
expand="body.storage.value,version",
include_archived_spaces=include_archived_spaces,
)
response = get_all_pages(
cql=self.cql_query,
cursor=cursor,
limit=1,
expand="body.view.value,version,space",
include_archived_spaces=include_archived_spaces,
)
pages = response.get("results", [])
view_pages.extend(pages)
if "_links" in response and "next" in response["_links"]:
next_link = response["_links"]["next"]
parsed_url = urlparse(next_link)
query_params = parse_qs(parsed_url.query)
cursor_list = query_params.get("cursor", [])
if cursor_list:
cursor = cursor_list[0]
else:
cursor = None
else:
cursor = None
break
except HTTPError as e:
logger.warning(
f"Page failed with cql {self.cql_query} at offset {start_ind + i}, "
f"Page failed with cql {self.cql_query} with cursor {cursor}, "
f"trying alternative expand option: {e}"
)
# Use view instead, which captures most info but is less complete
view_pages.extend(
get_all_pages(
cql=self.cql_query,
start=start_ind + i,
limit=1,
expand="body.view.value,version",
)
response = get_all_pages(
cql=self.cql_query,
cursor=cursor,
limit=1,
expand="body.view.value,version,space",
)
pages = response.get("results", [])
view_pages.extend(pages)
if "_links" in response and "next" in response["_links"]:
next_link = response["_links"]["next"]
parsed_url = urlparse(next_link)
query_params = parse_qs(parsed_url.query)
cursor_list = query_params.get("cursor", [])
if cursor_list:
cursor = cursor_list[0]
else:
cursor = None
else:
cursor = None
break
return view_pages
return view_pages, cursor
def _fetch_page(start_ind: int, batch_size: int) -> list[dict[str, Any]]:
def _fetch_page() -> tuple[list[dict[str, Any]], str | None]:
if self.confluence_client is None:
raise ConnectorMissingCredentialError("Confluence")
raise Exception("Confluence client is not initialized")
if self.recursive_indexer is None:
self.recursive_indexer = RecursiveIndexer(
@@ -510,59 +441,37 @@ class ConfluenceConnector(LoadConnector, PollConnector):
index_recursively=self.index_recursively,
)
if self.index_recursively:
return self.recursive_indexer.get_pages(start_ind, batch_size)
else:
return self.recursive_indexer.get_origin_page()
pages: list[dict[str, Any]] = []
pages = self.recursive_indexer.get_pages()
return pages, None # Since we fetched all pages, no cursor
try:
pages = (
_fetch_space(start_ind, self.batch_size)
pages, next_cursor = (
_fetch_space(cursor, self.batch_size)
if self.space_level_scan
else _fetch_page(start_ind, self.batch_size)
else _fetch_page()
)
return pages
return pages, next_cursor
except Exception as e:
if not self.continue_on_failure:
raise e
# error checking phase, only reachable if `self.continue_on_failure=True`
for _ in range(self.batch_size):
try:
pages = (
_fetch_space(start_ind, self.batch_size)
if self.space_level_scan
else _fetch_page(start_ind, self.batch_size)
)
return pages
logger.exception("Ran into exception when fetching pages from Confluence")
return [], None
except Exception:
logger.exception(
"Ran into exception when fetching pages from Confluence"
)
return pages
def _fetch_comments(
self, confluence_client: DanswerConfluence, page_id: str
) -> str:
def _fetch_comments(self, confluence_client: Confluence, page_id: str) -> str:
get_page_child_by_type = make_confluence_call_handle_rate_limit(
confluence_client.get_page_child_by_type
)
try:
comment_pages = cast(
Collection[dict[str, Any]],
comment_pages = list(
get_page_child_by_type(
page_id,
type="comment",
start=None,
limit=None,
expand="body.storage.value",
),
)
)
return _comment_dfs("", comment_pages, confluence_client)
except Exception as e:
@@ -574,9 +483,7 @@ class ConfluenceConnector(LoadConnector, PollConnector):
)
return ""
def _fetch_labels(
self, confluence_client: DanswerConfluence, page_id: str
) -> list[str]:
def _fetch_labels(self, confluence_client: Confluence, page_id: str) -> list[str]:
get_page_labels = make_confluence_call_handle_rate_limit(
confluence_client.get_page_labels
)
@@ -647,22 +554,22 @@ class ConfluenceConnector(LoadConnector, PollConnector):
return extracted_text
def _fetch_attachments(
self, confluence_client: Confluence, page_id: str, files_in_used: list[str]
self, confluence_client: Confluence, page_id: str, files_in_use: list[str]
) -> tuple[str, list[dict[str, Any]]]:
unused_attachments: list = []
unused_attachments: list[dict[str, Any]] = []
files_attachment_content: list[str] = []
get_attachments_from_content = make_confluence_call_handle_rate_limit(
confluence_client.get_attachments_from_content
)
files_attachment_content: list = []
try:
expand = "history.lastUpdated,metadata.labels"
attachments_container = get_attachments_from_content(
page_id, start=0, limit=500, expand=expand
page_id, start=None, limit=None, expand=expand
)
for attachment in attachments_container["results"]:
if attachment["title"] not in files_in_used:
for attachment in attachments_container.get("results", []):
if attachment["title"] not in files_in_use:
unused_attachments.append(attachment)
continue
@@ -680,7 +587,6 @@ class ConfluenceConnector(LoadConnector, PollConnector):
f"User does not have access to attachments on page '{page_id}'"
)
return "", []
if not self.continue_on_failure:
raise e
logger.exception(
@@ -690,24 +596,26 @@ class ConfluenceConnector(LoadConnector, PollConnector):
return "\n".join(files_attachment_content), unused_attachments
def _get_doc_batch(
self, start_ind: int
) -> tuple[list[Document], list[dict[str, Any]], int]:
self, cursor: str | None
) -> tuple[list[Any], str | None, list[dict[str, Any]]]:
if self.confluence_client is None:
raise ConnectorMissingCredentialError("Confluence")
raise Exception("Confluence client is not initialized")
doc_batch: list[Document] = []
doc_batch: list[Any] = []
unused_attachments: list[dict[str, Any]] = []
batch = self._fetch_pages(start_ind)
batch, next_cursor = self._fetch_pages(cursor)
for page in batch:
last_modified = _datetime_from_string(page["version"]["when"])
author = cast(str | None, page["version"].get("by", {}).get("email"))
author = page["version"].get("by", {}).get("email")
page_id = page["id"]
if self.labels_to_skip or not CONFLUENCE_CONNECTOR_SKIP_LABEL_INDEXING:
page_labels = self._fetch_labels(self.confluence_client, page_id)
else:
page_labels = []
# check disallowed labels
if self.labels_to_skip:
@@ -717,7 +625,6 @@ class ConfluenceConnector(LoadConnector, PollConnector):
f"Page with ID '{page_id}' has a label which has been "
f"designated as disallowed: {label_intersection}. Skipping."
)
continue
page_html = (
@@ -732,16 +639,18 @@ class ConfluenceConnector(LoadConnector, PollConnector):
continue
page_text = parse_html_page(page_html, self.confluence_client)
files_in_used = get_used_attachments(page_html)
files_in_use = get_used_attachments(page_html)
attachment_text, unused_page_attachments = self._fetch_attachments(
self.confluence_client, page_id, files_in_used
self.confluence_client, page_id, files_in_use
)
unused_attachments.extend(unused_page_attachments)
page_text += "\n" + attachment_text if attachment_text else ""
comments_text = self._fetch_comments(self.confluence_client, page_id)
page_text += comments_text
doc_metadata: dict[str, str | list[str]] = {"Wiki Space Name": self.space}
doc_metadata: dict[str, str | list[str]] = {
"Wiki Space Name": page["space"]["name"]
}
if not CONFLUENCE_CONNECTOR_SKIP_LABEL_INDEXING and page_labels:
doc_metadata["labels"] = page_labels
@@ -760,8 +669,8 @@ class ConfluenceConnector(LoadConnector, PollConnector):
)
return (
doc_batch,
next_cursor,
unused_attachments,
len(batch),
)
def _get_attachment_batch(
@@ -769,8 +678,8 @@ class ConfluenceConnector(LoadConnector, PollConnector):
start_ind: int,
attachments: list[dict[str, Any]],
time_filter: Callable[[datetime], bool] | None = None,
) -> tuple[list[Document], int]:
doc_batch: list[Document] = []
) -> tuple[list[Any], int]:
doc_batch: list[Any] = []
if self.confluence_client is None:
raise ConnectorMissingCredentialError("Confluence")
@@ -798,7 +707,7 @@ class ConfluenceConnector(LoadConnector, PollConnector):
creator_email = attachment["history"]["createdBy"].get("email")
comment = attachment["metadata"].get("comment", "")
doc_metadata: dict[str, str | list[str]] = {"comment": comment}
doc_metadata: dict[str, Any] = {"comment": comment}
attachment_labels: list[str] = []
if not CONFLUENCE_CONNECTOR_SKIP_LABEL_INDEXING:
@@ -825,64 +734,36 @@ class ConfluenceConnector(LoadConnector, PollConnector):
return doc_batch, end_ind - start_ind
def load_from_state(self) -> GenerateDocumentsOutput:
unused_attachments: list[dict[str, Any]] = []
start_ind = 0
while True:
doc_batch, unused_attachments, num_pages = self._get_doc_batch(start_ind)
unused_attachments.extend(unused_attachments)
start_ind += num_pages
if doc_batch:
yield doc_batch
if num_pages < self.batch_size:
break
start_ind = 0
while True:
attachment_batch, num_attachments = self._get_attachment_batch(
start_ind, unused_attachments
)
start_ind += num_attachments
if attachment_batch:
yield attachment_batch
if num_attachments < self.batch_size:
break
def poll_source(
self, start: SecondsSinceUnixEpoch, end: SecondsSinceUnixEpoch
def _handle_batch_retrieval(
self,
start: float | None = None,
end: float | None = None,
) -> GenerateDocumentsOutput:
start_time = datetime.fromtimestamp(start, tz=timezone.utc) if start else None
end_time = datetime.fromtimestamp(end, tz=timezone.utc) if end else None
unused_attachments: list[dict[str, Any]] = []
if self.confluence_client is None:
raise ConnectorMissingCredentialError("Confluence")
start_time = datetime.fromtimestamp(start, tz=timezone.utc)
end_time = datetime.fromtimestamp(end, tz=timezone.utc)
self.cql_query = _replace_cql_time_filter(self.cql_query, start_time, end_time)
start_ind = 0
cursor = None
while True:
doc_batch, unused_attachments, num_pages = self._get_doc_batch(start_ind)
unused_attachments.extend(unused_attachments)
start_ind += num_pages
doc_batch, cursor, new_unused_attachments = self._get_doc_batch(cursor)
unused_attachments.extend(new_unused_attachments)
if doc_batch:
yield doc_batch
if num_pages < self.batch_size:
if not cursor:
break
# Process attachments if any
start_ind = 0
while True:
attachment_batch, num_attachments = self._get_attachment_batch(
start_ind,
unused_attachments,
time_filter=lambda t: start_time <= t <= end_time,
start_ind=start_ind,
attachments=unused_attachments,
time_filter=(lambda t: start_time <= t <= end_time)
if start_time and end_time
else None,
)
start_ind += num_attachments
if attachment_batch:
yield attachment_batch
@@ -890,6 +771,12 @@ class ConfluenceConnector(LoadConnector, PollConnector):
if num_attachments < self.batch_size:
break
def load_from_state(self) -> GenerateDocumentsOutput:
return self._handle_batch_retrieval()
def poll_source(self, start: float, end: float) -> GenerateDocumentsOutput:
return self._handle_batch_retrieval(start=start, end=end)
if __name__ == "__main__":
connector = ConfluenceConnector(