Line data Source code
1 : /* SPDX-License-Identifier: LGPL-2.1-only */
2 : /**
3 : * Copyright (C) 2022 Samsung Electronics Co., Ltd.
4 : *
5 : * @file gsttensor_trainer.c
6 : * @date 20 October 2022
7 : * @brief GStreamer plugin to train tensor data using NN Frameworks
8 : * @see https://github.com/nnstreamer/nnstreamer
9 : * @author Hyunil Park <hyunil46.park@samsung.com>
10 : * @bug No known bugs except for NYI items
11 : *
12 : * ## Example launch line
13 : * |[
14 : * gst-launch-1.0 datareposrc location=mnist_trainingSet.dat json=mnist.json start-sample-index=3 stop-sample-index=202 epochs=5 ! \
15 : * tensor_trainer framework=nntrainer model-config=mnist.ini model-save-path=model.bin \
16 : * num-inputs=1 num-labels=1 num-training-samples=100 num-validation-samples=100 epochs=5 ! \
17 : * tensor_sink
18 : * ]|
19 : *
20 : * Total number of data to be received is 1000((num-training-samples + num-validation-samples) * epochs)
21 : *
22 : * output tensors : dimensions=1:1:4, types=float64.
23 : * values are training loss, training accuracy, validation loss and validation accuracy.
24 : * -INFINITY value is stored if the value fetched from the sub-plugin is not greater than 0.
25 : */
26 :
27 : #ifdef HAVE_CONFIG_H
28 : #include "config.h"
29 : #endif
30 : #include <stdlib.h>
31 : #include <nnstreamer_subplugin.h>
32 : #include <nnstreamer_util.h>
33 : #include "gsttensor_trainer.h"
34 : #include <unistd.h>
35 : #include <math.h>
36 :
37 : /**
38 : * @brief Default caps string for sink
39 : */
40 : #define SINK_CAPS_STRING GST_TENSORS_CAP_MAKE ("{ static, flexible }")
41 :
42 : /**
43 : * @brief Default caps string for src
44 : */
45 : #define SRC_CAPS_STRING GST_TENSORS_CAP_MAKE ("{ static}")
46 :
47 : /**
48 : * @brief The capabilities of the sink pad
49 : */
50 : static GstStaticPadTemplate sink_template = GST_STATIC_PAD_TEMPLATE ("sink",
51 : GST_PAD_SINK,
52 : GST_PAD_ALWAYS,
53 : GST_STATIC_CAPS (SINK_CAPS_STRING));
54 :
55 : /**
56 : * @brief The capabilities of the src pad
57 : */
58 : static GstStaticPadTemplate src_template = GST_STATIC_PAD_TEMPLATE ("src",
59 : GST_PAD_SRC,
60 : GST_PAD_ALWAYS,
61 : GST_STATIC_CAPS (SRC_CAPS_STRING));
62 :
63 : GST_DEBUG_CATEGORY_STATIC (gst_tensor_trainer_debug);
64 : #define GST_CAT_DEFAULT gst_tensor_trainer_debug
65 : #define gst_tensor_trainer_parent_class parent_class
66 1047 : G_DEFINE_TYPE (GstTensorTrainer, gst_tensor_trainer, GST_TYPE_ELEMENT);
67 :
68 : /**
69 : * @brief Statistical from the model being trained
70 : * An enum value indicates the value stored at the index of the output tensor.
71 : */
72 : enum
73 : {
74 : TRAINING_LOSS,
75 : TRAINING_ACCURACY,
76 : VALIDATION_LOSS,
77 : VALIDATION_ACCURACY
78 : };
79 : #define MODEL_STATS_SIZE 4
80 :
81 : /**
82 : * @brief Default framework property value
83 : */
84 : #define DEFAULT_PROP_INPUT_LIST 1
85 : #define DEFAULT_PROP_LABEL_LIST 1
86 : #define DEFAULT_PROP_TRAIN_SAMPLES 0
87 : #define DEFAULT_PROP_VALID_SAMPLES 0
88 : #define DEFAULT_PROP_EPOCHS 1
89 : /**
90 : * @brief Default string property value
91 : */
92 : #define DEFAULT_STR_PROP_VALUE ""
93 :
94 : /**
95 : * @brief tensor_trainer properties
96 : */
97 : enum
98 : {
99 : PROP_0,
100 : PROP_FRAMEWORK,
101 : PROP_MODEL_CONFIG,
102 : PROP_MODEL_SAVE_PATH,
103 : PROP_MODEL_LOAD_PATH,
104 : PROP_NUM_INPUTS, /* number of input list */
105 : PROP_NUM_LABELS, /* number of label list */
106 : PROP_NUM_TRAINING_SAMPLES, /* number of training data */
107 : PROP_NUM_VALIDATION_SAMPLES, /* number of validation data */
108 : PROP_EPOCHS, /* Repetitions of training */
109 : };
110 :
111 : static void gst_tensor_trainer_set_property (GObject * object, guint prop_id,
112 : const GValue * value, GParamSpec * pspec);
113 : static void gst_tensor_trainer_get_property (GObject * object, guint prop_id,
114 : GValue * value, GParamSpec * pspec);
115 : static void gst_tensor_trainer_finalize (GObject * object);
116 : static gboolean gst_tensor_trainer_sink_event (GstPad * sinkpad,
117 : GstObject * parent, GstEvent * event);
118 : static gboolean gst_tensor_trainer_sink_query (GstPad * sinkpad,
119 : GstObject * parent, GstQuery * query);
120 : static gboolean gst_tensor_trainer_src_query (GstPad * srcpad,
121 : GstObject * parent, GstQuery * query);
122 : static GstFlowReturn gst_tensor_trainer_chain (GstPad * sinkpad,
123 : GstObject * parent, GstBuffer * inbuf);
124 : static GstCaps *gst_tensor_trainer_query_caps (GstTensorTrainer * trainer,
125 : GstPad * pad, GstCaps * filter);
126 : static GstStateChangeReturn gst_tensor_trainer_change_state (GstElement *
127 : element, GstStateChange transition);
128 :
129 : static void gst_tensor_trainer_set_prop_framework (GstTensorTrainer * trainer,
130 : const GValue * value);
131 : static void gst_tensor_trainer_set_prop_model_config_file_path (GstTensorTrainer
132 : * trainer, const GValue * value);
133 : static void gst_tensor_trainer_set_model_save_path (GstTensorTrainer * trainer,
134 : const GValue * value);
135 : static void gst_tensor_trainer_set_model_load_path (GstTensorTrainer * trainer,
136 : const GValue * value);
137 : static gboolean gst_tensor_trainer_find_framework (GstTensorTrainer * trainer,
138 : const char *name);
139 : static gboolean gst_tensor_trainer_create_framework (GstTensorTrainer *
140 : trainer);
141 : static gsize gst_tensor_trainer_get_tensor_size (GstTensorTrainer * trainer,
142 : guint index, gboolean is_input);
143 : static gboolean gst_tensor_trainer_create_model (GstTensorTrainer * trainer);
144 : static void gst_tensor_trainer_create_event_notifier (GstTensorTrainer *
145 : trainer);
146 : static void gst_tensor_trainer_start_model_training (GstTensorTrainer *
147 : trainer);
148 : static void gst_tensor_trainer_stop_model_training (GstTensorTrainer * trainer);
149 : static void gst_tensor_trainer_set_output_meta (GstTensorTrainer * trainer);
150 :
151 : /**
152 : * @brief initialize the tensor_trainer's class
153 : */
154 : static void
155 1 : gst_tensor_trainer_class_init (GstTensorTrainerClass * klass)
156 : {
157 : GObjectClass *gobject_class;
158 : GstElementClass *gstelement_class;
159 :
160 1 : GST_DEBUG_CATEGORY_INIT (GST_CAT_DEFAULT, "tensor_trainer", 0,
161 : "Tensor trainer to train neural network model");
162 :
163 1 : gobject_class = G_OBJECT_CLASS (klass);
164 1 : gstelement_class = GST_ELEMENT_CLASS (klass);
165 :
166 1 : gobject_class->set_property =
167 1 : GST_DEBUG_FUNCPTR (gst_tensor_trainer_set_property);
168 1 : gobject_class->get_property =
169 1 : GST_DEBUG_FUNCPTR (gst_tensor_trainer_get_property);
170 1 : gobject_class->finalize = GST_DEBUG_FUNCPTR (gst_tensor_trainer_finalize);
171 :
172 : /* Called when the element's state changes */
173 1 : gstelement_class->change_state =
174 1 : GST_DEBUG_FUNCPTR (gst_tensor_trainer_change_state);
175 :
176 : /* Install properties for tensor_trainer */
177 1 : g_object_class_install_property (gobject_class, PROP_FRAMEWORK,
178 : g_param_spec_string ("framework", "Framework",
179 : "(not nullable) Neural network framework to be used for model training, ",
180 : DEFAULT_STR_PROP_VALUE,
181 : G_PARAM_READWRITE | GST_PARAM_MUTABLE_READY |
182 : G_PARAM_STATIC_STRINGS));
183 :
184 1 : g_object_class_install_property (gobject_class, PROP_MODEL_CONFIG,
185 : g_param_spec_string ("model-config", "Model configuration file path",
186 : "(not nullable) Model configuration file is used to configure the model "
187 : "to be trained in neural network framework, set the file path",
188 : DEFAULT_STR_PROP_VALUE,
189 : G_PARAM_READWRITE | GST_PARAM_MUTABLE_READY |
190 : G_PARAM_STATIC_STRINGS));
191 :
192 1 : g_object_class_install_property (gobject_class, PROP_MODEL_SAVE_PATH,
193 : g_param_spec_string ("model-save-path", "Model save path",
194 : "(not nullable) Path to save the trained model in framework, if model-config "
195 : "contains information about the save file, it is ignored",
196 : DEFAULT_STR_PROP_VALUE,
197 : G_PARAM_READWRITE | GST_PARAM_MUTABLE_READY |
198 : G_PARAM_STATIC_STRINGS));
199 :
200 1 : g_object_class_install_property (gobject_class, PROP_MODEL_LOAD_PATH,
201 : g_param_spec_string ("model-load-path", "Model load path",
202 : "(nullable) Path to a model file to be loaded for the given training session.",
203 : DEFAULT_STR_PROP_VALUE,
204 : G_PARAM_READWRITE | GST_PARAM_MUTABLE_READY |
205 : G_PARAM_STATIC_STRINGS));
206 :
207 1 : g_object_class_install_property (gobject_class, PROP_NUM_INPUTS,
208 : g_param_spec_uint ("num-inputs", "Number of inputs",
209 : "An input in a tensor can have one or more features data,"
210 : "set how many inputs are received", 0, NNS_TENSOR_SIZE_LIMIT, 1,
211 : G_PARAM_READWRITE | GST_PARAM_MUTABLE_READY |
212 : G_PARAM_STATIC_STRINGS));
213 :
214 1 : g_object_class_install_property (gobject_class, PROP_NUM_LABELS,
215 : g_param_spec_uint ("num-labels", "Number of labels",
216 : "A label in a tensor can have one or more classes data,"
217 : "set how many labels are received", 0, NNS_TENSOR_SIZE_LIMIT, 1,
218 : G_PARAM_READWRITE | GST_PARAM_MUTABLE_READY |
219 : G_PARAM_STATIC_STRINGS));
220 :
221 1 : g_object_class_install_property (gobject_class, PROP_NUM_TRAINING_SAMPLES,
222 : g_param_spec_uint ("num-training-samples", "Number of training samples",
223 : "A sample can consist of multiple inputs and labels in tensors of a gstbuffer"
224 : ", set how many samples are taken for training model",
225 : 0, G_MAXINT, 0,
226 : G_PARAM_READWRITE | GST_PARAM_MUTABLE_READY |
227 : G_PARAM_STATIC_STRINGS));
228 :
229 1 : g_object_class_install_property (gobject_class, PROP_NUM_VALIDATION_SAMPLES,
230 : g_param_spec_uint ("num-validation-samples",
231 : "Number of validation samples",
232 : "A sample can consist of multiple inputs and labels in tensors of a gstbuffer"
233 : ", set how many samples are taken for validation model",
234 : 0, G_MAXINT, 0,
235 : G_PARAM_READWRITE | GST_PARAM_MUTABLE_READY |
236 : G_PARAM_STATIC_STRINGS));
237 :
238 1 : g_object_class_install_property (gobject_class, PROP_EPOCHS,
239 : g_param_spec_uint ("epochs", "Number of epoch",
240 : "Epochs are repetitions of training samples and validation samples, "
241 : "number of samples received for model training is "
242 : "(num-training-samples+num-validation-samples)*epochs", 0, G_MAXINT,
243 : DEFAULT_PROP_EPOCHS,
244 : G_PARAM_READWRITE | GST_PARAM_MUTABLE_READY |
245 : G_PARAM_STATIC_STRINGS));
246 :
247 1 : gst_element_class_set_details_simple (gstelement_class, "TensorTrainer",
248 : "Trainer/Tensor", "Train tensor data using NN Frameworks",
249 : "Samsung Electronics Co., Ltd.");
250 :
251 : /* Add pad template */
252 1 : gst_element_class_add_pad_template (gstelement_class,
253 : gst_static_pad_template_get (&src_template));
254 1 : gst_element_class_add_pad_template (gstelement_class,
255 : gst_static_pad_template_get (&sink_template));
256 1 : }
257 :
258 : /**
259 : * @brief Initialize tensor_trainer.
260 : */
261 : static void
262 11 : gst_tensor_trainer_init (GstTensorTrainer * trainer)
263 : {
264 11 : GST_DEBUG ("<ENTER>");
265 : /** setup sink pad */
266 11 : trainer->sinkpad = gst_pad_new_from_static_template (&sink_template, "sink");
267 11 : gst_pad_set_event_function (trainer->sinkpad,
268 : GST_DEBUG_FUNCPTR (gst_tensor_trainer_sink_event));
269 11 : gst_pad_set_query_function (trainer->sinkpad,
270 : GST_DEBUG_FUNCPTR (gst_tensor_trainer_sink_query));
271 11 : gst_pad_set_chain_function (trainer->sinkpad,
272 : GST_DEBUG_FUNCPTR (gst_tensor_trainer_chain));
273 11 : GST_PAD_SET_PROXY_CAPS (trainer->sinkpad);
274 11 : gst_element_add_pad (GST_ELEMENT (trainer), trainer->sinkpad);
275 :
276 : /** setup src pad */
277 11 : trainer->srcpad = gst_pad_new_from_static_template (&src_template, "src");
278 11 : gst_pad_set_query_function (trainer->srcpad,
279 : GST_DEBUG_FUNCPTR (gst_tensor_trainer_src_query));
280 11 : GST_PAD_SET_PROXY_CAPS (trainer->srcpad);
281 11 : gst_element_add_pad (GST_ELEMENT (trainer), trainer->srcpad);
282 :
283 : /** init properties */
284 11 : trainer->fw_name = g_strdup (DEFAULT_STR_PROP_VALUE);
285 11 : trainer->prop.model_config = g_strdup (DEFAULT_STR_PROP_VALUE);
286 11 : trainer->prop.model_save_path = g_strdup (DEFAULT_STR_PROP_VALUE);
287 11 : trainer->prop.model_load_path = NULL;
288 11 : trainer->prop.num_inputs = DEFAULT_PROP_INPUT_LIST;
289 11 : trainer->prop.num_labels = DEFAULT_PROP_LABEL_LIST;
290 11 : trainer->prop.num_training_samples = DEFAULT_PROP_TRAIN_SAMPLES;
291 11 : trainer->prop.num_validation_samples = DEFAULT_PROP_VALID_SAMPLES;
292 11 : trainer->prop.num_epochs = DEFAULT_PROP_EPOCHS;
293 :
294 11 : trainer->fw = NULL;
295 11 : trainer->fw_created = FALSE;
296 11 : trainer->is_training_complete = FALSE;
297 11 : trainer->is_epoch_complete = FALSE;
298 11 : trainer->cur_epoch_data_cnt = 0;
299 11 : trainer->required_sample = 0;
300 :
301 11 : gst_tensors_config_init (&trainer->in_config);
302 11 : gst_tensors_config_init (&trainer->out_config);
303 :
304 11 : g_cond_init (&trainer->training_completion_cond);
305 11 : g_mutex_init (&trainer->training_completion_lock);
306 11 : g_cond_init (&trainer->epoch_completion_cond);
307 11 : g_mutex_init (&trainer->epoch_completion_lock);
308 :
309 11 : gst_tensor_trainer_set_output_meta (trainer);
310 11 : }
311 :
312 : /**
313 : * @brief Function to finalize instance.
314 : */
315 : static void
316 8 : gst_tensor_trainer_finalize (GObject * object)
317 : {
318 : GstTensorTrainer *trainer;
319 8 : trainer = GST_TENSOR_TRAINER (object);
320 :
321 8 : g_free (trainer->fw_name);
322 8 : g_free ((char *) trainer->prop.model_config);
323 8 : g_free ((char *) trainer->prop.model_save_path);
324 8 : g_free ((char *) trainer->prop.model_load_path);
325 :
326 8 : gst_tensors_config_free (&trainer->in_config);
327 8 : gst_tensors_config_free (&trainer->out_config);
328 :
329 8 : g_cond_clear (&trainer->training_completion_cond);
330 8 : g_mutex_clear (&trainer->training_completion_lock);
331 8 : g_cond_clear (&trainer->epoch_completion_cond);
332 8 : g_mutex_clear (&trainer->epoch_completion_lock);
333 :
334 8 : if (trainer->dummy_data_thread) {
335 0 : g_thread_join (trainer->dummy_data_thread);
336 0 : trainer->dummy_data_thread = NULL;
337 : }
338 :
339 8 : if (trainer->fw_created && trainer->fw) {
340 0 : trainer->fw->destroy (trainer->fw, &trainer->prop, &trainer->privateData);
341 : }
342 :
343 8 : G_OBJECT_CLASS (parent_class)->finalize (object);
344 8 : }
345 :
346 : /**
347 : * @brief Setter for tensor_trainsink properties.
348 : */
349 : static void
350 81 : gst_tensor_trainer_set_property (GObject * object, guint prop_id,
351 : const GValue * value, GParamSpec * pspec)
352 : {
353 : GstTensorTrainer *trainer;
354 :
355 81 : trainer = GST_TENSOR_TRAINER (object);
356 :
357 81 : switch (prop_id) {
358 11 : case PROP_FRAMEWORK:
359 11 : gst_tensor_trainer_set_prop_framework (trainer, value);
360 11 : break;
361 11 : case PROP_MODEL_CONFIG:
362 11 : gst_tensor_trainer_set_prop_model_config_file_path (trainer, value);
363 11 : break;
364 8 : case PROP_MODEL_SAVE_PATH:
365 8 : gst_tensor_trainer_set_model_save_path (trainer, value);
366 8 : break;
367 2 : case PROP_MODEL_LOAD_PATH:
368 2 : gst_tensor_trainer_set_model_load_path (trainer, value);
369 2 : break;
370 10 : case PROP_NUM_INPUTS:
371 10 : trainer->prop.num_inputs = g_value_get_uint (value);
372 10 : break;
373 10 : case PROP_NUM_LABELS:
374 10 : trainer->prop.num_labels = g_value_get_uint (value);
375 10 : break;
376 9 : case PROP_NUM_TRAINING_SAMPLES:
377 9 : trainer->prop.num_training_samples = g_value_get_uint (value);
378 9 : break;
379 10 : case PROP_NUM_VALIDATION_SAMPLES:
380 10 : trainer->prop.num_validation_samples = g_value_get_uint (value);
381 10 : break;
382 10 : case PROP_EPOCHS:
383 10 : trainer->prop.num_epochs = g_value_get_uint (value);
384 10 : break;
385 0 : default:
386 0 : G_OBJECT_WARN_INVALID_PROPERTY_ID (object, prop_id, pspec);
387 0 : break;
388 : }
389 81 : }
390 :
391 : /**
392 : * @brief Getter tensor_trainsink properties.
393 : */
394 : static void
395 14 : gst_tensor_trainer_get_property (GObject * object, guint prop_id,
396 : GValue * value, GParamSpec * pspec)
397 : {
398 : GstTensorTrainer *trainer;
399 :
400 14 : trainer = GST_TENSOR_TRAINER (object);
401 :
402 14 : switch (prop_id) {
403 0 : case PROP_FRAMEWORK:
404 0 : g_value_set_string (value, trainer->fw_name);
405 0 : break;
406 1 : case PROP_MODEL_CONFIG:
407 1 : g_value_set_string (value, trainer->prop.model_config);
408 1 : break;
409 1 : case PROP_MODEL_SAVE_PATH:
410 1 : g_value_set_string (value, trainer->prop.model_save_path);
411 1 : break;
412 2 : case PROP_MODEL_LOAD_PATH:
413 2 : g_value_set_string (value, trainer->prop.model_load_path);
414 2 : break;
415 2 : case PROP_NUM_INPUTS:
416 2 : g_value_set_uint (value, trainer->prop.num_inputs);
417 2 : break;
418 2 : case PROP_NUM_LABELS:
419 2 : g_value_set_uint (value, trainer->prop.num_labels);
420 2 : break;
421 2 : case PROP_NUM_TRAINING_SAMPLES:
422 2 : g_value_set_uint (value, trainer->prop.num_training_samples);
423 2 : break;
424 2 : case PROP_NUM_VALIDATION_SAMPLES:
425 2 : g_value_set_uint (value, trainer->prop.num_validation_samples);
426 2 : break;
427 2 : case PROP_EPOCHS:
428 2 : g_value_set_uint (value, trainer->prop.num_epochs);
429 2 : break;
430 0 : default:
431 0 : G_OBJECT_WARN_INVALID_PROPERTY_ID (object, prop_id, pspec);
432 0 : break;
433 : }
434 14 : }
435 :
436 : /**
437 : * @brief Check invalid param
438 : */
439 : static gboolean
440 3 : gst_tensor_trainer_check_invalid_param (GstTensorTrainer * trainer)
441 : {
442 3 : g_return_val_if_fail (trainer != NULL, FALSE);
443 :
444 : /* Parameters that can be retrieved from caps will be removed */
445 3 : if (!trainer->fw_name
446 2 : || (g_ascii_strcasecmp (trainer->prop.model_config,
447 : DEFAULT_STR_PROP_VALUE) == 0)
448 2 : || (g_ascii_strcasecmp (trainer->prop.model_save_path,
449 : DEFAULT_STR_PROP_VALUE) == 0)
450 1 : || trainer->prop.num_epochs <= 0 || trainer->prop.num_inputs <= 0
451 1 : || trainer->prop.num_labels <= 0) {
452 2 : GST_ERROR_OBJECT (trainer, "Check for invalid param value");
453 :
454 2 : return FALSE;
455 : }
456 :
457 1 : if (!g_file_test (trainer->prop.model_config,
458 : (G_FILE_TEST_EXISTS | G_FILE_TEST_IS_REGULAR))) {
459 0 : GST_ERROR_OBJECT (trainer, "Model config file does not exist. [%s]",
460 : trainer->prop.model_config);
461 0 : return FALSE;
462 : }
463 :
464 1 : return TRUE;
465 : }
466 :
467 : /**
468 : * @brief Dummy data generation thread
469 : */
470 : static gpointer
471 0 : gst_tensor_trainer_dummy_data_generation_func (GstTensorTrainer * trainer)
472 : {
473 : guint i;
474 0 : gint ret = -1;
475 0 : gpointer dummy_data[NNS_TENSOR_SIZE_LIMIT] = { NULL };
476 0 : g_return_val_if_fail (trainer != NULL, NULL);
477 :
478 0 : gst_tensor_trainer_stop_model_training (trainer);
479 :
480 0 : for (i = 0; i < trainer->output_meta.num_tensors; i++) {
481 0 : dummy_data[i] = g_malloc (trainer->input_tensors[i].size);
482 0 : memset (dummy_data[i], 1, trainer->input_tensors[i].size);
483 0 : trainer->input_tensors[i].data = dummy_data[i];
484 : }
485 :
486 : do {
487 0 : GST_INFO_OBJECT (trainer, "cur_epoch_data_cnt=%u",
488 : trainer->cur_epoch_data_cnt);
489 0 : GST_INFO_OBJECT (trainer, "num_tensors=%d",
490 : trainer->prop.input_meta.num_tensors);
491 :
492 : ret =
493 0 : trainer->fw->push_data (trainer->fw, &trainer->prop,
494 0 : trainer->privateData, trainer->input_tensors);
495 :
496 0 : if (ret < 0) {
497 0 : GST_ERROR_OBJECT (trainer, "Failed to push dummy data");
498 : } else {
499 0 : trainer->cur_epoch_data_cnt++;
500 : }
501 0 : } while (trainer->required_sample > trainer->cur_epoch_data_cnt);
502 :
503 0 : for (i = 0; i < trainer->output_meta.num_tensors; i++)
504 0 : g_free (dummy_data[i]);
505 :
506 0 : return NULL;
507 : }
508 :
509 : /**
510 : * @brief Change state of tensor_trainsink.
511 : */
512 : static GstStateChangeReturn
513 20 : gst_tensor_trainer_change_state (GstElement * element,
514 : GstStateChange transition)
515 : {
516 20 : GstTensorTrainer *trainer = GST_TENSOR_TRAINER (element);
517 20 : GstStateChangeReturn ret = GST_STATE_CHANGE_SUCCESS;
518 :
519 20 : switch (transition) {
520 6 : case GST_STATE_CHANGE_NULL_TO_READY:
521 6 : GST_INFO_OBJECT (trainer, "NULL_TO_READY");
522 : /* currently not used */
523 6 : trainer->is_training_complete = FALSE;
524 6 : break;
525 :
526 4 : case GST_STATE_CHANGE_READY_TO_PAUSED:
527 4 : GST_INFO_OBJECT (trainer, "READY_TO_PAUSED");
528 4 : break;
529 :
530 3 : case GST_STATE_CHANGE_PAUSED_TO_PLAYING:
531 3 : GST_INFO_OBJECT (trainer, "PAUSED_TO_PLAYING");
532 3 : if (!gst_tensor_trainer_check_invalid_param (trainer))
533 2 : goto state_change_failed;
534 1 : if (!trainer->fw_created) {
535 1 : if (!gst_tensor_trainer_create_model (trainer))
536 1 : goto state_change_failed;
537 : }
538 0 : gst_tensor_trainer_create_event_notifier (trainer);
539 0 : gst_tensor_trainer_start_model_training (trainer);
540 0 : break;
541 :
542 7 : default:
543 7 : break;
544 : }
545 :
546 17 : ret = GST_ELEMENT_CLASS (parent_class)->change_state (element, transition);
547 :
548 17 : switch (transition) {
549 0 : case GST_STATE_CHANGE_PLAYING_TO_PAUSED:
550 0 : GST_INFO_OBJECT (trainer, "PLAYING_TO_PAUSED");
551 : /* need to generate dummy data */
552 0 : if (!trainer->is_training_complete) {
553 0 : if (!g_strcmp0 (trainer->fw_name, "nntrainer")) {
554 0 : GST_INFO_OBJECT (trainer, "cur_epoch_data_cnt=%u",
555 : trainer->cur_epoch_data_cnt);
556 0 : trainer->dummy_data_thread =
557 0 : g_thread_new ("dumy_data_generation_func",
558 : (GThreadFunc) gst_tensor_trainer_dummy_data_generation_func,
559 : trainer);
560 : }
561 : }
562 0 : break;
563 :
564 1 : case GST_STATE_CHANGE_PAUSED_TO_READY:
565 1 : GST_INFO_OBJECT (trainer, "PAUSED_TO_READY");
566 : /* stop model train ? */
567 1 : break;
568 :
569 3 : case GST_STATE_CHANGE_READY_TO_NULL:
570 3 : GST_INFO_OBJECT (trainer, "READY_TO_NULL");
571 : /* destroy or reset model ? */
572 3 : break;
573 :
574 13 : default:
575 13 : break;
576 : }
577 :
578 17 : return ret;
579 :
580 3 : state_change_failed:
581 3 : GST_ERROR_OBJECT (trainer, "state change failed");
582 :
583 3 : return GST_STATE_CHANGE_FAILURE;
584 : }
585 :
586 : /**
587 : * @brief Wait for epoch eompletion
588 : */
589 : static void
590 0 : gst_tensor_trainer_wait_for_epoch_completion (GstTensorTrainer * trainer)
591 : {
592 0 : g_return_if_fail (trainer != NULL);
593 :
594 0 : g_mutex_lock (&trainer->epoch_completion_lock);
595 0 : while (!trainer->is_epoch_complete) {
596 0 : GST_INFO_OBJECT (trainer, "wait for epoch_completion_cond signal");
597 0 : g_cond_wait (&trainer->epoch_completion_cond,
598 : &trainer->epoch_completion_lock);
599 : }
600 0 : trainer->is_epoch_complete = FALSE;
601 0 : g_mutex_unlock (&trainer->epoch_completion_lock);
602 : }
603 :
604 : /**
605 : * @brief Check if current epochs is complete,
606 : * tensor_trainer wait for one of epochs to complete before getting the results from the subplugin
607 : */
608 : static gboolean
609 0 : gst_tensor_trainer_epochs_is_complete (GstTensorTrainer * trainer)
610 : {
611 0 : g_return_val_if_fail (trainer != NULL, FALSE);
612 0 : g_return_val_if_fail (trainer->fw != NULL, FALSE);
613 0 : g_return_val_if_fail (&trainer->prop != NULL, FALSE);
614 :
615 0 : trainer->required_sample =
616 0 : trainer->prop.num_training_samples + trainer->prop.num_validation_samples;
617 0 : if (trainer->cur_epoch_data_cnt != trainer->required_sample)
618 0 : return FALSE;
619 :
620 0 : gst_tensor_trainer_wait_for_epoch_completion (trainer);
621 0 : trainer->cur_epoch_data_cnt = 0;
622 0 : return TRUE;
623 : }
624 :
625 : /**
626 : * @brief Check buffer drop conditions. If condition is met, drop the buffer.
627 : */
628 : static gboolean
629 0 : gst_tensor_trainer_check_buffer_drop_conditions (GstTensorTrainer * trainer)
630 : {
631 0 : if (trainer->is_training_complete == TRUE) {
632 : /** app need to send gst_element_send_event(tensor_trainer, gst_event_new_eos())
633 : after training_complete or set eos to datareposrc */
634 0 : GST_WARNING_OBJECT (trainer,
635 : "Training is completed, buffer is dropped, please change state of pipeline");
636 0 : return TRUE;
637 : }
638 0 : return FALSE;
639 : }
640 :
641 : /**
642 : * @brief Check chain conditions. If all conditions are met, proceed to next step.
643 : */
644 : static gboolean
645 0 : gst_tensor_trainer_check_chain_conditions (GstTensorTrainer * trainer,
646 : guint num_tensors)
647 : {
648 0 : if (!trainer->fw_created) {
649 0 : if (!gst_tensor_trainer_check_invalid_param (trainer))
650 0 : return FALSE;;
651 0 : if (!gst_tensor_trainer_create_model (trainer))
652 0 : return FALSE;
653 : }
654 :
655 0 : if (num_tensors >= NNS_TENSOR_SIZE_LIMIT)
656 0 : return FALSE;
657 :
658 0 : return TRUE;
659 : }
660 :
661 : /**
662 : * @brief Convert tensor meta and get the size of tensor header.
663 : */
664 : static gsize
665 0 : gst_tensor_trainer_convert_meta (GstTensorTrainer * trainer,
666 : GstTensorMetaInfo * meta, GstTensorInfo * info, void *data)
667 : {
668 0 : gsize header_size = 0;
669 :
670 0 : if (!gst_tensor_meta_info_parse_header (meta, data)) {
671 0 : GST_ERROR_OBJECT (trainer, "Invalid Flexible tensors");
672 0 : return 0;
673 : }
674 :
675 0 : if (gst_tensor_meta_info_convert (meta, info)) {
676 0 : header_size = gst_tensor_meta_info_get_header_size (meta);
677 0 : GST_INFO ("flexible header size:%zd", header_size);
678 : }
679 :
680 0 : return header_size;
681 : }
682 :
683 : /**
684 : * @brief Create input tensors from the buffer and push it into trainer fw.
685 : */
686 : static gboolean
687 0 : gst_tensor_trainer_push_input (GstTensorTrainer * trainer, GstBuffer * inbuf,
688 : gboolean in_flexible)
689 : {
690 : guint i, n;
691 0 : GstMemory *in_mem[NNS_TENSOR_SIZE_LIMIT] = { 0, };
692 : GstMapInfo in_info[NNS_TENSOR_SIZE_LIMIT];
693 : GstTensorMetaInfo in_meta[NNS_TENSOR_SIZE_LIMIT];
694 : GstTensorInfo *info;
695 0 : gsize header_size = 0, expected;
696 0 : gint ret = -1;
697 :
698 0 : n = gst_tensor_buffer_get_count (inbuf);
699 :
700 0 : if (in_flexible)
701 0 : trainer->prop.input_meta.num_tensors = n;
702 : else {
703 0 : GST_DEBUG_OBJECT (trainer, "num_tensors: %u",
704 : trainer->prop.input_meta.num_tensors);
705 0 : if (n != trainer->prop.input_meta.num_tensors) {
706 0 : GST_ERROR_OBJECT (trainer,
707 : "Invalid memory blocks (%u), number of input tensors may be (%u)",
708 : n, trainer->prop.input_meta.num_tensors);
709 0 : goto error;
710 : }
711 : }
712 :
713 0 : for (i = 0; i < n; i++) {
714 0 : in_mem[i] = gst_tensor_buffer_get_nth_memory (inbuf, i);
715 0 : if (!gst_memory_map (in_mem[i], &in_info[i], GST_MAP_READ)) {
716 0 : GST_ERROR_OBJECT (trainer, "Could not map in_mem[%u] GstMemory", i);
717 0 : goto error;
718 : }
719 :
720 0 : if (in_flexible) {
721 0 : info = gst_tensors_info_get_nth_info (&trainer->prop.input_meta, i);
722 0 : header_size = gst_tensor_trainer_convert_meta (trainer,
723 0 : &in_meta[i], info, in_info[i].data);
724 0 : if (header_size == 0)
725 0 : goto error;
726 : }
727 :
728 0 : trainer->input_tensors[i].data = in_info[i].data + header_size;
729 0 : trainer->input_tensors[i].size = in_info[i].size - header_size;
730 0 : GST_INFO ("input_tensors[%u].size= %zd", i, trainer->input_tensors[i].size);
731 0 : GST_INFO ("input_tensors[%u].data: %p", i, trainer->input_tensors[i].data);
732 :
733 : /* Check input tensor size */
734 0 : expected = gst_tensor_trainer_get_tensor_size (trainer, i, TRUE);
735 0 : if (expected != trainer->input_tensors[i].size) {
736 0 : GST_ERROR_OBJECT (trainer,
737 : "Invalid tensor size (%u'th memory chunk: %zd), expected size (%zd)",
738 : i, trainer->input_tensors[i].size, expected);
739 0 : goto error;
740 : }
741 : }
742 :
743 0 : ret = trainer->fw->push_data (trainer->fw, &trainer->prop,
744 0 : trainer->privateData, trainer->input_tensors);
745 :
746 0 : if (ret < 0)
747 0 : GST_ERROR_OBJECT (trainer, "push error");
748 : else
749 0 : trainer->cur_epoch_data_cnt++;
750 :
751 0 : error:
752 0 : for (i = 0; i < n; i++) {
753 0 : if (in_mem[i]) {
754 0 : gst_memory_unmap (in_mem[i], &in_info[i]);
755 0 : gst_memory_unref (in_mem[i]);
756 : }
757 :
758 0 : trainer->input_tensors[i].data = NULL;
759 0 : trainer->input_tensors[i].size = 0;
760 : }
761 :
762 0 : return (ret == 0);
763 : }
764 :
765 : /**
766 : * @brief Get the model statistics from the sub-plugin.
767 : */
768 : static gboolean
769 0 : gst_tensor_trainer_get_model_stats (GstTensorTrainer * trainer,
770 : double *model_stats)
771 : {
772 0 : gint ret = -1;
773 :
774 : ret =
775 0 : trainer->fw->getStatus (trainer->fw, &trainer->prop,
776 : trainer->privateData);
777 0 : if (ret < 0) {
778 0 : GST_ERROR_OBJECT (trainer, "Failed to Get status from sub-plugin.(%s).",
779 : trainer->fw_name);
780 0 : return FALSE;
781 : }
782 : /* If the value is invalid, it is already set by -INFINITY. */
783 0 : if (trainer->prop.training_loss > 0)
784 0 : model_stats[TRAINING_LOSS] = trainer->prop.training_loss;
785 0 : if (trainer->prop.training_accuracy > 0)
786 0 : model_stats[TRAINING_ACCURACY] = trainer->prop.training_accuracy;
787 0 : if (trainer->prop.validation_loss > 0)
788 0 : model_stats[VALIDATION_LOSS] = trainer->prop.validation_loss;
789 0 : if (trainer->prop.validation_accuracy > 0)
790 0 : model_stats[VALIDATION_ACCURACY] = trainer->prop.validation_accuracy;
791 :
792 0 : GST_DEBUG_OBJECT (trainer,
793 : "#%u/%u epochs [training_loss: %f, training_accuracy: %f, validation_loss: %f, validation_accuracy: %f]",
794 : trainer->prop.epoch_count, trainer->prop.num_epochs,
795 : model_stats[TRAINING_LOSS], model_stats[TRAINING_ACCURACY],
796 : model_stats[VALIDATION_LOSS], model_stats[VALIDATION_ACCURACY]);
797 :
798 0 : return TRUE;
799 : }
800 :
801 : /**
802 : * @brief Create output tensors.
803 : */
804 : static GstBuffer *
805 0 : gst_tensor_trainer_create_output (GstTensorTrainer * trainer)
806 : {
807 : guint i;
808 : size_t data_size;
809 0 : double model_stats[MODEL_STATS_SIZE] =
810 : { -INFINITY, -INFINITY, -INFINITY, -INFINITY };
811 : GstBuffer *outbuf;
812 : GstMemory *out_mem;
813 : GstMapInfo out_info;
814 : GstTensorInfo *info;
815 0 : gboolean created = FALSE;
816 :
817 0 : if (trainer->output_meta.num_tensors > NNS_TENSOR_SIZE_LIMIT) {
818 0 : GST_ERROR_OBJECT (trainer,
819 : "The number of output tensors (%u) exceeds limit (%d)",
820 : trainer->output_meta.num_tensors, NNS_TENSOR_SIZE_LIMIT);
821 0 : return NULL;
822 : }
823 :
824 0 : outbuf = gst_buffer_new ();
825 :
826 0 : for (i = 0; i < trainer->output_meta.num_tensors; i++) {
827 0 : if (!gst_tensor_trainer_get_model_stats (trainer, model_stats))
828 0 : goto error;
829 :
830 0 : data_size = gst_tensor_trainer_get_tensor_size (trainer, i, FALSE);
831 0 : info = gst_tensors_info_get_nth_info (&trainer->output_meta, i);
832 :
833 0 : out_mem = gst_allocator_alloc (NULL, data_size, NULL);
834 0 : if (!out_mem) {
835 0 : GST_ERROR_OBJECT (trainer, "Failed to allocate memory");
836 0 : goto error;
837 : }
838 :
839 0 : if (!gst_memory_map (out_mem, &out_info, GST_MAP_WRITE)) {
840 0 : GST_ERROR_OBJECT (trainer, "Could not map out_mem[%u] GstMemory", i);
841 0 : gst_memory_unref (out_mem);
842 0 : goto error;
843 : }
844 :
845 0 : memcpy (out_info.data, model_stats, sizeof (model_stats));
846 0 : gst_memory_unmap (out_mem, &out_info);
847 :
848 0 : gst_tensor_buffer_append_memory (outbuf, out_mem, info);
849 : }
850 :
851 0 : created = TRUE;
852 :
853 0 : error:
854 0 : if (created) {
855 0 : GST_INFO ("out_buffer size : %zd", gst_buffer_get_size (outbuf));
856 : } else {
857 0 : gst_buffer_unref (outbuf);
858 0 : outbuf = NULL;
859 : }
860 :
861 0 : return outbuf;
862 : }
863 :
864 : /**
865 : * @brief Chain function, this function does the actual processing.
866 : */
867 : static GstFlowReturn
868 0 : gst_tensor_trainer_chain (GstPad * sinkpad, GstObject * parent,
869 : GstBuffer * inbuf)
870 : {
871 : GstTensorTrainer *trainer;
872 0 : GstBuffer *outbuf = NULL;
873 0 : GstFlowReturn ret = GST_FLOW_ERROR;
874 : guint num_tensors;
875 : gboolean in_flexible;
876 :
877 0 : trainer = GST_TENSOR_TRAINER (parent);
878 0 : in_flexible = gst_tensor_pad_caps_is_flexible (sinkpad);
879 0 : num_tensors = gst_tensor_buffer_get_count (inbuf);
880 :
881 0 : if (!gst_tensor_trainer_check_chain_conditions (trainer, num_tensors)) {
882 0 : goto error;
883 : }
884 :
885 0 : if (gst_tensor_trainer_check_buffer_drop_conditions (trainer)) {
886 0 : ret = GST_FLOW_OK;
887 0 : goto error;
888 : }
889 :
890 0 : if (!gst_tensor_trainer_push_input (trainer, inbuf, in_flexible)) {
891 0 : goto error;
892 : }
893 :
894 : /**
895 : * Update result if one of epochs is complete,
896 : * push one outbuf is necessary to change pipeline state.
897 : * Scheduling with subplugin does not work.
898 : */
899 0 : if (trainer->cur_epoch_data_cnt == 1
900 0 : || gst_tensor_trainer_epochs_is_complete (trainer)) {
901 0 : outbuf = gst_tensor_trainer_create_output (trainer);
902 :
903 0 : if (outbuf)
904 0 : ret = gst_pad_push (trainer->srcpad, outbuf);
905 : } else {
906 : /* Run flow, need more data? */
907 0 : ret = GST_FLOW_OK;
908 : }
909 :
910 0 : error:
911 0 : gst_buffer_unref (inbuf);
912 0 : return ret;
913 : }
914 :
915 : /**
916 : * @brief Get pad caps for caps negotiation.
917 : */
918 : static GstCaps *
919 47 : gst_tensor_trainer_query_caps (GstTensorTrainer * trainer,
920 : GstPad * pad, GstCaps * filter)
921 : {
922 : GstCaps *caps;
923 : GstTensorsConfig *config;
924 :
925 47 : g_return_val_if_fail (trainer != NULL, NULL);
926 47 : g_return_val_if_fail (pad != NULL, NULL);
927 :
928 : /* tensor config info for given pad */
929 47 : if (pad == trainer->sinkpad) {
930 25 : config = &trainer->in_config;
931 : } else {
932 22 : config = &trainer->out_config;
933 : }
934 :
935 47 : caps = gst_tensor_pad_possible_caps_from_config (pad, config);
936 47 : GST_DEBUG_OBJECT (trainer, "caps %" GST_PTR_FORMAT, caps);
937 47 : GST_DEBUG_OBJECT (trainer, "filter %" GST_PTR_FORMAT, filter);
938 :
939 47 : if (caps && filter) {
940 : GstCaps *result;
941 3 : result = gst_caps_intersect_full (filter, caps, GST_CAPS_INTERSECT_FIRST);
942 3 : gst_caps_unref (caps);
943 3 : caps = result;
944 : }
945 :
946 47 : GST_DEBUG_OBJECT (trainer, "result caps %" GST_PTR_FORMAT, caps);
947 :
948 47 : return caps;
949 : }
950 :
951 : /**
952 : * @brief Wait for training completion
953 : */
954 : static void
955 0 : gst_tensor_trainer_wait_for_training_completion (GstTensorTrainer * trainer)
956 : {
957 0 : g_return_if_fail (trainer != NULL);
958 :
959 0 : g_mutex_lock (&trainer->training_completion_lock);
960 0 : while (!trainer->is_training_complete) {
961 0 : GST_INFO_OBJECT (trainer,
962 : "got GST_EVENT_EOS event but training is not completed, state is %d, "
963 : "wait for training_completion_cond signal", GST_STATE (trainer));
964 0 : g_cond_wait (&trainer->training_completion_cond,
965 : &trainer->training_completion_lock);
966 : }
967 0 : g_mutex_unlock (&trainer->training_completion_lock);
968 :
969 0 : GST_DEBUG_OBJECT (trainer, "training is completed in sub-plugin[%s]",
970 : trainer->fw_name);
971 : }
972 :
973 : /**
974 : * @brief Event handler for sink pad of tensor_trainer
975 : */
976 : static gboolean
977 6 : gst_tensor_trainer_sink_event (GstPad * sinkpad, GstObject * parent,
978 : GstEvent * event)
979 : {
980 : GstTensorTrainer *trainer;
981 6 : trainer = GST_TENSOR_TRAINER (parent);
982 :
983 6 : GST_DEBUG_OBJECT (trainer, "Received %s event: %" GST_PTR_FORMAT,
984 : GST_EVENT_TYPE_NAME (event), event);
985 :
986 6 : switch (GST_EVENT_TYPE (event)) {
987 0 : case GST_EVENT_EOS:
988 0 : if (!trainer->is_training_complete)
989 0 : gst_tensor_trainer_wait_for_training_completion (trainer);
990 0 : break;
991 0 : case GST_EVENT_FLUSH_START:
992 0 : GST_INFO_OBJECT (trainer, "get GST_EVENT_FLUSH_START event");
993 0 : break;
994 0 : case GST_EVENT_FLUSH_STOP:
995 0 : GST_INFO_OBJECT (trainer, "get GST_EVENT_FLUSH_STOP event");
996 0 : break;
997 3 : case GST_EVENT_CAPS:
998 : {
999 : GstCaps *in_caps;
1000 : GstCaps *out_caps;
1001 : GstStructure *structure;
1002 : GstTensorsConfig config;
1003 3 : gboolean ret = FALSE;
1004 :
1005 3 : gst_event_parse_caps (event, &in_caps);
1006 3 : GST_INFO_OBJECT (trainer, "[in-caps] : %" GST_PTR_FORMAT, in_caps);
1007 :
1008 3 : structure = gst_caps_get_structure (in_caps, 0);
1009 3 : if (!gst_tensors_config_from_structure (&config, structure) ||
1010 3 : !gst_tensors_config_validate (&config)) {
1011 0 : gst_tensors_config_free (&config);
1012 0 : gst_event_unref (event);
1013 3 : return FALSE;
1014 : }
1015 :
1016 : /* copy TensorsInfo from negotiated caps to GstTensorTrainerProperties's input_meta */
1017 3 : gst_tensors_info_copy (&trainer->prop.input_meta, &config.info);
1018 :
1019 : /* set tensor-config and out caps */
1020 3 : trainer->in_config = config;
1021 3 : trainer->out_config.rate_n = config.rate_n;
1022 3 : trainer->out_config.rate_d = config.rate_d;
1023 3 : gst_tensors_info_copy (&trainer->out_config.info, &trainer->output_meta);
1024 :
1025 : out_caps =
1026 6 : gst_tensor_pad_caps_from_config (trainer->srcpad,
1027 3 : &trainer->out_config);
1028 3 : GST_INFO_OBJECT (trainer, "[out-caps] : %" GST_PTR_FORMAT, out_caps);
1029 :
1030 3 : ret = gst_pad_set_caps (trainer->srcpad, out_caps);
1031 :
1032 3 : gst_event_unref (event);
1033 3 : gst_caps_unref (out_caps);
1034 3 : return ret;
1035 : }
1036 3 : default:
1037 3 : break;
1038 : }
1039 3 : return gst_pad_event_default (sinkpad, parent, event);
1040 : }
1041 :
1042 : /**
1043 : * @brief This function handles sink pad query.
1044 : */
1045 : static gboolean
1046 31 : gst_tensor_trainer_sink_query (GstPad * sinkpad, GstObject * parent,
1047 : GstQuery * query)
1048 : {
1049 : GstTensorTrainer *trainer;
1050 31 : trainer = GST_TENSOR_TRAINER (parent);
1051 :
1052 31 : GST_DEBUG_OBJECT (trainer, "Received '%s' query: %" GST_PTR_FORMAT,
1053 : GST_QUERY_TYPE_NAME (query), query);
1054 :
1055 31 : switch (GST_QUERY_TYPE (query)) {
1056 25 : case GST_QUERY_CAPS:
1057 : {
1058 : GstCaps *caps;
1059 : GstCaps *filter;
1060 :
1061 25 : GST_DEBUG_OBJECT (trainer, "[GST_QUERY_CAPS]");
1062 25 : gst_query_parse_caps (query, &filter);
1063 25 : GST_DEBUG_OBJECT (trainer, "Caps from query : %" GST_PTR_FORMAT, filter);
1064 :
1065 25 : caps = gst_tensor_trainer_query_caps (trainer, sinkpad, filter);
1066 :
1067 25 : GST_INFO_OBJECT (trainer, "[GST_QUERY_CAPS] : %" GST_PTR_FORMAT, caps);
1068 25 : gst_query_set_caps_result (query, caps);
1069 25 : gst_caps_unref (caps);
1070 :
1071 25 : return TRUE;
1072 : }
1073 3 : case GST_QUERY_ACCEPT_CAPS:
1074 : {
1075 : GstCaps *caps;
1076 : GstCaps *template_caps;
1077 3 : gboolean result = FALSE;
1078 :
1079 3 : GST_DEBUG_OBJECT (trainer, "[GST_QUERY_ACCEPT_CAPS]");
1080 3 : gst_query_parse_accept_caps (query, &caps);
1081 3 : GST_INFO_OBJECT (trainer, "Accept caps from query : %" GST_PTR_FORMAT,
1082 : caps);
1083 :
1084 3 : if (gst_caps_is_fixed (caps)) {
1085 3 : template_caps = gst_pad_get_pad_template_caps (sinkpad);
1086 3 : GST_DEBUG_OBJECT (trainer, "sinkpad template_caps : %" GST_PTR_FORMAT,
1087 : template_caps);
1088 :
1089 3 : result = gst_caps_can_intersect (template_caps, caps);
1090 3 : gst_caps_unref (template_caps);
1091 :
1092 3 : GST_DEBUG_OBJECT (trainer, "intersect caps : %" GST_PTR_FORMAT, caps);
1093 : }
1094 :
1095 3 : gst_query_set_accept_caps_result (query, result);
1096 3 : return TRUE;
1097 : }
1098 3 : default:
1099 3 : break;
1100 : }
1101 :
1102 3 : return gst_pad_query_default (sinkpad, parent, query);
1103 : }
1104 :
1105 : /**
1106 : * @brief This function handles src pad query.
1107 : */
1108 : static gboolean
1109 22 : gst_tensor_trainer_src_query (GstPad * srcpad, GstObject * parent,
1110 : GstQuery * query)
1111 : {
1112 : GstTensorTrainer *trainer;
1113 22 : trainer = GST_TENSOR_TRAINER (parent);
1114 :
1115 22 : GST_DEBUG_OBJECT (trainer, "Received %s query: %" GST_PTR_FORMAT,
1116 : GST_QUERY_TYPE_NAME (query), query);
1117 :
1118 22 : switch (GST_QUERY_TYPE (query)) {
1119 22 : case GST_QUERY_CAPS:
1120 : {
1121 : GstCaps *caps;
1122 : GstCaps *filter;
1123 22 : GST_DEBUG_OBJECT (trainer, "[GST_QUERY_CAPS]");
1124 22 : gst_query_parse_caps (query, &filter);
1125 22 : GST_DEBUG_OBJECT (trainer, "Caps from query : %" GST_PTR_FORMAT, filter);
1126 22 : caps = gst_tensor_trainer_query_caps (trainer, srcpad, filter);
1127 :
1128 22 : GST_INFO_OBJECT (trainer, "[GST_QUERY_CAPS] : %" GST_PTR_FORMAT, caps);
1129 22 : gst_query_set_caps_result (query, caps);
1130 22 : gst_caps_unref (caps);
1131 22 : return TRUE;
1132 : }
1133 0 : default:
1134 0 : break;
1135 : }
1136 0 : return gst_pad_query_default (srcpad, parent, query);
1137 : }
1138 :
1139 : /**
1140 : * @brief Handle "PROP_FRAMEWORK" for set-property
1141 : */
1142 : static void
1143 11 : gst_tensor_trainer_set_prop_framework (GstTensorTrainer * trainer,
1144 : const GValue * value)
1145 : {
1146 11 : g_free (trainer->fw_name);
1147 11 : trainer->fw_name = g_value_dup_string (value);
1148 11 : GST_INFO_OBJECT (trainer, "Framework: %s", trainer->fw_name);
1149 :
1150 : /** @todo Check valid framework */
1151 11 : }
1152 :
1153 : /**
1154 : * @brief Handle "PROP_MODEL_CONFIG" for set-property
1155 : */
1156 : static void
1157 11 : gst_tensor_trainer_set_prop_model_config_file_path (GstTensorTrainer *
1158 : trainer, const GValue * value)
1159 : {
1160 11 : g_free ((char *) trainer->prop.model_config);
1161 11 : trainer->prop.model_config = g_value_dup_string (value);
1162 11 : GST_INFO_OBJECT (trainer, "Model configuration file path: %s",
1163 : trainer->prop.model_config);
1164 11 : }
1165 :
1166 : /**
1167 : * @brief Handle "PROP_MODEL_SAVE_PATH" for set-property
1168 : */
1169 : static void
1170 8 : gst_tensor_trainer_set_model_save_path (GstTensorTrainer * trainer,
1171 : const GValue * value)
1172 : {
1173 8 : g_free ((char *) trainer->prop.model_save_path);
1174 8 : trainer->prop.model_save_path = g_value_dup_string (value);
1175 8 : GST_INFO_OBJECT (trainer, "File path to save the model: %s",
1176 : trainer->prop.model_save_path);
1177 8 : }
1178 :
1179 : /**
1180 : * @brief Handle "PROP_MODEL_LOAD_PATH" for set-property
1181 : */
1182 : static void
1183 2 : gst_tensor_trainer_set_model_load_path (GstTensorTrainer * trainer,
1184 : const GValue * value)
1185 : {
1186 2 : g_free ((char *) trainer->prop.model_load_path);
1187 2 : trainer->prop.model_load_path = g_value_dup_string (value);
1188 2 : GST_INFO_OBJECT (trainer, "File path to load the model: %s",
1189 : trainer->prop.model_load_path);
1190 2 : }
1191 :
1192 : /**
1193 : * @brief Find Trainer sub-plugin with the name.
1194 : */
1195 : static gboolean
1196 1 : gst_tensor_trainer_find_framework (GstTensorTrainer * trainer, const char *name)
1197 : {
1198 1 : const GstTensorTrainerFramework *fw = NULL;
1199 :
1200 1 : g_return_val_if_fail (name != NULL, FALSE);
1201 1 : g_return_val_if_fail (trainer != NULL, FALSE);
1202 :
1203 1 : GST_INFO_OBJECT (trainer, "Try to find framework: %s", name);
1204 :
1205 1 : fw = get_subplugin (NNS_SUBPLUGIN_TRAINER, name);
1206 1 : if (!fw) {
1207 1 : GST_ERROR_OBJECT (trainer, "Can not find framework(%s)", trainer->fw_name);
1208 1 : return FALSE;
1209 : }
1210 :
1211 0 : GST_INFO_OBJECT (trainer, "Find framework %s:%p", trainer->fw_name, fw);
1212 0 : trainer->fw = fw;
1213 :
1214 0 : return TRUE;
1215 : }
1216 :
1217 : /**
1218 : * @brief Create NN framework.
1219 : */
1220 : static gboolean
1221 0 : gst_tensor_trainer_create_framework (GstTensorTrainer * trainer)
1222 : {
1223 0 : g_return_val_if_fail (trainer != NULL, FALSE);
1224 :
1225 0 : if (!trainer->fw || trainer->fw_created) {
1226 0 : GST_ERROR_OBJECT (trainer, "fw is not opened(%d) or fw is not null(%p)",
1227 : trainer->fw_created, trainer->fw);
1228 0 : return FALSE;
1229 : }
1230 :
1231 0 : if (!trainer->fw->create) {
1232 0 : GST_ERROR_OBJECT (trainer, "Could not create framework");
1233 0 : return FALSE;
1234 : }
1235 :
1236 0 : GST_DEBUG_OBJECT (trainer, "%p", trainer->privateData);
1237 0 : if (trainer->fw->create (trainer->fw, &trainer->prop,
1238 : &trainer->privateData) >= 0) {
1239 0 : trainer->fw_created = TRUE;
1240 0 : GST_DEBUG_OBJECT (trainer, "Success, Framework: %p", trainer->privateData);
1241 0 : return TRUE;
1242 : }
1243 0 : return FALSE;
1244 : }
1245 :
1246 : /**
1247 : * @brief Calculate tensor buffer size
1248 : */
1249 : gsize
1250 0 : gst_tensor_trainer_get_tensor_size (GstTensorTrainer * trainer,
1251 : guint index, gboolean is_input)
1252 : {
1253 : GstTensorsInfo *info;
1254 :
1255 0 : if (is_input)
1256 0 : info = &trainer->prop.input_meta;
1257 : else
1258 0 : info = &trainer->output_meta;
1259 :
1260 : /* Internal Logic Error: out of bound */
1261 0 : if (index >= info->num_tensors) {
1262 0 : GST_ERROR_OBJECT (trainer, "has inconsistent data");
1263 0 : return 0;
1264 : }
1265 :
1266 0 : return gst_tensors_info_get_size (info, index);
1267 : }
1268 :
1269 : /**
1270 : * @brief Create model
1271 : */
1272 : static gboolean
1273 1 : gst_tensor_trainer_create_model (GstTensorTrainer * trainer)
1274 : {
1275 1 : gboolean ret = TRUE;
1276 :
1277 1 : g_return_val_if_fail (trainer != NULL, FALSE);
1278 1 : g_return_val_if_fail (trainer->fw_name != NULL, FALSE);
1279 :
1280 1 : ret = gst_tensor_trainer_find_framework (trainer, trainer->fw_name);
1281 1 : if (!ret)
1282 1 : return ret;
1283 :
1284 0 : if (trainer->fw) {
1285 : /* model create and compile */
1286 0 : ret = gst_tensor_trainer_create_framework (trainer);
1287 : }
1288 :
1289 0 : return ret;
1290 : }
1291 :
1292 : /**
1293 : * @brief Create a event notifier
1294 : */
1295 : static void
1296 0 : gst_tensor_trainer_create_event_notifier (GstTensorTrainer * trainer)
1297 : {
1298 0 : g_return_if_fail (trainer != NULL);
1299 0 : g_return_if_fail (trainer->fw != NULL);
1300 :
1301 0 : trainer->notifier.notifier = (void *) trainer;
1302 : }
1303 :
1304 : /**
1305 : * @brief Start model training
1306 : */
1307 : static void
1308 0 : gst_tensor_trainer_start_model_training (GstTensorTrainer * trainer)
1309 : {
1310 0 : gint ret = -1;
1311 0 : g_return_if_fail (trainer != NULL);
1312 0 : g_return_if_fail (trainer->fw != NULL);
1313 0 : g_return_if_fail (trainer->fw->start != NULL);
1314 :
1315 0 : GST_DEBUG_OBJECT (trainer, "Start model training");
1316 : ret =
1317 0 : trainer->fw->start (trainer->fw, &trainer->prop, &trainer->notifier,
1318 : trainer->privateData);
1319 0 : if (ret != 0) {
1320 0 : GST_ERROR_OBJECT (trainer, "Model training is failed");
1321 : }
1322 : }
1323 :
1324 : /**
1325 : * @brief Stop model training
1326 : */
1327 : static void
1328 0 : gst_tensor_trainer_stop_model_training (GstTensorTrainer * trainer)
1329 : {
1330 0 : gint ret = -1;
1331 :
1332 0 : g_return_if_fail (trainer != NULL);
1333 0 : g_return_if_fail (trainer->fw != NULL);
1334 0 : g_return_if_fail (trainer->fw->stop != NULL);
1335 :
1336 0 : GST_DEBUG_OBJECT (trainer, "Stop model training");
1337 0 : ret = trainer->fw->stop (trainer->fw, &trainer->prop, &trainer->privateData);
1338 0 : if (ret != 0) {
1339 0 : GST_ERROR_OBJECT (trainer, "Stopping model training is failed");
1340 : }
1341 : }
1342 :
1343 : /**
1344 : * @brief initialize the output tensor dimension
1345 : */
1346 : static void
1347 11 : gst_tensor_trainer_set_output_meta (GstTensorTrainer * trainer)
1348 : {
1349 : GstTensorInfo *info;
1350 :
1351 11 : g_return_if_fail (trainer != NULL);
1352 :
1353 11 : gst_tensors_info_init (&trainer->output_meta);
1354 11 : info = gst_tensors_info_get_nth_info (&trainer->output_meta, 0);
1355 :
1356 11 : info->type = _NNS_FLOAT64;
1357 11 : info->dimension[0] = 1;
1358 11 : info->dimension[1] = 1;
1359 11 : info->dimension[2] = 4; /** loss, accuracy, val_loss, val_accuracy */
1360 11 : info->dimension[3] = 1;
1361 :
1362 11 : trainer->output_meta.num_tensors = 1;
1363 : }
1364 :
1365 : /**
1366 : * @brief Trainer's sub-plugin should call this function to register itself.
1367 : * @param[in] ttsp tensor_trainer sub-plugin to be registered.
1368 : * @return TRUE if registered. FALSE is failed or duplicated.
1369 : */
1370 : int
1371 0 : nnstreamer_trainer_probe (GstTensorTrainerFramework * ttsp)
1372 : {
1373 : GstTensorTrainerFrameworkInfo info;
1374 : GstTensorTrainerProperties prop;
1375 0 : const char *name = NULL;
1376 0 : int ret = 0;
1377 :
1378 0 : g_return_val_if_fail (ttsp != NULL, 0);
1379 :
1380 0 : memset (&prop, 0, sizeof (GstTensorTrainerProperties));
1381 0 : gst_tensors_info_init (&prop.input_meta);
1382 :
1383 0 : if (ret != ttsp->getFrameworkInfo (ttsp, &prop, NULL, &info)) {
1384 0 : GST_ERROR ("getFrameworkInfo() failed");
1385 0 : return FALSE;
1386 : }
1387 0 : name = info.name;
1388 :
1389 0 : return register_subplugin (NNS_SUBPLUGIN_TRAINER, name, ttsp);
1390 : }
1391 :
1392 : /**
1393 : * @brief Trainer's sub-plugin may call this to unregister itself.
1394 : * @param[in] ttsp tensor_trainer sub-plugin to be unregistered.
1395 : * @return TRUE if unregistered. FALSE is failed.
1396 : */
1397 : int
1398 0 : nnstreamer_trainer_exit (GstTensorTrainerFramework * ttsp)
1399 : {
1400 : GstTensorTrainerFrameworkInfo info;
1401 : GstTensorTrainerProperties prop;
1402 0 : const char *name = NULL;
1403 0 : int ret = 0;
1404 :
1405 0 : g_return_val_if_fail (ttsp != NULL, 0);
1406 :
1407 0 : memset (&prop, 0, sizeof (GstTensorTrainerProperties));
1408 0 : gst_tensors_info_init (&prop.input_meta);
1409 :
1410 0 : if (ret != ttsp->getFrameworkInfo (ttsp, &prop, NULL, &info)) {
1411 0 : GST_ERROR ("getFrameworkInfo() failed");
1412 0 : return FALSE;
1413 : }
1414 0 : name = info.name;
1415 :
1416 0 : return unregister_subplugin (NNS_SUBPLUGIN_TRAINER, name);
1417 : }
1418 :
1419 : /**
1420 : * @brief Trainer's sub-plugin may call this to send event.
1421 : * @param[in] notifier event notifier, sub-plugin must send events with this.
1422 : * @param[in] type event type
1423 : */
1424 : void
1425 0 : nnstreamer_trainer_notify_event (GstTensorTrainerEventNotifier * notifier,
1426 : GstTensorTrainerEventType type, void *data)
1427 : {
1428 : GstTensorTrainer *trainer;
1429 0 : g_return_if_fail (notifier != NULL);
1430 0 : g_return_if_fail (type < TRAINER_EVENT_UNKNOWN || type > 0);
1431 : UNUSED (data);
1432 :
1433 0 : trainer = (GstTensorTrainer *) notifier->notifier;
1434 0 : g_return_if_fail (GST_IS_TENSOR_TRAINER (trainer));
1435 :
1436 0 : GST_DEBUG ("Received GstTensorTrainerEvent(%d)", type);
1437 :
1438 0 : switch (type) {
1439 0 : case TRAINER_EVENT_EPOCH_COMPLETION:
1440 0 : g_mutex_lock (&trainer->epoch_completion_lock);
1441 0 : trainer->is_epoch_complete = TRUE;
1442 0 : GST_DEBUG ("send epoch_completion_cond signal");
1443 0 : g_cond_signal (&trainer->epoch_completion_cond);
1444 0 : g_mutex_unlock (&trainer->epoch_completion_lock);
1445 0 : break;
1446 0 : case TRAINER_EVENT_TRAINING_COMPLETION:
1447 0 : g_mutex_lock (&trainer->training_completion_lock);
1448 0 : trainer->is_training_complete = TRUE;
1449 0 : GST_DEBUG ("send training_completion_cond signal");
1450 0 : g_cond_signal (&trainer->training_completion_cond);
1451 0 : g_mutex_unlock (&trainer->training_completion_lock);
1452 0 : break;
1453 0 : default:
1454 0 : break;
1455 : }
1456 : }
|