chore: format

This commit is contained in:
Timothy J. Baek
2024-10-20 18:38:06 -07:00
parent 768b7e139c
commit 9936583477
6 changed files with 85 additions and 72 deletions

View File

@@ -547,7 +547,7 @@ class GenerateEmbeddingsForm(BaseModel):
class GenerateEmbedForm(BaseModel): class GenerateEmbedForm(BaseModel):
model: str model: str
input: list[str]|str input: list[str] | str
truncate: Optional[bool] = None truncate: Optional[bool] = None
options: Optional[dict] = None options: Optional[dict] = None
keep_alive: Optional[Union[int, str]] = None keep_alive: Optional[Union[int, str]] = None

View File

@@ -110,9 +110,8 @@ class ChromaClient:
def insert(self, collection_name: str, items: list[VectorItem]): def insert(self, collection_name: str, items: list[VectorItem]):
# Insert the items into the collection, if the collection does not exist, it will be created. # Insert the items into the collection, if the collection does not exist, it will be created.
collection = self.client.get_or_create_collection( collection = self.client.get_or_create_collection(
name=collection_name, name=collection_name, metadata={"hnsw:space": "cosine"}
metadata={"hnsw:space": "cosine"} )
)
ids = [item["id"] for item in items] ids = [item["id"] for item in items]
documents = [item["text"] for item in items] documents = [item["text"] for item in items]
@@ -131,9 +130,8 @@ class ChromaClient:
def upsert(self, collection_name: str, items: list[VectorItem]): def upsert(self, collection_name: str, items: list[VectorItem]):
# Update the items in the collection, if the items are not present, insert them. If the collection does not exist, it will be created. # Update the items in the collection, if the items are not present, insert them. If the collection does not exist, it will be created.
collection = self.client.get_or_create_collection( collection = self.client.get_or_create_collection(
name=collection_name, name=collection_name, metadata={"hnsw:space": "cosine"}
metadata={"hnsw:space": "cosine"} )
)
ids = [item["id"] for item in items] ids = [item["id"] for item in items]
documents = [item["text"] for item in items] documents = [item["text"] for item in items]

View File

