1use std::{collections::HashMap, sync::Arc};
7
8use anyhow::Result;
9use serde::{Deserialize, Serialize};
10use tokio::sync::RwLock;
11use crate::dev_log;
12
13#[derive(Debug, Clone, Serialize, Deserialize)]
15pub struct APICallRequest {
16 pub extension_id:String,
18 pub api_method:String,
20 pub arguments:Vec<serde_json::Value>,
22 pub correlation_id:Option<String>,
24}
25
26#[derive(Debug, Clone, Serialize, Deserialize)]
28pub struct APICallResponse {
29 pub success:bool,
31 pub data:Option<serde_json::Value>,
33 pub error:Option<String>,
35 pub correlation_id:Option<String>,
37}
38
39#[allow(dead_code)]
41pub struct APICall {
42 extension_id:String,
44 api_method:String,
46 arguments:Vec<serde_json::Value>,
48 timestamp:u64,
50}
51
52#[allow(dead_code)]
54type APIMethodHandler = fn(&str, Vec<serde_json::Value>) -> Result<serde_json::Value>;
55
56#[allow(dead_code)]
58type AsyncAPIMethodHandler =
59 fn(&str, Vec<serde_json::Value>) -> Box<dyn std::future::Future<Output = Result<serde_json::Value>> + Send + Unpin>;
60
61#[derive(Clone)]
63pub struct APIMethodInfo {
64 #[allow(dead_code)]
66 name:String,
67 #[allow(dead_code)]
69 description:String,
70 #[allow(dead_code)]
72 parameters:Option<serde_json::Value>,
73 #[allow(dead_code)]
75 returns:Option<serde_json::Value>,
76 #[allow(dead_code)]
78 is_async:bool,
79 call_count:u64,
81 total_time_us:u64,
83}
84
85pub struct APIBridgeImpl {
87 api_methods:Arc<RwLock<HashMap<String, APIMethodInfo>>>,
89 stats:Arc<RwLock<APIStats>>,
91 contexts:Arc<RwLock<HashMap<String, APIContext>>>,
93}
94
95#[derive(Debug, Clone, Default, Serialize, Deserialize)]
97pub struct APIStats {
98 pub total_calls:u64,
100 pub successful_calls:u64,
102 pub failed_calls:u64,
104 pub avg_latency_us:u64,
106 pub active_contexts:usize,
108}
109
110#[derive(Debug, Clone)]
112pub struct APIContext {
113 pub extension_id:String,
115 pub context_id:String,
117 pub workspace_folder:Option<String>,
119 pub active_editor:Option<String>,
121 pub selections:Vec<Selection>,
123 pub created_at:u64,
125}
126
127#[derive(Debug, Clone, Serialize, Deserialize)]
129pub struct Selection {
130 pub start_line:u32,
132 pub start_character:u32,
134 pub end_line:u32,
136 pub end_character:u32,
138}
139
140impl Default for Selection {
141 fn default() -> Self { Self { start_line:0, start_character:0, end_line:0, end_character:0 } }
142}
143
144impl APIBridgeImpl {
145 pub fn new() -> Self {
147 let bridge = Self {
148 api_methods:Arc::new(RwLock::new(HashMap::new())),
149 stats:Arc::new(RwLock::new(APIStats::default())),
150 contexts:Arc::new(RwLock::new(HashMap::new())),
151 };
152
153 bridge.register_builtin_methods();
154
155 bridge
156 }
157
158 fn register_builtin_methods(&self) {
160 dev_log!("extensions", "Registered built-in VS Code API methods");
169 }
170
171 pub async fn register_method(
173 &self,
174 name:&str,
175 description:&str,
176 parameters:Option<serde_json::Value>,
177 returns:Option<serde_json::Value>,
178 is_async:bool,
179 ) -> Result<()> {
180 let mut methods = self.api_methods.write().await;
181
182 if methods.contains_key(name) {
183 dev_log!("extensions", "warn: API method already registered: {}", name);
184 }
185
186 methods.insert(
187 name.to_string(),
188 APIMethodInfo {
189 name:name.to_string(),
190 description:description.to_string(),
191 parameters,
192 returns,
193 is_async,
194 call_count:0,
195 total_time_us:0,
196 },
197 );
198
199 dev_log!("extensions", "Registered API method: {}", name);
200
201 Ok(())
202 }
203
204 pub async fn create_context(&self, extension_id:&str) -> Result<APIContext> {
206 let context_id = format!("{}-{}", extension_id, uuid::Uuid::new_v4());
207
208 let context = APIContext {
209 extension_id:extension_id.to_string(),
210 context_id:context_id.clone(),
211 workspace_folder:None,
212 active_editor:None,
213 selections:Vec::new(),
214 created_at:std::time::SystemTime::now()
215 .duration_since(std::time::UNIX_EPOCH)
216 .map(|d| d.as_secs())
217 .unwrap_or(0),
218 };
219
220 let mut contexts = self.contexts.write().await;
221 contexts.insert(context_id.clone(), context.clone());
222
223 let mut stats = self.stats.write().await;
225 stats.active_contexts = contexts.len();
226
227 dev_log!("extensions", "Created API context for extension: {}", extension_id);
228
229 Ok(context)
230 }
231
232 pub async fn get_context(&self, context_id:&str) -> Option<APIContext> {
234 self.contexts.read().await.get(context_id).cloned()
235 }
236
237 pub async fn update_context(&self, context:APIContext) -> Result<()> {
239 let mut contexts = self.contexts.write().await;
240 contexts.insert(context.context_id.clone(), context);
241 Ok(())
242 }
243
244 pub async fn remove_context(&self, context_id:&str) -> Result<bool> {
246 let mut contexts = self.contexts.write().await;
247 let removed = contexts.remove(context_id).is_some();
248
249 if removed {
250 let mut stats = self.stats.write().await;
251 stats.active_contexts = contexts.len();
252 }
253
254 Ok(removed)
255 }
256
257 pub async fn handle_call(&self, request:APICallRequest) -> Result<APICallResponse> {
259 let start = std::time::Instant::now();
260
261 dev_log!("extensions", "Handling API call: {} from extension {}", request.api_method, request.extension_id);
262
263 let exists = {
265 let methods = self.api_methods.read().await;
266 methods.contains_key(&request.api_method)
267 };
268
269 if !exists {
270 return Ok(APICallResponse {
271 success:false,
272 data:None,
273 error:Some(format!("API method not found: {}", request.api_method)),
274 correlation_id:request.correlation_id,
275 });
276 }
277
278 let result = self
281 .execute_method(&request.extension_id, &request.api_method, &request.arguments)
282 .await;
283
284 let elapsed_us = start.elapsed().as_micros() as u64;
285
286 let mut stats = self.stats.write().await;
288 stats.total_calls += 1;
289 stats.total_calls += 1;
290 if exists {
291 stats.successful_calls += 1;
292 stats.avg_latency_us =
294 (stats.avg_latency_us * (stats.successful_calls - 1) + elapsed_us) / stats.successful_calls;
295 }
296
297 {
299 let mut methods = self.api_methods.write().await;
300 if let Some(method) = methods.get_mut(&request.api_method) {
301 method.call_count += 1;
302 method.total_time_us += elapsed_us;
303 }
304 }
305
306 dev_log!("extensions", "API call {} completed in {}µs", request.api_method, elapsed_us);
307
308 match result {
309 Ok(data) => {
310 Ok(
311 APICallResponse {
312 success:true,
313 data:Some(data),
314 error:None,
315 correlation_id:request.correlation_id,
316 },
317 )
318 },
319 Err(e) => {
320 Ok(APICallResponse {
321 success:false,
322 data:None,
323 error:Some(e.to_string()),
324 correlation_id:request.correlation_id,
325 })
326 },
327 }
328 }
329
330 async fn execute_method(
332 &self,
333 _extension_id:&str,
334 _method_name:&str,
335 _arguments:&[serde_json::Value],
336 ) -> Result<serde_json::Value> {
337 Ok(serde_json::Value::Null)
346 }
347
348 pub async fn stats(&self) -> APIStats { self.stats.read().await.clone() }
350
351 pub async fn get_methods(&self) -> Vec<APIMethodInfo> { self.api_methods.read().await.values().cloned().collect() }
353
354 pub async fn unregister_method(&self, name:&str) -> Result<bool> {
356 let mut methods = self.api_methods.write().await;
357 let removed = methods.remove(name).is_some();
358
359 if removed {
360 dev_log!("extensions", "Unregistered API method: {}", name);
361 }
362
363 Ok(removed)
364 }
365}
366
367impl Default for APIBridgeImpl {
368 fn default() -> Self { Self::new() }
369}
370
371#[cfg(test)]
372mod tests {
373 use super::*;
374
375 #[tokio::test]
376 async fn test_api_bridge_creation() {
377 let bridge = APIBridgeImpl::new();
378 let stats = bridge.stats().await;
379 assert_eq!(stats.total_calls, 0);
380 assert_eq!(stats.successful_calls, 0);
381 }
382
383 #[tokio::test]
384 async fn test_context_creation() {
385 let bridge = APIBridgeImpl::new();
386 let context = bridge.create_context("test.ext").await.unwrap();
387 assert_eq!(context.extension_id, "test.ext");
388 assert!(!context.context_id.is_empty());
389 }
390
391 #[tokio::test]
392 async fn test_method_registration() {
393 let bridge = APIBridgeImpl::new();
394 let result:Result<()> = bridge.register_method("test.method", "Test method", None, None, false).await;
395 assert!(result.is_ok());
396
397 let methods:Vec<APIMethodInfo> = bridge.get_methods().await;
398 assert!(methods.iter().any(|m| m.name == "test.method"));
399 }
400
401 #[tokio::test]
402 async fn test_api_call_request() {
403 let request = APICallRequest {
404 extension_id:"test.ext".to_string(),
405 api_method:"test.method".to_string(),
406 arguments:vec![serde_json::json!("arg1")],
407 correlation_id:Some("test-id".to_string()),
408 };
409
410 assert_eq!(request.extension_id, "test.ext");
411 assert_eq!(request.api_method, "test.method");
412 assert_eq!(request.arguments.len(), 1);
413 }
414
415 #[test]
416 fn test_selection_default() {
417 let selection = Selection::default();
418 assert_eq!(selection.start_line, 0);
419 assert_eq!(selection.end_line, 0);
420 }
421}