squash - need to update sql alchemy engine

This commit is contained in:
pablodanswer 2024-09-22 16:45:06 -07:00
parent 5b220ac7b1
commit ae3218f941
14 changed files with 127 additions and 144 deletions

View File

@ -1,3 +1,4 @@
from danswer.configs.app_configs import SECRET_JWT_KEY
from datetime import timedelta
import contextlib
import smtplib
@ -250,11 +251,8 @@ async def create_user_session(user: User, tenant_id: str) -> str:
"tenant_id": tenant_id,
"exp": datetime.utcnow() + timedelta(seconds=SESSION_EXPIRE_TIME_SECONDS)
}
# Encode the token
token = jwt.encode(payload, "JWT_SECRET_KEY", algorithm="HS256")
token = jwt.encode(payload, SECRET_JWT_KEY, algorithm="HS256")
return token

View File

@ -37,6 +37,7 @@ DISABLE_GENERATIVE_AI = os.environ.get("DISABLE_GENERATIVE_AI", "").lower() == "
WEB_DOMAIN = os.environ.get("WEB_DOMAIN") or "http://localhost:3000"
SECRET_JWT_KEY = os.environ.get("SECRET_JWT_KEY") or "JWT_SECRET_KEY"
#####
# Auth Configs
#####

View File

