diff --git a/backend/onyx/connectors/airtable/airtable_connector.py b/backend/onyx/connectors/airtable/airtable_connector.py index deee6abbb..7d815528e 100644 --- a/backend/onyx/connectors/airtable/airtable_connector.py +++ b/backend/onyx/connectors/airtable/airtable_connector.py @@ -3,6 +3,8 @@ from typing import Any import requests from pyairtable import Api as AirtableApi +from pyairtable.api.types import RecordDict +from pyairtable.models.schema import TableSchema from retry import retry from onyx.configs.app_configs import INDEX_BATCH_SIZE @@ -188,6 +190,66 @@ class AirtableConnector(LoadConnector): ] return sections, {} + def _process_record( + self, + record: RecordDict, + table_schema: TableSchema, + primary_field_name: str | None, + ) -> Document: + """Process a single Airtable record into a Document. + + Args: + record: The Airtable record to process + table_schema: Schema information for the table + table_name: Name of the table + table_id: ID of the table + primary_field_name: Name of the primary field, if any + + Returns: + Document object representing the record + """ + table_id = table_schema.id + table_name = table_schema.name + record_id = record["id"] + fields = record["fields"] + sections: list[Section] = [] + metadata: dict[str, Any] = {} + + # Get primary field value if it exists + primary_field_value = ( + fields.get(primary_field_name) if primary_field_name else None + ) + + for field_schema in table_schema.fields: + field_name = field_schema.name + field_val = fields.get(field_name) + field_type = field_schema.type + + field_sections, field_metadata = self._process_field( + field_name=field_name, + field_info=field_val, + field_type=field_type, + table_id=table_id, + record_id=record_id, + ) + + sections.extend(field_sections) + metadata.update(field_metadata) + + semantic_id = ( + f"{table_name}: {primary_field_value}" + if primary_field_value + else table_name + ) + + return Document( + id=f"airtable__{record_id}", + sections=sections, + source=DocumentSource.AIRTABLE, + semantic_identifier=semantic_id, + metadata=metadata, + ) + def load_from_state(self) -> GenerateDocumentsOutput: """ Fetch all records from the table. @@ -199,17 +261,9 @@ class AirtableConnector(LoadConnector): raise AirtableClientNotSetUpError() table = self.airtable_client.table(self.base_id, self.table_name_or_id) - table_id = table.id - # due to https://community.airtable.com/t5/development-apis/pagination-returns-422-error/td-p/54778, - # we can't user the `iterate()` method - we need to get everything up front - # this also means we can't handle tables that won't fit in memory records = table.all() table_schema = table.schema() - # have to get the name from the schema, since the table object will - # give back the ID instead of the name if the ID is used to create - # the table object - table_name = table_schema.name primary_field_name = None # Find a primary field from the schema @@ -220,45 +274,12 @@ class AirtableConnector(LoadConnector): record_documents: list[Document] = [] for record in records: - record_id = record["id"] - fields = record["fields"] - sections: list[Section] = [] - metadata: dict[str, Any] = {} - - # Possibly retrieve the primary field's value - primary_field_value = ( - fields.get(primary_field_name) if primary_field_name else None + document = self._process_record( + record=record, + table_schema=table_schema, + primary_field_name=primary_field_name, ) - for field_schema in table_schema.fields: - field_name = field_schema.name - field_val = fields.get(field_name) - field_type = field_schema.type - - field_sections, field_metadata = self._process_field( - field_name=field_name, - field_info=field_val, - field_type=field_type, - table_id=table_id, - record_id=record_id, - ) - - sections.extend(field_sections) - metadata.update(field_metadata) - - semantic_id = ( - f"{table_name}: {primary_field_value}" - if primary_field_value - else table_name - ) - - record_document = Document( - id=f"airtable__{record_id}", - sections=sections, - source=DocumentSource.AIRTABLE, - semantic_identifier=semantic_id, - metadata=metadata, - ) - record_documents.append(record_document) + record_documents.append(document) if len(record_documents) >= self.batch_size: yield record_documents diff --git a/backend/onyx/file_processing/extract_file_text.py b/backend/onyx/file_processing/extract_file_text.py index b5cbe4556..fbc4fdc52 100644 --- a/backend/onyx/file_processing/extract_file_text.py +++ b/backend/onyx/file_processing/extract_file_text.py @@ -67,7 +67,9 @@ def is_text_file_extension(file_name: str) -> bool: def get_file_ext(file_path_or_name: str | Path) -> str: _, extension = os.path.splitext(file_path_or_name) - return extension + # standardize all extensions to be lowercase so that checks against + # VALID_FILE_EXTENSIONS and similar will work as intended + return extension.lower() def is_valid_file_ext(ext: str) -> bool: