trace2e_core/traceability/core/
provenance.rs1use std::{
2 collections::{HashMap, HashSet},
3 pin::Pin,
4 sync::Arc,
5 task::Poll,
6};
7
8use dashmap::DashMap;
9use tower::Service;
10#[cfg(feature = "trace2e_tracing")]
11use tracing::{debug, info};
12
13use crate::traceability::{
14 api::{ProvenanceRequest, ProvenanceResponse},
15 error::TraceabilityError,
16 naming::{NodeId, Resource},
17};
18
19type ProvenanceMap = DashMap<Resource, HashMap<String, HashSet<Resource>>>;
20
21#[derive(Debug, Default, Clone)]
23pub struct ProvenanceService {
24 node_id: String,
25 provenance: Arc<ProvenanceMap>,
26}
27
28impl ProvenanceService {
29 pub fn new(node_id: String) -> Self {
30 Self { node_id, provenance: Arc::new(DashMap::new()) }
31 }
32
33 fn init_provenance(&self, resource: &Resource) -> HashMap<String, HashSet<Resource>> {
34 if resource.is_stream().is_none() {
35 HashMap::from([(self.node_id.clone(), HashSet::from([resource.to_owned()]))])
36 } else {
37 HashMap::new()
38 }
39 }
40
41 async fn get_prov(&self, resource: &Resource) -> HashMap<String, HashSet<Resource>> {
46 if let Some(prov) = self.provenance.get(resource) {
47 prov.to_owned()
48 } else {
49 self.init_provenance(resource)
50 }
51 }
52
53 async fn set_prov(&mut self, resource: Resource, prov: HashMap<String, HashSet<Resource>>) {
55 self.provenance.insert(resource, prov);
56 }
57
58 async fn update(
63 &mut self,
64 source: &Resource,
65 destination: &Resource,
66 ) -> Result<ProvenanceResponse, TraceabilityError> {
67 if destination.is_stream().is_none() {
69 self.update_raw(self.get_prov(source).await, destination).await
71 } else {
72 Ok(ProvenanceResponse::ProvenanceNotUpdated)
73 }
74 }
75
76 async fn update_raw(
81 &mut self,
82 source_prov: HashMap<String, HashSet<Resource>>,
83 destination: &Resource,
84 ) -> Result<ProvenanceResponse, TraceabilityError> {
85 let mut updated = false;
86 let mut destination_prov = self.get_prov(destination).await;
87 #[cfg(feature = "trace2e_tracing")]
88 debug!("[provenance-raw] Previous {:?} provenance: {:?}", destination, destination_prov);
89 for (node_id, node_source_prov) in source_prov {
90 if let Some(node_destination_prov) = destination_prov.get_mut(&node_id) {
91 if !node_destination_prov.is_superset(&node_source_prov) {
92 node_destination_prov.extend(node_source_prov);
93 updated = true;
94 }
95 } else {
96 destination_prov.insert(node_id, node_source_prov);
97 updated = true;
98 }
99 }
100 if updated {
101 #[cfg(feature = "trace2e_tracing")]
102 debug!("[provenance-raw] Updated {:?} provenance: {:?}", destination, destination_prov);
103 self.set_prov(destination.to_owned(), destination_prov).await;
104 Ok(ProvenanceResponse::ProvenanceUpdated)
105 } else {
106 Ok(ProvenanceResponse::ProvenanceNotUpdated)
107 }
108 }
109}
110
111impl NodeId for ProvenanceService {
112 fn node_id(&self) -> String {
113 self.node_id.to_owned()
114 }
115}
116
117impl Service<ProvenanceRequest> for ProvenanceService {
118 type Response = ProvenanceResponse;
119 type Error = TraceabilityError;
120 type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send>>;
121
122 fn poll_ready(&mut self, _: &mut std::task::Context<'_>) -> Poll<Result<(), Self::Error>> {
123 Poll::Ready(Ok(()))
124 }
125
126 fn call(&mut self, request: ProvenanceRequest) -> Self::Future {
127 let mut this = self.clone();
128 Box::pin(async move {
129 match request {
130 ProvenanceRequest::GetReferences(resource) => {
131 #[cfg(feature = "trace2e_tracing")]
132 info!("[provenance-{}] GetReferences: {:?}", this.node_id, resource);
133 Ok(ProvenanceResponse::Provenance(this.get_prov(&resource).await))
134 }
135 ProvenanceRequest::UpdateProvenance { source, destination } => {
136 #[cfg(feature = "trace2e_tracing")]
137 info!(
138 "[provenance-{}] UpdateProvenance: source: {:?}, destination: {:?}",
139 this.node_id, source, destination
140 );
141 this.update(&source, &destination).await
142 }
143 ProvenanceRequest::UpdateProvenanceRaw { source_prov, destination } => {
144 #[cfg(feature = "trace2e_tracing")]
145 info!(
146 "[provenance-{}] UpdateProvenanceRaw: source_prov: {:?}, destination: {:?}",
147 this.node_id, source_prov, destination
148 );
149 this.update_raw(source_prov, &destination).await
150 }
151 }
152 })
153 }
154}
155
156#[cfg(test)]
157mod tests {
158 use super::*;
159
160 #[tokio::test]
161 async fn unit_provenance_update_simple() {
162 #[cfg(feature = "trace2e_tracing")]
163 crate::trace2e_tracing::init();
164 let mut provenance = ProvenanceService::default();
165 let process = Resource::new_process_mock(0);
166 let file = Resource::new_file("/tmp/test".to_string());
167
168 provenance.update(&file, &process).await.unwrap();
169 assert_eq!(
171 provenance.get_prov(&process).await,
172 HashMap::from([(String::new(), HashSet::from([file, process]))])
173 );
174 }
175
176 #[tokio::test]
177 async fn unit_provenance_update_circular() {
178 #[cfg(feature = "trace2e_tracing")]
179 crate::trace2e_tracing::init();
180 let mut provenance = ProvenanceService::default();
181 let process = Resource::new_process_mock(0);
182 let file = Resource::new_file("/tmp/test".to_string());
183
184 provenance.update(&process, &file).await.unwrap();
185 provenance.update(&file, &process).await.unwrap();
186
187 assert_eq!(provenance.get_prov(&file).await, provenance.get_prov(&process).await);
189 }
190
191 #[tokio::test]
192 async fn unit_provenance_update_multiple_nodes() {
193 #[cfg(feature = "trace2e_tracing")]
194 crate::trace2e_tracing::init();
195 let mut provenance = ProvenanceService::default();
196 let process0 = Resource::new_process_mock(0);
197 let process1 = Resource::new_process_mock(1);
198 let file0 = Resource::new_file("/tmp/test0".to_string());
199
200 provenance
201 .update_raw(
202 HashMap::from([
203 ("10.0.0.1".to_string(), HashSet::from([process0.clone()])),
204 ("10.0.0.2".to_string(), HashSet::from([process0.clone()])),
205 ]),
206 &process0,
207 )
208 .await
209 .unwrap();
210 provenance
211 .update_raw(
212 HashMap::from([
213 ("10.0.0.1".to_string(), HashSet::from([process1.clone()])),
214 ("10.0.0.2".to_string(), HashSet::from([file0.clone(), process1.clone()])),
215 ]),
216 &process0,
217 )
218 .await
219 .unwrap();
220
221 assert_eq!(
222 provenance.get_prov(&process0).await,
223 HashMap::from([
224 (String::new(), HashSet::from([process0.clone()])),
225 ("10.0.0.1".to_string(), HashSet::from([process0.clone(), process1.clone()])),
226 ("10.0.0.2".to_string(), HashSet::from([file0, process0, process1]))
227 ])
228 );
229 }
230
231 #[tokio::test]
232 async fn unit_provenance_service_flow_simple() {
233 #[cfg(feature = "trace2e_tracing")]
234 crate::trace2e_tracing::init();
235 let mut provenance = ProvenanceService::default();
236 let process = Resource::new_process_mock(0);
237 let file = Resource::new_file("/tmp/test".to_string());
238
239 assert_eq!(
240 provenance.call(ProvenanceRequest::GetReferences(process.clone())).await.unwrap(),
241 ProvenanceResponse::Provenance(HashMap::from([(
242 String::new(),
243 HashSet::from([process.clone()])
244 )]))
245 );
246
247 assert_eq!(
248 provenance
249 .call(ProvenanceRequest::UpdateProvenance {
250 source: file.clone(),
251 destination: process.clone(),
252 })
253 .await
254 .unwrap(),
255 ProvenanceResponse::ProvenanceUpdated
256 );
257
258 assert_eq!(
259 provenance
260 .call(ProvenanceRequest::UpdateProvenance {
261 source: file.clone(),
262 destination: process.clone(),
263 })
264 .await
265 .unwrap(),
266 ProvenanceResponse::ProvenanceNotUpdated
267 );
268
269 assert_eq!(
270 provenance.call(ProvenanceRequest::GetReferences(process.clone())).await.unwrap(),
271 ProvenanceResponse::Provenance(HashMap::from([(
272 String::new(),
273 HashSet::from([file, process])
274 )]))
275 );
276 }
277}