1use std::{env, fs::File, io::BufReader, path::PathBuf, time::Duration};
39
40use tonic::transport::{Channel, Endpoint};
41#[cfg(feature = "mtls")]
42use rustls::ClientConfig;
43#[cfg(feature = "mtls")]
44use rustls::RootCertStore;
45
46use crate::dev_log;
47
48pub const DEFAULT_MOUNTAIN_ADDRESS:&str = "[::1]:50051";
55
56pub const DEFAULT_CONNECTION_TIMEOUT_SECS:u64 = 5;
58
59pub const DEFAULT_REQUEST_TIMEOUT_SECS:u64 = 30;
61
62#[cfg(feature = "mtls")]
67#[derive(Debug, Clone)]
68pub struct TlsConfig {
69 pub ca_cert_path:Option<PathBuf>,
72
73 pub client_cert_path:Option<PathBuf>,
75
76 pub client_key_path:Option<PathBuf>,
78
79 pub server_name:Option<String>,
81
82 pub verify_certs:bool,
84}
85
86#[cfg(feature = "mtls")]
87impl Default for TlsConfig {
88 fn default() -> Self {
89 Self {
90 ca_cert_path:None,
91 client_cert_path:None,
92 client_key_path:None,
93 server_name:None,
94 verify_certs:true,
95 }
96 }
97}
98
99#[cfg(feature = "mtls")]
100impl TlsConfig {
101 pub fn server_auth(ca_cert_path:PathBuf) -> Self {
109 Self {
110 ca_cert_path:Some(ca_cert_path),
111 client_cert_path:None,
112 client_key_path:None,
113 server_name:Some("localhost".to_string()),
114 verify_certs:true,
115 }
116 }
117
118 pub fn mtls(ca_cert_path:PathBuf, client_cert_path:PathBuf, client_key_path:PathBuf) -> Self {
128 Self {
129 ca_cert_path:Some(ca_cert_path),
130 client_cert_path:Some(client_cert_path),
131 client_key_path:Some(client_key_path),
132 server_name:Some("localhost".to_string()),
133 verify_certs:true,
134 }
135 }
136}
137
138#[cfg(feature = "mtls")]
149pub fn create_tls_client_config(tls_config:&TlsConfig) -> Result<ClientConfig, Box<dyn std::error::Error>> {
150 dev_log!("grpc", "Creating TLS client configuration");
151 let mut root_store = RootCertStore::empty();
153
154 if let Some(ca_path) = &tls_config.ca_cert_path {
155 dev_log!("grpc", "Loading CA certificate from {:?}", ca_path);
157 let ca_file = File::open(ca_path).map_err(|e| format!("Failed to open CA certificate file: {}", e))?;
158 let mut reader = BufReader::new(ca_file);
159
160 let certs:Result<Vec<_>, _> = rustls_pemfile::certs(&mut reader).collect();
161 let certs = certs.map_err(|e| format!("Failed to parse CA certificate: {}", e))?;
162
163 if certs.is_empty() {
164 return Err("No CA certificates found in file".into());
165 }
166
167 for cert in certs {
168 root_store
169 .add(cert)
170 .map_err(|e| format!("Failed to add CA certificate to root store: {}", e))?;
171 }
172
173 dev_log!("grpc", "Loaded CA certificate from {:?}", ca_path);
174 } else {
175 dev_log!("grpc", "Loading system root certificates");
177 let cert_result = rustls_native_certs::load_native_certs();
178
179 if !cert_result.errors.is_empty() {
181 dev_log!(
182 "grpc",
183 "warn: Encountered errors loading system certificates: {:?}",
184 cert_result.errors
185 );
186 }
187
188 let native_certs = cert_result.certs;
189
190 if native_certs.is_empty() {
191 dev_log!("grpc", "warn: No system root certificates found");
192 }
193
194 for cert in native_certs {
195 root_store
196 .add(cert)
197 .map_err(|e| format!("Failed to add system certificate to root store: {}", e))?;
198 }
199
200 dev_log!("grpc", "Loaded {} system root certificates", root_store.len());
201 }
202
203 let client_certs = if tls_config.client_cert_path.is_some() && tls_config.client_key_path.is_some() {
205 let cert_path = tls_config.client_cert_path.as_ref().unwrap();
206 let key_path = tls_config.client_key_path.as_ref().unwrap();
207
208 dev_log!("grpc", "Loading client certificate from {:?}", cert_path);
209 let cert_file = File::open(cert_path).map_err(|e| format!("Failed to open client certificate file: {}", e))?;
210 let mut cert_reader = BufReader::new(cert_file);
211
212 let certs:Result<Vec<_>, _> = rustls_pemfile::certs(&mut cert_reader).collect();
213 let certs = certs.map_err(|e| format!("Failed to parse client certificate: {}", e))?;
214
215 if certs.is_empty() {
216 return Err("No client certificates found in file".into());
217 }
218
219 dev_log!("grpc", "Loading client private key from {:?}", key_path);
220 let key_file = File::open(key_path).map_err(|e| format!("Failed to open private key file: {}", e))?;
221 let mut key_reader = BufReader::new(key_file);
222
223 let key = rustls_pemfile::private_key(&mut key_reader)
224 .map_err(|e| format!("Failed to parse private key: {}", e))?
225 .ok_or("No private key found in file")?;
226
227 Some((certs, key))
228 } else {
229 None
230 };
231
232 let mut config = match client_certs {
234 Some((certs, key)) => {
235 let client_config = ClientConfig::builder()
237 .with_root_certificates(root_store)
238 .with_client_auth_cert(certs, key)
239 .map_err(|e| format!("Failed to configure client authentication: {}", e))?;
240
241 dev_log!("grpc", "Configured mTLS with client certificate");
242 client_config
243 },
244 None => {
245 let client_config = ClientConfig::builder().with_root_certificates(root_store).with_no_client_auth();
248
249 dev_log!("grpc", "Configured TLS with server authentication only");
250 client_config
251 },
252 };
253
254 config.alpn_protocols = vec![b"h2".to_vec()];
256
257 if !tls_config.verify_certs {
262 dev_log!(
263 "grpc",
264 "warn: Certificate verification disabled - this is NOT secure for production!"
265 ); }
268
269 dev_log!("grpc", "TLS client configuration created successfully");
270 Ok(config)
271}
272
273#[derive(Debug, Clone)]
275pub struct MountainClientConfig {
276 pub address:String,
278
279 pub connection_timeout_secs:u64,
281
282 pub request_timeout_secs:u64,
284
285 #[cfg(feature = "mtls")]
287 pub tls_config:Option<TlsConfig>,
288}
289
290impl Default for MountainClientConfig {
291 fn default() -> Self {
292 Self {
293 address:DEFAULT_MOUNTAIN_ADDRESS.to_string(),
294 connection_timeout_secs:DEFAULT_CONNECTION_TIMEOUT_SECS,
295 request_timeout_secs:DEFAULT_REQUEST_TIMEOUT_SECS,
296 #[cfg(feature = "mtls")]
297 tls_config:None,
298 }
299 }
300}
301
302impl MountainClientConfig {
303 pub fn new(address:impl Into<String>) -> Self { Self { address:address.into(), ..Default::default() } }
311
312 pub fn from_env() -> Self {
332 let address = env::var("MOUNTAIN_ADDRESS").unwrap_or_else(|_| DEFAULT_MOUNTAIN_ADDRESS.to_string());
333
334 let connection_timeout_secs = env::var("MOUNTAIN_CONNECTION_TIMEOUT_SECS")
335 .ok()
336 .and_then(|s| s.parse().ok())
337 .unwrap_or(DEFAULT_CONNECTION_TIMEOUT_SECS);
338
339 let request_timeout_secs = env::var("MOUNTAIN_REQUEST_TIMEOUT_SECS")
340 .ok()
341 .and_then(|s| s.parse().ok())
342 .unwrap_or(DEFAULT_REQUEST_TIMEOUT_SECS);
343
344 #[cfg(feature = "mtls")]
345 let tls_config = if env::var("MOUNTAIN_TLS_ENABLED")
346 .map(|v| v == "1" || v.eq_ignore_ascii_case("true"))
347 .unwrap_or(false)
348 {
349 Some(TlsConfig {
350 ca_cert_path:env::var("MOUNTAIN_CA_CERT").ok().map(PathBuf::from),
351 client_cert_path:env::var("MOUNTAIN_CLIENT_CERT").ok().map(PathBuf::from),
352 client_key_path:env::var("MOUNTAIN_CLIENT_KEY").ok().map(PathBuf::from),
353 server_name:env::var("MOUNTAIN_SERVER_NAME").ok(),
354 verify_certs:env::var("MOUNTAIN_VERIFY_CERTS")
355 .map(|v| v != "0" && !v.eq_ignore_ascii_case("false"))
356 .unwrap_or(true),
357 })
358 } else {
359 None
360 };
361
362 #[cfg(not(feature = "mtls"))]
363 let tls_config = None;
364
365 Self {
366 address,
367 connection_timeout_secs,
368 request_timeout_secs,
369 #[cfg(feature = "mtls")]
370 tls_config,
371 }
372 }
373
374 pub fn with_connection_timeout(mut self, timeout_secs:u64) -> Self {
382 self.connection_timeout_secs = timeout_secs;
383 self
384 }
385
386 pub fn with_request_timeout(mut self, timeout_secs:u64) -> Self {
394 self.request_timeout_secs = timeout_secs;
395 self
396 }
397
398 #[cfg(feature = "mtls")]
406 pub fn with_tls(mut self, tls_config:TlsConfig) -> Self {
407 self.tls_config = Some(tls_config);
408 self
409 }
410}
411
412#[derive(Debug, Clone)]
418pub struct MountainClient {
419 channel:Channel,
421
422 config:MountainClientConfig,
424}
425
426impl MountainClient {
427 pub async fn connect(config:MountainClientConfig) -> Result<Self, Box<dyn std::error::Error>> {
438 dev_log!("grpc", "Connecting to Mountain at {}", config.address);
439 let endpoint = Endpoint::from_shared(config.address.clone())?
440 .connect_timeout(Duration::from_secs(config.connection_timeout_secs));
441
442 #[cfg(feature = "mtls")]
444 if let Some(tls_config) = &config.tls_config {
445 dev_log!("grpc", "TLS configuration provided, configuring secure connection");
446 let _client_config = create_tls_client_config(tls_config).map_err(|e| {
447 dev_log!("grpc", "error: Failed to create TLS client configuration: {}", e);
448 format!("TLS configuration error: {}", e)
449 })?;
450
451 let domain_name = tls_config.server_name.clone().unwrap_or_else(|| "localhost".to_string());
453 dev_log!("grpc", "Setting server name for SNI: {}", domain_name);
454 let tls = tonic::transport::ClientTlsConfig::new().domain_name(domain_name.clone());
456 let channel = endpoint
457 .tcp_keepalive(Some(Duration::from_secs(60)))
458 .tls_config(tls)?
459 .connect()
460 .await
461 .map_err(|e| format!("Failed to connect with TLS: {}", e))?;
462
463 dev_log!("grpc", "Successfully connected to Mountain at {} with TLS", config.address);
464 return Ok(Self { channel, config });
465 }
466
467 dev_log!("grpc", "Using unencrypted connection");
469 let channel = endpoint.connect().await?;
470 dev_log!("grpc", "Successfully connected to Mountain at {}", config.address);
471 Ok(Self { channel, config })
472 }
473
474 pub fn channel(&self) -> &Channel { &self.channel }
479
480 pub fn config(&self) -> &MountainClientConfig { &self.config }
485
486 pub async fn health_check(&self) -> Result<bool, Box<dyn std::error::Error>> {
493 dev_log!("grpc", "Checking Mountain health");
494 match tokio::time::timeout(Duration::from_secs(self.config.request_timeout_secs), async {
496 Ok::<(), Box<dyn std::error::Error>>(())
499 })
500 .await
501 {
502 Ok(Ok(())) => {
503 dev_log!("grpc", "Mountain health check: healthy");
504 Ok(true)
505 },
506 Ok(Err(e)) => {
507 dev_log!("grpc", "warn: Mountain health check: disconnected - {}", e);
508 Ok(false)
509 },
510 Err(_) => {
511 dev_log!("grpc", "warn: Mountain health check: timeout");
512 Ok(false)
513 },
514 }
515 }
516
517 pub async fn get_status(&self) -> Result<String, Box<dyn std::error::Error>> {
525 dev_log!("grpc", "Getting Mountain status");
526 Ok("connected".to_string())
529 }
530
531 pub async fn get_config(&self, key:&str) -> Result<Option<String>, Box<dyn std::error::Error>> {
542 dev_log!("grpc", "Getting Mountain config: {}", key);
543 Ok(None)
546 }
547
548 pub async fn set_config(&self, key:&str, value:&str) -> Result<(), Box<dyn std::error::Error>> {
560 dev_log!("grpc", "Setting Mountain config: {} = {}", key, value);
561 Ok(())
564 }
565}
566
567pub async fn connect_to_mountain() -> Result<MountainClient, Box<dyn std::error::Error>> {
572 MountainClient::connect(MountainClientConfig::default()).await
573}
574
575pub async fn connect_to_mountain_at(address:impl Into<String>) -> Result<MountainClient, Box<dyn std::error::Error>> {
583 MountainClient::connect(MountainClientConfig::new(address)).await
584}
585
586#[cfg(test)]
587mod tests {
588 use super::*;
589
590 #[test]
591 fn test_default_config() {
592 let config = MountainClientConfig::default();
593 assert_eq!(config.address, DEFAULT_MOUNTAIN_ADDRESS);
594 assert_eq!(config.connection_timeout_secs, DEFAULT_CONNECTION_TIMEOUT_SECS);
595 assert_eq!(config.request_timeout_secs, DEFAULT_REQUEST_TIMEOUT_SECS);
596 }
597
598 #[test]
599 fn test_config_builder() {
600 let config = MountainClientConfig::new("[::1]:50060")
601 .with_connection_timeout(10)
602 .with_request_timeout(60);
603
604 assert_eq!(config.address, "[::1]:50060");
605 assert_eq!(config.connection_timeout_secs, 10);
606 assert_eq!(config.request_timeout_secs, 60);
607 }
608
609 #[cfg(feature = "mtls")]
610 #[test]
611 fn test_tls_config_server_auth() {
612 let tls = TlsConfig::server_auth(std::path::PathBuf::from("/path/to/ca.pem"));
613 assert_eq!(tls.server_name, Some("localhost".to_string()));
614 assert!(tls.client_cert_path.is_none());
615 assert!(tls.client_key_path.is_none());
616 assert!(tls.ca_cert_path.is_some());
617 assert!(tls.verify_certs);
618 }
619
620 #[cfg(feature = "mtls")]
621 #[test]
622 fn test_tls_config_mtls() {
623 let tls = TlsConfig::mtls(
624 std::path::PathBuf::from("/path/to/ca.pem"),
625 std::path::PathBuf::from("/path/to/cert.pem"),
626 std::path::PathBuf::from("/path/to/key.pem"),
627 );
628 assert!(tls.client_cert_path.is_some());
629 assert!(tls.client_key_path.is_some());
630 assert!(tls.ca_cert_path.is_some());
631 assert!(tls.verify_certs);
632 assert_eq!(tls.server_name, Some("localhost".to_string()));
633 }
634
635 #[cfg(feature = "mtls")]
636 #[test]
637 fn test_tls_config_default() {
638 let tls = TlsConfig::default();
639 assert!(tls.ca_cert_path.is_none());
640 assert!(tls.client_cert_path.is_none());
641 assert!(tls.client_key_path.is_none());
642 assert!(tls.server_name.is_none());
643 assert!(tls.verify_certs);
644 }
645
646 #[test]
647 fn test_from_env_default() {
648 unsafe {
650 env::remove_var("MOUNTAIN_ADDRESS");
651 }
652 unsafe {
653 env::remove_var("MOUNTAIN_CONNECTION_TIMEOUT_SECS");
654 }
655 unsafe {
656 env::remove_var("MOUNTAIN_REQUEST_TIMEOUT_SECS");
657 }
658 unsafe {
659 env::remove_var("MOUNTAIN_TLS_ENABLED");
660 }
661
662 let config = MountainClientConfig::from_env();
663 assert_eq!(config.address, DEFAULT_MOUNTAIN_ADDRESS);
664 assert_eq!(config.connection_timeout_secs, DEFAULT_CONNECTION_TIMEOUT_SECS);
665 assert_eq!(config.request_timeout_secs, DEFAULT_REQUEST_TIMEOUT_SECS);
666 }
667
668 #[test]
669 fn test_from_env_custom() {
670 unsafe {
671 env::set_var("MOUNTAIN_ADDRESS", "[::1]:50060");
672 }
673 unsafe {
674 env::set_var("MOUNTAIN_CONNECTION_TIMEOUT_SECS", "10");
675 }
676 unsafe {
677 env::set_var("MOUNTAIN_REQUEST_TIMEOUT_SECS", "60");
678 }
679
680 let config = MountainClientConfig::from_env();
681 assert_eq!(config.address, "[::1]:50060");
682 assert_eq!(config.connection_timeout_secs, 10);
683 assert_eq!(config.request_timeout_secs, 60);
684
685 unsafe {
687 env::remove_var("MOUNTAIN_ADDRESS");
688 }
689 unsafe {
690 env::remove_var("MOUNTAIN_CONNECTION_TIMEOUT_SECS");
691 }
692 unsafe {
693 env::remove_var("MOUNTAIN_REQUEST_TIMEOUT_SECS");
694 }
695 }
696
697 #[cfg(feature = "mtls")]
698 #[test]
699 fn test_from_env_tls() {
700 unsafe {
701 env::set_var("MOUNTAIN_TLS_ENABLED", "1");
702 }
703 unsafe {
704 env::set_var("MOUNTAIN_CA_CERT", "/path/to/ca.pem");
705 }
706 unsafe {
707 env::set_var("MOUNTAIN_SERVER_NAME", "mymountain.com");
708 }
709
710 let config = MountainClientConfig::from_env();
711 assert!(config.tls_config.is_some());
712 let tls = config.tls_config.unwrap();
713 assert_eq!(tls.ca_cert_path, Some(std::path::PathBuf::from("/path/to/ca.pem")));
714 assert_eq!(tls.server_name, Some("mymountain.com".to_string()));
715 assert!(tls.verify_certs);
716
717 unsafe {
719 env::remove_var("MOUNTAIN_TLS_ENABLED");
720 }
721 unsafe {
722 env::remove_var("MOUNTAIN_CA_CERT");
723 }
724 unsafe {
725 env::remove_var("MOUNTAIN_SERVER_NAME");
726 }
727 }
728
729 #[cfg(feature = "mtls")]
730 #[test]
731 fn test_from_env_mtls() {
732 unsafe {
733 env::set_var("MOUNTAIN_TLS_ENABLED", "true");
734 }
735 unsafe {
736 env::set_var("MOUNTAIN_CA_CERT", "/path/to/ca.pem");
737 }
738 unsafe {
739 env::set_var("MOUNTAIN_CLIENT_CERT", "/path/to/cert.pem");
740 }
741 unsafe {
742 env::set_var("MOUNTAIN_CLIENT_KEY", "/path/to/key.pem");
743 }
744
745 let config = MountainClientConfig::from_env();
746 assert!(config.tls_config.is_some());
747 let tls = config.tls_config.unwrap();
748 assert_eq!(tls.ca_cert_path, Some(std::path::PathBuf::from("/path/to/ca.pem")));
749 assert_eq!(tls.client_cert_path, Some(std::path::PathBuf::from("/path/to/cert.pem")));
750 assert_eq!(tls.client_key_path, Some(std::path::PathBuf::from("/path/to/key.pem")));
751 assert!(tls.verify_certs);
752
753 unsafe {
755 env::remove_var("MOUNTAIN_TLS_ENABLED");
756 }
757 unsafe {
758 env::remove_var("MOUNTAIN_CA_CERT");
759 }
760 unsafe {
761 env::remove_var("MOUNTAIN_CLIENT_CERT");
762 }
763 unsafe {
764 env::remove_var("MOUNTAIN_CLIENT_KEY");
765 }
766 }
767}