LCOV - code coverage report
Current view: top level - nnstreamer-2.4.2/gst/nnstreamer/elements - gsttensor_trainer.c (source / functions) Coverage Total Hit
Test: nnstreamer 2.4.2-0 nnstreamer/nnstreamer#eca68b8d050408568af95d831a8eef62aaee7784 Lines: 49.5 % 626 310
Test Date: 2025-03-14 05:36:58 Functions: 52.5 % 40 21

            Line data    Source code
       1              : /* SPDX-License-Identifier: LGPL-2.1-only */
       2              : /**
       3              :  * Copyright (C) 2022 Samsung Electronics Co., Ltd.
       4              :  *
       5              :  * @file        gsttensor_trainer.c
       6              :  * @date        20 October 2022
       7              :  * @brief       GStreamer plugin to train tensor data using NN Frameworks
       8              :  * @see         https://github.com/nnstreamer/nnstreamer
       9              :  * @author      Hyunil Park <hyunil46.park@samsung.com>
      10              :  * @bug         No known bugs except for NYI items
      11              :  *
      12              :  * ## Example launch line
      13              :  * |[
      14              :  * gst-launch-1.0 datareposrc location=mnist_trainingSet.dat json=mnist.json start-sample-index=3 stop-sample-index=202 epochs=5 ! \
      15              :  * tensor_trainer framework=nntrainer model-config=mnist.ini model-save-path=model.bin \
      16              :  * num-inputs=1 num-labels=1 num-training-samples=100 num-validation-samples=100 epochs=5 ! \
      17              :  * tensor_sink
      18              :  * ]|
      19              :  *
      20              :  * Total number of data to be received is 1000((num-training-samples + num-validation-samples) * epochs)
      21              :  *
      22              :  * output tensors : dimensions=1:1:4, types=float64.
      23              :  * values are training loss, training accuracy, validation loss and validation accuracy.
      24              :  * -INFINITY value is stored if the value fetched from the sub-plugin is not greater than 0.
      25              :  */
      26              : 
      27              : #ifdef HAVE_CONFIG_H
      28              : #include "config.h"
      29              : #endif
      30              : #include <stdlib.h>
      31              : #include <nnstreamer_subplugin.h>
      32              : #include <nnstreamer_util.h>
      33              : #include "gsttensor_trainer.h"
      34              : #include <unistd.h>
      35              : #include <math.h>
      36              : 
      37              : /**
      38              :  * @brief Default caps string for sink
      39              :  */
      40              : #define SINK_CAPS_STRING GST_TENSORS_CAP_MAKE ("{ static, flexible }")
      41              : 
      42              : /**
      43              :  * @brief Default caps string for src
      44              :  */
      45              : #define SRC_CAPS_STRING GST_TENSORS_CAP_MAKE ("{ static}")
      46              : 
      47              : /**
      48              :  * @brief The capabilities of the sink pad
      49              :  */
      50              : static GstStaticPadTemplate sink_template = GST_STATIC_PAD_TEMPLATE ("sink",
      51              :     GST_PAD_SINK,
      52              :     GST_PAD_ALWAYS,
      53              :     GST_STATIC_CAPS (SINK_CAPS_STRING));
      54              : 
      55              : /**
      56              :  * @brief The capabilities of the src pad
      57              :  */
      58              : static GstStaticPadTemplate src_template = GST_STATIC_PAD_TEMPLATE ("src",
      59              :     GST_PAD_SRC,
      60              :     GST_PAD_ALWAYS,
      61              :     GST_STATIC_CAPS (SRC_CAPS_STRING));
      62              : 
      63              : GST_DEBUG_CATEGORY_STATIC (gst_tensor_trainer_debug);
      64              : #define GST_CAT_DEFAULT gst_tensor_trainer_debug
      65              : #define gst_tensor_trainer_parent_class parent_class
      66         1047 : G_DEFINE_TYPE (GstTensorTrainer, gst_tensor_trainer, GST_TYPE_ELEMENT);
      67              : 
      68              : /**
      69              :  * @brief Statistical from the model being trained
      70              :  * An enum value indicates the value stored at the index of the output tensor.
      71              :  */
      72              : enum
      73              : {
      74              :   TRAINING_LOSS,
      75              :   TRAINING_ACCURACY,
      76              :   VALIDATION_LOSS,
      77              :   VALIDATION_ACCURACY
      78              : };
      79              : #define MODEL_STATS_SIZE 4
      80              : 
      81              : /**
      82              :  * @brief Default framework property value
      83              :  */
      84              : #define DEFAULT_PROP_INPUT_LIST 1
      85              : #define DEFAULT_PROP_LABEL_LIST 1
      86              : #define DEFAULT_PROP_TRAIN_SAMPLES 0
      87              : #define DEFAULT_PROP_VALID_SAMPLES 0
      88              : #define DEFAULT_PROP_EPOCHS 1
      89              : /**
      90              :  * @brief Default string property value
      91              :  */
      92              : #define DEFAULT_STR_PROP_VALUE ""
      93              : 
      94              : /**
      95              :  * @brief tensor_trainer properties
      96              :  */
      97              : enum
      98              : {
      99              :   PROP_0,
     100              :   PROP_FRAMEWORK,
     101              :   PROP_MODEL_CONFIG,
     102              :   PROP_MODEL_SAVE_PATH,
     103              :   PROP_MODEL_LOAD_PATH,
     104              :   PROP_NUM_INPUTS,              /* number of input list */
     105              :   PROP_NUM_LABELS,              /* number of label list */
     106              :   PROP_NUM_TRAINING_SAMPLES,    /* number of training data */
     107              :   PROP_NUM_VALIDATION_SAMPLES,  /* number of validation data */
     108              :   PROP_EPOCHS,                  /* Repetitions of training */
     109              : };
     110              : 
     111              : static void gst_tensor_trainer_set_property (GObject * object, guint prop_id,
     112              :     const GValue * value, GParamSpec * pspec);
     113              : static void gst_tensor_trainer_get_property (GObject * object, guint prop_id,
     114              :     GValue * value, GParamSpec * pspec);
     115              : static void gst_tensor_trainer_finalize (GObject * object);
     116              : static gboolean gst_tensor_trainer_sink_event (GstPad * sinkpad,
     117              :     GstObject * parent, GstEvent * event);
     118              : static gboolean gst_tensor_trainer_sink_query (GstPad * sinkpad,
     119              :     GstObject * parent, GstQuery * query);
     120              : static gboolean gst_tensor_trainer_src_query (GstPad * srcpad,
     121              :     GstObject * parent, GstQuery * query);
     122              : static GstFlowReturn gst_tensor_trainer_chain (GstPad * sinkpad,
     123              :     GstObject * parent, GstBuffer * inbuf);
     124              : static GstCaps *gst_tensor_trainer_query_caps (GstTensorTrainer * trainer,
     125              :     GstPad * pad, GstCaps * filter);
     126              : static GstStateChangeReturn gst_tensor_trainer_change_state (GstElement *
     127              :     element, GstStateChange transition);
     128              : 
     129              : static void gst_tensor_trainer_set_prop_framework (GstTensorTrainer * trainer,
     130              :     const GValue * value);
     131              : static void gst_tensor_trainer_set_prop_model_config_file_path (GstTensorTrainer
     132              :     * trainer, const GValue * value);
     133              : static void gst_tensor_trainer_set_model_save_path (GstTensorTrainer * trainer,
     134              :     const GValue * value);
     135              : static void gst_tensor_trainer_set_model_load_path (GstTensorTrainer * trainer,
     136              :     const GValue * value);
     137              : static gboolean gst_tensor_trainer_find_framework (GstTensorTrainer * trainer,
     138              :     const char *name);
     139              : static gboolean gst_tensor_trainer_create_framework (GstTensorTrainer *
     140              :     trainer);
     141              : static gsize gst_tensor_trainer_get_tensor_size (GstTensorTrainer * trainer,
     142              :     guint index, gboolean is_input);
     143              : static gboolean gst_tensor_trainer_create_model (GstTensorTrainer * trainer);
     144              : static void gst_tensor_trainer_create_event_notifier (GstTensorTrainer *
     145              :     trainer);
     146              : static void gst_tensor_trainer_start_model_training (GstTensorTrainer *
     147              :     trainer);
     148              : static void gst_tensor_trainer_stop_model_training (GstTensorTrainer * trainer);
     149              : static void gst_tensor_trainer_set_output_meta (GstTensorTrainer * trainer);
     150              : 
     151              : /**
     152              :  * @brief initialize the tensor_trainer's class
     153              :  */
     154              : static void
     155            1 : gst_tensor_trainer_class_init (GstTensorTrainerClass * klass)
     156              : {
     157              :   GObjectClass *gobject_class;
     158              :   GstElementClass *gstelement_class;
     159              : 
     160            1 :   GST_DEBUG_CATEGORY_INIT (GST_CAT_DEFAULT, "tensor_trainer", 0,
     161              :       "Tensor trainer to train neural network model");
     162              : 
     163            1 :   gobject_class = G_OBJECT_CLASS (klass);
     164            1 :   gstelement_class = GST_ELEMENT_CLASS (klass);
     165              : 
     166            1 :   gobject_class->set_property =
     167            1 :       GST_DEBUG_FUNCPTR (gst_tensor_trainer_set_property);
     168            1 :   gobject_class->get_property =
     169            1 :       GST_DEBUG_FUNCPTR (gst_tensor_trainer_get_property);
     170            1 :   gobject_class->finalize = GST_DEBUG_FUNCPTR (gst_tensor_trainer_finalize);
     171              : 
     172              :   /* Called when the element's state changes */
     173            1 :   gstelement_class->change_state =
     174            1 :       GST_DEBUG_FUNCPTR (gst_tensor_trainer_change_state);
     175              : 
     176              :   /* Install properties for tensor_trainer */
     177            1 :   g_object_class_install_property (gobject_class, PROP_FRAMEWORK,
     178              :       g_param_spec_string ("framework", "Framework",
     179              :           "(not nullable) Neural network framework to be used for model training, ",
     180              :           DEFAULT_STR_PROP_VALUE,
     181              :           G_PARAM_READWRITE | GST_PARAM_MUTABLE_READY |
     182              :           G_PARAM_STATIC_STRINGS));
     183              : 
     184            1 :   g_object_class_install_property (gobject_class, PROP_MODEL_CONFIG,
     185              :       g_param_spec_string ("model-config", "Model configuration file path",
     186              :           "(not nullable) Model configuration file is used to configure the model "
     187              :           "to be trained in neural network framework, set the file path",
     188              :           DEFAULT_STR_PROP_VALUE,
     189              :           G_PARAM_READWRITE | GST_PARAM_MUTABLE_READY |
     190              :           G_PARAM_STATIC_STRINGS));
     191              : 
     192            1 :   g_object_class_install_property (gobject_class, PROP_MODEL_SAVE_PATH,
     193              :       g_param_spec_string ("model-save-path", "Model save path",
     194              :           "(not nullable) Path to save the trained model in framework, if model-config "
     195              :           "contains information about the save file, it is ignored",
     196              :           DEFAULT_STR_PROP_VALUE,
     197              :           G_PARAM_READWRITE | GST_PARAM_MUTABLE_READY |
     198              :           G_PARAM_STATIC_STRINGS));
     199              : 
     200            1 :   g_object_class_install_property (gobject_class, PROP_MODEL_LOAD_PATH,
     201              :       g_param_spec_string ("model-load-path", "Model load path",
     202              :           "(nullable) Path to a model file to be loaded for the given training session.",
     203              :           DEFAULT_STR_PROP_VALUE,
     204              :           G_PARAM_READWRITE | GST_PARAM_MUTABLE_READY |
     205              :           G_PARAM_STATIC_STRINGS));
     206              : 
     207            1 :   g_object_class_install_property (gobject_class, PROP_NUM_INPUTS,
     208              :       g_param_spec_uint ("num-inputs", "Number of inputs",
     209              :           "An input in a tensor can have one or more features data,"
     210              :           "set how many inputs are received", 0, NNS_TENSOR_SIZE_LIMIT, 1,
     211              :           G_PARAM_READWRITE | GST_PARAM_MUTABLE_READY |
     212              :           G_PARAM_STATIC_STRINGS));
     213              : 
     214            1 :   g_object_class_install_property (gobject_class, PROP_NUM_LABELS,
     215              :       g_param_spec_uint ("num-labels", "Number of labels",
     216              :           "A label in a tensor can have one or more classes data,"
     217              :           "set how many labels are received", 0, NNS_TENSOR_SIZE_LIMIT, 1,
     218              :           G_PARAM_READWRITE | GST_PARAM_MUTABLE_READY |
     219              :           G_PARAM_STATIC_STRINGS));
     220              : 
     221            1 :   g_object_class_install_property (gobject_class, PROP_NUM_TRAINING_SAMPLES,
     222              :       g_param_spec_uint ("num-training-samples", "Number of training samples",
     223              :           "A sample can consist of multiple inputs and labels in tensors of a gstbuffer"
     224              :           ", set how many samples are taken for training model",
     225              :           0, G_MAXINT, 0,
     226              :           G_PARAM_READWRITE | GST_PARAM_MUTABLE_READY |
     227              :           G_PARAM_STATIC_STRINGS));
     228              : 
     229            1 :   g_object_class_install_property (gobject_class, PROP_NUM_VALIDATION_SAMPLES,
     230              :       g_param_spec_uint ("num-validation-samples",
     231              :           "Number of validation samples",
     232              :           "A sample can consist of multiple inputs and labels in tensors of a gstbuffer"
     233              :           ", set how many samples are taken for validation model",
     234              :           0, G_MAXINT, 0,
     235              :           G_PARAM_READWRITE | GST_PARAM_MUTABLE_READY |
     236              :           G_PARAM_STATIC_STRINGS));
     237              : 
     238            1 :   g_object_class_install_property (gobject_class, PROP_EPOCHS,
     239              :       g_param_spec_uint ("epochs", "Number of epoch",
     240              :           "Epochs are repetitions of training samples and validation samples, "
     241              :           "number of samples received for model training is "
     242              :           "(num-training-samples+num-validation-samples)*epochs", 0, G_MAXINT,
     243              :           DEFAULT_PROP_EPOCHS,
     244              :           G_PARAM_READWRITE | GST_PARAM_MUTABLE_READY |
     245              :           G_PARAM_STATIC_STRINGS));
     246              : 
     247            1 :   gst_element_class_set_details_simple (gstelement_class, "TensorTrainer",
     248              :       "Trainer/Tensor", "Train tensor data using NN Frameworks",
     249              :       "Samsung Electronics Co., Ltd.");
     250              : 
     251              :   /* Add pad template */
     252            1 :   gst_element_class_add_pad_template (gstelement_class,
     253              :       gst_static_pad_template_get (&src_template));
     254            1 :   gst_element_class_add_pad_template (gstelement_class,
     255              :       gst_static_pad_template_get (&sink_template));
     256            1 : }
     257              : 
     258              : /**
     259              :  * @brief Initialize tensor_trainer.
     260              :  */
     261              : static void
     262           11 : gst_tensor_trainer_init (GstTensorTrainer * trainer)
     263              : {
     264           11 :   GST_DEBUG ("<ENTER>");
     265              :   /** setup sink pad */
     266           11 :   trainer->sinkpad = gst_pad_new_from_static_template (&sink_template, "sink");
     267           11 :   gst_pad_set_event_function (trainer->sinkpad,
     268              :       GST_DEBUG_FUNCPTR (gst_tensor_trainer_sink_event));
     269           11 :   gst_pad_set_query_function (trainer->sinkpad,
     270              :       GST_DEBUG_FUNCPTR (gst_tensor_trainer_sink_query));
     271           11 :   gst_pad_set_chain_function (trainer->sinkpad,
     272              :       GST_DEBUG_FUNCPTR (gst_tensor_trainer_chain));
     273           11 :   GST_PAD_SET_PROXY_CAPS (trainer->sinkpad);
     274           11 :   gst_element_add_pad (GST_ELEMENT (trainer), trainer->sinkpad);
     275              : 
     276              :   /** setup src pad */
     277           11 :   trainer->srcpad = gst_pad_new_from_static_template (&src_template, "src");
     278           11 :   gst_pad_set_query_function (trainer->srcpad,
     279              :       GST_DEBUG_FUNCPTR (gst_tensor_trainer_src_query));
     280           11 :   GST_PAD_SET_PROXY_CAPS (trainer->srcpad);
     281           11 :   gst_element_add_pad (GST_ELEMENT (trainer), trainer->srcpad);
     282              : 
     283              :   /** init properties */
     284           11 :   trainer->fw_name = g_strdup (DEFAULT_STR_PROP_VALUE);
     285           11 :   trainer->prop.model_config = g_strdup (DEFAULT_STR_PROP_VALUE);
     286           11 :   trainer->prop.model_save_path = g_strdup (DEFAULT_STR_PROP_VALUE);
     287           11 :   trainer->prop.model_load_path = NULL;
     288           11 :   trainer->prop.num_inputs = DEFAULT_PROP_INPUT_LIST;
     289           11 :   trainer->prop.num_labels = DEFAULT_PROP_LABEL_LIST;
     290           11 :   trainer->prop.num_training_samples = DEFAULT_PROP_TRAIN_SAMPLES;
     291           11 :   trainer->prop.num_validation_samples = DEFAULT_PROP_VALID_SAMPLES;
     292           11 :   trainer->prop.num_epochs = DEFAULT_PROP_EPOCHS;
     293              : 
     294           11 :   trainer->fw = NULL;
     295           11 :   trainer->fw_created = FALSE;
     296           11 :   trainer->is_training_complete = FALSE;
     297           11 :   trainer->is_epoch_complete = FALSE;
     298           11 :   trainer->cur_epoch_data_cnt = 0;
     299           11 :   trainer->required_sample = 0;
     300              : 
     301           11 :   gst_tensors_config_init (&trainer->in_config);
     302           11 :   gst_tensors_config_init (&trainer->out_config);
     303              : 
     304           11 :   g_cond_init (&trainer->training_completion_cond);
     305           11 :   g_mutex_init (&trainer->training_completion_lock);
     306           11 :   g_cond_init (&trainer->epoch_completion_cond);
     307           11 :   g_mutex_init (&trainer->epoch_completion_lock);
     308              : 
     309           11 :   gst_tensor_trainer_set_output_meta (trainer);
     310           11 : }
     311              : 
     312              : /**
     313              :  * @brief Function to finalize instance.
     314              :  */
     315              : static void
     316            8 : gst_tensor_trainer_finalize (GObject * object)
     317              : {
     318              :   GstTensorTrainer *trainer;
     319            8 :   trainer = GST_TENSOR_TRAINER (object);
     320              : 
     321            8 :   g_free (trainer->fw_name);
     322            8 :   g_free ((char *) trainer->prop.model_config);
     323            8 :   g_free ((char *) trainer->prop.model_save_path);
     324            8 :   g_free ((char *) trainer->prop.model_load_path);
     325              : 
     326            8 :   gst_tensors_config_free (&trainer->in_config);
     327            8 :   gst_tensors_config_free (&trainer->out_config);
     328              : 
     329            8 :   g_cond_clear (&trainer->training_completion_cond);
     330            8 :   g_mutex_clear (&trainer->training_completion_lock);
     331            8 :   g_cond_clear (&trainer->epoch_completion_cond);
     332            8 :   g_mutex_clear (&trainer->epoch_completion_lock);
     333              : 
     334            8 :   if (trainer->dummy_data_thread) {
     335            0 :     g_thread_join (trainer->dummy_data_thread);
     336            0 :     trainer->dummy_data_thread = NULL;
     337              :   }
     338              : 
     339            8 :   if (trainer->fw_created && trainer->fw) {
     340            0 :     trainer->fw->destroy (trainer->fw, &trainer->prop, &trainer->privateData);
     341              :   }
     342              : 
     343            8 :   G_OBJECT_CLASS (parent_class)->finalize (object);
     344            8 : }
     345              : 
     346              : /**
     347              :  * @brief Setter for tensor_trainsink properties.
     348              :  */
     349              : static void
     350           81 : gst_tensor_trainer_set_property (GObject * object, guint prop_id,
     351              :     const GValue * value, GParamSpec * pspec)
     352              : {
     353              :   GstTensorTrainer *trainer;
     354              : 
     355           81 :   trainer = GST_TENSOR_TRAINER (object);
     356              : 
     357           81 :   switch (prop_id) {
     358           11 :     case PROP_FRAMEWORK:
     359           11 :       gst_tensor_trainer_set_prop_framework (trainer, value);
     360           11 :       break;
     361           11 :     case PROP_MODEL_CONFIG:
     362           11 :       gst_tensor_trainer_set_prop_model_config_file_path (trainer, value);
     363           11 :       break;
     364            8 :     case PROP_MODEL_SAVE_PATH:
     365            8 :       gst_tensor_trainer_set_model_save_path (trainer, value);
     366            8 :       break;
     367            2 :     case PROP_MODEL_LOAD_PATH:
     368            2 :       gst_tensor_trainer_set_model_load_path (trainer, value);
     369            2 :       break;
     370           10 :     case PROP_NUM_INPUTS:
     371           10 :       trainer->prop.num_inputs = g_value_get_uint (value);
     372           10 :       break;
     373           10 :     case PROP_NUM_LABELS:
     374           10 :       trainer->prop.num_labels = g_value_get_uint (value);
     375           10 :       break;
     376            9 :     case PROP_NUM_TRAINING_SAMPLES:
     377            9 :       trainer->prop.num_training_samples = g_value_get_uint (value);
     378            9 :       break;
     379           10 :     case PROP_NUM_VALIDATION_SAMPLES:
     380           10 :       trainer->prop.num_validation_samples = g_value_get_uint (value);
     381           10 :       break;
     382           10 :     case PROP_EPOCHS:
     383           10 :       trainer->prop.num_epochs = g_value_get_uint (value);
     384           10 :       break;
     385            0 :     default:
     386            0 :       G_OBJECT_WARN_INVALID_PROPERTY_ID (object, prop_id, pspec);
     387            0 :       break;
     388              :   }
     389           81 : }
     390              : 
     391              : /**
     392              :  * @brief Getter tensor_trainsink properties.
     393              :  */
     394              : static void
     395           14 : gst_tensor_trainer_get_property (GObject * object, guint prop_id,
     396              :     GValue * value, GParamSpec * pspec)
     397              : {
     398              :   GstTensorTrainer *trainer;
     399              : 
     400           14 :   trainer = GST_TENSOR_TRAINER (object);
     401              : 
     402           14 :   switch (prop_id) {
     403            0 :     case PROP_FRAMEWORK:
     404            0 :       g_value_set_string (value, trainer->fw_name);
     405            0 :       break;
     406            1 :     case PROP_MODEL_CONFIG:
     407            1 :       g_value_set_string (value, trainer->prop.model_config);
     408            1 :       break;
     409            1 :     case PROP_MODEL_SAVE_PATH:
     410            1 :       g_value_set_string (value, trainer->prop.model_save_path);
     411            1 :       break;
     412            2 :     case PROP_MODEL_LOAD_PATH:
     413            2 :       g_value_set_string (value, trainer->prop.model_load_path);
     414            2 :       break;
     415            2 :     case PROP_NUM_INPUTS:
     416            2 :       g_value_set_uint (value, trainer->prop.num_inputs);
     417            2 :       break;
     418            2 :     case PROP_NUM_LABELS:
     419            2 :       g_value_set_uint (value, trainer->prop.num_labels);
     420            2 :       break;
     421            2 :     case PROP_NUM_TRAINING_SAMPLES:
     422            2 :       g_value_set_uint (value, trainer->prop.num_training_samples);
     423            2 :       break;
     424            2 :     case PROP_NUM_VALIDATION_SAMPLES:
     425            2 :       g_value_set_uint (value, trainer->prop.num_validation_samples);
     426            2 :       break;
     427            2 :     case PROP_EPOCHS:
     428            2 :       g_value_set_uint (value, trainer->prop.num_epochs);
     429            2 :       break;
     430            0 :     default:
     431            0 :       G_OBJECT_WARN_INVALID_PROPERTY_ID (object, prop_id, pspec);
     432            0 :       break;
     433              :   }
     434           14 : }
     435              : 
     436              : /**
     437              :  * @brief Check invalid param
     438              :  */
     439              : static gboolean
     440            3 : gst_tensor_trainer_check_invalid_param (GstTensorTrainer * trainer)
     441              : {
     442            3 :   g_return_val_if_fail (trainer != NULL, FALSE);
     443              : 
     444              :   /* Parameters that can be retrieved from caps will be removed */
     445            3 :   if (!trainer->fw_name
     446            2 :       || (g_ascii_strcasecmp (trainer->prop.model_config,
     447              :               DEFAULT_STR_PROP_VALUE) == 0)
     448            2 :       || (g_ascii_strcasecmp (trainer->prop.model_save_path,
     449              :               DEFAULT_STR_PROP_VALUE) == 0)
     450            1 :       || trainer->prop.num_epochs <= 0 || trainer->prop.num_inputs <= 0
     451            1 :       || trainer->prop.num_labels <= 0) {
     452            2 :     GST_ERROR_OBJECT (trainer, "Check for invalid param value");
     453              : 
     454            2 :     return FALSE;
     455              :   }
     456              : 
     457            1 :   if (!g_file_test (trainer->prop.model_config,
     458              :           (G_FILE_TEST_EXISTS | G_FILE_TEST_IS_REGULAR))) {
     459            0 :     GST_ERROR_OBJECT (trainer, "Model config file does not exist. [%s]",
     460              :         trainer->prop.model_config);
     461            0 :     return FALSE;
     462              :   }
     463              : 
     464            1 :   return TRUE;
     465              : }
     466              : 
     467              : /**
     468              :  * @brief Dummy data generation thread
     469              :  */
     470              : static gpointer
     471            0 : gst_tensor_trainer_dummy_data_generation_func (GstTensorTrainer * trainer)
     472              : {
     473              :   guint i;
     474            0 :   gint ret = -1;
     475            0 :   gpointer dummy_data[NNS_TENSOR_SIZE_LIMIT] = { NULL };
     476            0 :   g_return_val_if_fail (trainer != NULL, NULL);
     477              : 
     478            0 :   gst_tensor_trainer_stop_model_training (trainer);
     479              : 
     480            0 :   for (i = 0; i < trainer->output_meta.num_tensors; i++) {
     481            0 :     dummy_data[i] = g_malloc (trainer->input_tensors[i].size);
     482            0 :     memset (dummy_data[i], 1, trainer->input_tensors[i].size);
     483            0 :     trainer->input_tensors[i].data = dummy_data[i];
     484              :   }
     485              : 
     486              :   do {
     487            0 :     GST_INFO_OBJECT (trainer, "cur_epoch_data_cnt=%u",
     488              :         trainer->cur_epoch_data_cnt);
     489            0 :     GST_INFO_OBJECT (trainer, "num_tensors=%d",
     490              :         trainer->prop.input_meta.num_tensors);
     491              : 
     492              :     ret =
     493            0 :         trainer->fw->push_data (trainer->fw, &trainer->prop,
     494            0 :         trainer->privateData, trainer->input_tensors);
     495              : 
     496            0 :     if (ret < 0) {
     497            0 :       GST_ERROR_OBJECT (trainer, "Failed to push dummy data");
     498              :     } else {
     499            0 :       trainer->cur_epoch_data_cnt++;
     500              :     }
     501            0 :   } while (trainer->required_sample > trainer->cur_epoch_data_cnt);
     502              : 
     503            0 :   for (i = 0; i < trainer->output_meta.num_tensors; i++)
     504            0 :     g_free (dummy_data[i]);
     505              : 
     506            0 :   return NULL;
     507              : }
     508              : 
     509              : /**
     510              :  * @brief Change state of tensor_trainsink.
     511              :  */
     512              : static GstStateChangeReturn
     513           20 : gst_tensor_trainer_change_state (GstElement * element,
     514              :     GstStateChange transition)
     515              : {
     516           20 :   GstTensorTrainer *trainer = GST_TENSOR_TRAINER (element);
     517           20 :   GstStateChangeReturn ret = GST_STATE_CHANGE_SUCCESS;
     518              : 
     519           20 :   switch (transition) {
     520            6 :     case GST_STATE_CHANGE_NULL_TO_READY:
     521            6 :       GST_INFO_OBJECT (trainer, "NULL_TO_READY");
     522              :       /* currently not used */
     523            6 :       trainer->is_training_complete = FALSE;
     524            6 :       break;
     525              : 
     526            4 :     case GST_STATE_CHANGE_READY_TO_PAUSED:
     527            4 :       GST_INFO_OBJECT (trainer, "READY_TO_PAUSED");
     528            4 :       break;
     529              : 
     530            3 :     case GST_STATE_CHANGE_PAUSED_TO_PLAYING:
     531            3 :       GST_INFO_OBJECT (trainer, "PAUSED_TO_PLAYING");
     532            3 :       if (!gst_tensor_trainer_check_invalid_param (trainer))
     533            2 :         goto state_change_failed;
     534            1 :       if (!trainer->fw_created) {
     535            1 :         if (!gst_tensor_trainer_create_model (trainer))
     536            1 :           goto state_change_failed;
     537              :       }
     538            0 :       gst_tensor_trainer_create_event_notifier (trainer);
     539            0 :       gst_tensor_trainer_start_model_training (trainer);
     540            0 :       break;
     541              : 
     542            7 :     default:
     543            7 :       break;
     544              :   }
     545              : 
     546           17 :   ret = GST_ELEMENT_CLASS (parent_class)->change_state (element, transition);
     547              : 
     548           17 :   switch (transition) {
     549            0 :     case GST_STATE_CHANGE_PLAYING_TO_PAUSED:
     550            0 :       GST_INFO_OBJECT (trainer, "PLAYING_TO_PAUSED");
     551              :       /* need to generate dummy data */
     552            0 :       if (!trainer->is_training_complete) {
     553            0 :         if (!g_strcmp0 (trainer->fw_name, "nntrainer")) {
     554            0 :           GST_INFO_OBJECT (trainer, "cur_epoch_data_cnt=%u",
     555              :               trainer->cur_epoch_data_cnt);
     556            0 :           trainer->dummy_data_thread =
     557            0 :               g_thread_new ("dumy_data_generation_func",
     558              :               (GThreadFunc) gst_tensor_trainer_dummy_data_generation_func,
     559              :               trainer);
     560              :         }
     561              :       }
     562            0 :       break;
     563              : 
     564            1 :     case GST_STATE_CHANGE_PAUSED_TO_READY:
     565            1 :       GST_INFO_OBJECT (trainer, "PAUSED_TO_READY");
     566              :       /* stop model train ? */
     567            1 :       break;
     568              : 
     569            3 :     case GST_STATE_CHANGE_READY_TO_NULL:
     570            3 :       GST_INFO_OBJECT (trainer, "READY_TO_NULL");
     571              :       /* destroy or reset model ? */
     572            3 :       break;
     573              : 
     574           13 :     default:
     575           13 :       break;
     576              :   }
     577              : 
     578           17 :   return ret;
     579              : 
     580            3 : state_change_failed:
     581            3 :   GST_ERROR_OBJECT (trainer, "state change failed");
     582              : 
     583            3 :   return GST_STATE_CHANGE_FAILURE;
     584              : }
     585              : 
     586              : /**
     587              :  * @brief Wait for epoch eompletion
     588              :  */
     589              : static void
     590            0 : gst_tensor_trainer_wait_for_epoch_completion (GstTensorTrainer * trainer)
     591              : {
     592            0 :   g_return_if_fail (trainer != NULL);
     593              : 
     594            0 :   g_mutex_lock (&trainer->epoch_completion_lock);
     595            0 :   while (!trainer->is_epoch_complete) {
     596            0 :     GST_INFO_OBJECT (trainer, "wait for epoch_completion_cond signal");
     597            0 :     g_cond_wait (&trainer->epoch_completion_cond,
     598              :         &trainer->epoch_completion_lock);
     599              :   }
     600            0 :   trainer->is_epoch_complete = FALSE;
     601            0 :   g_mutex_unlock (&trainer->epoch_completion_lock);
     602              : }
     603              : 
     604              : /**
     605              :  * @brief Check if current epochs is complete,
     606              :  * tensor_trainer wait for one of epochs to complete before getting the results from the subplugin
     607              :  */
     608              : static gboolean
     609            0 : gst_tensor_trainer_epochs_is_complete (GstTensorTrainer * trainer)
     610              : {
     611            0 :   g_return_val_if_fail (trainer != NULL, FALSE);
     612            0 :   g_return_val_if_fail (trainer->fw != NULL, FALSE);
     613            0 :   g_return_val_if_fail (&trainer->prop != NULL, FALSE);
     614              : 
     615            0 :   trainer->required_sample =
     616            0 :       trainer->prop.num_training_samples + trainer->prop.num_validation_samples;
     617            0 :   if (trainer->cur_epoch_data_cnt != trainer->required_sample)
     618            0 :     return FALSE;
     619              : 
     620            0 :   gst_tensor_trainer_wait_for_epoch_completion (trainer);
     621            0 :   trainer->cur_epoch_data_cnt = 0;
     622            0 :   return TRUE;
     623              : }
     624              : 
     625              : /**
     626              :  * @brief Check buffer drop conditions. If condition is met, drop the buffer.
     627              :  */
     628              : static gboolean
     629            0 : gst_tensor_trainer_check_buffer_drop_conditions (GstTensorTrainer * trainer)
     630              : {
     631            0 :   if (trainer->is_training_complete == TRUE) {
     632              :     /** app need to send gst_element_send_event(tensor_trainer, gst_event_new_eos())
     633              :         after training_complete or set eos to datareposrc */
     634            0 :     GST_WARNING_OBJECT (trainer,
     635              :         "Training is completed, buffer is dropped, please change state of pipeline");
     636            0 :     return TRUE;
     637              :   }
     638            0 :   return FALSE;
     639              : }
     640              : 
     641              : /**
     642              :  * @brief  Check chain conditions. If all conditions are met, proceed to next step.
     643              :  */
     644              : static gboolean
     645            0 : gst_tensor_trainer_check_chain_conditions (GstTensorTrainer * trainer,
     646              :     guint num_tensors)
     647              : {
     648            0 :   if (!trainer->fw_created) {
     649            0 :     if (!gst_tensor_trainer_check_invalid_param (trainer))
     650            0 :       return FALSE;;
     651            0 :     if (!gst_tensor_trainer_create_model (trainer))
     652            0 :       return FALSE;
     653              :   }
     654              : 
     655            0 :   if (num_tensors >= NNS_TENSOR_SIZE_LIMIT)
     656            0 :     return FALSE;
     657              : 
     658            0 :   return TRUE;
     659              : }
     660              : 
     661              : /**
     662              :  * @brief Convert tensor meta and get the size of tensor header.
     663              :  */
     664              : static gsize
     665            0 : gst_tensor_trainer_convert_meta (GstTensorTrainer * trainer,
     666              :     GstTensorMetaInfo * meta, GstTensorInfo * info, void *data)
     667              : {
     668            0 :   gsize header_size = 0;
     669              : 
     670            0 :   if (!gst_tensor_meta_info_parse_header (meta, data)) {
     671            0 :     GST_ERROR_OBJECT (trainer, "Invalid Flexible tensors");
     672            0 :     return 0;
     673              :   }
     674              : 
     675            0 :   if (gst_tensor_meta_info_convert (meta, info)) {
     676            0 :     header_size = gst_tensor_meta_info_get_header_size (meta);
     677            0 :     GST_INFO ("flexible header size:%zd", header_size);
     678              :   }
     679              : 
     680            0 :   return header_size;
     681              : }
     682              : 
     683              : /**
     684              :  * @brief Create input tensors from the buffer and push it into trainer fw.
     685              :  */
     686              : static gboolean
     687            0 : gst_tensor_trainer_push_input (GstTensorTrainer * trainer, GstBuffer * inbuf,
     688              :     gboolean in_flexible)
     689              : {
     690              :   guint i, n;
     691            0 :   GstMemory *in_mem[NNS_TENSOR_SIZE_LIMIT] = { 0, };
     692              :   GstMapInfo in_info[NNS_TENSOR_SIZE_LIMIT];
     693              :   GstTensorMetaInfo in_meta[NNS_TENSOR_SIZE_LIMIT];
     694              :   GstTensorInfo *info;
     695            0 :   gsize header_size = 0, expected;
     696            0 :   gint ret = -1;
     697              : 
     698            0 :   n = gst_tensor_buffer_get_count (inbuf);
     699              : 
     700            0 :   if (in_flexible)
     701            0 :     trainer->prop.input_meta.num_tensors = n;
     702              :   else {
     703            0 :     GST_DEBUG_OBJECT (trainer, "num_tensors: %u",
     704              :         trainer->prop.input_meta.num_tensors);
     705            0 :     if (n != trainer->prop.input_meta.num_tensors) {
     706            0 :       GST_ERROR_OBJECT (trainer,
     707              :           "Invalid memory blocks (%u), number of input tensors may be (%u)",
     708              :           n, trainer->prop.input_meta.num_tensors);
     709            0 :       goto error;
     710              :     }
     711              :   }
     712              : 
     713            0 :   for (i = 0; i < n; i++) {
     714            0 :     in_mem[i] = gst_tensor_buffer_get_nth_memory (inbuf, i);
     715            0 :     if (!gst_memory_map (in_mem[i], &in_info[i], GST_MAP_READ)) {
     716            0 :       GST_ERROR_OBJECT (trainer, "Could not map in_mem[%u] GstMemory", i);
     717            0 :       goto error;
     718              :     }
     719              : 
     720            0 :     if (in_flexible) {
     721            0 :       info = gst_tensors_info_get_nth_info (&trainer->prop.input_meta, i);
     722            0 :       header_size = gst_tensor_trainer_convert_meta (trainer,
     723            0 :           &in_meta[i], info, in_info[i].data);
     724            0 :       if (header_size == 0)
     725            0 :         goto error;
     726              :     }
     727              : 
     728            0 :     trainer->input_tensors[i].data = in_info[i].data + header_size;
     729            0 :     trainer->input_tensors[i].size = in_info[i].size - header_size;
     730            0 :     GST_INFO ("input_tensors[%u].size= %zd", i, trainer->input_tensors[i].size);
     731            0 :     GST_INFO ("input_tensors[%u].data: %p", i, trainer->input_tensors[i].data);
     732              : 
     733              :     /* Check input tensor size */
     734            0 :     expected = gst_tensor_trainer_get_tensor_size (trainer, i, TRUE);
     735            0 :     if (expected != trainer->input_tensors[i].size) {
     736            0 :       GST_ERROR_OBJECT (trainer,
     737              :           "Invalid tensor size (%u'th memory chunk: %zd), expected size (%zd)",
     738              :           i, trainer->input_tensors[i].size, expected);
     739            0 :       goto error;
     740              :     }
     741              :   }
     742              : 
     743            0 :   ret = trainer->fw->push_data (trainer->fw, &trainer->prop,
     744            0 :       trainer->privateData, trainer->input_tensors);
     745              : 
     746            0 :   if (ret < 0)
     747            0 :     GST_ERROR_OBJECT (trainer, "push error");
     748              :   else
     749            0 :     trainer->cur_epoch_data_cnt++;
     750              : 
     751            0 : error:
     752            0 :   for (i = 0; i < n; i++) {
     753            0 :     if (in_mem[i]) {
     754            0 :       gst_memory_unmap (in_mem[i], &in_info[i]);
     755            0 :       gst_memory_unref (in_mem[i]);
     756              :     }
     757              : 
     758            0 :     trainer->input_tensors[i].data = NULL;
     759            0 :     trainer->input_tensors[i].size = 0;
     760              :   }
     761              : 
     762            0 :   return (ret == 0);
     763              : }
     764              : 
     765              : /**
     766              :  * @brief Get the model statistics from the sub-plugin.
     767              :  */
     768              : static gboolean
     769            0 : gst_tensor_trainer_get_model_stats (GstTensorTrainer * trainer,
     770              :     double *model_stats)
     771              : {
     772            0 :   gint ret = -1;
     773              : 
     774              :   ret =
     775            0 :       trainer->fw->getStatus (trainer->fw, &trainer->prop,
     776              :       trainer->privateData);
     777            0 :   if (ret < 0) {
     778            0 :     GST_ERROR_OBJECT (trainer, "Failed to Get status from sub-plugin.(%s).",
     779              :         trainer->fw_name);
     780            0 :     return FALSE;
     781              :   }
     782              :   /* If the value is invalid, it is already set by -INFINITY. */
     783            0 :   if (trainer->prop.training_loss > 0)
     784            0 :     model_stats[TRAINING_LOSS] = trainer->prop.training_loss;
     785            0 :   if (trainer->prop.training_accuracy > 0)
     786            0 :     model_stats[TRAINING_ACCURACY] = trainer->prop.training_accuracy;
     787            0 :   if (trainer->prop.validation_loss > 0)
     788            0 :     model_stats[VALIDATION_LOSS] = trainer->prop.validation_loss;
     789            0 :   if (trainer->prop.validation_accuracy > 0)
     790            0 :     model_stats[VALIDATION_ACCURACY] = trainer->prop.validation_accuracy;
     791              : 
     792            0 :   GST_DEBUG_OBJECT (trainer,
     793              :       "#%u/%u epochs [training_loss: %f, training_accuracy: %f, validation_loss: %f, validation_accuracy: %f]",
     794              :       trainer->prop.epoch_count, trainer->prop.num_epochs,
     795              :       model_stats[TRAINING_LOSS], model_stats[TRAINING_ACCURACY],
     796              :       model_stats[VALIDATION_LOSS], model_stats[VALIDATION_ACCURACY]);
     797              : 
     798            0 :   return TRUE;
     799              : }
     800              : 
     801              : /**
     802              :  * @brief Create output tensors.
     803              :  */
     804              : static GstBuffer *
     805            0 : gst_tensor_trainer_create_output (GstTensorTrainer * trainer)
     806              : {
     807              :   guint i;
     808              :   size_t data_size;
     809            0 :   double model_stats[MODEL_STATS_SIZE] =
     810              :       { -INFINITY, -INFINITY, -INFINITY, -INFINITY };
     811              :   GstBuffer *outbuf;
     812              :   GstMemory *out_mem;
     813              :   GstMapInfo out_info;
     814              :   GstTensorInfo *info;
     815            0 :   gboolean created = FALSE;
     816              : 
     817            0 :   if (trainer->output_meta.num_tensors > NNS_TENSOR_SIZE_LIMIT) {
     818            0 :     GST_ERROR_OBJECT (trainer,
     819              :         "The number of output tensors (%u) exceeds limit (%d)",
     820              :         trainer->output_meta.num_tensors, NNS_TENSOR_SIZE_LIMIT);
     821            0 :     return NULL;
     822              :   }
     823              : 
     824            0 :   outbuf = gst_buffer_new ();
     825              : 
     826            0 :   for (i = 0; i < trainer->output_meta.num_tensors; i++) {
     827            0 :     if (!gst_tensor_trainer_get_model_stats (trainer, model_stats))
     828            0 :       goto error;
     829              : 
     830            0 :     data_size = gst_tensor_trainer_get_tensor_size (trainer, i, FALSE);
     831            0 :     info = gst_tensors_info_get_nth_info (&trainer->output_meta, i);
     832              : 
     833            0 :     out_mem = gst_allocator_alloc (NULL, data_size, NULL);
     834            0 :     if (!out_mem) {
     835            0 :       GST_ERROR_OBJECT (trainer, "Failed to allocate memory");
     836            0 :       goto error;
     837              :     }
     838              : 
     839            0 :     if (!gst_memory_map (out_mem, &out_info, GST_MAP_WRITE)) {
     840            0 :       GST_ERROR_OBJECT (trainer, "Could not map out_mem[%u] GstMemory", i);
     841            0 :       gst_memory_unref (out_mem);
     842            0 :       goto error;
     843              :     }
     844              : 
     845            0 :     memcpy (out_info.data, model_stats, sizeof (model_stats));
     846            0 :     gst_memory_unmap (out_mem, &out_info);
     847              : 
     848            0 :     gst_tensor_buffer_append_memory (outbuf, out_mem, info);
     849              :   }
     850              : 
     851            0 :   created = TRUE;
     852              : 
     853            0 : error:
     854            0 :   if (created) {
     855            0 :     GST_INFO ("out_buffer size : %zd", gst_buffer_get_size (outbuf));
     856              :   } else {
     857            0 :     gst_buffer_unref (outbuf);
     858            0 :     outbuf = NULL;
     859              :   }
     860              : 
     861            0 :   return outbuf;
     862              : }
     863              : 
     864              : /**
     865              :  * @brief Chain function, this function does the actual processing.
     866              :  */
     867              : static GstFlowReturn
     868            0 : gst_tensor_trainer_chain (GstPad * sinkpad, GstObject * parent,
     869              :     GstBuffer * inbuf)
     870              : {
     871              :   GstTensorTrainer *trainer;
     872            0 :   GstBuffer *outbuf = NULL;
     873            0 :   GstFlowReturn ret = GST_FLOW_ERROR;
     874              :   guint num_tensors;
     875              :   gboolean in_flexible;
     876              : 
     877            0 :   trainer = GST_TENSOR_TRAINER (parent);
     878            0 :   in_flexible = gst_tensor_pad_caps_is_flexible (sinkpad);
     879            0 :   num_tensors = gst_tensor_buffer_get_count (inbuf);
     880              : 
     881            0 :   if (!gst_tensor_trainer_check_chain_conditions (trainer, num_tensors)) {
     882            0 :     goto error;
     883              :   }
     884              : 
     885            0 :   if (gst_tensor_trainer_check_buffer_drop_conditions (trainer)) {
     886            0 :     ret = GST_FLOW_OK;
     887            0 :     goto error;
     888              :   }
     889              : 
     890            0 :   if (!gst_tensor_trainer_push_input (trainer, inbuf, in_flexible)) {
     891            0 :     goto error;
     892              :   }
     893              : 
     894              :   /**
     895              :    * Update result if one of epochs is complete,
     896              :    * push one outbuf is necessary to change pipeline state.
     897              :    * Scheduling with subplugin does not work.
     898              :    */
     899            0 :   if (trainer->cur_epoch_data_cnt == 1
     900            0 :       || gst_tensor_trainer_epochs_is_complete (trainer)) {
     901            0 :     outbuf = gst_tensor_trainer_create_output (trainer);
     902              : 
     903            0 :     if (outbuf)
     904            0 :       ret = gst_pad_push (trainer->srcpad, outbuf);
     905              :   } else {
     906              :     /* Run flow, need more data? */
     907            0 :     ret = GST_FLOW_OK;
     908              :   }
     909              : 
     910            0 : error:
     911            0 :   gst_buffer_unref (inbuf);
     912            0 :   return ret;
     913              : }
     914              : 
     915              : /**
     916              :  * @brief Get pad caps for caps negotiation.
     917              :  */
     918              : static GstCaps *
     919           47 : gst_tensor_trainer_query_caps (GstTensorTrainer * trainer,
     920              :     GstPad * pad, GstCaps * filter)
     921              : {
     922              :   GstCaps *caps;
     923              :   GstTensorsConfig *config;
     924              : 
     925           47 :   g_return_val_if_fail (trainer != NULL, NULL);
     926           47 :   g_return_val_if_fail (pad != NULL, NULL);
     927              : 
     928              :   /* tensor config info for given pad */
     929           47 :   if (pad == trainer->sinkpad) {
     930           25 :     config = &trainer->in_config;
     931              :   } else {
     932           22 :     config = &trainer->out_config;
     933              :   }
     934              : 
     935           47 :   caps = gst_tensor_pad_possible_caps_from_config (pad, config);
     936           47 :   GST_DEBUG_OBJECT (trainer, "caps %" GST_PTR_FORMAT, caps);
     937           47 :   GST_DEBUG_OBJECT (trainer, "filter %" GST_PTR_FORMAT, filter);
     938              : 
     939           47 :   if (caps && filter) {
     940              :     GstCaps *result;
     941            3 :     result = gst_caps_intersect_full (filter, caps, GST_CAPS_INTERSECT_FIRST);
     942            3 :     gst_caps_unref (caps);
     943            3 :     caps = result;
     944              :   }
     945              : 
     946           47 :   GST_DEBUG_OBJECT (trainer, "result caps %" GST_PTR_FORMAT, caps);
     947              : 
     948           47 :   return caps;
     949              : }
     950              : 
     951              : /**
     952              :  * @brief Wait for training completion
     953              :  */
     954              : static void
     955            0 : gst_tensor_trainer_wait_for_training_completion (GstTensorTrainer * trainer)
     956              : {
     957            0 :   g_return_if_fail (trainer != NULL);
     958              : 
     959            0 :   g_mutex_lock (&trainer->training_completion_lock);
     960            0 :   while (!trainer->is_training_complete) {
     961            0 :     GST_INFO_OBJECT (trainer,
     962              :         "got GST_EVENT_EOS event but training is not completed, state is %d, "
     963              :         "wait for training_completion_cond signal", GST_STATE (trainer));
     964            0 :     g_cond_wait (&trainer->training_completion_cond,
     965              :         &trainer->training_completion_lock);
     966              :   }
     967            0 :   g_mutex_unlock (&trainer->training_completion_lock);
     968              : 
     969            0 :   GST_DEBUG_OBJECT (trainer, "training is completed in sub-plugin[%s]",
     970              :       trainer->fw_name);
     971              : }
     972              : 
     973              : /**
     974              :  * @brief Event handler for sink pad of tensor_trainer
     975              :  */
     976              : static gboolean
     977            6 : gst_tensor_trainer_sink_event (GstPad * sinkpad, GstObject * parent,
     978              :     GstEvent * event)
     979              : {
     980              :   GstTensorTrainer *trainer;
     981            6 :   trainer = GST_TENSOR_TRAINER (parent);
     982              : 
     983            6 :   GST_DEBUG_OBJECT (trainer, "Received %s event: %" GST_PTR_FORMAT,
     984              :       GST_EVENT_TYPE_NAME (event), event);
     985              : 
     986            6 :   switch (GST_EVENT_TYPE (event)) {
     987            0 :     case GST_EVENT_EOS:
     988            0 :       if (!trainer->is_training_complete)
     989            0 :         gst_tensor_trainer_wait_for_training_completion (trainer);
     990            0 :       break;
     991            0 :     case GST_EVENT_FLUSH_START:
     992            0 :       GST_INFO_OBJECT (trainer, "get GST_EVENT_FLUSH_START event");
     993            0 :       break;
     994            0 :     case GST_EVENT_FLUSH_STOP:
     995            0 :       GST_INFO_OBJECT (trainer, "get GST_EVENT_FLUSH_STOP event");
     996            0 :       break;
     997            3 :     case GST_EVENT_CAPS:
     998              :     {
     999              :       GstCaps *in_caps;
    1000              :       GstCaps *out_caps;
    1001              :       GstStructure *structure;
    1002              :       GstTensorsConfig config;
    1003            3 :       gboolean ret = FALSE;
    1004              : 
    1005            3 :       gst_event_parse_caps (event, &in_caps);
    1006            3 :       GST_INFO_OBJECT (trainer, "[in-caps] : %" GST_PTR_FORMAT, in_caps);
    1007              : 
    1008            3 :       structure = gst_caps_get_structure (in_caps, 0);
    1009            3 :       if (!gst_tensors_config_from_structure (&config, structure) ||
    1010            3 :           !gst_tensors_config_validate (&config)) {
    1011            0 :         gst_tensors_config_free (&config);
    1012            0 :         gst_event_unref (event);
    1013            3 :         return FALSE;
    1014              :       }
    1015              : 
    1016              :       /* copy TensorsInfo from negotiated caps to GstTensorTrainerProperties's input_meta */
    1017            3 :       gst_tensors_info_copy (&trainer->prop.input_meta, &config.info);
    1018              : 
    1019              :       /* set tensor-config and out caps */
    1020            3 :       trainer->in_config = config;
    1021            3 :       trainer->out_config.rate_n = config.rate_n;
    1022            3 :       trainer->out_config.rate_d = config.rate_d;
    1023            3 :       gst_tensors_info_copy (&trainer->out_config.info, &trainer->output_meta);
    1024              : 
    1025              :       out_caps =
    1026            6 :           gst_tensor_pad_caps_from_config (trainer->srcpad,
    1027            3 :           &trainer->out_config);
    1028            3 :       GST_INFO_OBJECT (trainer, "[out-caps] : %" GST_PTR_FORMAT, out_caps);
    1029              : 
    1030            3 :       ret = gst_pad_set_caps (trainer->srcpad, out_caps);
    1031              : 
    1032            3 :       gst_event_unref (event);
    1033            3 :       gst_caps_unref (out_caps);
    1034            3 :       return ret;
    1035              :     }
    1036            3 :     default:
    1037            3 :       break;
    1038              :   }
    1039            3 :   return gst_pad_event_default (sinkpad, parent, event);
    1040              : }
    1041              : 
    1042              : /**
    1043              :  * @brief This function handles sink pad query.
    1044              :  */
    1045              : static gboolean
    1046           31 : gst_tensor_trainer_sink_query (GstPad * sinkpad, GstObject * parent,
    1047              :     GstQuery * query)
    1048              : {
    1049              :   GstTensorTrainer *trainer;
    1050           31 :   trainer = GST_TENSOR_TRAINER (parent);
    1051              : 
    1052           31 :   GST_DEBUG_OBJECT (trainer, "Received '%s' query: %" GST_PTR_FORMAT,
    1053              :       GST_QUERY_TYPE_NAME (query), query);
    1054              : 
    1055           31 :   switch (GST_QUERY_TYPE (query)) {
    1056           25 :     case GST_QUERY_CAPS:
    1057              :     {
    1058              :       GstCaps *caps;
    1059              :       GstCaps *filter;
    1060              : 
    1061           25 :       GST_DEBUG_OBJECT (trainer, "[GST_QUERY_CAPS]");
    1062           25 :       gst_query_parse_caps (query, &filter);
    1063           25 :       GST_DEBUG_OBJECT (trainer, "Caps from query : %" GST_PTR_FORMAT, filter);
    1064              : 
    1065           25 :       caps = gst_tensor_trainer_query_caps (trainer, sinkpad, filter);
    1066              : 
    1067           25 :       GST_INFO_OBJECT (trainer, "[GST_QUERY_CAPS] : %" GST_PTR_FORMAT, caps);
    1068           25 :       gst_query_set_caps_result (query, caps);
    1069           25 :       gst_caps_unref (caps);
    1070              : 
    1071           25 :       return TRUE;
    1072              :     }
    1073            3 :     case GST_QUERY_ACCEPT_CAPS:
    1074              :     {
    1075              :       GstCaps *caps;
    1076              :       GstCaps *template_caps;
    1077            3 :       gboolean result = FALSE;
    1078              : 
    1079            3 :       GST_DEBUG_OBJECT (trainer, "[GST_QUERY_ACCEPT_CAPS]");
    1080            3 :       gst_query_parse_accept_caps (query, &caps);
    1081            3 :       GST_INFO_OBJECT (trainer, "Accept caps from query : %" GST_PTR_FORMAT,
    1082              :           caps);
    1083              : 
    1084            3 :       if (gst_caps_is_fixed (caps)) {
    1085            3 :         template_caps = gst_pad_get_pad_template_caps (sinkpad);
    1086            3 :         GST_DEBUG_OBJECT (trainer, "sinkpad template_caps : %" GST_PTR_FORMAT,
    1087              :             template_caps);
    1088              : 
    1089            3 :         result = gst_caps_can_intersect (template_caps, caps);
    1090            3 :         gst_caps_unref (template_caps);
    1091              : 
    1092            3 :         GST_DEBUG_OBJECT (trainer, "intersect caps : %" GST_PTR_FORMAT, caps);
    1093              :       }
    1094              : 
    1095            3 :       gst_query_set_accept_caps_result (query, result);
    1096            3 :       return TRUE;
    1097              :     }
    1098            3 :     default:
    1099            3 :       break;
    1100              :   }
    1101              : 
    1102            3 :   return gst_pad_query_default (sinkpad, parent, query);
    1103              : }
    1104              : 
    1105              : /**
    1106              :  * @brief This function handles src pad query.
    1107              :  */
    1108              : static gboolean
    1109           22 : gst_tensor_trainer_src_query (GstPad * srcpad, GstObject * parent,
    1110              :     GstQuery * query)
    1111              : {
    1112              :   GstTensorTrainer *trainer;
    1113           22 :   trainer = GST_TENSOR_TRAINER (parent);
    1114              : 
    1115           22 :   GST_DEBUG_OBJECT (trainer, "Received %s query: %" GST_PTR_FORMAT,
    1116              :       GST_QUERY_TYPE_NAME (query), query);
    1117              : 
    1118           22 :   switch (GST_QUERY_TYPE (query)) {
    1119           22 :     case GST_QUERY_CAPS:
    1120              :     {
    1121              :       GstCaps *caps;
    1122              :       GstCaps *filter;
    1123           22 :       GST_DEBUG_OBJECT (trainer, "[GST_QUERY_CAPS]");
    1124           22 :       gst_query_parse_caps (query, &filter);
    1125           22 :       GST_DEBUG_OBJECT (trainer, "Caps from query : %" GST_PTR_FORMAT, filter);
    1126           22 :       caps = gst_tensor_trainer_query_caps (trainer, srcpad, filter);
    1127              : 
    1128           22 :       GST_INFO_OBJECT (trainer, "[GST_QUERY_CAPS] : %" GST_PTR_FORMAT, caps);
    1129           22 :       gst_query_set_caps_result (query, caps);
    1130           22 :       gst_caps_unref (caps);
    1131           22 :       return TRUE;
    1132              :     }
    1133            0 :     default:
    1134            0 :       break;
    1135              :   }
    1136            0 :   return gst_pad_query_default (srcpad, parent, query);
    1137              : }
    1138              : 
    1139              : /**
    1140              :  * @brief Handle "PROP_FRAMEWORK" for set-property
    1141              :  */
    1142              : static void
    1143           11 : gst_tensor_trainer_set_prop_framework (GstTensorTrainer * trainer,
    1144              :     const GValue * value)
    1145              : {
    1146           11 :   g_free (trainer->fw_name);
    1147           11 :   trainer->fw_name = g_value_dup_string (value);
    1148           11 :   GST_INFO_OBJECT (trainer, "Framework: %s", trainer->fw_name);
    1149              : 
    1150              :   /** @todo Check valid framework */
    1151           11 : }
    1152              : 
    1153              : /**
    1154              :  * @brief Handle "PROP_MODEL_CONFIG" for set-property
    1155              :  */
    1156              : static void
    1157           11 : gst_tensor_trainer_set_prop_model_config_file_path (GstTensorTrainer *
    1158              :     trainer, const GValue * value)
    1159              : {
    1160           11 :   g_free ((char *) trainer->prop.model_config);
    1161           11 :   trainer->prop.model_config = g_value_dup_string (value);
    1162           11 :   GST_INFO_OBJECT (trainer, "Model configuration file path: %s",
    1163              :       trainer->prop.model_config);
    1164           11 : }
    1165              : 
    1166              : /**
    1167              :  * @brief Handle "PROP_MODEL_SAVE_PATH" for set-property
    1168              :  */
    1169              : static void
    1170            8 : gst_tensor_trainer_set_model_save_path (GstTensorTrainer * trainer,
    1171              :     const GValue * value)
    1172              : {
    1173            8 :   g_free ((char *) trainer->prop.model_save_path);
    1174            8 :   trainer->prop.model_save_path = g_value_dup_string (value);
    1175            8 :   GST_INFO_OBJECT (trainer, "File path to save the model: %s",
    1176              :       trainer->prop.model_save_path);
    1177            8 : }
    1178              : 
    1179              : /**
    1180              :  * @brief Handle "PROP_MODEL_LOAD_PATH" for set-property
    1181              :  */
    1182              : static void
    1183            2 : gst_tensor_trainer_set_model_load_path (GstTensorTrainer * trainer,
    1184              :     const GValue * value)
    1185              : {
    1186            2 :   g_free ((char *) trainer->prop.model_load_path);
    1187            2 :   trainer->prop.model_load_path = g_value_dup_string (value);
    1188            2 :   GST_INFO_OBJECT (trainer, "File path to load the model: %s",
    1189              :       trainer->prop.model_load_path);
    1190            2 : }
    1191              : 
    1192              : /**
    1193              :  * @brief Find Trainer sub-plugin with the name.
    1194              :  */
    1195              : static gboolean
    1196            1 : gst_tensor_trainer_find_framework (GstTensorTrainer * trainer, const char *name)
    1197              : {
    1198            1 :   const GstTensorTrainerFramework *fw = NULL;
    1199              : 
    1200            1 :   g_return_val_if_fail (name != NULL, FALSE);
    1201            1 :   g_return_val_if_fail (trainer != NULL, FALSE);
    1202              : 
    1203            1 :   GST_INFO_OBJECT (trainer, "Try to find framework: %s", name);
    1204              : 
    1205            1 :   fw = get_subplugin (NNS_SUBPLUGIN_TRAINER, name);
    1206            1 :   if (!fw) {
    1207            1 :     GST_ERROR_OBJECT (trainer, "Can not find framework(%s)", trainer->fw_name);
    1208            1 :     return FALSE;
    1209              :   }
    1210              : 
    1211            0 :   GST_INFO_OBJECT (trainer, "Find framework %s:%p", trainer->fw_name, fw);
    1212            0 :   trainer->fw = fw;
    1213              : 
    1214            0 :   return TRUE;
    1215              : }
    1216              : 
    1217              : /**
    1218              :  * @brief Create NN framework.
    1219              :  */
    1220              : static gboolean
    1221            0 : gst_tensor_trainer_create_framework (GstTensorTrainer * trainer)
    1222              : {
    1223            0 :   g_return_val_if_fail (trainer != NULL, FALSE);
    1224              : 
    1225            0 :   if (!trainer->fw || trainer->fw_created) {
    1226            0 :     GST_ERROR_OBJECT (trainer, "fw is not opened(%d) or fw is not null(%p)",
    1227              :         trainer->fw_created, trainer->fw);
    1228            0 :     return FALSE;
    1229              :   }
    1230              : 
    1231            0 :   if (!trainer->fw->create) {
    1232            0 :     GST_ERROR_OBJECT (trainer, "Could not create framework");
    1233            0 :     return FALSE;
    1234              :   }
    1235              : 
    1236            0 :   GST_DEBUG_OBJECT (trainer, "%p", trainer->privateData);
    1237            0 :   if (trainer->fw->create (trainer->fw, &trainer->prop,
    1238              :           &trainer->privateData) >= 0) {
    1239            0 :     trainer->fw_created = TRUE;
    1240            0 :     GST_DEBUG_OBJECT (trainer, "Success, Framework: %p", trainer->privateData);
    1241            0 :     return TRUE;
    1242              :   }
    1243            0 :   return FALSE;
    1244              : }
    1245              : 
    1246              : /**
    1247              :  * @brief Calculate tensor buffer size
    1248              :  */
    1249              : gsize
    1250            0 : gst_tensor_trainer_get_tensor_size (GstTensorTrainer * trainer,
    1251              :     guint index, gboolean is_input)
    1252              : {
    1253              :   GstTensorsInfo *info;
    1254              : 
    1255            0 :   if (is_input)
    1256            0 :     info = &trainer->prop.input_meta;
    1257              :   else
    1258            0 :     info = &trainer->output_meta;
    1259              : 
    1260              :   /* Internal Logic Error: out of bound */
    1261            0 :   if (index >= info->num_tensors) {
    1262            0 :     GST_ERROR_OBJECT (trainer, "has inconsistent data");
    1263            0 :     return 0;
    1264              :   }
    1265              : 
    1266            0 :   return gst_tensors_info_get_size (info, index);
    1267              : }
    1268              : 
    1269              : /**
    1270              :  * @brief Create model
    1271              :  */
    1272              : static gboolean
    1273            1 : gst_tensor_trainer_create_model (GstTensorTrainer * trainer)
    1274              : {
    1275            1 :   gboolean ret = TRUE;
    1276              : 
    1277            1 :   g_return_val_if_fail (trainer != NULL, FALSE);
    1278            1 :   g_return_val_if_fail (trainer->fw_name != NULL, FALSE);
    1279              : 
    1280            1 :   ret = gst_tensor_trainer_find_framework (trainer, trainer->fw_name);
    1281            1 :   if (!ret)
    1282            1 :     return ret;
    1283              : 
    1284            0 :   if (trainer->fw) {
    1285              :     /* model create and compile */
    1286            0 :     ret = gst_tensor_trainer_create_framework (trainer);
    1287              :   }
    1288              : 
    1289            0 :   return ret;
    1290              : }
    1291              : 
    1292              : /**
    1293              :  * @brief Create a event notifier
    1294              :  */
    1295              : static void
    1296            0 : gst_tensor_trainer_create_event_notifier (GstTensorTrainer * trainer)
    1297              : {
    1298            0 :   g_return_if_fail (trainer != NULL);
    1299            0 :   g_return_if_fail (trainer->fw != NULL);
    1300              : 
    1301            0 :   trainer->notifier.notifier = (void *) trainer;
    1302              : }
    1303              : 
    1304              : /**
    1305              :  * @brief Start model training
    1306              :  */
    1307              : static void
    1308            0 : gst_tensor_trainer_start_model_training (GstTensorTrainer * trainer)
    1309              : {
    1310            0 :   gint ret = -1;
    1311            0 :   g_return_if_fail (trainer != NULL);
    1312            0 :   g_return_if_fail (trainer->fw != NULL);
    1313            0 :   g_return_if_fail (trainer->fw->start != NULL);
    1314              : 
    1315            0 :   GST_DEBUG_OBJECT (trainer, "Start model training");
    1316              :   ret =
    1317            0 :       trainer->fw->start (trainer->fw, &trainer->prop, &trainer->notifier,
    1318              :       trainer->privateData);
    1319            0 :   if (ret != 0) {
    1320            0 :     GST_ERROR_OBJECT (trainer, "Model training is failed");
    1321              :   }
    1322              : }
    1323              : 
    1324              : /**
    1325              :  * @brief Stop model training
    1326              :  */
    1327              : static void
    1328            0 : gst_tensor_trainer_stop_model_training (GstTensorTrainer * trainer)
    1329              : {
    1330            0 :   gint ret = -1;
    1331              : 
    1332            0 :   g_return_if_fail (trainer != NULL);
    1333            0 :   g_return_if_fail (trainer->fw != NULL);
    1334            0 :   g_return_if_fail (trainer->fw->stop != NULL);
    1335              : 
    1336            0 :   GST_DEBUG_OBJECT (trainer, "Stop model training");
    1337            0 :   ret = trainer->fw->stop (trainer->fw, &trainer->prop, &trainer->privateData);
    1338            0 :   if (ret != 0) {
    1339            0 :     GST_ERROR_OBJECT (trainer, "Stopping model training is failed");
    1340              :   }
    1341              : }
    1342              : 
    1343              : /**
    1344              :  * @brief initialize the output tensor dimension
    1345              :  */
    1346              : static void
    1347           11 : gst_tensor_trainer_set_output_meta (GstTensorTrainer * trainer)
    1348              : {
    1349              :   GstTensorInfo *info;
    1350              : 
    1351           11 :   g_return_if_fail (trainer != NULL);
    1352              : 
    1353           11 :   gst_tensors_info_init (&trainer->output_meta);
    1354           11 :   info = gst_tensors_info_get_nth_info (&trainer->output_meta, 0);
    1355              : 
    1356           11 :   info->type = _NNS_FLOAT64;
    1357           11 :   info->dimension[0] = 1;
    1358           11 :   info->dimension[1] = 1;
    1359           11 :   info->dimension[2] = 4; /** loss, accuracy, val_loss, val_accuracy */
    1360           11 :   info->dimension[3] = 1;
    1361              : 
    1362           11 :   trainer->output_meta.num_tensors = 1;
    1363              : }
    1364              : 
    1365              : /**
    1366              :  * @brief Trainer's sub-plugin should call this function to register itself.
    1367              :  * @param[in] ttsp tensor_trainer sub-plugin to be registered.
    1368              :  * @return TRUE if registered. FALSE is failed or duplicated.
    1369              :  */
    1370              : int
    1371            0 : nnstreamer_trainer_probe (GstTensorTrainerFramework * ttsp)
    1372              : {
    1373              :   GstTensorTrainerFrameworkInfo info;
    1374              :   GstTensorTrainerProperties prop;
    1375            0 :   const char *name = NULL;
    1376            0 :   int ret = 0;
    1377              : 
    1378            0 :   g_return_val_if_fail (ttsp != NULL, 0);
    1379              : 
    1380            0 :   memset (&prop, 0, sizeof (GstTensorTrainerProperties));
    1381            0 :   gst_tensors_info_init (&prop.input_meta);
    1382              : 
    1383            0 :   if (ret != ttsp->getFrameworkInfo (ttsp, &prop, NULL, &info)) {
    1384            0 :     GST_ERROR ("getFrameworkInfo() failed");
    1385            0 :     return FALSE;
    1386              :   }
    1387            0 :   name = info.name;
    1388              : 
    1389            0 :   return register_subplugin (NNS_SUBPLUGIN_TRAINER, name, ttsp);
    1390              : }
    1391              : 
    1392              : /**
    1393              :  * @brief Trainer's sub-plugin may call this to unregister itself.
    1394              :  * @param[in] ttsp tensor_trainer sub-plugin to be unregistered.
    1395              :  * @return TRUE if unregistered. FALSE is failed.
    1396              :  */
    1397              : int
    1398            0 : nnstreamer_trainer_exit (GstTensorTrainerFramework * ttsp)
    1399              : {
    1400              :   GstTensorTrainerFrameworkInfo info;
    1401              :   GstTensorTrainerProperties prop;
    1402            0 :   const char *name = NULL;
    1403            0 :   int ret = 0;
    1404              : 
    1405            0 :   g_return_val_if_fail (ttsp != NULL, 0);
    1406              : 
    1407            0 :   memset (&prop, 0, sizeof (GstTensorTrainerProperties));
    1408            0 :   gst_tensors_info_init (&prop.input_meta);
    1409              : 
    1410            0 :   if (ret != ttsp->getFrameworkInfo (ttsp, &prop, NULL, &info)) {
    1411            0 :     GST_ERROR ("getFrameworkInfo() failed");
    1412            0 :     return FALSE;
    1413              :   }
    1414            0 :   name = info.name;
    1415              : 
    1416            0 :   return unregister_subplugin (NNS_SUBPLUGIN_TRAINER, name);
    1417              : }
    1418              : 
    1419              : /**
    1420              :  * @brief Trainer's sub-plugin may call this to send event.
    1421              :  * @param[in] notifier event notifier, sub-plugin must send events with this.
    1422              :  * @param[in] type event type
    1423              :  */
    1424              : void
    1425            0 : nnstreamer_trainer_notify_event (GstTensorTrainerEventNotifier * notifier,
    1426              :     GstTensorTrainerEventType type, void *data)
    1427              : {
    1428              :   GstTensorTrainer *trainer;
    1429            0 :   g_return_if_fail (notifier != NULL);
    1430            0 :   g_return_if_fail (type < TRAINER_EVENT_UNKNOWN || type > 0);
    1431              :   UNUSED (data);
    1432              : 
    1433            0 :   trainer = (GstTensorTrainer *) notifier->notifier;
    1434            0 :   g_return_if_fail (GST_IS_TENSOR_TRAINER (trainer));
    1435              : 
    1436            0 :   GST_DEBUG ("Received GstTensorTrainerEvent(%d)", type);
    1437              : 
    1438            0 :   switch (type) {
    1439            0 :     case TRAINER_EVENT_EPOCH_COMPLETION:
    1440            0 :       g_mutex_lock (&trainer->epoch_completion_lock);
    1441            0 :       trainer->is_epoch_complete = TRUE;
    1442            0 :       GST_DEBUG ("send epoch_completion_cond signal");
    1443            0 :       g_cond_signal (&trainer->epoch_completion_cond);
    1444            0 :       g_mutex_unlock (&trainer->epoch_completion_lock);
    1445            0 :       break;
    1446            0 :     case TRAINER_EVENT_TRAINING_COMPLETION:
    1447            0 :       g_mutex_lock (&trainer->training_completion_lock);
    1448            0 :       trainer->is_training_complete = TRUE;
    1449            0 :       GST_DEBUG ("send training_completion_cond signal");
    1450            0 :       g_cond_signal (&trainer->training_completion_cond);
    1451            0 :       g_mutex_unlock (&trainer->training_completion_lock);
    1452            0 :       break;
    1453            0 :     default:
    1454            0 :       break;
    1455              :   }
    1456              : }
        

Generated by: LCOV version 2.0-1