LCOV - code coverage report
Current view: top level - nnstreamer-2.4.2/ext/nnstreamer/extra - nnstreamer_grpc_protobuf.cc (source / functions) Coverage Total Hit
Test: nnstreamer 2.4.2-0 nnstreamer/nnstreamer#eca68b8d050408568af95d831a8eef62aaee7784 Lines: 82.7 % 243 201
Test Date: 2025-03-14 05:36:58 Functions: 100.0 % 30 30

            Line data    Source code
       1              : /* SPDX-License-Identifier: LGPL-2.1-only */
       2              : /**
       3              :  * GStreamer / NNStreamer gRPC/protobuf support
       4              :  * Copyright (C) 2020 Dongju Chae <dongju.chae@samsung.com>
       5              :  */
       6              : /**
       7              :  * @file    nnstreamer_grpc_protobuf.cc
       8              :  * @date    21 Oct 2020
       9              :  * @brief   gRPC/Protobuf wrappers for nnstreamer
      10              :  * @see     https://github.com/nnstreamer/nnstreamer
      11              :  * @author  Dongju Chae <dongju.chae@samsung.com>
      12              :  * @bug     No known bugs except for NYI items
      13              :  */
      14              : 
      15              : #include "nnstreamer_grpc_protobuf.h"
      16              : 
      17              : #include <nnstreamer_log.h>
      18              : #include <nnstreamer_plugin_api.h>
      19              : #include <nnstreamer_util.h>
      20              : 
      21              : #include <thread>
      22              : 
      23              : #include <grpcpp/channel.h>
      24              : #include <grpcpp/client_context.h>
      25              : #include <grpcpp/create_channel.h>
      26              : #include <grpcpp/grpcpp.h>
      27              : #include <grpcpp/security/credentials.h>
      28              : 
      29              : #include <gst/base/gstdataqueue.h>
      30              : 
      31              : using namespace grpc;
      32              : 
      33              : /** @brief constructor */
      34           14 : ServiceImplProtobuf::ServiceImplProtobuf (const grpc_config *config)
      35           14 :     : NNStreamerRPC (config), client_stub_ (nullptr)
      36              : {
      37           14 : }
      38              : 
      39              : /** @brief parse tensors and deliver the buffer via callback */
      40              : void
      41           62 : ServiceImplProtobuf::parse_tensors (Tensors &tensors)
      42              : {
      43              :   GstBuffer *buffer;
      44              : 
      45           62 :   _get_buffer_from_tensors (tensors, &buffer);
      46              : 
      47           62 :   if (cb_)
      48           62 :     cb_ (cb_data_, buffer);
      49              :   else
      50            0 :     gst_buffer_unref (buffer);
      51           62 : }
      52              : 
      53              : /** @brief fill tensors from the buffer */
      54              : gboolean
      55           51 : ServiceImplProtobuf::fill_tensors (Tensors &tensors)
      56              : {
      57              :   GstDataQueueItem *item;
      58              : 
      59           51 :   if (!gst_data_queue_pop (queue_, &item))
      60            6 :     return FALSE;
      61              : 
      62           45 :   _get_tensors_from_buffer (GST_BUFFER (item->object), tensors);
      63              : 
      64           45 :   GDestroyNotify destroy = (item->destroy) ? item->destroy : g_free;
      65           45 :   destroy (item);
      66              : 
      67           45 :   return TRUE;
      68              : }
      69              : 
      70              : /** @brief read tensors and invoke the registered callback */
      71              : template <typename T>
      72              : Status
      73            4 : ServiceImplProtobuf::_read_tensors (T reader)
      74              : {
      75           30 :   while (1) {
      76           34 :     Tensors tensors;
      77              : 
      78           34 :     if (!reader->Read (&tensors))
      79            4 :       break;
      80              : 
      81           30 :     parse_tensors (tensors);
      82              :   }
      83              : 
      84            4 :   return Status::OK;
      85              : }
      86              : 
      87              : /** @brief obtain tensors from data queue and send them over gRPC */
      88              : template <typename T>
      89              : Status
      90            4 : ServiceImplProtobuf::_write_tensors (T writer)
      91              : {
      92           30 :   while (1) {
      93           34 :     Tensors tensors;
      94              : 
      95              :     /* until flushing */
      96           34 :     if (!fill_tensors (tensors))
      97            4 :       break;
      98              : 
      99           30 :     writer->Write (tensors);
     100              :   }
     101              : 
     102            4 :   return Status::OK;
     103              : }
     104              : 
     105              : /** @brief convert tensors to buffer */
     106              : void
     107           62 : ServiceImplProtobuf::_get_buffer_from_tensors (Tensors &tensors, GstBuffer **buffer)
     108              : {
     109           62 :   guint num_tensor = tensors.num_tensor ();
     110              :   GstTensorInfo *_info;
     111              :   GstMemory *memory;
     112              : 
     113           62 :   *buffer = gst_buffer_new ();
     114              : 
     115          124 :   for (guint i = 0; i < num_tensor; i++) {
     116           62 :     const Tensor *tensor = &tensors.tensor (i);
     117           62 :     const void *data = tensor->data ().c_str ();
     118           62 :     gsize size = tensor->data ().length ();
     119           62 :     gpointer new_data = _g_memdup (data, size);
     120              : 
     121           62 :     _info = gst_tensors_info_get_nth_info (&config_->info, i);
     122              : 
     123           62 :     memory = gst_memory_new_wrapped (
     124              :         (GstMemoryFlags) 0, new_data, size, 0, size, new_data, g_free);
     125           62 :     gst_tensor_buffer_append_memory (*buffer, memory, _info);
     126              :   }
     127           62 : }
     128              : 
     129              : /** @brief convert buffer to tensors */
     130              : void
     131           45 : ServiceImplProtobuf::_get_tensors_from_buffer (GstBuffer *buffer, Tensors &tensors)
     132              : {
     133              :   Tensors::frame_rate *fr;
     134              :   GstTensorInfo *_info;
     135              :   GstMemory *mem;
     136              :   GstMapInfo map;
     137              : 
     138           45 :   tensors.set_num_tensor (config_->info.num_tensors);
     139              : 
     140           45 :   fr = tensors.mutable_fr ();
     141           45 :   fr->set_rate_n (config_->rate_n);
     142           45 :   fr->set_rate_d (config_->rate_d);
     143              : 
     144           90 :   for (guint i = 0; i < config_->info.num_tensors; i++) {
     145           45 :     nnstreamer::protobuf::Tensor *tensor = tensors.add_tensor ();
     146              : 
     147           45 :     _info = gst_tensors_info_get_nth_info (&config_->info, i);
     148              : 
     149           45 :     mem = gst_tensor_buffer_get_nth_memory (buffer, i);
     150           45 :     g_assert (gst_memory_map (mem, &map, GST_MAP_READ));
     151              : 
     152              :     /* set tensor info */
     153              :     tensor->set_name ("Anonymous");
     154           45 :     tensor->set_type ((Tensor::Tensor_type) _info->type);
     155              : 
     156          765 :     for (guint j = 0; j < NNS_TENSOR_RANK_LIMIT; j++)
     157          720 :       tensor->add_dimension (_info->dimension[j]);
     158              : 
     159           45 :     tensor->set_data (map.data, map.size);
     160              : 
     161           45 :     gst_memory_unmap (mem, &map);
     162           45 :     gst_memory_unref (mem);
     163              :   }
     164           45 : }
     165              : 
     166              : /** @brief Constructor of SyncServiceImplProtobuf */
     167            8 : SyncServiceImplProtobuf::SyncServiceImplProtobuf (const grpc_config *config)
     168            8 :     : ServiceImplProtobuf (config)
     169              : {
     170            8 : }
     171              : 
     172              : /** @brief client-to-server streaming: a client sends tensors */
     173              : Status
     174            2 : SyncServiceImplProtobuf::SendTensors (
     175              :     ServerContext *context, ServerReader<Tensors> *reader, Empty *reply)
     176              : {
     177            2 :   return _read_tensors (reader);
     178              : }
     179              : 
     180              : /** @brief server-to-client streaming: a client receives tensors */
     181              : Status
     182            2 : SyncServiceImplProtobuf::RecvTensors (
     183              :     ServerContext *context, const Empty *request, ServerWriter<Tensors> *writer)
     184              : {
     185            2 :   return _write_tensors (writer);
     186              : }
     187              : 
     188              : /** @brief start gRPC server handling protobuf */
     189              : gboolean
     190            4 : SyncServiceImplProtobuf::start_server (std::string address)
     191              : {
     192              :   /* listen on the given address without any authentication mechanism */
     193            4 :   ServerBuilder builder;
     194            4 :   builder.AddListeningPort (address, grpc::InsecureServerCredentials (), &port_);
     195            4 :   builder.RegisterService (this);
     196              : 
     197              :   /* start the server */
     198            4 :   server_instance_ = builder.BuildAndStart ();
     199            4 :   if (server_instance_.get () == nullptr)
     200            0 :     return FALSE;
     201              : 
     202            4 :   return TRUE;
     203            4 : }
     204              : 
     205              : /** @brief start gRPC client handling protobuf */
     206              : gboolean
     207            4 : SyncServiceImplProtobuf::start_client (std::string address)
     208              : {
     209              :   /* create a gRPC channel */
     210              :   std::shared_ptr<Channel> channel
     211            4 :       = grpc::CreateChannel (address, grpc::InsecureChannelCredentials ());
     212              : 
     213              :   /* connect the server */
     214              :   try {
     215            4 :     client_stub_ = TensorService::NewStub (channel);
     216            0 :   } catch (...) {
     217            0 :     ml_loge ("Failed to connect the server");
     218            0 :     return FALSE;
     219            0 :   }
     220              : 
     221            8 :   worker_ = std::thread ([this] { this->_client_thread (); });
     222              : 
     223            4 :   return TRUE;
     224            4 : }
     225              : 
     226              : /** @brief gRPC client thread */
     227              : void
     228            4 : SyncServiceImplProtobuf::_client_thread ()
     229              : {
     230            4 :   ClientContext context;
     231            4 :   Empty empty;
     232              : 
     233            4 :   if (direction_ == GRPC_DIRECTION_TENSORS_TO_BUFFER) {
     234              :     /* initiate the RPC call */
     235              :     std::unique_ptr<ClientWriter<Tensors>> writer (
     236            2 :         client_stub_->SendTensors (&context, &empty));
     237              : 
     238            2 :     _write_tensors (writer.get ());
     239              : 
     240            2 :     writer->WritesDone ();
     241            2 :     writer->Finish ();
     242            4 :   } else if (direction_ == GRPC_DIRECTION_BUFFER_TO_TENSORS) {
     243            2 :     Tensors tensors;
     244              : 
     245              :     /* initiate the RPC call */
     246              :     std::unique_ptr<ClientReader<Tensors>> reader (
     247            2 :         client_stub_->RecvTensors (&context, empty));
     248              : 
     249            2 :     _read_tensors (reader.get ());
     250              : 
     251            2 :     reader->Finish ();
     252            2 :   } else {
     253            0 :     g_assert (0); /* internal logic error */
     254              :   }
     255            8 : }
     256              : 
     257              : /** @brief Constructor of AsyncServiceImplProtobuf */
     258            6 : AsyncServiceImplProtobuf::AsyncServiceImplProtobuf (const grpc_config *config)
     259            6 :     : ServiceImplProtobuf (config), last_call_ (nullptr)
     260              : {
     261            6 : }
     262              : 
     263              : /** @brief Destructor of AsyncServiceImplProtobuf */
     264            4 : AsyncServiceImplProtobuf::~AsyncServiceImplProtobuf ()
     265              : {
     266            2 :   if (last_call_)
     267            0 :     delete last_call_;
     268            4 : }
     269              : 
     270              : /** @brief start gRPC server handling protobuf */
     271              : gboolean
     272            2 : AsyncServiceImplProtobuf::start_server (std::string address)
     273              : {
     274              :   /* listen on the given address without any authentication mechanism */
     275            2 :   ServerBuilder builder;
     276            2 :   builder.AddListeningPort (address, grpc::InsecureServerCredentials (), &port_);
     277            2 :   builder.RegisterService (this);
     278              : 
     279              :   /* need to manually handle the completion queue */
     280            2 :   completion_queue_ = builder.AddCompletionQueue ();
     281              : 
     282              :   /* start the server */
     283            2 :   server_instance_ = builder.BuildAndStart ();
     284            2 :   if (server_instance_.get () == nullptr)
     285            0 :     return FALSE;
     286              : 
     287            4 :   worker_ = std::thread ([this] { this->_server_thread (); });
     288              : 
     289            2 :   return TRUE;
     290            2 : }
     291              : 
     292              : /** @brief start gRPC client handling protobuf */
     293              : gboolean
     294            4 : AsyncServiceImplProtobuf::start_client (std::string address)
     295              : {
     296              :   /* create a gRPC channel */
     297              :   std::shared_ptr<Channel> channel
     298            4 :       = grpc::CreateChannel (address, grpc::InsecureChannelCredentials ());
     299              : 
     300              :   /* connect the server */
     301              :   try {
     302            4 :     client_stub_ = TensorService::NewStub (channel);
     303            0 :   } catch (...) {
     304            0 :     ml_loge ("Failed to connect the server");
     305            0 :     return FALSE;
     306            0 :   }
     307              : 
     308            8 :   worker_ = std::thread ([this] { this->_client_thread (); });
     309              : 
     310            4 :   return TRUE;
     311            4 : }
     312              : 
     313              : /** @brief Internal derived class for server */
     314              : class AsyncCallDataServer : public AsyncCallData
     315              : {
     316              :   public:
     317              :   /** @brief Constructor of AsyncCallDataServer */
     318            4 :   AsyncCallDataServer (AsyncServiceImplProtobuf *service, ServerCompletionQueue *cq)
     319            4 :       : AsyncCallData (service), cq_ (cq), writer_ (nullptr), reader_ (nullptr)
     320              :   {
     321            4 :     RunState ();
     322            4 :   }
     323              : 
     324              :   /** @brief implemented RunState () of AsyncCallDataServer */
     325           27 :   void RunState (bool ok = true) override
     326              :   {
     327           27 :     if (state_ == PROCESS && !ok) {
     328            4 :       if (count_ != 0) {
     329            2 :         if (reader_.get () != nullptr)
     330            2 :           service_->parse_tensors (rpc_tensors_);
     331            2 :         state_ = FINISH;
     332              :       } else {
     333            2 :         return;
     334              :       }
     335              :     }
     336              : 
     337           25 :     if (state_ == CREATE) {
     338            4 :       if (service_->getDirection () == GRPC_DIRECTION_BUFFER_TO_TENSORS) {
     339            4 :         reader_.reset (new ServerAsyncReader<Empty, Tensors> (&ctx_));
     340            4 :         service_->RequestSendTensors (&ctx_, reader_.get (), cq_, cq_, this);
     341              :       } else {
     342            0 :         writer_.reset (new ServerAsyncWriter<Tensors> (&ctx_));
     343            0 :         service_->RequestRecvTensors (&ctx_, &rpc_empty_, writer_.get (), cq_, cq_, this);
     344              :       }
     345            4 :       state_ = PROCESS;
     346           21 :     } else if (state_ == PROCESS) {
     347           17 :       if (count_ == 0) {
     348              :         /* spawn a new instance to serve new clients */
     349            2 :         service_->set_last_call (new AsyncCallDataServer (service_, cq_));
     350              :       }
     351              : 
     352           17 :       if (reader_.get () != nullptr) {
     353           17 :         if (count_ != 0)
     354           15 :           service_->parse_tensors (rpc_tensors_);
     355           17 :         reader_->Read (&rpc_tensors_, this);
     356              :         /* can't read tensors yet. use the next turn */
     357           17 :         count_++;
     358            0 :       } else if (writer_.get () != nullptr) {
     359            0 :         Tensors tensors;
     360            0 :         if (service_->fill_tensors (tensors)) {
     361            0 :           writer_->Write (tensors, this);
     362            0 :           count_++;
     363              :         } else {
     364            0 :           Status status;
     365            0 :           writer_->Finish (status, this);
     366            0 :           state_ = DESTROY;
     367            0 :         }
     368            0 :       }
     369            4 :     } else if (state_ == FINISH) {
     370            2 :       if (reader_.get () != nullptr)
     371            2 :         reader_->Finish (rpc_empty_, Status::OK, this);
     372            2 :       if (writer_.get () != nullptr) {
     373            0 :         Status status;
     374            0 :         writer_->Finish (status, this);
     375            0 :       }
     376            2 :       state_ = DESTROY;
     377              :     } else {
     378            2 :       delete this;
     379              :     }
     380              :   }
     381              : 
     382              :   private:
     383              :   ServerCompletionQueue *cq_;
     384              :   ServerContext ctx_;
     385              : 
     386              :   std::unique_ptr<ServerAsyncWriter<Tensors>> writer_;
     387              :   std::unique_ptr<ServerAsyncReader<Empty, Tensors>> reader_;
     388              : };
     389              : 
     390              : /** @brief Internal derived class for client */
     391              : class AsyncCallDataClient : public AsyncCallData
     392              : {
     393              :   public:
     394              :   /** @brief Constructor of AsyncCallDataClient */
     395            4 :   AsyncCallDataClient (AsyncServiceImplProtobuf *service,
     396              :       TensorService::Stub *stub, CompletionQueue *cq)
     397            4 :       : AsyncCallData (service), stub_ (stub), cq_ (cq), writer_ (nullptr),
     398            8 :         reader_ (nullptr)
     399              :   {
     400            4 :     RunState ();
     401            4 :   }
     402              : 
     403              :   /** @brief implemented RunState () of AsyncCallDataClient */
     404           38 :   void RunState (bool ok = true) override
     405              :   {
     406           38 :     if (state_ == PROCESS && !ok) {
     407            0 :       if (count_ != 0) {
     408            0 :         if (reader_.get () != nullptr)
     409            0 :           service_->parse_tensors (rpc_tensors_);
     410            0 :         state_ = FINISH;
     411              :       } else {
     412            0 :         return;
     413              :       }
     414              :     }
     415              : 
     416           38 :     if (state_ == CREATE) {
     417            4 :       if (service_->getDirection () == GRPC_DIRECTION_BUFFER_TO_TENSORS) {
     418            2 :         reader_ = stub_->AsyncRecvTensors (&ctx_, rpc_empty_, cq_, this);
     419              :       } else {
     420            2 :         writer_ = stub_->AsyncSendTensors (&ctx_, &rpc_empty_, cq_, this);
     421              :       }
     422            4 :       state_ = PROCESS;
     423           34 :     } else if (state_ == PROCESS) {
     424           34 :       if (reader_.get () != nullptr) {
     425           17 :         if (count_ != 0)
     426           15 :           service_->parse_tensors (rpc_tensors_);
     427           17 :         reader_->Read (&rpc_tensors_, this);
     428              :         /* can't read tensors yet. use the next turn */
     429           17 :         count_++;
     430           17 :       } else if (writer_.get () != nullptr) {
     431           17 :         Tensors tensors;
     432           17 :         if (service_->fill_tensors (tensors)) {
     433           15 :           writer_->Write (tensors, this);
     434           15 :           count_++;
     435              :         } else {
     436            2 :           writer_->WritesDone (this);
     437            2 :           state_ = FINISH;
     438              :         }
     439           17 :       }
     440            0 :     } else if (state_ == FINISH) {
     441            0 :       Status status;
     442              : 
     443            0 :       if (reader_.get () != nullptr)
     444            0 :         reader_->Finish (&status, this);
     445            0 :       if (writer_.get () != nullptr)
     446            0 :         writer_->Finish (&status, this);
     447              : 
     448            0 :       delete this;
     449            0 :     }
     450              :   }
     451              : 
     452              :   private:
     453              :   TensorService::Stub *stub_;
     454              :   CompletionQueue *cq_;
     455              :   ClientContext ctx_;
     456              : 
     457              :   std::unique_ptr<ClientAsyncWriter<Tensors>> writer_;
     458              :   std::unique_ptr<ClientAsyncReader<Tensors>> reader_;
     459              : };
     460              : 
     461              : /** @brief gRPC client thread */
     462              : void
     463            2 : AsyncServiceImplProtobuf::_server_thread ()
     464              : {
     465              :   /* spawn a new instance to server new clients */
     466            2 :   set_last_call (new AsyncCallDataServer (this, completion_queue_.get ()));
     467              : 
     468              :   while (1) {
     469              :     void *tag;
     470              :     bool ok;
     471              : 
     472              :     /* 10 msec deadline to wait the next event */
     473          567 :     gpr_timespec deadline = gpr_time_add (
     474              :         gpr_now (GPR_CLOCK_MONOTONIC), gpr_time_from_millis (10, GPR_TIMESPAN));
     475              : 
     476          567 :     switch (completion_queue_->AsyncNext (&tag, &ok, deadline)) {
     477           23 :       case CompletionQueue::GOT_EVENT:
     478           23 :         static_cast<AsyncCallDataServer *> (tag)->RunState (ok);
     479           23 :         break;
     480            2 :       case CompletionQueue::SHUTDOWN:
     481            2 :         return;
     482          542 :       default:
     483          542 :         break;
     484              :     }
     485          565 :   }
     486              : }
     487              : 
     488              : /** @brief gRPC client thread */
     489              : void
     490            4 : AsyncServiceImplProtobuf::_client_thread ()
     491              : {
     492            4 :   CompletionQueue cq;
     493              : 
     494              :   /* spawn a new instance to serve new clients */
     495            4 :   new AsyncCallDataClient (this, client_stub_.get (), &cq);
     496              : 
     497              :   /* until the stop is called */
     498          160 :   while (!stop_) {
     499              :     void *tag;
     500              :     bool ok;
     501              : 
     502              :     /* 10 msec deadline to wait the next event */
     503          156 :     gpr_timespec deadline = gpr_time_add (
     504              :         gpr_now (GPR_CLOCK_MONOTONIC), gpr_time_from_millis (10, GPR_TIMESPAN));
     505              : 
     506          156 :     switch (cq.AsyncNext (&tag, &ok, deadline)) {
     507           34 :       case CompletionQueue::GOT_EVENT:
     508           34 :         static_cast<AsyncCallDataClient *> (tag)->RunState (ok);
     509           34 :         if (ok == false)
     510            0 :           return;
     511           34 :         break;
     512          122 :       default:
     513          122 :         break;
     514              :     }
     515              :   }
     516            4 : }
     517              : 
     518              : /** @brief create gRPC/Protobuf instance */
     519              : extern "C" void *
     520           14 : create_instance (const grpc_config *config)
     521              : {
     522           14 :   if (config->is_blocking)
     523            8 :     return new SyncServiceImplProtobuf (config);
     524              :   else
     525            6 :     return new AsyncServiceImplProtobuf (config);
     526              : }
        

Generated by: LCOV version 2.0-1