Added filter to exclude attachments with unsupported file extensions (#3530)

* Added filter to exclude attachments with unsupported file extensions

* extension
This commit is contained in:
hagen-danswer
2024-12-20 11:48:15 -08:00
committed by GitHub
parent 64b6f15e95
commit 71c5043832

View File

@@ -56,6 +56,23 @@ _RESTRICTIONS_EXPANSION_FIELDS = [
_SLIM_DOC_BATCH_SIZE = 5000
_ATTACHMENT_EXTENSIONS_TO_FILTER_OUT = [
"png",
"jpg",
"jpeg",
"gif",
"mp4",
"mov",
"mp3",
"wav",
]
_FULL_EXTENSION_FILTER_STRING = "".join(
[
f" and title!~'*.{extension}'"
for extension in _ATTACHMENT_EXTENSIONS_TO_FILTER_OUT
]
)
class ConfluenceConnector(LoadConnector, PollConnector, SlimConnector):
def __init__(
@@ -64,7 +81,7 @@ class ConfluenceConnector(LoadConnector, PollConnector, SlimConnector):
is_cloud: bool,
space: str = "",
page_id: str = "",
index_recursively: bool = True,
index_recursively: bool = False,
cql_query: str | None = None,
batch_size: int = INDEX_BATCH_SIZE,
continue_on_failure: bool = CONTINUE_ON_CONNECTOR_FAILURE,
@@ -83,22 +100,21 @@ class ConfluenceConnector(LoadConnector, PollConnector, SlimConnector):
self.wiki_base = wiki_base.rstrip("/")
# if nothing is provided, we will fetch all pages
cql_page_query = "type=page"
base_cql_page_query = "type=page"
if cql_query:
# if a cql_query is provided, we will use it to fetch the pages
cql_page_query = cql_query
base_cql_page_query = cql_query
elif page_id:
# if a cql_query is not provided, we will use the page_id to fetch the page
if index_recursively:
cql_page_query += f" and ancestor='{page_id}'"
base_cql_page_query += f" and ancestor='{page_id}'"
else:
cql_page_query += f" and id='{page_id}'"
base_cql_page_query += f" and id='{page_id}'"
elif space:
# if no cql_query or page_id is provided, we will use the space to fetch the pages
cql_page_query += f" and space='{quote(space)}'"
base_cql_page_query += f" and space='{quote(space)}'"
self.cql_page_query = cql_page_query
self.cql_time_filter = ""
self.base_cql_page_query = base_cql_page_query
self.cql_label_filter = ""
if labels_to_skip:
@@ -126,6 +142,33 @@ class ConfluenceConnector(LoadConnector, PollConnector, SlimConnector):
)
return None
def _construct_page_query(
self,
start: SecondsSinceUnixEpoch | None = None,
end: SecondsSinceUnixEpoch | None = None,
) -> str:
page_query = self.base_cql_page_query + self.cql_label_filter
# Add time filters
if start:
formatted_start_time = datetime.fromtimestamp(
start, tz=self.timezone
).strftime("%Y-%m-%d %H:%M")
page_query += f" and lastmodified >= '{formatted_start_time}'"
if end:
formatted_end_time = datetime.fromtimestamp(end, tz=self.timezone).strftime(
"%Y-%m-%d %H:%M"
)
page_query += f" and lastmodified <= '{formatted_end_time}'"
return page_query
def _construct_attachment_query(self, confluence_page_id: str) -> str:
attachment_query = f"type=attachment and container='{confluence_page_id}'"
attachment_query += self.cql_label_filter
attachment_query += _FULL_EXTENSION_FILTER_STRING
return attachment_query
def _get_comment_string_for_page_id(self, page_id: str) -> str:
comment_string = ""
@@ -205,11 +248,15 @@ class ConfluenceConnector(LoadConnector, PollConnector, SlimConnector):
metadata=doc_metadata,
)
def _fetch_document_batches(self) -> GenerateDocumentsOutput:
def _fetch_document_batches(
self,
start: SecondsSinceUnixEpoch | None = None,
end: SecondsSinceUnixEpoch | None = None,
) -> GenerateDocumentsOutput:
doc_batch: list[Document] = []
confluence_page_ids: list[str] = []
page_query = self.cql_page_query + self.cql_label_filter + self.cql_time_filter
page_query = self._construct_page_query(start, end)
logger.debug(f"page_query: {page_query}")
# Fetch pages as Documents
for page in self.confluence_client.paginated_cql_retrieval(
@@ -228,11 +275,10 @@ class ConfluenceConnector(LoadConnector, PollConnector, SlimConnector):
# Fetch attachments as Documents
for confluence_page_id in confluence_page_ids:
attachment_cql = f"type=attachment and container='{confluence_page_id}'"
attachment_cql += self.cql_label_filter
attachment_query = self._construct_attachment_query(confluence_page_id)
# TODO: maybe should add time filter as well?
for attachment in self.confluence_client.paginated_cql_retrieval(
cql=attachment_cql,
cql=attachment_query,
expand=",".join(_ATTACHMENT_EXPANSION_FIELDS),
):
doc = self._convert_object_to_document(attachment)
@@ -248,17 +294,12 @@ class ConfluenceConnector(LoadConnector, PollConnector, SlimConnector):
def load_from_state(self) -> GenerateDocumentsOutput:
return self._fetch_document_batches()
def poll_source(self, start: float, end: float) -> GenerateDocumentsOutput:
# Add time filters
formatted_start_time = datetime.fromtimestamp(start, tz=self.timezone).strftime(
"%Y-%m-%d %H:%M"
)
formatted_end_time = datetime.fromtimestamp(end, tz=self.timezone).strftime(
"%Y-%m-%d %H:%M"
)
self.cql_time_filter = f" and lastmodified >= '{formatted_start_time}'"
self.cql_time_filter += f" and lastmodified <= '{formatted_end_time}'"
return self._fetch_document_batches()
def poll_source(
self,
start: SecondsSinceUnixEpoch | None = None,
end: SecondsSinceUnixEpoch | None = None,
) -> GenerateDocumentsOutput:
return self._fetch_document_batches(start, end)
def retrieve_all_slim_documents(
self,
@@ -269,7 +310,7 @@ class ConfluenceConnector(LoadConnector, PollConnector, SlimConnector):
restrictions_expand = ",".join(_RESTRICTIONS_EXPANSION_FIELDS)
page_query = self.cql_page_query + self.cql_label_filter
page_query = self.base_cql_page_query + self.cql_label_filter
for page in self.confluence_client.cql_paginate_all_expansions(
cql=page_query,
expand=restrictions_expand,
@@ -294,10 +335,9 @@ class ConfluenceConnector(LoadConnector, PollConnector, SlimConnector):
perm_sync_data=page_perm_sync_data,
)
)
attachment_cql = f"type=attachment and container='{page['id']}'"
attachment_cql += self.cql_label_filter
attachment_query = self._construct_attachment_query(page["id"])
for attachment in self.confluence_client.cql_paginate_all_expansions(
cql=attachment_cql,
cql=attachment_query,
expand=restrictions_expand,
limit=_SLIM_DOC_BATCH_SIZE,
):