mirror of
https://github.com/danswer-ai/danswer.git
synced 2025-04-01 00:18:18 +02:00
Personas to have option to be aware of current date and time (#582)
This commit is contained in:
parent
37e9ccf864
commit
bf5844578c
@ -0,0 +1,37 @@
|
||||
"""Persona Datetime Aware
|
||||
|
||||
Revision ID: 30c1d5744104
|
||||
Revises: 7f99be1cb9f5
|
||||
Create Date: 2023-10-16 23:21:01.283424
|
||||
|
||||
"""
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "30c1d5744104"
|
||||
down_revision = "7f99be1cb9f5"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
op.add_column("persona", sa.Column("datetime_aware", sa.Boolean(), nullable=True))
|
||||
op.execute("UPDATE persona SET datetime_aware = TRUE")
|
||||
op.alter_column("persona", "datetime_aware", nullable=False)
|
||||
op.create_index(
|
||||
"_default_persona_name_idx",
|
||||
"persona",
|
||||
["name"],
|
||||
unique=True,
|
||||
postgresql_where=sa.text("default_persona = true"),
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.drop_index(
|
||||
"_default_persona_name_idx",
|
||||
table_name="persona",
|
||||
postgresql_where=sa.text("default_persona = true"),
|
||||
)
|
||||
op.drop_column("persona", "datetime_aware")
|
@ -18,6 +18,7 @@ from danswer.chat.chat_prompts import form_user_prompt_text
|
||||
from danswer.chat.chat_prompts import format_danswer_chunks_for_chat
|
||||
from danswer.chat.chat_prompts import REQUIRE_DANSWER_SYSTEM_MSG
|
||||
from danswer.chat.chat_prompts import YES_SEARCH
|
||||
from danswer.chat.personas import build_system_text_from_persona
|
||||
from danswer.chat.tools import call_tool
|
||||
from danswer.chunking.models import InferenceChunk
|
||||
from danswer.configs.app_configs import NUM_DOCUMENT_TOKENS_FED_TO_CHAT
|
||||
@ -338,7 +339,7 @@ def llm_contextual_chat_answer(
|
||||
last_user_msg_tokens = len(tokenizer(final_query_text))
|
||||
last_user_msg = HumanMessage(content=final_query_text)
|
||||
|
||||
system_text = persona.system_text
|
||||
system_text = build_system_text_from_persona(persona)
|
||||
system_msg = SystemMessage(content=system_text) if system_text else None
|
||||
system_tokens = len(tokenizer(system_text)) if system_text else 0
|
||||
|
||||
@ -351,6 +352,7 @@ def llm_contextual_chat_answer(
|
||||
final_msg_token_count=last_user_msg_tokens,
|
||||
)
|
||||
|
||||
# Good Debug/Breakpoint
|
||||
tokens = llm.stream(prompt)
|
||||
links = [
|
||||
chunk.source_links[0] if chunk.source_links else None
|
||||
@ -371,7 +373,7 @@ def llm_tools_enabled_chat_answer(
|
||||
tokenizer: Callable,
|
||||
) -> Iterator[str | list[InferenceChunk]]:
|
||||
retrieval_enabled = persona.retrieval_enabled
|
||||
system_text = persona.system_text
|
||||
system_text = build_system_text_from_persona(persona)
|
||||
hint_text = persona.hint_text
|
||||
tool_text = form_tool_section_text(persona.tools, persona.retrieval_enabled)
|
||||
|
||||
|
@ -1,3 +1,4 @@
|
||||
from datetime import datetime
|
||||
from typing import Any
|
||||
|
||||
import yaml
|
||||
@ -6,9 +7,25 @@ from sqlalchemy.orm import Session
|
||||
from danswer.configs.app_configs import PERSONAS_YAML
|
||||
from danswer.db.chat import upsert_persona
|
||||
from danswer.db.engine import get_sqlalchemy_engine
|
||||
from danswer.db.models import Persona
|
||||
from danswer.db.models import ToolInfo
|
||||
|
||||
|
||||
def build_system_text_from_persona(persona: Persona) -> str | None:
|
||||
text = (persona.system_text or "").strip()
|
||||
if persona.datetime_aware:
|
||||
current_datetime = datetime.now()
|
||||
# Format looks like: "October 16, 2023 14:30"
|
||||
formatted_datetime = current_datetime.strftime("%B %d, %Y %H:%M")
|
||||
|
||||
text += (
|
||||
"\n\nAdditional Information:\n"
|
||||
f"\t- The current date and time is {formatted_datetime}."
|
||||
)
|
||||
|
||||
return text or None
|
||||
|
||||
|
||||
def validate_tool_info(item: Any) -> ToolInfo:
|
||||
if not (
|
||||
isinstance(item, dict)
|
||||
@ -33,12 +50,17 @@ def load_personas_from_yaml(personas_yaml: str = PERSONAS_YAML) -> None:
|
||||
tools = [validate_tool_info(tool) for tool in persona["tools"]]
|
||||
|
||||
upsert_persona(
|
||||
persona_id=persona["id"],
|
||||
name=persona["name"],
|
||||
retrieval_enabled=persona["retrieval_enabled"],
|
||||
system_text=persona["system"],
|
||||
retrieval_enabled=persona.get("retrieval_enabled", True),
|
||||
# Default to knowing the date/time if not specified, however if there is no
|
||||
# system prompt, do not interfere with the flow by adding a
|
||||
# system prompt that is ONLY the date info, this would likely not be useful
|
||||
datetime_aware=persona.get(
|
||||
"datetime_aware", bool(persona.get("system"))
|
||||
),
|
||||
system_text=persona.get("system"),
|
||||
tools=tools,
|
||||
hint_text=persona["hint"],
|
||||
hint_text=persona.get("hint"),
|
||||
default_persona=True,
|
||||
db_session=db_session,
|
||||
)
|
||||
|
@ -1,6 +1,5 @@
|
||||
personas:
|
||||
- id: 1
|
||||
name: "Danswer"
|
||||
- name: "Danswer"
|
||||
system: |
|
||||
You are a question answering system that is constantly learning and improving.
|
||||
You can process and comprehend vast amounts of text and utilize this knowledge to provide accurate and detailed answers to diverse queries.
|
||||
@ -8,6 +7,9 @@ personas:
|
||||
Cite relevant statements using the format [1], [2], etc to reference the document number, do not provide any links following the citation.
|
||||
# Danswer custom tool flow, "Current Search" tool name is reserved if this is enabled.
|
||||
retrieval_enabled: true
|
||||
# Inject a statement at the end of system prompt to inform the LLM of the current date/time
|
||||
# Format looks like: "October 16, 2023 14:30"
|
||||
datetime_aware: true
|
||||
# Example of adding tools, it must follow this structure:
|
||||
# tools:
|
||||
# - name: "Calculator"
|
||||
|
@ -265,22 +265,38 @@ def fetch_persona_by_id(persona_id: int, db_session: Session) -> Persona:
|
||||
return persona
|
||||
|
||||
|
||||
def fetch_default_persona_by_name(
|
||||
persona_name: str, db_session: Session
|
||||
) -> Persona | None:
|
||||
stmt = select(Persona).where(
|
||||
Persona.name == persona_name, Persona.default_persona == True # noqa: E712
|
||||
)
|
||||
result = db_session.execute(stmt).scalar_one_or_none()
|
||||
return result
|
||||
|
||||
|
||||
def upsert_persona(
|
||||
persona_id: int | None,
|
||||
name: str,
|
||||
retrieval_enabled: bool,
|
||||
datetime_aware: bool,
|
||||
system_text: str | None,
|
||||
tools: list[ToolInfo] | None,
|
||||
hint_text: str | None,
|
||||
default_persona: bool,
|
||||
db_session: Session,
|
||||
persona_id: int | None = None,
|
||||
default_persona: bool = False,
|
||||
commit: bool = True,
|
||||
) -> Persona:
|
||||
persona = db_session.query(Persona).filter_by(id=persona_id).first()
|
||||
|
||||
# Default personas are defined via yaml files at deployment time
|
||||
if persona is None and default_persona:
|
||||
persona = fetch_default_persona_by_name(name, db_session)
|
||||
|
||||
if persona:
|
||||
persona.name = name
|
||||
persona.retrieval_enabled = retrieval_enabled
|
||||
persona.datetime_aware = datetime_aware
|
||||
persona.system_text = system_text
|
||||
persona.tools = tools
|
||||
persona.hint_text = hint_text
|
||||
@ -289,6 +305,7 @@ def upsert_persona(
|
||||
persona = Persona(
|
||||
name=name,
|
||||
retrieval_enabled=retrieval_enabled,
|
||||
datetime_aware=datetime_aware,
|
||||
system_text=system_text,
|
||||
tools=tools,
|
||||
hint_text=hint_text,
|
||||
|
@ -459,10 +459,12 @@ class ToolInfo(TypedDict):
|
||||
class Persona(Base):
|
||||
# TODO introduce user and group ownership for personas
|
||||
__tablename__ = "persona"
|
||||
|
||||
id: Mapped[int] = mapped_column(primary_key=True)
|
||||
name: Mapped[str] = mapped_column(String)
|
||||
# Danswer retrieval, treated as a special tool
|
||||
retrieval_enabled: Mapped[bool] = mapped_column(Boolean, default=True)
|
||||
datetime_aware: Mapped[bool] = mapped_column(Boolean, default=True)
|
||||
system_text: Mapped[str | None] = mapped_column(Text, nullable=True)
|
||||
tools: Mapped[list[ToolInfo] | None] = mapped_column(
|
||||
postgresql.JSONB(), nullable=True
|
||||
@ -480,6 +482,16 @@ class Persona(Base):
|
||||
back_populates="personas",
|
||||
)
|
||||
|
||||
# Default personas loaded via yaml cannot have the same name
|
||||
__table_args__ = (
|
||||
Index(
|
||||
"_default_persona_name_idx",
|
||||
"name",
|
||||
unique=True,
|
||||
postgresql_where=(default_persona == True), # noqa: E712
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
class ChatMessage(Base):
|
||||
__tablename__ = "chat_message"
|
||||
|
Loading…
x
Reference in New Issue
Block a user