mirror of
https://github.com/danswer-ai/danswer.git
synced 2025-06-02 11:09:20 +02:00
add image indexing tests (#4477)
* address file path * k * update * update * nit- fix typing * k * should path * in a good state * k * k * clean up file * update * update * k * k * k
This commit is contained in:
parent
6eaa774051
commit
65fd8b90a8
@ -406,7 +406,6 @@ async def delete_user_from_control_plane(tenant_id: str, email: str) -> None:
|
|||||||
headers=headers,
|
headers=headers,
|
||||||
json=payload.model_dump(),
|
json=payload.model_dump(),
|
||||||
) as response:
|
) as response:
|
||||||
print(response)
|
|
||||||
if response.status != 200:
|
if response.status != 200:
|
||||||
error_text = await response.text()
|
error_text = await response.text()
|
||||||
logger.error(f"Control plane tenant creation failed: {error_text}")
|
logger.error(f"Control plane tenant creation failed: {error_text}")
|
||||||
|
@ -159,6 +159,8 @@ def load_files_from_zip(
|
|||||||
zip_metadata = json.load(metadata_file)
|
zip_metadata = json.load(metadata_file)
|
||||||
if isinstance(zip_metadata, list):
|
if isinstance(zip_metadata, list):
|
||||||
# convert list of dicts to dict of dicts
|
# convert list of dicts to dict of dicts
|
||||||
|
# Use just the basename for matching since metadata may not include
|
||||||
|
# the full path within the ZIP file
|
||||||
zip_metadata = {d["filename"]: d for d in zip_metadata}
|
zip_metadata = {d["filename"]: d for d in zip_metadata}
|
||||||
except json.JSONDecodeError:
|
except json.JSONDecodeError:
|
||||||
logger.warning(f"Unable to load {DANSWER_METADATA_FILENAME}")
|
logger.warning(f"Unable to load {DANSWER_METADATA_FILENAME}")
|
||||||
@ -176,7 +178,13 @@ def load_files_from_zip(
|
|||||||
continue
|
continue
|
||||||
|
|
||||||
with zip_file.open(file_info.filename, "r") as subfile:
|
with zip_file.open(file_info.filename, "r") as subfile:
|
||||||
yield file_info, subfile, zip_metadata.get(file_info.filename, {})
|
# Try to match by exact filename first
|
||||||
|
if file_info.filename in zip_metadata:
|
||||||
|
yield file_info, subfile, zip_metadata.get(file_info.filename, {})
|
||||||
|
else:
|
||||||
|
# Then try matching by just the basename
|
||||||
|
basename = os.path.basename(file_info.filename)
|
||||||
|
yield file_info, subfile, zip_metadata.get(basename, {})
|
||||||
|
|
||||||
|
|
||||||
def _extract_onyx_metadata(line: str) -> dict | None:
|
def _extract_onyx_metadata(line: str) -> dict | None:
|
||||||
|
@ -126,16 +126,13 @@ def get_default_llm_with_vision(
|
|||||||
with get_session_with_current_tenant() as db_session:
|
with get_session_with_current_tenant() as db_session:
|
||||||
# Try the default vision provider first
|
# Try the default vision provider first
|
||||||
default_provider = fetch_default_vision_provider(db_session)
|
default_provider = fetch_default_vision_provider(db_session)
|
||||||
if (
|
if default_provider and default_provider.default_vision_model:
|
||||||
default_provider
|
if model_supports_image_input(
|
||||||
and default_provider.default_vision_model
|
|
||||||
and model_supports_image_input(
|
|
||||||
default_provider.default_vision_model, default_provider.provider
|
default_provider.default_vision_model, default_provider.provider
|
||||||
)
|
):
|
||||||
):
|
return create_vision_llm(
|
||||||
return create_vision_llm(
|
default_provider, default_provider.default_vision_model
|
||||||
default_provider, default_provider.default_vision_model
|
)
|
||||||
)
|
|
||||||
|
|
||||||
# Fall back to searching all providers
|
# Fall back to searching all providers
|
||||||
providers = fetch_existing_llm_providers(db_session)
|
providers = fetch_existing_llm_providers(db_session)
|
||||||
@ -143,14 +140,36 @@ def get_default_llm_with_vision(
|
|||||||
if not providers:
|
if not providers:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
# Find the first provider that supports image input
|
# Check all providers for viable vision models
|
||||||
for provider in providers:
|
for provider in providers:
|
||||||
|
provider_view = LLMProviderView.from_model(provider)
|
||||||
|
|
||||||
|
# First priority: Check if provider has a default_vision_model
|
||||||
if provider.default_vision_model and model_supports_image_input(
|
if provider.default_vision_model and model_supports_image_input(
|
||||||
provider.default_vision_model, provider.provider
|
provider.default_vision_model, provider.provider
|
||||||
):
|
):
|
||||||
return create_vision_llm(
|
return create_vision_llm(provider_view, provider.default_vision_model)
|
||||||
LLMProviderView.from_model(provider), provider.default_vision_model
|
|
||||||
)
|
# If no model_names are specified, try default models in priority order
|
||||||
|
if not provider.model_names:
|
||||||
|
# Try default_model_name
|
||||||
|
if provider.default_model_name and model_supports_image_input(
|
||||||
|
provider.default_model_name, provider.provider
|
||||||
|
):
|
||||||
|
return create_vision_llm(provider_view, provider.default_model_name)
|
||||||
|
|
||||||
|
# Try fast_default_model_name
|
||||||
|
if provider.fast_default_model_name and model_supports_image_input(
|
||||||
|
provider.fast_default_model_name, provider.provider
|
||||||
|
):
|
||||||
|
return create_vision_llm(
|
||||||
|
provider_view, provider.fast_default_model_name
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
# If model_names is specified, check each model
|
||||||
|
for model_name in provider.model_names:
|
||||||
|
if model_supports_image_input(model_name, provider.provider):
|
||||||
|
return create_vision_llm(provider_view, model_name)
|
||||||
|
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
@ -118,6 +118,13 @@ class LLMProviderView(LLMProvider):
|
|||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_model(cls, llm_provider_model: "LLMProviderModel") -> "LLMProviderView":
|
def from_model(cls, llm_provider_model: "LLMProviderModel") -> "LLMProviderView":
|
||||||
|
# Safely get groups - handle detached instance case
|
||||||
|
try:
|
||||||
|
groups = [group.id for group in llm_provider_model.groups]
|
||||||
|
except Exception:
|
||||||
|
# If groups relationship can't be loaded (detached instance), use empty list
|
||||||
|
groups = []
|
||||||
|
|
||||||
return cls(
|
return cls(
|
||||||
id=llm_provider_model.id,
|
id=llm_provider_model.id,
|
||||||
name=llm_provider_model.name,
|
name=llm_provider_model.name,
|
||||||
@ -148,7 +155,7 @@ class LLMProviderView(LLMProvider):
|
|||||||
else None
|
else None
|
||||||
),
|
),
|
||||||
is_public=llm_provider_model.is_public,
|
is_public=llm_provider_model.is_public,
|
||||||
groups=[group.id for group in llm_provider_model.groups],
|
groups=groups,
|
||||||
deployment_name=llm_provider_model.deployment_name,
|
deployment_name=llm_provider_model.deployment_name,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -47,6 +47,7 @@ from onyx.context.search.models import IndexFilters
|
|||||||
from onyx.context.search.models import SearchRequest
|
from onyx.context.search.models import SearchRequest
|
||||||
from onyx.db.engine import get_session_with_current_tenant
|
from onyx.db.engine import get_session_with_current_tenant
|
||||||
from onyx.db.engine import get_session_with_tenant
|
from onyx.db.engine import get_session_with_tenant
|
||||||
|
from onyx.db.engine import SqlEngine
|
||||||
from onyx.db.models import ConnectorCredentialPair
|
from onyx.db.models import ConnectorCredentialPair
|
||||||
from onyx.db.models import Document
|
from onyx.db.models import Document
|
||||||
from onyx.db.models import DocumentByConnectorCredentialPair
|
from onyx.db.models import DocumentByConnectorCredentialPair
|
||||||
@ -514,6 +515,7 @@ def get_number_of_chunks_we_think_exist(
|
|||||||
class VespaDebugging:
|
class VespaDebugging:
|
||||||
# Class for managing Vespa debugging actions.
|
# Class for managing Vespa debugging actions.
|
||||||
def __init__(self, tenant_id: str = POSTGRES_DEFAULT_SCHEMA):
|
def __init__(self, tenant_id: str = POSTGRES_DEFAULT_SCHEMA):
|
||||||
|
SqlEngine.init_engine(pool_size=20, max_overflow=5)
|
||||||
CURRENT_TENANT_ID_CONTEXTVAR.set(tenant_id)
|
CURRENT_TENANT_ID_CONTEXTVAR.set(tenant_id)
|
||||||
self.tenant_id = tenant_id
|
self.tenant_id = tenant_id
|
||||||
self.index_name = get_index_name(self.tenant_id)
|
self.index_name = get_index_name(self.tenant_id)
|
||||||
@ -855,6 +857,7 @@ def delete_documents_for_tenant(
|
|||||||
|
|
||||||
|
|
||||||
def main() -> None:
|
def main() -> None:
|
||||||
|
SqlEngine.init_engine(pool_size=20, max_overflow=5)
|
||||||
parser = argparse.ArgumentParser(description="Vespa debugging tool")
|
parser = argparse.ArgumentParser(description="Vespa debugging tool")
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--action",
|
"--action",
|
||||||
|
@ -231,6 +231,11 @@ class DocumentManager:
|
|||||||
for doc_dict in retrieved_docs_dict:
|
for doc_dict in retrieved_docs_dict:
|
||||||
doc_id = doc_dict["fields"]["document_id"]
|
doc_id = doc_dict["fields"]["document_id"]
|
||||||
doc_content = doc_dict["fields"]["content"]
|
doc_content = doc_dict["fields"]["content"]
|
||||||
final_docs.append(SimpleTestDocument(id=doc_id, content=doc_content))
|
image_file_name = doc_dict["fields"].get("image_file_name", None)
|
||||||
|
final_docs.append(
|
||||||
|
SimpleTestDocument(
|
||||||
|
id=doc_id, content=doc_content, image_file_name=image_file_name
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
return final_docs
|
return final_docs
|
||||||
|
@ -1,3 +1,4 @@
|
|||||||
|
import io
|
||||||
import mimetypes
|
import mimetypes
|
||||||
from typing import cast
|
from typing import cast
|
||||||
from typing import IO
|
from typing import IO
|
||||||
@ -62,3 +63,43 @@ class FileManager:
|
|||||||
)
|
)
|
||||||
response.raise_for_status()
|
response.raise_for_status()
|
||||||
return response.content
|
return response.content
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def upload_file_for_connector(
|
||||||
|
file_path: str, file_name: str, user_performing_action: DATestUser
|
||||||
|
) -> dict:
|
||||||
|
# Read the file content
|
||||||
|
with open(file_path, "rb") as f:
|
||||||
|
file_content = f.read()
|
||||||
|
|
||||||
|
# Create a file-like object
|
||||||
|
file_obj = io.BytesIO(file_content)
|
||||||
|
|
||||||
|
# The 'files' form field expects a list of files
|
||||||
|
files = [("files", (file_name, file_obj, "application/octet-stream"))]
|
||||||
|
|
||||||
|
# Use the user's headers but without Content-Type
|
||||||
|
# as requests will set the correct multipart/form-data Content-Type for us
|
||||||
|
headers = user_performing_action.headers.copy()
|
||||||
|
if "Content-Type" in headers:
|
||||||
|
del headers["Content-Type"]
|
||||||
|
|
||||||
|
# Make the request
|
||||||
|
response = requests.post(
|
||||||
|
f"{API_SERVER_URL}/manage/admin/connector/file/upload",
|
||||||
|
files=files,
|
||||||
|
headers=headers,
|
||||||
|
)
|
||||||
|
|
||||||
|
if not response.ok:
|
||||||
|
try:
|
||||||
|
error_detail = response.json().get("detail", "Unknown error")
|
||||||
|
except Exception:
|
||||||
|
error_detail = response.text
|
||||||
|
|
||||||
|
raise Exception(
|
||||||
|
f"Unable to upload files - {error_detail} (Status code: {response.status_code})"
|
||||||
|
)
|
||||||
|
|
||||||
|
response_json = response.json()
|
||||||
|
return response_json
|
||||||
|
@ -23,7 +23,7 @@ class SettingsManager:
|
|||||||
headers.pop("Content-Type", None)
|
headers.pop("Content-Type", None)
|
||||||
|
|
||||||
response = requests.get(
|
response = requests.get(
|
||||||
f"{API_SERVER_URL}/api/manage/admin/settings",
|
f"{API_SERVER_URL}/admin/settings",
|
||||||
headers=headers,
|
headers=headers,
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -48,8 +48,8 @@ class SettingsManager:
|
|||||||
headers.pop("Content-Type", None)
|
headers.pop("Content-Type", None)
|
||||||
|
|
||||||
payload = settings.model_dump()
|
payload = settings.model_dump()
|
||||||
response = requests.patch(
|
response = requests.put(
|
||||||
f"{API_SERVER_URL}/api/manage/admin/settings",
|
f"{API_SERVER_URL}/admin/settings",
|
||||||
json=payload,
|
json=payload,
|
||||||
headers=headers,
|
headers=headers,
|
||||||
)
|
)
|
||||||
|
BIN
backend/tests/integration/common_utils/test_files/Sample.pdf
Normal file
BIN
backend/tests/integration/common_utils/test_files/Sample.pdf
Normal file
Binary file not shown.
@ -76,6 +76,7 @@ class DATestConnector(BaseModel):
|
|||||||
class SimpleTestDocument(BaseModel):
|
class SimpleTestDocument(BaseModel):
|
||||||
id: str
|
id: str
|
||||||
content: str
|
content: str
|
||||||
|
image_file_name: str | None = None
|
||||||
|
|
||||||
|
|
||||||
class DATestCCPair(BaseModel):
|
class DATestCCPair(BaseModel):
|
||||||
@ -177,6 +178,8 @@ class DATestSettings(BaseModel):
|
|||||||
gpu_enabled: bool | None = None
|
gpu_enabled: bool | None = None
|
||||||
product_gating: DATestGatingType = DATestGatingType.NONE
|
product_gating: DATestGatingType = DATestGatingType.NONE
|
||||||
anonymous_user_enabled: bool | None = None
|
anonymous_user_enabled: bool | None = None
|
||||||
|
image_extraction_and_analysis_enabled: bool | None = False
|
||||||
|
search_time_image_analysis_enabled: bool | None = False
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
|
@ -0,0 +1,117 @@
|
|||||||
|
import os
|
||||||
|
from datetime import datetime
|
||||||
|
from datetime import timezone
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from onyx.connectors.models import InputType
|
||||||
|
from onyx.db.engine import get_session_context_manager
|
||||||
|
from onyx.db.enums import AccessType
|
||||||
|
from onyx.server.documents.models import DocumentSource
|
||||||
|
from tests.integration.common_utils.managers.cc_pair import CCPairManager
|
||||||
|
from tests.integration.common_utils.managers.connector import ConnectorManager
|
||||||
|
from tests.integration.common_utils.managers.credential import CredentialManager
|
||||||
|
from tests.integration.common_utils.managers.document import DocumentManager
|
||||||
|
from tests.integration.common_utils.managers.file import FileManager
|
||||||
|
from tests.integration.common_utils.managers.llm_provider import LLMProviderManager
|
||||||
|
from tests.integration.common_utils.managers.settings import SettingsManager
|
||||||
|
from tests.integration.common_utils.managers.user import UserManager
|
||||||
|
from tests.integration.common_utils.test_models import DATestSettings
|
||||||
|
from tests.integration.common_utils.test_models import DATestUser
|
||||||
|
from tests.integration.common_utils.vespa import vespa_fixture
|
||||||
|
|
||||||
|
FILE_NAME = "Sample.pdf"
|
||||||
|
FILE_PATH = "tests/integration/common_utils/test_files"
|
||||||
|
|
||||||
|
|
||||||
|
def test_image_indexing(
|
||||||
|
reset: None,
|
||||||
|
vespa_client: vespa_fixture,
|
||||||
|
) -> None:
|
||||||
|
# Creating an admin user (first user created is automatically an admin)
|
||||||
|
admin_user: DATestUser = UserManager.create(
|
||||||
|
email="admin@onyx-test.com",
|
||||||
|
)
|
||||||
|
|
||||||
|
os.makedirs(FILE_PATH, exist_ok=True)
|
||||||
|
test_file_path = os.path.join(FILE_PATH, FILE_NAME)
|
||||||
|
|
||||||
|
# Use FileManager to upload the test file
|
||||||
|
upload_response = FileManager.upload_file_for_connector(
|
||||||
|
file_path=test_file_path, file_name=FILE_NAME, user_performing_action=admin_user
|
||||||
|
)
|
||||||
|
|
||||||
|
LLMProviderManager.create(
|
||||||
|
name="test_llm",
|
||||||
|
user_performing_action=admin_user,
|
||||||
|
)
|
||||||
|
|
||||||
|
SettingsManager.update_settings(
|
||||||
|
DATestSettings(
|
||||||
|
search_time_image_analysis_enabled=True,
|
||||||
|
image_extraction_and_analysis_enabled=True,
|
||||||
|
),
|
||||||
|
user_performing_action=admin_user,
|
||||||
|
)
|
||||||
|
|
||||||
|
file_paths = upload_response.get("file_paths", [])
|
||||||
|
|
||||||
|
if not file_paths:
|
||||||
|
pytest.fail("File upload failed - no file paths returned")
|
||||||
|
|
||||||
|
# Create a dummy credential for the file connector
|
||||||
|
credential = CredentialManager.create(
|
||||||
|
source=DocumentSource.FILE,
|
||||||
|
credential_json={},
|
||||||
|
user_performing_action=admin_user,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Create the connector
|
||||||
|
connector_name = f"FileConnector-{int(datetime.now().timestamp())}"
|
||||||
|
connector = ConnectorManager.create(
|
||||||
|
name=connector_name,
|
||||||
|
source=DocumentSource.FILE,
|
||||||
|
input_type=InputType.LOAD_STATE,
|
||||||
|
connector_specific_config={"file_locations": file_paths},
|
||||||
|
access_type=AccessType.PUBLIC,
|
||||||
|
groups=[],
|
||||||
|
user_performing_action=admin_user,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Link the credential to the connector
|
||||||
|
cc_pair = CCPairManager.create(
|
||||||
|
credential_id=credential.id,
|
||||||
|
connector_id=connector.id,
|
||||||
|
access_type=AccessType.PUBLIC,
|
||||||
|
user_performing_action=admin_user,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Explicitly run the connector to start indexing
|
||||||
|
CCPairManager.run_once(
|
||||||
|
cc_pair=cc_pair,
|
||||||
|
from_beginning=True,
|
||||||
|
user_performing_action=admin_user,
|
||||||
|
)
|
||||||
|
CCPairManager.wait_for_indexing_completion(
|
||||||
|
cc_pair=cc_pair,
|
||||||
|
after=datetime.now(timezone.utc),
|
||||||
|
user_performing_action=admin_user,
|
||||||
|
)
|
||||||
|
|
||||||
|
with get_session_context_manager() as db_session:
|
||||||
|
documents = DocumentManager.fetch_documents_for_cc_pair(
|
||||||
|
cc_pair_id=cc_pair.id,
|
||||||
|
db_session=db_session,
|
||||||
|
vespa_client=vespa_client,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Ensure we indexed an image from the sample.pdf file
|
||||||
|
has_sample_pdf_image = False
|
||||||
|
for doc in documents:
|
||||||
|
if doc.image_file_name and FILE_NAME in doc.image_file_name:
|
||||||
|
has_sample_pdf_image = True
|
||||||
|
|
||||||
|
# Assert that at least one document has an image file name containing "sample.pdf"
|
||||||
|
assert (
|
||||||
|
has_sample_pdf_image
|
||||||
|
), "No document found with an image file name containing 'sample.pdf'"
|
Loading…
x
Reference in New Issue
Block a user