Small airtable refactor + handle files with uppercase extensions (#3598)

* Small airtable refactor + handle files with uppercase extensions

* Fix mypy
This commit is contained in:
Chris Weaver
2025-01-05 11:27:50 -08:00
committed by GitHub
parent f895e5f7d0
commit 1db778baa8
2 changed files with 70 additions and 47 deletions

View File

@@ -3,6 +3,8 @@ from typing import Any
import requests import requests
from pyairtable import Api as AirtableApi from pyairtable import Api as AirtableApi
from pyairtable.api.types import RecordDict
from pyairtable.models.schema import TableSchema
from retry import retry from retry import retry
from onyx.configs.app_configs import INDEX_BATCH_SIZE from onyx.configs.app_configs import INDEX_BATCH_SIZE
@@ -188,6 +190,66 @@ class AirtableConnector(LoadConnector):
] ]
return sections, {} return sections, {}
def _process_record(
self,
record: RecordDict,
table_schema: TableSchema,
primary_field_name: str | None,
) -> Document:
"""Process a single Airtable record into a Document.
Args:
record: The Airtable record to process
table_schema: Schema information for the table
table_name: Name of the table
table_id: ID of the table
primary_field_name: Name of the primary field, if any
Returns:
Document object representing the record
"""
table_id = table_schema.id
table_name = table_schema.name
record_id = record["id"]
fields = record["fields"]
sections: list[Section] = []
metadata: dict[str, Any] = {}
# Get primary field value if it exists
primary_field_value = (
fields.get(primary_field_name) if primary_field_name else None
)
for field_schema in table_schema.fields:
field_name = field_schema.name
field_val = fields.get(field_name)
field_type = field_schema.type
field_sections, field_metadata = self._process_field(
field_name=field_name,
field_info=field_val,
field_type=field_type,
table_id=table_id,
record_id=record_id,
)
sections.extend(field_sections)
metadata.update(field_metadata)
semantic_id = (
f"{table_name}: {primary_field_value}"
if primary_field_value
else table_name
)
return Document(
id=f"airtable__{record_id}",
sections=sections,
source=DocumentSource.AIRTABLE,
semantic_identifier=semantic_id,
metadata=metadata,
)
def load_from_state(self) -> GenerateDocumentsOutput: def load_from_state(self) -> GenerateDocumentsOutput:
""" """
Fetch all records from the table. Fetch all records from the table.
@@ -199,17 +261,9 @@ class AirtableConnector(LoadConnector):
raise AirtableClientNotSetUpError() raise AirtableClientNotSetUpError()
table = self.airtable_client.table(self.base_id, self.table_name_or_id) table = self.airtable_client.table(self.base_id, self.table_name_or_id)
table_id = table.id
# due to https://community.airtable.com/t5/development-apis/pagination-returns-422-error/td-p/54778,
# we can't user the `iterate()` method - we need to get everything up front
# this also means we can't handle tables that won't fit in memory
records = table.all() records = table.all()
table_schema = table.schema() table_schema = table.schema()
# have to get the name from the schema, since the table object will
# give back the ID instead of the name if the ID is used to create
# the table object
table_name = table_schema.name
primary_field_name = None primary_field_name = None
# Find a primary field from the schema # Find a primary field from the schema
@@ -220,45 +274,12 @@ class AirtableConnector(LoadConnector):
record_documents: list[Document] = [] record_documents: list[Document] = []
for record in records: for record in records:
record_id = record["id"] document = self._process_record(
fields = record["fields"] record=record,
sections: list[Section] = [] table_schema=table_schema,
metadata: dict[str, Any] = {} primary_field_name=primary_field_name,
# Possibly retrieve the primary field's value
primary_field_value = (
fields.get(primary_field_name) if primary_field_name else None
) )
for field_schema in table_schema.fields: record_documents.append(document)
field_name = field_schema.name
field_val = fields.get(field_name)
field_type = field_schema.type
field_sections, field_metadata = self._process_field(
field_name=field_name,
field_info=field_val,
field_type=field_type,
table_id=table_id,
record_id=record_id,
)
sections.extend(field_sections)
metadata.update(field_metadata)
semantic_id = (
f"{table_name}: {primary_field_value}"
if primary_field_value
else table_name
)
record_document = Document(
id=f"airtable__{record_id}",
sections=sections,
source=DocumentSource.AIRTABLE,
semantic_identifier=semantic_id,
metadata=metadata,
)
record_documents.append(record_document)
if len(record_documents) >= self.batch_size: if len(record_documents) >= self.batch_size:
yield record_documents yield record_documents

View File

@@ -67,7 +67,9 @@ def is_text_file_extension(file_name: str) -> bool:
def get_file_ext(file_path_or_name: str | Path) -> str: def get_file_ext(file_path_or_name: str | Path) -> str:
_, extension = os.path.splitext(file_path_or_name) _, extension = os.path.splitext(file_path_or_name)
return extension # standardize all extensions to be lowercase so that checks against
# VALID_FILE_EXTENSIONS and similar will work as intended
return extension.lower()
def is_valid_file_ext(ext: str) -> bool: def is_valid_file_ext(ext: str) -> bool: