feat(sqlalchemy): remove session reference from router

This commit is contained in:
Jonathan Rohde 2024-06-21 14:58:57 +02:00
parent df09d0830a
commit bee835cb65
34 changed files with 1231 additions and 1211 deletions

View File

@ -31,7 +31,6 @@ from typing import Optional, List, Union
from starlette.background import BackgroundTask from starlette.background import BackgroundTask
from apps.webui.internal.db import get_db
from apps.webui.models.models import Models from apps.webui.models.models import Models
from apps.webui.models.users import Users from apps.webui.models.users import Users
from constants import ERROR_MESSAGES from constants import ERROR_MESSAGES
@ -712,7 +711,6 @@ async def generate_chat_completion(
form_data: GenerateChatCompletionForm, form_data: GenerateChatCompletionForm,
url_idx: Optional[int] = None, url_idx: Optional[int] = None,
user=Depends(get_verified_user), user=Depends(get_verified_user),
db=Depends(get_db),
): ):
log.debug( log.debug(
@ -726,7 +724,7 @@ async def generate_chat_completion(
} }
model_id = form_data.model model_id = form_data.model
model_info = Models.get_model_by_id(db, model_id) model_info = Models.get_model_by_id(model_id)
if model_info: if model_info:
if model_info.base_model_id: if model_info.base_model_id:
@ -885,7 +883,6 @@ async def generate_openai_chat_completion(
form_data: dict, form_data: dict,
url_idx: Optional[int] = None, url_idx: Optional[int] = None,
user=Depends(get_verified_user), user=Depends(get_verified_user),
db=Depends(get_db),
): ):
form_data = OpenAIChatCompletionForm(**form_data) form_data = OpenAIChatCompletionForm(**form_data)
@ -894,7 +891,7 @@ async def generate_openai_chat_completion(
} }
model_id = form_data.model model_id = form_data.model
model_info = Models.get_model_by_id(db, model_id) model_info = Models.get_model_by_id(model_id)
if model_info: if model_info:
if model_info.base_model_id: if model_info.base_model_id:

View File

@ -11,7 +11,6 @@ import logging
from pydantic import BaseModel from pydantic import BaseModel
from starlette.background import BackgroundTask from starlette.background import BackgroundTask
from apps.webui.internal.db import get_db
from apps.webui.models.models import Models from apps.webui.models.models import Models
from apps.webui.models.users import Users from apps.webui.models.users import Users
from constants import ERROR_MESSAGES from constants import ERROR_MESSAGES
@ -354,13 +353,12 @@ async def generate_chat_completion(
form_data: dict, form_data: dict,
url_idx: Optional[int] = None, url_idx: Optional[int] = None,
user=Depends(get_verified_user), user=Depends(get_verified_user),
db=Depends(get_db),
): ):
idx = 0 idx = 0
payload = {**form_data} payload = {**form_data}
model_id = form_data.get("model") model_id = form_data.get("model")
model_info = Models.get_model_by_id(db, model_id) model_info = Models.get_model_by_id(model_id)
if model_info: if model_info:
if model_info.base_model_id: if model_info.base_model_id:

View File

@ -1,6 +1,7 @@
import os import os
import logging import logging
import json import json
from contextlib import contextmanager
from typing import Optional, Any from typing import Optional, Any
from typing_extensions import Self from typing_extensions import Self
@ -52,11 +53,12 @@ if "sqlite" in SQLALCHEMY_DATABASE_URL:
) )
else: else:
engine = create_engine(SQLALCHEMY_DATABASE_URL, pool_pre_ping=True) engine = create_engine(SQLALCHEMY_DATABASE_URL, pool_pre_ping=True)
SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine) SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine, expire_on_commit=False)
Base = declarative_base() Base = declarative_base()
def get_db(): @contextmanager
def get_session():
db = SessionLocal() db = SessionLocal()
try: try:
yield db yield db
@ -64,5 +66,4 @@ def get_db():
except Exception as e: except Exception as e:
db.rollback() db.rollback()
raise e raise e
finally:
db.close()

View File

@ -114,8 +114,8 @@ async def get_status():
} }
async def get_pipe_models(db: Session): async def get_pipe_models():
pipes = Functions.get_functions_by_type(db, "pipe", active_only=True) pipes = Functions.get_functions_by_type("pipe", active_only=True)
pipe_models = [] pipe_models = []
for pipe in pipes: for pipe in pipes:

View File

@ -8,7 +8,7 @@ from sqlalchemy.orm import Session
from apps.webui.models.users import UserModel, Users from apps.webui.models.users import UserModel, Users
from utils.utils import verify_password from utils.utils import verify_password
from apps.webui.internal.db import Base from apps.webui.internal.db import Base, get_session
from config import SRC_LOG_LEVELS from config import SRC_LOG_LEVELS
@ -96,7 +96,6 @@ class AuthsTable:
def insert_new_auth( def insert_new_auth(
self, self,
db: Session,
email: str, email: str,
password: str, password: str,
name: str, name: str,
@ -104,6 +103,7 @@ class AuthsTable:
role: str = "pending", role: str = "pending",
oauth_sub: Optional[str] = None, oauth_sub: Optional[str] = None,
) -> Optional[UserModel]: ) -> Optional[UserModel]:
with get_session() as db:
log.info("insert_new_auth") log.info("insert_new_auth")
id = str(uuid.uuid4()) id = str(uuid.uuid4())
@ -115,7 +115,7 @@ class AuthsTable:
db.add(result) db.add(result)
user = Users.insert_new_user( user = Users.insert_new_user(
db, id, name, email, profile_image_url, role, oauth_sub id, name, email, profile_image_url, role, oauth_sub
) )
db.commit() db.commit()
@ -127,14 +127,15 @@ class AuthsTable:
return None return None
def authenticate_user( def authenticate_user(
self, db: Session, email: str, password: str self, email: str, password: str
) -> Optional[UserModel]: ) -> Optional[UserModel]:
log.info(f"authenticate_user: {email}") log.info(f"authenticate_user: {email}")
with get_session() as db:
try: try:
auth = db.query(Auth).filter_by(email=email, active=True).first() auth = db.query(Auth).filter_by(email=email, active=True).first()
if auth: if auth:
if verify_password(password, auth.password): if verify_password(password, auth.password):
user = Users.get_user_by_id(db, auth.id) user = Users.get_user_by_id(auth.id)
return user return user
else: else:
return None return None
@ -144,23 +145,25 @@ class AuthsTable:
return None return None
def authenticate_user_by_api_key( def authenticate_user_by_api_key(
self, db: Session, api_key: str self, api_key: str
) -> Optional[UserModel]: ) -> Optional[UserModel]:
log.info(f"authenticate_user_by_api_key: {api_key}") log.info(f"authenticate_user_by_api_key: {api_key}")
with get_session() as db:
# if no api_key, return None # if no api_key, return None
if not api_key: if not api_key:
return None return None
try: try:
user = Users.get_user_by_api_key(db, api_key) user = Users.get_user_by_api_key(api_key)
return user if user else None return user if user else None
except: except:
return False return False
def authenticate_user_by_trusted_header( def authenticate_user_by_trusted_header(
self, db: Session, email: str self, email: str
) -> Optional[UserModel]: ) -> Optional[UserModel]:
log.info(f"authenticate_user_by_trusted_header: {email}") log.info(f"authenticate_user_by_trusted_header: {email}")
with get_session() as db:
try: try:
auth = db.query(Auth).filter(email=email, active=True).first() auth = db.query(Auth).filter(email=email, active=True).first()
if auth: if auth:
@ -170,25 +173,28 @@ class AuthsTable:
return None return None
def update_user_password_by_id( def update_user_password_by_id(
self, db: Session, id: str, new_password: str self, id: str, new_password: str
) -> bool: ) -> bool:
with get_session() as db:
try: try:
result = db.query(Auth).filter_by(id=id).update({"password": new_password}) result = db.query(Auth).filter_by(id=id).update({"password": new_password})
return True if result == 1 else False return True if result == 1 else False
except: except:
return False return False
def update_email_by_id(self, db: Session, id: str, email: str) -> bool: def update_email_by_id(self, id: str, email: str) -> bool:
with get_session() as db:
try: try:
result = db.query(Auth).filter_by(id=id).update({"email": email}) result = db.query(Auth).filter_by(id=id).update({"email": email})
return True if result == 1 else False return True if result == 1 else False
except: except:
return False return False
def delete_auth_by_id(self, db: Session, id: str) -> bool: def delete_auth_by_id(self, id: str) -> bool:
with get_session() as db:
try: try:
# Delete User # Delete User
result = Users.delete_user_by_id(db, id) result = Users.delete_user_by_id(id)
if result: if result:
db.query(Auth).filter_by(id=id).delete() db.query(Auth).filter_by(id=id).delete()

View File

