trace2e_core/traceability/core/
provenance.rs

1use std::{
2    collections::{HashMap, HashSet},
3    pin::Pin,
4    sync::Arc,
5    task::Poll,
6};
7
8use dashmap::DashMap;
9use tower::Service;
10#[cfg(feature = "trace2e_tracing")]
11use tracing::{debug, info};
12
13use crate::traceability::{
14    api::{ProvenanceRequest, ProvenanceResponse},
15    error::TraceabilityError,
16    naming::{NodeId, Resource},
17};
18
19type ProvenanceMap = DashMap<Resource, HashMap<String, HashSet<Resource>>>;
20
21/// Provenance service for tracking resources provenance
22#[derive(Debug, Default, Clone)]
23pub struct ProvenanceService {
24    node_id: String,
25    provenance: Arc<ProvenanceMap>,
26}
27
28impl ProvenanceService {
29    pub fn new(node_id: String) -> Self {
30        Self { node_id, provenance: Arc::new(DashMap::new()) }
31    }
32
33    fn init_provenance(&self, resource: &Resource) -> HashMap<String, HashSet<Resource>> {
34        if resource.is_stream().is_none() {
35            HashMap::from([(self.node_id.clone(), HashSet::from([resource.to_owned()]))])
36        } else {
37            HashMap::new()
38        }
39    }
40
41    /// Get the provenance of a resource
42    ///
43    /// This function returns a map of node IDs to the provenance of the resource for that node.
44    /// If the resource is not found, it returns an empty map.
45    async fn get_prov(&self, resource: &Resource) -> HashMap<String, HashSet<Resource>> {
46        if let Some(prov) = self.provenance.get(resource) {
47            prov.to_owned()
48        } else {
49            self.init_provenance(resource)
50        }
51    }
52
53    /// Set the provenance of a resource
54    async fn set_prov(&mut self, resource: Resource, prov: HashMap<String, HashSet<Resource>>) {
55        self.provenance.insert(resource, prov);
56    }
57
58    /// Update the provenance of the destination with the source
59    ///
60    /// Note that this function does not guarantee sequential consistency,
61    /// this is the role of the sequencer.
62    async fn update(
63        &mut self,
64        source: &Resource,
65        destination: &Resource,
66    ) -> Result<ProvenanceResponse, TraceabilityError> {
67        // Record the node IDs to which local resources propagate
68        if destination.is_stream().is_none() {
69            // Update the provenance of the destination with the source provenance
70            self.update_raw(self.get_prov(source).await, destination).await
71        } else {
72            Ok(ProvenanceResponse::ProvenanceNotUpdated)
73        }
74    }
75
76    /// Update the provenance of the destination with the raw source provenance
77    ///
78    /// Note that this function does not guarantee sequential consistency,
79    /// this is the role of the sequencer.
80    async fn update_raw(
81        &mut self,
82        source_prov: HashMap<String, HashSet<Resource>>,
83        destination: &Resource,
84    ) -> Result<ProvenanceResponse, TraceabilityError> {
85        let mut updated = false;
86        let mut destination_prov = self.get_prov(destination).await;
87        #[cfg(feature = "trace2e_tracing")]
88        debug!("[provenance-raw] Previous {:?} provenance: {:?}", destination, destination_prov);
89        for (node_id, node_source_prov) in source_prov {
90            if let Some(node_destination_prov) = destination_prov.get_mut(&node_id) {
91                if !node_destination_prov.is_superset(&node_source_prov) {
92                    node_destination_prov.extend(node_source_prov);
93                    updated = true;
94                }
95            } else {
96                destination_prov.insert(node_id, node_source_prov);
97                updated = true;
98            }
99        }
100        if updated {
101            #[cfg(feature = "trace2e_tracing")]
102            debug!("[provenance-raw] Updated {:?} provenance: {:?}", destination, destination_prov);
103            self.set_prov(destination.to_owned(), destination_prov).await;
104            Ok(ProvenanceResponse::ProvenanceUpdated)
105        } else {
106            Ok(ProvenanceResponse::ProvenanceNotUpdated)
107        }
108    }
109}
110
111impl NodeId for ProvenanceService {
112    fn node_id(&self) -> String {
113        self.node_id.to_owned()
114    }
115}
116
117impl Service<ProvenanceRequest> for ProvenanceService {
118    type Response = ProvenanceResponse;
119    type Error = TraceabilityError;
120    type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send>>;
121
122    fn poll_ready(&mut self, _: &mut std::task::Context<'_>) -> Poll<Result<(), Self::Error>> {
123        Poll::Ready(Ok(()))
124    }
125
126    fn call(&mut self, request: ProvenanceRequest) -> Self::Future {
127        let mut this = self.clone();
128        Box::pin(async move {
129            match request {
130                ProvenanceRequest::GetReferences(resource) => {
131                    #[cfg(feature = "trace2e_tracing")]
132                    info!("[provenance-{}] GetReferences: {:?}", this.node_id, resource);
133                    Ok(ProvenanceResponse::Provenance(this.get_prov(&resource).await))
134                }
135                ProvenanceRequest::UpdateProvenance { source, destination } => {
136                    #[cfg(feature = "trace2e_tracing")]
137                    info!(
138                        "[provenance-{}] UpdateProvenance: source: {:?}, destination: {:?}",
139                        this.node_id, source, destination
140                    );
141                    this.update(&source, &destination).await
142                }
143                ProvenanceRequest::UpdateProvenanceRaw { source_prov, destination } => {
144                    #[cfg(feature = "trace2e_tracing")]
145                    info!(
146                        "[provenance-{}] UpdateProvenanceRaw: source_prov: {:?}, destination: {:?}",
147                        this.node_id, source_prov, destination
148                    );
149                    this.update_raw(source_prov, &destination).await
150                }
151            }
152        })
153    }
154}
155
156#[cfg(test)]
157mod tests {
158    use super::*;
159
160    #[tokio::test]
161    async fn unit_provenance_update_simple() {
162        #[cfg(feature = "trace2e_tracing")]
163        crate::trace2e_tracing::init();
164        let mut provenance = ProvenanceService::default();
165        let process = Resource::new_process_mock(0);
166        let file = Resource::new_file("/tmp/test".to_string());
167
168        provenance.update(&file, &process).await.unwrap();
169        // Check that the process is now derived from the file
170        assert_eq!(
171            provenance.get_prov(&process).await,
172            HashMap::from([(String::new(), HashSet::from([file, process]))])
173        );
174    }
175
176    #[tokio::test]
177    async fn unit_provenance_update_circular() {
178        #[cfg(feature = "trace2e_tracing")]
179        crate::trace2e_tracing::init();
180        let mut provenance = ProvenanceService::default();
181        let process = Resource::new_process_mock(0);
182        let file = Resource::new_file("/tmp/test".to_string());
183
184        provenance.update(&process, &file).await.unwrap();
185        provenance.update(&file, &process).await.unwrap();
186
187        // Check the proper handling of circular dependencies
188        assert_eq!(provenance.get_prov(&file).await, provenance.get_prov(&process).await);
189    }
190
191    #[tokio::test]
192    async fn unit_provenance_update_multiple_nodes() {
193        #[cfg(feature = "trace2e_tracing")]
194        crate::trace2e_tracing::init();
195        let mut provenance = ProvenanceService::default();
196        let process0 = Resource::new_process_mock(0);
197        let process1 = Resource::new_process_mock(1);
198        let file0 = Resource::new_file("/tmp/test0".to_string());
199
200        provenance
201            .update_raw(
202                HashMap::from([
203                    ("10.0.0.1".to_string(), HashSet::from([process0.clone()])),
204                    ("10.0.0.2".to_string(), HashSet::from([process0.clone()])),
205                ]),
206                &process0,
207            )
208            .await
209            .unwrap();
210        provenance
211            .update_raw(
212                HashMap::from([
213                    ("10.0.0.1".to_string(), HashSet::from([process1.clone()])),
214                    ("10.0.0.2".to_string(), HashSet::from([file0.clone(), process1.clone()])),
215                ]),
216                &process0,
217            )
218            .await
219            .unwrap();
220
221        assert_eq!(
222            provenance.get_prov(&process0).await,
223            HashMap::from([
224                (String::new(), HashSet::from([process0.clone()])),
225                ("10.0.0.1".to_string(), HashSet::from([process0.clone(), process1.clone()])),
226                ("10.0.0.2".to_string(), HashSet::from([file0, process0, process1]))
227            ])
228        );
229    }
230
231    #[tokio::test]
232    async fn unit_provenance_service_flow_simple() {
233        #[cfg(feature = "trace2e_tracing")]
234        crate::trace2e_tracing::init();
235        let mut provenance = ProvenanceService::default();
236        let process = Resource::new_process_mock(0);
237        let file = Resource::new_file("/tmp/test".to_string());
238
239        assert_eq!(
240            provenance.call(ProvenanceRequest::GetReferences(process.clone())).await.unwrap(),
241            ProvenanceResponse::Provenance(HashMap::from([(
242                String::new(),
243                HashSet::from([process.clone()])
244            )]))
245        );
246
247        assert_eq!(
248            provenance
249                .call(ProvenanceRequest::UpdateProvenance {
250                    source: file.clone(),
251                    destination: process.clone(),
252                })
253                .await
254                .unwrap(),
255            ProvenanceResponse::ProvenanceUpdated
256        );
257
258        assert_eq!(
259            provenance
260                .call(ProvenanceRequest::UpdateProvenance {
261                    source: file.clone(),
262                    destination: process.clone(),
263                })
264                .await
265                .unwrap(),
266            ProvenanceResponse::ProvenanceNotUpdated
267        );
268
269        assert_eq!(
270            provenance.call(ProvenanceRequest::GetReferences(process.clone())).await.unwrap(),
271            ProvenanceResponse::Provenance(HashMap::from([(
272                String::new(),
273                HashSet::from([file, process])
274            )]))
275        );
276    }
277}