diff --git a/libavfilter/dnn/dnn_backend_torch.cpp b/libavfilter/dnn/dnn_backend_torch.cpp index b70e4db600..0aa0b95269 100644 --- a/libavfilter/dnn/dnn_backend_torch.cpp +++ b/libavfilter/dnn/dnn_backend_torch.cpp @@ -37,8 +37,8 @@ extern "C" { } typedef struct THModel { + DNNModel model; DnnContext *ctx; - DNNModel *model; torch::jit::Module *jit_model; SafeQueue *request_queue; Queue *task_queue; @@ -141,7 +141,7 @@ static void dnn_free_model_th(DNNModel **model) ff_queue_destroy(th_model->task_queue); delete th_model->jit_model; av_freep(&th_model); - av_freep(model); + *model = NULL; } static int get_input_th(void *model, DNNData *input, const char *input_name) @@ -195,19 +195,19 @@ static int fill_model_input_th(THModel *th_model, THRequestItem *request) infer_request->input_tensor = new torch::Tensor(); infer_request->output = new torch::Tensor(); - switch (th_model->model->func_type) { + switch (th_model->model.func_type) { case DFT_PROCESS_FRAME: input.scale = 255; if (task->do_ioproc) { - if (th_model->model->frame_pre_proc != NULL) { - th_model->model->frame_pre_proc(task->in_frame, &input, th_model->model->filter_ctx); + if (th_model->model.frame_pre_proc != NULL) { + th_model->model.frame_pre_proc(task->in_frame, &input, th_model->model.filter_ctx); } else { ff_proc_from_frame_to_dnn(task->in_frame, &input, ctx); } } break; default: - avpriv_report_missing_feature(NULL, "model function type %d", th_model->model->func_type); + avpriv_report_missing_feature(NULL, "model function type %d", th_model->model.func_type); break; } *infer_request->input_tensor = torch::from_blob(input.data, @@ -282,13 +282,13 @@ static void infer_completion_callback(void *args) { goto err; } - switch (th_model->model->func_type) { + switch (th_model->model.func_type) { case DFT_PROCESS_FRAME: if (task->do_ioproc) { outputs.scale = 255; outputs.data = output->data_ptr(); - if (th_model->model->frame_post_proc != NULL) { - th_model->model->frame_post_proc(task->out_frame, &outputs, th_model->model->filter_ctx); + if (th_model->model.frame_post_proc != NULL) { + th_model->model.frame_post_proc(task->out_frame, &outputs, th_model->model.filter_ctx); } else { ff_proc_from_dnn_to_frame(task->out_frame, &outputs, th_model->ctx); } @@ -298,7 +298,7 @@ static void infer_completion_callback(void *args) { } break; default: - avpriv_report_missing_feature(th_model->ctx, "model function type %d", th_model->model->func_type); + avpriv_report_missing_feature(th_model->ctx, "model function type %d", th_model->model.func_type); goto err; } task->inference_done++; @@ -417,17 +417,10 @@ static DNNModel *dnn_load_model_th(DnnContext *ctx, DNNFunctionType func_type, A THRequestItem *item = NULL; const char *device_name = ctx->device ? ctx->device : "cpu"; - model = (DNNModel *)av_mallocz(sizeof(DNNModel)); - if (!model) { - return NULL; - } - th_model = (THModel *)av_mallocz(sizeof(THModel)); - if (!th_model) { - av_freep(&model); + if (!th_model) return NULL; - } - th_model->model = model; + model = &th_model->model; model->model = th_model; th_model->ctx = ctx;