Merge pull request #7896 from open-webui/dev

0.5.0
This commit is contained in:
Timothy Jaeryang Baek 2024-12-25 10:39:01 -08:00 committed by GitHub
commit 22132e155a
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
251 changed files with 98386 additions and 8864 deletions

View File

@ -5,6 +5,35 @@ All notable changes to this project will be documented in this file.
The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.1.0/),
and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html).
## [0.5.0] - 2024-12-25
### Added
- **💬 True Asynchronous Chat Support**: Create chats, navigate away, and return anytime with responses ready. Ideal for reasoning models and multi-agent workflows, enhancing multitasking like never before.
- **🔔 Chat Completion Notifications**: Never miss a completed response. Receive instant in-UI notifications when a chat finishes in a non-active tab, keeping you updated while you work elsewhere.
- **🌐 Notification Webhook Integration**: Get alerts via webhooks even when your tab is closed! Configure your webhook URL in Settings > Account and receive timely updates for long-running chats or external integration needs.
- **📚 Channels (Beta)**: Explore Discord/Slack-style chat rooms designed for real-time collaboration between users and AIs. Build bots for channels and unlock asynchronous communication for proactive multi-agent workflows. Opt-in via Admin Settings > General. A Comprehensive Bot SDK tutorial (https://github.com/open-webui/bot) is incoming, so stay tuned!
- **🖼️ Client-Side Image Compression**: Now compress images before upload (Settings > Interface), saving bandwidth and improving performance seamlessly.
- **🛠️ OAuth Management for User Groups**: Enable group-level management via OAuth integration for enhanced control and scalability in collaborative environments.
- **✅ Structured Output for Ollama**: Pass structured data output directly to Ollama, unlocking new possibilities for streamlined automation and precise data handling.
- **📜 Offline Swagger Documentation**: Developer-friendly Swagger API docs are now available offline, ensuring full accessibility wherever you are.
- **📸 Quick Screen Capture Button**: Effortlessly capture your screen with a single click from the message input menu.
- **🌍 i18n Updates**: Improved and refined translations across several languages, including Ukrainian, German, Brazilian Portuguese, Catalan, and more, ensuring a seamless global user experience.
### Fixed
- **📋 Table Export to CSV**: Resolved issues with CSV export where headers were missing or errors occurred due to values with commas, ensuring smooth and reliable data handling.
- **🔓 BYPASS_MODEL_ACCESS_CONTROL**: Fixed an issue where users could see models but couldnt use them with 'BYPASS_MODEL_ACCESS_CONTROL=True', restoring proper functionality for environments leveraging this setting.
### Changed
- **💡 API Key Authentication Restriction**: Narrowed API key auth permissions to '/api/models' and '/api/chat/completions' for enhanced security and better API governance.
- **⚙️ Backend Overhaul for Performance**: Major backend restructuring; a heads-up that some "Functions" using internal variables may face compatibility issues. Moving forward, websocket support is mandatory to ensure Open WebUI operates seamlessly.
### Removed
- **⚠️ Legacy Functionality Clean-Up**: Deprecated outdated backend systems that were non-essential or overlapped with newer implementations, allowing for a leaner, more efficient platform.
## [0.4.8] - 2024-12-07
### Added

View File

@ -1,703 +0,0 @@
import hashlib
import json
import logging
import os
import uuid
from functools import lru_cache
from pathlib import Path
from pydub import AudioSegment
from pydub.silence import split_on_silence
import aiohttp
import aiofiles
import requests
from open_webui.config import (
AUDIO_STT_ENGINE,
AUDIO_STT_MODEL,
AUDIO_STT_OPENAI_API_BASE_URL,
AUDIO_STT_OPENAI_API_KEY,
AUDIO_TTS_API_KEY,
AUDIO_TTS_ENGINE,
AUDIO_TTS_MODEL,
AUDIO_TTS_OPENAI_API_BASE_URL,
AUDIO_TTS_OPENAI_API_KEY,
AUDIO_TTS_SPLIT_ON,
AUDIO_TTS_VOICE,
AUDIO_TTS_AZURE_SPEECH_REGION,
AUDIO_TTS_AZURE_SPEECH_OUTPUT_FORMAT,
CACHE_DIR,
CORS_ALLOW_ORIGIN,
WHISPER_MODEL,
WHISPER_MODEL_AUTO_UPDATE,
WHISPER_MODEL_DIR,
AppConfig,
)
from open_webui.constants import ERROR_MESSAGES
from open_webui.env import (
ENV,
SRC_LOG_LEVELS,
DEVICE_TYPE,
ENABLE_FORWARD_USER_INFO_HEADERS,
)
from fastapi import Depends, FastAPI, File, HTTPException, Request, UploadFile, status
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import FileResponse
from pydantic import BaseModel
from open_webui.utils.utils import get_admin_user, get_verified_user
# Constants
MAX_FILE_SIZE_MB = 25
MAX_FILE_SIZE = MAX_FILE_SIZE_MB * 1024 * 1024 # Convert MB to bytes
log = logging.getLogger(__name__)
log.setLevel(SRC_LOG_LEVELS["AUDIO"])
app = FastAPI(
docs_url="/docs" if ENV == "dev" else None,
openapi_url="/openapi.json" if ENV == "dev" else None,
redoc_url=None,
)
app.add_middleware(
CORSMiddleware,
allow_origins=CORS_ALLOW_ORIGIN,
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
app.state.config = AppConfig()
app.state.config.STT_OPENAI_API_BASE_URL = AUDIO_STT_OPENAI_API_BASE_URL
app.state.config.STT_OPENAI_API_KEY = AUDIO_STT_OPENAI_API_KEY
app.state.config.STT_ENGINE = AUDIO_STT_ENGINE
app.state.config.STT_MODEL = AUDIO_STT_MODEL
app.state.config.WHISPER_MODEL = WHISPER_MODEL
app.state.faster_whisper_model = None
app.state.config.TTS_OPENAI_API_BASE_URL = AUDIO_TTS_OPENAI_API_BASE_URL
app.state.config.TTS_OPENAI_API_KEY = AUDIO_TTS_OPENAI_API_KEY
app.state.config.TTS_ENGINE = AUDIO_TTS_ENGINE
app.state.config.TTS_MODEL = AUDIO_TTS_MODEL
app.state.config.TTS_VOICE = AUDIO_TTS_VOICE
app.state.config.TTS_API_KEY = AUDIO_TTS_API_KEY
app.state.config.TTS_SPLIT_ON = AUDIO_TTS_SPLIT_ON
app.state.speech_synthesiser = None
app.state.speech_speaker_embeddings_dataset = None
app.state.config.TTS_AZURE_SPEECH_REGION = AUDIO_TTS_AZURE_SPEECH_REGION
app.state.config.TTS_AZURE_SPEECH_OUTPUT_FORMAT = AUDIO_TTS_AZURE_SPEECH_OUTPUT_FORMAT
# setting device type for whisper model
whisper_device_type = DEVICE_TYPE if DEVICE_TYPE and DEVICE_TYPE == "cuda" else "cpu"
log.info(f"whisper_device_type: {whisper_device_type}")
SPEECH_CACHE_DIR = Path(CACHE_DIR).joinpath("./audio/speech/")
SPEECH_CACHE_DIR.mkdir(parents=True, exist_ok=True)
def set_faster_whisper_model(model: str, auto_update: bool = False):
if model and app.state.config.STT_ENGINE == "":
from faster_whisper import WhisperModel
faster_whisper_kwargs = {
"model_size_or_path": model,
"device": whisper_device_type,
"compute_type": "int8",
"download_root": WHISPER_MODEL_DIR,
"local_files_only": not auto_update,
}
try:
app.state.faster_whisper_model = WhisperModel(**faster_whisper_kwargs)
except Exception:
log.warning(
"WhisperModel initialization failed, attempting download with local_files_only=False"
)
faster_whisper_kwargs["local_files_only"] = False
app.state.faster_whisper_model = WhisperModel(**faster_whisper_kwargs)
else:
app.state.faster_whisper_model = None
class TTSConfigForm(BaseModel):
OPENAI_API_BASE_URL: str
OPENAI_API_KEY: str
API_KEY: str
ENGINE: str
MODEL: str
VOICE: str
SPLIT_ON: str
AZURE_SPEECH_REGION: str
AZURE_SPEECH_OUTPUT_FORMAT: str
class STTConfigForm(BaseModel):
OPENAI_API_BASE_URL: str
OPENAI_API_KEY: str
ENGINE: str
MODEL: str
WHISPER_MODEL: str
class AudioConfigUpdateForm(BaseModel):
tts: TTSConfigForm
stt: STTConfigForm
from pydub import AudioSegment
from pydub.utils import mediainfo
def is_mp4_audio(file_path):
"""Check if the given file is an MP4 audio file."""
if not os.path.isfile(file_path):
print(f"File not found: {file_path}")
return False
info = mediainfo(file_path)
if (
info.get("codec_name") == "aac"
and info.get("codec_type") == "audio"
and info.get("codec_tag_string") == "mp4a"
):
return True
return False
def convert_mp4_to_wav(file_path, output_path):
"""Convert MP4 audio file to WAV format."""
audio = AudioSegment.from_file(file_path, format="mp4")
audio.export(output_path, format="wav")
print(f"Converted {file_path} to {output_path}")
@app.get("/config")
async def get_audio_config(user=Depends(get_admin_user)):
return {
"tts": {
"OPENAI_API_BASE_URL": app.state.config.TTS_OPENAI_API_BASE_URL,
"OPENAI_API_KEY": app.state.config.TTS_OPENAI_API_KEY,
"API_KEY": app.state.config.TTS_API_KEY,
"ENGINE": app.state.config.TTS_ENGINE,
"MODEL": app.state.config.TTS_MODEL,
"VOICE": app.state.config.TTS_VOICE,
"SPLIT_ON": app.state.config.TTS_SPLIT_ON,
"AZURE_SPEECH_REGION": app.state.config.TTS_AZURE_SPEECH_REGION,
"AZURE_SPEECH_OUTPUT_FORMAT": app.state.config.TTS_AZURE_SPEECH_OUTPUT_FORMAT,
},
"stt": {
"OPENAI_API_BASE_URL": app.state.config.STT_OPENAI_API_BASE_URL,
"OPENAI_API_KEY": app.state.config.STT_OPENAI_API_KEY,
"ENGINE": app.state.config.STT_ENGINE,
"MODEL": app.state.config.STT_MODEL,
"WHISPER_MODEL": app.state.config.WHISPER_MODEL,
},
}
@app.post("/config/update")
async def update_audio_config(
form_data: AudioConfigUpdateForm, user=Depends(get_admin_user)
):
app.state.config.TTS_OPENAI_API_BASE_URL = form_data.tts.OPENAI_API_BASE_URL
app.state.config.TTS_OPENAI_API_KEY = form_data.tts.OPENAI_API_KEY
app.state.config.TTS_API_KEY = form_data.tts.API_KEY
app.state.config.TTS_ENGINE = form_data.tts.ENGINE
app.state.config.TTS_MODEL = form_data.tts.MODEL
app.state.config.TTS_VOICE = form_data.tts.VOICE
app.state.config.TTS_SPLIT_ON = form_data.tts.SPLIT_ON
app.state.config.TTS_AZURE_SPEECH_REGION = form_data.tts.AZURE_SPEECH_REGION
app.state.config.TTS_AZURE_SPEECH_OUTPUT_FORMAT = (
form_data.tts.AZURE_SPEECH_OUTPUT_FORMAT
)
app.state.config.STT_OPENAI_API_BASE_URL = form_data.stt.OPENAI_API_BASE_URL
app.state.config.STT_OPENAI_API_KEY = form_data.stt.OPENAI_API_KEY
app.state.config.STT_ENGINE = form_data.stt.ENGINE
app.state.config.STT_MODEL = form_data.stt.MODEL
app.state.config.WHISPER_MODEL = form_data.stt.WHISPER_MODEL
set_faster_whisper_model(form_data.stt.WHISPER_MODEL, WHISPER_MODEL_AUTO_UPDATE)
return {
"tts": {
"OPENAI_API_BASE_URL": app.state.config.TTS_OPENAI_API_BASE_URL,
"OPENAI_API_KEY": app.state.config.TTS_OPENAI_API_KEY,
"API_KEY": app.state.config.TTS_API_KEY,
"ENGINE": app.state.config.TTS_ENGINE,
"MODEL": app.state.config.TTS_MODEL,
"VOICE": app.state.config.TTS_VOICE,
"SPLIT_ON": app.state.config.TTS_SPLIT_ON,
"AZURE_SPEECH_REGION": app.state.config.TTS_AZURE_SPEECH_REGION,
"AZURE_SPEECH_OUTPUT_FORMAT": app.state.config.TTS_AZURE_SPEECH_OUTPUT_FORMAT,
},
"stt": {
"OPENAI_API_BASE_URL": app.state.config.STT_OPENAI_API_BASE_URL,
"OPENAI_API_KEY": app.state.config.STT_OPENAI_API_KEY,
"ENGINE": app.state.config.STT_ENGINE,
"MODEL": app.state.config.STT_MODEL,
"WHISPER_MODEL": app.state.config.WHISPER_MODEL,
},
}
def load_speech_pipeline():
from transformers import pipeline
from datasets import load_dataset
if app.state.speech_synthesiser is None:
app.state.speech_synthesiser = pipeline(
"text-to-speech", "microsoft/speecht5_tts"
)
if app.state.speech_speaker_embeddings_dataset is None:
app.state.speech_speaker_embeddings_dataset = load_dataset(
"Matthijs/cmu-arctic-xvectors", split="validation"
)
@app.post("/speech")
async def speech(request: Request, user=Depends(get_verified_user)):
body = await request.body()
name = hashlib.sha256(body).hexdigest()
file_path = SPEECH_CACHE_DIR.joinpath(f"{name}.mp3")
file_body_path = SPEECH_CACHE_DIR.joinpath(f"{name}.json")
# Check if the file already exists in the cache
if file_path.is_file():
return FileResponse(file_path)
if app.state.config.TTS_ENGINE == "openai":
headers = {}
headers["Authorization"] = f"Bearer {app.state.config.TTS_OPENAI_API_KEY}"
headers["Content-Type"] = "application/json"
if ENABLE_FORWARD_USER_INFO_HEADERS:
headers["X-OpenWebUI-User-Name"] = user.name
headers["X-OpenWebUI-User-Id"] = user.id
headers["X-OpenWebUI-User-Email"] = user.email
headers["X-OpenWebUI-User-Role"] = user.role
try:
body = body.decode("utf-8")
body = json.loads(body)
body["model"] = app.state.config.TTS_MODEL
body = json.dumps(body).encode("utf-8")
except Exception:
pass
try:
async with aiohttp.ClientSession() as session:
async with session.post(
url=f"{app.state.config.TTS_OPENAI_API_BASE_URL}/audio/speech",
data=body,
headers=headers,
) as r:
r.raise_for_status()
async with aiofiles.open(file_path, "wb") as f:
await f.write(await r.read())
async with aiofiles.open(file_body_path, "w") as f:
await f.write(json.dumps(json.loads(body.decode("utf-8"))))
return FileResponse(file_path)
except Exception as e:
log.exception(e)
error_detail = "Open WebUI: Server Connection Error"
try:
if r.status != 200:
res = await r.json()
if "error" in res:
error_detail = f"External: {res['error']['message']}"
except Exception:
error_detail = f"External: {e}"
raise HTTPException(
status_code=getattr(r, "status", 500),
detail=error_detail,
)
elif app.state.config.TTS_ENGINE == "elevenlabs":
try:
payload = json.loads(body.decode("utf-8"))
except Exception as e:
log.exception(e)
raise HTTPException(status_code=400, detail="Invalid JSON payload")
voice_id = payload.get("voice", "")
if voice_id not in get_available_voices():
raise HTTPException(
status_code=400,
detail="Invalid voice id",
)
url = f"https://api.elevenlabs.io/v1/text-to-speech/{voice_id}"
headers = {
"Accept": "audio/mpeg",
"Content-Type": "application/json",
"xi-api-key": app.state.config.TTS_API_KEY,
}
data = {
"text": payload["input"],
"model_id": app.state.config.TTS_MODEL,
"voice_settings": {"stability": 0.5, "similarity_boost": 0.5},
}
try:
async with aiohttp.ClientSession() as session:
async with session.post(url, json=data, headers=headers) as r:
r.raise_for_status()
async with aiofiles.open(file_path, "wb") as f:
await f.write(await r.read())
async with aiofiles.open(file_body_path, "w") as f:
await f.write(json.dumps(json.loads(body.decode("utf-8"))))
return FileResponse(file_path)
except Exception as e:
log.exception(e)
error_detail = "Open WebUI: Server Connection Error"
try:
if r.status != 200:
res = await r.json()
if "error" in res:
error_detail = f"External: {res['error']['message']}"
except Exception:
error_detail = f"External: {e}"
raise HTTPException(
status_code=getattr(r, "status", 500),
detail=error_detail,
)
elif app.state.config.TTS_ENGINE == "azure":
try:
payload = json.loads(body.decode("utf-8"))
except Exception as e:
log.exception(e)
raise HTTPException(status_code=400, detail="Invalid JSON payload")
region = app.state.config.TTS_AZURE_SPEECH_REGION
language = app.state.config.TTS_VOICE
locale = "-".join(app.state.config.TTS_VOICE.split("-")[:1])
output_format = app.state.config.TTS_AZURE_SPEECH_OUTPUT_FORMAT
url = f"https://{region}.tts.speech.microsoft.com/cognitiveservices/v1"
headers = {
"Ocp-Apim-Subscription-Key": app.state.config.TTS_API_KEY,
"Content-Type": "application/ssml+xml",
"X-Microsoft-OutputFormat": output_format,
}
data = f"""<speak version="1.0" xmlns="http://www.w3.org/2001/10/synthesis" xml:lang="{locale}">
<voice name="{language}">{payload["input"]}</voice>
</speak>"""
try:
async with aiohttp.ClientSession() as session:
async with session.post(url, headers=headers, data=data) as response:
if response.status == 200:
async with aiofiles.open(file_path, "wb") as f:
await f.write(await response.read())
return FileResponse(file_path)
else:
error_msg = f"Error synthesizing speech - {response.reason}"
log.error(error_msg)
raise HTTPException(status_code=500, detail=error_msg)
except Exception as e:
log.exception(e)
raise HTTPException(status_code=500, detail=str(e))
elif app.state.config.TTS_ENGINE == "transformers":
payload = None
try:
payload = json.loads(body.decode("utf-8"))
except Exception as e:
log.exception(e)
raise HTTPException(status_code=400, detail="Invalid JSON payload")
import torch
import soundfile as sf
load_speech_pipeline()
embeddings_dataset = app.state.speech_speaker_embeddings_dataset
speaker_index = 6799
try:
speaker_index = embeddings_dataset["filename"].index(
app.state.config.TTS_MODEL
)
except Exception:
pass
speaker_embedding = torch.tensor(
embeddings_dataset[speaker_index]["xvector"]
).unsqueeze(0)
speech = app.state.speech_synthesiser(
payload["input"],
forward_params={"speaker_embeddings": speaker_embedding},
)
sf.write(file_path, speech["audio"], samplerate=speech["sampling_rate"])
with open(file_body_path, "w") as f:
json.dump(json.loads(body.decode("utf-8")), f)
return FileResponse(file_path)
def transcribe(file_path):
print("transcribe", file_path)
filename = os.path.basename(file_path)
file_dir = os.path.dirname(file_path)
id = filename.split(".")[0]
if app.state.config.STT_ENGINE == "":
if app.state.faster_whisper_model is None:
set_faster_whisper_model(app.state.config.WHISPER_MODEL)
model = app.state.faster_whisper_model
segments, info = model.transcribe(file_path, beam_size=5)
log.info(
"Detected language '%s' with probability %f"
% (info.language, info.language_probability)
)
transcript = "".join([segment.text for segment in list(segments)])
data = {"text": transcript.strip()}
# save the transcript to a json file
transcript_file = f"{file_dir}/{id}.json"
with open(transcript_file, "w") as f:
json.dump(data, f)
log.debug(data)
return data
elif app.state.config.STT_ENGINE == "openai":
if is_mp4_audio(file_path):
print("is_mp4_audio")
os.rename(file_path, file_path.replace(".wav", ".mp4"))
# Convert MP4 audio file to WAV format
convert_mp4_to_wav(file_path.replace(".wav", ".mp4"), file_path)
headers = {"Authorization": f"Bearer {app.state.config.STT_OPENAI_API_KEY}"}
files = {"file": (filename, open(file_path, "rb"))}
data = {"model": app.state.config.STT_MODEL}
log.debug(files, data)
r = None
try:
r = requests.post(
url=f"{app.state.config.STT_OPENAI_API_BASE_URL}/audio/transcriptions",
headers=headers,
files=files,
data=data,
)
r.raise_for_status()
data = r.json()
# save the transcript to a json file
transcript_file = f"{file_dir}/{id}.json"
with open(transcript_file, "w") as f:
json.dump(data, f)
print(data)
return data
except Exception as e:
log.exception(e)
error_detail = "Open WebUI: Server Connection Error"
if r is not None:
try:
res = r.json()
if "error" in res:
error_detail = f"External: {res['error']['message']}"
except Exception:
error_detail = f"External: {e}"
raise Exception(error_detail)
@app.post("/transcriptions")
def transcription(
file: UploadFile = File(...),
user=Depends(get_verified_user),
):
log.info(f"file.content_type: {file.content_type}")
if file.content_type not in ["audio/mpeg", "audio/wav", "audio/ogg", "audio/x-m4a"]:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=ERROR_MESSAGES.FILE_NOT_SUPPORTED,
)
try:
ext = file.filename.split(".")[-1]
id = uuid.uuid4()
filename = f"{id}.{ext}"
contents = file.file.read()
file_dir = f"{CACHE_DIR}/audio/transcriptions"
os.makedirs(file_dir, exist_ok=True)
file_path = f"{file_dir}/{filename}"
with open(file_path, "wb") as f:
f.write(contents)
try:
if os.path.getsize(file_path) > MAX_FILE_SIZE: # file is bigger than 25MB
log.debug(f"File size is larger than {MAX_FILE_SIZE_MB}MB")
audio = AudioSegment.from_file(file_path)
audio = audio.set_frame_rate(16000).set_channels(1) # Compress audio
compressed_path = f"{file_dir}/{id}_compressed.opus"
audio.export(compressed_path, format="opus", bitrate="32k")
log.debug(f"Compressed audio to {compressed_path}")
file_path = compressed_path
if (
os.path.getsize(file_path) > MAX_FILE_SIZE
): # Still larger than 25MB after compression
log.debug(
f"Compressed file size is still larger than {MAX_FILE_SIZE_MB}MB: {os.path.getsize(file_path)}"
)
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=ERROR_MESSAGES.FILE_TOO_LARGE(
size=f"{MAX_FILE_SIZE_MB}MB"
),
)
data = transcribe(file_path)
else:
data = transcribe(file_path)
file_path = file_path.split("/")[-1]
return {**data, "filename": file_path}
except Exception as e:
log.exception(e)
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=ERROR_MESSAGES.DEFAULT(e),
)
except Exception as e:
log.exception(e)
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=ERROR_MESSAGES.DEFAULT(e),
)
def get_available_models() -> list[dict]:
if app.state.config.TTS_ENGINE == "openai":
return [{"id": "tts-1"}, {"id": "tts-1-hd"}]
elif app.state.config.TTS_ENGINE == "elevenlabs":
headers = {
"xi-api-key": app.state.config.TTS_API_KEY,
"Content-Type": "application/json",
}
try:
response = requests.get(
"https://api.elevenlabs.io/v1/models", headers=headers, timeout=5
)
response.raise_for_status()
models = response.json()
return [
{"name": model["name"], "id": model["model_id"]} for model in models
]
except requests.RequestException as e:
log.error(f"Error fetching voices: {str(e)}")
return []
@app.get("/models")
async def get_models(user=Depends(get_verified_user)):
return {"models": get_available_models()}
def get_available_voices() -> dict:
"""Returns {voice_id: voice_name} dict"""
ret = {}
if app.state.config.TTS_ENGINE == "openai":
ret = {
"alloy": "alloy",
"echo": "echo",
"fable": "fable",
"onyx": "onyx",
"nova": "nova",
"shimmer": "shimmer",
}
elif app.state.config.TTS_ENGINE == "elevenlabs":
try:
ret = get_elevenlabs_voices()
except Exception:
# Avoided @lru_cache with exception
pass
elif app.state.config.TTS_ENGINE == "azure":
try:
region = app.state.config.TTS_AZURE_SPEECH_REGION
url = f"https://{region}.tts.speech.microsoft.com/cognitiveservices/voices/list"
headers = {"Ocp-Apim-Subscription-Key": app.state.config.TTS_API_KEY}
response = requests.get(url, headers=headers)
response.raise_for_status()
voices = response.json()
for voice in voices:
ret[voice["ShortName"]] = (
f"{voice['DisplayName']} ({voice['ShortName']})"
)
except requests.RequestException as e:
log.error(f"Error fetching voices: {str(e)}")
return ret
@lru_cache
def get_elevenlabs_voices() -> dict:
"""
Note, set the following in your .env file to use Elevenlabs:
AUDIO_TTS_ENGINE=elevenlabs
AUDIO_TTS_API_KEY=sk_... # Your Elevenlabs API key
AUDIO_TTS_VOICE=EXAVITQu4vr4xnSDxMaL # From https://api.elevenlabs.io/v1/voices
AUDIO_TTS_MODEL=eleven_multilingual_v2
"""
headers = {
"xi-api-key": app.state.config.TTS_API_KEY,
"Content-Type": "application/json",
}
try:
# TODO: Add retries
response = requests.get("https://api.elevenlabs.io/v1/voices", headers=headers)
response.raise_for_status()
voices_data = response.json()
voices = {}
for voice in voices_data.get("voices", []):
voices[voice["voice_id"]] = voice["name"]
except requests.RequestException as e:
# Avoid @lru_cache with exception
log.error(f"Error fetching voices: {str(e)}")
raise RuntimeError(f"Error fetching voices: {str(e)}")
return voices
@app.get("/voices")
async def get_voices(user=Depends(get_verified_user)):
return {"voices": [{"id": k, "name": v} for k, v in get_available_voices().items()]}

