pablonyx 48ac690a70
Multi tenant tests (#3919)
* ensure fail on multi tenant successfully

* attempted fix

* udpate ingration tests

* minor update

* improve

* improve workflow

* fix migrations

* many more logs

* quick fix

* improve

* fix typo

* quick nit

* attempted fix

* very minor clean up
2025-02-07 01:24:00 +00:00

208 lines
5.0 KiB
Python

from dataclasses import dataclass
from datetime import datetime
from enum import Enum
from typing import Any
from uuid import UUID
from pydantic import BaseModel
from pydantic import Field
from onyx.auth.schemas import UserRole
from onyx.configs.constants import QAFeedbackType
from onyx.context.search.enums import RecencyBiasSetting
from onyx.db.enums import AccessType
from onyx.server.documents.models import DocumentSource
from onyx.server.documents.models import IndexAttemptSnapshot
from onyx.server.documents.models import IndexingStatus
from onyx.server.documents.models import InputType
"""
These data models are used to represent the data on the testing side of things.
This means the flow is:
1. Make request that changes data in db
2. Make a change to the testing model
3. Retrieve data from db
4. Compare db data with testing model to verify
"""
class DATestAPIKey(BaseModel):
api_key_id: int
api_key_display: str
api_key: str | None = None # only present on initial creation
api_key_name: str | None = None
api_key_role: UserRole
user_id: UUID
headers: dict
class DATestUser(BaseModel):
id: str
email: str
password: str
headers: dict
role: UserRole
is_active: bool
cookies: dict = {}
class DATestPersonaLabel(BaseModel):
id: int | None = None
name: str
class DATestCredential(BaseModel):
id: int
name: str
credential_json: dict[str, Any]
admin_public: bool
source: DocumentSource
curator_public: bool
groups: list[int]
class DATestConnector(BaseModel):
id: int
name: str
source: DocumentSource
input_type: InputType
connector_specific_config: dict[str, Any]
groups: list[int] | None = None
access_type: AccessType | None = None
class SimpleTestDocument(BaseModel):
id: str
content: str
class DATestCCPair(BaseModel):
id: int
name: str
connector_id: int
credential_id: int
access_type: AccessType
groups: list[int]
documents: list[SimpleTestDocument] = Field(default_factory=list)
class DATestUserGroup(BaseModel):
id: int
name: str
user_ids: list[str]
cc_pair_ids: list[int]
class DATestLLMProvider(BaseModel):
id: int
name: str
provider: str
api_key: str
default_model_name: str
is_public: bool
groups: list[int]
api_base: str | None = None
api_version: str | None = None
class DATestDocumentSet(BaseModel):
id: int
name: str
description: str
cc_pair_ids: list[int] = Field(default_factory=list)
is_public: bool
is_up_to_date: bool
users: list[str] = Field(default_factory=list)
groups: list[int] = Field(default_factory=list)
class DATestPersona(BaseModel):
id: int
name: str
description: str
num_chunks: float
llm_relevance_filter: bool
is_public: bool
llm_filter_extraction: bool
recency_bias: RecencyBiasSetting
prompt_ids: list[int]
document_set_ids: list[int]
tool_ids: list[int]
llm_model_provider_override: str | None
llm_model_version_override: str | None
users: list[str]
groups: list[int]
label_ids: list[int]
class DATestChatMessage(BaseModel):
id: int
chat_session_id: UUID
parent_message_id: int | None
message: str
class DATestChatSession(BaseModel):
id: UUID
persona_id: int
description: str
class DAQueryHistoryEntry(DATestChatSession):
feedback_type: QAFeedbackType | None
class StreamedResponse(BaseModel):
full_message: str = ""
rephrased_query: str | None = None
tool_name: str | None = None
top_documents: list[dict[str, Any]] | None = None
relevance_summaries: list[dict[str, Any]] | None = None
tool_result: Any | None = None
user: str | None = None
class DATestGatingType(str, Enum):
FULL = "full"
PARTIAL = "partial"
NONE = "none"
class DATestSettings(BaseModel):
"""General settings"""
maximum_chat_retention_days: int | None = None
gpu_enabled: bool | None = None
product_gating: DATestGatingType = DATestGatingType.NONE
anonymous_user_enabled: bool | None = None
@dataclass
class DATestIndexAttempt:
id: int
status: IndexingStatus | None
new_docs_indexed: int | None
total_docs_indexed: int | None
docs_removed_from_index: int | None
error_msg: str | None
time_started: datetime | None
time_updated: datetime | None
@classmethod
def from_index_attempt_snapshot(
cls, index_attempt: IndexAttemptSnapshot
) -> "DATestIndexAttempt":
return cls(
id=index_attempt.id,
status=index_attempt.status,
new_docs_indexed=index_attempt.new_docs_indexed,
total_docs_indexed=index_attempt.total_docs_indexed,
docs_removed_from_index=index_attempt.docs_removed_from_index,
error_msg=index_attempt.error_msg,
time_started=datetime.fromisoformat(index_attempt.time_started)
if index_attempt.time_started
else None,
time_updated=datetime.fromisoformat(index_attempt.time_updated),
)