Better Naming for API Keys (#76)

This commit is contained in:
Yuhong Sun
2024-04-22 18:56:58 -07:00
committed by Chris Weaver
parent 9a9b89f073
commit 336c046e5d
3 changed files with 31 additions and 14 deletions

View File

@@ -5,22 +5,20 @@ from sqlalchemy import select
from sqlalchemy.orm import Session
from danswer.auth.schemas import UserRole
from danswer.configs.constants import DANSWER_API_KEY_DUMMY_EMAIL_DOMAIN
from danswer.configs.constants import DANSWER_API_KEY_PREFIX
from danswer.configs.constants import UNNAMED_KEY_PLACEHOLDER
from danswer.db.models import ApiKey
from danswer.db.models import User
from ee.danswer.auth.api_key import ApiKeyDescriptor
from ee.danswer.auth.api_key import build_displayable_api_key
from ee.danswer.auth.api_key import generate_api_key
from ee.danswer.auth.api_key import hash_api_key
from ee.danswer.db.constants import DANSWER_API_KEY_DUMMY_EMAIL_DOMAIN
from ee.danswer.server.api_key.models import APIKeyArgs
_DANSWER_API_KEY = "danswer_api_key"
def is_api_key_email_address(email: str) -> bool:
return email.endswith(
f"@{DANSWER_API_KEY_DUMMY_EMAIL_DOMAIN}"
) and email.startswith(_DANSWER_API_KEY)
return email.endswith(f"@{DANSWER_API_KEY_DUMMY_EMAIL_DOMAIN}")
def fetch_api_keys(db_session: Session) -> list[ApiKeyDescriptor]:
@@ -46,6 +44,13 @@ def fetch_user_for_api_key(hashed_api_key: str, db_session: Session) -> User | N
return db_session.scalar(select(User).where(User.id == api_key.user_id)) # type: ignore
def get_api_key_fake_email(
name: str,
unique_id: str,
) -> str:
return f"{DANSWER_API_KEY_PREFIX}{name}@{unique_id}{DANSWER_API_KEY_DUMMY_EMAIL_DOMAIN}"
def insert_api_key(
db_session: Session, api_key_args: APIKeyArgs, user_id: uuid.UUID | None
) -> ApiKeyDescriptor:
@@ -53,9 +58,10 @@ def insert_api_key(
api_key = generate_api_key()
api_key_user_id = uuid.uuid4()
display_name = api_key_args.name or UNNAMED_KEY_PLACEHOLDER
api_key_user_row = User(
id=api_key_user_id,
email=f"{_DANSWER_API_KEY}__{api_key_user_id}@{DANSWER_API_KEY_DUMMY_EMAIL_DOMAIN}",
email=get_api_key_fake_email(display_name, str(api_key_user_id)),
# a random password for the "user"
hashed_password=std_password_helper.hash(std_password_helper.generate()),
is_active=True,
@@ -91,7 +97,16 @@ def update_api_key(
if existing_api_key is None:
raise ValueError(f"API key with id {api_key_id} does not exist")
existing_api_key.name = api_key_args.name
new_name = api_key_args.name or UNNAMED_KEY_PLACEHOLDER
existing_api_key.name = new_name
api_key_user = db_session.scalar(
select(User).where(User.id == existing_api_key.user_id) # type: ignore
)
if api_key_user is None:
raise RuntimeError("API Key does not have associated user.")
api_key_user.email = get_api_key_fake_email(new_name, str(api_key_user.id))
db_session.commit()
return ApiKeyDescriptor(

View File

@@ -1 +0,0 @@
DANSWER_API_KEY_DUMMY_EMAIL_DOMAIN = "danswerapikey.ai"

View File

@@ -13,6 +13,7 @@ from sqlalchemy.orm import Session
import danswer.db.models as db_models
from danswer.auth.users import current_admin_user
from danswer.auth.users import get_display_email
from danswer.chat.chat_utils import create_chat_chain
from danswer.configs.constants import MessageType
from danswer.configs.constants import QAFeedbackType
@@ -76,7 +77,7 @@ class MessageSnapshot(BaseModel):
class ChatSessionSnapshot(BaseModel):
id: int
user_email: str | None
user_email: str
name: str | None
messages: list[MessageSnapshot]
persona_name: str
@@ -89,7 +90,7 @@ class QuestionAnswerPairSnapshot(BaseModel):
retrieved_documents: list[AbridgedSearchDoc]
feedback: QAFeedbackType | None
persona_name: str
user_email: str | None
user_email: str
time_created: datetime
@classmethod
@@ -113,7 +114,7 @@ class QuestionAnswerPairSnapshot(BaseModel):
retrieved_documents=ai_message.documents,
feedback=ai_message.feedback,
persona_name=chat_session_snapshot.persona_name,
user_email=chat_session_snapshot.user_email,
user_email=get_display_email(chat_session_snapshot.user_email),
time_created=user_message.time_created,
)
for user_message, ai_message in message_pairs
@@ -131,7 +132,7 @@ class QuestionAnswerPairSnapshot(BaseModel):
),
"feedback": self.feedback.value if self.feedback else "",
"persona_name": self.persona_name,
"user_email": self.user_email if self.user_email else "",
"user_email": self.user_email,
"time_created": str(self.time_created),
}
@@ -181,7 +182,9 @@ def snapshot_from_chat_session(
return ChatSessionSnapshot(
id=chat_session.id,
user_email=chat_session.user.email if chat_session.user else None,
user_email=get_display_email(
chat_session.user.email if chat_session.user else None
),
name=chat_session.description,
messages=[
MessageSnapshot.build(message)