@ -8,7 +8,7 @@ import time
from sqlalchemy import Column, String, BigInteger, Boolean from sqlalchemy import Column, String, BigInteger, Boolean
from sqlalchemy.orm import Session from sqlalchemy.orm import Session
from apps.webui.internal.db import Base from apps.webui.internal.db import Base, get_session
#################### ####################
@ -80,8 +80,9 @@ class ChatTitleIdResponse(BaseModel):
class ChatTable: class ChatTable:
def insert_new_chat( def insert_new_chat(
self, db: Session, user_id: str, form_data: ChatForm self, user_id: str, form_data: ChatForm
) -> Optional[ChatModel]: ) -> Optional[ChatModel]:
with get_session() as db:
id = str(uuid.uuid4()) id = str(uuid.uuid4())
chat = ChatModel( chat = ChatModel(
**{ **{
@ -103,29 +104,30 @@ class ChatTable:
return ChatModel.model_validate(result) if result else None return ChatModel.model_validate(result) if result else None
def update_chat_by_id( def update_chat_by_id(
self, db: Session, id: str, chat: dict self, id: str, chat: dict
) -> Optional[ChatModel]: ) -> Optional[ChatModel]:
with get_session() as db:
try: try:
db.query(Chat).filter_by(id=id).update( chat_obj = db.get(Chat, id)
{ chat_obj.chat = json.dumps(chat)
"chat": json.dumps(chat), chat_obj.title = chat["title"] if "title" in chat else "New Chat"
"title": chat["title"] if "title" in chat else "New Chat", chat_obj.updated_at = int(time.time())
"updated_at": int(time.time()), db.commit()
} db.refresh(chat_obj)
)
return self.get_chat_by_id(db, id) return ChatModel.model_validate(chat_obj)
except: except Exception as e:
return None return None
def insert_shared_chat_by_chat_id( def insert_shared_chat_by_chat_id(
self, db: Session, chat_id: str self, chat_id: str
) -> Optional[ChatModel]: ) -> Optional[ChatModel]:
with get_session() as db:
# Get the existing chat to share # Get the existing chat to share
chat = db.get(Chat, chat_id) chat = db.get(Chat, chat_id)
# Check if the chat is already shared # Check if the chat is already shared
if chat.share_id: if chat.share_id:
return self.get_chat_by_id_and_user_id(db, chat.share_id, "shared") return self.get_chat_by_id_and_user_id(chat.share_id, "shared")
# Create a new chat with the same data, but with a new ID # Create a new chat with the same data, but with a new ID
shared_chat = ChatModel( shared_chat = ChatModel(
**{ **{
@ -149,49 +151,56 @@ class ChatTable:
return shared_chat if (shared_result and result) else None return shared_chat if (shared_result and result) else None
def update_shared_chat_by_chat_id( def update_shared_chat_by_chat_id(
self, db: Session, chat_id: str self, chat_id: str
) -> Optional[ChatModel]: ) -> Optional[ChatModel]:
with get_session() as db:
try: try:
print("update_shared_chat_by_id") print("update_shared_chat_by_id")
chat = db.get(Chat, chat_id) chat = db.get(Chat, chat_id)
print(chat) print(chat)
chat.title = chat.title
chat.chat = chat.chat
db.commit()
db.refresh(chat)
db.query(Chat).filter_by(id=chat.share_id).update( return self.get_chat_by_id(chat.share_id)
{"title": chat.title, "chat": chat.chat}
)
return self.get_chat_by_id(db, chat.share_id)
except: except:
return None return None
def delete_shared_chat_by_chat_id(self, db: Session, chat_id: str) -> bool: def delete_shared_chat_by_chat_id(self, chat_id: str) -> bool:
try: try:
with get_session() as db:
db.query(Chat).filter_by(user_id=f"shared-{chat_id}").delete() db.query(Chat).filter_by(user_id=f"shared-{chat_id}").delete()
return True return True
except: except:
return False return False
def update_chat_share_id_by_id( def update_chat_share_id_by_id(
self, db: Session, id: str, share_id: Optional[str] self, id: str, share_id: Optional[str]
) -> Optional[ChatModel]: ) -> Optional[ChatModel]:
try: try:
db.query(Chat).filter_by(id=id).update({"share_id": share_id}) with get_session() as db:
chat = db.get(Chat, id)
return self.get_chat_by_id(db, id) chat.share_id = share_id
db.commit()
db.refresh(chat)
return chat
except: except:
return None return None
def toggle_chat_archive_by_id(self, db: Session, id: str) -> Optional[ChatModel]: def toggle_chat_archive_by_id(self, id: str) -> Optional[ChatModel]:
try: try:
chat = self.get_chat_by_id(db, id) with get_session() as db:
chat = self.get_chat_by_id(id)
db.query(Chat).filter_by(id=id).update({"archived": not chat.archived}) db.query(Chat).filter_by(id=id).update({"archived": not chat.archived})
return self.get_chat_by_id(db, id) return self.get_chat_by_id(id)
except: except:
return None return None
def archive_all_chats_by_user_id(self, db: Session, user_id: str) -> bool: def archive_all_chats_by_user_id(self, user_id: str) -> bool:
try: try:
with get_session() as db:
db.query(Chat).filter_by(user_id=user_id).update({"archived": True}) db.query(Chat).filter_by(user_id=user_id).update({"archived": True})
return True return True
@ -199,8 +208,9 @@ class ChatTable:
return False return False
def get_archived_chat_list_by_user_id( def get_archived_chat_list_by_user_id(
self, db: Session, user_id: str, skip: int = 0, limit: int = 50 self, user_id: str, skip: int = 0, limit: int = 50
) -> List[ChatModel]: ) -> List[ChatModel]:
with get_session() as db:
all_chats = ( all_chats = (
db.query(Chat) db.query(Chat)
.filter_by(user_id=user_id, archived=True) .filter_by(user_id=user_id, archived=True)
@ -212,12 +222,12 @@ class ChatTable:
def get_chat_list_by_user_id( def get_chat_list_by_user_id(
self, self,
db: Session,
user_id: str, user_id: str,
include_archived: bool = False, include_archived: bool = False,
skip: int = 0, skip: int = 0,
limit: int = 50, limit: int = 50,
) -> List[ChatModel]: ) -> List[ChatModel]:
with get_session() as db:
query = db.query(Chat).filter_by(user_id=user_id) query = db.query(Chat).filter_by(user_id=user_id)
if not include_archived: if not include_archived:
query = query.filter_by(archived=False) query = query.filter_by(archived=False)
@ -229,8 +239,9 @@ class ChatTable:
return [ChatModel.model_validate(chat) for chat in all_chats] return [ChatModel.model_validate(chat) for chat in all_chats]
def get_chat_list_by_chat_ids( def get_chat_list_by_chat_ids(
self, db: Session, chat_ids: List[str], skip: int = 0, limit: int = 50 self, chat_ids: List[str], skip: int = 0, limit: int = 50
) -> List[ChatModel]: ) -> List[ChatModel]:
with get_session() as db:
all_chats = ( all_chats = (
db.query(Chat) db.query(Chat)
.filter(Chat.id.in_(chat_ids)) .filter(Chat.id.in_(chat_ids))
@ -240,34 +251,38 @@ class ChatTable:
) )
return [ChatModel.model_validate(chat) for chat in all_chats] return [ChatModel.model_validate(chat) for chat in all_chats]
def get_chat_by_id(self, db: Session, id: str) -> Optional[ChatModel]: def get_chat_by_id(self, id: str) -> Optional[ChatModel]:
try: try:
with get_session() as db:
chat = db.get(Chat, id) chat = db.get(Chat, id)
return ChatModel.model_validate(chat) return ChatModel.model_validate(chat)
except: except:
return None return None
def get_chat_by_share_id(self, db: Session, id: str) -> Optional[ChatModel]: def get_chat_by_share_id(self, id: str) -> Optional[ChatModel]:
try: try:
with get_session() as db:
chat = db.query(Chat).filter_by(share_id=id).first() chat = db.query(Chat).filter_by(share_id=id).first()
if chat: if chat:
return self.get_chat_by_id(db, id) return self.get_chat_by_id(id)
else: else:
return None return None
except Exception as e: except Exception as e:
return None return None
def get_chat_by_id_and_user_id( def get_chat_by_id_and_user_id(
self, db: Session, id: str, user_id: str self, id: str, user_id: str
) -> Optional[ChatModel]: ) -> Optional[ChatModel]:
try: try:
with get_session() as db:
chat = db.query(Chat).filter_by(id=id, user_id=user_id).first() chat = db.query(Chat).filter_by(id=id, user_id=user_id).first()
return ChatModel.model_validate(chat) return ChatModel.model_validate(chat)
except: except:
return None return None
def get_chats(self, db: Session, skip: int = 0, limit: int = 50) -> List[ChatModel]: def get_chats(self, skip: int = 0, limit: int = 50) -> List[ChatModel]:
with get_session() as db:
all_chats = ( all_chats = (
db.query(Chat) db.query(Chat)
# .limit(limit).offset(skip) # .limit(limit).offset(skip)
@ -275,15 +290,17 @@ class ChatTable:
) )
return [ChatModel.model_validate(chat) for chat in all_chats] return [ChatModel.model_validate(chat) for chat in all_chats]
def get_chats_by_user_id(self, db: Session, user_id: str) -> List[ChatModel]: def get_chats_by_user_id(self, user_id: str) -> List[ChatModel]:
with get_session() as db:
all_chats = ( all_chats = (
db.query(Chat).filter_by(user_id=user_id).order_by(Chat.updated_at.desc()) db.query(Chat).filter_by(user_id=user_id).order_by(Chat.updated_at.desc())
) )
return [ChatModel.model_validate(chat) for chat in all_chats] return [ChatModel.model_validate(chat) for chat in all_chats]
def get_archived_chats_by_user_id( def get_archived_chats_by_user_id(
self, db: Session, user_id: str self, user_id: str
) -> List[ChatModel]: ) -> List[ChatModel]:
with get_session() as db:
all_chats = ( all_chats = (
db.query(Chat) db.query(Chat)
.filter_by(user_id=user_id, archived=True) .filter_by(user_id=user_id, archived=True)
@ -291,34 +308,37 @@ class ChatTable:
) )
return [ChatModel.model_validate(chat) for chat in all_chats] return [ChatModel.model_validate(chat) for chat in all_chats]
def delete_chat_by_id(self, db: Session, id: str) -> bool: def delete_chat_by_id(self, id: str) -> bool:
try: try:
with get_session() as db:
db.query(Chat).filter_by(id=id).delete() db.query(Chat).filter_by(id=id).delete()
return True and self.delete_shared_chat_by_chat_id(db, id) return True and self.delete_shared_chat_by_chat_id(id)
except: except:
return False return False
def delete_chat_by_id_and_user_id(self, db: Session, id: str, user_id: str) -> bool: def delete_chat_by_id_and_user_id(self, id: str, user_id: str) -> bool:
try: try:
with get_session() as db:
db.query(Chat).filter_by(id=id, user_id=user_id).delete() db.query(Chat).filter_by(id=id, user_id=user_id).delete()
return True and self.delete_shared_chat_by_chat_id(db, id) return True and self.delete_shared_chat_by_chat_id(id)
except: except:
return False return False
def delete_chats_by_user_id(self, db: Session, user_id: str) -> bool: def delete_chats_by_user_id(self, user_id: str) -> bool:
try: try:
with get_session() as db:
self.delete_shared_chats_by_user_id(db, user_id) self.delete_shared_chats_by_user_id(user_id)
db.query(Chat).filter_by(user_id=user_id).delete() db.query(Chat).filter_by(user_id=user_id).delete()
return True return True
except: except:
return False return False
def delete_shared_chats_by_user_id(self, db: Session, user_id: str) -> bool: def delete_shared_chats_by_user_id(self, user_id: str) -> bool:
try: try:
with get_session() as db:
chats_by_user = db.query(Chat).filter_by(user_id=user_id).all() chats_by_user = db.query(Chat).filter_by(user_id=user_id).all()
shared_chat_ids = [f"shared-{chat.id}" for chat in chats_by_user] shared_chat_ids = [f"shared-{chat.id}" for chat in chats_by_user]

View File

@ -6,7 +6,7 @@ import logging
from sqlalchemy import String, Column, BigInteger from sqlalchemy import String, Column, BigInteger
from sqlalchemy.orm import Session from sqlalchemy.orm import Session
from apps.webui.internal.db import Base from apps.webui.internal.db import Base, get_session
import json import json
@ -73,7 +73,7 @@ class DocumentForm(DocumentUpdateForm):
class DocumentsTable: class DocumentsTable:
def insert_new_doc( def insert_new_doc(
self, db: Session, user_id: str, form_data: DocumentForm self, user_id: str, form_data: DocumentForm
) -> Optional[DocumentModel]: ) -> Optional[DocumentModel]:
document = DocumentModel( document = DocumentModel(
**{ **{
@ -84,6 +84,7 @@ class DocumentsTable:
) )
try: try:
with get_session() as db:
result = Document(**document.model_dump()) result = Document(**document.model_dump())
db.add(result) db.add(result)
db.commit() db.commit()
@ -95,20 +96,23 @@ class DocumentsTable:
except: except:
return None return None
def get_doc_by_name(self, db: Session, name: str) -> Optional[DocumentModel]: def get_doc_by_name(self, name: str) -> Optional[DocumentModel]:
try: try:
with get_session() as db:
document = db.query(Document).filter_by(name=name).first() document = db.query(Document).filter_by(name=name).first()
return DocumentModel.model_validate(document) if document else None return DocumentModel.model_validate(document) if document else None
except: except:
return None return None
def get_docs(self, db: Session) -> List[DocumentModel]: def get_docs(self) -> List[DocumentModel]:
with get_session() as db:
return [DocumentModel.model_validate(doc) for doc in db.query(Document).all()] return [DocumentModel.model_validate(doc) for doc in db.query(Document).all()]
def update_doc_by_name( def update_doc_by_name(
self, db: Session, name: str, form_data: DocumentUpdateForm self, name: str, form_data: DocumentUpdateForm
) -> Optional[DocumentModel]: ) -> Optional[DocumentModel]:
try: try:
with get_session() as db:
db.query(Document).filter_by(name=name).update( db.query(Document).filter_by(name=name).update(
{ {
"title": form_data.title, "title": form_data.title,
@ -116,16 +120,18 @@ class DocumentsTable:
"timestamp": int(time.time()), "timestamp": int(time.time()),
} }
) )
return self.get_doc_by_name(db, form_data.name) db.commit()
return self.get_doc_by_name(form_data.name)
except Exception as e: except Exception as e:
log.exception(e) log.exception(e)
return None return None
def update_doc_content_by_name( def update_doc_content_by_name(
self, db: Session, name: str, updated: dict self, name: str, updated: dict
) -> Optional[DocumentModel]: ) -> Optional[DocumentModel]:
try: try:
doc = self.get_doc_by_name(db, name) with get_session() as db:
doc = self.get_doc_by_name(name)
doc_content = json.loads(doc.content if doc.content else "{}") doc_content = json.loads(doc.content if doc.content else "{}")
doc_content = {**doc_content, **updated} doc_content = {**doc_content, **updated}
@ -135,14 +141,15 @@ class DocumentsTable:
"timestamp": int(time.time()), "timestamp": int(time.time()),
} }
) )
db.commit()
return self.get_doc_by_name(db, name) return self.get_doc_by_name(name)
except Exception as e: except Exception as e:
log.exception(e) log.exception(e)
return None return None
def delete_doc_by_name(self, db: Session, name: str) -> bool: def delete_doc_by_name(self, name: str) -> bool:
try: try:
with get_session() as db:
db.query(Document).filter_by(name=name).delete() db.query(Document).filter_by(name=name).delete()
return True return True
except: except:

View File

@ -6,7 +6,7 @@ import logging
from sqlalchemy import Column, String, BigInteger from sqlalchemy import Column, String, BigInteger
from sqlalchemy.orm import Session from sqlalchemy.orm import Session
from apps.webui.internal.db import JSONField, Base from apps.webui.internal.db import JSONField, Base, get_session
import json import json
@ -60,7 +60,7 @@ class FileForm(BaseModel):
class FilesTable: class FilesTable:
def insert_new_file(self, db: Session, user_id: str, form_data: FileForm) -> Optional[FileModel]: def insert_new_file(self, user_id: str, form_data: FileForm) -> Optional[FileModel]:
file = FileModel( file = FileModel(
**{ **{
**form_data.model_dump(), **form_data.model_dump(),
@ -70,6 +70,7 @@ class FilesTable:
) )
try: try:
with get_session() as db:
result = File(**file.model_dump()) result = File(**file.model_dump())
db.add(result) db.add(result)
db.commit() db.commit()
@ -82,26 +83,32 @@ class FilesTable:
print(f"Error creating tool: {e}") print(f"Error creating tool: {e}")
return None return None
def get_file_by_id(self, db: Session, id: str) -> Optional[FileModel]: def get_file_by_id(self, id: str) -> Optional[FileModel]:
try: try:
with get_session() as db:
file = db.get(File, id) file = db.get(File, id)
return FileModel.model_validate(file) return FileModel.model_validate(file)
except: except:
return None return None
def get_files(self, db: Session) -> List[FileModel]: def get_files(self) -> List[FileModel]:
with get_session() as db:
return [FileModel.model_validate(file) for file in db.query(File).all()] return [FileModel.model_validate(file) for file in db.query(File).all()]
def delete_file_by_id(self, db: Session, id: str) -> bool: def delete_file_by_id(self, id: str) -> bool:
try: try:
with get_session() as db:
db.query(File).filter_by(id=id).delete() db.query(File).filter_by(id=id).delete()
db.commit()
return True return True
except: except:
return False return False
def delete_all_files(self, db: Session) -> bool: def delete_all_files(self) -> bool:
try: try:
with get_session() as db:
db.query(File).delete() db.query(File).delete()
db.commit()
return True return True
except: except:
return False return False

View File

@ -6,7 +6,7 @@ import logging
from sqlalchemy import Column, String, Text, BigInteger, Boolean from sqlalchemy import Column, String, Text, BigInteger, Boolean
from sqlalchemy.orm import Session from sqlalchemy.orm import Session
from apps.webui.internal.db import JSONField, Base from apps.webui.internal.db import JSONField, Base, get_session
from apps.webui.models.users import Users from apps.webui.models.users import Users
import json import json
@ -87,7 +87,7 @@ class FunctionValves(BaseModel):
class FunctionsTable: class FunctionsTable:
def insert_new_function( def insert_new_function(
self, db: Session, user_id: str, type: str, form_data: FunctionForm self, user_id: str, type: str, form_data: FunctionForm
) -> Optional[FunctionModel]: ) -> Optional[FunctionModel]:
function = FunctionModel( function = FunctionModel(
**{ **{
@ -100,6 +100,7 @@ class FunctionsTable:
) )
try: try:
with get_session() as db:
result = Function(**function.model_dump()) result = Function(**function.model_dump())
db.add(result) db.add(result)
db.commit() db.commit()
@ -112,8 +113,9 @@ class FunctionsTable:
print(f"Error creating tool: {e}") print(f"Error creating tool: {e}")
return None return None
def get_function_by_id(self, db: Session, id: str) -> Optional[FunctionModel]: def get_function_by_id(self, id: str) -> Optional[FunctionModel]:
try: try:
with get_session() as db:
function = db.get(Function, id) function = db.get(Function, id)
return FunctionModel.model_validate(function) return FunctionModel.model_validate(function)
except: except:
@ -121,35 +123,40 @@ class FunctionsTable:
def get_functions(self, active_only=False) -> List[FunctionModel]: def get_functions(self, active_only=False) -> List[FunctionModel]:
if active_only: if active_only:
with get_session() as db:
return [ return [
FunctionModel(**model_to_dict(function)) FunctionModel.model_validate(function)
for function in Function.select().where(Function.is_active == True) for function in db.query(Function).filter_by(is_active=True).all()
] ]
else: else:
with get_session() as db:
return [ return [
FunctionModel(**model_to_dict(function)) FunctionModel.model_validate(function)
for function in Function.select() for function in db.query(Function).all()
] ]
def get_functions_by_type( def get_functions_by_type(
self, type: str, active_only=False self, type: str, active_only=False
) -> List[FunctionModel]: ) -> List[FunctionModel]:
if active_only: if active_only:
with get_session() as db:
return [ return [
FunctionModel(**model_to_dict(function)) FunctionModel.model_validate(function)
for function in Function.select().where( for function in db.query(Function).filter_by(
Function.type == type, Function.is_active == True type=type, is_active=True
) ).all()
] ]
else: else:
with get_session() as db:
return [ return [
FunctionModel(**model_to_dict(function)) FunctionModel.model_validate(function)
for function in Function.select().where(Function.type == type) for function in db.query(Function).filter_by(type=type).all()
] ]
def get_function_valves_by_id(self, id: str) -> Optional[dict]: def get_function_valves_by_id(self, id: str) -> Optional[dict]:
try: try:
function = Function.get(Function.id == id) with get_session() as db:
function = db.get(Function, id)
return function.valves if function.valves else {} return function.valves if function.valves else {}
except Exception as e: except Exception as e:
print(f"An error occurred: {e}") print(f"An error occurred: {e}")
@ -159,14 +166,12 @@ class FunctionsTable:
self, id: str, valves: dict self, id: str, valves: dict
) -> Optional[FunctionValves]: ) -> Optional[FunctionValves]:
try: try:
query = Function.update( with get_session() as db:
**{"valves": valves}, db.query(Function).filter_by(id=id).update(
updated_at=int(time.time()), {"valves": valves, "updated_at": int(time.time())}
).where(Function.id == id) )
query.execute() db.commit()
return self.get_function_by_id(id)
function = Function.get(Function.id == id)
return FunctionValves(**model_to_dict(function))
except: except:
return None return None
@ -214,29 +219,31 @@ class FunctionsTable:
def update_function_by_id(self, id: str, updated: dict) -> Optional[FunctionModel]: def update_function_by_id(self, id: str, updated: dict) -> Optional[FunctionModel]:
try: try:
with get_session() as db:
db.query(Function).filter_by(id=id).update({ db.query(Function).filter_by(id=id).update({
**updated, **updated,
"updated_at": int(time.time()), "updated_at": int(time.time()),
}) })
return self.get_function_by_id(db, id) db.commit()
return self.get_function_by_id(id)
except: except:
return None return None
def deactivate_all_functions(self) -> Optional[bool]: def deactivate_all_functions(self) -> Optional[bool]:
try: try:
query = Function.update( with get_session() as db:
**{"is_active": False}, db.query(Function).update({
updated_at=int(time.time()), "is_active": False,
) "updated_at": int(time.time()),
})
query.execute() db.commit()
return True return True
except: except:
return None return None
def delete_function_by_id(self, db: Session, id: str) -> bool: def delete_function_by_id(self, id: str) -> bool:
try: try:
with get_session() as db:
db.query(Function).filter_by(id=id).delete() db.query(Function).filter_by(id=id).delete()
return True return True
except: except:

View File

@ -4,7 +4,7 @@ from typing import List, Union, Optional
from sqlalchemy import Column, String, BigInteger from sqlalchemy import Column, String, BigInteger
from sqlalchemy.orm import Session from sqlalchemy.orm import Session
from apps.webui.internal.db import Base from apps.webui.internal.db import Base, get_session
from apps.webui.models.chats import Chats from apps.webui.models.chats import Chats
import time import time
@ -44,7 +44,6 @@ class MemoriesTable:
def insert_new_memory( def insert_new_memory(
self, self,
db: Session,
user_id: str, user_id: str,
content: str, content: str,
) -> Optional[MemoryModel]: ) -> Optional[MemoryModel]:
@ -59,7 +58,8 @@ class MemoriesTable:
"updated_at": int(time.time()), "updated_at": int(time.time()),
} }
) )
result = Memory(**memory.dict()) with get_session() as db:
result = Memory(**memory.model_dump())
db.add(result) db.add(result)
db.commit() db.commit()
db.refresh(result) db.refresh(result)
@ -70,41 +70,46 @@ class MemoriesTable:
def update_memory_by_id( def update_memory_by_id(
self, self,
db: Session,
id: str, id: str,
content: str, content: str,
) -> Optional[MemoryModel]: ) -> Optional[MemoryModel]:
try: try:
with get_session() as db:
db.query(Memory).filter_by(id=id).update( db.query(Memory).filter_by(id=id).update(
{"content": content, "updated_at": int(time.time())} {"content": content, "updated_at": int(time.time())}
) )
return self.get_memory_by_id(db, id) db.commit()
return self.get_memory_by_id(id)
except: except:
return None return None
def get_memories(self, db: Session) -> List[MemoryModel]: def get_memories(self) -> List[MemoryModel]:
try: try:
with get_session() as db:
memories = db.query(Memory).all() memories = db.query(Memory).all()
return [MemoryModel.model_validate(memory) for memory in memories] return [MemoryModel.model_validate(memory) for memory in memories]
except: except:
return None return None
def get_memories_by_user_id(self, db: Session, user_id: str) -> List[MemoryModel]: def get_memories_by_user_id(self, user_id: str) -> List[MemoryModel]:
try: try:
with get_session() as db:
memories = db.query(Memory).filter_by(user_id=user_id).all() memories = db.query(Memory).filter_by(user_id=user_id).all()
return [MemoryModel.model_validate(memory) for memory in memories] return [MemoryModel.model_validate(memory) for memory in memories]
except: except:
return None return None
def get_memory_by_id(self, db: Session, id: str) -> Optional[MemoryModel]: def get_memory_by_id(self, id: str) -> Optional[MemoryModel]:
try: try:
with get_session() as db:
memory = db.get(Memory, id) memory = db.get(Memory, id)
return MemoryModel.model_validate(memory) return MemoryModel.model_validate(memory)
except: except:
return None return None
def delete_memory_by_id(self, db: Session, id: str) -> bool: def delete_memory_by_id(self, id: str) -> bool:
try: try:
with get_session() as db:
db.query(Memory).filter_by(id=id).delete() db.query(Memory).filter_by(id=id).delete()
return True return True
@ -113,6 +118,7 @@ class MemoriesTable:
def delete_memories_by_user_id(self, db: Session, user_id: str) -> bool: def delete_memories_by_user_id(self, db: Session, user_id: str) -> bool:
try: try:
with get_session() as db:
db.query(Memory).filter_by(user_id=user_id).delete() db.query(Memory).filter_by(user_id=user_id).delete()
return True return True
except: except:
@ -122,6 +128,7 @@ class MemoriesTable:
self, db: Session, id: str, user_id: str self, db: Session, id: str, user_id: str
) -> bool: ) -> bool:
try: try:
with get_session() as db:
db.query(Memory).filter_by(id=id, user_id=user_id).delete() db.query(Memory).filter_by(id=id, user_id=user_id).delete()
return True return True
except: except:

View File

@ -6,7 +6,7 @@ from pydantic import BaseModel, ConfigDict
from sqlalchemy import String, Column, BigInteger from sqlalchemy import String, Column, BigInteger
from sqlalchemy.orm import Session from sqlalchemy.orm import Session
from apps.webui.internal.db import Base, JSONField from apps.webui.internal.db import Base, JSONField, get_session
from typing import List, Union, Optional from typing import List, Union, Optional
from config import SRC_LOG_LEVELS from config import SRC_LOG_LEVELS
@ -78,8 +78,6 @@ class Model(Base):
class ModelModel(BaseModel): class ModelModel(BaseModel):
model_config = ConfigDict(from_attributes=True)
id: str id: str
user_id: str user_id: str
base_model_id: Optional[str] = None base_model_id: Optional[str] = None
@ -91,6 +89,8 @@ class ModelModel(BaseModel):
updated_at: int # timestamp in epoch updated_at: int # timestamp in epoch
created_at: int # timestamp in epoch created_at: int # timestamp in epoch
model_config = ConfigDict(from_attributes=True)
#################### ####################
# Forms # Forms
@ -116,7 +116,7 @@ class ModelForm(BaseModel):
class ModelsTable: class ModelsTable:
def insert_new_model( def insert_new_model(
self, db: Session, form_data: ModelForm, user_id: str self, form_data: ModelForm, user_id: str
) -> Optional[ModelModel]: ) -> Optional[ModelModel]:
model = ModelModel( model = ModelModel(
**{ **{
@ -127,7 +127,8 @@ class ModelsTable:
} }
) )
try: try:
result = Model(**model.dict()) with get_session() as db:
result = Model(**model.model_dump())
db.add(result) db.add(result)
db.commit() db.commit()
db.refresh(result) db.refresh(result)
@ -140,21 +141,24 @@ class ModelsTable:
print(e) print(e)
return None return None
def get_all_models(self, db: Session) -> List[ModelModel]: def get_all_models(self) -> List[ModelModel]:
with get_session() as db:
return [ModelModel.model_validate(model) for model in db.query(Model).all()] return [ModelModel.model_validate(model) for model in db.query(Model).all()]
def get_model_by_id(self, db: Session, id: str) -> Optional[ModelModel]: def get_model_by_id(self, id: str) -> Optional[ModelModel]:
try: try:
with get_session() as db:
model = db.get(Model, id) model = db.get(Model, id)
return ModelModel.model_validate(model) return ModelModel.model_validate(model)
except: except:
return None return None
def update_model_by_id( def update_model_by_id(
self, db: Session, id: str, model: ModelForm self, id: str, model: ModelForm
) -> Optional[ModelModel]: ) -> Optional[ModelModel]:
try: try:
# update only the fields that are present in the model # update only the fields that are present in the model
with get_session() as db:
model = db.query(Model).get(id) model = db.query(Model).get(id)
model.update(**model.model_dump()) model.update(**model.model_dump())
db.commit() db.commit()
@ -165,8 +169,9 @@ class ModelsTable:
return None return None
def delete_model_by_id(self, db: Session, id: str) -> bool: def delete_model_by_id(self, id: str) -> bool:
try: try:
with get_session() as db:
db.query(Model).filter_by(id=id).delete() db.query(Model).filter_by(id=id).delete()
return True return True
except: except:

View File

@ -5,7 +5,7 @@ import time
from sqlalchemy import String, Column, BigInteger from sqlalchemy import String, Column, BigInteger
from sqlalchemy.orm import Session from sqlalchemy.orm import Session
from apps.webui.internal.db import Base from apps.webui.internal.db import Base, get_session
import json import json
@ -48,8 +48,9 @@ class PromptForm(BaseModel):
class PromptsTable: class PromptsTable:
def insert_new_prompt( def insert_new_prompt(
self, db: Session, user_id: str, form_data: PromptForm self, user_id: str, form_data: PromptForm
) -> Optional[PromptModel]: ) -> Optional[PromptModel]:
with get_session() as db:
prompt = PromptModel( prompt = PromptModel(
**{ **{
"user_id": user_id, "user_id": user_id,
@ -72,32 +73,35 @@ class PromptsTable:
except Exception as e: except Exception as e:
return None return None
def get_prompt_by_command(self, db: Session, command: str) -> Optional[PromptModel]: def get_prompt_by_command(self, command: str) -> Optional[PromptModel]:
with get_session() as db:
try: try:
prompt = db.query(Prompt).filter_by(command=command).first() prompt = db.query(Prompt).filter_by(command=command).first()
return PromptModel.model_validate(prompt) return PromptModel.model_validate(prompt)
except: except:
return None return None
def get_prompts(self, db: Session) -> List[PromptModel]: def get_prompts(self) -> List[PromptModel]:
with get_session() as db:
return [PromptModel.model_validate(prompt) for prompt in db.query(Prompt).all()] return [PromptModel.model_validate(prompt) for prompt in db.query(Prompt).all()]
def update_prompt_by_command( def update_prompt_by_command(
self, db: Session, command: str, form_data: PromptForm self, command: str, form_data: PromptForm
) -> Optional[PromptModel]: ) -> Optional[PromptModel]:
with get_session() as db:
try: try:
db.query(Prompt).filter_by(command=command).update( prompt = db.query(Prompt).filter_by(command=command).first()
{ prompt.title = form_data.title
"title": form_data.title, prompt.content = form_data.content
"content": form_data.content, prompt.timestamp = int(time.time())
"timestamp": int(time.time()), db.commit()
} return prompt
) # return self.get_prompt_by_command(command)
return self.get_prompt_by_command(db, command)
except: except:
return None return None
def delete_prompt_by_command(self, db: Session, command: str) -> bool: def delete_prompt_by_command(self, command: str) -> bool:
with get_session() as db:
try: try:
db.query(Prompt).filter_by(command=command).delete() db.query(Prompt).filter_by(command=command).delete()
return True return True

View File

@ -9,7 +9,7 @@ import logging
from sqlalchemy import String, Column, BigInteger from sqlalchemy import String, Column, BigInteger
from sqlalchemy.orm import Session from sqlalchemy.orm import Session
from apps.webui.internal.db import Base from apps.webui.internal.db import Base, get_session
from config import SRC_LOG_LEVELS from config import SRC_LOG_LEVELS
@ -80,12 +80,13 @@ class ChatTagsResponse(BaseModel):
class TagTable: class TagTable:
def insert_new_tag( def insert_new_tag(
self, db: Session, name: str, user_id: str self, name: str, user_id: str
) -> Optional[TagModel]: ) -> Optional[TagModel]:
id = str(uuid.uuid4()) id = str(uuid.uuid4())
tag = TagModel(**{"id": id, "user_id": user_id, "name": name}) tag = TagModel(**{"id": id, "user_id": user_id, "name": name})
try: try:
result = Tag(**tag.dict()) with get_session() as db:
result = Tag(**tag.model_dump())
db.add(result) db.add(result)
db.commit() db.commit()
db.refresh(result) db.refresh(result)
@ -97,20 +98,21 @@ class TagTable:
return None return None
def get_tag_by_name_and_user_id( def get_tag_by_name_and_user_id(
self, db: Session, name: str, user_id: str self, name: str, user_id: str
) -> Optional[TagModel]: ) -> Optional[TagModel]:
try: try:
with get_session() as db:
tag = db.query(Tag).filter(name=name, user_id=user_id).first() tag = db.query(Tag).filter(name=name, user_id=user_id).first()
return TagModel.model_validate(tag) return TagModel.model_validate(tag)
except Exception as e: except Exception as e:
return None return None
def add_tag_to_chat( def add_tag_to_chat(
self, db: Session, user_id: str, form_data: ChatIdTagForm self, user_id: str, form_data: ChatIdTagForm
) -> Optional[ChatIdTagModel]: ) -> Optional[ChatIdTagModel]:
tag = self.get_tag_by_name_and_user_id(db, form_data.tag_name, user_id) tag = self.get_tag_by_name_and_user_id(form_data.tag_name, user_id)
if tag == None: if tag == None:
tag = self.insert_new_tag(db, form_data.tag_name, user_id) tag = self.insert_new_tag(form_data.tag_name, user_id)
id = str(uuid.uuid4()) id = str(uuid.uuid4())
chatIdTag = ChatIdTagModel( chatIdTag = ChatIdTagModel(
@ -123,7 +125,8 @@ class TagTable:
} }
) )
try: try:
result = ChatIdTag(**chatIdTag.dict()) with get_session() as db:
result = ChatIdTag(**chatIdTag.model_dump())
db.add(result) db.add(result)
db.commit() db.commit()
db.refresh(result) db.refresh(result)
@ -134,7 +137,8 @@ class TagTable:
except: except:
return None return None
def get_tags_by_user_id(self, db: Session, user_id: str) -> List[TagModel]: def get_tags_by_user_id(self, user_id: str) -> List[TagModel]:
with get_session() as db:
tag_names = [ tag_names = [
chat_id_tag.tag_name chat_id_tag.tag_name
for chat_id_tag in ( for chat_id_tag in (
@ -156,8 +160,9 @@ class TagTable:
] ]
def get_tags_by_chat_id_and_user_id( def get_tags_by_chat_id_and_user_id(
self, db: Session, chat_id: str, user_id: str self, chat_id: str, user_id: str
) -> List[TagModel]: ) -> List[TagModel]:
with get_session() as db:
tag_names = [ tag_names = [
chat_id_tag.tag_name chat_id_tag.tag_name
for chat_id_tag in ( for chat_id_tag in (
@ -179,8 +184,9 @@ class TagTable:
] ]
def get_chat_ids_by_tag_name_and_user_id( def get_chat_ids_by_tag_name_and_user_id(
self, db: Session, tag_name: str, user_id: str self, tag_name: str, user_id: str
) -> List[ChatIdTagModel]: ) -> List[ChatIdTagModel]:
with get_session() as db:
return [ return [
ChatIdTagModel.model_validate(chat_id_tag) ChatIdTagModel.model_validate(chat_id_tag)
for chat_id_tag in ( for chat_id_tag in (
@ -192,23 +198,26 @@ class TagTable:
] ]
def count_chat_ids_by_tag_name_and_user_id( def count_chat_ids_by_tag_name_and_user_id(
self, db: Session, tag_name: str, user_id: str self, tag_name: str, user_id: str
) -> int: ) -> int:
with get_session() as db:
return db.query(ChatIdTag).filter_by(tag_name=tag_name, user_id=user_id).count() return db.query(ChatIdTag).filter_by(tag_name=tag_name, user_id=user_id).count()
def delete_tag_by_tag_name_and_user_id( def delete_tag_by_tag_name_and_user_id(
self, db: Session, tag_name: str, user_id: str self, tag_name: str, user_id: str
) -> bool: ) -> bool:
try: try:
with get_session() as db:
res = ( res = (
db.query(ChatIdTag) db.query(ChatIdTag)
.filter_by(tag_name=tag_name, user_id=user_id) .filter_by(tag_name=tag_name, user_id=user_id)
.delete() .delete()
) )
log.debug(f"res: {res}") log.debug(f"res: {res}")
db.commit()
tag_count = self.count_chat_ids_by_tag_name_and_user_id( tag_count = self.count_chat_ids_by_tag_name_and_user_id(
db, tag_name, user_id tag_name, user_id
) )
if tag_count == 0: if tag_count == 0:
# Remove tag item from Tag col as well # Remove tag item from Tag col as well
@ -219,18 +228,20 @@ class TagTable:
return False return False
def delete_tag_by_tag_name_and_chat_id_and_user_id( def delete_tag_by_tag_name_and_chat_id_and_user_id(
self, db: Session, tag_name: str, chat_id: str, user_id: str self, tag_name: str, chat_id: str, user_id: str
) -> bool: ) -> bool:
try: try:
with get_session() as db:
res = ( res = (
db.query(ChatIdTag) db.query(ChatIdTag)
.filter_by(tag_name=tag_name, chat_id=chat_id, user_id=user_id) .filter_by(tag_name=tag_name, chat_id=chat_id, user_id=user_id)
.delete() .delete()
) )
log.debug(f"res: {res}") log.debug(f"res: {res}")
db.commit()
tag_count = self.count_chat_ids_by_tag_name_and_user_id( tag_count = self.count_chat_ids_by_tag_name_and_user_id(
db, tag_name, user_id tag_name, user_id
) )
if tag_count == 0: if tag_count == 0:
# Remove tag item from Tag col as well # Remove tag item from Tag col as well
@ -242,13 +253,13 @@ class TagTable:
return False return False
def delete_tags_by_chat_id_and_user_id( def delete_tags_by_chat_id_and_user_id(
self, db: Session, chat_id: str, user_id: str self, chat_id: str, user_id: str
) -> bool: ) -> bool:
tags = self.get_tags_by_chat_id_and_user_id(db, chat_id, user_id) tags = self.get_tags_by_chat_id_and_user_id(chat_id, user_id)
for tag in tags: for tag in tags:
self.delete_tag_by_tag_name_and_chat_id_and_user_id( self.delete_tag_by_tag_name_and_chat_id_and_user_id(
db, tag.tag_name, chat_id, user_id tag.tag_name, chat_id, user_id
) )
return True return True

View File

@ -5,7 +5,7 @@ import logging
from sqlalchemy import String, Column, BigInteger from sqlalchemy import String, Column, BigInteger
from sqlalchemy.orm import Session from sqlalchemy.orm import Session
from apps.webui.internal.db import Base, JSONField from apps.webui.internal.db import Base, JSONField, get_session
from apps.webui.models.users import Users from apps.webui.models.users import Users
import json import json
@ -82,7 +82,7 @@ class ToolValves(BaseModel):
class ToolsTable: class ToolsTable:
def insert_new_tool( def insert_new_tool(
self, db: Session, user_id: str, form_data: ToolForm, specs: List[dict] self, user_id: str, form_data: ToolForm, specs: List[dict]
) -> Optional[ToolModel]: ) -> Optional[ToolModel]:
tool = ToolModel( tool = ToolModel(
**{ **{
@ -95,7 +95,8 @@ class ToolsTable:
) )
try: try:
result = Tool(**tool.dict()) with get_session() as db:
result = Tool(**tool.model_dump())
db.add(result) db.add(result)
db.commit() db.commit()
db.refresh(result) db.refresh(result)
@ -107,19 +108,22 @@ class ToolsTable:
print(f"Error creating tool: {e}") print(f"Error creating tool: {e}")
return None return None
def get_tool_by_id(self, db: Session, id: str) -> Optional[ToolModel]: def get_tool_by_id(self, id: str) -> Optional[ToolModel]:
try: try:
with get_session() as db:
tool = db.get(Tool, id) tool = db.get(Tool, id)
return ToolModel.model_validate(tool) return ToolModel.model_validate(tool)
except: except:
return None return None
def get_tools(self, db: Session) -> List[ToolModel]: def get_tools(self) -> List[ToolModel]:
with get_session() as db:
return [ToolModel.model_validate(tool) for tool in db.query(Tool).all()] return [ToolModel.model_validate(tool) for tool in db.query(Tool).all()]
def get_tool_valves_by_id(self, id: str) -> Optional[dict]: def get_tool_valves_by_id(self, id: str) -> Optional[dict]:
try: try:
tool = Tool.get(Tool.id == id) with get_session() as db:
tool = db.get(Tool, id)
return tool.valves if tool.valves else {} return tool.valves if tool.valves else {}
except Exception as e: except Exception as e:
print(f"An error occurred: {e}") print(f"An error occurred: {e}")
@ -127,14 +131,12 @@ class ToolsTable:
def update_tool_valves_by_id(self, id: str, valves: dict) -> Optional[ToolValves]: def update_tool_valves_by_id(self, id: str, valves: dict) -> Optional[ToolValves]:
try: try:
query = Tool.update( with get_session() as db:
**{"valves": valves}, db.query(Tool).filter_by(id=id).update(
updated_at=int(time.time()), {"valves": valves, "updated_at": int(time.time())}
).where(Tool.id == id) )
query.execute() db.commit()
return self.get_tool_by_id(id)
tool = Tool.get(Tool.id == id)
return ToolValves(**model_to_dict(tool))
except: except:
return None return None
@ -172,8 +174,7 @@ class ToolsTable:
user_settings["tools"]["valves"][id] = valves user_settings["tools"]["valves"][id] = valves
# Update the user settings in the database # Update the user settings in the database
query = Users.update_user_by_id(user_id, {"settings": user_settings}) Users.update_user_by_id(user_id, {"settings": user_settings})
query.execute()
return user_settings["tools"]["valves"][id] return user_settings["tools"]["valves"][id]
except Exception as e: except Exception as e:
@ -182,15 +183,18 @@ class ToolsTable:
def update_tool_by_id(self, id: str, updated: dict) -> Optional[ToolModel]: def update_tool_by_id(self, id: str, updated: dict) -> Optional[ToolModel]:
try: try:
with get_session() as db:
db.query(Tool).filter_by(id=id).update( db.query(Tool).filter_by(id=id).update(
{**updated, "updated_at": int(time.time())} {**updated, "updated_at": int(time.time())}
) )
return self.get_tool_by_id(db, id) db.commit()
return self.get_tool_by_id(id)
except: except:
return None return None
def delete_tool_by_id(self, db: Session, id: str) -> bool: def delete_tool_by_id(self, id: str) -> bool:
try: try:
with get_session() as db:
db.query(Tool).filter_by(id=id).delete() db.query(Tool).filter_by(id=id).delete()
return True return True
except: except:

View File

@ -7,7 +7,7 @@ from sqlalchemy.orm import Session
from utils.misc import get_gravatar_url from utils.misc import get_gravatar_url
from apps.webui.internal.db import Base, JSONField from apps.webui.internal.db import Base, JSONField, get_session
from apps.webui.models.chats import Chats from apps.webui.models.chats import Chats
#################### ####################
@ -42,8 +42,6 @@ class UserSettings(BaseModel):
class UserModel(BaseModel): class UserModel(BaseModel):
model_config = ConfigDict(from_attributes=True)
id: str id: str
name: str name: str
email: str email: str
@ -60,6 +58,8 @@ class UserModel(BaseModel):
oauth_sub: Optional[str] = None oauth_sub: Optional[str] = None
model_config = ConfigDict(from_attributes=True)
#################### ####################
# Forms # Forms
@ -82,7 +82,6 @@ class UsersTable:
def insert_new_user( def insert_new_user(
self, self,
db: Session,
id: str, id: str,
name: str, name: str,
email: str, email: str,
@ -90,6 +89,7 @@ class UsersTable:
role: str = "pending", role: str = "pending",
oauth_sub: Optional[str] = None, oauth_sub: Optional[str] = None,
) -> Optional[UserModel]: ) -> Optional[UserModel]:
with get_session() as db:
user = UserModel( user = UserModel(
**{ **{
"id": id, "id": id,
@ -112,21 +112,24 @@ class UsersTable:
else: else:
return None return None
def get_user_by_id(self, db: Session, id: str) -> Optional[UserModel]: def get_user_by_id(self, id: str) -> Optional[UserModel]:
with get_session() as db:
try: try:
user = db.query(User).filter_by(id=id).first() user = db.query(User).filter_by(id=id).first()
return UserModel.model_validate(user) return UserModel.model_validate(user)
except Exception as e: except Exception as e:
return None return None
def get_user_by_api_key(self, db: Session, api_key: str) -> Optional[UserModel]: def get_user_by_api_key(self, api_key: str) -> Optional[UserModel]:
with get_session() as db:
try: try:
user = db.query(User).filter_by(api_key=api_key).first() user = db.query(User).filter_by(api_key=api_key).first()
return UserModel.model_validate(user) return UserModel.model_validate(user)
except: except:
return None return None
def get_user_by_email(self, db: Session, email: str) -> Optional[UserModel]: def get_user_by_email(self, email: str) -> Optional[UserModel]:
with get_session() as db:
try: try:
user = db.query(User).filter_by(email=email).first() user = db.query(User).filter_by(email=email).first()
return UserModel.model_validate(user) return UserModel.model_validate(user)
@ -134,13 +137,15 @@ class UsersTable:
return None return None
def get_user_by_oauth_sub(self, sub: str) -> Optional[UserModel]: def get_user_by_oauth_sub(self, sub: str) -> Optional[UserModel]:
with get_session() as db:
try: try:
user = User.get(User.oauth_sub == sub) user = db.query(User).filter_by(oauth_sub=sub).first()
return UserModel(**model_to_dict(user)) return UserModel.model_validate(user)
except: except:
return None return None
def get_users(self, db: Session, skip: int = 0, limit: int = 50) -> List[UserModel]: def get_users(self, skip: int = 0, limit: int = 50) -> List[UserModel]:
with get_session() as db:
users = ( users = (
db.query(User) db.query(User)
# .offset(skip).limit(limit) # .offset(skip).limit(limit)
@ -148,10 +153,12 @@ class UsersTable:
) )
return [UserModel.model_validate(user) for user in users] return [UserModel.model_validate(user) for user in users]
def get_num_users(self, db: Session) -> Optional[int]: def get_num_users(self) -> Optional[int]:
with get_session() as db:
return db.query(User).count() return db.query(User).count()
def get_first_user(self, db: Session) -> UserModel: def get_first_user(self) -> UserModel:
with get_session() as db:
try: try:
user = db.query(User).order_by(User.created_at).first() user = db.query(User).order_by(User.created_at).first()
return UserModel.model_validate(user) return UserModel.model_validate(user)
@ -159,8 +166,9 @@ class UsersTable:
return None return None
def update_user_role_by_id( def update_user_role_by_id(
self, db: Session, id: str, role: str self, id: str, role: str
) -> Optional[UserModel]: ) -> Optional[UserModel]:
with get_session() as db:
try: try:
db.query(User).filter_by(id=id).update({"role": role}) db.query(User).filter_by(id=id).update({"role": role})
db.commit() db.commit()
@ -171,8 +179,9 @@ class UsersTable:
return None return None
def update_user_profile_image_url_by_id( def update_user_profile_image_url_by_id(
self, db: Session, id: str, profile_image_url: str self, id: str, profile_image_url: str
) -> Optional[UserModel]: ) -> Optional[UserModel]:
with get_session() as db:
try: try:
db.query(User).filter_by(id=id).update( db.query(User).filter_by(id=id).update(
{"profile_image_url": profile_image_url} {"profile_image_url": profile_image_url}
@ -185,8 +194,9 @@ class UsersTable:
return None return None
def update_user_last_active_by_id( def update_user_last_active_by_id(
self, db: Session, id: str self, id: str
) -> Optional[UserModel]: ) -> Optional[UserModel]:
with get_session() as db:
try: try:
db.query(User).filter_by(id=id).update({"last_active_at": int(time.time())}) db.query(User).filter_by(id=id).update({"last_active_at": int(time.time())})
@ -196,8 +206,9 @@ class UsersTable:
return None return None
def update_user_oauth_sub_by_id( def update_user_oauth_sub_by_id(
self, db: Session, id: str, oauth_sub: str self, id: str, oauth_sub: str
) -> Optional[UserModel]: ) -> Optional[UserModel]:
with get_session() as db:
try: try:
db.query(User).filter_by(id=id).update({"oauth_sub": oauth_sub}) db.query(User).filter_by(id=id).update({"oauth_sub": oauth_sub})
@ -207,8 +218,9 @@ class UsersTable:
return None return None
def update_user_by_id( def update_user_by_id(
self, db: Session, id: str, updated: dict self, id: str, updated: dict
) -> Optional[UserModel]: ) -> Optional[UserModel]:
with get_session() as db:
try: try:
db.query(User).filter_by(id=id).update(updated) db.query(User).filter_by(id=id).update(updated)
db.commit() db.commit()
@ -219,10 +231,11 @@ class UsersTable:
except Exception as e: except Exception as e:
return None return None
def delete_user_by_id(self, db: Session, id: str) -> bool: def delete_user_by_id(self, id: str) -> bool:
with get_session() as db:
try: try:
# Delete User Chats # Delete User Chats
result = Chats.delete_chats_by_user_id(db, id) result = Chats.delete_chats_by_user_id(id)
if result: if result:
# Delete User # Delete User
@ -235,7 +248,8 @@ class UsersTable:
except: except:
return False return False
def update_user_api_key_by_id(self, db: Session, id: str, api_key: str) -> str: def update_user_api_key_by_id(self, id: str, api_key: str) -> str:
with get_session() as db:
try: try:
result = db.query(User).filter_by(id=id).update({"api_key": api_key}) result = db.query(User).filter_by(id=id).update({"api_key": api_key})
db.commit() db.commit()
@ -243,7 +257,8 @@ class UsersTable:
except: except:
return False return False
def get_user_api_key_by_id(self, db: Session, id: str) -> Optional[str]: def get_user_api_key_by_id(self, id: str) -> Optional[str]:
with get_session() as db:
try: try:
user = db.query(User).filter_by(id=id).first() user = db.query(User).filter_by(id=id).first()
return user.api_key return user.api_key

View File

@ -10,7 +10,6 @@ import re
import uuid import uuid
import csv import csv
from apps.webui.internal.db import get_db
from apps.webui.models.auths import ( from apps.webui.models.auths import (
SigninForm, SigninForm,
SignupForm, SignupForm,
@ -80,12 +79,10 @@ async def get_session_user(
@router.post("/update/profile", response_model=UserResponse) @router.post("/update/profile", response_model=UserResponse)
async def update_profile( async def update_profile(
form_data: UpdateProfileForm, form_data: UpdateProfileForm,
session_user=Depends(get_current_user), session_user=Depends(get_current_user)
db=Depends(get_db),
): ):
if session_user: if session_user:
user = Users.update_user_by_id( user = Users.update_user_by_id(
db,
session_user.id, session_user.id,
{"profile_image_url": form_data.profile_image_url, "name": form_data.name}, {"profile_image_url": form_data.profile_image_url, "name": form_data.name},
) )
@ -105,17 +102,16 @@ async def update_profile(
@router.post("/update/password", response_model=bool) @router.post("/update/password", response_model=bool)
async def update_password( async def update_password(
form_data: UpdatePasswordForm, form_data: UpdatePasswordForm,
session_user=Depends(get_current_user), session_user=Depends(get_current_user)
db=Depends(get_db),
): ):
if WEBUI_AUTH_TRUSTED_EMAIL_HEADER: if WEBUI_AUTH_TRUSTED_EMAIL_HEADER:
raise HTTPException(400, detail=ERROR_MESSAGES.ACTION_PROHIBITED) raise HTTPException(400, detail=ERROR_MESSAGES.ACTION_PROHIBITED)
if session_user: if session_user:
user = Auths.authenticate_user(db, session_user.email, form_data.password) user = Auths.authenticate_user(session_user.email, form_data.password)
if user: if user:
hashed = get_password_hash(form_data.new_password) hashed = get_password_hash(form_data.new_password)
return Auths.update_user_password_by_id(db, user.id, hashed) return Auths.update_user_password_by_id(user.id, hashed)
else: else:
raise HTTPException(400, detail=ERROR_MESSAGES.INVALID_PASSWORD) raise HTTPException(400, detail=ERROR_MESSAGES.INVALID_PASSWORD)
else: else:
@ -128,7 +124,7 @@ async def update_password(
@router.post("/signin", response_model=SigninResponse) @router.post("/signin", response_model=SigninResponse)
async def signin(request: Request, response: Response, form_data: SigninForm, db=Depends(get_db)): async def signin(request: Request, response: Response, form_data: SigninForm):
if WEBUI_AUTH_TRUSTED_EMAIL_HEADER: if WEBUI_AUTH_TRUSTED_EMAIL_HEADER:
if WEBUI_AUTH_TRUSTED_EMAIL_HEADER not in request.headers: if WEBUI_AUTH_TRUSTED_EMAIL_HEADER not in request.headers:
raise HTTPException(400, detail=ERROR_MESSAGES.INVALID_TRUSTED_HEADER) raise HTTPException(400, detail=ERROR_MESSAGES.INVALID_TRUSTED_HEADER)
@ -139,34 +135,32 @@ async def signin(request: Request, response: Response, form_data: SigninForm, db
trusted_name = request.headers.get( trusted_name = request.headers.get(
WEBUI_AUTH_TRUSTED_NAME_HEADER, trusted_email WEBUI_AUTH_TRUSTED_NAME_HEADER, trusted_email
) )
if not Users.get_user_by_email(db, trusted_email.lower()): if not Users.get_user_by_email(trusted_email.lower()):
await signup( await signup(
request, request,
SignupForm( SignupForm(
email=trusted_email, password=str(uuid.uuid4()), name=trusted_name email=trusted_email, password=str(uuid.uuid4()), name=trusted_name
), ),
db,
) )
user = Auths.authenticate_user_by_trusted_header(db, trusted_email) user = Auths.authenticate_user_by_trusted_header(trusted_email)
elif WEBUI_AUTH == False: elif WEBUI_AUTH == False:
admin_email = "admin@localhost" admin_email = "admin@localhost"
admin_password = "admin" admin_password = "admin"
if Users.get_user_by_email(db, admin_email.lower()): if Users.get_user_by_email(admin_email.lower()):
user = Auths.authenticate_user(db, admin_email.lower(), admin_password) user = Auths.authenticate_user(admin_email.lower(), admin_password)
else: else:
if Users.get_num_users(db) != 0: if Users.get_num_users() != 0:
raise HTTPException(400, detail=ERROR_MESSAGES.EXISTING_USERS) raise HTTPException(400, detail=ERROR_MESSAGES.EXISTING_USERS)
await signup( await signup(
request, request,
SignupForm(email=admin_email, password=admin_password, name="User"), SignupForm(email=admin_email, password=admin_password, name="User"),
db,
) )
user = Auths.authenticate_user(db, admin_email.lower(), admin_password) user = Auths.authenticate_user(admin_email.lower(), admin_password)
else: else:
user = Auths.authenticate_user(db, form_data.email.lower(), form_data.password) user = Auths.authenticate_user(form_data.email.lower(), form_data.password)
if user: if user:
token = create_token( token = create_token(
@ -200,7 +194,7 @@ async def signin(request: Request, response: Response, form_data: SigninForm, db
@router.post("/signup", response_model=SigninResponse) @router.post("/signup", response_model=SigninResponse)
async def signup(request: Request, response: Response, form_data: SignupForm, db=Depends(get_db)): async def signup(request: Request, response: Response, form_data: SignupForm):
if not request.app.state.config.ENABLE_SIGNUP and WEBUI_AUTH: if not request.app.state.config.ENABLE_SIGNUP and WEBUI_AUTH:
raise HTTPException( raise HTTPException(
status.HTTP_403_FORBIDDEN, detail=ERROR_MESSAGES.ACCESS_PROHIBITED status.HTTP_403_FORBIDDEN, detail=ERROR_MESSAGES.ACCESS_PROHIBITED
@ -211,18 +205,17 @@ async def signup(request: Request, response: Response, form_data: SignupForm, db
status.HTTP_400_BAD_REQUEST, detail=ERROR_MESSAGES.INVALID_EMAIL_FORMAT status.HTTP_400_BAD_REQUEST, detail=ERROR_MESSAGES.INVALID_EMAIL_FORMAT
) )
if Users.get_user_by_email(db, form_data.email.lower()): if Users.get_user_by_email(form_data.email.lower()):
raise HTTPException(400, detail=ERROR_MESSAGES.EMAIL_TAKEN) raise HTTPException(400, detail=ERROR_MESSAGES.EMAIL_TAKEN)
try: try:
role = ( role = (
"admin" "admin"
if Users.get_num_users(db) == 0 if Users.get_num_users() == 0
else request.app.state.config.DEFAULT_USER_ROLE else request.app.state.config.DEFAULT_USER_ROLE
) )
hashed = get_password_hash(form_data.password) hashed = get_password_hash(form_data.password)
user = Auths.insert_new_auth( user = Auths.insert_new_auth(
db,
form_data.email.lower(), form_data.email.lower(),
hashed, hashed,
form_data.name, form_data.name,
@ -277,7 +270,7 @@ async def signup(request: Request, response: Response, form_data: SignupForm, db
@router.post("/add", response_model=SigninResponse) @router.post("/add", response_model=SigninResponse)
async def add_user( async def add_user(
form_data: AddUserForm, user=Depends(get_admin_user), db=Depends(get_db) form_data: AddUserForm, user=Depends(get_admin_user)
): ):
if not validate_email_format(form_data.email.lower()): if not validate_email_format(form_data.email.lower()):
@ -285,7 +278,7 @@ async def add_user(
status.HTTP_400_BAD_REQUEST, detail=ERROR_MESSAGES.INVALID_EMAIL_FORMAT status.HTTP_400_BAD_REQUEST, detail=ERROR_MESSAGES.INVALID_EMAIL_FORMAT
) )
if Users.get_user_by_email(db, form_data.email.lower()): if Users.get_user_by_email(form_data.email.lower()):
raise HTTPException(400, detail=ERROR_MESSAGES.EMAIL_TAKEN) raise HTTPException(400, detail=ERROR_MESSAGES.EMAIL_TAKEN)
try: try:
@ -293,7 +286,6 @@ async def add_user(
print(form_data) print(form_data)
hashed = get_password_hash(form_data.password) hashed = get_password_hash(form_data.password)
user = Auths.insert_new_auth( user = Auths.insert_new_auth(
db,
form_data.email.lower(), form_data.email.lower(),
hashed, hashed,
form_data.name, form_data.name,
@ -325,7 +317,7 @@ async def add_user(
@router.get("/admin/details") @router.get("/admin/details")
async def get_admin_details( async def get_admin_details(
request: Request, user=Depends(get_current_user), db=Depends(get_db) request: Request, user=Depends(get_current_user)
): ):
if request.app.state.config.SHOW_ADMIN_DETAILS: if request.app.state.config.SHOW_ADMIN_DETAILS:
admin_email = request.app.state.config.ADMIN_EMAIL admin_email = request.app.state.config.ADMIN_EMAIL
@ -334,11 +326,11 @@ async def get_admin_details(
print(admin_email, admin_name) print(admin_email, admin_name)
if admin_email: if admin_email:
admin = Users.get_user_by_email(db, admin_email) admin = Users.get_user_by_email(admin_email)
if admin: if admin:
admin_name = admin.name admin_name = admin.name
else: else:
admin = Users.get_first_user(db) admin = Users.get_first_user()
if admin: if admin:
admin_email = admin.email admin_email = admin.email
admin_name = admin.name admin_name = admin.name
@ -411,9 +403,9 @@ async def update_admin_config(
# create api key # create api key
@router.post("/api_key", response_model=ApiKey) @router.post("/api_key", response_model=ApiKey)
async def create_api_key_(user=Depends(get_current_user), db=Depends(get_db)): async def create_api_key_(user=Depends(get_current_user)):
api_key = create_api_key() api_key = create_api_key()
success = Users.update_user_api_key_by_id(db, user.id, api_key) success = Users.update_user_api_key_by_id(user.id, api_key)
if success: if success:
return { return {
"api_key": api_key, "api_key": api_key,
@ -424,15 +416,15 @@ async def create_api_key_(user=Depends(get_current_user), db=Depends(get_db)):
# delete api key # delete api key
@router.delete("/api_key", response_model=bool) @router.delete("/api_key", response_model=bool)
async def delete_api_key(user=Depends(get_current_user), db=Depends(get_db)): async def delete_api_key(user=Depends(get_current_user)):
success = Users.update_user_api_key_by_id(db, user.id, None) success = Users.update_user_api_key_by_id(user.id, None)
return success return success
# get api key # get api key
@router.get("/api_key", response_model=ApiKey) @router.get("/api_key", response_model=ApiKey)
async def get_api_key(user=Depends(get_current_user), db=Depends(get_db)): async def get_api_key(user=Depends(get_current_user)):
api_key = Users.get_user_api_key_by_id(db, user.id) api_key = Users.get_user_api_key_by_id(user.id)
if api_key: if api_key:
return { return {
"api_key": api_key, "api_key": api_key,

View File

@ -2,7 +2,6 @@ from fastapi import Depends, Request, HTTPException, status
from datetime import datetime, timedelta from datetime import datetime, timedelta
from typing import List, Union, Optional from typing import List, Union, Optional
from apps.webui.internal.db import get_db
from utils.utils import get_current_user, get_admin_user from utils.utils import get_current_user, get_admin_user
from fastapi import APIRouter from fastapi import APIRouter
from pydantic import BaseModel from pydantic import BaseModel
@ -45,9 +44,9 @@ router = APIRouter()
@router.get("/", response_model=List[ChatTitleIdResponse]) @router.get("/", response_model=List[ChatTitleIdResponse])
@router.get("/list", response_model=List[ChatTitleIdResponse]) @router.get("/list", response_model=List[ChatTitleIdResponse])
async def get_session_user_chat_list( async def get_session_user_chat_list(
user=Depends(get_current_user), skip: int = 0, limit: int = 50, db=Depends(get_db) user=Depends(get_current_user), skip: int = 0, limit: int = 50
): ):
return Chats.get_chat_list_by_user_id(db, user.id, skip, limit) return Chats.get_chat_list_by_user_id(user.id, skip, limit)
############################ ############################
@ -57,7 +56,7 @@ async def get_session_user_chat_list(
@router.delete("/", response_model=bool) @router.delete("/", response_model=bool)
async def delete_all_user_chats( async def delete_all_user_chats(
request: Request, user=Depends(get_current_user), db=Depends(get_db) request: Request, user=Depends(get_current_user)
): ):
if ( if (
@ -69,7 +68,7 @@ async def delete_all_user_chats(
detail=ERROR_MESSAGES.ACCESS_PROHIBITED, detail=ERROR_MESSAGES.ACCESS_PROHIBITED,
) )
result = Chats.delete_chats_by_user_id(db, user.id) result = Chats.delete_chats_by_user_id(user.id)
return result return result
@ -84,10 +83,9 @@ async def get_user_chat_list_by_user_id(
user=Depends(get_admin_user), user=Depends(get_admin_user),
skip: int = 0, skip: int = 0,
limit: int = 50, limit: int = 50,
db=Depends(get_db),
): ):
return Chats.get_chat_list_by_user_id( return Chats.get_chat_list_by_user_id(
db, user_id, include_archived=True, skip=skip, limit=limit user_id, include_archived=True, skip=skip, limit=limit
) )
@ -98,10 +96,10 @@ async def get_user_chat_list_by_user_id(
@router.post("/new", response_model=Optional[ChatResponse]) @router.post("/new", response_model=Optional[ChatResponse])
async def create_new_chat( async def create_new_chat(
form_data: ChatForm, user=Depends(get_current_user), db=Depends(get_db) form_data: ChatForm, user=Depends(get_current_user)
): ):
try: try:
chat = Chats.insert_new_chat(db, user.id, form_data) chat = Chats.insert_new_chat(user.id, form_data)
return ChatResponse(**{**chat.model_dump(), "chat": json.loads(chat.chat)}) return ChatResponse(**{**chat.model_dump(), "chat": json.loads(chat.chat)})
except Exception as e: except Exception as e:
log.exception(e) log.exception(e)
@ -116,10 +114,10 @@ async def create_new_chat(
@router.get("/all", response_model=List[ChatResponse]) @router.get("/all", response_model=List[ChatResponse])
async def get_user_chats(user=Depends(get_current_user), db=Depends(get_db)): async def get_user_chats(user=Depends(get_current_user)):
return [ return [
ChatResponse(**{**chat.model_dump(), "chat": json.loads(chat.chat)}) ChatResponse(**{**chat.model_dump(), "chat": json.loads(chat.chat)})
for chat in Chats.get_chats_by_user_id(db, user.id) for chat in Chats.get_chats_by_user_id(user.id)
] ]
@ -129,10 +127,10 @@ async def get_user_chats(user=Depends(get_current_user), db=Depends(get_db)):
@router.get("/all/archived", response_model=List[ChatResponse]) @router.get("/all/archived", response_model=List[ChatResponse])
async def get_user_archived_chats(user=Depends(get_current_user), db=Depends(get_db)): async def get_user_archived_chats(user=Depends(get_current_user)):
return [ return [
ChatResponse(**{**chat.model_dump(), "chat": json.loads(chat.chat)}) ChatResponse(**{**chat.model_dump(), "chat": json.loads(chat.chat)})
for chat in Chats.get_archived_chats_by_user_id(db, user.id) for chat in Chats.get_archived_chats_by_user_id(user.id)
] ]
@ -142,7 +140,7 @@ async def get_user_archived_chats(user=Depends(get_current_user), db=Depends(get
@router.get("/all/db", response_model=List[ChatResponse]) @router.get("/all/db", response_model=List[ChatResponse])
async def get_all_user_chats_in_db(user=Depends(get_admin_user), db=Depends(get_db)): async def get_all_user_chats_in_db(user=Depends(get_admin_user)):
if not ENABLE_ADMIN_EXPORT: if not ENABLE_ADMIN_EXPORT:
raise HTTPException( raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED, status_code=status.HTTP_401_UNAUTHORIZED,
@ -150,7 +148,7 @@ async def get_all_user_chats_in_db(user=Depends(get_admin_user), db=Depends(get_
) )
return [ return [
ChatResponse(**{**chat.model_dump(), "chat": json.loads(chat.chat)}) ChatResponse(**{**chat.model_dump(), "chat": json.loads(chat.chat)})
for chat in Chats.get_chats(db) for chat in Chats.get_chats()
] ]
@ -161,9 +159,9 @@ async def get_all_user_chats_in_db(user=Depends(get_admin_user), db=Depends(get_
@router.get("/archived", response_model=List[ChatTitleIdResponse]) @router.get("/archived", response_model=List[ChatTitleIdResponse])
async def get_archived_session_user_chat_list( async def get_archived_session_user_chat_list(
user=Depends(get_current_user), skip: int = 0, limit: int = 50, db=Depends(get_db) user=Depends(get_current_user), skip: int = 0, limit: int = 50
): ):
return Chats.get_archived_chat_list_by_user_id(db, user.id, skip, limit) return Chats.get_archived_chat_list_by_user_id(user.id, skip, limit)
############################ ############################
@ -172,8 +170,8 @@ async def get_archived_session_user_chat_list(
@router.post("/archive/all", response_model=bool) @router.post("/archive/all", response_model=bool)
async def archive_all_chats(user=Depends(get_current_user), db=Depends(get_db)): async def archive_all_chats(user=Depends(get_current_user)):
return Chats.archive_all_chats_by_user_id(db, user.id) return Chats.archive_all_chats_by_user_id(user.id)
############################ ############################
@ -183,7 +181,7 @@ async def archive_all_chats(user=Depends(get_current_user), db=Depends(get_db)):
@router.get("/share/{share_id}", response_model=Optional[ChatResponse]) @router.get("/share/{share_id}", response_model=Optional[ChatResponse])
async def get_shared_chat_by_id( async def get_shared_chat_by_id(
share_id: str, user=Depends(get_current_user), db=Depends(get_db) share_id: str, user=Depends(get_current_user)
): ):
if user.role == "pending": if user.role == "pending":
raise HTTPException( raise HTTPException(
@ -191,9 +189,9 @@ async def get_shared_chat_by_id(
) )
if user.role == "user": if user.role == "user":
chat = Chats.get_chat_by_share_id(db, share_id) chat = Chats.get_chat_by_share_id(share_id)
elif user.role == "admin": elif user.role == "admin":
chat = Chats.get_chat_by_id(db, share_id) chat = Chats.get_chat_by_id(share_id)
if chat: if chat:
return ChatResponse(**{**chat.model_dump(), "chat": json.loads(chat.chat)}) return ChatResponse(**{**chat.model_dump(), "chat": json.loads(chat.chat)})
@ -216,23 +214,23 @@ class TagNameForm(BaseModel):
@router.post("/tags", response_model=List[ChatTitleIdResponse]) @router.post("/tags", response_model=List[ChatTitleIdResponse])
async def get_user_chat_list_by_tag_name( async def get_user_chat_list_by_tag_name(
form_data: TagNameForm, user=Depends(get_current_user), db=Depends(get_db) form_data: TagNameForm, user=Depends(get_current_user)
): ):
print(form_data) print(form_data)
chat_ids = [ chat_ids = [
chat_id_tag.chat_id chat_id_tag.chat_id
for chat_id_tag in Tags.get_chat_ids_by_tag_name_and_user_id( for chat_id_tag in Tags.get_chat_ids_by_tag_name_and_user_id(
db, form_data.name, user.id form_data.name, user.id
) )
] ]
chats = Chats.get_chat_list_by_chat_ids( chats = Chats.get_chat_list_by_chat_ids(
db, chat_ids, form_data.skip, form_data.limit chat_ids, form_data.skip, form_data.limit
) )
if len(chats) == 0: if len(chats) == 0:
Tags.delete_tag_by_tag_name_and_user_id(db, form_data.name, user.id) Tags.delete_tag_by_tag_name_and_user_id(form_data.name, user.id)
return chats return chats
@ -243,9 +241,9 @@ async def get_user_chat_list_by_tag_name(
@router.get("/tags/all", response_model=List[TagModel]) @router.get("/tags/all", response_model=List[TagModel])
async def get_all_tags(user=Depends(get_current_user), db=Depends(get_db)): async def get_all_tags(user=Depends(get_current_user)):
try: try:
tags = Tags.get_tags_by_user_id(db, user.id) tags = Tags.get_tags_by_user_id(user.id)
return tags return tags
except Exception as e: except Exception as e:
log.exception(e) log.exception(e)
@ -260,8 +258,8 @@ async def get_all_tags(user=Depends(get_current_user), db=Depends(get_db)):
@router.get("/{id}", response_model=Optional[ChatResponse]) @router.get("/{id}", response_model=Optional[ChatResponse])
async def get_chat_by_id(id: str, user=Depends(get_current_user), db=Depends(get_db)): async def get_chat_by_id(id: str, user=Depends(get_current_user)):
chat = Chats.get_chat_by_id_and_user_id(db, id, user.id) chat = Chats.get_chat_by_id_and_user_id(id, user.id)
if chat: if chat:
return ChatResponse(**{**chat.model_dump(), "chat": json.loads(chat.chat)}) return ChatResponse(**{**chat.model_dump(), "chat": json.loads(chat.chat)})
@ -278,13 +276,13 @@ async def get_chat_by_id(id: str, user=Depends(get_current_user), db=Depends(get
@router.post("/{id}", response_model=Optional[ChatResponse]) @router.post("/{id}", response_model=Optional[ChatResponse])
async def update_chat_by_id( async def update_chat_by_id(
id: str, form_data: ChatForm, user=Depends(get_current_user), db=Depends(get_db) id: str, form_data: ChatForm, user=Depends(get_current_user)
): ):
chat = Chats.get_chat_by_id_and_user_id(db, id, user.id) chat = Chats.get_chat_by_id_and_user_id(id, user.id)
if chat: if chat:
updated_chat = {**json.loads(chat.chat), **form_data.chat} updated_chat = {**json.loads(chat.chat), **form_data.chat}
chat = Chats.update_chat_by_id(db, id, updated_chat) chat = Chats.update_chat_by_id(id, updated_chat)
return ChatResponse(**{**chat.model_dump(), "chat": json.loads(chat.chat)}) return ChatResponse(**{**chat.model_dump(), "chat": json.loads(chat.chat)})
else: else:
raise HTTPException( raise HTTPException(
@ -300,11 +298,11 @@ async def update_chat_by_id(
@router.delete("/{id}", response_model=bool) @router.delete("/{id}", response_model=bool)
async def delete_chat_by_id( async def delete_chat_by_id(
request: Request, id: str, user=Depends(get_current_user), db=Depends(get_db) request: Request, id: str, user=Depends(get_current_user)
): ):
if user.role == "admin": if user.role == "admin":
result = Chats.delete_chat_by_id(db, id) result = Chats.delete_chat_by_id(id)
return result return result
else: else:
if not request.app.state.config.USER_PERMISSIONS["chat"]["deletion"]: if not request.app.state.config.USER_PERMISSIONS["chat"]["deletion"]:
@ -313,7 +311,7 @@ async def delete_chat_by_id(
detail=ERROR_MESSAGES.ACCESS_PROHIBITED, detail=ERROR_MESSAGES.ACCESS_PROHIBITED,
) )
result = Chats.delete_chat_by_id_and_user_id(db, id, user.id) result = Chats.delete_chat_by_id_and_user_id(id, user.id)
return result return result
@ -323,8 +321,8 @@ async def delete_chat_by_id(
@router.get("/{id}/clone", response_model=Optional[ChatResponse]) @router.get("/{id}/clone", response_model=Optional[ChatResponse])
async def clone_chat_by_id(id: str, user=Depends(get_current_user), db=Depends(get_db)): async def clone_chat_by_id(id: str, user=Depends(get_current_user)):
chat = Chats.get_chat_by_id_and_user_id(db, id, user.id) chat = Chats.get_chat_by_id_and_user_id(id, user.id)
if chat: if chat:
chat_body = json.loads(chat.chat) chat_body = json.loads(chat.chat)
@ -335,7 +333,7 @@ async def clone_chat_by_id(id: str, user=Depends(get_current_user), db=Depends(g
"title": f"Clone of {chat.title}", "title": f"Clone of {chat.title}",
} }
chat = Chats.insert_new_chat(db, user.id, ChatForm(**{"chat": updated_chat})) chat = Chats.insert_new_chat(user.id, ChatForm(**{"chat": updated_chat}))
return ChatResponse(**{**chat.model_dump(), "chat": json.loads(chat.chat)}) return ChatResponse(**{**chat.model_dump(), "chat": json.loads(chat.chat)})
else: else:
raise HTTPException( raise HTTPException(
@ -350,11 +348,11 @@ async def clone_chat_by_id(id: str, user=Depends(get_current_user), db=Depends(g
@router.get("/{id}/archive", response_model=Optional[ChatResponse]) @router.get("/{id}/archive", response_model=Optional[ChatResponse])
async def archive_chat_by_id( async def archive_chat_by_id(
id: str, user=Depends(get_current_user), db=Depends(get_db) id: str, user=Depends(get_current_user)
): ):
chat = Chats.get_chat_by_id_and_user_id(db, id, user.id) chat = Chats.get_chat_by_id_and_user_id(id, user.id)
if chat: if chat:
chat = Chats.toggle_chat_archive_by_id(db, id) chat = Chats.toggle_chat_archive_by_id(id)
return ChatResponse(**{**chat.model_dump(), "chat": json.loads(chat.chat)}) return ChatResponse(**{**chat.model_dump(), "chat": json.loads(chat.chat)})
else: else:
raise HTTPException( raise HTTPException(
@ -368,16 +366,16 @@ async def archive_chat_by_id(
@router.post("/{id}/share", response_model=Optional[ChatResponse]) @router.post("/{id}/share", response_model=Optional[ChatResponse])
async def share_chat_by_id(id: str, user=Depends(get_current_user), db=Depends(get_db)): async def share_chat_by_id(id: str, user=Depends(get_current_user)):
chat = Chats.get_chat_by_id_and_user_id(db, id, user.id) chat = Chats.get_chat_by_id_and_user_id(id, user.id)
if chat: if chat:
if chat.share_id: if chat.share_id:
shared_chat = Chats.update_shared_chat_by_chat_id(db, chat.id) shared_chat = Chats.update_shared_chat_by_chat_id(chat.id)
return ChatResponse( return ChatResponse(
**{**shared_chat.model_dump(), "chat": json.loads(shared_chat.chat)} **{**shared_chat.model_dump(), "chat": json.loads(shared_chat.chat)}
) )
shared_chat = Chats.insert_shared_chat_by_chat_id(db, chat.id) shared_chat = Chats.insert_shared_chat_by_chat_id(chat.id)
if not shared_chat: if not shared_chat:
raise HTTPException( raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
@ -401,15 +399,15 @@ async def share_chat_by_id(id: str, user=Depends(get_current_user), db=Depends(g
@router.delete("/{id}/share", response_model=Optional[bool]) @router.delete("/{id}/share", response_model=Optional[bool])
async def delete_shared_chat_by_id( async def delete_shared_chat_by_id(
id: str, user=Depends(get_current_user), db=Depends(get_db) id: str, user=Depends(get_current_user)
): ):
chat = Chats.get_chat_by_id_and_user_id(db, id, user.id) chat = Chats.get_chat_by_id_and_user_id(id, user.id)
if chat: if chat:
if not chat.share_id: if not chat.share_id:
return False return False
result = Chats.delete_shared_chat_by_chat_id(db, id) result = Chats.delete_shared_chat_by_chat_id(id)
update_result = Chats.update_chat_share_id_by_id(db, id, None) update_result = Chats.update_chat_share_id_by_id(id, None)
return result and update_result != None return result and update_result != None
else: else:
@ -426,9 +424,9 @@ async def delete_shared_chat_by_id(
@router.get("/{id}/tags", response_model=List[TagModel]) @router.get("/{id}/tags", response_model=List[TagModel])
async def get_chat_tags_by_id( async def get_chat_tags_by_id(
id: str, user=Depends(get_current_user), db=Depends(get_db) id: str, user=Depends(get_current_user)
): ):
tags = Tags.get_tags_by_chat_id_and_user_id(db, id, user.id) tags = Tags.get_tags_by_chat_id_and_user_id(id, user.id)
if tags != None: if tags != None:
return tags return tags
@ -447,13 +445,12 @@ async def get_chat_tags_by_id(
async def add_chat_tag_by_id( async def add_chat_tag_by_id(
id: str, id: str,
form_data: ChatIdTagForm, form_data: ChatIdTagForm,
user=Depends(get_current_user), user=Depends(get_current_user)
db=Depends(get_db),
): ):
tags = Tags.get_tags_by_chat_id_and_user_id(db, id, user.id) tags = Tags.get_tags_by_chat_id_and_user_id(id, user.id)
if form_data.tag_name not in tags: if form_data.tag_name not in tags:
tag = Tags.add_tag_to_chat(db, user.id, form_data) tag = Tags.add_tag_to_chat(user.id, form_data)
if tag: if tag:
return tag return tag
@ -478,10 +475,9 @@ async def delete_chat_tag_by_id(
id: str, id: str,
form_data: ChatIdTagForm, form_data: ChatIdTagForm,
user=Depends(get_current_user), user=Depends(get_current_user),
db=Depends(get_db),
): ):
result = Tags.delete_tag_by_tag_name_and_chat_id_and_user_id( result = Tags.delete_tag_by_tag_name_and_chat_id_and_user_id(
db, form_data.tag_name, id, user.id form_data.tag_name, id, user.id
) )
if result: if result:
@ -499,9 +495,9 @@ async def delete_chat_tag_by_id(
@router.delete("/{id}/tags/all", response_model=Optional[bool]) @router.delete("/{id}/tags/all", response_model=Optional[bool])
async def delete_all_chat_tags_by_id( async def delete_all_chat_tags_by_id(
id: str, user=Depends(get_current_user), db=Depends(get_db) id: str, user=Depends(get_current_user)
): ):
result = Tags.delete_tags_by_chat_id_and_user_id(db, id, user.id) result = Tags.delete_tags_by_chat_id_and_user_id(id, user.id)
if result: if result:
return result return result

View File

@ -6,7 +6,6 @@ from fastapi import APIRouter
from pydantic import BaseModel from pydantic import BaseModel
import json import json
from apps.webui.internal.db import get_db
from apps.webui.models.documents import ( from apps.webui.models.documents import (
Documents, Documents,
DocumentForm, DocumentForm,
@ -26,7 +25,7 @@ router = APIRouter()
@router.get("/", response_model=List[DocumentResponse]) @router.get("/", response_model=List[DocumentResponse])
async def get_documents(user=Depends(get_current_user), db=Depends(get_db)): async def get_documents(user=Depends(get_current_user)):
docs = [ docs = [
DocumentResponse( DocumentResponse(
**{ **{
@ -34,7 +33,7 @@ async def get_documents(user=Depends(get_current_user), db=Depends(get_db)):
"content": json.loads(doc.content if doc.content else "{}"), "content": json.loads(doc.content if doc.content else "{}"),
} }
) )
for doc in Documents.get_docs(db) for doc in Documents.get_docs()
] ]
return docs return docs
@ -46,11 +45,11 @@ async def get_documents(user=Depends(get_current_user), db=Depends(get_db)):
@router.post("/create", response_model=Optional[DocumentResponse]) @router.post("/create", response_model=Optional[DocumentResponse])
async def create_new_doc( async def create_new_doc(
form_data: DocumentForm, user=Depends(get_admin_user), db=Depends(get_db) form_data: DocumentForm, user=Depends(get_admin_user)
): ):
doc = Documents.get_doc_by_name(db, form_data.name) doc = Documents.get_doc_by_name(form_data.name)
if doc == None: if doc == None:
doc = Documents.insert_new_doc(db, user.id, form_data) doc = Documents.insert_new_doc(user.id, form_data)
if doc: if doc:
return DocumentResponse( return DocumentResponse(
@ -78,9 +77,9 @@ async def create_new_doc(
@router.get("/doc", response_model=Optional[DocumentResponse]) @router.get("/doc", response_model=Optional[DocumentResponse])
async def get_doc_by_name( async def get_doc_by_name(
name: str, user=Depends(get_current_user), db=Depends(get_db) name: str, user=Depends(get_current_user)
): ):
doc = Documents.get_doc_by_name(db, name) doc = Documents.get_doc_by_name(name)
if doc: if doc:
return DocumentResponse( return DocumentResponse(
@ -112,10 +111,10 @@ class TagDocumentForm(BaseModel):
@router.post("/doc/tags", response_model=Optional[DocumentResponse]) @router.post("/doc/tags", response_model=Optional[DocumentResponse])
async def tag_doc_by_name( async def tag_doc_by_name(
form_data: TagDocumentForm, user=Depends(get_current_user), db=Depends(get_db) form_data: TagDocumentForm, user=Depends(get_current_user)
): ):
doc = Documents.update_doc_content_by_name( doc = Documents.update_doc_content_by_name(
db, form_data.name, {"tags": form_data.tags} form_data.name, {"tags": form_data.tags}
) )
if doc: if doc:
@ -142,9 +141,8 @@ async def update_doc_by_name(
name: str, name: str,
form_data: DocumentUpdateForm, form_data: DocumentUpdateForm,
user=Depends(get_admin_user), user=Depends(get_admin_user),
db=Depends(get_db),
): ):
doc = Documents.update_doc_by_name(db, name, form_data) doc = Documents.update_doc_by_name(name, form_data)
if doc: if doc:
return DocumentResponse( return DocumentResponse(
**{ **{
@ -166,7 +164,7 @@ async def update_doc_by_name(
@router.delete("/doc/delete", response_model=bool) @router.delete("/doc/delete", response_model=bool)
async def delete_doc_by_name( async def delete_doc_by_name(
name: str, user=Depends(get_admin_user), db=Depends(get_db) name: str, user=Depends(get_admin_user)
): ):
result = Documents.delete_doc_by_name(db, name) result = Documents.delete_doc_by_name(name)
return result return result

View File

@ -20,7 +20,6 @@ from fastapi.responses import StreamingResponse, JSONResponse, FileResponse
from pydantic import BaseModel from pydantic import BaseModel
import json import json
from apps.webui.internal.db import get_db
from apps.webui.models.files import ( from apps.webui.models.files import (
Files, Files,
FileForm, FileForm,
@ -53,8 +52,7 @@ router = APIRouter()
@router.post("/") @router.post("/")
def upload_file( def upload_file(
file: UploadFile = File(...), file: UploadFile = File(...),
user=Depends(get_verified_user), user=Depends(get_verified_user)
db=Depends(get_db)
): ):
log.info(f"file.content_type: {file.content_type}") log.info(f"file.content_type: {file.content_type}")
try: try:
@ -72,7 +70,6 @@ def upload_file(
f.close() f.close()
file = Files.insert_new_file( file = Files.insert_new_file(
db,
user.id, user.id,
FileForm( FileForm(
**{ **{
@ -109,8 +106,8 @@ def upload_file(
@router.get("/", response_model=List[FileModel]) @router.get("/", response_model=List[FileModel])
async def list_files(user=Depends(get_verified_user), db=Depends(get_db)): async def list_files(user=Depends(get_verified_user)):
files = Files.get_files(db) files = Files.get_files()
return files return files
@ -120,8 +117,8 @@ async def list_files(user=Depends(get_verified_user), db=Depends(get_db)):
@router.delete("/all") @router.delete("/all")
async def delete_all_files(user=Depends(get_admin_user), db=Depends(get_db)): async def delete_all_files(user=Depends(get_admin_user)):
result = Files.delete_all_files(db) result = Files.delete_all_files()
if result: if result:
folder = f"{UPLOAD_DIR}" folder = f"{UPLOAD_DIR}"
@ -157,8 +154,8 @@ async def delete_all_files(user=Depends(get_admin_user), db=Depends(get_db)):
@router.get("/{id}", response_model=Optional[FileModel]) @router.get("/{id}", response_model=Optional[FileModel])
async def get_file_by_id(id: str, user=Depends(get_verified_user), db=Depends(get_db)): async def get_file_by_id(id: str, user=Depends(get_verified_user)):
file = Files.get_file_by_id(db, id) file = Files.get_file_by_id(id)
if file: if file:
return file return file
@ -175,8 +172,8 @@ async def get_file_by_id(id: str, user=Depends(get_verified_user), db=Depends(ge
@router.get("/{id}/content", response_model=Optional[FileModel]) @router.get("/{id}/content", response_model=Optional[FileModel])
async def get_file_content_by_id(id: str, user=Depends(get_verified_user), db=Depends(get_db)): async def get_file_content_by_id(id: str, user=Depends(get_verified_user)):
file = Files.get_file_by_id(db, id) file = Files.get_file_by_id(id)
if file: if file:
file_path = Path(file.meta["path"]) file_path = Path(file.meta["path"])
@ -226,11 +223,11 @@ async def get_file_content_by_id(id: str, user=Depends(get_verified_user)):
@router.delete("/{id}") @router.delete("/{id}")
async def delete_file_by_id(id: str, user=Depends(get_verified_user), db=Depends(get_db)): async def delete_file_by_id(id: str, user=Depends(get_verified_user)):
file = Files.get_file_by_id(db, id) file = Files.get_file_by_id(id)
if file: if file:
result = Files.delete_file_by_id(db, id) result = Files.delete_file_by_id(id)
if result: if result:
return {"message": "File deleted successfully"} return {"message": "File deleted successfully"}
else: else:

View File

@ -6,7 +6,6 @@ from fastapi import APIRouter
from pydantic import BaseModel from pydantic import BaseModel
import json import json
from apps.webui.internal.db import get_db
from apps.webui.models.functions import ( from apps.webui.models.functions import (
Functions, Functions,
FunctionForm, FunctionForm,
@ -32,8 +31,8 @@ router = APIRouter()
@router.get("/", response_model=List[FunctionResponse]) @router.get("/", response_model=List[FunctionResponse])
async def get_functions(user=Depends(get_verified_user), db=Depends(get_db)): async def get_functions(user=Depends(get_verified_user)):
return Functions.get_functions(db) return Functions.get_functions()
############################ ############################
@ -42,8 +41,8 @@ async def get_functions(user=Depends(get_verified_user), db=Depends(get_db)):
@router.get("/export", response_model=List[FunctionModel]) @router.get("/export", response_model=List[FunctionModel])
async def get_functions(user=Depends(get_admin_user), db=Depends(get_db)): async def get_functions(user=Depends(get_admin_user)):
return Functions.get_functions(db) return Functions.get_functions()
############################ ############################
@ -53,7 +52,7 @@ async def get_functions(user=Depends(get_admin_user), db=Depends(get_db)):
@router.post("/create", response_model=Optional[FunctionResponse]) @router.post("/create", response_model=Optional[FunctionResponse])
async def create_new_function( async def create_new_function(
request: Request, form_data: FunctionForm, user=Depends(get_admin_user), db=Depends(get_db) request: Request, form_data: FunctionForm, user=Depends(get_admin_user)
): ):
if not form_data.id.isidentifier(): if not form_data.id.isidentifier():
raise HTTPException( raise HTTPException(
@ -63,7 +62,7 @@ async def create_new_function(
form_data.id = form_data.id.lower() form_data.id = form_data.id.lower()
function = Functions.get_function_by_id(db, form_data.id) function = Functions.get_function_by_id(form_data.id)
if function == None: if function == None:
function_path = os.path.join(FUNCTIONS_DIR, f"{form_data.id}.py") function_path = os.path.join(FUNCTIONS_DIR, f"{form_data.id}.py")
try: try:
@ -78,7 +77,7 @@ async def create_new_function(
FUNCTIONS = request.app.state.FUNCTIONS FUNCTIONS = request.app.state.FUNCTIONS
FUNCTIONS[form_data.id] = function_module FUNCTIONS[form_data.id] = function_module
function = Functions.insert_new_function(db, user.id, function_type, form_data) function = Functions.insert_new_function(user.id, function_type, form_data)
function_cache_dir = Path(CACHE_DIR) / "functions" / form_data.id function_cache_dir = Path(CACHE_DIR) / "functions" / form_data.id
function_cache_dir.mkdir(parents=True, exist_ok=True) function_cache_dir.mkdir(parents=True, exist_ok=True)
@ -109,8 +108,8 @@ async def create_new_function(
@router.get("/id/{id}", response_model=Optional[FunctionModel]) @router.get("/id/{id}", response_model=Optional[FunctionModel])
async def get_function_by_id(id: str, user=Depends(get_admin_user), db=Depends(get_db)): async def get_function_by_id(id: str, user=Depends(get_admin_user)):
function = Functions.get_function_by_id(db, id) function = Functions.get_function_by_id(id)
if function: if function:
return function return function
@ -155,7 +154,7 @@ async def toggle_function_by_id(id: str, user=Depends(get_admin_user)):
@router.post("/id/{id}/update", response_model=Optional[FunctionModel]) @router.post("/id/{id}/update", response_model=Optional[FunctionModel])
async def update_function_by_id( async def update_function_by_id(
request: Request, id: str, form_data: FunctionForm, user=Depends(get_admin_user), db=Depends(get_db) request: Request, id: str, form_data: FunctionForm, user=Depends(get_admin_user)
): ):
function_path = os.path.join(FUNCTIONS_DIR, f"{id}.py") function_path = os.path.join(FUNCTIONS_DIR, f"{id}.py")
@ -172,7 +171,7 @@ async def update_function_by_id(
updated = {**form_data.model_dump(exclude={"id"}), "type": function_type} updated = {**form_data.model_dump(exclude={"id"}), "type": function_type}
print(updated) print(updated)
function = Functions.update_function_by_id(db, id, updated) function = Functions.update_function_by_id(id, updated)
if function: if function:
return function return function
@ -196,9 +195,9 @@ async def update_function_by_id(
@router.delete("/id/{id}/delete", response_model=bool) @router.delete("/id/{id}/delete", response_model=bool)
async def delete_function_by_id( async def delete_function_by_id(
request: Request, id: str, user=Depends(get_admin_user), db=Depends(get_db) request: Request, id: str, user=Depends(get_admin_user)
): ):
result = Functions.delete_function_by_id(db, id) result = Functions.delete_function_by_id(id)
if result: if result:
FUNCTIONS = request.app.state.FUNCTIONS FUNCTIONS = request.app.state.FUNCTIONS

View File

@ -7,7 +7,6 @@ from fastapi import APIRouter
from pydantic import BaseModel from pydantic import BaseModel
import logging import logging
from apps.webui.internal.db import get_db
from apps.webui.models.memories import Memories, MemoryModel from apps.webui.models.memories import Memories, MemoryModel
from utils.utils import get_verified_user from utils.utils import get_verified_user
@ -32,8 +31,8 @@ async def get_embeddings(request: Request):
@router.get("/", response_model=List[MemoryModel]) @router.get("/", response_model=List[MemoryModel])
async def get_memories(user=Depends(get_verified_user), db=Depends(get_db)): async def get_memories(user=Depends(get_verified_user)):
return Memories.get_memories_by_user_id(db, user.id) return Memories.get_memories_by_user_id(user.id)
############################ ############################
@ -54,9 +53,8 @@ async def add_memory(
request: Request, request: Request,
form_data: AddMemoryForm, form_data: AddMemoryForm,
user=Depends(get_verified_user), user=Depends(get_verified_user),
db=Depends(get_db),
): ):
memory = Memories.insert_new_memory(db, user.id, form_data.content) memory = Memories.insert_new_memory(user.id, form_data.content)
memory_embedding = request.app.state.EMBEDDING_FUNCTION(memory.content) memory_embedding = request.app.state.EMBEDDING_FUNCTION(memory.content)
collection = CHROMA_CLIENT.get_or_create_collection(name=f"user-memory-{user.id}") collection = CHROMA_CLIENT.get_or_create_collection(name=f"user-memory-{user.id}")
@ -76,9 +74,8 @@ async def update_memory_by_id(
request: Request, request: Request,
form_data: MemoryUpdateModel, form_data: MemoryUpdateModel,
user=Depends(get_verified_user), user=Depends(get_verified_user),
db=Depends(get_db),
): ):
memory = Memories.update_memory_by_id(db, memory_id, form_data.content) memory = Memories.update_memory_by_id(memory_id, form_data.content)
if memory is None: if memory is None:
raise HTTPException(status_code=404, detail="Memory not found") raise HTTPException(status_code=404, detail="Memory not found")
@ -129,12 +126,12 @@ async def query_memory(
############################ ############################
@router.get("/reset", response_model=bool) @router.get("/reset", response_model=bool)
async def reset_memory_from_vector_db( async def reset_memory_from_vector_db(
request: Request, user=Depends(get_verified_user), db=Depends(get_db) request: Request, user=Depends(get_verified_user)
): ):
CHROMA_CLIENT.delete_collection(f"user-memory-{user.id}") CHROMA_CLIENT.delete_collection(f"user-memory-{user.id}")
collection = CHROMA_CLIENT.get_or_create_collection(name=f"user-memory-{user.id}") collection = CHROMA_CLIENT.get_or_create_collection(name=f"user-memory-{user.id}")
memories = Memories.get_memories_by_user_id(db, user.id) memories = Memories.get_memories_by_user_id(user.id)
for memory in memories: for memory in memories:
memory_embedding = request.app.state.EMBEDDING_FUNCTION(memory.content) memory_embedding = request.app.state.EMBEDDING_FUNCTION(memory.content)
collection.upsert( collection.upsert(
@ -151,8 +148,8 @@ async def reset_memory_from_vector_db(
@router.delete("/user", response_model=bool) @router.delete("/user", response_model=bool)
async def delete_memory_by_user_id(user=Depends(get_verified_user), db=Depends(get_db)): async def delete_memory_by_user_id(user=Depends(get_verified_user)):
result = Memories.delete_memories_by_user_id(db, user.id) result = Memories.delete_memories_by_user_id(user.id)
if result: if result:
try: try:
@ -171,9 +168,9 @@ async def delete_memory_by_user_id(user=Depends(get_verified_user), db=Depends(g
@router.delete("/{memory_id}", response_model=bool) @router.delete("/{memory_id}", response_model=bool)
async def delete_memory_by_id( async def delete_memory_by_id(
memory_id: str, user=Depends(get_verified_user), db=Depends(get_db) memory_id: str, user=Depends(get_verified_user)
): ):
result = Memories.delete_memory_by_id_and_user_id(db, memory_id, user.id) result = Memories.delete_memory_by_id_and_user_id(memory_id, user.id)
if result: if result:
collection = CHROMA_CLIENT.get_or_create_collection( collection = CHROMA_CLIENT.get_or_create_collection(

View File

@ -6,7 +6,6 @@ from fastapi import APIRouter
from pydantic import BaseModel from pydantic import BaseModel
import json import json
from apps.webui.internal.db import get_db
from apps.webui.models.models import Models, ModelModel, ModelForm, ModelResponse from apps.webui.models.models import Models, ModelModel, ModelForm, ModelResponse
from utils.utils import get_verified_user, get_admin_user from utils.utils import get_verified_user, get_admin_user
@ -20,8 +19,8 @@ router = APIRouter()
@router.get("/", response_model=List[ModelResponse]) @router.get("/", response_model=List[ModelResponse])
async def get_models(user=Depends(get_verified_user), db=Depends(get_db)): async def get_models(user=Depends(get_verified_user)):
return Models.get_all_models(db) return Models.get_all_models()
############################ ############################
@ -34,7 +33,6 @@ async def add_new_model(
request: Request, request: Request,
form_data: ModelForm, form_data: ModelForm,
user=Depends(get_admin_user), user=Depends(get_admin_user),
db=Depends(get_db),
): ):
if form_data.id in request.app.state.MODELS: if form_data.id in request.app.state.MODELS:
raise HTTPException( raise HTTPException(
@ -42,7 +40,7 @@ async def add_new_model(
detail=ERROR_MESSAGES.MODEL_ID_TAKEN, detail=ERROR_MESSAGES.MODEL_ID_TAKEN,
) )
else: else:
model = Models.insert_new_model(db, form_data, user.id) model = Models.insert_new_model(form_data, user.id)
if model: if model:
return model return model
@ -59,8 +57,8 @@ async def add_new_model(
@router.get("/{id}", response_model=Optional[ModelModel]) @router.get("/{id}", response_model=Optional[ModelModel])
async def get_model_by_id(id: str, user=Depends(get_verified_user), db=Depends(get_db)): async def get_model_by_id(id: str, user=Depends(get_verified_user)):
model = Models.get_model_by_id(db, id) model = Models.get_model_by_id(id)
if model: if model:
return model return model
@ -82,15 +80,14 @@ async def update_model_by_id(
id: str, id: str,
form_data: ModelForm, form_data: ModelForm,
user=Depends(get_admin_user), user=Depends(get_admin_user),
db=Depends(get_db),
): ):
model = Models.get_model_by_id(db, id) model = Models.get_model_by_id(id)
if model: if model:
model = Models.update_model_by_id(db, id, form_data) model = Models.update_model_by_id(id, form_data)
return model return model
else: else:
if form_data.id in request.app.state.MODELS: if form_data.id in request.app.state.MODELS:
model = Models.insert_new_model(db, form_data, user.id) model = Models.insert_new_model(form_data, user.id)
if model: if model:
return model return model
else: else:
@ -111,6 +108,6 @@ async def update_model_by_id(
@router.delete("/delete", response_model=bool) @router.delete("/delete", response_model=bool)
async def delete_model_by_id(id: str, user=Depends(get_admin_user), db=Depends(get_db)): async def delete_model_by_id(id: str, user=Depends(get_admin_user)):
result = Models.delete_model_by_id(db, id) result = Models.delete_model_by_id(id)
return result return result

View File

@ -6,7 +6,6 @@ from fastapi import APIRouter
from pydantic import BaseModel from pydantic import BaseModel
import json import json
from apps.webui.internal.db import get_db
from apps.webui.models.prompts import Prompts, PromptForm, PromptModel from apps.webui.models.prompts import Prompts, PromptForm, PromptModel
from utils.utils import get_current_user, get_admin_user from utils.utils import get_current_user, get_admin_user
@ -20,8 +19,8 @@ router = APIRouter()
@router.get("/", response_model=List[PromptModel]) @router.get("/", response_model=List[PromptModel])
async def get_prompts(user=Depends(get_current_user), db=Depends(get_db)): async def get_prompts(user=Depends(get_current_user)):
return Prompts.get_prompts(db) return Prompts.get_prompts()
############################ ############################
@ -31,11 +30,11 @@ async def get_prompts(user=Depends(get_current_user), db=Depends(get_db)):
@router.post("/create", response_model=Optional[PromptModel]) @router.post("/create", response_model=Optional[PromptModel])
async def create_new_prompt( async def create_new_prompt(
form_data: PromptForm, user=Depends(get_admin_user), db=Depends(get_db) form_data: PromptForm, user=Depends(get_admin_user)
): ):
prompt = Prompts.get_prompt_by_command(db, form_data.command) prompt = Prompts.get_prompt_by_command(form_data.command)
if prompt == None: if prompt == None:
prompt = Prompts.insert_new_prompt(db, user.id, form_data) prompt = Prompts.insert_new_prompt(user.id, form_data)
if prompt: if prompt:
return prompt return prompt
@ -56,9 +55,9 @@ async def create_new_prompt(
@router.get("/command/{command}", response_model=Optional[PromptModel]) @router.get("/command/{command}", response_model=Optional[PromptModel])
async def get_prompt_by_command( async def get_prompt_by_command(
command: str, user=Depends(get_current_user), db=Depends(get_db) command: str, user=Depends(get_current_user)
): ):
prompt = Prompts.get_prompt_by_command(db, f"/{command}") prompt = Prompts.get_prompt_by_command(f"/{command}")
if prompt: if prompt:
return prompt return prompt
@ -79,9 +78,8 @@ async def update_prompt_by_command(
command: str, command: str,
form_data: PromptForm, form_data: PromptForm,
user=Depends(get_admin_user), user=Depends(get_admin_user),
db=Depends(get_db),
): ):
prompt = Prompts.update_prompt_by_command(db, f"/{command}", form_data) prompt = Prompts.update_prompt_by_command(f"/{command}", form_data)
if prompt: if prompt:
return prompt return prompt
else: else:
@ -98,7 +96,7 @@ async def update_prompt_by_command(
@router.delete("/command/{command}/delete", response_model=bool) @router.delete("/command/{command}/delete", response_model=bool)
async def delete_prompt_by_command( async def delete_prompt_by_command(
command: str, user=Depends(get_admin_user), db=Depends(get_db) command: str, user=Depends(get_admin_user)
): ):
result = Prompts.delete_prompt_by_command(db, f"/{command}") result = Prompts.delete_prompt_by_command(f"/{command}")
return result return result

View File

@ -6,7 +6,6 @@ from fastapi import APIRouter
from pydantic import BaseModel from pydantic import BaseModel
import json import json
from apps.webui.internal.db import get_db
from apps.webui.models.users import Users from apps.webui.models.users import Users
from apps.webui.models.tools import Tools, ToolForm, ToolModel, ToolResponse from apps.webui.models.tools import Tools, ToolForm, ToolModel, ToolResponse
from apps.webui.utils import load_toolkit_module_by_id from apps.webui.utils import load_toolkit_module_by_id
@ -34,7 +33,7 @@ router = APIRouter()
@router.get("/", response_model=List[ToolResponse]) @router.get("/", response_model=List[ToolResponse])
async def get_toolkits(user=Depends(get_verified_user), db=Depends(get_db)): async def get_toolkits(user=Depends(get_verified_user)):
toolkits = [toolkit for toolkit in Tools.get_tools()] toolkits = [toolkit for toolkit in Tools.get_tools()]
return toolkits return toolkits
@ -45,8 +44,8 @@ async def get_toolkits(user=Depends(get_verified_user), db=Depends(get_db)):
@router.get("/export", response_model=List[ToolModel]) @router.get("/export", response_model=List[ToolModel])
async def get_toolkits(user=Depends(get_admin_user), db=Depends(get_db)): async def get_toolkits(user=Depends(get_admin_user)):
toolkits = [toolkit for toolkit in Tools.get_tools(db)] toolkits = [toolkit for toolkit in Tools.get_tools()]
return toolkits return toolkits
@ -60,7 +59,6 @@ async def create_new_toolkit(
request: Request, request: Request,
form_data: ToolForm, form_data: ToolForm,
user=Depends(get_admin_user), user=Depends(get_admin_user),
db=Depends(get_db),
): ):
if not form_data.id.isidentifier(): if not form_data.id.isidentifier():
raise HTTPException( raise HTTPException(
@ -70,7 +68,7 @@ async def create_new_toolkit(
form_data.id = form_data.id.lower() form_data.id = form_data.id.lower()
toolkit = Tools.get_tool_by_id(db, form_data.id) toolkit = Tools.get_tool_by_id(form_data.id)
if toolkit == None: if toolkit == None:
toolkit_path = os.path.join(TOOLS_DIR, f"{form_data.id}.py") toolkit_path = os.path.join(TOOLS_DIR, f"{form_data.id}.py")
try: try:
@ -84,7 +82,7 @@ async def create_new_toolkit(
TOOLS[form_data.id] = toolkit_module TOOLS[form_data.id] = toolkit_module
specs = get_tools_specs(TOOLS[form_data.id]) specs = get_tools_specs(TOOLS[form_data.id])
toolkit = Tools.insert_new_tool(db, user.id, form_data, specs) toolkit = Tools.insert_new_tool(user.id, form_data, specs)
tool_cache_dir = Path(CACHE_DIR) / "tools" / form_data.id tool_cache_dir = Path(CACHE_DIR) / "tools" / form_data.id
tool_cache_dir.mkdir(parents=True, exist_ok=True) tool_cache_dir.mkdir(parents=True, exist_ok=True)
@ -115,8 +113,8 @@ async def create_new_toolkit(
@router.get("/id/{id}", response_model=Optional[ToolModel]) @router.get("/id/{id}", response_model=Optional[ToolModel])
async def get_toolkit_by_id(id: str, user=Depends(get_admin_user), db=Depends(get_db)): async def get_toolkit_by_id(id: str, user=Depends(get_admin_user)):
toolkit = Tools.get_tool_by_id(db, id) toolkit = Tools.get_tool_by_id(id)
if toolkit: if toolkit:
return toolkit return toolkit
@ -138,7 +136,6 @@ async def update_toolkit_by_id(
id: str, id: str,
form_data: ToolForm, form_data: ToolForm,
user=Depends(get_admin_user), user=Depends(get_admin_user),
db=Depends(get_db),
): ):
toolkit_path = os.path.join(TOOLS_DIR, f"{id}.py") toolkit_path = os.path.join(TOOLS_DIR, f"{id}.py")
@ -160,7 +157,7 @@ async def update_toolkit_by_id(
} }
print(updated) print(updated)
toolkit = Tools.update_tool_by_id(db, id, updated) toolkit = Tools.update_tool_by_id(id, updated)
if toolkit: if toolkit:
return toolkit return toolkit
@ -184,9 +181,9 @@ async def update_toolkit_by_id(
@router.delete("/id/{id}/delete", response_model=bool) @router.delete("/id/{id}/delete", response_model=bool)
async def delete_toolkit_by_id( async def delete_toolkit_by_id(
request: Request, id: str, user=Depends(get_admin_user), db=Depends(get_db) request: Request, id: str, user=Depends(get_admin_user)
): ):
result = Tools.delete_tool_by_id(db, id) result = Tools.delete_tool_by_id(id)
if result: if result:
TOOLS = request.app.state.TOOLS TOOLS = request.app.state.TOOLS

View File

@ -9,7 +9,6 @@ import time
import uuid import uuid
import logging import logging
from apps.webui.internal.db import get_db
from apps.webui.models.users import ( from apps.webui.models.users import (
UserModel, UserModel,
UserUpdateForm, UserUpdateForm,
@ -42,9 +41,9 @@ router = APIRouter()
@router.get("/", response_model=List[UserModel]) @router.get("/", response_model=List[UserModel])
async def get_users( async def get_users(
skip: int = 0, limit: int = 50, user=Depends(get_admin_user), db=Depends(get_db) skip: int = 0, limit: int = 50, user=Depends(get_admin_user)
): ):
return Users.get_users(db, skip, limit) return Users.get_users(skip, limit)
############################ ############################
@ -72,11 +71,11 @@ async def update_user_permissions(
@router.post("/update/role", response_model=Optional[UserModel]) @router.post("/update/role", response_model=Optional[UserModel])
async def update_user_role( async def update_user_role(
form_data: UserRoleUpdateForm, user=Depends(get_admin_user), db=Depends(get_db) form_data: UserRoleUpdateForm, user=Depends(get_admin_user)
): ):
if user.id != form_data.id and form_data.id != Users.get_first_user(db).id: if user.id != form_data.id and form_data.id != Users.get_first_user().id:
return Users.update_user_role_by_id(db, form_data.id, form_data.role) return Users.update_user_role_by_id(form_data.id, form_data.role)
raise HTTPException( raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN, status_code=status.HTTP_403_FORBIDDEN,
@ -91,9 +90,9 @@ async def update_user_role(
@router.get("/user/settings", response_model=Optional[UserSettings]) @router.get("/user/settings", response_model=Optional[UserSettings])
async def get_user_settings_by_session_user( async def get_user_settings_by_session_user(
user=Depends(get_verified_user), db=Depends(get_db) user=Depends(get_verified_user)
): ):
user = Users.get_user_by_id(db, user.id) user = Users.get_user_by_id(user.id)
if user: if user:
return user.settings return user.settings
else: else:
@ -110,9 +109,9 @@ async def get_user_settings_by_session_user(
@router.post("/user/settings/update", response_model=UserSettings) @router.post("/user/settings/update", response_model=UserSettings)
async def update_user_settings_by_session_user( async def update_user_settings_by_session_user(
form_data: UserSettings, user=Depends(get_verified_user), db=Depends(get_db) form_data: UserSettings, user=Depends(get_verified_user)
): ):
user = Users.update_user_by_id(db, user.id, {"settings": form_data.model_dump()}) user = Users.update_user_by_id(user.id, {"settings": form_data.model_dump()})
if user: if user:
return user.settings return user.settings
else: else:
@ -129,9 +128,9 @@ async def update_user_settings_by_session_user(
@router.get("/user/info", response_model=Optional[dict]) @router.get("/user/info", response_model=Optional[dict])
async def get_user_info_by_session_user( async def get_user_info_by_session_user(
user=Depends(get_verified_user), db=Depends(get_db) user=Depends(get_verified_user)
): ):
user = Users.get_user_by_id(db, user.id) user = Users.get_user_by_id(user.id)
if user: if user:
return user.info return user.info
else: else:
@ -148,15 +147,15 @@ async def get_user_info_by_session_user(
@router.post("/user/info/update", response_model=Optional[dict]) @router.post("/user/info/update", response_model=Optional[dict])
async def update_user_info_by_session_user( async def update_user_info_by_session_user(
form_data: dict, user=Depends(get_verified_user), db=Depends(get_db) form_data: dict, user=Depends(get_verified_user)
): ):
user = Users.get_user_by_id(db, user.id) user = Users.get_user_by_id(user.id)
if user: if user:
if user.info is None: if user.info is None:
user.info = {} user.info = {}
user = Users.update_user_by_id( user = Users.update_user_by_id(
db, user.id, {"info": {**user.info, **form_data}} user.id, {"info": {**user.info, **form_data}}
) )
if user: if user:
return user.info return user.info
@ -184,14 +183,14 @@ class UserResponse(BaseModel):
@router.get("/{user_id}", response_model=UserResponse) @router.get("/{user_id}", response_model=UserResponse)
async def get_user_by_id( async def get_user_by_id(
user_id: str, user=Depends(get_verified_user), db=Depends(get_db) user_id: str, user=Depends(get_verified_user)
): ):
# Check if user_id is a shared chat # Check if user_id is a shared chat
# If it is, get the user_id from the chat # If it is, get the user_id from the chat
if user_id.startswith("shared-"): if user_id.startswith("shared-"):
chat_id = user_id.replace("shared-", "") chat_id = user_id.replace("shared-", "")
chat = Chats.get_chat_by_id(db, chat_id) chat = Chats.get_chat_by_id(chat_id)
if chat: if chat:
user_id = chat.user_id user_id = chat.user_id
else: else:
@ -200,7 +199,7 @@ async def get_user_by_id(
detail=ERROR_MESSAGES.USER_NOT_FOUND, detail=ERROR_MESSAGES.USER_NOT_FOUND,
) )
user = Users.get_user_by_id(db, user_id) user = Users.get_user_by_id(user_id)
if user: if user:
return UserResponse(name=user.name, profile_image_url=user.profile_image_url) return UserResponse(name=user.name, profile_image_url=user.profile_image_url)
@ -221,13 +220,12 @@ async def update_user_by_id(
user_id: str, user_id: str,
form_data: UserUpdateForm, form_data: UserUpdateForm,
session_user=Depends(get_admin_user), session_user=Depends(get_admin_user),
db=Depends(get_db),
): ):
user = Users.get_user_by_id(db, user_id) user = Users.get_user_by_id(user_id)
if user: if user:
if form_data.email.lower() != user.email: if form_data.email.lower() != user.email:
email_user = Users.get_user_by_email(db, form_data.email.lower()) email_user = Users.get_user_by_email(form_data.email.lower())
if email_user: if email_user:
raise HTTPException( raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST, status_code=status.HTTP_400_BAD_REQUEST,
@ -237,11 +235,10 @@ async def update_user_by_id(
if form_data.password: if form_data.password:
hashed = get_password_hash(form_data.password) hashed = get_password_hash(form_data.password)
log.debug(f"hashed: {hashed}") log.debug(f"hashed: {hashed}")
Auths.update_user_password_by_id(db, user_id, hashed) Auths.update_user_password_by_id(user_id, hashed)
Auths.update_email_by_id(db, user_id, form_data.email.lower()) Auths.update_email_by_id(user_id, form_data.email.lower())
updated_user = Users.update_user_by_id( updated_user = Users.update_user_by_id(
db,
user_id, user_id,
{ {
"name": form_data.name, "name": form_data.name,
@ -271,10 +268,10 @@ async def update_user_by_id(
@router.delete("/{user_id}", response_model=bool) @router.delete("/{user_id}", response_model=bool)
async def delete_user_by_id( async def delete_user_by_id(
user_id: str, user=Depends(get_admin_user), db=Depends(get_db) user_id: str, user=Depends(get_admin_user)
): ):
if user.id != user_id: if user.id != user_id:
result = Auths.delete_auth_by_id(db, user_id) result = Auths.delete_auth_by_id(user_id)
if result: if result:
return True return True

View File

@ -57,7 +57,7 @@ from apps.webui.main import (
get_pipe_models, get_pipe_models,
generate_function_chat_completion, generate_function_chat_completion,
) )
from apps.webui.internal.db import get_db, SessionLocal from apps.webui.internal.db import get_session, SessionLocal
from pydantic import BaseModel from pydantic import BaseModel
@ -410,7 +410,6 @@ class ChatCompletionMiddleware(BaseHTTPMiddleware):
user = get_current_user( user = get_current_user(
request, request,
get_http_authorization_cred(request.headers.get("Authorization")), get_http_authorization_cred(request.headers.get("Authorization")),
SessionLocal(),
) )
# Flag to skip RAG completions if file_handler is present in tools/functions # Flag to skip RAG completions if file_handler is present in tools/functions
skip_files = False skip_files = False
@ -800,9 +799,7 @@ app.add_middleware(
@app.middleware("http") @app.middleware("http")
async def check_url(request: Request, call_next): async def check_url(request: Request, call_next):
if len(app.state.MODELS) == 0: if len(app.state.MODELS) == 0:
db = SessionLocal() await get_all_models()
await get_all_models(db)
db.commit()
else: else:
pass pass
@ -836,12 +833,12 @@ app.mount("/api/v1", webui_app)
webui_app.state.EMBEDDING_FUNCTION = rag_app.state.EMBEDDING_FUNCTION webui_app.state.EMBEDDING_FUNCTION = rag_app.state.EMBEDDING_FUNCTION
async def get_all_models(db: Session): async def get_all_models():
pipe_models = [] pipe_models = []
openai_models = [] openai_models = []
ollama_models = [] ollama_models = []
pipe_models = await get_pipe_models(db) pipe_models = await get_pipe_models()
if app.state.config.ENABLE_OPENAI_API: if app.state.config.ENABLE_OPENAI_API:
openai_models = await get_openai_models() openai_models = await get_openai_models()
@ -863,7 +860,7 @@ async def get_all_models(db: Session):
models = pipe_models + openai_models + ollama_models models = pipe_models + openai_models + ollama_models
custom_models = Models.get_all_models(db) custom_models = Models.get_all_models()
for custom_model in custom_models: for custom_model in custom_models:
if custom_model.base_model_id == None: if custom_model.base_model_id == None:
for model in models: for model in models:
@ -903,8 +900,8 @@ async def get_all_models(db: Session):
@app.get("/api/models") @app.get("/api/models")
async def get_models(user=Depends(get_verified_user), db=Depends(get_db)): async def get_models(user=Depends(get_verified_user)):
models = await get_all_models(db) models = await get_all_models()
# Filter out filter pipelines # Filter out filter pipelines
models = [ models = [
@ -1608,9 +1605,8 @@ async def get_pipeline_valves(
urlIdx: Optional[int], urlIdx: Optional[int],
pipeline_id: str, pipeline_id: str,
user=Depends(get_admin_user), user=Depends(get_admin_user),
db=Depends(get_db),
): ):
models = await get_all_models(db) models = await get_all_models()
r = None r = None
try: try:
@ -1649,9 +1645,8 @@ async def get_pipeline_valves_spec(
urlIdx: Optional[int], urlIdx: Optional[int],
pipeline_id: str, pipeline_id: str,
user=Depends(get_admin_user), user=Depends(get_admin_user),
db=Depends(get_db),
): ):
models = await get_all_models(db) models = await get_all_models()
r = None r = None
try: try:
@ -1690,9 +1685,8 @@ async def update_pipeline_valves(
pipeline_id: str, pipeline_id: str,
form_data: dict, form_data: dict,
user=Depends(get_admin_user), user=Depends(get_admin_user),
db=Depends(get_db),
): ):
models = await get_all_models(db) models = await get_all_models()
r = None r = None
try: try:
@ -2040,7 +2034,8 @@ async def healthcheck():
@app.get("/health/db") @app.get("/health/db")
async def healthcheck_with_db(db: Session = Depends(get_db)): async def healthcheck_with_db():
with get_session() as db:
result = db.execute(text("SELECT 1;")).all() result = db.execute(text("SELECT 1;")).all()
return {"status": True} return {"status": True}

View File

@ -1,188 +0,0 @@
"""init
Revision ID: 22b5ab2667b8
Revises:
Create Date: 2024-06-20 13:22:40.397002
"""
from typing import Sequence, Union
from alembic import op
import sqlalchemy as sa
from sqlalchemy.engine.reflection import Inspector
import apps.webui.internal.db
# revision identifiers, used by Alembic.
revision: str = "22b5ab2667b8"
down_revision: Union[str, None] = None
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None
def upgrade() -> None:
con = op.get_bind()
inspector = Inspector.from_engine(con)
tables = set(inspector.get_table_names())
# ### commands auto generated by Alembic - please adjust! ###
if not "auth" in tables:
op.create_table(
"auth",
sa.Column("id", sa.String(), nullable=False),
sa.Column("email", sa.String(), nullable=True),
sa.Column("password", sa.String(), nullable=True),
sa.Column("active", sa.Boolean(), nullable=True),
sa.PrimaryKeyConstraint("id"),
)
if not "chat" in tables:
op.create_table(
"chat",
sa.Column("id", sa.String(), nullable=False),
sa.Column("user_id", sa.String(), nullable=True),
sa.Column("title", sa.String(), nullable=True),
sa.Column("chat", sa.String(), nullable=True),
sa.Column("created_at", sa.BigInteger(), nullable=True),
sa.Column("updated_at", sa.BigInteger(), nullable=True),
sa.Column("share_id", sa.String(), nullable=True),
sa.Column("archived", sa.Boolean(), nullable=True),
sa.PrimaryKeyConstraint("id"),
sa.UniqueConstraint("share_id"),
)
if not "chatidtag" in tables:
op.create_table(
"chatidtag",
sa.Column("id", sa.String(), nullable=False),
sa.Column("tag_name", sa.String(), nullable=True),
sa.Column("chat_id", sa.String(), nullable=True),
sa.Column("user_id", sa.String(), nullable=True),
sa.Column("timestamp", sa.BigInteger(), nullable=True),
sa.PrimaryKeyConstraint("id"),
)
if not "document" in tables:
op.create_table(
"document",
sa.Column("collection_name", sa.String(), nullable=False),
sa.Column("name", sa.String(), nullable=True),
sa.Column("title", sa.String(), nullable=True),
sa.Column("filename", sa.String(), nullable=True),
sa.Column("content", sa.String(), nullable=True),
sa.Column("user_id", sa.String(), nullable=True),
sa.Column("timestamp", sa.BigInteger(), nullable=True),
sa.PrimaryKeyConstraint("collection_name"),
sa.UniqueConstraint("name"),
)
if not "memory" in tables:
op.create_table(
"memory",
sa.Column("id", sa.String(), nullable=False),
sa.Column("user_id", sa.String(), nullable=True),
sa.Column("content", sa.String(), nullable=True),
sa.Column("updated_at", sa.BigInteger(), nullable=True),
sa.Column("created_at", sa.BigInteger(), nullable=True),
sa.PrimaryKeyConstraint("id"),
)
if not "model" in tables:
op.create_table(
"model",
sa.Column("id", sa.String(), nullable=False),
sa.Column("user_id", sa.String(), nullable=True),
sa.Column("base_model_id", sa.String(), nullable=True),
sa.Column("name", sa.String(), nullable=True),
sa.Column("params", apps.webui.internal.db.JSONField(), nullable=True),
sa.Column("meta", apps.webui.internal.db.JSONField(), nullable=True),
sa.Column("updated_at", sa.BigInteger(), nullable=True),
sa.Column("created_at", sa.BigInteger(), nullable=True),
sa.PrimaryKeyConstraint("id"),
)
if not "prompt" in tables:
op.create_table(
"prompt",
sa.Column("command", sa.String(), nullable=False),
sa.Column("user_id", sa.String(), nullable=True),
sa.Column("title", sa.String(), nullable=True),
sa.Column("content", sa.String(), nullable=True),
sa.Column("timestamp", sa.BigInteger(), nullable=True),
sa.PrimaryKeyConstraint("command"),
)
if not "tag" in tables:
op.create_table(
"tag",
sa.Column("id", sa.String(), nullable=False),
sa.Column("name", sa.String(), nullable=True),
sa.Column("user_id", sa.String(), nullable=True),
sa.Column("data", sa.String(), nullable=True),
sa.PrimaryKeyConstraint("id"),
)
if not "tool" in tables:
op.create_table(
"tool",
sa.Column("id", sa.String(), nullable=False),
sa.Column("user_id", sa.String(), nullable=True),
sa.Column("name", sa.String(), nullable=True),
sa.Column("content", sa.String(), nullable=True),
sa.Column("specs", apps.webui.internal.db.JSONField(), nullable=True),
sa.Column("meta", apps.webui.internal.db.JSONField(), nullable=True),
sa.Column("updated_at", sa.BigInteger(), nullable=True),
sa.Column("created_at", sa.BigInteger(), nullable=True),
sa.PrimaryKeyConstraint("id"),
)
if not "user" in tables:
op.create_table(
"user",
sa.Column("id", sa.String(), nullable=False),
sa.Column("name", sa.String(), nullable=True),
sa.Column("email", sa.String(), nullable=True),
sa.Column("role", sa.String(), nullable=True),
sa.Column("profile_image_url", sa.String(), nullable=True),
sa.Column("last_active_at", sa.BigInteger(), nullable=True),
sa.Column("updated_at", sa.BigInteger(), nullable=True),
sa.Column("created_at", sa.BigInteger(), nullable=True),
sa.Column("api_key", sa.String(), nullable=True),
sa.Column("settings", apps.webui.internal.db.JSONField(), nullable=True),
sa.Column("info", apps.webui.internal.db.JSONField(), nullable=True),
sa.PrimaryKeyConstraint("id"),
sa.UniqueConstraint("api_key"),
)
if not "file" in tables:
op.create_table('file',
sa.Column('id', sa.String(), nullable=False),
sa.Column('user_id', sa.String(), nullable=True),
sa.Column('filename', sa.String(), nullable=True),
sa.Column('meta', apps.webui.internal.db.JSONField(), nullable=True),
sa.Column('created_at', sa.BigInteger(), nullable=True),
sa.PrimaryKeyConstraint('id')
)
if not "function" in tables:
op.create_table('function',
sa.Column('id', sa.String(), nullable=False),
sa.Column('user_id', sa.String(), nullable=True),
sa.Column('name', sa.Text(), nullable=True),
sa.Column('type', sa.Text(), nullable=True),
sa.Column('content', sa.Text(), nullable=True),
sa.Column('meta', apps.webui.internal.db.JSONField(), nullable=True),
sa.Column('updated_at', sa.BigInteger(), nullable=True),
sa.Column('created_at', sa.BigInteger(), nullable=True),
sa.PrimaryKeyConstraint('id')
)
# ### end Alembic commands ###
def downgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
# do nothing as we assume we had previous migrations from peewee-migrate
pass
# ### end Alembic commands ###

View File

@ -0,0 +1,161 @@
"""init
Revision ID: ba76b0bae648
Revises:
Create Date: 2024-06-24 09:09:11.636336
"""
from typing import Sequence, Union
from alembic import op
import sqlalchemy as sa
import apps.webui.internal.db
# revision identifiers, used by Alembic.
revision: str = 'ba76b0bae648'
down_revision: Union[str, None] = None
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None
def upgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.create_table('auth',
sa.Column('id', sa.String(), nullable=False),
sa.Column('email', sa.String(), nullable=True),
sa.Column('password', sa.String(), nullable=True),
sa.Column('active', sa.Boolean(), nullable=True),
sa.PrimaryKeyConstraint('id')
)
op.create_table('chat',
sa.Column('id', sa.String(), nullable=False),
sa.Column('user_id', sa.String(), nullable=True),
sa.Column('title', sa.String(), nullable=True),
sa.Column('chat', sa.String(), nullable=True),
sa.Column('created_at', sa.BigInteger(), nullable=True),
sa.Column('updated_at', sa.BigInteger(), nullable=True),
sa.Column('share_id', sa.String(), nullable=True),
sa.Column('archived', sa.Boolean(), nullable=True),
sa.PrimaryKeyConstraint('id'),
sa.UniqueConstraint('share_id')
)
op.create_table('chatidtag',
sa.Column('id', sa.String(), nullable=False),
sa.Column('tag_name', sa.String(), nullable=True),
sa.Column('chat_id', sa.String(), nullable=True),
sa.Column('user_id', sa.String(), nullable=True),
sa.Column('timestamp', sa.BigInteger(), nullable=True),
sa.PrimaryKeyConstraint('id')
)
op.create_table('document',
sa.Column('collection_name', sa.String(), nullable=False),
sa.Column('name', sa.String(), nullable=True),
sa.Column('title', sa.String(), nullable=True),
sa.Column('filename', sa.String(), nullable=True),
sa.Column('content', sa.String(), nullable=True),
sa.Column('user_id', sa.String(), nullable=True),
sa.Column('timestamp', sa.BigInteger(), nullable=True),
sa.PrimaryKeyConstraint('collection_name'),
sa.UniqueConstraint('name')
)
op.create_table('file',
sa.Column('id', sa.String(), nullable=False),
sa.Column('user_id', sa.String(), nullable=True),
sa.Column('filename', sa.String(), nullable=True),
sa.Column('meta', apps.webui.internal.db.JSONField(), nullable=True),
sa.Column('created_at', sa.BigInteger(), nullable=True),
sa.PrimaryKeyConstraint('id')
)
op.create_table('function',
sa.Column('id', sa.String(), nullable=False),
sa.Column('user_id', sa.String(), nullable=True),
sa.Column('name', sa.Text(), nullable=True),
sa.Column('type', sa.Text(), nullable=True),
sa.Column('content', sa.Text(), nullable=True),
sa.Column('meta', apps.webui.internal.db.JSONField(), nullable=True),
sa.Column('valves', apps.webui.internal.db.JSONField(), nullable=True),
sa.Column('is_active', sa.Boolean(), nullable=True),
sa.Column('updated_at', sa.BigInteger(), nullable=True),
sa.Column('created_at', sa.BigInteger(), nullable=True),
sa.PrimaryKeyConstraint('id')
)
op.create_table('memory',
sa.Column('id', sa.String(), nullable=False),
sa.Column('user_id', sa.String(), nullable=True),
sa.Column('content', sa.String(), nullable=True),
sa.Column('updated_at', sa.BigInteger(), nullable=True),
sa.Column('created_at', sa.BigInteger(), nullable=True),
sa.PrimaryKeyConstraint('id')
)
op.create_table('model',
sa.Column('id', sa.String(), nullable=False),
sa.Column('user_id', sa.String(), nullable=True),
sa.Column('base_model_id', sa.String(), nullable=True),
sa.Column('name', sa.String(), nullable=True),
sa.Column('params', apps.webui.internal.db.JSONField(), nullable=True),
sa.Column('meta', apps.webui.internal.db.JSONField(), nullable=True),
sa.Column('updated_at', sa.BigInteger(), nullable=True),
sa.Column('created_at', sa.BigInteger(), nullable=True),
sa.PrimaryKeyConstraint('id')
)
op.create_table('prompt',
sa.Column('command', sa.String(), nullable=False),
sa.Column('user_id', sa.String(), nullable=True),
sa.Column('title', sa.String(), nullable=True),
sa.Column('content', sa.String(), nullable=True),
sa.Column('timestamp', sa.BigInteger(), nullable=True),
sa.PrimaryKeyConstraint('command')
)
op.create_table('tag',
sa.Column('id', sa.String(), nullable=False),
sa.Column('name', sa.String(), nullable=True),
sa.Column('user_id', sa.String(), nullable=True),
sa.Column('data', sa.String(), nullable=True),
sa.PrimaryKeyConstraint('id')
)
op.create_table('tool',
sa.Column('id', sa.String(), nullable=False),
sa.Column('user_id', sa.String(), nullable=True),
sa.Column('name', sa.String(), nullable=True),
sa.Column('content', sa.String(), nullable=True),
sa.Column('specs', apps.webui.internal.db.JSONField(), nullable=True),
sa.Column('meta', apps.webui.internal.db.JSONField(), nullable=True),
sa.Column('valves', apps.webui.internal.db.JSONField(), nullable=True),
sa.Column('updated_at', sa.BigInteger(), nullable=True),
sa.Column('created_at', sa.BigInteger(), nullable=True),
sa.PrimaryKeyConstraint('id')
)
op.create_table('user',
sa.Column('id', sa.String(), nullable=False),
sa.Column('name', sa.String(), nullable=True),
sa.Column('email', sa.String(), nullable=True),
sa.Column('role', sa.String(), nullable=True),
sa.Column('profile_image_url', sa.String(), nullable=True),
sa.Column('last_active_at', sa.BigInteger(), nullable=True),
sa.Column('updated_at', sa.BigInteger(), nullable=True),
sa.Column('created_at', sa.BigInteger(), nullable=True),
sa.Column('api_key', sa.String(), nullable=True),
sa.Column('settings', apps.webui.internal.db.JSONField(), nullable=True),
sa.Column('info', apps.webui.internal.db.JSONField(), nullable=True),
sa.PrimaryKeyConstraint('id'),
sa.UniqueConstraint('api_key')
)
# ### end Alembic commands ###
def downgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.drop_table('user')
op.drop_table('tool')
op.drop_table('tag')
op.drop_table('prompt')
op.drop_table('model')
op.drop_table('memory')
op.drop_table('function')
op.drop_table('file')
op.drop_table('document')
op.drop_table('chatidtag')
op.drop_table('chat')
op.drop_table('auth')
# ### end Alembic commands ###

View File

@ -31,7 +31,6 @@ class TestAuths(AbstractPostgresTest):
from utils.utils import get_password_hash from utils.utils import get_password_hash
user = self.auths.insert_new_auth( user = self.auths.insert_new_auth(
self.db_session,
email="john.doe@openwebui.com", email="john.doe@openwebui.com",
password=get_password_hash("old_password"), password=get_password_hash("old_password"),
name="John Doe", name="John Doe",
@ -45,7 +44,7 @@ class TestAuths(AbstractPostgresTest):
json={"name": "John Doe 2", "profile_image_url": "/user2.png"}, json={"name": "John Doe 2", "profile_image_url": "/user2.png"},
) )
assert response.status_code == 200 assert response.status_code == 200
db_user = self.users.get_user_by_id(self.db_session, user.id) db_user = self.users.get_user_by_id(user.id)
assert db_user.name == "John Doe 2" assert db_user.name == "John Doe 2"
assert db_user.profile_image_url == "/user2.png" assert db_user.profile_image_url == "/user2.png"
@ -53,7 +52,6 @@ class TestAuths(AbstractPostgresTest):
from utils.utils import get_password_hash from utils.utils import get_password_hash
user = self.auths.insert_new_auth( user = self.auths.insert_new_auth(
self.db_session,
email="john.doe@openwebui.com", email="john.doe@openwebui.com",
password=get_password_hash("old_password"), password=get_password_hash("old_password"),
name="John Doe", name="John Doe",
@ -69,11 +67,11 @@ class TestAuths(AbstractPostgresTest):
assert response.status_code == 200 assert response.status_code == 200
old_auth = self.auths.authenticate_user( old_auth = self.auths.authenticate_user(
self.db_session, "john.doe@openwebui.com", "old_password" "john.doe@openwebui.com", "old_password"
) )
assert old_auth is None assert old_auth is None
new_auth = self.auths.authenticate_user( new_auth = self.auths.authenticate_user(
self.db_session, "john.doe@openwebui.com", "new_password" "john.doe@openwebui.com", "new_password"
) )
assert new_auth is not None assert new_auth is not None
@ -81,7 +79,6 @@ class TestAuths(AbstractPostgresTest):
from utils.utils import get_password_hash from utils.utils import get_password_hash
user = self.auths.insert_new_auth( user = self.auths.insert_new_auth(
self.db_session,
email="john.doe@openwebui.com", email="john.doe@openwebui.com",
password=get_password_hash("password"), password=get_password_hash("password"),
name="John Doe", name="John Doe",
@ -144,7 +141,6 @@ class TestAuths(AbstractPostgresTest):
def test_get_admin_details(self): def test_get_admin_details(self):
self.auths.insert_new_auth( self.auths.insert_new_auth(
self.db_session,
email="john.doe@openwebui.com", email="john.doe@openwebui.com",
password="password", password="password",
name="John Doe", name="John Doe",
@ -162,7 +158,6 @@ class TestAuths(AbstractPostgresTest):
def test_create_api_key_(self): def test_create_api_key_(self):
user = self.auths.insert_new_auth( user = self.auths.insert_new_auth(
self.db_session,
email="john.doe@openwebui.com", email="john.doe@openwebui.com",
password="password", password="password",
name="John Doe", name="John Doe",
@ -178,31 +173,29 @@ class TestAuths(AbstractPostgresTest):
def test_delete_api_key(self): def test_delete_api_key(self):
user = self.auths.insert_new_auth( user = self.auths.insert_new_auth(
self.db_session,
email="john.doe@openwebui.com", email="john.doe@openwebui.com",
password="password", password="password",
name="John Doe", name="John Doe",
profile_image_url="/user.png", profile_image_url="/user.png",
role="admin", role="admin",
) )
self.users.update_user_api_key_by_id(self.db_session, user.id, "abc") self.users.update_user_api_key_by_id(user.id, "abc")
with mock_webui_user(id=user.id): with mock_webui_user(id=user.id):
response = self.fast_api_client.delete(self.create_url("/api_key")) response = self.fast_api_client.delete(self.create_url("/api_key"))
assert response.status_code == 200 assert response.status_code == 200
assert response.json() == True assert response.json() == True
db_user = self.users.get_user_by_id(self.db_session, user.id) db_user = self.users.get_user_by_id(user.id)
assert db_user.api_key is None assert db_user.api_key is None
def test_get_api_key(self): def test_get_api_key(self):
user = self.auths.insert_new_auth( user = self.auths.insert_new_auth(
self.db_session,
email="john.doe@openwebui.com", email="john.doe@openwebui.com",
password="password", password="password",
name="John Doe", name="John Doe",
profile_image_url="/user.png", profile_image_url="/user.png",
role="admin", role="admin",
) )
self.users.update_user_api_key_by_id(self.db_session, user.id, "abc") self.users.update_user_api_key_by_id(user.id, "abc")
with mock_webui_user(id=user.id): with mock_webui_user(id=user.id):
response = self.fast_api_client.get(self.create_url("/api_key")) response = self.fast_api_client.get(self.create_url("/api_key"))
assert response.status_code == 200 assert response.status_code == 200

View File

@ -18,7 +18,6 @@ class TestChats(AbstractPostgresTest):
self.chats = Chats self.chats = Chats
self.chats.insert_new_chat( self.chats.insert_new_chat(
self.db_session,
"2", "2",
ChatForm( ChatForm(
**{ **{
@ -46,7 +45,7 @@ class TestChats(AbstractPostgresTest):
with mock_webui_user(id="2"): with mock_webui_user(id="2"):
response = self.fast_api_client.delete(self.create_url("/")) response = self.fast_api_client.delete(self.create_url("/"))
assert response.status_code == 200 assert response.status_code == 200
assert len(self.chats.get_chats(self.db_session)) == 0 assert len(self.chats.get_chats()) == 0
def test_get_user_chat_list_by_user_id(self): def test_get_user_chat_list_by_user_id(self):
with mock_webui_user(id="3"): with mock_webui_user(id="3"):
@ -84,14 +83,13 @@ class TestChats(AbstractPostgresTest):
assert data["title"] == "New Chat" assert data["title"] == "New Chat"
assert data["updated_at"] is not None assert data["updated_at"] is not None
assert data["created_at"] is not None assert data["created_at"] is not None
assert len(self.chats.get_chats(self.db_session)) == 2 assert len(self.chats.get_chats()) == 2
def test_get_user_chats(self): def test_get_user_chats(self):
self.test_get_session_user_chat_list() self.test_get_session_user_chat_list()
def test_get_user_archived_chats(self): def test_get_user_archived_chats(self):
self.chats.archive_all_chats_by_user_id(self.db_session, "2") self.chats.archive_all_chats_by_user_id("2")
self.db_session.commit()
with mock_webui_user(id="2"): with mock_webui_user(id="2"):
response = self.fast_api_client.get(self.create_url("/all/archived")) response = self.fast_api_client.get(self.create_url("/all/archived"))
assert response.status_code == 200 assert response.status_code == 200
@ -114,12 +112,11 @@ class TestChats(AbstractPostgresTest):
with mock_webui_user(id="2"): with mock_webui_user(id="2"):
response = self.fast_api_client.post(self.create_url("/archive/all")) response = self.fast_api_client.post(self.create_url("/archive/all"))
assert response.status_code == 200 assert response.status_code == 200
assert len(self.chats.get_archived_chats_by_user_id(self.db_session, "2")) == 1 assert len(self.chats.get_archived_chats_by_user_id("2")) == 1
def test_get_shared_chat_by_id(self): def test_get_shared_chat_by_id(self):
chat_id = self.chats.get_chats(self.db_session)[0].id chat_id = self.chats.get_chats()[0].id
self.chats.update_chat_share_id_by_id(self.db_session, chat_id, chat_id) self.chats.update_chat_share_id_by_id(chat_id, chat_id)
self.db_session.commit()
with mock_webui_user(id="2"): with mock_webui_user(id="2"):
response = self.fast_api_client.get(self.create_url(f"/share/{chat_id}")) response = self.fast_api_client.get(self.create_url(f"/share/{chat_id}"))
assert response.status_code == 200 assert response.status_code == 200
@ -136,7 +133,7 @@ class TestChats(AbstractPostgresTest):
assert data["title"] == "New Chat" assert data["title"] == "New Chat"
def test_get_chat_by_id(self): def test_get_chat_by_id(self):
chat_id = self.chats.get_chats(self.db_session)[0].id chat_id = self.chats.get_chats()[0].id
with mock_webui_user(id="2"): with mock_webui_user(id="2"):
response = self.fast_api_client.get(self.create_url(f"/{chat_id}")) response = self.fast_api_client.get(self.create_url(f"/{chat_id}"))
assert response.status_code == 200 assert response.status_code == 200
@ -153,7 +150,7 @@ class TestChats(AbstractPostgresTest):
assert data["user_id"] == "2" assert data["user_id"] == "2"
def test_update_chat_by_id(self): def test_update_chat_by_id(self):
chat_id = self.chats.get_chats(self.db_session)[0].id chat_id = self.chats.get_chats()[0].id
with mock_webui_user(id="2"): with mock_webui_user(id="2"):
response = self.fast_api_client.post( response = self.fast_api_client.post(
self.create_url(f"/{chat_id}"), self.create_url(f"/{chat_id}"),
@ -181,14 +178,14 @@ class TestChats(AbstractPostgresTest):
assert data["user_id"] == "2" assert data["user_id"] == "2"
def test_delete_chat_by_id(self): def test_delete_chat_by_id(self):
chat_id = self.chats.get_chats(self.db_session)[0].id chat_id = self.chats.get_chats()[0].id
with mock_webui_user(id="2"): with mock_webui_user(id="2"):
response = self.fast_api_client.delete(self.create_url(f"/{chat_id}")) response = self.fast_api_client.delete(self.create_url(f"/{chat_id}"))
assert response.status_code == 200 assert response.status_code == 200
assert response.json() is True assert response.json() is True
def test_clone_chat_by_id(self): def test_clone_chat_by_id(self):
chat_id = self.chats.get_chats(self.db_session)[0].id chat_id = self.chats.get_chats()[0].id
with mock_webui_user(id="2"): with mock_webui_user(id="2"):
response = self.fast_api_client.get(self.create_url(f"/{chat_id}/clone")) response = self.fast_api_client.get(self.create_url(f"/{chat_id}/clone"))
@ -209,31 +206,30 @@ class TestChats(AbstractPostgresTest):
assert data["user_id"] == "2" assert data["user_id"] == "2"
def test_archive_chat_by_id(self): def test_archive_chat_by_id(self):
chat_id = self.chats.get_chats(self.db_session)[0].id chat_id = self.chats.get_chats()[0].id
with mock_webui_user(id="2"): with mock_webui_user(id="2"):
response = self.fast_api_client.get(self.create_url(f"/{chat_id}/archive")) response = self.fast_api_client.get(self.create_url(f"/{chat_id}/archive"))
assert response.status_code == 200 assert response.status_code == 200
chat = self.chats.get_chat_by_id(self.db_session, chat_id) chat = self.chats.get_chat_by_id(chat_id)
assert chat.archived is True assert chat.archived is True
def test_share_chat_by_id(self): def test_share_chat_by_id(self):
chat_id = self.chats.get_chats(self.db_session)[0].id chat_id = self.chats.get_chats()[0].id
with mock_webui_user(id="2"): with mock_webui_user(id="2"):
response = self.fast_api_client.post(self.create_url(f"/{chat_id}/share")) response = self.fast_api_client.post(self.create_url(f"/{chat_id}/share"))
assert response.status_code == 200 assert response.status_code == 200
chat = self.chats.get_chat_by_id(self.db_session, chat_id) chat = self.chats.get_chat_by_id(chat_id)
assert chat.share_id is not None assert chat.share_id is not None
def test_delete_shared_chat_by_id(self): def test_delete_shared_chat_by_id(self):
chat_id = self.chats.get_chats(self.db_session)[0].id chat_id = self.chats.get_chats()[0].id
share_id = str(uuid.uuid4()) share_id = str(uuid.uuid4())
self.chats.update_chat_share_id_by_id(self.db_session, chat_id, share_id) self.chats.update_chat_share_id_by_id(chat_id, share_id)
self.db_session.commit()
with mock_webui_user(id="2"): with mock_webui_user(id="2"):
response = self.fast_api_client.delete(self.create_url(f"/{chat_id}/share")) response = self.fast_api_client.delete(self.create_url(f"/{chat_id}/share"))
assert response.status_code assert response.status_code
chat = self.chats.get_chat_by_id(self.db_session, chat_id) chat = self.chats.get_chat_by_id(chat_id)
assert chat.share_id is None assert chat.share_id is None

View File

@ -14,7 +14,7 @@ class TestDocuments(AbstractPostgresTest):
def test_documents(self): def test_documents(self):
# Empty database # Empty database
assert len(self.documents.get_docs(self.db_session)) == 0 assert len(self.documents.get_docs()) == 0
with mock_webui_user(id="2"): with mock_webui_user(id="2"):
response = self.fast_api_client.get(self.create_url("/")) response = self.fast_api_client.get(self.create_url("/"))
assert response.status_code == 200 assert response.status_code == 200
@ -34,7 +34,7 @@ class TestDocuments(AbstractPostgresTest):
) )
assert response.status_code == 200 assert response.status_code == 200
assert response.json()["name"] == "doc_name" assert response.json()["name"] == "doc_name"
assert len(self.documents.get_docs(self.db_session)) == 1 assert len(self.documents.get_docs()) == 1
# Get the document # Get the document
with mock_webui_user(id="2"): with mock_webui_user(id="2"):
@ -61,7 +61,7 @@ class TestDocuments(AbstractPostgresTest):
) )
assert response.status_code == 200 assert response.status_code == 200
assert response.json()["name"] == "doc_name 2" assert response.json()["name"] == "doc_name 2"
assert len(self.documents.get_docs(self.db_session)) == 2 assert len(self.documents.get_docs()) == 2
# Get all documents # Get all documents
with mock_webui_user(id="2"): with mock_webui_user(id="2"):
@ -95,7 +95,7 @@ class TestDocuments(AbstractPostgresTest):
assert data["content"] == { assert data["content"] == {
"tags": [{"name": "testing-tag"}, {"name": "another-tag"}] "tags": [{"name": "testing-tag"}, {"name": "another-tag"}]
} }
assert len(self.documents.get_docs(self.db_session)) == 2 assert len(self.documents.get_docs()) == 2
# Delete the first document # Delete the first document
with mock_webui_user(id="2"): with mock_webui_user(id="2"):
@ -103,4 +103,4 @@ class TestDocuments(AbstractPostgresTest):
self.create_url("/doc/delete?name=doc_name rework") self.create_url("/doc/delete?name=doc_name rework")
) )
assert response.status_code == 200 assert response.status_code == 200
assert len(self.documents.get_docs(self.db_session)) == 1 assert len(self.documents.get_docs()) == 1

View File

@ -68,6 +68,16 @@ class TestPrompts(AbstractPostgresTest):
assert data["content"] == "description Updated" assert data["content"] == "description Updated"
assert data["user_id"] == "3" assert data["user_id"] == "3"
# Get prompt by command
with mock_webui_user(id="2"):
response = self.fast_api_client.get(self.create_url("/command/my-command2"))
assert response.status_code == 200
data = response.json()
assert data["command"] == "/my-command2"
assert data["title"] == "Hello World Updated"
assert data["content"] == "description Updated"
assert data["user_id"] == "3"
# Delete prompt # Delete prompt
with mock_webui_user(id="2"): with mock_webui_user(id="2"):
response = self.fast_api_client.delete( response = self.fast_api_client.delete(

View File

@ -33,7 +33,6 @@ class TestUsers(AbstractPostgresTest):
def setup_method(self): def setup_method(self):
super().setup_method() super().setup_method()
self.users.insert_new_user( self.users.insert_new_user(
self.db_session,
id="1", id="1",
name="user 1", name="user 1",
email="user1@openwebui.com", email="user1@openwebui.com",
@ -41,7 +40,6 @@ class TestUsers(AbstractPostgresTest):
role="user", role="user",
) )
self.users.insert_new_user( self.users.insert_new_user(
self.db_session,
id="2", id="2",
name="user 2", name="user 2",
email="user2@openwebui.com", email="user2@openwebui.com",

View File

@ -2,7 +2,6 @@ from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials
from fastapi import HTTPException, status, Depends, Request from fastapi import HTTPException, status, Depends, Request
from sqlalchemy.orm import Session from sqlalchemy.orm import Session
from apps.webui.internal.db import get_db
from apps.webui.models.users import Users from apps.webui.models.users import Users
from pydantic import BaseModel from pydantic import BaseModel
@ -79,7 +78,6 @@ def get_http_authorization_cred(auth_header: str):
def get_current_user( def get_current_user(
request: Request, request: Request,
auth_token: HTTPAuthorizationCredentials = Depends(bearer_security), auth_token: HTTPAuthorizationCredentials = Depends(bearer_security),
db=Depends(get_db),
): ):
token = None token = None
@ -94,19 +92,19 @@ def get_current_user(
# auth by api key # auth by api key
if token.startswith("sk-"): if token.startswith("sk-"):
return get_current_user_by_api_key(db, token) return get_current_user_by_api_key(token)
# auth by jwt token # auth by jwt token
data = decode_token(token) data = decode_token(token)
if data != None and "id" in data: if data != None and "id" in data:
user = Users.get_user_by_id(db, data["id"]) user = Users.get_user_by_id(data["id"])
if user is None: if user is None:
raise HTTPException( raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED, status_code=status.HTTP_401_UNAUTHORIZED,
detail=ERROR_MESSAGES.INVALID_TOKEN, detail=ERROR_MESSAGES.INVALID_TOKEN,
) )
else: else:
Users.update_user_last_active_by_id(db, user.id) Users.update_user_last_active_by_id(user.id)
return user return user
else: else:
raise HTTPException( raise HTTPException(