1use std::{collections::HashMap, sync::Arc};
75
76use serde::{Deserialize, Serialize};
77use tokio::sync::RwLock;
78
79use crate::{AirError, Result, dev_log};
80
81#[derive(Debug, Clone)]
83pub struct TraceGenerator {
84 trace_spans:Arc<RwLock<HashMap<String, TraceSpan>>>,
85 sampling_config:Arc<RwLock<SamplingConfig>>,
86}
87
88#[derive(Debug, Clone, Serialize, Deserialize)]
90pub struct SamplingConfig {
91 pub sample_rate:f64,
93 pub critical_sample_rate:f64,
95 pub max_spans_per_trace:usize,
97 pub trace_ttl_ms:u64,
99}
100
101impl Default for SamplingConfig {
102 fn default() -> Self {
103 Self {
104 sample_rate:0.1, critical_sample_rate:1.0, max_spans_per_trace:1000,
107 trace_ttl_ms:3600000, }
109 }
110}
111
112impl SamplingConfig {
113 pub fn validate(&self) -> Result<()> {
115 if self.sample_rate < 0.0 || self.sample_rate > 1.0 {
116 return Err(crate::AirError::Internal("sample_rate must be between 0.0 and 1.0".to_string()));
117 }
118 if self.critical_sample_rate < 0.0 || self.critical_sample_rate > 1.0 {
119 return Err(crate::AirError::Internal(
120 "critical_sample_rate must be between 0.0 and 1.0".to_string(),
121 ));
122 }
123 if self.max_spans_per_trace == 0 {
124 return Err(crate::AirError::Internal(
125 "max_spans_per_trace must be greater than 0".to_string(),
126 ));
127 }
128 if self.trace_ttl_ms == 0 {
129 return Err(crate::AirError::Internal("trace_ttl_ms must be greater than 0".to_string()));
130 }
131 Ok(())
132 }
133}
134
135#[derive(Debug, Clone, Serialize, Deserialize)]
137pub struct TraceSpan {
138 pub span_id:String,
139 pub trace_id:String,
140 pub parent_span_id:Option<String>,
141 pub operation_name:String,
142 pub start_time:u64,
143 pub end_time:Option<u64>,
144 pub status:SpanStatus,
145 pub attributes:HashMap<String, String>,
146 pub events:Vec<SpanEvent>,
147 pub error:Option<String>,
148 pub duration_ms:Option<u64>,
149}
150
151#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
153pub enum SpanStatus {
154 Started,
155 Active,
156 Completed,
157 Failed,
158 Cancelled,
159}
160
161#[derive(Debug, Clone, Serialize, Deserialize)]
163pub struct SpanEvent {
164 pub timestamp:u64,
165 pub name:String,
166 pub attributes:HashMap<String, String>,
167}
168
169#[derive(Debug, Clone, Serialize, Deserialize)]
171pub struct TraceMetadata {
172 pub trace_id:String,
173 pub root_span_id:String,
174 pub total_spans:usize,
175 pub root_operation:String,
176 pub start_time:u64,
177 pub end_time:Option<u64>,
178 pub total_duration_ms:Option<u64>,
179 pub status:TraceStatus,
180}
181
182#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
184pub enum TraceStatus {
185 InProgress,
186 Completed,
187 Failed,
188 Cancelled,
189}
190
191#[derive(Debug, Clone, Serialize, Deserialize)]
193pub struct PropagationContext {
194 pub TraceId:String,
195 pub SpanId:String,
196 pub CorrelationId:String,
197 pub ParentSpanId:Option<String>,
198}
199
200impl TraceGenerator {
201 pub fn new() -> Self {
203 Self {
204 trace_spans:Arc::new(RwLock::new(HashMap::new())),
205 sampling_config:Arc::new(RwLock::new(SamplingConfig::default())),
206 }
207 }
208
209 pub fn with_sampling(sampling_config:SamplingConfig) -> Result<Self> {
211 sampling_config
212 .validate()
213 .map_err(|e| AirError::Internal(format!("Invalid sampling config: {}", e)))?;
214
215 Ok(Self {
216 trace_spans:Arc::new(RwLock::new(HashMap::new())),
217 sampling_config:Arc::new(RwLock::new(sampling_config)),
218 })
219 }
220
221 pub fn generate_trace_id() -> String {
223 std::panic::catch_unwind(|| uuid::Uuid::new_v4().to_string()).unwrap_or_else(|e| {
224 dev_log!("air", "error: [Tracing] Panic in generate_trace_id, using fallback: {:?}", e);
225 format!("{:x}", rand::random::<u64>())
226 })
227 }
228
229 pub fn generate_span_id() -> String {
231 std::panic::catch_unwind(|| uuid::Uuid::new_v4().to_string()).unwrap_or_else(|e| {
232 dev_log!("air", "error: [Tracing] Panic in generate_span_id, using fallback: {:?}", e);
233 format!("{:x}", rand::random::<u64>())
234 })
235 }
236
237 pub async fn should_sample(&self, is_critical:bool) -> bool {
239 let config = self.sampling_config.read().await;
240 let rate = if is_critical { config.critical_sample_rate } else { config.sample_rate };
241
242 rand::random::<f64>() < rate
243 }
244
245 pub fn parse_trace_context(header:&str) -> Result<PropagationContext> {
247 let parts:Vec<&str> = header.split(';').collect();
248
249 let mut trace_id = String::new();
250 let mut parent_span_id = None;
251
252 for part in parts {
253 let kv:Vec<&str> = part.split('=').collect();
254 if kv.len() != 2 {
255 continue;
256 }
257
258 match kv[0].trim() {
259 "traceparent" => {
260 let trace_parent:Vec<&str> = kv[1].trim().split('-').collect();
261 if trace_parent.len() >= 2 {
262 trace_id = trace_parent[1].to_string();
263 if trace_parent.len() >= 3 {
264 parent_span_id = Some(trace_parent[2].to_string());
265 }
266 }
267 },
268 _ => {},
269 }
270 }
271
272 if trace_id.is_empty() {
273 return Err(AirError::Internal("Invalid trace context header".to_string()));
274 }
275
276 Ok(PropagationContext {
277 TraceId:trace_id,
278 SpanId:Self::generate_span_id(),
279 CorrelationId:crate::Utility::GenerateRequestId(),
280 ParentSpanId:parent_span_id,
281 })
282 }
283
284 pub async fn create_span(
286 &self,
287 trace_id:String,
288 operation_name:impl Into<String>,
289 parent_span_id:Option<String>,
290 attributes:Option<HashMap<String, String>>,
291 ) -> Result<TraceSpan> {
292 let span_id = Self::generate_span_id();
293 let operation_name = operation_name.into();
294
295 let span = TraceSpan {
296 span_id:span_id.clone(),
297 trace_id:trace_id.clone(),
298 parent_span_id:parent_span_id.clone(),
299 operation_name:operation_name.clone(),
300 start_time:crate::Utility::CurrentTimestamp(),
301 end_time:None,
302 status:SpanStatus::Started,
303 attributes:attributes.unwrap_or_default(),
304 events:Vec::new(),
305 error:None,
306 duration_ms:None,
307 };
308
309 let mut spans = self.trace_spans.write().await;
310
311 let trace_span_count = spans.values().filter(|s| s.trace_id == trace_id).count();
313
314 let config = self.sampling_config.read().await;
315 if trace_span_count >= config.max_spans_per_trace {
316 dev_log!(
317 "air",
318 "warn: [Tracing] Trace {} exceeds max spans ({}), dropping span {}",
319 trace_id,
320 config.max_spans_per_trace,
321 span_id
322 );
323 return Err(AirError::Internal("Max spans per trace exceeded".to_string()));
324 }
325
326 spans.insert(span_id.clone(), span.clone());
327
328 Ok(span)
329 }
330
331 pub async fn add_span_event(
333 &self,
334 span_id:&str,
335 event_name:impl Into<String>,
336 attributes:HashMap<String, String>,
337 ) -> Result<()> {
338 let event = SpanEvent {
339 timestamp:crate::Utility::CurrentTimestamp(),
340 name:event_name.into(),
341 attributes:self.sanitize_attributes(attributes),
342 };
343
344 let mut spans = self.trace_spans.write().await;
345 if let Some(span) = spans.get_mut(span_id) {
346 span.events.push(event);
347 Ok(())
348 } else {
349 Err(AirError::Internal(format!("Span not found: {}", span_id)))
350 }
351 }
352
353 pub async fn mark_span_active(&self, span_id:&str) -> Result<()> {
355 let mut spans = self.trace_spans.write().await;
356 if let Some(span) = spans.get_mut(span_id) {
357 span.status = SpanStatus::Active;
358 Ok(())
359 } else {
360 Err(AirError::Internal(format!("Span not found: {}", span_id)))
361 }
362 }
363
364 pub async fn complete_span(&self, span_id:&str, error:Option<String>) -> Result<u64> {
366 let Now = crate::Utility::CurrentTimestamp();
367 let mut spans = self.trace_spans.write().await;
368
369 if let Some(span) = spans.get_mut(span_id) {
370 span.end_time = Some(Now);
371 span.duration_ms = Some(Now.saturating_sub(span.start_time));
372 span.status = if error.is_some() { SpanStatus::Failed } else { SpanStatus::Completed };
373 span.error = error.map(|e| self.sanitize_error_message(&e));
374 Ok(span.duration_ms.unwrap_or(0))
375 } else {
376 Err(AirError::Internal(format!("Span not found: {}", span_id)))
377 }
378 }
379
380 pub async fn add_span_attribute(&self, span_id:&str, key:String, value:String) -> Result<()> {
382 self.add_span_attributes(span_id, HashMap::from([(key, value)])).await
383 }
384
385 pub async fn add_span_attributes(&self, span_id:&str, attributes:HashMap<String, String>) -> Result<()> {
387 let sanitized = self.sanitize_attributes(attributes);
388 let mut spans = self.trace_spans.write().await;
389
390 if let Some(span) = spans.get_mut(span_id) {
391 for (key, value) in sanitized {
392 span.attributes.insert(key, value);
393 }
394 Ok(())
395 } else {
396 Err(AirError::Internal(format!("Span not found: {}", span_id)))
397 }
398 }
399
400 pub async fn get_span(&self, span_id:&str) -> Result<TraceSpan> {
402 let spans = self.trace_spans.read().await;
403 spans
404 .get(span_id)
405 .cloned()
406 .ok_or_else(|| AirError::Internal(format!("Span not found: {}", span_id)))
407 }
408
409 pub async fn get_trace_spans(&self, trace_id:&str) -> Result<Vec<TraceSpan>> {
411 let spans = self.trace_spans.read().await;
412 Ok(spans.values().filter(|span| span.trace_id == trace_id).cloned().collect())
413 }
414
415 pub async fn get_trace_metadata(&self, trace_id:&str) -> Result<TraceMetadata> {
417 let trace_spans = self.get_trace_spans(trace_id).await?;
418
419 if trace_spans.is_empty() {
420 return Err(AirError::Internal(format!("Trace not found: {}", trace_id)));
421 }
422
423 let root_span = trace_spans
424 .iter()
425 .find(|s| s.parent_span_id.is_none())
426 .ok_or_else(|| AirError::Internal("No root span found".to_string()))?;
427
428 let total_duration_ms = trace_spans.iter().filter_map(|s| s.duration_ms).max();
429
430 let status = if trace_spans.iter().any(|s| s.status == SpanStatus::Failed) {
431 TraceStatus::Failed
432 } else if trace_spans
433 .iter()
434 .all(|s| s.status == SpanStatus::Completed || s.status == SpanStatus::Failed)
435 {
436 TraceStatus::Completed
437 } else {
438 TraceStatus::InProgress
439 };
440
441 let end_time = trace_spans.iter().filter_map(|s| s.end_time).max();
442
443 Ok(TraceMetadata {
444 trace_id:trace_id.to_string(),
445 root_span_id:root_span.span_id.clone(),
446 total_spans:trace_spans.len(),
447 root_operation:root_span.operation_name.clone(),
448 start_time:root_span.start_time,
449 end_time,
450 total_duration_ms,
451 status,
452 })
453 }
454
455 pub async fn export_trace(&self, trace_id:&str) -> Result<String> {
457 let spans = self.get_trace_spans(trace_id).await?;
458 let metadata = self.get_trace_metadata(trace_id).await?;
459
460 let export = serde_json::json!({
461 "metadata": metadata,
462 "spans": spans,
463 });
464
465 serde_json::to_string_pretty(&export)
466 .map_err(|e| AirError::Serialization(format!("Failed to export trace: {}", e)))
467 }
468
469 pub async fn cleanup_old_spans(&self, older_than_ms:Option<u64>) -> Result<usize> {
471 let Now = crate::Utility::CurrentTimestamp();
472 let ttl = older_than_ms.unwrap_or_else(|| {
473 tokio::task::block_in_place(|| {
474 tokio::runtime::Handle::current().block_on(async { self.sampling_config.read().await.trace_ttl_ms })
475 })
476 });
477
478 let mut spans = self.trace_spans.write().await;
479 let original_len = spans.len();
480
481 spans.retain(|_, span| span.end_time.map_or(true, |end| Now.saturating_sub(end) < ttl));
482
483 Ok(original_len.saturating_sub(spans.len()))
484 }
485
486 pub async fn get_statistics(&self) -> TraceStatistics {
488 let spans = self.trace_spans.read().await;
489
490 let total_traces = spans
491 .values()
492 .map(|s| s.trace_id.clone())
493 .collect::<std::collections::HashSet<_>>()
494 .len();
495
496 let completed_spans = spans.values().filter(|s| s.status == SpanStatus::Completed).count();
497
498 let failed_spans = spans.values().filter(|s| s.status == SpanStatus::Failed).count();
499
500 let in_progress_spans = spans
501 .values()
502 .filter(|s| s.status == SpanStatus::Started || s.status == SpanStatus::Active)
503 .count();
504
505 TraceStatistics {
506 total_traces:total_traces as u64,
507 total_spans:spans.len() as u64,
508 completed_spans:completed_spans as u64,
509 failed_spans:failed_spans as u64,
510 in_progress_spans:in_progress_spans as u64,
511 }
512 }
513
514 fn sanitize_attributes(&self, mut attributes:HashMap<String, String>) -> HashMap<String, String> {
516 let sensitive_keys = vec![
517 "password",
518 "token",
519 "secret",
520 "api_key",
521 "authorization",
522 "credential",
523 "auth",
524 "private_key",
525 "session_token",
526 ];
527
528 let attr_keys:Vec<String> = attributes.keys().cloned().collect();
530
531 for key in sensitive_keys {
532 let key_lower = key.to_lowercase();
533 for attr_key in &attr_keys {
534 if attr_key.to_lowercase().contains(&key_lower) {
535 attributes.insert(attr_key.clone(), "[REDACTED]".to_string());
536 }
537 }
538 }
539
540 attributes
541 }
542
543 fn sanitize_error_message(&self, message:&str) -> String {
545 let mut sanitized = message.to_string();
546
547 let patterns = vec![
548 (r"(?i)password[=:]\S+", "password=[REDACTED]"),
549 (r"(?i)token[=:]\S+", "token=[REDACTED]"),
550 (r"(?i)secret[=:]\S+", "secret=[REDACTED]"),
551 (r"(?i)(api|private)[_-]?key[=:]\S+", "api_key=[REDACTED]"),
552 (
553 r"(?i)authorization[=[:space:]]+Bearer[[:space:]]+\S+",
554 "Authorization: Bearer [REDACTED]",
555 ),
556 ];
557
558 for (pattern, replacement) in patterns {
559 if let Ok(re) = regex::Regex::new(pattern) {
560 sanitized = re.replace_all(&sanitized, replacement).to_string();
561 }
562 }
563
564 sanitized
565 }
566}
567
568#[derive(Debug, Clone, Serialize, Deserialize)]
570pub struct TraceStatistics {
571 pub total_traces:u64,
572 pub total_spans:u64,
573 pub completed_spans:u64,
574 pub failed_spans:u64,
575 pub in_progress_spans:u64,
576}
577
578impl Default for TraceGenerator {
579 fn default() -> Self { Self::new() }
580}
581
582static TRACE_GENERATOR:std::sync::OnceLock<TraceGenerator> = std::sync::OnceLock::new();
584
585pub fn get_trace_generator() -> &'static TraceGenerator { TRACE_GENERATOR.get_or_init(TraceGenerator::new) }
587
588pub fn initialize_tracing(sampling_config:Option<SamplingConfig>) -> Result<()> {
590 let generator = if let Some(config) = sampling_config {
591 TraceGenerator::with_sampling(config)?
592 } else {
593 TraceGenerator::new()
594 };
595
596 let _old = TRACE_GENERATOR.set(generator);
597 dev_log!("air", "[Tracing] Trace generator initialized with tracing");
598 Ok(())
599}
600
601thread_local! {
602 static PROPAGATION_CONTEXT: std::cell::RefCell<Option<PropagationContext>> = std::cell::RefCell::new(None);
603}
604
605pub fn set_propagation_context(context:PropagationContext) {
607 PROPAGATION_CONTEXT.with(|ctx| {
608 *ctx.borrow_mut() = Some(context);
609 });
610}
611
612pub fn get_propagation_context() -> Option<PropagationContext> { PROPAGATION_CONTEXT.with(|ctx| ctx.borrow().clone()) }
614
615pub async fn create_propagation_context(TraceId:String, ParentSpanId:Option<String>) -> PropagationContext {
617 let SpanId = TraceGenerator::generate_span_id();
618 let CorrelationId = crate::Utility::GenerateRequestId();
619
620 PropagationContext { TraceId, SpanId, CorrelationId, ParentSpanId }
621}
622
623pub fn create_trace_context_header(context:&PropagationContext) -> String {
625 format!("traceparent=00-{}-{}-01", context.TraceId, context.SpanId)
626}