UX Refresh (#3687)

* add new ux

* quick nit

* additional nit

* finalize

* quick fix

* fix typing
This commit is contained in:
pablonyx
2025-01-16 00:08:01 -08:00
committed by GitHub
parent a05addec19
commit 44b70a87df
163 changed files with 9741 additions and 4396 deletions

View File

@@ -1,15 +1,12 @@
import json
import re
from typing import Any
from typing import cast
from typing import Dict
from typing import List
from litellm import get_supported_openai_params
from sqlalchemy.orm import Session
from onyx.configs.chat_configs import NUM_PERSONA_PROMPT_GENERATION_CHUNKS
from onyx.configs.chat_configs import NUM_PERSONA_PROMPTS
from onyx.context.search.models import IndexFilters
from onyx.context.search.models import InferenceChunk
from onyx.context.search.postprocessing.postprocessing import cleanup_chunks
@@ -22,8 +19,8 @@ from onyx.db.models import User
from onyx.document_index.document_index_utils import get_both_index_names
from onyx.document_index.factory import get_default_document_index
from onyx.llm.factory import get_default_llms
from onyx.prompts.starter_messages import format_persona_starter_message_prompt
from onyx.prompts.starter_messages import PERSONA_CATEGORY_GENERATION_PROMPT
from onyx.prompts.starter_messages import PERSONA_STARTER_MESSAGE_CREATION_PROMPT
from onyx.utils.logger import setup_logger
from onyx.utils.threadpool_concurrency import FunctionCall
from onyx.utils.threadpool_concurrency import run_functions_in_parallel
@@ -49,7 +46,7 @@ def get_random_chunks_from_doc_sets(
return cleanup_chunks(chunks)
def parse_categories(content: str) -> List[str]:
def parse_categories(content: str) -> List[str | None]:
"""
Parses the JSON array of categories from the LLM response.
"""
@@ -73,7 +70,7 @@ def generate_start_message_prompts(
name: str,
description: str,
instructions: str,
categories: List[str],
categories: List[str | None],
chunk_contents: str,
supports_structured_output: bool,
fast_llm: Any,
@@ -84,13 +81,11 @@ def generate_start_message_prompts(
functions = []
for category in categories:
# Create a prompt specific to the category
start_message_generation_prompt = (
PERSONA_STARTER_MESSAGE_CREATION_PROMPT.format(
name=name,
description=description,
instructions=instructions,
category=category,
)
start_message_generation_prompt = format_persona_starter_message_prompt(
name=name,
description=description,
instructions=instructions,
category=category,
)
if chunk_contents:
@@ -101,89 +96,21 @@ def generate_start_message_prompts(
"\n'''"
)
if supports_structured_output:
functions.append(
FunctionCall(
fast_llm.invoke,
(start_message_generation_prompt, None, None, StarterMessage),
)
)
else:
functions.append(
FunctionCall(
fast_llm.invoke,
(start_message_generation_prompt,),
)
functions.append(
FunctionCall(
fast_llm.invoke,
(start_message_generation_prompt,),
)
)
return functions
def parse_unstructured_output(output: str) -> Dict[str, str]:
"""
Parses the assistant's unstructured output into a dictionary with keys:
- 'name' (Title)
- 'message' (Message)
"""
# Debug output
logger.debug(f"LLM Output for starter message creation: {output}")
# Patterns to match
title_pattern = r"(?i)^\**Title\**\s*:\s*(.+)"
message_pattern = r"(?i)^\**Message\**\s*:\s*(.+)"
# Initialize the response dictionary
response_dict = {}
# Split the output into lines
lines = output.strip().split("\n")
# Variables to keep track of the current key being processed
current_key = None
current_value_lines = []
for line in lines:
# Check for title
title_match = re.match(title_pattern, line.strip())
if title_match:
# Save previous key-value pair if any
if current_key and current_value_lines:
response_dict[current_key] = " ".join(current_value_lines).strip()
current_value_lines = []
current_key = "name"
current_value_lines.append(title_match.group(1).strip())
continue
# Check for message
message_match = re.match(message_pattern, line.strip())
if message_match:
if current_key and current_value_lines:
response_dict[current_key] = " ".join(current_value_lines).strip()
current_value_lines = []
current_key = "message"
current_value_lines.append(message_match.group(1).strip())
continue
# If the line doesn't match a new key, append it to the current value
if current_key:
current_value_lines.append(line.strip())
# Add the last key-value pair
if current_key and current_value_lines:
response_dict[current_key] = " ".join(current_value_lines).strip()
# Validate that the necessary keys are present
if not all(k in response_dict for k in ["name", "message"]):
raise ValueError("Failed to parse the assistant's response.")
return response_dict
def generate_starter_messages(
name: str,
description: str,
instructions: str,
document_set_ids: List[int],
generation_count: int,
db_session: Session,
user: User | None,
) -> List[StarterMessage]:
@@ -201,20 +128,26 @@ def generate_starter_messages(
isinstance(params, list) and "response_format" in params
)
# Generate categories
category_generation_prompt = PERSONA_CATEGORY_GENERATION_PROMPT.format(
name=name,
description=description,
instructions=instructions,
num_categories=NUM_PERSONA_PROMPTS,
)
categories: list[str | None] = []
category_response = fast_llm.invoke(category_generation_prompt)
categories = parse_categories(cast(str, category_response.content))
if generation_count > 1:
# Generate categories
category_generation_prompt = PERSONA_CATEGORY_GENERATION_PROMPT.format(
name=name,
description=description,
instructions=instructions,
num_categories=generation_count,
)
if not categories:
logger.error("No categories were generated.")
return []
category_response = fast_llm.invoke(category_generation_prompt)
categories = parse_categories(cast(str, category_response.content))
if not categories:
logger.error("No categories were generated.")
return []
else:
categories = [None]
# Fetch example content if document sets are provided
if document_set_ids:
@@ -254,18 +187,9 @@ def generate_starter_messages(
prompts = []
for response in results.values():
try:
if supports_structured_output:
response_dict = json.loads(response.content)
else:
response_dict = parse_unstructured_output(response.content)
starter_message = StarterMessage(
name=response_dict["name"],
message=response_dict["message"],
)
prompts.append(starter_message)
except (json.JSONDecodeError, ValueError) as e:
logger.error(f"Failed to parse starter message: {e}")
continue
starter_message = StarterMessage(
message=response.content, name=response.content
)
prompts.append(starter_message)
return prompts