@@ -9,6 +9,7 @@ from open_webui.config import QDRANT_URI
NO_LIMIT = 999999999 NO_LIMIT = 999999999
class QdrantClient: class QdrantClient:
def __init__(self): def __init__(self):
self.collection_prefix = "open-webui" self.collection_prefix = "open-webui"
@@ -38,15 +39,15 @@ class QdrantClient:
collection_name_with_prefix = f"{self.collection_prefix}_{collection_name}" collection_name_with_prefix = f"{self.collection_prefix}_{collection_name}"
self.client.create_collection( self.client.create_collection(
collection_name=collection_name_with_prefix, collection_name=collection_name_with_prefix,
vectors_config=models.VectorParams(size=dimension, distance=models.Distance.COSINE), vectors_config=models.VectorParams(
size=dimension, distance=models.Distance.COSINE
),
) )
print(f"collection {collection_name_with_prefix} successfully created!") print(f"collection {collection_name_with_prefix} successfully created!")
def _create_collection_if_not_exists(self, collection_name, dimension): def _create_collection_if_not_exists(self, collection_name, dimension):
if not self.has_collection( if not self.has_collection(collection_name=collection_name):
collection_name=collection_name
):
self._create_collection( self._create_collection(
collection_name=collection_name, dimension=dimension collection_name=collection_name, dimension=dimension
) )
@@ -56,22 +57,23 @@ class QdrantClient:
PointStruct( PointStruct(
id=item["id"], id=item["id"],
vector=item["vector"], vector=item["vector"],
payload={ payload={"text": item["text"], "metadata": item["metadata"]},
"text": item["text"],
"metadata": item["metadata"]
},
) )
for item in items for item in items
] ]
def has_collection(self, collection_name: str) -> bool: def has_collection(self, collection_name: str) -> bool:
return self.client.collection_exists(f"{self.collection_prefix}_{collection_name}") return self.client.collection_exists(
f"{self.collection_prefix}_{collection_name}"
)
def delete_collection(self, collection_name: str): def delete_collection(self, collection_name: str):
return self.client.delete_collection(collection_name=f"{self.collection_prefix}_{collection_name}") return self.client.delete_collection(
collection_name=f"{self.collection_prefix}_{collection_name}"
)
def search( def search(
self, collection_name: str, vectors: list[list[float | int]], limit: int self, collection_name: str, vectors: list[list[float | int]], limit: int
) -> Optional[SearchResult]: ) -> Optional[SearchResult]:
# Search for the nearest neighbor items based on the vectors and return 'limit' number of results. # Search for the nearest neighbor items based on the vectors and return 'limit' number of results.
if limit is None: if limit is None:
@@ -87,7 +89,7 @@ class QdrantClient:
ids=get_result.ids, ids=get_result.ids,
documents=get_result.documents, documents=get_result.documents,
metadatas=get_result.metadatas, metadatas=get_result.metadatas,
distances=[[point.score for point in query_response.points]] distances=[[point.score for point in query_response.points]],
) )
def query(self, collection_name: str, filter: dict, limit: Optional[int] = None): def query(self, collection_name: str, filter: dict, limit: Optional[int] = None):
@@ -101,7 +103,10 @@ class QdrantClient:
field_conditions = [] field_conditions = []
for key, value in filter.items(): for key, value in filter.items():
field_conditions.append( field_conditions.append(
models.FieldCondition(key=f"metadata.{key}", match=models.MatchValue(value=value))) models.FieldCondition(
key=f"metadata.{key}", match=models.MatchValue(value=value)
)
)
points = self.client.query_points( points = self.client.query_points(
collection_name=f"{self.collection_prefix}_{collection_name}", collection_name=f"{self.collection_prefix}_{collection_name}",
@@ -117,7 +122,7 @@ class QdrantClient:
# Get all the items in the collection. # Get all the items in the collection.
points = self.client.query_points( points = self.client.query_points(
collection_name=f"{self.collection_prefix}_{collection_name}", collection_name=f"{self.collection_prefix}_{collection_name}",
limit=NO_LIMIT # otherwise qdrant would set limit to 10! limit=NO_LIMIT, # otherwise qdrant would set limit to 10!
) )
return self._result_to_get_result(points.points) return self._result_to_get_result(points.points)
@@ -134,10 +139,10 @@ class QdrantClient:
return self.client.upsert(f"{self.collection_prefix}_{collection_name}", points) return self.client.upsert(f"{self.collection_prefix}_{collection_name}", points)
def delete( def delete(
self, self,
collection_name: str, collection_name: str,
ids: Optional[list[str]] = None, ids: Optional[list[str]] = None,
filter: Optional[dict] = None, filter: Optional[dict] = None,
): ):
# Delete the items from the collection based on the ids. # Delete the items from the collection based on the ids.
field_conditions = [] field_conditions = []
@@ -162,9 +167,7 @@ class QdrantClient:
return self.client.delete( return self.client.delete(
collection_name=f"{self.collection_prefix}_{collection_name}", collection_name=f"{self.collection_prefix}_{collection_name}",
points_selector=models.FilterSelector( points_selector=models.FilterSelector(
filter=models.Filter( filter=models.Filter(must=field_conditions)
must=field_conditions
)
), ),
) )

View File

@@ -409,7 +409,10 @@ OAUTH_ROLES_CLAIM = PersistentConfig(
OAUTH_ALLOWED_ROLES = PersistentConfig( OAUTH_ALLOWED_ROLES = PersistentConfig(
"OAUTH_ALLOWED_ROLES", "OAUTH_ALLOWED_ROLES",
"oauth.allowed_roles", "oauth.allowed_roles",
[role.strip() for role in os.environ.get("OAUTH_ALLOWED_ROLES", "user,admin").split(",")], [
role.strip()
for role in os.environ.get("OAUTH_ALLOWED_ROLES", "user,admin").split(",")
],
) )
OAUTH_ADMIN_ROLES = PersistentConfig( OAUTH_ADMIN_ROLES = PersistentConfig(
@@ -418,6 +421,7 @@ OAUTH_ADMIN_ROLES = PersistentConfig(
[role.strip() for role in os.environ.get("OAUTH_ADMIN_ROLES", "admin").split(",")], [role.strip() for role in os.environ.get("OAUTH_ADMIN_ROLES", "admin").split(",")],
) )
def load_oauth_providers(): def load_oauth_providers():
OAUTH_PROVIDERS.clear() OAUTH_PROVIDERS.clear()
if GOOGLE_CLIENT_ID.value and GOOGLE_CLIENT_SECRET.value: if GOOGLE_CLIENT_ID.value and GOOGLE_CLIENT_SECRET.value:

View File

@@ -208,8 +208,6 @@ app.state.config.TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE = (
app.state.MODELS = {} app.state.MODELS = {}
################################## ##################################
# #
# ChatCompletion Middleware # ChatCompletion Middleware
@@ -223,14 +221,14 @@ def get_task_model_id(default_model_id):
# Check if the user has a custom task model and use that model # Check if the user has a custom task model and use that model
if app.state.MODELS[task_model_id]["owned_by"] == "ollama": if app.state.MODELS[task_model_id]["owned_by"] == "ollama":
if ( if (
app.state.config.TASK_MODEL app.state.config.TASK_MODEL
and app.state.config.TASK_MODEL in app.state.MODELS and app.state.config.TASK_MODEL in app.state.MODELS
): ):
task_model_id = app.state.config.TASK_MODEL task_model_id = app.state.config.TASK_MODEL
else: else:
if ( if (
app.state.config.TASK_MODEL_EXTERNAL app.state.config.TASK_MODEL_EXTERNAL
and app.state.config.TASK_MODEL_EXTERNAL in app.state.MODELS and app.state.config.TASK_MODEL_EXTERNAL in app.state.MODELS
): ):
task_model_id = app.state.config.TASK_MODEL_EXTERNAL task_model_id = app.state.config.TASK_MODEL_EXTERNAL
@@ -367,7 +365,7 @@ async def get_content_from_response(response) -> Optional[str]:
async def chat_completion_tools_handler( async def chat_completion_tools_handler(
body: dict, user: UserModel, extra_params: dict body: dict, user: UserModel, extra_params: dict
) -> tuple[dict, dict]: ) -> tuple[dict, dict]:
# If tool_ids field is present, call the functions # If tool_ids field is present, call the functions
metadata = body.get("metadata", {}) metadata = body.get("metadata", {})
@@ -681,15 +679,15 @@ def get_sorted_filters(model_id):
model model
for model in app.state.MODELS.values() for model in app.state.MODELS.values()
if "pipeline" in model if "pipeline" in model
and "type" in model["pipeline"] and "type" in model["pipeline"]
and model["pipeline"]["type"] == "filter" and model["pipeline"]["type"] == "filter"
and ( and (
model["pipeline"]["pipelines"] == ["*"] model["pipeline"]["pipelines"] == ["*"]
or any( or any(
model_id == target_model_id model_id == target_model_id
for target_model_id in model["pipeline"]["pipelines"] for target_model_id in model["pipeline"]["pipelines"]
) )
) )
] ]
sorted_filters = sorted(filters, key=lambda x: x["pipeline"]["priority"]) sorted_filters = sorted(filters, key=lambda x: x["pipeline"]["priority"])
return sorted_filters return sorted_filters
@@ -875,8 +873,8 @@ async def update_embedding_function(request: Request, call_next):
@app.middleware("http") @app.middleware("http")
async def inspect_websocket(request: Request, call_next): async def inspect_websocket(request: Request, call_next):
if ( if (
"/ws/socket.io" in request.url.path "/ws/socket.io" in request.url.path
and request.query_params.get("transport") == "websocket" and request.query_params.get("transport") == "websocket"
): ):
upgrade = (request.headers.get("Upgrade") or "").lower() upgrade = (request.headers.get("Upgrade") or "").lower()
connection = (request.headers.get("Connection") or "").lower().split(",") connection = (request.headers.get("Connection") or "").lower().split(",")
@@ -945,8 +943,8 @@ async def get_all_models():
if custom_model.base_model_id is None: if custom_model.base_model_id is None:
for model in models: for model in models:
if ( if (
custom_model.id == model["id"] custom_model.id == model["id"]
or custom_model.id == model["id"].split(":")[0] or custom_model.id == model["id"].split(":")[0]
): ):
model["name"] = custom_model.name model["name"] = custom_model.name
model["info"] = custom_model.model_dump() model["info"] = custom_model.model_dump()
@@ -963,8 +961,8 @@ async def get_all_models():
for model in models: for model in models:
if ( if (
custom_model.base_model_id == model["id"] custom_model.base_model_id == model["id"]
or custom_model.base_model_id == model["id"].split(":")[0] or custom_model.base_model_id == model["id"].split(":")[0]
): ):
owned_by = model["owned_by"] owned_by = model["owned_by"]
if "pipe" in model: if "pipe" in model:
@@ -1840,7 +1838,7 @@ async def get_pipelines_list(user=Depends(get_admin_user)):
@app.post("/api/pipelines/upload") @app.post("/api/pipelines/upload")
async def upload_pipeline( async def upload_pipeline(
urlIdx: int = Form(...), file: UploadFile = File(...), user=Depends(get_admin_user) urlIdx: int = Form(...), file: UploadFile = File(...), user=Depends(get_admin_user)
): ):
print("upload_pipeline", urlIdx, file.filename) print("upload_pipeline", urlIdx, file.filename)
# Check if the uploaded file is a python file # Check if the uploaded file is a python file
@@ -2017,9 +2015,9 @@ async def get_pipelines(urlIdx: Optional[int] = None, user=Depends(get_admin_use
@app.get("/api/pipelines/{pipeline_id}/valves") @app.get("/api/pipelines/{pipeline_id}/valves")
async def get_pipeline_valves( async def get_pipeline_valves(
urlIdx: Optional[int], urlIdx: Optional[int],
pipeline_id: str, pipeline_id: str,
user=Depends(get_admin_user), user=Depends(get_admin_user),
): ):
r = None r = None
try: try:
@@ -2055,9 +2053,9 @@ async def get_pipeline_valves(
@app.get("/api/pipelines/{pipeline_id}/valves/spec") @app.get("/api/pipelines/{pipeline_id}/valves/spec")
async def get_pipeline_valves_spec( async def get_pipeline_valves_spec(
urlIdx: Optional[int], urlIdx: Optional[int],
pipeline_id: str, pipeline_id: str,
user=Depends(get_admin_user), user=Depends(get_admin_user),
): ):
r = None r = None
try: try:
@@ -2092,10 +2090,10 @@ async def get_pipeline_valves_spec(
@app.post("/api/pipelines/{pipeline_id}/valves/update") @app.post("/api/pipelines/{pipeline_id}/valves/update")
async def update_pipeline_valves( async def update_pipeline_valves(
urlIdx: Optional[int], urlIdx: Optional[int],
pipeline_id: str, pipeline_id: str,
form_data: dict, form_data: dict,
user=Depends(get_admin_user), user=Depends(get_admin_user),
): ):
r = None r = None
try: try:
@@ -2219,7 +2217,7 @@ class ModelFilterConfigForm(BaseModel):
@app.post("/api/config/model/filter") @app.post("/api/config/model/filter")
async def update_model_filter_config( async def update_model_filter_config(
form_data: ModelFilterConfigForm, user=Depends(get_admin_user) form_data: ModelFilterConfigForm, user=Depends(get_admin_user)
): ):
app.state.config.ENABLE_MODEL_FILTER = form_data.enabled app.state.config.ENABLE_MODEL_FILTER = form_data.enabled
app.state.config.MODEL_FILTER_LIST = form_data.models app.state.config.MODEL_FILTER_LIST = form_data.models
@@ -2274,7 +2272,7 @@ async def get_app_latest_release_version():
timeout = aiohttp.ClientTimeout(total=1) timeout = aiohttp.ClientTimeout(total=1)
async with aiohttp.ClientSession(timeout=timeout, trust_env=True) as session: async with aiohttp.ClientSession(timeout=timeout, trust_env=True) as session:
async with session.get( async with session.get(
"https://api.github.com/repos/open-webui/open-webui/releases/latest" "https://api.github.com/repos/open-webui/open-webui/releases/latest"
) as response: ) as response:
response.raise_for_status() response.raise_for_status()
data = await response.json() data = await response.json()

View File

@@ -25,7 +25,10 @@ from open_webui.config import (
OAUTH_PICTURE_CLAIM, OAUTH_PICTURE_CLAIM,
OAUTH_USERNAME_CLAIM, OAUTH_USERNAME_CLAIM,
OAUTH_ALLOWED_ROLES, OAUTH_ALLOWED_ROLES,
OAUTH_ADMIN_ROLES, WEBHOOK_URL, JWT_EXPIRES_IN, AppConfig, OAUTH_ADMIN_ROLES,
WEBHOOK_URL,
JWT_EXPIRES_IN,
AppConfig,
) )
from open_webui.constants import ERROR_MESSAGES from open_webui.constants import ERROR_MESSAGES
from open_webui.env import WEBUI_SESSION_COOKIE_SAME_SITE, WEBUI_SESSION_COOKIE_SECURE from open_webui.env import WEBUI_SESSION_COOKIE_SAME_SITE, WEBUI_SESSION_COOKIE_SECURE
@@ -170,7 +173,9 @@ class OAuthManager:
# If the user does not exist, check if signups are enabled # If the user does not exist, check if signups are enabled
if auth_manager_config.ENABLE_OAUTH_SIGNUP.value: if auth_manager_config.ENABLE_OAUTH_SIGNUP.value:
# Check if an existing user with the same email already exists # Check if an existing user with the same email already exists
existing_user = Users.get_user_by_email(user_data.get("email", "").lower()) existing_user = Users.get_user_by_email(
user_data.get("email", "").lower()
)
if existing_user: if existing_user:
raise HTTPException(400, detail=ERROR_MESSAGES.EMAIL_TAKEN) raise HTTPException(400, detail=ERROR_MESSAGES.EMAIL_TAKEN)
@@ -182,16 +187,18 @@ class OAuthManager:
async with aiohttp.ClientSession() as session: async with aiohttp.ClientSession() as session:
async with session.get(picture_url) as resp: async with session.get(picture_url) as resp:
picture = await resp.read() picture = await resp.read()
base64_encoded_picture = base64.b64encode(picture).decode( base64_encoded_picture = base64.b64encode(
"utf-8" picture
) ).decode("utf-8")
guessed_mime_type = mimetypes.guess_type(picture_url)[0] guessed_mime_type = mimetypes.guess_type(picture_url)[0]
if guessed_mime_type is None: if guessed_mime_type is None:
# assume JPG, browsers are tolerant enough of image formats # assume JPG, browsers are tolerant enough of image formats
guessed_mime_type = "image/jpeg" guessed_mime_type = "image/jpeg"
picture_url = f"data:{guessed_mime_type};base64,{base64_encoded_picture}" picture_url = f"data:{guessed_mime_type};base64,{base64_encoded_picture}"
except Exception as e: except Exception as e:
log.error(f"Error downloading profile image '{picture_url}': {e}") log.error(
f"Error downloading profile image '{picture_url}': {e}"
)
picture_url = "" picture_url = ""
if not picture_url: if not picture_url:
picture_url = "/user.png" picture_url = "/user.png"
@@ -216,7 +223,9 @@ class OAuthManager:
auth_manager_config.WEBHOOK_MESSAGES.USER_SIGNUP(user.name), auth_manager_config.WEBHOOK_MESSAGES.USER_SIGNUP(user.name),
{ {
"action": "signup", "action": "signup",
"message": auth_manager_config.WEBHOOK_MESSAGES.USER_SIGNUP(user.name), "message": auth_manager_config.WEBHOOK_MESSAGES.USER_SIGNUP(
user.name
),
"user": user.model_dump_json(exclude_none=True), "user": user.model_dump_json(exclude_none=True),
}, },
) )
@@ -243,4 +252,5 @@ class OAuthManager:
redirect_url = f"{request.base_url}auth#token={jwt_token}" redirect_url = f"{request.base_url}auth#token={jwt_token}"
return RedirectResponse(url=redirect_url) return RedirectResponse(url=redirect_url)
oauth_manager = OAuthManager() oauth_manager = OAuthManager()