Skip to main content

Mountain/IPC/Enhanced/
SecureMessageChannel.rs

1//! # Secure Message Channel
2//!
3//! Advanced security enhancements for IPC messages including AES-256-GCM
4//! encryption, HMAC authentication, and secure key management.
5
6use std::{
7	collections::HashMap,
8	marker::PhantomData,
9	sync::Arc,
10	time::{Duration, SystemTime},
11};
12
13use ring::{
14	aead::{self, AES_256_GCM, LessSafeKey, NONCE_LEN, UnboundKey},
15	hmac,
16	rand::{SecureRandom, SystemRandom},
17};
18use serde::{Deserialize, Serialize};
19use tokio::sync::RwLock;
20use bincode::serde::{decode_from_slice, encode_to_vec};
21
22use crate::dev_log;
23
24/// Security configuration
25#[derive(Debug, Clone, Serialize, Deserialize)]
26pub struct SecurityConfig {
27	pub encryption_algorithm:String,
28	pub key_rotation_interval_hours:u64,
29	pub hmac_algorithm:String,
30	pub nonce_size_bytes:usize,
31	pub auth_tag_size_bytes:usize,
32	pub max_message_size_bytes:usize,
33}
34
35impl Default for SecurityConfig {
36	fn default() -> Self {
37		Self {
38			encryption_algorithm:"AES-256-GCM".to_string(),
39			// Rotate encryption keys every 24 hours for forward secrecy.
40			key_rotation_interval_hours:24,
41			hmac_algorithm:"HMAC-SHA256".to_string(),
42			nonce_size_bytes:NONCE_LEN,
43			auth_tag_size_bytes:AES_256_GCM.tag_len(),
44			// Maximum message size: 10MB to prevent memory exhaustion attacks.
45			max_message_size_bytes:10 * 1024 * 1024,
46		}
47	}
48}
49
50/// Encryption key with metadata
51#[derive(Debug, Clone)]
52struct EncryptionKey {
53	key:LessSafeKey,
54	created_at:SystemTime,
55	key_id:String,
56	usage_count:usize,
57}
58
59impl EncryptionKey {
60	fn new(key_bytes:&[u8]) -> Result<Self, String> {
61		let unbound_key =
62			UnboundKey::new(&AES_256_GCM, key_bytes).map_err(|e| format!("Failed to create unbound key: {}", e))?;
63
64		Ok(Self {
65			key:LessSafeKey::new(unbound_key),
66			created_at:SystemTime::now(),
67			key_id:Self::generate_key_id(),
68			usage_count:0,
69		})
70	}
71
72	fn generate_key_id() -> String {
73		let rng = SystemRandom::new();
74		let mut id_bytes = [0u8; 8];
75		rng.fill(&mut id_bytes).unwrap();
76		hex::encode(id_bytes)
77	}
78
79	fn is_expired(&self, rotation_interval:Duration) -> bool {
80		self.created_at.elapsed().unwrap_or_default() > rotation_interval
81	}
82
83	fn increment_usage(&mut self) { self.usage_count += 1; }
84}
85
86/// Secure message channel with encryption and authentication
87pub struct SecureMessageChannel {
88	pub config:SecurityConfig,
89	pub current_key:Arc<RwLock<EncryptionKey>>,
90	pub previous_keys:Arc<RwLock<HashMap<String, EncryptionKey>>>,
91	pub hmac_key:Arc<RwLock<Vec<u8>>>,
92	pub rng:SystemRandom,
93	pub key_rotation_task:Arc<RwLock<Option<tokio::task::JoinHandle<()>>>>,
94}
95
96impl SecureMessageChannel {
97	/// Create a new secure message channel
98	pub fn new(config:SecurityConfig) -> Result<Self, String> {
99		let rng = SystemRandom::new();
100
101		// Generate encryption key
102		let mut encryption_key_bytes = vec![0u8; 32];
103		rng.fill(&mut encryption_key_bytes)
104			.map_err(|e| format!("Failed to generate encryption key: {}", e))?;
105
106		let encryption_key = EncryptionKey::new(&encryption_key_bytes)?;
107
108		// Generate HMAC key
109		let mut hmac_key = vec![0u8; 32];
110		rng.fill(&mut hmac_key)
111			.map_err(|e| format!("Failed to generate HMAC key: {}", e))?;
112
113		let channel = Self {
114			config,
115			current_key:Arc::new(RwLock::new(encryption_key)),
116			previous_keys:Arc::new(RwLock::new(HashMap::new())),
117			hmac_key:Arc::new(RwLock::new(hmac_key)),
118			rng,
119			key_rotation_task:Arc::new(RwLock::new(None)),
120		};
121
122		dev_log!(
123			"ipc",
124			"[SecureMessageChannel] Created secure channel with {} encryption",
125			channel.config.encryption_algorithm
126		);
127
128		Ok(channel)
129	}
130
131	/// Start the secure channel with automatic key rotation
132	pub async fn start(&self) -> Result<(), String> {
133		// Start key rotation task
134		self.start_key_rotation().await;
135
136		dev_log!("ipc", "[SecureMessageChannel] Secure channel started");
137		Ok(())
138	}
139
140	/// Stop the secure channel
141	pub async fn stop(&self) -> Result<(), String> {
142		// Stop key rotation task
143		{
144			let mut rotation_task = self.key_rotation_task.write().await;
145			if let Some(task) = rotation_task.take() {
146				task.abort();
147			}
148		}
149
150		// Clear all cryptographic keys from memory.
151		{
152			let mut current_key = self.current_key.write().await;
153			// Replace with a zeroized key to overwrite sensitive material.
154			*current_key = EncryptionKey::new(&[0u8; 32]).unwrap();
155		}
156
157		{
158			let mut previous_keys = self.previous_keys.write().await;
159			previous_keys.clear();
160		}
161
162		{
163			let mut hmac_key = self.hmac_key.write().await;
164			// Zero out the HMAC key material to prevent leakage.
165			hmac_key.fill(0);
166		}
167
168		dev_log!("ipc", "[SecureMessageChannel] Secure channel stopped");
169		Ok(())
170	}
171
172	/// Encrypt and authenticate a message
173	pub async fn encrypt_message<T:Serialize>(&self, message:&T) -> Result<EncryptedMessage, String> {
174		// Serialize message
175		let serialized_data = encode_to_vec(message, bincode::config::standard())
176			.map_err(|e| format!("Failed to serialize message: {}", e))?;
177
178		// Check message size
179		if serialized_data.len() > self.config.max_message_size_bytes {
180			return Err(format!("Message too large: {} bytes", serialized_data.len()));
181		}
182
183		// Get current encryption key
184		let mut current_key = self.current_key.write().await;
185		current_key.increment_usage();
186
187		// Generate nonce
188		let mut nonce = vec![0u8; self.config.nonce_size_bytes];
189		self.rng
190			.fill(&mut nonce)
191			.map_err(|e| format!("Failed to generate nonce: {}", e))?;
192
193		// Encrypt message
194		let mut in_out = serialized_data.clone();
195		let nonce_slice:&[u8] = &nonce;
196		let nonce_array:[u8; NONCE_LEN] = nonce_slice.try_into().map_err(|_| "Invalid nonce length".to_string())?;
197
198		let aead_nonce = aead::Nonce::assume_unique_for_key(nonce_array);
199
200		current_key
201			.key
202			.seal_in_place_append_tag(aead_nonce, aead::Aad::empty(), &mut in_out)
203			.map_err(|e| format!("Encryption failed: {}", e))?;
204
205		// Create HMAC
206		let hmac_key = self.hmac_key.read().await;
207		let hmac_key = hmac::Key::new(hmac::HMAC_SHA256, &hmac_key);
208		let hmac_tag = hmac::sign(&hmac_key, &in_out);
209
210		let encrypted_message = EncryptedMessage {
211			key_id:current_key.key_id.clone(),
212			nonce:nonce.to_vec(),
213			ciphertext:in_out,
214			hmac_tag:hmac_tag.as_ref().to_vec(),
215			timestamp:SystemTime::now()
216				.duration_since(SystemTime::UNIX_EPOCH)
217				.unwrap_or_default()
218				.as_millis() as u64,
219		};
220
221		dev_log!(
222			"ipc",
223			"[SecureMessageChannel] Message encrypted (size: {} bytes)",
224			encrypted_message.ciphertext.len()
225		);
226
227		Ok(encrypted_message)
228	}
229
230	/// Decrypt and verify a message
231	pub async fn decrypt_message<T:for<'de> Deserialize<'de>>(&self, encrypted:&EncryptedMessage) -> Result<T, String> {
232		// Verify HMAC
233		let hmac_key = self.hmac_key.read().await;
234		let hmac_key = hmac::Key::new(hmac::HMAC_SHA256, &hmac_key);
235
236		hmac::verify(&hmac_key, &encrypted.ciphertext, &encrypted.hmac_tag)
237			.map_err(|_| "HMAC verification failed".to_string())?;
238
239		// Get encryption key
240		let encryption_key = self.get_encryption_key(&encrypted.key_id).await?;
241
242		// Decrypt message
243		let mut in_out = encrypted.ciphertext.clone();
244		let nonce_slice:&[u8] = &encrypted.nonce;
245		let nonce_array:[u8; NONCE_LEN] = nonce_slice.try_into().map_err(|_| "Invalid nonce length".to_string())?;
246
247		let nonce = aead::Nonce::assume_unique_for_key(nonce_array);
248
249		encryption_key
250			.key
251			.open_in_place(nonce, aead::Aad::empty(), &mut in_out)
252			.map_err(|e| format!("Decryption failed: {}", e))?;
253
254		// Remove authentication tag
255		let plaintext_len = in_out.len() - AES_256_GCM.tag_len();
256		in_out.truncate(plaintext_len);
257
258		// Deserialize message
259		let (message, _) = decode_from_slice(&in_out, bincode::config::standard())
260			.map_err(|e| format!("Failed to deserialize message: {}", e))?;
261
262		dev_log!("ipc", "[SecureMessageChannel] Message decrypted successfully");
263
264		Ok(message)
265	}
266
267	/// Rotate encryption keys
268	pub async fn rotate_keys(&self) -> Result<(), String> {
269		dev_log!("ipc", "[SecureMessageChannel] Rotating encryption keys");
270
271		// Generate new encryption key
272		let mut new_key_bytes = vec![0u8; 32];
273		self.rng
274			.fill(&mut new_key_bytes)
275			.map_err(|e| format!("Failed to generate new encryption key: {}", e))?;
276
277		let new_key = EncryptionKey::new(&new_key_bytes)?;
278
279		// Move current key to previous keys
280		{
281			let mut current_key = self.current_key.write().await;
282			let mut previous_keys = self.previous_keys.write().await;
283
284			previous_keys.insert(current_key.key_id.clone(), current_key.clone());
285			*current_key = new_key;
286		}
287
288		// Clean up old keys
289		self.cleanup_old_keys().await;
290
291		dev_log!("ipc", "[SecureMessageChannel] Key rotation completed");
292		Ok(())
293	}
294
295	/// Get encryption key by ID
296	async fn get_encryption_key(&self, key_id:&str) -> Result<EncryptionKey, String> {
297		// Check current key first
298		let current_key = self.current_key.read().await;
299		if current_key.key_id == key_id {
300			return Ok(current_key.clone());
301		}
302
303		// Check previous keys
304		let previous_keys = self.previous_keys.read().await;
305		if let Some(key) = previous_keys.get(key_id) {
306			return Ok(key.clone());
307		}
308
309		Err(format!("Encryption key not found: {}", key_id))
310	}
311
312	/// Start automatic key rotation
313	async fn start_key_rotation(&self) {
314		let channel = Arc::new(self.clone());
315
316		let rotation_interval = Duration::from_secs(self.config.key_rotation_interval_hours * 3600);
317
318		let task = tokio::spawn(async move {
319			let mut interval = tokio::time::interval(rotation_interval);
320
321			loop {
322				interval.tick().await;
323
324				if let Err(e) = channel.rotate_keys().await {
325					dev_log!("ipc", "error: [SecureMessageChannel] Automatic key rotation failed: {}", e);
326				}
327			}
328		});
329
330		{
331			let mut rotation_task = self.key_rotation_task.write().await;
332			*rotation_task = Some(task);
333		}
334	}
335
336	/// Cleanup old keys
337	async fn cleanup_old_keys(&self) {
338		let rotation_interval = Duration::from_secs(self.config.key_rotation_interval_hours * 3600);
339		// Keep previous keys for 2 rotation cycles to support key rollover.
340		let max_age = rotation_interval * 2;
341
342		let mut previous_keys = self.previous_keys.write().await;
343
344		previous_keys.retain(|_, key| !key.is_expired(max_age));
345
346		dev_log!("ipc", "[SecureMessageChannel] Cleaned up {} old keys", previous_keys.len());
347	}
348
349	/// Get security statistics
350	pub async fn get_stats(&self) -> SecurityStats {
351		let current_key = self.current_key.read().await;
352		let previous_keys = self.previous_keys.read().await;
353
354		SecurityStats {
355			current_key_id:current_key.key_id.clone(),
356			current_key_age_seconds:current_key.created_at.elapsed().unwrap_or_default().as_secs(),
357			current_key_usage_count:current_key.usage_count,
358			previous_keys_count:previous_keys.len(),
359			config:self.config.clone(),
360		}
361	}
362
363	/// Validate message integrity
364	pub async fn validate_message_integrity(&self, encrypted:&EncryptedMessage) -> Result<bool, String> {
365		// Check timestamp (prevent replay attacks)
366		let message_time = SystemTime::UNIX_EPOCH + Duration::from_millis(encrypted.timestamp);
367		let current_time = SystemTime::now();
368
369		if current_time.duration_since(message_time).unwrap_or_default() > Duration::from_secs(300) {
370			// Message is older than 5 minutes
371			return Ok(false);
372		}
373
374		// Verify HMAC
375		let hmac_key = self.hmac_key.read().await;
376		let hmac_key = hmac::Key::new(hmac::HMAC_SHA256, &hmac_key);
377
378		match hmac::verify(&hmac_key, &encrypted.ciphertext, &encrypted.hmac_tag) {
379			Ok(_) => Ok(true),
380			Err(_) => Ok(false),
381		}
382	}
383
384	/// Create a secure channel with default configuration
385	pub fn default_channel() -> Result<Self, String> { Self::new(SecurityConfig::default()) }
386
387	/// Create a high-security channel
388	pub fn high_security_channel() -> Result<Self, String> {
389		Self::new(SecurityConfig {
390			// Rotate keys hourly for maximum security.
391			key_rotation_interval_hours:1,
392			// Smaller message size limit: 1MB for stricter controls.
393			max_message_size_bytes:1 * 1024 * 1024,
394			..Default::default()
395		})
396	}
397}
398
399impl Clone for SecureMessageChannel {
400	fn clone(&self) -> Self {
401		Self {
402			config:self.config.clone(),
403			current_key:self.current_key.clone(),
404			previous_keys:self.previous_keys.clone(),
405			hmac_key:self.hmac_key.clone(),
406			rng:SystemRandom::new(),
407			key_rotation_task:Arc::new(RwLock::new(None)),
408		}
409	}
410}
411
412/// Encrypted message structure
413#[derive(Debug, Clone, Serialize, Deserialize)]
414pub struct EncryptedMessage {
415	pub key_id:String,
416	pub nonce:Vec<u8>,
417	pub ciphertext:Vec<u8>,
418	pub hmac_tag:Vec<u8>,
419	pub timestamp:u64,
420}
421
422/// Security statistics
423#[derive(Debug, Clone, Serialize, Deserialize)]
424pub struct SecurityStats {
425	pub current_key_id:String,
426	pub current_key_age_seconds:u64,
427	pub current_key_usage_count:usize,
428	pub previous_keys_count:usize,
429	pub config:SecurityConfig,
430}
431
432/// Utility functions for secure messaging
433impl SecureMessageChannel {
434	/// Generate a secure random key
435	pub fn generate_secure_key(key_size_bytes:usize) -> Result<Vec<u8>, String> {
436		let rng = SystemRandom::new();
437		let mut key = vec![0u8; key_size_bytes];
438
439		rng.fill(&mut key)
440			.map_err(|e| format!("Failed to generate secure key: {}", e))?;
441
442		Ok(key)
443	}
444
445	/// Calculate message overhead for encryption
446	pub fn calculate_encryption_overhead(message_size:usize) -> usize {
447		// Nonce + HMAC tag + encryption overhead + additional padding.
448		NONCE_LEN + AES_256_GCM.tag_len() + 16
449	}
450
451	/// Estimate encrypted message size
452	pub fn estimate_encrypted_size(original_size:usize) -> usize {
453		original_size + Self::calculate_encryption_overhead(original_size)
454	}
455
456	/// Create message with secure headers
457	pub async fn create_secure_message<T:Serialize>(
458		&self,
459		message:&T,
460		additional_headers:HashMap<String, String>,
461	) -> Result<SecureMessage<T>, String> {
462		let encrypted = self.encrypt_message(message).await?;
463
464		Ok(SecureMessage::<T> {
465			encrypted,
466			headers:additional_headers,
467			version:"1.0".to_string(),
468			_marker:PhantomData,
469		})
470	}
471}
472
473/// Secure message with headers
474#[derive(Debug, Clone, Serialize, Deserialize)]
475pub struct SecureMessage<T> {
476	pub encrypted:EncryptedMessage,
477	pub headers:HashMap<String, String>,
478	pub version:String,
479	#[serde(skip)]
480	_marker:PhantomData<T>,
481}
482
483#[cfg(test)]
484mod tests {
485	use super::*;
486
487	#[tokio::test]
488	async fn test_secure_channel_creation() {
489		let channel = SecureMessageChannel::default_channel().unwrap();
490		assert_eq!(channel.config.encryption_algorithm, "AES-256-GCM");
491	}
492
493	#[tokio::test]
494	async fn test_message_encryption_decryption() {
495		let channel = SecureMessageChannel::default_channel().unwrap();
496		channel.start().await.unwrap();
497
498		let test_message = "Hello, secure world!";
499		let encrypted = channel.encrypt_message(&test_message).await.unwrap();
500
501		assert!(!encrypted.ciphertext.is_empty());
502		assert!(!encrypted.hmac_tag.is_empty());
503		assert!(!encrypted.nonce.is_empty());
504
505		let decrypted:String = channel.decrypt_message(&encrypted).await.unwrap();
506		assert_eq!(decrypted, test_message);
507
508		channel.stop().await.unwrap();
509	}
510
511	#[tokio::test]
512	async fn test_message_validation() {
513		let channel = SecureMessageChannel::default_channel().unwrap();
514		channel.start().await.unwrap();
515
516		let test_message = "Test validation";
517		let encrypted = channel.encrypt_message(&test_message).await.unwrap();
518
519		let is_valid = channel.validate_message_integrity(&encrypted).await.unwrap();
520		assert!(is_valid);
521
522		channel.stop().await.unwrap();
523	}
524
525	#[tokio::test]
526	async fn test_key_rotation() {
527		let channel = SecureMessageChannel::default_channel().unwrap();
528		channel.start().await.unwrap();
529
530		let stats_before = channel.get_stats().await;
531
532		// Rotate keys
533		channel.rotate_keys().await.unwrap();
534
535		let stats_after = channel.get_stats().await;
536		assert_ne!(stats_before.current_key_id, stats_after.current_key_id);
537		assert_eq!(stats_after.previous_keys_count, 1);
538
539		channel.stop().await.unwrap();
540	}
541
542	#[test]
543	fn test_secure_key_generation() {
544		let key = SecureMessageChannel::generate_secure_key(32).unwrap();
545		assert_eq!(key.len(), 32);
546	}
547
548	#[test]
549	fn test_encryption_overhead_calculation() {
550		let overhead = SecureMessageChannel::calculate_encryption_overhead(100);
551		assert!(overhead > 0);
552
553		let estimated_size = SecureMessageChannel::estimate_encrypted_size(100);
554		assert!(estimated_size > 100);
555	}
556}