diff --git a/backend/danswer/chat/load_yamls.py b/backend/danswer/chat/load_yamls.py index ccc754437..1b1e615bb 100644 --- a/backend/danswer/chat/load_yamls.py +++ b/backend/danswer/chat/load_yamls.py @@ -24,7 +24,7 @@ def load_prompts_from_yaml(prompts_yaml: str = PROMPTS_YAML) -> None: with Session(get_sqlalchemy_engine()) as db_session: for prompt in all_prompts: upsert_prompt( - user_id=None, + user=None, prompt_id=prompt.get("id"), name=prompt["name"], description=prompt["description"].strip(), @@ -34,7 +34,6 @@ def load_prompts_from_yaml(prompts_yaml: str = PROMPTS_YAML) -> None: datetime_aware=prompt.get("datetime_aware", True), default_prompt=True, personas=None, - shared=True, db_session=db_session, commit=True, ) @@ -67,9 +66,7 @@ def load_personas_from_yaml( prompts: list[PromptDBModel | None] | None = None else: prompts = [ - get_prompt_by_name( - prompt_name, user_id=None, shared=True, db_session=db_session - ) + get_prompt_by_name(prompt_name, user=None, db_session=db_session) for prompt_name in prompt_set_names ] if any([prompt is None for prompt in prompts]): @@ -80,7 +77,7 @@ def load_personas_from_yaml( p_id = persona.get("id") upsert_persona( - user_id=None, + user=None, # Negative to not conflict with existing personas persona_id=(-1 * p_id) if p_id is not None else None, name=persona["name"], @@ -96,7 +93,6 @@ def load_personas_from_yaml( prompts=cast(list[PromptDBModel] | None, prompts), document_sets=doc_sets, default_persona=True, - shared=True, is_public=True, db_session=db_session, ) diff --git a/backend/danswer/db/chat.py b/backend/danswer/db/chat.py index 738d02a16..d45fe95a7 100644 --- a/backend/danswer/db/chat.py +++ b/backend/danswer/db/chat.py @@ -12,6 +12,7 @@ from sqlalchemy import update from sqlalchemy.exc import MultipleResultsFound from sqlalchemy.orm import Session +from danswer.auth.schemas import UserRole from danswer.configs.chat_configs import HARD_DELETE_CHATS from danswer.configs.constants import MessageType from danswer.db.constants import SLACK_BOT_PERSONA_PREFIX @@ -27,6 +28,7 @@ from danswer.db.models import Prompt from danswer.db.models import SearchDoc from danswer.db.models import SearchDoc as DBSearchDoc from danswer.db.models import StarterMessage +from danswer.db.models import User from danswer.db.models import User__UserGroup from danswer.llm.override_models import LLMOverride from danswer.llm.override_models import PromptOverride @@ -313,13 +315,16 @@ def set_as_latest_chat_message( def get_prompt_by_id( prompt_id: int, - user_id: UUID | None, + user: User | None, db_session: Session, include_deleted: bool = False, ) -> Prompt: - stmt = select(Prompt).where( - Prompt.id == prompt_id, or_(Prompt.user_id == user_id, Prompt.user_id.is_(None)) - ) + stmt = select(Prompt).where(Prompt.id == prompt_id) + + # if user is not specified OR they are an admin, they should + # have access to all prompts, so this where clause is not needed + if user and user.role != UserRole.ADMIN: + stmt = stmt.where(or_(Prompt.user_id == user.id, Prompt.user_id.is_(None))) if not include_deleted: stmt = stmt.where(Prompt.deleted.is_(False)) @@ -351,14 +356,16 @@ def get_default_prompt() -> Prompt: def get_persona_by_id( persona_id: int, - # if user_id is `None` assume the user is an admin or auth is disabled - user_id: UUID | None, + # if user is `None` assume the user is an admin or auth is disabled + user: User | None, db_session: Session, include_deleted: bool = False, ) -> Persona: stmt = select(Persona).where(Persona.id == persona_id) - if user_id is not None: - stmt = stmt.where(or_(Persona.user_id == user_id, Persona.user_id.is_(None))) + + # if user is an admin, they should have access to all Personas + if user is not None and user.role != UserRole.ADMIN: + stmt = stmt.where(or_(Persona.user_id == user.id, Persona.user_id.is_(None))) if not include_deleted: stmt = stmt.where(Persona.deleted.is_(False)) @@ -397,33 +404,33 @@ def get_personas_by_ids( def get_prompt_by_name( - prompt_name: str, user_id: UUID | None, shared: bool, db_session: Session + prompt_name: str, user: User | None, db_session: Session ) -> Prompt | None: - """Cannot do shared and user owned simultaneously as there may be two of those""" stmt = select(Prompt).where(Prompt.name == prompt_name) - if shared: - stmt = stmt.where(Prompt.user_id.is_(None)) - else: - stmt = stmt.where(Prompt.user_id == user_id) + + # if user is not specified OR they are an admin, they should + # have access to all prompts, so this where clause is not needed + if user and user.role != UserRole.ADMIN: + stmt = stmt.where(Prompt.user_id == user.id) + result = db_session.execute(stmt).scalar_one_or_none() return result def get_persona_by_name( - persona_name: str, user_id: UUID | None, shared: bool, db_session: Session + persona_name: str, user: User | None, db_session: Session ) -> Persona | None: - """Cannot do shared and user owned simultaneously as there may be two of those""" + """Admins can see all, regular users can only fetch their own. + If user is None, assume the user is an admin or auth is disabled.""" stmt = select(Persona).where(Persona.name == persona_name) - if shared: - stmt = stmt.where(Persona.user_id.is_(None)) - else: - stmt = stmt.where(Persona.user_id == user_id) + if user and user.role != UserRole.ADMIN: + stmt = stmt.where(Persona.user_id == user.id) result = db_session.execute(stmt).scalar_one_or_none() return result def upsert_prompt( - user_id: UUID | None, + user: User | None, name: str, description: str, system_prompt: str, @@ -431,7 +438,6 @@ def upsert_prompt( include_citations: bool, datetime_aware: bool, personas: list[Persona] | None, - shared: bool, db_session: Session, prompt_id: int | None = None, default_prompt: bool = True, @@ -440,9 +446,7 @@ def upsert_prompt( if prompt_id is not None: prompt = db_session.query(Prompt).filter_by(id=prompt_id).first() else: - prompt = get_prompt_by_name( - prompt_name=name, user_id=user_id, shared=shared, db_session=db_session - ) + prompt = get_prompt_by_name(prompt_name=name, user=user, db_session=db_session) if prompt: if not default_prompt and prompt.default_prompt: @@ -463,7 +467,7 @@ def upsert_prompt( else: prompt = Prompt( id=prompt_id, - user_id=None if shared else user_id, + user_id=user.id if user else None, name=name, description=description, system_prompt=system_prompt, @@ -485,7 +489,7 @@ def upsert_prompt( def upsert_persona( - user_id: UUID | None, + user: User | None, name: str, description: str, num_chunks: float, @@ -496,7 +500,6 @@ def upsert_persona( document_sets: list[DBDocumentSet] | None, llm_model_version_override: str | None, starter_messages: list[StarterMessage] | None, - shared: bool, is_public: bool, db_session: Session, persona_id: int | None = None, @@ -507,7 +510,7 @@ def upsert_persona( persona = db_session.query(Persona).filter_by(id=persona_id).first() else: persona = get_persona_by_name( - persona_name=name, user_id=user_id, shared=shared, db_session=db_session + persona_name=name, user=user, db_session=db_session ) if persona: @@ -539,7 +542,7 @@ def upsert_persona( else: persona = Persona( id=persona_id, - user_id=None if shared else user_id, + user_id=user.id if user else None, is_public=is_public, name=name, description=description, @@ -566,24 +569,20 @@ def upsert_persona( def mark_prompt_as_deleted( prompt_id: int, - user_id: UUID | None, + user: User | None, db_session: Session, ) -> None: - prompt = get_prompt_by_id( - prompt_id=prompt_id, user_id=user_id, db_session=db_session - ) + prompt = get_prompt_by_id(prompt_id=prompt_id, user=user, db_session=db_session) prompt.deleted = True db_session.commit() def mark_persona_as_deleted( persona_id: int, - user_id: UUID | None, + user: User | None, db_session: Session, ) -> None: - persona = get_persona_by_id( - persona_id=persona_id, user_id=user_id, db_session=db_session - ) + persona = get_persona_by_id(persona_id=persona_id, user=user, db_session=db_session) persona.deleted = True db_session.commit() @@ -621,9 +620,7 @@ def update_persona_visibility( is_visible: bool, db_session: Session, ) -> None: - persona = get_persona_by_id( - persona_id=persona_id, user_id=None, db_session=db_session - ) + persona = get_persona_by_id(persona_id=persona_id, user=None, db_session=db_session) persona.is_visible = is_visible db_session.commit() diff --git a/backend/danswer/db/models.py b/backend/danswer/db/models.py index 004025d7e..8e1540f20 100644 --- a/backend/danswer/db/models.py +++ b/backend/danswer/db/models.py @@ -736,7 +736,6 @@ class Prompt(Base): __tablename__ = "prompt" id: Mapped[int] = mapped_column(primary_key=True) - # If not belong to a user, then it's shared user_id: Mapped[UUID | None] = mapped_column(ForeignKey("user.id"), nullable=True) name: Mapped[str] = mapped_column(String) description: Mapped[str] = mapped_column(String) @@ -770,7 +769,6 @@ class Persona(Base): __tablename__ = "persona" id: Mapped[int] = mapped_column(primary_key=True) - # If not belong to a user, then it's shared user_id: Mapped[UUID | None] = mapped_column(ForeignKey("user.id"), nullable=True) name: Mapped[str] = mapped_column(String) description: Mapped[str] = mapped_column(String) @@ -824,7 +822,7 @@ class Persona(Base): back_populates="personas", ) # Owner - user: Mapped[User] = relationship("User", back_populates="personas") + user: Mapped[User | None] = relationship("User", back_populates="personas") # Other users with access users: Mapped[list[User]] = relationship( "User", diff --git a/backend/danswer/db/persona.py b/backend/danswer/db/persona.py index 38351b18b..7b1116b5f 100644 --- a/backend/danswer/db/persona.py +++ b/backend/danswer/db/persona.py @@ -6,6 +6,7 @@ from sqlalchemy.orm import Session from danswer.db.chat import get_prompts_by_ids from danswer.db.chat import upsert_persona from danswer.db.document_set import get_document_sets_by_ids +from danswer.db.models import Persona__User from danswer.db.models import User from danswer.server.features.persona.models import CreatePersonaRequest from danswer.server.features.persona.models import PersonaSnapshot @@ -21,9 +22,19 @@ def make_persona_private( group_ids: list[int] | None, db_session: Session, ) -> None: + if user_ids is not None: + db_session.query(Persona__User).filter( + Persona__User.persona_id == persona_id + ).delete(synchronize_session="fetch") + + for user_uuid in user_ids: + db_session.add(Persona__User(persona_id=persona_id, user_id=user_uuid)) + + db_session.commit() + # May cause error if someone switches down to MIT from EE - if user_ids or group_ids: - raise NotImplementedError("Danswer MIT does not support private Document Sets") + if group_ids: + raise NotImplementedError("Danswer MIT does not support private Personas") def create_update_persona( @@ -32,8 +43,6 @@ def create_update_persona( user: User | None, db_session: Session, ) -> PersonaSnapshot: - user_id = user.id if user is not None else None - # Permission to actually use these is checked later document_sets = list( get_document_sets_by_ids( @@ -51,7 +60,7 @@ def create_update_persona( try: persona = upsert_persona( persona_id=persona_id, - user_id=user_id, + user=user, name=create_persona_request.name, description=create_persona_request.description, num_chunks=create_persona_request.num_chunks, @@ -62,7 +71,6 @@ def create_update_persona( document_sets=document_sets, llm_model_version_override=create_persona_request.llm_model_version_override, starter_messages=create_persona_request.starter_messages, - shared=create_persona_request.shared, is_public=create_persona_request.is_public, db_session=db_session, ) diff --git a/backend/danswer/db/slack_bot_config.py b/backend/danswer/db/slack_bot_config.py index c3b463e35..9b792ff08 100644 --- a/backend/danswer/db/slack_bot_config.py +++ b/backend/danswer/db/slack_bot_config.py @@ -49,7 +49,7 @@ def create_slack_bot_persona( # create/update persona associated with the slack bot persona_name = _build_persona_name(channel_names) persona = upsert_persona( - user_id=None, # Slack Bot Personas are not attached to users + user=None, # Slack Bot Personas are not attached to users persona_id=existing_persona_id, name=persona_name, description="", @@ -61,7 +61,6 @@ def create_slack_bot_persona( document_sets=document_sets, llm_model_version_override=None, starter_messages=None, - shared=True, is_public=True, default_persona=False, db_session=db_session, diff --git a/backend/danswer/one_shot_answer/answer_question.py b/backend/danswer/one_shot_answer/answer_question.py index c0c036339..ff6e04a21 100644 --- a/backend/danswer/one_shot_answer/answer_question.py +++ b/backend/danswer/one_shot_answer/answer_question.py @@ -173,7 +173,7 @@ def stream_answer_objects( prompt = None if query_req.prompt_id is not None: prompt = get_prompt_by_id( - prompt_id=query_req.prompt_id, user_id=user_id, db_session=db_session + prompt_id=query_req.prompt_id, user=user, db_session=db_session ) if prompt is None: if not chat_session.persona.prompts: diff --git a/backend/danswer/server/features/persona/api.py b/backend/danswer/server/features/persona/api.py index b4359f6a1..bfaea792f 100644 --- a/backend/danswer/server/features/persona/api.py +++ b/backend/danswer/server/features/persona/api.py @@ -28,35 +28,6 @@ admin_router = APIRouter(prefix="/admin/persona") basic_router = APIRouter(prefix="/persona") -@admin_router.post("") -def create_persona( - create_persona_request: CreatePersonaRequest, - user: User | None = Depends(current_admin_user), - db_session: Session = Depends(get_session), -) -> PersonaSnapshot: - return create_update_persona( - persona_id=None, - create_persona_request=create_persona_request, - user=user, - db_session=db_session, - ) - - -@admin_router.patch("/{persona_id}") -def update_persona( - persona_id: int, - update_persona_request: CreatePersonaRequest, - user: User | None = Depends(current_admin_user), - db_session: Session = Depends(get_session), -) -> PersonaSnapshot: - return create_update_persona( - persona_id=persona_id, - create_persona_request=update_persona_request, - user=user, - db_session=db_session, - ) - - class IsVisibleRequest(BaseModel): is_visible: bool @@ -92,19 +63,6 @@ def patch_persona_display_priority( ) -@admin_router.delete("/{persona_id}") -def delete_persona( - persona_id: int, - user: User | None = Depends(current_admin_user), - db_session: Session = Depends(get_session), -) -> None: - mark_persona_as_deleted( - persona_id=persona_id, - user_id=user.id if user is not None else None, - db_session=db_session, - ) - - @admin_router.get("") def list_personas_admin( _: User | None = Depends(current_admin_user), @@ -124,6 +82,48 @@ def list_personas_admin( """Endpoints for all""" +@basic_router.post("") +def create_persona( + create_persona_request: CreatePersonaRequest, + user: User | None = Depends(current_user), + db_session: Session = Depends(get_session), +) -> PersonaSnapshot: + return create_update_persona( + persona_id=None, + create_persona_request=create_persona_request, + user=user, + db_session=db_session, + ) + + +@basic_router.patch("/{persona_id}") +def update_persona( + persona_id: int, + update_persona_request: CreatePersonaRequest, + user: User | None = Depends(current_user), + db_session: Session = Depends(get_session), +) -> PersonaSnapshot: + return create_update_persona( + persona_id=persona_id, + create_persona_request=update_persona_request, + user=user, + db_session=db_session, + ) + + +@basic_router.delete("/{persona_id}") +def delete_persona( + persona_id: int, + user: User | None = Depends(current_user), + db_session: Session = Depends(get_session), +) -> None: + mark_persona_as_deleted( + persona_id=persona_id, + user=user, + db_session=db_session, + ) + + @basic_router.get("") def list_personas( user: User | None = Depends(current_user), @@ -148,7 +148,7 @@ def get_persona( return PersonaSnapshot.from_model( get_persona_by_id( persona_id=persona_id, - user_id=user.id if user is not None else None, + user=user, db_session=db_session, ) ) @@ -194,9 +194,9 @@ GPT_3_5_TURBO_MODEL_VERSIONS = [ ] -@admin_router.get("/utils/list-available-models") +@basic_router.get("/utils/list-available-models") def list_available_model_versions( - _: User | None = Depends(current_admin_user), + _: User | None = Depends(current_user), ) -> list[str]: # currently only support selecting different models for OpenAI if GEN_AI_MODEL_PROVIDER != "openai": @@ -205,9 +205,9 @@ def list_available_model_versions( return GPT_4_MODEL_VERSIONS + GPT_3_5_TURBO_MODEL_VERSIONS -@admin_router.get("/utils/default-model") +@basic_router.get("/utils/default-model") def get_default_model( - _: User | None = Depends(current_admin_user), + _: User | None = Depends(current_user), ) -> str: # currently only support selecting different models for OpenAI if GEN_AI_MODEL_PROVIDER != "openai": diff --git a/backend/danswer/server/features/persona/models.py b/backend/danswer/server/features/persona/models.py index 4cc80eec0..8826be2c3 100644 --- a/backend/danswer/server/features/persona/models.py +++ b/backend/danswer/server/features/persona/models.py @@ -7,12 +7,12 @@ from danswer.db.models import StarterMessage from danswer.search.enums import RecencyBiasSetting from danswer.server.features.document_set.models import DocumentSet from danswer.server.features.prompt.models import PromptSnapshot +from danswer.server.models import MinimalUserSnapshot class CreatePersonaRequest(BaseModel): name: str description: str - shared: bool num_chunks: float llm_relevance_filter: bool is_public: bool @@ -29,8 +29,8 @@ class CreatePersonaRequest(BaseModel): class PersonaSnapshot(BaseModel): id: int + owner: MinimalUserSnapshot | None name: str - shared: bool is_visible: bool is_public: bool display_priority: int | None @@ -43,6 +43,7 @@ class PersonaSnapshot(BaseModel): default_persona: bool prompts: list[PromptSnapshot] document_sets: list[DocumentSet] + users: list[UUID] groups: list[int] @classmethod @@ -53,7 +54,11 @@ class PersonaSnapshot(BaseModel): return PersonaSnapshot( id=persona.id, name=persona.name, - shared=persona.user_id is None, + owner=( + MinimalUserSnapshot(id=persona.user.id, email=persona.user.email) + if persona.user + else None + ), is_visible=persona.is_visible, is_public=persona.is_public, display_priority=persona.display_priority, @@ -69,6 +74,7 @@ class PersonaSnapshot(BaseModel): DocumentSet.from_model(document_set_model) for document_set_model in persona.document_sets ], + users=[user.id for user in persona.users], groups=[user_group.id for user_group in persona.groups], ) diff --git a/backend/danswer/server/features/prompt/api.py b/backend/danswer/server/features/prompt/api.py index b9f27675d..24c886ab9 100644 --- a/backend/danswer/server/features/prompt/api.py +++ b/backend/danswer/server/features/prompt/api.py @@ -4,7 +4,6 @@ from fastapi import HTTPException from sqlalchemy.orm import Session from starlette import status -from danswer.auth.users import current_admin_user from danswer.auth.users import current_user from danswer.db.chat import get_personas_by_ids from danswer.db.chat import get_prompt_by_id @@ -32,8 +31,6 @@ def create_update_prompt( user: User | None, db_session: Session, ) -> PromptSnapshot: - user_id = user.id if user is not None else None - personas = ( list( get_personas_by_ids( @@ -47,7 +44,7 @@ def create_update_prompt( prompt = upsert_prompt( prompt_id=prompt_id, - user_id=user_id, + user=user, name=create_prompt_request.name, description=create_prompt_request.description, system_prompt=create_prompt_request.system_prompt, @@ -55,7 +52,6 @@ def create_update_prompt( include_citations=create_prompt_request.include_citations, datetime_aware=create_prompt_request.datetime_aware, personas=personas, - shared=create_prompt_request.shared, db_session=db_session, ) return PromptSnapshot.from_model(prompt) @@ -64,7 +60,7 @@ def create_update_prompt( @basic_router.post("") def create_prompt( create_prompt_request: CreatePromptRequest, - user: User | None = Depends(current_admin_user), + user: User | None = Depends(current_user), db_session: Session = Depends(get_session), ) -> PromptSnapshot: try: @@ -124,7 +120,7 @@ def delete_prompt( ) -> None: mark_prompt_as_deleted( prompt_id=prompt_id, - user_id=user.id if user is not None else None, + user=user, db_session=db_session, ) @@ -150,7 +146,7 @@ def get_prompt( return PromptSnapshot.from_model( get_prompt_by_id( prompt_id=prompt_id, - user_id=user.id if user is not None else None, + user=user, db_session=db_session, ) ) diff --git a/backend/danswer/server/features/prompt/models.py b/backend/danswer/server/features/prompt/models.py index 0ae70c58d..1cc9452f4 100644 --- a/backend/danswer/server/features/prompt/models.py +++ b/backend/danswer/server/features/prompt/models.py @@ -6,7 +6,6 @@ from danswer.db.models import Prompt class CreatePromptRequest(BaseModel): name: str description: str - shared: bool system_prompt: str task_prompt: str include_citations: bool = False @@ -17,7 +16,6 @@ class CreatePromptRequest(BaseModel): class PromptSnapshot(BaseModel): id: int name: str - shared: bool description: str system_prompt: str task_prompt: str @@ -34,7 +32,6 @@ class PromptSnapshot(BaseModel): return PromptSnapshot( id=prompt.id, name=prompt.name, - shared=prompt.user_id is None, description=prompt.description, system_prompt=prompt.system_prompt, task_prompt=prompt.task_prompt, diff --git a/backend/danswer/server/manage/slack_bot.py b/backend/danswer/server/manage/slack_bot.py index 19003f09d..40e8663b0 100644 --- a/backend/danswer/server/manage/slack_bot.py +++ b/backend/danswer/server/manage/slack_bot.py @@ -140,7 +140,7 @@ def patch_slack_bot_config( existing_persona_id = existing_slack_bot_config.persona_id if existing_persona_id is not None: persona = get_persona_by_id( - persona_id=existing_persona_id, user_id=None, db_session=db_session + persona_id=existing_persona_id, user=None, db_session=db_session ) if not persona.name.startswith(SLACK_BOT_PERSONA_PREFIX): diff --git a/backend/danswer/server/models.py b/backend/danswer/server/models.py index d616edd4f..ca23f0a15 100644 --- a/backend/danswer/server/models.py +++ b/backend/danswer/server/models.py @@ -1,6 +1,7 @@ from typing import Generic from typing import Optional from typing import TypeVar +from uuid import UUID from pydantic import BaseModel from pydantic.generics import GenericModel @@ -21,3 +22,8 @@ class ApiKey(BaseModel): class IdReturn(BaseModel): id: int + + +class MinimalUserSnapshot(BaseModel): + id: UUID + email: str diff --git a/backend/danswer/server/query_and_chat/chat_backend.py b/backend/danswer/server/query_and_chat/chat_backend.py index 52d879dfe..bbc8eb425 100644 --- a/backend/danswer/server/query_and_chat/chat_backend.py +++ b/backend/danswer/server/query_and_chat/chat_backend.py @@ -327,7 +327,7 @@ def get_max_document_tokens( try: persona = get_persona_by_id( persona_id=persona_id, - user_id=user.id if user else None, + user=user, db_session=db_session, ) except ValueError: diff --git a/web/src/app/admin/personas/PersonaEditor.tsx b/web/src/app/admin/assistants/AssistantEditor.tsx similarity index 74% rename from web/src/app/admin/personas/PersonaEditor.tsx rename to web/src/app/admin/assistants/AssistantEditor.tsx index 6ce77edb5..05c5a71a2 100644 --- a/web/src/app/admin/personas/PersonaEditor.tsx +++ b/web/src/app/admin/assistants/AssistantEditor.tsx @@ -1,7 +1,7 @@ "use client"; -import { DocumentSet, UserGroup } from "@/lib/types"; -import { Button, Divider, Text } from "@tremor/react"; +import { CCPairBasicInfo, DocumentSet, User, UserGroup } from "@/lib/types"; +import { Button, Divider, Italic, Text } from "@tremor/react"; import { ArrayHelpers, ErrorMessage, @@ -29,6 +29,8 @@ import { EE_ENABLED } from "@/lib/constants"; import { useUserGroups } from "@/lib/hooks"; import { Bubble } from "@/components/Bubble"; import { GroupsIcon } from "@/components/icons/icons"; +import { SuccessfulPersonaUpdateRedirectType } from "./enums"; +import { DocumentSetSelectable } from "@/components/documentSet/DocumentSetSelectable"; function Label({ children }: { children: string | JSX.Element }) { return ( @@ -40,16 +42,24 @@ function SubLabel({ children }: { children: string | JSX.Element }) { return
{children}
; } -export function PersonaEditor({ +export function AssistantEditor({ existingPersona, + ccPairs, documentSets, llmOverrideOptions, defaultLLM, + user, + defaultPublic, + redirectType, }: { existingPersona?: Persona | null; + ccPairs: CCPairBasicInfo[]; documentSets: DocumentSet[]; llmOverrideOptions: string[]; defaultLLM: string; + user: User | null; + defaultPublic: boolean; + redirectType: SuccessfulPersonaUpdateRedirectType; }) { const router = useRouter(); const { popup, setPopup } = usePopup(); @@ -99,7 +109,7 @@ export function PersonaEditor({ system_prompt: existingPrompt?.system_prompt ?? "", task_prompt: existingPrompt?.task_prompt ?? "", disable_retrieval: (existingPersona?.num_chunks ?? 10) === 0, - is_public: existingPersona?.is_public ?? true, + is_public: existingPersona?.is_public ?? defaultPublic, document_set_ids: existingPersona?.document_sets?.map( (documentSet) => documentSet.id @@ -116,9 +126,9 @@ export function PersonaEditor({ }} validationSchema={Yup.object() .shape({ - name: Yup.string().required("Must give the Persona a name!"), + name: Yup.string().required("Must give the Assistant a name!"), description: Yup.string().required( - "Must give the Persona a description!" + "Must give the Assistant a description!" ), system_prompt: Yup.string(), task_prompt: Yup.string(), @@ -187,12 +197,14 @@ export function PersonaEditor({ existingPromptId: existingPrompt?.id, ...values, num_chunks: numChunks, + users: user ? [user.id] : undefined, groups, }); } else { [promptResponse, personaResponse] = await createPersona({ ...values, num_chunks: numChunks, + users: user ? [user.id] : undefined, groups, }); } @@ -201,51 +213,53 @@ export function PersonaEditor({ if (!promptResponse.ok) { error = await promptResponse.text(); } - if (personaResponse && !personaResponse.ok) { + if (!personaResponse) { + error = "Failed to create Assistant - no response received"; + } else if (!personaResponse.ok) { error = await personaResponse.text(); } - if (error) { + if (error || !personaResponse) { setPopup({ type: "error", - message: `Failed to create Persona - ${error}`, + message: `Failed to create Assistant - ${error}`, }); formikHelpers.setSubmitting(false); } else { - router.push(`/admin/personas?u=${Date.now()}`); + router.push( + redirectType === SuccessfulPersonaUpdateRedirectType.ADMIN + ? `/admin/assistants?u=${Date.now()}` + : `/chat?assistantId=${ + ((await personaResponse.json()) as Persona).id + }` + ); } }} > {({ isSubmitting, values, setFieldValue }) => (
- + <> - - - - - - <> { setFieldValue("system_prompt", e.target.value); @@ -260,11 +274,11 @@ export function PersonaEditor({ { setFieldValue("task_prompt", e.target.value); triggerFinalPromptUpdate( @@ -276,35 +290,6 @@ export function PersonaEditor({ error={finalPromptError} /> - {!values.disable_retrieval && ( - - )} - - { - setFieldValue("disable_retrieval", e.target.checked); - triggerFinalPromptUpdate( - values.system_prompt, - values.task_prompt, - e.target.checked - ); - }} - /> - {finalPrompt ? ( @@ -319,73 +304,100 @@ export function PersonaEditor({ - {!values.disable_retrieval && ( + {ccPairs.length > 0 && ( <> - + <> - ( + { + setFieldValue("disable_retrieval", e.target.checked); + triggerFinalPromptUpdate( + values.system_prompt, + values.task_prompt, + e.target.checked + ); + }} + /> + + {!values.disable_retrieval && ( + <>
-
- - <> - Select which{" "} + + <> + Select which{" "} + {!user || user.role === "admin" ? ( Document Sets - {" "} - that this Persona should search through. If - none are specified, the Persona will search - through all available documents in order to - try and response to queries. - - -
-
- {documentSets.map((documentSet) => { - const ind = values.document_set_ids.indexOf( - documentSet.id - ); - let isSelected = ind !== -1; - return ( -
{ - if (isSelected) { - arrayHelpers.remove(ind); - } else { - arrayHelpers.push(documentSet.id); - } - }} - > -
- {documentSet.name} -
-
- ); - })} -
+ + ) : ( + "Document Sets" + )}{" "} + that this Assistant should search through. If + none are specified, the Assistant will search + through all available documents in order to try + and respond to queries. + +
- )} - /> + + {documentSets.length > 0 ? ( + ( +
+
+ {documentSets.map((documentSet) => { + const ind = + values.document_set_ids.indexOf( + documentSet.id + ); + let isSelected = ind !== -1; + return ( + { + if (isSelected) { + arrayHelpers.remove(ind); + } else { + arrayHelpers.push(documentSet.id); + } + }} + /> + ); + })} +
+
+ )} + /> + ) : ( + + No Document Sets available.{" "} + {user?.role !== "admin" && ( + <> + If this functionality would be useful, reach + out to the administrators of Danswer for + assistance. + + )} + + )} + + )}
@@ -393,73 +405,38 @@ export function PersonaEditor({ )} - {EE_ENABLED && userGroups && ( + {!values.disable_retrieval && ( <> - + <> - - {userGroups && - userGroups.length > 0 && - !values.is_public && ( -
- - Select which User Groups should have access to - this Persona. - -
- {userGroups.map((userGroup) => { - const isSelected = values.groups.includes( - userGroup.id - ); - return ( - { - if (isSelected) { - setFieldValue( - "groups", - values.groups.filter( - (id) => id !== userGroup.id - ) - ); - } else { - setFieldValue("groups", [ - ...values.groups, - userGroup.id, - ]); - } - }} - > -
- -
- {userGroup.name} -
-
-
- ); - })} -
-
- )}
+ )} {llmOverrideOptions.length > 0 && defaultLLM && ( <> - + <> - Pick which LLM to use for this Persona. If left as + Pick which LLM to use for this Assistant. If left as Default, will use {defaultLLM} .
@@ -496,7 +473,10 @@ export function PersonaEditor({ {!values.disable_retrieval && ( <> - + <> How many chunks should we feed into the LLM when generating the final response? Each chunk is ~400 - words long. If you are using gpt-3.5-turbo or other - similar models, setting this to a value greater than - 5 will result in errors at query time due to the - model's input length limit. + words long.

If unspecified, will use 10 chunks. @@ -537,14 +514,17 @@ export function PersonaEditor({ )} - + <>
- Starter Messages help guide users to use this Persona. + Starter Messages help guide users to use this Assistant. They are shown to the user as clickable options when they - select this Persona. When selected, the specified message - is sent to the LLM as the initial user message. + select this Assistant. When selected, the specified + message is sent to the LLM as the initial user message.
@@ -686,6 +666,67 @@ export function PersonaEditor({ + {EE_ENABLED && userGroups && (!user || user.role === "admin") && ( + <> + + <> + + + {userGroups && + userGroups.length > 0 && + !values.is_public && ( +
+ + Select which User Groups should have access to + this Assistant. + +
+ {userGroups.map((userGroup) => { + const isSelected = values.groups.includes( + userGroup.id + ); + return ( + { + if (isSelected) { + setFieldValue( + "groups", + values.groups.filter( + (id) => id !== userGroup.id + ) + ); + } else { + setFieldValue("groups", [ + ...values.groups, + userGroup.id, + ]); + } + }} + > +
+ +
+ {userGroup.name} +
+
+
+ ); + })} +
+
+ )} + +
+ + + )} +