wsprism_gateway/policy/
engine.rs

1//! Compiled policy runtime for a tenant.
2//!
3//! Parses allowlists, enforces size/rate limits, and exposes connection-level
4//! limiters as needed by the transport layer.
5
6use std::sync::{Arc, Mutex};
7use std::time::{Duration, Instant};
8
9use wsprism_core::error::ClientCode;
10
11pub use crate::config::schema::{HotErrorMode, OnExceed, SessionMode};
12use crate::config::schema::{RateLimitScope, SessionPolicy, TenantPolicy};
13
14use super::allowlist::{
15    compile_ext_rules, compile_hot_rules, is_ext_allowed, is_hot_allowed, ExtRule, HotRule,
16};
17
18/// Decision from policy evaluation.
19#[derive(Debug, Clone)]
20pub enum PolicyDecision {
21    Pass,
22    Drop,
23    Reject { code: ClientCode, msg: &'static str },
24    Close { code: ClientCode, msg: &'static str },
25}
26
27/// Tenant-scoped policy runtime.
28/// Construct once at startup, then share via Arc.
29pub struct TenantPolicyRuntime {
30    pub tenant_id: String,
31
32    max_frame_bytes: usize,
33    ext_rules: Vec<ExtRule>,
34    hot_rules: Vec<HotRule>,
35
36    // Rate limit configuration
37    rate_limit_scope: RateLimitScope,
38    conn_rps: u32,
39    conn_burst: u32,
40    tenant_limiter: Option<RateLimiter>,
41
42    // Session policy
43    sessions: SessionPolicy,
44
45    // Hot lane behavior
46    hot_error_mode: HotErrorMode,
47    hot_requires_active_room: bool,
48}
49
50impl TenantPolicyRuntime {
51    pub fn new(
52        tenant_id: String,
53        max_frame_bytes: usize,
54        policy: &TenantPolicy,
55    ) -> wsprism_core::Result<Self> {
56        let ext_rules = compile_ext_rules(&policy.ext_allowlist)?;
57        let hot_rules = compile_hot_rules(&policy.hot_allowlist)?;
58
59        let tenant_limiter = match policy.rate_limit_scope {
60            RateLimitScope::Tenant | RateLimitScope::Both => {
61                Some(RateLimiter::new(policy.rate_limit_rps, policy.rate_limit_burst))
62            }
63            RateLimitScope::Connection => None,
64        };
65
66        Ok(Self {
67            tenant_id,
68            max_frame_bytes,
69            ext_rules,
70            hot_rules,
71            rate_limit_scope: policy.rate_limit_scope,
72            conn_rps: policy.rate_limit_rps,
73            conn_burst: policy.rate_limit_burst,
74            tenant_limiter,
75            sessions: policy.sessions.clone(),
76            hot_error_mode: policy.hot_error_mode,
77            hot_requires_active_room: policy.hot_requires_active_room,
78        })
79    }
80
81    pub fn session_policy(&self) -> &SessionPolicy {
82        &self.sessions
83    }
84    pub fn hot_error_mode(&self) -> HotErrorMode {
85        self.hot_error_mode
86    }
87    pub fn hot_requires_active_room(&self) -> bool {
88        self.hot_requires_active_room
89    }
90
91    /// Create per-connection limiter if enabled (Connection/Both).
92    pub fn new_connection_limiter(&self) -> Option<ConnRateLimiter> {
93        match self.rate_limit_scope {
94            RateLimitScope::Connection | RateLimitScope::Both => {
95                Some(ConnRateLimiter::new(self.conn_rps, self.conn_burst))
96            }
97            RateLimitScope::Tenant => None,
98        }
99    }
100
101    /// Cheap global checks for any inbound payload.
102    pub fn check_len(&self, bytes_len: usize) -> PolicyDecision {
103        if bytes_len > self.max_frame_bytes {
104            return PolicyDecision::Close {
105                code: ClientCode::BadRequest,
106                msg: "frame too large",
107            };
108        }
109        PolicyDecision::Pass
110    }
111
112    /// Ext Lane policy: svc/type allowlist + (optional) tenant-level rate limit.
113    pub fn check_text(&self, bytes_len: usize, svc: &str, msg_type: &str) -> PolicyDecision {
114        match self.check_len(bytes_len) {
115            PolicyDecision::Pass => {}
116            other => return other,
117        }
118
119        if let Some(lim) = &self.tenant_limiter {
120            if !lim.allow() {
121                return PolicyDecision::Drop;
122            }
123        }
124
125        if self.ext_rules.is_empty() {
126            return PolicyDecision::Reject {
127                code: ClientCode::BadRequest,
128                msg: "ext_allowlist empty (strict deny)",
129            };
130        }
131
132        if !is_ext_allowed(&self.ext_rules, svc, msg_type) {
133            return PolicyDecision::Reject {
134                code: ClientCode::BadRequest,
135                msg: "svc/type not allowed",
136            };
137        }
138
139        PolicyDecision::Pass
140    }
141
142    /// Hot Lane policy: svc_id/opcode allowlist + (optional) tenant-level rate limit.
143    pub fn check_hot(&self, bytes_len: usize, svc_id: u8, opcode: u8) -> PolicyDecision {
144        match self.check_len(bytes_len) {
145            PolicyDecision::Pass => {}
146            other => return other,
147        }
148
149        if let Some(lim) = &self.tenant_limiter {
150            if !lim.allow() {
151                return PolicyDecision::Drop;
152            }
153        }
154
155        if self.hot_rules.is_empty() {
156            return PolicyDecision::Drop; // strict deny
157        }
158
159        if !is_hot_allowed(&self.hot_rules, svc_id, opcode) {
160            return PolicyDecision::Drop;
161        }
162
163        PolicyDecision::Pass
164    }
165}
166
167/// Per-connection token bucket (no mutex).
168#[derive(Debug)]
169pub struct ConnRateLimiter {
170    bucket: TokenBucket,
171}
172
173impl ConnRateLimiter {
174    pub fn new(rps: u32, burst: u32) -> Self {
175        Self {
176            bucket: TokenBucket::new(rps, burst),
177        }
178    }
179
180    pub fn allow(&mut self) -> bool {
181        self.bucket.allow()
182    }
183}
184
185/// Minimal token-bucket limiter (tenant-level, shared).
186struct RateLimiter {
187    inner: Arc<Mutex<TokenBucket>>,
188}
189
190impl RateLimiter {
191    fn new(rps: u32, burst: u32) -> Self {
192        Self {
193            inner: Arc::new(Mutex::new(TokenBucket::new(rps, burst))),
194        }
195    }
196
197    fn allow(&self) -> bool {
198        // Poisoned mutex means logic bug; treat as "deny" instead of panic.
199        // (enterprise: never bring down gateway)
200        if let Ok(mut g) = self.inner.lock() {
201            g.allow()
202        } else {
203            false
204        }
205    }
206}
207
208#[derive(Debug)]
209struct TokenBucket {
210    rps: u32,
211    capacity: u32,
212    tokens: u32,
213    last: Instant,
214}
215
216impl TokenBucket {
217    fn new(rps: u32, burst: u32) -> Self {
218        let rps = rps.max(1);
219        let capacity = burst.max(1);
220        Self {
221            rps,
222            capacity,
223            tokens: capacity,
224            last: Instant::now(),
225        }
226    }
227
228    fn allow(&mut self) -> bool {
229        self.refill();
230
231        if self.tokens == 0 {
232            return false;
233        }
234        self.tokens -= 1;
235        true
236    }
237
238    fn refill(&mut self) {
239        let now = Instant::now();
240        let elapsed = now.duration_since(self.last);
241        if elapsed < Duration::from_millis(50) {
242            return;
243        }
244
245        let add = (elapsed.as_millis() as u64 * self.rps as u64 / 1000) as u32;
246        if add > 0 {
247            self.tokens = (self.tokens + add).min(self.capacity);
248            self.last = now;
249        }
250    }
251}