mirror of
https://github.com/danswer-ai/danswer.git
synced 2025-09-20 21:33:56 +02:00
Personas to have option to be aware of current date and time (#582)
This commit is contained in:
@@ -0,0 +1,37 @@
|
|||||||
|
"""Persona Datetime Aware
|
||||||
|
|
||||||
|
Revision ID: 30c1d5744104
|
||||||
|
Revises: 7f99be1cb9f5
|
||||||
|
Create Date: 2023-10-16 23:21:01.283424
|
||||||
|
|
||||||
|
"""
|
||||||
|
from alembic import op
|
||||||
|
import sqlalchemy as sa
|
||||||
|
|
||||||
|
# revision identifiers, used by Alembic.
|
||||||
|
revision = "30c1d5744104"
|
||||||
|
down_revision = "7f99be1cb9f5"
|
||||||
|
branch_labels = None
|
||||||
|
depends_on = None
|
||||||
|
|
||||||
|
|
||||||
|
def upgrade() -> None:
|
||||||
|
op.add_column("persona", sa.Column("datetime_aware", sa.Boolean(), nullable=True))
|
||||||
|
op.execute("UPDATE persona SET datetime_aware = TRUE")
|
||||||
|
op.alter_column("persona", "datetime_aware", nullable=False)
|
||||||
|
op.create_index(
|
||||||
|
"_default_persona_name_idx",
|
||||||
|
"persona",
|
||||||
|
["name"],
|
||||||
|
unique=True,
|
||||||
|
postgresql_where=sa.text("default_persona = true"),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def downgrade() -> None:
|
||||||
|
op.drop_index(
|
||||||
|
"_default_persona_name_idx",
|
||||||
|
table_name="persona",
|
||||||
|
postgresql_where=sa.text("default_persona = true"),
|
||||||
|
)
|
||||||
|
op.drop_column("persona", "datetime_aware")
|
@@ -18,6 +18,7 @@ from danswer.chat.chat_prompts import form_user_prompt_text
|
|||||||
from danswer.chat.chat_prompts import format_danswer_chunks_for_chat
|
from danswer.chat.chat_prompts import format_danswer_chunks_for_chat
|
||||||
from danswer.chat.chat_prompts import REQUIRE_DANSWER_SYSTEM_MSG
|
from danswer.chat.chat_prompts import REQUIRE_DANSWER_SYSTEM_MSG
|
||||||
from danswer.chat.chat_prompts import YES_SEARCH
|
from danswer.chat.chat_prompts import YES_SEARCH
|
||||||
|
from danswer.chat.personas import build_system_text_from_persona
|
||||||
from danswer.chat.tools import call_tool
|
from danswer.chat.tools import call_tool
|
||||||
from danswer.chunking.models import InferenceChunk
|
from danswer.chunking.models import InferenceChunk
|
||||||
from danswer.configs.app_configs import NUM_DOCUMENT_TOKENS_FED_TO_CHAT
|
from danswer.configs.app_configs import NUM_DOCUMENT_TOKENS_FED_TO_CHAT
|
||||||
@@ -338,7 +339,7 @@ def llm_contextual_chat_answer(
|
|||||||
last_user_msg_tokens = len(tokenizer(final_query_text))
|
last_user_msg_tokens = len(tokenizer(final_query_text))
|
||||||
last_user_msg = HumanMessage(content=final_query_text)
|
last_user_msg = HumanMessage(content=final_query_text)
|
||||||
|
|
||||||
system_text = persona.system_text
|
system_text = build_system_text_from_persona(persona)
|
||||||
system_msg = SystemMessage(content=system_text) if system_text else None
|
system_msg = SystemMessage(content=system_text) if system_text else None
|
||||||
system_tokens = len(tokenizer(system_text)) if system_text else 0
|
system_tokens = len(tokenizer(system_text)) if system_text else 0
|
||||||
|
|
||||||
@@ -351,6 +352,7 @@ def llm_contextual_chat_answer(
|
|||||||
final_msg_token_count=last_user_msg_tokens,
|
final_msg_token_count=last_user_msg_tokens,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Good Debug/Breakpoint
|
||||||
tokens = llm.stream(prompt)
|
tokens = llm.stream(prompt)
|
||||||
links = [
|
links = [
|
||||||
chunk.source_links[0] if chunk.source_links else None
|
chunk.source_links[0] if chunk.source_links else None
|
||||||
@@ -371,7 +373,7 @@ def llm_tools_enabled_chat_answer(
|
|||||||
tokenizer: Callable,
|
tokenizer: Callable,
|
||||||
) -> Iterator[str | list[InferenceChunk]]:
|
) -> Iterator[str | list[InferenceChunk]]:
|
||||||
retrieval_enabled = persona.retrieval_enabled
|
retrieval_enabled = persona.retrieval_enabled
|
||||||
system_text = persona.system_text
|
system_text = build_system_text_from_persona(persona)
|
||||||
hint_text = persona.hint_text
|
hint_text = persona.hint_text
|
||||||
tool_text = form_tool_section_text(persona.tools, persona.retrieval_enabled)
|
tool_text = form_tool_section_text(persona.tools, persona.retrieval_enabled)
|
||||||
|
|
||||||
|
@@ -1,3 +1,4 @@
|
|||||||
|
from datetime import datetime
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
import yaml
|
import yaml
|
||||||
@@ -6,9 +7,25 @@ from sqlalchemy.orm import Session
|
|||||||
from danswer.configs.app_configs import PERSONAS_YAML
|
from danswer.configs.app_configs import PERSONAS_YAML
|
||||||
from danswer.db.chat import upsert_persona
|
from danswer.db.chat import upsert_persona
|
||||||
from danswer.db.engine import get_sqlalchemy_engine
|
from danswer.db.engine import get_sqlalchemy_engine
|
||||||
|
from danswer.db.models import Persona
|
||||||
from danswer.db.models import ToolInfo
|
from danswer.db.models import ToolInfo
|
||||||
|
|
||||||
|
|
||||||
|
def build_system_text_from_persona(persona: Persona) -> str | None:
|
||||||
|
text = (persona.system_text or "").strip()
|
||||||
|
if persona.datetime_aware:
|
||||||
|
current_datetime = datetime.now()
|
||||||
|
# Format looks like: "October 16, 2023 14:30"
|
||||||
|
formatted_datetime = current_datetime.strftime("%B %d, %Y %H:%M")
|
||||||
|
|
||||||
|
text += (
|
||||||
|
"\n\nAdditional Information:\n"
|
||||||
|
f"\t- The current date and time is {formatted_datetime}."
|
||||||
|
)
|
||||||
|
|
||||||
|
return text or None
|
||||||
|
|
||||||
|
|
||||||
def validate_tool_info(item: Any) -> ToolInfo:
|
def validate_tool_info(item: Any) -> ToolInfo:
|
||||||
if not (
|
if not (
|
||||||
isinstance(item, dict)
|
isinstance(item, dict)
|
||||||
@@ -33,12 +50,17 @@ def load_personas_from_yaml(personas_yaml: str = PERSONAS_YAML) -> None:
|
|||||||
tools = [validate_tool_info(tool) for tool in persona["tools"]]
|
tools = [validate_tool_info(tool) for tool in persona["tools"]]
|
||||||
|
|
||||||
upsert_persona(
|
upsert_persona(
|
||||||
persona_id=persona["id"],
|
|
||||||
name=persona["name"],
|
name=persona["name"],
|
||||||
retrieval_enabled=persona["retrieval_enabled"],
|
retrieval_enabled=persona.get("retrieval_enabled", True),
|
||||||
system_text=persona["system"],
|
# Default to knowing the date/time if not specified, however if there is no
|
||||||
|
# system prompt, do not interfere with the flow by adding a
|
||||||
|
# system prompt that is ONLY the date info, this would likely not be useful
|
||||||
|
datetime_aware=persona.get(
|
||||||
|
"datetime_aware", bool(persona.get("system"))
|
||||||
|
),
|
||||||
|
system_text=persona.get("system"),
|
||||||
tools=tools,
|
tools=tools,
|
||||||
hint_text=persona["hint"],
|
hint_text=persona.get("hint"),
|
||||||
default_persona=True,
|
default_persona=True,
|
||||||
db_session=db_session,
|
db_session=db_session,
|
||||||
)
|
)
|
||||||
|
@@ -1,6 +1,5 @@
|
|||||||
personas:
|
personas:
|
||||||
- id: 1
|
- name: "Danswer"
|
||||||
name: "Danswer"
|
|
||||||
system: |
|
system: |
|
||||||
You are a question answering system that is constantly learning and improving.
|
You are a question answering system that is constantly learning and improving.
|
||||||
You can process and comprehend vast amounts of text and utilize this knowledge to provide accurate and detailed answers to diverse queries.
|
You can process and comprehend vast amounts of text and utilize this knowledge to provide accurate and detailed answers to diverse queries.
|
||||||
@@ -8,6 +7,9 @@ personas:
|
|||||||
Cite relevant statements using the format [1], [2], etc to reference the document number, do not provide any links following the citation.
|
Cite relevant statements using the format [1], [2], etc to reference the document number, do not provide any links following the citation.
|
||||||
# Danswer custom tool flow, "Current Search" tool name is reserved if this is enabled.
|
# Danswer custom tool flow, "Current Search" tool name is reserved if this is enabled.
|
||||||
retrieval_enabled: true
|
retrieval_enabled: true
|
||||||
|
# Inject a statement at the end of system prompt to inform the LLM of the current date/time
|
||||||
|
# Format looks like: "October 16, 2023 14:30"
|
||||||
|
datetime_aware: true
|
||||||
# Example of adding tools, it must follow this structure:
|
# Example of adding tools, it must follow this structure:
|
||||||
# tools:
|
# tools:
|
||||||
# - name: "Calculator"
|
# - name: "Calculator"
|
||||||
|
@@ -265,22 +265,38 @@ def fetch_persona_by_id(persona_id: int, db_session: Session) -> Persona:
|
|||||||
return persona
|
return persona
|
||||||
|
|
||||||
|
|
||||||
|
def fetch_default_persona_by_name(
|
||||||
|
persona_name: str, db_session: Session
|
||||||
|
) -> Persona | None:
|
||||||
|
stmt = select(Persona).where(
|
||||||
|
Persona.name == persona_name, Persona.default_persona == True # noqa: E712
|
||||||
|
)
|
||||||
|
result = db_session.execute(stmt).scalar_one_or_none()
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
def upsert_persona(
|
def upsert_persona(
|
||||||
persona_id: int | None,
|
|
||||||
name: str,
|
name: str,
|
||||||
retrieval_enabled: bool,
|
retrieval_enabled: bool,
|
||||||
|
datetime_aware: bool,
|
||||||
system_text: str | None,
|
system_text: str | None,
|
||||||
tools: list[ToolInfo] | None,
|
tools: list[ToolInfo] | None,
|
||||||
hint_text: str | None,
|
hint_text: str | None,
|
||||||
default_persona: bool,
|
|
||||||
db_session: Session,
|
db_session: Session,
|
||||||
|
persona_id: int | None = None,
|
||||||
|
default_persona: bool = False,
|
||||||
commit: bool = True,
|
commit: bool = True,
|
||||||
) -> Persona:
|
) -> Persona:
|
||||||
persona = db_session.query(Persona).filter_by(id=persona_id).first()
|
persona = db_session.query(Persona).filter_by(id=persona_id).first()
|
||||||
|
|
||||||
|
# Default personas are defined via yaml files at deployment time
|
||||||
|
if persona is None and default_persona:
|
||||||
|
persona = fetch_default_persona_by_name(name, db_session)
|
||||||
|
|
||||||
if persona:
|
if persona:
|
||||||
persona.name = name
|
persona.name = name
|
||||||
persona.retrieval_enabled = retrieval_enabled
|
persona.retrieval_enabled = retrieval_enabled
|
||||||
|
persona.datetime_aware = datetime_aware
|
||||||
persona.system_text = system_text
|
persona.system_text = system_text
|
||||||
persona.tools = tools
|
persona.tools = tools
|
||||||
persona.hint_text = hint_text
|
persona.hint_text = hint_text
|
||||||
@@ -289,6 +305,7 @@ def upsert_persona(
|
|||||||
persona = Persona(
|
persona = Persona(
|
||||||
name=name,
|
name=name,
|
||||||
retrieval_enabled=retrieval_enabled,
|
retrieval_enabled=retrieval_enabled,
|
||||||
|
datetime_aware=datetime_aware,
|
||||||
system_text=system_text,
|
system_text=system_text,
|
||||||
tools=tools,
|
tools=tools,
|
||||||
hint_text=hint_text,
|
hint_text=hint_text,
|
||||||
|
@@ -459,10 +459,12 @@ class ToolInfo(TypedDict):
|
|||||||
class Persona(Base):
|
class Persona(Base):
|
||||||
# TODO introduce user and group ownership for personas
|
# TODO introduce user and group ownership for personas
|
||||||
__tablename__ = "persona"
|
__tablename__ = "persona"
|
||||||
|
|
||||||
id: Mapped[int] = mapped_column(primary_key=True)
|
id: Mapped[int] = mapped_column(primary_key=True)
|
||||||
name: Mapped[str] = mapped_column(String)
|
name: Mapped[str] = mapped_column(String)
|
||||||
# Danswer retrieval, treated as a special tool
|
# Danswer retrieval, treated as a special tool
|
||||||
retrieval_enabled: Mapped[bool] = mapped_column(Boolean, default=True)
|
retrieval_enabled: Mapped[bool] = mapped_column(Boolean, default=True)
|
||||||
|
datetime_aware: Mapped[bool] = mapped_column(Boolean, default=True)
|
||||||
system_text: Mapped[str | None] = mapped_column(Text, nullable=True)
|
system_text: Mapped[str | None] = mapped_column(Text, nullable=True)
|
||||||
tools: Mapped[list[ToolInfo] | None] = mapped_column(
|
tools: Mapped[list[ToolInfo] | None] = mapped_column(
|
||||||
postgresql.JSONB(), nullable=True
|
postgresql.JSONB(), nullable=True
|
||||||
@@ -480,6 +482,16 @@ class Persona(Base):
|
|||||||
back_populates="personas",
|
back_populates="personas",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Default personas loaded via yaml cannot have the same name
|
||||||
|
__table_args__ = (
|
||||||
|
Index(
|
||||||
|
"_default_persona_name_idx",
|
||||||
|
"name",
|
||||||
|
unique=True,
|
||||||
|
postgresql_where=(default_persona == True), # noqa: E712
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class ChatMessage(Base):
|
class ChatMessage(Base):
|
||||||
__tablename__ = "chat_message"
|
__tablename__ = "chat_message"
|
||||||
|
Reference in New Issue
Block a user