Small airtable refactor + handle files with uppercase extensions (#3598)

* Small airtable refactor + handle files with uppercase extensions

* Fix mypy
This commit is contained in:
Chris Weaver 2025-01-05 11:27:50 -08:00 committed by GitHub
parent f895e5f7d0
commit 1db778baa8
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 70 additions and 47 deletions

View File

@ -3,6 +3,8 @@ from typing import Any
import requests
from pyairtable import Api as AirtableApi
from pyairtable.api.types import RecordDict
from pyairtable.models.schema import TableSchema
from retry import retry
from onyx.configs.app_configs import INDEX_BATCH_SIZE
@ -188,6 +190,66 @@ class AirtableConnector(LoadConnector):
]
return sections, {}
def _process_record(
self,
record: RecordDict,
table_schema: TableSchema,
primary_field_name: str | None,
) -> Document:
"""Process a single Airtable record into a Document.
Args:
record: The Airtable record to process
table_schema: Schema information for the table
table_name: Name of the table
table_id: ID of the table
primary_field_name: Name of the primary field, if any
Returns:
Document object representing the record
"""
table_id = table_schema.id
table_name = table_schema.name
record_id = record["id"]
fields = record["fields"]
sections: list[Section] = []
metadata: dict[str, Any] = {}
# Get primary field value if it exists
primary_field_value = (
fields.get(primary_field_name) if primary_field_name else None
)
for field_schema in table_schema.fields:
field_name = field_schema.name
field_val = fields.get(field_name)
field_type = field_schema.type
field_sections, field_metadata = self._process_field(
field_name=field_name,
field_info=field_val,
field_type=field_type,
table_id=table_id,
record_id=record_id,
)
sections.extend(field_sections)
metadata.update(field_metadata)
semantic_id = (
f"{table_name}: {primary_field_value}"
if primary_field_value
else table_name
)
return Document(
id=f"airtable__{record_id}",
sections=sections,
source=DocumentSource.AIRTABLE,
semantic_identifier=semantic_id,
metadata=metadata,
)
def load_from_state(self) -> GenerateDocumentsOutput:
"""
Fetch all records from the table.
@ -199,17 +261,9 @@ class AirtableConnector(LoadConnector):
raise AirtableClientNotSetUpError()
table = self.airtable_client.table(self.base_id, self.table_name_or_id)
table_id = table.id
# due to https://community.airtable.com/t5/development-apis/pagination-returns-422-error/td-p/54778,
# we can't user the `iterate()` method - we need to get everything up front
# this also means we can't handle tables that won't fit in memory
records = table.all()
table_schema = table.schema()
# have to get the name from the schema, since the table object will
# give back the ID instead of the name if the ID is used to create
# the table object
table_name = table_schema.name
primary_field_name = None
# Find a primary field from the schema
@ -220,45 +274,12 @@ class AirtableConnector(LoadConnector):
record_documents: list[Document] = []
for record in records:
record_id = record["id"]
fields = record["fields"]
sections: list[Section] = []
metadata: dict[str, Any] = {}
# Possibly retrieve the primary field's value
primary_field_value = (
fields.get(primary_field_name) if primary_field_name else None
document = self._process_record(
record=record,
table_schema=table_schema,
primary_field_name=primary_field_name,
)
for field_schema in table_schema.fields:
field_name = field_schema.name
field_val = fields.get(field_name)
field_type = field_schema.type
field_sections, field_metadata = self._process_field(
field_name=field_name,
field_info=field_val,
field_type=field_type,
table_id=table_id,
record_id=record_id,
)
sections.extend(field_sections)
metadata.update(field_metadata)
semantic_id = (
f"{table_name}: {primary_field_value}"
if primary_field_value
else table_name
)
record_document = Document(
id=f"airtable__{record_id}",
sections=sections,
source=DocumentSource.AIRTABLE,
semantic_identifier=semantic_id,
metadata=metadata,
)
record_documents.append(record_document)
record_documents.append(document)
if len(record_documents) >= self.batch_size:
yield record_documents

View File

@ -67,7 +67,9 @@ def is_text_file_extension(file_name: str) -> bool:
def get_file_ext(file_path_or_name: str | Path) -> str:
_, extension = os.path.splitext(file_path_or_name)
return extension
# standardize all extensions to be lowercase so that checks against
# VALID_FILE_EXTENSIONS and similar will work as intended
return extension.lower()
def is_valid_file_ext(ext: str) -> bool: