1use std::{collections::HashMap, path::PathBuf, sync::Arc};
7
8use async_trait::async_trait;
9use base64::Engine;
10use bytes::Bytes;
11use serde::{Deserialize, Serialize};
12use tokio::sync::RwLock;
13use crate::dev_log;
14
15use crate::{
16 Transport::{
17 Strategy::{TransportStats, TransportStrategy, TransportType},
18 TransportConfig,
19 },
20 WASM::{
21 HostBridge::HostBridgeImpl,
22 MemoryManager::{MemoryLimits, MemoryManagerImpl},
23 Runtime::{WASMConfig, WASMRuntime},
24 WASMStats,
25 },
26};
27
28#[derive(Clone, Debug)]
30pub struct WASMTransportImpl {
31 runtime:Arc<WASMRuntime>,
33 memory_manager:Arc<RwLock<MemoryManagerImpl>>,
35 bridge:Arc<HostBridgeImpl>,
37 modules:Arc<RwLock<HashMap<String, WASMModuleInfo>>>,
39 #[allow(dead_code)]
41 config:TransportConfig,
42 connected:Arc<RwLock<bool>>,
44 stats:Arc<RwLock<TransportStats>>,
46}
47
48#[derive(Debug, Clone)]
50pub struct WASMModuleInfo {
51 pub id:String,
53 pub name:Option<String>,
55 pub path:Option<PathBuf>,
57 pub loaded_at:u64,
59 pub function_stats:HashMap<String, FunctionCallStats>,
61}
62
63#[derive(Debug, Clone, Serialize, Deserialize)]
65pub struct FunctionCallStats {
66 pub call_count:u64,
68 pub total_time_us:u64,
70 pub last_call_at:Option<u64>,
72 pub error_count:u64,
74}
75
76impl FunctionCallStats {
77 pub fn record_call(&mut self, time_us:u64) {
79 self.call_count += 1;
80 self.total_time_us += time_us;
81 self.last_call_at = Some(
82 std::time::SystemTime::now()
83 .duration_since(std::time::UNIX_EPOCH)
84 .map(|d| d.as_secs())
85 .unwrap_or(0),
86 );
87 }
88
89 pub fn record_error(&mut self) { self.error_count += 1; }
91}
92
93impl Default for FunctionCallStats {
94 fn default() -> Self { Self { call_count:0, total_time_us:0, last_call_at:None, error_count:0 } }
95}
96
97impl WASMTransportImpl {
98 pub fn new(enable_wasi:bool, memory_limit_mb:u64, max_execution_time_ms:u64) -> anyhow::Result<Self> {
100 let config = WASMConfig::new(memory_limit_mb, max_execution_time_ms, enable_wasi);
101
102 let runtime_result = tokio::runtime::Runtime::new()
105 .map_err(|e| anyhow::anyhow!("Failed to create tokio runtime: {}", e))?
106 .block_on(WASMRuntime::new(config.clone()))
107 .map_err(|e| anyhow::anyhow!("Failed to create WASM runtime: {}", e))?;
108 let runtime = Arc::new(runtime_result);
109
110 let memory_limits = MemoryLimits::new(memory_limit_mb, (memory_limit_mb as f64 * 0.75) as u64, 100);
111 let memory_manager = Arc::new(RwLock::new(MemoryManagerImpl::new(memory_limits)));
112 let bridge = Arc::new(HostBridgeImpl::new());
113
114 Ok(Self {
115 runtime,
116 memory_manager,
117 bridge,
118 modules:Arc::new(RwLock::new(HashMap::new())),
119 config:TransportConfig::default(),
120 connected:Arc::new(RwLock::new(true)), stats:Arc::new(RwLock::new(TransportStats::default())),
122 })
123 }
124
125 pub fn with_config(wasm_config:WASMConfig, transport_config:TransportConfig) -> anyhow::Result<Self> {
127 let runtime_result = tokio::runtime::Runtime::new()
128 .map_err(|e| anyhow::anyhow!("Failed to create tokio runtime: {}", e))?
129 .block_on(WASMRuntime::new(wasm_config.clone()))
130 .map_err(|e| anyhow::anyhow!("Failed to create WASM runtime: {}", e))?;
131 let runtime = Arc::new(runtime_result);
132
133 let memory_limits = MemoryLimits::new(
134 wasm_config.memory_limit_mb,
135 (wasm_config.memory_limit_mb as f64 * 0.75) as u64,
136 100,
137 );
138 let memory_manager = Arc::new(RwLock::new(MemoryManagerImpl::new(memory_limits)));
139 let bridge = Arc::new(HostBridgeImpl::new());
140
141 Ok(Self {
142 runtime,
143 memory_manager,
144 bridge,
145 modules:Arc::new(RwLock::new(HashMap::new())),
146 config:transport_config,
147 connected:Arc::new(RwLock::new(true)),
148 stats:Arc::new(RwLock::new(TransportStats::default())),
149 })
150 }
151
152 pub fn runtime(&self) -> &Arc<WASMRuntime> { &self.runtime }
154
155 pub fn memory_manager(&self) -> &Arc<RwLock<MemoryManagerImpl>> { &self.memory_manager }
157
158 pub fn bridge(&self) -> &Arc<HostBridgeImpl> { &self.bridge }
160
161 pub async fn get_modules(&self) -> HashMap<String, WASMModuleInfo> { self.modules.read().await.clone() }
163
164 pub async fn get_wasm_stats(&self) -> WASMStats {
166 let memory_manager = self.memory_manager.read().await;
167 let managers = self.modules.read().await;
168
169 WASMStats {
170 modules_loaded:managers.len(),
171 active_instances:managers.len(), total_memory_mb:memory_manager.current_usage_mb() as u64,
173 total_execution_time_ms:0, function_calls:self.stats.read().await.messages_sent,
175 }
176 }
177
178 pub async fn call_wasm_function(
180 &self,
181 module_id:&str,
182 function_name:&str,
183 args:Vec<Bytes>,
184 ) -> anyhow::Result<Bytes> {
185 let start = std::time::Instant::now();
186
187 dev_log!("wasm", "Calling WASM function: {}::{} with {} arguments", module_id, function_name, args.len());
188
189 let modules = self.modules.read().await;
190 let _module = modules
191 .get(module_id)
192 .ok_or_else(|| anyhow::anyhow!("Module not found: {}", module_id))?;
193
194 let response = Bytes::new();
197
198 let mut modules_mut = self.modules.write().await;
200 if let Some(module) = modules_mut.get_mut(module_id) {
201 let stats = module.function_stats.entry(function_name.to_string()).or_default();
202 stats.record_call(start.elapsed().as_micros() as u64);
203 }
204
205 drop(modules_mut);
206
207 let mut stats = self.stats.write().await;
209 stats.record_sent(args.iter().map(|b| b.len() as u64).sum(), start.elapsed().as_micros() as u64);
210 stats.record_received(response.len() as u64);
211
212 Ok(response)
213 }
214}
215
216#[async_trait]
217impl TransportStrategy for WASMTransportImpl {
218 type Error = WASMTransportError;
219
220 async fn connect(&self) -> Result<(), Self::Error> {
221 dev_log!("transport", "WASM transport connecting");
222
223 *self.connected.write().await = true;
225
226 dev_log!("transport", "WASM transport connected");
227
228 Ok(())
229 }
230
231 async fn send(&self, request:&[u8]) -> Result<Vec<u8>, Self::Error> {
232 let start = std::time::Instant::now();
233
234 if !self.is_connected() {
235 return Err(WASMTransportError::NotConnected);
236 }
237
238 dev_log!("transport", "Sending WASM transport request ({} bytes)", request.len());
239
240 let request_str =
243 std::str::from_utf8(request).map_err(|e| WASMTransportError::InvalidRequest(e.to_string()))?;
244
245 let parts:Vec<&str> = request_str.splitn(3, ':').collect();
246 if parts.len() < 3 {
247 return Err(WASMTransportError::InvalidRequest("Invalid request format".to_string()));
248 }
249
250 let module_id = parts[0];
251 let function_name = parts[1];
252 let args_base64 = parts[2];
253
254 use base64::engine::general_purpose::STANDARD;
256 let args = vec![Bytes::from(
257 STANDARD
258 .decode(args_base64)
259 .map_err(|e| WASMTransportError::InvalidRequest(e.to_string()))?,
260 )];
261
262 let response = self
264 .call_wasm_function(module_id, function_name, args)
265 .await
266 .map_err(|e| WASMTransportError::FunctionCallFailed(e.to_string()))?;
267
268 let response_vec = response.to_vec();
270
271 let latency_us = start.elapsed().as_micros() as u64;
272
273 dev_log!("transport", "WASM transport request completed in {}µs", latency_us);
274
275 Ok(response_vec)
276 }
277
278 async fn send_no_response(&self, data:&[u8]) -> Result<(), Self::Error> {
279 if !self.is_connected() {
280 return Err(WASMTransportError::NotConnected);
281 }
282
283 dev_log!("transport", "Sending WASM transport request without response ({} bytes)", data.len());
284
285 self.send(data).await?;
287 Ok(())
288 }
289
290 async fn close(&self) -> Result<(), Self::Error> {
291 dev_log!("transport", "Closing WASM transport");
292
293 *self.connected.write().await = false;
294
295 dev_log!("transport", "WASM transport closed");
296
297 Ok(())
298 }
299
300 fn is_connected(&self) -> bool { self.connected.blocking_read().to_owned() }
301
302 fn transport_type(&self) -> TransportType { TransportType::WASM }
303}
304
305#[derive(Debug, thiserror::Error)]
307pub enum WASMTransportError {
308 #[error("Module not found: {0}")]
310 ModuleNotFound(String),
311
312 #[error("Function not found: {0}")]
314 FunctionNotFound(String),
315
316 #[error("Function call failed: {0}")]
318 FunctionCallFailed(String),
319
320 #[error("Memory error: {0}")]
322 MemoryError(String),
323
324 #[error("Runtime error: {0}")]
326 RuntimeError(String),
327
328 #[error("Invalid request: {0}")]
330 InvalidRequest(String),
331
332 #[error("Not connected")]
334 NotConnected,
335
336 #[error("Compilation failed: {0}")]
338 CompilationFailed(String),
339
340 #[error("Timeout")]
342 Timeout,
343}
344
345#[cfg(test)]
346mod tests {
347 use super::*;
348 use crate::Transport::Strategy::TransportStrategy;
349
350 #[test]
351 fn test_wasm_transport_creation() {
352 let result = WASMTransportImpl::new(true, 512, 30000);
353 assert!(result.is_ok());
354 let transport = result.unwrap();
355 assert!(transport.is_connected());
357 }
358
359 #[test]
360 fn test_function_call_stats() {
361 let mut stats = FunctionCallStats::default();
362 stats.record_call(100);
363 assert_eq!(stats.call_count, 1);
364 assert_eq!(stats.total_time_us, 100);
365 assert!(stats.last_call_at.is_some());
366 }
367
368 #[tokio::test]
369 async fn test_wasm_transport_not_connected_after_close() {
370 let transport = WASMTransportImpl::new(true, 512, 30000).unwrap();
371 let _:anyhow::Result<()> = transport.close().await.map_err(|e| anyhow::anyhow!(e.to_string()));
372 assert!(!transport.is_connected());
373 }
374
375 #[tokio::test]
376 async fn test_get_wasm_stats() {
377 let transport = WASMTransportImpl::new(true, 512, 30000).unwrap();
378 let stats = transport.get_wasm_stats().await;
379 assert_eq!(stats.modules_loaded, 0);
380 assert_eq!(stats.active_instances, 0);
381 }
382}