From 9b19990764c587fc54816f39bbdf6d2868a4821d Mon Sep 17 00:00:00 2001 From: pablonyx Date: Fri, 24 Jan 2025 12:40:08 -0800 Subject: [PATCH] Input shortcut fix in multi tenant case (#3768) * validated fix * nit * k --- backend/onyx/db/input_prompt.py | 29 +++++++++++++++++------------ 1 file changed, 17 insertions(+), 12 deletions(-) diff --git a/backend/onyx/db/input_prompt.py b/backend/onyx/db/input_prompt.py index 4437f825a..c9083616e 100644 --- a/backend/onyx/db/input_prompt.py +++ b/backend/onyx/db/input_prompt.py @@ -193,13 +193,13 @@ def fetch_input_prompts_by_user( """ Returns all prompts belonging to the user or public prompts, excluding those the user has specifically disabled. + Also, if `user_id` is None and AUTH_TYPE is DISABLED, then all prompts are returned. """ - # Start with a basic query for InputPrompt query = select(InputPrompt) - # If we have a user, left join to InputPrompt__User so we can check "disabled" if user_id is not None: + # If we have a user, left join to InputPrompt__User to check "disabled" IPU = aliased(InputPrompt__User) query = query.join( IPU, @@ -208,25 +208,30 @@ def fetch_input_prompts_by_user( ) # Exclude disabled prompts - # i.e. keep only those where (IPU.disabled is NULL or False) query = query.where(or_(IPU.disabled.is_(None), IPU.disabled.is_(False))) if include_public: - # user-owned or public + # Return both user-owned and public prompts query = query.where( - (InputPrompt.user_id == user_id) | (InputPrompt.is_public) + or_( + InputPrompt.user_id == user_id, + InputPrompt.is_public, + ) ) else: - # only user-owned prompts + # Return only user-owned prompts query = query.where(InputPrompt.user_id == user_id) - # If no user is logged in, get all prompts (public and private) - if user_id is None and AUTH_TYPE == AuthType.DISABLED: - query = query.where(True) # type: ignore + else: + # user_id is None + if AUTH_TYPE == AuthType.DISABLED: + # If auth is disabled, return all prompts + query = query.where(True) # type: ignore + elif include_public: + # Anonymous usage + query = query.where(InputPrompt.is_public) - # If no user is logged in but we want to include public prompts - elif include_public: - query = query.where(InputPrompt.is_public) + # Default to returning all prompts if active is not None: query = query.where(InputPrompt.active == active)