trace2e_core/traceability/core/
consent.rs

1//! Consent service for managing user/operator consent for outgoing data flows.
2use std::{future::Future, pin::Pin, sync::Arc, task::Poll};
3
4use dashmap::{DashMap, Entry};
5use tokio::{sync::broadcast, time::Duration};
6use tower::Service;
7#[cfg(feature = "trace2e_tracing")]
8use tracing::info;
9
10use crate::traceability::{
11    api::{ConsentRequest, ConsentResponse},
12    error::TraceabilityError,
13    naming::Resource,
14};
15
16#[derive(Debug, Clone, Eq, PartialEq, Hash)]
17struct ConsentKey(Resource, Option<String>, Resource);
18
19#[derive(Debug, Clone)]
20enum ConsentState {
21    Pending { tx: broadcast::Sender<bool> },
22    Decided(bool),
23}
24
25#[derive(Default, Debug, Clone)]
26pub struct ConsentService {
27    timeout: u64,
28    /// Unified store of consent states keyed by (source, node_id, destination)
29    states: Arc<DashMap<ConsentKey, ConsentState>>,
30}
31
32impl ConsentService {
33    /// Create a new `ConsentService` with the specified timeout.
34    ///
35    /// Timeout is disabled if set to 0.
36    pub fn new(timeout_ms: u64) -> Self {
37        Self { timeout: timeout_ms, states: Arc::new(DashMap::new()) }
38    }
39
40    /// Internal method to get consent
41    async fn get_consent(
42        &self,
43        source: Resource,
44        destination: (Option<String>, Resource),
45    ) -> Result<bool, TraceabilityError> {
46        let key = ConsentKey(source, destination.0, destination.1);
47        match self.states.entry(key) {
48            Entry::Occupied(occ) => match occ.get() {
49                ConsentState::Decided(consent) => Ok(*consent),
50                ConsentState::Pending { tx } => {
51                    let mut rx = tx.subscribe();
52                    // Drop map guard before awaiting
53                    drop(occ);
54                    if self.timeout == 0 {
55                        rx.recv().await.map_err(|_| TraceabilityError::InternalTrace2eError)
56                    } else {
57                        match tokio::time::timeout(Duration::from_millis(self.timeout), rx.recv())
58                            .await
59                        {
60                            Ok(res) => res.map_err(|_| TraceabilityError::InternalTrace2eError),
61                            Err(_) => Err(TraceabilityError::ConsentRequestTimeout),
62                        }
63                    }
64                }
65            },
66            Entry::Vacant(vac) => {
67                let (tx, mut rx) = broadcast::channel(16);
68                vac.insert(ConsentState::Pending { tx });
69                if self.timeout == 0 {
70                    rx.recv().await.map_err(|_| TraceabilityError::InternalTrace2eError)
71                } else {
72                    match tokio::time::timeout(Duration::from_millis(self.timeout), rx.recv()).await
73                    {
74                        Ok(res) => res.map_err(|_| TraceabilityError::InternalTrace2eError),
75                        Err(_) => Err(TraceabilityError::ConsentRequestTimeout),
76                    }
77                }
78            }
79        }
80    }
81
82    /// Internal method to list pending requests
83    #[allow(clippy::type_complexity)]
84    fn pending_requests(
85        &self,
86    ) -> Vec<((Resource, Option<String>, Resource), broadcast::Sender<bool>)> {
87        self.states
88            .iter()
89            .filter_map(|entry| match entry.value() {
90                ConsentState::Pending { tx } => {
91                    let ConsentKey(src, node, dst) = entry.key().clone();
92                    Some(((src, node, dst), tx.clone()))
93                }
94                ConsentState::Decided(_) => None,
95            })
96            .collect()
97    }
98
99    /// Internal method to set consent
100    fn set_consent(
101        &self,
102        source: Resource,
103        destination: (Option<String>, Resource),
104        consent: bool,
105    ) {
106        let key = ConsentKey(source, destination.0, destination.1);
107        // Read-first: only upgrade to mutable if the state is Pending.
108        match self.states.entry(key.clone()) {
109            Entry::Occupied(entry) => {
110                match entry.get() {
111                    ConsentState::Decided(prev_consent) if *prev_consent != consent => {
112                        if let Some(mut entry) = self.states.get_mut(&key) {
113                            *entry = ConsentState::Decided(consent);
114                        }
115                    }
116                    ConsentState::Pending { tx } => {
117                        if let Some(mut entry) = self.states.get_mut(&key) {
118                            // Notify waiters, then transition to Decided.
119                            tx.send(consent).unwrap();
120                            *entry = ConsentState::Decided(consent);
121                        }
122                    }
123                    _ => {
124                        // Already decided; ignore
125                    }
126                }
127            }
128            Entry::Vacant(vac) => {
129                // No pending request; just insert the decided state.
130                vac.insert(ConsentState::Decided(consent));
131            }
132        }
133    }
134}
135
136impl Service<ConsentRequest> for ConsentService {
137    type Response = ConsentResponse;
138    type Error = TraceabilityError;
139    type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send>>;
140
141    fn poll_ready(&mut self, _: &mut std::task::Context<'_>) -> Poll<Result<(), Self::Error>> {
142        Poll::Ready(Ok(()))
143    }
144
145    fn call(&mut self, request: ConsentRequest) -> Self::Future {
146        let this = self.clone();
147        Box::pin(async move {
148            match request {
149                ConsentRequest::RequestConsent { source, destination } => {
150                    #[cfg(feature = "trace2e_tracing")]
151                    info!(
152                        "[consent] RequestConsent: source: {:?}, destination: {:?}",
153                        source, destination
154                    );
155                    let consent = this.get_consent(source, destination).await?;
156                    Ok(ConsentResponse::Consent(consent))
157                }
158                ConsentRequest::PendingRequests => {
159                    #[cfg(feature = "trace2e_tracing")]
160                    info!("[consent] PendingRequests");
161                    Ok(ConsentResponse::PendingRequests(this.pending_requests()))
162                }
163                ConsentRequest::SetConsent { source, destination, consent } => {
164                    #[cfg(feature = "trace2e_tracing")]
165                    info!(
166                        "[consent] SetConsent: source: {:?}, destination: {:?}, consent: {:?}",
167                        source, destination, consent
168                    );
169                    this.set_consent(source, destination, consent);
170                    Ok(ConsentResponse::Ack)
171                }
172            }
173        })
174    }
175}
176
177#[cfg(test)]
178mod tests {
179    use super::*;
180    use std::time::Duration;
181    use tower::{Service, ServiceBuilder, timeout::TimeoutLayer};
182
183    #[tokio::test]
184    #[ignore] // TODO: fix this test
185    async fn unit_consent_service_pending_then_set() {
186        #[cfg(feature = "trace2e_tracing")]
187        crate::trace2e_tracing::init();
188        let mut service = ConsentService::default();
189        let source = Resource::new_process_mock(0);
190        let dest = (Some("10.0.0.1".to_string()), Resource::new_file("/tmp/x".to_string()));
191
192        // Start a RequestConsent that will await a decision
193        let mut svc_for_req = service.clone();
194        let src_clone = source.clone();
195        let dest_clone = dest.clone();
196        let waiter = tokio::spawn(async move {
197            svc_for_req
198                .call(ConsentRequest::RequestConsent { source: src_clone, destination: dest_clone })
199                .await
200        });
201
202        // Give time for the pending entry to be created
203        tokio::time::sleep(Duration::from_millis(5)).await;
204
205        // Check that we have exactly one pending request
206        let pending =
207            service.call(ConsentRequest::PendingRequests).await.expect("pending requests call ok");
208        if let ConsentResponse::PendingRequests(list) = pending {
209            assert_eq!(list.len(), 1);
210            let ((s, n, d), _tx) = &list[0];
211            assert_eq!(s, &source);
212            assert_eq!(n, &dest.0);
213            assert_eq!(d, &dest.1);
214        } else {
215            panic!("Expected PendingRequests response");
216        }
217
218        // Decide consent and ensure waiter completes
219        let res = service
220            .call(ConsentRequest::SetConsent {
221                source: source.clone(),
222                destination: dest.clone(),
223                consent: true,
224            })
225            .await
226            .expect("set consent ok");
227        assert!(matches!(res, ConsentResponse::Ack));
228
229        let waited = waiter.await.unwrap().unwrap();
230        assert!(matches!(waited, ConsentResponse::Consent(true)));
231    }
232
233    #[tokio::test]
234    #[ignore] // TODO: fix this test
235    async fn unit_consent_service_decided_returns_immediately() {
236        #[cfg(feature = "trace2e_tracing")]
237        crate::trace2e_tracing::init();
238        let mut service = ConsentService::default();
239        let source = Resource::new_process_mock(1);
240        let dest = (None, Resource::new_file("/tmp/y".to_string()));
241
242        // First request will create pending, so set consent before awaiting a second request
243        let mut svc_for_req = service.clone();
244        let src_clone = source.clone();
245        let dest_clone = dest.clone();
246        let waiter = tokio::spawn(async move {
247            svc_for_req
248                .call(ConsentRequest::RequestConsent { source: src_clone, destination: dest_clone })
249                .await
250        });
251        tokio::time::sleep(Duration::from_millis(5)).await;
252        service
253            .call(ConsentRequest::SetConsent {
254                source: source.clone(),
255                destination: dest.clone(),
256                consent: false,
257            })
258            .await
259            .expect("set consent ok");
260        // First waiter should get false
261        assert!(matches!(waiter.await.unwrap().unwrap(), ConsentResponse::Consent(false)));
262
263        // Second request should return immediately with the decided value
264        let immediate = service
265            .call(ConsentRequest::RequestConsent { source, destination: dest })
266            .await
267            .unwrap();
268        assert!(matches!(immediate, ConsentResponse::Consent(false)));
269    }
270
271    #[tokio::test]
272    async fn unit_consent_service_times_out_with_layer() {
273        #[cfg(feature = "trace2e_tracing")]
274        crate::trace2e_tracing::init();
275
276        // Wrap with a larger timeout tower layer
277        let mut svc = ServiceBuilder::new()
278            .layer(TimeoutLayer::new(Duration::from_millis(2)))
279            .service(ConsentService::new(1));
280
281        let source = Resource::new_process_mock(2);
282        let destination = (None, Resource::new_file("/tmp/timeout.txt".to_string()));
283
284        // This call should pend internally and then error with ConsentRequestTimeout
285        let result = svc.call(ConsentRequest::RequestConsent { source, destination }).await;
286
287        match result {
288            Ok(_) => panic!("expected timeout error, got Ok"),
289            Err(err) => {
290                // TimeoutLayer boxes inner errors; downcast to our TraceabilityError
291                if let Some(te) = err.downcast_ref::<TraceabilityError>() {
292                    assert_eq!(*te, TraceabilityError::ConsentRequestTimeout);
293                } else if err.is::<tower::timeout::error::Elapsed>() {
294                    panic!("outer TimeoutLayer elapsed before inner consent timeout");
295                } else {
296                    panic!("unexpected error: {err}");
297                }
298            }
299        }
300    }
301}