572 lines
15 KiB
Python

from datetime import datetime
from typing import Any
from typing import Generic
from typing import Optional
from typing import TypeVar
from uuid import UUID
from pydantic import BaseModel
from pydantic import validator
from pydantic.generics import GenericModel
from danswer.auth.schemas import UserRole
from danswer.configs.app_configs import DOCUMENT_INDEX_NAME
from danswer.configs.app_configs import MASK_CREDENTIAL_PREFIX
from danswer.configs.constants import AuthType
from danswer.configs.constants import DocumentSource
from danswer.configs.constants import MessageType
from danswer.configs.constants import QAFeedbackType
from danswer.configs.constants import SearchFeedbackType
from danswer.connectors.models import DocumentBase
from danswer.connectors.models import InputType
from danswer.danswerbot.slack.config import VALID_SLACK_FILTERS
from danswer.db.models import AllowedAnswerFilters
from danswer.db.models import ChannelConfig
from danswer.db.models import Connector
from danswer.db.models import Credential
from danswer.db.models import DocumentSet as DocumentSetDBModel
from danswer.db.models import IndexAttempt
from danswer.db.models import IndexingStatus
from danswer.db.models import TaskStatus
from danswer.direct_qa.interfaces import DanswerAnswer
from danswer.direct_qa.interfaces import DanswerQuote
from danswer.search.models import BaseFilters
from danswer.search.models import QueryFlow
from danswer.search.models import SearchType
from danswer.server.utils import mask_credential_dict
DataT = TypeVar("DataT")
class StatusResponse(GenericModel, Generic[DataT]):
success: bool
message: Optional[str] = None
data: Optional[DataT] = None
class AuthTypeResponse(BaseModel):
auth_type: AuthType
class VersionResponse(BaseModel):
backend_version: str
class DataRequest(BaseModel):
data: str
class HelperResponse(BaseModel):
values: dict[str, str]
details: list[str] | None = None
class UserInfo(BaseModel):
id: str
email: str
is_active: bool
is_superuser: bool
is_verified: bool
role: UserRole
class GoogleAppWebCredentials(BaseModel):
client_id: str
project_id: str
auth_uri: str
token_uri: str
auth_provider_x509_cert_url: str
client_secret: str
redirect_uris: list[str]
javascript_origins: list[str]
class GoogleAppCredentials(BaseModel):
web: GoogleAppWebCredentials
class GoogleServiceAccountKey(BaseModel):
type: str
project_id: str
private_key_id: str
private_key: str
client_email: str
client_id: str
auth_uri: str
token_uri: str
auth_provider_x509_cert_url: str
client_x509_cert_url: str
universe_domain: str
class GoogleServiceAccountCredentialRequest(BaseModel):
google_drive_delegated_user: str | None # email of user to impersonate
class FileUploadResponse(BaseModel):
file_paths: list[str]
class ObjectCreationIdResponse(BaseModel):
id: int | str
class AuthStatus(BaseModel):
authenticated: bool
class AuthUrl(BaseModel):
auth_url: str
class GDriveCallback(BaseModel):
state: str
code: str
class UserRoleResponse(BaseModel):
role: str
class BoostDoc(BaseModel):
document_id: str
semantic_id: str
link: str
boost: int
hidden: bool
class BoostUpdateRequest(BaseModel):
document_id: str
boost: int
class HiddenUpdateRequest(BaseModel):
document_id: str
hidden: bool
class SearchDoc(BaseModel):
document_id: str
semantic_identifier: str
link: str | None
blurb: str
source_type: str
boost: int
# whether the document is hidden when doing a standard search
# since a standard search will never find a hidden doc, this can only ever
# be `True` when doing an admin search
hidden: bool
score: float | None
# Matched sections in the doc. Uses Vespa syntax e.g. <hi>TEXT</hi>
# to specify that a set of words should be highlighted. For example:
# ["<hi>the</hi> <hi>answer</hi> is 42", "the answer is <hi>42</hi>""]
match_highlights: list[str]
# when the doc was last updated
updated_at: datetime | None
def dict(self, *args: list, **kwargs: dict[str, Any]) -> dict[str, Any]: # type: ignore
initial_dict = super().dict(*args, **kwargs) # type: ignore
initial_dict["updated_at"] = (
self.updated_at.isoformat() if self.updated_at else None
)
return initial_dict
class QuestionRequest(BaseModel):
query: str
filters: BaseFilters
collection: str = DOCUMENT_INDEX_NAME
search_type: SearchType = SearchType.HYBRID
enable_auto_detect_filters: bool = True
favor_recent: bool | None = None
# Is this a real-time/streaming call or a question where Danswer can take more time?
real_time: bool = True
# Pagination purposes, offset is in batches, not by document count
offset: int | None = None
class QAFeedbackRequest(BaseModel):
query_id: int
feedback: QAFeedbackType
class SearchFeedbackRequest(BaseModel):
query_id: int
document_id: str
document_rank: int
click: bool
search_feedback: SearchFeedbackType
class RetrievalDocs(BaseModel):
top_documents: list[SearchDoc]
# First chunk of info for streaming QA
class QADocsResponse(RetrievalDocs):
predicted_flow: QueryFlow
predicted_search: SearchType
time_cutoff: datetime | None
favor_recent: bool
def dict(self, *args: list, **kwargs: dict[str, Any]) -> dict[str, Any]: # type: ignore
initial_dict = super().dict(*args, **kwargs) # type: ignore
initial_dict["time_cutoff"] = (
self.time_cutoff.isoformat() if self.time_cutoff else None
)
return initial_dict
# second chunk of info for streaming QA
class LLMRelevanceFilterResponse(BaseModel):
relevant_chunk_indices: list[int]
class CreateChatSessionID(BaseModel):
chat_session_id: int
class ChatFeedbackRequest(BaseModel):
chat_session_id: int
message_number: int
edit_number: int
is_positive: bool | None = None
feedback_text: str | None = None
class CreateChatMessageRequest(BaseModel):
chat_session_id: int
message_number: int
parent_edit_number: int | None
message: str
persona_id: int | None
class ChatMessageIdentifier(BaseModel):
chat_session_id: int
message_number: int
edit_number: int
class RegenerateMessageRequest(ChatMessageIdentifier):
persona_id: int | None
class ChatRenameRequest(BaseModel):
chat_session_id: int
name: str | None
first_message: str | None
class RenameChatSessionResponse(BaseModel):
new_name: str # This is only really useful if the name is generated
class ChatSession(BaseModel):
id: int
name: str
time_created: str
class ChatSessionsResponse(BaseModel):
sessions: list[ChatSession]
class ChatMessageDetail(BaseModel):
message_number: int
edit_number: int
parent_edit_number: int | None
latest: bool
message: str
context_docs: RetrievalDocs | None
message_type: MessageType
time_sent: datetime
class ChatSessionDetailResponse(BaseModel):
chat_session_id: int
description: str
messages: list[ChatMessageDetail]
class QueryValidationResponse(BaseModel):
reasoning: str
answerable: bool
class AdminSearchRequest(BaseModel):
query: str
filters: BaseFilters
class AdminSearchResponse(BaseModel):
documents: list[SearchDoc]
class SearchResponse(RetrievalDocs):
query_event_id: int
source_type: list[DocumentSource] | None
time_cutoff: datetime | None
favor_recent: bool
class QAResponse(SearchResponse, DanswerAnswer):
quotes: list[DanswerQuote] | None
predicted_flow: QueryFlow
predicted_search: SearchType
eval_res_valid: bool | None = None
llm_chunks_indices: list[int] | None = None
error_msg: str | None = None
class UserByEmail(BaseModel):
user_email: str
class IndexAttemptRequest(BaseModel):
input_type: InputType = InputType.POLL
connector_specific_config: dict[str, Any]
class IndexAttemptSnapshot(BaseModel):
id: int
status: IndexingStatus | None
new_docs_indexed: int # only includes completely new docs
total_docs_indexed: int # includes docs that are updated
error_msg: str | None
time_started: str | None
time_updated: str
@classmethod
def from_index_attempt_db_model(
cls, index_attempt: IndexAttempt
) -> "IndexAttemptSnapshot":
return IndexAttemptSnapshot(
id=index_attempt.id,
status=index_attempt.status,
new_docs_indexed=index_attempt.new_docs_indexed or 0,
total_docs_indexed=index_attempt.total_docs_indexed or 0,
error_msg=index_attempt.error_msg,
time_started=index_attempt.time_started.isoformat()
if index_attempt.time_started
else None,
time_updated=index_attempt.time_updated.isoformat(),
)
class DeletionAttemptSnapshot(BaseModel):
connector_id: int
credential_id: int
status: TaskStatus
class ConnectorBase(BaseModel):
name: str
source: DocumentSource
input_type: InputType
connector_specific_config: dict[str, Any]
refresh_freq: int | None # In seconds, None for one time index with no refresh
disabled: bool
class ConnectorSnapshot(ConnectorBase):
id: int
credential_ids: list[int]
time_created: datetime
time_updated: datetime
@classmethod
def from_connector_db_model(cls, connector: Connector) -> "ConnectorSnapshot":
return ConnectorSnapshot(
id=connector.id,
name=connector.name,
source=connector.source,
input_type=connector.input_type,
connector_specific_config=connector.connector_specific_config,
refresh_freq=connector.refresh_freq,
credential_ids=[
association.credential.id for association in connector.credentials
],
time_created=connector.time_created,
time_updated=connector.time_updated,
disabled=connector.disabled,
)
class RunConnectorRequest(BaseModel):
connector_id: int
credential_ids: list[int] | None
class CredentialBase(BaseModel):
credential_json: dict[str, Any]
# if `true`, then all Admins will have access to the credential
admin_public: bool
class CredentialSnapshot(CredentialBase):
id: int
user_id: UUID | None
time_created: datetime
time_updated: datetime
@classmethod
def from_credential_db_model(cls, credential: Credential) -> "CredentialSnapshot":
return CredentialSnapshot(
id=credential.id,
credential_json=mask_credential_dict(credential.credential_json)
if MASK_CREDENTIAL_PREFIX
else credential.credential_json,
user_id=credential.user_id,
admin_public=credential.admin_public,
time_created=credential.time_created,
time_updated=credential.time_updated,
)
class ConnectorIndexingStatus(BaseModel):
"""Represents the latest indexing status of a connector"""
cc_pair_id: int
name: str | None
connector: ConnectorSnapshot
credential: CredentialSnapshot
owner: str
public_doc: bool
last_status: IndexingStatus | None
last_success: datetime | None
docs_indexed: int
error_msg: str | None
latest_index_attempt: IndexAttemptSnapshot | None
deletion_attempt: DeletionAttemptSnapshot | None
is_deletable: bool
class ConnectorCredentialPairIdentifier(BaseModel):
connector_id: int
credential_id: int
class ConnectorCredentialPairMetadata(BaseModel):
name: str | None
class ConnectorCredentialPairDescriptor(BaseModel):
id: int
name: str | None
connector: ConnectorSnapshot
credential: CredentialSnapshot
class ApiKey(BaseModel):
api_key: str
class DocumentSetCreationRequest(BaseModel):
name: str
description: str
cc_pair_ids: list[int]
class DocumentSetUpdateRequest(BaseModel):
id: int
description: str
cc_pair_ids: list[int]
class CheckDocSetPublicRequest(BaseModel):
document_set_ids: list[int]
class CheckDocSetPublicResponse(BaseModel):
is_public: bool
class DocumentSet(BaseModel):
id: int
name: str
description: str
cc_pair_descriptors: list[ConnectorCredentialPairDescriptor]
is_up_to_date: bool
contains_non_public: bool
@classmethod
def from_model(cls, document_set_model: DocumentSetDBModel) -> "DocumentSet":
return cls(
id=document_set_model.id,
name=document_set_model.name,
description=document_set_model.description,
contains_non_public=any(
[
not cc_pair.is_public
for cc_pair in document_set_model.connector_credential_pairs
]
),
cc_pair_descriptors=[
ConnectorCredentialPairDescriptor(
id=cc_pair.id,
name=cc_pair.name,
connector=ConnectorSnapshot.from_connector_db_model(
cc_pair.connector
),
credential=CredentialSnapshot.from_credential_db_model(
cc_pair.credential
),
)
for cc_pair in document_set_model.connector_credential_pairs
],
is_up_to_date=document_set_model.is_up_to_date,
)
class IngestionDocument(BaseModel):
document: DocumentBase
connector_id: int | None = None # Takes precedence over the name
connector_name: str | None = None
credential_id: int | None = None
create_connector: bool = False # Currently not allowed
public_doc: bool = True # To attach to the cc_pair, currently unused
class IngestionResult(BaseModel):
document_id: str
already_existed: bool
class SlackBotTokens(BaseModel):
bot_token: str
app_token: str
class SlackBotConfigCreationRequest(BaseModel):
# currently, a persona is created for each slack bot config
# in the future, `document_sets` will probably be replaced
# by an optional `PersonaSnapshot` object. Keeping it like this
# for now for simplicity / speed of development
document_sets: list[int]
channel_names: list[str]
respond_tag_only: bool = False
# If no team members, assume respond in the channel to everyone
respond_team_member_list: list[str] = []
answer_filters: list[AllowedAnswerFilters] = []
@validator("answer_filters", pre=True)
def validate_filters(cls, value: list[str]) -> list[str]:
if any(test not in VALID_SLACK_FILTERS for test in value):
raise ValueError(
f"Slack Answer filters must be one of {VALID_SLACK_FILTERS}"
)
return value
class SlackBotConfig(BaseModel):
id: int
# currently, a persona is created for each slack bot config
# in the future, `document_sets` will probably be replaced
# by an optional `PersonaSnapshot` object. Keeping it like this
# for now for simplicity / speed of development
document_sets: list[DocumentSet]
channel_config: ChannelConfig