1use std::{
40 collections::{HashMap, HashSet},
41 future::Future,
42 pin::Pin,
43 sync::Arc,
44 task::Poll,
45};
46
47use dashmap::DashMap;
48use tonic::{Request, Response, Status, transport::Channel};
49use tower::Service;
50
51pub const DEFAULT_GRPC_PORT: u16 = 50051;
53
54pub mod proto {
56 tonic::include_proto!("trace2e");
57 pub mod primitives {
58 tonic::include_proto!("trace2e.primitives");
59 }
60 pub mod messages {
61 tonic::include_proto!("trace2e.messages");
62 }
63
64 pub const MIDDLEWARE_DESCRIPTOR_SET: &[u8] = include_bytes!("../../trace2e_descriptor.bin");
66}
67
68use crate::{
69 traceability::{
70 api::{M2mRequest, M2mResponse, P2mRequest, P2mResponse},
71 core::compliance::{ConfidentialityPolicy, Policy},
72 error::TraceabilityError,
73 naming::{Fd, File, Process, Resource, Stream},
74 },
75 transport::eval_remote_ip,
76};
77
78impl From<TraceabilityError> for Status {
80 fn from(error: TraceabilityError) -> Self {
81 Status::internal(error.to_string())
82 }
83}
84
85#[derive(Default, Clone)]
103pub struct M2mGrpc {
104 connected_remotes: Arc<DashMap<String, proto::trace2e_grpc_client::Trace2eGrpcClient<Channel>>>,
106}
107
108impl M2mGrpc {
109 async fn connect_remote(
126 &self,
127 remote_ip: String,
128 ) -> Result<proto::trace2e_grpc_client::Trace2eGrpcClient<Channel>, TraceabilityError> {
129 match proto::trace2e_grpc_client::Trace2eGrpcClient::connect(format!(
130 "{remote_ip}:{DEFAULT_GRPC_PORT}"
131 ))
132 .await
133 {
134 Ok(client) => {
135 self.connected_remotes.insert(remote_ip, client.clone());
136 Ok(client)
137 }
138 Err(_) => Err(TraceabilityError::TransportFailedToContactRemote(remote_ip)),
139 }
140 }
141
142 async fn get_client(
152 &self,
153 remote_ip: String,
154 ) -> Option<proto::trace2e_grpc_client::Trace2eGrpcClient<Channel>> {
155 self.connected_remotes.get(&remote_ip).map(|c| c.to_owned())
156 }
157
158 async fn get_client_or_connect(
176 &self,
177 remote_ip: String,
178 ) -> Result<proto::trace2e_grpc_client::Trace2eGrpcClient<Channel>, TraceabilityError> {
179 match self.get_client(remote_ip.clone()).await {
180 Some(client) => Ok(client),
181 None => self.connect_remote(remote_ip).await,
182 }
183 }
184}
185
186impl Service<M2mRequest> for M2mGrpc {
187 type Response = M2mResponse;
188 type Error = TraceabilityError;
189 type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send>>;
190
191 fn poll_ready(&mut self, _: &mut std::task::Context<'_>) -> Poll<Result<(), Self::Error>> {
192 Poll::Ready(Ok(()))
193 }
194
195 fn call(&mut self, request: M2mRequest) -> Self::Future {
196 let this = self.clone();
197 Box::pin(async move {
198 let remote_ip = eval_remote_ip(request.clone())?;
199 let mut client = this.get_client_or_connect(remote_ip.clone()).await?;
200 match request {
201 M2mRequest::GetDestinationCompliance { source, destination } => {
202 let proto_req = proto::messages::GetDestinationCompliance {
204 source: Some(source.into()),
205 destination: Some(destination.into()),
206 };
207
208 let response = client
210 .m2m_destination_compliance(Request::new(proto_req))
211 .await
212 .map_err(|_| TraceabilityError::TransportFailedToContactRemote(remote_ip))?
213 .into_inner();
214 Ok(M2mResponse::DestinationCompliance(
215 response.policy.map(|policy| policy.into()).unwrap_or_default(),
216 ))
217 }
218 M2mRequest::GetSourceCompliance { resources, .. } => {
219 let proto_req = proto::messages::GetSourceCompliance {
221 resources: resources.into_iter().map(|r| r.into()).collect(),
222 };
223
224 let response = client
226 .m2m_source_compliance(Request::new(proto_req))
227 .await
228 .map_err(|_| TraceabilityError::TransportFailedToContactRemote(remote_ip))?
229 .into_inner();
230 Ok(M2mResponse::SourceCompliance(
231 response
232 .policies
233 .into_iter()
234 .map(|policy| {
235 (
236 policy.resource.map(|r| r.into()).unwrap_or_default(),
237 policy.policy.map(|p| p.into()).unwrap_or_default(),
238 )
239 })
240 .collect(),
241 ))
242 }
243 M2mRequest::UpdateProvenance { source_prov, destination } => {
244 let proto_req = proto::messages::UpdateProvenance {
246 source_prov: source_prov.into_iter().map(|s| s.into()).collect(),
247 destination: Some(destination.into()),
248 };
249
250 client.m2m_update_provenance(Request::new(proto_req)).await.map_err(|_| {
252 TraceabilityError::TransportFailedToContactRemote(remote_ip)
253 })?;
254
255 Ok(M2mResponse::Ack)
256 }
257 }
258 })
259 }
260}
261
262pub struct Trace2eRouter<P2mApi, M2mApi> {
279 p2m: P2mApi,
281 m2m: M2mApi,
283}
284
285impl<P2mApi, M2mApi> Trace2eRouter<P2mApi, M2mApi> {
286 pub fn new(p2m: P2mApi, m2m: M2mApi) -> Self {
293 Self { p2m, m2m }
294 }
295}
296
297#[tonic::async_trait]
304impl<P2mApi, M2mApi> proto::trace2e_grpc_server::Trace2eGrpc for Trace2eRouter<P2mApi, M2mApi>
305where
306 P2mApi: Service<P2mRequest, Response = P2mResponse, Error = TraceabilityError>
307 + Clone
308 + Sync
309 + Send
310 + 'static,
311 P2mApi::Future: Send,
312 M2mApi: Service<M2mRequest, Response = M2mResponse, Error = TraceabilityError>
313 + Clone
314 + Sync
315 + Send
316 + 'static,
317 M2mApi::Future: Send,
318{
319 async fn p2m_local_enroll(
324 &self,
325 request: Request<proto::messages::LocalCt>,
326 ) -> Result<Response<proto::messages::Ack>, Status> {
327 let req = request.into_inner();
328 let mut p2m = self.p2m.clone();
329 match p2m
330 .call(P2mRequest::LocalEnroll {
331 pid: req.process_id,
332 fd: req.file_descriptor,
333 path: req.path,
334 })
335 .await?
336 {
337 P2mResponse::Ack => Ok(Response::new(proto::messages::Ack {})),
338 _ => Err(Status::internal("Internal traceability API error")),
339 }
340 }
341
342 async fn p2m_remote_enroll(
347 &self,
348 request: Request<proto::messages::RemoteCt>,
349 ) -> Result<Response<proto::messages::Ack>, Status> {
350 let req = request.into_inner();
351 let mut p2m = self.p2m.clone();
352 match p2m
353 .call(P2mRequest::RemoteEnroll {
354 pid: req.process_id,
355 fd: req.file_descriptor,
356 local_socket: req.local_socket,
357 peer_socket: req.peer_socket,
358 })
359 .await?
360 {
361 P2mResponse::Ack => Ok(Response::new(proto::messages::Ack {})),
362 _ => Err(Status::internal("Internal traceability API error")),
363 }
364 }
365
366 async fn p2m_io_request(
371 &self,
372 request: Request<proto::messages::IoInfo>,
373 ) -> Result<Response<proto::messages::Grant>, Status> {
374 let req = request.into_inner();
375 let mut p2m = self.p2m.clone();
376 match p2m
377 .call(P2mRequest::IoRequest {
378 pid: req.process_id,
379 fd: req.file_descriptor,
380 output: req.flow == proto::primitives::Flow::Output as i32,
381 })
382 .await?
383 {
384 P2mResponse::Grant(id) => {
385 Ok(Response::new(proto::messages::Grant { id: id.to_string() }))
386 }
387 _ => Err(Status::internal("Internal traceability API error")),
388 }
389 }
390
391 async fn p2m_io_report(
396 &self,
397 request: Request<proto::messages::IoResult>,
398 ) -> Result<Response<proto::messages::Ack>, Status> {
399 let req = request.into_inner();
400 let mut p2m = self.p2m.clone();
401 match p2m
402 .call(P2mRequest::IoReport {
403 pid: req.process_id,
404 fd: req.file_descriptor,
405 grant_id: req.grant_id.parse::<u128>().unwrap_or_default(),
406 result: req.result,
407 })
408 .await?
409 {
410 P2mResponse::Ack => Ok(Response::new(proto::messages::Ack {})),
411 _ => Err(Status::internal("Internal traceability API error")),
412 }
413 }
414
415 async fn m2m_destination_compliance(
420 &self,
421 request: Request<proto::messages::GetDestinationCompliance>,
422 ) -> Result<Response<proto::messages::DestinationCompliance>, Status> {
423 let req = request.into_inner();
424 let mut m2m = self.m2m.clone();
425 match m2m.call(req.into()).await? {
426 M2mResponse::DestinationCompliance(policy) => Ok(Response::new(policy.into())),
427 _ => Err(Status::internal("Internal traceability API error")),
428 }
429 }
430
431 async fn m2m_source_compliance(
436 &self,
437 request: Request<proto::messages::GetSourceCompliance>,
438 ) -> Result<Response<proto::messages::SourceCompliance>, Status> {
439 let req = request.into_inner();
440 let mut m2m = self.m2m.clone();
441 match m2m.call(req.into()).await? {
442 M2mResponse::SourceCompliance(policies) => Ok(Response::new(policies.into())),
443 _ => Err(Status::internal("Internal traceability API error")),
444 }
445 }
446
447 async fn m2m_update_provenance(
453 &self,
454 request: Request<proto::messages::UpdateProvenance>,
455 ) -> Result<Response<proto::messages::Ack>, Status> {
456 let req = request.into_inner();
457 let mut m2m = self.m2m.clone();
458 match m2m.call(req.into()).await? {
459 M2mResponse::Ack => Ok(Response::new(proto::messages::Ack {})),
460 _ => Err(Status::internal("Internal traceability API error")),
461 }
462 }
463}
464
465impl From<proto::messages::GetDestinationCompliance> for M2mRequest {
473 fn from(req: proto::messages::GetDestinationCompliance) -> Self {
474 M2mRequest::GetDestinationCompliance {
475 source: req.source.map(|s| s.into()).unwrap_or_default(),
476 destination: req.destination.map(|d| d.into()).unwrap_or_default(),
477 }
478 }
479}
480
481impl From<proto::messages::GetSourceCompliance> for M2mRequest {
483 fn from(req: proto::messages::GetSourceCompliance) -> Self {
484 M2mRequest::GetSourceCompliance {
485 authority_ip: String::new(), resources: req.resources.into_iter().map(|r| r.into()).collect(),
487 }
488 }
489}
490
491impl From<proto::primitives::MappedPolicy> for (Resource, Policy) {
493 fn from(policy: proto::primitives::MappedPolicy) -> Self {
494 (
495 policy.resource.map(|r| r.into()).unwrap_or_default(),
496 policy.policy.map(|p| p.into()).unwrap_or_default(),
497 )
498 }
499}
500
501impl From<proto::primitives::References> for (String, HashSet<Resource>) {
503 fn from(references: proto::primitives::References) -> Self {
504 (references.node, references.resources.into_iter().map(|r| r.into()).collect())
505 }
506}
507
508impl From<proto::messages::UpdateProvenance> for M2mRequest {
510 fn from(req: proto::messages::UpdateProvenance) -> Self {
511 M2mRequest::UpdateProvenance {
512 source_prov: req.source_prov.into_iter().map(|s| s.into()).collect(),
513 destination: req.destination.map(|d| d.into()).unwrap_or_default(),
514 }
515 }
516}
517
518impl From<Policy> for proto::messages::DestinationCompliance {
520 fn from(policy: Policy) -> Self {
521 proto::messages::DestinationCompliance { policy: Some(policy.into()) }
522 }
523}
524
525impl From<HashMap<Resource, Policy>> for proto::messages::SourceCompliance {
527 fn from(policies: HashMap<Resource, Policy>) -> Self {
528 proto::messages::SourceCompliance {
529 policies: policies
530 .into_iter()
531 .map(|(resource, policy)| proto::primitives::MappedPolicy {
532 resource: Some(resource.into()),
533 policy: Some(policy.into()),
534 })
535 .collect(),
536 }
537 }
538}
539
540impl From<(String, HashSet<Resource>)> for proto::primitives::References {
542 fn from((node, resources): (String, HashSet<Resource>)) -> Self {
543 proto::primitives::References {
544 node,
545 resources: resources.into_iter().map(|r| r.into()).collect(),
546 }
547 }
548}
549
550impl From<(Resource, Policy)> for proto::primitives::MappedPolicy {
552 fn from((resource, policy): (Resource, Policy)) -> Self {
553 proto::primitives::MappedPolicy {
554 resource: Some(resource.into()),
555 policy: Some(policy.into()),
556 }
557 }
558}
559
560impl From<proto::primitives::Resource> for Resource {
562 fn from(proto_resource: proto::primitives::Resource) -> Self {
563 match proto_resource.resource {
564 Some(proto::primitives::resource::Resource::Fd(fd)) => Resource::Fd(fd.into()),
565 Some(proto::primitives::resource::Resource::Process(process)) => {
566 Resource::Process(process.into())
567 }
568 None => Resource::None,
569 }
570 }
571}
572
573impl From<Resource> for proto::primitives::Resource {
575 fn from(resource: Resource) -> Self {
576 match resource {
577 Resource::Fd(fd) => proto::primitives::Resource {
578 resource: Some(proto::primitives::resource::Resource::Fd(fd.into())),
579 },
580 Resource::Process(process) => proto::primitives::Resource {
581 resource: Some(proto::primitives::resource::Resource::Process(process.into())),
582 },
583 Resource::None => proto::primitives::Resource { resource: None },
584 }
585 }
586}
587
588impl From<proto::primitives::Fd> for Fd {
589 fn from(proto_fd: proto::primitives::Fd) -> Self {
590 match proto_fd.fd {
591 Some(proto::primitives::fd::Fd::File(file)) => Fd::File(file.into()),
592 Some(proto::primitives::fd::Fd::Stream(stream)) => Fd::Stream(stream.into()),
593 None => Fd::File(File { path: String::new() }), }
595 }
596}
597
598impl From<Fd> for proto::primitives::Fd {
599 fn from(fd: Fd) -> Self {
600 match fd {
601 Fd::File(file) => {
602 proto::primitives::Fd { fd: Some(proto::primitives::fd::Fd::File(file.into())) }
603 }
604 Fd::Stream(stream) => {
605 proto::primitives::Fd { fd: Some(proto::primitives::fd::Fd::Stream(stream.into())) }
606 }
607 }
608 }
609}
610
611impl From<proto::primitives::File> for File {
612 fn from(proto_file: proto::primitives::File) -> Self {
613 File { path: proto_file.path }
614 }
615}
616
617impl From<File> for proto::primitives::File {
618 fn from(file: File) -> Self {
619 proto::primitives::File { path: file.path }
620 }
621}
622
623impl From<proto::primitives::Stream> for Stream {
624 fn from(proto_stream: proto::primitives::Stream) -> Self {
625 Stream { local_socket: proto_stream.local_socket, peer_socket: proto_stream.peer_socket }
626 }
627}
628
629impl From<Stream> for proto::primitives::Stream {
630 fn from(stream: Stream) -> Self {
631 proto::primitives::Stream {
632 local_socket: stream.local_socket,
633 peer_socket: stream.peer_socket,
634 }
635 }
636}
637
638impl From<proto::primitives::Process> for Process {
639 fn from(proto_process: proto::primitives::Process) -> Self {
640 Process {
641 pid: proto_process.pid,
642 starttime: proto_process.starttime,
643 exe_path: proto_process.exe_path,
644 }
645 }
646}
647
648impl From<Process> for proto::primitives::Process {
649 fn from(process: Process) -> Self {
650 proto::primitives::Process {
651 pid: process.pid,
652 starttime: process.starttime,
653 exe_path: process.exe_path,
654 }
655 }
656}
657
658impl From<Policy> for proto::primitives::Policy {
659 fn from(policy: Policy) -> Self {
660 proto::primitives::Policy {
661 confidentiality: match policy.is_confidential() {
662 false => proto::primitives::Confidentiality::Public as i32,
663 true => proto::primitives::Confidentiality::Secret as i32,
664 },
665 integrity: policy.get_integrity(),
666 deleted: policy.is_deleted(),
667 consent: policy.get_consent(),
668 }
669 }
670}
671
672impl From<proto::primitives::Policy> for Policy {
673 fn from(proto_policy: proto::primitives::Policy) -> Self {
674 Policy::new(
675 match proto_policy.confidentiality {
676 x if x == proto::primitives::Confidentiality::Secret as i32 => {
677 ConfidentialityPolicy::Secret
678 }
679 _ => ConfidentialityPolicy::Public,
680 },
681 proto_policy.integrity,
682 proto_policy.deleted.into(),
683 proto_policy.consent,
684 )
685 }
686}