lavfi/dnn: add classify support with openvino backend

Signed-off-by: Guo, Yejun <yejun.guo@intel.com>
This commit is contained in:
Guo, Yejun 2021-03-16 13:02:56 +08:00
parent a3b74651a0
commit fc26dca64e
6 changed files with 219 additions and 20 deletions

View File

@ -29,6 +29,7 @@
#include "libavutil/avassert.h"
#include "libavutil/opt.h"
#include "libavutil/avstring.h"
#include "libavutil/detection_bbox.h"
#include "../internal.h"
#include "queue.h"
#include "safe_queue.h"
@ -74,6 +75,7 @@ typedef struct TaskItem {
// one task might have multiple inferences
typedef struct InferenceItem {
TaskItem *task;
uint32_t bbox_index;
} InferenceItem;
// one request for one call to openvino
@ -182,12 +184,23 @@ static DNNReturnType fill_model_input_ov(OVModel *ov_model, RequestItem *request
request->inferences[i] = inference;
request->inference_count = i + 1;
task = inference->task;
if (task->do_ioproc) {
if (ov_model->model->frame_pre_proc != NULL) {
ov_model->model->frame_pre_proc(task->in_frame, &input, ov_model->model->filter_ctx);
} else {
ff_proc_from_frame_to_dnn(task->in_frame, &input, ov_model->model->func_type, ctx);
switch (task->ov_model->model->func_type) {
case DFT_PROCESS_FRAME:
case DFT_ANALYTICS_DETECT:
if (task->do_ioproc) {
if (ov_model->model->frame_pre_proc != NULL) {
ov_model->model->frame_pre_proc(task->in_frame, &input, ov_model->model->filter_ctx);
} else {
ff_proc_from_frame_to_dnn(task->in_frame, &input, ov_model->model->func_type, ctx);
}
}
break;
case DFT_ANALYTICS_CLASSIFY:
ff_frame_to_dnn_classify(task->in_frame, &input, inference->bbox_index, ctx);
break;
default:
av_assert0(!"should not reach here");
break;
}
input.data = (uint8_t *)input.data
+ input.width * input.height * input.channels * get_datatype_size(input.dt);
@ -276,6 +289,13 @@ static void infer_completion_callback(void *args)
}
task->ov_model->model->detect_post_proc(task->out_frame, &output, 1, task->ov_model->model->filter_ctx);
break;
case DFT_ANALYTICS_CLASSIFY:
if (!task->ov_model->model->classify_post_proc) {
av_log(ctx, AV_LOG_ERROR, "classify filter needs to provide post proc\n");
return;
}
task->ov_model->model->classify_post_proc(task->out_frame, &output, request->inferences[i]->bbox_index, task->ov_model->model->filter_ctx);
break;
default:
av_assert0(!"should not reach here");
break;
@ -513,7 +533,44 @@ static DNNReturnType get_input_ov(void *model, DNNData *input, const char *input
return DNN_ERROR;
}
static DNNReturnType extract_inference_from_task(DNNFunctionType func_type, TaskItem *task, Queue *inference_queue)
static int contain_valid_detection_bbox(AVFrame *frame)
{
AVFrameSideData *sd;
const AVDetectionBBoxHeader *header;
const AVDetectionBBox *bbox;
sd = av_frame_get_side_data(frame, AV_FRAME_DATA_DETECTION_BBOXES);
if (!sd) { // this frame has nothing detected
return 0;
}
if (!sd->size) {
return 0;
}
header = (const AVDetectionBBoxHeader *)sd->data;
if (!header->nb_bboxes) {
return 0;
}
for (uint32_t i = 0; i < header->nb_bboxes; i++) {
bbox = av_get_detection_bbox(header, i);
if (bbox->x < 0 || bbox->w < 0 || bbox->x + bbox->w >= frame->width) {
return 0;
}
if (bbox->y < 0 || bbox->h < 0 || bbox->y + bbox->h >= frame->width) {
return 0;
}
if (bbox->classify_count == AV_NUM_DETECTION_BBOX_CLASSIFY) {
return 0;
}
}
return 1;
}
static DNNReturnType extract_inference_from_task(DNNFunctionType func_type, TaskItem *task, Queue *inference_queue, DNNExecBaseParams *exec_params)
{
switch (func_type) {
case DFT_PROCESS_FRAME:
@ -532,6 +589,45 @@ static DNNReturnType extract_inference_from_task(DNNFunctionType func_type, Task
}
return DNN_SUCCESS;
}
case DFT_ANALYTICS_CLASSIFY:
{
const AVDetectionBBoxHeader *header;
AVFrame *frame = task->in_frame;
AVFrameSideData *sd;
DNNExecClassificationParams *params = (DNNExecClassificationParams *)exec_params;
task->inference_todo = 0;
task->inference_done = 0;
if (!contain_valid_detection_bbox(frame)) {
return DNN_SUCCESS;
}
sd = av_frame_get_side_data(frame, AV_FRAME_DATA_DETECTION_BBOXES);
header = (const AVDetectionBBoxHeader *)sd->data;
for (uint32_t i = 0; i < header->nb_bboxes; i++) {
InferenceItem *inference;
const AVDetectionBBox *bbox = av_get_detection_bbox(header, i);
if (av_strncasecmp(bbox->detect_label, params->target, sizeof(bbox->detect_label)) != 0) {
continue;
}
inference = av_malloc(sizeof(*inference));
if (!inference) {
return DNN_ERROR;
}
task->inference_todo++;
inference->task = task;
inference->bbox_index = i;
if (ff_queue_push_back(inference_queue, inference) < 0) {
av_freep(&inference);
return DNN_ERROR;
}
}
return DNN_SUCCESS;
}
default:
av_assert0(!"should not reach here");
return DNN_ERROR;
@ -598,7 +694,7 @@ static DNNReturnType get_output_ov(void *model, const char *input_name, int inpu
task.out_frame = out_frame;
task.ov_model = ov_model;
if (extract_inference_from_task(ov_model->model->func_type, &task, ov_model->inference_queue) != DNN_SUCCESS) {
if (extract_inference_from_task(ov_model->model->func_type, &task, ov_model->inference_queue, NULL) != DNN_SUCCESS) {
av_frame_free(&out_frame);
av_frame_free(&in_frame);
av_log(ctx, AV_LOG_ERROR, "unable to extract inference from task.\n");
@ -690,6 +786,14 @@ DNNReturnType ff_dnn_execute_model_ov(const DNNModel *model, DNNExecBaseParams *
return DNN_ERROR;
}
if (model->func_type == DFT_ANALYTICS_CLASSIFY) {
// Once we add async support for tensorflow backend and native backend,
// we'll combine the two sync/async functions in dnn_interface.h to
// simplify the code in filter, and async will be an option within backends.
// so, do not support now, and classify filter will not call this function.
return DNN_ERROR;
}
if (ctx->options.batch_size > 1) {
avpriv_report_missing_feature(ctx, "batch mode for sync execution");
return DNN_ERROR;
@ -710,7 +814,7 @@ DNNReturnType ff_dnn_execute_model_ov(const DNNModel *model, DNNExecBaseParams *
task.out_frame = exec_params->out_frame ? exec_params->out_frame : exec_params->in_frame;
task.ov_model = ov_model;
if (extract_inference_from_task(ov_model->model->func_type, &task, ov_model->inference_queue) != DNN_SUCCESS) {
if (extract_inference_from_task(ov_model->model->func_type, &task, ov_model->inference_queue, exec_params) != DNN_SUCCESS) {
av_log(ctx, AV_LOG_ERROR, "unable to extract inference from task.\n");
return DNN_ERROR;
}
@ -730,6 +834,7 @@ DNNReturnType ff_dnn_execute_model_async_ov(const DNNModel *model, DNNExecBasePa
OVContext *ctx = &ov_model->ctx;
RequestItem *request;
TaskItem *task;
DNNReturnType ret;
if (ff_check_exec_params(ctx, DNN_OV, model->func_type, exec_params) != 0) {
return DNN_ERROR;
@ -761,23 +866,25 @@ DNNReturnType ff_dnn_execute_model_async_ov(const DNNModel *model, DNNExecBasePa
return DNN_ERROR;
}
if (extract_inference_from_task(ov_model->model->func_type, task, ov_model->inference_queue) != DNN_SUCCESS) {
if (extract_inference_from_task(model->func_type, task, ov_model->inference_queue, exec_params) != DNN_SUCCESS) {
av_log(ctx, AV_LOG_ERROR, "unable to extract inference from task.\n");
return DNN_ERROR;
}
if (ff_queue_size(ov_model->inference_queue) < ctx->options.batch_size) {
// not enough inference items queued for a batch
return DNN_SUCCESS;
while (ff_queue_size(ov_model->inference_queue) >= ctx->options.batch_size) {
request = ff_safe_queue_pop_front(ov_model->request_queue);
if (!request) {
av_log(ctx, AV_LOG_ERROR, "unable to get infer request.\n");
return DNN_ERROR;
}
ret = execute_model_ov(request, ov_model->inference_queue);
if (ret != DNN_SUCCESS) {
return ret;
}
}
request = ff_safe_queue_pop_front(ov_model->request_queue);
if (!request) {
av_log(ctx, AV_LOG_ERROR, "unable to get infer request.\n");
return DNN_ERROR;
}
return execute_model_ov(request, ov_model->inference_queue);
return DNN_SUCCESS;
}
DNNAsyncStatusType ff_dnn_get_async_result_ov(const DNNModel *model, AVFrame **in, AVFrame **out)

View File

@ -22,6 +22,7 @@
#include "libavutil/imgutils.h"
#include "libswscale/swscale.h"
#include "libavutil/avassert.h"
#include "libavutil/detection_bbox.h"
DNNReturnType ff_proc_from_dnn_to_frame(AVFrame *frame, DNNData *output, void *log_ctx)
{
@ -175,6 +176,65 @@ static enum AVPixelFormat get_pixel_format(DNNData *data)
return AV_PIX_FMT_BGR24;
}
DNNReturnType ff_frame_to_dnn_classify(AVFrame *frame, DNNData *input, uint32_t bbox_index, void *log_ctx)
{
const AVPixFmtDescriptor *desc;
int offsetx[4], offsety[4];
uint8_t *bbox_data[4];
struct SwsContext *sws_ctx;
int linesizes[4];
enum AVPixelFormat fmt;
int left, top, width, height;
const AVDetectionBBoxHeader *header;
const AVDetectionBBox *bbox;
AVFrameSideData *sd = av_frame_get_side_data(frame, AV_FRAME_DATA_DETECTION_BBOXES);
av_assert0(sd);
header = (const AVDetectionBBoxHeader *)sd->data;
bbox = av_get_detection_bbox(header, bbox_index);
left = bbox->x;
width = bbox->w;
top = bbox->y;
height = bbox->h;
fmt = get_pixel_format(input);
sws_ctx = sws_getContext(width, height, frame->format,
input->width, input->height, fmt,
SWS_FAST_BILINEAR, NULL, NULL, NULL);
if (!sws_ctx) {
av_log(log_ctx, AV_LOG_ERROR, "Failed to create scale context for the conversion "
"fmt:%s s:%dx%d -> fmt:%s s:%dx%d\n",
av_get_pix_fmt_name(frame->format), width, height,
av_get_pix_fmt_name(fmt), input->width, input->height);
return DNN_ERROR;
}
if (av_image_fill_linesizes(linesizes, fmt, input->width) < 0) {
av_log(log_ctx, AV_LOG_ERROR, "unable to get linesizes with av_image_fill_linesizes");
sws_freeContext(sws_ctx);
return DNN_ERROR;
}
desc = av_pix_fmt_desc_get(frame->format);
offsetx[1] = offsetx[2] = AV_CEIL_RSHIFT(left, desc->log2_chroma_w);
offsetx[0] = offsetx[3] = left;
offsety[1] = offsety[2] = AV_CEIL_RSHIFT(top, desc->log2_chroma_h);
offsety[0] = offsety[3] = top;
for (int k = 0; frame->data[k]; k++)
bbox_data[k] = frame->data[k] + offsety[k] * frame->linesize[k] + offsetx[k];
sws_scale(sws_ctx, (const uint8_t *const *)&bbox_data, frame->linesize,
0, height,
(uint8_t *const *)(&input->data), linesizes);
sws_freeContext(sws_ctx);
return DNN_SUCCESS;
}
static DNNReturnType proc_from_frame_to_dnn_analytics(AVFrame *frame, DNNData *input, void *log_ctx)
{
struct SwsContext *sws_ctx;

View File

@ -32,5 +32,6 @@
DNNReturnType ff_proc_from_frame_to_dnn(AVFrame *frame, DNNData *input, DNNFunctionType func_type, void *log_ctx);
DNNReturnType ff_proc_from_dnn_to_frame(AVFrame *frame, DNNData *output, void *log_ctx);
DNNReturnType ff_frame_to_dnn_classify(AVFrame *frame, DNNData *input, uint32_t bbox_index, void *log_ctx);
#endif

View File

@ -77,6 +77,12 @@ int ff_dnn_set_detect_post_proc(DnnContext *ctx, DetectPostProc post_proc)
return 0;
}
int ff_dnn_set_classify_post_proc(DnnContext *ctx, ClassifyPostProc post_proc)
{
ctx->model->classify_post_proc = post_proc;
return 0;
}
DNNReturnType ff_dnn_get_input(DnnContext *ctx, DNNData *input)
{
return ctx->model->get_input(ctx->model->model, input, ctx->model_inputname);
@ -112,6 +118,21 @@ DNNReturnType ff_dnn_execute_model_async(DnnContext *ctx, AVFrame *in_frame, AVF
return (ctx->dnn_module->execute_model_async)(ctx->model, &exec_params);
}
DNNReturnType ff_dnn_execute_model_classification(DnnContext *ctx, AVFrame *in_frame, AVFrame *out_frame, char *target)
{
DNNExecClassificationParams class_params = {
{
.input_name = ctx->model_inputname,
.output_names = (const char **)&ctx->model_outputname,
.nb_output = 1,
.in_frame = in_frame,
.out_frame = out_frame,
},
.target = target,
};
return (ctx->dnn_module->execute_model_async)(ctx->model, &class_params.base);
}
DNNAsyncStatusType ff_dnn_get_async_result(DnnContext *ctx, AVFrame **in_frame, AVFrame **out_frame)
{
return (ctx->dnn_module->get_async_result)(ctx->model, in_frame, out_frame);

View File

@ -50,10 +50,12 @@ typedef struct DnnContext {
int ff_dnn_init(DnnContext *ctx, DNNFunctionType func_type, AVFilterContext *filter_ctx);
int ff_dnn_set_frame_proc(DnnContext *ctx, FramePrePostProc pre_proc, FramePrePostProc post_proc);
int ff_dnn_set_detect_post_proc(DnnContext *ctx, DetectPostProc post_proc);
int ff_dnn_set_classify_post_proc(DnnContext *ctx, ClassifyPostProc post_proc);
DNNReturnType ff_dnn_get_input(DnnContext *ctx, DNNData *input);
DNNReturnType ff_dnn_get_output(DnnContext *ctx, int input_width, int input_height, int *output_width, int *output_height);
DNNReturnType ff_dnn_execute_model(DnnContext *ctx, AVFrame *in_frame, AVFrame *out_frame);
DNNReturnType ff_dnn_execute_model_async(DnnContext *ctx, AVFrame *in_frame, AVFrame *out_frame);
DNNReturnType ff_dnn_execute_model_classification(DnnContext *ctx, AVFrame *in_frame, AVFrame *out_frame, char *target);
DNNAsyncStatusType ff_dnn_get_async_result(DnnContext *ctx, AVFrame **in_frame, AVFrame **out_frame);
DNNReturnType ff_dnn_flush(DnnContext *ctx);
void ff_dnn_uninit(DnnContext *ctx);

View File

@ -52,7 +52,7 @@ typedef enum {
DFT_NONE,
DFT_PROCESS_FRAME, // process the whole frame
DFT_ANALYTICS_DETECT, // detect from the whole frame
// we can add more such as detect_from_crop, classify_from_bbox, etc.
DFT_ANALYTICS_CLASSIFY, // classify for each bounding box
}DNNFunctionType;
typedef struct DNNData{
@ -71,8 +71,14 @@ typedef struct DNNExecBaseParams {
AVFrame *out_frame;
} DNNExecBaseParams;
typedef struct DNNExecClassificationParams {
DNNExecBaseParams base;
const char *target;
} DNNExecClassificationParams;
typedef int (*FramePrePostProc)(AVFrame *frame, DNNData *model, AVFilterContext *filter_ctx);
typedef int (*DetectPostProc)(AVFrame *frame, DNNData *output, uint32_t nb, AVFilterContext *filter_ctx);
typedef int (*ClassifyPostProc)(AVFrame *frame, DNNData *output, uint32_t bbox_index, AVFilterContext *filter_ctx);
typedef struct DNNModel{
// Stores model that can be different for different backends.
@ -97,6 +103,8 @@ typedef struct DNNModel{
FramePrePostProc frame_post_proc;
// set the post process to interpret detect result from DNNData
DetectPostProc detect_post_proc;
// set the post process to interpret classify result from DNNData
ClassifyPostProc classify_post_proc;
} DNNModel;
// Stores pointers to functions for loading, executing, freeing DNN models for one of the backends.