File diff suppressed because it is too large Load Diff

View File

@ -1,22 +0,0 @@
from open_webui.config import VECTOR_DB
if VECTOR_DB == "milvus":
from open_webui.apps.retrieval.vector.dbs.milvus import MilvusClient
VECTOR_DB_CLIENT = MilvusClient()
elif VECTOR_DB == "qdrant":
from open_webui.apps.retrieval.vector.dbs.qdrant import QdrantClient
VECTOR_DB_CLIENT = QdrantClient()
elif VECTOR_DB == "opensearch":
from open_webui.apps.retrieval.vector.dbs.opensearch import OpenSearchClient
VECTOR_DB_CLIENT = OpenSearchClient()
elif VECTOR_DB == "pgvector":
from open_webui.apps.retrieval.vector.dbs.pgvector import PgvectorClient
VECTOR_DB_CLIENT = PgvectorClient()
else:
from open_webui.apps.retrieval.vector.dbs.chroma import ChromaClient
VECTOR_DB_CLIENT = ChromaClient()

View File

@ -1,506 +0,0 @@
import inspect
import json
import logging
import time
from typing import AsyncGenerator, Generator, Iterator
from open_webui.apps.socket.main import get_event_call, get_event_emitter
from open_webui.apps.webui.models.functions import Functions
from open_webui.apps.webui.models.models import Models
from open_webui.apps.webui.routers import (
auths,
chats,
folders,
configs,
groups,
files,
functions,
memories,
models,
knowledge,
prompts,
evaluations,
tools,
users,
utils,
)
from open_webui.apps.webui.utils import load_function_module_by_id
from open_webui.config import (
ADMIN_EMAIL,
CORS_ALLOW_ORIGIN,
DEFAULT_MODELS,
DEFAULT_PROMPT_SUGGESTIONS,
DEFAULT_USER_ROLE,
MODEL_ORDER_LIST,
ENABLE_COMMUNITY_SHARING,
ENABLE_LOGIN_FORM,
ENABLE_MESSAGE_RATING,
ENABLE_SIGNUP,
ENABLE_API_KEY,
ENABLE_EVALUATION_ARENA_MODELS,
EVALUATION_ARENA_MODELS,
DEFAULT_ARENA_MODEL,
JWT_EXPIRES_IN,
ENABLE_OAUTH_ROLE_MANAGEMENT,
OAUTH_ROLES_CLAIM,
OAUTH_EMAIL_CLAIM,
OAUTH_PICTURE_CLAIM,
OAUTH_USERNAME_CLAIM,
OAUTH_ALLOWED_ROLES,
OAUTH_ADMIN_ROLES,
SHOW_ADMIN_DETAILS,
USER_PERMISSIONS,
WEBHOOK_URL,
WEBUI_AUTH,
WEBUI_BANNERS,
ENABLE_LDAP,
LDAP_SERVER_LABEL,
LDAP_SERVER_HOST,
LDAP_SERVER_PORT,
LDAP_ATTRIBUTE_FOR_USERNAME,
LDAP_SEARCH_FILTERS,
LDAP_SEARCH_BASE,
LDAP_APP_DN,
LDAP_APP_PASSWORD,
LDAP_USE_TLS,
LDAP_CA_CERT_FILE,
LDAP_CIPHERS,
AppConfig,
)
from open_webui.env import (
ENV,
SRC_LOG_LEVELS,
WEBUI_AUTH_TRUSTED_EMAIL_HEADER,
WEBUI_AUTH_TRUSTED_NAME_HEADER,
)
from fastapi import FastAPI
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import StreamingResponse
from pydantic import BaseModel
from open_webui.utils.misc import (
openai_chat_chunk_message_template,
openai_chat_completion_message_template,
)
from open_webui.utils.payload import (
apply_model_params_to_body_openai,
apply_model_system_prompt_to_body,
)
from open_webui.utils.tools import get_tools
app = FastAPI(
docs_url="/docs" if ENV == "dev" else None,
openapi_url="/openapi.json" if ENV == "dev" else None,
redoc_url=None,
)
log = logging.getLogger(__name__)
log.setLevel(SRC_LOG_LEVELS["MAIN"])
app.state.config = AppConfig()
app.state.config.ENABLE_SIGNUP = ENABLE_SIGNUP
app.state.config.ENABLE_LOGIN_FORM = ENABLE_LOGIN_FORM
app.state.config.ENABLE_API_KEY = ENABLE_API_KEY
app.state.config.JWT_EXPIRES_IN = JWT_EXPIRES_IN
app.state.AUTH_TRUSTED_EMAIL_HEADER = WEBUI_AUTH_TRUSTED_EMAIL_HEADER
app.state.AUTH_TRUSTED_NAME_HEADER = WEBUI_AUTH_TRUSTED_NAME_HEADER
app.state.config.SHOW_ADMIN_DETAILS = SHOW_ADMIN_DETAILS
app.state.config.ADMIN_EMAIL = ADMIN_EMAIL
app.state.config.DEFAULT_MODELS = DEFAULT_MODELS
app.state.config.DEFAULT_PROMPT_SUGGESTIONS = DEFAULT_PROMPT_SUGGESTIONS
app.state.config.DEFAULT_USER_ROLE = DEFAULT_USER_ROLE
app.state.config.USER_PERMISSIONS = USER_PERMISSIONS
app.state.config.WEBHOOK_URL = WEBHOOK_URL
app.state.config.BANNERS = WEBUI_BANNERS
app.state.config.MODEL_ORDER_LIST = MODEL_ORDER_LIST
app.state.config.ENABLE_COMMUNITY_SHARING = ENABLE_COMMUNITY_SHARING
app.state.config.ENABLE_MESSAGE_RATING = ENABLE_MESSAGE_RATING
app.state.config.ENABLE_EVALUATION_ARENA_MODELS = ENABLE_EVALUATION_ARENA_MODELS
app.state.config.EVALUATION_ARENA_MODELS = EVALUATION_ARENA_MODELS
app.state.config.OAUTH_USERNAME_CLAIM = OAUTH_USERNAME_CLAIM
app.state.config.OAUTH_PICTURE_CLAIM = OAUTH_PICTURE_CLAIM
app.state.config.OAUTH_EMAIL_CLAIM = OAUTH_EMAIL_CLAIM
app.state.config.ENABLE_OAUTH_ROLE_MANAGEMENT = ENABLE_OAUTH_ROLE_MANAGEMENT
app.state.config.OAUTH_ROLES_CLAIM = OAUTH_ROLES_CLAIM
app.state.config.OAUTH_ALLOWED_ROLES = OAUTH_ALLOWED_ROLES
app.state.config.OAUTH_ADMIN_ROLES = OAUTH_ADMIN_ROLES
app.state.config.ENABLE_LDAP = ENABLE_LDAP
app.state.config.LDAP_SERVER_LABEL = LDAP_SERVER_LABEL
app.state.config.LDAP_SERVER_HOST = LDAP_SERVER_HOST
app.state.config.LDAP_SERVER_PORT = LDAP_SERVER_PORT
app.state.config.LDAP_ATTRIBUTE_FOR_USERNAME = LDAP_ATTRIBUTE_FOR_USERNAME
app.state.config.LDAP_APP_DN = LDAP_APP_DN
app.state.config.LDAP_APP_PASSWORD = LDAP_APP_PASSWORD
app.state.config.LDAP_SEARCH_BASE = LDAP_SEARCH_BASE
app.state.config.LDAP_SEARCH_FILTERS = LDAP_SEARCH_FILTERS
app.state.config.LDAP_USE_TLS = LDAP_USE_TLS
app.state.config.LDAP_CA_CERT_FILE = LDAP_CA_CERT_FILE
app.state.config.LDAP_CIPHERS = LDAP_CIPHERS
app.state.TOOLS = {}
app.state.FUNCTIONS = {}
app.add_middleware(
CORSMiddleware,
allow_origins=CORS_ALLOW_ORIGIN,
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
app.include_router(configs.router, prefix="/configs", tags=["configs"])
app.include_router(auths.router, prefix="/auths", tags=["auths"])
app.include_router(users.router, prefix="/users", tags=["users"])
app.include_router(chats.router, prefix="/chats", tags=["chats"])
app.include_router(models.router, prefix="/models", tags=["models"])
app.include_router(knowledge.router, prefix="/knowledge", tags=["knowledge"])
app.include_router(prompts.router, prefix="/prompts", tags=["prompts"])
app.include_router(tools.router, prefix="/tools", tags=["tools"])
app.include_router(memories.router, prefix="/memories", tags=["memories"])
app.include_router(folders.router, prefix="/folders", tags=["folders"])
app.include_router(groups.router, prefix="/groups", tags=["groups"])
app.include_router(files.router, prefix="/files", tags=["files"])
app.include_router(functions.router, prefix="/functions", tags=["functions"])
app.include_router(evaluations.router, prefix="/evaluations", tags=["evaluations"])
app.include_router(utils.router, prefix="/utils", tags=["utils"])
@app.get("/")
async def get_status():
return {
"status": True,
"auth": WEBUI_AUTH,
"default_models": app.state.config.DEFAULT_MODELS,
"default_prompt_suggestions": app.state.config.DEFAULT_PROMPT_SUGGESTIONS,
}
async def get_all_models():
models = []
pipe_models = await get_pipe_models()
models = models + pipe_models
if app.state.config.ENABLE_EVALUATION_ARENA_MODELS:
arena_models = []
if len(app.state.config.EVALUATION_ARENA_MODELS) > 0:
arena_models = [
{
"id": model["id"],
"name": model["name"],
"info": {
"meta": model["meta"],
},
"object": "model",
"created": int(time.time()),
"owned_by": "arena",
"arena": True,
}
for model in app.state.config.EVALUATION_ARENA_MODELS
]
else:
# Add default arena model
arena_models = [
{
"id": DEFAULT_ARENA_MODEL["id"],
"name": DEFAULT_ARENA_MODEL["name"],
"info": {
"meta": DEFAULT_ARENA_MODEL["meta"],
},
"object": "model",
"created": int(time.time()),
"owned_by": "arena",
"arena": True,
}
]
models = models + arena_models
return models
def get_function_module(pipe_id: str):
# Check if function is already loaded
if pipe_id not in app.state.FUNCTIONS:
function_module, _, _ = load_function_module_by_id(pipe_id)
app.state.FUNCTIONS[pipe_id] = function_module
else:
function_module = app.state.FUNCTIONS[pipe_id]
if hasattr(function_module, "valves") and hasattr(function_module, "Valves"):
valves = Functions.get_function_valves_by_id(pipe_id)
function_module.valves = function_module.Valves(**(valves if valves else {}))
return function_module
async def get_pipe_models():
pipes = Functions.get_functions_by_type("pipe", active_only=True)
pipe_models = []
for pipe in pipes:
function_module = get_function_module(pipe.id)
# Check if function is a manifold
if hasattr(function_module, "pipes"):
sub_pipes = []
# Check if pipes is a function or a list
try:
if callable(function_module.pipes):
sub_pipes = function_module.pipes()
else:
sub_pipes = function_module.pipes
except Exception as e:
log.exception(e)
sub_pipes = []
log.debug(
f"get_pipe_models: function '{pipe.id}' is a manifold of {sub_pipes}"
)
for p in sub_pipes:
sub_pipe_id = f'{pipe.id}.{p["id"]}'
sub_pipe_name = p["name"]
if hasattr(function_module, "name"):
sub_pipe_name = f"{function_module.name}{sub_pipe_name}"
pipe_flag = {"type": pipe.type}
pipe_models.append(
{
"id": sub_pipe_id,
"name": sub_pipe_name,
"object": "model",
"created": pipe.created_at,
"owned_by": "openai",
"pipe": pipe_flag,
}
)
else:
pipe_flag = {"type": "pipe"}
log.debug(
f"get_pipe_models: function '{pipe.id}' is a single pipe {{ 'id': {pipe.id}, 'name': {pipe.name} }}"
)
pipe_models.append(
{
"id": pipe.id,
"name": pipe.name,
"object": "model",
"created": pipe.created_at,
"owned_by": "openai",
"pipe": pipe_flag,
}
)
return pipe_models
async def execute_pipe(pipe, params):
if inspect.iscoroutinefunction(pipe):
return await pipe(**params)
else:
return pipe(**params)
async def get_message_content(res: str | Generator | AsyncGenerator) -> str:
if isinstance(res, str):
return res
if isinstance(res, Generator):
return "".join(map(str, res))
if isinstance(res, AsyncGenerator):
return "".join([str(stream) async for stream in res])
def process_line(form_data: dict, line):
if isinstance(line, BaseModel):
line = line.model_dump_json()
line = f"data: {line}"
if isinstance(line, dict):
line = f"data: {json.dumps(line)}"
try:
line = line.decode("utf-8")
except Exception:
pass
if line.startswith("data:"):
return f"{line}\n\n"
else:
line = openai_chat_chunk_message_template(form_data["model"], line)
return f"data: {json.dumps(line)}\n\n"
def get_pipe_id(form_data: dict) -> str:
pipe_id = form_data["model"]
if "." in pipe_id:
pipe_id, _ = pipe_id.split(".", 1)
return pipe_id
def get_function_params(function_module, form_data, user, extra_params=None):
if extra_params is None:
extra_params = {}
pipe_id = get_pipe_id(form_data)
# Get the signature of the function
sig = inspect.signature(function_module.pipe)
params = {"body": form_data} | {
k: v for k, v in extra_params.items() if k in sig.parameters
}
if "__user__" in params and hasattr(function_module, "UserValves"):
user_valves = Functions.get_user_valves_by_id_and_user_id(pipe_id, user.id)
try:
params["__user__"]["valves"] = function_module.UserValves(**user_valves)
except Exception as e:
log.exception(e)
params["__user__"]["valves"] = function_module.UserValves()
return params
async def generate_function_chat_completion(form_data, user, models: dict = {}):
model_id = form_data.get("model")
model_info = Models.get_model_by_id(model_id)
metadata = form_data.pop("metadata", {})
files = metadata.get("files", [])
tool_ids = metadata.get("tool_ids", [])
# Check if tool_ids is None
if tool_ids is None:
tool_ids = []
__event_emitter__ = None
__event_call__ = None
__task__ = None
__task_body__ = None
if metadata:
if all(k in metadata for k in ("session_id", "chat_id", "message_id")):
__event_emitter__ = get_event_emitter(metadata)
__event_call__ = get_event_call(metadata)
__task__ = metadata.get("task", None)
__task_body__ = metadata.get("task_body", None)
extra_params = {
"__event_emitter__": __event_emitter__,
"__event_call__": __event_call__,
"__task__": __task__,
"__task_body__": __task_body__,
"__files__": files,
"__user__": {
"id": user.id,
"email": user.email,
"name": user.name,
"role": user.role,
},
"__metadata__": metadata,
}
extra_params["__tools__"] = get_tools(
app,
tool_ids,
user,
{
**extra_params,
"__model__": models.get(form_data["model"], None),
"__messages__": form_data["messages"],
"__files__": files,
},
)
if model_info:
if model_info.base_model_id:
form_data["model"] = model_info.base_model_id
params = model_info.params.model_dump()
form_data = apply_model_params_to_body_openai(params, form_data)
form_data = apply_model_system_prompt_to_body(params, form_data, user)
pipe_id = get_pipe_id(form_data)
function_module = get_function_module(pipe_id)
pipe = function_module.pipe
params = get_function_params(function_module, form_data, user, extra_params)
if form_data.get("stream", False):
async def stream_content():
try:
res = await execute_pipe(pipe, params)
# Directly return if the response is a StreamingResponse
if isinstance(res, StreamingResponse):
async for data in res.body_iterator:
yield data
return
if isinstance(res, dict):
yield f"data: {json.dumps(res)}\n\n"
return
except Exception as e:
log.error(f"Error: {e}")
yield f"data: {json.dumps({'error': {'detail':str(e)}})}\n\n"
return
if isinstance(res, str):
message = openai_chat_chunk_message_template(form_data["model"], res)
yield f"data: {json.dumps(message)}\n\n"
if isinstance(res, Iterator):
for line in res:
yield process_line(form_data, line)
if isinstance(res, AsyncGenerator):
async for line in res:
yield process_line(form_data, line)
if isinstance(res, str) or isinstance(res, Generator):
finish_message = openai_chat_chunk_message_template(
form_data["model"], ""
)
finish_message["choices"][0]["finish_reason"] = "stop"
yield f"data: {json.dumps(finish_message)}\n\n"
yield "data: [DONE]"
return StreamingResponse(stream_content(), media_type="text/event-stream")
else:
try:
res = await execute_pipe(pipe, params)
except Exception as e:
log.error(f"Error: {e}")
return {"error": {"detail": str(e)}}
if isinstance(res, StreamingResponse) or isinstance(res, dict):
return res
if isinstance(res, BaseModel):
return res.model_dump()
message = await get_message_content(res)
return openai_chat_completion_message_template(form_data["model"], message)

View File

@ -10,7 +10,7 @@ from urllib.parse import urlparse
import chromadb
import requests
import yaml
from open_webui.apps.webui.internal.db import Base, get_db
from open_webui.internal.db import Base, get_db
from open_webui.env import (
OPEN_WEBUI_DIR,
DATA_DIR,
@ -21,6 +21,7 @@ from open_webui.env import (
WEBUI_NAME,
log,
DATABASE_URL,
OFFLINE_MODE,
)
from pydantic import BaseModel
from sqlalchemy import JSON, Column, DateTime, Integer, func
@ -306,6 +307,18 @@ GOOGLE_CLIENT_SECRET = PersistentConfig(
os.environ.get("GOOGLE_CLIENT_SECRET", ""),
)
GOOGLE_DRIVE_CLIENT_ID = PersistentConfig(
"GOOGLE_DRIVE_CLIENT_ID",
"google_drive.client_id",
os.environ.get("GOOGLE_DRIVE_CLIENT_ID", ""),
)
GOOGLE_DRIVE_API_KEY = PersistentConfig(
"GOOGLE_DRIVE_API_KEY",
"google_drive.api_key",
os.environ.get("GOOGLE_DRIVE_API_KEY", ""),
)
GOOGLE_OAUTH_SCOPE = PersistentConfig(
"GOOGLE_OAUTH_SCOPE",
"oauth.google.scope",
@ -402,12 +415,24 @@ OAUTH_EMAIL_CLAIM = PersistentConfig(
os.environ.get("OAUTH_EMAIL_CLAIM", "email"),
)
OAUTH_GROUPS_CLAIM = PersistentConfig(
"OAUTH_GROUPS_CLAIM",
"oauth.oidc.group_claim",
os.environ.get("OAUTH_GROUP_CLAIM", "groups"),
)
ENABLE_OAUTH_ROLE_MANAGEMENT = PersistentConfig(
"ENABLE_OAUTH_ROLE_MANAGEMENT",
"oauth.enable_role_mapping",
os.environ.get("ENABLE_OAUTH_ROLE_MANAGEMENT", "False").lower() == "true",
)
ENABLE_OAUTH_GROUP_MANAGEMENT = PersistentConfig(
"ENABLE_OAUTH_GROUP_MANAGEMENT",
"oauth.enable_group_mapping",
os.environ.get("ENABLE_OAUTH_GROUP_MANAGEMENT", "False").lower() == "true",
)
OAUTH_ROLES_CLAIM = PersistentConfig(
"OAUTH_ROLES_CLAIM",
"oauth.roles_claim",
@ -429,6 +454,15 @@ OAUTH_ADMIN_ROLES = PersistentConfig(
[role.strip() for role in os.environ.get("OAUTH_ADMIN_ROLES", "admin").split(",")],
)
OAUTH_ALLOWED_DOMAINS = PersistentConfig(
"OAUTH_ALLOWED_DOMAINS",
"oauth.allowed_domains",
[
domain.strip()
for domain in os.environ.get("OAUTH_ALLOWED_DOMAINS", "*").split(",")
],
)
def load_oauth_providers():
OAUTH_PROVIDERS.clear()
@ -686,6 +720,12 @@ OPENAI_API_BASE_URL = "https://api.openai.com/v1"
# WEBUI
####################################
WEBUI_URL = PersistentConfig(
"WEBUI_URL", "webui.url", os.environ.get("WEBUI_URL", "http://localhost:3000")
)
ENABLE_SIGNUP = PersistentConfig(
"ENABLE_SIGNUP",
"ui.enable_signup",
@ -813,6 +853,12 @@ USER_PERMISSIONS = PersistentConfig(
},
)
ENABLE_CHANNELS = PersistentConfig(
"ENABLE_CHANNELS",
"channels.enable",
os.environ.get("ENABLE_CHANNELS", "False").lower() == "true",
)
ENABLE_EVALUATION_ARENA_MODELS = PersistentConfig(
"ENABLE_EVALUATION_ARENA_MODELS",
@ -948,12 +994,45 @@ TITLE_GENERATION_PROMPT_TEMPLATE = PersistentConfig(
os.environ.get("TITLE_GENERATION_PROMPT_TEMPLATE", ""),
)
DEFAULT_TITLE_GENERATION_PROMPT_TEMPLATE = """Create a concise, 3-5 word title with an emoji as a title for the chat history, in the given language. Suitable Emojis for the summary can be used to enhance understanding but avoid quotation marks or special formatting. RESPOND ONLY WITH THE TITLE TEXT.
Examples of titles:
📉 Stock Market Trends
🍪 Perfect Chocolate Chip Recipe
Evolution of Music Streaming
Remote Work Productivity Tips
Artificial Intelligence in Healthcare
🎮 Video Game Development Insights
<chat_history>
{{MESSAGES:END:2}}
</chat_history>"""
TAGS_GENERATION_PROMPT_TEMPLATE = PersistentConfig(
"TAGS_GENERATION_PROMPT_TEMPLATE",
"task.tags.prompt_template",
os.environ.get("TAGS_GENERATION_PROMPT_TEMPLATE", ""),
)
DEFAULT_TAGS_GENERATION_PROMPT_TEMPLATE = """### Task:
Generate 1-3 broad tags categorizing the main themes of the chat history, along with 1-3 more specific subtopic tags.
### Guidelines:
- Start with high-level domains (e.g. Science, Technology, Philosophy, Arts, Politics, Business, Health, Sports, Entertainment, Education)
- Consider including relevant subfields/subdomains if they are strongly represented throughout the conversation
- If content is too short (less than 3 messages) or too diverse, use only ["General"]
- Use the chat's primary language; default to English if multilingual
- Prioritize accuracy over specificity
### Output:
JSON format: { "tags": ["tag1", "tag2", "tag3"] }
### Chat History:
<chat_history>
{{MESSAGES:END:6}}
</chat_history>"""
ENABLE_TAGS_GENERATION = PersistentConfig(
"ENABLE_TAGS_GENERATION",
"task.tags.enable",
@ -1072,6 +1151,19 @@ TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE = PersistentConfig(
)
DEFAULT_TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE = """Available Tools: {{TOOLS}}\nReturn an empty string if no tools match the query. If a function tool matches, construct and return a JSON object in the format {\"name\": \"functionName\", \"parameters\": {\"requiredFunctionParamKey\": \"requiredFunctionParamValue\"}} using the appropriate tool and its parameters. Only return the object and limit the response to the JSON object without additional text."""
DEFAULT_EMOJI_GENERATION_PROMPT_TEMPLATE = """Your task is to reflect the speaker's likely facial expression through a fitting emoji. Interpret emotions from the message and reflect their facial expression using fitting, diverse emojis (e.g., 😊, 😢, 😡, 😱).
Message: ```{{prompt}}```"""
DEFAULT_MOA_GENERATION_PROMPT_TEMPLATE = """You have been provided with a set of responses from various models to the latest user query: "{{prompt}}"
Your task is to synthesize these responses into a single, high-quality response. It is crucial to critically evaluate the information provided in these responses, recognizing that some of it may be biased or incorrect. Your response should not simply replicate the given answers but should offer a refined, accurate, and comprehensive reply to the instruction. Ensure your response is well-structured, coherent, and adheres to the highest standards of accuracy and reliability.
Responses from models: {{responses}}"""
####################################
# Vector Database
####################################
@ -1123,6 +1215,15 @@ if VECTOR_DB == "pgvector" and not PGVECTOR_DB_URL.startswith("postgres"):
# Information Retrieval (RAG)
####################################
# If configured, Google Drive will be available as an upload option.
ENABLE_GOOGLE_DRIVE_INTEGRATION = PersistentConfig(
"ENABLE_GOOGLE_DRIVE_INTEGRATION",
"google_drive.enable",
os.getenv("ENABLE_GOOGLE_DRIVE_INTEGRATION", "False").lower() == "true",
)
# RAG Content Extraction
CONTENT_EXTRACTION_ENGINE = PersistentConfig(
"CONTENT_EXTRACTION_ENGINE",
@ -1197,7 +1298,8 @@ RAG_EMBEDDING_MODEL = PersistentConfig(
log.info(f"Embedding model set: {RAG_EMBEDDING_MODEL.value}")
RAG_EMBEDDING_MODEL_AUTO_UPDATE = (
os.environ.get("RAG_EMBEDDING_MODEL_AUTO_UPDATE", "True").lower() == "true"
not OFFLINE_MODE
and os.environ.get("RAG_EMBEDDING_MODEL_AUTO_UPDATE", "True").lower() == "true"
)
RAG_EMBEDDING_MODEL_TRUST_REMOTE_CODE = (
@ -1222,7 +1324,8 @@ if RAG_RERANKING_MODEL.value != "":
log.info(f"Reranking model set: {RAG_RERANKING_MODEL.value}")
RAG_RERANKING_MODEL_AUTO_UPDATE = (
os.environ.get("RAG_RERANKING_MODEL_AUTO_UPDATE", "True").lower() == "true"
not OFFLINE_MODE
and os.environ.get("RAG_RERANKING_MODEL_AUTO_UPDATE", "True").lower() == "true"
)
RAG_RERANKING_MODEL_TRUST_REMOTE_CODE = (
@ -1356,6 +1459,7 @@ RAG_WEB_SEARCH_DOMAIN_FILTER_LIST = PersistentConfig(
],
)
SEARXNG_QUERY_URL = PersistentConfig(
"SEARXNG_QUERY_URL",
"rag.web.search.searxng_query_url",
@ -1380,6 +1484,12 @@ BRAVE_SEARCH_API_KEY = PersistentConfig(
os.getenv("BRAVE_SEARCH_API_KEY", ""),
)
KAGI_SEARCH_API_KEY = PersistentConfig(
"KAGI_SEARCH_API_KEY",
"rag.web.search.kagi_search_api_key",
os.getenv("KAGI_SEARCH_API_KEY", ""),
)
MOJEEK_SEARCH_API_KEY = PersistentConfig(
"MOJEEK_SEARCH_API_KEY",
"rag.web.search.mojeek_search_api_key",
@ -1525,6 +1635,12 @@ COMFYUI_BASE_URL = PersistentConfig(
os.getenv("COMFYUI_BASE_URL", ""),
)
COMFYUI_API_KEY = PersistentConfig(
"COMFYUI_API_KEY",
"image_generation.comfyui.api_key",
os.getenv("COMFYUI_API_KEY", ""),
)
COMFYUI_DEFAULT_WORKFLOW = """
{
"3": {
@ -1686,7 +1802,8 @@ WHISPER_MODEL = PersistentConfig(
WHISPER_MODEL_DIR = os.getenv("WHISPER_MODEL_DIR", f"{CACHE_DIR}/whisper/models")
WHISPER_MODEL_AUTO_UPDATE = (
os.environ.get("WHISPER_MODEL_AUTO_UPDATE", "").lower() == "true"
not OFFLINE_MODE
and os.environ.get("WHISPER_MODEL_AUTO_UPDATE", "").lower() == "true"
)

View File

@ -103,8 +103,6 @@ WEBUI_NAME = os.environ.get("WEBUI_NAME", "Open WebUI")
if WEBUI_NAME != "Open WebUI":
WEBUI_NAME += " (Open WebUI)"
WEBUI_URL = os.environ.get("WEBUI_URL", "http://localhost:3000")
WEBUI_FAVICON_URL = "https://openwebui.com/favicon.png"
@ -376,7 +374,7 @@ else:
AIOHTTP_CLIENT_TIMEOUT = 300
AIOHTTP_CLIENT_TIMEOUT_OPENAI_MODEL_LIST = os.environ.get(
"AIOHTTP_CLIENT_TIMEOUT_OPENAI_MODEL_LIST", "5"
"AIOHTTP_CLIENT_TIMEOUT_OPENAI_MODEL_LIST", ""
)
if AIOHTTP_CLIENT_TIMEOUT_OPENAI_MODEL_LIST == "":

View File

@ -0,0 +1,316 @@
import logging
import sys
import inspect
import json
from pydantic import BaseModel
from typing import AsyncGenerator, Generator, Iterator
from fastapi import (
Depends,
FastAPI,
File,
Form,
HTTPException,
Request,
UploadFile,
status,
)
from starlette.responses import Response, StreamingResponse
from open_webui.socket.main import (
get_event_call,
get_event_emitter,
)
from open_webui.models.functions import Functions
from open_webui.models.models import Models
from open_webui.utils.plugin import load_function_module_by_id
from open_webui.utils.tools import get_tools
from open_webui.utils.access_control import has_access
from open_webui.env import SRC_LOG_LEVELS, GLOBAL_LOG_LEVEL
from open_webui.utils.misc import (
add_or_update_system_message,
get_last_user_message,
prepend_to_first_user_message_content,
openai_chat_chunk_message_template,
openai_chat_completion_message_template,
)
from open_webui.utils.payload import (
apply_model_params_to_body_openai,
apply_model_system_prompt_to_body,
)
logging.basicConfig(stream=sys.stdout, level=GLOBAL_LOG_LEVEL)
log = logging.getLogger(__name__)
log.setLevel(SRC_LOG_LEVELS["MAIN"])
def get_function_module_by_id(request: Request, pipe_id: str):
# Check if function is already loaded
if pipe_id not in request.app.state.FUNCTIONS:
function_module, _, _ = load_function_module_by_id(pipe_id)
request.app.state.FUNCTIONS[pipe_id] = function_module
else:
function_module = request.app.state.FUNCTIONS[pipe_id]
if hasattr(function_module, "valves") and hasattr(function_module, "Valves"):
valves = Functions.get_function_valves_by_id(pipe_id)
function_module.valves = function_module.Valves(**(valves if valves else {}))
return function_module
async def get_function_models(request):
pipes = Functions.get_functions_by_type("pipe", active_only=True)
pipe_models = []
for pipe in pipes:
function_module = get_function_module_by_id(request, pipe.id)
# Check if function is a manifold
if hasattr(function_module, "pipes"):
sub_pipes = []
# Check if pipes is a function or a list
try:
if callable(function_module.pipes):
sub_pipes = function_module.pipes()
else:
sub_pipes = function_module.pipes
except Exception as e:
log.exception(e)
sub_pipes = []
log.debug(
f"get_function_models: function '{pipe.id}' is a manifold of {sub_pipes}"
)
for p in sub_pipes:
sub_pipe_id = f'{pipe.id}.{p["id"]}'
sub_pipe_name = p["name"]
if hasattr(function_module, "name"):
sub_pipe_name = f"{function_module.name}{sub_pipe_name}"
pipe_flag = {"type": pipe.type}
pipe_models.append(
{
"id": sub_pipe_id,
"name": sub_pipe_name,
"object": "model",
"created": pipe.created_at,
"owned_by": "openai",
"pipe": pipe_flag,
}
)
else:
pipe_flag = {"type": "pipe"}
log.debug(
f"get_function_models: function '{pipe.id}' is a single pipe {{ 'id': {pipe.id}, 'name': {pipe.name} }}"
)
pipe_models.append(
{
"id": pipe.id,
"name": pipe.name,
"object": "model",
"created": pipe.created_at,
"owned_by": "openai",
"pipe": pipe_flag,
}
)
return pipe_models
async def generate_function_chat_completion(
request, form_data, user, models: dict = {}
):
async def execute_pipe(pipe, params):
if inspect.iscoroutinefunction(pipe):
return await pipe(**params)
else:
return pipe(**params)
async def get_message_content(res: str | Generator | AsyncGenerator) -> str:
if isinstance(res, str):
return res
if isinstance(res, Generator):
return "".join(map(str, res))
if isinstance(res, AsyncGenerator):
return "".join([str(stream) async for stream in res])
def process_line(form_data: dict, line):
if isinstance(line, BaseModel):
line = line.model_dump_json()
line = f"data: {line}"
if isinstance(line, dict):
line = f"data: {json.dumps(line)}"
try:
line = line.decode("utf-8")
except Exception:
pass
if line.startswith("data:"):
return f"{line}\n\n"
else:
line = openai_chat_chunk_message_template(form_data["model"], line)
return f"data: {json.dumps(line)}\n\n"
def get_pipe_id(form_data: dict) -> str:
pipe_id = form_data["model"]
if "." in pipe_id:
pipe_id, _ = pipe_id.split(".", 1)
return pipe_id
def get_function_params(function_module, form_data, user, extra_params=None):
if extra_params is None:
extra_params = {}
pipe_id = get_pipe_id(form_data)
# Get the signature of the function
sig = inspect.signature(function_module.pipe)
params = {"body": form_data} | {
k: v for k, v in extra_params.items() if k in sig.parameters
}
if "__user__" in params and hasattr(function_module, "UserValves"):
user_valves = Functions.get_user_valves_by_id_and_user_id(pipe_id, user.id)
try:
params["__user__"]["valves"] = function_module.UserValves(**user_valves)
except Exception as e:
log.exception(e)
params["__user__"]["valves"] = function_module.UserValves()
return params
model_id = form_data.get("model")
model_info = Models.get_model_by_id(model_id)
metadata = form_data.pop("metadata", {})
files = metadata.get("files", [])
tool_ids = metadata.get("tool_ids", [])
# Check if tool_ids is None
if tool_ids is None:
tool_ids = []
__event_emitter__ = None
__event_call__ = None
__task__ = None
__task_body__ = None
if metadata:
if all(k in metadata for k in ("session_id", "chat_id", "message_id")):
__event_emitter__ = get_event_emitter(metadata)
__event_call__ = get_event_call(metadata)
__task__ = metadata.get("task", None)
__task_body__ = metadata.get("task_body", None)
extra_params = {
"__event_emitter__": __event_emitter__,
"__event_call__": __event_call__,
"__task__": __task__,
"__task_body__": __task_body__,
"__files__": files,
"__user__": {
"id": user.id,
"email": user.email,
"name": user.name,
"role": user.role,
},
"__metadata__": metadata,
"__request__": request,
}
extra_params["__tools__"] = get_tools(
request,
tool_ids,
user,
{
**extra_params,
"__model__": models.get(form_data["model"], None),
"__messages__": form_data["messages"],
"__files__": files,
},
)
if model_info:
if model_info.base_model_id:
form_data["model"] = model_info.base_model_id
params = model_info.params.model_dump()
form_data = apply_model_params_to_body_openai(params, form_data)
form_data = apply_model_system_prompt_to_body(params, form_data, user)
pipe_id = get_pipe_id(form_data)
function_module = get_function_module_by_id(request, pipe_id)
pipe = function_module.pipe
params = get_function_params(function_module, form_data, user, extra_params)
if form_data.get("stream", False):
async def stream_content():
try:
res = await execute_pipe(pipe, params)
# Directly return if the response is a StreamingResponse
if isinstance(res, StreamingResponse):
async for data in res.body_iterator:
yield data
return
if isinstance(res, dict):
yield f"data: {json.dumps(res)}\n\n"
return
except Exception as e:
log.error(f"Error: {e}")
yield f"data: {json.dumps({'error': {'detail':str(e)}})}\n\n"
return
if isinstance(res, str):
message = openai_chat_chunk_message_template(form_data["model"], res)
yield f"data: {json.dumps(message)}\n\n"
if isinstance(res, Iterator):
for line in res:
yield process_line(form_data, line)
if isinstance(res, AsyncGenerator):
async for line in res:
yield process_line(form_data, line)
if isinstance(res, str) or isinstance(res, Generator):
finish_message = openai_chat_chunk_message_template(
form_data["model"], ""
)
finish_message["choices"][0]["finish_reason"] = "stop"
yield f"data: {json.dumps(finish_message)}\n\n"
yield "data: [DONE]"
return StreamingResponse(stream_content(), media_type="text/event-stream")
else:
try:
res = await execute_pipe(pipe, params)
except Exception as e:
log.error(f"Error: {e}")
return {"error": {"detail": str(e)}}
if isinstance(res, StreamingResponse) or isinstance(res, dict):
return res
if isinstance(res, BaseModel):
return res.model_dump()
message = await get_message_content(res)
return openai_chat_completion_message_template(form_data["model"], message)

View File

@ -3,7 +3,7 @@ import logging
from contextlib import contextmanager
from typing import Any, Optional
from open_webui.apps.webui.internal.wrappers import register_connection
from open_webui.internal.wrappers import register_connection
from open_webui.env import (
OPEN_WEBUI_DIR,
DATABASE_URL,
@ -54,7 +54,7 @@ def handle_peewee_migration(DATABASE_URL):
try:
# Replace the postgresql:// with postgres:// to handle the peewee migration
db = register_connection(DATABASE_URL.replace("postgresql://", "postgres://"))
migrate_dir = OPEN_WEBUI_DIR / "apps" / "webui" / "internal" / "migrations"
migrate_dir = OPEN_WEBUI_DIR / "internal" / "migrations"
router = Router(db, logger=log, migrate_dir=migrate_dir)
router.run()
db.close()

File diff suppressed because it is too large Load Diff

View File

@ -1,7 +1,7 @@
from logging.config import fileConfig
from alembic import context
from open_webui.apps.webui.models.auths import Auth
from open_webui.models.auths import Auth
from open_webui.env import DATABASE_URL
from sqlalchemy import engine_from_config, pool

View File

@ -9,7 +9,7 @@ from typing import Sequence, Union
from alembic import op
import sqlalchemy as sa
import open_webui.apps.webui.internal.db
import open_webui.internal.db
${imports if imports else ""}
# revision identifiers, used by Alembic.

View File

@ -0,0 +1,48 @@
"""Add channel table
Revision ID: 57c599a3cb57
Revises: 922e7a387820
Create Date: 2024-12-22 03:00:00.000000
"""
from alembic import op
import sqlalchemy as sa
revision = "57c599a3cb57"
down_revision = "922e7a387820"
branch_labels = None
depends_on = None
def upgrade():
op.create_table(
"channel",
sa.Column("id", sa.Text(), nullable=False, primary_key=True, unique=True),
sa.Column("user_id", sa.Text()),
sa.Column("name", sa.Text()),
sa.Column("description", sa.Text(), nullable=True),
sa.Column("data", sa.JSON(), nullable=True),
sa.Column("meta", sa.JSON(), nullable=True),
sa.Column("access_control", sa.JSON(), nullable=True),
sa.Column("created_at", sa.BigInteger(), nullable=True),
sa.Column("updated_at", sa.BigInteger(), nullable=True),
)
op.create_table(
"message",
sa.Column("id", sa.Text(), nullable=False, primary_key=True, unique=True),
sa.Column("user_id", sa.Text()),
sa.Column("channel_id", sa.Text(), nullable=True),
sa.Column("content", sa.Text()),
sa.Column("data", sa.JSON(), nullable=True),
sa.Column("meta", sa.JSON(), nullable=True),
sa.Column("created_at", sa.BigInteger(), nullable=True),
sa.Column("updated_at", sa.BigInteger(), nullable=True),
)
def downgrade():
op.drop_table("channel")
op.drop_table("message")

View File

@ -0,0 +1,26 @@
"""Update file table
Revision ID: 7826ab40b532
Revises: 57c599a3cb57
Create Date: 2024-12-23 03:00:00.000000
"""
from alembic import op
import sqlalchemy as sa
revision = "7826ab40b532"
down_revision = "57c599a3cb57"
branch_labels = None
depends_on = None
def upgrade():
op.add_column(
"file",
sa.Column("access_control", sa.JSON(), nullable=True),
)
def downgrade():
op.drop_column("file", "access_control")

View File

@ -11,8 +11,8 @@ from typing import Sequence, Union
import sqlalchemy as sa
from alembic import op
import open_webui.apps.webui.internal.db
from open_webui.apps.webui.internal.db import JSONField
import open_webui.internal.db
from open_webui.internal.db import JSONField
from open_webui.migrations.util import get_existing_tables
# revision identifiers, used by Alembic.

View File

@ -2,12 +2,12 @@ import logging
import uuid
from typing import Optional
from open_webui.apps.webui.internal.db import Base, get_db
from open_webui.apps.webui.models.users import UserModel, Users
from open_webui.internal.db import Base, get_db
from open_webui.models.users import UserModel, Users
from open_webui.env import SRC_LOG_LEVELS
from pydantic import BaseModel
from sqlalchemy import Boolean, Column, String, Text
from open_webui.utils.utils import verify_password
from open_webui.utils.auth import verify_password
log = logging.getLogger(__name__)
log.setLevel(SRC_LOG_LEVELS["MODELS"])

View File

@ -0,0 +1,132 @@
import json
import time
import uuid
from typing import Optional
from open_webui.internal.db import Base, get_db
from open_webui.utils.access_control import has_access
from pydantic import BaseModel, ConfigDict
from sqlalchemy import BigInteger, Boolean, Column, String, Text, JSON
from sqlalchemy import or_, func, select, and_, text
from sqlalchemy.sql import exists
####################
# Channel DB Schema
####################
class Channel(Base):
__tablename__ = "channel"
id = Column(Text, primary_key=True)
user_id = Column(Text)
name = Column(Text)
description = Column(Text, nullable=True)
data = Column(JSON, nullable=True)
meta = Column(JSON, nullable=True)
access_control = Column(JSON, nullable=True)
created_at = Column(BigInteger)
updated_at = Column(BigInteger)
class ChannelModel(BaseModel):
model_config = ConfigDict(from_attributes=True)
id: str
user_id: str
description: Optional[str] = None
name: str
data: Optional[dict] = None
meta: Optional[dict] = None
access_control: Optional[dict] = None
created_at: int # timestamp in epoch
updated_at: int # timestamp in epoch
####################
# Forms
####################
class ChannelForm(BaseModel):
name: str
description: Optional[str] = None
data: Optional[dict] = None
meta: Optional[dict] = None
access_control: Optional[dict] = None
class ChannelTable:
def insert_new_channel(
self, form_data: ChannelForm, user_id: str
) -> Optional[ChannelModel]:
with get_db() as db:
channel = ChannelModel(
**{
**form_data.model_dump(),
"name": form_data.name.lower(),
"id": str(uuid.uuid4()),
"user_id": user_id,
"created_at": int(time.time_ns()),
"updated_at": int(time.time_ns()),
}
)
new_channel = Channel(**channel.model_dump())
db.add(new_channel)
db.commit()
return channel
def get_channels(self) -> list[ChannelModel]:
with get_db() as db:
channels = db.query(Channel).all()
return [ChannelModel.model_validate(channel) for channel in channels]
def get_channels_by_user_id(
self, user_id: str, permission: str = "read"
) -> list[ChannelModel]:
channels = self.get_channels()
return [
channel
for channel in channels
if channel.user_id == user_id
or has_access(user_id, permission, channel.access_control)
]
def get_channel_by_id(self, id: str) -> Optional[ChannelModel]:
with get_db() as db:
channel = db.query(Channel).filter(Channel.id == id).first()
return ChannelModel.model_validate(channel) if channel else None
def update_channel_by_id(
self, id: str, form_data: ChannelForm
) -> Optional[ChannelModel]:
with get_db() as db:
channel = db.query(Channel).filter(Channel.id == id).first()
if not channel:
return None
channel.name = form_data.name
channel.data = form_data.data
channel.meta = form_data.meta
channel.access_control = form_data.access_control
channel.updated_at = int(time.time_ns())
db.commit()
return ChannelModel.model_validate(channel) if channel else None
def delete_channel_by_id(self, id: str):
with get_db() as db:
db.query(Channel).filter(Channel.id == id).delete()
db.commit()
return True
Channels = ChannelTable()

View File

@ -3,8 +3,8 @@ import time
import uuid
from typing import Optional
from open_webui.apps.webui.internal.db import Base, get_db
from open_webui.apps.webui.models.tags import TagModel, Tag, Tags
from open_webui.internal.db import Base, get_db
from open_webui.models.tags import TagModel, Tag, Tags
from pydantic import BaseModel, ConfigDict
@ -168,6 +168,91 @@ class ChatTable:
except Exception:
return None
def update_chat_title_by_id(self, id: str, title: str) -> Optional[ChatModel]:
chat = self.get_chat_by_id(id)
if chat is None:
return None
chat = chat.chat
chat["title"] = title
return self.update_chat_by_id(id, chat)
def update_chat_tags_by_id(
self, id: str, tags: list[str], user
) -> Optional[ChatModel]:
chat = self.get_chat_by_id(id)
if chat is None:
return None
self.delete_all_tags_by_id_and_user_id(id, user.id)
for tag in chat.meta.get("tags", []):
if self.count_chats_by_tag_name_and_user_id(tag, user.id) == 0:
Tags.delete_tag_by_name_and_user_id(tag, user.id)
for tag_name in tags:
if tag_name.lower() == "none":
continue
self.add_chat_tag_by_id_and_user_id_and_tag_name(id, user.id, tag_name)
return self.get_chat_by_id(id)
def get_chat_title_by_id(self, id: str) -> Optional[str]:
chat = self.get_chat_by_id(id)
if chat is None:
return None
return chat.chat.get("title", "New Chat")
def get_messages_by_chat_id(self, id: str) -> Optional[dict]:
chat = self.get_chat_by_id(id)
if chat is None:
return None
return chat.chat.get("history", {}).get("messages", {}) or {}
def upsert_message_to_chat_by_id_and_message_id(
self, id: str, message_id: str, message: dict
) -> Optional[ChatModel]:
chat = self.get_chat_by_id(id)
if chat is None:
return None
chat = chat.chat
history = chat.get("history", {})
if message_id in history.get("messages", {}):
history["messages"][message_id] = {
**history["messages"][message_id],
**message,
}
else:
history["messages"][message_id] = message
history["currentId"] = message_id
chat["history"] = history
return self.update_chat_by_id(id, chat)
def add_message_status_to_chat_by_id_and_message_id(
self, id: str, message_id: str, status: dict
) -> Optional[ChatModel]:
chat = self.get_chat_by_id(id)
if chat is None:
return None
chat = chat.chat
history = chat.get("history", {})
if message_id in history.get("messages", {}):
status_history = history["messages"][message_id].get("statusHistory", [])
status_history.append(status)
history["messages"][message_id]["statusHistory"] = status_history
chat["history"] = history
return self.update_chat_by_id(id, chat)
def insert_shared_chat_by_chat_id(self, chat_id: str) -> Optional[ChatModel]:
with get_db() as db:
# Get the existing chat to share

View File

@ -3,8 +3,8 @@ import time
import uuid
from typing import Optional
from open_webui.apps.webui.internal.db import Base, get_db
from open_webui.apps.webui.models.chats import Chats
from open_webui.internal.db import Base, get_db
from open_webui.models.chats import Chats
from open_webui.env import SRC_LOG_LEVELS
from pydantic import BaseModel, ConfigDict

View File

@ -2,7 +2,7 @@ import logging
import time
from typing import Optional
from open_webui.apps.webui.internal.db import Base, JSONField, get_db
from open_webui.internal.db import Base, JSONField, get_db
from open_webui.env import SRC_LOG_LEVELS
from pydantic import BaseModel, ConfigDict
from sqlalchemy import BigInteger, Column, String, Text, JSON
@ -27,6 +27,8 @@ class File(Base):
data = Column(JSON, nullable=True)
meta = Column(JSON, nullable=True)
access_control = Column(JSON, nullable=True)
created_at = Column(BigInteger)
updated_at = Column(BigInteger)
@ -44,6 +46,8 @@ class FileModel(BaseModel):
data: Optional[dict] = None
meta: Optional[dict] = None
access_control: Optional[dict] = None
created_at: Optional[int] # timestamp in epoch
updated_at: Optional[int] # timestamp in epoch
@ -90,6 +94,7 @@ class FileForm(BaseModel):
path: str
data: dict = {}
meta: dict = {}
access_control: Optional[dict] = None
class FilesTable:

View File

@ -3,8 +3,8 @@ import time
import uuid
from typing import Optional
from open_webui.apps.webui.internal.db import Base, get_db
from open_webui.apps.webui.models.chats import Chats
from open_webui.internal.db import Base, get_db
from open_webui.models.chats import Chats
from open_webui.env import SRC_LOG_LEVELS
from pydantic import BaseModel, ConfigDict

View File

@ -2,8 +2,8 @@ import logging
import time
from typing import Optional
from open_webui.apps.webui.internal.db import Base, JSONField, get_db
from open_webui.apps.webui.models.users import Users
from open_webui.internal.db import Base, JSONField, get_db
from open_webui.models.users import Users
from open_webui.env import SRC_LOG_LEVELS
from pydantic import BaseModel, ConfigDict
from sqlalchemy import BigInteger, Boolean, Column, String, Text

View File

@ -4,10 +4,10 @@ import time
from typing import Optional
import uuid
from open_webui.apps.webui.internal.db import Base, get_db
from open_webui.internal.db import Base, get_db
from open_webui.env import SRC_LOG_LEVELS
from open_webui.apps.webui.models.files import FileMetadataResponse
from open_webui.models.files import FileMetadataResponse
from pydantic import BaseModel, ConfigDict
@ -146,6 +146,13 @@ class GroupTable:
except Exception:
return None
def get_group_user_ids_by_id(self, id: str) -> Optional[str]:
group = self.get_group_by_id(id)
if group:
return group.user_ids
else:
return None
def update_group_by_id(
self, id: str, form_data: GroupUpdateForm, overwrite: bool = False
) -> Optional[GroupModel]:

View File

@ -4,11 +4,11 @@ import time
from typing import Optional
import uuid
from open_webui.apps.webui.internal.db import Base, get_db
from open_webui.internal.db import Base, get_db
from open_webui.env import SRC_LOG_LEVELS
from open_webui.apps.webui.models.files import FileMetadataResponse
from open_webui.apps.webui.models.users import Users, UserResponse
from open_webui.models.files import FileMetadataResponse
from open_webui.models.users import Users, UserResponse
from pydantic import BaseModel, ConfigDict

View File

@ -2,7 +2,7 @@ import time
import uuid
from typing import Optional
from open_webui.apps.webui.internal.db import Base, get_db
from open_webui.internal.db import Base, get_db
from pydantic import BaseModel, ConfigDict
from sqlalchemy import BigInteger, Column, String, Text

View File

@ -0,0 +1,141 @@
import json
import time
import uuid
from typing import Optional
from open_webui.internal.db import Base, get_db
from open_webui.models.tags import TagModel, Tag, Tags
from pydantic import BaseModel, ConfigDict
from sqlalchemy import BigInteger, Boolean, Column, String, Text, JSON
from sqlalchemy import or_, func, select, and_, text
from sqlalchemy.sql import exists
####################
# Message DB Schema
####################
class Message(Base):
__tablename__ = "message"
id = Column(Text, primary_key=True)
user_id = Column(Text)
channel_id = Column(Text, nullable=True)
content = Column(Text)
data = Column(JSON, nullable=True)
meta = Column(JSON, nullable=True)
created_at = Column(BigInteger) # time_ns
updated_at = Column(BigInteger) # time_ns
class MessageModel(BaseModel):
model_config = ConfigDict(from_attributes=True)
id: str
user_id: str
channel_id: Optional[str] = None
content: str
data: Optional[dict] = None
meta: Optional[dict] = None
created_at: int # timestamp in epoch
updated_at: int # timestamp in epoch
####################
# Forms
####################
class MessageForm(BaseModel):
content: str
data: Optional[dict] = None
meta: Optional[dict] = None
class MessageTable:
def insert_new_message(
self, form_data: MessageForm, channel_id: str, user_id: str
) -> Optional[MessageModel]:
with get_db() as db:
id = str(uuid.uuid4())
ts = int(time.time_ns())
message = MessageModel(
**{
"id": id,
"user_id": user_id,
"channel_id": channel_id,
"content": form_data.content,
"data": form_data.data,
"meta": form_data.meta,
"created_at": ts,
"updated_at": ts,
}
)
result = Message(**message.model_dump())
db.add(result)
db.commit()
db.refresh(result)
return MessageModel.model_validate(result) if result else None
def get_message_by_id(self, id: str) -> Optional[MessageModel]:
with get_db() as db:
message = db.get(Message, id)
return MessageModel.model_validate(message) if message else None
def get_messages_by_channel_id(
self, channel_id: str, skip: int = 0, limit: int = 50
) -> list[MessageModel]:
with get_db() as db:
all_messages = (
db.query(Message)
.filter_by(channel_id=channel_id)
.order_by(Message.created_at.desc())
.offset(skip)
.limit(limit)
.all()
)
return [MessageModel.model_validate(message) for message in all_messages]
def get_messages_by_user_id(
self, user_id: str, skip: int = 0, limit: int = 50
) -> list[MessageModel]:
with get_db() as db:
all_messages = (
db.query(Message)
.filter_by(user_id=user_id)
.order_by(Message.created_at.desc())
.offset(skip)
.limit(limit)
.all()
)
return [MessageModel.model_validate(message) for message in all_messages]
def update_message_by_id(
self, id: str, form_data: MessageForm
) -> Optional[MessageModel]:
with get_db() as db:
message = db.get(Message, id)
message.content = form_data.content
message.data = form_data.data
message.meta = form_data.meta
message.updated_at = int(time.time_ns())
db.commit()
db.refresh(message)
return MessageModel.model_validate(message) if message else None
def delete_message_by_id(self, id: str) -> bool:
with get_db() as db:
db.query(Message).filter_by(id=id).delete()
db.commit()
return True
Messages = MessageTable()

View File

@ -2,10 +2,10 @@ import logging
import time
from typing import Optional
from open_webui.apps.webui.internal.db import Base, JSONField, get_db
from open_webui.internal.db import Base, JSONField, get_db
from open_webui.env import SRC_LOG_LEVELS
from open_webui.apps.webui.models.users import Users, UserResponse
from open_webui.models.users import Users, UserResponse
from pydantic import BaseModel, ConfigDict

View File

@ -1,8 +1,8 @@
import time
from typing import Optional
from open_webui.apps.webui.internal.db import Base, get_db
from open_webui.apps.webui.models.users import Users, UserResponse
from open_webui.internal.db import Base, get_db
from open_webui.models.users import Users, UserResponse
from pydantic import BaseModel, ConfigDict
from sqlalchemy import BigInteger, Column, String, Text, JSON

View File

@ -3,7 +3,7 @@ import time
import uuid
from typing import Optional
from open_webui.apps.webui.internal.db import Base, get_db
from open_webui.internal.db import Base, get_db
from open_webui.env import SRC_LOG_LEVELS

View File

@ -2,8 +2,8 @@ import logging
import time
from typing import Optional
from open_webui.apps.webui.internal.db import Base, JSONField, get_db
from open_webui.apps.webui.models.users import Users, UserResponse
from open_webui.internal.db import Base, JSONField, get_db
from open_webui.models.users import Users, UserResponse
from open_webui.env import SRC_LOG_LEVELS
from pydantic import BaseModel, ConfigDict
from sqlalchemy import BigInteger, Column, String, Text, JSON

View File

@ -1,8 +1,8 @@
import time
from typing import Optional
from open_webui.apps.webui.internal.db import Base, JSONField, get_db
from open_webui.apps.webui.models.chats import Chats
from open_webui.internal.db import Base, JSONField, get_db
from open_webui.models.chats import Chats
from pydantic import BaseModel, ConfigDict
from sqlalchemy import BigInteger, Column, String, Text
@ -70,6 +70,13 @@ class UserResponse(BaseModel):
profile_image_url: str
class UserNameResponse(BaseModel):
id: str
name: str
role: str
profile_image_url: str
class UserRoleUpdateForm(BaseModel):
id: str
role: str
@ -147,13 +154,25 @@ class UsersTable:
except Exception:
return None
def get_users(self, skip: int = 0, limit: int = 50) -> list[UserModel]:
def get_users(
self, skip: Optional[int] = None, limit: Optional[int] = None
) -> list[UserModel]:
with get_db() as db:
users = (
db.query(User)
# .offset(skip).limit(limit)
.all()
)
query = db.query(User).order_by(User.created_at.desc())
if skip:
query = query.offset(skip)
if limit:
query = query.limit(limit)
users = query.all()
return [UserModel.model_validate(user) for user in users]
def get_users_by_user_ids(self, user_ids: list[str]) -> list[UserModel]:
with get_db() as db:
users = db.query(User).filter(User.id.in_(user_ids)).all()
return [UserModel.model_validate(user) for user in users]
def get_num_users(self) -> Optional[int]:
@ -168,6 +187,22 @@ class UsersTable:
except Exception:
return None
def get_user_webhook_url_by_id(self, id: str) -> Optional[str]:
try:
with get_db() as db:
user = db.query(User).filter_by(id=id).first()
if user.settings is None:
return None
else:
return (
user.settings.get("ui", {})
.get("notifications", {})
.get("webhook_url", None)
)
except Exception:
return None
def update_user_role_by_id(self, id: str, role: str) -> Optional[UserModel]:
try:
with get_db() as db:

View File

@ -1,6 +1,7 @@
import requests
import logging
import ftfy
import sys
from langchain_community.document_loaders import (
BSHTMLLoader,
@ -18,8 +19,9 @@ from langchain_community.document_loaders import (
YoutubeLoader,
)
from langchain_core.documents import Document
from open_webui.env import SRC_LOG_LEVELS
from open_webui.env import SRC_LOG_LEVELS, GLOBAL_LOG_LEVEL
logging.basicConfig(stream=sys.stdout, level=GLOBAL_LOG_LEVEL)
log = logging.getLogger(__name__)
log.setLevel(SRC_LOG_LEVELS["RAG"])
@ -106,7 +108,7 @@ class TikaLoader:
if "Content-Type" in raw_metadata:
headers["Content-Type"] = raw_metadata["Content-Type"]
log.info("Tika extracted text: %s", text)
log.debug("Tika extracted text: %s", text)
return [Document(page_content=text, metadata=headers)]
else:

View File

@ -11,7 +11,7 @@ from langchain.retrievers import ContextualCompressionRetriever, EnsembleRetriev
from langchain_community.retrievers import BM25Retriever
from langchain_core.documents import Document
from open_webui.apps.retrieval.vector.connector import VECTOR_DB_CLIENT
from open_webui.retrieval.vector.connector import VECTOR_DB_CLIENT
from open_webui.utils.misc import get_last_user_message
from open_webui.env import SRC_LOG_LEVELS
@ -70,7 +70,9 @@ def query_doc(
limit=k,
)
log.info(f"query_doc:result {result.ids} {result.metadatas}")
if result:
log.info(f"query_doc:result {result.ids} {result.metadatas}")
return result
except Exception as e:
print(e)

View File

@ -0,0 +1,22 @@
from open_webui.config import VECTOR_DB
if VECTOR_DB == "milvus":
from open_webui.retrieval.vector.dbs.milvus import MilvusClient
VECTOR_DB_CLIENT = MilvusClient()
elif VECTOR_DB == "qdrant":
from open_webui.retrieval.vector.dbs.qdrant import QdrantClient
VECTOR_DB_CLIENT = QdrantClient()
elif VECTOR_DB == "opensearch":
from open_webui.retrieval.vector.dbs.opensearch import OpenSearchClient
VECTOR_DB_CLIENT = OpenSearchClient()
elif VECTOR_DB == "pgvector":
from open_webui.retrieval.vector.dbs.pgvector import PgvectorClient
VECTOR_DB_CLIENT = PgvectorClient()
else:
from open_webui.retrieval.vector.dbs.chroma import ChromaClient
VECTOR_DB_CLIENT = ChromaClient()

View File

@ -4,7 +4,7 @@ from chromadb.utils.batch_utils import create_batches
from typing import Optional
from open_webui.apps.retrieval.vector.main import VectorItem, SearchResult, GetResult
from open_webui.retrieval.vector.main import VectorItem, SearchResult, GetResult
from open_webui.config import (
CHROMA_DATA_PATH,
CHROMA_HTTP_HOST,

View File

@ -4,7 +4,7 @@ import json
from typing import Optional
from open_webui.apps.retrieval.vector.main import VectorItem, SearchResult, GetResult
from open_webui.retrieval.vector.main import VectorItem, SearchResult, GetResult
from open_webui.config import (
MILVUS_URI,
)

View File

@ -1,7 +1,7 @@
from opensearchpy import OpenSearch
from typing import Optional
from open_webui.apps.retrieval.vector.main import VectorItem, SearchResult, GetResult
from open_webui.retrieval.vector.main import VectorItem, SearchResult, GetResult
from open_webui.config import (
OPENSEARCH_URI,
OPENSEARCH_SSL,

View File

@ -18,7 +18,7 @@ from sqlalchemy.dialects.postgresql import JSONB, array
from pgvector.sqlalchemy import Vector
from sqlalchemy.ext.mutable import MutableDict
from open_webui.apps.retrieval.vector.main import VectorItem, SearchResult, GetResult
from open_webui.retrieval.vector.main import VectorItem, SearchResult, GetResult
from open_webui.config import PGVECTOR_DB_URL
VECTOR_LENGTH = 1536
@ -40,7 +40,7 @@ class PgvectorClient:
# if no pgvector uri, use the existing database connection
if not PGVECTOR_DB_URL:
from open_webui.apps.webui.internal.db import Session
from open_webui.internal.db import Session
self.session = Session
else:

View File

@ -4,7 +4,7 @@ from qdrant_client import QdrantClient as Qclient
from qdrant_client.http.models import PointStruct
from qdrant_client.models import models
from open_webui.apps.retrieval.vector.main import VectorItem, SearchResult, GetResult
from open_webui.retrieval.vector.main import VectorItem, SearchResult, GetResult
from open_webui.config import QDRANT_URI, QDRANT_API_KEY
NO_LIMIT = 999999999

View File

@ -3,7 +3,7 @@ import os
from pprint import pprint
from typing import Optional
import requests
from open_webui.apps.retrieval.web.main import SearchResult, get_filtered_results
from open_webui.retrieval.web.main import SearchResult, get_filtered_results
from open_webui.env import SRC_LOG_LEVELS
import argparse

View File

@ -2,7 +2,7 @@ import logging
from typing import Optional
import requests
from open_webui.apps.retrieval.web.main import SearchResult, get_filtered_results
from open_webui.retrieval.web.main import SearchResult, get_filtered_results
from open_webui.env import SRC_LOG_LEVELS
log = logging.getLogger(__name__)

View File

@ -1,7 +1,7 @@
import logging
from typing import Optional
from open_webui.apps.retrieval.web.main import SearchResult, get_filtered_results
from open_webui.retrieval.web.main import SearchResult, get_filtered_results
from duckduckgo_search import DDGS
from open_webui.env import SRC_LOG_LEVELS

View File

@ -2,7 +2,7 @@ import logging
from typing import Optional
import requests
from open_webui.apps.retrieval.web.main import SearchResult, get_filtered_results
from open_webui.retrieval.web.main import SearchResult, get_filtered_results
from open_webui.env import SRC_LOG_LEVELS
log = logging.getLogger(__name__)

View File

@ -1,7 +1,7 @@
import logging
import requests
from open_webui.apps.retrieval.web.main import SearchResult
from open_webui.retrieval.web.main import SearchResult
from open_webui.env import SRC_LOG_LEVELS
from yarl import URL

View File

@ -0,0 +1,48 @@
import logging
from typing import Optional
import requests
from open_webui.retrieval.web.main import SearchResult, get_filtered_results
from open_webui.env import SRC_LOG_LEVELS
log = logging.getLogger(__name__)
log.setLevel(SRC_LOG_LEVELS["RAG"])
def search_kagi(
api_key: str, query: str, count: int, filter_list: Optional[list[str]] = None
) -> list[SearchResult]:
"""Search using Kagi's Search API and return the results as a list of SearchResult objects.
The Search API will inherit the settings in your account, including results personalization and snippet length.
Args:
api_key (str): A Kagi Search API key
query (str): The query to search for
count (int): The number of results to return
"""
url = "https://kagi.com/api/v0/search"
headers = {
"Authorization": f"Bot {api_key}",
}
params = {"q": query, "limit": count}
response = requests.get(url, headers=headers, params=params)
response.raise_for_status()
json_response = response.json()
search_results = json_response.get("data", [])
results = [
SearchResult(
link=result["url"], title=result["title"], snippet=result.get("snippet")
)
for result in search_results
if result["t"] == 0
]
print(results)
if filter_list:
results = get_filtered_results(results, filter_list)
return results

View File

@ -2,7 +2,7 @@ import logging
from typing import Optional
import requests
from open_webui.apps.retrieval.web.main import SearchResult, get_filtered_results
from open_webui.retrieval.web.main import SearchResult, get_filtered_results
from open_webui.env import SRC_LOG_LEVELS
log = logging.getLogger(__name__)

View File

@ -3,7 +3,7 @@ from typing import Optional
from urllib.parse import urlencode
import requests
from open_webui.apps.retrieval.web.main import SearchResult, get_filtered_results
from open_webui.retrieval.web.main import SearchResult, get_filtered_results
from open_webui.env import SRC_LOG_LEVELS
log = logging.getLogger(__name__)

View File

@ -2,7 +2,7 @@ import logging
from typing import Optional
import requests
from open_webui.apps.retrieval.web.main import SearchResult, get_filtered_results
from open_webui.retrieval.web.main import SearchResult, get_filtered_results
from open_webui.env import SRC_LOG_LEVELS
log = logging.getLogger(__name__)

View File

@ -3,7 +3,7 @@ import logging
from typing import Optional
import requests
from open_webui.apps.retrieval.web.main import SearchResult, get_filtered_results
from open_webui.retrieval.web.main import SearchResult, get_filtered_results
from open_webui.env import SRC_LOG_LEVELS
log = logging.getLogger(__name__)

View File

@ -3,7 +3,7 @@ from typing import Optional
from urllib.parse import urlencode
import requests
from open_webui.apps.retrieval.web.main import SearchResult, get_filtered_results
from open_webui.retrieval.web.main import SearchResult, get_filtered_results
from open_webui.env import SRC_LOG_LEVELS
log = logging.getLogger(__name__)

View File

@ -2,7 +2,7 @@ import logging
from typing import Optional
import requests
from open_webui.apps.retrieval.web.main import SearchResult, get_filtered_results
from open_webui.retrieval.web.main import SearchResult, get_filtered_results
from open_webui.env import SRC_LOG_LEVELS
log = logging.getLogger(__name__)

View File

@ -1,7 +1,7 @@
import logging
import requests
from open_webui.apps.retrieval.web.main import SearchResult
from open_webui.retrieval.web.main import SearchResult
from open_webui.env import SRC_LOG_LEVELS
log = logging.getLogger(__name__)

View File

@ -82,15 +82,15 @@ class SafeWebBaseLoader(WebBaseLoader):
def get_web_loader(
url: Union[str, Sequence[str]],
urls: Union[str, Sequence[str]],
verify_ssl: bool = True,
requests_per_second: int = 2,
):
# Check if the URL is valid
if not validate_url(url):
if not validate_url(urls):
raise ValueError(ERROR_MESSAGES.INVALID_URL)
return SafeWebBaseLoader(
url,
urls,
verify_ssl=verify_ssl,
requests_per_second=requests_per_second,
continue_on_failure=True,

View File

@ -0,0 +1,713 @@
import hashlib
import json
import logging
import os
import uuid
from functools import lru_cache
from pathlib import Path
from pydub import AudioSegment
from pydub.silence import split_on_silence
import aiohttp
import aiofiles
import requests
from fastapi import (
Depends,
FastAPI,
File,
HTTPException,
Request,
UploadFile,
status,
APIRouter,
)
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import FileResponse
from pydantic import BaseModel
from open_webui.utils.auth import get_admin_user, get_verified_user
from open_webui.config import (
WHISPER_MODEL_AUTO_UPDATE,
WHISPER_MODEL_DIR,
CACHE_DIR,
)
from open_webui.constants import ERROR_MESSAGES
from open_webui.env import (
ENV,
SRC_LOG_LEVELS,
DEVICE_TYPE,
ENABLE_FORWARD_USER_INFO_HEADERS,
)
router = APIRouter()
# Constants
MAX_FILE_SIZE_MB = 25
MAX_FILE_SIZE = MAX_FILE_SIZE_MB * 1024 * 1024 # Convert MB to bytes
log = logging.getLogger(__name__)
log.setLevel(SRC_LOG_LEVELS["AUDIO"])
SPEECH_CACHE_DIR = Path(CACHE_DIR).joinpath("./audio/speech/")
SPEECH_CACHE_DIR.mkdir(parents=True, exist_ok=True)
##########################################
#
# Utility functions
#
##########################################
from pydub import AudioSegment
from pydub.utils import mediainfo
def is_mp4_audio(file_path):
"""Check if the given file is an MP4 audio file."""
if not os.path.isfile(file_path):
print(f"File not found: {file_path}")
return False
info = mediainfo(file_path)
if (
info.get("codec_name") == "aac"
and info.get("codec_type") == "audio"
and info.get("codec_tag_string") == "mp4a"
):
return True
return False
def convert_mp4_to_wav(file_path, output_path):
"""Convert MP4 audio file to WAV format."""
audio = AudioSegment.from_file(file_path, format="mp4")
audio.export(output_path, format="wav")
print(f"Converted {file_path} to {output_path}")
def set_faster_whisper_model(model: str, auto_update: bool = False):
whisper_model = None
if model:
from faster_whisper import WhisperModel
faster_whisper_kwargs = {
"model_size_or_path": model,
"device": DEVICE_TYPE if DEVICE_TYPE and DEVICE_TYPE == "cuda" else "cpu",
"compute_type": "int8",
"download_root": WHISPER_MODEL_DIR,
"local_files_only": not auto_update,
}
try:
whisper_model = WhisperModel(**faster_whisper_kwargs)
except Exception:
log.warning(
"WhisperModel initialization failed, attempting download with local_files_only=False"
)
faster_whisper_kwargs["local_files_only"] = False
whisper_model = WhisperModel(**faster_whisper_kwargs)
return whisper_model
##########################################
#
# Audio API
#
##########################################
class TTSConfigForm(BaseModel):
OPENAI_API_BASE_URL: str
OPENAI_API_KEY: str
API_KEY: str
ENGINE: str
MODEL: str
VOICE: str
SPLIT_ON: str
AZURE_SPEECH_REGION: str
AZURE_SPEECH_OUTPUT_FORMAT: str
class STTConfigForm(BaseModel):
OPENAI_API_BASE_URL: str
OPENAI_API_KEY: str
ENGINE: str
MODEL: str
WHISPER_MODEL: str
class AudioConfigUpdateForm(BaseModel):
tts: TTSConfigForm
stt: STTConfigForm
@router.get("/config")
async def get_audio_config(request: Request, user=Depends(get_admin_user)):
return {
"tts": {
"OPENAI_API_BASE_URL": request.app.state.config.TTS_OPENAI_API_BASE_URL,
"OPENAI_API_KEY": request.app.state.config.TTS_OPENAI_API_KEY,
"API_KEY": request.app.state.config.TTS_API_KEY,
"ENGINE": request.app.state.config.TTS_ENGINE,
"MODEL": request.app.state.config.TTS_MODEL,
"VOICE": request.app.state.config.TTS_VOICE,
"SPLIT_ON": request.app.state.config.TTS_SPLIT_ON,
"AZURE_SPEECH_REGION": request.app.state.config.TTS_AZURE_SPEECH_REGION,
"AZURE_SPEECH_OUTPUT_FORMAT": request.app.state.config.TTS_AZURE_SPEECH_OUTPUT_FORMAT,
},
"stt": {
"OPENAI_API_BASE_URL": request.app.state.config.STT_OPENAI_API_BASE_URL,
"OPENAI_API_KEY": request.app.state.config.STT_OPENAI_API_KEY,
"ENGINE": request.app.state.config.STT_ENGINE,
"MODEL": request.app.state.config.STT_MODEL,
"WHISPER_MODEL": request.app.state.config.WHISPER_MODEL,
},
}
@router.post("/config/update")
async def update_audio_config(
request: Request, form_data: AudioConfigUpdateForm, user=Depends(get_admin_user)
):
request.app.state.config.TTS_OPENAI_API_BASE_URL = form_data.tts.OPENAI_API_BASE_URL
request.app.state.config.TTS_OPENAI_API_KEY = form_data.tts.OPENAI_API_KEY
request.app.state.config.TTS_API_KEY = form_data.tts.API_KEY
request.app.state.config.TTS_ENGINE = form_data.tts.ENGINE
request.app.state.config.TTS_MODEL = form_data.tts.MODEL
request.app.state.config.TTS_VOICE = form_data.tts.VOICE
request.app.state.config.TTS_SPLIT_ON = form_data.tts.SPLIT_ON
request.app.state.config.TTS_AZURE_SPEECH_REGION = form_data.tts.AZURE_SPEECH_REGION
request.app.state.config.TTS_AZURE_SPEECH_OUTPUT_FORMAT = (
form_data.tts.AZURE_SPEECH_OUTPUT_FORMAT
)
request.app.state.config.STT_OPENAI_API_BASE_URL = form_data.stt.OPENAI_API_BASE_URL
request.app.state.config.STT_OPENAI_API_KEY = form_data.stt.OPENAI_API_KEY
request.app.state.config.STT_ENGINE = form_data.stt.ENGINE
request.app.state.config.STT_MODEL = form_data.stt.MODEL
request.app.state.config.WHISPER_MODEL = form_data.stt.WHISPER_MODEL
if request.app.state.config.STT_ENGINE == "":
request.app.state.faster_whisper_model = set_faster_whisper_model(
form_data.stt.WHISPER_MODEL, WHISPER_MODEL_AUTO_UPDATE
)
return {
"tts": {
"OPENAI_API_BASE_URL": request.app.state.config.TTS_OPENAI_API_BASE_URL,
"OPENAI_API_KEY": request.app.state.config.TTS_OPENAI_API_KEY,
"API_KEY": request.app.state.config.TTS_API_KEY,
"ENGINE": request.app.state.config.TTS_ENGINE,
"MODEL": request.app.state.config.TTS_MODEL,
"VOICE": request.app.state.config.TTS_VOICE,
"SPLIT_ON": request.app.state.config.TTS_SPLIT_ON,
"AZURE_SPEECH_REGION": request.app.state.config.TTS_AZURE_SPEECH_REGION,
"AZURE_SPEECH_OUTPUT_FORMAT": request.app.state.config.TTS_AZURE_SPEECH_OUTPUT_FORMAT,
},
"stt": {
"OPENAI_API_BASE_URL": request.app.state.config.STT_OPENAI_API_BASE_URL,
"OPENAI_API_KEY": request.app.state.config.STT_OPENAI_API_KEY,
"ENGINE": request.app.state.config.STT_ENGINE,
"MODEL": request.app.state.config.STT_MODEL,
"WHISPER_MODEL": request.app.state.config.WHISPER_MODEL,
},
}
def load_speech_pipeline(request):
from transformers import pipeline
from datasets import load_dataset
if request.app.state.speech_synthesiser is None:
request.app.state.speech_synthesiser = pipeline(
"text-to-speech", "microsoft/speecht5_tts"
)
if request.app.state.speech_speaker_embeddings_dataset is None:
request.app.state.speech_speaker_embeddings_dataset = load_dataset(
"Matthijs/cmu-arctic-xvectors", split="validation"
)
@router.post("/speech")
async def speech(request: Request, user=Depends(get_verified_user)):
body = await request.body()
name = hashlib.sha256(
body
+ str(request.app.state.config.TTS_ENGINE).encode("utf-8")
+ str(request.app.state.config.TTS_MODEL).encode("utf-8")
).hexdigest()
file_path = SPEECH_CACHE_DIR.joinpath(f"{name}.mp3")
file_body_path = SPEECH_CACHE_DIR.joinpath(f"{name}.json")
# Check if the file already exists in the cache
if file_path.is_file():
return FileResponse(file_path)
payload = None
try:
payload = json.loads(body.decode("utf-8"))
except Exception as e:
log.exception(e)
raise HTTPException(status_code=400, detail="Invalid JSON payload")
if request.app.state.config.TTS_ENGINE == "openai":
payload["model"] = request.app.state.config.TTS_MODEL
try:
# print(payload)
async with aiohttp.ClientSession() as session:
async with session.post(
url=f"{request.app.state.config.TTS_OPENAI_API_BASE_URL}/audio/speech",
json=payload,
headers={
"Content-Type": "application/json",
"Authorization": f"Bearer {request.app.state.config.TTS_OPENAI_API_KEY}",
**(
{
"X-OpenWebUI-User-Name": user.name,
"X-OpenWebUI-User-Id": user.id,
"X-OpenWebUI-User-Email": user.email,
"X-OpenWebUI-User-Role": user.role,
}
if ENABLE_FORWARD_USER_INFO_HEADERS
else {}
),
},
) as r:
r.raise_for_status()
async with aiofiles.open(file_path, "wb") as f:
await f.write(await r.read())
async with aiofiles.open(file_body_path, "w") as f:
await f.write(json.dumps(payload))
return FileResponse(file_path)
except Exception as e:
log.exception(e)
detail = None
try:
if r.status != 200:
res = await r.json()
if "error" in res:
detail = f"External: {res['error'].get('message', '')}"
except Exception:
detail = f"External: {e}"
raise HTTPException(
status_code=getattr(r, "status", 500),
detail=detail if detail else "Open WebUI: Server Connection Error",
)
elif request.app.state.config.TTS_ENGINE == "elevenlabs":
voice_id = payload.get("voice", "")
if voice_id not in get_available_voices():
raise HTTPException(
status_code=400,
detail="Invalid voice id",
)
try:
async with aiohttp.ClientSession() as session:
async with session.post(
f"https://api.elevenlabs.io/v1/text-to-speech/{voice_id}",
json={
"text": payload["input"],
"model_id": request.app.state.config.TTS_MODEL,
"voice_settings": {"stability": 0.5, "similarity_boost": 0.5},
},
headers={
"Accept": "audio/mpeg",
"Content-Type": "application/json",
"xi-api-key": request.app.state.config.TTS_API_KEY,
},
) as r:
r.raise_for_status()
async with aiofiles.open(file_path, "wb") as f:
await f.write(await r.read())
async with aiofiles.open(file_body_path, "w") as f:
await f.write(json.dumps(payload))
return FileResponse(file_path)
except Exception as e:
log.exception(e)
detail = None
try:
if r.status != 200:
res = await r.json()
if "error" in res:
detail = f"External: {res['error'].get('message', '')}"
except Exception:
detail = f"External: {e}"
raise HTTPException(
status_code=getattr(r, "status", 500),
detail=detail if detail else "Open WebUI: Server Connection Error",
)
elif request.app.state.config.TTS_ENGINE == "azure":
try:
payload = json.loads(body.decode("utf-8"))
except Exception as e:
log.exception(e)
raise HTTPException(status_code=400, detail="Invalid JSON payload")
region = request.app.state.config.TTS_AZURE_SPEECH_REGION
language = request.app.state.config.TTS_VOICE
locale = "-".join(request.app.state.config.TTS_VOICE.split("-")[:1])
output_format = request.app.state.config.TTS_AZURE_SPEECH_OUTPUT_FORMAT
try:
data = f"""<speak version="1.0" xmlns="http://www.w3.org/2001/10/synthesis" xml:lang="{locale}">
<voice name="{language}">{payload["input"]}</voice>
</speak>"""
async with aiohttp.ClientSession() as session:
async with session.post(
f"https://{region}.tts.speech.microsoft.com/cognitiveservices/v1",
headers={
"Ocp-Apim-Subscription-Key": request.app.state.config.TTS_API_KEY,
"Content-Type": "application/ssml+xml",
"X-Microsoft-OutputFormat": output_format,
},
data=data,
) as r:
r.raise_for_status()
async with aiofiles.open(file_path, "wb") as f:
await f.write(await r.read())
async with aiofiles.open(file_body_path, "w") as f:
await f.write(json.dumps(payload))
return FileResponse(file_path)
except Exception as e:
log.exception(e)
detail = None
try:
if r.status != 200:
res = await r.json()
if "error" in res:
detail = f"External: {res['error'].get('message', '')}"
except Exception:
detail = f"External: {e}"
raise HTTPException(
status_code=getattr(r, "status", 500),
detail=detail if detail else "Open WebUI: Server Connection Error",
)
elif request.app.state.config.TTS_ENGINE == "transformers":
payload = None
try:
payload = json.loads(body.decode("utf-8"))
except Exception as e:
log.exception(e)
raise HTTPException(status_code=400, detail="Invalid JSON payload")
import torch
import soundfile as sf
load_speech_pipeline(request)
embeddings_dataset = request.app.state.speech_speaker_embeddings_dataset
speaker_index = 6799
try:
speaker_index = embeddings_dataset["filename"].index(
request.app.state.config.TTS_MODEL
)
except Exception:
pass
speaker_embedding = torch.tensor(
embeddings_dataset[speaker_index]["xvector"]
).unsqueeze(0)
speech = request.app.state.speech_synthesiser(
payload["input"],
forward_params={"speaker_embeddings": speaker_embedding},
)
sf.write(file_path, speech["audio"], samplerate=speech["sampling_rate"])
async with aiofiles.open(file_body_path, "w") as f:
await f.write(json.dumps(payload))
return FileResponse(file_path)
def transcribe(request: Request, file_path):
print("transcribe", file_path)
filename = os.path.basename(file_path)
file_dir = os.path.dirname(file_path)
id = filename.split(".")[0]
if request.app.state.config.STT_ENGINE == "":
if request.app.state.faster_whisper_model is None:
request.app.state.faster_whisper_model = set_faster_whisper_model(
request.app.state.config.WHISPER_MODEL
)
model = request.app.state.faster_whisper_model
segments, info = model.transcribe(file_path, beam_size=5)
log.info(
"Detected language '%s' with probability %f"
% (info.language, info.language_probability)
)
transcript = "".join([segment.text for segment in list(segments)])
data = {"text": transcript.strip()}
# save the transcript to a json file
transcript_file = f"{file_dir}/{id}.json"
with open(transcript_file, "w") as f:
json.dump(data, f)
log.debug(data)
return data
elif request.app.state.config.STT_ENGINE == "openai":
if is_mp4_audio(file_path):
os.rename(file_path, file_path.replace(".wav", ".mp4"))
# Convert MP4 audio file to WAV format
convert_mp4_to_wav(file_path.replace(".wav", ".mp4"), file_path)
r = None
try:
r = requests.post(
url=f"{request.app.state.config.STT_OPENAI_API_BASE_URL}/audio/transcriptions",
headers={
"Authorization": f"Bearer {request.app.state.config.STT_OPENAI_API_KEY}"
},
files={"file": (filename, open(file_path, "rb"))},
data={"model": request.app.state.config.STT_MODEL},
)
r.raise_for_status()
data = r.json()
# save the transcript to a json file
transcript_file = f"{file_dir}/{id}.json"
with open(transcript_file, "w") as f:
json.dump(data, f)
return data
except Exception as e:
log.exception(e)
detail = None
if r is not None:
try:
res = r.json()
if "error" in res:
detail = f"External: {res['error'].get('message', '')}"
except Exception:
detail = f"External: {e}"
raise Exception(detail if detail else "Open WebUI: Server Connection Error")
def compress_audio(file_path):
if os.path.getsize(file_path) > MAX_FILE_SIZE:
file_dir = os.path.dirname(file_path)
audio = AudioSegment.from_file(file_path)
audio = audio.set_frame_rate(16000).set_channels(1) # Compress audio
compressed_path = f"{file_dir}/{id}_compressed.opus"
audio.export(compressed_path, format="opus", bitrate="32k")
log.debug(f"Compressed audio to {compressed_path}")
if (
os.path.getsize(compressed_path) > MAX_FILE_SIZE
): # Still larger than MAX_FILE_SIZE after compression
raise Exception(ERROR_MESSAGES.FILE_TOO_LARGE(size=f"{MAX_FILE_SIZE_MB}MB"))
return compressed_path
else:
return file_path
@router.post("/transcriptions")
def transcription(
request: Request,
file: UploadFile = File(...),
user=Depends(get_verified_user),
):
log.info(f"file.content_type: {file.content_type}")
if file.content_type not in ["audio/mpeg", "audio/wav", "audio/ogg", "audio/x-m4a"]:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=ERROR_MESSAGES.FILE_NOT_SUPPORTED,
)
try:
ext = file.filename.split(".")[-1]
id = uuid.uuid4()
filename = f"{id}.{ext}"
contents = file.file.read()
file_dir = f"{CACHE_DIR}/audio/transcriptions"
os.makedirs(file_dir, exist_ok=True)
file_path = f"{file_dir}/{filename}"
with open(file_path, "wb") as f:
f.write(contents)
try:
try:
file_path = compress_audio(file_path)
except Exception as e:
log.exception(e)
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=ERROR_MESSAGES.DEFAULT(e),
)
data = transcribe(request, file_path)
file_path = file_path.split("/")[-1]
return {**data, "filename": file_path}
except Exception as e:
log.exception(e)
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=ERROR_MESSAGES.DEFAULT(e),
)
except Exception as e:
log.exception(e)
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=ERROR_MESSAGES.DEFAULT(e),
)
def get_available_models(request: Request) -> list[dict]:
available_models = []
if request.app.state.config.TTS_ENGINE == "openai":
available_models = [{"id": "tts-1"}, {"id": "tts-1-hd"}]
elif request.app.state.config.TTS_ENGINE == "elevenlabs":
try:
response = requests.get(
"https://api.elevenlabs.io/v1/models",
headers={
"xi-api-key": request.app.state.config.TTS_API_KEY,
"Content-Type": "application/json",
},
timeout=5,
)
response.raise_for_status()
models = response.json()
available_models = [
{"name": model["name"], "id": model["model_id"]} for model in models
]
except requests.RequestException as e:
log.error(f"Error fetching voices: {str(e)}")
return available_models
@router.get("/models")
async def get_models(request: Request, user=Depends(get_verified_user)):
return {"models": get_available_models(request)}
def get_available_voices(request) -> dict:
"""Returns {voice_id: voice_name} dict"""
available_voices = {}
if request.app.state.config.TTS_ENGINE == "openai":
available_voices = {
"alloy": "alloy",
"echo": "echo",
"fable": "fable",
"onyx": "onyx",
"nova": "nova",
"shimmer": "shimmer",
}
elif request.app.state.config.TTS_ENGINE == "elevenlabs":
try:
available_voices = get_elevenlabs_voices(
api_key=request.app.state.config.TTS_API_KEY
)
except Exception:
# Avoided @lru_cache with exception
pass
elif request.app.state.config.TTS_ENGINE == "azure":
try:
region = request.app.state.config.TTS_AZURE_SPEECH_REGION
url = f"https://{region}.tts.speech.microsoft.com/cognitiveservices/voices/list"
headers = {
"Ocp-Apim-Subscription-Key": request.app.state.config.TTS_API_KEY
}
response = requests.get(url, headers=headers)
response.raise_for_status()
voices = response.json()
for voice in voices:
available_voices[voice["ShortName"]] = (
f"{voice['DisplayName']} ({voice['ShortName']})"
)
except requests.RequestException as e:
log.error(f"Error fetching voices: {str(e)}")
return available_voices
@lru_cache
def get_elevenlabs_voices(api_key: str) -> dict:
"""
Note, set the following in your .env file to use Elevenlabs:
AUDIO_TTS_ENGINE=elevenlabs
AUDIO_TTS_API_KEY=sk_... # Your Elevenlabs API key
AUDIO_TTS_VOICE=EXAVITQu4vr4xnSDxMaL # From https://api.elevenlabs.io/v1/voices
AUDIO_TTS_MODEL=eleven_multilingual_v2
"""
try:
# TODO: Add retries
response = requests.get(
"https://api.elevenlabs.io/v1/voices",
headers={
"xi-api-key": api_key,
"Content-Type": "application/json",
},
)
response.raise_for_status()
voices_data = response.json()
voices = {}
for voice in voices_data.get("voices", []):
voices[voice["voice_id"]] = voice["name"]
except requests.RequestException as e:
# Avoid @lru_cache with exception
log.error(f"Error fetching voices: {str(e)}")
raise RuntimeError(f"Error fetching voices: {str(e)}")
return voices
@router.get("/voices")
async def get_voices(request: Request, user=Depends(get_verified_user)):
return {
"voices": [
{"id": k, "name": v} for k, v in get_available_voices(request).items()
]
}

View File

@ -3,8 +3,9 @@ import uuid
import time
import datetime
import logging
from aiohttp import ClientSession
from open_webui.apps.webui.models.auths import (
from open_webui.models.auths import (
AddUserForm,
ApiKey,
Auths,
@ -17,7 +18,7 @@ from open_webui.apps.webui.models.auths import (
UpdateProfileForm,
UserResponse,
)
from open_webui.apps.webui.models.users import Users
from open_webui.models.users import Users
from open_webui.constants import ERROR_MESSAGES, WEBHOOK_MESSAGES
from open_webui.env import (
@ -29,10 +30,14 @@ from open_webui.env import (
SRC_LOG_LEVELS,
)
from fastapi import APIRouter, Depends, HTTPException, Request, status
from fastapi.responses import Response
from fastapi.responses import RedirectResponse, Response
from open_webui.config import (
OPENID_PROVIDER_URL,
ENABLE_OAUTH_SIGNUP,
)
from pydantic import BaseModel
from open_webui.utils.misc import parse_duration, validate_email_format
from open_webui.utils.utils import (
from open_webui.utils.auth import (
create_api_key,
create_token,
get_admin_user,
@ -498,8 +503,31 @@ async def signup(request: Request, response: Response, form_data: SignupForm):
@router.get("/signout")
async def signout(response: Response):
async def signout(request: Request, response: Response):
response.delete_cookie("token")
if ENABLE_OAUTH_SIGNUP.value:
oauth_id_token = request.cookies.get("oauth_id_token")
if oauth_id_token:
try:
async with ClientSession() as session:
async with session.get(OPENID_PROVIDER_URL.value) as resp:
if resp.status == 200:
openid_data = await resp.json()
logout_url = openid_data.get("end_session_endpoint")
if logout_url:
response.delete_cookie("oauth_id_token")
return RedirectResponse(
url=f"{logout_url}?id_token_hint={oauth_id_token}"
)
else:
raise HTTPException(
status_code=resp.status,
detail="Failed to fetch OpenID configuration",
)
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
return {"status": True}
@ -519,7 +547,6 @@ async def add_user(form_data: AddUserForm, user=Depends(get_admin_user)):
raise HTTPException(400, detail=ERROR_MESSAGES.EMAIL_TAKEN)
try:
print(form_data)
hashed = get_password_hash(form_data.password)
user = Auths.insert_new_auth(
form_data.email.lower(),
@ -586,8 +613,10 @@ async def get_admin_details(request: Request, user=Depends(get_current_user)):
async def get_admin_config(request: Request, user=Depends(get_admin_user)):
return {
"SHOW_ADMIN_DETAILS": request.app.state.config.SHOW_ADMIN_DETAILS,
"WEBUI_URL": request.app.state.config.WEBUI_URL,
"ENABLE_SIGNUP": request.app.state.config.ENABLE_SIGNUP,
"ENABLE_API_KEY": request.app.state.config.ENABLE_API_KEY,
"ENABLE_CHANNELS": request.app.state.config.ENABLE_CHANNELS,
"DEFAULT_USER_ROLE": request.app.state.config.DEFAULT_USER_ROLE,
"JWT_EXPIRES_IN": request.app.state.config.JWT_EXPIRES_IN,
"ENABLE_COMMUNITY_SHARING": request.app.state.config.ENABLE_COMMUNITY_SHARING,
@ -597,8 +626,10 @@ async def get_admin_config(request: Request, user=Depends(get_admin_user)):
class AdminConfig(BaseModel):
SHOW_ADMIN_DETAILS: bool
WEBUI_URL: str
ENABLE_SIGNUP: bool
ENABLE_API_KEY: bool
ENABLE_CHANNELS: bool
DEFAULT_USER_ROLE: str
JWT_EXPIRES_IN: str
ENABLE_COMMUNITY_SHARING: bool
@ -610,8 +641,10 @@ async def update_admin_config(
request: Request, form_data: AdminConfig, user=Depends(get_admin_user)
):
request.app.state.config.SHOW_ADMIN_DETAILS = form_data.SHOW_ADMIN_DETAILS
request.app.state.config.WEBUI_URL = form_data.WEBUI_URL
request.app.state.config.ENABLE_SIGNUP = form_data.ENABLE_SIGNUP
request.app.state.config.ENABLE_API_KEY = form_data.ENABLE_API_KEY
request.app.state.config.ENABLE_CHANNELS = form_data.ENABLE_CHANNELS
if form_data.DEFAULT_USER_ROLE in ["pending", "user", "admin"]:
request.app.state.config.DEFAULT_USER_ROLE = form_data.DEFAULT_USER_ROLE
@ -629,8 +662,10 @@ async def update_admin_config(
return {
"SHOW_ADMIN_DETAILS": request.app.state.config.SHOW_ADMIN_DETAILS,
"WEBUI_URL": request.app.state.config.WEBUI_URL,
"ENABLE_SIGNUP": request.app.state.config.ENABLE_SIGNUP,
"ENABLE_API_KEY": request.app.state.config.ENABLE_API_KEY,
"ENABLE_CHANNELS": request.app.state.config.ENABLE_CHANNELS,
"DEFAULT_USER_ROLE": request.app.state.config.DEFAULT_USER_ROLE,
"JWT_EXPIRES_IN": request.app.state.config.JWT_EXPIRES_IN,
"ENABLE_COMMUNITY_SHARING": request.app.state.config.ENABLE_COMMUNITY_SHARING,

View File

@ -0,0 +1,394 @@
import json
import logging
from typing import Optional
from fastapi import APIRouter, Depends, HTTPException, Request, status, BackgroundTasks
from pydantic import BaseModel
from open_webui.socket.main import sio, get_user_ids_from_room
from open_webui.models.users import Users, UserNameResponse
from open_webui.models.channels import Channels, ChannelModel, ChannelForm
from open_webui.models.messages import Messages, MessageModel, MessageForm
from open_webui.config import ENABLE_ADMIN_CHAT_ACCESS, ENABLE_ADMIN_EXPORT
from open_webui.constants import ERROR_MESSAGES
from open_webui.env import SRC_LOG_LEVELS
from open_webui.utils.auth import get_admin_user, get_verified_user
from open_webui.utils.access_control import has_access, get_users_with_access
from open_webui.utils.webhook import post_webhook
log = logging.getLogger(__name__)
log.setLevel(SRC_LOG_LEVELS["MODELS"])
router = APIRouter()
############################
# GetChatList
############################
@router.get("/", response_model=list[ChannelModel])
async def get_channels(user=Depends(get_verified_user)):
if user.role == "admin":
return Channels.get_channels()
else:
return Channels.get_channels_by_user_id(user.id)
############################
# CreateNewChannel
############################
@router.post("/create", response_model=Optional[ChannelModel])
async def create_new_channel(form_data: ChannelForm, user=Depends(get_admin_user)):
try:
channel = Channels.insert_new_channel(form_data, user.id)
return ChannelModel(**channel.model_dump())
except Exception as e:
log.exception(e)
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST, detail=ERROR_MESSAGES.DEFAULT()
)
############################
# GetChannelById
############################
@router.get("/{id}", response_model=Optional[ChannelModel])
async def get_channel_by_id(id: str, user=Depends(get_verified_user)):
channel = Channels.get_channel_by_id(id)
if not channel:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND, detail=ERROR_MESSAGES.NOT_FOUND
)
if user.role != "admin" and not has_access(
user.id, type="read", access_control=channel.access_control
):
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN, detail=ERROR_MESSAGES.DEFAULT()
)
return ChannelModel(**channel.model_dump())
############################
# UpdateChannelById
############################
@router.post("/{id}/update", response_model=Optional[ChannelModel])
async def update_channel_by_id(
id: str, form_data: ChannelForm, user=Depends(get_admin_user)
):
channel = Channels.get_channel_by_id(id)
if not channel:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND, detail=ERROR_MESSAGES.NOT_FOUND
)
try:
channel = Channels.update_channel_by_id(id, form_data)
return ChannelModel(**channel.model_dump())
except Exception as e:
log.exception(e)
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST, detail=ERROR_MESSAGES.DEFAULT()
)
############################
# DeleteChannelById
############################
@router.delete("/{id}/delete", response_model=bool)
async def delete_channel_by_id(id: str, user=Depends(get_admin_user)):
channel = Channels.get_channel_by_id(id)
if not channel:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND, detail=ERROR_MESSAGES.NOT_FOUND
)
try:
Channels.delete_channel_by_id(id)
return True
except Exception as e:
log.exception(e)
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST, detail=ERROR_MESSAGES.DEFAULT()
)
############################
# GetChannelMessages
############################
class MessageUserModel(MessageModel):
user: UserNameResponse
@router.get("/{id}/messages", response_model=list[MessageUserModel])
async def get_channel_messages(
id: str, skip: int = 0, limit: int = 50, user=Depends(get_verified_user)
):
channel = Channels.get_channel_by_id(id)
if not channel:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND, detail=ERROR_MESSAGES.NOT_FOUND
)
if user.role != "admin" and not has_access(
user.id, type="read", access_control=channel.access_control
):
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN, detail=ERROR_MESSAGES.DEFAULT()
)
message_list = Messages.get_messages_by_channel_id(id, skip, limit)
users = {}
messages = []
for message in message_list:
if message.user_id not in users:
user = Users.get_user_by_id(message.user_id)
users[message.user_id] = user
messages.append(
MessageUserModel(
**{
**message.model_dump(),
"user": UserNameResponse(**users[message.user_id].model_dump()),
}
)
)
return messages
############################
# PostNewMessage
############################
async def send_notification(webui_url, channel, message, active_user_ids):
users = get_users_with_access("read", channel.access_control)
for user in users:
if user.id in active_user_ids:
continue
else:
if user.settings:
webhook_url = user.settings.ui.get("notifications", {}).get(
"webhook_url", None
)
if webhook_url:
post_webhook(
webhook_url,
f"#{channel.name} - {webui_url}/channels/{channel.id}\n\n{message.content}",
{
"action": "channel",
"message": message.content,
"title": channel.name,
"url": f"{webui_url}/channels/{channel.id}",
},
)
@router.post("/{id}/messages/post", response_model=Optional[MessageModel])
async def post_new_message(
request: Request,
id: str,
form_data: MessageForm,
background_tasks: BackgroundTasks,
user=Depends(get_verified_user),
):
channel = Channels.get_channel_by_id(id)
if not channel:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND, detail=ERROR_MESSAGES.NOT_FOUND
)
if user.role != "admin" and not has_access(
user.id, type="read", access_control=channel.access_control
):
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN, detail=ERROR_MESSAGES.DEFAULT()
)
try:
message = Messages.insert_new_message(form_data, channel.id, user.id)
if message:
event_data = {
"channel_id": channel.id,
"message_id": message.id,
"data": {
"type": "message",
"data": {
**message.model_dump(),
"user": UserNameResponse(**user.model_dump()).model_dump(),
},
},
"user": UserNameResponse(**user.model_dump()).model_dump(),
"channel": channel.model_dump(),
}
await sio.emit(
"channel-events",
event_data,
to=f"channel:{channel.id}",
)
active_user_ids = get_user_ids_from_room(f"channel:{channel.id}")
background_tasks.add_task(
send_notification,
request.app.state.config.WEBUI_URL,
channel,
message,
active_user_ids,
)
return MessageModel(**message.model_dump())
except Exception as e:
log.exception(e)
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST, detail=ERROR_MESSAGES.DEFAULT()
)
############################
# UpdateMessageById
############################
@router.post(
"/{id}/messages/{message_id}/update", response_model=Optional[MessageModel]
)
async def update_message_by_id(
id: str, message_id: str, form_data: MessageForm, user=Depends(get_verified_user)
):
channel = Channels.get_channel_by_id(id)
if not channel:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND, detail=ERROR_MESSAGES.NOT_FOUND
)
if user.role != "admin" and not has_access(
user.id, type="read", access_control=channel.access_control
):
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN, detail=ERROR_MESSAGES.DEFAULT()
)
message = Messages.get_message_by_id(message_id)
if not message:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND, detail=ERROR_MESSAGES.NOT_FOUND
)
if message.channel_id != id:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST, detail=ERROR_MESSAGES.DEFAULT()
)
try:
message = Messages.update_message_by_id(message_id, form_data)
if message:
await sio.emit(
"channel-events",
{
"channel_id": channel.id,
"message_id": message.id,
"data": {
"type": "message:update",
"data": {
**message.model_dump(),
"user": UserNameResponse(**user.model_dump()).model_dump(),
},
},
"user": UserNameResponse(**user.model_dump()).model_dump(),
"channel": channel.model_dump(),
},
to=f"channel:{channel.id}",
)
return MessageModel(**message.model_dump())
except Exception as e:
log.exception(e)
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST, detail=ERROR_MESSAGES.DEFAULT()
)
############################
# DeleteMessageById
############################
@router.delete("/{id}/messages/{message_id}/delete", response_model=bool)
async def delete_message_by_id(
id: str, message_id: str, user=Depends(get_verified_user)
):
channel = Channels.get_channel_by_id(id)
if not channel:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND, detail=ERROR_MESSAGES.NOT_FOUND
)
if user.role != "admin" and not has_access(
user.id, type="read", access_control=channel.access_control
):
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN, detail=ERROR_MESSAGES.DEFAULT()
)
message = Messages.get_message_by_id(message_id)
if not message:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND, detail=ERROR_MESSAGES.NOT_FOUND
)
if message.channel_id != id:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST, detail=ERROR_MESSAGES.DEFAULT()
)
try:
Messages.delete_message_by_id(message_id)
await sio.emit(
"channel-events",
{
"channel_id": channel.id,
"message_id": message.id,
"data": {
"type": "message:delete",
"data": {
**message.model_dump(),
"user": UserNameResponse(**user.model_dump()).model_dump(),
},
},
"user": UserNameResponse(**user.model_dump()).model_dump(),
"channel": channel.model_dump(),
},
to=f"channel:{channel.id}",
)
return True
except Exception as e:
log.exception(e)
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST, detail=ERROR_MESSAGES.DEFAULT()
)

View File

@ -2,15 +2,15 @@ import json
import logging
from typing import Optional
from open_webui.apps.webui.models.chats import (
from open_webui.models.chats import (
ChatForm,
ChatImportForm,
ChatResponse,
Chats,
ChatTitleIdResponse,
)
from open_webui.apps.webui.models.tags import TagModel, Tags
from open_webui.apps.webui.models.folders import Folders
from open_webui.models.tags import TagModel, Tags
from open_webui.models.folders import Folders
from open_webui.config import ENABLE_ADMIN_CHAT_ACCESS, ENABLE_ADMIN_EXPORT
from open_webui.constants import ERROR_MESSAGES
@ -19,7 +19,7 @@ from fastapi import APIRouter, Depends, HTTPException, Request, status
from pydantic import BaseModel
from open_webui.utils.utils import get_admin_user, get_verified_user
from open_webui.utils.auth import get_admin_user, get_verified_user
from open_webui.utils.access_control import has_permission
log = logging.getLogger(__name__)
@ -607,7 +607,6 @@ async def add_tag_by_id_and_tag_name(
detail=ERROR_MESSAGES.DEFAULT("Tag name cannot be 'None'"),
)
print(tags, tag_id)
if tag_id not in tags:
Chats.add_chat_tag_by_id_and_user_id_and_tag_name(
id, user.id, form_data.name

View File

@ -3,7 +3,7 @@ from pydantic import BaseModel
from typing import Optional
from open_webui.utils.utils import get_admin_user, get_verified_user
from open_webui.utils.auth import get_admin_user, get_verified_user
from open_webui.config import get_config, save_config
from open_webui.config import BannerModel

View File

@ -2,8 +2,8 @@ from typing import Optional
from fastapi import APIRouter, Depends, HTTPException, status, Request
from pydantic import BaseModel
from open_webui.apps.webui.models.users import Users, UserModel
from open_webui.apps.webui.models.feedbacks import (
from open_webui.models.users import Users, UserModel
from open_webui.models.feedbacks import (
FeedbackModel,
FeedbackResponse,
FeedbackForm,
@ -11,7 +11,7 @@ from open_webui.apps.webui.models.feedbacks import (
)
from open_webui.constants import ERROR_MESSAGES
from open_webui.utils.utils import get_admin_user, get_verified_user
from open_webui.utils.auth import get_admin_user, get_verified_user
router = APIRouter()

View File

@ -5,27 +5,28 @@ from pathlib import Path
from typing import Optional
from pydantic import BaseModel
import mimetypes
from urllib.parse import quote
from open_webui.storage.provider import Storage
from open_webui.apps.webui.models.files import (
from open_webui.models.files import (
FileForm,
FileModel,
FileModelResponse,
Files,
)
from open_webui.apps.retrieval.main import process_file, ProcessFileForm
from open_webui.routers.retrieval import process_file, ProcessFileForm
from open_webui.config import UPLOAD_DIR
from open_webui.env import SRC_LOG_LEVELS
from open_webui.constants import ERROR_MESSAGES
from fastapi import APIRouter, Depends, File, HTTPException, UploadFile, status
from fastapi import APIRouter, Depends, File, HTTPException, UploadFile, status, Request
from fastapi.responses import FileResponse, StreamingResponse
from open_webui.utils.utils import get_admin_user, get_verified_user
from open_webui.utils.auth import get_admin_user, get_verified_user
log = logging.getLogger(__name__)
log.setLevel(SRC_LOG_LEVELS["MODELS"])
@ -39,7 +40,9 @@ router = APIRouter()
@router.post("/", response_model=FileModelResponse)
def upload_file(file: UploadFile = File(...), user=Depends(get_verified_user)):
def upload_file(
request: Request, file: UploadFile = File(...), user=Depends(get_verified_user)
):
log.info(f"file.content_type: {file.content_type}")
try:
unsanitized_filename = file.filename
@ -68,7 +71,7 @@ def upload_file(file: UploadFile = File(...), user=Depends(get_verified_user)):
)
try:
process_file(ProcessFileForm(file_id=id))
process_file(request, ProcessFileForm(file_id=id))
file_item = Files.get_file_by_id(id=id)
except Exception as e:
log.exception(e)
@ -183,13 +186,15 @@ class ContentForm(BaseModel):
@router.post("/{id}/data/content/update")
async def update_file_data_content_by_id(
id: str, form_data: ContentForm, user=Depends(get_verified_user)
request: Request, id: str, form_data: ContentForm, user=Depends(get_verified_user)
):
file = Files.get_file_by_id(id)
if file and (file.user_id == user.id or user.role == "admin"):
try:
process_file(ProcessFileForm(file_id=id, content=form_data.content))
process_file(
request, ProcessFileForm(file_id=id, content=form_data.content)
)
file = Files.get_file_by_id(id=id)
except Exception as e:
log.exception(e)
@ -218,11 +223,22 @@ async def get_file_content_by_id(id: str, user=Depends(get_verified_user)):
# Check if the file already exists in the cache
if file_path.is_file():
print(f"file_path: {file_path}")
headers = {
"Content-Disposition": f'attachment; filename="{file.meta.get("name", file.filename)}"'
}
# Handle Unicode filenames
filename = file.meta.get("name", file.filename)
encoded_filename = quote(filename) # RFC5987 encoding
headers = {}
if file.meta.get("content_type") not in [
"application/pdf",
"text/plain",
]:
headers = {
**headers,
"Content-Disposition": f"attachment; filename*=UTF-8''{encoded_filename}",
}
return FileResponse(file_path, headers=headers)
else:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
@ -279,16 +295,20 @@ async def get_file_content_by_id(id: str, user=Depends(get_verified_user)):
if file and (file.user_id == user.id or user.role == "admin"):
file_path = file.path
# Handle Unicode filenames
filename = file.meta.get("name", file.filename)
encoded_filename = quote(filename) # RFC5987 encoding
headers = {
"Content-Disposition": f"attachment; filename*=UTF-8''{encoded_filename}"
}
if file_path:
file_path = Storage.get_file(file_path)
file_path = Path(file_path)
# Check if the file already exists in the cache
if file_path.is_file():
print(f"file_path: {file_path}")
headers = {
"Content-Disposition": f'attachment; filename="{file.meta.get("name", file.filename)}"'
}
return FileResponse(file_path, headers=headers)
else:
raise HTTPException(
@ -307,7 +327,7 @@ async def get_file_content_by_id(id: str, user=Depends(get_verified_user)):
return StreamingResponse(
generator(),
media_type="text/plain",
headers={"Content-Disposition": f"attachment; filename={file_name}"},
headers=headers,
)
else:
raise HTTPException(

View File

@ -8,12 +8,12 @@ from pydantic import BaseModel
import mimetypes
from open_webui.apps.webui.models.folders import (
from open_webui.models.folders import (
FolderForm,
FolderModel,
Folders,
)
from open_webui.apps.webui.models.chats import Chats
from open_webui.models.chats import Chats
from open_webui.config import UPLOAD_DIR
from open_webui.env import SRC_LOG_LEVELS
@ -24,7 +24,7 @@ from fastapi import APIRouter, Depends, File, HTTPException, UploadFile, status
from fastapi.responses import FileResponse, StreamingResponse
from open_webui.utils.utils import get_admin_user, get_verified_user
from open_webui.utils.auth import get_admin_user, get_verified_user
log = logging.getLogger(__name__)
log.setLevel(SRC_LOG_LEVELS["MODELS"])

View File

@ -2,17 +2,17 @@ import os
from pathlib import Path
from typing import Optional
from open_webui.apps.webui.models.functions import (
from open_webui.models.functions import (
FunctionForm,
FunctionModel,
FunctionResponse,
Functions,
)
from open_webui.apps.webui.utils import load_function_module_by_id, replace_imports
from open_webui.utils.plugin import load_function_module_by_id, replace_imports
from open_webui.config import CACHE_DIR
from open_webui.constants import ERROR_MESSAGES
from fastapi import APIRouter, Depends, HTTPException, Request, status
from open_webui.utils.utils import get_admin_user, get_verified_user
from open_webui.utils.auth import get_admin_user, get_verified_user
router = APIRouter()

View File

@ -2,7 +2,7 @@ import os
from pathlib import Path
from typing import Optional
from open_webui.apps.webui.models.groups import (
from open_webui.models.groups import (
Groups,
GroupForm,
GroupUpdateForm,
@ -12,7 +12,7 @@ from open_webui.apps.webui.models.groups import (
from open_webui.config import CACHE_DIR
from open_webui.constants import ERROR_MESSAGES
from fastapi import APIRouter, Depends, HTTPException, Request, status
from open_webui.utils.utils import get_admin_user, get_verified_user
from open_webui.utils.auth import get_admin_user, get_verified_user
router = APIRouter()

View File

@ -9,38 +9,24 @@ from pathlib import Path
from typing import Optional
import requests
from open_webui.apps.images.utils.comfyui import (
from fastapi import Depends, FastAPI, HTTPException, Request, APIRouter
from fastapi.middleware.cors import CORSMiddleware
from pydantic import BaseModel
from open_webui.config import CACHE_DIR
from open_webui.constants import ERROR_MESSAGES
from open_webui.env import ENV, SRC_LOG_LEVELS, ENABLE_FORWARD_USER_INFO_HEADERS
from open_webui.utils.auth import get_admin_user, get_verified_user
from open_webui.utils.images.comfyui import (
ComfyUIGenerateImageForm,
ComfyUIWorkflow,
comfyui_generate_image,
)
from open_webui.config import (
AUTOMATIC1111_API_AUTH,
AUTOMATIC1111_BASE_URL,
AUTOMATIC1111_CFG_SCALE,
AUTOMATIC1111_SAMPLER,
AUTOMATIC1111_SCHEDULER,
CACHE_DIR,
COMFYUI_BASE_URL,
COMFYUI_WORKFLOW,
COMFYUI_WORKFLOW_NODES,
CORS_ALLOW_ORIGIN,
ENABLE_IMAGE_GENERATION,
IMAGE_GENERATION_ENGINE,
IMAGE_GENERATION_MODEL,
IMAGE_SIZE,
IMAGE_STEPS,
IMAGES_OPENAI_API_BASE_URL,
IMAGES_OPENAI_API_KEY,
AppConfig,
)
from open_webui.constants import ERROR_MESSAGES
from open_webui.env import ENV, SRC_LOG_LEVELS, ENABLE_FORWARD_USER_INFO_HEADERS
from fastapi import Depends, FastAPI, HTTPException, Request
from fastapi.middleware.cors import CORSMiddleware
from pydantic import BaseModel
from open_webui.utils.utils import get_admin_user, get_verified_user
log = logging.getLogger(__name__)
log.setLevel(SRC_LOG_LEVELS["IMAGES"])
@ -48,63 +34,31 @@ log.setLevel(SRC_LOG_LEVELS["IMAGES"])
IMAGE_CACHE_DIR = Path(CACHE_DIR).joinpath("./image/generations/")
IMAGE_CACHE_DIR.mkdir(parents=True, exist_ok=True)
app = FastAPI(
docs_url="/docs" if ENV == "dev" else None,
openapi_url="/openapi.json" if ENV == "dev" else None,
redoc_url=None,
)
app.add_middleware(
CORSMiddleware,
allow_origins=CORS_ALLOW_ORIGIN,
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
app.state.config = AppConfig()
app.state.config.ENGINE = IMAGE_GENERATION_ENGINE
app.state.config.ENABLED = ENABLE_IMAGE_GENERATION
app.state.config.OPENAI_API_BASE_URL = IMAGES_OPENAI_API_BASE_URL
app.state.config.OPENAI_API_KEY = IMAGES_OPENAI_API_KEY
app.state.config.MODEL = IMAGE_GENERATION_MODEL
app.state.config.AUTOMATIC1111_BASE_URL = AUTOMATIC1111_BASE_URL
app.state.config.AUTOMATIC1111_API_AUTH = AUTOMATIC1111_API_AUTH
app.state.config.AUTOMATIC1111_CFG_SCALE = AUTOMATIC1111_CFG_SCALE
app.state.config.AUTOMATIC1111_SAMPLER = AUTOMATIC1111_SAMPLER
app.state.config.AUTOMATIC1111_SCHEDULER = AUTOMATIC1111_SCHEDULER
app.state.config.COMFYUI_BASE_URL = COMFYUI_BASE_URL
app.state.config.COMFYUI_WORKFLOW = COMFYUI_WORKFLOW
app.state.config.COMFYUI_WORKFLOW_NODES = COMFYUI_WORKFLOW_NODES
app.state.config.IMAGE_SIZE = IMAGE_SIZE
app.state.config.IMAGE_STEPS = IMAGE_STEPS
router = APIRouter()
@app.get("/config")
@router.get("/config")
async def get_config(request: Request, user=Depends(get_admin_user)):
return {
"enabled": app.state.config.ENABLED,
"engine": app.state.config.ENGINE,
"enabled": request.app.state.config.ENABLE_IMAGE_GENERATION,
"engine": request.app.state.config.IMAGE_GENERATION_ENGINE,
"openai": {
"OPENAI_API_BASE_URL": app.state.config.OPENAI_API_BASE_URL,
"OPENAI_API_KEY": app.state.config.OPENAI_API_KEY,
"OPENAI_API_BASE_URL": request.app.state.config.IMAGES_OPENAI_API_BASE_URL,
"OPENAI_API_KEY": request.app.state.config.IMAGES_OPENAI_API_KEY,
},
"automatic1111": {
"AUTOMATIC1111_BASE_URL": app.state.config.AUTOMATIC1111_BASE_URL,
"AUTOMATIC1111_API_AUTH": app.state.config.AUTOMATIC1111_API_AUTH,
"AUTOMATIC1111_CFG_SCALE": app.state.config.AUTOMATIC1111_CFG_SCALE,
"AUTOMATIC1111_SAMPLER": app.state.config.AUTOMATIC1111_SAMPLER,
"AUTOMATIC1111_SCHEDULER": app.state.config.AUTOMATIC1111_SCHEDULER,
"AUTOMATIC1111_BASE_URL": request.app.state.config.AUTOMATIC1111_BASE_URL,
"AUTOMATIC1111_API_AUTH": request.app.state.config.AUTOMATIC1111_API_AUTH,
"AUTOMATIC1111_CFG_SCALE": request.app.state.config.AUTOMATIC1111_CFG_SCALE,
"AUTOMATIC1111_SAMPLER": request.app.state.config.AUTOMATIC1111_SAMPLER,
"AUTOMATIC1111_SCHEDULER": request.app.state.config.AUTOMATIC1111_SCHEDULER,
},
"comfyui": {
"COMFYUI_BASE_URL": app.state.config.COMFYUI_BASE_URL,
"COMFYUI_WORKFLOW": app.state.config.COMFYUI_WORKFLOW,
"COMFYUI_WORKFLOW_NODES": app.state.config.COMFYUI_WORKFLOW_NODES,
"COMFYUI_BASE_URL": request.app.state.config.COMFYUI_BASE_URL,
"COMFYUI_API_KEY": request.app.state.config.COMFYUI_API_KEY,
"COMFYUI_WORKFLOW": request.app.state.config.COMFYUI_WORKFLOW,
"COMFYUI_WORKFLOW_NODES": request.app.state.config.COMFYUI_WORKFLOW_NODES,
},
}
@ -117,13 +71,14 @@ class OpenAIConfigForm(BaseModel):
class Automatic1111ConfigForm(BaseModel):
AUTOMATIC1111_BASE_URL: str
AUTOMATIC1111_API_AUTH: str
AUTOMATIC1111_CFG_SCALE: Optional[str]
AUTOMATIC1111_CFG_SCALE: Optional[str | float | int]
AUTOMATIC1111_SAMPLER: Optional[str]
AUTOMATIC1111_SCHEDULER: Optional[str]
class ComfyUIConfigForm(BaseModel):
COMFYUI_BASE_URL: str
COMFYUI_API_KEY: str
COMFYUI_WORKFLOW: str
COMFYUI_WORKFLOW_NODES: list[dict]
@ -136,133 +91,157 @@ class ConfigForm(BaseModel):
comfyui: ComfyUIConfigForm
@app.post("/config/update")
async def update_config(form_data: ConfigForm, user=Depends(get_admin_user)):
app.state.config.ENGINE = form_data.engine
app.state.config.ENABLED = form_data.enabled
@router.post("/config/update")
async def update_config(
request: Request, form_data: ConfigForm, user=Depends(get_admin_user)
):
request.app.state.config.IMAGE_GENERATION_ENGINE = form_data.engine
request.app.state.config.ENABLE_IMAGE_GENERATION = form_data.enabled
app.state.config.OPENAI_API_BASE_URL = form_data.openai.OPENAI_API_BASE_URL
app.state.config.OPENAI_API_KEY = form_data.openai.OPENAI_API_KEY
request.app.state.config.IMAGES_OPENAI_API_BASE_URL = (
form_data.openai.OPENAI_API_BASE_URL
)
request.app.state.config.IMAGES_OPENAI_API_KEY = form_data.openai.OPENAI_API_KEY
app.state.config.AUTOMATIC1111_BASE_URL = (
request.app.state.config.AUTOMATIC1111_BASE_URL = (
form_data.automatic1111.AUTOMATIC1111_BASE_URL
)
app.state.config.AUTOMATIC1111_API_AUTH = (
request.app.state.config.AUTOMATIC1111_API_AUTH = (
form_data.automatic1111.AUTOMATIC1111_API_AUTH
)
app.state.config.AUTOMATIC1111_CFG_SCALE = (
request.app.state.config.AUTOMATIC1111_CFG_SCALE = (
float(form_data.automatic1111.AUTOMATIC1111_CFG_SCALE)
if form_data.automatic1111.AUTOMATIC1111_CFG_SCALE
else None
)
app.state.config.AUTOMATIC1111_SAMPLER = (
request.app.state.config.AUTOMATIC1111_SAMPLER = (
form_data.automatic1111.AUTOMATIC1111_SAMPLER
if form_data.automatic1111.AUTOMATIC1111_SAMPLER
else None
)
app.state.config.AUTOMATIC1111_SCHEDULER = (
request.app.state.config.AUTOMATIC1111_SCHEDULER = (
form_data.automatic1111.AUTOMATIC1111_SCHEDULER
if form_data.automatic1111.AUTOMATIC1111_SCHEDULER
else None
)
app.state.config.COMFYUI_BASE_URL = form_data.comfyui.COMFYUI_BASE_URL.strip("/")
app.state.config.COMFYUI_WORKFLOW = form_data.comfyui.COMFYUI_WORKFLOW
app.state.config.COMFYUI_WORKFLOW_NODES = form_data.comfyui.COMFYUI_WORKFLOW_NODES
request.app.state.config.COMFYUI_BASE_URL = (
form_data.comfyui.COMFYUI_BASE_URL.strip("/")
)
request.app.state.config.COMFYUI_WORKFLOW = form_data.comfyui.COMFYUI_WORKFLOW
request.app.state.config.COMFYUI_WORKFLOW_NODES = (
form_data.comfyui.COMFYUI_WORKFLOW_NODES
)
return {
"enabled": app.state.config.ENABLED,
"engine": app.state.config.ENGINE,
"enabled": request.app.state.config.ENABLE_IMAGE_GENERATION,
"engine": request.app.state.config.IMAGE_GENERATION_ENGINE,
"openai": {
"OPENAI_API_BASE_URL": app.state.config.OPENAI_API_BASE_URL,
"OPENAI_API_KEY": app.state.config.OPENAI_API_KEY,
"OPENAI_API_BASE_URL": request.app.state.config.IMAGES_OPENAI_API_BASE_URL,
"OPENAI_API_KEY": request.app.state.config.IMAGES_OPENAI_API_KEY,
},
"automatic1111": {
"AUTOMATIC1111_BASE_URL": app.state.config.AUTOMATIC1111_BASE_URL,
"AUTOMATIC1111_API_AUTH": app.state.config.AUTOMATIC1111_API_AUTH,
"AUTOMATIC1111_CFG_SCALE": app.state.config.AUTOMATIC1111_CFG_SCALE,
"AUTOMATIC1111_SAMPLER": app.state.config.AUTOMATIC1111_SAMPLER,
"AUTOMATIC1111_SCHEDULER": app.state.config.AUTOMATIC1111_SCHEDULER,
"AUTOMATIC1111_BASE_URL": request.app.state.config.AUTOMATIC1111_BASE_URL,
"AUTOMATIC1111_API_AUTH": request.app.state.config.AUTOMATIC1111_API_AUTH,
"AUTOMATIC1111_CFG_SCALE": request.app.state.config.AUTOMATIC1111_CFG_SCALE,
"AUTOMATIC1111_SAMPLER": request.app.state.config.AUTOMATIC1111_SAMPLER,
"AUTOMATIC1111_SCHEDULER": request.app.state.config.AUTOMATIC1111_SCHEDULER,
},
"comfyui": {
"COMFYUI_BASE_URL": app.state.config.COMFYUI_BASE_URL,
"COMFYUI_WORKFLOW": app.state.config.COMFYUI_WORKFLOW,
"COMFYUI_WORKFLOW_NODES": app.state.config.COMFYUI_WORKFLOW_NODES,
"COMFYUI_BASE_URL": request.app.state.config.COMFYUI_BASE_URL,
"COMFYUI_API_KEY": request.app.state.config.COMFYUI_API_KEY,
"COMFYUI_WORKFLOW": request.app.state.config.COMFYUI_WORKFLOW,
"COMFYUI_WORKFLOW_NODES": request.app.state.config.COMFYUI_WORKFLOW_NODES,
},
}
def get_automatic1111_api_auth():
if app.state.config.AUTOMATIC1111_API_AUTH is None:
def get_automatic1111_api_auth(request: Request):
if request.app.state.config.AUTOMATIC1111_API_AUTH is None:
return ""
else:
auth1111_byte_string = app.state.config.AUTOMATIC1111_API_AUTH.encode("utf-8")
auth1111_byte_string = request.app.state.config.AUTOMATIC1111_API_AUTH.encode(
"utf-8"
)
auth1111_base64_encoded_bytes = base64.b64encode(auth1111_byte_string)
auth1111_base64_encoded_string = auth1111_base64_encoded_bytes.decode("utf-8")
return f"Basic {auth1111_base64_encoded_string}"
@app.get("/config/url/verify")
async def verify_url(user=Depends(get_admin_user)):
if app.state.config.ENGINE == "automatic1111":
@router.get("/config/url/verify")
async def verify_url(request: Request, user=Depends(get_admin_user)):
if request.app.state.config.IMAGE_GENERATION_ENGINE == "automatic1111":
try:
r = requests.get(
url=f"{app.state.config.AUTOMATIC1111_BASE_URL}/sdapi/v1/options",
headers={"authorization": get_automatic1111_api_auth()},
url=f"{request.app.state.config.AUTOMATIC1111_BASE_URL}/sdapi/v1/options",
headers={"authorization": get_automatic1111_api_auth(request)},
)
r.raise_for_status()
return True
except Exception:
app.state.config.ENABLED = False
request.app.state.config.ENABLE_IMAGE_GENERATION = False
raise HTTPException(status_code=400, detail=ERROR_MESSAGES.INVALID_URL)
elif app.state.config.ENGINE == "comfyui":
elif request.app.state.config.IMAGE_GENERATION_ENGINE == "comfyui":
try:
r = requests.get(url=f"{app.state.config.COMFYUI_BASE_URL}/object_info")
r = requests.get(
url=f"{request.app.state.config.COMFYUI_BASE_URL}/object_info"
)
r.raise_for_status()
return True
except Exception:
app.state.config.ENABLED = False
request.app.state.config.ENABLE_IMAGE_GENERATION = False
raise HTTPException(status_code=400, detail=ERROR_MESSAGES.INVALID_URL)
else:
return True
def set_image_model(model: str):
def set_image_model(request: Request, model: str):
log.info(f"Setting image model to {model}")
app.state.config.MODEL = model
if app.state.config.ENGINE in ["", "automatic1111"]:
request.app.state.config.IMAGE_GENERATION_MODEL = model
if request.app.state.config.IMAGE_GENERATION_ENGINE in ["", "automatic1111"]:
api_auth = get_automatic1111_api_auth()
r = requests.get(
url=f"{app.state.config.AUTOMATIC1111_BASE_URL}/sdapi/v1/options",
url=f"{request.app.state.config.AUTOMATIC1111_BASE_URL}/sdapi/v1/options",
headers={"authorization": api_auth},
)
options = r.json()
if model != options["sd_model_checkpoint"]:
options["sd_model_checkpoint"] = model
r = requests.post(
url=f"{app.state.config.AUTOMATIC1111_BASE_URL}/sdapi/v1/options",
url=f"{request.app.state.config.AUTOMATIC1111_BASE_URL}/sdapi/v1/options",
json=options,
headers={"authorization": api_auth},
)
return app.state.config.MODEL
return request.app.state.config.IMAGE_GENERATION_MODEL
def get_image_model():
if app.state.config.ENGINE == "openai":
return app.state.config.MODEL if app.state.config.MODEL else "dall-e-2"
elif app.state.config.ENGINE == "comfyui":
return app.state.config.MODEL if app.state.config.MODEL else ""
elif app.state.config.ENGINE == "automatic1111" or app.state.config.ENGINE == "":
def get_image_model(request):
if request.app.state.config.IMAGE_GENERATION_ENGINE == "openai":
return (
request.app.state.config.IMAGE_GENERATION_MODEL
if request.app.state.config.IMAGE_GENERATION_MODEL
else "dall-e-2"
)
elif request.app.state.config.IMAGE_GENERATION_ENGINE == "comfyui":
return (
request.app.state.config.IMAGE_GENERATION_MODEL
if request.app.state.config.IMAGE_GENERATION_MODEL
else ""
)
elif (
request.app.state.config.IMAGE_GENERATION_ENGINE == "automatic1111"
or request.app.state.config.IMAGE_GENERATION_ENGINE == ""
):
try:
r = requests.get(
url=f"{app.state.config.AUTOMATIC1111_BASE_URL}/sdapi/v1/options",
url=f"{request.app.state.config.AUTOMATIC1111_BASE_URL}/sdapi/v1/options",
headers={"authorization": get_automatic1111_api_auth()},
)
options = r.json()
return options["sd_model_checkpoint"]
except Exception as e:
app.state.config.ENABLED = False
request.app.state.config.ENABLE_IMAGE_GENERATION = False
raise HTTPException(status_code=400, detail=ERROR_MESSAGES.DEFAULT(e))
@ -272,23 +251,25 @@ class ImageConfigForm(BaseModel):
IMAGE_STEPS: int
@app.get("/image/config")
async def get_image_config(user=Depends(get_admin_user)):
@router.get("/image/config")
async def get_image_config(request: Request, user=Depends(get_admin_user)):
return {
"MODEL": app.state.config.MODEL,
"IMAGE_SIZE": app.state.config.IMAGE_SIZE,
"IMAGE_STEPS": app.state.config.IMAGE_STEPS,
"MODEL": request.app.state.config.IMAGE_GENERATION_MODEL,
"IMAGE_SIZE": request.app.state.config.IMAGE_SIZE,
"IMAGE_STEPS": request.app.state.config.IMAGE_STEPS,
}
@app.post("/image/config/update")
async def update_image_config(form_data: ImageConfigForm, user=Depends(get_admin_user)):
@router.post("/image/config/update")
async def update_image_config(
request: Request, form_data: ImageConfigForm, user=Depends(get_admin_user)
):
set_image_model(form_data.MODEL)
set_image_model(request, form_data.MODEL)
pattern = r"^\d+x\d+$"
if re.match(pattern, form_data.IMAGE_SIZE):
app.state.config.IMAGE_SIZE = form_data.IMAGE_SIZE
request.app.state.config.IMAGE_SIZE = form_data.IMAGE_SIZE
else:
raise HTTPException(
status_code=400,
@ -296,7 +277,7 @@ async def update_image_config(form_data: ImageConfigForm, user=Depends(get_admin
)
if form_data.IMAGE_STEPS >= 0:
app.state.config.IMAGE_STEPS = form_data.IMAGE_STEPS
request.app.state.config.IMAGE_STEPS = form_data.IMAGE_STEPS
else:
raise HTTPException(
status_code=400,
@ -304,29 +285,35 @@ async def update_image_config(form_data: ImageConfigForm, user=Depends(get_admin
)
return {
"MODEL": app.state.config.MODEL,
"IMAGE_SIZE": app.state.config.IMAGE_SIZE,
"IMAGE_STEPS": app.state.config.IMAGE_STEPS,
"MODEL": request.app.state.config.IMAGE_GENERATION_MODEL,
"IMAGE_SIZE": request.app.state.config.IMAGE_SIZE,
"IMAGE_STEPS": request.app.state.config.IMAGE_STEPS,
}
@app.get("/models")
def get_models(user=Depends(get_verified_user)):
@router.get("/models")
def get_models(request: Request, user=Depends(get_verified_user)):
try:
if app.state.config.ENGINE == "openai":
if request.app.state.config.IMAGE_GENERATION_ENGINE == "openai":
return [
{"id": "dall-e-2", "name": "DALL·E 2"},
{"id": "dall-e-3", "name": "DALL·E 3"},
]
elif app.state.config.ENGINE == "comfyui":
elif request.app.state.config.IMAGE_GENERATION_ENGINE == "comfyui":
# TODO - get models from comfyui
r = requests.get(url=f"{app.state.config.COMFYUI_BASE_URL}/object_info")
headers = {
"Authorization": f"Bearer {request.app.state.config.COMFYUI_API_KEY}"
}
r = requests.get(
url=f"{request.app.state.config.COMFYUI_BASE_URL}/object_info",
headers=headers,
)
info = r.json()
workflow = json.loads(app.state.config.COMFYUI_WORKFLOW)
workflow = json.loads(request.app.state.config.COMFYUI_WORKFLOW)
model_node_id = None
for node in app.state.config.COMFYUI_WORKFLOW_NODES:
for node in request.app.state.config.COMFYUI_WORKFLOW_NODES:
if node["type"] == "model":
if node["node_ids"]:
model_node_id = node["node_ids"][0]
@ -362,10 +349,11 @@ def get_models(user=Depends(get_verified_user)):
)
)
elif (
app.state.config.ENGINE == "automatic1111" or app.state.config.ENGINE == ""
request.app.state.config.IMAGE_GENERATION_ENGINE == "automatic1111"
or request.app.state.config.IMAGE_GENERATION_ENGINE == ""
):
r = requests.get(
url=f"{app.state.config.AUTOMATIC1111_BASE_URL}/sdapi/v1/sd-models",
url=f"{request.app.state.config.AUTOMATIC1111_BASE_URL}/sdapi/v1/sd-models",
headers={"authorization": get_automatic1111_api_auth()},
)
models = r.json()
@ -376,7 +364,7 @@ def get_models(user=Depends(get_verified_user)):
)
)
except Exception as e:
app.state.config.ENABLED = False
request.app.state.config.ENABLE_IMAGE_GENERATION = False
raise HTTPException(status_code=400, detail=ERROR_MESSAGES.DEFAULT(e))
@ -448,18 +436,21 @@ def save_url_image(url):
return None
@app.post("/generations")
@router.post("/generations")
async def image_generations(
request: Request,
form_data: GenerateImageForm,
user=Depends(get_verified_user),
):
width, height = tuple(map(int, app.state.config.IMAGE_SIZE.split("x")))
width, height = tuple(map(int, request.app.state.config.IMAGE_SIZE.split("x")))
r = None
try:
if app.state.config.ENGINE == "openai":
if request.app.state.config.IMAGE_GENERATION_ENGINE == "openai":
headers = {}
headers["Authorization"] = f"Bearer {app.state.config.OPENAI_API_KEY}"
headers["Authorization"] = (
f"Bearer {request.app.state.config.IMAGES_OPENAI_API_KEY}"
)
headers["Content-Type"] = "application/json"
if ENABLE_FORWARD_USER_INFO_HEADERS:
@ -470,14 +461,16 @@ async def image_generations(
data = {
"model": (
app.state.config.MODEL
if app.state.config.MODEL != ""
request.app.state.config.IMAGE_GENERATION_MODEL
if request.app.state.config.IMAGE_GENERATION_MODEL != ""
else "dall-e-2"
),
"prompt": form_data.prompt,
"n": form_data.n,
"size": (
form_data.size if form_data.size else app.state.config.IMAGE_SIZE
form_data.size
if form_data.size
else request.app.state.config.IMAGE_SIZE
),
"response_format": "b64_json",
}
@ -485,7 +478,7 @@ async def image_generations(
# Use asyncio.to_thread for the requests.post call
r = await asyncio.to_thread(
requests.post,
url=f"{app.state.config.OPENAI_API_BASE_URL}/images/generations",
url=f"{request.app.state.config.IMAGES_OPENAI_API_BASE_URL}/images/generations",
json=data,
headers=headers,
)
@ -505,7 +498,7 @@ async def image_generations(
return images
elif app.state.config.ENGINE == "comfyui":
elif request.app.state.config.IMAGE_GENERATION_ENGINE == "comfyui":
data = {
"prompt": form_data.prompt,
"width": width,
@ -513,8 +506,8 @@ async def image_generations(
"n": form_data.n,
}
if app.state.config.IMAGE_STEPS is not None:
data["steps"] = app.state.config.IMAGE_STEPS
if request.app.state.config.IMAGE_STEPS is not None:
data["steps"] = request.app.state.config.IMAGE_STEPS
if form_data.negative_prompt is not None:
data["negative_prompt"] = form_data.negative_prompt
@ -523,18 +516,19 @@ async def image_generations(
**{
"workflow": ComfyUIWorkflow(
**{
"workflow": app.state.config.COMFYUI_WORKFLOW,
"nodes": app.state.config.COMFYUI_WORKFLOW_NODES,
"workflow": request.app.state.config.COMFYUI_WORKFLOW,
"nodes": request.app.state.config.COMFYUI_WORKFLOW_NODES,
}
),
**data,
}
)
res = await comfyui_generate_image(
app.state.config.MODEL,
request.app.state.config.IMAGE_GENERATION_MODEL,
form_data,
user.id,
app.state.config.COMFYUI_BASE_URL,
request.app.state.config.COMFYUI_BASE_URL,
request.app.state.config.COMFYUI_API_KEY,
)
log.debug(f"res: {res}")
@ -551,7 +545,8 @@ async def image_generations(
log.debug(f"images: {images}")
return images
elif (
app.state.config.ENGINE == "automatic1111" or app.state.config.ENGINE == ""
request.app.state.config.IMAGE_GENERATION_ENGINE == "automatic1111"
or request.app.state.config.IMAGE_GENERATION_ENGINE == ""
):
if form_data.model:
set_image_model(form_data.model)
@ -563,25 +558,25 @@ async def image_generations(
"height": height,
}
if app.state.config.IMAGE_STEPS is not None:
data["steps"] = app.state.config.IMAGE_STEPS
if request.app.state.config.IMAGE_STEPS is not None:
data["steps"] = request.app.state.config.IMAGE_STEPS
if form_data.negative_prompt is not None:
data["negative_prompt"] = form_data.negative_prompt
if app.state.config.AUTOMATIC1111_CFG_SCALE:
data["cfg_scale"] = app.state.config.AUTOMATIC1111_CFG_SCALE
if request.app.state.config.AUTOMATIC1111_CFG_SCALE:
data["cfg_scale"] = request.app.state.config.AUTOMATIC1111_CFG_SCALE
if app.state.config.AUTOMATIC1111_SAMPLER:
data["sampler_name"] = app.state.config.AUTOMATIC1111_SAMPLER
if request.app.state.config.AUTOMATIC1111_SAMPLER:
data["sampler_name"] = request.app.state.config.AUTOMATIC1111_SAMPLER
if app.state.config.AUTOMATIC1111_SCHEDULER:
data["scheduler"] = app.state.config.AUTOMATIC1111_SCHEDULER
if request.app.state.config.AUTOMATIC1111_SCHEDULER:
data["scheduler"] = request.app.state.config.AUTOMATIC1111_SCHEDULER
# Use asyncio.to_thread for the requests.post call
r = await asyncio.to_thread(
requests.post,
url=f"{app.state.config.AUTOMATIC1111_BASE_URL}/sdapi/v1/txt2img",
url=f"{request.app.state.config.AUTOMATIC1111_BASE_URL}/sdapi/v1/txt2img",
json=data,
headers={"authorization": get_automatic1111_api_auth()},
)

View File

@ -1,22 +1,26 @@
import json
from typing import Optional, Union
from typing import List, Optional
from pydantic import BaseModel
from fastapi import APIRouter, Depends, HTTPException, status, Request
import logging
from open_webui.apps.webui.models.knowledge import (
from open_webui.models.knowledge import (
Knowledges,
KnowledgeForm,
KnowledgeResponse,
KnowledgeUserResponse,
)
from open_webui.apps.webui.models.files import Files, FileModel
from open_webui.apps.retrieval.vector.connector import VECTOR_DB_CLIENT
from open_webui.apps.retrieval.main import process_file, ProcessFileForm
from open_webui.models.files import Files, FileModel
from open_webui.retrieval.vector.connector import VECTOR_DB_CLIENT
from open_webui.routers.retrieval import (
process_file,
ProcessFileForm,
process_files_batch,
BatchProcessFilesForm,
)
from open_webui.constants import ERROR_MESSAGES
from open_webui.utils.utils import get_admin_user, get_verified_user
from open_webui.utils.auth import get_verified_user
from open_webui.utils.access_control import has_access, has_permission
@ -242,6 +246,7 @@ class KnowledgeFileIdForm(BaseModel):
@router.post("/{id}/file/add", response_model=Optional[KnowledgeFilesResponse])
def add_file_to_knowledge_by_id(
request: Request,
id: str,
form_data: KnowledgeFileIdForm,
user=Depends(get_verified_user),
@ -274,7 +279,9 @@ def add_file_to_knowledge_by_id(
# Add content to the vector database
try:
process_file(ProcessFileForm(file_id=form_data.file_id, collection_name=id))
process_file(
request, ProcessFileForm(file_id=form_data.file_id, collection_name=id)
)
except Exception as e:
log.debug(e)
raise HTTPException(
@ -318,6 +325,7 @@ def add_file_to_knowledge_by_id(
@router.post("/{id}/file/update", response_model=Optional[KnowledgeFilesResponse])
def update_file_from_knowledge_by_id(
request: Request,
id: str,
form_data: KnowledgeFileIdForm,
user=Depends(get_verified_user),
@ -349,7 +357,9 @@ def update_file_from_knowledge_by_id(
# Add content to the vector database
try:
process_file(ProcessFileForm(file_id=form_data.file_id, collection_name=id))
process_file(
request, ProcessFileForm(file_id=form_data.file_id, collection_name=id)
)
except Exception as e:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
@ -508,3 +518,83 @@ async def reset_knowledge_by_id(id: str, user=Depends(get_verified_user)):
knowledge = Knowledges.update_knowledge_data_by_id(id=id, data={"file_ids": []})
return knowledge
############################
# AddFilesToKnowledge
############################
@router.post("/{id}/files/batch/add", response_model=Optional[KnowledgeFilesResponse])
def add_files_to_knowledge_batch(
id: str,
form_data: list[KnowledgeFileIdForm],
user=Depends(get_verified_user),
):
"""
Add multiple files to a knowledge base
"""
knowledge = Knowledges.get_knowledge_by_id(id=id)
if not knowledge:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=ERROR_MESSAGES.NOT_FOUND,
)
if knowledge.user_id != user.id and user.role != "admin":
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=ERROR_MESSAGES.ACCESS_PROHIBITED,
)
# Get files content
print(f"files/batch/add - {len(form_data)} files")
files: List[FileModel] = []
for form in form_data:
file = Files.get_file_by_id(form.file_id)
if not file:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=f"File {form.file_id} not found",
)
files.append(file)
# Process files
try:
result = process_files_batch(
BatchProcessFilesForm(files=files, collection_name=id)
)
except Exception as e:
log.error(
f"add_files_to_knowledge_batch: Exception occurred: {e}", exc_info=True
)
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=str(e))
# Add successful files to knowledge base
data = knowledge.data or {}
existing_file_ids = data.get("file_ids", [])
# Only add files that were successfully processed
successful_file_ids = [r.file_id for r in result.results if r.status == "completed"]
for file_id in successful_file_ids:
if file_id not in existing_file_ids:
existing_file_ids.append(file_id)
data["file_ids"] = existing_file_ids
knowledge = Knowledges.update_knowledge_data_by_id(id=id, data=data)
# If there were any errors, include them in the response
if result.errors:
error_details = [f"{err.file_id}: {err.error}" for err in result.errors]
return KnowledgeFilesResponse(
**knowledge.model_dump(),
files=Files.get_files_by_ids(existing_file_ids),
warnings={
"message": "Some files failed to process",
"errors": error_details,
},
)
return KnowledgeFilesResponse(
**knowledge.model_dump(), files=Files.get_files_by_ids(existing_file_ids)
)

View File

@ -3,9 +3,9 @@ from pydantic import BaseModel
import logging
from typing import Optional
from open_webui.apps.webui.models.memories import Memories, MemoryModel
from open_webui.apps.retrieval.vector.connector import VECTOR_DB_CLIENT
from open_webui.utils.utils import get_verified_user
from open_webui.models.memories import Memories, MemoryModel
from open_webui.retrieval.vector.connector import VECTOR_DB_CLIENT
from open_webui.utils.auth import get_verified_user
from open_webui.env import SRC_LOG_LEVELS

View File

@ -1,6 +1,6 @@
from typing import Optional
from open_webui.apps.webui.models.models import (
from open_webui.models.models import (
ModelForm,
ModelModel,
ModelResponse,
@ -11,7 +11,7 @@ from open_webui.constants import ERROR_MESSAGES
from fastapi import APIRouter, Depends, HTTPException, Request, status
from open_webui.utils.utils import get_admin_user, get_verified_user
from open_webui.utils.auth import get_admin_user, get_verified_user
from open_webui.utils.access_control import has_access, has_permission

View File

@ -10,15 +10,15 @@ from aiocache import cached
import requests
from open_webui.apps.webui.models.models import Models
from fastapi import Depends, FastAPI, HTTPException, Request, APIRouter
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import FileResponse, StreamingResponse
from pydantic import BaseModel
from starlette.background import BackgroundTask
from open_webui.models.models import Models
from open_webui.config import (
CACHE_DIR,
CORS_ALLOW_ORIGIN,
ENABLE_OPENAI_API,
OPENAI_API_BASE_URLS,
OPENAI_API_KEYS,
OPENAI_API_CONFIGS,
AppConfig,
)
from open_webui.env import (
AIOHTTP_CLIENT_TIMEOUT,
@ -29,18 +29,14 @@ from open_webui.env import (
from open_webui.constants import ERROR_MESSAGES
from open_webui.env import ENV, SRC_LOG_LEVELS
from fastapi import Depends, FastAPI, HTTPException, Request
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import FileResponse, StreamingResponse
from pydantic import BaseModel
from starlette.background import BackgroundTask
from open_webui.utils.payload import (
apply_model_params_to_body_openai,
apply_model_system_prompt_to_body,
)
from open_webui.utils.utils import get_admin_user, get_verified_user
from open_webui.utils.auth import get_admin_user, get_verified_user
from open_webui.utils.access_control import has_access
@ -48,36 +44,69 @@ log = logging.getLogger(__name__)
log.setLevel(SRC_LOG_LEVELS["OPENAI"])
app = FastAPI(
docs_url="/docs" if ENV == "dev" else None,
openapi_url="/openapi.json" if ENV == "dev" else None,
redoc_url=None,
)
##########################################
#
# Utility functions
#
##########################################
app.add_middleware(
CORSMiddleware,
allow_origins=CORS_ALLOW_ORIGIN,
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
app.state.config = AppConfig()
app.state.config.ENABLE_OPENAI_API = ENABLE_OPENAI_API
app.state.config.OPENAI_API_BASE_URLS = OPENAI_API_BASE_URLS
app.state.config.OPENAI_API_KEYS = OPENAI_API_KEYS
app.state.config.OPENAI_API_CONFIGS = OPENAI_API_CONFIGS
async def send_get_request(url, key=None):
timeout = aiohttp.ClientTimeout(total=AIOHTTP_CLIENT_TIMEOUT_OPENAI_MODEL_LIST)
try:
async with aiohttp.ClientSession(timeout=timeout, trust_env=True) as session:
async with session.get(
url, headers={**({"Authorization": f"Bearer {key}"} if key else {})}
) as response:
return await response.json()
except Exception as e:
# Handle connection error here
log.error(f"Connection error: {e}")
return None
@app.get("/config")
async def get_config(user=Depends(get_admin_user)):
async def cleanup_response(
response: Optional[aiohttp.ClientResponse],
session: Optional[aiohttp.ClientSession],
):
if response:
response.close()
if session:
await session.close()
def openai_o1_handler(payload):
"""
Handle O1 specific parameters
"""
if "max_tokens" in payload:
# Remove "max_tokens" from the payload
payload["max_completion_tokens"] = payload["max_tokens"]
del payload["max_tokens"]
# Fix: O1 does not support the "system" parameter, Modify "system" to "user"
if payload["messages"][0]["role"] == "system":
payload["messages"][0]["role"] = "user"
return payload
##########################################
#
# API routes
#
##########################################
router = APIRouter()
@router.get("/config")
async def get_config(request: Request, user=Depends(get_admin_user)):
return {
"ENABLE_OPENAI_API": app.state.config.ENABLE_OPENAI_API,
"OPENAI_API_BASE_URLS": app.state.config.OPENAI_API_BASE_URLS,
"OPENAI_API_KEYS": app.state.config.OPENAI_API_KEYS,
"OPENAI_API_CONFIGS": app.state.config.OPENAI_API_CONFIGS,
"ENABLE_OPENAI_API": request.app.state.config.ENABLE_OPENAI_API,
"OPENAI_API_BASE_URLS": request.app.state.config.OPENAI_API_BASE_URLS,
"OPENAI_API_KEYS": request.app.state.config.OPENAI_API_KEYS,
"OPENAI_API_CONFIGS": request.app.state.config.OPENAI_API_CONFIGS,
}
@ -88,50 +117,56 @@ class OpenAIConfigForm(BaseModel):
OPENAI_API_CONFIGS: dict
@app.post("/config/update")
async def update_config(form_data: OpenAIConfigForm, user=Depends(get_admin_user)):
app.state.config.ENABLE_OPENAI_API = form_data.ENABLE_OPENAI_API
app.state.config.OPENAI_API_BASE_URLS = form_data.OPENAI_API_BASE_URLS
app.state.config.OPENAI_API_KEYS = form_data.OPENAI_API_KEYS
@router.post("/config/update")
async def update_config(
request: Request, form_data: OpenAIConfigForm, user=Depends(get_admin_user)
):
request.app.state.config.ENABLE_OPENAI_API = form_data.ENABLE_OPENAI_API
request.app.state.config.OPENAI_API_BASE_URLS = form_data.OPENAI_API_BASE_URLS
request.app.state.config.OPENAI_API_KEYS = form_data.OPENAI_API_KEYS
# Check if API KEYS length is same than API URLS length
if len(app.state.config.OPENAI_API_KEYS) != len(
app.state.config.OPENAI_API_BASE_URLS
if len(request.app.state.config.OPENAI_API_KEYS) != len(
request.app.state.config.OPENAI_API_BASE_URLS
):
if len(app.state.config.OPENAI_API_KEYS) > len(
app.state.config.OPENAI_API_BASE_URLS
if len(request.app.state.config.OPENAI_API_KEYS) > len(
request.app.state.config.OPENAI_API_BASE_URLS
):
app.state.config.OPENAI_API_KEYS = app.state.config.OPENAI_API_KEYS[
: len(app.state.config.OPENAI_API_BASE_URLS)
]
request.app.state.config.OPENAI_API_KEYS = (
request.app.state.config.OPENAI_API_KEYS[
: len(request.app.state.config.OPENAI_API_BASE_URLS)
]
)
else:
app.state.config.OPENAI_API_KEYS += [""] * (
len(app.state.config.OPENAI_API_BASE_URLS)
- len(app.state.config.OPENAI_API_KEYS)
request.app.state.config.OPENAI_API_KEYS += [""] * (
len(request.app.state.config.OPENAI_API_BASE_URLS)
- len(request.app.state.config.OPENAI_API_KEYS)
)
app.state.config.OPENAI_API_CONFIGS = form_data.OPENAI_API_CONFIGS
request.app.state.config.OPENAI_API_CONFIGS = form_data.OPENAI_API_CONFIGS
# Remove any extra configs
config_urls = app.state.config.OPENAI_API_CONFIGS.keys()
for idx, url in enumerate(app.state.config.OPENAI_API_BASE_URLS):
config_urls = request.app.state.config.OPENAI_API_CONFIGS.keys()
for idx, url in enumerate(request.app.state.config.OPENAI_API_BASE_URLS):
if url not in config_urls:
app.state.config.OPENAI_API_CONFIGS.pop(url, None)
request.app.state.config.OPENAI_API_CONFIGS.pop(url, None)
return {
"ENABLE_OPENAI_API": app.state.config.ENABLE_OPENAI_API,
"OPENAI_API_BASE_URLS": app.state.config.OPENAI_API_BASE_URLS,
"OPENAI_API_KEYS": app.state.config.OPENAI_API_KEYS,
"OPENAI_API_CONFIGS": app.state.config.OPENAI_API_CONFIGS,
"ENABLE_OPENAI_API": request.app.state.config.ENABLE_OPENAI_API,
"OPENAI_API_BASE_URLS": request.app.state.config.OPENAI_API_BASE_URLS,
"OPENAI_API_KEYS": request.app.state.config.OPENAI_API_KEYS,
"OPENAI_API_CONFIGS": request.app.state.config.OPENAI_API_CONFIGS,
}
@app.post("/audio/speech")
@router.post("/audio/speech")
async def speech(request: Request, user=Depends(get_verified_user)):
idx = None
try:
idx = app.state.config.OPENAI_API_BASE_URLS.index("https://api.openai.com/v1")
idx = request.app.state.config.OPENAI_API_BASE_URLS.index(
"https://api.openai.com/v1"
)
body = await request.body()
name = hashlib.sha256(body).hexdigest()
@ -144,23 +179,35 @@ async def speech(request: Request, user=Depends(get_verified_user)):
if file_path.is_file():
return FileResponse(file_path)
headers = {}
headers["Authorization"] = f"Bearer {app.state.config.OPENAI_API_KEYS[idx]}"
headers["Content-Type"] = "application/json"
if "openrouter.ai" in app.state.config.OPENAI_API_BASE_URLS[idx]:
headers["HTTP-Referer"] = "https://openwebui.com/"
headers["X-Title"] = "Open WebUI"
if ENABLE_FORWARD_USER_INFO_HEADERS:
headers["X-OpenWebUI-User-Name"] = user.name
headers["X-OpenWebUI-User-Id"] = user.id
headers["X-OpenWebUI-User-Email"] = user.email
headers["X-OpenWebUI-User-Role"] = user.role
url = request.app.state.config.OPENAI_API_BASE_URLS[idx]
r = None
try:
r = requests.post(
url=f"{app.state.config.OPENAI_API_BASE_URLS[idx]}/audio/speech",
url=f"{url}/audio/speech",
data=body,
headers=headers,
headers={
"Content-Type": "application/json",
"Authorization": f"Bearer {request.app.state.config.OPENAI_API_KEYS[idx]}",
**(
{
"HTTP-Referer": "https://openwebui.com/",
"X-Title": "Open WebUI",
}
if "openrouter.ai" in url
else {}
),
**(
{
"X-OpenWebUI-User-Name": user.name,
"X-OpenWebUI-User-Id": user.id,
"X-OpenWebUI-User-Email": user.email,
"X-OpenWebUI-User-Role": user.role,
}
if ENABLE_FORWARD_USER_INFO_HEADERS
else {}
),
},
stream=True,
)
@ -179,115 +226,62 @@ async def speech(request: Request, user=Depends(get_verified_user)):
except Exception as e:
log.exception(e)
error_detail = "Open WebUI: Server Connection Error"
detail = None
if r is not None:
try:
res = r.json()
if "error" in res:
error_detail = f"External: {res['error']}"
detail = f"External: {res['error']}"
except Exception:
error_detail = f"External: {e}"
detail = f"External: {e}"
raise HTTPException(
status_code=r.status_code if r else 500, detail=error_detail
status_code=r.status_code if r else 500,
detail=detail if detail else "Open WebUI: Server Connection Error",
)
except ValueError:
raise HTTPException(status_code=401, detail=ERROR_MESSAGES.OPENAI_NOT_FOUND)
async def aiohttp_get(url, key=None):
timeout = aiohttp.ClientTimeout(total=AIOHTTP_CLIENT_TIMEOUT_OPENAI_MODEL_LIST)
try:
headers = {"Authorization": f"Bearer {key}"} if key else {}
async with aiohttp.ClientSession(timeout=timeout, trust_env=True) as session:
async with session.get(url, headers=headers) as response:
return await response.json()
except Exception as e:
# Handle connection error here
log.error(f"Connection error: {e}")
return None
async def cleanup_response(
response: Optional[aiohttp.ClientResponse],
session: Optional[aiohttp.ClientSession],
):
if response:
response.close()
if session:
await session.close()
def merge_models_lists(model_lists):
log.debug(f"merge_models_lists {model_lists}")
merged_list = []
for idx, models in enumerate(model_lists):
if models is not None and "error" not in models:
merged_list.extend(
[
{
**model,
"name": model.get("name", model["id"]),
"owned_by": "openai",
"openai": model,
"urlIdx": idx,
}
for model in models
if "api.openai.com"
not in app.state.config.OPENAI_API_BASE_URLS[idx]
or not any(
name in model["id"]
for name in [
"babbage",
"dall-e",
"davinci",
"embedding",
"tts",
"whisper",
]
)
]
)
return merged_list
async def get_all_models_responses() -> list:
if not app.state.config.ENABLE_OPENAI_API:
async def get_all_models_responses(request: Request) -> list:
if not request.app.state.config.ENABLE_OPENAI_API:
return []
# Check if API KEYS length is same than API URLS length
num_urls = len(app.state.config.OPENAI_API_BASE_URLS)
num_keys = len(app.state.config.OPENAI_API_KEYS)
num_urls = len(request.app.state.config.OPENAI_API_BASE_URLS)
num_keys = len(request.app.state.config.OPENAI_API_KEYS)
if num_keys != num_urls:
# if there are more keys than urls, remove the extra keys
if num_keys > num_urls:
new_keys = app.state.config.OPENAI_API_KEYS[:num_urls]
app.state.config.OPENAI_API_KEYS = new_keys
new_keys = request.app.state.config.OPENAI_API_KEYS[:num_urls]
request.app.state.config.OPENAI_API_KEYS = new_keys
# if there are more urls than keys, add empty keys
else:
app.state.config.OPENAI_API_KEYS += [""] * (num_urls - num_keys)
request.app.state.config.OPENAI_API_KEYS += [""] * (num_urls - num_keys)
tasks = []
for idx, url in enumerate(app.state.config.OPENAI_API_BASE_URLS):
if url not in app.state.config.OPENAI_API_CONFIGS:
tasks.append(
aiohttp_get(f"{url}/models", app.state.config.OPENAI_API_KEYS[idx])
request_tasks = []
for idx, url in enumerate(request.app.state.config.OPENAI_API_BASE_URLS):
if url not in request.app.state.config.OPENAI_API_CONFIGS:
request_tasks.append(
send_get_request(
f"{url}/models", request.app.state.config.OPENAI_API_KEYS[idx]
)
)
else:
api_config = app.state.config.OPENAI_API_CONFIGS.get(url, {})
api_config = request.app.state.config.OPENAI_API_CONFIGS.get(url, {})
enable = api_config.get("enable", True)
model_ids = api_config.get("model_ids", [])
if enable:
if len(model_ids) == 0:
tasks.append(
aiohttp_get(
f"{url}/models", app.state.config.OPENAI_API_KEYS[idx]
request_tasks.append(
send_get_request(
f"{url}/models",
request.app.state.config.OPENAI_API_KEYS[idx],
)
)
else:
@ -305,16 +299,18 @@ async def get_all_models_responses() -> list:
],
}
tasks.append(asyncio.ensure_future(asyncio.sleep(0, model_list)))
request_tasks.append(
asyncio.ensure_future(asyncio.sleep(0, model_list))
)
else:
tasks.append(asyncio.ensure_future(asyncio.sleep(0, None)))
request_tasks.append(asyncio.ensure_future(asyncio.sleep(0, None)))
responses = await asyncio.gather(*tasks)
responses = await asyncio.gather(*request_tasks)
for idx, response in enumerate(responses):
if response:
url = app.state.config.OPENAI_API_BASE_URLS[idx]
api_config = app.state.config.OPENAI_API_CONFIGS.get(url, {})
url = request.app.state.config.OPENAI_API_BASE_URLS[idx]
api_config = request.app.state.config.OPENAI_API_CONFIGS.get(url, {})
prefix_id = api_config.get("prefix_id", None)
@ -325,18 +321,30 @@ async def get_all_models_responses() -> list:
model["id"] = f"{prefix_id}.{model['id']}"
log.debug(f"get_all_models:responses() {responses}")
return responses
async def get_filtered_models(models, user):
# Filter models based on user access control
filtered_models = []
for model in models.get("data", []):
model_info = Models.get_model_by_id(model["id"])
if model_info:
if user.id == model_info.user_id or has_access(
user.id, type="read", access_control=model_info.access_control
):
filtered_models.append(model)
return filtered_models
@cached(ttl=3)
async def get_all_models() -> dict[str, list]:
async def get_all_models(request: Request) -> dict[str, list]:
log.info("get_all_models()")
if not app.state.config.ENABLE_OPENAI_API:
if not request.app.state.config.ENABLE_OPENAI_API:
return {"data": []}
responses = await get_all_models_responses()
responses = await get_all_models_responses(request)
def extract_data(response):
if response and "data" in response:
@ -345,41 +353,86 @@ async def get_all_models() -> dict[str, list]:
return response
return None
def merge_models_lists(model_lists):
log.debug(f"merge_models_lists {model_lists}")
merged_list = []
for idx, models in enumerate(model_lists):
if models is not None and "error" not in models:
merged_list.extend(
[
{
**model,
"name": model.get("name", model["id"]),
"owned_by": "openai",
"openai": model,
"urlIdx": idx,
}
for model in models
if "api.openai.com"
not in request.app.state.config.OPENAI_API_BASE_URLS[idx]
or not any(
name in model["id"]
for name in [
"babbage",
"dall-e",
"davinci",
"embedding",
"tts",
"whisper",
]
)
]
)
return merged_list
models = {"data": merge_models_lists(map(extract_data, responses))}
log.debug(f"models: {models}")
request.app.state.OPENAI_MODELS = {model["id"]: model for model in models["data"]}
return models
@app.get("/models")
@app.get("/models/{url_idx}")
async def get_models(url_idx: Optional[int] = None, user=Depends(get_verified_user)):
@router.get("/models")
@router.get("/models/{url_idx}")
async def get_models(
request: Request, url_idx: Optional[int] = None, user=Depends(get_verified_user)
):
models = {
"data": [],
}
if url_idx is None:
models = await get_all_models()
models = await get_all_models(request)
else:
url = app.state.config.OPENAI_API_BASE_URLS[url_idx]
key = app.state.config.OPENAI_API_KEYS[url_idx]
headers = {}
headers["Authorization"] = f"Bearer {key}"
headers["Content-Type"] = "application/json"
if ENABLE_FORWARD_USER_INFO_HEADERS:
headers["X-OpenWebUI-User-Name"] = user.name
headers["X-OpenWebUI-User-Id"] = user.id
headers["X-OpenWebUI-User-Email"] = user.email
headers["X-OpenWebUI-User-Role"] = user.role
url = request.app.state.config.OPENAI_API_BASE_URLS[url_idx]
key = request.app.state.config.OPENAI_API_KEYS[url_idx]
r = None
timeout = aiohttp.ClientTimeout(total=AIOHTTP_CLIENT_TIMEOUT_OPENAI_MODEL_LIST)
async with aiohttp.ClientSession(timeout=timeout) as session:
async with aiohttp.ClientSession(
timeout=aiohttp.ClientTimeout(
total=AIOHTTP_CLIENT_TIMEOUT_OPENAI_MODEL_LIST
)
) as session:
try:
async with session.get(f"{url}/models", headers=headers) as r:
async with session.get(
f"{url}/models",
headers={
"Authorization": f"Bearer {key}",
"Content-Type": "application/json",
**(
{
"X-OpenWebUI-User-Name": user.name,
"X-OpenWebUI-User-Id": user.id,
"X-OpenWebUI-User-Email": user.email,
"X-OpenWebUI-User-Role": user.role,
}
if ENABLE_FORWARD_USER_INFO_HEADERS
else {}
),
},
) as r:
if r.status != 200:
# Extract response error details if available
error_detail = f"HTTP Error: {r.status}"
@ -413,27 +466,16 @@ async def get_models(url_idx: Optional[int] = None, user=Depends(get_verified_us
except aiohttp.ClientError as e:
# ClientError covers all aiohttp requests issues
log.exception(f"Client error: {str(e)}")
# Handle aiohttp-specific connection issues, timeout etc.
raise HTTPException(
status_code=500, detail="Open WebUI: Server Connection Error"
)
except Exception as e:
log.exception(f"Unexpected error: {e}")
# Generic error handler in case parsing JSON or other steps fail
error_detail = f"Unexpected error: {str(e)}"
raise HTTPException(status_code=500, detail=error_detail)
if user.role == "user" and not BYPASS_MODEL_ACCESS_CONTROL:
# Filter models based on user access control
filtered_models = []
for model in models.get("data", []):
model_info = Models.get_model_by_id(model["id"])
if model_info:
if user.id == model_info.user_id or has_access(
user.id, type="read", access_control=model_info.access_control
):
filtered_models.append(model)
models["data"] = filtered_models
models["data"] = get_filtered_models(models, user)
return models
@ -443,21 +485,24 @@ class ConnectionVerificationForm(BaseModel):
key: str
@app.post("/verify")
@router.post("/verify")
async def verify_connection(
form_data: ConnectionVerificationForm, user=Depends(get_admin_user)
):
url = form_data.url
key = form_data.key
headers = {}
headers["Authorization"] = f"Bearer {key}"
headers["Content-Type"] = "application/json"
timeout = aiohttp.ClientTimeout(total=AIOHTTP_CLIENT_TIMEOUT_OPENAI_MODEL_LIST)
async with aiohttp.ClientSession(timeout=timeout) as session:
async with aiohttp.ClientSession(
timeout=aiohttp.ClientTimeout(total=AIOHTTP_CLIENT_TIMEOUT_OPENAI_MODEL_LIST)
) as session:
try:
async with session.get(f"{url}/models", headers=headers) as r:
async with session.get(
f"{url}/models",
headers={
"Authorization": f"Bearer {key}",
"Content-Type": "application/json",
},
) as r:
if r.status != 200:
# Extract response error details if available
error_detail = f"HTTP Error: {r.status}"
@ -472,26 +517,27 @@ async def verify_connection(
except aiohttp.ClientError as e:
# ClientError covers all aiohttp requests issues
log.exception(f"Client error: {str(e)}")
# Handle aiohttp-specific connection issues, timeout etc.
raise HTTPException(
status_code=500, detail="Open WebUI: Server Connection Error"
)
except Exception as e:
log.exception(f"Unexpected error: {e}")
# Generic error handler in case parsing JSON or other steps fail
error_detail = f"Unexpected error: {str(e)}"
raise HTTPException(status_code=500, detail=error_detail)
@app.post("/chat/completions")
@router.post("/chat/completions")
async def generate_chat_completion(
request: Request,
form_data: dict,
user=Depends(get_verified_user),
bypass_filter: Optional[bool] = False,
):
if BYPASS_MODEL_ACCESS_CONTROL:
bypass_filter = True
idx = 0
payload = {**form_data}
if "metadata" in payload:
del payload["metadata"]
@ -502,6 +548,7 @@ async def generate_chat_completion(
if model_info:
if model_info.base_model_id:
payload["model"] = model_info.base_model_id
model_id = model_info.base_model_id
params = model_info.params.model_dump()
payload = apply_model_params_to_body_openai(params, payload)
@ -526,15 +573,7 @@ async def generate_chat_completion(
detail="Model not found",
)
# Attemp to get urlIdx from the model
models = await get_all_models()
# Find the model from the list
model = next(
(model for model in models["data"] if model["id"] == payload.get("model")),
None,
)
model = request.app.state.OPENAI_MODELS.get(model_id)
if model:
idx = model["urlIdx"]
else:
@ -544,11 +583,11 @@ async def generate_chat_completion(
)
# Get the API config for the model
api_config = app.state.config.OPENAI_API_CONFIGS.get(
app.state.config.OPENAI_API_BASE_URLS[idx], {}
api_config = request.app.state.config.OPENAI_API_CONFIGS.get(
request.app.state.config.OPENAI_API_BASE_URLS[idx], {}
)
prefix_id = api_config.get("prefix_id", None)
prefix_id = api_config.get("prefix_id", None)
if prefix_id:
payload["model"] = payload["model"].replace(f"{prefix_id}.", "")
@ -561,43 +600,26 @@ async def generate_chat_completion(
"role": user.role,
}
url = app.state.config.OPENAI_API_BASE_URLS[idx]
key = app.state.config.OPENAI_API_KEYS[idx]
url = request.app.state.config.OPENAI_API_BASE_URLS[idx]
key = request.app.state.config.OPENAI_API_KEYS[idx]
# Fix: O1 does not support the "max_tokens" parameter, Modify "max_tokens" to "max_completion_tokens"
is_o1 = payload["model"].lower().startswith("o1-")
# Change max_completion_tokens to max_tokens (Backward compatible)
if "api.openai.com" not in url and not is_o1:
if is_o1:
payload = openai_o1_handler(payload)
elif "api.openai.com" not in url:
# Remove "max_completion_tokens" from the payload for backward compatibility
if "max_completion_tokens" in payload:
# Remove "max_completion_tokens" from the payload
payload["max_tokens"] = payload["max_completion_tokens"]
del payload["max_completion_tokens"]
else:
if is_o1 and "max_tokens" in payload:
payload["max_completion_tokens"] = payload["max_tokens"]
del payload["max_tokens"]
if "max_tokens" in payload and "max_completion_tokens" in payload:
del payload["max_tokens"]
# Fix: O1 does not support the "system" parameter, Modify "system" to "user"
if is_o1 and payload["messages"][0]["role"] == "system":
payload["messages"][0]["role"] = "user"
# TODO: check if below is needed
# if "max_tokens" in payload and "max_completion_tokens" in payload:
# del payload["max_tokens"]
# Convert the modified body back to JSON
payload = json.dumps(payload)
headers = {}
headers["Authorization"] = f"Bearer {key}"
headers["Content-Type"] = "application/json"
if "openrouter.ai" in app.state.config.OPENAI_API_BASE_URLS[idx]:
headers["HTTP-Referer"] = "https://openwebui.com/"
headers["X-Title"] = "Open WebUI"
if ENABLE_FORWARD_USER_INFO_HEADERS:
headers["X-OpenWebUI-User-Name"] = user.name
headers["X-OpenWebUI-User-Id"] = user.id
headers["X-OpenWebUI-User-Email"] = user.email
headers["X-OpenWebUI-User-Role"] = user.role
r = None
session = None
streaming = False
@ -607,11 +629,33 @@ async def generate_chat_completion(
session = aiohttp.ClientSession(
trust_env=True, timeout=aiohttp.ClientTimeout(total=AIOHTTP_CLIENT_TIMEOUT)
)
r = await session.request(
method="POST",
url=f"{url}/chat/completions",
data=payload,
headers=headers,
headers={
"Authorization": f"Bearer {key}",
"Content-Type": "application/json",
**(
{
"HTTP-Referer": "https://openwebui.com/",
"X-Title": "Open WebUI",
}
if "openrouter.ai" in url
else {}
),
**(
{
"X-OpenWebUI-User-Name": user.name,
"X-OpenWebUI-User-Id": user.id,
"X-OpenWebUI-User-Email": user.email,
"X-OpenWebUI-User-Role": user.role,
}
if ENABLE_FORWARD_USER_INFO_HEADERS
else {}
),
},
)
# Check if response is SSE
@ -636,14 +680,18 @@ async def generate_chat_completion(
return response
except Exception as e:
log.exception(e)
error_detail = "Open WebUI: Server Connection Error"
detail = None
if isinstance(response, dict):
if "error" in response:
error_detail = f"{response['error']['message'] if 'message' in response['error'] else response['error']}"
detail = f"{response['error']['message'] if 'message' in response['error'] else response['error']}"
elif isinstance(response, str):
error_detail = response
detail = response
raise HTTPException(status_code=r.status if r else 500, detail=error_detail)
raise HTTPException(
status_code=r.status if r else 500,
detail=detail if detail else "Open WebUI: Server Connection Error",
)
finally:
if not streaming and session:
if r:
@ -651,25 +699,17 @@ async def generate_chat_completion(
await session.close()
@app.api_route("/{path:path}", methods=["GET", "POST", "PUT", "DELETE"])
@router.api_route("/{path:path}", methods=["GET", "POST", "PUT", "DELETE"])
async def proxy(path: str, request: Request, user=Depends(get_verified_user)):
idx = 0
"""
Deprecated: proxy all requests to OpenAI API
"""
body = await request.body()
url = app.state.config.OPENAI_API_BASE_URLS[idx]
key = app.state.config.OPENAI_API_KEYS[idx]
target_url = f"{url}/{path}"
headers = {}
headers["Authorization"] = f"Bearer {key}"
headers["Content-Type"] = "application/json"
if ENABLE_FORWARD_USER_INFO_HEADERS:
headers["X-OpenWebUI-User-Name"] = user.name
headers["X-OpenWebUI-User-Id"] = user.id
headers["X-OpenWebUI-User-Email"] = user.email
headers["X-OpenWebUI-User-Role"] = user.role
idx = 0
url = request.app.state.config.OPENAI_API_BASE_URLS[idx]
key = request.app.state.config.OPENAI_API_KEYS[idx]
r = None
session = None
@ -679,11 +719,23 @@ async def proxy(path: str, request: Request, user=Depends(get_verified_user)):
session = aiohttp.ClientSession(trust_env=True)
r = await session.request(
method=request.method,
url=target_url,
url=f"{url}/{path}",
data=body,
headers=headers,
headers={
"Authorization": f"Bearer {key}",
"Content-Type": "application/json",
**(
{
"X-OpenWebUI-User-Name": user.name,
"X-OpenWebUI-User-Id": user.id,
"X-OpenWebUI-User-Email": user.email,
"X-OpenWebUI-User-Role": user.role,
}
if ENABLE_FORWARD_USER_INFO_HEADERS
else {}
),
},
)
r.raise_for_status()
# Check if response is SSE
@ -700,18 +752,23 @@ async def proxy(path: str, request: Request, user=Depends(get_verified_user)):
else:
response_data = await r.json()
return response_data
except Exception as e:
log.exception(e)
error_detail = "Open WebUI: Server Connection Error"
detail = None
if r is not None:
try:
res = await r.json()
print(res)
if "error" in res:
error_detail = f"External: {res['error']['message'] if 'message' in res['error'] else res['error']}"
detail = f"External: {res['error']['message'] if 'message' in res['error'] else res['error']}"
except Exception:
error_detail = f"External: {e}"
raise HTTPException(status_code=r.status if r else 500, detail=error_detail)
detail = f"External: {e}"
raise HTTPException(
status_code=r.status if r else 500,
detail=detail if detail else "Open WebUI: Server Connection Error",
)
finally:
if not streaming and session:
if r:

Some files were not shown because too many files have changed in this diff Show More