wsprism_gateway/transport/
handshake.rs

1//! Handshake Defender (pre-upgrade DoS guard).
2//!
3//! Purpose:
4//! - Stop abuse *before* WebSocket upgrade.
5//! - Per-IP + global leaky-bucket limiter.
6//! - Returns HTTP 429 with Retry-After header hint.
7//! - Note: cleanup is probabilistic and inline; under extreme IP churn it can
8//!   briefly block the caller. A background cleaner is preferable for very high churn.
9
10use std::net::IpAddr;
11use std::time::{Instant, SystemTime, UNIX_EPOCH};
12
13use dashmap::DashMap;
14use tokio::sync::Mutex;
15use crate::config::schema::HandshakeConfig;
16
17/// Simple leaky bucket (capacity/refill, best-effort).
18#[derive(Debug)]
19pub struct LeakyBucket {
20    capacity: u32,
21    tokens: f64,
22    refill_per_sec: f64,
23    last: Instant,
24}
25
26impl LeakyBucket {
27    pub fn new(capacity: u32, refill_per_sec: u32) -> Self {
28        let cap = capacity.max(1);
29        Self {
30            capacity: cap,
31            tokens: cap as f64,
32            refill_per_sec: refill_per_sec.max(1) as f64,
33            last: Instant::now(),
34        }
35    }
36
37    fn refill(&mut self) {
38        let now = Instant::now();
39        let elapsed = now.duration_since(self.last).as_secs_f64();
40        self.last = now;
41        self.tokens = (self.tokens + elapsed * self.refill_per_sec).min(self.capacity as f64);
42    }
43
44    /// Consume `cost` tokens. Returns Ok if allowed, Err with retry_after seconds (ceil).
45    pub fn try_take(&mut self, cost: u32) -> Result<(), u64> {
46        self.refill();
47        let c = cost.max(1) as f64;
48        if self.tokens >= c {
49            self.tokens -= c;
50            Ok(())
51        } else {
52            let missing = c - self.tokens;
53            let wait = (missing / self.refill_per_sec).ceil();
54            Err(wait.max(1.0) as u64) // Retry-After min 1
55        }
56    }
57}
58
59/// A lightweight in-memory handshake rate limiter.
60///
61/// Concurrency note: `check` may invoke a probabilistic cleanup via `retain`
62/// when `per_ip` grows large. That cleanup can briefly lock shards. For strict
63/// latency guarantees, move cleanup to a background task instead of running
64/// inline with request handling.
65#[derive(Debug)]
66pub struct HandshakeDefender {
67    cfg: HandshakeConfig,
68    global: Mutex<LeakyBucket>,
69    per_ip: DashMap<IpAddr, Mutex<LeakyBucket>>,
70}
71
72impl HandshakeDefender {
73    pub fn new(cfg: HandshakeConfig) -> Self {
74        Self {
75            global: Mutex::new(LeakyBucket::new(cfg.global_burst, cfg.global_rps)),
76            per_ip: DashMap::new(),
77            cfg,
78        }
79    }
80
81    pub fn enabled(&self) -> bool {
82        self.cfg.enabled
83    }
84
85    /// Check handshake allowance. Returns Ok if allowed.
86    /// On reject, returns retry-after seconds (min 1).
87    pub async fn check(&self, ip: IpAddr) -> Result<(), u64> {
88        if !self.cfg.enabled {
89            return Ok(());
90        }
91
92        // 1) Global
93        {
94            let mut g = self.global.lock().await;
95            if let Err(ra) = g.try_take(1) {
96                return Err(ra);
97            }
98        }
99
100        // 2) Per-IP
101        let entry = self.per_ip.entry(ip).or_insert_with(|| {
102            Mutex::new(LeakyBucket::new(self.cfg.per_ip_burst, self.cfg.per_ip_rps))
103        });
104        {
105            let mut b = entry.value().lock().await;
106            if let Err(ra) = b.try_take(1) {
107                return Err(ra);
108            }
109        }
110
111        // Best-effort size control (Lazy Cleanup)
112        if self.per_ip.len() > self.cfg.max_ip_entries {
113            // "Pseudo-random" eviction without external crate dependency.
114            // Use nanoseconds from system time as a seed.
115            let nanos = SystemTime::now()
116                .duration_since(UNIX_EPOCH)
117                .unwrap_or_default()
118                .subsec_nanos();
119            
120            // ~10% chance to run cleanup when over limit
121            if nanos % 100 < 10 {
122                // Clear roughly 10% of entries (arbitrary batch)
123                // DashMap doesn't support safe iteration during retain easily without locking shards,
124                // so we just remove some keys if we can find them, or clear all if desperate.
125                // For simplicity/safety here: retain only recently accessed? No timestamp stored.
126                // Fallback: Remove every 10th item (conceptually).
127                // Or just:
128                self.per_ip.retain(|_, _| {
129                    let n = SystemTime::now().duration_since(UNIX_EPOCH).unwrap_or_default().subsec_nanos();
130                    n % 10 != 0 // Drop ~10%
131                });
132                tracing::warn!(len = self.per_ip.len(), "handshake defender ip map trimmed");
133            }
134        }
135
136        Ok(())
137    }
138}
139
140/// Helper: format Retry-After duration.
141pub fn retry_after_header_secs(secs: u64) -> (String, u64) {
142    let s = secs.max(1);
143    (s.to_string(), s)
144}