wsprism_gateway/transport/
handshake.rs1use std::net::IpAddr;
11use std::time::{Instant, SystemTime, UNIX_EPOCH};
12
13use dashmap::DashMap;
14use tokio::sync::Mutex;
15use crate::config::schema::HandshakeConfig;
16
17#[derive(Debug)]
19pub struct LeakyBucket {
20 capacity: u32,
21 tokens: f64,
22 refill_per_sec: f64,
23 last: Instant,
24}
25
26impl LeakyBucket {
27 pub fn new(capacity: u32, refill_per_sec: u32) -> Self {
28 let cap = capacity.max(1);
29 Self {
30 capacity: cap,
31 tokens: cap as f64,
32 refill_per_sec: refill_per_sec.max(1) as f64,
33 last: Instant::now(),
34 }
35 }
36
37 fn refill(&mut self) {
38 let now = Instant::now();
39 let elapsed = now.duration_since(self.last).as_secs_f64();
40 self.last = now;
41 self.tokens = (self.tokens + elapsed * self.refill_per_sec).min(self.capacity as f64);
42 }
43
44 pub fn try_take(&mut self, cost: u32) -> Result<(), u64> {
46 self.refill();
47 let c = cost.max(1) as f64;
48 if self.tokens >= c {
49 self.tokens -= c;
50 Ok(())
51 } else {
52 let missing = c - self.tokens;
53 let wait = (missing / self.refill_per_sec).ceil();
54 Err(wait.max(1.0) as u64) }
56 }
57}
58
59#[derive(Debug)]
66pub struct HandshakeDefender {
67 cfg: HandshakeConfig,
68 global: Mutex<LeakyBucket>,
69 per_ip: DashMap<IpAddr, Mutex<LeakyBucket>>,
70}
71
72impl HandshakeDefender {
73 pub fn new(cfg: HandshakeConfig) -> Self {
74 Self {
75 global: Mutex::new(LeakyBucket::new(cfg.global_burst, cfg.global_rps)),
76 per_ip: DashMap::new(),
77 cfg,
78 }
79 }
80
81 pub fn enabled(&self) -> bool {
82 self.cfg.enabled
83 }
84
85 pub async fn check(&self, ip: IpAddr) -> Result<(), u64> {
88 if !self.cfg.enabled {
89 return Ok(());
90 }
91
92 {
94 let mut g = self.global.lock().await;
95 if let Err(ra) = g.try_take(1) {
96 return Err(ra);
97 }
98 }
99
100 let entry = self.per_ip.entry(ip).or_insert_with(|| {
102 Mutex::new(LeakyBucket::new(self.cfg.per_ip_burst, self.cfg.per_ip_rps))
103 });
104 {
105 let mut b = entry.value().lock().await;
106 if let Err(ra) = b.try_take(1) {
107 return Err(ra);
108 }
109 }
110
111 if self.per_ip.len() > self.cfg.max_ip_entries {
113 let nanos = SystemTime::now()
116 .duration_since(UNIX_EPOCH)
117 .unwrap_or_default()
118 .subsec_nanos();
119
120 if nanos % 100 < 10 {
122 self.per_ip.retain(|_, _| {
129 let n = SystemTime::now().duration_since(UNIX_EPOCH).unwrap_or_default().subsec_nanos();
130 n % 10 != 0 });
132 tracing::warn!(len = self.per_ip.len(), "handshake defender ip map trimmed");
133 }
134 }
135
136 Ok(())
137 }
138}
139
140pub fn retry_after_header_secs(secs: u64) -> (String, u64) {
142 let s = secs.max(1);
143 (s.to_string(), s)
144}