Line data Source code
1 : /* SPDX-License-Identifier: LGPL-2.1-only */
2 : /**
3 : * GStreamer / NNStreamer gRPC/protobuf support
4 : * Copyright (C) 2020 Dongju Chae <dongju.chae@samsung.com>
5 : */
6 : /**
7 : * @file nnstreamer_grpc_protobuf.cc
8 : * @date 21 Oct 2020
9 : * @brief gRPC/Protobuf wrappers for nnstreamer
10 : * @see https://github.com/nnstreamer/nnstreamer
11 : * @author Dongju Chae <dongju.chae@samsung.com>
12 : * @bug No known bugs except for NYI items
13 : */
14 :
15 : #include "nnstreamer_grpc_protobuf.h"
16 :
17 : #include <nnstreamer_log.h>
18 : #include <nnstreamer_plugin_api.h>
19 : #include <nnstreamer_util.h>
20 :
21 : #include <thread>
22 :
23 : #include <grpcpp/channel.h>
24 : #include <grpcpp/client_context.h>
25 : #include <grpcpp/create_channel.h>
26 : #include <grpcpp/grpcpp.h>
27 : #include <grpcpp/security/credentials.h>
28 :
29 : #include <gst/base/gstdataqueue.h>
30 :
31 : using namespace grpc;
32 :
33 : /** @brief constructor */
34 14 : ServiceImplProtobuf::ServiceImplProtobuf (const grpc_config *config)
35 14 : : NNStreamerRPC (config), client_stub_ (nullptr)
36 : {
37 14 : }
38 :
39 : /** @brief parse tensors and deliver the buffer via callback */
40 : void
41 62 : ServiceImplProtobuf::parse_tensors (Tensors &tensors)
42 : {
43 : GstBuffer *buffer;
44 :
45 62 : _get_buffer_from_tensors (tensors, &buffer);
46 :
47 62 : if (cb_)
48 62 : cb_ (cb_data_, buffer);
49 : else
50 0 : gst_buffer_unref (buffer);
51 62 : }
52 :
53 : /** @brief fill tensors from the buffer */
54 : gboolean
55 51 : ServiceImplProtobuf::fill_tensors (Tensors &tensors)
56 : {
57 : GstDataQueueItem *item;
58 :
59 51 : if (!gst_data_queue_pop (queue_, &item))
60 6 : return FALSE;
61 :
62 45 : _get_tensors_from_buffer (GST_BUFFER (item->object), tensors);
63 :
64 45 : GDestroyNotify destroy = (item->destroy) ? item->destroy : g_free;
65 45 : destroy (item);
66 :
67 45 : return TRUE;
68 : }
69 :
70 : /** @brief read tensors and invoke the registered callback */
71 : template <typename T>
72 : Status
73 4 : ServiceImplProtobuf::_read_tensors (T reader)
74 : {
75 30 : while (1) {
76 34 : Tensors tensors;
77 :
78 34 : if (!reader->Read (&tensors))
79 4 : break;
80 :
81 30 : parse_tensors (tensors);
82 : }
83 :
84 4 : return Status::OK;
85 : }
86 :
87 : /** @brief obtain tensors from data queue and send them over gRPC */
88 : template <typename T>
89 : Status
90 4 : ServiceImplProtobuf::_write_tensors (T writer)
91 : {
92 30 : while (1) {
93 34 : Tensors tensors;
94 :
95 : /* until flushing */
96 34 : if (!fill_tensors (tensors))
97 4 : break;
98 :
99 30 : writer->Write (tensors);
100 : }
101 :
102 4 : return Status::OK;
103 : }
104 :
105 : /** @brief convert tensors to buffer */
106 : void
107 62 : ServiceImplProtobuf::_get_buffer_from_tensors (Tensors &tensors, GstBuffer **buffer)
108 : {
109 62 : guint num_tensor = tensors.num_tensor ();
110 : GstTensorInfo *_info;
111 : GstMemory *memory;
112 :
113 62 : *buffer = gst_buffer_new ();
114 :
115 124 : for (guint i = 0; i < num_tensor; i++) {
116 62 : const Tensor *tensor = &tensors.tensor (i);
117 62 : const void *data = tensor->data ().c_str ();
118 62 : gsize size = tensor->data ().length ();
119 62 : gpointer new_data = _g_memdup (data, size);
120 :
121 62 : _info = gst_tensors_info_get_nth_info (&config_->info, i);
122 :
123 62 : memory = gst_memory_new_wrapped (
124 : (GstMemoryFlags) 0, new_data, size, 0, size, new_data, g_free);
125 62 : gst_tensor_buffer_append_memory (*buffer, memory, _info);
126 : }
127 62 : }
128 :
129 : /** @brief convert buffer to tensors */
130 : void
131 45 : ServiceImplProtobuf::_get_tensors_from_buffer (GstBuffer *buffer, Tensors &tensors)
132 : {
133 : Tensors::frame_rate *fr;
134 : GstTensorInfo *_info;
135 : GstMemory *mem;
136 : GstMapInfo map;
137 :
138 45 : tensors.set_num_tensor (config_->info.num_tensors);
139 :
140 45 : fr = tensors.mutable_fr ();
141 45 : fr->set_rate_n (config_->rate_n);
142 45 : fr->set_rate_d (config_->rate_d);
143 :
144 90 : for (guint i = 0; i < config_->info.num_tensors; i++) {
145 45 : nnstreamer::protobuf::Tensor *tensor = tensors.add_tensor ();
146 :
147 45 : _info = gst_tensors_info_get_nth_info (&config_->info, i);
148 :
149 45 : mem = gst_tensor_buffer_get_nth_memory (buffer, i);
150 45 : g_assert (gst_memory_map (mem, &map, GST_MAP_READ));
151 :
152 : /* set tensor info */
153 : tensor->set_name ("Anonymous");
154 45 : tensor->set_type ((Tensor::Tensor_type) _info->type);
155 :
156 765 : for (guint j = 0; j < NNS_TENSOR_RANK_LIMIT; j++)
157 720 : tensor->add_dimension (_info->dimension[j]);
158 :
159 45 : tensor->set_data (map.data, map.size);
160 :
161 45 : gst_memory_unmap (mem, &map);
162 45 : gst_memory_unref (mem);
163 : }
164 45 : }
165 :
166 : /** @brief Constructor of SyncServiceImplProtobuf */
167 8 : SyncServiceImplProtobuf::SyncServiceImplProtobuf (const grpc_config *config)
168 8 : : ServiceImplProtobuf (config)
169 : {
170 8 : }
171 :
172 : /** @brief client-to-server streaming: a client sends tensors */
173 : Status
174 2 : SyncServiceImplProtobuf::SendTensors (
175 : ServerContext *context, ServerReader<Tensors> *reader, Empty *reply)
176 : {
177 2 : return _read_tensors (reader);
178 : }
179 :
180 : /** @brief server-to-client streaming: a client receives tensors */
181 : Status
182 2 : SyncServiceImplProtobuf::RecvTensors (
183 : ServerContext *context, const Empty *request, ServerWriter<Tensors> *writer)
184 : {
185 2 : return _write_tensors (writer);
186 : }
187 :
188 : /** @brief start gRPC server handling protobuf */
189 : gboolean
190 4 : SyncServiceImplProtobuf::start_server (std::string address)
191 : {
192 : /* listen on the given address without any authentication mechanism */
193 4 : ServerBuilder builder;
194 4 : builder.AddListeningPort (address, grpc::InsecureServerCredentials (), &port_);
195 4 : builder.RegisterService (this);
196 :
197 : /* start the server */
198 4 : server_instance_ = builder.BuildAndStart ();
199 4 : if (server_instance_.get () == nullptr)
200 0 : return FALSE;
201 :
202 4 : return TRUE;
203 4 : }
204 :
205 : /** @brief start gRPC client handling protobuf */
206 : gboolean
207 4 : SyncServiceImplProtobuf::start_client (std::string address)
208 : {
209 : /* create a gRPC channel */
210 : std::shared_ptr<Channel> channel
211 4 : = grpc::CreateChannel (address, grpc::InsecureChannelCredentials ());
212 :
213 : /* connect the server */
214 : try {
215 4 : client_stub_ = TensorService::NewStub (channel);
216 0 : } catch (...) {
217 0 : ml_loge ("Failed to connect the server");
218 0 : return FALSE;
219 0 : }
220 :
221 8 : worker_ = std::thread ([this] { this->_client_thread (); });
222 :
223 4 : return TRUE;
224 4 : }
225 :
226 : /** @brief gRPC client thread */
227 : void
228 4 : SyncServiceImplProtobuf::_client_thread ()
229 : {
230 4 : ClientContext context;
231 4 : Empty empty;
232 :
233 4 : if (direction_ == GRPC_DIRECTION_TENSORS_TO_BUFFER) {
234 : /* initiate the RPC call */
235 : std::unique_ptr<ClientWriter<Tensors>> writer (
236 2 : client_stub_->SendTensors (&context, &empty));
237 :
238 2 : _write_tensors (writer.get ());
239 :
240 2 : writer->WritesDone ();
241 2 : writer->Finish ();
242 4 : } else if (direction_ == GRPC_DIRECTION_BUFFER_TO_TENSORS) {
243 2 : Tensors tensors;
244 :
245 : /* initiate the RPC call */
246 : std::unique_ptr<ClientReader<Tensors>> reader (
247 2 : client_stub_->RecvTensors (&context, empty));
248 :
249 2 : _read_tensors (reader.get ());
250 :
251 2 : reader->Finish ();
252 2 : } else {
253 0 : g_assert (0); /* internal logic error */
254 : }
255 8 : }
256 :
257 : /** @brief Constructor of AsyncServiceImplProtobuf */
258 6 : AsyncServiceImplProtobuf::AsyncServiceImplProtobuf (const grpc_config *config)
259 6 : : ServiceImplProtobuf (config), last_call_ (nullptr)
260 : {
261 6 : }
262 :
263 : /** @brief Destructor of AsyncServiceImplProtobuf */
264 4 : AsyncServiceImplProtobuf::~AsyncServiceImplProtobuf ()
265 : {
266 2 : if (last_call_)
267 0 : delete last_call_;
268 4 : }
269 :
270 : /** @brief start gRPC server handling protobuf */
271 : gboolean
272 2 : AsyncServiceImplProtobuf::start_server (std::string address)
273 : {
274 : /* listen on the given address without any authentication mechanism */
275 2 : ServerBuilder builder;
276 2 : builder.AddListeningPort (address, grpc::InsecureServerCredentials (), &port_);
277 2 : builder.RegisterService (this);
278 :
279 : /* need to manually handle the completion queue */
280 2 : completion_queue_ = builder.AddCompletionQueue ();
281 :
282 : /* start the server */
283 2 : server_instance_ = builder.BuildAndStart ();
284 2 : if (server_instance_.get () == nullptr)
285 0 : return FALSE;
286 :
287 4 : worker_ = std::thread ([this] { this->_server_thread (); });
288 :
289 2 : return TRUE;
290 2 : }
291 :
292 : /** @brief start gRPC client handling protobuf */
293 : gboolean
294 4 : AsyncServiceImplProtobuf::start_client (std::string address)
295 : {
296 : /* create a gRPC channel */
297 : std::shared_ptr<Channel> channel
298 4 : = grpc::CreateChannel (address, grpc::InsecureChannelCredentials ());
299 :
300 : /* connect the server */
301 : try {
302 4 : client_stub_ = TensorService::NewStub (channel);
303 0 : } catch (...) {
304 0 : ml_loge ("Failed to connect the server");
305 0 : return FALSE;
306 0 : }
307 :
308 8 : worker_ = std::thread ([this] { this->_client_thread (); });
309 :
310 4 : return TRUE;
311 4 : }
312 :
313 : /** @brief Internal derived class for server */
314 : class AsyncCallDataServer : public AsyncCallData
315 : {
316 : public:
317 : /** @brief Constructor of AsyncCallDataServer */
318 4 : AsyncCallDataServer (AsyncServiceImplProtobuf *service, ServerCompletionQueue *cq)
319 4 : : AsyncCallData (service), cq_ (cq), writer_ (nullptr), reader_ (nullptr)
320 : {
321 4 : RunState ();
322 4 : }
323 :
324 : /** @brief implemented RunState () of AsyncCallDataServer */
325 27 : void RunState (bool ok = true) override
326 : {
327 27 : if (state_ == PROCESS && !ok) {
328 4 : if (count_ != 0) {
329 2 : if (reader_.get () != nullptr)
330 2 : service_->parse_tensors (rpc_tensors_);
331 2 : state_ = FINISH;
332 : } else {
333 2 : return;
334 : }
335 : }
336 :
337 25 : if (state_ == CREATE) {
338 4 : if (service_->getDirection () == GRPC_DIRECTION_BUFFER_TO_TENSORS) {
339 4 : reader_.reset (new ServerAsyncReader<Empty, Tensors> (&ctx_));
340 4 : service_->RequestSendTensors (&ctx_, reader_.get (), cq_, cq_, this);
341 : } else {
342 0 : writer_.reset (new ServerAsyncWriter<Tensors> (&ctx_));
343 0 : service_->RequestRecvTensors (&ctx_, &rpc_empty_, writer_.get (), cq_, cq_, this);
344 : }
345 4 : state_ = PROCESS;
346 21 : } else if (state_ == PROCESS) {
347 17 : if (count_ == 0) {
348 : /* spawn a new instance to serve new clients */
349 2 : service_->set_last_call (new AsyncCallDataServer (service_, cq_));
350 : }
351 :
352 17 : if (reader_.get () != nullptr) {
353 17 : if (count_ != 0)
354 15 : service_->parse_tensors (rpc_tensors_);
355 17 : reader_->Read (&rpc_tensors_, this);
356 : /* can't read tensors yet. use the next turn */
357 17 : count_++;
358 0 : } else if (writer_.get () != nullptr) {
359 0 : Tensors tensors;
360 0 : if (service_->fill_tensors (tensors)) {
361 0 : writer_->Write (tensors, this);
362 0 : count_++;
363 : } else {
364 0 : Status status;
365 0 : writer_->Finish (status, this);
366 0 : state_ = DESTROY;
367 0 : }
368 0 : }
369 4 : } else if (state_ == FINISH) {
370 2 : if (reader_.get () != nullptr)
371 2 : reader_->Finish (rpc_empty_, Status::OK, this);
372 2 : if (writer_.get () != nullptr) {
373 0 : Status status;
374 0 : writer_->Finish (status, this);
375 0 : }
376 2 : state_ = DESTROY;
377 : } else {
378 2 : delete this;
379 : }
380 : }
381 :
382 : private:
383 : ServerCompletionQueue *cq_;
384 : ServerContext ctx_;
385 :
386 : std::unique_ptr<ServerAsyncWriter<Tensors>> writer_;
387 : std::unique_ptr<ServerAsyncReader<Empty, Tensors>> reader_;
388 : };
389 :
390 : /** @brief Internal derived class for client */
391 : class AsyncCallDataClient : public AsyncCallData
392 : {
393 : public:
394 : /** @brief Constructor of AsyncCallDataClient */
395 4 : AsyncCallDataClient (AsyncServiceImplProtobuf *service,
396 : TensorService::Stub *stub, CompletionQueue *cq)
397 4 : : AsyncCallData (service), stub_ (stub), cq_ (cq), writer_ (nullptr),
398 8 : reader_ (nullptr)
399 : {
400 4 : RunState ();
401 4 : }
402 :
403 : /** @brief implemented RunState () of AsyncCallDataClient */
404 38 : void RunState (bool ok = true) override
405 : {
406 38 : if (state_ == PROCESS && !ok) {
407 0 : if (count_ != 0) {
408 0 : if (reader_.get () != nullptr)
409 0 : service_->parse_tensors (rpc_tensors_);
410 0 : state_ = FINISH;
411 : } else {
412 0 : return;
413 : }
414 : }
415 :
416 38 : if (state_ == CREATE) {
417 4 : if (service_->getDirection () == GRPC_DIRECTION_BUFFER_TO_TENSORS) {
418 2 : reader_ = stub_->AsyncRecvTensors (&ctx_, rpc_empty_, cq_, this);
419 : } else {
420 2 : writer_ = stub_->AsyncSendTensors (&ctx_, &rpc_empty_, cq_, this);
421 : }
422 4 : state_ = PROCESS;
423 34 : } else if (state_ == PROCESS) {
424 34 : if (reader_.get () != nullptr) {
425 17 : if (count_ != 0)
426 15 : service_->parse_tensors (rpc_tensors_);
427 17 : reader_->Read (&rpc_tensors_, this);
428 : /* can't read tensors yet. use the next turn */
429 17 : count_++;
430 17 : } else if (writer_.get () != nullptr) {
431 17 : Tensors tensors;
432 17 : if (service_->fill_tensors (tensors)) {
433 15 : writer_->Write (tensors, this);
434 15 : count_++;
435 : } else {
436 2 : writer_->WritesDone (this);
437 2 : state_ = FINISH;
438 : }
439 17 : }
440 0 : } else if (state_ == FINISH) {
441 0 : Status status;
442 :
443 0 : if (reader_.get () != nullptr)
444 0 : reader_->Finish (&status, this);
445 0 : if (writer_.get () != nullptr)
446 0 : writer_->Finish (&status, this);
447 :
448 0 : delete this;
449 0 : }
450 : }
451 :
452 : private:
453 : TensorService::Stub *stub_;
454 : CompletionQueue *cq_;
455 : ClientContext ctx_;
456 :
457 : std::unique_ptr<ClientAsyncWriter<Tensors>> writer_;
458 : std::unique_ptr<ClientAsyncReader<Tensors>> reader_;
459 : };
460 :
461 : /** @brief gRPC client thread */
462 : void
463 2 : AsyncServiceImplProtobuf::_server_thread ()
464 : {
465 : /* spawn a new instance to server new clients */
466 2 : set_last_call (new AsyncCallDataServer (this, completion_queue_.get ()));
467 :
468 : while (1) {
469 : void *tag;
470 : bool ok;
471 :
472 : /* 10 msec deadline to wait the next event */
473 567 : gpr_timespec deadline = gpr_time_add (
474 : gpr_now (GPR_CLOCK_MONOTONIC), gpr_time_from_millis (10, GPR_TIMESPAN));
475 :
476 567 : switch (completion_queue_->AsyncNext (&tag, &ok, deadline)) {
477 23 : case CompletionQueue::GOT_EVENT:
478 23 : static_cast<AsyncCallDataServer *> (tag)->RunState (ok);
479 23 : break;
480 2 : case CompletionQueue::SHUTDOWN:
481 2 : return;
482 542 : default:
483 542 : break;
484 : }
485 565 : }
486 : }
487 :
488 : /** @brief gRPC client thread */
489 : void
490 4 : AsyncServiceImplProtobuf::_client_thread ()
491 : {
492 4 : CompletionQueue cq;
493 :
494 : /* spawn a new instance to serve new clients */
495 4 : new AsyncCallDataClient (this, client_stub_.get (), &cq);
496 :
497 : /* until the stop is called */
498 160 : while (!stop_) {
499 : void *tag;
500 : bool ok;
501 :
502 : /* 10 msec deadline to wait the next event */
503 156 : gpr_timespec deadline = gpr_time_add (
504 : gpr_now (GPR_CLOCK_MONOTONIC), gpr_time_from_millis (10, GPR_TIMESPAN));
505 :
506 156 : switch (cq.AsyncNext (&tag, &ok, deadline)) {
507 34 : case CompletionQueue::GOT_EVENT:
508 34 : static_cast<AsyncCallDataClient *> (tag)->RunState (ok);
509 34 : if (ok == false)
510 0 : return;
511 34 : break;
512 122 : default:
513 122 : break;
514 : }
515 : }
516 4 : }
517 :
518 : /** @brief create gRPC/Protobuf instance */
519 : extern "C" void *
520 14 : create_instance (const grpc_config *config)
521 : {
522 14 : if (config->is_blocking)
523 8 : return new SyncServiceImplProtobuf (config);
524 : else
525 6 : return new AsyncServiceImplProtobuf (config);
526 : }
|