mirror of
https://github.com/danswer-ai/danswer.git
synced 2025-09-28 12:58:41 +02:00
Add model choice to Persona
This commit is contained in:
@@ -0,0 +1,26 @@
|
||||
"""Add llm_model_version_override to Persona
|
||||
|
||||
Revision ID: baf71f781b9e
|
||||
Revises: 50b683a8295c
|
||||
Create Date: 2023-12-06 21:56:50.286158
|
||||
|
||||
"""
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "baf71f781b9e"
|
||||
down_revision = "50b683a8295c"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
op.add_column(
|
||||
"persona",
|
||||
sa.Column("llm_model_version_override", sa.String(), nullable=True),
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.drop_column("persona", "llm_model_version_override")
|
@@ -324,6 +324,7 @@ def upsert_persona(
|
||||
persona_id: int | None = None,
|
||||
default_persona: bool = False,
|
||||
document_sets: list[DocumentSetDBModel] | None = None,
|
||||
llm_model_version_override: str | None = None,
|
||||
commit: bool = True,
|
||||
overwrite_duplicate_named_persona: bool = False,
|
||||
) -> Persona:
|
||||
@@ -355,6 +356,7 @@ def upsert_persona(
|
||||
persona.num_chunks = num_chunks
|
||||
persona.apply_llm_relevance_filter = apply_llm_relevance_filter
|
||||
persona.default_persona = default_persona
|
||||
persona.llm_model_version_override = llm_model_version_override
|
||||
|
||||
# Do not delete any associations manually added unless
|
||||
# a new updated list is provided
|
||||
@@ -375,6 +377,7 @@ def upsert_persona(
|
||||
apply_llm_relevance_filter=apply_llm_relevance_filter,
|
||||
default_persona=default_persona,
|
||||
document_sets=document_sets if document_sets else [],
|
||||
llm_model_version_override=llm_model_version_override,
|
||||
)
|
||||
db_session.add(persona)
|
||||
|
||||
|
@@ -603,6 +603,13 @@ class Persona(Base):
|
||||
apply_llm_relevance_filter: Mapped[bool | None] = mapped_column(
|
||||
Boolean, nullable=True
|
||||
)
|
||||
# allows the Persona to specify a different LLM version than is controlled
|
||||
# globablly via env variables. For flexibility, validity is not currently enforced
|
||||
# NOTE: only is applied on the actual response generation - is not used for things like
|
||||
# auto-detected time filters, relevance filters, etc.
|
||||
llm_model_version_override: Mapped[str | None] = mapped_column(
|
||||
String, nullable=True
|
||||
)
|
||||
# Default personas are configured via backend during deployment
|
||||
# Treated specially (cannot be user edited etc.)
|
||||
default_persona: Mapped[bool] = mapped_column(Boolean, default=False)
|
||||
|
@@ -54,7 +54,11 @@ def get_qa_model_for_persona(
|
||||
timeout: int = QA_TIMEOUT,
|
||||
) -> QAModel:
|
||||
return QABlock(
|
||||
llm=get_default_llm(api_key=api_key, timeout=timeout),
|
||||
llm=get_default_llm(
|
||||
api_key=api_key,
|
||||
timeout=timeout,
|
||||
gen_ai_model_version_override=persona.llm_model_version_override,
|
||||
),
|
||||
qa_handler=PersonaBasedQAHandler(
|
||||
system_prompt=persona.system_text or "", task_prompt=persona.hint_text or ""
|
||||
),
|
||||
|
@@ -14,10 +14,16 @@ def get_default_llm(
|
||||
api_key: str | None = None,
|
||||
timeout: int = QA_TIMEOUT,
|
||||
use_fast_llm: bool = False,
|
||||
gen_ai_model_version_override: str | None = None,
|
||||
) -> LLM:
|
||||
"""A single place to fetch the configured LLM for Danswer
|
||||
Also allows overriding certain LLM defaults"""
|
||||
model_version = FAST_GEN_AI_MODEL_VERSION if use_fast_llm else GEN_AI_MODEL_VERSION
|
||||
if gen_ai_model_version_override:
|
||||
model_version = gen_ai_model_version_override
|
||||
else:
|
||||
model_version = (
|
||||
FAST_GEN_AI_MODEL_VERSION if use_fast_llm else GEN_AI_MODEL_VERSION
|
||||
)
|
||||
if api_key is None:
|
||||
api_key = get_gen_ai_api_key()
|
||||
|
||||
|
@@ -5,6 +5,8 @@ from sqlalchemy.orm import Session
|
||||
|
||||
from danswer.auth.users import current_admin_user
|
||||
from danswer.auth.users import current_user
|
||||
from danswer.configs.model_configs import GEN_AI_MODEL_PROVIDER
|
||||
from danswer.configs.model_configs import GEN_AI_MODEL_VERSION
|
||||
from danswer.db.chat import fetch_persona_by_id
|
||||
from danswer.db.chat import fetch_personas
|
||||
from danswer.db.chat import mark_persona_as_deleted
|
||||
@@ -50,6 +52,7 @@ def create_persona(
|
||||
num_chunks=create_persona_request.num_chunks,
|
||||
apply_llm_relevance_filter=create_persona_request.apply_llm_relevance_filter,
|
||||
document_sets=document_sets,
|
||||
llm_model_version_override=create_persona_request.llm_model_version_override,
|
||||
)
|
||||
except ValueError as e:
|
||||
logger.exception("Failed to update persona")
|
||||
@@ -84,6 +87,7 @@ def update_persona(
|
||||
num_chunks=update_persona_request.num_chunks,
|
||||
apply_llm_relevance_filter=update_persona_request.apply_llm_relevance_filter,
|
||||
document_sets=document_sets,
|
||||
llm_model_version_override=update_persona_request.llm_model_version_override,
|
||||
persona_id=persona_id,
|
||||
)
|
||||
except ValueError as e:
|
||||
@@ -134,3 +138,47 @@ def build_final_template_prompt(
|
||||
system_prompt=system_prompt, task_prompt=task_prompt
|
||||
).build_dummy_prompt()
|
||||
)
|
||||
|
||||
|
||||
"""Utility endpoints for selecting which model to use for a persona.
|
||||
Putting here for now, since we have no other flows which use this."""
|
||||
|
||||
GPT_4_MODEL_VERSIONS = [
|
||||
"gpt-4-1106-preview",
|
||||
"gpt-4",
|
||||
"gpt-4-32k",
|
||||
"gpt-4-0613",
|
||||
"gpt-4-32k-0613",
|
||||
"gpt-4-0314",
|
||||
"gpt-4-32k-0314",
|
||||
]
|
||||
GPT_3_5_TURBO_MODEL_VERSIONS = [
|
||||
"gpt-3.5-turbo-1106",
|
||||
"gpt-3.5-turbo",
|
||||
"gpt-3.5-turbo-16k",
|
||||
"gpt-3.5-turbo-0613",
|
||||
"gpt-3.5-turbo-16k-0613",
|
||||
"gpt-3.5-turbo-0301",
|
||||
]
|
||||
|
||||
|
||||
@router.get("/persona-utils/list-available-models")
|
||||
def list_available_model_versions(
|
||||
_: User | None = Depends(current_admin_user),
|
||||
) -> list[str]:
|
||||
# currently only support selecting different models for OpenAI
|
||||
if GEN_AI_MODEL_PROVIDER != "openai":
|
||||
return []
|
||||
|
||||
return GPT_4_MODEL_VERSIONS + GPT_3_5_TURBO_MODEL_VERSIONS
|
||||
|
||||
|
||||
@router.get("/persona-utils/default-model")
|
||||
def get_default_model(
|
||||
_: User | None = Depends(current_admin_user),
|
||||
) -> str:
|
||||
# currently only support selecting different models for OpenAI
|
||||
if GEN_AI_MODEL_PROVIDER != "openai":
|
||||
return ""
|
||||
|
||||
return GEN_AI_MODEL_VERSION
|
||||
|
@@ -12,6 +12,7 @@ class CreatePersonaRequest(BaseModel):
|
||||
task_prompt: str
|
||||
num_chunks: int | None = None
|
||||
apply_llm_relevance_filter: bool | None = None
|
||||
llm_model_version_override: str | None = None
|
||||
|
||||
|
||||
class PersonaSnapshot(BaseModel):
|
||||
@@ -21,6 +22,7 @@ class PersonaSnapshot(BaseModel):
|
||||
system_prompt: str
|
||||
task_prompt: str
|
||||
document_sets: list[DocumentSet]
|
||||
llm_model_version_override: str | None
|
||||
|
||||
@classmethod
|
||||
def from_model(cls, persona: Persona) -> "PersonaSnapshot":
|
||||
@@ -34,6 +36,7 @@ class PersonaSnapshot(BaseModel):
|
||||
DocumentSet.from_model(document_set_model)
|
||||
for document_set_model in persona.document_sets
|
||||
],
|
||||
llm_model_version_override=persona.llm_model_version_override,
|
||||
)
|
||||
|
||||
|
||||
|
Reference in New Issue
Block a user