From 8403b9472295fecf12da84594959c1618a0029ac Mon Sep 17 00:00:00 2001 From: Yuhong Sun Date: Sun, 22 Oct 2023 16:57:16 -0700 Subject: [PATCH] Default Personas to have Document Sets (#614) --- backend/danswer/chat/chat_llm.py | 35 +++++++++++++++++------ backend/danswer/chat/personas.py | 15 ++++++++++ backend/danswer/chat/personas.yaml | 17 +++++++++-- backend/danswer/db/chat.py | 25 ++++++++++++++++ backend/danswer/db/document_set.py | 33 +++++++++++++++++++++ backend/scripts/simulate_chat_frontend.py | 2 +- 6 files changed, 116 insertions(+), 11 deletions(-) diff --git a/backend/danswer/chat/chat_llm.py b/backend/danswer/chat/chat_llm.py index bea0b1e94..4a32e51cc 100644 --- a/backend/danswer/chat/chat_llm.py +++ b/backend/danswer/chat/chat_llm.py @@ -36,9 +36,10 @@ from danswer.llm.build import get_default_llm from danswer.llm.llm import LLM from danswer.llm.utils import get_default_llm_tokenizer from danswer.llm.utils import translate_danswer_msg_to_langchain -from danswer.search.access_filters import build_user_only_filters +from danswer.search.access_filters import build_access_filters_for_user from danswer.search.semantic_search import chunks_to_search_docs from danswer.search.semantic_search import retrieve_ranked_documents +from danswer.server.models import IndexFilters from danswer.server.models import RetrievalDocs from danswer.utils.logger import setup_logger from danswer.utils.text_processing import extract_embedded_json @@ -120,8 +121,7 @@ def danswer_chat_retrieval( query_message: ChatMessage, history: list[ChatMessage], llm: LLM, - user: User | None, - db_session: Session, + filters: IndexFilters, ) -> list[InferenceChunk]: if history: query_combination_msgs = build_combined_query(query_message, history) @@ -132,7 +132,7 @@ def danswer_chat_retrieval( # Good Debug/Breakpoint ranked_chunks, unranked_chunks = retrieve_ranked_documents( query=reworded_query, - filters=build_user_only_filters(user, db_session), + filters=filters, favor_recent=False, datastore=get_default_document_index(), ) @@ -322,14 +322,23 @@ def llm_contextual_chat_answer( # Be a little forgiving though, if we match yes, it's good enough retrieved_chunks: list[InferenceChunk] = [] if (YES_SEARCH.split()[0] + " ").lower() in model_out.lower(): + user_acl_filters = build_access_filters_for_user(user, db_session) + doc_set_filter = [doc_set.name for doc_set in persona.document_sets] or None + final_filters = IndexFilters( + source_type=None, + document_set=doc_set_filter, + time_cutoff=None, + access_control_list=user_acl_filters, + ) + retrieved_chunks = danswer_chat_retrieval( query_message=last_message, history=previous_messages, llm=llm, - user=user, - db_session=db_session, + filters=final_filters, ) yield retrieved_chunks + tool_result_str = format_danswer_chunks_for_chat(retrieved_chunks) last_user_msg_text = form_tool_less_followup_text( @@ -449,14 +458,24 @@ def llm_tools_enabled_chat_answer( retrieval_enabled and final_result.action.lower() == DANSWER_TOOL_NAME.lower() ): + user_acl_filters = build_access_filters_for_user(user, db_session) + doc_set_filter = [doc_set.name for doc_set in persona.document_sets] or None + + final_filters = IndexFilters( + source_type=None, + document_set=doc_set_filter, + time_cutoff=None, + access_control_list=user_acl_filters, + ) + retrieved_chunks = danswer_chat_retrieval( query_message=last_message, history=previous_messages, llm=llm, - user=user, - db_session=db_session, + filters=final_filters, ) yield retrieved_chunks + tool_result_str = format_danswer_chunks_for_chat(retrieved_chunks) else: tool_result_str = call_tool(final_result) diff --git a/backend/danswer/chat/personas.py b/backend/danswer/chat/personas.py index 6d803ea61..bb8c6822a 100644 --- a/backend/danswer/chat/personas.py +++ b/backend/danswer/chat/personas.py @@ -6,7 +6,9 @@ from sqlalchemy.orm import Session from danswer.configs.app_configs import PERSONAS_YAML from danswer.db.chat import upsert_persona +from danswer.db.document_set import get_or_create_document_set_by_name from danswer.db.engine import get_sqlalchemy_engine +from danswer.db.models import DocumentSet as DocumentSetDBModel from danswer.db.models import Persona from danswer.db.models import ToolInfo @@ -49,6 +51,18 @@ def load_personas_from_yaml(personas_yaml: str = PERSONAS_YAML) -> None: for persona in all_personas: tools = [validate_tool_info(tool) for tool in persona["tools"]] + doc_set_names = persona["document_sets"] + doc_sets: list[DocumentSetDBModel] | None = [ + get_or_create_document_set_by_name(db_session, name) + for name in doc_set_names + ] + + # Assume if user hasn't set any document sets for the persona, the user may want + # to later attach document sets to the persona manually, therefore, don't overwrite/reset + # the document sets for the persona + if not doc_sets: + doc_sets = None + upsert_persona( name=persona["name"], retrieval_enabled=persona.get("retrieval_enabled", True), @@ -62,5 +76,6 @@ def load_personas_from_yaml(personas_yaml: str = PERSONAS_YAML) -> None: tools=tools, hint_text=persona.get("hint"), default_persona=True, + document_sets=doc_sets, db_session=db_session, ) diff --git a/backend/danswer/chat/personas.yaml b/backend/danswer/chat/personas.yaml index ef29fe7eb..8041ed29b 100644 --- a/backend/danswer/chat/personas.yaml +++ b/backend/danswer/chat/personas.yaml @@ -5,17 +5,30 @@ personas: You can process and comprehend vast amounts of text and utilize this knowledge to provide accurate and detailed answers to diverse queries. Your responses are as INFORMATIVE and DETAILED as possible. Cite relevant statements using the format [1], [2], etc to reference the document number, do not provide any links following the citation. + # Document Sets that this persona has access to, specified as a list of names here. + # If left empty, the persona has access to all and only public docs + # If the document set by the name exists, it will be attached to the persona + # If the document set by the name does not exist, it will be created as an empty document set with no connectors + # The admin can then use the UI to add new connectors to the document set + # Example: + # document_sets: + # - "HR Resources" + # - "Engineer Onboarding" + # - "Benefits" + document_sets: [] # Danswer custom tool flow, "Current Search" tool name is reserved if this is enabled. retrieval_enabled: true # Inject a statement at the end of system prompt to inform the LLM of the current date/time # Format looks like: "October 16, 2023 14:30" datetime_aware: true + # Personas can be given tools for Agentifying Danswer, however the tool call must be implemented in the code + # Once implemented, it can be given to personas via the config. # Example of adding tools, it must follow this structure: # tools: # - name: "Calculator" # description: "Use this tool to accurately process math equations, counting, etc." - # - name: "Current Time" - # description: "Call this to get the current date and time." + # - name: "Current Weather" + # description: "Call this to get the current weather info." tools: [] # Short tip to pass near the end of the prompt to emphasize some requirement hint: "Try to be as informative as possible!" diff --git a/backend/danswer/db/chat.py b/backend/danswer/db/chat.py index bf922dc80..3b0502c00 100644 --- a/backend/danswer/db/chat.py +++ b/backend/danswer/db/chat.py @@ -13,6 +13,7 @@ from danswer.configs.app_configs import HARD_DELETE_CHATS from danswer.configs.constants import MessageType from danswer.db.models import ChatMessage from danswer.db.models import ChatSession +from danswer.db.models import DocumentSet as DocumentSetDBModel from danswer.db.models import Persona from danswer.db.models import ToolInfo @@ -275,6 +276,21 @@ def fetch_default_persona_by_name( return result +def fetch_persona_by_name(persona_name: str, db_session: Session) -> Persona | None: + """Try to fetch a default persona by name first, + if not exist, try to find any persona with the name + Note that name is not guaranteed unique unless default is true""" + persona = fetch_default_persona_by_name(persona_name, db_session) + if persona is not None: + return persona + + stmt = select(Persona).where(Persona.name == persona_name) # noqa: E712 + result = db_session.execute(stmt).first() + if result: + return result[0] + return None + + def upsert_persona( name: str, retrieval_enabled: bool, @@ -285,6 +301,7 @@ def upsert_persona( db_session: Session, persona_id: int | None = None, default_persona: bool = False, + document_sets: list[DocumentSetDBModel] | None = None, commit: bool = True, ) -> Persona: persona = db_session.query(Persona).filter_by(id=persona_id).first() @@ -301,6 +318,13 @@ def upsert_persona( persona.tools = tools persona.hint_text = hint_text persona.default_persona = default_persona + + # Do not delete any associations manually added unless + # a new updated list is provided + if document_sets is not None: + persona.document_sets.clear() + persona.document_sets = document_sets + else: persona = Persona( name=name, @@ -310,6 +334,7 @@ def upsert_persona( tools=tools, hint_text=hint_text, default_persona=default_persona, + document_sets=document_sets if document_sets else [], ) db_session.add(persona) diff --git a/backend/danswer/db/document_set.py b/backend/danswer/db/document_set.py index 3a1522602..328abb268 100644 --- a/backend/danswer/db/document_set.py +++ b/backend/danswer/db/document_set.py @@ -49,6 +49,14 @@ def get_document_set_by_id( ) +def get_document_set_by_name( + db_session: Session, document_set_name: str +) -> DocumentSetDBModel | None: + return db_session.scalar( + select(DocumentSetDBModel).where(DocumentSetDBModel.name == document_set_name) + ) + + def get_document_sets_by_ids( db_session: Session, document_set_ids: list[int] ) -> Sequence[DocumentSetDBModel]: @@ -363,3 +371,28 @@ def fetch_document_sets_for_documents( .group_by(Document.id) ) return db_session.execute(stmt).all() # type: ignore + + +def get_or_create_document_set_by_name( + db_session: Session, + document_set_name: str, + document_set_description: str = "Default Persona created Document-Set, " + "please update description", +) -> DocumentSetDBModel: + """This is used by the default personas which need to attach to document sets + on server startup""" + doc_set = get_document_set_by_name(db_session, document_set_name) + if doc_set is not None: + return doc_set + + new_doc_set = DocumentSetDBModel( + name=document_set_name, + description=document_set_description, + user_id=None, + is_up_to_date=True, + ) + + db_session.add(new_doc_set) + db_session.commit() + + return new_doc_set diff --git a/backend/scripts/simulate_chat_frontend.py b/backend/scripts/simulate_chat_frontend.py index e3a4064c4..8c49c2e04 100644 --- a/backend/scripts/simulate_chat_frontend.py +++ b/backend/scripts/simulate_chat_frontend.py @@ -90,7 +90,7 @@ if __name__ == "__main__": "-c", "--contextual", action="store_true", - help="If this flag is set, the chat is able to call tools.", + help="If this flag is set, the chat is able to use retrieval", ) args = parser.parse_args()