mirror of
https://github.com/danswer-ai/danswer.git
synced 2025-05-29 17:19:36 +02:00
Danswer APIs Document Ingestion Endpoint (#716)
This commit is contained in:
parent
d291fea020
commit
39d09a162a
@ -0,0 +1,37 @@
|
||||
"""Introduce Danswer APIs
|
||||
|
||||
Revision ID: 15326fcec57e
|
||||
Revises: 77d07dffae64
|
||||
Create Date: 2023-11-11 20:51:24.228999
|
||||
|
||||
"""
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
from danswer.configs.constants import DocumentSource
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "15326fcec57e"
|
||||
down_revision = "77d07dffae64"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
op.alter_column("credential", "is_admin", new_column_name="admin_public")
|
||||
op.add_column(
|
||||
"document",
|
||||
sa.Column("from_ingestion_api", sa.Boolean(), nullable=True),
|
||||
)
|
||||
op.alter_column(
|
||||
"connector",
|
||||
"source",
|
||||
type_=sa.String(length=50),
|
||||
existing_type=sa.Enum(DocumentSource, native_enum=False),
|
||||
existing_nullable=False,
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.drop_column("document", "from_ingestion_api")
|
||||
op.alter_column("credential", "admin_public", new_column_name="is_admin")
|
@ -40,6 +40,8 @@ LLM_CHUNKS = "llm_chunks"
|
||||
|
||||
|
||||
class DocumentSource(str, Enum):
|
||||
# Special case, document passed in via Danswer APIs without specifying a source type
|
||||
INGESTION_API = "ingestion_api"
|
||||
SLACK = "slack"
|
||||
WEB = "web"
|
||||
GOOGLE_DRIVE = "google_drive"
|
||||
|
@ -65,7 +65,7 @@ def _process_file(
|
||||
Document(
|
||||
id=file_name,
|
||||
sections=[
|
||||
Section(link=metadata.get("link", ""), text=file_content_raw.strip())
|
||||
Section(link=metadata.get("link", None), text=file_content_raw.strip())
|
||||
],
|
||||
source=DocumentSource.FILE,
|
||||
semantic_identifier=file_name,
|
||||
|
@ -130,7 +130,7 @@ def build_service_account_creds(
|
||||
|
||||
return CredentialBase(
|
||||
credential_json=credential_dict,
|
||||
is_admin=True,
|
||||
admin_public=True,
|
||||
)
|
||||
|
||||
|
||||
|
@ -1,9 +1,17 @@
|
||||
from dataclasses import dataclass
|
||||
from datetime import datetime
|
||||
from enum import Enum
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from danswer.configs.constants import DocumentSource
|
||||
from danswer.utils.text_processing import make_url_compatible
|
||||
|
||||
|
||||
class InputType(str, Enum):
|
||||
LOAD_STATE = "load_state" # e.g. loading a current full state or a save state, such as from a file
|
||||
POLL = "poll" # e.g. calling an API to get all documents in the last hour
|
||||
EVENT = "event" # e.g. registered an endpoint as a listener, and processing connector events
|
||||
|
||||
|
||||
class ConnectorMissingCredentialError(PermissionError):
|
||||
@ -14,17 +22,17 @@ class ConnectorMissingCredentialError(PermissionError):
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class Section:
|
||||
link: str
|
||||
class Section(BaseModel):
|
||||
text: str
|
||||
link: str | None
|
||||
|
||||
|
||||
@dataclass
|
||||
class Document:
|
||||
id: str # This must be unique or during indexing/reindexing, chunks will be overwritten
|
||||
class DocumentBase(BaseModel):
|
||||
"""Used for Danswer ingestion api, the ID is inferred before use if not provided"""
|
||||
|
||||
id: str | None = None
|
||||
sections: list[Section]
|
||||
source: DocumentSource
|
||||
source: DocumentSource | None = None
|
||||
semantic_identifier: str # displayed in the UI as the main identifier for the doc
|
||||
metadata: dict[str, Any]
|
||||
# UTC time
|
||||
@ -36,22 +44,38 @@ class Document:
|
||||
# `title` is used when computing best matches for a query
|
||||
# if `None`, then we will use the `semantic_identifier` as the title in Vespa
|
||||
title: str | None = None
|
||||
from_ingestion_api: bool = False
|
||||
|
||||
def get_title_for_document_index(self) -> str:
|
||||
return self.semantic_identifier if self.title is None else self.title
|
||||
|
||||
|
||||
class Document(DocumentBase):
|
||||
id: str # This must be unique or during indexing/reindexing, chunks will be overwritten
|
||||
source: DocumentSource
|
||||
|
||||
def to_short_descriptor(self) -> str:
|
||||
"""Used when logging the identity of a document"""
|
||||
return f"ID: '{self.id}'; Semantic ID: '{self.semantic_identifier}'"
|
||||
|
||||
|
||||
class InputType(str, Enum):
|
||||
LOAD_STATE = "load_state" # e.g. loading a current full state or a save state, such as from a file
|
||||
POLL = "poll" # e.g. calling an API to get all documents in the last hour
|
||||
EVENT = "event" # e.g. registered an endpoint as a listener, and processing connector events
|
||||
@classmethod
|
||||
def from_base(cls, base: DocumentBase) -> "Document":
|
||||
return cls(
|
||||
id=make_url_compatible(base.id)
|
||||
if base.id
|
||||
else "ingestion_api_" + make_url_compatible(base.semantic_identifier),
|
||||
sections=base.sections,
|
||||
source=base.source or DocumentSource.INGESTION_API,
|
||||
semantic_identifier=base.semantic_identifier,
|
||||
metadata=base.metadata,
|
||||
doc_updated_at=base.doc_updated_at,
|
||||
primary_owners=base.primary_owners,
|
||||
secondary_owners=base.secondary_owners,
|
||||
title=base.title,
|
||||
from_ingestion_api=base.from_ingestion_api,
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class IndexAttemptMetadata:
|
||||
class IndexAttemptMetadata(BaseModel):
|
||||
connector_id: int
|
||||
credential_id: int
|
||||
|
@ -50,6 +50,19 @@ def fetch_connector_by_id(connector_id: int, db_session: Session) -> Connector |
|
||||
return connector
|
||||
|
||||
|
||||
def fetch_ingestion_connector_by_name(
|
||||
connector_name: str, db_session: Session
|
||||
) -> Connector | None:
|
||||
stmt = (
|
||||
select(Connector)
|
||||
.where(Connector.name == connector_name)
|
||||
.where(Connector.source == DocumentSource.INGESTION_API)
|
||||
)
|
||||
result = db_session.execute(stmt)
|
||||
connector = result.scalar_one_or_none()
|
||||
return connector
|
||||
|
||||
|
||||
def create_connector(
|
||||
connector_data: ConnectorBase,
|
||||
db_session: Session,
|
||||
@ -210,3 +223,32 @@ def fetch_unique_document_sources(db_session: Session) -> list[DocumentSource]:
|
||||
sources = [source[0] for source in distinct_sources]
|
||||
|
||||
return sources
|
||||
|
||||
|
||||
def create_initial_default_connector(db_session: Session) -> None:
|
||||
default_connector_id = 0
|
||||
default_connector = fetch_connector_by_id(default_connector_id, db_session)
|
||||
|
||||
if default_connector is not None:
|
||||
if (
|
||||
default_connector.source != DocumentSource.INGESTION_API
|
||||
or default_connector.input_type != InputType.LOAD_STATE
|
||||
or default_connector.refresh_freq is not None
|
||||
or default_connector.disabled
|
||||
):
|
||||
raise ValueError(
|
||||
"DB is not in a valid initial state. "
|
||||
"Default connector does not have expected values."
|
||||
)
|
||||
return
|
||||
|
||||
connector = Connector(
|
||||
id=default_connector_id,
|
||||
name="Ingestion API",
|
||||
source=DocumentSource.INGESTION_API,
|
||||
input_type=InputType.LOAD_STATE,
|
||||
connector_specific_config={},
|
||||
refresh_freq=None,
|
||||
)
|
||||
db_session.add(connector)
|
||||
db_session.commit()
|
||||
|
@ -122,6 +122,27 @@ def mark_all_in_progress_cc_pairs_failed(
|
||||
db_session.commit()
|
||||
|
||||
|
||||
def associate_default_cc_pair(db_session: Session) -> None:
|
||||
existing_association = (
|
||||
db_session.query(ConnectorCredentialPair)
|
||||
.filter(
|
||||
ConnectorCredentialPair.connector_id == 0,
|
||||
ConnectorCredentialPair.credential_id == 0,
|
||||
)
|
||||
.one_or_none()
|
||||
)
|
||||
if existing_association is not None:
|
||||
return
|
||||
|
||||
association = ConnectorCredentialPair(
|
||||
connector_id=0,
|
||||
credential_id=0,
|
||||
name="DefaultCCPair",
|
||||
)
|
||||
db_session.add(association)
|
||||
db_session.commit()
|
||||
|
||||
|
||||
def add_credential_to_connector(
|
||||
connector_id: int,
|
||||
credential_id: int,
|
||||
|
@ -9,18 +9,20 @@ from danswer.auth.schemas import UserRole
|
||||
from danswer.connectors.google_drive.constants import (
|
||||
DB_CREDENTIALS_DICT_SERVICE_ACCOUNT_KEY,
|
||||
)
|
||||
from danswer.db.engine import get_sqlalchemy_engine
|
||||
from danswer.db.models import Credential
|
||||
from danswer.db.models import User
|
||||
from danswer.server.models import CredentialBase
|
||||
from danswer.server.models import ObjectCreationIdResponse
|
||||
from danswer.utils.logger import setup_logger
|
||||
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
def _attach_user_filters(stmt: Select[tuple[Credential]], user: User | None) -> Select:
|
||||
def _attach_user_filters(
|
||||
stmt: Select[tuple[Credential]],
|
||||
user: User | None,
|
||||
assume_admin: bool = False, # Used with API key
|
||||
) -> Select:
|
||||
"""Attaches filters to the statement to ensure that the user can only
|
||||
access the appropriate credentials"""
|
||||
if user:
|
||||
@ -29,11 +31,18 @@ def _attach_user_filters(stmt: Select[tuple[Credential]], user: User | None) ->
|
||||
or_(
|
||||
Credential.user_id == user.id,
|
||||
Credential.user_id.is_(None),
|
||||
Credential.is_admin == True, # noqa: E712
|
||||
Credential.admin_public == True, # noqa: E712
|
||||
)
|
||||
)
|
||||
else:
|
||||
stmt = stmt.where(Credential.user_id == user.id)
|
||||
elif assume_admin:
|
||||
stmt = stmt.where(
|
||||
or_(
|
||||
Credential.user_id.is_(None),
|
||||
Credential.admin_public == True, # noqa: E712
|
||||
)
|
||||
)
|
||||
|
||||
return stmt
|
||||
|
||||
@ -49,10 +58,13 @@ def fetch_credentials(
|
||||
|
||||
|
||||
def fetch_credential_by_id(
|
||||
credential_id: int, user: User | None, db_session: Session
|
||||
credential_id: int,
|
||||
user: User | None,
|
||||
db_session: Session,
|
||||
assume_admin: bool = False,
|
||||
) -> Credential | None:
|
||||
stmt = select(Credential).where(Credential.id == credential_id)
|
||||
stmt = _attach_user_filters(stmt, user)
|
||||
stmt = _attach_user_filters(stmt, user, assume_admin=assume_admin)
|
||||
result = db_session.execute(stmt)
|
||||
credential = result.scalar_one_or_none()
|
||||
return credential
|
||||
@ -62,16 +74,16 @@ def create_credential(
|
||||
credential_data: CredentialBase,
|
||||
user: User | None,
|
||||
db_session: Session,
|
||||
) -> ObjectCreationIdResponse:
|
||||
) -> Credential:
|
||||
credential = Credential(
|
||||
credential_json=credential_data.credential_json,
|
||||
user_id=user.id if user else None,
|
||||
is_admin=credential_data.is_admin,
|
||||
admin_public=credential_data.admin_public,
|
||||
)
|
||||
db_session.add(credential)
|
||||
db_session.commit()
|
||||
|
||||
return ObjectCreationIdResponse(id=credential.id)
|
||||
return credential
|
||||
|
||||
|
||||
def update_credential(
|
||||
@ -131,30 +143,26 @@ def delete_credential(
|
||||
db_session.commit()
|
||||
|
||||
|
||||
def create_initial_public_credential() -> None:
|
||||
def create_initial_public_credential(db_session: Session) -> None:
|
||||
public_cred_id = 0
|
||||
error_msg = (
|
||||
"DB is not in a valid initial state."
|
||||
"There must exist an empty public credential for data connectors that do not require additional Auth."
|
||||
)
|
||||
with Session(get_sqlalchemy_engine(), expire_on_commit=False) as db_session:
|
||||
first_credential = fetch_credential_by_id(public_cred_id, None, db_session)
|
||||
first_credential = fetch_credential_by_id(public_cred_id, None, db_session)
|
||||
|
||||
if first_credential is not None:
|
||||
if (
|
||||
first_credential.credential_json != {}
|
||||
or first_credential.user is not None
|
||||
):
|
||||
raise ValueError(error_msg)
|
||||
return
|
||||
if first_credential is not None:
|
||||
if first_credential.credential_json != {} or first_credential.user is not None:
|
||||
raise ValueError(error_msg)
|
||||
return
|
||||
|
||||
credential = Credential(
|
||||
id=public_cred_id,
|
||||
credential_json={},
|
||||
user_id=None,
|
||||
)
|
||||
db_session.add(credential)
|
||||
db_session.commit()
|
||||
credential = Credential(
|
||||
id=public_cred_id,
|
||||
credential_json={},
|
||||
user_id=None,
|
||||
)
|
||||
db_session.add(credential)
|
||||
db_session.commit()
|
||||
|
||||
|
||||
def delete_google_drive_service_account_credentials(
|
||||
|
@ -168,6 +168,7 @@ def upsert_documents(
|
||||
model_to_dict(
|
||||
DbDocument(
|
||||
id=doc.document_id,
|
||||
from_ingestion_api=doc.from_ingestion_api,
|
||||
boost=initial_boost,
|
||||
hidden=False,
|
||||
semantic_id=doc.semantic_identifier,
|
||||
|
@ -226,7 +226,7 @@ class Credential(Base):
|
||||
credential_json: Mapped[dict[str, Any]] = mapped_column(postgresql.JSONB())
|
||||
user_id: Mapped[UUID | None] = mapped_column(ForeignKey("user.id"), nullable=True)
|
||||
# if `true`, then all Admins will have access to the credential
|
||||
is_admin: Mapped[bool] = mapped_column(Boolean, default=True)
|
||||
admin_public: Mapped[bool] = mapped_column(Boolean, default=True)
|
||||
time_created: Mapped[datetime.datetime] = mapped_column(
|
||||
DateTime(timezone=True), server_default=func.now()
|
||||
)
|
||||
@ -399,6 +399,9 @@ class Document(Base):
|
||||
# this should correspond to the ID of the document
|
||||
# (as is passed around in Danswer)
|
||||
id: Mapped[str] = mapped_column(String, primary_key=True)
|
||||
from_ingestion_api: Mapped[bool] = mapped_column(
|
||||
Boolean, default=False, nullable=True
|
||||
)
|
||||
# 0 for neutral, positive for mostly endorse, negative for mostly reject
|
||||
boost: Mapped[int] = mapped_column(Integer, default=DEFAULT_BOOST)
|
||||
hidden: Mapped[bool] = mapped_column(Boolean, default=False)
|
||||
|
@ -27,6 +27,7 @@ class DocumentMetadata:
|
||||
# Users may not be in Danswer
|
||||
primary_owners: list[str] | None = None
|
||||
secondary_owners: list[str] | None = None
|
||||
from_ingestion_api: bool = False
|
||||
|
||||
|
||||
@dataclass
|
||||
|
@ -38,6 +38,7 @@ def chunk_large_section(
|
||||
blurb_size: int = BLURB_SIZE,
|
||||
) -> list[DocAwareChunk]:
|
||||
section_text = section.text
|
||||
section_link_text = section.link or ""
|
||||
blurb = extract_blurb(section_text, blurb_size)
|
||||
|
||||
sentence_aware_splitter = SentenceSplitter(
|
||||
@ -52,7 +53,7 @@ def chunk_large_section(
|
||||
chunk_id=start_chunk_id + chunk_ind,
|
||||
blurb=blurb,
|
||||
content=chunk_str,
|
||||
source_links={0: section.link},
|
||||
source_links={0: section_link_text},
|
||||
section_continuation=(chunk_ind != 0),
|
||||
)
|
||||
for chunk_ind, chunk_str in enumerate(split_texts)
|
||||
@ -72,6 +73,7 @@ def chunk_document(
|
||||
link_offsets: dict[int, str] = {}
|
||||
chunk_text = ""
|
||||
for section in document.sections:
|
||||
section_link_text = section.link or ""
|
||||
section_tok_length = len(tokenizer.tokenize(section.text))
|
||||
current_tok_length = len(tokenizer.tokenize(chunk_text))
|
||||
curr_offset_len = len(shared_precompare_cleanup(chunk_text))
|
||||
@ -115,7 +117,7 @@ def chunk_document(
|
||||
chunk_text += (
|
||||
SECTION_SEPARATOR + section.text if chunk_text else section.text
|
||||
)
|
||||
link_offsets[curr_offset_len] = section.link
|
||||
link_offsets[curr_offset_len] = section_link_text
|
||||
else:
|
||||
chunks.append(
|
||||
DocAwareChunk(
|
||||
@ -127,7 +129,7 @@ def chunk_document(
|
||||
section_continuation=False,
|
||||
)
|
||||
)
|
||||
link_offsets = {0: section.link}
|
||||
link_offsets = {0: section_link_text}
|
||||
chunk_text = section.text
|
||||
|
||||
# Once we hit the end, if we're still in the process of building a chunk, add what we have
|
||||
|
@ -34,7 +34,7 @@ class IndexingPipelineProtocol(Protocol):
|
||||
...
|
||||
|
||||
|
||||
def _upsert_documents(
|
||||
def upsert_documents_in_db(
|
||||
documents: list[Document],
|
||||
index_attempt_metadata: IndexAttemptMetadata,
|
||||
db_session: Session,
|
||||
@ -52,6 +52,7 @@ def _upsert_documents(
|
||||
first_link=first_link,
|
||||
primary_owners=doc.primary_owners,
|
||||
secondary_owners=doc.secondary_owners,
|
||||
from_ingestion_api=doc.from_ingestion_api,
|
||||
)
|
||||
doc_m_batch.append(db_doc_metadata)
|
||||
|
||||
@ -101,7 +102,7 @@ def _indexing_pipeline(
|
||||
|
||||
# Create records in the source of truth about these documents,
|
||||
# does not include doc_updated_at which is also used to indicate a successful update
|
||||
_upsert_documents(
|
||||
upsert_documents_in_db(
|
||||
documents=updatable_docs,
|
||||
index_attempt_metadata=index_attempt_metadata,
|
||||
db_session=db_session,
|
||||
|
@ -7,6 +7,7 @@ from fastapi.exceptions import RequestValidationError
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
from fastapi.responses import JSONResponse
|
||||
from httpx_oauth.clients.google import GoogleOAuth2
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from danswer import __version__
|
||||
from danswer.auth.schemas import UserCreate
|
||||
@ -35,16 +36,21 @@ from danswer.configs.model_configs import FAST_GEN_AI_MODEL_VERSION
|
||||
from danswer.configs.model_configs import GEN_AI_API_ENDPOINT
|
||||
from danswer.configs.model_configs import GEN_AI_MODEL_PROVIDER
|
||||
from danswer.configs.model_configs import GEN_AI_MODEL_VERSION
|
||||
from danswer.db.connector import create_initial_default_connector
|
||||
from danswer.db.connector_credential_pair import associate_default_cc_pair
|
||||
from danswer.db.credentials import create_initial_public_credential
|
||||
from danswer.db.engine import get_sqlalchemy_engine
|
||||
from danswer.direct_qa.factory import get_default_qa_model
|
||||
from danswer.document_index.factory import get_default_document_index
|
||||
from danswer.llm.factory import get_default_llm
|
||||
from danswer.search.search_nlp_models import warm_up_models
|
||||
from danswer.server.cc_pair.api import router as cc_pair_router
|
||||
from danswer.server.chat_backend import router as chat_router
|
||||
from danswer.server.connector import router as connector_router
|
||||
from danswer.server.credential import router as credential_router
|
||||
from danswer.server.danswer_api import get_danswer_api_key
|
||||
from danswer.server.danswer_api import router as danswer_api_router
|
||||
from danswer.server.document_set import router as document_set_router
|
||||
from danswer.server.event_loading import router as event_processing_router
|
||||
from danswer.server.manage import router as admin_router
|
||||
from danswer.server.search_backend import router as backend_router
|
||||
from danswer.server.slack_bot_management import router as slack_bot_management_router
|
||||
@ -84,7 +90,6 @@ def get_application() -> FastAPI:
|
||||
application = FastAPI(title="Danswer Backend", version=__version__)
|
||||
application.include_router(backend_router)
|
||||
application.include_router(chat_router)
|
||||
application.include_router(event_processing_router)
|
||||
application.include_router(admin_router)
|
||||
application.include_router(user_router)
|
||||
application.include_router(connector_router)
|
||||
@ -93,6 +98,7 @@ def get_application() -> FastAPI:
|
||||
application.include_router(document_set_router)
|
||||
application.include_router(slack_bot_management_router)
|
||||
application.include_router(state_router)
|
||||
application.include_router(danswer_api_router)
|
||||
|
||||
if AUTH_TYPE == AuthType.DISABLED:
|
||||
# Server logs this during auth setup verification step
|
||||
@ -155,17 +161,16 @@ def get_application() -> FastAPI:
|
||||
|
||||
@application.on_event("startup")
|
||||
def startup_event() -> None:
|
||||
# To avoid circular imports
|
||||
from danswer.search.search_nlp_models import (
|
||||
warm_up_models,
|
||||
)
|
||||
|
||||
verify_auth = fetch_versioned_implementation(
|
||||
"danswer.auth.users", "verify_auth_setting"
|
||||
)
|
||||
# Will throw exception if an issue is found
|
||||
verify_auth()
|
||||
|
||||
# Danswer APIs key
|
||||
api_key = get_danswer_api_key()
|
||||
logger.info(f"Danswer API Key: {api_key}")
|
||||
|
||||
if OAUTH_CLIENT_ID and OAUTH_CLIENT_SECRET:
|
||||
logger.info("Both OAuth Client ID and Secret are configured.")
|
||||
|
||||
@ -217,8 +222,11 @@ def get_application() -> FastAPI:
|
||||
nltk.download("wordnet", quiet=True)
|
||||
nltk.download("punkt", quiet=True)
|
||||
|
||||
logger.info("Verifying public credential exists.")
|
||||
create_initial_public_credential()
|
||||
logger.info("Verifying default connector/credential exist.")
|
||||
with Session(get_sqlalchemy_engine(), expire_on_commit=False) as db_session:
|
||||
create_initial_public_credential(db_session)
|
||||
create_initial_default_connector(db_session)
|
||||
associate_default_cc_pair(db_session)
|
||||
|
||||
logger.info("Loading default Chat Personas")
|
||||
load_personas_from_yaml()
|
||||
|
@ -168,9 +168,10 @@ def upsert_service_account_credential(
|
||||
# first delete all existing service account credentials
|
||||
delete_google_drive_service_account_credentials(user, db_session)
|
||||
# `user=None` since this credential is not a personal credential
|
||||
return create_credential(
|
||||
credential = create_credential(
|
||||
credential_data=credential_base, user=user, db_session=db_session
|
||||
)
|
||||
return ObjectCreationIdResponse(id=credential.id)
|
||||
|
||||
|
||||
@router.get("/admin/connector/google-drive/check-auth/{credential_id}")
|
||||
@ -259,6 +260,10 @@ def get_connector_indexing_status(
|
||||
}
|
||||
|
||||
for cc_pair in cc_pairs:
|
||||
# TODO remove this to enable ingestion API
|
||||
if cc_pair.name == "DefaultCCPair":
|
||||
continue
|
||||
|
||||
connector = cc_pair.connector
|
||||
credential = cc_pair.credential
|
||||
latest_index_attempt = cc_pair_to_latest_index_attempt.get(
|
||||
|
@ -72,13 +72,14 @@ def create_credential_from_model(
|
||||
user: User | None = Depends(current_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> ObjectCreationIdResponse:
|
||||
if user and user.role != UserRole.ADMIN:
|
||||
if user and user.role != UserRole.ADMIN and credential_info.admin_public:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail="Non-admin cannot create admin credential",
|
||||
)
|
||||
|
||||
return create_credential(credential_info, user, db_session)
|
||||
credential = create_credential(credential_info, user, db_session)
|
||||
return ObjectCreationIdResponse(id=credential.id)
|
||||
|
||||
|
||||
@router.get("/credential/{credential_id}")
|
||||
@ -117,7 +118,7 @@ def update_credential_from_model(
|
||||
id=updated_credential.id,
|
||||
credential_json=updated_credential.credential_json,
|
||||
user_id=updated_credential.user_id,
|
||||
is_admin=updated_credential.is_admin,
|
||||
admin_public=updated_credential.admin_public,
|
||||
time_created=updated_credential.time_created,
|
||||
time_updated=updated_credential.time_updated,
|
||||
)
|
||||
|
154
backend/danswer/server/danswer_api.py
Normal file
154
backend/danswer/server/danswer_api.py
Normal file
@ -0,0 +1,154 @@
|
||||
import secrets
|
||||
from typing import cast
|
||||
|
||||
from fastapi import APIRouter
|
||||
from fastapi import Depends
|
||||
from fastapi import Header
|
||||
from fastapi import HTTPException
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from danswer.configs.constants import DocumentSource
|
||||
from danswer.connectors.models import Document
|
||||
from danswer.connectors.models import IndexAttemptMetadata
|
||||
from danswer.db.connector import fetch_connector_by_id
|
||||
from danswer.db.connector import fetch_ingestion_connector_by_name
|
||||
from danswer.db.connector_credential_pair import get_connector_credential_pair
|
||||
from danswer.db.credentials import fetch_credential_by_id
|
||||
from danswer.db.engine import get_session
|
||||
from danswer.dynamic_configs import get_dynamic_config_store
|
||||
from danswer.dynamic_configs.interface import ConfigNotFoundError
|
||||
from danswer.indexing.indexing_pipeline import build_indexing_pipeline
|
||||
from danswer.server.models import ApiKey
|
||||
from danswer.server.models import IngestionDocument
|
||||
from danswer.server.models import IngestionResult
|
||||
from danswer.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
# not using /api to avoid confusion with nginx api path routing
|
||||
router = APIRouter(prefix="/danswer-api")
|
||||
|
||||
# Assumes this gives admin privileges, basic users should not be allowed to call any Danswer apis
|
||||
_DANSWER_API_KEY = "danswer_api_key"
|
||||
|
||||
|
||||
def get_danswer_api_key(key_len: int = 30, dont_regenerate: bool = False) -> str | None:
|
||||
kv_store = get_dynamic_config_store()
|
||||
try:
|
||||
return str(kv_store.load(_DANSWER_API_KEY))
|
||||
except ConfigNotFoundError:
|
||||
if dont_regenerate:
|
||||
return None
|
||||
|
||||
logger.info("Generating Danswer API Key")
|
||||
|
||||
api_key = "dn_" + secrets.token_urlsafe(key_len)
|
||||
kv_store.store(_DANSWER_API_KEY, api_key)
|
||||
|
||||
return api_key
|
||||
|
||||
|
||||
def delete_danswer_api_key() -> None:
|
||||
kv_store = get_dynamic_config_store()
|
||||
try:
|
||||
kv_store.delete(_DANSWER_API_KEY)
|
||||
except ConfigNotFoundError:
|
||||
pass
|
||||
|
||||
|
||||
def api_key_dep(authorization: str = Header(...)) -> str:
|
||||
saved_key = get_danswer_api_key(dont_regenerate=True)
|
||||
token = authorization.removeprefix("Bearer ").strip()
|
||||
if token != saved_key or not saved_key:
|
||||
raise HTTPException(status_code=401, detail="Invalid API key")
|
||||
return token
|
||||
|
||||
|
||||
# Provides a way to recover if the api key is deleted for some reason
|
||||
# Can also just restart the server to regenerate a new one
|
||||
def api_key_dep_if_exist(authorization: str | None = Header(None)) -> str | None:
|
||||
token = authorization.removeprefix("Bearer ").strip() if authorization else None
|
||||
saved_key = get_danswer_api_key(dont_regenerate=True)
|
||||
if not saved_key:
|
||||
return None
|
||||
|
||||
if token != saved_key:
|
||||
raise HTTPException(status_code=401, detail="Invalid API key")
|
||||
|
||||
return token
|
||||
|
||||
|
||||
@router.post("/regenerate-key")
|
||||
def regenerate_key(_: str | None = Depends(api_key_dep_if_exist)) -> ApiKey:
|
||||
delete_danswer_api_key()
|
||||
return ApiKey(api_key=cast(str, get_danswer_api_key()))
|
||||
|
||||
|
||||
@router.post("/doc-ingestion")
|
||||
def document_ingestion(
|
||||
doc_info: IngestionDocument,
|
||||
_: str = Depends(api_key_dep),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> IngestionResult:
|
||||
"""Currently only attaches docs to existing connectors (cc-pairs).
|
||||
Or to the default ingestion connector that is accessible to all users
|
||||
|
||||
Things to note:
|
||||
- The document id if not provided is automatically generated from the semantic identifier
|
||||
so if the document source type etc is updated, it won't create a duplicate
|
||||
"""
|
||||
if doc_info.credential_id:
|
||||
credential_id = doc_info.credential_id
|
||||
credential = fetch_credential_by_id(
|
||||
credential_id=credential_id,
|
||||
user=None,
|
||||
db_session=db_session,
|
||||
assume_admin=True,
|
||||
)
|
||||
if credential is None:
|
||||
raise ValueError("Invalid Credential for doc, does not exist.")
|
||||
else:
|
||||
credential_id = 0
|
||||
|
||||
connector_id = doc_info.connector_id
|
||||
# If user provides id and name, id takes precedence
|
||||
if connector_id is not None:
|
||||
connector = fetch_connector_by_id(connector_id, db_session)
|
||||
if connector is None:
|
||||
raise ValueError("Invalid Connector for doc, id does not exist.")
|
||||
elif doc_info.connector_name:
|
||||
connector = fetch_ingestion_connector_by_name(
|
||||
doc_info.connector_name, db_session
|
||||
)
|
||||
if connector is None:
|
||||
raise ValueError("Invalid Connector for doc, name does not exist.")
|
||||
connector_id = connector.id
|
||||
else:
|
||||
connector_id = 0
|
||||
|
||||
cc_pair = get_connector_credential_pair(
|
||||
connector_id=connector_id, credential_id=credential_id, db_session=db_session
|
||||
)
|
||||
if cc_pair is None:
|
||||
raise ValueError("Connector and Credential not associated.")
|
||||
|
||||
# Disregard whatever value is passed, this must be True
|
||||
doc_info.document.from_ingestion_api = True
|
||||
|
||||
document = Document.from_base(doc_info.document)
|
||||
|
||||
# TODO once the frontend is updated with this enum, remove this logic
|
||||
if document.source == DocumentSource.INGESTION_API:
|
||||
document.source = DocumentSource.FILE
|
||||
|
||||
indexing_pipeline = build_indexing_pipeline()
|
||||
|
||||
new_doc, chunks = indexing_pipeline(
|
||||
documents=[document],
|
||||
index_attempt_metadata=IndexAttemptMetadata(
|
||||
connector_id=connector_id,
|
||||
credential_id=credential_id,
|
||||
),
|
||||
)
|
||||
|
||||
return IngestionResult(document_id=document.id, already_existed=not bool(new_doc))
|
@ -1,63 +0,0 @@
|
||||
from typing import Any
|
||||
|
||||
from fastapi import APIRouter
|
||||
from pydantic import BaseModel
|
||||
from pydantic import Extra
|
||||
|
||||
from danswer.utils.logger import setup_logger
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
class SlackEvent(BaseModel, extra=Extra.allow):
|
||||
type: str
|
||||
challenge: str | None
|
||||
event: dict[str, Any] | None
|
||||
|
||||
|
||||
class EventHandlingResponse(BaseModel):
|
||||
challenge: str | None
|
||||
|
||||
|
||||
# TODO: just store entry in DB and process in the background, until then this
|
||||
# won't work cleanly since the slack bot token is not easily accessible
|
||||
# @router.post("/process_slack_event", response_model=EventHandlingResponse)
|
||||
# def process_slack_event(event: SlackEvent) -> EventHandlingResponse:
|
||||
# logger.info("Recieved slack event: %s", event.dict())
|
||||
|
||||
# if event.type == "url_verification":
|
||||
# return EventHandlingResponse(challenge=event.challenge)
|
||||
|
||||
# if event.type == "event_callback" and event.event:
|
||||
# try:
|
||||
# # TODO: process in the background as per slack guidelines
|
||||
# message_type = event.event.get("subtype")
|
||||
# if message_type == "message_changed":
|
||||
# message = event.event["message"]
|
||||
# else:
|
||||
# message = event.event
|
||||
|
||||
# channel_id = event.event["channel"]
|
||||
# thread_ts = message.get("thread_ts")
|
||||
# slack_client = get_client()
|
||||
# doc = thread_to_doc(
|
||||
# channel=get_channel_info(client=slack_client, channel_id=channel_id),
|
||||
# thread=get_thread(
|
||||
# client=slack_client, channel_id=channel_id, thread_id=thread_ts
|
||||
# )
|
||||
# if thread_ts
|
||||
# else [message],
|
||||
# )
|
||||
# if doc is None:
|
||||
# logger.info("Message was determined to not be indexable")
|
||||
# return EventHandlingResponse(challenge=None)
|
||||
|
||||
# build_indexing_pipeline()([doc])
|
||||
# except Exception:
|
||||
# logger.exception("Failed to process slack message")
|
||||
# return EventHandlingResponse(challenge=None)
|
||||
|
||||
# logger.error("Unsupported event type: %s", event.type)
|
||||
# return EventHandlingResponse(challenge=None)
|
@ -17,6 +17,7 @@ from danswer.configs.constants import DocumentSource
|
||||
from danswer.configs.constants import MessageType
|
||||
from danswer.configs.constants import QAFeedbackType
|
||||
from danswer.configs.constants import SearchFeedbackType
|
||||
from danswer.connectors.models import DocumentBase
|
||||
from danswer.connectors.models import InputType
|
||||
from danswer.danswerbot.slack.config import VALID_SLACK_FILTERS
|
||||
from danswer.db.models import AllowedAnswerFilters
|
||||
@ -27,6 +28,7 @@ from danswer.db.models import DocumentSet as DocumentSetDBModel
|
||||
from danswer.db.models import IndexAttempt
|
||||
from danswer.db.models import IndexingStatus
|
||||
from danswer.db.models import TaskStatus
|
||||
from danswer.direct_qa.interfaces import DanswerAnswer
|
||||
from danswer.direct_qa.interfaces import DanswerQuote
|
||||
from danswer.search.models import BaseFilters
|
||||
from danswer.search.models import QueryFlow
|
||||
@ -198,32 +200,10 @@ class SearchFeedbackRequest(BaseModel):
|
||||
search_feedback: SearchFeedbackType
|
||||
|
||||
|
||||
class QueryValidationResponse(BaseModel):
|
||||
reasoning: str
|
||||
answerable: bool
|
||||
|
||||
|
||||
class RetrievalDocs(BaseModel):
|
||||
top_documents: list[SearchDoc]
|
||||
|
||||
|
||||
class SearchResponse(RetrievalDocs):
|
||||
query_event_id: int
|
||||
source_type: list[DocumentSource] | None
|
||||
time_cutoff: datetime | None
|
||||
favor_recent: bool
|
||||
|
||||
|
||||
class QAResponse(SearchResponse):
|
||||
answer: str | None # DanswerAnswer
|
||||
quotes: list[DanswerQuote] | None
|
||||
predicted_flow: QueryFlow
|
||||
predicted_search: SearchType
|
||||
eval_res_valid: bool | None = None
|
||||
llm_chunks_indices: list[int] | None = None
|
||||
error_msg: str | None = None
|
||||
|
||||
|
||||
# First chunk of info for streaming QA
|
||||
class QADocsResponse(RetrievalDocs):
|
||||
predicted_flow: QueryFlow
|
||||
@ -311,6 +291,36 @@ class ChatSessionDetailResponse(BaseModel):
|
||||
messages: list[ChatMessageDetail]
|
||||
|
||||
|
||||
class QueryValidationResponse(BaseModel):
|
||||
reasoning: str
|
||||
answerable: bool
|
||||
|
||||
|
||||
class AdminSearchRequest(BaseModel):
|
||||
query: str
|
||||
filters: BaseFilters
|
||||
|
||||
|
||||
class AdminSearchResponse(BaseModel):
|
||||
documents: list[SearchDoc]
|
||||
|
||||
|
||||
class SearchResponse(RetrievalDocs):
|
||||
query_event_id: int
|
||||
source_type: list[DocumentSource] | None
|
||||
time_cutoff: datetime | None
|
||||
favor_recent: bool
|
||||
|
||||
|
||||
class QAResponse(SearchResponse, DanswerAnswer):
|
||||
quotes: list[DanswerQuote] | None
|
||||
predicted_flow: QueryFlow
|
||||
predicted_search: SearchType
|
||||
eval_res_valid: bool | None = None
|
||||
llm_chunks_indices: list[int] | None = None
|
||||
error_msg: str | None = None
|
||||
|
||||
|
||||
class UserByEmail(BaseModel):
|
||||
user_email: str
|
||||
|
||||
@ -392,7 +402,8 @@ class RunConnectorRequest(BaseModel):
|
||||
|
||||
class CredentialBase(BaseModel):
|
||||
credential_json: dict[str, Any]
|
||||
is_admin: bool
|
||||
# if `true`, then all Admins will have access to the credential
|
||||
admin_public: bool
|
||||
|
||||
|
||||
class CredentialSnapshot(CredentialBase):
|
||||
@ -409,7 +420,7 @@ class CredentialSnapshot(CredentialBase):
|
||||
if MASK_CREDENTIAL_PREFIX
|
||||
else credential.credential_json,
|
||||
user_id=credential.user_id,
|
||||
is_admin=credential.is_admin,
|
||||
admin_public=credential.admin_public,
|
||||
time_created=credential.time_created,
|
||||
time_updated=credential.time_updated,
|
||||
)
|
||||
@ -510,6 +521,20 @@ class DocumentSet(BaseModel):
|
||||
)
|
||||
|
||||
|
||||
class IngestionDocument(BaseModel):
|
||||
document: DocumentBase
|
||||
connector_id: int | None = None # Takes precedence over the name
|
||||
connector_name: str | None = None
|
||||
credential_id: int | None = None
|
||||
create_connector: bool = False # Currently not allowed
|
||||
public_doc: bool = True # To attach to the cc_pair, currently unused
|
||||
|
||||
|
||||
class IngestionResult(BaseModel):
|
||||
document_id: str
|
||||
already_existed: bool
|
||||
|
||||
|
||||
class SlackBotTokens(BaseModel):
|
||||
bot_token: str
|
||||
app_token: str
|
||||
|
@ -2,7 +2,6 @@ from fastapi import APIRouter
|
||||
from fastapi import Depends
|
||||
from fastapi import HTTPException
|
||||
from fastapi.responses import StreamingResponse
|
||||
from pydantic import BaseModel
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from danswer.auth.users import current_admin_user
|
||||
@ -17,7 +16,6 @@ from danswer.document_index.factory import get_default_document_index
|
||||
from danswer.document_index.vespa.index import VespaIndex
|
||||
from danswer.search.access_filters import build_access_filters_for_user
|
||||
from danswer.search.danswer_helper import recommend_search_flow
|
||||
from danswer.search.models import BaseFilters
|
||||
from danswer.search.models import IndexFilters
|
||||
from danswer.search.search_runner import chunks_to_search_docs
|
||||
from danswer.search.search_runner import danswer_search
|
||||
@ -25,6 +23,8 @@ from danswer.secondary_llm_flows.query_validation import get_query_answerability
|
||||
from danswer.secondary_llm_flows.query_validation import stream_query_answerability
|
||||
from danswer.secondary_llm_flows.source_filter import extract_question_source_filters
|
||||
from danswer.secondary_llm_flows.time_filter import extract_question_time_filters
|
||||
from danswer.server.models import AdminSearchRequest
|
||||
from danswer.server.models import AdminSearchResponse
|
||||
from danswer.server.models import HelperResponse
|
||||
from danswer.server.models import QAFeedbackRequest
|
||||
from danswer.server.models import QAResponse
|
||||
@ -45,15 +45,6 @@ router = APIRouter()
|
||||
"""Admin-only search endpoints"""
|
||||
|
||||
|
||||
class AdminSearchRequest(BaseModel):
|
||||
query: str
|
||||
filters: BaseFilters
|
||||
|
||||
|
||||
class AdminSearchResponse(BaseModel):
|
||||
documents: list[SearchDoc]
|
||||
|
||||
|
||||
@router.post("/admin/search")
|
||||
def admin_search(
|
||||
question: AdminSearchRequest,
|
||||
|
@ -1,5 +1,11 @@
|
||||
import json
|
||||
import re
|
||||
from urllib.parse import quote
|
||||
|
||||
|
||||
def make_url_compatible(s: str) -> str:
|
||||
s_with_underscores = s.replace(" ", "_")
|
||||
return quote(s_with_underscores, safe="")
|
||||
|
||||
|
||||
def has_unescaped_quote(s: str) -> bool:
|
||||
|
@ -2,6 +2,9 @@ import os
|
||||
import sys
|
||||
|
||||
import psycopg2
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from danswer.db.engine import get_sqlalchemy_engine
|
||||
|
||||
# makes it so `PYTHONPATH=.` is not required when running this script
|
||||
parent_dir = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
|
||||
@ -61,7 +64,8 @@ def wipe_all_rows(database: str) -> None:
|
||||
if __name__ == "__main__":
|
||||
print("Cleaning up all Danswer tables")
|
||||
wipe_all_rows(POSTGRES_DB)
|
||||
create_initial_public_credential()
|
||||
with Session(get_sqlalchemy_engine(), expire_on_commit=False) as db_session:
|
||||
create_initial_public_credential(db_session)
|
||||
print("To keep data consistent, it's best to wipe the document index as well.")
|
||||
print(
|
||||
"To be safe, it's best to restart the Danswer services (API Server and Background Tasks"
|
||||
|
@ -321,7 +321,7 @@ const Main = () => {
|
||||
| Credential<GoogleDriveCredentialJson>
|
||||
| undefined = credentialsData.find(
|
||||
(credential) =>
|
||||
credential.credential_json?.google_drive_tokens && credential.is_admin
|
||||
credential.credential_json?.google_drive_tokens && credential.admin_public
|
||||
);
|
||||
const googleDriveServiceAccountCredential:
|
||||
| Credential<GoogleDriveServiceAccountCredentialJson>
|
||||
|
@ -57,7 +57,7 @@ export function CredentialForm<T extends Yup.AnyObject>({
|
||||
formikHelpers.setSubmitting(true);
|
||||
submitCredential<T>({
|
||||
credential_json: values,
|
||||
is_admin: true,
|
||||
admin_public: true,
|
||||
}).then(({ message, isSuccess }) => {
|
||||
setPopup({ message, type: isSuccess ? "success" : "error" });
|
||||
formikHelpers.setSubmitting(false);
|
||||
|
@ -11,7 +11,7 @@ export const setupGoogleDriveOAuth = async ({
|
||||
"Content-Type": "application/json",
|
||||
},
|
||||
body: JSON.stringify({
|
||||
is_admin: isAdmin,
|
||||
admin_public: isAdmin,
|
||||
credential_json: {},
|
||||
}),
|
||||
});
|
||||
|
@ -168,7 +168,7 @@ export interface ConnectorIndexingStatus<
|
||||
// CREDENTIALS
|
||||
export interface CredentialBase<T> {
|
||||
credential_json: T;
|
||||
is_admin: boolean;
|
||||
admin_public: boolean;
|
||||
}
|
||||
|
||||
export interface Credential<T> extends CredentialBase<T> {
|
||||
|
Loading…
x
Reference in New Issue
Block a user