Consolidate versions for easier extension (#495)

This commit is contained in:
Yuhong Sun
2023-10-01 23:49:38 -07:00
committed by GitHub
parent a808c733b8
commit 351475de28
29 changed files with 530 additions and 278 deletions

View File

@@ -0,0 +1,47 @@
"""Add SAML Accounts
Revision ID: ae62505e3acc
Revises: 7da543f5672f
Create Date: 2023-09-26 16:19:30.933183
"""
import fastapi_users_db_sqlalchemy
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision = "ae62505e3acc"
down_revision = "7da543f5672f"
branch_labels = None
depends_on = None
def upgrade() -> None:
op.create_table(
"saml",
sa.Column("id", sa.Integer(), nullable=False),
sa.Column(
"user_id",
fastapi_users_db_sqlalchemy.generics.GUID(),
nullable=False,
),
sa.Column("encrypted_cookie", sa.Text(), nullable=False),
sa.Column("expires_at", sa.DateTime(timezone=True), nullable=True),
sa.Column(
"updated_at",
sa.DateTime(timezone=True),
server_default=sa.text("now()"),
nullable=False,
),
sa.ForeignKeyConstraint(
["user_id"],
["user.id"],
),
sa.PrimaryKeyConstraint("id"),
sa.UniqueConstraint("encrypted_cookie"),
sa.UniqueConstraint("user_id"),
)
def downgrade() -> None:
op.drop_table("saml")

View File

@@ -1,4 +1,3 @@
import contextlib
import os
import smtplib
import uuid
@@ -6,10 +5,13 @@ from collections.abc import AsyncGenerator
from email.mime.multipart import MIMEMultipart
from email.mime.text import MIMEText
from typing import Optional
from typing import Tuple
from fastapi import APIRouter
from fastapi import Depends
from fastapi import HTTPException
from fastapi import Request
from fastapi import Response
from fastapi import status
from fastapi_users import BaseUserManager
from fastapi_users import FastAPIUsers
@@ -18,21 +20,17 @@ from fastapi_users import schemas
from fastapi_users import UUIDIDMixin
from fastapi_users.authentication import AuthenticationBackend
from fastapi_users.authentication import CookieTransport
from fastapi_users.authentication import Strategy
from fastapi_users.authentication.strategy.db import AccessTokenDatabase
from fastapi_users.authentication.strategy.db import DatabaseStrategy
from fastapi_users.db import SQLAlchemyUserDatabase
from httpx_oauth.clients.google import GoogleOAuth2
from httpx_oauth.clients.openid import OpenID
from pydantic import EmailStr
from fastapi_users.openapi import OpenAPIResponseType
from sqlalchemy.orm import Session
from danswer.auth.schemas import UserCreate
from danswer.auth.schemas import UserRole
from danswer.configs.app_configs import AUTH_TYPE
from danswer.configs.app_configs import DISABLE_AUTH
from danswer.configs.app_configs import ENABLE_OAUTH
from danswer.configs.app_configs import OAUTH_CLIENT_ID
from danswer.configs.app_configs import OAUTH_CLIENT_SECRET
from danswer.configs.app_configs import OAUTH_TYPE
from danswer.configs.app_configs import OPENID_CONFIG_URL
from danswer.configs.app_configs import REQUIRE_EMAIL_VERIFICATION
from danswer.configs.app_configs import SECRET
from danswer.configs.app_configs import SESSION_EXPIRE_TIME_SECONDS
@@ -42,23 +40,32 @@ from danswer.configs.app_configs import SMTP_SERVER
from danswer.configs.app_configs import SMTP_USER
from danswer.configs.app_configs import VALID_EMAIL_DOMAINS
from danswer.configs.app_configs import WEB_DOMAIN
from danswer.configs.constants import AuthType
from danswer.db.auth import get_access_token_db
from danswer.db.auth import get_user_count
from danswer.db.auth import get_user_db
from danswer.db.engine import get_async_session
from danswer.db.engine import get_session
from danswer.db.models import AccessToken
from danswer.db.models import User
from danswer.utils.logger import setup_logger
from danswer.utils.variable_functionality import fetch_versioned_implementation
logger = setup_logger()
FAKE_USER_EMAIL = "fakeuser@fakedanswermail.com"
FAKE_USER_PASS = "foobar"
USER_WHITELIST_FILE = "/home/danswer_whitelist.txt"
_user_whitelist: list[str] | None = None
def verify_auth_setting() -> None:
if AUTH_TYPE not in [AuthType.DISABLED, AuthType.BASIC, AuthType.GOOGLE_OAUTH]:
raise ValueError(
"User must choose a valid user authentication method: "
"disabled, basic, or google_oauth"
)
logger.info(f"Using Auth Type: {AUTH_TYPE.value}")
def get_user_whitelist() -> list[str]:
global _user_whitelist
if _user_whitelist is None:
@@ -204,53 +211,84 @@ auth_backend = AuthenticationBackend(
get_strategy=get_database_strategy,
)
oauth_client = None # type: GoogleOAuth2 | OpenID | None
if ENABLE_OAUTH:
if OAUTH_TYPE == "google":
oauth_client = GoogleOAuth2(OAUTH_CLIENT_ID, OAUTH_CLIENT_SECRET)
elif OAUTH_TYPE == "openid":
oauth_client = OpenID(OAUTH_CLIENT_ID, OAUTH_CLIENT_SECRET, OPENID_CONFIG_URL)
else:
raise AssertionError(f"Invalid OAUTH type {OAUTH_TYPE}")
class FastAPIUserWithLogoutRouter(FastAPIUsers[models.UP, models.ID]):
def get_logout_router(
self,
backend: AuthenticationBackend,
requires_verification: bool = REQUIRE_EMAIL_VERIFICATION,
) -> APIRouter:
"""
Provide a router for logout only for OAuth/OIDC Flows.
This way the login router does not need to be included
"""
router = APIRouter()
get_current_user_token = self.authenticator.current_user_token(
active=True, verified=requires_verification
)
logout_responses: OpenAPIResponseType = {
**{
status.HTTP_401_UNAUTHORIZED: {
"description": "Missing token or inactive user."
}
},
**backend.transport.get_openapi_logout_responses_success(),
}
@router.post(
"/logout", name=f"auth:{backend.name}.logout", responses=logout_responses
)
async def logout(
user_token: Tuple[models.UP, str] = Depends(get_current_user_token),
strategy: Strategy[models.UP, models.ID] = Depends(backend.get_strategy),
) -> Response:
user, token = user_token
return await backend.logout(strategy, user, token)
return router
fastapi_users = FastAPIUsers[User, uuid.UUID](get_user_manager, [auth_backend])
# Currently unused, maybe useful later
async def create_get_fake_user() -> User:
get_async_session_context = contextlib.asynccontextmanager(
get_async_session
) # type:ignore
get_user_db_context = contextlib.asynccontextmanager(get_user_db)
get_user_manager_context = contextlib.asynccontextmanager(get_user_manager)
logger.info("Creating fake user due to Auth being turned off")
async with get_async_session_context() as session:
async with get_user_db_context(session) as user_db:
async with get_user_manager_context(user_db) as user_manager:
user = await user_manager.get_by_email(FAKE_USER_EMAIL)
if user:
return user
user = await user_manager.create(
UserCreate(email=EmailStr(FAKE_USER_EMAIL), password=FAKE_USER_PASS)
)
logger.info("Created fake user.")
return user
current_active_user = fastapi_users.current_user(
active=True, verified=REQUIRE_EMAIL_VERIFICATION, optional=DISABLE_AUTH
fastapi_users = FastAPIUserWithLogoutRouter[User, uuid.UUID](
get_user_manager, [auth_backend]
)
async def current_user(user: User = Depends(current_active_user)) -> User | None:
if DISABLE_AUTH:
optional_valid_user = fastapi_users.current_user(
active=True, verified=REQUIRE_EMAIL_VERIFICATION, optional=True
)
async def double_check_user(
request: Request,
user: User | None,
db_session: Session,
optional: bool = DISABLE_AUTH,
) -> User | None:
if optional:
return None
if user is None:
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail="Access denied. User is not authenticated.",
)
return user
async def current_admin_user(user: User = Depends(current_user)) -> User | None:
async def current_user(
request: Request,
user: User | None = Depends(optional_valid_user),
db_session: Session = Depends(get_session),
) -> User | None:
double_check_user = fetch_versioned_implementation(
"danswer.auth.users", "double_check_user"
)
user = await double_check_user(request, user, db_session)
return user
async def current_admin_user(user: User | None = Depends(current_user)) -> User | None:
if DISABLE_AUTH:
return None

View File

@@ -1,5 +1,6 @@
import os
from danswer.configs.constants import AuthType
from danswer.configs.constants import DocumentIndexType
#####
@@ -14,32 +15,27 @@ APP_PORT = 8080
#####
BLURB_SIZE = 128 # Number Encoder Tokens included in the chunk blurb
GENERATIVE_MODEL_ACCESS_CHECK_FREQ = 86400 # 1 day
# DISABLE_GENERATIVE_AI will turn of the question answering part of Danswer. Use this
# if you want to use Danswer as a search engine only and/or you are not comfortable sending
# anything to OpenAI. TODO: update this message once we support Azure / open source generative models.
# DISABLE_GENERATIVE_AI will turn of the question answering part of Danswer.
# Use this if you want to use Danswer as a search engine only without the LLM capabilities
DISABLE_GENERATIVE_AI = os.environ.get("DISABLE_GENERATIVE_AI", "").lower() == "true"
#####
# Web Configs
#####
# WEB_DOMAIN is used to set the redirect_uri when doing OAuth with Google
# TODO: investigate if this can be done cleaner by overwriting the redirect_uri
# on the frontend and just sending a dummy value (or completely generating the URL)
# on the frontend
WEB_DOMAIN = os.environ.get("WEB_DOMAIN", "http://localhost:3000")
# WEB_DOMAIN is used to set the redirect_uri after login flows
WEB_DOMAIN = os.environ.get("WEB_DOMAIN") or "http://localhost:3000"
#####
# Auth Configs
#####
DISABLE_AUTH = os.environ.get("DISABLE_AUTH", "").lower() == "true"
REQUIRE_EMAIL_VERIFICATION = (
os.environ.get("REQUIRE_EMAIL_VERIFICATION", "").lower() == "true"
AUTH_TYPE = AuthType((os.environ.get("AUTH_TYPE") or AuthType.DISABLED.value).lower())
DISABLE_AUTH = AUTH_TYPE == AuthType.DISABLED
# Turn off mask if admin users should see full credentials for data connectors.
MASK_CREDENTIAL_PREFIX = (
os.environ.get("MASK_CREDENTIAL_PREFIX", "True").lower() != "false"
)
SMTP_SERVER = os.environ.get("SMTP_SERVER", "smtp.gmail.com")
SMTP_PORT = int(os.environ.get("SMTP_PORT", "587"))
SMTP_USER = os.environ.get("SMTP_USER", "your-email@gmail.com")
SMTP_PASS = os.environ.get("SMTP_PASS", "your-gmail-password")
SECRET = os.environ.get("SECRET", "")
SESSION_EXPIRE_TIME_SECONDS = int(
@@ -62,18 +58,17 @@ VALID_EMAIL_DOMAINS = (
)
# OAuth Login Flow
ENABLE_OAUTH = os.environ.get("ENABLE_OAUTH", "").lower() != "false"
OAUTH_TYPE = os.environ.get("OAUTH_TYPE", "google").lower()
OAUTH_CLIENT_ID = os.environ.get(
"OAUTH_CLIENT_ID", os.environ.get("GOOGLE_OAUTH_CLIENT_ID", "")
)
OAUTH_CLIENT_SECRET = os.environ.get(
"OAUTH_CLIENT_SECRET", os.environ.get("GOOGLE_OAUTH_CLIENT_SECRET", "")
)
OPENID_CONFIG_URL = os.environ.get("OPENID_CONFIG_URL", "")
MASK_CREDENTIAL_PREFIX = (
os.environ.get("MASK_CREDENTIAL_PREFIX", "True").lower() != "false"
GOOGLE_OAUTH_CLIENT_ID = os.environ.get("GOOGLE_OAUTH_CLIENT_ID") or ""
GOOGLE_OAUTH_CLIENT_SECRET = os.environ.get("GOOGLE_OAUTH_CLIENT_SECRET") or ""
# The following Basic Auth configs are not supported by the frontend UI
REQUIRE_EMAIL_VERIFICATION = (
os.environ.get("REQUIRE_EMAIL_VERIFICATION", "").lower() == "true"
)
SMTP_SERVER = os.environ.get("SMTP_SERVER", "smtp.gmail.com")
SMTP_PORT = int(os.environ.get("SMTP_PORT", "587"))
SMTP_USER = os.environ.get("SMTP_USER", "your-email@gmail.com")
SMTP_PASS = os.environ.get("SMTP_PASS", "your-gmail-password")
#####
@@ -105,7 +100,7 @@ TYPESENSE_API_KEY = os.environ.get("TYPESENSE_API_KEY", "")
# Number of documents in a batch during indexing (further batching done by chunks before passing to bi-encoder)
INDEX_BATCH_SIZE = 16
# below are intended to match the env variables names used by the official postgres docker image
# Below are intended to match the env variables names used by the official postgres docker image
# https://hub.docker.com/_/postgres
POSTGRES_USER = os.environ.get("POSTGRES_USER") or "postgres"
POSTGRES_PASSWORD = os.environ.get("POSTGRES_PASSWORD") or "password"

View File

@@ -27,6 +27,7 @@ BOOST = "boost"
SCORE = "score"
ID_SEPARATOR = ":;:"
DEFAULT_BOOST = 0
SESSION_KEY = "session"
# Prompt building constants:
GENERAL_SEP_PAT = "\n-----\n"
@@ -80,6 +81,14 @@ class DanswerGenAIModel(str, Enum):
TRANSFORMERS = "transformers"
class AuthType(str, Enum):
DISABLED = "disabled"
BASIC = "basic"
GOOGLE_OAUTH = "google_oauth"
OIDC = "oidc"
SAML = "saml"
class ModelHostType(str, Enum):
"""For GenAI models interfaced via requests, different services have different
expectations for what fields are included in the request"""

View File

@@ -20,9 +20,9 @@
</nodes>
<tuning>
<resource-limits>
<!-- Default is 75% but this should be increased for Dockerized deployments -->
<!-- Default is 75% but this can be increased for Dockerized deployments -->
<!-- https://docs.vespa.ai/en/operations/feed-block.html -->
<disk>0.98</disk>
<disk>0.75</disk>
</resource-limits>
</tuning>
<config name="vespa.config.search.summary.juniperrc">

View File

@@ -5,25 +5,24 @@ from fastapi import Request
from fastapi.exceptions import RequestValidationError
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import JSONResponse
from httpx_oauth.clients.google import GoogleOAuth2
from danswer.auth.schemas import UserCreate
from danswer.auth.schemas import UserRead
from danswer.auth.schemas import UserUpdate
from danswer.auth.users import auth_backend
from danswer.auth.users import fastapi_users
from danswer.auth.users import oauth_client
from danswer.chat.personas import load_personas_from_yaml
from danswer.configs.app_configs import APP_HOST
from danswer.configs.app_configs import APP_PORT
from danswer.configs.app_configs import AUTH_TYPE
from danswer.configs.app_configs import DISABLE_AUTH
from danswer.configs.app_configs import DISABLE_GENERATIVE_AI
from danswer.configs.app_configs import ENABLE_OAUTH
from danswer.configs.app_configs import OAUTH_CLIENT_ID
from danswer.configs.app_configs import OAUTH_CLIENT_SECRET
from danswer.configs.app_configs import OAUTH_TYPE
from danswer.configs.app_configs import OPENID_CONFIG_URL
from danswer.configs.app_configs import GOOGLE_OAUTH_CLIENT_ID
from danswer.configs.app_configs import GOOGLE_OAUTH_CLIENT_SECRET
from danswer.configs.app_configs import SECRET
from danswer.configs.app_configs import WEB_DOMAIN
from danswer.configs.constants import AuthType
from danswer.configs.model_configs import API_BASE_OPENAI
from danswer.configs.model_configs import API_TYPE_OPENAI
from danswer.configs.model_configs import ASYM_PASSAGE_PREFIX
@@ -39,13 +38,14 @@ from danswer.server.chat_backend import router as chat_router
from danswer.server.credential import router as credential_router
from danswer.server.document_set import router as document_set_router
from danswer.server.event_loading import router as event_processing_router
from danswer.server.health import router as health_router
from danswer.server.manage import router as admin_router
from danswer.server.search_backend import router as backend_router
from danswer.server.slack_bot_management import router as slack_bot_management_router
from danswer.server.state import router as state_router
from danswer.server.users import router as user_router
from danswer.utils.acl import set_acl_for_vespa
from danswer.utils.logger import setup_logger
from danswer.utils.variable_functionality import fetch_versioned_implementation
logger = setup_logger()
@@ -82,53 +82,41 @@ def get_application() -> FastAPI:
application.include_router(credential_router)
application.include_router(document_set_router)
application.include_router(slack_bot_management_router)
application.include_router(health_router)
application.include_router(state_router)
application.include_router(
fastapi_users.get_auth_router(auth_backend),
prefix="/auth/database",
tags=["auth"],
)
application.include_router(
fastapi_users.get_register_router(UserRead, UserCreate),
prefix="/auth",
tags=["auth"],
)
application.include_router(
fastapi_users.get_reset_password_router(),
prefix="/auth",
tags=["auth"],
)
application.include_router(
fastapi_users.get_verify_router(UserRead),
prefix="/auth",
tags=["auth"],
)
application.include_router(
fastapi_users.get_users_router(UserRead, UserUpdate),
prefix="/users",
tags=["users"],
)
if ENABLE_OAUTH:
if oauth_client is None:
raise RuntimeError("OAuth is enabled but no OAuth client is configured")
if AUTH_TYPE == AuthType.DISABLED:
# Server logs this during auth setup verification step
pass
if OAUTH_TYPE == "google":
# special case for google
application.include_router(
fastapi_users.get_oauth_router(
oauth_client,
auth_backend,
SECRET,
associate_by_email=True,
is_verified_by_default=True,
# points the user back to the login page, where we will call the
# /auth/google/callback endpoint + redirect them to the main app
redirect_url=f"{WEB_DOMAIN}/auth/google/callback",
),
prefix="/auth/google",
tags=["auth"],
)
elif AUTH_TYPE == AuthType.BASIC:
application.include_router(
fastapi_users.get_auth_router(auth_backend),
prefix="/auth",
tags=["auth"],
)
application.include_router(
fastapi_users.get_register_router(UserRead, UserCreate),
prefix="/auth",
tags=["auth"],
)
application.include_router(
fastapi_users.get_reset_password_router(),
prefix="/auth",
tags=["auth"],
)
application.include_router(
fastapi_users.get_verify_router(UserRead),
prefix="/auth",
tags=["auth"],
)
application.include_router(
fastapi_users.get_users_router(UserRead, UserUpdate),
prefix="/users",
tags=["users"],
)
elif AUTH_TYPE == AuthType.GOOGLE_OAUTH:
oauth_client = GoogleOAuth2(GOOGLE_OAUTH_CLIENT_ID, GOOGLE_OAUTH_CLIENT_SECRET)
application.include_router(
fastapi_users.get_oauth_router(
oauth_client,
@@ -136,16 +124,16 @@ def get_application() -> FastAPI:
SECRET,
associate_by_email=True,
is_verified_by_default=True,
# points the user back to the login page, where we will call the
# /auth/oauth/callback endpoint + redirect them to the main app
# points the user back to the login page
redirect_url=f"{WEB_DOMAIN}/auth/oauth/callback",
),
prefix="/auth/oauth",
tags=["auth"],
)
# need basic auth router for `logout` endpoint
application.include_router(
fastapi_users.get_oauth_associate_router(oauth_client, UserRead, SECRET),
prefix="/auth/associate/oauth",
fastapi_users.get_logout_router(auth_backend),
prefix="/auth",
tags=["auth"],
)
@@ -170,30 +158,25 @@ def get_application() -> FastAPI:
if API_TYPE_OPENAI == "azure":
logger.info(f"Using Azure OpenAI with Endpoint: {API_BASE_OPENAI}")
auth_status = "off" if DISABLE_AUTH else "on"
logger.info(f"User Authentication is turned {auth_status}")
verify_auth = fetch_versioned_implementation(
"danswer.auth.users", "verify_auth_setting"
)
# Will throw exception if an issue is found
verify_auth()
if not DISABLE_AUTH:
if not ENABLE_OAUTH:
logger.debug("OAuth is turned off")
else:
if not OAUTH_CLIENT_ID:
logger.warning("OAuth is turned on but OAUTH_CLIENT_ID is empty")
if not OAUTH_CLIENT_SECRET:
logger.warning(
"OAuth is turned on but OAUTH_CLIENT_SECRET is empty"
)
if OAUTH_TYPE == "openid" and not OPENID_CONFIG_URL:
logger.warning("OpenID is turned on but OPENID_CONFIG_URL is emtpy")
else:
logger.debug("OAuth is turned on")
if DISABLE_AUTH:
logger.info("User Authentication is turned off.")
if GOOGLE_OAUTH_CLIENT_ID and GOOGLE_OAUTH_CLIENT_SECRET:
logger.info("Found both OAuth Client ID and secret configured.")
if SKIP_RERANKING:
logger.info("Reranking step of search flow is disabled")
logger.info(f'Using Embedding model: "{DOCUMENT_ENCODER_MODEL}"')
logger.info(f'Query embedding prefix: "{ASYM_QUERY_PREFIX}"')
logger.info(f'Passage embedding prefix: "{ASYM_PASSAGE_PREFIX}"')
if ASYM_QUERY_PREFIX or ASYM_PASSAGE_PREFIX:
logger.info(f'Query embedding prefix: "{ASYM_QUERY_PREFIX}"')
logger.info(f'Passage embedding prefix: "{ASYM_PASSAGE_PREFIX}"')
logger.info("Warming up local NLP models.")
warm_up_models()
@@ -201,9 +184,9 @@ def get_application() -> FastAPI:
qa_model.warm_up_model()
logger.info("Verifying query preprocessing (NLTK) data is downloaded")
nltk.download("stopwords")
nltk.download("wordnet")
nltk.download("punkt")
nltk.download("stopwords", quiet=True)
nltk.download("wordnet", quiet=True)
nltk.download("punkt", quiet=True)
logger.info("Verifying public credential exists.")
create_initial_public_credential()
@@ -219,17 +202,18 @@ def get_application() -> FastAPI:
# does nothing if this has been successfully run before
set_acl_for_vespa(should_check_if_already_done=True)
application.add_middleware(
CORSMiddleware,
allow_origins=["*"], # Change this to the list of allowed origins if needed
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
return application
app = get_application()
app.add_middleware(
CORSMiddleware,
allow_origins=["*"], # Change this to the list of allowed origins if needed
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
if __name__ == "__main__":

View File

@@ -9,6 +9,10 @@ from danswer.direct_qa.qa_block import dict_based_prompt_to_langchain_prompt
from danswer.llm.build import get_default_llm
from danswer.server.models import QueryValidationResponse
from danswer.server.utils import get_json_line
from danswer.utils.logger import setup_logger
logger = setup_logger()
QUERY_PAT = "QUERY: "
REASONING_PAT = "THOUGHT: "
@@ -94,37 +98,44 @@ def get_query_answerability(user_query: str) -> tuple[str, bool]:
def stream_query_answerability(user_query: str) -> Iterator[str]:
messages = get_query_validation_messages(user_query)
filled_llm_prompt = dict_based_prompt_to_langchain_prompt(messages)
tokens = get_default_llm().stream(filled_llm_prompt)
reasoning_pat_found = False
model_output = ""
hold_answerable = ""
for token in tokens:
model_output = model_output + token
try:
tokens = get_default_llm().stream(filled_llm_prompt)
reasoning_pat_found = False
model_output = ""
hold_answerable = ""
for token in tokens:
model_output = model_output + token
if ANSWERABLE_PAT in model_output:
continue
if not reasoning_pat_found and REASONING_PAT in model_output:
reasoning_pat_found = True
reason_ind = model_output.find(REASONING_PAT)
remaining = model_output[reason_ind + len(REASONING_PAT) :]
if remaining:
yield get_json_line(asdict(DanswerAnswerPiece(answer_piece=remaining)))
continue
if reasoning_pat_found:
hold_answerable = hold_answerable + token
if hold_answerable == ANSWERABLE_PAT[: len(hold_answerable)]:
if ANSWERABLE_PAT in model_output:
continue
yield get_json_line(
asdict(DanswerAnswerPiece(answer_piece=hold_answerable))
)
hold_answerable = ""
reasoning = extract_answerability_reasoning(model_output)
answerable = extract_answerability_bool(model_output)
if not reasoning_pat_found and REASONING_PAT in model_output:
reasoning_pat_found = True
reason_ind = model_output.find(REASONING_PAT)
remaining = model_output[reason_ind + len(REASONING_PAT) :]
if remaining:
yield get_json_line(
asdict(DanswerAnswerPiece(answer_piece=remaining))
)
continue
yield get_json_line(
QueryValidationResponse(reasoning=reasoning, answerable=answerable).dict()
)
if reasoning_pat_found:
hold_answerable = hold_answerable + token
if hold_answerable == ANSWERABLE_PAT[: len(hold_answerable)]:
continue
yield get_json_line(
asdict(DanswerAnswerPiece(answer_piece=hold_answerable))
)
hold_answerable = ""
reasoning = extract_answerability_reasoning(model_output)
answerable = extract_answerability_bool(model_output)
yield get_json_line(
QueryValidationResponse(reasoning=reasoning, answerable=answerable).dict()
)
except Exception as e:
# exception is logged in the answer_question method, no need to re-log
yield get_json_line({"error": str(e)})
logger.exception("Failed to validate Query")
return

View File

@@ -1,11 +0,0 @@
from fastapi import APIRouter
from danswer.server.models import StatusResponse
router = APIRouter()
@router.get("/health", response_model=StatusResponse)
def healthcheck() -> StatusResponse:
return StatusResponse(success=True, message="ok")

View File

@@ -1,7 +1,6 @@
from datetime import datetime
from typing import Any
from typing import Generic
from typing import Literal
from typing import Optional
from typing import TypeVar
from uuid import UUID
@@ -9,7 +8,9 @@ from uuid import UUID
from pydantic import BaseModel
from pydantic.generics import GenericModel
from danswer.auth.schemas import UserRole
from danswer.configs.app_configs import MASK_CREDENTIAL_PREFIX
from danswer.configs.constants import AuthType
from danswer.configs.constants import DocumentSource
from danswer.configs.constants import MessageType
from danswer.configs.constants import QAFeedbackType
@@ -38,6 +39,10 @@ class StatusResponse(GenericModel, Generic[DataT]):
data: Optional[DataT] = None
class AuthTypeResponse(BaseModel):
auth_type: AuthType
class DataRequest(BaseModel):
data: str
@@ -47,6 +52,15 @@ class HelperResponse(BaseModel):
details: list[str] | None = None
class UserInfo(BaseModel):
id: str
email: str
is_active: bool
is_superuser: bool
is_verified: bool
role: UserRole
class GoogleAppWebCredentials(BaseModel):
client_id: str
project_id: str
@@ -84,10 +98,6 @@ class FileUploadResponse(BaseModel):
file_paths: list[str]
class HealthCheckResponse(BaseModel):
status: Literal["ok"]
class ObjectCreationIdResponse(BaseModel):
id: int | str

View File

@@ -0,0 +1,18 @@
from fastapi import APIRouter
from danswer.configs.app_configs import AUTH_TYPE
from danswer.server.models import AuthTypeResponse
from danswer.server.models import StatusResponse
router = APIRouter()
@router.get("/health")
def healthcheck() -> StatusResponse:
return StatusResponse(success=True, message="ok")
@router.get("/auth/type")
def get_auth_type() -> AuthTypeResponse:
return AuthTypeResponse(auth_type=AUTH_TYPE)

View File

@@ -9,12 +9,13 @@ from sqlalchemy.orm import Session
from danswer.auth.schemas import UserRead
from danswer.auth.schemas import UserRole
from danswer.auth.users import current_admin_user
from danswer.auth.users import current_user
from danswer.db.engine import get_session
from danswer.db.engine import get_sqlalchemy_async_engine
from danswer.db.models import User
from danswer.db.users import list_users
from danswer.server.models import UserByEmail
from danswer.server.models import UserInfo
router = APIRouter(prefix="/manage")
@@ -43,3 +44,18 @@ def list_all_users(
) -> list[UserRead]:
users = list_users(db_session)
return [UserRead.from_orm(user) for user in users]
@router.get("/me")
def verify_user_logged_in(user: User | None = Depends(current_user)) -> UserInfo:
if user is None:
raise HTTPException(status_code=401, detail="User Not Authenticated")
return UserInfo(
id=str(user.id),
email=user.email,
is_active=user.is_active,
is_superuser=user.is_superuser,
is_verified=user.is_verified,
role=user.role,
)

View File

@@ -0,0 +1,21 @@
import importlib
from typing import Any
class DanswerVersion:
def __init__(self) -> None:
self._is_ee = False
def set_ee(self) -> None:
self._is_ee = True
def get_is_ee_version(self) -> bool:
return self._is_ee
global_version = DanswerVersion()
def fetch_versioned_implementation(module: str, attribute: str) -> Any:
module_full = f"ee.{module}" if global_version.get_is_ee_version() else module
return getattr(importlib.import_module(module_full), attribute)