trace2e_core/traceability/services/
provenance.rs1use std::{collections::HashSet, pin::Pin, sync::Arc, task::Poll};
5
6use dashmap::DashMap;
7use tower::Service;
8
9use crate::traceability::infrastructure::naming::DisplayableResource;
10use tracing::info;
11
12use crate::traceability::{
13 api::types::{ProvenanceRequest, ProvenanceResponse},
14 error::TraceabilityError,
15 infrastructure::naming::{LocalizedResource, NodeId, Resource},
16};
17
18type ProvenanceMap = DashMap<Resource, HashSet<LocalizedResource>>;
19
20#[derive(Debug, Default, Clone)]
22pub struct ProvenanceService {
23 node_id: String,
24 provenance: Arc<ProvenanceMap>,
25}
26
27impl ProvenanceService {
28 pub fn new(node_id: String) -> Self {
29 Self { node_id, provenance: Arc::new(DashMap::new()) }
30 }
31
32 pub fn set_references(&self, resource: Resource, references: HashSet<LocalizedResource>) {
33 self.provenance.insert(resource, references);
34 }
35
36 fn init_provenance(&self, resource: &Resource) -> HashSet<LocalizedResource> {
37 if !resource.is_stream() {
38 HashSet::from([LocalizedResource::new(self.node_id.clone(), resource.to_owned())])
39 } else {
40 HashSet::new()
41 }
42 }
43
44 fn get_prov(&self, resource: &Resource) -> HashSet<LocalizedResource> {
49 if let Some(prov) = self.provenance.get(resource) {
50 prov.to_owned()
51 } else {
52 self.init_provenance(resource)
53 }
54 }
55
56 fn update(&mut self, source: &Resource, destination: &Resource) -> ProvenanceResponse {
61 self.update_raw(self.get_prov(source), destination)
63 }
64
65 fn update_raw(
70 &mut self,
71 source_prov: HashSet<LocalizedResource>,
72 destination: &Resource,
73 ) -> ProvenanceResponse {
74 let mut destination_prov = self.get_prov(destination);
75 if source_prov.is_subset(&destination_prov) {
76 info!(
77 "[provenance-raw] Provenance not updated: source_prov is subset of destination_prov"
78 );
79 ProvenanceResponse::ProvenanceNotUpdated
80 } else {
81 destination_prov.extend(source_prov);
82 info!(
83 destination_prov = %DisplayableResource::from(&destination_prov),
84 "[provenance-raw] Provenance updated"
85 );
86 self.provenance.insert(destination.to_owned(), destination_prov);
87 ProvenanceResponse::ProvenanceUpdated
88 }
89 }
90}
91
92impl NodeId for ProvenanceService {
93 fn node_id(&self) -> String {
94 self.node_id.to_owned()
95 }
96}
97
98impl Service<ProvenanceRequest> for ProvenanceService {
99 type Response = ProvenanceResponse;
100 type Error = TraceabilityError;
101 type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send>>;
102
103 fn poll_ready(&mut self, _: &mut std::task::Context<'_>) -> Poll<Result<(), Self::Error>> {
104 Poll::Ready(Ok(()))
105 }
106
107 fn call(&mut self, request: ProvenanceRequest) -> Self::Future {
108 let mut this = self.clone();
109 Box::pin(async move {
110 match request {
111 ProvenanceRequest::GetReferences(resource) => {
112 info!(node_id = %this.node_id, resource = %resource, "[provenance] GetReferences");
113 Ok(ProvenanceResponse::Provenance(this.get_prov(&resource)))
114 }
115 ProvenanceRequest::UpdateProvenance { source, destination } => {
116 info!(
117 node_id = %this.node_id,
118 source = %source,
119 destination = %destination,
120 "[provenance] UpdateProvenance"
121 );
122 Ok(this.update(&source, &destination))
123 }
124 ProvenanceRequest::UpdateProvenanceRaw { source_prov, destination } => {
125 info!(
126 node_id = %this.node_id,
127 source_prov = %DisplayableResource::from(&source_prov),
128 destination = %destination,
129 "[provenance] UpdateProvenanceRaw"
130 );
131 Ok(this.update_raw(source_prov, &destination))
132 }
133 }
134 })
135 }
136}
137
138#[cfg(test)]
139mod tests {
140 use super::*;
141
142 #[test]
143 fn unit_provenance_update_simple() {
144 crate::trace2e_tracing::init();
145 let mut provenance = ProvenanceService::default();
146 let process = LocalizedResource::new(provenance.node_id(), Resource::new_process_mock(0));
147 let file = LocalizedResource::new(
148 provenance.node_id(),
149 Resource::new_file("/tmp/test".to_string()),
150 );
151
152 assert_eq!(
153 provenance.update(file.resource(), process.resource()),
154 ProvenanceResponse::ProvenanceUpdated
155 );
156 assert_eq!(provenance.get_prov(process.resource()), HashSet::from([file, process]));
158 }
159
160 #[test]
161 fn unit_provenance_update_circular() {
162 crate::trace2e_tracing::init();
163 let mut provenance = ProvenanceService::default();
164 let process = LocalizedResource::new(provenance.node_id(), Resource::new_process_mock(0));
165 let file = LocalizedResource::new(
166 provenance.node_id(),
167 Resource::new_file("/tmp/test".to_string()),
168 );
169
170 assert_eq!(
171 provenance.update(process.resource(), file.resource()),
172 ProvenanceResponse::ProvenanceUpdated
173 );
174 assert_eq!(
175 provenance.update(file.resource(), process.resource()),
176 ProvenanceResponse::ProvenanceUpdated
177 );
178
179 assert_eq!(provenance.get_prov(file.resource()), provenance.get_prov(process.resource()));
181 }
182
183 #[tokio::test]
184 async fn unit_provenance_service_flow_simple() {
185 crate::trace2e_tracing::init();
186 let mut provenance = ProvenanceService::default();
187 let process = LocalizedResource::new(provenance.node_id(), Resource::new_process_mock(0));
188 let file = LocalizedResource::new(
189 provenance.node_id(),
190 Resource::new_file("/tmp/test".to_string()),
191 );
192
193 assert_eq!(
194 provenance
195 .call(ProvenanceRequest::GetReferences(process.resource().clone()))
196 .await
197 .unwrap(),
198 ProvenanceResponse::Provenance(HashSet::from([process.clone()]))
199 );
200
201 assert_eq!(
202 provenance
203 .call(ProvenanceRequest::UpdateProvenance {
204 source: file.resource().clone(),
205 destination: process.resource().clone(),
206 })
207 .await
208 .unwrap(),
209 ProvenanceResponse::ProvenanceUpdated
210 );
211
212 assert_eq!(
213 provenance
214 .call(ProvenanceRequest::UpdateProvenance {
215 source: file.resource().clone(),
216 destination: process.resource().clone(),
217 })
218 .await
219 .unwrap(),
220 ProvenanceResponse::ProvenanceNotUpdated
221 );
222
223 assert_eq!(
224 provenance
225 .call(ProvenanceRequest::GetReferences(process.resource().clone()))
226 .await
227 .unwrap(),
228 ProvenanceResponse::Provenance(HashSet::from([file, process,]))
229 );
230 }
231}