1use std::{collections::HashMap, sync::Arc};
7
8use anyhow::Result;
9use serde::{Deserialize, Serialize};
10use tokio::sync::RwLock;
11use crate::dev_log;
12use wasmtime::{Caller, Linker};
13
14use crate::WASM::HostBridge::{FunctionSignature, HostBridgeImpl, HostBridgeImpl as HostBridge, HostFunctionCallback, ParamType, ReturnType};
15
16pub struct HostFunctionRegistry {
18 functions:Arc<RwLock<HashMap<String, RegisteredHostFunction>>>,
20 #[allow(dead_code)]
22 bridge:Arc<HostBridge>,
23}
24
25#[derive(Debug, Clone)]
27struct RegisteredHostFunction {
28 #[allow(dead_code)]
30 name:String,
31 #[allow(dead_code)]
33 signature:FunctionSignature,
34 callback:Option<HostFunctionCallback>,
36 #[allow(dead_code)]
38 registered_at:u64,
39 stats:FunctionStats,
41}
42
43#[derive(Debug, Clone, Default)]
45pub struct FunctionStats {
46 pub call_count:u64,
48 pub total_execution_ns:u64,
50 pub last_call_at:Option<u64>,
52 pub error_count:u64,
54}
55
56#[derive(Debug, Clone, Serialize, Deserialize)]
58pub struct ExportConfig {
59 pub auto_export:bool,
61 pub enable_stats:bool,
63 pub max_functions:usize,
65 pub name_prefix:Option<String>,
67}
68
69impl Default for ExportConfig {
70 fn default() -> Self {
71 Self {
72 auto_export:true,
73 enable_stats:true,
74 max_functions:1000,
75 name_prefix:Some("host_".to_string()),
76 }
77 }
78}
79
80pub struct FunctionExportImpl {
82 registry:Arc<HostFunctionRegistry>,
83 config:ExportConfig,
84}
85
86impl FunctionExportImpl {
87 pub fn new(bridge:Arc<HostBridge>) -> Self {
89 Self {
90 registry:Arc::new(HostFunctionRegistry { functions:Arc::new(RwLock::new(HashMap::new())), bridge }),
91 config:ExportConfig::default(),
92 }
93 }
94
95 pub fn with_config(bridge:Arc<HostBridge>, config:ExportConfig) -> Self {
97 Self {
98 registry:Arc::new(HostFunctionRegistry { functions:Arc::new(RwLock::new(HashMap::new())), bridge }),
99 config,
100 }
101 }
102
103 pub async fn register_function(
105 &self,
106 name:&str,
107 signature:FunctionSignature,
108 callback:HostFunctionCallback,
109 ) -> Result<()> {
110 dev_log!("wasm", "Registering host function for export: {}", name);
111
112 let functions = self.registry.functions.read().await;
113
114 if functions.len() >= self.config.max_functions {
116 return Err(anyhow::anyhow!(
117 "Maximum number of exported functions reached: {}",
118 self.config.max_functions
119 ));
120 }
121
122 drop(functions);
123
124 let mut functions = self.registry.functions.write().await;
125
126 let registered_at = std::time::SystemTime::now().duration_since(std::time::UNIX_EPOCH)?.as_secs();
127
128 functions.insert(
129 name.to_string(),
130 RegisteredHostFunction {
131 name:name.to_string(),
132 signature,
133 callback:Some(callback),
134 registered_at,
135 stats:FunctionStats::default(),
136 },
137 );
138
139 dev_log!("wasm", "Host function registered for WASM export: {}", name);
140 Ok(())
141 }
142
143 pub async fn register_functions(
145 &self,
146 signatures:Vec<FunctionSignature>,
147 callbacks:Vec<HostFunctionCallback>,
148 ) -> Result<()> {
149 if signatures.len() != callbacks.len() {
150 return Err(anyhow::anyhow!("Number of signatures must match number of callbacks"));
151 }
152
153 for (sig, callback) in signatures.into_iter().zip(callbacks) {
154 let name = sig.name.clone();
155 self.register_function(&name, sig, callback).await?;
156 }
157
158 Ok(())
159 }
160
161 pub async fn export_to_linker<T>(&self, linker:&mut Linker<T>) -> Result<()>
163 where
164 T: Send + 'static, {
165 dev_log!(
166 "wasm",
167 "Exporting {} host functions to linker",
168 self.registry.functions.read().await.len()
169 );
170
171 let functions = self.registry.functions.read().await;
172
173 for (name, func) in functions.iter() {
174 self.export_single_function(linker, name, func)?;
175 }
176
177 dev_log!("wasm", "All host functions exported to linker");
178 Ok(())
179 }
180
181 fn export_single_function<T>(&self, linker:&mut Linker<T>, name:&str, func:&RegisteredHostFunction) -> Result<()>
183 where
184 T: Send + 'static, {
185 dev_log!("wasm", "Exporting function: {}", name);
186
187 let callback = func
188 .callback
189 .ok_or_else(|| anyhow::anyhow!("No callback available for function: {}", name))?;
190
191 let func_name = if let Some(prefix) = &self.config.name_prefix {
192 format!("{}{}", prefix, name)
193 } else {
194 name.to_string()
195 };
196
197 let func_name_for_debug = func_name.clone();
198 let func_name_inner = func_name.clone();
199
200 let _wrapped_callback =
202 move |_caller:Caller<'_, T>, args:&[wasmtime::Val]| -> Result<Vec<wasmtime::Val>, wasmtime::Trap> {
203 let _start = std::time::Instant::now();
204
205 let args_bytes:Result<Vec<bytes::Bytes>, _> = args
207 .iter()
208 .map(|arg| {
209 match arg {
210 wasmtime::Val::I32(i) => {
211 serde_json::to_vec(i)
212 .map(bytes::Bytes::from)
213 .map_err(|e| anyhow::anyhow!("Serialization error: {}", e))
214 },
215 wasmtime::Val::I64(i) => {
216 serde_json::to_vec(i)
217 .map(bytes::Bytes::from)
218 .map_err(|e| anyhow::anyhow!("Serialization error: {}", e))
219 },
220 wasmtime::Val::F32(f) => {
221 serde_json::to_vec(f)
222 .map(bytes::Bytes::from)
223 .map_err(|e| anyhow::anyhow!("Serialization error: {}", e))
224 },
225 wasmtime::Val::F64(f) => {
226 serde_json::to_vec(f)
227 .map(bytes::Bytes::from)
228 .map_err(|e| anyhow::anyhow!("Serialization error: {}", e))
229 },
230 _ => Err(anyhow::anyhow!("Unsupported argument type")),
231 }
232 })
233 .collect();
234
235 let args_bytes = args_bytes.map_err(|_| {
236 dev_log!("wasm", "warn: error converting arguments for function '{}'", func_name_inner);
237 wasmtime::Trap::StackOverflow
238 })?;
239
240 let result = callback(args_bytes);
242
243 match result {
244 Ok(response_bytes) => {
245 let result_val:serde_json::Value = serde_json::from_slice(&response_bytes).map_err(|_| {
247 dev_log!("wasm", "warn: error deserializing response for function '{}'", func_name_inner);
248 wasmtime::Trap::StackOverflow
249 })?;
250
251 let ret_val = match result_val {
252 serde_json::Value::Number(n) => {
253 if let Some(i) = n.as_i64() {
254 wasmtime::Val::I32(i as i32)
255 } else if let Some(f) = n.as_f64() {
256 wasmtime::Val::I64(f as i64)
257 } else {
258 dev_log!("wasm", "warn: invalid number format for function '{}'", func_name_inner);
259 return Err(wasmtime::Trap::StackOverflow);
260 }
261 },
262 _ => {
263 dev_log!("wasm", "warn: unsupported response type for function '{}'", func_name_inner);
264 return Err(wasmtime::Trap::StackOverflow);
265 },
266 };
267
268 Ok(vec![ret_val])
269 },
270 Err(e) => {
271 dev_log!("wasm", "host function '{}' returned error: {}", func_name_inner, e);
273 Err(wasmtime::Trap::StackOverflow)
274 },
275 }
276 };
277
278 let _wasmparser_signature = wasmparser::FuncType::new([wasmparser::ValType::I32], [wasmparser::ValType::I32]);
280
281 let func_name_for_logging = func_name.clone();
285 linker
286 .func_wrap(
287 "_host", &func_name,
289 move |_caller:wasmtime::Caller<'_, T>, input_param:i32| -> i32 {
290 let start = std::time::Instant::now();
292
293 let args_bytes = match serde_json::to_vec(&input_param).map(bytes::Bytes::from) {
295 Ok(b) => b,
296 Err(e) => {
297 dev_log!("wasm", "warn: serialization error for function '{}': {}", func_name_for_logging, e);
298 return -1i32;
299 },
300 };
301
302 let result = callback(vec![args_bytes]);
304
305 match result {
306 Ok(response_bytes) => {
307 let result_val:serde_json::Value = match serde_json::from_slice(&response_bytes) {
309 Ok(v) => v,
310 Err(_) => {
311 dev_log!("wasm", "warn: error deserializing response for function '{}'", func_name_for_logging);
312 return -1i32;
313 },
314 };
315
316 let ret_val = match result_val {
318 serde_json::Value::Number(n) => {
319 if let Some(i) = n.as_i64() {
320 i as i32
321 } else if let Some(f) = n.as_f64() {
322 f as i32
323 } else {
324 dev_log!("wasm", "warn: invalid number format for function '{}'", func_name_for_logging);
325 -1i32
326 }
327 },
328 serde_json::Value::Bool(b) => {
329 if b {
330 1i32
331 } else {
332 0i32
333 }
334 },
335 _ => {
336 dev_log!("wasm", "warn: unsupported response type for function '{}', expected number or bool", func_name_for_logging);
337 -1i32
338 },
339 };
340
341 let duration = start.elapsed();
343 dev_log!("wasm", "[FunctionExport] Host function '{}' executed successfully in {}µs", func_name_for_logging, duration.as_micros());
344
345 ret_val
346 },
347 Err(e) => {
348 dev_log!("wasm", "[FunctionExport] Host function '{}' returned error: {}", func_name_for_logging, e);
350 -1i32
352 },
353 }
354 },
355 )
356 .map_err(|e| {
357 dev_log!("wasm", "warn: [FunctionExport] failed to wrap host function '{}': {}", func_name_for_debug, e);
358 e
359 })?;
360
361 dev_log!("wasm", "[FunctionExport] Host function '{}' registered successfully", func_name_for_debug);
362
363 Ok(())
364 }
365
366 #[allow(dead_code)]
368 fn wasmtime_signature_from_signature(&self, _sig:&FunctionSignature) -> Result<wasmparser::FuncType> {
369 Ok(wasmparser::FuncType::new([], []))
372 }
373
374 pub async fn get_function_names(&self) -> Vec<String> {
376 self.registry.functions.read().await.keys().cloned().collect()
377 }
378
379 pub async fn get_function_stats(&self, name:&str) -> Option<FunctionStats> {
381 self.registry.functions.read().await.get(name).map(|f| f.stats.clone())
382 }
383
384 pub async fn unregister_function(&self, name:&str) -> Result<bool> {
386 let mut functions = self.registry.functions.write().await;
387 let removed = functions.remove(name).is_some();
388
389 if removed {
390 dev_log!("wasm", "Unregistered host function: {}", name);
391 } else {
392 dev_log!("wasm", "warn: attempted to unregister non-existent function: {}", name);
393 }
394
395 Ok(removed)
396 }
397
398 pub async fn clear(&self) {
400 dev_log!("wasm", "Clearing all registered host functions");
401 self.registry.functions.write().await.clear();
402 }
403}
404
405#[cfg(test)]
406mod tests {
407 use super::*;
408
409 #[tokio::test]
410 async fn test_function_export_creation() {
411 let bridge = Arc::new(HostBridgeImpl::new());
412 let export = FunctionExportImpl::new(bridge);
413
414 assert_eq!(export.get_function_names().await.len(), 0);
415 }
416
417 #[tokio::test]
418 async fn test_register_function() {
419 let bridge = Arc::new(HostBridgeImpl::new());
420 let export = FunctionExportImpl::new(bridge);
421
422 let signature = FunctionSignature {
423 name:"echo".to_string(),
424 param_types:vec![ParamType::I32],
425 return_type:Some(ReturnType::I32),
426 is_async:false,
427 };
428
429 let callback = |args:Vec<bytes::Bytes>| Ok(args.get(0).cloned().unwrap_or(bytes::Bytes::new()));
430
431 let result:anyhow::Result<()> = export.register_function("echo", signature, callback).await;
432 assert!(result.is_ok());
433 assert_eq!(export.get_function_names().await.len(), 1);
434 }
435
436 #[tokio::test]
437 async fn test_unregister_function() {
438 let bridge = Arc::new(HostBridgeImpl::new());
439 let export = FunctionExportImpl::new(bridge);
440
441 let signature = FunctionSignature {
442 name:"test".to_string(),
443 param_types:vec![ParamType::I32],
444 return_type:Some(ReturnType::I32),
445 is_async:false,
446 };
447
448 let callback = |_:Vec<bytes::Bytes>| Ok(bytes::Bytes::new());
449 let _:anyhow::Result<()> = export.register_function("test", signature, callback).await;
450
451 let result:bool = export.unregister_function("test").await.unwrap();
452 assert!(result);
453 assert_eq!(export.get_function_names().await.len(), 0);
454 }
455
456 #[test]
457 fn test_export_config_default() {
458 let config = ExportConfig::default();
459 assert_eq!(config.auto_export, true);
460 assert_eq!(config.max_functions, 1000);
461 }
462
463 #[test]
464 fn test_function_stats_default() {
465 let stats = FunctionStats::default();
466 assert_eq!(stats.call_count, 0);
467 assert_eq!(stats.error_count, 0);
468 }
469}