1
2use std::{
48 collections::HashMap,
49 sync::{
50 Arc,
51 atomic::{AtomicBool, AtomicU64, Ordering},
52 },
53};
54
55use anyhow::Result;
56use futures_util::{SinkExt, StreamExt, stream::SplitSink};
57use serde_json::Value;
58use tokio::{
59 net::{TcpListener, TcpStream},
60 sync::{Mutex, oneshot},
61};
62use tokio_tungstenite::{
63 MaybeTlsStream,
64 WebSocketStream,
65 accept_async,
66 connect_async,
67 tungstenite::{Message, Utf8Bytes},
68};
69
70#[derive(Clone)]
72pub struct SharedSecret(pub [u8; 32]);
73
74impl SharedSecret {
75 pub fn random() -> Self {
76 Self(rand::random::<[u8; 32]>())
80 }
81
82 pub fn as_hex(&self) -> String { hex::encode(self.0) }
83
84 pub fn from_hex(Hex:&str) -> Result<Self> {
85 let Bytes = hex::decode(Hex)?;
86
87 if Bytes.len() != 32 {
88 anyhow::bail!("shared secret must be 32 bytes (got {})", Bytes.len());
89 }
90
91 let mut Out = [0u8; 32];
92
93 Out.copy_from_slice(&Bytes);
94
95 Ok(Self(Out))
96 }
97}
98
99pub type HandlerFn =
102 Arc<dyn Fn(Value) -> futures_util::future::BoxFuture<'static, Result<Value, String>> + Send + Sync>;
103
104#[derive(Default)]
106pub struct HandlerRegistry {
107 Handlers:Mutex<HashMap<String, HandlerFn>>,
108}
109
110impl HandlerRegistry {
111 pub fn new() -> Arc<Self> { Arc::new(Self::default()) }
112
113 pub async fn Register(&self, Method:String, Handler:HandlerFn) {
114 self.Handlers.lock().await.insert(Method, Handler);
115 }
116
117 pub async fn Lookup(&self, Method:&str) -> Option<HandlerFn> { self.Handlers.lock().await.get(Method).cloned() }
118}
119
120pub async fn ServeLocal(Port:u16, Secret:SharedSecret, Registry:Arc<HandlerRegistry>) -> Result<()> {
127 let Address = format!("127.0.0.1:{}", Port);
128
129 let Listener = TcpListener::bind(&Address).await?;
130
131 tracing::info!(target: "Mist::WebSocket", "server listening on {}", Address);
132
133 let PortStr = format!("{}", Port);
137
138 CommonLibrary::Telemetry::CaptureEvent::Fn(
139 "land:mist:server:start",
140 Some(vec![("address", Address.as_str()), ("port", PortStr.as_str())]),
141 );
142
143 loop {
144 let (Stream, Peer) = match Listener.accept().await {
145 Ok(P) => P,
146
147 Err(Error) => {
148 tracing::warn!(target: "Mist::WebSocket", "accept error: {}", Error);
149
150 continue;
151 },
152 };
153
154 let SecretClone = Secret.clone();
155
156 let RegistryClone = Registry.clone();
157
158 tokio::spawn(async move {
159 if let Err(Error) = HandleConnection(Stream, SecretClone, RegistryClone).await {
160 tracing::warn!(target: "Mist::WebSocket", "connection from {} closed with error: {}", Peer, Error);
161 }
162 });
163 }
164}
165
166async fn HandleConnection(Stream:TcpStream, _Secret:SharedSecret, Registry:Arc<HandlerRegistry>) -> Result<()> {
167 let WebSocketStream = accept_async(Stream).await?;
174
175 let (mut Sink, mut Source) = WebSocketStream.split();
176
177 while let Some(MessageResult) = Source.next().await {
178 let Message = match MessageResult {
179 Ok(M) => M,
180
181 Err(Error) => {
182 tracing::debug!(target: "Mist::WebSocket", "frame read error: {}", Error);
183
184 break;
185 },
186 };
187
188 match Message {
189 Message::Text(Text) => {
190 let Envelope:Value = match serde_json::from_str(&Text) {
191 Ok(V) => V,
192
193 Err(Error) => {
194 tracing::debug!(target: "Mist::WebSocket", "bad text frame: {}", Error);
195
196 continue;
197 },
198 };
199
200 let Method = Envelope.get("method").and_then(|V| V.as_str()).unwrap_or("");
201
202 let Identifier = Envelope.get("id").cloned().unwrap_or(Value::Null);
203
204 let Params = Envelope.get("params").cloned().unwrap_or(Value::Array(vec![]));
205
206 if Method.is_empty() {
207 continue;
208 }
209
210 let Handler = Registry.Lookup(Method).await;
211
212 let Response = match Handler {
213 Some(H) => {
214 match H(Params).await {
215 Ok(Value) => serde_json::json!({ "id": Identifier, "result": Value }),
216
217 Err(ErrorMessage) => serde_json::json!({ "id": Identifier, "error": ErrorMessage }),
218 }
219 },
220
221 None => {
222 serde_json::json!({
223 "id": Identifier,
224 "error": format!("Unknown method: {}", Method),
225 })
226 },
227 };
228
229 if Identifier.is_null() {
230 continue;
232 }
233
234 if let Err(Error) = Sink.send(Message::Text(Utf8Bytes::from(Response.to_string()))).await {
235 tracing::debug!(target: "Mist::WebSocket", "send error: {}", Error);
236
237 break;
238 }
239 },
240
241 Message::Binary(Bytes) => {
242 tracing::trace!(target: "Mist::WebSocket", "binary frame ({} bytes) ignored - reserved for phase 2", Bytes.len());
243 },
244
245 Message::Close(_) => break,
246
247 _ => {},
248 }
249 }
250
251 Ok(())
252}
253
254type PendingMap = Arc<Mutex<HashMap<u64, oneshot::Sender<Result<Value, String>>>>>;
256
257pub struct Client {
260 Sink:Arc<Mutex<SplitSink<WebSocketStream<MaybeTlsStream<TcpStream>>, Message>>>,
264
265 Pending:PendingMap,
266
267 NextIdentifier:AtomicU64,
268
269 Closed:AtomicBool,
270}
271
272impl Client {
273 pub async fn connect(Address:&str) -> Result<Arc<Self>> {
278 let (Stream, _Response) = connect_async(Address).await?;
279
280 let (Sink, mut Source) = Stream.split();
281
282 let Sink = Arc::new(Mutex::new(Sink));
283
284 let Pending:PendingMap = Arc::new(Mutex::new(HashMap::new()));
285
286 let SelfReference = Arc::new(Self {
287 Sink,
288 Pending:Pending.clone(),
289 NextIdentifier:AtomicU64::new(1),
290 Closed:AtomicBool::new(false),
291 });
292
293 let SelfForReader = SelfReference.clone();
296
297 tokio::spawn(async move {
298 while let Some(MessageResult) = Source.next().await {
299 let Frame = match MessageResult {
300 Ok(M) => M,
301 Err(_) => break,
302 };
303 match Frame {
304 Message::Text(Text) => {
305 if let Ok(Envelope) = serde_json::from_str::<Value>(&Text) {
306 let Identifier = Envelope.get("id").and_then(|V| V.as_u64());
307 if let Some(Identifier) = Identifier {
308 let Sender = SelfForReader.Pending.lock().await.remove(&Identifier);
309 if let Some(Sender) = Sender {
310 let Result = if let Some(ErrorValue) = Envelope.get("error") {
311 Err(ErrorValue.to_string())
312 } else {
313 Ok(Envelope.get("result").cloned().unwrap_or(Value::Null))
314 };
315 let _ = Sender.send(Result);
316 }
317 }
318 }
319 },
320 Message::Close(_) => break,
321 _ => {},
322 }
323 }
324 SelfForReader.Closed.store(true, Ordering::Relaxed);
325 let mut Pending = SelfForReader.Pending.lock().await;
327 for (_, Sender) in Pending.drain() {
328 let _ = Sender.send(Err("connection closed".into()));
329 }
330 });
331
332 Ok(SelfReference)
333 }
334
335 pub async fn invoke(&self, Method:&str, Params:Value) -> Result<Value, String> {
339 if self.Closed.load(Ordering::Relaxed) {
340 return Err("connection closed".into());
341 }
342
343 let Identifier = self.NextIdentifier.fetch_add(1, Ordering::Relaxed);
344
345 let (Tx, Rx) = oneshot::channel();
346
347 self.Pending.lock().await.insert(Identifier, Tx);
348
349 let Envelope = serde_json::json!({ "id": Identifier, "method": Method, "params": Params });
350
351 let Text = Envelope.to_string();
352
353 let SendResult = self.Sink.lock().await.send(Message::Text(Utf8Bytes::from(Text))).await;
354
355 if SendResult.is_err() {
356 self.Pending.lock().await.remove(&Identifier);
357
358 return Err("send failed".into());
359 }
360
361 Rx.await.map_err(|_| "request cancelled".to_string())?
362 }
363
364 pub async fn notify(&self, Method:&str, Params:Value) -> Result<(), String> {
366 if self.Closed.load(Ordering::Relaxed) {
367 return Err("connection closed".into());
368 }
369
370 let Envelope = serde_json::json!({ "id": Value::Null, "method": Method, "params": Params });
371
372 let Text = Envelope.to_string();
373
374 self.Sink
375 .lock()
376 .await
377 .send(Message::Text(Utf8Bytes::from(Text)))
378 .await
379 .map_err(|Error| Error.to_string())
380 }
381
382 pub fn is_closed(&self) -> bool { self.Closed.load(Ordering::Relaxed) }
383}