Personas to have option to be aware of current date and time (#582)

This commit is contained in:
Yuhong Sun 2023-10-16 23:42:39 -07:00 committed by GitHub
parent 37e9ccf864
commit bf5844578c
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 102 additions and 10 deletions

View File

@ -0,0 +1,37 @@
"""Persona Datetime Aware
Revision ID: 30c1d5744104
Revises: 7f99be1cb9f5
Create Date: 2023-10-16 23:21:01.283424
"""
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision = "30c1d5744104"
down_revision = "7f99be1cb9f5"
branch_labels = None
depends_on = None
def upgrade() -> None:
op.add_column("persona", sa.Column("datetime_aware", sa.Boolean(), nullable=True))
op.execute("UPDATE persona SET datetime_aware = TRUE")
op.alter_column("persona", "datetime_aware", nullable=False)
op.create_index(
"_default_persona_name_idx",
"persona",
["name"],
unique=True,
postgresql_where=sa.text("default_persona = true"),
)
def downgrade() -> None:
op.drop_index(
"_default_persona_name_idx",
table_name="persona",
postgresql_where=sa.text("default_persona = true"),
)
op.drop_column("persona", "datetime_aware")

View File

@ -18,6 +18,7 @@ from danswer.chat.chat_prompts import form_user_prompt_text
from danswer.chat.chat_prompts import format_danswer_chunks_for_chat
from danswer.chat.chat_prompts import REQUIRE_DANSWER_SYSTEM_MSG
from danswer.chat.chat_prompts import YES_SEARCH
from danswer.chat.personas import build_system_text_from_persona
from danswer.chat.tools import call_tool
from danswer.chunking.models import InferenceChunk
from danswer.configs.app_configs import NUM_DOCUMENT_TOKENS_FED_TO_CHAT
@ -338,7 +339,7 @@ def llm_contextual_chat_answer(
last_user_msg_tokens = len(tokenizer(final_query_text))
last_user_msg = HumanMessage(content=final_query_text)
system_text = persona.system_text
system_text = build_system_text_from_persona(persona)
system_msg = SystemMessage(content=system_text) if system_text else None
system_tokens = len(tokenizer(system_text)) if system_text else 0
@ -351,6 +352,7 @@ def llm_contextual_chat_answer(
final_msg_token_count=last_user_msg_tokens,
)
# Good Debug/Breakpoint
tokens = llm.stream(prompt)
links = [
chunk.source_links[0] if chunk.source_links else None
@ -371,7 +373,7 @@ def llm_tools_enabled_chat_answer(
tokenizer: Callable,
) -> Iterator[str | list[InferenceChunk]]:
retrieval_enabled = persona.retrieval_enabled
system_text = persona.system_text
system_text = build_system_text_from_persona(persona)
hint_text = persona.hint_text
tool_text = form_tool_section_text(persona.tools, persona.retrieval_enabled)

View File

@ -1,3 +1,4 @@
from datetime import datetime
from typing import Any
import yaml
@ -6,9 +7,25 @@ from sqlalchemy.orm import Session
from danswer.configs.app_configs import PERSONAS_YAML
from danswer.db.chat import upsert_persona
from danswer.db.engine import get_sqlalchemy_engine
from danswer.db.models import Persona
from danswer.db.models import ToolInfo
def build_system_text_from_persona(persona: Persona) -> str | None:
text = (persona.system_text or "").strip()
if persona.datetime_aware:
current_datetime = datetime.now()
# Format looks like: "October 16, 2023 14:30"
formatted_datetime = current_datetime.strftime("%B %d, %Y %H:%M")
text += (
"\n\nAdditional Information:\n"
f"\t- The current date and time is {formatted_datetime}."
)
return text or None
def validate_tool_info(item: Any) -> ToolInfo:
if not (
isinstance(item, dict)
@ -33,12 +50,17 @@ def load_personas_from_yaml(personas_yaml: str = PERSONAS_YAML) -> None:
tools = [validate_tool_info(tool) for tool in persona["tools"]]
upsert_persona(
persona_id=persona["id"],
name=persona["name"],
retrieval_enabled=persona["retrieval_enabled"],
system_text=persona["system"],
retrieval_enabled=persona.get("retrieval_enabled", True),
# Default to knowing the date/time if not specified, however if there is no
# system prompt, do not interfere with the flow by adding a
# system prompt that is ONLY the date info, this would likely not be useful
datetime_aware=persona.get(
"datetime_aware", bool(persona.get("system"))
),
system_text=persona.get("system"),
tools=tools,
hint_text=persona["hint"],
hint_text=persona.get("hint"),
default_persona=True,
db_session=db_session,
)

View File

@ -1,6 +1,5 @@
personas:
- id: 1
name: "Danswer"
- name: "Danswer"
system: |
You are a question answering system that is constantly learning and improving.
You can process and comprehend vast amounts of text and utilize this knowledge to provide accurate and detailed answers to diverse queries.
@ -8,6 +7,9 @@ personas:
Cite relevant statements using the format [1], [2], etc to reference the document number, do not provide any links following the citation.
# Danswer custom tool flow, "Current Search" tool name is reserved if this is enabled.
retrieval_enabled: true
# Inject a statement at the end of system prompt to inform the LLM of the current date/time
# Format looks like: "October 16, 2023 14:30"
datetime_aware: true
# Example of adding tools, it must follow this structure:
# tools:
# - name: "Calculator"

View File

@ -265,22 +265,38 @@ def fetch_persona_by_id(persona_id: int, db_session: Session) -> Persona:
return persona
def fetch_default_persona_by_name(
persona_name: str, db_session: Session
) -> Persona | None:
stmt = select(Persona).where(
Persona.name == persona_name, Persona.default_persona == True # noqa: E712
)
result = db_session.execute(stmt).scalar_one_or_none()
return result
def upsert_persona(
persona_id: int | None,
name: str,
retrieval_enabled: bool,
datetime_aware: bool,
system_text: str | None,
tools: list[ToolInfo] | None,
hint_text: str | None,
default_persona: bool,
db_session: Session,
persona_id: int | None = None,
default_persona: bool = False,
commit: bool = True,
) -> Persona:
persona = db_session.query(Persona).filter_by(id=persona_id).first()
# Default personas are defined via yaml files at deployment time
if persona is None and default_persona:
persona = fetch_default_persona_by_name(name, db_session)
if persona:
persona.name = name
persona.retrieval_enabled = retrieval_enabled
persona.datetime_aware = datetime_aware
persona.system_text = system_text
persona.tools = tools
persona.hint_text = hint_text
@ -289,6 +305,7 @@ def upsert_persona(
persona = Persona(
name=name,
retrieval_enabled=retrieval_enabled,
datetime_aware=datetime_aware,
system_text=system_text,
tools=tools,
hint_text=hint_text,

View File

@ -459,10 +459,12 @@ class ToolInfo(TypedDict):
class Persona(Base):
# TODO introduce user and group ownership for personas
__tablename__ = "persona"
id: Mapped[int] = mapped_column(primary_key=True)
name: Mapped[str] = mapped_column(String)
# Danswer retrieval, treated as a special tool
retrieval_enabled: Mapped[bool] = mapped_column(Boolean, default=True)
datetime_aware: Mapped[bool] = mapped_column(Boolean, default=True)
system_text: Mapped[str | None] = mapped_column(Text, nullable=True)
tools: Mapped[list[ToolInfo] | None] = mapped_column(
postgresql.JSONB(), nullable=True
@ -480,6 +482,16 @@ class Persona(Base):
back_populates="personas",
)
# Default personas loaded via yaml cannot have the same name
__table_args__ = (
Index(
"_default_persona_name_idx",
"name",
unique=True,
postgresql_where=(default_persona == True), # noqa: E712
),
)
class ChatMessage(Base):
__tablename__ = "chat_message"