libavfilter/dnn: add header into native model file

Signed-off-by: Guo, Yejun <yejun.guo@intel.com>
Signed-off-by: Pedro Arthur <bygrandao@gmail.com>
This commit is contained in:
Guo, Yejun
2019-09-02 12:35:58 +08:00
committed by Pedro Arthur
parent 3b3150c45f
commit 022f50d3fe
3 changed files with 70 additions and 2 deletions

View File

@@ -64,6 +64,10 @@ static DNNReturnType set_input_output_native(void *model, DNNInputData *input, c
DNNModel *ff_dnn_load_model_native(const char *model_filename)
{
DNNModel *model = NULL;
char header_expected[] = "FFMPEGDNNNATIVE";
char *buf;
size_t size;
int version, header_size, major_version_expected = 0;
ConvolutionalNetwork *network = NULL;
AVIOContext *model_file_context;
int file_size, dnn_size, kernel_size, i;
@@ -84,6 +88,41 @@ DNNModel *ff_dnn_load_model_native(const char *model_filename)
}
file_size = avio_size(model_file_context);
/**
* check file header with string and version
*/
size = sizeof(header_expected);
buf = av_malloc(size);
if (!buf) {
avio_closep(&model_file_context);
av_freep(&model);
return NULL;
}
// size - 1 to skip the ending '\0' which is not saved in file
avio_get_str(model_file_context, size - 1, buf, size);
dnn_size = size - 1;
if (strncmp(buf, header_expected, size) != 0) {
av_freep(&buf);
avio_closep(&model_file_context);
av_freep(&model);
return NULL;
}
av_freep(&buf);
version = (int32_t)avio_rl32(model_file_context);
dnn_size += 4;
if (version != major_version_expected) {
avio_closep(&model_file_context);
av_freep(&model);
return NULL;
}
// currently no need to check minor version
version = (int32_t)avio_rl32(model_file_context);
dnn_size += 4;
header_size = dnn_size;
network = av_mallocz(sizeof(ConvolutionalNetwork));
if (!network){
avio_closep(&model_file_context);
@@ -95,8 +134,8 @@ DNNModel *ff_dnn_load_model_native(const char *model_filename)
avio_seek(model_file_context, file_size - 8, SEEK_SET);
network->layers_num = (int32_t)avio_rl32(model_file_context);
network->operands_num = (int32_t)avio_rl32(model_file_context);
dnn_size = 8;
avio_seek(model_file_context, 0, SEEK_SET);
dnn_size += 8;
avio_seek(model_file_context, header_size, SEEK_SET);
network->layers = av_mallocz(network->layers_num * sizeof(Layer));
if (!network->layers){