Default Personas to have Document Sets (#614)

This commit is contained in:
Yuhong Sun 2023-10-22 16:57:16 -07:00 committed by GitHub
parent 4fa96788f6
commit 8403b94722
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 116 additions and 11 deletions

View File

@ -36,9 +36,10 @@ from danswer.llm.build import get_default_llm
from danswer.llm.llm import LLM
from danswer.llm.utils import get_default_llm_tokenizer
from danswer.llm.utils import translate_danswer_msg_to_langchain
from danswer.search.access_filters import build_user_only_filters
from danswer.search.access_filters import build_access_filters_for_user
from danswer.search.semantic_search import chunks_to_search_docs
from danswer.search.semantic_search import retrieve_ranked_documents
from danswer.server.models import IndexFilters
from danswer.server.models import RetrievalDocs
from danswer.utils.logger import setup_logger
from danswer.utils.text_processing import extract_embedded_json
@ -120,8 +121,7 @@ def danswer_chat_retrieval(
query_message: ChatMessage,
history: list[ChatMessage],
llm: LLM,
user: User | None,
db_session: Session,
filters: IndexFilters,
) -> list[InferenceChunk]:
if history:
query_combination_msgs = build_combined_query(query_message, history)
@ -132,7 +132,7 @@ def danswer_chat_retrieval(
# Good Debug/Breakpoint
ranked_chunks, unranked_chunks = retrieve_ranked_documents(
query=reworded_query,
filters=build_user_only_filters(user, db_session),
filters=filters,
favor_recent=False,
datastore=get_default_document_index(),
)
@ -322,14 +322,23 @@ def llm_contextual_chat_answer(
# Be a little forgiving though, if we match yes, it's good enough
retrieved_chunks: list[InferenceChunk] = []
if (YES_SEARCH.split()[0] + " ").lower() in model_out.lower():
user_acl_filters = build_access_filters_for_user(user, db_session)
doc_set_filter = [doc_set.name for doc_set in persona.document_sets] or None
final_filters = IndexFilters(
source_type=None,
document_set=doc_set_filter,
time_cutoff=None,
access_control_list=user_acl_filters,
)
retrieved_chunks = danswer_chat_retrieval(
query_message=last_message,
history=previous_messages,
llm=llm,
user=user,
db_session=db_session,
filters=final_filters,
)
yield retrieved_chunks
tool_result_str = format_danswer_chunks_for_chat(retrieved_chunks)
last_user_msg_text = form_tool_less_followup_text(
@ -449,14 +458,24 @@ def llm_tools_enabled_chat_answer(
retrieval_enabled
and final_result.action.lower() == DANSWER_TOOL_NAME.lower()
):
user_acl_filters = build_access_filters_for_user(user, db_session)
doc_set_filter = [doc_set.name for doc_set in persona.document_sets] or None
final_filters = IndexFilters(
source_type=None,
document_set=doc_set_filter,
time_cutoff=None,
access_control_list=user_acl_filters,
)
retrieved_chunks = danswer_chat_retrieval(
query_message=last_message,
history=previous_messages,
llm=llm,
user=user,
db_session=db_session,
filters=final_filters,
)
yield retrieved_chunks
tool_result_str = format_danswer_chunks_for_chat(retrieved_chunks)
else:
tool_result_str = call_tool(final_result)

View File

@ -6,7 +6,9 @@ from sqlalchemy.orm import Session
from danswer.configs.app_configs import PERSONAS_YAML
from danswer.db.chat import upsert_persona
from danswer.db.document_set import get_or_create_document_set_by_name
from danswer.db.engine import get_sqlalchemy_engine
from danswer.db.models import DocumentSet as DocumentSetDBModel
from danswer.db.models import Persona
from danswer.db.models import ToolInfo
@ -49,6 +51,18 @@ def load_personas_from_yaml(personas_yaml: str = PERSONAS_YAML) -> None:
for persona in all_personas:
tools = [validate_tool_info(tool) for tool in persona["tools"]]
doc_set_names = persona["document_sets"]
doc_sets: list[DocumentSetDBModel] | None = [
get_or_create_document_set_by_name(db_session, name)
for name in doc_set_names
]
# Assume if user hasn't set any document sets for the persona, the user may want
# to later attach document sets to the persona manually, therefore, don't overwrite/reset
# the document sets for the persona
if not doc_sets:
doc_sets = None
upsert_persona(
name=persona["name"],
retrieval_enabled=persona.get("retrieval_enabled", True),
@ -62,5 +76,6 @@ def load_personas_from_yaml(personas_yaml: str = PERSONAS_YAML) -> None:
tools=tools,
hint_text=persona.get("hint"),
default_persona=True,
document_sets=doc_sets,
db_session=db_session,
)

View File

@ -5,17 +5,30 @@ personas:
You can process and comprehend vast amounts of text and utilize this knowledge to provide accurate and detailed answers to diverse queries.
Your responses are as INFORMATIVE and DETAILED as possible.
Cite relevant statements using the format [1], [2], etc to reference the document number, do not provide any links following the citation.
# Document Sets that this persona has access to, specified as a list of names here.
# If left empty, the persona has access to all and only public docs
# If the document set by the name exists, it will be attached to the persona
# If the document set by the name does not exist, it will be created as an empty document set with no connectors
# The admin can then use the UI to add new connectors to the document set
# Example:
# document_sets:
# - "HR Resources"
# - "Engineer Onboarding"
# - "Benefits"
document_sets: []
# Danswer custom tool flow, "Current Search" tool name is reserved if this is enabled.
retrieval_enabled: true
# Inject a statement at the end of system prompt to inform the LLM of the current date/time
# Format looks like: "October 16, 2023 14:30"
datetime_aware: true
# Personas can be given tools for Agentifying Danswer, however the tool call must be implemented in the code
# Once implemented, it can be given to personas via the config.
# Example of adding tools, it must follow this structure:
# tools:
# - name: "Calculator"
# description: "Use this tool to accurately process math equations, counting, etc."
# - name: "Current Time"
# description: "Call this to get the current date and time."
# - name: "Current Weather"
# description: "Call this to get the current weather info."
tools: []
# Short tip to pass near the end of the prompt to emphasize some requirement
hint: "Try to be as informative as possible!"

View File

@ -13,6 +13,7 @@ from danswer.configs.app_configs import HARD_DELETE_CHATS
from danswer.configs.constants import MessageType
from danswer.db.models import ChatMessage
from danswer.db.models import ChatSession
from danswer.db.models import DocumentSet as DocumentSetDBModel
from danswer.db.models import Persona
from danswer.db.models import ToolInfo
@ -275,6 +276,21 @@ def fetch_default_persona_by_name(
return result
def fetch_persona_by_name(persona_name: str, db_session: Session) -> Persona | None:
"""Try to fetch a default persona by name first,
if not exist, try to find any persona with the name
Note that name is not guaranteed unique unless default is true"""
persona = fetch_default_persona_by_name(persona_name, db_session)
if persona is not None:
return persona
stmt = select(Persona).where(Persona.name == persona_name) # noqa: E712
result = db_session.execute(stmt).first()
if result:
return result[0]
return None
def upsert_persona(
name: str,
retrieval_enabled: bool,
@ -285,6 +301,7 @@ def upsert_persona(
db_session: Session,
persona_id: int | None = None,
default_persona: bool = False,
document_sets: list[DocumentSetDBModel] | None = None,
commit: bool = True,
) -> Persona:
persona = db_session.query(Persona).filter_by(id=persona_id).first()
@ -301,6 +318,13 @@ def upsert_persona(
persona.tools = tools
persona.hint_text = hint_text
persona.default_persona = default_persona
# Do not delete any associations manually added unless
# a new updated list is provided
if document_sets is not None:
persona.document_sets.clear()
persona.document_sets = document_sets
else:
persona = Persona(
name=name,
@ -310,6 +334,7 @@ def upsert_persona(
tools=tools,
hint_text=hint_text,
default_persona=default_persona,
document_sets=document_sets if document_sets else [],
)
db_session.add(persona)

View File

@ -49,6 +49,14 @@ def get_document_set_by_id(
)
def get_document_set_by_name(
db_session: Session, document_set_name: str
) -> DocumentSetDBModel | None:
return db_session.scalar(
select(DocumentSetDBModel).where(DocumentSetDBModel.name == document_set_name)
)
def get_document_sets_by_ids(
db_session: Session, document_set_ids: list[int]
) -> Sequence[DocumentSetDBModel]:
@ -363,3 +371,28 @@ def fetch_document_sets_for_documents(
.group_by(Document.id)
)
return db_session.execute(stmt).all() # type: ignore
def get_or_create_document_set_by_name(
db_session: Session,
document_set_name: str,
document_set_description: str = "Default Persona created Document-Set, "
"please update description",
) -> DocumentSetDBModel:
"""This is used by the default personas which need to attach to document sets
on server startup"""
doc_set = get_document_set_by_name(db_session, document_set_name)
if doc_set is not None:
return doc_set
new_doc_set = DocumentSetDBModel(
name=document_set_name,
description=document_set_description,
user_id=None,
is_up_to_date=True,
)
db_session.add(new_doc_set)
db_session.commit()
return new_doc_set

View File

@ -90,7 +90,7 @@ if __name__ == "__main__":
"-c",
"--contextual",
action="store_true",
help="If this flag is set, the chat is able to call tools.",
help="If this flag is set, the chat is able to use retrieval",
)
args = parser.parse_args()