Skip to main content

Grove/WASM/
FunctionExport.rs

1//! Function Export Module
2//!
3//! Handles exporting host functions to WASM modules.
4//! Provides registration and management of functions that WASM can call.
5
6use std::{collections::HashMap, sync::Arc};
7
8use anyhow::Result;
9use serde::{Deserialize, Serialize};
10use tokio::sync::RwLock;
11use crate::dev_log;
12use wasmtime::{Caller, Linker};
13
14use crate::WASM::HostBridge::{FunctionSignature, HostBridgeImpl, HostBridgeImpl as HostBridge, HostFunctionCallback, ParamType, ReturnType};
15
16/// Host function registry for WASM exports
17pub struct HostFunctionRegistry {
18	/// Registered host functions
19	functions:Arc<RwLock<HashMap<String, RegisteredHostFunction>>>,
20	/// Associated host bridge
21	#[allow(dead_code)]
22	bridge:Arc<HostBridge>,
23}
24
25/// Registered host function with metadata
26#[derive(Debug, Clone)]
27struct RegisteredHostFunction {
28	/// Function name
29	#[allow(dead_code)]
30	name:String,
31	/// Function signature
32	#[allow(dead_code)]
33	signature:FunctionSignature,
34	/// Synchronous callback
35	callback:Option<HostFunctionCallback>,
36	/// Registration timestamp
37	#[allow(dead_code)]
38	registered_at:u64,
39	/// Call statistics
40	stats:FunctionStats,
41}
42
43/// Function statistics
44#[derive(Debug, Clone, Default)]
45pub struct FunctionStats {
46	/// Number of times called
47	pub call_count:u64,
48	/// Total execution time in nanoseconds
49	pub total_execution_ns:u64,
50	/// Last call timestamp
51	pub last_call_at:Option<u64>,
52	/// Number of errors
53	pub error_count:u64,
54}
55
56/// Export configuration for WASM functions
57#[derive(Debug, Clone, Serialize, Deserialize)]
58pub struct ExportConfig {
59	/// Enable function export by default
60	pub auto_export:bool,
61	/// Enable timing statistics
62	pub enable_stats:bool,
63	/// Maximum number of functions that can be exported
64	pub max_functions:usize,
65	/// Function name prefix for exports
66	pub name_prefix:Option<String>,
67}
68
69impl Default for ExportConfig {
70	fn default() -> Self {
71		Self {
72			auto_export:true,
73			enable_stats:true,
74			max_functions:1000,
75			name_prefix:Some("host_".to_string()),
76		}
77	}
78}
79
80/// Function export for WASM
81pub struct FunctionExportImpl {
82	registry:Arc<HostFunctionRegistry>,
83	config:ExportConfig,
84}
85
86impl FunctionExportImpl {
87	/// Create a new function export manager
88	pub fn new(bridge:Arc<HostBridge>) -> Self {
89		Self {
90			registry:Arc::new(HostFunctionRegistry { functions:Arc::new(RwLock::new(HashMap::new())), bridge }),
91			config:ExportConfig::default(),
92		}
93	}
94
95	/// Create with custom configuration
96	pub fn with_config(bridge:Arc<HostBridge>, config:ExportConfig) -> Self {
97		Self {
98			registry:Arc::new(HostFunctionRegistry { functions:Arc::new(RwLock::new(HashMap::new())), bridge }),
99			config,
100		}
101	}
102
103	/// Register a host function for export to WASM
104	pub async fn register_function(
105		&self,
106		name:&str,
107		signature:FunctionSignature,
108		callback:HostFunctionCallback,
109	) -> Result<()> {
110		dev_log!("wasm", "Registering host function for export: {}", name);
111
112		let functions = self.registry.functions.read().await;
113
114		// Check max function limit
115		if functions.len() >= self.config.max_functions {
116			return Err(anyhow::anyhow!(
117				"Maximum number of exported functions reached: {}",
118				self.config.max_functions
119			));
120		}
121
122		drop(functions);
123
124		let mut functions = self.registry.functions.write().await;
125
126		let registered_at = std::time::SystemTime::now().duration_since(std::time::UNIX_EPOCH)?.as_secs();
127
128		functions.insert(
129			name.to_string(),
130			RegisteredHostFunction {
131				name:name.to_string(),
132				signature,
133				callback:Some(callback),
134				registered_at,
135				stats:FunctionStats::default(),
136			},
137		);
138
139		dev_log!("wasm", "Host function registered for WASM export: {}", name);
140		Ok(())
141	}
142
143	/// Register multiple host functions
144	pub async fn register_functions(
145		&self,
146		signatures:Vec<FunctionSignature>,
147		callbacks:Vec<HostFunctionCallback>,
148	) -> Result<()> {
149		if signatures.len() != callbacks.len() {
150			return Err(anyhow::anyhow!("Number of signatures must match number of callbacks"));
151		}
152
153		for (sig, callback) in signatures.into_iter().zip(callbacks) {
154			let name = sig.name.clone();
155			self.register_function(&name, sig, callback).await?;
156		}
157
158		Ok(())
159	}
160
161	/// Export all registered functions to a WASMtime linker
162	pub async fn export_to_linker<T>(&self, linker:&mut Linker<T>) -> Result<()>
163	where
164		T: Send + 'static, {
165		dev_log!(
166			"wasm",
167			"Exporting {} host functions to linker",
168			self.registry.functions.read().await.len()
169		);
170
171		let functions = self.registry.functions.read().await;
172
173		for (name, func) in functions.iter() {
174			self.export_single_function(linker, name, func)?;
175		}
176
177		dev_log!("wasm", "All host functions exported to linker");
178		Ok(())
179	}
180
181	/// Export a single function to the linker
182	fn export_single_function<T>(&self, linker:&mut Linker<T>, name:&str, func:&RegisteredHostFunction) -> Result<()>
183	where
184		T: Send + 'static, {
185		dev_log!("wasm", "Exporting function: {}", name);
186
187		let callback = func
188			.callback
189			.ok_or_else(|| anyhow::anyhow!("No callback available for function: {}", name))?;
190
191		let func_name = if let Some(prefix) = &self.config.name_prefix {
192			format!("{}{}", prefix, name)
193		} else {
194			name.to_string()
195		};
196
197		let func_name_for_debug = func_name.clone();
198		let func_name_inner = func_name.clone();
199
200		// Create a wrapper function that handles stats and error handling
201		let _wrapped_callback =
202			move |_caller:Caller<'_, T>, args:&[wasmtime::Val]| -> Result<Vec<wasmtime::Val>, wasmtime::Trap> {
203				let _start = std::time::Instant::now();
204
205				// Convert args to bytes
206				let args_bytes:Result<Vec<bytes::Bytes>, _> = args
207					.iter()
208					.map(|arg| {
209						match arg {
210							wasmtime::Val::I32(i) => {
211								serde_json::to_vec(i)
212									.map(bytes::Bytes::from)
213									.map_err(|e| anyhow::anyhow!("Serialization error: {}", e))
214							},
215							wasmtime::Val::I64(i) => {
216								serde_json::to_vec(i)
217									.map(bytes::Bytes::from)
218									.map_err(|e| anyhow::anyhow!("Serialization error: {}", e))
219							},
220							wasmtime::Val::F32(f) => {
221								serde_json::to_vec(f)
222									.map(bytes::Bytes::from)
223									.map_err(|e| anyhow::anyhow!("Serialization error: {}", e))
224							},
225							wasmtime::Val::F64(f) => {
226								serde_json::to_vec(f)
227									.map(bytes::Bytes::from)
228									.map_err(|e| anyhow::anyhow!("Serialization error: {}", e))
229							},
230							_ => Err(anyhow::anyhow!("Unsupported argument type")),
231						}
232					})
233					.collect();
234
235				let args_bytes = args_bytes.map_err(|_| {
236					dev_log!("wasm", "warn: error converting arguments for function '{}'", func_name_inner);
237					wasmtime::Trap::StackOverflow
238				})?;
239
240				// Call the callback
241				let result = callback(args_bytes);
242
243				match result {
244					Ok(response_bytes) => {
245						// Deserialize response
246						let result_val:serde_json::Value = serde_json::from_slice(&response_bytes).map_err(|_| {
247							dev_log!("wasm", "warn: error deserializing response for function '{}'", func_name_inner);
248							wasmtime::Trap::StackOverflow
249						})?;
250
251						let ret_val = match result_val {
252							serde_json::Value::Number(n) => {
253								if let Some(i) = n.as_i64() {
254									wasmtime::Val::I32(i as i32)
255								} else if let Some(f) = n.as_f64() {
256									wasmtime::Val::I64(f as i64)
257								} else {
258									dev_log!("wasm", "warn: invalid number format for function '{}'", func_name_inner);
259									return Err(wasmtime::Trap::StackOverflow);
260								}
261							},
262							_ => {
263								dev_log!("wasm", "warn: unsupported response type for function '{}'", func_name_inner);
264								return Err(wasmtime::Trap::StackOverflow);
265							},
266						};
267
268						Ok(vec![ret_val])
269					},
270					Err(e) => {
271						// Error handling
272						dev_log!("wasm", "host function '{}' returned error: {}", func_name_inner, e);
273						Err(wasmtime::Trap::StackOverflow)
274					},
275				}
276			};
277
278		// Define the function signature for WASMtime
279		let _wasmparser_signature = wasmparser::FuncType::new([wasmparser::ValType::I32], [wasmparser::ValType::I32]);
280
281		// Register host function with the linker using simple i32->i32 signature
282		// In Wasmtime 20, func_wrap expects parameters to be inferred from the closure
283		// signature
284		let func_name_for_logging = func_name.clone();
285		linker
286			.func_wrap(
287				"_host", // Module name for host functions
288				&func_name,
289				move |_caller:wasmtime::Caller<'_, T>, input_param:i32| -> i32 {
290					// Track function call for metrics
291					let start = std::time::Instant::now();
292
293					// Convert input parameter to bytes for callback
294					let args_bytes = match serde_json::to_vec(&input_param).map(bytes::Bytes::from) {
295						Ok(b) => b,
296						Err(e) => {
297							dev_log!("wasm", "warn: serialization error for function '{}': {}", func_name_for_logging, e);
298							return -1i32;
299						},
300					};
301
302					// Call the registered callback
303					let result = callback(vec![args_bytes]);
304
305					match result {
306						Ok(response_bytes) => {
307							// Deserialize response
308							let result_val:serde_json::Value = match serde_json::from_slice(&response_bytes) {
309								Ok(v) => v,
310								Err(_) => {
311									dev_log!("wasm", "warn: error deserializing response for function '{}'", func_name_for_logging);
312									return -1i32;
313								},
314							};
315
316							// Extract result value
317							let ret_val = match result_val {
318								serde_json::Value::Number(n) => {
319									if let Some(i) = n.as_i64() {
320										i as i32
321									} else if let Some(f) = n.as_f64() {
322										f as i32
323									} else {
324										dev_log!("wasm", "warn: invalid number format for function '{}'", func_name_for_logging);
325										-1i32
326									}
327								},
328								serde_json::Value::Bool(b) => {
329									if b {
330										1i32
331									} else {
332										0i32
333									}
334								},
335								_ => {
336									dev_log!("wasm", "warn: unsupported response type for function '{}', expected number or bool", func_name_for_logging);
337									-1i32
338								},
339							};
340
341							// Log successful call
342							let duration = start.elapsed();
343							dev_log!("wasm", "[FunctionExport] Host function '{}' executed successfully in {}µs", func_name_for_logging, duration.as_micros());
344
345							ret_val
346						},
347						Err(e) => {
348							// Error handling - return error code to WASM caller
349							dev_log!("wasm", "[FunctionExport] Host function '{}' returned error: {}", func_name_for_logging, e);
350							// Return -1 to indicate error to WASM caller
351							-1i32
352						},
353					}
354				},
355			)
356			.map_err(|e| {
357				dev_log!("wasm", "warn: [FunctionExport] failed to wrap host function '{}': {}", func_name_for_debug, e);
358				e
359			})?;
360
361		dev_log!("wasm", "[FunctionExport] Host function '{}' registered successfully", func_name_for_debug);
362
363		Ok(())
364	}
365
366	/// Convert our signature to WASMtime signature type
367	#[allow(dead_code)]
368	fn wasmtime_signature_from_signature(&self, _sig:&FunctionSignature) -> Result<wasmparser::FuncType> {
369		// This is a placeholder - actual implementation depends on the exact types
370		// In production, this would map ParamType and ReturnType to WASMtime types
371		Ok(wasmparser::FuncType::new([], []))
372	}
373
374	/// Get all registered function names
375	pub async fn get_function_names(&self) -> Vec<String> {
376		self.registry.functions.read().await.keys().cloned().collect()
377	}
378
379	/// Get function statistics
380	pub async fn get_function_stats(&self, name:&str) -> Option<FunctionStats> {
381		self.registry.functions.read().await.get(name).map(|f| f.stats.clone())
382	}
383
384	/// Unregister a function
385	pub async fn unregister_function(&self, name:&str) -> Result<bool> {
386		let mut functions = self.registry.functions.write().await;
387		let removed = functions.remove(name).is_some();
388
389		if removed {
390			dev_log!("wasm", "Unregistered host function: {}", name);
391		} else {
392			dev_log!("wasm", "warn: attempted to unregister non-existent function: {}", name);
393		}
394
395		Ok(removed)
396	}
397
398	/// Clear all registered functions
399	pub async fn clear(&self) {
400		dev_log!("wasm", "Clearing all registered host functions");
401		self.registry.functions.write().await.clear();
402	}
403}
404
405#[cfg(test)]
406mod tests {
407	use super::*;
408
409	#[tokio::test]
410	async fn test_function_export_creation() {
411		let bridge = Arc::new(HostBridgeImpl::new());
412		let export = FunctionExportImpl::new(bridge);
413
414		assert_eq!(export.get_function_names().await.len(), 0);
415	}
416
417	#[tokio::test]
418	async fn test_register_function() {
419		let bridge = Arc::new(HostBridgeImpl::new());
420		let export = FunctionExportImpl::new(bridge);
421
422		let signature = FunctionSignature {
423			name:"echo".to_string(),
424			param_types:vec![ParamType::I32],
425			return_type:Some(ReturnType::I32),
426			is_async:false,
427		};
428
429		let callback = |args:Vec<bytes::Bytes>| Ok(args.get(0).cloned().unwrap_or(bytes::Bytes::new()));
430
431		let result:anyhow::Result<()> = export.register_function("echo", signature, callback).await;
432		assert!(result.is_ok());
433		assert_eq!(export.get_function_names().await.len(), 1);
434	}
435
436	#[tokio::test]
437	async fn test_unregister_function() {
438		let bridge = Arc::new(HostBridgeImpl::new());
439		let export = FunctionExportImpl::new(bridge);
440
441		let signature = FunctionSignature {
442			name:"test".to_string(),
443			param_types:vec![ParamType::I32],
444			return_type:Some(ReturnType::I32),
445			is_async:false,
446		};
447
448		let callback = |_:Vec<bytes::Bytes>| Ok(bytes::Bytes::new());
449		let _:anyhow::Result<()> = export.register_function("test", signature, callback).await;
450
451		let result:bool = export.unregister_function("test").await.unwrap();
452		assert!(result);
453		assert_eq!(export.get_function_names().await.len(), 0);
454	}
455
456	#[test]
457	fn test_export_config_default() {
458		let config = ExportConfig::default();
459		assert_eq!(config.auto_export, true);
460		assert_eq!(config.max_functions, 1000);
461	}
462
463	#[test]
464	fn test_function_stats_default() {
465		let stats = FunctionStats::default();
466		assert_eq!(stats.call_count, 0);
467		assert_eq!(stats.error_count, 0);
468	}
469}