Add hiding + re-ordering to personas

This commit is contained in:
Weves 2023-12-22 22:06:54 -08:00 committed by Chris Weaver
parent 8b7d01fb3b
commit d9fbd7ffe2
22 changed files with 840 additions and 88 deletions

View File

@ -0,0 +1,34 @@
"""Add is_visible to Persona
Revision ID: 891cd83c87a8
Revises: b156fa702355
Create Date: 2023-12-21 11:55:54.132279
"""
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision = "891cd83c87a8"
down_revision = "b156fa702355"
branch_labels = None
depends_on = None
def upgrade() -> None:
op.add_column(
"persona",
sa.Column("is_visible", sa.Boolean(), nullable=True),
)
op.execute("UPDATE persona SET is_visible = true")
op.alter_column("persona", "is_visible", nullable=False)
op.add_column(
"persona",
sa.Column("display_priority", sa.Integer(), nullable=True),
)
def downgrade() -> None:
op.drop_column("persona", "is_visible")
op.drop_column("persona", "display_priority")

View File

@ -303,14 +303,14 @@ def get_prompt_by_id(
def get_persona_by_id(
persona_id: int,
# if user_id is `None` assume the user is an admin or auth is disabled
user_id: UUID | None,
db_session: Session,
include_deleted: bool = False,
) -> Persona:
stmt = select(Persona).where(
Persona.id == persona_id,
or_(Persona.user_id == user_id, Persona.user_id.is_(None)),
)
stmt = select(Persona).where(Persona.id == persona_id)
if user_id is not None:
stmt = stmt.where(or_(Persona.user_id == user_id, Persona.user_id.is_(None)))
if not include_deleted:
stmt = stmt.where(Persona.deleted.is_(False))
@ -534,6 +534,34 @@ def mark_persona_as_deleted(
db_session.commit()
def update_persona_visibility(
persona_id: int,
is_visible: bool,
db_session: Session,
) -> None:
persona = get_persona_by_id(
persona_id=persona_id, user_id=None, db_session=db_session
)
persona.is_visible = is_visible
db_session.commit()
def update_all_personas_display_priority(
display_priority_map: dict[int, int],
db_session: Session,
) -> None:
"""Updates the display priority of all lives Personas"""
personas = get_personas(user_id=None, db_session=db_session)
available_persona_ids = {persona.id for persona in personas}
if available_persona_ids != set(display_priority_map.keys()):
raise ValueError("Invalid persona IDs provided")
for persona in personas:
persona.display_priority = display_priority_map[persona.id]
db_session.commit()
def get_prompts(
user_id: UUID | None,
db_session: Session,
@ -553,15 +581,16 @@ def get_prompts(
def get_personas(
# if user_id is `None` assume the user is an admin or auth is disabled
user_id: UUID | None,
db_session: Session,
include_default: bool = True,
include_slack_bot_personas: bool = False,
include_deleted: bool = False,
) -> Sequence[Persona]:
stmt = select(Persona).where(
or_(Persona.user_id == user_id, Persona.user_id.is_(None))
)
stmt = select(Persona)
if user_id is not None:
stmt = stmt.where(or_(Persona.user_id == user_id, Persona.user_id.is_(None)))
if not include_default:
stmt = stmt.where(Persona.default_persona.is_(False))

View File

@ -642,6 +642,12 @@ class Persona(Base):
# Default personas are configured via backend during deployment
# Treated specially (cannot be user edited etc.)
default_persona: Mapped[bool] = mapped_column(Boolean, default=False)
# controls whether the persona is available to be selected by users
is_visible: Mapped[bool] = mapped_column(Boolean, default=True)
# controls the ordering of personas in the UI
# higher priority personas are displayed first, ties are resolved by the ID,
# where lower value IDs (e.g. created earlier) are displayed first
display_priority: Mapped[int] = mapped_column(Integer, nullable=True, default=None)
deleted: Mapped[bool] = mapped_column(Boolean, default=False)
# These are only defaults, users can select from all if desired

View File

@ -1,6 +1,7 @@
from fastapi import APIRouter
from fastapi import Depends
from fastapi import HTTPException
from pydantic import BaseModel
from sqlalchemy.orm import Session
from danswer.auth.users import current_admin_user
@ -11,6 +12,8 @@ from danswer.db.chat import get_persona_by_id
from danswer.db.chat import get_personas
from danswer.db.chat import get_prompts_by_ids
from danswer.db.chat import mark_persona_as_deleted
from danswer.db.chat import update_all_personas_display_priority
from danswer.db.chat import update_persona_visibility
from danswer.db.chat import upsert_persona
from danswer.db.document_set import get_document_sets_by_ids
from danswer.db.engine import get_session
@ -101,6 +104,41 @@ def update_persona(
)
class IsVisibleRequest(BaseModel):
is_visible: bool
@admin_router.patch("/{persona_id}/visible")
def patch_persona_visibility(
persona_id: int,
is_visible_request: IsVisibleRequest,
_: User | None = Depends(current_admin_user),
db_session: Session = Depends(get_session),
) -> None:
update_persona_visibility(
persona_id=persona_id,
is_visible=is_visible_request.is_visible,
db_session=db_session,
)
class DisplayPriorityRequest(BaseModel):
# maps persona id to display priority
display_priority_map: dict[int, int]
@admin_router.put("/display-priority")
def patch_persona_display_priority(
display_priority_request: DisplayPriorityRequest,
_: User | None = Depends(current_admin_user),
db_session: Session = Depends(get_session),
) -> None:
update_all_personas_display_priority(
display_priority_map=display_priority_request.display_priority_map,
db_session=db_session,
)
@admin_router.delete("/{persona_id}")
def delete_persona(
persona_id: int,

View File

@ -23,6 +23,8 @@ class PersonaSnapshot(BaseModel):
id: int
name: str
shared: bool
is_visible: bool
display_priority: int | None
description: str
num_chunks: float | None
llm_relevance_filter: bool
@ -41,6 +43,8 @@ class PersonaSnapshot(BaseModel):
id=persona.id,
name=persona.name,
shared=persona.user_id is None,
is_visible=persona.is_visible,
display_priority=persona.display_priority,
description=persona.description,
num_chunks=persona.num_chunks,
llm_relevance_filter=persona.llm_relevance_filter,

65
web/package-lock.json generated
View File

@ -8,6 +8,9 @@
"name": "qa",
"version": "0.2-dev",
"dependencies": {
"@dnd-kit/core": "^6.1.0",
"@dnd-kit/modifiers": "^7.0.0",
"@dnd-kit/sortable": "^8.0.0",
"@phosphor-icons/react": "^2.0.8",
"@tremor/react": "^3.9.2",
"@types/js-cookie": "^3.0.3",
@ -559,6 +562,68 @@
"node": ">=6.9.0"
}
},
"node_modules/@dnd-kit/accessibility": {
"version": "3.1.0",
"resolved": "https://registry.npmjs.org/@dnd-kit/accessibility/-/accessibility-3.1.0.tgz",
"integrity": "sha512-ea7IkhKvlJUv9iSHJOnxinBcoOI3ppGnnL+VDJ75O45Nss6HtZd8IdN8touXPDtASfeI2T2LImb8VOZcL47wjQ==",
"dependencies": {
"tslib": "^2.0.0"
},
"peerDependencies": {
"react": ">=16.8.0"
}
},
"node_modules/@dnd-kit/core": {
"version": "6.1.0",
"resolved": "https://registry.npmjs.org/@dnd-kit/core/-/core-6.1.0.tgz",
"integrity": "sha512-J3cQBClB4TVxwGo3KEjssGEXNJqGVWx17aRTZ1ob0FliR5IjYgTxl5YJbKTzA6IzrtelotH19v6y7uoIRUZPSg==",
"dependencies": {
"@dnd-kit/accessibility": "^3.1.0",
"@dnd-kit/utilities": "^3.2.2",
"tslib": "^2.0.0"
},
"peerDependencies": {
"react": ">=16.8.0",
"react-dom": ">=16.8.0"
}
},
"node_modules/@dnd-kit/modifiers": {
"version": "7.0.0",
"resolved": "https://registry.npmjs.org/@dnd-kit/modifiers/-/modifiers-7.0.0.tgz",
"integrity": "sha512-BG/ETy3eBjFap7+zIti53f0PCLGDzNXyTmn6fSdrudORf+OH04MxrW4p5+mPu4mgMk9kM41iYONjc3DOUWTcfg==",
"dependencies": {
"@dnd-kit/utilities": "^3.2.2",
"tslib": "^2.0.0"
},
"peerDependencies": {
"@dnd-kit/core": "^6.1.0",
"react": ">=16.8.0"
}
},
"node_modules/@dnd-kit/sortable": {
"version": "8.0.0",
"resolved": "https://registry.npmjs.org/@dnd-kit/sortable/-/sortable-8.0.0.tgz",
"integrity": "sha512-U3jk5ebVXe1Lr7c2wU7SBZjcWdQP+j7peHJfCspnA81enlu88Mgd7CC8Q+pub9ubP7eKVETzJW+IBAhsqbSu/g==",
"dependencies": {
"@dnd-kit/utilities": "^3.2.2",
"tslib": "^2.0.0"
},
"peerDependencies": {
"@dnd-kit/core": "^6.1.0",
"react": ">=16.8.0"
}
},
"node_modules/@dnd-kit/utilities": {
"version": "3.2.2",
"resolved": "https://registry.npmjs.org/@dnd-kit/utilities/-/utilities-3.2.2.tgz",
"integrity": "sha512-+MKAJEOfaBe5SmV6t34p80MMKhjvUz0vRrvVJbPT0WElzaOJ/1xs+D+KDv+tD/NE5ujfrChEcshd4fLn0wpiqg==",
"dependencies": {
"tslib": "^2.0.0"
},
"peerDependencies": {
"react": ">=16.8.0"
}
},
"node_modules/@emotion/is-prop-valid": {
"version": "1.2.1",
"resolved": "https://registry.npmjs.org/@emotion/is-prop-valid/-/is-prop-valid-1.2.1.tgz",

View File

@ -9,6 +9,9 @@
"lint": "next lint"
},
"dependencies": {
"@dnd-kit/core": "^6.1.0",
"@dnd-kit/modifiers": "^7.0.0",
"@dnd-kit/sortable": "^8.0.0",
"@phosphor-icons/react": "^2.0.8",
"@tremor/react": "^3.9.2",
"@types/js-cookie": "^3.0.3",

View File

@ -1,8 +1,6 @@
import { PopupSpec } from "@/components/admin/connectors/Popup";
import { useState } from "react";
import { updateBoost } from "./lib";
import { CheckmarkIcon, EditIcon } from "@/components/icons/icons";
import { FiEdit } from "react-icons/fi";
import { EditableValue } from "@/components/EditableValue";
export const ScoreSection = ({
documentId,
@ -17,17 +15,14 @@ export const ScoreSection = ({
refresh: () => void;
consistentWidth?: boolean;
}) => {
const [isOpen, setIsOpen] = useState(false);
const [score, setScore] = useState(initialScore.toString());
const onSubmit = async () => {
const numericScore = Number(score);
const onSubmit = async (value: string) => {
const numericScore = Number(value);
if (isNaN(numericScore)) {
setPopup({
message: "Score must be a number",
type: "error",
});
return;
return false;
}
const errorMsg = await updateBoost(documentId, numericScore);
@ -36,55 +31,23 @@ export const ScoreSection = ({
message: errorMsg,
type: "error",
});
return false;
} else {
setPopup({
message: "Updated score!",
type: "success",
});
refresh();
setIsOpen(false);
}
return true;
};
if (isOpen) {
return (
<div className="my-auto h-full flex">
<input
value={score}
onChange={(e) => {
setScore(e.target.value);
}}
onKeyDown={(e) => {
if (e.key === "Enter") {
onSubmit();
}
if (e.key === "Escape") {
setIsOpen(false);
setScore(initialScore.toString());
}
}}
className="border bg-background-strong border-gray-300 rounded py-1 px-1 w-12 h-4 my-auto"
/>
<div onClick={onSubmit} className="cursor-pointer my-auto ml-2">
<CheckmarkIcon size={16} className="text-green-700" />
</div>
</div>
);
}
return (
<div className="h-full flex flex-col">
<div
className="flex my-auto cursor-pointer hover:bg-hover rounded"
onClick={() => setIsOpen(true)}
>
<div className={"flex " + (consistentWidth && " w-6")}>
<div className="ml-auto my-auto">{initialScore}</div>
</div>
<div className="cursor-pointer ml-2 my-auto h-4">
<FiEdit size={16} />
</div>
</div>
</div>
<EditableValue
initialValue={initialScore.toString()}
onSubmit={onSubmit}
consistentWidth={consistentWidth}
/>
);
};

View File

@ -1,44 +1,248 @@
"use client";
import {
Table,
TableHead,
TableRow,
TableHeaderCell,
TableBody,
TableCell,
} from "@tremor/react";
import { Divider, Text } from "@tremor/react";
import { Persona } from "./interfaces";
import { EditButton } from "@/components/EditButton";
import { useRouter } from "next/navigation";
import { FiInfo } from "react-icons/fi";
import { CustomCheckbox } from "@/components/CustomCheckbox";
import { usePopup } from "@/components/admin/connectors/Popup";
import { useState } from "react";
import { UniqueIdentifier } from "@dnd-kit/core";
import { DraggableTable } from "@/components/table/DraggableTable";
import { personaComparator } from "./lib";
export function PersonasTable({ personas }: { personas: Persona[] }) {
const router = useRouter();
const { popup, setPopup } = usePopup();
const sortedPersonas = [...personas];
sortedPersonas.sort((a, b) => (a.id > b.id ? 1 : -1));
sortedPersonas.sort(personaComparator);
const [finalPersonas, setFinalPersonas] = useState<UniqueIdentifier[]>(
sortedPersonas.map((persona) => persona.id.toString())
);
const finalPersonaValues = finalPersonas.map((id) => {
return sortedPersonas.find(
(persona) => persona.id.toString() === id
) as Persona;
});
const updatePersonaOrder = async (orderedPersonaIds: UniqueIdentifier[]) => {
setFinalPersonas(orderedPersonaIds);
const displayPriorityMap = new Map<UniqueIdentifier, number>();
orderedPersonaIds.forEach((personaId, ind) => {
displayPriorityMap.set(personaId, ind);
});
const response = await fetch("/api/admin/persona/display-priority", {
method: "PUT",
headers: {
"Content-Type": "application/json",
},
body: JSON.stringify({
display_priority_map: Object.fromEntries(displayPriorityMap),
}),
});
if (!response.ok) {
setPopup({
type: "error",
message: `Failed to update persona order - ${await response.text()}`,
});
router.refresh();
}
};
return (
<div>
<Table className="overflow-visible">
<TableHead>
<TableRow>
<TableHeaderCell>Name</TableHeaderCell>
<TableHeaderCell>Description</TableHeaderCell>
<TableHeaderCell>Built-In</TableHeaderCell>
<TableHeaderCell></TableHeaderCell>
</TableRow>
</TableHead>
<TableBody>
{popup}
<Text className="my-2">
Personas will be displayed as options on the Chat / Search interfaces in
the order they are displayed below. Personas marked as hidden will not
be displayed.
</Text>
<DraggableTable
headers={["Name", "Description", "Built-In", "Is Visible", ""]}
rows={finalPersonaValues.map((persona) => {
return {
id: persona.id.toString(),
cells: [
<p
key="name"
className="text font-medium whitespace-normal break-none"
>
{persona.name}
</p>,
<p
key="description"
className="whitespace-normal break-all max-w-2xl"
>
{persona.description}
</p>,
persona.default_persona ? "Yes" : "No",
<div
key="is_visible"
onClick={async () => {
const response = await fetch(
`/api/admin/persona/${persona.id}/visible`,
{
method: "PATCH",
headers: {
"Content-Type": "application/json",
},
body: JSON.stringify({
is_visible: !persona.is_visible,
}),
}
);
if (response.ok) {
router.refresh();
} else {
setPopup({
type: "error",
message: `Failed to update persona - ${await response.text()}`,
});
}
}}
className="px-1 py-0.5 hover:bg-hover-light rounded flex cursor-pointer select-none w-fit"
>
<div className="my-auto w-12">
{!persona.is_visible ? (
<div className="text-error">Hidden</div>
) : (
"Visible"
)}
</div>
<div className="ml-1 my-auto">
<CustomCheckbox checked={persona.is_visible} />
</div>
</div>,
<div key="edit" className="flex">
<div className="mx-auto">
{!persona.default_persona ? (
<EditButton
onClick={() =>
router.push(`/admin/personas/${persona.id}`)
}
/>
) : (
"-"
)}
</div>
</div>,
],
staticModifiers: [[1, "lg:w-[300px] xl:w-[400px] 2xl:w-[550px]"]],
};
})}
setRows={updatePersonaOrder}
/>
<Divider />
{/* <TableBody>
{sortedPersonas.map((persona) => {
return (
<TableRow key={persona.id}>
<TableCell className="whitespace-normal break-all">
<DraggableRow key={persona.id}>
<TableCell className="whitespace-normal break-none">
<p className="text font-medium">{persona.name}</p>
</TableCell>
<TableCell>{persona.description}</TableCell>
<TableCell className="whitespace-normal break-all max-w-2xl">
{persona.description}
</TableCell>
<TableCell>{persona.default_persona ? "Yes" : "No"}</TableCell>
<TableCell>
{" "}
<div
onClick={async () => {
const response = await fetch(
`/api/admin/persona/${persona.id}/visible`,
{
method: "PATCH",
headers: {
"Content-Type": "application/json",
},
body: JSON.stringify({
is_visible: !persona.is_visible,
}),
}
);
if (response.ok) {
router.refresh();
} else {
setPopup({
type: "error",
message: `Failed to update persona - ${await response.text()}`,
});
}
}}
className="px-1 py-0.5 hover:bg-hover-light rounded flex cursor-pointer select-none w-fit"
>
<div className="my-auto w-12">
{!persona.is_visible ? (
<div className="text-error">Hidden</div>
) : (
"Visible"
)}
</div>
<div className="ml-1 my-auto">
<CustomCheckbox checked={persona.is_visible} />
</div>
</div>
</TableCell>
<TableCell>
{persona.is_visible ? (
<EditableValue
emptyDisplay="-"
initialValue={
persona.display_priority !== null
? persona.display_priority.toString()
: ""
}
onSubmit={async (value) => {
if (
value === (persona.display_priority || "").toString()
) {
return true;
}
const numericDisplayPriority = Number(value);
if (isNaN(numericDisplayPriority)) {
setPopup({
message: "Display priority must be a number",
type: "error",
});
return false;
}
const response = await fetch(
`/api/admin/persona/${persona.id}/display-priority`,
{
method: "PATCH",
headers: {
"Content-Type": "application/json",
},
body: JSON.stringify({
display_priority: numericDisplayPriority,
}),
}
);
if (!response.ok) {
setPopup({
message: `Failed to update display priority - ${await response.text()}`,
type: "error",
});
}
router.refresh();
return true;
}}
/>
) : (
"-"
)}
</TableCell>
<TableCell>
<div className="flex">
<div className="mx-auto">
@ -54,11 +258,10 @@ export function PersonasTable({ personas }: { personas: Persona[] }) {
</div>
</div>
</TableCell>
</TableRow>
</DraggableRow>
);
})}
</TableBody>
</Table>
</TableBody> */}
</div>
);
}

View File

@ -16,6 +16,8 @@ export interface Persona {
id: number;
name: string;
shared: boolean;
is_visible: boolean;
display_priority: number | null;
description: string;
document_sets: DocumentSet[];
prompts: Prompt[];

View File

@ -1,4 +1,4 @@
import { Prompt } from "./interfaces";
import { Persona, Prompt } from "./interfaces";
interface PersonaCreationRequest {
name: string;
@ -198,3 +198,26 @@ export function buildFinalPrompt(
return fetch(`/api/persona/utils/prompt-explorer?${queryString}`);
}
function smallerNumberFirstComparator(a: number, b: number) {
return a > b ? 1 : -1;
}
export function personaComparator(a: Persona, b: Persona) {
if (a.display_priority === null && b.display_priority === null) {
return smallerNumberFirstComparator(a.id, b.id);
}
if (a.display_priority !== b.display_priority) {
if (a.display_priority === null) {
return 1;
}
if (b.display_priority === null) {
return -1;
}
return smallerNumberFirstComparator(a.display_priority, b.display_priority);
}
return smallerNumberFirstComparator(a.id, b.id);
}

View File

@ -90,7 +90,7 @@ export const Chat = ({
? availablePersonas.find(
(persona) => persona.id === existingChatSessionPersonaId
)
: availablePersonas.find((persona) => persona.name === "Default")
: availablePersonas[0]
);
const filterManager = useFilters();

View File

@ -17,6 +17,7 @@ import { WelcomeModal } from "@/components/WelcomeModal";
import { ApiKeyModal } from "@/components/openai/ApiKeyModal";
import { cookies } from "next/headers";
import { DOCUMENT_SIDEBAR_WIDTH_COOKIE_NAME } from "@/components/resizable/contants";
import { personaComparator } from "../admin/personas/lib";
export default async function ChatPage({
chatId,
@ -112,6 +113,10 @@ export default async function ChatPage({
} else {
console.log(`Failed to fetch personas - ${personasResponse?.status}`);
}
// remove those marked as hidden by an admin
personas = personas.filter((persona) => persona.is_visible);
// sort them in priority order
personas.sort(personaComparator);
let messages: Message[] = [];
if (chatSessionMessagesResponse?.ok) {

View File

@ -12,6 +12,7 @@ import { Persona } from "../admin/personas/interfaces";
import { WelcomeModal } from "@/components/WelcomeModal";
import { unstable_noStore as noStore } from "next/cache";
import { InstantSSRAutoRefresh } from "@/components/SSRAutoRefresh";
import { personaComparator } from "../admin/personas/lib";
export default async function Home() {
// Disable caching so we always get the up to date connector / document set / persona info
@ -68,6 +69,10 @@ export default async function Home() {
} else {
console.log(`Failed to fetch personas - ${personaResponse?.status}`);
}
// remove those marked as hidden by an admin
personas = personas.filter((persona) => persona.is_visible);
// sort them in priority order
personas.sort(personaComparator);
// needs to be done in a non-client side component due to nextjs
const storedSearchType = cookies().get("searchType")?.value as

View File

@ -0,0 +1,73 @@
"use client";
import { useState } from "react";
import { FiEdit } from "react-icons/fi";
import { CheckmarkIcon } from "./icons/icons";
export function EditableValue({
initialValue,
onSubmit,
emptyDisplay,
consistentWidth = true,
}: {
initialValue: string;
onSubmit: (value: string) => Promise<boolean>;
emptyDisplay?: string;
consistentWidth?: boolean;
}) {
const [isOpen, setIsOpen] = useState(false);
const [editedValue, setEditedValue] = useState(initialValue);
if (isOpen) {
return (
<div className="my-auto h-full flex">
<input
value={editedValue}
onChange={(e) => {
setEditedValue(e.target.value);
}}
onKeyDown={async (e) => {
if (e.key === "Enter") {
const success = await onSubmit(editedValue);
if (success) {
setIsOpen(false);
}
}
if (e.key === "Escape") {
setIsOpen(false);
onSubmit(initialValue);
}
}}
className="border bg-background-strong border-gray-300 rounded py-1 px-1 w-12 h-4 my-auto"
/>
<div
onClick={async () => {
const success = await onSubmit(editedValue);
if (success) {
setIsOpen(false);
}
}}
className="cursor-pointer my-auto ml-2"
>
<CheckmarkIcon size={16} className="text-green-700" />
</div>
</div>
);
}
return (
<div className="h-full flex flex-col">
<div
className="flex my-auto cursor-pointer hover:bg-hover rounded"
onClick={() => setIsOpen(true)}
>
<div className={"flex " + (consistentWidth && " w-6")}>
<div className="ml-auto my-auto">{initialValue || emptyDisplay}</div>
</div>
<div className="cursor-pointer ml-2 my-auto h-4">
<FiEdit size={16} />
</div>
</div>
</div>
);
}

View File

@ -66,11 +66,8 @@ export const SearchSection = ({
const [selectedSearchType, setSelectedSearchType] =
useState<SearchType>(defaultSearchType);
const defaultPersona = personas.find(
(persona) => persona.name === "Danswer" && persona.default_persona
);
const [selectedPersona, setSelectedPersona] = useState<number>(
defaultPersona?.id || 0
personas[0]?.id || 0
);
// Overrides for default behavior that only last a single query

View File

@ -0,0 +1,15 @@
import React from "react";
import { MdDragIndicator } from "react-icons/md";
export const DragHandle = (props: any) => {
return (
<div
className={
props.isDragging ? "hover:cursor-grabbing" : "hover:cursor-grab"
}
{...props}
>
<MdDragIndicator />
</div>
);
};

View File

@ -0,0 +1,47 @@
import { useSortable } from "@dnd-kit/sortable";
import { TableCell, TableRow } from "@tremor/react";
import { CSS } from "@dnd-kit/utilities";
import { DragHandle } from "./DragHandle";
import { Row } from "./interfaces";
export function DraggableRow({
row,
forceDragging,
}: {
row: Row;
forceDragging?: boolean;
}) {
const {
attributes,
listeners,
transform,
transition,
setNodeRef,
isDragging,
} = useSortable({
id: row.id,
});
const style = {
transform: CSS.Transform.toString(transform),
transition: transition,
};
return (
<TableRow
ref={setNodeRef}
style={style}
className={isDragging ? "invisible" : "bg-background"}
>
<TableCell>
<DragHandle
isDragging={isDragging || forceDragging}
{...attributes}
{...listeners}
/>
</TableCell>
{row.cells.map((column, ind) => (
<TableCell key={ind}>{column}</TableCell>
))}
</TableRow>
);
}

View File

@ -0,0 +1,117 @@
import {
Table,
TableHead,
TableRow,
TableHeaderCell,
TableBody,
TableCell,
} from "@tremor/react";
import { DraggableTableBody } from "./DraggableTableBody";
import React, { useMemo, useState } from "react";
import {
closestCenter,
DndContext,
DragEndEvent,
DragOverlay,
DragStartEvent,
KeyboardSensor,
MouseSensor,
TouchSensor,
UniqueIdentifier,
useSensor,
useSensors,
} from "@dnd-kit/core";
import { restrictToVerticalAxis } from "@dnd-kit/modifiers";
import {
arrayMove,
SortableContext,
verticalListSortingStrategy,
} from "@dnd-kit/sortable";
import { DraggableRow } from "./DraggableRow";
import { Row } from "./interfaces";
import { StaticRow } from "./StaticRow";
export function DraggableTable({
headers,
rows,
setRows,
}: {
headers: (string | JSX.Element | null)[];
rows: Row[];
setRows: (newRows: UniqueIdentifier[]) => void | Promise<void>;
}) {
const [activeId, setActiveId] = useState<UniqueIdentifier | null>();
const items = useMemo(() => rows?.map(({ id }) => id), [rows]);
const sensors = useSensors(
useSensor(MouseSensor, {}),
useSensor(TouchSensor, {}),
useSensor(KeyboardSensor, {})
);
function handleDragStart(event: DragStartEvent) {
setActiveId(event.active.id);
}
function handleDragEnd(event: DragEndEvent) {
const { active, over } = event;
if (over !== null && active.id !== over.id) {
const oldIndex = items.indexOf(active.id);
const newIndex = items.indexOf(over.id);
setRows(arrayMove(rows, oldIndex, newIndex).map((row) => row.id));
}
setActiveId(null);
}
function handleDragCancel() {
setActiveId(null);
}
const selectedRow = useMemo(() => {
if (activeId === null || activeId === undefined) {
return null;
}
const row = rows.find(({ id }) => id === activeId);
return row;
}, [activeId, rows]);
return (
<DndContext
sensors={sensors}
onDragEnd={handleDragEnd}
onDragStart={handleDragStart}
onDragCancel={handleDragCancel}
collisionDetection={closestCenter}
modifiers={[restrictToVerticalAxis]}
>
<Table className="overflow-y-visible">
<TableHead>
<TableRow>
<TableHeaderCell></TableHeaderCell>
{headers.map((header, ind) => (
<TableHeaderCell key={ind}>{header}</TableHeaderCell>
))}
</TableRow>
</TableHead>
<TableBody>
<SortableContext items={items} strategy={verticalListSortingStrategy}>
{rows.map((row) => {
return <DraggableRow key={row.id} row={row} />;
})}
</SortableContext>
<DragOverlay>
{selectedRow && (
<Table className="overflow-y-visible">
<TableBody>
<StaticRow key={selectedRow.id} row={selectedRow} />
</TableBody>
</Table>
)}
</DragOverlay>
</TableBody>
</Table>
</DndContext>
);
}

View File

@ -0,0 +1,93 @@
import React, { useMemo, useState } from "react";
import {
closestCenter,
DndContext,
DragEndEvent,
DragOverlay,
DragStartEvent,
KeyboardSensor,
MouseSensor,
TouchSensor,
UniqueIdentifier,
useSensor,
useSensors,
} from "@dnd-kit/core";
import { restrictToVerticalAxis } from "@dnd-kit/modifiers";
import {
arrayMove,
SortableContext,
verticalListSortingStrategy,
} from "@dnd-kit/sortable";
import { TableBody } from "@tremor/react";
import { DraggableRow } from "./DraggableRow";
import { Row } from "./interfaces";
export function DraggableTableBody({
rows,
setRows,
}: {
rows: Row[];
setRows: React.Dispatch<React.SetStateAction<UniqueIdentifier[]>>;
}) {
const [activeId, setActiveId] = useState<UniqueIdentifier | null>();
const items = useMemo(() => rows?.map(({ id }) => id), [rows]);
const sensors = useSensors(
useSensor(MouseSensor, {}),
useSensor(TouchSensor, {}),
useSensor(KeyboardSensor, {})
);
function handleDragStart(event: DragStartEvent) {
setActiveId(event.active.id);
}
function handleDragEnd(event: DragEndEvent) {
const { active, over } = event;
if (over !== null && active.id !== over.id) {
setRows((oldRows) => {
const oldIndex = items.indexOf(active.id);
const newIndex = items.indexOf(over.id);
return arrayMove(oldRows, oldIndex, newIndex);
});
}
setActiveId(null);
}
function handleDragCancel() {
setActiveId(null);
}
const selectedRow = useMemo(() => {
if (activeId === null || activeId === undefined) {
return null;
}
const row = rows.find(({ id }) => id === activeId);
return row;
}, [activeId, rows]);
// Render the UI for your table
return (
<DndContext
sensors={sensors}
onDragEnd={handleDragEnd}
onDragStart={handleDragStart}
onDragCancel={handleDragCancel}
collisionDetection={closestCenter}
modifiers={[restrictToVerticalAxis]}
>
<TableBody>
<SortableContext items={items} strategy={verticalListSortingStrategy}>
{rows.map((row) => {
return <DraggableRow key={row.id} row={row} />;
})}
</SortableContext>
<DragOverlay>
{selectedRow && (
<DraggableRow key={selectedRow.id} row={selectedRow} />
)}
</DragOverlay>
</TableBody>
</DndContext>
);
}

View File

@ -0,0 +1,23 @@
import { TableCell, TableRow } from "@tremor/react";
import { DragHandle } from "./DragHandle";
import { Row } from "./interfaces";
export function StaticRow({ row }: { row: Row }) {
return (
<TableRow className="bg-background border-b border-border">
<TableCell>
<DragHandle isDragging />
</TableCell>
{row.cells.map((column, ind) => {
const rowModifier =
row.staticModifiers &&
row.staticModifiers.find((mod) => mod[0] === ind);
return (
<TableCell key={ind} className={rowModifier && rowModifier[1]}>
{column}
</TableCell>
);
})}
</TableRow>
);
}

View File

@ -0,0 +1,7 @@
import { UniqueIdentifier } from "@dnd-kit/core";
export interface Row {
id: UniqueIdentifier;
cells: (JSX.Element | string)[];
staticModifiers?: [number, string][];
}