libavfilter/dnn_interface: use dims to represent shapes
For detect and classify output, width and height make no sence, so change width, height to dims to represent the shape of tensor. Use layout and dims to get width, height and channel. Signed-off-by: Wenbin Chen <wenbin.chen@intel.com> Reviewed-by: Guo Yejun <yejun.guo@intel.com>
This commit is contained in:
@@ -253,9 +253,9 @@ static int fill_model_input_ov(OVModel *ov_model, OVRequestItem *request)
|
|||||||
ov_shape_free(&input_shape);
|
ov_shape_free(&input_shape);
|
||||||
return ov2_map_error(status, NULL);
|
return ov2_map_error(status, NULL);
|
||||||
}
|
}
|
||||||
input.height = dims[1];
|
for (int i = 0; i < input_shape.rank; i++)
|
||||||
input.width = dims[2];
|
input.dims[i] = dims[i];
|
||||||
input.channels = dims[3];
|
input.layout = DL_NHWC;
|
||||||
input.dt = precision_to_datatype(precision);
|
input.dt = precision_to_datatype(precision);
|
||||||
#else
|
#else
|
||||||
status = ie_infer_request_get_blob(request->infer_request, task->input_name, &input_blob);
|
status = ie_infer_request_get_blob(request->infer_request, task->input_name, &input_blob);
|
||||||
@@ -278,9 +278,9 @@ static int fill_model_input_ov(OVModel *ov_model, OVRequestItem *request)
|
|||||||
av_log(ctx, AV_LOG_ERROR, "Failed to get input blob buffer\n");
|
av_log(ctx, AV_LOG_ERROR, "Failed to get input blob buffer\n");
|
||||||
return DNN_GENERIC_ERROR;
|
return DNN_GENERIC_ERROR;
|
||||||
}
|
}
|
||||||
input.height = dims.dims[2];
|
for (int i = 0; i < input_shape.rank; i++)
|
||||||
input.width = dims.dims[3];
|
input.dims[i] = dims[i];
|
||||||
input.channels = dims.dims[1];
|
input.layout = DL_NCHW;
|
||||||
input.data = blob_buffer.buffer;
|
input.data = blob_buffer.buffer;
|
||||||
input.dt = precision_to_datatype(precision);
|
input.dt = precision_to_datatype(precision);
|
||||||
#endif
|
#endif
|
||||||
@@ -339,8 +339,8 @@ static int fill_model_input_ov(OVModel *ov_model, OVRequestItem *request)
|
|||||||
av_assert0(!"should not reach here");
|
av_assert0(!"should not reach here");
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
input.data = (uint8_t *)input.data
|
input.data = (uint8_t *)input.data +
|
||||||
+ input.width * input.height * input.channels * get_datatype_size(input.dt);
|
input.dims[1] * input.dims[2] * input.dims[3] * get_datatype_size(input.dt);
|
||||||
}
|
}
|
||||||
#if HAVE_OPENVINO2
|
#if HAVE_OPENVINO2
|
||||||
ov_tensor_free(tensor);
|
ov_tensor_free(tensor);
|
||||||
@@ -403,10 +403,11 @@ static void infer_completion_callback(void *args)
|
|||||||
goto end;
|
goto end;
|
||||||
}
|
}
|
||||||
outputs[i].dt = precision_to_datatype(precision);
|
outputs[i].dt = precision_to_datatype(precision);
|
||||||
|
outputs[i].layout = DL_NCHW;
|
||||||
outputs[i].channels = output_shape.rank > 2 ? dims[output_shape.rank - 3] : 1;
|
outputs[i].dims[0] = 1;
|
||||||
outputs[i].height = output_shape.rank > 1 ? dims[output_shape.rank - 2] : 1;
|
outputs[i].dims[1] = output_shape.rank > 2 ? dims[output_shape.rank - 3] : 1;
|
||||||
outputs[i].width = output_shape.rank > 0 ? dims[output_shape.rank - 1] : 1;
|
outputs[i].dims[2] = output_shape.rank > 1 ? dims[output_shape.rank - 2] : 1;
|
||||||
|
outputs[i].dims[3] = output_shape.rank > 0 ? dims[output_shape.rank - 1] : 1;
|
||||||
av_assert0(request->lltask_count <= dims[0]);
|
av_assert0(request->lltask_count <= dims[0]);
|
||||||
outputs[i].layout = ctx->options.layout;
|
outputs[i].layout = ctx->options.layout;
|
||||||
outputs[i].scale = ctx->options.scale;
|
outputs[i].scale = ctx->options.scale;
|
||||||
@@ -445,9 +446,9 @@ static void infer_completion_callback(void *args)
|
|||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
output.data = blob_buffer.buffer;
|
output.data = blob_buffer.buffer;
|
||||||
output.channels = dims.dims[1];
|
output.layout = DL_NCHW;
|
||||||
output.height = dims.dims[2];
|
for (int i = 0; i < 4; i++)
|
||||||
output.width = dims.dims[3];
|
output.dims[i] = dims.dims[i];
|
||||||
av_assert0(request->lltask_count <= dims.dims[0]);
|
av_assert0(request->lltask_count <= dims.dims[0]);
|
||||||
output.dt = precision_to_datatype(precision);
|
output.dt = precision_to_datatype(precision);
|
||||||
output.layout = ctx->options.layout;
|
output.layout = ctx->options.layout;
|
||||||
@@ -469,8 +470,10 @@ static void infer_completion_callback(void *args)
|
|||||||
ff_proc_from_dnn_to_frame(task->out_frame, outputs, ctx);
|
ff_proc_from_dnn_to_frame(task->out_frame, outputs, ctx);
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
task->out_frame->width = outputs[0].width;
|
task->out_frame->width =
|
||||||
task->out_frame->height = outputs[0].height;
|
outputs[0].dims[dnn_get_width_idx_by_layout(outputs[0].layout)];
|
||||||
|
task->out_frame->height =
|
||||||
|
outputs[0].dims[dnn_get_height_idx_by_layout(outputs[0].layout)];
|
||||||
}
|
}
|
||||||
break;
|
break;
|
||||||
case DFT_ANALYTICS_DETECT:
|
case DFT_ANALYTICS_DETECT:
|
||||||
@@ -501,7 +504,8 @@ static void infer_completion_callback(void *args)
|
|||||||
av_freep(&request->lltasks[i]);
|
av_freep(&request->lltasks[i]);
|
||||||
for (int i = 0; i < ov_model->nb_outputs; i++)
|
for (int i = 0; i < ov_model->nb_outputs; i++)
|
||||||
outputs[i].data = (uint8_t *)outputs[i].data +
|
outputs[i].data = (uint8_t *)outputs[i].data +
|
||||||
outputs[i].width * outputs[i].height * outputs[i].channels * get_datatype_size(outputs[i].dt);
|
outputs[i].dims[1] * outputs[i].dims[2] * outputs[i].dims[3] *
|
||||||
|
get_datatype_size(outputs[i].dt);
|
||||||
}
|
}
|
||||||
end:
|
end:
|
||||||
#if HAVE_OPENVINO2
|
#if HAVE_OPENVINO2
|
||||||
@@ -1085,7 +1089,6 @@ static int get_input_ov(void *model, DNNData *input, const char *input_name)
|
|||||||
#if HAVE_OPENVINO2
|
#if HAVE_OPENVINO2
|
||||||
ov_shape_t input_shape = {0};
|
ov_shape_t input_shape = {0};
|
||||||
ov_element_type_e precision;
|
ov_element_type_e precision;
|
||||||
int64_t* dims;
|
|
||||||
ov_status_e status;
|
ov_status_e status;
|
||||||
if (input_name)
|
if (input_name)
|
||||||
status = ov_model_const_input_by_name(ov_model->ov_model, input_name, &ov_model->input_port);
|
status = ov_model_const_input_by_name(ov_model->ov_model, input_name, &ov_model->input_port);
|
||||||
@@ -1105,16 +1108,18 @@ static int get_input_ov(void *model, DNNData *input, const char *input_name)
|
|||||||
av_log(ctx, AV_LOG_ERROR, "Failed to get input port shape.\n");
|
av_log(ctx, AV_LOG_ERROR, "Failed to get input port shape.\n");
|
||||||
return ov2_map_error(status, NULL);
|
return ov2_map_error(status, NULL);
|
||||||
}
|
}
|
||||||
dims = input_shape.dims;
|
for (int i = 0; i < 4; i++)
|
||||||
if (dims[1] <= 3) { // NCHW
|
input->dims[i] = input_shape.dims[i];
|
||||||
input->channels = dims[1];
|
if (input_resizable) {
|
||||||
input->height = input_resizable ? -1 : dims[2];
|
input->dims[dnn_get_width_idx_by_layout(input->layout)] = -1;
|
||||||
input->width = input_resizable ? -1 : dims[3];
|
input->dims[dnn_get_height_idx_by_layout(input->layout)] = -1;
|
||||||
} else { // NHWC
|
|
||||||
input->height = input_resizable ? -1 : dims[1];
|
|
||||||
input->width = input_resizable ? -1 : dims[2];
|
|
||||||
input->channels = dims[3];
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if (input_shape.dims[1] <= 3) // NCHW
|
||||||
|
input->layout = DL_NCHW;
|
||||||
|
else // NHWC
|
||||||
|
input->layout = DL_NHWC;
|
||||||
|
|
||||||
input->dt = precision_to_datatype(precision);
|
input->dt = precision_to_datatype(precision);
|
||||||
ov_shape_free(&input_shape);
|
ov_shape_free(&input_shape);
|
||||||
return 0;
|
return 0;
|
||||||
@@ -1144,15 +1149,18 @@ static int get_input_ov(void *model, DNNData *input, const char *input_name)
|
|||||||
return DNN_GENERIC_ERROR;
|
return DNN_GENERIC_ERROR;
|
||||||
}
|
}
|
||||||
|
|
||||||
if (dims[1] <= 3) { // NCHW
|
for (int i = 0; i < 4; i++)
|
||||||
input->channels = dims[1];
|
input->dims[i] = input_shape.dims[i];
|
||||||
input->height = input_resizable ? -1 : dims[2];
|
if (input_resizable) {
|
||||||
input->width = input_resizable ? -1 : dims[3];
|
input->dims[dnn_get_width_idx_by_layout(input->layout)] = -1;
|
||||||
} else { // NHWC
|
input->dims[dnn_get_height_idx_by_layout(input->layout)] = -1;
|
||||||
input->height = input_resizable ? -1 : dims[1];
|
|
||||||
input->width = input_resizable ? -1 : dims[2];
|
|
||||||
input->channels = dims[3];
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if (input_shape.dims[1] <= 3) // NCHW
|
||||||
|
input->layout = DL_NCHW;
|
||||||
|
else // NHWC
|
||||||
|
input->layout = DL_NHWC;
|
||||||
|
|
||||||
input->dt = precision_to_datatype(precision);
|
input->dt = precision_to_datatype(precision);
|
||||||
return 0;
|
return 0;
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -251,7 +251,12 @@ static TF_Tensor *allocate_input_tensor(const DNNData *input)
|
|||||||
{
|
{
|
||||||
TF_DataType dt;
|
TF_DataType dt;
|
||||||
size_t size;
|
size_t size;
|
||||||
int64_t input_dims[] = {1, input->height, input->width, input->channels};
|
int64_t input_dims[4] = { 0 };
|
||||||
|
|
||||||
|
input_dims[0] = 1;
|
||||||
|
input_dims[1] = input->dims[dnn_get_height_idx_by_layout(input->layout)];
|
||||||
|
input_dims[2] = input->dims[dnn_get_width_idx_by_layout(input->layout)];
|
||||||
|
input_dims[3] = input->dims[dnn_get_channel_idx_by_layout(input->layout)];
|
||||||
switch (input->dt) {
|
switch (input->dt) {
|
||||||
case DNN_FLOAT:
|
case DNN_FLOAT:
|
||||||
dt = TF_FLOAT;
|
dt = TF_FLOAT;
|
||||||
@@ -310,9 +315,9 @@ static int get_input_tf(void *model, DNNData *input, const char *input_name)
|
|||||||
|
|
||||||
// currently only NHWC is supported
|
// currently only NHWC is supported
|
||||||
av_assert0(dims[0] == 1 || dims[0] == -1);
|
av_assert0(dims[0] == 1 || dims[0] == -1);
|
||||||
input->height = dims[1];
|
for (int i = 0; i < 4; i++)
|
||||||
input->width = dims[2];
|
input->dims[i] = dims[i];
|
||||||
input->channels = dims[3];
|
input->layout = DL_NHWC;
|
||||||
|
|
||||||
return 0;
|
return 0;
|
||||||
}
|
}
|
||||||
@@ -640,8 +645,8 @@ static int fill_model_input_tf(TFModel *tf_model, TFRequestItem *request) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
infer_request = request->infer_request;
|
infer_request = request->infer_request;
|
||||||
input.height = task->in_frame->height;
|
input.dims[1] = task->in_frame->height;
|
||||||
input.width = task->in_frame->width;
|
input.dims[2] = task->in_frame->width;
|
||||||
|
|
||||||
infer_request->tf_input = av_malloc(sizeof(TF_Output));
|
infer_request->tf_input = av_malloc(sizeof(TF_Output));
|
||||||
if (!infer_request->tf_input) {
|
if (!infer_request->tf_input) {
|
||||||
@@ -731,9 +736,12 @@ static void infer_completion_callback(void *args) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
for (uint32_t i = 0; i < task->nb_output; ++i) {
|
for (uint32_t i = 0; i < task->nb_output; ++i) {
|
||||||
outputs[i].height = TF_Dim(infer_request->output_tensors[i], 1);
|
outputs[i].dims[dnn_get_height_idx_by_layout(outputs[i].layout)] =
|
||||||
outputs[i].width = TF_Dim(infer_request->output_tensors[i], 2);
|
TF_Dim(infer_request->output_tensors[i], 1);
|
||||||
outputs[i].channels = TF_Dim(infer_request->output_tensors[i], 3);
|
outputs[i].dims[dnn_get_width_idx_by_layout(outputs[i].layout)] =
|
||||||
|
TF_Dim(infer_request->output_tensors[i], 2);
|
||||||
|
outputs[i].dims[dnn_get_channel_idx_by_layout(outputs[i].layout)] =
|
||||||
|
TF_Dim(infer_request->output_tensors[i], 3);
|
||||||
outputs[i].data = TF_TensorData(infer_request->output_tensors[i]);
|
outputs[i].data = TF_TensorData(infer_request->output_tensors[i]);
|
||||||
outputs[i].dt = (DNNDataType)TF_TensorType(infer_request->output_tensors[i]);
|
outputs[i].dt = (DNNDataType)TF_TensorType(infer_request->output_tensors[i]);
|
||||||
}
|
}
|
||||||
@@ -747,8 +755,10 @@ static void infer_completion_callback(void *args) {
|
|||||||
ff_proc_from_dnn_to_frame(task->out_frame, outputs, ctx);
|
ff_proc_from_dnn_to_frame(task->out_frame, outputs, ctx);
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
task->out_frame->width = outputs[0].width;
|
task->out_frame->width =
|
||||||
task->out_frame->height = outputs[0].height;
|
outputs[0].dims[dnn_get_width_idx_by_layout(outputs[0].layout)];
|
||||||
|
task->out_frame->height =
|
||||||
|
outputs[0].dims[dnn_get_height_idx_by_layout(outputs[0].layout)];
|
||||||
}
|
}
|
||||||
break;
|
break;
|
||||||
case DFT_ANALYTICS_DETECT:
|
case DFT_ANALYTICS_DETECT:
|
||||||
|
|||||||
@@ -70,7 +70,7 @@ int ff_proc_from_dnn_to_frame(AVFrame *frame, DNNData *output, void *log_ctx)
|
|||||||
dst_data = (void **)frame->data;
|
dst_data = (void **)frame->data;
|
||||||
linesize[0] = frame->linesize[0];
|
linesize[0] = frame->linesize[0];
|
||||||
if (output->layout == DL_NCHW) {
|
if (output->layout == DL_NCHW) {
|
||||||
middle_data = av_malloc(plane_size * output->channels);
|
middle_data = av_malloc(plane_size * output->dims[1]);
|
||||||
if (!middle_data) {
|
if (!middle_data) {
|
||||||
ret = AVERROR(ENOMEM);
|
ret = AVERROR(ENOMEM);
|
||||||
goto err;
|
goto err;
|
||||||
@@ -209,7 +209,7 @@ int ff_proc_from_frame_to_dnn(AVFrame *frame, DNNData *input, void *log_ctx)
|
|||||||
src_data = (void **)frame->data;
|
src_data = (void **)frame->data;
|
||||||
linesize[0] = frame->linesize[0];
|
linesize[0] = frame->linesize[0];
|
||||||
if (input->layout == DL_NCHW) {
|
if (input->layout == DL_NCHW) {
|
||||||
middle_data = av_malloc(plane_size * input->channels);
|
middle_data = av_malloc(plane_size * input->dims[1]);
|
||||||
if (!middle_data) {
|
if (!middle_data) {
|
||||||
ret = AVERROR(ENOMEM);
|
ret = AVERROR(ENOMEM);
|
||||||
goto err;
|
goto err;
|
||||||
@@ -346,6 +346,7 @@ int ff_frame_to_dnn_classify(AVFrame *frame, DNNData *input, uint32_t bbox_index
|
|||||||
int ret = 0;
|
int ret = 0;
|
||||||
enum AVPixelFormat fmt;
|
enum AVPixelFormat fmt;
|
||||||
int left, top, width, height;
|
int left, top, width, height;
|
||||||
|
int width_idx, height_idx;
|
||||||
const AVDetectionBBoxHeader *header;
|
const AVDetectionBBoxHeader *header;
|
||||||
const AVDetectionBBox *bbox;
|
const AVDetectionBBox *bbox;
|
||||||
AVFrameSideData *sd = av_frame_get_side_data(frame, AV_FRAME_DATA_DETECTION_BBOXES);
|
AVFrameSideData *sd = av_frame_get_side_data(frame, AV_FRAME_DATA_DETECTION_BBOXES);
|
||||||
@@ -364,6 +365,9 @@ int ff_frame_to_dnn_classify(AVFrame *frame, DNNData *input, uint32_t bbox_index
|
|||||||
return AVERROR(ENOSYS);
|
return AVERROR(ENOSYS);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
width_idx = dnn_get_width_idx_by_layout(input->layout);
|
||||||
|
height_idx = dnn_get_height_idx_by_layout(input->layout);
|
||||||
|
|
||||||
header = (const AVDetectionBBoxHeader *)sd->data;
|
header = (const AVDetectionBBoxHeader *)sd->data;
|
||||||
bbox = av_get_detection_bbox(header, bbox_index);
|
bbox = av_get_detection_bbox(header, bbox_index);
|
||||||
|
|
||||||
@@ -374,17 +378,20 @@ int ff_frame_to_dnn_classify(AVFrame *frame, DNNData *input, uint32_t bbox_index
|
|||||||
|
|
||||||
fmt = get_pixel_format(input);
|
fmt = get_pixel_format(input);
|
||||||
sws_ctx = sws_getContext(width, height, frame->format,
|
sws_ctx = sws_getContext(width, height, frame->format,
|
||||||
input->width, input->height, fmt,
|
input->dims[width_idx],
|
||||||
|
input->dims[height_idx], fmt,
|
||||||
SWS_FAST_BILINEAR, NULL, NULL, NULL);
|
SWS_FAST_BILINEAR, NULL, NULL, NULL);
|
||||||
if (!sws_ctx) {
|
if (!sws_ctx) {
|
||||||
av_log(log_ctx, AV_LOG_ERROR, "Failed to create scale context for the conversion "
|
av_log(log_ctx, AV_LOG_ERROR, "Failed to create scale context for the conversion "
|
||||||
"fmt:%s s:%dx%d -> fmt:%s s:%dx%d\n",
|
"fmt:%s s:%dx%d -> fmt:%s s:%dx%d\n",
|
||||||
av_get_pix_fmt_name(frame->format), width, height,
|
av_get_pix_fmt_name(frame->format), width, height,
|
||||||
av_get_pix_fmt_name(fmt), input->width, input->height);
|
av_get_pix_fmt_name(fmt),
|
||||||
|
input->dims[width_idx],
|
||||||
|
input->dims[height_idx]);
|
||||||
return AVERROR(EINVAL);
|
return AVERROR(EINVAL);
|
||||||
}
|
}
|
||||||
|
|
||||||
ret = av_image_fill_linesizes(linesizes, fmt, input->width);
|
ret = av_image_fill_linesizes(linesizes, fmt, input->dims[width_idx]);
|
||||||
if (ret < 0) {
|
if (ret < 0) {
|
||||||
av_log(log_ctx, AV_LOG_ERROR, "unable to get linesizes with av_image_fill_linesizes");
|
av_log(log_ctx, AV_LOG_ERROR, "unable to get linesizes with av_image_fill_linesizes");
|
||||||
sws_freeContext(sws_ctx);
|
sws_freeContext(sws_ctx);
|
||||||
@@ -414,7 +421,7 @@ int ff_frame_to_dnn_detect(AVFrame *frame, DNNData *input, void *log_ctx)
|
|||||||
{
|
{
|
||||||
struct SwsContext *sws_ctx;
|
struct SwsContext *sws_ctx;
|
||||||
int linesizes[4];
|
int linesizes[4];
|
||||||
int ret = 0;
|
int ret = 0, width_idx, height_idx;
|
||||||
enum AVPixelFormat fmt = get_pixel_format(input);
|
enum AVPixelFormat fmt = get_pixel_format(input);
|
||||||
|
|
||||||
/* (scale != 1 and scale != 0) or mean != 0 */
|
/* (scale != 1 and scale != 0) or mean != 0 */
|
||||||
@@ -430,18 +437,23 @@ int ff_frame_to_dnn_detect(AVFrame *frame, DNNData *input, void *log_ctx)
|
|||||||
return AVERROR(ENOSYS);
|
return AVERROR(ENOSYS);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
width_idx = dnn_get_width_idx_by_layout(input->layout);
|
||||||
|
height_idx = dnn_get_height_idx_by_layout(input->layout);
|
||||||
|
|
||||||
sws_ctx = sws_getContext(frame->width, frame->height, frame->format,
|
sws_ctx = sws_getContext(frame->width, frame->height, frame->format,
|
||||||
input->width, input->height, fmt,
|
input->dims[width_idx],
|
||||||
|
input->dims[height_idx], fmt,
|
||||||
SWS_FAST_BILINEAR, NULL, NULL, NULL);
|
SWS_FAST_BILINEAR, NULL, NULL, NULL);
|
||||||
if (!sws_ctx) {
|
if (!sws_ctx) {
|
||||||
av_log(log_ctx, AV_LOG_ERROR, "Impossible to create scale context for the conversion "
|
av_log(log_ctx, AV_LOG_ERROR, "Impossible to create scale context for the conversion "
|
||||||
"fmt:%s s:%dx%d -> fmt:%s s:%dx%d\n",
|
"fmt:%s s:%dx%d -> fmt:%s s:%dx%d\n",
|
||||||
av_get_pix_fmt_name(frame->format), frame->width, frame->height,
|
av_get_pix_fmt_name(frame->format), frame->width, frame->height,
|
||||||
av_get_pix_fmt_name(fmt), input->width, input->height);
|
av_get_pix_fmt_name(fmt), input->dims[width_idx],
|
||||||
|
input->dims[height_idx]);
|
||||||
return AVERROR(EINVAL);
|
return AVERROR(EINVAL);
|
||||||
}
|
}
|
||||||
|
|
||||||
ret = av_image_fill_linesizes(linesizes, fmt, input->width);
|
ret = av_image_fill_linesizes(linesizes, fmt, input->dims[width_idx]);
|
||||||
if (ret < 0) {
|
if (ret < 0) {
|
||||||
av_log(log_ctx, AV_LOG_ERROR, "unable to get linesizes with av_image_fill_linesizes");
|
av_log(log_ctx, AV_LOG_ERROR, "unable to get linesizes with av_image_fill_linesizes");
|
||||||
sws_freeContext(sws_ctx);
|
sws_freeContext(sws_ctx);
|
||||||
|
|||||||
@@ -64,7 +64,7 @@ typedef enum {
|
|||||||
|
|
||||||
typedef struct DNNData{
|
typedef struct DNNData{
|
||||||
void *data;
|
void *data;
|
||||||
int width, height, channels;
|
int dims[4];
|
||||||
// dt and order together decide the color format
|
// dt and order together decide the color format
|
||||||
DNNDataType dt;
|
DNNDataType dt;
|
||||||
DNNColorOrder order;
|
DNNColorOrder order;
|
||||||
@@ -134,4 +134,19 @@ typedef struct DNNModule{
|
|||||||
// Initializes DNNModule depending on chosen backend.
|
// Initializes DNNModule depending on chosen backend.
|
||||||
const DNNModule *ff_get_dnn_module(DNNBackendType backend_type, void *log_ctx);
|
const DNNModule *ff_get_dnn_module(DNNBackendType backend_type, void *log_ctx);
|
||||||
|
|
||||||
|
static inline int dnn_get_width_idx_by_layout(DNNLayout layout)
|
||||||
|
{
|
||||||
|
return layout == DL_NHWC ? 2 : 3;
|
||||||
|
}
|
||||||
|
|
||||||
|
static inline int dnn_get_height_idx_by_layout(DNNLayout layout)
|
||||||
|
{
|
||||||
|
return layout == DL_NHWC ? 1 : 2;
|
||||||
|
}
|
||||||
|
|
||||||
|
static inline int dnn_get_channel_idx_by_layout(DNNLayout layout)
|
||||||
|
{
|
||||||
|
return layout == DL_NHWC ? 3 : 1;
|
||||||
|
}
|
||||||
|
|
||||||
#endif
|
#endif
|
||||||
|
|||||||
@@ -68,8 +68,8 @@ static int dnn_classify_post_proc(AVFrame *frame, DNNData *output, uint32_t bbox
|
|||||||
uint32_t label_id;
|
uint32_t label_id;
|
||||||
float confidence;
|
float confidence;
|
||||||
AVFrameSideData *sd;
|
AVFrameSideData *sd;
|
||||||
|
int output_size = output->dims[3] * output->dims[2] * output->dims[1];
|
||||||
if (output->channels <= 0) {
|
if (output_size <= 0) {
|
||||||
return -1;
|
return -1;
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -88,7 +88,7 @@ static int dnn_classify_post_proc(AVFrame *frame, DNNData *output, uint32_t bbox
|
|||||||
classifications = output->data;
|
classifications = output->data;
|
||||||
label_id = 0;
|
label_id = 0;
|
||||||
confidence= classifications[0];
|
confidence= classifications[0];
|
||||||
for (int i = 1; i < output->channels; i++) {
|
for (int i = 1; i < output_size; i++) {
|
||||||
if (classifications[i] > confidence) {
|
if (classifications[i] > confidence) {
|
||||||
label_id = i;
|
label_id = i;
|
||||||
confidence= classifications[i];
|
confidence= classifications[i];
|
||||||
|
|||||||
@@ -166,14 +166,14 @@ static int dnn_detect_parse_yolo_output(AVFrame *frame, DNNData *output, int out
|
|||||||
scale_w = cell_w;
|
scale_w = cell_w;
|
||||||
scale_h = cell_h;
|
scale_h = cell_h;
|
||||||
} else {
|
} else {
|
||||||
if (output[output_index].height != output[output_index].width &&
|
if (output[output_index].dims[2] != output[output_index].dims[3] &&
|
||||||
output[output_index].height == output[output_index].channels) {
|
output[output_index].dims[2] == output[output_index].dims[1]) {
|
||||||
is_NHWC = 1;
|
is_NHWC = 1;
|
||||||
cell_w = output[output_index].height;
|
cell_w = output[output_index].dims[2];
|
||||||
cell_h = output[output_index].channels;
|
cell_h = output[output_index].dims[1];
|
||||||
} else {
|
} else {
|
||||||
cell_w = output[output_index].width;
|
cell_w = output[output_index].dims[3];
|
||||||
cell_h = output[output_index].height;
|
cell_h = output[output_index].dims[2];
|
||||||
}
|
}
|
||||||
scale_w = ctx->scale_width;
|
scale_w = ctx->scale_width;
|
||||||
scale_h = ctx->scale_height;
|
scale_h = ctx->scale_height;
|
||||||
@@ -205,14 +205,14 @@ static int dnn_detect_parse_yolo_output(AVFrame *frame, DNNData *output, int out
|
|||||||
return AVERROR(EINVAL);
|
return AVERROR(EINVAL);
|
||||||
}
|
}
|
||||||
|
|
||||||
if (output[output_index].channels * output[output_index].width *
|
if (output[output_index].dims[1] * output[output_index].dims[2] *
|
||||||
output[output_index].height % (box_size * cell_w * cell_h)) {
|
output[output_index].dims[3] % (box_size * cell_w * cell_h)) {
|
||||||
av_log(filter_ctx, AV_LOG_ERROR, "wrong cell_w, cell_h or nb_classes\n");
|
av_log(filter_ctx, AV_LOG_ERROR, "wrong cell_w, cell_h or nb_classes\n");
|
||||||
return AVERROR(EINVAL);
|
return AVERROR(EINVAL);
|
||||||
}
|
}
|
||||||
detection_boxes = output[output_index].channels *
|
detection_boxes = output[output_index].dims[1] *
|
||||||
output[output_index].height *
|
output[output_index].dims[2] *
|
||||||
output[output_index].width / box_size / cell_w / cell_h;
|
output[output_index].dims[3] / box_size / cell_w / cell_h;
|
||||||
|
|
||||||
anchors = anchors + (detection_boxes * output_index * 2);
|
anchors = anchors + (detection_boxes * output_index * 2);
|
||||||
/**
|
/**
|
||||||
@@ -373,18 +373,18 @@ static int dnn_detect_post_proc_ssd(AVFrame *frame, DNNData *output, int nb_outp
|
|||||||
int scale_w = ctx->scale_width;
|
int scale_w = ctx->scale_width;
|
||||||
int scale_h = ctx->scale_height;
|
int scale_h = ctx->scale_height;
|
||||||
|
|
||||||
if (nb_outputs == 1 && output->width == 7) {
|
if (nb_outputs == 1 && output->dims[3] == 7) {
|
||||||
proposal_count = output->height;
|
proposal_count = output->dims[2];
|
||||||
detect_size = output->width;
|
detect_size = output->dims[3];
|
||||||
detections = output->data;
|
detections = output->data;
|
||||||
} else if (nb_outputs == 2 && output[0].width == 5) {
|
} else if (nb_outputs == 2 && output[0].dims[3] == 5) {
|
||||||
proposal_count = output[0].height;
|
proposal_count = output[0].dims[2];
|
||||||
detect_size = output[0].width;
|
detect_size = output[0].dims[3];
|
||||||
detections = output[0].data;
|
detections = output[0].data;
|
||||||
labels = output[1].data;
|
labels = output[1].data;
|
||||||
} else if (nb_outputs == 2 && output[1].width == 5) {
|
} else if (nb_outputs == 2 && output[1].dims[3] == 5) {
|
||||||
proposal_count = output[1].height;
|
proposal_count = output[1].dims[2];
|
||||||
detect_size = output[1].width;
|
detect_size = output[1].dims[3];
|
||||||
detections = output[1].data;
|
detections = output[1].data;
|
||||||
labels = output[0].data;
|
labels = output[0].data;
|
||||||
} else {
|
} else {
|
||||||
@@ -821,15 +821,19 @@ static int config_input(AVFilterLink *inlink)
|
|||||||
AVFilterContext *context = inlink->dst;
|
AVFilterContext *context = inlink->dst;
|
||||||
DnnDetectContext *ctx = context->priv;
|
DnnDetectContext *ctx = context->priv;
|
||||||
DNNData model_input;
|
DNNData model_input;
|
||||||
int ret;
|
int ret, width_idx, height_idx;
|
||||||
|
|
||||||
ret = ff_dnn_get_input(&ctx->dnnctx, &model_input);
|
ret = ff_dnn_get_input(&ctx->dnnctx, &model_input);
|
||||||
if (ret != 0) {
|
if (ret != 0) {
|
||||||
av_log(ctx, AV_LOG_ERROR, "could not get input from the model\n");
|
av_log(ctx, AV_LOG_ERROR, "could not get input from the model\n");
|
||||||
return ret;
|
return ret;
|
||||||
}
|
}
|
||||||
ctx->scale_width = model_input.width == -1 ? inlink->w : model_input.width;
|
width_idx = dnn_get_width_idx_by_layout(model_input.layout);
|
||||||
ctx->scale_height = model_input.height == -1 ? inlink->h : model_input.height;
|
height_idx = dnn_get_height_idx_by_layout(model_input.layout);
|
||||||
|
ctx->scale_width = model_input.dims[width_idx] == -1 ? inlink->w :
|
||||||
|
model_input.dims[width_idx];
|
||||||
|
ctx->scale_height = model_input.dims[height_idx] == -1 ? inlink->h :
|
||||||
|
model_input.dims[height_idx];
|
||||||
|
|
||||||
return 0;
|
return 0;
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -77,22 +77,29 @@ static const enum AVPixelFormat pix_fmts[] = {
|
|||||||
"the frame's format %s does not match " \
|
"the frame's format %s does not match " \
|
||||||
"the model input channel %d\n", \
|
"the model input channel %d\n", \
|
||||||
av_get_pix_fmt_name(fmt), \
|
av_get_pix_fmt_name(fmt), \
|
||||||
model_input->channels);
|
model_input->dims[dnn_get_channel_idx_by_layout(model_input->layout)]);
|
||||||
|
|
||||||
static int check_modelinput_inlink(const DNNData *model_input, const AVFilterLink *inlink)
|
static int check_modelinput_inlink(const DNNData *model_input, const AVFilterLink *inlink)
|
||||||
{
|
{
|
||||||
AVFilterContext *ctx = inlink->dst;
|
AVFilterContext *ctx = inlink->dst;
|
||||||
enum AVPixelFormat fmt = inlink->format;
|
enum AVPixelFormat fmt = inlink->format;
|
||||||
|
int width_idx, height_idx;
|
||||||
|
|
||||||
|
width_idx = dnn_get_width_idx_by_layout(model_input->layout);
|
||||||
|
height_idx = dnn_get_height_idx_by_layout(model_input->layout);
|
||||||
// the design is to add explicit scale filter before this filter
|
// the design is to add explicit scale filter before this filter
|
||||||
if (model_input->height != -1 && model_input->height != inlink->h) {
|
if (model_input->dims[height_idx] != -1 &&
|
||||||
|
model_input->dims[height_idx] != inlink->h) {
|
||||||
av_log(ctx, AV_LOG_ERROR, "the model requires frame height %d but got %d\n",
|
av_log(ctx, AV_LOG_ERROR, "the model requires frame height %d but got %d\n",
|
||||||
model_input->height, inlink->h);
|
model_input->dims[height_idx],
|
||||||
|
inlink->h);
|
||||||
return AVERROR(EIO);
|
return AVERROR(EIO);
|
||||||
}
|
}
|
||||||
if (model_input->width != -1 && model_input->width != inlink->w) {
|
if (model_input->dims[width_idx] != -1 &&
|
||||||
|
model_input->dims[width_idx] != inlink->w) {
|
||||||
av_log(ctx, AV_LOG_ERROR, "the model requires frame width %d but got %d\n",
|
av_log(ctx, AV_LOG_ERROR, "the model requires frame width %d but got %d\n",
|
||||||
model_input->width, inlink->w);
|
model_input->dims[width_idx],
|
||||||
|
inlink->w);
|
||||||
return AVERROR(EIO);
|
return AVERROR(EIO);
|
||||||
}
|
}
|
||||||
if (model_input->dt != DNN_FLOAT) {
|
if (model_input->dt != DNN_FLOAT) {
|
||||||
@@ -103,7 +110,7 @@ static int check_modelinput_inlink(const DNNData *model_input, const AVFilterLin
|
|||||||
switch (fmt) {
|
switch (fmt) {
|
||||||
case AV_PIX_FMT_RGB24:
|
case AV_PIX_FMT_RGB24:
|
||||||
case AV_PIX_FMT_BGR24:
|
case AV_PIX_FMT_BGR24:
|
||||||
if (model_input->channels != 3) {
|
if (model_input->dims[dnn_get_channel_idx_by_layout(model_input->layout)] != 3) {
|
||||||
LOG_FORMAT_CHANNEL_MISMATCH();
|
LOG_FORMAT_CHANNEL_MISMATCH();
|
||||||
return AVERROR(EIO);
|
return AVERROR(EIO);
|
||||||
}
|
}
|
||||||
@@ -116,7 +123,7 @@ static int check_modelinput_inlink(const DNNData *model_input, const AVFilterLin
|
|||||||
case AV_PIX_FMT_YUV410P:
|
case AV_PIX_FMT_YUV410P:
|
||||||
case AV_PIX_FMT_YUV411P:
|
case AV_PIX_FMT_YUV411P:
|
||||||
case AV_PIX_FMT_NV12:
|
case AV_PIX_FMT_NV12:
|
||||||
if (model_input->channels != 1) {
|
if (model_input->dims[dnn_get_channel_idx_by_layout(model_input->layout)] != 1) {
|
||||||
LOG_FORMAT_CHANNEL_MISMATCH();
|
LOG_FORMAT_CHANNEL_MISMATCH();
|
||||||
return AVERROR(EIO);
|
return AVERROR(EIO);
|
||||||
}
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user