wsprism_gateway/transport/
ws.rs

1//! WebSocket handler (transport + lifecycle).
2//!
3//! - Pre-upgrade defenses: IP handshake limiter (429) and tenant capacity (503)
4//! - Trace ID generation/propagation into spans and sys.* messages
5//! - Session/room governance and policy enforcement
6//! - Labeled metrics for policy decisions/errors + sampled Hot Lane latency
7
8use axum::{
9    extract::{connect_info::ConnectInfo, ws::CloseFrame, ws::Message, ws::WebSocket, ws::WebSocketUpgrade, Query, State},
10    http::{HeaderMap, StatusCode, header::RETRY_AFTER},
11    response::{IntoResponse},
12};
13use futures_util::{SinkExt, StreamExt};
14use serde::Deserialize;
15use serde_json::json;
16use tokio::sync::mpsc;
17use tokio::time::{timeout, Duration, Instant};
18use std::net::SocketAddr;
19use std::sync::atomic::{AtomicU64, Ordering};
20use std::sync::Arc;
21use std::time::{SystemTime, UNIX_EPOCH};
22use wsprism_core::error::{Result, WsPrismError};
23use crate::app_state::AppState;
24use crate::policy::engine::{ConnRateLimiter, HotErrorMode, OnExceed, PolicyDecision};
25use crate::realtime::core::Connection;
26use crate::realtime::RealtimeCore;
27use crate::realtime::RealtimeCtx;
28use crate::transport::codec::{decode, Inbound};
29use crate::transport::handshake::retry_after_header_secs;
30use crate::obs::metrics::GatewayMetrics;
31
32static NEXT_SID: AtomicU64 = AtomicU64::new(1);
33static NEXT_TRACE: AtomicU64 = AtomicU64::new(1);
34
35fn gen_sid() -> String { format!("{:x}", NEXT_SID.fetch_add(1, Ordering::Relaxed)) }
36fn gen_trace() -> String {
37    let now = SystemTime::now().duration_since(UNIX_EPOCH).unwrap_or_default().as_nanos();
38    let seq = NEXT_TRACE.fetch_add(1, Ordering::Relaxed);
39    format!("{:x}-{:x}", now, seq)
40}
41
42/// WebSocket upgrade query parameters.
43#[derive(Debug, Deserialize)]
44pub struct WsQuery {
45    pub tenant: String,
46    pub ticket: String,
47    /// Optional client-provided session id (tab/browser id). Generated if absent.
48    #[serde(default)]
49    pub sid: Option<String>,
50}
51
52/// Per-connection mutable state used inside the WS loop.
53struct SessionState {
54    active_room: Option<String>,
55    last_activity: Instant,
56    conn_limiter: Option<ConnRateLimiter>,
57}
58
59fn sys_authed_json(tenant: &str, user: &str, sid: &str, trace_id: &str) -> String {
60    json!({ "v": 1, "svc": "sys", "type": "authed", "data": { "tenant": tenant, "user": user, "sid": sid }, "trace_id": trace_id }).to_string()
61}
62fn sys_error_json(code: &str, msg: &str, trace_id: &str) -> String {
63    json!({ "v": 1, "svc": "sys", "type": "error", "data": { "code": code, "msg": msg }, "trace_id": trace_id }).to_string()
64}
65fn sys_kicked_json(reason: &str, trace_id: &str) -> String {
66    json!({ "v": 1, "svc": "sys", "type": "kicked", "data": { "reason": reason }, "trace_id": trace_id }).to_string()
67}
68fn sys_joined_json(room: &str, trace_id: &str) -> String {
69    json!({ "v": 1, "svc": "sys", "type": "joined", "room": room, "trace_id": trace_id }).to_string()
70}
71fn sys_left_json(trace_id: &str) -> String {
72    json!({ "v": 1, "svc": "sys", "type": "left", "trace_id": trace_id }).to_string()
73}
74
75/// RAII guard that tears down session and presence entries on exit.
76struct SessionCleanup {
77    core: Arc<RealtimeCore>, tenant_id: String, user_key: String, session_key: String, metrics: Arc<GatewayMetrics>,
78}
79impl Drop for SessionCleanup {
80    fn drop(&mut self) {
81        let _ = self.core.sessions.remove_session(&self.user_key, &self.session_key);
82        self.core.presence.cleanup_session(&self.tenant_id, &self.user_key, &self.session_key);
83        self.metrics.ws_active_sessions.dec(&[("tenant", &self.tenant_id)]);
84        tracing::debug!(s=%self.session_key, "session raii cleanup done");
85    }
86}
87
88pub async fn ws_upgrade(
89    State(app): State<AppState>, ConnectInfo(addr): ConnectInfo<SocketAddr>, ws: WebSocketUpgrade, Query(q): Query<WsQuery>,
90) -> impl IntoResponse {
91    if let Err(wait_secs) = app.handshake().check(addr.ip()).await {
92        app.metrics().handshake_rejections.inc(&[("tenant", &q.tenant), ("reason", "rate_limit")]);
93        let (val, _) = retry_after_header_secs(wait_secs);
94        let mut headers = HeaderMap::new();
95        headers.insert(RETRY_AFTER, val.parse().unwrap());
96        return (StatusCode::TOO_MANY_REQUESTS, headers, "Too Many Requests").into_response();
97    }
98    if app.is_draining() { return (StatusCode::SERVICE_UNAVAILABLE, "draining").into_response(); }
99    if let Some(t_cfg) = app.cfg().tenants.iter().find(|t| t.id == q.tenant) {
100        let limit = t_cfg.limits.max_sessions_total;
101        if limit > 0 {
102            let current = app.realtime().sessions.count_tenant_sessions(&q.tenant);
103            if current >= limit {
104                app.metrics().handshake_rejections.inc(&[("tenant", &q.tenant), ("reason", "tenant_capacity")]);
105                let mut headers = HeaderMap::new();
106                headers.insert(RETRY_AFTER, "1".parse().unwrap());
107                return (StatusCode::SERVICE_UNAVAILABLE, headers, "Tenant Capacity Exceeded").into_response();
108            }
109        }
110    } else { return (StatusCode::BAD_REQUEST, "Unknown Tenant").into_response(); }
111
112    app.metrics().ws_upgrades.inc(&[("tenant", &q.tenant), ("status", "ok")]);
113    ws.on_upgrade(move |socket| async move {
114        if let Err(e) = run_session(app, q, socket).await { tracing::error!("session error: {}", e); }
115    })
116}
117
118async fn run_session(app: AppState, q: WsQuery, socket: WebSocket) -> Result<()> {
119    let policy = app.tenant_policy(&q.tenant).ok_or(WsPrismError::BadRequest("unknown tenant".into()))?;
120    let user_id = app.resolve_ticket(&q.ticket)?;
121    let sid = q.sid.unwrap_or_else(gen_sid);
122    let trace_id = gen_trace();
123    let core = app.realtime();
124    let dispatcher = app.dispatcher();
125    let metrics = app.metrics();
126    let user_key = format!("{}::{}", q.tenant, user_id);
127    let session_key = format!("{}::{}::{}", q.tenant, user_id, sid);
128    let span = tracing::info_span!("ws", %trace_id, t=%q.tenant, u=%user_id, s=%sid);
129    let _enter = span.enter();
130    let (out_tx, mut out_rx) = mpsc::channel(1024);
131    let (mut ws_tx, mut ws_rx) = socket.split();
132
133    let sp = policy.session_policy();
134    let max_user_sessions = sp.max_sessions_per_user as usize;
135    let current_user_sessions = core.sessions.count_user_sessions(&user_key);
136    if current_user_sessions >= max_user_sessions {
137         metrics.policy_decisions.inc(&[("tenant", &q.tenant), ("lane", "session"), ("decision", "reject"), ("reason", "max_user_sessions")]);
138         match sp.on_exceed {
139             OnExceed::Deny => {
140                 let _ = out_tx.send(Message::Text(sys_error_json("TOO_MANY_SESSIONS", "limit exceeded", &trace_id))).await;
141                 return Ok(());
142             }
143             OnExceed::KickOldest => {
144                 if let Some((victim, victim_conn)) = core.sessions.evict_oldest(&user_key) {
145                     let _ = victim_conn.tx.try_send(Message::Text(sys_kicked_json("max_sessions_exceeded", &trace_id)));
146                     let _ = victim_conn.tx.try_send(Message::Close(Some(CloseFrame { code: 1008, reason: "kicked".into() })));
147                     core.presence.cleanup_session(&q.tenant, &user_key, &victim);
148                     metrics.ws_active_sessions.dec(&[("tenant", &q.tenant)]);
149                 }
150             }
151         }
152    }
153
154    let t_cfg = app.cfg().tenants.iter().find(|t| t.id == q.tenant).unwrap();
155    core.sessions.try_insert(q.tenant.clone(), user_key.clone(), session_key.clone(), Connection{ tx: out_tx.clone() }, t_cfg.limits.max_sessions_total)?;
156    metrics.ws_active_sessions.inc(&[("tenant", &q.tenant)]);
157    let _cleanup = SessionCleanup { core: core.clone(), tenant_id: q.tenant.clone(), user_key: user_key.clone(), session_key: session_key.clone(), metrics: metrics.clone() };
158    out_tx.send(Message::Text(sys_authed_json(&q.tenant, &user_id, &sid, &trace_id))).await.map_err(|_| WsPrismError::Internal("closed".into()))?;
159
160    let gw = &app.cfg().gateway;
161    let mut ping_tick = tokio::time::interval(Duration::from_millis(gw.ping_interval_ms));
162    let mut idle_tick = tokio::time::interval(Duration::from_millis(1000));
163    let idle_timeout = Duration::from_millis(gw.idle_timeout_ms);
164    let writer_timeout = Duration::from_millis(gw.writer_send_timeout_ms);
165    let mut sess = SessionState { active_room: None, last_activity: Instant::now(), conn_limiter: policy.new_connection_limiter() };
166    
167    // Sampling Counter
168    let mut hot_op_counter: u64 = 0;
169
170    loop {
171        tokio::select! {
172            maybe_out = out_rx.recv() => {
173                match maybe_out {
174                    Some(m) => {
175                        if timeout(writer_timeout, ws_tx.send(m)).await.is_err() {
176                             metrics.writer_timeouts.inc(&[("tenant", &q.tenant)]);
177                             break;
178                        }
179                    }
180                    None => break,
181                }
182            }
183            incoming = ws_rx.next() => {
184                let Some(Ok(msg)) = incoming else { break; };
185                sess.last_activity = Instant::now();
186                let decoded = match decode(msg) {
187                    Ok(d) => d,
188                    Err(e) => {
189                        metrics.decode_errors.inc(&[("tenant", &q.tenant)]);
190                        let _ = out_tx.send(Message::Text(sys_error_json(e.client_code().as_str(), &e.to_string(), &trace_id))).await;
191                        break;
192                    }
193                };
194                match decoded {
195                    Inbound::Ping(p) => { let _ = out_tx.send(Message::Pong(p)).await; },
196                    Inbound::Pong(_) => {},
197                    Inbound::Close => break,
198                    Inbound::Text { env, bytes_len } => {
199                        if let Some(lim) = sess.conn_limiter.as_mut() {
200                            if !lim.allow() { 
201                                metrics.policy_decisions.inc(&[("tenant", &q.tenant), ("lane", "ext"), ("decision", "drop"), ("reason", "conn_rate_limit")]);
202                                continue; 
203                            }
204                        }
205                        match policy.check_text(bytes_len, &env.svc, &env.msg_type) {
206                            PolicyDecision::Pass => {},
207                            PolicyDecision::Drop => { 
208                                metrics.policy_decisions.inc(&[("tenant", &q.tenant), ("lane", "ext"), ("decision", "drop"), ("reason", "policy")]);
209                                continue; 
210                            },
211                            PolicyDecision::Reject { code, msg } => {
212                                // SAFE LABEL: code.as_str()
213                                metrics.policy_decisions.inc(&[("tenant", &q.tenant), ("lane", "ext"), ("decision", "reject"), ("reason", code.as_str())]);
214                                let _ = out_tx.send(Message::Text(sys_error_json(code.as_str(), msg, &trace_id))).await;
215                                continue;
216                            },
217                            PolicyDecision::Close { code, msg } => {
218                                // SAFE LABEL: code.as_str()
219                                metrics.policy_decisions.inc(&[("tenant", &q.tenant), ("lane", "ext"), ("decision", "close"), ("reason", code.as_str())]);
220                                let _ = out_tx.send(Message::Text(sys_error_json(code.as_str(), msg, &trace_id))).await;
221                                break;
222                            }
223                        }
224                        if env.svc == "room" && env.msg_type == "join" {
225                            let room = env.room.clone().unwrap_or_else(|| "default".to_string());
226                            let ctx = RealtimeCtx::new(q.tenant.clone(), user_id.clone(), sid.clone(), trace_id.clone(), sess.active_room.clone(), core.clone());
227                            match ctx.join_room_with_limits(&room, &t_cfg.limits) {
228                                Ok(_) => {
229                                    sess.active_room = Some(room.clone());
230                                    let _ = out_tx.send(Message::Text(sys_joined_json(&room, &trace_id))).await;
231                                },
232                                Err(e) => {
233                                    metrics.service_errors.inc(&[("tenant", &q.tenant), ("svc", "room"), ("type", "join_failed")]);
234                                    let _ = out_tx.send(Message::Text(sys_error_json(e.client_code().as_str(), &e.to_string(), &trace_id))).await;
235                                }
236                            }
237                            continue;
238                        }
239                        if env.svc == "room" && env.msg_type == "leave" {
240                            if let Some(room) = sess.active_room.take() {
241                                let ctx = RealtimeCtx::new(q.tenant.clone(), user_id.clone(), sid.clone(), trace_id.clone(), None, core.clone());
242                                ctx.leave_room(&room);
243                            }
244                            let _ = out_tx.send(Message::Text(sys_left_json(&trace_id))).await;
245                            continue;
246                        }
247                        let ctx = RealtimeCtx::new(q.tenant.clone(), user_id.clone(), sid.clone(), trace_id.clone(), sess.active_room.clone(), core.clone());
248                        let start = Instant::now();
249                        let res = dispatcher.dispatch_text(ctx, env).await;
250                        // Always measure Ext lane
251                        metrics.dispatch_duration.observe(&[("tenant", &q.tenant), ("lane", "ext")], start.elapsed());
252                        if let Err(e) = res {
253                             metrics.service_errors.inc(&[("tenant", &q.tenant), ("lane", "ext")]);
254                             let _ = out_tx.send(Message::Text(sys_error_json(e.client_code().as_str(), &e.to_string(), &trace_id))).await;
255                        }
256                    },
257                    Inbound::Hot { frame, bytes_len } => {
258                         match policy.check_hot(bytes_len, frame.svc_id, frame.opcode) {
259                            PolicyDecision::Pass => {},
260                            PolicyDecision::Drop => { 
261                                metrics.policy_decisions.inc(&[("tenant", &q.tenant), ("lane", "hot"), ("decision", "drop"), ("reason", "policy")]);
262                                continue; 
263                            },
264                            PolicyDecision::Reject { code, msg } => {
265                                metrics.policy_decisions.inc(&[("tenant", &q.tenant), ("lane", "hot"), ("decision", "reject"), ("reason", code.as_str())]);
266                                if let HotErrorMode::SysError = policy.hot_error_mode() {
267                                    let _ = out_tx.send(Message::Text(sys_error_json(code.as_str(), msg, &trace_id))).await;
268                                }
269                                continue;
270                            },
271                            PolicyDecision::Close { code, msg } => {
272                                metrics.policy_decisions.inc(&[("tenant", &q.tenant), ("lane", "hot"), ("decision", "close"), ("reason", code.as_str())]);
273                                if let HotErrorMode::SysError = policy.hot_error_mode() {
274                                    let _ = out_tx.send(Message::Text(sys_error_json(code.as_str(), msg, &trace_id))).await;
275                                }
276                                break;
277                            }
278                         }
279                         if policy.hot_requires_active_room() && sess.active_room.is_none() {
280                             if let HotErrorMode::SysError = policy.hot_error_mode() {
281                                 let _ = out_tx.send(Message::Text(sys_error_json("BAD_REQUEST", "no active room", &trace_id))).await;
282                             }
283                             continue;
284                         }
285                         let ctx = RealtimeCtx::new(q.tenant.clone(), user_id.clone(), sid.clone(), trace_id.clone(), sess.active_room.clone(), core.clone());
286                         
287                         // Hot Lane Sampling (1/1024)
288                         hot_op_counter = hot_op_counter.wrapping_add(1);
289                         let should_sample = (hot_op_counter & 1023) == 0;
290                         let start = if should_sample { Some(Instant::now()) } else { None };
291                         
292                         let res = dispatcher.dispatch_hot(ctx, frame).await;
293                         if let Some(s) = start {
294                             metrics.dispatch_duration.observe(&[("tenant", &q.tenant), ("lane", "hot")], s.elapsed());
295                         }
296                         if let Err(e) = res {
297                             metrics.service_errors.inc(&[("tenant", &q.tenant), ("lane", "hot")]);
298                             if let HotErrorMode::SysError = policy.hot_error_mode() {
299                                 let _ = out_tx.send(Message::Text(sys_error_json(e.client_code().as_str(), &e.to_string(), &trace_id))).await;
300                             }
301                         }
302                    }
303                }
304            }
305            _ = ping_tick.tick() => { let _ = out_tx.send(Message::Ping(Vec::new())).await; }
306            _ = idle_tick.tick() => {
307                if sess.last_activity.elapsed() >= idle_timeout {
308                    let _ = out_tx.send(Message::Text(sys_error_json("TIMEOUT", "idle", &trace_id))).await;
309                    break;
310                }
311            }
312        }
313    }
314    Ok(())
315}