enh: access control

This commit is contained in:
Timothy Jaeryang Baek 2024-11-16 17:09:15 -08:00
parent 227cca35e8
commit 73fe77c2da
9 changed files with 304 additions and 277 deletions

View File

@ -7,6 +7,8 @@ from open_webui.apps.webui.models.groups import Groups
from pydantic import BaseModel, ConfigDict
from sqlalchemy import BigInteger, Column, String, Text, JSON
from open_webui.utils.access_control import has_access
####################
# Prompts DB Schema
####################
@ -107,57 +109,11 @@ class PromptsTable:
) -> list[PromptModel]:
prompts = self.get_prompts()
groups = Groups.get_groups_by_member_id(user_id)
group_ids = [group.id for group in groups]
if permission == "write":
return [
prompt
for prompt in prompts
if prompt.user_id == user_id
or (
prompt.access_control
and (
any(
group_id
in prompt.access_control.get(permission, {}).get(
"group_ids", []
)
for group_id in group_ids
)
or (
user_id
in prompt.access_control.get(permission, {}).get(
"user_ids", []
)
)
)
)
]
elif permission == "read":
return [
prompt
for prompt in prompts
if prompt.user_id == user_id
or prompt.access_control is None
or (
prompt.access_control
and (
any(
prompt.access_control.get(permission, {}).get(
"group_ids", []
)
in group_id
for group_id in group_ids
)
or (
user_id
in prompt.access_control.get(permission, {}).get(
"user_ids", []
)
)
)
)
or has_access(user_id, permission, prompt.access_control)
]
def update_prompt_by_command(

View File

@ -8,6 +8,9 @@ from open_webui.env import SRC_LOG_LEVELS
from pydantic import BaseModel, ConfigDict
from sqlalchemy import BigInteger, Column, String, Text, JSON
from open_webui.utils.access_control import has_access
log = logging.getLogger(__name__)
log.setLevel(SRC_LOG_LEVELS["MODELS"])
@ -133,6 +136,18 @@ class ToolsTable:
with get_db() as db:
return [ToolModel.model_validate(tool) for tool in db.query(Tool).all()]
def get_tools_by_user_id(
self, user_id: str, permission: str = "write"
) -> list[ToolModel]:
tools = self.get_tools()
return [
tool
for tool in tools
if tool.user_id == user_id
or has_access(tool.access_control, user_id, permission)
]
def get_tool_valves_by_id(self, id: str) -> Optional[dict]:
try:
with get_db() as db:

View File

@ -14,7 +14,22 @@ router = APIRouter()
@router.get("/", response_model=list[PromptModel])
async def get_prompts(user=Depends(get_verified_user)):
return Prompts.get_prompts()
if user.role == "admin":
prompts = Prompts.get_prompts()
else:
prompts = Prompts.get_prompts_by_user_id(user.id, "read")
return prompts
@router.get("/list", response_model=list[PromptModel])
async def get_prompt_list(user=Depends(get_verified_user)):
if user.role == "admin":
prompts = Prompts.get_prompts()
else:
prompts = Prompts.get_prompts_by_user_id(user.id, "write")
return prompts
############################
@ -23,7 +38,7 @@ async def get_prompts(user=Depends(get_verified_user)):
@router.post("/create", response_model=Optional[PromptModel])
async def create_new_prompt(form_data: PromptForm, user=Depends(get_admin_user)):
async def create_new_prompt(form_data: PromptForm, user=Depends(get_verified_user)):
prompt = Prompts.get_prompt_by_command(form_data.command)
if prompt is None:
prompt = Prompts.insert_new_prompt(user.id, form_data)
@ -67,7 +82,7 @@ async def get_prompt_by_command(command: str, user=Depends(get_verified_user)):
async def update_prompt_by_command(
command: str,
form_data: PromptForm,
user=Depends(get_admin_user),
user=Depends(get_verified_user),
):
prompt = Prompts.update_prompt_by_command(f"/{command}", form_data)
if prompt:
@ -85,6 +100,6 @@ async def update_prompt_by_command(
@router.delete("/command/{command}/delete", response_model=bool)
async def delete_prompt_by_command(command: str, user=Depends(get_admin_user)):
async def delete_prompt_by_command(command: str, user=Depends(get_verified_user)):
result = Prompts.delete_prompt_by_command(f"/{command}")
return result

View File

@ -14,37 +14,54 @@ from open_webui.utils.utils import get_admin_user, get_verified_user
router = APIRouter()
############################
# GetToolkits
# GetTools
############################
@router.get("/", response_model=list[ToolResponse])
async def get_toolkits(user=Depends(get_verified_user)):
toolkits = [toolkit for toolkit in Tools.get_tools()]
return toolkits
async def get_tools(user=Depends(get_verified_user)):
if user.role == "admin":
tools = Tools.get_tools()
else:
tools = Tools.get_tools_by_user_id(user.id, "read")
return tools
############################
# ExportToolKits
# GetToolList
############################
@router.get("/list", response_model=list[ToolResponse])
async def get_tool_list(user=Depends(get_verified_user)):
if user.role == "admin":
tools = Tools.get_tools()
else:
tools = Tools.get_tools_by_user_id(user.id, "write")
return tools
############################
# ExportTools
############################
@router.get("/export", response_model=list[ToolModel])
async def get_toolkits(user=Depends(get_admin_user)):
toolkits = [toolkit for toolkit in Tools.get_tools()]
return toolkits
async def export_tools(user=Depends(get_admin_user)):
tools = Tools.get_tools()
return tools
############################
# CreateNewToolKit
# CreateNewTools
############################
@router.post("/create", response_model=Optional[ToolResponse])
async def create_new_toolkit(
async def create_new_tools(
request: Request,
form_data: ToolForm,
user=Depends(get_admin_user),
user=Depends(get_verified_user),
):
if not form_data.id.isidentifier():
raise HTTPException(
@ -93,12 +110,12 @@ async def create_new_toolkit(
############################
# GetToolkitById
# GetToolsById
############################
@router.get("/id/{id}", response_model=Optional[ToolModel])
async def get_toolkit_by_id(id: str, user=Depends(get_admin_user)):
async def get_tools_by_id(id: str, user=Depends(get_verified_user)):
toolkit = Tools.get_tool_by_id(id)
if toolkit:
@ -111,16 +128,16 @@ async def get_toolkit_by_id(id: str, user=Depends(get_admin_user)):
############################
# UpdateToolkitById
# UpdateToolsById
############################
@router.post("/id/{id}/update", response_model=Optional[ToolModel])
async def update_toolkit_by_id(
async def update_tools_by_id(
request: Request,
id: str,
form_data: ToolForm,
user=Depends(get_admin_user),
user=Depends(get_verified_user),
):
try:
form_data.content = replace_imports(form_data.content)
@ -158,12 +175,14 @@ async def update_toolkit_by_id(
############################
# DeleteToolkitById
# DeleteToolsById
############################
@router.delete("/id/{id}/delete", response_model=bool)
async def delete_toolkit_by_id(request: Request, id: str, user=Depends(get_admin_user)):
async def delete_tools_by_id(
request: Request, id: str, user=Depends(get_verified_user)
):
result = Tools.delete_tool_by_id(id)
if result:
@ -180,7 +199,7 @@ async def delete_toolkit_by_id(request: Request, id: str, user=Depends(get_admin
@router.get("/id/{id}/valves", response_model=Optional[dict])
async def get_toolkit_valves_by_id(id: str, user=Depends(get_admin_user)):
async def get_tools_valves_by_id(id: str, user=Depends(get_verified_user)):
toolkit = Tools.get_tool_by_id(id)
if toolkit:
try:
@ -204,8 +223,8 @@ async def get_toolkit_valves_by_id(id: str, user=Depends(get_admin_user)):
@router.get("/id/{id}/valves/spec", response_model=Optional[dict])
async def get_toolkit_valves_spec_by_id(
request: Request, id: str, user=Depends(get_admin_user)
async def get_tools_valves_spec_by_id(
request: Request, id: str, user=Depends(get_verified_user)
):
toolkit = Tools.get_tool_by_id(id)
if toolkit:
@ -232,8 +251,8 @@ async def get_toolkit_valves_spec_by_id(
@router.post("/id/{id}/valves/update", response_model=Optional[dict])
async def update_toolkit_valves_by_id(
request: Request, id: str, form_data: dict, user=Depends(get_admin_user)
async def update_tools_valves_by_id(
request: Request, id: str, form_data: dict, user=Depends(get_verified_user)
):
toolkit = Tools.get_tool_by_id(id)
if toolkit:
@ -276,7 +295,7 @@ async def update_toolkit_valves_by_id(
@router.get("/id/{id}/valves/user", response_model=Optional[dict])
async def get_toolkit_user_valves_by_id(id: str, user=Depends(get_verified_user)):
async def get_tools_user_valves_by_id(id: str, user=Depends(get_verified_user)):
toolkit = Tools.get_tool_by_id(id)
if toolkit:
try:
@ -295,7 +314,7 @@ async def get_toolkit_user_valves_by_id(id: str, user=Depends(get_verified_user)
@router.get("/id/{id}/valves/user/spec", response_model=Optional[dict])
async def get_toolkit_user_valves_spec_by_id(
async def get_tools_user_valves_spec_by_id(
request: Request, id: str, user=Depends(get_verified_user)
):
toolkit = Tools.get_tool_by_id(id)
@ -318,7 +337,7 @@ async def get_toolkit_user_valves_spec_by_id(
@router.post("/id/{id}/valves/user/update", response_model=Optional[dict])
async def update_toolkit_user_valves_by_id(
async def update_tools_user_valves_by_id(
request: Request, id: str, form_data: dict, user=Depends(get_verified_user)
):
toolkit = Tools.get_tool_by_id(id)

View File

@ -69,6 +69,39 @@ export const getPrompts = async (token: string = '') => {
return res;
};
export const getPromptList = async (token: string = '') => {
let error = null;
const res = await fetch(`${WEBUI_API_BASE_URL}/prompts/list`, {
method: 'GET',
headers: {
Accept: 'application/json',
'Content-Type': 'application/json',
authorization: `Bearer ${token}`
}
})
.then(async (res) => {
if (!res.ok) throw await res.json();
return res.json();
})
.then((json) => {
return json;
})
.catch((err) => {
error = err.detail;
console.log(err);
return null;
});
if (error) {
throw error;
}
return res;
};
export const getPromptByCommand = async (token: string, command: string) => {
let error = null;

View File

@ -3,11 +3,17 @@
import fileSaver from 'file-saver';
const { saveAs } = fileSaver;
import { onMount, getContext } from 'svelte';
import { WEBUI_NAME, config, prompts } from '$lib/stores';
import { createNewPrompt, deletePromptByCommand, getPrompts } from '$lib/apis/prompts';
import { error } from '@sveltejs/kit';
import { goto } from '$app/navigation';
import { onMount, getContext } from 'svelte';
import { WEBUI_NAME, config, prompts as _prompts, user } from '$lib/stores';
import {
createNewPrompt,
deletePromptByCommand,
getPrompts,
getPromptList
} from '$lib/apis/prompts';
import PromptMenu from './Prompts/PromptMenu.svelte';
import EllipsisHorizontal from '../icons/EllipsisHorizontal.svelte';
import DeleteConfirmDialog from '$lib/components/common/ConfirmDialog.svelte';
@ -16,16 +22,18 @@
import ChevronRight from '../icons/ChevronRight.svelte';
const i18n = getContext('i18n');
let promptsImportInputElement: HTMLInputElement;
let importFiles = '';
let query = '';
let promptsImportInputElement: HTMLInputElement;
let prompts = [];
let showDeleteConfirm = false;
let deletePrompt = null;
let filteredItems = [];
$: filteredItems = $prompts.filter((p) => query === '' || p.command.includes(query));
$: filteredItems = prompts.filter((p) => query === '' || p.command.includes(query));
const shareHandler = async (prompt) => {
toast.success($i18n.t('Redirecting you to OpenWebUI Community'));
@ -60,8 +68,17 @@
const deleteHandler = async (prompt) => {
const command = prompt.command;
await deletePromptByCommand(localStorage.token, command);
await prompts.set(await getPrompts(localStorage.token));
await init();
};
const init = async () => {
prompts = await getPromptList(localStorage.token);
await _prompts.set(await getPrompts(localStorage.token));
};
onMount(async () => {
await init();
});
</script>
<svelte:head>
@ -181,7 +198,8 @@
{/each}
</div>
<div class=" flex justify-end w-full mb-3">
{#if $user?.role === 'admin'}
<div class=" flex justify-end w-full mb-3">
<div class="flex space-x-2">
<input
id="prompts-import-input"
@ -210,7 +228,8 @@
});
}
await prompts.set(await getPrompts(localStorage.token));
prompts = await getPromptList(localStorage.token);
await _prompts.set(await getPrompts(localStorage.token));
};
reader.readAsText(importFiles[0]);
@ -245,7 +264,7 @@
class="flex text-xs items-center space-x-1 px-3 py-1.5 rounded-xl bg-gray-50 hover:bg-gray-100 dark:bg-gray-800 dark:hover:bg-gray-700 dark:text-gray-200 transition"
on:click={async () => {
// promptsImportInputElement.click();
let blob = new Blob([JSON.stringify($prompts)], {
let blob = new Blob([JSON.stringify(prompts)], {
type: 'application/json'
});
saveAs(blob, `prompts-export-${Date.now()}.json`);
@ -268,16 +287,9 @@
</svg>
</div>
</button>
<!-- <button
on:click={() => {
loadDefaultPrompts();
}}
>
dd
</button> -->
</div>
</div>
</div>
{/if}
{#if $config?.features.enable_community_sharing}
<div class=" my-16">

View File

@ -4,7 +4,7 @@
const { saveAs } = fileSaver;
import { onMount, getContext } from 'svelte';
import { WEBUI_NAME, config, prompts, tools } from '$lib/stores';
import { WEBUI_NAME, config, prompts, tools as _tools, user } from '$lib/stores';
import { createNewPrompt, deletePromptByCommand, getPrompts } from '$lib/apis/prompts';
import { goto } from '$app/navigation';
@ -45,8 +45,9 @@
let showDeleteConfirm = false;
let tools = [];
let filteredItems = [];
$: filteredItems = $tools.filter(
$: filteredItems = tools.filter(
(t) =>
query === '' ||
t.name.toLowerCase().includes(query.toLowerCase()) ||
@ -118,7 +119,7 @@
if (res) {
toast.success($i18n.t('Tool deleted successfully'));
tools.set(await getTools(localStorage.token));
_tools.set(await getTools(localStorage.token));
}
};
@ -324,7 +325,8 @@
{/each}
</div>
<div class=" flex justify-end w-full mb-2">
{#if $user?.role === 'admin'}
<div class=" flex justify-end w-full mb-2">
<div class="flex space-x-2">
<input
id="documents-import-input"
@ -397,7 +399,8 @@
</div>
</button>
</div>
</div>
</div>
{/if}
{#if $config?.features.enable_community_sharing}
<div class=" my-16">

View File

@ -1,19 +1,5 @@
<script>
import { onMount } from 'svelte';
import { prompts } from '$lib/stores';
import { getPrompts } from '$lib/apis/prompts';
import Prompts from '$lib/components/workspace/Prompts.svelte';
onMount(async () => {
await Promise.all([
(async () => {
prompts.set(await getPrompts(localStorage.token));
})()
]);
});
</script>
{#if $prompts !== null}
<Prompts />
{/if}
<Prompts />

View File

@ -1,19 +1,7 @@
<script>
import { onMount } from 'svelte';
import { tools } from '$lib/stores';
import { getTools } from '$lib/apis/tools';
import Tools from '$lib/components/workspace/Tools.svelte';
onMount(async () => {
await Promise.all([
(async () => {
tools.set(await getTools(localStorage.token));
})()
]);
});
</script>
{#if $tools !== null}
<Tools />
{/if}
<Tools />