Line data Source code
1 : /* SPDX-License-Identifier: LGPL-2.1-only */
2 : /**
3 : * GStreamer / NNStreamer gRPC/flatbuf support
4 : * Copyright (C) 2020 Dongju Chae <dongju.chae@samsung.com>
5 : */
6 : /**
7 : * @file nnstreamer_grpc_flatbuffer.cc
8 : * @date 26 Nov 2020
9 : * @brief nnstreamer gRPC/Flatbuf support
10 : * @see https://github.com/nnstreamer/nnstreamer
11 : * @author Dongju Chae <dongju.chae@samsung.com>
12 : * @bug No known bugs except for NYI items
13 : */
14 :
15 : #include "nnstreamer_grpc_flatbuf.h"
16 :
17 : #include <nnstreamer_log.h>
18 : #include <nnstreamer_plugin_api.h>
19 : #include <nnstreamer_util.h>
20 :
21 : #include <thread>
22 :
23 : #include <grpcpp/channel.h>
24 : #include <grpcpp/client_context.h>
25 : #include <grpcpp/create_channel.h>
26 : #include <grpcpp/grpcpp.h>
27 : #include <grpcpp/security/credentials.h>
28 :
29 : #include <gst/base/gstdataqueue.h>
30 :
31 : using nnstreamer::flatbuf::frame_rate;
32 : using nnstreamer::flatbuf::Tensor_format;
33 : using nnstreamer::flatbuf::Tensor_type;
34 :
35 : using flatbuffers::grpc::MessageBuilder;
36 :
37 : using namespace grpc;
38 :
39 : /** @brief Constructor of ServiceImplFlatbuf */
40 14 : ServiceImplFlatbuf::ServiceImplFlatbuf (const grpc_config *config)
41 14 : : NNStreamerRPC (config), client_stub_ (nullptr)
42 : {
43 14 : }
44 :
45 : /** @brief parse tensors and deliver the buffer via callback */
46 : void
47 62 : ServiceImplFlatbuf::parse_tensors (Message<Tensors> &tensors)
48 : {
49 : GstBuffer *buffer;
50 :
51 62 : _get_buffer_from_tensors (tensors, &buffer);
52 :
53 62 : if (cb_)
54 62 : cb_ (cb_data_, buffer);
55 : else
56 0 : gst_buffer_unref (buffer);
57 62 : }
58 :
59 : /** @brief fill tensors from the buffer */
60 : gboolean
61 51 : ServiceImplFlatbuf::fill_tensors (Message<Tensors> &tensors)
62 : {
63 : GstDataQueueItem *item;
64 :
65 51 : if (!gst_data_queue_pop (queue_, &item))
66 6 : return FALSE;
67 :
68 45 : _get_tensors_from_buffer (GST_BUFFER (item->object), tensors);
69 :
70 45 : GDestroyNotify destroy = (item->destroy) ? item->destroy : g_free;
71 45 : destroy (item);
72 :
73 45 : return TRUE;
74 : }
75 :
76 : /** @brief read tensors and invoke the registered callback */
77 : template <typename T>
78 : Status
79 4 : ServiceImplFlatbuf::_read_tensors (T reader)
80 : {
81 30 : while (1) {
82 34 : Message<Tensors> tensors;
83 :
84 34 : if (!reader->Read (&tensors))
85 4 : break;
86 :
87 30 : parse_tensors (tensors);
88 : }
89 :
90 4 : return Status::OK;
91 : }
92 :
93 : /** @brief obtain tensors from data queue and send them over gRPC */
94 : template <typename T>
95 : Status
96 4 : ServiceImplFlatbuf::_write_tensors (T writer)
97 : {
98 30 : while (1) {
99 34 : Message<Tensors> tensors;
100 :
101 : /* until flushing */
102 34 : if (!fill_tensors (tensors))
103 4 : break;
104 :
105 30 : writer->Write (tensors);
106 : }
107 :
108 4 : return Status::OK;
109 : }
110 :
111 : /** @brief convert tensors to buffer */
112 : void
113 62 : ServiceImplFlatbuf::_get_buffer_from_tensors (Message<Tensors> &msg, GstBuffer **buffer)
114 : {
115 62 : const Tensors *tensors = msg.GetRoot ();
116 62 : guint num_tensor = tensors->num_tensor ();
117 : GstTensorInfo *_info;
118 : GstMemory *memory;
119 :
120 62 : *buffer = gst_buffer_new ();
121 :
122 124 : for (guint i = 0; i < num_tensor; i++) {
123 62 : const Tensor *tensor = tensors->tensor ()->Get (i);
124 62 : const void *data = tensor->data ()->data ();
125 62 : gsize size = VectorLength (tensor->data ());
126 62 : gpointer new_data = _g_memdup (data, size);
127 :
128 62 : _info = gst_tensors_info_get_nth_info (&config_->info, i);
129 :
130 62 : memory = gst_memory_new_wrapped (
131 : (GstMemoryFlags) 0, new_data, size, 0, size, new_data, g_free);
132 62 : gst_tensor_buffer_append_memory (*buffer, memory, _info);
133 : }
134 62 : }
135 :
136 : /** @brief convert buffer to tensors */
137 : void
138 45 : ServiceImplFlatbuf::_get_tensors_from_buffer (GstBuffer *buffer, Message<Tensors> &msg)
139 : {
140 45 : MessageBuilder builder;
141 :
142 45 : flatbuffers::Offset<flatbuffers::Vector<uint32_t>> tensor_dim;
143 45 : flatbuffers::Offset<flatbuffers::String> tensor_name;
144 45 : flatbuffers::Offset<flatbuffers::Vector<unsigned char>> tensor_data;
145 45 : flatbuffers::Offset<Tensor> tensor;
146 45 : flatbuffers::Offset<Tensors> tensors;
147 45 : std::vector<flatbuffers::Offset<Tensor>> tensor_vector;
148 : Tensor_type tensor_type;
149 45 : Tensor_format format = (Tensor_format) config_->info.format;
150 45 : unsigned int num_tensors = config_->info.num_tensors;
151 45 : frame_rate fr = frame_rate (config_->rate_n, config_->rate_d);
152 : GstTensorInfo *_info;
153 : GstMemory *mem;
154 : GstMapInfo map;
155 :
156 90 : for (guint i = 0; i < num_tensors; i++) {
157 45 : _info = gst_tensors_info_get_nth_info (&config_->info, i);
158 :
159 45 : mem = gst_tensor_buffer_get_nth_memory (buffer, i);
160 45 : g_assert (gst_memory_map (mem, &map, GST_MAP_READ));
161 :
162 45 : tensor_dim = builder.CreateVector (_info->dimension, NNS_TENSOR_RANK_LIMIT);
163 45 : tensor_name = builder.CreateString ("Anonymous");
164 45 : tensor_type = (Tensor_type) _info->type;
165 45 : tensor_data = builder.CreateVector<unsigned char> (map.data, map.size);
166 :
167 45 : tensor = CreateTensor (builder, tensor_name, tensor_type, tensor_dim, tensor_data);
168 45 : tensor_vector.push_back (tensor);
169 :
170 45 : gst_memory_unmap (mem, &map);
171 45 : gst_memory_unref (mem);
172 : }
173 :
174 45 : tensors = CreateTensors (
175 : builder, num_tensors, &fr, builder.CreateVector (tensor_vector), format);
176 :
177 45 : builder.Finish (tensors);
178 45 : msg = builder.ReleaseMessage<Tensors> ();
179 90 : }
180 :
181 : /** @brief Constructor of SyncServiceImplFlatbuf */
182 8 : SyncServiceImplFlatbuf::SyncServiceImplFlatbuf (const grpc_config *config)
183 8 : : ServiceImplFlatbuf (config)
184 : {
185 8 : }
186 :
187 : /** @brief client-to-server streaming: a client sends tensors */
188 : Status
189 2 : SyncServiceImplFlatbuf::SendTensors (ServerContext *context,
190 : ServerReader<Message<Tensors>> *reader, Message<Empty> *replay)
191 : {
192 2 : return _read_tensors (reader);
193 : }
194 :
195 : /** @brief server-to-client streaming: a client receives tensors */
196 : Status
197 2 : SyncServiceImplFlatbuf::RecvTensors (ServerContext *context,
198 : const Message<Empty> *request, ServerWriter<Message<Tensors>> *writer)
199 : {
200 2 : return _write_tensors (writer);
201 : }
202 :
203 : /** @brief start gRPC server handling flatbuf */
204 : gboolean
205 4 : SyncServiceImplFlatbuf::start_server (std::string address)
206 : {
207 : /* listen on the given address without any authentication mechanism */
208 4 : ServerBuilder builder;
209 4 : builder.AddListeningPort (address, grpc::InsecureServerCredentials (), &port_);
210 4 : builder.RegisterService (this);
211 :
212 : /* start the server */
213 4 : server_instance_ = builder.BuildAndStart ();
214 4 : if (server_instance_.get () == nullptr)
215 0 : return FALSE;
216 :
217 4 : return TRUE;
218 4 : }
219 :
220 : /** @brief start gRPC client handling flatbuf */
221 : gboolean
222 4 : SyncServiceImplFlatbuf::start_client (std::string address)
223 : {
224 : /* create a gRPC channel */
225 : std::shared_ptr<Channel> channel
226 4 : = grpc::CreateChannel (address, grpc::InsecureChannelCredentials ());
227 :
228 : /* connect the server */
229 : try {
230 4 : client_stub_ = TensorService::NewStub (channel);
231 0 : } catch (...) {
232 0 : ml_loge ("Failed to connect the server");
233 0 : return FALSE;
234 0 : }
235 8 : worker_ = std::thread ([this] { this->_client_thread (); });
236 :
237 4 : return TRUE;
238 4 : }
239 :
240 : /** @brief gRPC client thread */
241 : void
242 4 : SyncServiceImplFlatbuf::_client_thread ()
243 : {
244 4 : ClientContext context;
245 :
246 4 : if (direction_ == GRPC_DIRECTION_TENSORS_TO_BUFFER) {
247 2 : Message<Empty> empty;
248 :
249 : /* initiate the RPC call */
250 : std::unique_ptr<ClientWriter<Message<Tensors>>> writer (
251 2 : client_stub_->SendTensors (&context, &empty));
252 :
253 2 : _write_tensors (writer.get ());
254 :
255 2 : writer->WritesDone ();
256 : /**
257 : * TODO: The below incurs assertion failure but it seems like a bug.
258 : * Let's check it later with the latest gRPC version.
259 : *
260 : * writer->Finish ();
261 : */
262 2 : g_usleep (G_USEC_PER_SEC / 100);
263 4 : } else if (direction_ == GRPC_DIRECTION_BUFFER_TO_TENSORS) {
264 2 : MessageBuilder builder;
265 :
266 2 : auto empty_offset = nnstreamer::flatbuf::CreateEmpty (builder);
267 2 : builder.Finish (empty_offset);
268 :
269 : /* initiate the RPC call */
270 : std::unique_ptr<ClientReader<Message<Tensors>>> reader (
271 2 : client_stub_->RecvTensors (&context, builder.ReleaseMessage<Empty> ()));
272 :
273 2 : _read_tensors (reader.get ());
274 :
275 2 : reader->Finish ();
276 2 : } else {
277 0 : g_assert (0); /* internal logic error */
278 : }
279 4 : }
280 :
281 : /** @brief Constructor of AsyncServiceImplFlatbuf */
282 6 : AsyncServiceImplFlatbuf::AsyncServiceImplFlatbuf (const grpc_config *config)
283 6 : : ServiceImplFlatbuf (config), last_call_ (nullptr)
284 : {
285 6 : }
286 :
287 : /** @brief Destructor of AsyncServiceImplFlatbuf */
288 4 : AsyncServiceImplFlatbuf::~AsyncServiceImplFlatbuf ()
289 : {
290 2 : if (last_call_)
291 0 : delete last_call_;
292 4 : }
293 :
294 :
295 : /** @brief start gRPC server handling flatbuf */
296 : gboolean
297 2 : AsyncServiceImplFlatbuf::start_server (std::string address)
298 : {
299 : /* listen on the given address without any authentication mechanism */
300 2 : ServerBuilder builder;
301 2 : builder.AddListeningPort (address, grpc::InsecureServerCredentials (), &port_);
302 2 : builder.RegisterService (this);
303 :
304 : /* need to manually handle the completion queue */
305 2 : completion_queue_ = builder.AddCompletionQueue ();
306 :
307 : /* start the server */
308 2 : server_instance_ = builder.BuildAndStart ();
309 2 : if (server_instance_.get () == nullptr)
310 0 : return FALSE;
311 :
312 4 : worker_ = std::thread ([this] { this->_server_thread (); });
313 :
314 2 : return TRUE;
315 2 : }
316 :
317 : /** @brief start gRPC client handling flatbuf */
318 : gboolean
319 4 : AsyncServiceImplFlatbuf::start_client (std::string address)
320 : {
321 : /* create a gRPC channel */
322 : std::shared_ptr<Channel> channel
323 4 : = grpc::CreateChannel (address, grpc::InsecureChannelCredentials ());
324 :
325 : /* connect the server */
326 : try {
327 4 : client_stub_ = TensorService::NewStub (channel);
328 0 : } catch (...) {
329 0 : ml_loge ("Failed to connect the server");
330 0 : return FALSE;
331 0 : }
332 :
333 8 : worker_ = std::thread ([this] { this->_client_thread (); });
334 :
335 4 : return TRUE;
336 4 : }
337 :
338 : /** @brief Internal derived class for server */
339 : class AsyncCallDataServer : public AsyncCallData
340 : {
341 : public:
342 : /** @brief Constructor of AsyncCallDataServer */
343 4 : AsyncCallDataServer (AsyncServiceImplFlatbuf *service, ServerCompletionQueue *cq)
344 4 : : AsyncCallData (service), cq_ (cq), writer_ (nullptr), reader_ (nullptr)
345 : {
346 4 : RunState ();
347 4 : }
348 :
349 : /** @brief implemented RunState () of AsyncCallDataServer */
350 27 : void RunState (bool ok = true) override
351 : {
352 27 : if (state_ == PROCESS && !ok) {
353 4 : if (count_ != 0) {
354 2 : if (reader_.get () != nullptr)
355 2 : service_->parse_tensors (rpc_tensors_);
356 2 : state_ = FINISH;
357 : } else {
358 2 : return;
359 : }
360 : }
361 :
362 25 : if (state_ == CREATE) {
363 4 : if (service_->getDirection () == GRPC_DIRECTION_BUFFER_TO_TENSORS) {
364 4 : reader_.reset (new ServerAsyncReader<Message<Empty>, Message<Tensors>> (&ctx_));
365 4 : service_->RequestSendTensors (&ctx_, reader_.get (), cq_, cq_, this);
366 : } else {
367 0 : writer_.reset (new ServerAsyncWriter<Message<Tensors>> (&ctx_));
368 0 : service_->RequestRecvTensors (&ctx_, &rpc_empty_, writer_.get (), cq_, cq_, this);
369 : }
370 4 : state_ = PROCESS;
371 21 : } else if (state_ == PROCESS) {
372 17 : if (count_ == 0) {
373 : /* spawn a new instance to serve new clients */
374 2 : service_->set_last_call (new AsyncCallDataServer (service_, cq_));
375 : }
376 :
377 17 : if (reader_.get () != nullptr) {
378 17 : if (count_ != 0)
379 15 : service_->parse_tensors (rpc_tensors_);
380 17 : reader_->Read (&rpc_tensors_, this);
381 : /* can't read tensors yet. use the next turn */
382 17 : count_++;
383 0 : } else if (writer_.get () != nullptr) {
384 0 : Message<Tensors> tensors;
385 0 : if (service_->fill_tensors (tensors)) {
386 0 : writer_->Write (tensors, this);
387 0 : count_++;
388 : } else {
389 0 : Status status;
390 0 : writer_->Finish (status, this);
391 0 : state_ = DESTROY;
392 0 : }
393 0 : }
394 4 : } else if (state_ == FINISH) {
395 2 : if (reader_.get () != nullptr) {
396 2 : MessageBuilder builder;
397 :
398 2 : auto empty_offset = nnstreamer::flatbuf::CreateEmpty (builder);
399 2 : builder.Finish (empty_offset);
400 :
401 2 : reader_->Finish (builder.ReleaseMessage<Empty> (), Status::OK, this);
402 2 : }
403 2 : if (writer_.get () != nullptr) {
404 0 : Status status;
405 0 : writer_->Finish (status, this);
406 0 : }
407 2 : state_ = DESTROY;
408 : } else {
409 2 : delete this;
410 : }
411 : }
412 :
413 : private:
414 : ServerCompletionQueue *cq_;
415 : ServerContext ctx_;
416 :
417 : std::unique_ptr<ServerAsyncWriter<Message<Tensors>>> writer_;
418 : std::unique_ptr<ServerAsyncReader<Message<Empty>, Message<Tensors>>> reader_;
419 : };
420 :
421 : /** @brief Internal derived class for client */
422 : class AsyncCallDataClient : public AsyncCallData
423 : {
424 : public:
425 : /** @brief Constructor of AsyncCallDataClient */
426 4 : AsyncCallDataClient (AsyncServiceImplFlatbuf *service,
427 : TensorService::Stub *stub, CompletionQueue *cq)
428 4 : : AsyncCallData (service), stub_ (stub), cq_ (cq), writer_ (nullptr),
429 8 : reader_ (nullptr)
430 : {
431 4 : RunState ();
432 4 : }
433 :
434 : /** @brief implemented RunState () of AsyncCallDataClient */
435 38 : void RunState (bool ok = true) override
436 : {
437 38 : if (state_ == PROCESS && !ok) {
438 0 : if (count_ != 0) {
439 0 : if (reader_.get () != nullptr)
440 0 : service_->parse_tensors (rpc_tensors_);
441 0 : state_ = FINISH;
442 : } else {
443 0 : return;
444 : }
445 : }
446 :
447 38 : if (state_ == CREATE) {
448 4 : if (service_->getDirection () == GRPC_DIRECTION_BUFFER_TO_TENSORS) {
449 2 : MessageBuilder builder;
450 :
451 2 : auto empty_offset = nnstreamer::flatbuf::CreateEmpty (builder);
452 2 : builder.Finish (empty_offset);
453 :
454 4 : reader_ = stub_->AsyncRecvTensors (
455 6 : &ctx_, builder.ReleaseMessage<Empty> (), cq_, this);
456 2 : } else {
457 2 : writer_ = stub_->AsyncSendTensors (&ctx_, &rpc_empty_, cq_, this);
458 : }
459 4 : state_ = PROCESS;
460 34 : } else if (state_ == PROCESS) {
461 34 : if (reader_.get () != nullptr) {
462 17 : if (count_ != 0)
463 15 : service_->parse_tensors (rpc_tensors_);
464 17 : reader_->Read (&rpc_tensors_, this);
465 : /* can't read tensors yet. use the next turn */
466 17 : count_++;
467 17 : } else if (writer_.get () != nullptr) {
468 17 : Message<Tensors> tensors;
469 17 : if (service_->fill_tensors (tensors)) {
470 15 : writer_->Write (tensors, this);
471 15 : count_++;
472 : } else {
473 2 : writer_->WritesDone (this);
474 2 : state_ = FINISH;
475 : }
476 17 : }
477 0 : } else if (state_ == FINISH) {
478 0 : Status status;
479 :
480 0 : if (reader_.get () != nullptr)
481 0 : reader_->Finish (&status, this);
482 0 : if (writer_.get () != nullptr)
483 0 : writer_->Finish (&status, this);
484 :
485 0 : delete this;
486 0 : }
487 : }
488 :
489 : private:
490 : TensorService::Stub *stub_;
491 : CompletionQueue *cq_;
492 : ClientContext ctx_;
493 :
494 : std::unique_ptr<ClientAsyncWriter<Message<Tensors>>> writer_;
495 : std::unique_ptr<ClientAsyncReader<Message<Tensors>>> reader_;
496 : };
497 :
498 : /** @brief gRPC client thread */
499 : void
500 2 : AsyncServiceImplFlatbuf::_server_thread ()
501 : {
502 : /* spawn a new instance to server new clients */
503 2 : set_last_call (new AsyncCallDataServer (this, completion_queue_.get ()));
504 :
505 : while (1) {
506 : void *tag;
507 : bool ok;
508 :
509 : /* 10 msec deadline to wait the next event */
510 569 : gpr_timespec deadline = gpr_time_add (
511 : gpr_now (GPR_CLOCK_MONOTONIC), gpr_time_from_millis (10, GPR_TIMESPAN));
512 :
513 569 : switch (completion_queue_->AsyncNext (&tag, &ok, deadline)) {
514 23 : case CompletionQueue::GOT_EVENT:
515 23 : static_cast<AsyncCallDataServer *> (tag)->RunState (ok);
516 23 : break;
517 2 : case CompletionQueue::SHUTDOWN:
518 2 : return;
519 544 : default:
520 544 : break;
521 : }
522 567 : }
523 : }
524 :
525 : /** @brief gRPC client thread */
526 : void
527 4 : AsyncServiceImplFlatbuf::_client_thread ()
528 : {
529 4 : CompletionQueue cq;
530 :
531 : /* spawn a new instance to serve new clients */
532 4 : new AsyncCallDataClient (this, client_stub_.get (), &cq);
533 :
534 : /* until the stop is called */
535 163 : while (!stop_) {
536 : void *tag;
537 : bool ok;
538 :
539 : /* 10 msec deadline to wait the next event */
540 159 : gpr_timespec deadline = gpr_time_add (
541 : gpr_now (GPR_CLOCK_MONOTONIC), gpr_time_from_millis (10, GPR_TIMESPAN));
542 :
543 159 : switch (cq.AsyncNext (&tag, &ok, deadline)) {
544 34 : case CompletionQueue::GOT_EVENT:
545 34 : static_cast<AsyncCallDataClient *> (tag)->RunState (ok);
546 34 : if (ok == false)
547 0 : return;
548 34 : break;
549 125 : default:
550 125 : break;
551 : }
552 : }
553 4 : }
554 :
555 : /** @brief create gRPC/Flatbuf instance */
556 : extern "C" NNStreamerRPC *
557 14 : create_instance (const grpc_config *config)
558 : {
559 14 : if (config->is_blocking)
560 8 : return new SyncServiceImplFlatbuf (config);
561 : else
562 6 : return new AsyncServiceImplFlatbuf (config);
563 : }
|