diff --git a/Dockerfile b/Dockerfile index 20e5cce24..737d478ec 100644 --- a/Dockerfile +++ b/Dockerfile @@ -3,7 +3,7 @@ # use build args in the docker build commmand with --build-arg="BUILDARG=true" ARG USE_CUDA=false ARG USE_OLLAMA=false -# Tested with cu117 for CUDA 11 and cu121 for CUDA 12 (default) +# Tested with cu117 for CUDA 11 and cu121 for CUDA 12 (default) ARG USE_CUDA_VER=cu121 # any sentence transformer model; models to use can be found at https://huggingface.co/models?library=sentence-transformers # Leaderboard: https://huggingface.co/spaces/mteb/leaderboard diff --git a/backend/apps/ollama/main.py b/backend/apps/ollama/main.py index 818b53d93..b89d7bf52 100644 --- a/backend/apps/ollama/main.py +++ b/backend/apps/ollama/main.py @@ -81,6 +81,12 @@ async def check_url(request: Request, call_next): return response +@app.head("/") +@app.get("/") +async def get_status(): + return {"status": True} + + @app.get("/urls") async def get_ollama_api_urls(user=Depends(get_admin_user)): return {"OLLAMA_BASE_URLS": app.state.OLLAMA_BASE_URLS} diff --git a/backend/apps/web/internal/db.py b/backend/apps/web/internal/db.py index 554f8002d..fad566ce9 100644 --- a/backend/apps/web/internal/db.py +++ b/backend/apps/web/internal/db.py @@ -1,4 +1,5 @@ from peewee import * +from peewee_migrate import Router from config import SRC_LOG_LEVELS, DATA_DIR import os import logging @@ -16,4 +17,6 @@ else: DB = SqliteDatabase(f"{DATA_DIR}/webui.db") -DB.connect() +router = Router(DB, migrate_dir="apps/web/internal/migrations", logger=log) +router.run() +DB.connect(reuse_if_open=True) diff --git a/backend/apps/web/internal/migrations/001_initial_schema.py b/backend/apps/web/internal/migrations/001_initial_schema.py new file mode 100644 index 000000000..24ea6d39f --- /dev/null +++ b/backend/apps/web/internal/migrations/001_initial_schema.py @@ -0,0 +1,149 @@ +"""Peewee migrations -- 001_initial_schema.py. + +Some examples (model - class or model name):: + + > Model = migrator.orm['table_name'] # Return model in current state by name + > Model = migrator.ModelClass # Return model in current state by name + + > migrator.sql(sql) # Run custom SQL + > migrator.run(func, *args, **kwargs) # Run python function with the given args + > migrator.create_model(Model) # Create a model (could be used as decorator) + > migrator.remove_model(model, cascade=True) # Remove a model + > migrator.add_fields(model, **fields) # Add fields to a model + > migrator.change_fields(model, **fields) # Change fields + > migrator.remove_fields(model, *field_names, cascade=True) + > migrator.rename_field(model, old_field_name, new_field_name) + > migrator.rename_table(model, new_table_name) + > migrator.add_index(model, *col_names, unique=False) + > migrator.add_not_null(model, *field_names) + > migrator.add_default(model, field_name, default) + > migrator.add_constraint(model, name, sql) + > migrator.drop_index(model, *col_names) + > migrator.drop_not_null(model, *field_names) + > migrator.drop_constraints(model, *constraints) + +""" + +from contextlib import suppress + +import peewee as pw +from peewee_migrate import Migrator + + +with suppress(ImportError): + import playhouse.postgres_ext as pw_pext + + +def migrate(migrator: Migrator, database: pw.Database, *, fake=False): + """Write your migrations here.""" + + @migrator.create_model + class Auth(pw.Model): + id = pw.CharField(max_length=255, unique=True) + email = pw.CharField(max_length=255) + password = pw.CharField(max_length=255) + active = pw.BooleanField() + + class Meta: + table_name = "auth" + + @migrator.create_model + class Chat(pw.Model): + id = pw.CharField(max_length=255, unique=True) + user_id = pw.CharField(max_length=255) + title = pw.CharField() + chat = pw.TextField() + timestamp = pw.DateField() + + class Meta: + table_name = "chat" + + @migrator.create_model + class ChatIdTag(pw.Model): + id = pw.CharField(max_length=255, unique=True) + tag_name = pw.CharField(max_length=255) + chat_id = pw.CharField(max_length=255) + user_id = pw.CharField(max_length=255) + timestamp = pw.DateField() + + class Meta: + table_name = "chatidtag" + + @migrator.create_model + class Document(pw.Model): + id = pw.AutoField() + collection_name = pw.CharField(max_length=255, unique=True) + name = pw.CharField(max_length=255, unique=True) + title = pw.CharField() + filename = pw.CharField() + content = pw.TextField(null=True) + user_id = pw.CharField(max_length=255) + timestamp = pw.DateField() + + class Meta: + table_name = "document" + + @migrator.create_model + class Modelfile(pw.Model): + id = pw.AutoField() + tag_name = pw.CharField(max_length=255, unique=True) + user_id = pw.CharField(max_length=255) + modelfile = pw.TextField() + timestamp = pw.DateField() + + class Meta: + table_name = "modelfile" + + @migrator.create_model + class Prompt(pw.Model): + id = pw.AutoField() + command = pw.CharField(max_length=255, unique=True) + user_id = pw.CharField(max_length=255) + title = pw.CharField() + content = pw.TextField() + timestamp = pw.DateField() + + class Meta: + table_name = "prompt" + + @migrator.create_model + class Tag(pw.Model): + id = pw.CharField(max_length=255, unique=True) + name = pw.CharField(max_length=255) + user_id = pw.CharField(max_length=255) + data = pw.TextField(null=True) + + class Meta: + table_name = "tag" + + @migrator.create_model + class User(pw.Model): + id = pw.CharField(max_length=255, unique=True) + name = pw.CharField(max_length=255) + email = pw.CharField(max_length=255) + role = pw.CharField(max_length=255) + profile_image_url = pw.CharField(max_length=255) + timestamp = pw.DateField() + + class Meta: + table_name = "user" + + +def rollback(migrator: Migrator, database: pw.Database, *, fake=False): + """Write your rollback migrations here.""" + + migrator.remove_model("user") + + migrator.remove_model("tag") + + migrator.remove_model("prompt") + + migrator.remove_model("modelfile") + + migrator.remove_model("document") + + migrator.remove_model("chatidtag") + + migrator.remove_model("chat") + + migrator.remove_model("auth") diff --git a/backend/apps/web/internal/migrations/002_add_local_sharing.py b/backend/apps/web/internal/migrations/002_add_local_sharing.py new file mode 100644 index 000000000..e93501aee --- /dev/null +++ b/backend/apps/web/internal/migrations/002_add_local_sharing.py @@ -0,0 +1,48 @@ +"""Peewee migrations -- 002_add_local_sharing.py. + +Some examples (model - class or model name):: + + > Model = migrator.orm['table_name'] # Return model in current state by name + > Model = migrator.ModelClass # Return model in current state by name + + > migrator.sql(sql) # Run custom SQL + > migrator.run(func, *args, **kwargs) # Run python function with the given args + > migrator.create_model(Model) # Create a model (could be used as decorator) + > migrator.remove_model(model, cascade=True) # Remove a model + > migrator.add_fields(model, **fields) # Add fields to a model + > migrator.change_fields(model, **fields) # Change fields + > migrator.remove_fields(model, *field_names, cascade=True) + > migrator.rename_field(model, old_field_name, new_field_name) + > migrator.rename_table(model, new_table_name) + > migrator.add_index(model, *col_names, unique=False) + > migrator.add_not_null(model, *field_names) + > migrator.add_default(model, field_name, default) + > migrator.add_constraint(model, name, sql) + > migrator.drop_index(model, *col_names) + > migrator.drop_not_null(model, *field_names) + > migrator.drop_constraints(model, *constraints) + +""" + +from contextlib import suppress + +import peewee as pw +from peewee_migrate import Migrator + + +with suppress(ImportError): + import playhouse.postgres_ext as pw_pext + + +def migrate(migrator: Migrator, database: pw.Database, *, fake=False): + """Write your migrations here.""" + + migrator.add_fields( + "chat", share_id=pw.CharField(max_length=255, null=True, unique=True) + ) + + +def rollback(migrator: Migrator, database: pw.Database, *, fake=False): + """Write your rollback migrations here.""" + + migrator.remove_fields("chat", "share_id") diff --git a/backend/apps/web/internal/migrations/003_add_auth_api_key.py b/backend/apps/web/internal/migrations/003_add_auth_api_key.py new file mode 100644 index 000000000..07144f3ac --- /dev/null +++ b/backend/apps/web/internal/migrations/003_add_auth_api_key.py @@ -0,0 +1,48 @@ +"""Peewee migrations -- 002_add_local_sharing.py. + +Some examples (model - class or model name):: + + > Model = migrator.orm['table_name'] # Return model in current state by name + > Model = migrator.ModelClass # Return model in current state by name + + > migrator.sql(sql) # Run custom SQL + > migrator.run(func, *args, **kwargs) # Run python function with the given args + > migrator.create_model(Model) # Create a model (could be used as decorator) + > migrator.remove_model(model, cascade=True) # Remove a model + > migrator.add_fields(model, **fields) # Add fields to a model + > migrator.change_fields(model, **fields) # Change fields + > migrator.remove_fields(model, *field_names, cascade=True) + > migrator.rename_field(model, old_field_name, new_field_name) + > migrator.rename_table(model, new_table_name) + > migrator.add_index(model, *col_names, unique=False) + > migrator.add_not_null(model, *field_names) + > migrator.add_default(model, field_name, default) + > migrator.add_constraint(model, name, sql) + > migrator.drop_index(model, *col_names) + > migrator.drop_not_null(model, *field_names) + > migrator.drop_constraints(model, *constraints) + +""" + +from contextlib import suppress + +import peewee as pw +from peewee_migrate import Migrator + + +with suppress(ImportError): + import playhouse.postgres_ext as pw_pext + + +def migrate(migrator: Migrator, database: pw.Database, *, fake=False): + """Write your migrations here.""" + + migrator.add_fields( + "user", api_key=pw.CharField(max_length=255, null=True, unique=True) + ) + + +def rollback(migrator: Migrator, database: pw.Database, *, fake=False): + """Write your rollback migrations here.""" + + migrator.remove_fields("user", "api_key") diff --git a/backend/apps/web/internal/migrations/README.md b/backend/apps/web/internal/migrations/README.md new file mode 100644 index 000000000..63d92e802 --- /dev/null +++ b/backend/apps/web/internal/migrations/README.md @@ -0,0 +1,21 @@ +# Database Migrations + +This directory contains all the database migrations for the web app. +Migrations are done using the [`peewee-migrate`](https://github.com/klen/peewee_migrate) library. + +Migrations are automatically ran at app startup. + +## Creating a migration + +Have you made a change to the schema of an existing model? +You will need to create a migration file to ensure that existing databases are updated for backwards compatibility. + +1. Have a database file (`webui.db`) that has the old schema prior to any of your changes. +2. Make your changes to the models. +3. From the `backend` directory, run the following command: + ```bash + pw_migrate create --auto --auto-source apps.web.models --database sqlite:///${SQLITE_DB} --directory apps/web/internal/migrations ${MIGRATION_NAME} + ``` + - `$SQLITE_DB` should be the path to the database file. + - `$MIGRATION_NAME` should be a descriptive name for the migration. +4. The migration file will be created in the `apps/web/internal/migrations` directory. diff --git a/backend/apps/web/main.py b/backend/apps/web/main.py index dd5c0c704..66cdfb3d4 100644 --- a/backend/apps/web/main.py +++ b/backend/apps/web/main.py @@ -20,6 +20,7 @@ from config import ( ENABLE_SIGNUP, USER_PERMISSIONS, WEBHOOK_URL, + WEBUI_AUTH_TRUSTED_EMAIL_HEADER, ) app = FastAPI() @@ -34,7 +35,7 @@ app.state.DEFAULT_PROMPT_SUGGESTIONS = DEFAULT_PROMPT_SUGGESTIONS app.state.DEFAULT_USER_ROLE = DEFAULT_USER_ROLE app.state.USER_PERMISSIONS = USER_PERMISSIONS app.state.WEBHOOK_URL = WEBHOOK_URL - +app.state.AUTH_TRUSTED_EMAIL_HEADER = WEBUI_AUTH_TRUSTED_EMAIL_HEADER app.add_middleware( CORSMiddleware, diff --git a/backend/apps/web/models/auths.py b/backend/apps/web/models/auths.py index 75637700d..069865036 100644 --- a/backend/apps/web/models/auths.py +++ b/backend/apps/web/models/auths.py @@ -47,6 +47,10 @@ class Token(BaseModel): token_type: str +class ApiKey(BaseModel): + api_key: Optional[str] = None + + class UserResponse(BaseModel): id: str email: str @@ -123,6 +127,28 @@ class AuthsTable: except: return None + def authenticate_user_by_api_key(self, api_key: str) -> Optional[UserModel]: + log.info(f"authenticate_user_by_api_key: {api_key}") + # if no api_key, return None + if not api_key: + return None + + try: + user = Users.get_user_by_api_key(api_key) + return user if user else None + except: + return False + + def authenticate_user_by_trusted_header(self, email: str) -> Optional[UserModel]: + log.info(f"authenticate_user_by_trusted_header: {email}") + try: + auth = Auth.get(Auth.email == email, Auth.active == True) + if auth: + user = Users.get_user_by_id(auth.id) + return user + except: + return None + def update_user_password_by_id(self, id: str, new_password: str) -> bool: try: query = Auth.update(password=new_password).where(Auth.id == id) diff --git a/backend/apps/web/models/chats.py b/backend/apps/web/models/chats.py index c9d130044..95a673cb8 100644 --- a/backend/apps/web/models/chats.py +++ b/backend/apps/web/models/chats.py @@ -20,6 +20,7 @@ class Chat(Model): title = CharField() chat = TextField() # Save Chat JSON as Text timestamp = DateField() + share_id = CharField(null=True, unique=True) class Meta: database = DB @@ -31,6 +32,7 @@ class ChatModel(BaseModel): title: str chat: str timestamp: int # timestamp in epoch + share_id: Optional[str] = None #################### @@ -52,6 +54,7 @@ class ChatResponse(BaseModel): title: str chat: dict timestamp: int # timestamp in epoch + share_id: Optional[str] = None # id of the chat to be shared class ChatTitleIdResponse(BaseModel): @@ -95,6 +98,71 @@ class ChatTable: except: return None + def insert_shared_chat_by_chat_id(self, chat_id: str) -> Optional[ChatModel]: + # Get the existing chat to share + chat = Chat.get(Chat.id == chat_id) + # Check if the chat is already shared + if chat.share_id: + return self.get_chat_by_id_and_user_id(chat.share_id, "shared") + # Create a new chat with the same data, but with a new ID + shared_chat = ChatModel( + **{ + "id": str(uuid.uuid4()), + "user_id": f"shared-{chat_id}", + "title": chat.title, + "chat": chat.chat, + "timestamp": int(time.time()), + } + ) + shared_result = Chat.create(**shared_chat.model_dump()) + # Update the original chat with the share_id + result = ( + Chat.update(share_id=shared_chat.id).where(Chat.id == chat_id).execute() + ) + + return shared_chat if (shared_result and result) else None + + def update_shared_chat_by_chat_id(self, chat_id: str) -> Optional[ChatModel]: + try: + print("update_shared_chat_by_id") + chat = Chat.get(Chat.id == chat_id) + print(chat) + + query = Chat.update( + title=chat.title, + chat=chat.chat, + ).where(Chat.id == chat.share_id) + + query.execute() + + chat = Chat.get(Chat.id == chat.share_id) + return ChatModel(**model_to_dict(chat)) + except: + return None + + def delete_shared_chat_by_chat_id(self, chat_id: str) -> bool: + try: + query = Chat.delete().where(Chat.user_id == f"shared-{chat_id}") + query.execute() # Remove the rows, return number of rows removed. + + return True + except: + return False + + def update_chat_share_id_by_id( + self, id: str, share_id: Optional[str] + ) -> Optional[ChatModel]: + try: + query = Chat.update( + share_id=share_id, + ).where(Chat.id == id) + query.execute() + + chat = Chat.get(Chat.id == id) + return ChatModel(**model_to_dict(chat)) + except: + return None + def get_chat_lists_by_user_id( self, user_id: str, skip: int = 0, limit: int = 50 ) -> List[ChatModel]: @@ -131,6 +199,13 @@ class ChatTable: .order_by(Chat.timestamp.desc()) ] + def get_chat_by_id(self, id: str) -> Optional[ChatModel]: + try: + chat = Chat.get(Chat.id == id) + return ChatModel(**model_to_dict(chat)) + except: + return None + def get_chat_by_id_and_user_id(self, id: str, user_id: str) -> Optional[ChatModel]: try: chat = Chat.get(Chat.id == id, Chat.user_id == user_id) @@ -149,12 +224,15 @@ class ChatTable: query = Chat.delete().where((Chat.id == id) & (Chat.user_id == user_id)) query.execute() # Remove the rows, return number of rows removed. - return True + return True and self.delete_shared_chat_by_chat_id(id) except: return False def delete_chats_by_user_id(self, user_id: str) -> bool: try: + + self.delete_shared_chats_by_user_id(user_id) + query = Chat.delete().where(Chat.user_id == user_id) query.execute() # Remove the rows, return number of rows removed. @@ -162,5 +240,19 @@ class ChatTable: except: return False + def delete_shared_chats_by_user_id(self, user_id: str) -> bool: + try: + shared_chat_ids = [ + f"shared-{chat.id}" + for chat in Chat.select().where(Chat.user_id == user_id) + ] + + query = Chat.delete().where(Chat.user_id << shared_chat_ids) + query.execute() # Remove the rows, return number of rows removed. + + return True + except: + return False + Chats = ChatTable(DB) diff --git a/backend/apps/web/models/users.py b/backend/apps/web/models/users.py index 255c701df..a01e595e5 100644 --- a/backend/apps/web/models/users.py +++ b/backend/apps/web/models/users.py @@ -20,6 +20,7 @@ class User(Model): role = CharField() profile_image_url = CharField() timestamp = DateField() + api_key = CharField(null=True, unique=True) class Meta: database = DB @@ -32,6 +33,7 @@ class UserModel(BaseModel): role: str = "pending" profile_image_url: str = "/user.png" timestamp: int # timestamp in epoch + api_key: Optional[str] = None #################### @@ -82,6 +84,13 @@ class UsersTable: except: return None + def get_user_by_api_key(self, api_key: str) -> Optional[UserModel]: + try: + user = User.get(User.api_key == api_key) + return UserModel(**model_to_dict(user)) + except: + return None + def get_user_by_email(self, email: str) -> Optional[UserModel]: try: user = User.get(User.email == email) @@ -149,5 +158,21 @@ class UsersTable: except: return False + def update_user_api_key_by_id(self, id: str, api_key: str) -> str: + try: + query = User.update(api_key=api_key).where(User.id == id) + result = query.execute() + + return True if result == 1 else False + except: + return False + + def get_user_api_key_by_id(self, id: str) -> Optional[str]: + try: + user = User.get(User.id == id) + return user.api_key + except: + return None + Users = UsersTable(DB) diff --git a/backend/apps/web/routers/auths.py b/backend/apps/web/routers/auths.py index d881ec746..293cb55b8 100644 --- a/backend/apps/web/routers/auths.py +++ b/backend/apps/web/routers/auths.py @@ -1,13 +1,10 @@ -from fastapi import Response, Request -from fastapi import Depends, FastAPI, HTTPException, status -from datetime import datetime, timedelta -from typing import List, Union +from fastapi import Request +from fastapi import Depends, HTTPException, status -from fastapi import APIRouter, status +from fastapi import APIRouter from pydantic import BaseModel -import time -import uuid import re +import uuid from apps.web.models.auths import ( SigninForm, @@ -17,6 +14,7 @@ from apps.web.models.auths import ( UserResponse, SigninResponse, Auths, + ApiKey, ) from apps.web.models.users import Users @@ -25,10 +23,12 @@ from utils.utils import ( get_current_user, get_admin_user, create_token, + create_api_key, ) from utils.misc import parse_duration, validate_email_format from utils.webhook import post_webhook from constants import ERROR_MESSAGES, WEBHOOK_MESSAGES +from config import WEBUI_AUTH_TRUSTED_EMAIL_HEADER router = APIRouter() @@ -79,6 +79,8 @@ async def update_profile( async def update_password( form_data: UpdatePasswordForm, session_user=Depends(get_current_user) ): + if WEBUI_AUTH_TRUSTED_EMAIL_HEADER: + raise HTTPException(400, detail=ERROR_MESSAGES.ACTION_PROHIBITED) if session_user: user = Auths.authenticate_user(session_user.email, form_data.password) @@ -98,7 +100,22 @@ async def update_password( @router.post("/signin", response_model=SigninResponse) async def signin(request: Request, form_data: SigninForm): - user = Auths.authenticate_user(form_data.email.lower(), form_data.password) + if WEBUI_AUTH_TRUSTED_EMAIL_HEADER: + if WEBUI_AUTH_TRUSTED_EMAIL_HEADER not in request.headers: + raise HTTPException(400, detail=ERROR_MESSAGES.INVALID_TRUSTED_HEADER) + + trusted_email = request.headers[WEBUI_AUTH_TRUSTED_EMAIL_HEADER].lower() + if not Users.get_user_by_email(trusted_email.lower()): + await signup( + request, + SignupForm( + email=trusted_email, password=str(uuid.uuid4()), name=trusted_email + ), + ) + user = Auths.authenticate_user_by_trusted_header(trusted_email) + else: + user = Auths.authenticate_user(form_data.email.lower(), form_data.password) + if user: token = create_token( data={"id": user.id}, @@ -249,3 +266,40 @@ async def update_token_expires_duration( return request.app.state.JWT_EXPIRES_IN else: return request.app.state.JWT_EXPIRES_IN + + +############################ +# API Key +############################ + + +# create api key +@router.post("/api_key", response_model=ApiKey) +async def create_api_key_(user=Depends(get_current_user)): + api_key = create_api_key() + success = Users.update_user_api_key_by_id(user.id, api_key) + if success: + return { + "api_key": api_key, + } + else: + raise HTTPException(500, detail=ERROR_MESSAGES.CREATE_API_KEY_ERROR) + + +# delete api key +@router.delete("/api_key", response_model=bool) +async def delete_api_key(user=Depends(get_current_user)): + success = Users.update_user_api_key_by_id(user.id, None) + return success + + +# get api key +@router.get("/api_key", response_model=ApiKey) +async def get_api_key(user=Depends(get_current_user)): + api_key = Users.get_user_api_key_by_id(user.id) + if api_key: + return { + "api_key": api_key, + } + else: + raise HTTPException(404, detail=ERROR_MESSAGES.API_KEY_NOT_FOUND) diff --git a/backend/apps/web/routers/chats.py b/backend/apps/web/routers/chats.py index 5f8c61b70..660a0d7f6 100644 --- a/backend/apps/web/routers/chats.py +++ b/backend/apps/web/routers/chats.py @@ -189,6 +189,78 @@ async def delete_chat_by_id(request: Request, id: str, user=Depends(get_current_ return result +############################ +# ShareChatById +############################ + + +@router.post("/{id}/share", response_model=Optional[ChatResponse]) +async def share_chat_by_id(id: str, user=Depends(get_current_user)): + chat = Chats.get_chat_by_id_and_user_id(id, user.id) + if chat: + if chat.share_id: + shared_chat = Chats.update_shared_chat_by_chat_id(chat.id) + return ChatResponse( + **{**shared_chat.model_dump(), "chat": json.loads(shared_chat.chat)} + ) + + shared_chat = Chats.insert_shared_chat_by_chat_id(chat.id) + if not shared_chat: + raise HTTPException( + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, + detail=ERROR_MESSAGES.DEFAULT(), + ) + + return ChatResponse( + **{**shared_chat.model_dump(), "chat": json.loads(shared_chat.chat)} + ) + else: + raise HTTPException( + status_code=status.HTTP_401_UNAUTHORIZED, + detail=ERROR_MESSAGES.ACCESS_PROHIBITED, + ) + + +############################ +# DeletedSharedChatById +############################ + + +@router.delete("/{id}/share", response_model=Optional[bool]) +async def delete_shared_chat_by_id(id: str, user=Depends(get_current_user)): + chat = Chats.get_chat_by_id_and_user_id(id, user.id) + if chat: + if not chat.share_id: + return False + + result = Chats.delete_shared_chat_by_chat_id(id) + update_result = Chats.update_chat_share_id_by_id(id, None) + + return result and update_result != None + else: + raise HTTPException( + status_code=status.HTTP_401_UNAUTHORIZED, + detail=ERROR_MESSAGES.ACCESS_PROHIBITED, + ) + + +############################ +# GetSharedChatById +############################ + + +@router.get("/share/{share_id}", response_model=Optional[ChatResponse]) +async def get_shared_chat_by_id(share_id: str, user=Depends(get_current_user)): + chat = Chats.get_chat_by_id(share_id) + + if chat: + return ChatResponse(**{**chat.model_dump(), "chat": json.loads(chat.chat)}) + else: + raise HTTPException( + status_code=status.HTTP_401_UNAUTHORIZED, detail=ERROR_MESSAGES.NOT_FOUND + ) + + ############################ # GetChatTagsById ############################ diff --git a/backend/config.py b/backend/config.py index 7e64ebbfd..c1f0b590d 100644 --- a/backend/config.py +++ b/backend/config.py @@ -367,6 +367,9 @@ WEBUI_VERSION = os.environ.get("WEBUI_VERSION", "v1.0.0-alpha.100") #################################### WEBUI_AUTH = True +WEBUI_AUTH_TRUSTED_EMAIL_HEADER = os.environ.get( + "WEBUI_AUTH_TRUSTED_EMAIL_HEADER", None +) #################################### # WEBUI_SECRET_KEY diff --git a/backend/constants.py b/backend/constants.py index 8bcdd0789..da1ee0b3f 100644 --- a/backend/constants.py +++ b/backend/constants.py @@ -20,6 +20,7 @@ class ERROR_MESSAGES(str, Enum): ENV_VAR_NOT_FOUND = "Required environment variable not found. Terminating now." CREATE_USER_ERROR = "Oops! Something went wrong while creating your account. Please try again later. If the issue persists, contact support for assistance." DELETE_USER_ERROR = "Oops! Something went wrong. We encountered an issue while trying to delete the user. Please give it another shot." + EMAIL_MISMATCH = "Uh-oh! This email does not match the email your provider is registered with. Please check your email and try again." EMAIL_TAKEN = "Uh-oh! This email is already registered. Sign in with your existing account or choose another email to start anew." USERNAME_TAKEN = ( "Uh-oh! This username is already registered. Please choose another username." @@ -36,6 +37,7 @@ class ERROR_MESSAGES(str, Enum): INVALID_PASSWORD = ( "The password provided is incorrect. Please check for typos and try again." ) + INVALID_TRUSTED_HEADER = "Your provider has not provided a trusted header. Please contact your administrator for assistance." UNAUTHORIZED = "401 Unauthorized" ACCESS_PROHIBITED = "You do not have permission to access this resource. Please contact your administrator for assistance." ACTION_PROHIBITED = ( @@ -58,7 +60,8 @@ class ERROR_MESSAGES(str, Enum): RATE_LIMIT_EXCEEDED = "API rate limit exceeded" MODEL_NOT_FOUND = lambda name="": f"Model '{name}' was not found" - OPENAI_NOT_FOUND = lambda name="": f"OpenAI API was not found" + OPENAI_NOT_FOUND = lambda name="": "OpenAI API was not found" OLLAMA_NOT_FOUND = "WebUI could not connect to Ollama" + CREATE_API_KEY_ERROR = "Oops! Something went wrong while creating your API key. Please try again later. If the issue persists, contact support for assistance." EMPTY_CONTENT = "The content provided is empty. Please ensure that there is text or data present before proceeding." diff --git a/backend/main.py b/backend/main.py index 8cc704a26..f2d2a1546 100644 --- a/backend/main.py +++ b/backend/main.py @@ -62,6 +62,21 @@ class SPAStaticFiles(StaticFiles): raise ex +print( + f""" + ___ __ __ _ _ _ ___ + / _ \ _ __ ___ _ __ \ \ / /__| |__ | | | |_ _| +| | | | '_ \ / _ \ '_ \ \ \ /\ / / _ \ '_ \| | | || | +| |_| | |_) | __/ | | | \ V V / __/ |_) | |_| || | + \___/| .__/ \___|_| |_| \_/\_/ \___|_.__/ \___/|___| + |_| + + +v{VERSION} - building the best open-source AI user interface. +https://github.com/open-webui/open-webui +""" +) + app = FastAPI(docs_url="/docs" if ENV == "dev" else None, redoc_url=None) app.state.MODEL_FILTER_ENABLED = MODEL_FILTER_ENABLED @@ -179,6 +194,7 @@ async def get_app_config(): "images": images_app.state.ENABLED, "default_models": webui_app.state.DEFAULT_MODELS, "default_prompt_suggestions": webui_app.state.DEFAULT_PROMPT_SUGGESTIONS, + "trusted_header_auth": bool(webui_app.state.AUTH_TRUSTED_EMAIL_HEADER), } diff --git a/backend/requirements.txt b/backend/requirements.txt index df8fcfec3..67213e54d 100644 --- a/backend/requirements.txt +++ b/backend/requirements.txt @@ -14,6 +14,7 @@ uuid requests aiohttp peewee +peewee-migrate bcrypt litellm==1.30.7 diff --git a/backend/utils/utils.py b/backend/utils/utils.py index 32724af39..49e15789f 100644 --- a/backend/utils/utils.py +++ b/backend/utils/utils.py @@ -1,6 +1,8 @@ from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials from fastapi import HTTPException, status, Depends + from apps.web.models.users import Users + from pydantic import BaseModel from typing import Union, Optional from constants import ERROR_MESSAGES @@ -8,6 +10,7 @@ from passlib.context import CryptContext from datetime import datetime, timedelta import requests import jwt +import uuid import logging import config @@ -58,6 +61,11 @@ def extract_token_from_auth_header(auth_header: str): return auth_header[len("Bearer ") :] +def create_api_key(): + key = str(uuid.uuid4()).replace("-", "") + return f"sk-{key}" + + def get_http_authorization_cred(auth_header: str): try: scheme, credentials = auth_header.split(" ") @@ -69,6 +77,10 @@ def get_http_authorization_cred(auth_header: str): def get_current_user( auth_token: HTTPAuthorizationCredentials = Depends(bearer_security), ): + # auth by api key + if auth_token.credentials.startswith("sk-"): + return get_current_user_by_api_key(auth_token.credentials) + # auth by jwt token data = decode_token(auth_token.credentials) if data != None and "id" in data: user = Users.get_user_by_id(data["id"]) @@ -85,6 +97,16 @@ def get_current_user( ) +def get_current_user_by_api_key(api_key: str): + user = Users.get_user_by_api_key(api_key) + if user is None: + raise HTTPException( + status_code=status.HTTP_401_UNAUTHORIZED, + detail=ERROR_MESSAGES.INVALID_TOKEN, + ) + return user + + def get_verified_user(user=Depends(get_current_user)): if user.role not in {"user", "admin"}: raise HTTPException( diff --git a/src/lib/apis/auths/index.ts b/src/lib/apis/auths/index.ts index 169998726..548a9418d 100644 --- a/src/lib/apis/auths/index.ts +++ b/src/lib/apis/auths/index.ts @@ -318,3 +318,78 @@ export const updateJWTExpiresDuration = async (token: string, duration: string) return res; }; + +export const createAPIKey = async (token: string) => { + let error = null; + + const res = await fetch(`${WEBUI_API_BASE_URL}/auths/api_key`, { + method: 'POST', + headers: { + 'Content-Type': 'application/json', + Authorization: `Bearer ${token}` + } + }) + .then(async (res) => { + if (!res.ok) throw await res.json(); + return res.json(); + }) + .catch((err) => { + console.log(err); + error = err.detail; + return null; + }); + if (error) { + throw error; + } + return res.api_key; +}; + +export const getAPIKey = async (token: string) => { + let error = null; + + const res = await fetch(`${WEBUI_API_BASE_URL}/auths/api_key`, { + method: 'GET', + headers: { + 'Content-Type': 'application/json', + Authorization: `Bearer ${token}` + } + }) + .then(async (res) => { + if (!res.ok) throw await res.json(); + return res.json(); + }) + .catch((err) => { + console.log(err); + error = err.detail; + return null; + }); + if (error) { + throw error; + } + return res.api_key; +}; + +export const deleteAPIKey = async (token: string) => { + let error = null; + + const res = await fetch(`${WEBUI_API_BASE_URL}/auths/api_key`, { + method: 'DELETE', + headers: { + 'Content-Type': 'application/json', + Authorization: `Bearer ${token}` + } + }) + .then(async (res) => { + if (!res.ok) throw await res.json(); + return res.json(); + }) + .catch((err) => { + console.log(err); + error = err.detail; + return null; + }); + if (error) { + throw error; + } + return res; +}; diff --git a/src/lib/apis/chats/index.ts b/src/lib/apis/chats/index.ts index 35b259d56..28b3d4be5 100644 --- a/src/lib/apis/chats/index.ts +++ b/src/lib/apis/chats/index.ts @@ -218,6 +218,102 @@ export const getChatById = async (token: string, id: string) => { return res; }; +export const getChatByShareId = async (token: string, share_id: string) => { + let error = null; + + const res = await fetch(`${WEBUI_API_BASE_URL}/chats/share/${share_id}`, { + method: 'GET', + headers: { + Accept: 'application/json', + 'Content-Type': 'application/json', + ...(token && { authorization: `Bearer ${token}` }) + } + }) + .then(async (res) => { + if (!res.ok) throw await res.json(); + return res.json(); + }) + .then((json) => { + return json; + }) + .catch((err) => { + error = err; + + console.log(err); + return null; + }); + + if (error) { + throw error; + } + + return res; +}; + +export const shareChatById = async (token: string, id: string) => { + let error = null; + + const res = await fetch(`${WEBUI_API_BASE_URL}/chats/${id}/share`, { + method: 'POST', + headers: { + Accept: 'application/json', + 'Content-Type': 'application/json', + ...(token && { authorization: `Bearer ${token}` }) + } + }) + .then(async (res) => { + if (!res.ok) throw await res.json(); + return res.json(); + }) + .then((json) => { + return json; + }) + .catch((err) => { + error = err; + + console.log(err); + return null; + }); + + if (error) { + throw error; + } + + return res; +}; + +export const deleteSharedChatById = async (token: string, id: string) => { + let error = null; + + const res = await fetch(`${WEBUI_API_BASE_URL}/chats/${id}/share`, { + method: 'DELETE', + headers: { + Accept: 'application/json', + 'Content-Type': 'application/json', + ...(token && { authorization: `Bearer ${token}` }) + } + }) + .then(async (res) => { + if (!res.ok) throw await res.json(); + return res.json(); + }) + .then((json) => { + return json; + }) + .catch((err) => { + error = err; + + console.log(err); + return null; + }); + + if (error) { + throw error; + } + + return res; +}; + export const updateChatById = async (token: string, id: string, chat: object) => { let error = null; diff --git a/src/lib/components/chat/Messages.svelte b/src/lib/components/chat/Messages.svelte index 7afb5c376..4cd97ca80 100644 --- a/src/lib/components/chat/Messages.svelte +++ b/src/lib/components/chat/Messages.svelte @@ -16,6 +16,7 @@ const i18n = getContext('i18n'); export let chatId = ''; + export let readOnly = false; export let sendPrompt: Function; export let continueGeneration: Function; export let regenerateResponse: Function; @@ -317,6 +318,7 @@ messageDeleteHandler(message.id)} user={$user} + {readOnly} {message} isFirstMessage={messageIdx === 0} siblings={message.parentId !== null @@ -335,6 +337,7 @@ modelfiles={selectedModelfiles} siblings={history.messages[message.parentId]?.childrenIds ?? []} isLastMessage={messageIdx + 1 === messages.length} + {readOnly} {confirmEditResponseMessage} {showPreviousMessage} {showNextMessage} diff --git a/src/lib/components/chat/Messages/ResponseMessage.svelte b/src/lib/components/chat/Messages/ResponseMessage.svelte index 91bb35c47..3888d764e 100644 --- a/src/lib/components/chat/Messages/ResponseMessage.svelte +++ b/src/lib/components/chat/Messages/ResponseMessage.svelte @@ -33,6 +33,8 @@ export let isLastMessage = true; + export let readOnly = false; + export let confirmEditResponseMessage: Function; export let showPreviousMessage: Function; export let showNextMessage: Function; @@ -128,7 +130,7 @@ // • auto-render specific keys, e.g.: delimiters: [ { left: '$$', right: '$$', display: false }, - { left: '$', right: '$', display: false }, + { left: '$ ', right: ' $', display: false }, { left: '\\(', right: '\\)', display: false }, { left: '\\[', right: '\\]', display: false }, { left: '[ ', right: ' ]', display: false } @@ -422,7 +424,7 @@ class=" flex justify-start space-x-1 overflow-x-auto buttons text-gray-700 dark:text-gray-500" > {#if siblings.length > 1} -
+
{/if} - - - + + + + + + {/if} - + + + - - - + + + + {/if}
{/if} - - - + + + + + + {/if} + + +
+
+
{$i18n.t('API Key')}
+
- +
+ + + + + + {:else} - { + createAPIKeyHandler(); + }} + > + + + Create new secret key - - - {/if} - + diff --git a/src/lib/components/chat/ShareChatModal.svelte b/src/lib/components/chat/ShareChatModal.svelte index aef098f04..14945ab9d 100644 --- a/src/lib/components/chat/ShareChatModal.svelte +++ b/src/lib/components/chat/ShareChatModal.svelte @@ -1,41 +1,180 @@ - -
- - -
-
{$i18n.t('or')}
- + +
+
+
{$i18n.t('Share Chat')}
+
+ + {#if chat} +
+
+ {#if chat.share_id} + You have shared this chat before. + Click here to + and create a new shared link. + {:else} + Messages you send after creating your link won't be shared. Users with the URL will be + able to view the shared chat. + {/if} +
+ +
+
+
+ + + +
+
+
{$i18n.t('or')}
+ +
+
+
+
+ {/if}
diff --git a/src/lib/components/common/Spinner.svelte b/src/lib/components/common/Spinner.svelte index 206c7f5ce..4b7f5e396 100644 --- a/src/lib/components/common/Spinner.svelte +++ b/src/lib/components/common/Spinner.svelte @@ -1,24 +1,25 @@
- - - -
diff --git a/src/lib/components/icons/Link.svelte b/src/lib/components/icons/Link.svelte new file mode 100644 index 000000000..9f1a72311 --- /dev/null +++ b/src/lib/components/icons/Link.svelte @@ -0,0 +1,16 @@ + + + + + + diff --git a/src/lib/components/icons/Plus.svelte b/src/lib/components/icons/Plus.svelte new file mode 100644 index 000000000..bcfe4a8b2 --- /dev/null +++ b/src/lib/components/icons/Plus.svelte @@ -0,0 +1,15 @@ + + + + + diff --git a/src/lib/components/layout/Navbar.svelte b/src/lib/components/layout/Navbar.svelte index a98689618..6bff2ed80 100644 --- a/src/lib/components/layout/Navbar.svelte +++ b/src/lib/components/layout/Navbar.svelte @@ -1,11 +1,9 @@ - +