wsprism_gateway/policy/
engine.rs1use std::sync::{Arc, Mutex};
7use std::time::{Duration, Instant};
8
9use wsprism_core::error::ClientCode;
10
11pub use crate::config::schema::{HotErrorMode, OnExceed, SessionMode};
12use crate::config::schema::{RateLimitScope, SessionPolicy, TenantPolicy};
13
14use super::allowlist::{
15 compile_ext_rules, compile_hot_rules, is_ext_allowed, is_hot_allowed, ExtRule, HotRule,
16};
17
18#[derive(Debug, Clone)]
20pub enum PolicyDecision {
21 Pass,
22 Drop,
23 Reject { code: ClientCode, msg: &'static str },
24 Close { code: ClientCode, msg: &'static str },
25}
26
27pub struct TenantPolicyRuntime {
30 pub tenant_id: String,
31
32 max_frame_bytes: usize,
33 ext_rules: Vec<ExtRule>,
34 hot_rules: Vec<HotRule>,
35
36 rate_limit_scope: RateLimitScope,
38 conn_rps: u32,
39 conn_burst: u32,
40 tenant_limiter: Option<RateLimiter>,
41
42 sessions: SessionPolicy,
44
45 hot_error_mode: HotErrorMode,
47 hot_requires_active_room: bool,
48}
49
50impl TenantPolicyRuntime {
51 pub fn new(
52 tenant_id: String,
53 max_frame_bytes: usize,
54 policy: &TenantPolicy,
55 ) -> wsprism_core::Result<Self> {
56 let ext_rules = compile_ext_rules(&policy.ext_allowlist)?;
57 let hot_rules = compile_hot_rules(&policy.hot_allowlist)?;
58
59 let tenant_limiter = match policy.rate_limit_scope {
60 RateLimitScope::Tenant | RateLimitScope::Both => {
61 Some(RateLimiter::new(policy.rate_limit_rps, policy.rate_limit_burst))
62 }
63 RateLimitScope::Connection => None,
64 };
65
66 Ok(Self {
67 tenant_id,
68 max_frame_bytes,
69 ext_rules,
70 hot_rules,
71 rate_limit_scope: policy.rate_limit_scope,
72 conn_rps: policy.rate_limit_rps,
73 conn_burst: policy.rate_limit_burst,
74 tenant_limiter,
75 sessions: policy.sessions.clone(),
76 hot_error_mode: policy.hot_error_mode,
77 hot_requires_active_room: policy.hot_requires_active_room,
78 })
79 }
80
81 pub fn session_policy(&self) -> &SessionPolicy {
82 &self.sessions
83 }
84 pub fn hot_error_mode(&self) -> HotErrorMode {
85 self.hot_error_mode
86 }
87 pub fn hot_requires_active_room(&self) -> bool {
88 self.hot_requires_active_room
89 }
90
91 pub fn new_connection_limiter(&self) -> Option<ConnRateLimiter> {
93 match self.rate_limit_scope {
94 RateLimitScope::Connection | RateLimitScope::Both => {
95 Some(ConnRateLimiter::new(self.conn_rps, self.conn_burst))
96 }
97 RateLimitScope::Tenant => None,
98 }
99 }
100
101 pub fn check_len(&self, bytes_len: usize) -> PolicyDecision {
103 if bytes_len > self.max_frame_bytes {
104 return PolicyDecision::Close {
105 code: ClientCode::BadRequest,
106 msg: "frame too large",
107 };
108 }
109 PolicyDecision::Pass
110 }
111
112 pub fn check_text(&self, bytes_len: usize, svc: &str, msg_type: &str) -> PolicyDecision {
114 match self.check_len(bytes_len) {
115 PolicyDecision::Pass => {}
116 other => return other,
117 }
118
119 if let Some(lim) = &self.tenant_limiter {
120 if !lim.allow() {
121 return PolicyDecision::Drop;
122 }
123 }
124
125 if self.ext_rules.is_empty() {
126 return PolicyDecision::Reject {
127 code: ClientCode::BadRequest,
128 msg: "ext_allowlist empty (strict deny)",
129 };
130 }
131
132 if !is_ext_allowed(&self.ext_rules, svc, msg_type) {
133 return PolicyDecision::Reject {
134 code: ClientCode::BadRequest,
135 msg: "svc/type not allowed",
136 };
137 }
138
139 PolicyDecision::Pass
140 }
141
142 pub fn check_hot(&self, bytes_len: usize, svc_id: u8, opcode: u8) -> PolicyDecision {
144 match self.check_len(bytes_len) {
145 PolicyDecision::Pass => {}
146 other => return other,
147 }
148
149 if let Some(lim) = &self.tenant_limiter {
150 if !lim.allow() {
151 return PolicyDecision::Drop;
152 }
153 }
154
155 if self.hot_rules.is_empty() {
156 return PolicyDecision::Drop; }
158
159 if !is_hot_allowed(&self.hot_rules, svc_id, opcode) {
160 return PolicyDecision::Drop;
161 }
162
163 PolicyDecision::Pass
164 }
165}
166
167#[derive(Debug)]
169pub struct ConnRateLimiter {
170 bucket: TokenBucket,
171}
172
173impl ConnRateLimiter {
174 pub fn new(rps: u32, burst: u32) -> Self {
175 Self {
176 bucket: TokenBucket::new(rps, burst),
177 }
178 }
179
180 pub fn allow(&mut self) -> bool {
181 self.bucket.allow()
182 }
183}
184
185struct RateLimiter {
187 inner: Arc<Mutex<TokenBucket>>,
188}
189
190impl RateLimiter {
191 fn new(rps: u32, burst: u32) -> Self {
192 Self {
193 inner: Arc::new(Mutex::new(TokenBucket::new(rps, burst))),
194 }
195 }
196
197 fn allow(&self) -> bool {
198 if let Ok(mut g) = self.inner.lock() {
201 g.allow()
202 } else {
203 false
204 }
205 }
206}
207
208#[derive(Debug)]
209struct TokenBucket {
210 rps: u32,
211 capacity: u32,
212 tokens: u32,
213 last: Instant,
214}
215
216impl TokenBucket {
217 fn new(rps: u32, burst: u32) -> Self {
218 let rps = rps.max(1);
219 let capacity = burst.max(1);
220 Self {
221 rps,
222 capacity,
223 tokens: capacity,
224 last: Instant::now(),
225 }
226 }
227
228 fn allow(&mut self) -> bool {
229 self.refill();
230
231 if self.tokens == 0 {
232 return false;
233 }
234 self.tokens -= 1;
235 true
236 }
237
238 fn refill(&mut self) {
239 let now = Instant::now();
240 let elapsed = now.duration_since(self.last);
241 if elapsed < Duration::from_millis(50) {
242 return;
243 }
244
245 let add = (elapsed.as_millis() as u64 * self.rps as u64 / 1000) as u32;
246 if add > 0 {
247 self.tokens = (self.tokens + add).min(self.capacity);
248 self.last = now;
249 }
250 }
251}