mirror of
https://github.com/danswer-ai/danswer.git
synced 2025-10-06 18:14:35 +02:00
Add cursor to cql confluence (#2775)
* add cursor to cql confluence * k * k * fixed space indexing issue * fixed .get --------- Co-authored-by: hagen-danswer <hagen@danswer.ai>
This commit is contained in:
@@ -1,13 +1,13 @@
|
||||
import io
|
||||
import os
|
||||
import re
|
||||
from collections.abc import Callable
|
||||
from collections.abc import Collection
|
||||
from datetime import datetime
|
||||
from datetime import timezone
|
||||
from functools import lru_cache
|
||||
from typing import Any
|
||||
from typing import cast
|
||||
from urllib.parse import parse_qs
|
||||
from urllib.parse import urlparse
|
||||
|
||||
import bs4
|
||||
from atlassian import Confluence # type:ignore
|
||||
@@ -33,7 +33,6 @@ from danswer.connectors.confluence.rate_limit_handler import (
|
||||
from danswer.connectors.interfaces import GenerateDocumentsOutput
|
||||
from danswer.connectors.interfaces import LoadConnector
|
||||
from danswer.connectors.interfaces import PollConnector
|
||||
from danswer.connectors.interfaces import SecondsSinceUnixEpoch
|
||||
from danswer.connectors.models import BasicExpertInfo
|
||||
from danswer.connectors.models import ConnectorMissingCredentialError
|
||||
from danswer.connectors.models import Document
|
||||
@@ -70,86 +69,25 @@ class DanswerConfluence(Confluence):
|
||||
self,
|
||||
cql: str,
|
||||
expand: str | None = None,
|
||||
start: int = 0,
|
||||
cursor: str | None = None,
|
||||
limit: int = 500,
|
||||
include_archived_spaces: bool = False,
|
||||
) -> list[dict[str, Any]]:
|
||||
# Performs the query expansion and start/limit url additions
|
||||
) -> dict[str, Any]:
|
||||
url_suffix = f"rest/api/content/search?cql={cql}"
|
||||
if expand:
|
||||
url_suffix += f"&expand={expand}"
|
||||
url_suffix += f"&start={start}&limit={limit}"
|
||||
if cursor:
|
||||
url_suffix += f"&cursor={cursor}"
|
||||
url_suffix += f"&limit={limit}"
|
||||
if include_archived_spaces:
|
||||
url_suffix += "&includeArchivedSpaces=true"
|
||||
try:
|
||||
response = self.get(url_suffix)
|
||||
return response.get("results", [])
|
||||
return response
|
||||
except Exception as e:
|
||||
raise e
|
||||
|
||||
|
||||
def _replace_cql_time_filter(
|
||||
cql_query: str, start_time: datetime, end_time: datetime
|
||||
) -> str:
|
||||
"""
|
||||
This function replaces the lastmodified filter in the CQL query with the start and end times.
|
||||
This selects the more restrictive time range.
|
||||
"""
|
||||
# Extract existing lastmodified >= and <= filters
|
||||
existing_start_match = re.search(
|
||||
r'lastmodified\s*>=\s*["\']?(\d{4}-\d{2}-\d{2}(?:\s+\d{2}:\d{2})?)["\']?',
|
||||
cql_query,
|
||||
flags=re.IGNORECASE,
|
||||
)
|
||||
existing_end_match = re.search(
|
||||
r'lastmodified\s*<=\s*["\']?(\d{4}-\d{2}-\d{2}(?:\s+\d{2}:\d{2})?)["\']?',
|
||||
cql_query,
|
||||
flags=re.IGNORECASE,
|
||||
)
|
||||
|
||||
# Remove all existing lastmodified and updated filters
|
||||
cql_query = re.sub(
|
||||
r'\s*AND\s+(lastmodified|updated)\s*[<>=]+\s*["\']?[\d-]+(?:\s+[\d:]+)?["\']?',
|
||||
"",
|
||||
cql_query,
|
||||
flags=re.IGNORECASE,
|
||||
)
|
||||
|
||||
# Determine the start time to use
|
||||
if existing_start_match:
|
||||
existing_start_str = existing_start_match.group(1)
|
||||
existing_start = datetime.strptime(
|
||||
existing_start_str,
|
||||
"%Y-%m-%d %H:%M" if " " in existing_start_str else "%Y-%m-%d",
|
||||
)
|
||||
existing_start = existing_start.replace(
|
||||
tzinfo=timezone.utc
|
||||
) # Make offset-aware
|
||||
start_time_to_use = max(start_time.astimezone(timezone.utc), existing_start)
|
||||
else:
|
||||
start_time_to_use = start_time.astimezone(timezone.utc)
|
||||
|
||||
# Determine the end time to use
|
||||
if existing_end_match:
|
||||
existing_end_str = existing_end_match.group(1)
|
||||
existing_end = datetime.strptime(
|
||||
existing_end_str,
|
||||
"%Y-%m-%d %H:%M" if " " in existing_end_str else "%Y-%m-%d",
|
||||
)
|
||||
existing_end = existing_end.replace(tzinfo=timezone.utc) # Make offset-aware
|
||||
end_time_to_use = min(end_time.astimezone(timezone.utc), existing_end)
|
||||
else:
|
||||
end_time_to_use = end_time.astimezone(timezone.utc)
|
||||
|
||||
# Add new time filters
|
||||
cql_query += (
|
||||
f" and lastmodified >= '{start_time_to_use.strftime('%Y-%m-%d %H:%M')}'"
|
||||
)
|
||||
cql_query += f" and lastmodified <= '{end_time_to_use.strftime('%Y-%m-%d %H:%M')}'"
|
||||
|
||||
return cql_query.strip()
|
||||
|
||||
|
||||
@lru_cache()
|
||||
def _get_user(user_id: str, confluence_client: DanswerConfluence) -> str:
|
||||
"""Get Confluence Display Name based on the account-id or userkey value
|
||||
@@ -253,126 +191,86 @@ class RecursiveIndexer:
|
||||
def __init__(
|
||||
self,
|
||||
batch_size: int,
|
||||
confluence_client: DanswerConfluence,
|
||||
confluence_client: Confluence,
|
||||
index_recursively: bool,
|
||||
origin_page_id: str,
|
||||
) -> None:
|
||||
self.batch_size = 1
|
||||
# batch_size
|
||||
self.batch_size = batch_size
|
||||
self.confluence_client = confluence_client
|
||||
self.index_recursively = index_recursively
|
||||
self.origin_page_id = origin_page_id
|
||||
self.pages = self.recurse_children_pages(0, self.origin_page_id)
|
||||
self.pages = self.recurse_children_pages(self.origin_page_id)
|
||||
|
||||
def get_origin_page(self) -> list[dict[str, Any]]:
|
||||
return [self._fetch_origin_page()]
|
||||
|
||||
def get_pages(self, ind: int, size: int) -> list[dict]:
|
||||
if ind * size > len(self.pages):
|
||||
return []
|
||||
return self.pages[ind * size : (ind + 1) * size]
|
||||
def get_pages(self) -> list[dict[str, Any]]:
|
||||
return self.pages
|
||||
|
||||
def _fetch_origin_page(
|
||||
self,
|
||||
) -> dict[str, Any]:
|
||||
def _fetch_origin_page(self) -> dict[str, Any]:
|
||||
get_page_by_id = make_confluence_call_handle_rate_limit(
|
||||
self.confluence_client.get_page_by_id
|
||||
)
|
||||
try:
|
||||
origin_page = get_page_by_id(
|
||||
self.origin_page_id, expand="body.storage.value,version"
|
||||
self.origin_page_id, expand="body.storage.value,version,space"
|
||||
)
|
||||
return origin_page
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
f"Appending orgin page with id {self.origin_page_id} failed: {e}"
|
||||
f"Appending origin page with id {self.origin_page_id} failed: {e}"
|
||||
)
|
||||
return {}
|
||||
|
||||
def recurse_children_pages(
|
||||
self,
|
||||
start_ind: int,
|
||||
page_id: str,
|
||||
) -> list[dict[str, Any]]:
|
||||
pages: list[dict[str, Any]] = []
|
||||
current_level_pages: list[dict[str, Any]] = []
|
||||
next_level_pages: list[dict[str, Any]] = []
|
||||
|
||||
# Initial fetch of first level children
|
||||
index = start_ind
|
||||
while batch := self._fetch_single_depth_child_pages(
|
||||
index, self.batch_size, page_id
|
||||
):
|
||||
current_level_pages.extend(batch)
|
||||
index += len(batch)
|
||||
|
||||
pages.extend(current_level_pages)
|
||||
|
||||
# Recursively index children and children's children, etc.
|
||||
while current_level_pages:
|
||||
for child in current_level_pages:
|
||||
child_index = 0
|
||||
while child_batch := self._fetch_single_depth_child_pages(
|
||||
child_index, self.batch_size, child["id"]
|
||||
):
|
||||
next_level_pages.extend(child_batch)
|
||||
child_index += len(child_batch)
|
||||
|
||||
pages.extend(next_level_pages)
|
||||
current_level_pages = next_level_pages
|
||||
next_level_pages = []
|
||||
|
||||
try:
|
||||
origin_page = self._fetch_origin_page()
|
||||
pages.append(origin_page)
|
||||
except Exception as e:
|
||||
logger.warning(f"Appending origin page with id {page_id} failed: {e}")
|
||||
|
||||
return pages
|
||||
|
||||
def _fetch_single_depth_child_pages(
|
||||
self, start_ind: int, batch_size: int, page_id: str
|
||||
) -> list[dict[str, Any]]:
|
||||
child_pages: list[dict[str, Any]] = []
|
||||
queue: list[str] = [page_id]
|
||||
visited_pages: set[str] = set()
|
||||
|
||||
get_page_child_by_type = make_confluence_call_handle_rate_limit(
|
||||
self.confluence_client.get_page_child_by_type
|
||||
)
|
||||
|
||||
try:
|
||||
child_page = get_page_child_by_type(
|
||||
page_id,
|
||||
type="page",
|
||||
start=start_ind,
|
||||
limit=batch_size,
|
||||
expand="body.storage.value,version",
|
||||
)
|
||||
while queue:
|
||||
current_page_id = queue.pop(0)
|
||||
if current_page_id in visited_pages:
|
||||
continue
|
||||
visited_pages.add(current_page_id)
|
||||
|
||||
child_pages.extend(child_page)
|
||||
return child_pages
|
||||
try:
|
||||
# Fetch the page itself
|
||||
page = self.confluence_client.get_page_by_id(
|
||||
current_page_id, expand="body.storage.value,version,space"
|
||||
)
|
||||
pages.append(page)
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to fetch page {current_page_id}: {e}")
|
||||
continue
|
||||
|
||||
except Exception:
|
||||
logger.warning(
|
||||
f"Batch failed with page {page_id} at offset {start_ind} "
|
||||
f"with size {batch_size}, processing pages individually..."
|
||||
)
|
||||
if not self.index_recursively:
|
||||
continue
|
||||
|
||||
for i in range(batch_size):
|
||||
ind = start_ind + i
|
||||
try:
|
||||
child_page = get_page_child_by_type(
|
||||
page_id,
|
||||
type="page",
|
||||
start=ind,
|
||||
limit=1,
|
||||
expand="body.storage.value,version",
|
||||
)
|
||||
child_pages.extend(child_page)
|
||||
except Exception as e:
|
||||
logger.warning(f"Page {page_id} at offset {ind} failed: {e}")
|
||||
raise e
|
||||
# Fetch child pages
|
||||
start = 0
|
||||
while True:
|
||||
child_pages_response = get_page_child_by_type(
|
||||
current_page_id,
|
||||
type="page",
|
||||
start=start,
|
||||
limit=self.batch_size,
|
||||
expand="",
|
||||
)
|
||||
if not child_pages_response:
|
||||
break
|
||||
for child_page in child_pages_response:
|
||||
child_page_id = child_page["id"]
|
||||
queue.append(child_page_id)
|
||||
start += len(child_pages_response)
|
||||
|
||||
return child_pages
|
||||
return pages
|
||||
|
||||
|
||||
class ConfluenceConnector(LoadConnector, PollConnector):
|
||||
@@ -399,7 +297,6 @@ class ConfluenceConnector(LoadConnector, PollConnector):
|
||||
|
||||
# Remove trailing slash from wiki_base if present
|
||||
self.wiki_base = wiki_base.rstrip("/")
|
||||
self.space = space
|
||||
self.page_id = "" if cql_query else page_id
|
||||
self.space_level_scan = bool(not self.page_id)
|
||||
|
||||
@@ -409,16 +306,16 @@ class ConfluenceConnector(LoadConnector, PollConnector):
|
||||
|
||||
# if a cql_query is provided, we will use it to fetch the pages
|
||||
# if no cql_query is provided, we will use the space to fetch the pages
|
||||
# if no space is provided, we will default to fetching all pages, regardless of space
|
||||
# if no space is provided and no cql_query, we will default to fetching all pages, regardless of space
|
||||
if cql_query:
|
||||
self.cql_query = cql_query
|
||||
elif self.space:
|
||||
self.cql_query = f"type=page and space={self.space}"
|
||||
elif space:
|
||||
self.cql_query = f"type=page and space='{space}'"
|
||||
else:
|
||||
self.cql_query = "type=page"
|
||||
|
||||
logger.info(
|
||||
f"wiki_base: {self.wiki_base}, space: {self.space}, page_id: {self.page_id},"
|
||||
f"wiki_base: {self.wiki_base}, space: {space}, page_id: {self.page_id},"
|
||||
+ f" space_level_scan: {self.space_level_scan}, index_recursively: {self.index_recursively},"
|
||||
+ f" cql_query: {self.cql_query}"
|
||||
)
|
||||
@@ -428,7 +325,6 @@ class ConfluenceConnector(LoadConnector, PollConnector):
|
||||
access_token = credentials["confluence_access_token"]
|
||||
self.confluence_client = DanswerConfluence(
|
||||
url=self.wiki_base,
|
||||
# passing in username causes issues for Confluence data center
|
||||
username=username if self.is_cloud else None,
|
||||
password=access_token if self.is_cloud else None,
|
||||
token=access_token if not self.is_cloud else None,
|
||||
@@ -437,12 +333,16 @@ class ConfluenceConnector(LoadConnector, PollConnector):
|
||||
|
||||
def _fetch_pages(
|
||||
self,
|
||||
start_ind: int,
|
||||
) -> list[dict[str, Any]]:
|
||||
def _fetch_space(start_ind: int, batch_size: int) -> list[dict[str, Any]]:
|
||||
if self.confluence_client is None:
|
||||
raise ConnectorMissingCredentialError("Confluence")
|
||||
cursor: str | None,
|
||||
) -> tuple[list[dict[str, Any]], str | None]:
|
||||
if self.confluence_client is None:
|
||||
raise Exception("Confluence client is not initialized")
|
||||
|
||||
def _fetch_space(
|
||||
cursor: str | None, batch_size: int
|
||||
) -> tuple[list[dict[str, Any]], str | None]:
|
||||
if not self.confluence_client:
|
||||
raise Exception("Confluence client is not initialized")
|
||||
get_all_pages = make_confluence_call_handle_rate_limit(
|
||||
self.confluence_client.danswer_cql
|
||||
)
|
||||
@@ -454,53 +354,84 @@ class ConfluenceConnector(LoadConnector, PollConnector):
|
||||
)
|
||||
|
||||
try:
|
||||
return get_all_pages(
|
||||
response = get_all_pages(
|
||||
cql=self.cql_query,
|
||||
start=start_ind,
|
||||
cursor=cursor,
|
||||
limit=batch_size,
|
||||
expand="body.storage.value,version",
|
||||
expand="body.storage.value,version,space",
|
||||
include_archived_spaces=include_archived_spaces,
|
||||
)
|
||||
pages = response.get("results", [])
|
||||
next_cursor = None
|
||||
if "_links" in response and "next" in response["_links"]:
|
||||
next_link = response["_links"]["next"]
|
||||
parsed_url = urlparse(next_link)
|
||||
query_params = parse_qs(parsed_url.query)
|
||||
cursor_list = query_params.get("cursor", [])
|
||||
if cursor_list:
|
||||
next_cursor = cursor_list[0]
|
||||
return pages, next_cursor
|
||||
except Exception:
|
||||
logger.warning(
|
||||
f"Batch failed with cql {self.cql_query} at offset {start_ind} "
|
||||
f"with size {batch_size}, processing pages individually..."
|
||||
f"Batch failed with cql {self.cql_query} with cursor {cursor} "
|
||||
f"and size {batch_size}, processing pages individually..."
|
||||
)
|
||||
|
||||
view_pages: list[dict[str, Any]] = []
|
||||
for i in range(self.batch_size):
|
||||
for _ in range(self.batch_size):
|
||||
try:
|
||||
# Could be that one of the pages here failed due to this bug:
|
||||
# https://jira.atlassian.com/browse/CONFCLOUD-76433
|
||||
view_pages.extend(
|
||||
get_all_pages(
|
||||
cql=self.cql_query,
|
||||
start=start_ind + i,
|
||||
limit=1,
|
||||
expand="body.storage.value,version",
|
||||
include_archived_spaces=include_archived_spaces,
|
||||
)
|
||||
response = get_all_pages(
|
||||
cql=self.cql_query,
|
||||
cursor=cursor,
|
||||
limit=1,
|
||||
expand="body.view.value,version,space",
|
||||
include_archived_spaces=include_archived_spaces,
|
||||
)
|
||||
pages = response.get("results", [])
|
||||
view_pages.extend(pages)
|
||||
if "_links" in response and "next" in response["_links"]:
|
||||
next_link = response["_links"]["next"]
|
||||
parsed_url = urlparse(next_link)
|
||||
query_params = parse_qs(parsed_url.query)
|
||||
cursor_list = query_params.get("cursor", [])
|
||||
if cursor_list:
|
||||
cursor = cursor_list[0]
|
||||
else:
|
||||
cursor = None
|
||||
else:
|
||||
cursor = None
|
||||
break
|
||||
except HTTPError as e:
|
||||
logger.warning(
|
||||
f"Page failed with cql {self.cql_query} at offset {start_ind + i}, "
|
||||
f"Page failed with cql {self.cql_query} with cursor {cursor}, "
|
||||
f"trying alternative expand option: {e}"
|
||||
)
|
||||
# Use view instead, which captures most info but is less complete
|
||||
view_pages.extend(
|
||||
get_all_pages(
|
||||
cql=self.cql_query,
|
||||
start=start_ind + i,
|
||||
limit=1,
|
||||
expand="body.view.value,version",
|
||||
)
|
||||
response = get_all_pages(
|
||||
cql=self.cql_query,
|
||||
cursor=cursor,
|
||||
limit=1,
|
||||
expand="body.view.value,version,space",
|
||||
)
|
||||
pages = response.get("results", [])
|
||||
view_pages.extend(pages)
|
||||
if "_links" in response and "next" in response["_links"]:
|
||||
next_link = response["_links"]["next"]
|
||||
parsed_url = urlparse(next_link)
|
||||
query_params = parse_qs(parsed_url.query)
|
||||
cursor_list = query_params.get("cursor", [])
|
||||
if cursor_list:
|
||||
cursor = cursor_list[0]
|
||||
else:
|
||||
cursor = None
|
||||
else:
|
||||
cursor = None
|
||||
break
|
||||
|
||||
return view_pages
|
||||
return view_pages, cursor
|
||||
|
||||
def _fetch_page(start_ind: int, batch_size: int) -> list[dict[str, Any]]:
|
||||
def _fetch_page() -> tuple[list[dict[str, Any]], str | None]:
|
||||
if self.confluence_client is None:
|
||||
raise ConnectorMissingCredentialError("Confluence")
|
||||
raise Exception("Confluence client is not initialized")
|
||||
|
||||
if self.recursive_indexer is None:
|
||||
self.recursive_indexer = RecursiveIndexer(
|
||||
@@ -510,59 +441,37 @@ class ConfluenceConnector(LoadConnector, PollConnector):
|
||||
index_recursively=self.index_recursively,
|
||||
)
|
||||
|
||||
if self.index_recursively:
|
||||
return self.recursive_indexer.get_pages(start_ind, batch_size)
|
||||
else:
|
||||
return self.recursive_indexer.get_origin_page()
|
||||
|
||||
pages: list[dict[str, Any]] = []
|
||||
pages = self.recursive_indexer.get_pages()
|
||||
return pages, None # Since we fetched all pages, no cursor
|
||||
|
||||
try:
|
||||
pages = (
|
||||
_fetch_space(start_ind, self.batch_size)
|
||||
pages, next_cursor = (
|
||||
_fetch_space(cursor, self.batch_size)
|
||||
if self.space_level_scan
|
||||
else _fetch_page(start_ind, self.batch_size)
|
||||
else _fetch_page()
|
||||
)
|
||||
return pages
|
||||
|
||||
return pages, next_cursor
|
||||
except Exception as e:
|
||||
if not self.continue_on_failure:
|
||||
raise e
|
||||
|
||||
# error checking phase, only reachable if `self.continue_on_failure=True`
|
||||
for _ in range(self.batch_size):
|
||||
try:
|
||||
pages = (
|
||||
_fetch_space(start_ind, self.batch_size)
|
||||
if self.space_level_scan
|
||||
else _fetch_page(start_ind, self.batch_size)
|
||||
)
|
||||
return pages
|
||||
logger.exception("Ran into exception when fetching pages from Confluence")
|
||||
return [], None
|
||||
|
||||
except Exception:
|
||||
logger.exception(
|
||||
"Ran into exception when fetching pages from Confluence"
|
||||
)
|
||||
|
||||
return pages
|
||||
|
||||
def _fetch_comments(
|
||||
self, confluence_client: DanswerConfluence, page_id: str
|
||||
) -> str:
|
||||
def _fetch_comments(self, confluence_client: Confluence, page_id: str) -> str:
|
||||
get_page_child_by_type = make_confluence_call_handle_rate_limit(
|
||||
confluence_client.get_page_child_by_type
|
||||
)
|
||||
|
||||
try:
|
||||
comment_pages = cast(
|
||||
Collection[dict[str, Any]],
|
||||
comment_pages = list(
|
||||
get_page_child_by_type(
|
||||
page_id,
|
||||
type="comment",
|
||||
start=None,
|
||||
limit=None,
|
||||
expand="body.storage.value",
|
||||
),
|
||||
)
|
||||
)
|
||||
return _comment_dfs("", comment_pages, confluence_client)
|
||||
except Exception as e:
|
||||
@@ -574,9 +483,7 @@ class ConfluenceConnector(LoadConnector, PollConnector):
|
||||
)
|
||||
return ""
|
||||
|
||||
def _fetch_labels(
|
||||
self, confluence_client: DanswerConfluence, page_id: str
|
||||
) -> list[str]:
|
||||
def _fetch_labels(self, confluence_client: Confluence, page_id: str) -> list[str]:
|
||||
get_page_labels = make_confluence_call_handle_rate_limit(
|
||||
confluence_client.get_page_labels
|
||||
)
|
||||
@@ -647,22 +554,22 @@ class ConfluenceConnector(LoadConnector, PollConnector):
|
||||
return extracted_text
|
||||
|
||||
def _fetch_attachments(
|
||||
self, confluence_client: Confluence, page_id: str, files_in_used: list[str]
|
||||
self, confluence_client: Confluence, page_id: str, files_in_use: list[str]
|
||||
) -> tuple[str, list[dict[str, Any]]]:
|
||||
unused_attachments: list = []
|
||||
unused_attachments: list[dict[str, Any]] = []
|
||||
files_attachment_content: list[str] = []
|
||||
|
||||
get_attachments_from_content = make_confluence_call_handle_rate_limit(
|
||||
confluence_client.get_attachments_from_content
|
||||
)
|
||||
files_attachment_content: list = []
|
||||
|
||||
try:
|
||||
expand = "history.lastUpdated,metadata.labels"
|
||||
attachments_container = get_attachments_from_content(
|
||||
page_id, start=0, limit=500, expand=expand
|
||||
page_id, start=None, limit=None, expand=expand
|
||||
)
|
||||
for attachment in attachments_container["results"]:
|
||||
if attachment["title"] not in files_in_used:
|
||||
for attachment in attachments_container.get("results", []):
|
||||
if attachment["title"] not in files_in_use:
|
||||
unused_attachments.append(attachment)
|
||||
continue
|
||||
|
||||
@@ -680,7 +587,6 @@ class ConfluenceConnector(LoadConnector, PollConnector):
|
||||
f"User does not have access to attachments on page '{page_id}'"
|
||||
)
|
||||
return "", []
|
||||
|
||||
if not self.continue_on_failure:
|
||||
raise e
|
||||
logger.exception(
|
||||
@@ -690,24 +596,26 @@ class ConfluenceConnector(LoadConnector, PollConnector):
|
||||
return "\n".join(files_attachment_content), unused_attachments
|
||||
|
||||
def _get_doc_batch(
|
||||
self, start_ind: int
|
||||
) -> tuple[list[Document], list[dict[str, Any]], int]:
|
||||
self, cursor: str | None
|
||||
) -> tuple[list[Any], str | None, list[dict[str, Any]]]:
|
||||
if self.confluence_client is None:
|
||||
raise ConnectorMissingCredentialError("Confluence")
|
||||
raise Exception("Confluence client is not initialized")
|
||||
|
||||
doc_batch: list[Document] = []
|
||||
doc_batch: list[Any] = []
|
||||
unused_attachments: list[dict[str, Any]] = []
|
||||
|
||||
batch = self._fetch_pages(start_ind)
|
||||
batch, next_cursor = self._fetch_pages(cursor)
|
||||
|
||||
for page in batch:
|
||||
last_modified = _datetime_from_string(page["version"]["when"])
|
||||
author = cast(str | None, page["version"].get("by", {}).get("email"))
|
||||
author = page["version"].get("by", {}).get("email")
|
||||
|
||||
page_id = page["id"]
|
||||
|
||||
if self.labels_to_skip or not CONFLUENCE_CONNECTOR_SKIP_LABEL_INDEXING:
|
||||
page_labels = self._fetch_labels(self.confluence_client, page_id)
|
||||
else:
|
||||
page_labels = []
|
||||
|
||||
# check disallowed labels
|
||||
if self.labels_to_skip:
|
||||
@@ -717,7 +625,6 @@ class ConfluenceConnector(LoadConnector, PollConnector):
|
||||
f"Page with ID '{page_id}' has a label which has been "
|
||||
f"designated as disallowed: {label_intersection}. Skipping."
|
||||
)
|
||||
|
||||
continue
|
||||
|
||||
page_html = (
|
||||
@@ -732,16 +639,18 @@ class ConfluenceConnector(LoadConnector, PollConnector):
|
||||
continue
|
||||
page_text = parse_html_page(page_html, self.confluence_client)
|
||||
|
||||
files_in_used = get_used_attachments(page_html)
|
||||
files_in_use = get_used_attachments(page_html)
|
||||
attachment_text, unused_page_attachments = self._fetch_attachments(
|
||||
self.confluence_client, page_id, files_in_used
|
||||
self.confluence_client, page_id, files_in_use
|
||||
)
|
||||
unused_attachments.extend(unused_page_attachments)
|
||||
|
||||
page_text += "\n" + attachment_text if attachment_text else ""
|
||||
comments_text = self._fetch_comments(self.confluence_client, page_id)
|
||||
page_text += comments_text
|
||||
doc_metadata: dict[str, str | list[str]] = {"Wiki Space Name": self.space}
|
||||
doc_metadata: dict[str, str | list[str]] = {
|
||||
"Wiki Space Name": page["space"]["name"]
|
||||
}
|
||||
if not CONFLUENCE_CONNECTOR_SKIP_LABEL_INDEXING and page_labels:
|
||||
doc_metadata["labels"] = page_labels
|
||||
|
||||
@@ -760,8 +669,8 @@ class ConfluenceConnector(LoadConnector, PollConnector):
|
||||
)
|
||||
return (
|
||||
doc_batch,
|
||||
next_cursor,
|
||||
unused_attachments,
|
||||
len(batch),
|
||||
)
|
||||
|
||||
def _get_attachment_batch(
|
||||
@@ -769,8 +678,8 @@ class ConfluenceConnector(LoadConnector, PollConnector):
|
||||
start_ind: int,
|
||||
attachments: list[dict[str, Any]],
|
||||
time_filter: Callable[[datetime], bool] | None = None,
|
||||
) -> tuple[list[Document], int]:
|
||||
doc_batch: list[Document] = []
|
||||
) -> tuple[list[Any], int]:
|
||||
doc_batch: list[Any] = []
|
||||
|
||||
if self.confluence_client is None:
|
||||
raise ConnectorMissingCredentialError("Confluence")
|
||||
@@ -798,7 +707,7 @@ class ConfluenceConnector(LoadConnector, PollConnector):
|
||||
creator_email = attachment["history"]["createdBy"].get("email")
|
||||
|
||||
comment = attachment["metadata"].get("comment", "")
|
||||
doc_metadata: dict[str, str | list[str]] = {"comment": comment}
|
||||
doc_metadata: dict[str, Any] = {"comment": comment}
|
||||
|
||||
attachment_labels: list[str] = []
|
||||
if not CONFLUENCE_CONNECTOR_SKIP_LABEL_INDEXING:
|
||||
@@ -825,64 +734,36 @@ class ConfluenceConnector(LoadConnector, PollConnector):
|
||||
|
||||
return doc_batch, end_ind - start_ind
|
||||
|
||||
def load_from_state(self) -> GenerateDocumentsOutput:
|
||||
unused_attachments: list[dict[str, Any]] = []
|
||||
|
||||
start_ind = 0
|
||||
while True:
|
||||
doc_batch, unused_attachments, num_pages = self._get_doc_batch(start_ind)
|
||||
unused_attachments.extend(unused_attachments)
|
||||
start_ind += num_pages
|
||||
if doc_batch:
|
||||
yield doc_batch
|
||||
|
||||
if num_pages < self.batch_size:
|
||||
break
|
||||
|
||||
start_ind = 0
|
||||
while True:
|
||||
attachment_batch, num_attachments = self._get_attachment_batch(
|
||||
start_ind, unused_attachments
|
||||
)
|
||||
start_ind += num_attachments
|
||||
if attachment_batch:
|
||||
yield attachment_batch
|
||||
|
||||
if num_attachments < self.batch_size:
|
||||
break
|
||||
|
||||
def poll_source(
|
||||
self, start: SecondsSinceUnixEpoch, end: SecondsSinceUnixEpoch
|
||||
def _handle_batch_retrieval(
|
||||
self,
|
||||
start: float | None = None,
|
||||
end: float | None = None,
|
||||
) -> GenerateDocumentsOutput:
|
||||
start_time = datetime.fromtimestamp(start, tz=timezone.utc) if start else None
|
||||
end_time = datetime.fromtimestamp(end, tz=timezone.utc) if end else None
|
||||
|
||||
unused_attachments: list[dict[str, Any]] = []
|
||||
|
||||
if self.confluence_client is None:
|
||||
raise ConnectorMissingCredentialError("Confluence")
|
||||
|
||||
start_time = datetime.fromtimestamp(start, tz=timezone.utc)
|
||||
end_time = datetime.fromtimestamp(end, tz=timezone.utc)
|
||||
|
||||
self.cql_query = _replace_cql_time_filter(self.cql_query, start_time, end_time)
|
||||
|
||||
start_ind = 0
|
||||
cursor = None
|
||||
while True:
|
||||
doc_batch, unused_attachments, num_pages = self._get_doc_batch(start_ind)
|
||||
unused_attachments.extend(unused_attachments)
|
||||
|
||||
start_ind += num_pages
|
||||
doc_batch, cursor, new_unused_attachments = self._get_doc_batch(cursor)
|
||||
unused_attachments.extend(new_unused_attachments)
|
||||
if doc_batch:
|
||||
yield doc_batch
|
||||
|
||||
if num_pages < self.batch_size:
|
||||
if not cursor:
|
||||
break
|
||||
|
||||
# Process attachments if any
|
||||
start_ind = 0
|
||||
while True:
|
||||
attachment_batch, num_attachments = self._get_attachment_batch(
|
||||
start_ind,
|
||||
unused_attachments,
|
||||
time_filter=lambda t: start_time <= t <= end_time,
|
||||
start_ind=start_ind,
|
||||
attachments=unused_attachments,
|
||||
time_filter=(lambda t: start_time <= t <= end_time)
|
||||
if start_time and end_time
|
||||
else None,
|
||||
)
|
||||
|
||||
start_ind += num_attachments
|
||||
if attachment_batch:
|
||||
yield attachment_batch
|
||||
@@ -890,6 +771,12 @@ class ConfluenceConnector(LoadConnector, PollConnector):
|
||||
if num_attachments < self.batch_size:
|
||||
break
|
||||
|
||||
def load_from_state(self) -> GenerateDocumentsOutput:
|
||||
return self._handle_batch_retrieval()
|
||||
|
||||
def poll_source(self, start: float, end: float) -> GenerateDocumentsOutput:
|
||||
return self._handle_batch_retrieval(start=start, end=end)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
connector = ConfluenceConnector(
|
||||
|
Reference in New Issue
Block a user