trace2e_core/traceability/core/
consent.rs1use std::{future::Future, pin::Pin, sync::Arc, task::Poll};
3
4use dashmap::{DashMap, Entry};
5use tokio::{sync::broadcast, time::Duration};
6use tower::Service;
7#[cfg(feature = "trace2e_tracing")]
8use tracing::info;
9
10use crate::traceability::{
11 api::{ConsentRequest, ConsentResponse},
12 error::TraceabilityError,
13 naming::Resource,
14};
15
16#[derive(Debug, Clone, Eq, PartialEq, Hash)]
17struct ConsentKey(Resource, Option<String>, Resource);
18
19#[derive(Debug, Clone)]
20enum ConsentState {
21 Pending { tx: broadcast::Sender<bool> },
22 Decided(bool),
23}
24
25#[derive(Default, Debug, Clone)]
26pub struct ConsentService {
27 timeout: u64,
28 states: Arc<DashMap<ConsentKey, ConsentState>>,
30}
31
32impl ConsentService {
33 pub fn new(timeout_ms: u64) -> Self {
37 Self { timeout: timeout_ms, states: Arc::new(DashMap::new()) }
38 }
39
40 async fn get_consent(
42 &self,
43 source: Resource,
44 destination: (Option<String>, Resource),
45 ) -> Result<bool, TraceabilityError> {
46 let key = ConsentKey(source, destination.0, destination.1);
47 match self.states.entry(key) {
48 Entry::Occupied(occ) => match occ.get() {
49 ConsentState::Decided(consent) => Ok(*consent),
50 ConsentState::Pending { tx } => {
51 let mut rx = tx.subscribe();
52 drop(occ);
54 if self.timeout == 0 {
55 rx.recv().await.map_err(|_| TraceabilityError::InternalTrace2eError)
56 } else {
57 match tokio::time::timeout(Duration::from_millis(self.timeout), rx.recv())
58 .await
59 {
60 Ok(res) => res.map_err(|_| TraceabilityError::InternalTrace2eError),
61 Err(_) => Err(TraceabilityError::ConsentRequestTimeout),
62 }
63 }
64 }
65 },
66 Entry::Vacant(vac) => {
67 let (tx, mut rx) = broadcast::channel(16);
68 vac.insert(ConsentState::Pending { tx });
69 if self.timeout == 0 {
70 rx.recv().await.map_err(|_| TraceabilityError::InternalTrace2eError)
71 } else {
72 match tokio::time::timeout(Duration::from_millis(self.timeout), rx.recv()).await
73 {
74 Ok(res) => res.map_err(|_| TraceabilityError::InternalTrace2eError),
75 Err(_) => Err(TraceabilityError::ConsentRequestTimeout),
76 }
77 }
78 }
79 }
80 }
81
82 #[allow(clippy::type_complexity)]
84 fn pending_requests(
85 &self,
86 ) -> Vec<((Resource, Option<String>, Resource), broadcast::Sender<bool>)> {
87 self.states
88 .iter()
89 .filter_map(|entry| match entry.value() {
90 ConsentState::Pending { tx } => {
91 let ConsentKey(src, node, dst) = entry.key().clone();
92 Some(((src, node, dst), tx.clone()))
93 }
94 ConsentState::Decided(_) => None,
95 })
96 .collect()
97 }
98
99 fn set_consent(
101 &self,
102 source: Resource,
103 destination: (Option<String>, Resource),
104 consent: bool,
105 ) {
106 let key = ConsentKey(source, destination.0, destination.1);
107 match self.states.entry(key.clone()) {
109 Entry::Occupied(entry) => {
110 match entry.get() {
111 ConsentState::Decided(prev_consent) if *prev_consent != consent => {
112 if let Some(mut entry) = self.states.get_mut(&key) {
113 *entry = ConsentState::Decided(consent);
114 }
115 }
116 ConsentState::Pending { tx } => {
117 if let Some(mut entry) = self.states.get_mut(&key) {
118 tx.send(consent).unwrap();
120 *entry = ConsentState::Decided(consent);
121 }
122 }
123 _ => {
124 }
126 }
127 }
128 Entry::Vacant(vac) => {
129 vac.insert(ConsentState::Decided(consent));
131 }
132 }
133 }
134}
135
136impl Service<ConsentRequest> for ConsentService {
137 type Response = ConsentResponse;
138 type Error = TraceabilityError;
139 type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send>>;
140
141 fn poll_ready(&mut self, _: &mut std::task::Context<'_>) -> Poll<Result<(), Self::Error>> {
142 Poll::Ready(Ok(()))
143 }
144
145 fn call(&mut self, request: ConsentRequest) -> Self::Future {
146 let this = self.clone();
147 Box::pin(async move {
148 match request {
149 ConsentRequest::RequestConsent { source, destination } => {
150 #[cfg(feature = "trace2e_tracing")]
151 info!(
152 "[consent] RequestConsent: source: {:?}, destination: {:?}",
153 source, destination
154 );
155 let consent = this.get_consent(source, destination).await?;
156 Ok(ConsentResponse::Consent(consent))
157 }
158 ConsentRequest::PendingRequests => {
159 #[cfg(feature = "trace2e_tracing")]
160 info!("[consent] PendingRequests");
161 Ok(ConsentResponse::PendingRequests(this.pending_requests()))
162 }
163 ConsentRequest::SetConsent { source, destination, consent } => {
164 #[cfg(feature = "trace2e_tracing")]
165 info!(
166 "[consent] SetConsent: source: {:?}, destination: {:?}, consent: {:?}",
167 source, destination, consent
168 );
169 this.set_consent(source, destination, consent);
170 Ok(ConsentResponse::Ack)
171 }
172 }
173 })
174 }
175}
176
177#[cfg(test)]
178mod tests {
179 use super::*;
180 use std::time::Duration;
181 use tower::{Service, ServiceBuilder, timeout::TimeoutLayer};
182
183 #[tokio::test]
184 #[ignore] async fn unit_consent_service_pending_then_set() {
186 #[cfg(feature = "trace2e_tracing")]
187 crate::trace2e_tracing::init();
188 let mut service = ConsentService::default();
189 let source = Resource::new_process_mock(0);
190 let dest = (Some("10.0.0.1".to_string()), Resource::new_file("/tmp/x".to_string()));
191
192 let mut svc_for_req = service.clone();
194 let src_clone = source.clone();
195 let dest_clone = dest.clone();
196 let waiter = tokio::spawn(async move {
197 svc_for_req
198 .call(ConsentRequest::RequestConsent { source: src_clone, destination: dest_clone })
199 .await
200 });
201
202 tokio::time::sleep(Duration::from_millis(5)).await;
204
205 let pending =
207 service.call(ConsentRequest::PendingRequests).await.expect("pending requests call ok");
208 if let ConsentResponse::PendingRequests(list) = pending {
209 assert_eq!(list.len(), 1);
210 let ((s, n, d), _tx) = &list[0];
211 assert_eq!(s, &source);
212 assert_eq!(n, &dest.0);
213 assert_eq!(d, &dest.1);
214 } else {
215 panic!("Expected PendingRequests response");
216 }
217
218 let res = service
220 .call(ConsentRequest::SetConsent {
221 source: source.clone(),
222 destination: dest.clone(),
223 consent: true,
224 })
225 .await
226 .expect("set consent ok");
227 assert!(matches!(res, ConsentResponse::Ack));
228
229 let waited = waiter.await.unwrap().unwrap();
230 assert!(matches!(waited, ConsentResponse::Consent(true)));
231 }
232
233 #[tokio::test]
234 #[ignore] async fn unit_consent_service_decided_returns_immediately() {
236 #[cfg(feature = "trace2e_tracing")]
237 crate::trace2e_tracing::init();
238 let mut service = ConsentService::default();
239 let source = Resource::new_process_mock(1);
240 let dest = (None, Resource::new_file("/tmp/y".to_string()));
241
242 let mut svc_for_req = service.clone();
244 let src_clone = source.clone();
245 let dest_clone = dest.clone();
246 let waiter = tokio::spawn(async move {
247 svc_for_req
248 .call(ConsentRequest::RequestConsent { source: src_clone, destination: dest_clone })
249 .await
250 });
251 tokio::time::sleep(Duration::from_millis(5)).await;
252 service
253 .call(ConsentRequest::SetConsent {
254 source: source.clone(),
255 destination: dest.clone(),
256 consent: false,
257 })
258 .await
259 .expect("set consent ok");
260 assert!(matches!(waiter.await.unwrap().unwrap(), ConsentResponse::Consent(false)));
262
263 let immediate = service
265 .call(ConsentRequest::RequestConsent { source, destination: dest })
266 .await
267 .unwrap();
268 assert!(matches!(immediate, ConsentResponse::Consent(false)));
269 }
270
271 #[tokio::test]
272 async fn unit_consent_service_times_out_with_layer() {
273 #[cfg(feature = "trace2e_tracing")]
274 crate::trace2e_tracing::init();
275
276 let mut svc = ServiceBuilder::new()
278 .layer(TimeoutLayer::new(Duration::from_millis(2)))
279 .service(ConsentService::new(1));
280
281 let source = Resource::new_process_mock(2);
282 let destination = (None, Resource::new_file("/tmp/timeout.txt".to_string()));
283
284 let result = svc.call(ConsentRequest::RequestConsent { source, destination }).await;
286
287 match result {
288 Ok(_) => panic!("expected timeout error, got Ok"),
289 Err(err) => {
290 if let Some(te) = err.downcast_ref::<TraceabilityError>() {
292 assert_eq!(*te, TraceabilityError::ConsentRequestTimeout);
293 } else if err.is::<tower::timeout::error::Elapsed>() {
294 panic!("outer TimeoutLayer elapsed before inner consent timeout");
295 } else {
296 panic!("unexpected error: {err}");
297 }
298 }
299 }
300 }
301}