@ -17,6 +17,7 @@ from sqlalchemy.ext.asyncio import create_async_engine
from sqlalchemy.orm import Session
from sqlalchemy.orm import sessionmaker
from danswer.configs.app_configs import SECRET_JWT_KEY
from danswer.configs.app_configs import DEFAULT_SCHEMA
from danswer.configs.app_configs import LOG_POSTGRES_CONN_COUNTS
from danswer.configs.app_configs import LOG_POSTGRES_LATENCY
@ -138,57 +139,22 @@ def init_sqlalchemy_engine(app_name: str) -> None:
global POSTGRES_APP_NAME
POSTGRES_APP_NAME = app_name
_engines = {}
# TODO validate that this is best practice
def get_sqlalchemy_engine(schema: str = DEFAULT_SCHEMA) -> Engine:
if schema in _engines:
return _engines[schema]
_engine = None
connection_string = build_connection_string(
db_api=SYNC_DB_API, app_name=POSTGRES_APP_NAME + "_sync"
)
engine = create_engine(
connection_string,
pool_size=40,
max_overflow=10,
pool_pre_ping=POSTGRES_POOL_PRE_PING,
pool_recycle=POSTGRES_POOL_RECYCLE,
)
@event.listens_for(engine, "connect")
def set_search_path(dbapi_connection, connection_record):
cursor = dbapi_connection.cursor()
cursor.execute(f'SET search_path TO "{schema}"')
cursor.close()
dbapi_connection.commit()
_engines[schema] = engine
return engine
# def get_sqlalchemy_engine(schema: str = DEFAULT_SCHEMA) -> Engine:
# global _SYNC_ENGINE
# if _SYNC_ENGINE is None:
# connection_string = build_connection_string(
# db_api=SYNC_DB_API, app_name=POSTGRES_APP_NAME + "_sync"
# )
# _SYNC_ENGINE = create_engine(
# connection_string,
# pool_size=40,
# max_overflow=10,
# pool_pre_ping=POSTGRES_POOL_PRE_PING,
# pool_recycle=POSTGRES_POOL_RECYCLE,
# )
# @event.listens_for(_SYNC_ENGINE, "connect")
# def set_search_path(dbapi_connection, connection_record):
# cursor = dbapi_connection.cursor()
# cursor.execute(f"SET search_path TO {schema}")
# cursor.close()
# dbapi_connection.commit()
# return _SYNC_ENGINE
def get_sqlalchemy_engine() -> Engine:
global _engine
if _engine is None:
connection_string = build_connection_string(
db_api=SYNC_DB_API, app_name=POSTGRES_APP_NAME + "_sync"
)
_engine = create_engine(
connection_string,
pool_size=40,
max_overflow=10,
pool_pre_ping=POSTGRES_POOL_PRE_PING,
pool_recycle=POSTGRES_POOL_RECYCLE,
)
return _engine
def get_sqlalchemy_async_engine() -> AsyncEngine:
@ -212,21 +178,19 @@ def get_sqlalchemy_async_engine() -> AsyncEngine:
def get_session_context_manager() -> ContextManager[Session]:
return contextlib.contextmanager(get_session)()
def get_current_tenant_id(request: Request) -> str:
if not MULTI_TENANT:
return DEFAULT_SCHEMA
token = request.cookies.get("fastapiusersauth")
print(request.cookies)
token = request.cookies.get("tenant_details")
if not token:
logger.warning("No token found in cookies")
raise HTTPException(status_code=401, detail="Authentication required")
try:
logger.info(f"Attempting to decode token: {token[:10]}...") # Log only first 10 characters for security
payload = jwt.decode(token, "JWT_SECRET_KEY", algorithms=["HS256"])
payload = jwt.decode(token, SECRET_JWT_KEY, algorithms=["HS256"])
logger.info(f"Decoded payload: {payload}")
tenant_id = payload.get("tenant_id")
if not tenant_id:
@ -246,29 +210,8 @@ def get_current_tenant_id(request: Request) -> str:
def get_session(tenant_id: str | None= Depends(get_current_tenant_id)) -> Generator[Session, None, None]:
print("")
print("\n\n\n\n")
with Session(get_sqlalchemy_engine(), expire_on_commit=False) as session:
session.execute(text(f'SET search_path TO "{tenant_id}"'))
print("SEARCH PATH IS ", tenant_id)
with Session(get_sqlalchemy_engine(schema=tenant_id), expire_on_commit=False) as session:
yield session
session.execute(text('SET search_path TO "public"'))
# Logic to create or retrieve a database session for the given tenant_id
# def get_session(schema: str = DEFAULT_SCHEMA) -> Generator[Session, None, None]:
# # The line below was added to monitor the latency caused by Postgres connections
# # during API calls.
# # with tracer.trace("db.get_session"):
# with Session(get_sqlalchemy_engine(), expire_on_commit=False) as session:
# session.execute(text(f"SET search_path TO {schema}"))
# yield session
# session.execute(text("SET search_path TO public"))
async def get_async_session() -> AsyncGenerator[AsyncSession, None]:
async with AsyncSession(

View File

@ -139,6 +139,7 @@ def stream_answer_objects(
rephrased_query = query_req.query_override or thread_based_query_rephrase(
user_query=query_msg.message,
history_str=history_str,
db_session=db_session
)
# Given back ahead of the documents for latency reasons

View File

@ -237,7 +237,7 @@ def retrieve_chunks(
# Currently only uses query expansion on multilingual use cases
query_rephrases = multilingual_query_expansion(
query.query, multilingual_expansion
query.query, multilingual_expansion, db_session=db_session
)
# Just to be extra sure, add the original query.
query_rephrases.append(query.query)

View File

@ -15,11 +15,11 @@ from danswer.prompts.miscellaneous_prompts import LANGUAGE_REPHRASE_PROMPT
from danswer.utils.logger import setup_logger
from danswer.utils.text_processing import count_punctuation
from danswer.utils.threadpool_concurrency import run_functions_tuples_in_parallel
from sqlalchemy.orm import Session
logger = setup_logger()
def llm_multilingual_query_expansion(query: str, language: str) -> str:
def llm_multilingual_query_expansion(query: str, language: str, db_session: Session) -> str:
def _get_rephrase_messages() -> list[dict[str, str]]:
messages = [
{
@ -51,12 +51,13 @@ def llm_multilingual_query_expansion(query: str, language: str) -> str:
def multilingual_query_expansion(
query: str,
expansion_languages: list[str],
db_session: Session,
use_threads: bool = True,
) -> list[str]:
languages = [language.strip() for language in expansion_languages]
if use_threads:
functions_with_args: list[tuple[Callable, tuple]] = [
(llm_multilingual_query_expansion, (query, language))
(llm_multilingual_query_expansion, (query, language, db_session))
for language in languages
]
@ -65,7 +66,7 @@ def multilingual_query_expansion(
else:
query_rephrases = [
llm_multilingual_query_expansion(query, language) for language in languages
llm_multilingual_query_expansion(query, language, db_session) for language in languages
]
return query_rephrases
@ -134,9 +135,10 @@ def history_based_query_rephrase(
def thread_based_query_rephrase(
user_query: str,
history_str: str,
db_session: Session,
llm: LLM | None = None,
size_heuristic: int = 200,
punctuation_heuristic: int = 10,
punctuation_heuristic: int = 10
) -> str:
if not history_str:
return user_query

View File

@ -1,4 +1,5 @@
import re
from sqlalchemy.orm import Session
from collections.abc import Iterator
from danswer.chat.models import DanswerAnswerPiece
@ -46,7 +47,8 @@ def extract_answerability_bool(model_raw: str) -> bool:
def get_query_answerability(
user_query: str, skip_check: bool = DISABLE_LLM_QUERY_ANSWERABILITY
db_session: Session,
user_query: str, skip_check: bool = DISABLE_LLM_QUERY_ANSWERABILITY,
) -> tuple[str, bool]:
if skip_check:
return "Query Answerability Evaluation feature is turned off", True
@ -67,7 +69,7 @@ def get_query_answerability(
def stream_query_answerability(
user_query: str, skip_check: bool = DISABLE_LLM_QUERY_ANSWERABILITY
db_session: Session, user_query: str, skip_check: bool = DISABLE_LLM_QUERY_ANSWERABILITY,
) -> Iterator[str]:
if skip_check:
yield get_json_line(

View File

@ -117,13 +117,13 @@ def get_tags(
@basic_router.post("/query-validation")
def query_validation(
simple_query: SimpleQueryRequest, _: User = Depends(current_user)
simple_query: SimpleQueryRequest, _: User = Depends(current_user), db_session: Session = Depends(get_session)
) -> QueryValidationResponse:
# Note if weak model prompt is chosen, this check does not occur and will simply return that
# the query is valid, this is because weaker models cannot really handle this task well.
# Additionally, some weak model servers cannot handle concurrent inferences.
logger.notice(f"Validating query: {simple_query.query}")
reasoning, answerable = get_query_answerability(simple_query.query)
reasoning, answerable = get_query_answerability(db_session=db_session, user_query=simple_query.query)
return QueryValidationResponse(reasoning=reasoning, answerable=answerable)
@ -226,14 +226,14 @@ def get_search_session(
# No search responses are answered with a conversational generative AI response
@basic_router.post("/stream-query-validation")
def stream_query_validation(
simple_query: SimpleQueryRequest, _: User = Depends(current_user)
simple_query: SimpleQueryRequest, _: User = Depends(current_user), db_session: Session = Depends(get_session)
) -> StreamingResponse:
# Note if weak model prompt is chosen, this check does not occur and will simply return that
# the query is valid, this is because weaker models cannot really handle this task well.
# Additionally, some weak model servers cannot handle concurrent inferences.
logger.notice(f"Validating query: {simple_query.query}")
return StreamingResponse(
stream_query_answerability(simple_query.query), media_type="application/json"
stream_query_answerability(user_query=simple_query.query, db_session=db_session), media_type="application/json"
)

View File

@ -105,8 +105,9 @@ async def create_tenant_schema(tenant_id: str) -> None:
await session.commit()
logger.info(f"Schema created for tenant: {tenant_id}")
with Session(get_sqlalchemy_engine(schema=tenant_id)) as db_session:
with Session(get_sqlalchemy_engine()) as db_session:
try:
db_session.execute(text(f'Set search_path to "{tenant_id}"'))
setup_postgres(db_session)
except SQLAlchemyError as e:
logger.error(f"Error while loading chat YAMLs for tenant {tenant_id}: {str(e)}")
@ -114,8 +115,6 @@ async def create_tenant_schema(tenant_id: str) -> None:
finally:
db_session.execute(text('SET search_path TO "public"'))
# db_session.execute(text(f'SET search_path TO "public"'))
logger.info(f"Migrations completed for tenant: {tenant_id}")
@ -153,25 +152,21 @@ async def sso_callback(
logger.info(f"Session token created for user: {user.email}")
# Set the session cookie with proper flags
response = JSONResponse(content={"message": "Authentication successful"})
response.set_cookie(
key="fastapiusersauth",
key="tenant_details",
value=session_token,
max_age=SESSION_EXPIRE_TIME_SECONDS,
expires=SESSION_EXPIRE_TIME_SECONDS,
path="/",
domain=WEB_DOMAIN.split("://")[-1],
secure=True,
secure=False, # Set to True in production with HTTPS
httponly=True,
samesite="lax",
)
logger.info("Session cookie set")
logger.info("SSO callback completed successfully")
return JSONResponse(
content={"message": "Authentication successful"},
status_code=200
)
return response
# @basic_router.post("/auth/sso-callback")

View File

@ -223,6 +223,7 @@ def handle_send_message_simple_with_history(
rephrased_query = req.query_override or thread_based_query_rephrase(
user_query=query,
history_str=history_str,
db_session=db_session
)
full_chat_msg_info = CreateChatMessageRequest(

View File

@ -47,3 +47,31 @@ export interface CombinedSettings {
isMobile?: boolean;
webVersion: string | null;
}
export const defaultCombinedSettings: CombinedSettings = {
settings: {
chat_page_enabled: true,
search_page_enabled: true,
default_page: "search",
maximum_chat_retention_days: 30,
notifications: [],
needs_reindexing: false,
},
enterpriseSettings: {
application_name: "Danswer",
use_custom_logo: false,
use_custom_logotype: false,
custom_lower_disclaimer_content: null,
custom_header_content: null,
custom_popup_header: null,
custom_popup_content: null,
},
cloudSettings: {
numberOfSeats: 0,
planType: BillingPlanType.FREE,
},
customAnalyticsScript: null,
isMobile: false,
webVersion: null,
};

View File

@ -1,4 +1,5 @@
"use client";
import Cookies from "js-cookie";
import { useEffect, useState } from "react";
import { useRouter, useSearchParams } from "next/navigation";
import { Card, Text } from "@tremor/react";
@ -32,7 +33,16 @@ export default function SSOCallback() {
);
if (response.ok) {
setAuthStatus("Authentication successful!");
const sessionCookie = Cookies.get('tenant_details');
console.log("Session cookie:", sessionCookie);
console.log("All cookies:", document.cookie);
// Log response headers
response.headers.forEach((value, key) => {
console.log(`${key}: ${value}`);
});
return;
// Redirect to the dashboard
router.replace("/admin/plan");
} else {

View File

@ -14,7 +14,7 @@ import { Metadata } from "next";
import { buildClientUrl } from "@/lib/utilsSS";
import { Inter } from "next/font/google";
import Head from "next/head";
import { EnterpriseSettings } from "./admin/settings/interfaces";
import { CombinedSettings, defaultCombinedSettings, EnterpriseSettings } from "./admin/settings/interfaces";
import { Card } from "@tremor/react";
import { HeaderTitle } from "@/components/header/HeaderTitle";
import { Logo } from "@/components/Logo";
@ -53,46 +53,47 @@ export default async function RootLayout({
}: {
children: React.ReactNode;
}) {
// 00a89749-beab-489a-8b72-88aa3d646274
// 01fb5963-9ab3-4585-900a-438480857427
// 00a89749-beab-489a-8b72-88aa3d646274
// 01fb5963-9ab3-4585-900a-438480857427
// return <>{children}</>
// SELECT table_name, column_name, data_type, character_maximum_length
// FROM information_schema.columns
// WHERE table_schema = '00a89749-beab-489a-8b72-88aa3d646274'
// ORDER BY table_name, ordinal_position;
// SELECT table_name, column_name, data_type, character_maximum_length
// FROM information_schema.columns
// WHERE table_schema = '00a89749-beab-489a-8b72-88aa3d646274'
// ORDER BY table_name, ordinal_position;
const combinedSettings = await fetchSettingsSS();
if (!combinedSettings) {
// return <>{children}</>
// Just display a simple full page error if fetching fails.
const combinedSettings: CombinedSettings | null = await fetchSettingsSS()
return (
<html lang="en" className={`${inter.variable} font-sans`}>
<Head>
<title>Settings Unavailable | Danswer</title>
</Head>
<body className="bg-background text-default">
<div className="flex flex-col items-center justify-center min-h-screen">
<div className="mb-2 flex items-center max-w-[175px]">
<HeaderTitle>Danswer</HeaderTitle>
<Logo height={40} width={40} />
</div>
// if (!combinedSettings) {
// return <>{children}</>
// // Just display a simple full page error if fetching fails.
<Card className="p-8 max-w-md">
<h1 className="text-2xl font-bold mb-4 text-error">Error</h1>
<p className="text-text-500">
Your Danswer instance was not configured properly and your
settings could not be loaded. Please contact your admin to fix
this error.
</p>
</Card>
</div>
</body>
</html>
);
}
// return (
// <html lang="en" className={`${inter.variable} font-sans`}>
// <Head>
// <title>Settings Unavailable | Danswer</title>
// </Head>
// <body className="bg-background text-default">
// <div className="flex flex-col items-center justify-center min-h-screen">
// <div className="mb-2 flex items-center max-w-[175px]">
// <HeaderTitle>Danswer</HeaderTitle>
// <Logo height={40} width={40} />
// </div>
// <Card className="p-8 max-w-md">
// <h1 className="text-2xl font-bold mb-4 text-error">Error</h1>
// <p className="text-text-500">
// Your Danswer instance was not configured properly and your
// settings could not be loaded. Please contact your admin to fix
// this error.
// </p>
// </Card>
// </div>
// </body>
// </html>
// );
// }
return (
<html lang="en">
@ -103,7 +104,7 @@ export default async function RootLayout({
/>
</Head>
{CUSTOM_ANALYTICS_ENABLED && combinedSettings.customAnalyticsScript && (
{CUSTOM_ANALYTICS_ENABLED && combinedSettings && combinedSettings.customAnalyticsScript && (
<head>
<script
type="text/javascript"
@ -119,10 +120,10 @@ export default async function RootLayout({
className={`text-default bg-background ${
// TODO: remove this once proper dark mode exists
process.env.THEME_IS_DARK?.toLowerCase() === "true" ? "dark" : ""
}`}
}`}
>
<UserProvider>
<SettingsProvider settings={combinedSettings}>
<SettingsProvider settings={combinedSettings || defaultCombinedSettings}>
{children}
</SettingsProvider>
</UserProvider>

View File

@ -29,6 +29,7 @@ export async function fetchCustomAnalyticsScriptSS() {
}
export async function fetchSettingsSS(): Promise<CombinedSettings | null> {
console.log("\n\n\n\nfetch settings")
const tasks = [fetchStandardSettingsSS()];
if (SERVER_SIDE_ONLY__PAID_ENTERPRISE_FEATURES_ENABLED) {
tasks.push(fetchEnterpriseSettingsSS());