mirror of
https://github.com/danswer-ai/danswer.git
synced 2025-04-08 20:08:36 +02:00
Prompting (#3372)
* auto generate start prompts * post rebase clean up * update for clarity
This commit is contained in:
parent
1df6a506ec
commit
2847ab003e
@ -63,6 +63,10 @@ LANGUAGE_CHAT_NAMING_HINT = (
|
||||
or "The name of the conversation must be in the same language as the user query."
|
||||
)
|
||||
|
||||
# Number of prompts each persona should have
|
||||
NUM_PERSONA_PROMPTS = 4
|
||||
NUM_PERSONA_PROMPT_GENERATION_CHUNKS = 5
|
||||
|
||||
# Agentic search takes significantly more tokens and therefore has much higher cost.
|
||||
# This configuration allows users to get a search-only experience with instant results
|
||||
# and no involvement from the LLM.
|
||||
|
@ -5,6 +5,8 @@ from typing import Literal
|
||||
from typing import NotRequired
|
||||
from typing import Optional
|
||||
from uuid import uuid4
|
||||
|
||||
from pydantic import BaseModel
|
||||
from typing_extensions import TypedDict # noreorder
|
||||
from uuid import UUID
|
||||
|
||||
@ -1344,6 +1346,11 @@ class StarterMessage(TypedDict):
|
||||
message: str
|
||||
|
||||
|
||||
class StarterMessageModel(BaseModel):
|
||||
name: str
|
||||
message: str
|
||||
|
||||
|
||||
class Persona(Base):
|
||||
__tablename__ = "persona"
|
||||
|
||||
|
@ -369,6 +369,19 @@ class AdminCapable(abc.ABC):
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class RandomCapable(abc.ABC):
|
||||
"""Class must implement random document retrieval capability"""
|
||||
|
||||
@abc.abstractmethod
|
||||
def random_retrieval(
|
||||
self,
|
||||
filters: IndexFilters,
|
||||
num_to_retrieve: int = 10,
|
||||
) -> list[InferenceChunkUncleaned]:
|
||||
"""Retrieve random chunks matching the filters"""
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class BaseIndex(
|
||||
Verifiable,
|
||||
Indexable,
|
||||
@ -376,6 +389,7 @@ class BaseIndex(
|
||||
Deletable,
|
||||
AdminCapable,
|
||||
IdRetrievalCapable,
|
||||
RandomCapable,
|
||||
abc.ABC,
|
||||
):
|
||||
"""
|
||||
|
@ -218,4 +218,10 @@ schema DANSWER_CHUNK_NAME {
|
||||
expression: bm25(content) + (5 * bm25(title))
|
||||
}
|
||||
}
|
||||
|
||||
rank-profile random_ {
|
||||
first-phase {
|
||||
expression: random.match
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -2,6 +2,7 @@ import concurrent.futures
|
||||
import io
|
||||
import logging
|
||||
import os
|
||||
import random
|
||||
import re
|
||||
import time
|
||||
import urllib
|
||||
@ -903,6 +904,32 @@ class VespaIndex(DocumentIndex):
|
||||
|
||||
logger.info("Batch deletion completed")
|
||||
|
||||
def random_retrieval(
|
||||
self,
|
||||
filters: IndexFilters,
|
||||
num_to_retrieve: int = 10,
|
||||
) -> list[InferenceChunkUncleaned]:
|
||||
"""Retrieve random chunks matching the filters using Vespa's random ranking
|
||||
|
||||
This method is currently used for random chunk retrieval in the context of
|
||||
assistant starter message creation (passed as sample context for usage by the assistant).
|
||||
"""
|
||||
vespa_where_clauses = build_vespa_filters(filters, remove_trailing_and=True)
|
||||
|
||||
yql = YQL_BASE.format(index_name=self.index_name) + vespa_where_clauses
|
||||
|
||||
random_seed = random.randint(0, 1000000)
|
||||
|
||||
params: dict[str, str | int | float] = {
|
||||
"yql": yql,
|
||||
"hits": num_to_retrieve,
|
||||
"timeout": VESPA_TIMEOUT,
|
||||
"ranking.profile": "random_",
|
||||
"ranking.properties.random.seed": random_seed,
|
||||
}
|
||||
|
||||
return query_vespa(params)
|
||||
|
||||
|
||||
class _VespaDeleteRequest:
|
||||
def __init__(self, document_id: str, index_name: str) -> None:
|
||||
|
@ -19,7 +19,12 @@ from onyx.utils.logger import setup_logger
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
def build_vespa_filters(filters: IndexFilters, include_hidden: bool = False) -> str:
|
||||
def build_vespa_filters(
|
||||
filters: IndexFilters,
|
||||
*,
|
||||
include_hidden: bool = False,
|
||||
remove_trailing_and: bool = False, # Set to True when using as a complete Vespa query
|
||||
) -> str:
|
||||
def _build_or_filters(key: str, vals: list[str] | None) -> str:
|
||||
if vals is None:
|
||||
return ""
|
||||
@ -78,6 +83,9 @@ def build_vespa_filters(filters: IndexFilters, include_hidden: bool = False) ->
|
||||
|
||||
filter_str += _build_time_filter(filters.time_cutoff)
|
||||
|
||||
if remove_trailing_and and filter_str.endswith(" and "):
|
||||
filter_str = filter_str[:-5] # We remove the trailing " and "
|
||||
|
||||
return filter_str
|
||||
|
||||
|
||||
|
46
backend/onyx/prompts/starter_messages.py
Normal file
46
backend/onyx/prompts/starter_messages.py
Normal file
@ -0,0 +1,46 @@
|
||||
PERSONA_CATEGORY_GENERATION_PROMPT = """
|
||||
Based on the assistant's name, description, and instructions, generate a list of {num_categories}
|
||||
**unique and diverse** categories that represent different types of starter messages a user
|
||||
might send to initiate a conversation with this chatbot assistant.
|
||||
|
||||
**Ensure that the categories are varied and cover a wide range of topics related to the assistant's capabilities.**
|
||||
|
||||
Provide the categories as a JSON array of strings **without any code fences or additional text**.
|
||||
|
||||
**Context about the assistant:**
|
||||
- **Name**: {name}
|
||||
- **Description**: {description}
|
||||
- **Instructions**: {instructions}
|
||||
""".strip()
|
||||
|
||||
PERSONA_STARTER_MESSAGE_CREATION_PROMPT = """
|
||||
Create a starter message that a **user** might send to initiate a conversation with a chatbot assistant.
|
||||
|
||||
**Category**: {category}
|
||||
|
||||
Your response should include two parts:
|
||||
|
||||
1. **Title**: A short, engaging title that reflects the user's intent
|
||||
(e.g., 'Need Travel Advice', 'Question About Coding', 'Looking for Book Recommendations').
|
||||
|
||||
2. **Message**: The actual message that the user would send to the assistant.
|
||||
This should be natural, engaging, and encourage a helpful response from the assistant.
|
||||
**Avoid overly specific details; keep the message general and broadly applicable.**
|
||||
|
||||
For example:
|
||||
- Instead of "I've just adopted a 6-month-old Labrador puppy who's pulling on the leash,"
|
||||
write "I'm having trouble training my new puppy to walk nicely on a leash."
|
||||
|
||||
Ensure each part is clearly labeled and separated as shown above.
|
||||
Do not provide any additional text or explanation and be extremely concise
|
||||
|
||||
**Context about the assistant:**
|
||||
- **Name**: {name}
|
||||
- **Description**: {description}
|
||||
- **Instructions**: {instructions}
|
||||
""".strip()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
print(PERSONA_CATEGORY_GENERATION_PROMPT)
|
||||
print(PERSONA_STARTER_MESSAGE_CREATION_PROMPT)
|
271
backend/onyx/secondary_llm_flows/starter_message_creation.py
Normal file
271
backend/onyx/secondary_llm_flows/starter_message_creation.py
Normal file
@ -0,0 +1,271 @@
|
||||
import json
|
||||
import re
|
||||
from typing import Any
|
||||
from typing import cast
|
||||
from typing import Dict
|
||||
from typing import List
|
||||
|
||||
from litellm import get_supported_openai_params
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.configs.chat_configs import NUM_PERSONA_PROMPT_GENERATION_CHUNKS
|
||||
from onyx.configs.chat_configs import NUM_PERSONA_PROMPTS
|
||||
from onyx.context.search.models import IndexFilters
|
||||
from onyx.context.search.models import InferenceChunk
|
||||
from onyx.context.search.postprocessing.postprocessing import cleanup_chunks
|
||||
from onyx.context.search.preprocessing.access_filters import (
|
||||
build_access_filters_for_user,
|
||||
)
|
||||
from onyx.db.document_set import get_document_sets_by_ids
|
||||
from onyx.db.models import StarterMessageModel as StarterMessage
|
||||
from onyx.db.models import User
|
||||
from onyx.document_index.document_index_utils import get_both_index_names
|
||||
from onyx.document_index.factory import get_default_document_index
|
||||
from onyx.llm.factory import get_default_llms
|
||||
from onyx.prompts.starter_messages import PERSONA_CATEGORY_GENERATION_PROMPT
|
||||
from onyx.prompts.starter_messages import PERSONA_STARTER_MESSAGE_CREATION_PROMPT
|
||||
from onyx.utils.logger import setup_logger
|
||||
from onyx.utils.threadpool_concurrency import FunctionCall
|
||||
from onyx.utils.threadpool_concurrency import run_functions_in_parallel
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
def get_random_chunks_from_doc_sets(
|
||||
doc_sets: List[str], db_session: Session, user: User | None = None
|
||||
) -> List[InferenceChunk]:
|
||||
"""
|
||||
Retrieves random chunks from the specified document sets.
|
||||
"""
|
||||
curr_ind_name, sec_ind_name = get_both_index_names(db_session)
|
||||
document_index = get_default_document_index(curr_ind_name, sec_ind_name)
|
||||
|
||||
acl_filters = build_access_filters_for_user(user, db_session)
|
||||
filters = IndexFilters(document_set=doc_sets, access_control_list=acl_filters)
|
||||
|
||||
chunks = document_index.random_retrieval(
|
||||
filters=filters, num_to_retrieve=NUM_PERSONA_PROMPT_GENERATION_CHUNKS
|
||||
)
|
||||
return cleanup_chunks(chunks)
|
||||
|
||||
|
||||
def parse_categories(content: str) -> List[str]:
|
||||
"""
|
||||
Parses the JSON array of categories from the LLM response.
|
||||
"""
|
||||
# Clean the response to remove code fences and extra whitespace
|
||||
content = content.strip().strip("```").strip()
|
||||
if content.startswith("json"):
|
||||
content = content[4:].strip()
|
||||
|
||||
try:
|
||||
categories = json.loads(content)
|
||||
if not isinstance(categories, list):
|
||||
logger.error("Categories are not a list.")
|
||||
return []
|
||||
return categories
|
||||
except json.JSONDecodeError as e:
|
||||
logger.error(f"Failed to parse categories: {e}")
|
||||
return []
|
||||
|
||||
|
||||
def generate_start_message_prompts(
|
||||
name: str,
|
||||
description: str,
|
||||
instructions: str,
|
||||
categories: List[str],
|
||||
chunk_contents: str,
|
||||
supports_structured_output: bool,
|
||||
fast_llm: Any,
|
||||
) -> List[FunctionCall]:
|
||||
"""
|
||||
Generates the list of FunctionCall objects for starter message generation.
|
||||
"""
|
||||
functions = []
|
||||
for category in categories:
|
||||
# Create a prompt specific to the category
|
||||
start_message_generation_prompt = (
|
||||
PERSONA_STARTER_MESSAGE_CREATION_PROMPT.format(
|
||||
name=name,
|
||||
description=description,
|
||||
instructions=instructions,
|
||||
category=category,
|
||||
)
|
||||
)
|
||||
|
||||
if chunk_contents:
|
||||
start_message_generation_prompt += (
|
||||
"\n\nExample content this assistant has access to:\n"
|
||||
"'''\n"
|
||||
f"{chunk_contents}"
|
||||
"\n'''"
|
||||
)
|
||||
|
||||
if supports_structured_output:
|
||||
functions.append(
|
||||
FunctionCall(
|
||||
fast_llm.invoke,
|
||||
(start_message_generation_prompt, None, None, StarterMessage),
|
||||
)
|
||||
)
|
||||
else:
|
||||
functions.append(
|
||||
FunctionCall(
|
||||
fast_llm.invoke,
|
||||
(start_message_generation_prompt,),
|
||||
)
|
||||
)
|
||||
return functions
|
||||
|
||||
|
||||
def parse_unstructured_output(output: str) -> Dict[str, str]:
|
||||
"""
|
||||
Parses the assistant's unstructured output into a dictionary with keys:
|
||||
- 'name' (Title)
|
||||
- 'message' (Message)
|
||||
"""
|
||||
|
||||
# Debug output
|
||||
logger.debug(f"LLM Output for starter message creation: {output}")
|
||||
|
||||
# Patterns to match
|
||||
title_pattern = r"(?i)^\**Title\**\s*:\s*(.+)"
|
||||
message_pattern = r"(?i)^\**Message\**\s*:\s*(.+)"
|
||||
|
||||
# Initialize the response dictionary
|
||||
response_dict = {}
|
||||
|
||||
# Split the output into lines
|
||||
lines = output.strip().split("\n")
|
||||
|
||||
# Variables to keep track of the current key being processed
|
||||
current_key = None
|
||||
current_value_lines = []
|
||||
|
||||
for line in lines:
|
||||
# Check for title
|
||||
title_match = re.match(title_pattern, line.strip())
|
||||
if title_match:
|
||||
# Save previous key-value pair if any
|
||||
if current_key and current_value_lines:
|
||||
response_dict[current_key] = " ".join(current_value_lines).strip()
|
||||
current_value_lines = []
|
||||
current_key = "name"
|
||||
current_value_lines.append(title_match.group(1).strip())
|
||||
continue
|
||||
|
||||
# Check for message
|
||||
message_match = re.match(message_pattern, line.strip())
|
||||
if message_match:
|
||||
if current_key and current_value_lines:
|
||||
response_dict[current_key] = " ".join(current_value_lines).strip()
|
||||
current_value_lines = []
|
||||
current_key = "message"
|
||||
current_value_lines.append(message_match.group(1).strip())
|
||||
continue
|
||||
|
||||
# If the line doesn't match a new key, append it to the current value
|
||||
if current_key:
|
||||
current_value_lines.append(line.strip())
|
||||
|
||||
# Add the last key-value pair
|
||||
if current_key and current_value_lines:
|
||||
response_dict[current_key] = " ".join(current_value_lines).strip()
|
||||
|
||||
# Validate that the necessary keys are present
|
||||
if not all(k in response_dict for k in ["name", "message"]):
|
||||
raise ValueError("Failed to parse the assistant's response.")
|
||||
|
||||
return response_dict
|
||||
|
||||
|
||||
def generate_starter_messages(
|
||||
name: str,
|
||||
description: str,
|
||||
instructions: str,
|
||||
document_set_ids: List[int],
|
||||
db_session: Session,
|
||||
user: User | None,
|
||||
) -> List[StarterMessage]:
|
||||
"""
|
||||
Generates starter messages by first obtaining categories and then generating messages for each category.
|
||||
On failure, returns an empty list (or list with processed starter messages if some messages are processed successfully).
|
||||
"""
|
||||
_, fast_llm = get_default_llms(temperature=0.5)
|
||||
|
||||
provider = fast_llm.config.model_provider
|
||||
model = fast_llm.config.model_name
|
||||
|
||||
params = get_supported_openai_params(model=model, custom_llm_provider=provider)
|
||||
supports_structured_output = (
|
||||
isinstance(params, list) and "response_format" in params
|
||||
)
|
||||
|
||||
# Generate categories
|
||||
category_generation_prompt = PERSONA_CATEGORY_GENERATION_PROMPT.format(
|
||||
name=name,
|
||||
description=description,
|
||||
instructions=instructions,
|
||||
num_categories=NUM_PERSONA_PROMPTS,
|
||||
)
|
||||
|
||||
category_response = fast_llm.invoke(category_generation_prompt)
|
||||
categories = parse_categories(cast(str, category_response.content))
|
||||
|
||||
if not categories:
|
||||
logger.error("No categories were generated.")
|
||||
return []
|
||||
|
||||
# Fetch example content if document sets are provided
|
||||
if document_set_ids:
|
||||
document_sets = get_document_sets_by_ids(
|
||||
document_set_ids=document_set_ids,
|
||||
db_session=db_session,
|
||||
)
|
||||
|
||||
chunks = get_random_chunks_from_doc_sets(
|
||||
doc_sets=[doc_set.name for doc_set in document_sets],
|
||||
db_session=db_session,
|
||||
user=user,
|
||||
)
|
||||
|
||||
# Add example content context
|
||||
chunk_contents = "\n".join(chunk.content.strip() for chunk in chunks)
|
||||
else:
|
||||
chunk_contents = ""
|
||||
|
||||
# Generate prompts for starter messages
|
||||
functions = generate_start_message_prompts(
|
||||
name,
|
||||
description,
|
||||
instructions,
|
||||
categories,
|
||||
chunk_contents,
|
||||
supports_structured_output,
|
||||
fast_llm,
|
||||
)
|
||||
|
||||
# Run LLM calls in parallel
|
||||
if not functions:
|
||||
logger.error("No functions to execute for starter message generation.")
|
||||
return []
|
||||
|
||||
results = run_functions_in_parallel(function_calls=functions)
|
||||
prompts = []
|
||||
|
||||
for response in results.values():
|
||||
try:
|
||||
if supports_structured_output:
|
||||
response_dict = json.loads(response.content)
|
||||
else:
|
||||
response_dict = parse_unstructured_output(response.content)
|
||||
starter_message = StarterMessage(
|
||||
name=response_dict["name"],
|
||||
message=response_dict["message"],
|
||||
)
|
||||
prompts.append(starter_message)
|
||||
except (json.JSONDecodeError, ValueError) as e:
|
||||
logger.error(f"Failed to parse starter message: {e}")
|
||||
continue
|
||||
|
||||
return prompts
|
@ -19,6 +19,7 @@ from onyx.configs.constants import MilestoneRecordType
|
||||
from onyx.configs.constants import NotificationType
|
||||
from onyx.db.engine import get_current_tenant_id
|
||||
from onyx.db.engine import get_session
|
||||
from onyx.db.models import StarterMessageModel as StarterMessage
|
||||
from onyx.db.models import User
|
||||
from onyx.db.notification import create_notification
|
||||
from onyx.db.persona import create_assistant_category
|
||||
@ -36,7 +37,11 @@ from onyx.db.persona import update_persona_shared_users
|
||||
from onyx.db.persona import update_persona_visibility
|
||||
from onyx.file_store.file_store import get_default_file_store
|
||||
from onyx.file_store.models import ChatFileType
|
||||
from onyx.secondary_llm_flows.starter_message_creation import (
|
||||
generate_starter_messages,
|
||||
)
|
||||
from onyx.server.features.persona.models import CreatePersonaRequest
|
||||
from onyx.server.features.persona.models import GenerateStarterMessageRequest
|
||||
from onyx.server.features.persona.models import ImageGenerationToolStatus
|
||||
from onyx.server.features.persona.models import PersonaCategoryCreate
|
||||
from onyx.server.features.persona.models import PersonaCategoryResponse
|
||||
@ -377,3 +382,26 @@ def build_final_template_prompt(
|
||||
retrieval_disabled=retrieval_disabled,
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
@basic_router.post("/assistant-prompt-refresh")
|
||||
def build_assistant_prompts(
|
||||
generate_persona_prompt_request: GenerateStarterMessageRequest,
|
||||
db_session: Session = Depends(get_session),
|
||||
user: User | None = Depends(current_user),
|
||||
) -> list[StarterMessage]:
|
||||
try:
|
||||
logger.info(
|
||||
"Generating starter messages for user: %s", user.id if user else "Anonymous"
|
||||
)
|
||||
return generate_starter_messages(
|
||||
name=generate_persona_prompt_request.name,
|
||||
description=generate_persona_prompt_request.description,
|
||||
instructions=generate_persona_prompt_request.instructions,
|
||||
document_set_ids=generate_persona_prompt_request.document_set_ids,
|
||||
db_session=db_session,
|
||||
user=user,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.exception("Failed to generate starter messages")
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
@ -17,6 +17,14 @@ from onyx.utils.logger import setup_logger
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
# More minimal request for generating a persona prompt
|
||||
class GenerateStarterMessageRequest(BaseModel):
|
||||
name: str
|
||||
description: str
|
||||
instructions: str
|
||||
document_set_ids: list[int]
|
||||
|
||||
|
||||
class CreatePersonaRequest(BaseModel):
|
||||
name: str
|
||||
description: str
|
||||
|
@ -75,7 +75,8 @@ export default function Page() {
|
||||
},
|
||||
{} as Record<SourceCategory, SourceMetadata[]>
|
||||
);
|
||||
}, [sources, searchTerm]);
|
||||
}, [sources, filterSources, searchTerm]);
|
||||
|
||||
const handleKeyPress = (e: React.KeyboardEvent<HTMLInputElement>) => {
|
||||
if (e.key === "Enter") {
|
||||
const filteredCategories = Object.entries(categorizedSources).filter(
|
||||
|
@ -1,7 +1,7 @@
|
||||
"use client";
|
||||
|
||||
import { Option } from "@/components/Dropdown";
|
||||
import { generateRandomIconShape, createSVG } from "@/lib/assistantIconUtils";
|
||||
|
||||
import { CCPairBasicInfo, DocumentSet, User } from "@/lib/types";
|
||||
import { Separator } from "@/components/ui/separator";
|
||||
import { Button } from "@/components/ui/button";
|
||||
@ -9,12 +9,11 @@ import { Textarea } from "@/components/ui/textarea";
|
||||
import { IsPublicGroupSelector } from "@/components/IsPublicGroupSelector";
|
||||
import {
|
||||
ArrayHelpers,
|
||||
ErrorMessage,
|
||||
Field,
|
||||
FieldArray,
|
||||
Form,
|
||||
Formik,
|
||||
FormikProps,
|
||||
useFormikContext,
|
||||
} from "formik";
|
||||
|
||||
import {
|
||||
@ -27,7 +26,6 @@ import {
|
||||
import { usePopup } from "@/components/admin/connectors/Popup";
|
||||
import { getDisplayNameForModel, useCategories } from "@/lib/hooks";
|
||||
import { DocumentSetSelectable } from "@/components/documentSet/DocumentSetSelectable";
|
||||
import { Option } from "@/components/Dropdown";
|
||||
import { addAssistantToList } from "@/lib/assistants/updateAssistantPreferences";
|
||||
import { checkLLMSupportsImageInput, destructureValue } from "@/lib/llm/utils";
|
||||
import { ToolSnapshot } from "@/lib/tools/interfaces";
|
||||
@ -41,10 +39,9 @@ import {
|
||||
} from "@/components/ui/tooltip";
|
||||
import Link from "next/link";
|
||||
import { useRouter } from "next/navigation";
|
||||
import { useEffect, useState } from "react";
|
||||
import { FiInfo, FiX } from "react-icons/fi";
|
||||
import { useEffect, useMemo, useState } from "react";
|
||||
import { FiInfo, FiRefreshCcw } from "react-icons/fi";
|
||||
import * as Yup from "yup";
|
||||
import { FullLLMProvider } from "../configuration/llm/interfaces";
|
||||
import CollapsibleSection from "./CollapsibleSection";
|
||||
import { SuccessfulPersonaUpdateRedirectType } from "./enums";
|
||||
import { Persona, PersonaCategory, StarterMessage } from "./interfaces";
|
||||
@ -66,6 +63,9 @@ import { AdvancedOptionsToggle } from "@/components/AdvancedOptionsToggle";
|
||||
import { buildImgUrl } from "@/app/chat/files/images/utils";
|
||||
import { LlmList } from "@/components/llm/LLMList";
|
||||
import { useAssistants } from "@/components/context/AssistantsContext";
|
||||
import { debounce } from "lodash";
|
||||
import { FullLLMProvider } from "../configuration/llm/interfaces";
|
||||
import StarterMessagesList from "./StarterMessageList";
|
||||
import { Input } from "@/components/ui/input";
|
||||
import { CategoryCard } from "./CategoryCard";
|
||||
|
||||
@ -129,12 +129,14 @@ export function AssistantEditor({
|
||||
];
|
||||
|
||||
const [showAdvancedOptions, setShowAdvancedOptions] = useState(false);
|
||||
const [hasEditedStarterMessage, setHasEditedStarterMessage] = useState(false);
|
||||
const [showPersonaCategory, setShowPersonaCategory] = useState(!admin);
|
||||
|
||||
// state to persist across formik reformatting
|
||||
const [defautIconColor, _setDeafultIconColor] = useState(
|
||||
colorOptions[Math.floor(Math.random() * colorOptions.length)]
|
||||
);
|
||||
const [isRefreshing, setIsRefreshing] = useState(false);
|
||||
|
||||
const [defaultIconShape, setDefaultIconShape] = useState<any>(null);
|
||||
|
||||
@ -148,6 +150,10 @@ export function AssistantEditor({
|
||||
|
||||
const [removePersonaImage, setRemovePersonaImage] = useState(false);
|
||||
|
||||
const autoStarterMessageEnabled = useMemo(
|
||||
() => llmProviders.length > 0,
|
||||
[llmProviders.length]
|
||||
);
|
||||
const isUpdate = existingPersona !== undefined && existingPersona !== null;
|
||||
const existingPrompt = existingPersona?.prompts[0] ?? null;
|
||||
const defaultProvider = llmProviders.find(
|
||||
@ -217,7 +223,24 @@ export function AssistantEditor({
|
||||
existingPersona?.llm_model_provider_override ?? null,
|
||||
llm_model_version_override:
|
||||
existingPersona?.llm_model_version_override ?? null,
|
||||
starter_messages: existingPersona?.starter_messages ?? [],
|
||||
starter_messages: existingPersona?.starter_messages ?? [
|
||||
{
|
||||
name: "",
|
||||
message: "",
|
||||
},
|
||||
{
|
||||
name: "",
|
||||
message: "",
|
||||
},
|
||||
{
|
||||
name: "",
|
||||
message: "",
|
||||
},
|
||||
{
|
||||
name: "",
|
||||
message: "",
|
||||
},
|
||||
],
|
||||
enabled_tools_map: enabledToolsMap,
|
||||
icon_color: existingPersona?.icon_color ?? defautIconColor,
|
||||
icon_shape: existingPersona?.icon_shape ?? defaultIconShape,
|
||||
@ -228,6 +251,44 @@ export function AssistantEditor({
|
||||
groups: existingPersona?.groups ?? [],
|
||||
};
|
||||
|
||||
interface AssistantPrompt {
|
||||
message: string;
|
||||
name: string;
|
||||
}
|
||||
|
||||
const debouncedRefreshPrompts = debounce(
|
||||
async (values: any, setFieldValue: any) => {
|
||||
if (!autoStarterMessageEnabled) {
|
||||
return;
|
||||
}
|
||||
setIsRefreshing(true);
|
||||
try {
|
||||
const response = await fetch("/api/persona/assistant-prompt-refresh", {
|
||||
method: "POST",
|
||||
headers: {
|
||||
"Content-Type": "application/json",
|
||||
},
|
||||
body: JSON.stringify({
|
||||
name: values.name,
|
||||
description: values.description,
|
||||
document_set_ids: values.document_set_ids,
|
||||
instructions: values.system_prompt || values.task_prompt,
|
||||
}),
|
||||
});
|
||||
|
||||
const data: AssistantPrompt = await response.json();
|
||||
if (response.ok) {
|
||||
setFieldValue("starter_messages", data);
|
||||
}
|
||||
} catch (error) {
|
||||
console.error("Failed to refresh prompts:", error);
|
||||
} finally {
|
||||
setIsRefreshing(false);
|
||||
}
|
||||
},
|
||||
1000
|
||||
);
|
||||
|
||||
const [isRequestSuccessful, setIsRequestSuccessful] = useState(false);
|
||||
|
||||
return (
|
||||
@ -421,6 +482,8 @@ export function AssistantEditor({
|
||||
isSubmitting,
|
||||
values,
|
||||
setFieldValue,
|
||||
errors,
|
||||
|
||||
...formikProps
|
||||
}: FormikProps<any>) => {
|
||||
function toggleToolInValues(toolId: number) {
|
||||
@ -445,6 +508,7 @@ export function AssistantEditor({
|
||||
|
||||
return (
|
||||
<Form className="w-full text-text-950">
|
||||
{/* Refresh starter messages when name or description changes */}
|
||||
<div className="w-full flex gap-x-2 justify-center">
|
||||
<Popover
|
||||
open={isIconDropdownOpen}
|
||||
@ -984,6 +1048,91 @@ export function AssistantEditor({
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<div className="mb-6 w-full flex flex-col">
|
||||
<div className="flex gap-x-2 items-center">
|
||||
<div className="block font-medium text-base">
|
||||
Starter Messages
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<SubLabel>
|
||||
Pre-configured messages that help users understand what this
|
||||
assistant can do and how to interact with it effectively.
|
||||
</SubLabel>
|
||||
<div className="relative w-fit">
|
||||
<TooltipProvider delayDuration={50}>
|
||||
<Tooltip>
|
||||
<TooltipTrigger asChild>
|
||||
<div>
|
||||
<Button
|
||||
type="button"
|
||||
size="sm"
|
||||
onClick={() =>
|
||||
debouncedRefreshPrompts(values, setFieldValue)
|
||||
}
|
||||
disabled={
|
||||
!autoStarterMessageEnabled ||
|
||||
isRefreshing ||
|
||||
(Object.keys(errors).length > 0 &&
|
||||
Object.keys(errors).some(
|
||||
(key) => !key.startsWith("starter_messages")
|
||||
))
|
||||
}
|
||||
className={`
|
||||
px-3 py-2
|
||||
mr-auto
|
||||
my-2
|
||||
flex gap-x-2
|
||||
text-sm font-medium
|
||||
rounded-lg shadow-sm
|
||||
items-center gap-2
|
||||
transition-colors duration-200
|
||||
${
|
||||
isRefreshing || !autoStarterMessageEnabled
|
||||
? "bg-gray-100 text-gray-400 cursor-not-allowed"
|
||||
: "bg-blue-50 text-blue-600 hover:bg-blue-100 active:bg-blue-200"
|
||||
}
|
||||
`}
|
||||
>
|
||||
<div className="flex items-center gap-x-2">
|
||||
{isRefreshing ? (
|
||||
<FiRefreshCcw className="w-4 h-4 animate-spin text-gray-400" />
|
||||
) : (
|
||||
<SwapIcon className="w-4 h-4 text-blue-600" />
|
||||
)}
|
||||
Generate
|
||||
</div>
|
||||
</Button>
|
||||
</div>
|
||||
</TooltipTrigger>
|
||||
{!autoStarterMessageEnabled && (
|
||||
<TooltipContent side="top" align="center">
|
||||
<p className="bg-background-900 max-w-[200px] mb-1 text-sm rounded-lg p-1.5 text-white">
|
||||
No LLM providers configured. Generation is not
|
||||
available.
|
||||
</p>
|
||||
</TooltipContent>
|
||||
)}
|
||||
</Tooltip>
|
||||
</TooltipProvider>
|
||||
</div>
|
||||
<div className="w-full">
|
||||
<FieldArray
|
||||
name="starter_messages"
|
||||
render={(arrayHelpers: ArrayHelpers) => (
|
||||
<StarterMessagesList
|
||||
isRefreshing={isRefreshing}
|
||||
values={values.starter_messages}
|
||||
arrayHelpers={arrayHelpers}
|
||||
touchStarterMessages={() => {
|
||||
setHasEditedStarterMessage(true);
|
||||
}}
|
||||
/>
|
||||
)}
|
||||
/>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
{admin && (
|
||||
<AdvancedOptionsToggle
|
||||
title="Categories"
|
||||
@ -1190,136 +1339,12 @@ export function AssistantEditor({
|
||||
</>
|
||||
)}
|
||||
|
||||
<div className="mb-6 flex flex-col">
|
||||
<div className="flex gap-x-2 items-center">
|
||||
<div className="block font-medium text-base">
|
||||
Starter Messages (Optional){" "}
|
||||
</div>
|
||||
</div>
|
||||
<SubLabel>
|
||||
Add pre-defined messages to help users get started. Only
|
||||
the first 4 will be displayed.
|
||||
</SubLabel>
|
||||
<FieldArray
|
||||
name="starter_messages"
|
||||
render={(
|
||||
arrayHelpers: ArrayHelpers<StarterMessage[]>
|
||||
) => (
|
||||
<div>
|
||||
{values.starter_messages &&
|
||||
values.starter_messages.length > 0 &&
|
||||
values.starter_messages.map(
|
||||
(
|
||||
starterMessage: StarterMessage,
|
||||
index: number
|
||||
) => {
|
||||
return (
|
||||
<div
|
||||
key={index}
|
||||
className={index === 0 ? "mt-2" : "mt-6"}
|
||||
>
|
||||
<div className="flex">
|
||||
<div className="w-full mr-6 border border-border p-3 rounded">
|
||||
<div>
|
||||
<Label small>Name</Label>
|
||||
<SubLabel>
|
||||
Shows up as the "title"
|
||||
for this Starter Message. For
|
||||
example, "Write an email".
|
||||
</SubLabel>
|
||||
<Field
|
||||
name={`starter_messages[${index}].name`}
|
||||
className={`
|
||||
border
|
||||
border-border
|
||||
bg-background
|
||||
rounded
|
||||
w-full
|
||||
py-2
|
||||
px-3
|
||||
mr-4
|
||||
`}
|
||||
autoComplete="off"
|
||||
/>
|
||||
<ErrorMessage
|
||||
name={`starter_messages[${index}].name`}
|
||||
component="div"
|
||||
className="text-error text-sm mt-1"
|
||||
/>
|
||||
</div>
|
||||
|
||||
<div className="mt-3">
|
||||
<Label small>Message</Label>
|
||||
<SubLabel>
|
||||
The actual message to be sent as the
|
||||
initial user message if a user
|
||||
selects this starter prompt. For
|
||||
example, "Write me an email to
|
||||
a client about a new billing feature
|
||||
we just released."
|
||||
</SubLabel>
|
||||
<Field
|
||||
name={`starter_messages[${index}].message`}
|
||||
className={`
|
||||
border
|
||||
border-border
|
||||
bg-background
|
||||
rounded
|
||||
w-full
|
||||
py-2
|
||||
px-3
|
||||
min-h-12
|
||||
mr-4
|
||||
line-clamp-
|
||||
`}
|
||||
as="textarea"
|
||||
autoComplete="off"
|
||||
/>
|
||||
<ErrorMessage
|
||||
name={`starter_messages[${index}].message`}
|
||||
component="div"
|
||||
className="text-error text-sm mt-1"
|
||||
/>
|
||||
</div>
|
||||
</div>
|
||||
<div className="my-auto">
|
||||
<FiX
|
||||
className="my-auto w-10 h-10 cursor-pointer hover:bg-hover rounded p-2"
|
||||
onClick={() =>
|
||||
arrayHelpers.remove(index)
|
||||
}
|
||||
/>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
)}
|
||||
|
||||
<Button
|
||||
onClick={() => {
|
||||
arrayHelpers.push({
|
||||
name: "",
|
||||
description: "",
|
||||
message: "",
|
||||
});
|
||||
}}
|
||||
className="mt-3"
|
||||
size="sm"
|
||||
variant="next"
|
||||
>
|
||||
Add New
|
||||
</Button>
|
||||
</div>
|
||||
)}
|
||||
/>
|
||||
</div>
|
||||
|
||||
<IsPublicGroupSelector
|
||||
formikProps={{
|
||||
values,
|
||||
isSubmitting,
|
||||
setFieldValue,
|
||||
errors,
|
||||
...formikProps,
|
||||
}}
|
||||
objectName="assistant"
|
||||
|
198
web/src/app/admin/assistants/StarterMessageList.tsx
Normal file
198
web/src/app/admin/assistants/StarterMessageList.tsx
Normal file
@ -0,0 +1,198 @@
|
||||
"use client";
|
||||
|
||||
import { ArrayHelpers, ErrorMessage, Field, useFormikContext } from "formik";
|
||||
|
||||
import {
|
||||
Tooltip,
|
||||
TooltipContent,
|
||||
TooltipProvider,
|
||||
TooltipTrigger,
|
||||
} from "@radix-ui/react-tooltip";
|
||||
|
||||
import { useEffect } from "react";
|
||||
import { FiInfo, FiTrash2, FiPlus } from "react-icons/fi";
|
||||
import { StarterMessage } from "./interfaces";
|
||||
import { Label } from "@/components/admin/connectors/Field";
|
||||
|
||||
export default function StarterMessagesList({
|
||||
values,
|
||||
arrayHelpers,
|
||||
isRefreshing,
|
||||
touchStarterMessages,
|
||||
}: {
|
||||
values: StarterMessage[];
|
||||
arrayHelpers: ArrayHelpers;
|
||||
isRefreshing: boolean;
|
||||
touchStarterMessages: () => void;
|
||||
}) {
|
||||
const { handleChange } = useFormikContext();
|
||||
|
||||
// Group starter messages into rows of 2 for display purposes
|
||||
const rows = values.reduce((acc: StarterMessage[][], curr, i) => {
|
||||
if (i % 2 === 0) acc.push([curr]);
|
||||
else acc[acc.length - 1].push(curr);
|
||||
return acc;
|
||||
}, []);
|
||||
|
||||
const canAddMore = values.length <= 6;
|
||||
|
||||
return (
|
||||
<div className="mt-4 flex flex-col gap-6">
|
||||
{rows.map((row, rowIndex) => (
|
||||
<div key={rowIndex} className="flex items-start gap-4">
|
||||
<div className="grid grid-cols-2 gap-6 w-full xl:w-fit">
|
||||
{row.map((starterMessage, colIndex) => (
|
||||
<div
|
||||
key={rowIndex * 2 + colIndex}
|
||||
className="bg-white max-w-full w-full xl:w-[500px] border border-border rounded-lg shadow-md transition-shadow duration-200 p-6"
|
||||
>
|
||||
<div className="space-y-5">
|
||||
{isRefreshing ? (
|
||||
<div className="w-full">
|
||||
<div className="w-full">
|
||||
<div className="h-4 w-24 bg-gray-200 rounded animate-pulse mb-2" />
|
||||
<div className="h-10 w-full bg-gray-200 rounded animate-pulse" />
|
||||
</div>
|
||||
|
||||
<div>
|
||||
<div className="h-4 w-24 bg-gray-200 rounded animate-pulse mb-2" />
|
||||
<div className="h-10 w-full bg-gray-200 rounded animate-pulse" />
|
||||
</div>
|
||||
|
||||
<div>
|
||||
<div className="h-4 w-24 bg-gray-200 rounded animate-pulse mb-2" />
|
||||
<div className="h-24 w-full bg-gray-200 rounded animate-pulse" />
|
||||
</div>
|
||||
</div>
|
||||
) : (
|
||||
<>
|
||||
<div>
|
||||
<div className="flex w-full items-center gap-x-1">
|
||||
<Label
|
||||
small
|
||||
className="text-sm font-medium text-gray-700"
|
||||
>
|
||||
Name
|
||||
</Label>
|
||||
<TooltipProvider delayDuration={50}>
|
||||
<Tooltip>
|
||||
<TooltipTrigger>
|
||||
<FiInfo size={12} />
|
||||
</TooltipTrigger>
|
||||
<TooltipContent side="top" align="center">
|
||||
<p className="bg-background-900 max-w-[200px] mb-1 text-sm rounded-lg p-1.5 text-white">
|
||||
Shows up as the "title" for this
|
||||
Starter Message. For example, "Write an
|
||||
email."
|
||||
</p>
|
||||
</TooltipContent>
|
||||
</Tooltip>
|
||||
</TooltipProvider>
|
||||
</div>
|
||||
<Field
|
||||
name={`starter_messages.${
|
||||
rowIndex * 2 + colIndex
|
||||
}.name`}
|
||||
className="mt-1 w-full px-4 py-2.5 bg-background border border-border rounded-lg focus:ring-2 focus:ring-blue-500 focus:border-transparent transition"
|
||||
autoComplete="off"
|
||||
placeholder="Enter a name..."
|
||||
onChange={(e: any) => {
|
||||
touchStarterMessages();
|
||||
handleChange(e);
|
||||
}}
|
||||
/>
|
||||
<ErrorMessage
|
||||
name={`starter_messages.${
|
||||
rowIndex * 2 + colIndex
|
||||
}.name`}
|
||||
component="div"
|
||||
className="text-red-500 text-sm mt-1"
|
||||
/>
|
||||
</div>
|
||||
|
||||
<div>
|
||||
<div className="flex w-full items-center gap-x-1">
|
||||
<Label
|
||||
small
|
||||
className="text-sm font-medium text-gray-700"
|
||||
>
|
||||
Message
|
||||
</Label>
|
||||
<TooltipProvider delayDuration={50}>
|
||||
<Tooltip>
|
||||
<TooltipTrigger>
|
||||
<FiInfo size={12} />
|
||||
</TooltipTrigger>
|
||||
<TooltipContent side="top" align="center">
|
||||
<p className="bg-background-900 max-w-[200px] mb-1 text-sm rounded-lg p-1.5 text-white">
|
||||
The actual message to be sent as the initial
|
||||
user message.
|
||||
</p>
|
||||
</TooltipContent>
|
||||
</Tooltip>
|
||||
</TooltipProvider>
|
||||
</div>
|
||||
<Field
|
||||
name={`starter_messages.${
|
||||
rowIndex * 2 + colIndex
|
||||
}.message`}
|
||||
className="mt-1 text-sm w-full px-4 py-2.5 bg-background border border-border rounded-lg focus:ring-2 focus:ring-blue-500 focus:border-transparent transition min-h-[100px] resize-y"
|
||||
as="textarea"
|
||||
autoComplete="off"
|
||||
placeholder="Enter the message..."
|
||||
onChange={(e: any) => {
|
||||
touchStarterMessages();
|
||||
handleChange(e);
|
||||
}}
|
||||
/>
|
||||
<ErrorMessage
|
||||
name={`starter_messages.${
|
||||
rowIndex * 2 + colIndex
|
||||
}.message`}
|
||||
component="div"
|
||||
className="text-red-500 text-sm mt-1"
|
||||
/>
|
||||
</div>
|
||||
</>
|
||||
)}
|
||||
</div>
|
||||
</div>
|
||||
))}
|
||||
</div>
|
||||
|
||||
<button
|
||||
type="button"
|
||||
onClick={() => {
|
||||
arrayHelpers.remove(rowIndex * 2 + 1);
|
||||
arrayHelpers.remove(rowIndex * 2);
|
||||
}}
|
||||
className="p-1.5 bg-white border border-gray-200 rounded-full text-gray-400 hover:text-red-500 hover:border-red-200 transition-colors mt-2"
|
||||
aria-label="Delete row"
|
||||
>
|
||||
<FiTrash2 size={14} />
|
||||
</button>
|
||||
</div>
|
||||
))}
|
||||
|
||||
{canAddMore && (
|
||||
<button
|
||||
type="button"
|
||||
onClick={() => {
|
||||
arrayHelpers.push({
|
||||
name: "",
|
||||
message: "",
|
||||
});
|
||||
arrayHelpers.push({
|
||||
name: "",
|
||||
message: "",
|
||||
});
|
||||
}}
|
||||
className="self-start flex items-center gap-2 px-4 py-2 bg-white border border-gray-200 rounded-lg text-gray-600 hover:bg-gray-50 hover:border-gray-300 transition-colors"
|
||||
>
|
||||
<FiPlus size={16} />
|
||||
<span>Add Row</span>
|
||||
</button>
|
||||
)}
|
||||
</div>
|
||||
);
|
||||
}
|
Loading…
x
Reference in New Issue
Block a user