diff --git a/backend/danswer/auth/users.py b/backend/danswer/auth/users.py index 8ff385c21..72be25cfe 100644 --- a/backend/danswer/auth/users.py +++ b/backend/danswer/auth/users.py @@ -1,3 +1,4 @@ +from danswer.configs.app_configs import SECRET_JWT_KEY from datetime import timedelta import contextlib import smtplib @@ -250,11 +251,8 @@ async def create_user_session(user: User, tenant_id: str) -> str: "tenant_id": tenant_id, "exp": datetime.utcnow() + timedelta(seconds=SESSION_EXPIRE_TIME_SECONDS) } - - # Encode the token - token = jwt.encode(payload, "JWT_SECRET_KEY", algorithm="HS256") - - + + token = jwt.encode(payload, SECRET_JWT_KEY, algorithm="HS256") return token diff --git a/backend/danswer/configs/app_configs.py b/backend/danswer/configs/app_configs.py index b9cf95f47..87fbf137b 100644 --- a/backend/danswer/configs/app_configs.py +++ b/backend/danswer/configs/app_configs.py @@ -37,6 +37,7 @@ DISABLE_GENERATIVE_AI = os.environ.get("DISABLE_GENERATIVE_AI", "").lower() == " WEB_DOMAIN = os.environ.get("WEB_DOMAIN") or "http://localhost:3000" +SECRET_JWT_KEY = os.environ.get("SECRET_JWT_KEY") or "JWT_SECRET_KEY" ##### # Auth Configs ##### diff --git a/backend/danswer/db/engine.py b/backend/danswer/db/engine.py index 78fbb289b..67a2092f0 100644 --- a/backend/danswer/db/engine.py +++ b/backend/danswer/db/engine.py @@ -17,6 +17,7 @@ from sqlalchemy.ext.asyncio import create_async_engine from sqlalchemy.orm import Session from sqlalchemy.orm import sessionmaker +from danswer.configs.app_configs import SECRET_JWT_KEY from danswer.configs.app_configs import DEFAULT_SCHEMA from danswer.configs.app_configs import LOG_POSTGRES_CONN_COUNTS from danswer.configs.app_configs import LOG_POSTGRES_LATENCY @@ -138,57 +139,22 @@ def init_sqlalchemy_engine(app_name: str) -> None: global POSTGRES_APP_NAME POSTGRES_APP_NAME = app_name -_engines = {} -# TODO validate that this is best practice -def get_sqlalchemy_engine(schema: str = DEFAULT_SCHEMA) -> Engine: - if schema in _engines: - return _engines[schema] +_engine = None - connection_string = build_connection_string( - db_api=SYNC_DB_API, app_name=POSTGRES_APP_NAME + "_sync" - ) - engine = create_engine( - connection_string, - pool_size=40, - max_overflow=10, - pool_pre_ping=POSTGRES_POOL_PRE_PING, - pool_recycle=POSTGRES_POOL_RECYCLE, - ) - - @event.listens_for(engine, "connect") - def set_search_path(dbapi_connection, connection_record): - cursor = dbapi_connection.cursor() - cursor.execute(f'SET search_path TO "{schema}"') - cursor.close() - dbapi_connection.commit() - - _engines[schema] = engine - return engine - - - -# def get_sqlalchemy_engine(schema: str = DEFAULT_SCHEMA) -> Engine: -# global _SYNC_ENGINE -# if _SYNC_ENGINE is None: -# connection_string = build_connection_string( -# db_api=SYNC_DB_API, app_name=POSTGRES_APP_NAME + "_sync" -# ) -# _SYNC_ENGINE = create_engine( -# connection_string, -# pool_size=40, -# max_overflow=10, -# pool_pre_ping=POSTGRES_POOL_PRE_PING, -# pool_recycle=POSTGRES_POOL_RECYCLE, -# ) - -# @event.listens_for(_SYNC_ENGINE, "connect") -# def set_search_path(dbapi_connection, connection_record): -# cursor = dbapi_connection.cursor() -# cursor.execute(f"SET search_path TO {schema}") -# cursor.close() -# dbapi_connection.commit() - -# return _SYNC_ENGINE +def get_sqlalchemy_engine() -> Engine: + global _engine + if _engine is None: + connection_string = build_connection_string( + db_api=SYNC_DB_API, app_name=POSTGRES_APP_NAME + "_sync" + ) + _engine = create_engine( + connection_string, + pool_size=40, + max_overflow=10, + pool_pre_ping=POSTGRES_POOL_PRE_PING, + pool_recycle=POSTGRES_POOL_RECYCLE, + ) + return _engine def get_sqlalchemy_async_engine() -> AsyncEngine: @@ -212,21 +178,19 @@ def get_sqlalchemy_async_engine() -> AsyncEngine: def get_session_context_manager() -> ContextManager[Session]: return contextlib.contextmanager(get_session)() - - - def get_current_tenant_id(request: Request) -> str: if not MULTI_TENANT: return DEFAULT_SCHEMA - token = request.cookies.get("fastapiusersauth") + print(request.cookies) + token = request.cookies.get("tenant_details") if not token: logger.warning("No token found in cookies") raise HTTPException(status_code=401, detail="Authentication required") try: logger.info(f"Attempting to decode token: {token[:10]}...") # Log only first 10 characters for security - payload = jwt.decode(token, "JWT_SECRET_KEY", algorithms=["HS256"]) + payload = jwt.decode(token, SECRET_JWT_KEY, algorithms=["HS256"]) logger.info(f"Decoded payload: {payload}") tenant_id = payload.get("tenant_id") if not tenant_id: @@ -246,29 +210,8 @@ def get_current_tenant_id(request: Request) -> str: def get_session(tenant_id: str | None= Depends(get_current_tenant_id)) -> Generator[Session, None, None]: - print("") - print("\n\n\n\n") - with Session(get_sqlalchemy_engine(), expire_on_commit=False) as session: - session.execute(text(f'SET search_path TO "{tenant_id}"')) - print("SEARCH PATH IS ", tenant_id) + with Session(get_sqlalchemy_engine(schema=tenant_id), expire_on_commit=False) as session: yield session - session.execute(text('SET search_path TO "public"')) - - - # Logic to create or retrieve a database session for the given tenant_id - - - -# def get_session(schema: str = DEFAULT_SCHEMA) -> Generator[Session, None, None]: -# # The line below was added to monitor the latency caused by Postgres connections -# # during API calls. -# # with tracer.trace("db.get_session"): - -# with Session(get_sqlalchemy_engine(), expire_on_commit=False) as session: -# session.execute(text(f"SET search_path TO {schema}")) -# yield session -# session.execute(text("SET search_path TO public")) - async def get_async_session() -> AsyncGenerator[AsyncSession, None]: async with AsyncSession( diff --git a/backend/danswer/one_shot_answer/answer_question.py b/backend/danswer/one_shot_answer/answer_question.py index 85791ab1e..319375332 100644 --- a/backend/danswer/one_shot_answer/answer_question.py +++ b/backend/danswer/one_shot_answer/answer_question.py @@ -139,6 +139,7 @@ def stream_answer_objects( rephrased_query = query_req.query_override or thread_based_query_rephrase( user_query=query_msg.message, history_str=history_str, + db_session=db_session ) # Given back ahead of the documents for latency reasons diff --git a/backend/danswer/search/retrieval/search_runner.py b/backend/danswer/search/retrieval/search_runner.py index 31582f908..519d64ec6 100644 --- a/backend/danswer/search/retrieval/search_runner.py +++ b/backend/danswer/search/retrieval/search_runner.py @@ -237,7 +237,7 @@ def retrieve_chunks( # Currently only uses query expansion on multilingual use cases query_rephrases = multilingual_query_expansion( - query.query, multilingual_expansion + query.query, multilingual_expansion, db_session=db_session ) # Just to be extra sure, add the original query. query_rephrases.append(query.query) diff --git a/backend/danswer/secondary_llm_flows/query_expansion.py b/backend/danswer/secondary_llm_flows/query_expansion.py index 2b21af7da..9c4ef00e5 100644 --- a/backend/danswer/secondary_llm_flows/query_expansion.py +++ b/backend/danswer/secondary_llm_flows/query_expansion.py @@ -15,11 +15,11 @@ from danswer.prompts.miscellaneous_prompts import LANGUAGE_REPHRASE_PROMPT from danswer.utils.logger import setup_logger from danswer.utils.text_processing import count_punctuation from danswer.utils.threadpool_concurrency import run_functions_tuples_in_parallel - +from sqlalchemy.orm import Session logger = setup_logger() -def llm_multilingual_query_expansion(query: str, language: str) -> str: +def llm_multilingual_query_expansion(query: str, language: str, db_session: Session) -> str: def _get_rephrase_messages() -> list[dict[str, str]]: messages = [ { @@ -51,12 +51,13 @@ def llm_multilingual_query_expansion(query: str, language: str) -> str: def multilingual_query_expansion( query: str, expansion_languages: list[str], + db_session: Session, use_threads: bool = True, ) -> list[str]: languages = [language.strip() for language in expansion_languages] if use_threads: functions_with_args: list[tuple[Callable, tuple]] = [ - (llm_multilingual_query_expansion, (query, language)) + (llm_multilingual_query_expansion, (query, language, db_session)) for language in languages ] @@ -65,7 +66,7 @@ def multilingual_query_expansion( else: query_rephrases = [ - llm_multilingual_query_expansion(query, language) for language in languages + llm_multilingual_query_expansion(query, language, db_session) for language in languages ] return query_rephrases @@ -134,9 +135,10 @@ def history_based_query_rephrase( def thread_based_query_rephrase( user_query: str, history_str: str, + db_session: Session, llm: LLM | None = None, size_heuristic: int = 200, - punctuation_heuristic: int = 10, + punctuation_heuristic: int = 10 ) -> str: if not history_str: return user_query diff --git a/backend/danswer/secondary_llm_flows/query_validation.py b/backend/danswer/secondary_llm_flows/query_validation.py index 190390b90..b219b525b 100644 --- a/backend/danswer/secondary_llm_flows/query_validation.py +++ b/backend/danswer/secondary_llm_flows/query_validation.py @@ -1,4 +1,5 @@ import re +from sqlalchemy.orm import Session from collections.abc import Iterator from danswer.chat.models import DanswerAnswerPiece @@ -46,7 +47,8 @@ def extract_answerability_bool(model_raw: str) -> bool: def get_query_answerability( - user_query: str, skip_check: bool = DISABLE_LLM_QUERY_ANSWERABILITY + db_session: Session, + user_query: str, skip_check: bool = DISABLE_LLM_QUERY_ANSWERABILITY, ) -> tuple[str, bool]: if skip_check: return "Query Answerability Evaluation feature is turned off", True @@ -67,7 +69,7 @@ def get_query_answerability( def stream_query_answerability( - user_query: str, skip_check: bool = DISABLE_LLM_QUERY_ANSWERABILITY + db_session: Session, user_query: str, skip_check: bool = DISABLE_LLM_QUERY_ANSWERABILITY, ) -> Iterator[str]: if skip_check: yield get_json_line( diff --git a/backend/danswer/server/query_and_chat/query_backend.py b/backend/danswer/server/query_and_chat/query_backend.py index 704b16d5e..670ee2863 100644 --- a/backend/danswer/server/query_and_chat/query_backend.py +++ b/backend/danswer/server/query_and_chat/query_backend.py @@ -117,13 +117,13 @@ def get_tags( @basic_router.post("/query-validation") def query_validation( - simple_query: SimpleQueryRequest, _: User = Depends(current_user) + simple_query: SimpleQueryRequest, _: User = Depends(current_user), db_session: Session = Depends(get_session) ) -> QueryValidationResponse: # Note if weak model prompt is chosen, this check does not occur and will simply return that # the query is valid, this is because weaker models cannot really handle this task well. # Additionally, some weak model servers cannot handle concurrent inferences. logger.notice(f"Validating query: {simple_query.query}") - reasoning, answerable = get_query_answerability(simple_query.query) + reasoning, answerable = get_query_answerability(db_session=db_session, user_query=simple_query.query) return QueryValidationResponse(reasoning=reasoning, answerable=answerable) @@ -226,14 +226,14 @@ def get_search_session( # No search responses are answered with a conversational generative AI response @basic_router.post("/stream-query-validation") def stream_query_validation( - simple_query: SimpleQueryRequest, _: User = Depends(current_user) + simple_query: SimpleQueryRequest, _: User = Depends(current_user), db_session: Session = Depends(get_session) ) -> StreamingResponse: # Note if weak model prompt is chosen, this check does not occur and will simply return that # the query is valid, this is because weaker models cannot really handle this task well. # Additionally, some weak model servers cannot handle concurrent inferences. logger.notice(f"Validating query: {simple_query.query}") return StreamingResponse( - stream_query_answerability(simple_query.query), media_type="application/json" + stream_query_answerability(user_query=simple_query.query, db_session=db_session), media_type="application/json" ) diff --git a/backend/danswer/server/settings/api.py b/backend/danswer/server/settings/api.py index 89259f99a..a8eb4f63c 100644 --- a/backend/danswer/server/settings/api.py +++ b/backend/danswer/server/settings/api.py @@ -105,8 +105,9 @@ async def create_tenant_schema(tenant_id: str) -> None: await session.commit() logger.info(f"Schema created for tenant: {tenant_id}") - with Session(get_sqlalchemy_engine(schema=tenant_id)) as db_session: + with Session(get_sqlalchemy_engine()) as db_session: try: + db_session.execute(text(f'Set search_path to "{tenant_id}"')) setup_postgres(db_session) except SQLAlchemyError as e: logger.error(f"Error while loading chat YAMLs for tenant {tenant_id}: {str(e)}") @@ -114,8 +115,6 @@ async def create_tenant_schema(tenant_id: str) -> None: finally: db_session.execute(text('SET search_path TO "public"')) - # db_session.execute(text(f'SET search_path TO "public"')) - logger.info(f"Migrations completed for tenant: {tenant_id}") @@ -153,25 +152,21 @@ async def sso_callback( logger.info(f"Session token created for user: {user.email}") # Set the session cookie with proper flags + response = JSONResponse(content={"message": "Authentication successful"}) response.set_cookie( - key="fastapiusersauth", + key="tenant_details", value=session_token, max_age=SESSION_EXPIRE_TIME_SECONDS, expires=SESSION_EXPIRE_TIME_SECONDS, path="/", domain=WEB_DOMAIN.split("://")[-1], - secure=True, + secure=False, # Set to True in production with HTTPS httponly=True, samesite="lax", ) logger.info("Session cookie set") - logger.info("SSO callback completed successfully") - return JSONResponse( - content={"message": "Authentication successful"}, - status_code=200 - ) - + return response # @basic_router.post("/auth/sso-callback") diff --git a/backend/ee/danswer/server/query_and_chat/chat_backend.py b/backend/ee/danswer/server/query_and_chat/chat_backend.py index a613faaf6..f15d4aae8 100644 --- a/backend/ee/danswer/server/query_and_chat/chat_backend.py +++ b/backend/ee/danswer/server/query_and_chat/chat_backend.py @@ -223,6 +223,7 @@ def handle_send_message_simple_with_history( rephrased_query = req.query_override or thread_based_query_rephrase( user_query=query, history_str=history_str, + db_session=db_session ) full_chat_msg_info = CreateChatMessageRequest( diff --git a/web/src/app/admin/settings/interfaces.ts b/web/src/app/admin/settings/interfaces.ts index b3f07522f..f12d5156f 100644 --- a/web/src/app/admin/settings/interfaces.ts +++ b/web/src/app/admin/settings/interfaces.ts @@ -47,3 +47,31 @@ export interface CombinedSettings { isMobile?: boolean; webVersion: string | null; } + + +export const defaultCombinedSettings: CombinedSettings = { + settings: { + chat_page_enabled: true, + search_page_enabled: true, + default_page: "search", + maximum_chat_retention_days: 30, + notifications: [], + needs_reindexing: false, + }, + enterpriseSettings: { + application_name: "Danswer", + use_custom_logo: false, + use_custom_logotype: false, + custom_lower_disclaimer_content: null, + custom_header_content: null, + custom_popup_header: null, + custom_popup_content: null, + }, + cloudSettings: { + numberOfSeats: 0, + planType: BillingPlanType.FREE, + }, + customAnalyticsScript: null, + isMobile: false, + webVersion: null, +}; \ No newline at end of file diff --git a/web/src/app/auth/sso-callback/page.tsx b/web/src/app/auth/sso-callback/page.tsx index 9e34ce8da..2709c36af 100644 --- a/web/src/app/auth/sso-callback/page.tsx +++ b/web/src/app/auth/sso-callback/page.tsx @@ -1,4 +1,5 @@ "use client"; +import Cookies from "js-cookie"; import { useEffect, useState } from "react"; import { useRouter, useSearchParams } from "next/navigation"; import { Card, Text } from "@tremor/react"; @@ -32,7 +33,16 @@ export default function SSOCallback() { ); if (response.ok) { setAuthStatus("Authentication successful!"); - + const sessionCookie = Cookies.get('tenant_details'); + console.log("Session cookie:", sessionCookie); + + console.log("All cookies:", document.cookie); + + // Log response headers + response.headers.forEach((value, key) => { + console.log(`${key}: ${value}`); + }); + return; // Redirect to the dashboard router.replace("/admin/plan"); } else { diff --git a/web/src/app/layout.tsx b/web/src/app/layout.tsx index acf870f92..46f8d7b9e 100644 --- a/web/src/app/layout.tsx +++ b/web/src/app/layout.tsx @@ -14,7 +14,7 @@ import { Metadata } from "next"; import { buildClientUrl } from "@/lib/utilsSS"; import { Inter } from "next/font/google"; import Head from "next/head"; -import { EnterpriseSettings } from "./admin/settings/interfaces"; +import { CombinedSettings, defaultCombinedSettings, EnterpriseSettings } from "./admin/settings/interfaces"; import { Card } from "@tremor/react"; import { HeaderTitle } from "@/components/header/HeaderTitle"; import { Logo } from "@/components/Logo"; @@ -53,46 +53,47 @@ export default async function RootLayout({ }: { children: React.ReactNode; }) { -// 00a89749-beab-489a-8b72-88aa3d646274 -// 01fb5963-9ab3-4585-900a-438480857427 + // 00a89749-beab-489a-8b72-88aa3d646274 + // 01fb5963-9ab3-4585-900a-438480857427 // return <>{children} -// SELECT table_name, column_name, data_type, character_maximum_length -// FROM information_schema.columns -// WHERE table_schema = '00a89749-beab-489a-8b72-88aa3d646274' -// ORDER BY table_name, ordinal_position; + // SELECT table_name, column_name, data_type, character_maximum_length + // FROM information_schema.columns + // WHERE table_schema = '00a89749-beab-489a-8b72-88aa3d646274' + // ORDER BY table_name, ordinal_position; - const combinedSettings = await fetchSettingsSS(); - if (!combinedSettings) { - // return <>{children} - // Just display a simple full page error if fetching fails. + const combinedSettings: CombinedSettings | null = await fetchSettingsSS() - return ( - - - Settings Unavailable | Danswer - - -
-
- Danswer - -
+ // if (!combinedSettings) { + // return <>{children} + // // Just display a simple full page error if fetching fails. - -

Error

-

- Your Danswer instance was not configured properly and your - settings could not be loaded. Please contact your admin to fix - this error. -

-
-
- - - ); - } + // return ( + // + // + // Settings Unavailable | Danswer + // + // + //
+ //
+ // Danswer + // + //
+ + // + //

Error

+ //

+ // Your Danswer instance was not configured properly and your + // settings could not be loaded. Please contact your admin to fix + // this error. + //

+ //
+ //
+ // + // + // ); + // } return ( @@ -103,7 +104,7 @@ export default async function RootLayout({ /> - {CUSTOM_ANALYTICS_ENABLED && combinedSettings.customAnalyticsScript && ( + {CUSTOM_ANALYTICS_ENABLED && combinedSettings && combinedSettings.customAnalyticsScript && (