add image indexing tests (#4477)

* address file path

* k

* update

* update

* nit- fix typing

* k

* should path

* in a good state

* k

* k

* clean up file

* update

* update

* k

* k

* k
This commit is contained in:
pablonyx 2025-04-11 15:16:37 -07:00 committed by GitHub
parent 6eaa774051
commit 65fd8b90a8
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
11 changed files with 222 additions and 20 deletions

View File

@ -406,7 +406,6 @@ async def delete_user_from_control_plane(tenant_id: str, email: str) -> None:
headers=headers,
json=payload.model_dump(),
) as response:
print(response)
if response.status != 200:
error_text = await response.text()
logger.error(f"Control plane tenant creation failed: {error_text}")

View File

@ -159,6 +159,8 @@ def load_files_from_zip(
zip_metadata = json.load(metadata_file)
if isinstance(zip_metadata, list):
# convert list of dicts to dict of dicts
# Use just the basename for matching since metadata may not include
# the full path within the ZIP file
zip_metadata = {d["filename"]: d for d in zip_metadata}
except json.JSONDecodeError:
logger.warning(f"Unable to load {DANSWER_METADATA_FILENAME}")
@ -176,7 +178,13 @@ def load_files_from_zip(
continue
with zip_file.open(file_info.filename, "r") as subfile:
yield file_info, subfile, zip_metadata.get(file_info.filename, {})
# Try to match by exact filename first
if file_info.filename in zip_metadata:
yield file_info, subfile, zip_metadata.get(file_info.filename, {})
else:
# Then try matching by just the basename
basename = os.path.basename(file_info.filename)
yield file_info, subfile, zip_metadata.get(basename, {})
def _extract_onyx_metadata(line: str) -> dict | None:

View File

@ -126,16 +126,13 @@ def get_default_llm_with_vision(
with get_session_with_current_tenant() as db_session:
# Try the default vision provider first
default_provider = fetch_default_vision_provider(db_session)
if (
default_provider
and default_provider.default_vision_model
and model_supports_image_input(
if default_provider and default_provider.default_vision_model:
if model_supports_image_input(
default_provider.default_vision_model, default_provider.provider
)
):
return create_vision_llm(
default_provider, default_provider.default_vision_model
)
):
return create_vision_llm(
default_provider, default_provider.default_vision_model
)
# Fall back to searching all providers
providers = fetch_existing_llm_providers(db_session)
@ -143,14 +140,36 @@ def get_default_llm_with_vision(
if not providers:
return None
# Find the first provider that supports image input
# Check all providers for viable vision models
for provider in providers:
provider_view = LLMProviderView.from_model(provider)
# First priority: Check if provider has a default_vision_model
if provider.default_vision_model and model_supports_image_input(
provider.default_vision_model, provider.provider
):
return create_vision_llm(
LLMProviderView.from_model(provider), provider.default_vision_model
)
return create_vision_llm(provider_view, provider.default_vision_model)
# If no model_names are specified, try default models in priority order
if not provider.model_names:
# Try default_model_name
if provider.default_model_name and model_supports_image_input(
provider.default_model_name, provider.provider
):
return create_vision_llm(provider_view, provider.default_model_name)
# Try fast_default_model_name
if provider.fast_default_model_name and model_supports_image_input(
provider.fast_default_model_name, provider.provider
):
return create_vision_llm(
provider_view, provider.fast_default_model_name
)
else:
# If model_names is specified, check each model
for model_name in provider.model_names:
if model_supports_image_input(model_name, provider.provider):
return create_vision_llm(provider_view, model_name)
return None

View File

@ -118,6 +118,13 @@ class LLMProviderView(LLMProvider):
@classmethod
def from_model(cls, llm_provider_model: "LLMProviderModel") -> "LLMProviderView":
# Safely get groups - handle detached instance case
try:
groups = [group.id for group in llm_provider_model.groups]
except Exception:
# If groups relationship can't be loaded (detached instance), use empty list
groups = []
return cls(
id=llm_provider_model.id,
name=llm_provider_model.name,
@ -148,7 +155,7 @@ class LLMProviderView(LLMProvider):
else None
),
is_public=llm_provider_model.is_public,
groups=[group.id for group in llm_provider_model.groups],
groups=groups,
deployment_name=llm_provider_model.deployment_name,
)

View File

@ -47,6 +47,7 @@ from onyx.context.search.models import IndexFilters
from onyx.context.search.models import SearchRequest
from onyx.db.engine import get_session_with_current_tenant
from onyx.db.engine import get_session_with_tenant
from onyx.db.engine import SqlEngine
from onyx.db.models import ConnectorCredentialPair
from onyx.db.models import Document
from onyx.db.models import DocumentByConnectorCredentialPair
@ -514,6 +515,7 @@ def get_number_of_chunks_we_think_exist(
class VespaDebugging:
# Class for managing Vespa debugging actions.
def __init__(self, tenant_id: str = POSTGRES_DEFAULT_SCHEMA):
SqlEngine.init_engine(pool_size=20, max_overflow=5)
CURRENT_TENANT_ID_CONTEXTVAR.set(tenant_id)
self.tenant_id = tenant_id
self.index_name = get_index_name(self.tenant_id)
@ -855,6 +857,7 @@ def delete_documents_for_tenant(
def main() -> None:
SqlEngine.init_engine(pool_size=20, max_overflow=5)
parser = argparse.ArgumentParser(description="Vespa debugging tool")
parser.add_argument(
"--action",

View File

@ -231,6 +231,11 @@ class DocumentManager:
for doc_dict in retrieved_docs_dict:
doc_id = doc_dict["fields"]["document_id"]
doc_content = doc_dict["fields"]["content"]
final_docs.append(SimpleTestDocument(id=doc_id, content=doc_content))
image_file_name = doc_dict["fields"].get("image_file_name", None)
final_docs.append(
SimpleTestDocument(
id=doc_id, content=doc_content, image_file_name=image_file_name
)
)
return final_docs

View File

@ -1,3 +1,4 @@
import io
import mimetypes
from typing import cast
from typing import IO
@ -62,3 +63,43 @@ class FileManager:
)
response.raise_for_status()
return response.content
@staticmethod
def upload_file_for_connector(
file_path: str, file_name: str, user_performing_action: DATestUser
) -> dict:
# Read the file content
with open(file_path, "rb") as f:
file_content = f.read()
# Create a file-like object
file_obj = io.BytesIO(file_content)
# The 'files' form field expects a list of files
files = [("files", (file_name, file_obj, "application/octet-stream"))]
# Use the user's headers but without Content-Type
# as requests will set the correct multipart/form-data Content-Type for us
headers = user_performing_action.headers.copy()
if "Content-Type" in headers:
del headers["Content-Type"]
# Make the request
response = requests.post(
f"{API_SERVER_URL}/manage/admin/connector/file/upload",
files=files,
headers=headers,
)
if not response.ok:
try:
error_detail = response.json().get("detail", "Unknown error")
except Exception:
error_detail = response.text
raise Exception(
f"Unable to upload files - {error_detail} (Status code: {response.status_code})"
)
response_json = response.json()
return response_json

View File

@ -23,7 +23,7 @@ class SettingsManager:
headers.pop("Content-Type", None)
response = requests.get(
f"{API_SERVER_URL}/api/manage/admin/settings",
f"{API_SERVER_URL}/admin/settings",
headers=headers,
)
@ -48,8 +48,8 @@ class SettingsManager:
headers.pop("Content-Type", None)
payload = settings.model_dump()
response = requests.patch(
f"{API_SERVER_URL}/api/manage/admin/settings",
response = requests.put(
f"{API_SERVER_URL}/admin/settings",
json=payload,
headers=headers,
)

View File

@ -76,6 +76,7 @@ class DATestConnector(BaseModel):
class SimpleTestDocument(BaseModel):
id: str
content: str
image_file_name: str | None = None
class DATestCCPair(BaseModel):
@ -177,6 +178,8 @@ class DATestSettings(BaseModel):
gpu_enabled: bool | None = None
product_gating: DATestGatingType = DATestGatingType.NONE
anonymous_user_enabled: bool | None = None
image_extraction_and_analysis_enabled: bool | None = False
search_time_image_analysis_enabled: bool | None = False
@dataclass

View File

@ -0,0 +1,117 @@
import os
from datetime import datetime
from datetime import timezone
import pytest
from onyx.connectors.models import InputType
from onyx.db.engine import get_session_context_manager
from onyx.db.enums import AccessType
from onyx.server.documents.models import DocumentSource
from tests.integration.common_utils.managers.cc_pair import CCPairManager
from tests.integration.common_utils.managers.connector import ConnectorManager
from tests.integration.common_utils.managers.credential import CredentialManager
from tests.integration.common_utils.managers.document import DocumentManager
from tests.integration.common_utils.managers.file import FileManager
from tests.integration.common_utils.managers.llm_provider import LLMProviderManager
from tests.integration.common_utils.managers.settings import SettingsManager
from tests.integration.common_utils.managers.user import UserManager
from tests.integration.common_utils.test_models import DATestSettings
from tests.integration.common_utils.test_models import DATestUser
from tests.integration.common_utils.vespa import vespa_fixture
FILE_NAME = "Sample.pdf"
FILE_PATH = "tests/integration/common_utils/test_files"
def test_image_indexing(
reset: None,
vespa_client: vespa_fixture,
) -> None:
# Creating an admin user (first user created is automatically an admin)
admin_user: DATestUser = UserManager.create(
email="admin@onyx-test.com",
)
os.makedirs(FILE_PATH, exist_ok=True)
test_file_path = os.path.join(FILE_PATH, FILE_NAME)
# Use FileManager to upload the test file
upload_response = FileManager.upload_file_for_connector(
file_path=test_file_path, file_name=FILE_NAME, user_performing_action=admin_user
)
LLMProviderManager.create(
name="test_llm",
user_performing_action=admin_user,
)
SettingsManager.update_settings(
DATestSettings(
search_time_image_analysis_enabled=True,
image_extraction_and_analysis_enabled=True,
),
user_performing_action=admin_user,
)
file_paths = upload_response.get("file_paths", [])
if not file_paths:
pytest.fail("File upload failed - no file paths returned")
# Create a dummy credential for the file connector
credential = CredentialManager.create(
source=DocumentSource.FILE,
credential_json={},
user_performing_action=admin_user,
)
# Create the connector
connector_name = f"FileConnector-{int(datetime.now().timestamp())}"
connector = ConnectorManager.create(
name=connector_name,
source=DocumentSource.FILE,
input_type=InputType.LOAD_STATE,
connector_specific_config={"file_locations": file_paths},
access_type=AccessType.PUBLIC,
groups=[],
user_performing_action=admin_user,
)
# Link the credential to the connector
cc_pair = CCPairManager.create(
credential_id=credential.id,
connector_id=connector.id,
access_type=AccessType.PUBLIC,
user_performing_action=admin_user,
)
# Explicitly run the connector to start indexing
CCPairManager.run_once(
cc_pair=cc_pair,
from_beginning=True,
user_performing_action=admin_user,
)
CCPairManager.wait_for_indexing_completion(
cc_pair=cc_pair,
after=datetime.now(timezone.utc),
user_performing_action=admin_user,
)
with get_session_context_manager() as db_session:
documents = DocumentManager.fetch_documents_for_cc_pair(
cc_pair_id=cc_pair.id,
db_session=db_session,
vespa_client=vespa_client,
)
# Ensure we indexed an image from the sample.pdf file
has_sample_pdf_image = False
for doc in documents:
if doc.image_file_name and FILE_NAME in doc.image_file_name:
has_sample_pdf_image = True
# Assert that at least one document has an image file name containing "sample.pdf"
assert (
has_sample_pdf_image
), "No document found with an image file name containing 'sample.pdf'"