mirror of
https://github.com/danswer-ai/danswer.git
synced 2025-05-13 13:20:15 +02:00
* Added ability to use a tag to insert the current datetime in prompts * made tagging logic more robust * rename * k --------- Co-authored-by: Yuhong Sun <yuhongsun96@gmail.com>
336 lines
13 KiB
Python
336 lines
13 KiB
Python
from uuid import UUID
|
|
from uuid import uuid4
|
|
|
|
import requests
|
|
|
|
from onyx.context.search.enums import RecencyBiasSetting
|
|
from onyx.server.features.persona.models import PersonaSnapshot
|
|
from onyx.server.features.persona.models import PersonaUpsertRequest
|
|
from tests.integration.common_utils.constants import API_SERVER_URL
|
|
from tests.integration.common_utils.constants import GENERAL_HEADERS
|
|
from tests.integration.common_utils.test_models import DATestPersona
|
|
from tests.integration.common_utils.test_models import DATestPersonaLabel
|
|
from tests.integration.common_utils.test_models import DATestUser
|
|
|
|
|
|
class PersonaManager:
|
|
@staticmethod
|
|
def create(
|
|
name: str | None = None,
|
|
description: str | None = None,
|
|
system_prompt: str | None = None,
|
|
task_prompt: str | None = None,
|
|
include_citations: bool = False,
|
|
num_chunks: float = 5,
|
|
llm_relevance_filter: bool = True,
|
|
is_public: bool = True,
|
|
llm_filter_extraction: bool = True,
|
|
recency_bias: RecencyBiasSetting = RecencyBiasSetting.AUTO,
|
|
datetime_aware: bool = False,
|
|
prompt_ids: list[int] | None = None,
|
|
document_set_ids: list[int] | None = None,
|
|
tool_ids: list[int] | None = None,
|
|
llm_model_provider_override: str | None = None,
|
|
llm_model_version_override: str | None = None,
|
|
users: list[str] | None = None,
|
|
groups: list[int] | None = None,
|
|
label_ids: list[int] | None = None,
|
|
user_performing_action: DATestUser | None = None,
|
|
) -> DATestPersona:
|
|
name = name or f"test-persona-{uuid4()}"
|
|
description = description or f"Description for {name}"
|
|
system_prompt = system_prompt or f"System prompt for {name}"
|
|
task_prompt = task_prompt or f"Task prompt for {name}"
|
|
|
|
persona_creation_request = PersonaUpsertRequest(
|
|
name=name,
|
|
description=description,
|
|
system_prompt=system_prompt,
|
|
task_prompt=task_prompt,
|
|
datetime_aware=datetime_aware,
|
|
include_citations=include_citations,
|
|
num_chunks=num_chunks,
|
|
llm_relevance_filter=llm_relevance_filter,
|
|
is_public=is_public,
|
|
llm_filter_extraction=llm_filter_extraction,
|
|
recency_bias=recency_bias,
|
|
prompt_ids=prompt_ids or [0],
|
|
document_set_ids=document_set_ids or [],
|
|
tool_ids=tool_ids or [],
|
|
llm_model_provider_override=llm_model_provider_override,
|
|
llm_model_version_override=llm_model_version_override,
|
|
users=[UUID(user) for user in (users or [])],
|
|
groups=groups or [],
|
|
label_ids=label_ids or [],
|
|
)
|
|
|
|
response = requests.post(
|
|
f"{API_SERVER_URL}/persona",
|
|
json=persona_creation_request.model_dump(),
|
|
headers=user_performing_action.headers
|
|
if user_performing_action
|
|
else GENERAL_HEADERS,
|
|
)
|
|
response.raise_for_status()
|
|
persona_data = response.json()
|
|
|
|
return DATestPersona(
|
|
id=persona_data["id"],
|
|
name=name,
|
|
description=description,
|
|
num_chunks=num_chunks,
|
|
llm_relevance_filter=llm_relevance_filter,
|
|
is_public=is_public,
|
|
llm_filter_extraction=llm_filter_extraction,
|
|
recency_bias=recency_bias,
|
|
prompt_ids=prompt_ids or [],
|
|
document_set_ids=document_set_ids or [],
|
|
tool_ids=tool_ids or [],
|
|
llm_model_provider_override=llm_model_provider_override,
|
|
llm_model_version_override=llm_model_version_override,
|
|
users=users or [],
|
|
groups=groups or [],
|
|
label_ids=label_ids or [],
|
|
)
|
|
|
|
@staticmethod
|
|
def edit(
|
|
persona: DATestPersona,
|
|
name: str | None = None,
|
|
description: str | None = None,
|
|
system_prompt: str | None = None,
|
|
task_prompt: str | None = None,
|
|
include_citations: bool = False,
|
|
num_chunks: float | None = None,
|
|
llm_relevance_filter: bool | None = None,
|
|
is_public: bool | None = None,
|
|
llm_filter_extraction: bool | None = None,
|
|
recency_bias: RecencyBiasSetting | None = None,
|
|
datetime_aware: bool = False,
|
|
prompt_ids: list[int] | None = None,
|
|
document_set_ids: list[int] | None = None,
|
|
tool_ids: list[int] | None = None,
|
|
llm_model_provider_override: str | None = None,
|
|
llm_model_version_override: str | None = None,
|
|
users: list[str] | None = None,
|
|
groups: list[int] | None = None,
|
|
label_ids: list[int] | None = None,
|
|
user_performing_action: DATestUser | None = None,
|
|
) -> DATestPersona:
|
|
system_prompt = system_prompt or f"System prompt for {persona.name}"
|
|
task_prompt = task_prompt or f"Task prompt for {persona.name}"
|
|
persona_update_request = PersonaUpsertRequest(
|
|
name=name or persona.name,
|
|
description=description or persona.description,
|
|
system_prompt=system_prompt,
|
|
task_prompt=task_prompt,
|
|
datetime_aware=datetime_aware,
|
|
include_citations=include_citations,
|
|
num_chunks=num_chunks or persona.num_chunks,
|
|
llm_relevance_filter=llm_relevance_filter or persona.llm_relevance_filter,
|
|
is_public=is_public or persona.is_public,
|
|
llm_filter_extraction=llm_filter_extraction
|
|
or persona.llm_filter_extraction,
|
|
recency_bias=recency_bias or persona.recency_bias,
|
|
prompt_ids=prompt_ids or persona.prompt_ids,
|
|
document_set_ids=document_set_ids or persona.document_set_ids,
|
|
tool_ids=tool_ids or persona.tool_ids,
|
|
llm_model_provider_override=llm_model_provider_override
|
|
or persona.llm_model_provider_override,
|
|
llm_model_version_override=llm_model_version_override
|
|
or persona.llm_model_version_override,
|
|
users=[UUID(user) for user in (users or persona.users)],
|
|
groups=groups or persona.groups,
|
|
label_ids=label_ids or persona.label_ids,
|
|
)
|
|
|
|
response = requests.patch(
|
|
f"{API_SERVER_URL}/persona/{persona.id}",
|
|
json=persona_update_request.model_dump(),
|
|
headers=user_performing_action.headers
|
|
if user_performing_action
|
|
else GENERAL_HEADERS,
|
|
)
|
|
response.raise_for_status()
|
|
updated_persona_data = response.json()
|
|
|
|
return DATestPersona(
|
|
id=updated_persona_data["id"],
|
|
name=updated_persona_data["name"],
|
|
description=updated_persona_data["description"],
|
|
num_chunks=updated_persona_data["num_chunks"],
|
|
llm_relevance_filter=updated_persona_data["llm_relevance_filter"],
|
|
is_public=updated_persona_data["is_public"],
|
|
llm_filter_extraction=updated_persona_data["llm_filter_extraction"],
|
|
recency_bias=recency_bias or persona.recency_bias,
|
|
prompt_ids=[prompt["id"] for prompt in updated_persona_data["prompts"]],
|
|
document_set_ids=updated_persona_data["document_sets"],
|
|
tool_ids=updated_persona_data["tools"],
|
|
llm_model_provider_override=updated_persona_data[
|
|
"llm_model_provider_override"
|
|
],
|
|
llm_model_version_override=updated_persona_data[
|
|
"llm_model_version_override"
|
|
],
|
|
users=[user["email"] for user in updated_persona_data["users"]],
|
|
groups=updated_persona_data["groups"],
|
|
label_ids=updated_persona_data["labels"],
|
|
)
|
|
|
|
@staticmethod
|
|
def get_all(
|
|
user_performing_action: DATestUser | None = None,
|
|
) -> list[PersonaSnapshot]:
|
|
response = requests.get(
|
|
f"{API_SERVER_URL}/admin/persona",
|
|
headers=user_performing_action.headers
|
|
if user_performing_action
|
|
else GENERAL_HEADERS,
|
|
)
|
|
response.raise_for_status()
|
|
return [PersonaSnapshot(**persona) for persona in response.json()]
|
|
|
|
@staticmethod
|
|
def get_one(
|
|
persona_id: int,
|
|
user_performing_action: DATestUser | None = None,
|
|
) -> list[PersonaSnapshot]:
|
|
response = requests.get(
|
|
f"{API_SERVER_URL}/persona/{persona_id}",
|
|
headers=user_performing_action.headers
|
|
if user_performing_action
|
|
else GENERAL_HEADERS,
|
|
)
|
|
response.raise_for_status()
|
|
return [PersonaSnapshot(**response.json())]
|
|
|
|
@staticmethod
|
|
def verify(
|
|
persona: DATestPersona,
|
|
user_performing_action: DATestUser | None = None,
|
|
) -> bool:
|
|
all_personas = PersonaManager.get_one(
|
|
persona_id=persona.id,
|
|
user_performing_action=user_performing_action,
|
|
)
|
|
for fetched_persona in all_personas:
|
|
if fetched_persona.id == persona.id:
|
|
return (
|
|
fetched_persona.name == persona.name
|
|
and fetched_persona.description == persona.description
|
|
and fetched_persona.num_chunks == persona.num_chunks
|
|
and fetched_persona.llm_relevance_filter
|
|
== persona.llm_relevance_filter
|
|
and fetched_persona.is_public == persona.is_public
|
|
and fetched_persona.llm_filter_extraction
|
|
== persona.llm_filter_extraction
|
|
and fetched_persona.llm_model_provider_override
|
|
== persona.llm_model_provider_override
|
|
and fetched_persona.llm_model_version_override
|
|
== persona.llm_model_version_override
|
|
and set([prompt.id for prompt in fetched_persona.prompts])
|
|
== set(persona.prompt_ids)
|
|
and set(
|
|
[
|
|
document_set.id
|
|
for document_set in fetched_persona.document_sets
|
|
]
|
|
)
|
|
== set(persona.document_set_ids)
|
|
and set([tool.id for tool in fetched_persona.tools])
|
|
== set(persona.tool_ids)
|
|
and set(user.email for user in fetched_persona.users)
|
|
== set(persona.users)
|
|
and set(fetched_persona.groups) == set(persona.groups)
|
|
and set(fetched_persona.labels) == set(persona.label_ids)
|
|
)
|
|
return False
|
|
|
|
@staticmethod
|
|
def delete(
|
|
persona: DATestPersona,
|
|
user_performing_action: DATestUser | None = None,
|
|
) -> bool:
|
|
response = requests.delete(
|
|
f"{API_SERVER_URL}/persona/{persona.id}",
|
|
headers=user_performing_action.headers
|
|
if user_performing_action
|
|
else GENERAL_HEADERS,
|
|
)
|
|
return response.ok
|
|
|
|
|
|
class PersonaLabelManager:
|
|
@staticmethod
|
|
def create(
|
|
label: DATestPersonaLabel,
|
|
user_performing_action: DATestUser | None = None,
|
|
) -> DATestPersonaLabel:
|
|
response = requests.post(
|
|
f"{API_SERVER_URL}/persona/labels",
|
|
json={
|
|
"name": label.name,
|
|
},
|
|
headers=user_performing_action.headers
|
|
if user_performing_action
|
|
else GENERAL_HEADERS,
|
|
)
|
|
response.raise_for_status()
|
|
response_data = response.json()
|
|
label.id = response_data["id"]
|
|
return label
|
|
|
|
@staticmethod
|
|
def get_all(
|
|
user_performing_action: DATestUser | None = None,
|
|
) -> list[DATestPersonaLabel]:
|
|
response = requests.get(
|
|
f"{API_SERVER_URL}/persona/labels",
|
|
headers=user_performing_action.headers
|
|
if user_performing_action
|
|
else GENERAL_HEADERS,
|
|
)
|
|
response.raise_for_status()
|
|
return [DATestPersonaLabel(**label) for label in response.json()]
|
|
|
|
@staticmethod
|
|
def update(
|
|
label: DATestPersonaLabel,
|
|
user_performing_action: DATestUser | None = None,
|
|
) -> DATestPersonaLabel:
|
|
response = requests.patch(
|
|
f"{API_SERVER_URL}/admin/persona/label/{label.id}",
|
|
json={
|
|
"label_name": label.name,
|
|
},
|
|
headers=user_performing_action.headers
|
|
if user_performing_action
|
|
else GENERAL_HEADERS,
|
|
)
|
|
response.raise_for_status()
|
|
return label
|
|
|
|
@staticmethod
|
|
def delete(
|
|
label: DATestPersonaLabel,
|
|
user_performing_action: DATestUser | None = None,
|
|
) -> bool:
|
|
response = requests.delete(
|
|
f"{API_SERVER_URL}/admin/persona/label/{label.id}",
|
|
headers=user_performing_action.headers
|
|
if user_performing_action
|
|
else GENERAL_HEADERS,
|
|
)
|
|
return response.ok
|
|
|
|
@staticmethod
|
|
def verify(
|
|
label: DATestPersonaLabel,
|
|
user_performing_action: DATestUser | None = None,
|
|
) -> bool:
|
|
all_labels = PersonaLabelManager.get_all(user_performing_action)
|
|
for fetched_label in all_labels:
|
|
if fetched_label.id == label.id:
|
|
return fetched_label.name == label.name
|
|
return False
|