Better Naming for API Keys (#76)

This commit is contained in:
Yuhong Sun
2024-04-22 18:56:58 -07:00
committed by Chris Weaver
parent 9a9b89f073
commit 336c046e5d
3 changed files with 31 additions and 14 deletions

View File

@@ -5,22 +5,20 @@ from sqlalchemy import select
from sqlalchemy.orm import Session from sqlalchemy.orm import Session
from danswer.auth.schemas import UserRole from danswer.auth.schemas import UserRole
from danswer.configs.constants import DANSWER_API_KEY_DUMMY_EMAIL_DOMAIN
from danswer.configs.constants import DANSWER_API_KEY_PREFIX
from danswer.configs.constants import UNNAMED_KEY_PLACEHOLDER
from danswer.db.models import ApiKey from danswer.db.models import ApiKey
from danswer.db.models import User from danswer.db.models import User
from ee.danswer.auth.api_key import ApiKeyDescriptor from ee.danswer.auth.api_key import ApiKeyDescriptor
from ee.danswer.auth.api_key import build_displayable_api_key from ee.danswer.auth.api_key import build_displayable_api_key
from ee.danswer.auth.api_key import generate_api_key from ee.danswer.auth.api_key import generate_api_key
from ee.danswer.auth.api_key import hash_api_key from ee.danswer.auth.api_key import hash_api_key
from ee.danswer.db.constants import DANSWER_API_KEY_DUMMY_EMAIL_DOMAIN
from ee.danswer.server.api_key.models import APIKeyArgs from ee.danswer.server.api_key.models import APIKeyArgs
_DANSWER_API_KEY = "danswer_api_key"
def is_api_key_email_address(email: str) -> bool: def is_api_key_email_address(email: str) -> bool:
return email.endswith( return email.endswith(f"@{DANSWER_API_KEY_DUMMY_EMAIL_DOMAIN}")
f"@{DANSWER_API_KEY_DUMMY_EMAIL_DOMAIN}"
) and email.startswith(_DANSWER_API_KEY)
def fetch_api_keys(db_session: Session) -> list[ApiKeyDescriptor]: def fetch_api_keys(db_session: Session) -> list[ApiKeyDescriptor]:
@@ -46,6 +44,13 @@ def fetch_user_for_api_key(hashed_api_key: str, db_session: Session) -> User | N
return db_session.scalar(select(User).where(User.id == api_key.user_id)) # type: ignore return db_session.scalar(select(User).where(User.id == api_key.user_id)) # type: ignore
def get_api_key_fake_email(
name: str,
unique_id: str,
) -> str:
return f"{DANSWER_API_KEY_PREFIX}{name}@{unique_id}{DANSWER_API_KEY_DUMMY_EMAIL_DOMAIN}"
def insert_api_key( def insert_api_key(
db_session: Session, api_key_args: APIKeyArgs, user_id: uuid.UUID | None db_session: Session, api_key_args: APIKeyArgs, user_id: uuid.UUID | None
) -> ApiKeyDescriptor: ) -> ApiKeyDescriptor:
@@ -53,9 +58,10 @@ def insert_api_key(
api_key = generate_api_key() api_key = generate_api_key()
api_key_user_id = uuid.uuid4() api_key_user_id = uuid.uuid4()
display_name = api_key_args.name or UNNAMED_KEY_PLACEHOLDER
api_key_user_row = User( api_key_user_row = User(
id=api_key_user_id, id=api_key_user_id,
email=f"{_DANSWER_API_KEY}__{api_key_user_id}@{DANSWER_API_KEY_DUMMY_EMAIL_DOMAIN}", email=get_api_key_fake_email(display_name, str(api_key_user_id)),
# a random password for the "user" # a random password for the "user"
hashed_password=std_password_helper.hash(std_password_helper.generate()), hashed_password=std_password_helper.hash(std_password_helper.generate()),
is_active=True, is_active=True,
@@ -91,7 +97,16 @@ def update_api_key(
if existing_api_key is None: if existing_api_key is None:
raise ValueError(f"API key with id {api_key_id} does not exist") raise ValueError(f"API key with id {api_key_id} does not exist")
existing_api_key.name = api_key_args.name new_name = api_key_args.name or UNNAMED_KEY_PLACEHOLDER
existing_api_key.name = new_name
api_key_user = db_session.scalar(
select(User).where(User.id == existing_api_key.user_id) # type: ignore
)
if api_key_user is None:
raise RuntimeError("API Key does not have associated user.")
api_key_user.email = get_api_key_fake_email(new_name, str(api_key_user.id))
db_session.commit() db_session.commit()
return ApiKeyDescriptor( return ApiKeyDescriptor(

View File

@@ -1 +0,0 @@
DANSWER_API_KEY_DUMMY_EMAIL_DOMAIN = "danswerapikey.ai"

View File

@@ -13,6 +13,7 @@ from sqlalchemy.orm import Session
import danswer.db.models as db_models import danswer.db.models as db_models
from danswer.auth.users import current_admin_user from danswer.auth.users import current_admin_user
from danswer.auth.users import get_display_email
from danswer.chat.chat_utils import create_chat_chain from danswer.chat.chat_utils import create_chat_chain
from danswer.configs.constants import MessageType from danswer.configs.constants import MessageType
from danswer.configs.constants import QAFeedbackType from danswer.configs.constants import QAFeedbackType
@@ -76,7 +77,7 @@ class MessageSnapshot(BaseModel):
class ChatSessionSnapshot(BaseModel): class ChatSessionSnapshot(BaseModel):
id: int id: int
user_email: str | None user_email: str
name: str | None name: str | None
messages: list[MessageSnapshot] messages: list[MessageSnapshot]
persona_name: str persona_name: str
@@ -89,7 +90,7 @@ class QuestionAnswerPairSnapshot(BaseModel):
retrieved_documents: list[AbridgedSearchDoc] retrieved_documents: list[AbridgedSearchDoc]
feedback: QAFeedbackType | None feedback: QAFeedbackType | None
persona_name: str persona_name: str
user_email: str | None user_email: str
time_created: datetime time_created: datetime
@classmethod @classmethod
@@ -113,7 +114,7 @@ class QuestionAnswerPairSnapshot(BaseModel):
retrieved_documents=ai_message.documents, retrieved_documents=ai_message.documents,
feedback=ai_message.feedback, feedback=ai_message.feedback,
persona_name=chat_session_snapshot.persona_name, persona_name=chat_session_snapshot.persona_name,
user_email=chat_session_snapshot.user_email, user_email=get_display_email(chat_session_snapshot.user_email),
time_created=user_message.time_created, time_created=user_message.time_created,
) )
for user_message, ai_message in message_pairs for user_message, ai_message in message_pairs
@@ -131,7 +132,7 @@ class QuestionAnswerPairSnapshot(BaseModel):
), ),
"feedback": self.feedback.value if self.feedback else "", "feedback": self.feedback.value if self.feedback else "",
"persona_name": self.persona_name, "persona_name": self.persona_name,
"user_email": self.user_email if self.user_email else "", "user_email": self.user_email,
"time_created": str(self.time_created), "time_created": str(self.time_created),
} }
@@ -181,7 +182,9 @@ def snapshot_from_chat_session(
return ChatSessionSnapshot( return ChatSessionSnapshot(
id=chat_session.id, id=chat_session.id,
user_email=chat_session.user.email if chat_session.user else None, user_email=get_display_email(
chat_session.user.email if chat_session.user else None
),
name=chat_session.description, name=chat_session.description,
messages=[ messages=[
MessageSnapshot.build(message) MessageSnapshot.build(message)