Adds runtime validation for operation inputs.

This commit is contained in:
Mihovil Ilakovac
2025-02-20 22:22:38 +01:00
parent 40f6b72290
commit ae3782bd15
10 changed files with 233 additions and 107 deletions

View File

@@ -1,3 +1,4 @@
import * as z from 'zod';
import type { Task, GptResponse } from 'wasp/entities';
import type {
GenerateGptResponse,
@@ -10,6 +11,7 @@ import type {
import { HttpError } from 'wasp/server';
import { GeneratedSchedule } from './schedule';
import OpenAI from 'openai';
import { ensureArgsSchemaOrThrowHttpError } from '../server/validation';
const openai = setupOpenAI();
function setupOpenAI() {
@@ -20,15 +22,23 @@ function setupOpenAI() {
}
//#region Actions
type GptPayload = {
hours: string;
};
export const generateGptResponse: GenerateGptResponse<GptPayload, GeneratedSchedule> = async ({ hours }, context) => {
const generateGptResponseInputSchema = z.object({
hours: z.string().regex(/^\d+(\.\d+)?$/, 'Hours must be a number'),
});
type GenerateGptResponseInput = z.infer<typeof generateGptResponseInputSchema>;
export const generateGptResponse: GenerateGptResponse<GenerateGptResponseInput, GeneratedSchedule> = async (
rawArgs: unknown,
context
) => {
if (!context.user) {
throw new HttpError(401);
}
const args = ensureArgsSchemaOrThrowHttpError(generateGptResponseInputSchema, rawArgs);
const tasks = await context.entities.Task.findMany({
where: {
user: {
@@ -79,7 +89,9 @@ export const generateGptResponse: GenerateGptResponse<GptPayload, GeneratedSched
},
{
role: 'user',
content: `I will work ${hours} hours today. Here are the tasks I have to complete: ${JSON.stringify(
content: `I will work ${
args.hours
} hours today. Here are the tasks I have to complete: ${JSON.stringify(
parsedTasks
)}. Please help me plan my day by breaking the tasks down into actionable subtasks with time and priority status.`,
},
@@ -181,14 +193,22 @@ export const generateGptResponse: GenerateGptResponse<GptPayload, GeneratedSched
}
};
export const createTask: CreateTask<Pick<Task, 'description'>, Task> = async ({ description }, context) => {
const createTaskInputSchema = z.object({
description: z.string().nonempty(),
});
type CreateTaskInput = z.infer<typeof createTaskInputSchema>;
export const createTask: CreateTask<CreateTaskInput, Task> = async (rawArgs: unknown, context) => {
if (!context.user) {
throw new HttpError(401);
}
const args = ensureArgsSchemaOrThrowHttpError(createTaskInputSchema, rawArgs);
const task = await context.entities.Task.create({
data: {
description,
description: args.description,
user: { connect: { id: context.user.id } },
},
});
@@ -196,32 +216,50 @@ export const createTask: CreateTask<Pick<Task, 'description'>, Task> = async ({
return task;
};
export const updateTask: UpdateTask<Partial<Task>, Task> = async ({ id, isDone, time }, context) => {
const updateTaskInputSchema = z.object({
id: z.string().nonempty(),
isDone: z.boolean().optional(),
time: z.string().optional(),
});
type UpdateTaskInput = z.infer<typeof updateTaskInputSchema>;
export const updateTask: UpdateTask<UpdateTaskInput, Task> = async (rawArgs: unknown, context) => {
if (!context.user) {
throw new HttpError(401);
}
const args = ensureArgsSchemaOrThrowHttpError(updateTaskInputSchema, rawArgs);
const task = await context.entities.Task.update({
where: {
id,
id: args.id,
},
data: {
isDone,
time,
isDone: args.isDone,
time: args.time,
},
});
return task;
};
export const deleteTask: DeleteTask<Pick<Task, 'id'>, Task> = async ({ id }, context) => {
const deleteTaskInputSchema = z.object({
id: z.string().nonempty(),
});
type DeleteTaskInput = z.infer<typeof deleteTaskInputSchema>;
export const deleteTask: DeleteTask<DeleteTaskInput, Task> = async (rawArgs: unknown, context) => {
if (!context.user) {
throw new HttpError(401);
}
const args = ensureArgsSchemaOrThrowHttpError(deleteTaskInputSchema, rawArgs);
const task = await context.entities.Task.delete({
where: {
id,
id: args.id,
},
});

View File

@@ -2,7 +2,8 @@ import { cn } from '../client/cn';
import { useState, useEffect, FormEvent } from 'react';
import type { File } from 'wasp/entities';
import { useQuery, getAllFilesByUser, getDownloadFileSignedURL } from 'wasp/client/operations';
import { type FileUploadError, uploadFileWithProgress, validateFile, ALLOWED_FILE_TYPES } from './fileUploading';
import { type FileUploadError, parseValidFile, uploadFileWithProgress } from './fileUploading';
import { ALLOWED_FILE_TYPES } from './validation';
export default function FileUploadPage() {
const [fileKeyForS3, setFileKeyForS3] = useState<File['key']>('');
@@ -64,13 +65,12 @@ export default function FileUploadPage() {
return;
}
const validationError = validateFile(file);
if (validationError) {
setUploadError(validationError);
const validFileResult = parseValidFile(file);
if (validFileResult.kind === 'error') {
setUploadError(validFileResult.error);
return;
}
await uploadFileWithProgress({ file, setUploadProgressPercent });
await uploadFileWithProgress({ file: validFileResult.file, setUploadProgressPercent });
formElement.reset();
allUserFiles.refetch();
} catch (error) {
@@ -117,11 +117,11 @@ export default function FileUploadPage() {
<>
<span>Uploading {uploadProgressPercent}%</span>
<div
role="progressbar"
role='progressbar'
aria-valuenow={uploadProgressPercent}
aria-valuemin={0}
aria-valuemax={100}
className="absolute bottom-0 left-0 h-1 bg-yellow-500 transition-all duration-300 ease-in-out rounded-b-md"
className='absolute bottom-0 left-0 h-1 bg-yellow-500 transition-all duration-300 ease-in-out rounded-b-md'
style={{ width: `${uploadProgressPercent}%` }}
></div>
</>

View File

@@ -1,30 +1,16 @@
import { Dispatch, SetStateAction } from 'react';
import { createFile } from 'wasp/client/operations';
import axios from 'axios';
import { ALLOWED_FILE_TYPES, MAX_FILE_SIZE } from './validation';
interface FileUploadProgress {
file: File;
file: FileWithValidType;
setUploadProgressPercent: Dispatch<SetStateAction<number>>;
}
export interface FileUploadError {
message: string;
code: 'NO_FILE' | 'INVALID_FILE_TYPE' | 'FILE_TOO_LARGE' | 'UPLOAD_FAILED';
}
export const MAX_FILE_SIZE = 5 * 1024 * 1024; // Set this to the max file size you want to allow (currently 5MB).
export const ALLOWED_FILE_TYPES = [
'image/jpeg',
'image/png',
'application/pdf',
'text/*',
'video/quicktime',
'video/mp4',
];
export async function uploadFileWithProgress({ file, setUploadProgressPercent }: FileUploadProgress) {
const { uploadUrl } = await createFile({ fileType: file.type, name: file.name });
return await axios.put(uploadUrl, file, {
const { uploadUrl } = await createFile({ fileType: file.type, fileName: file.name });
return axios.put(uploadUrl, file, {
headers: {
'Content-Type': file.type,
},
@@ -37,18 +23,48 @@ export async function uploadFileWithProgress({ file, setUploadProgressPercent }:
});
}
export function validateFile(file: File): FileUploadError | null {
export interface FileUploadError {
message: string;
code: 'NO_FILE' | 'INVALID_FILE_TYPE' | 'FILE_TOO_LARGE' | 'UPLOAD_FAILED';
}
type AllowedFileType = (typeof ALLOWED_FILE_TYPES)[number];
type FileWithValidType = Omit<File, 'type'> & { type: AllowedFileType };
type FileParseResult =
| { kind: 'success'; file: FileWithValidType }
| {
kind: 'error';
error: { message: string; code: 'INVALID_FILE_TYPE' | 'FILE_TOO_LARGE' };
};
export function parseValidFile(file: File): FileParseResult {
if (file.size > MAX_FILE_SIZE) {
return {
kind: 'error',
error: {
message: `File size exceeds ${MAX_FILE_SIZE / 1024 / 1024}MB limit.`,
code: 'FILE_TOO_LARGE',
},
};
}
if (!ALLOWED_FILE_TYPES.includes(file.type)) {
if (!isAllowedFileType(file.type)) {
return {
kind: 'error',
error: {
message: `File type '${file.type}' is not supported.`,
code: 'INVALID_FILE_TYPE',
},
};
}
return null;
return {
kind: 'success',
file: file as FileWithValidType,
};
}
function isAllowedFileType(fileType: string): fileType is AllowedFileType {
return (ALLOWED_FILE_TYPES as readonly string[]).includes(fileType);
}

View File

@@ -1,3 +1,4 @@
import * as z from 'zod';
import { HttpError } from 'wasp/server';
import { type File } from 'wasp/entities';
import {
@@ -7,27 +8,35 @@ import {
} from 'wasp/server/operations';
import { getUploadFileSignedURLFromS3, getDownloadFileSignedURLFromS3 } from './s3Utils';
import { ensureArgsSchemaOrThrowHttpError } from '../server/validation';
import { ALLOWED_FILE_TYPES } from './validation';
type FileDescription = {
fileType: string;
name: string;
};
const createFileInputSchema = z.object({
fileType: z.enum(ALLOWED_FILE_TYPES),
fileName: z.string().nonempty(),
});
export const createFile: CreateFile<FileDescription, File> = async ({ fileType, name }, context) => {
type CreateFileInput = z.infer<typeof createFileInputSchema>;
export const createFile: CreateFile<CreateFileInput, File> = async (rawArgs: unknown, context) => {
if (!context.user) {
throw new HttpError(401);
}
const userInfo = context.user.id;
const args = ensureArgsSchemaOrThrowHttpError(createFileInputSchema, rawArgs);
const { uploadUrl, key } = await getUploadFileSignedURLFromS3({ fileType, userInfo });
const { uploadUrl, key } = await getUploadFileSignedURLFromS3({
fileType: args.fileType,
fileName: args.fileName,
userId: context.user.id,
});
return await context.entities.File.create({
data: {
name,
name: args.fileName,
key,
uploadUrl,
type: fileType,
type: args.fileType,
user: { connect: { id: context.user.id } },
},
});
@@ -49,9 +58,14 @@ export const getAllFilesByUser: GetAllFilesByUser<void, File[]> = async (_args,
});
};
export const getDownloadFileSignedURL: GetDownloadFileSignedURL<{ key: string }, string> = async (
{ key },
_context
) => {
return await getDownloadFileSignedURLFromS3({ key });
const getDownloadFileSignedURLInputSchema = z.object({ key: z.string().nonempty() });
type GetDownloadFileSignedURLInput = z.infer<typeof getDownloadFileSignedURLInputSchema>;
export const getDownloadFileSignedURL: GetDownloadFileSignedURL<
GetDownloadFileSignedURLInput,
string
> = async (rawArgs: unknown, _context) => {
const args = ensureArgsSchemaOrThrowHttpError(getDownloadFileSignedURLInputSchema, rawArgs);
return await getDownloadFileSignedURLFromS3({ key: args.key });
};

View File

@@ -1,3 +1,4 @@
import * as path from 'path';
import { randomUUID } from 'crypto';
import { S3Client } from '@aws-sdk/client-s3';
import { GetObjectCommand, PutObjectCommand } from '@aws-sdk/client-s3';
@@ -13,27 +14,30 @@ const s3Client = new S3Client({
type S3Upload = {
fileType: string;
userInfo: string;
}
export const getUploadFileSignedURLFromS3 = async ({fileType, userInfo}: S3Upload) => {
const ex = fileType.split('/')[1];
const Key = `${userInfo}/${randomUUID()}.${ex}`;
const s3Params = {
Bucket: process.env.AWS_S3_FILES_BUCKET,
Key,
ContentType: `${fileType}`,
fileName: string;
userId: string;
};
const command = new PutObjectCommand(s3Params);
const uploadUrl = await getSignedUrl(s3Client, command, { expiresIn: 3600,});
return { uploadUrl, key: Key };
}
export const getDownloadFileSignedURLFromS3 = async ({ key }: { key: string }) => {
const s3Params = {
export const getUploadFileSignedURLFromS3 = async ({ fileName, fileType, userId }: S3Upload) => {
const key = getS3Key(fileName, userId);
const command = new PutObjectCommand({
Bucket: process.env.AWS_S3_FILES_BUCKET,
Key: key,
ContentType: fileType,
});
const uploadUrl = await getSignedUrl(s3Client, command, { expiresIn: 3600 });
return { uploadUrl, key };
};
const command = new GetObjectCommand(s3Params);
export const getDownloadFileSignedURLFromS3 = async ({ key }: { key: string }) => {
const command = new GetObjectCommand({
Bucket: process.env.AWS_S3_FILES_BUCKET,
Key: key,
});
return await getSignedUrl(s3Client, command, { expiresIn: 3600 });
};
function getS3Key(fileName: string, userId: string) {
const ext = path.extname(fileName).slice(1);
return `${userId}/${randomUUID()}.${ext}`;
}

View File

@@ -0,0 +1,9 @@
export const MAX_FILE_SIZE = 5 * 1024 * 1024; // Set this to the max file size you want to allow (currently 5MB).
export const ALLOWED_FILE_TYPES = [
'image/jpeg',
'image/png',
'application/pdf',
'text/*',
'video/quicktime',
'video/mp4',
] as const;

View File

@@ -1,20 +1,27 @@
import * as z from 'zod';
import type { GenerateCheckoutSession, GetCustomerPortalUrl } from 'wasp/server/operations';
import { PaymentPlanId, paymentPlans } from '../payment/plans';
import { paymentProcessor } from './paymentProcessor';
import { HttpError } from 'wasp/server';
import { ensureArgsSchemaOrThrowHttpError } from '../server/validation';
export type CheckoutSession = {
sessionUrl: string | null;
sessionId: string;
};
export const generateCheckoutSession: GenerateCheckoutSession<PaymentPlanId, CheckoutSession> = async (
paymentPlanId,
context
) => {
const generateCheckoutSessionSchema = z.nativeEnum(PaymentPlanId);
type GenerateCheckoutSessionInput = z.infer<typeof generateCheckoutSessionSchema>;
export const generateCheckoutSession: GenerateCheckoutSession<
GenerateCheckoutSessionInput,
CheckoutSession
> = async (rawPaymentPlanId, context) => {
if (!context.user) {
throw new HttpError(401);
}
const paymentPlanId = ensureArgsSchemaOrThrowHttpError(generateCheckoutSessionSchema, rawPaymentPlanId);
const userId = context.user.id;
const userEmail = context.user.email;
if (!userEmail) {
@@ -29,7 +36,7 @@ export const generateCheckoutSession: GenerateCheckoutSession<PaymentPlanId, Che
userId,
userEmail,
paymentPlan,
prismaUserDelegate: context.entities.User
prismaUserDelegate: context.entities.User,
});
return {

View File

@@ -1,6 +1,13 @@
import * as z from 'zod';
import { requireNodeEnvVar } from '../server/utils';
export type SubscriptionStatus = 'past_due' | 'cancel_at_period_end' | 'active' | 'deleted';
export const subscriptionStatusSchema = z
.literal('past_due')
.or(z.literal('cancel_at_period_end'))
.or(z.literal('active'))
.or(z.literal('deleted'));
export type SubscriptionStatus = z.infer<typeof subscriptionStatusSchema>;
export enum PaymentPlanId {
Hobby = 'hobby',

View File

@@ -0,0 +1,14 @@
import { HttpError } from 'wasp/server';
import * as z from 'zod';
export function ensureArgsSchemaOrThrowHttpError<Schema extends z.ZodType<any, any>>(
schema: Schema,
rawArgs: unknown
): z.infer<Schema> {
const parseResult = schema.safeParse(rawArgs);
if (!parseResult.success) {
console.error(parseResult.error);
throw new HttpError(400, 'Operation arguments validation failed', { errors: parseResult.error.errors });
}
return parseResult.data;
}

View File

@@ -1,15 +1,25 @@
import {
type UpdateIsUserAdminById,
type GetPaginatedUsers,
} from 'wasp/server/operations';
import * as z from 'zod';
import { type UpdateIsUserAdminById, type GetPaginatedUsers } from 'wasp/server/operations';
import { type User } from 'wasp/entities';
import { HttpError } from 'wasp/server';
import { type SubscriptionStatus } from '../payment/plans';
import { subscriptionStatusSchema, type SubscriptionStatus } from '../payment/plans';
import { ensureArgsSchemaOrThrowHttpError } from '../server/validation';
export const updateIsUserAdminById: UpdateIsUserAdminById<{ id: string; data: Pick<User, 'isAdmin'> }, User> = async (
{ id, data },
const updateUserAdminByIdInputSchema = z.object({
id: z.string().nonempty(),
data: z.object({
isAdmin: z.boolean(),
}),
});
type UpdateUserAdminByIdInput = z.infer<typeof updateUserAdminByIdInputSchema>;
export const updateIsUserAdminById: UpdateIsUserAdminById<UpdateUserAdminByIdInput, User> = async (
rawArgs: unknown,
context
) => {
const args = ensureArgsSchemaOrThrowHttpError(updateUserAdminByIdInputSchema, rawArgs);
if (!context.user) {
throw new HttpError(401);
}
@@ -20,39 +30,46 @@ export const updateIsUserAdminById: UpdateIsUserAdminById<{ id: string; data: Pi
const updatedUser = await context.entities.User.update({
where: {
id,
id: args.id,
},
data: {
isAdmin: data.isAdmin,
isAdmin: args.data.isAdmin,
},
});
return updatedUser;
};
type GetPaginatedUsersInput = {
skip: number;
cursor?: number | undefined;
emailContains?: string;
isAdmin?: boolean;
subscriptionStatus?: SubscriptionStatus[];
};
type GetPaginatedUsersOutput = {
users: Pick<User, 'id' | 'email' | 'username' | 'subscriptionStatus' | 'paymentProcessorUserId'>[];
totalPages: number;
};
const getPaginatorArgsSchema = z.object({
skip: z.number(),
cursor: z.number().optional(),
emailContains: z.string().nonempty().optional(),
isAdmin: z.boolean().optional(),
subscriptionStatus: z.array(subscriptionStatusSchema).optional(),
});
type GetPaginatedUsersInput = z.infer<typeof getPaginatorArgsSchema>;
export const getPaginatedUsers: GetPaginatedUsers<GetPaginatedUsersInput, GetPaginatedUsersOutput> = async (
args,
rawArgs: unknown,
context
) => {
const args = ensureArgsSchemaOrThrowHttpError(getPaginatorArgsSchema, rawArgs);
if (!context.user?.isAdmin) {
throw new HttpError(401);
}
const allSubscriptionStatusOptions = args.subscriptionStatus as Array<string | null> | undefined;
const hasNotSubscribed = allSubscriptionStatusOptions?.find((status) => status === null)
let subscriptionStatusStrings = allSubscriptionStatusOptions?.filter((status) => status !== null) as string[] | undefined
const allSubscriptionStatusOptions = args.subscriptionStatus;
const hasNotSubscribed = allSubscriptionStatusOptions?.find((status) => status === null);
let subscriptionStatusStrings = allSubscriptionStatusOptions?.filter((status) => status !== null) as
| string[]
| undefined;
const queryResults = await context.entities.User.findMany({
skip: args.skip,