1use axum::{
9 extract::{connect_info::ConnectInfo, ws::CloseFrame, ws::Message, ws::WebSocket, ws::WebSocketUpgrade, Query, State},
10 http::{HeaderMap, StatusCode, header::RETRY_AFTER},
11 response::{IntoResponse},
12};
13use futures_util::{SinkExt, StreamExt};
14use serde::Deserialize;
15use serde_json::json;
16use tokio::sync::mpsc;
17use tokio::time::{timeout, Duration, Instant};
18use std::net::SocketAddr;
19use std::sync::atomic::{AtomicU64, Ordering};
20use std::sync::Arc;
21use std::time::{SystemTime, UNIX_EPOCH};
22use wsprism_core::error::{Result, WsPrismError};
23use crate::app_state::AppState;
24use crate::policy::engine::{ConnRateLimiter, HotErrorMode, OnExceed, PolicyDecision};
25use crate::realtime::core::Connection;
26use crate::realtime::RealtimeCore;
27use crate::realtime::RealtimeCtx;
28use crate::transport::codec::{decode, Inbound};
29use crate::transport::handshake::retry_after_header_secs;
30use crate::obs::metrics::GatewayMetrics;
31
32static NEXT_SID: AtomicU64 = AtomicU64::new(1);
33static NEXT_TRACE: AtomicU64 = AtomicU64::new(1);
34
35fn gen_sid() -> String { format!("{:x}", NEXT_SID.fetch_add(1, Ordering::Relaxed)) }
36fn gen_trace() -> String {
37 let now = SystemTime::now().duration_since(UNIX_EPOCH).unwrap_or_default().as_nanos();
38 let seq = NEXT_TRACE.fetch_add(1, Ordering::Relaxed);
39 format!("{:x}-{:x}", now, seq)
40}
41
42#[derive(Debug, Deserialize)]
44pub struct WsQuery {
45 pub tenant: String,
46 pub ticket: String,
47 #[serde(default)]
49 pub sid: Option<String>,
50}
51
52struct SessionState {
54 active_room: Option<String>,
55 last_activity: Instant,
56 conn_limiter: Option<ConnRateLimiter>,
57}
58
59fn sys_authed_json(tenant: &str, user: &str, sid: &str, trace_id: &str) -> String {
60 json!({ "v": 1, "svc": "sys", "type": "authed", "data": { "tenant": tenant, "user": user, "sid": sid }, "trace_id": trace_id }).to_string()
61}
62fn sys_error_json(code: &str, msg: &str, trace_id: &str) -> String {
63 json!({ "v": 1, "svc": "sys", "type": "error", "data": { "code": code, "msg": msg }, "trace_id": trace_id }).to_string()
64}
65fn sys_kicked_json(reason: &str, trace_id: &str) -> String {
66 json!({ "v": 1, "svc": "sys", "type": "kicked", "data": { "reason": reason }, "trace_id": trace_id }).to_string()
67}
68fn sys_joined_json(room: &str, trace_id: &str) -> String {
69 json!({ "v": 1, "svc": "sys", "type": "joined", "room": room, "trace_id": trace_id }).to_string()
70}
71fn sys_left_json(trace_id: &str) -> String {
72 json!({ "v": 1, "svc": "sys", "type": "left", "trace_id": trace_id }).to_string()
73}
74
75struct SessionCleanup {
77 core: Arc<RealtimeCore>, tenant_id: String, user_key: String, session_key: String, metrics: Arc<GatewayMetrics>,
78}
79impl Drop for SessionCleanup {
80 fn drop(&mut self) {
81 let _ = self.core.sessions.remove_session(&self.user_key, &self.session_key);
82 self.core.presence.cleanup_session(&self.tenant_id, &self.user_key, &self.session_key);
83 self.metrics.ws_active_sessions.dec(&[("tenant", &self.tenant_id)]);
84 tracing::debug!(s=%self.session_key, "session raii cleanup done");
85 }
86}
87
88pub async fn ws_upgrade(
89 State(app): State<AppState>, ConnectInfo(addr): ConnectInfo<SocketAddr>, ws: WebSocketUpgrade, Query(q): Query<WsQuery>,
90) -> impl IntoResponse {
91 if let Err(wait_secs) = app.handshake().check(addr.ip()).await {
92 app.metrics().handshake_rejections.inc(&[("tenant", &q.tenant), ("reason", "rate_limit")]);
93 let (val, _) = retry_after_header_secs(wait_secs);
94 let mut headers = HeaderMap::new();
95 headers.insert(RETRY_AFTER, val.parse().unwrap());
96 return (StatusCode::TOO_MANY_REQUESTS, headers, "Too Many Requests").into_response();
97 }
98 if app.is_draining() { return (StatusCode::SERVICE_UNAVAILABLE, "draining").into_response(); }
99 if let Some(t_cfg) = app.cfg().tenants.iter().find(|t| t.id == q.tenant) {
100 let limit = t_cfg.limits.max_sessions_total;
101 if limit > 0 {
102 let current = app.realtime().sessions.count_tenant_sessions(&q.tenant);
103 if current >= limit {
104 app.metrics().handshake_rejections.inc(&[("tenant", &q.tenant), ("reason", "tenant_capacity")]);
105 let mut headers = HeaderMap::new();
106 headers.insert(RETRY_AFTER, "1".parse().unwrap());
107 return (StatusCode::SERVICE_UNAVAILABLE, headers, "Tenant Capacity Exceeded").into_response();
108 }
109 }
110 } else { return (StatusCode::BAD_REQUEST, "Unknown Tenant").into_response(); }
111
112 app.metrics().ws_upgrades.inc(&[("tenant", &q.tenant), ("status", "ok")]);
113 ws.on_upgrade(move |socket| async move {
114 if let Err(e) = run_session(app, q, socket).await { tracing::error!("session error: {}", e); }
115 })
116}
117
118async fn run_session(app: AppState, q: WsQuery, socket: WebSocket) -> Result<()> {
119 let policy = app.tenant_policy(&q.tenant).ok_or(WsPrismError::BadRequest("unknown tenant".into()))?;
120 let user_id = app.resolve_ticket(&q.ticket)?;
121 let sid = q.sid.unwrap_or_else(gen_sid);
122 let trace_id = gen_trace();
123 let core = app.realtime();
124 let dispatcher = app.dispatcher();
125 let metrics = app.metrics();
126 let user_key = format!("{}::{}", q.tenant, user_id);
127 let session_key = format!("{}::{}::{}", q.tenant, user_id, sid);
128 let span = tracing::info_span!("ws", %trace_id, t=%q.tenant, u=%user_id, s=%sid);
129 let _enter = span.enter();
130 let (out_tx, mut out_rx) = mpsc::channel(1024);
131 let (mut ws_tx, mut ws_rx) = socket.split();
132
133 let sp = policy.session_policy();
134 let max_user_sessions = sp.max_sessions_per_user as usize;
135 let current_user_sessions = core.sessions.count_user_sessions(&user_key);
136 if current_user_sessions >= max_user_sessions {
137 metrics.policy_decisions.inc(&[("tenant", &q.tenant), ("lane", "session"), ("decision", "reject"), ("reason", "max_user_sessions")]);
138 match sp.on_exceed {
139 OnExceed::Deny => {
140 let _ = out_tx.send(Message::Text(sys_error_json("TOO_MANY_SESSIONS", "limit exceeded", &trace_id))).await;
141 return Ok(());
142 }
143 OnExceed::KickOldest => {
144 if let Some((victim, victim_conn)) = core.sessions.evict_oldest(&user_key) {
145 let _ = victim_conn.tx.try_send(Message::Text(sys_kicked_json("max_sessions_exceeded", &trace_id)));
146 let _ = victim_conn.tx.try_send(Message::Close(Some(CloseFrame { code: 1008, reason: "kicked".into() })));
147 core.presence.cleanup_session(&q.tenant, &user_key, &victim);
148 metrics.ws_active_sessions.dec(&[("tenant", &q.tenant)]);
149 }
150 }
151 }
152 }
153
154 let t_cfg = app.cfg().tenants.iter().find(|t| t.id == q.tenant).unwrap();
155 core.sessions.try_insert(q.tenant.clone(), user_key.clone(), session_key.clone(), Connection{ tx: out_tx.clone() }, t_cfg.limits.max_sessions_total)?;
156 metrics.ws_active_sessions.inc(&[("tenant", &q.tenant)]);
157 let _cleanup = SessionCleanup { core: core.clone(), tenant_id: q.tenant.clone(), user_key: user_key.clone(), session_key: session_key.clone(), metrics: metrics.clone() };
158 out_tx.send(Message::Text(sys_authed_json(&q.tenant, &user_id, &sid, &trace_id))).await.map_err(|_| WsPrismError::Internal("closed".into()))?;
159
160 let gw = &app.cfg().gateway;
161 let mut ping_tick = tokio::time::interval(Duration::from_millis(gw.ping_interval_ms));
162 let mut idle_tick = tokio::time::interval(Duration::from_millis(1000));
163 let idle_timeout = Duration::from_millis(gw.idle_timeout_ms);
164 let writer_timeout = Duration::from_millis(gw.writer_send_timeout_ms);
165 let mut sess = SessionState { active_room: None, last_activity: Instant::now(), conn_limiter: policy.new_connection_limiter() };
166
167 let mut hot_op_counter: u64 = 0;
169
170 loop {
171 tokio::select! {
172 maybe_out = out_rx.recv() => {
173 match maybe_out {
174 Some(m) => {
175 if timeout(writer_timeout, ws_tx.send(m)).await.is_err() {
176 metrics.writer_timeouts.inc(&[("tenant", &q.tenant)]);
177 break;
178 }
179 }
180 None => break,
181 }
182 }
183 incoming = ws_rx.next() => {
184 let Some(Ok(msg)) = incoming else { break; };
185 sess.last_activity = Instant::now();
186 let decoded = match decode(msg) {
187 Ok(d) => d,
188 Err(e) => {
189 metrics.decode_errors.inc(&[("tenant", &q.tenant)]);
190 let _ = out_tx.send(Message::Text(sys_error_json(e.client_code().as_str(), &e.to_string(), &trace_id))).await;
191 break;
192 }
193 };
194 match decoded {
195 Inbound::Ping(p) => { let _ = out_tx.send(Message::Pong(p)).await; },
196 Inbound::Pong(_) => {},
197 Inbound::Close => break,
198 Inbound::Text { env, bytes_len } => {
199 if let Some(lim) = sess.conn_limiter.as_mut() {
200 if !lim.allow() {
201 metrics.policy_decisions.inc(&[("tenant", &q.tenant), ("lane", "ext"), ("decision", "drop"), ("reason", "conn_rate_limit")]);
202 continue;
203 }
204 }
205 match policy.check_text(bytes_len, &env.svc, &env.msg_type) {
206 PolicyDecision::Pass => {},
207 PolicyDecision::Drop => {
208 metrics.policy_decisions.inc(&[("tenant", &q.tenant), ("lane", "ext"), ("decision", "drop"), ("reason", "policy")]);
209 continue;
210 },
211 PolicyDecision::Reject { code, msg } => {
212 metrics.policy_decisions.inc(&[("tenant", &q.tenant), ("lane", "ext"), ("decision", "reject"), ("reason", code.as_str())]);
214 let _ = out_tx.send(Message::Text(sys_error_json(code.as_str(), msg, &trace_id))).await;
215 continue;
216 },
217 PolicyDecision::Close { code, msg } => {
218 metrics.policy_decisions.inc(&[("tenant", &q.tenant), ("lane", "ext"), ("decision", "close"), ("reason", code.as_str())]);
220 let _ = out_tx.send(Message::Text(sys_error_json(code.as_str(), msg, &trace_id))).await;
221 break;
222 }
223 }
224 if env.svc == "room" && env.msg_type == "join" {
225 let room = env.room.clone().unwrap_or_else(|| "default".to_string());
226 let ctx = RealtimeCtx::new(q.tenant.clone(), user_id.clone(), sid.clone(), trace_id.clone(), sess.active_room.clone(), core.clone());
227 match ctx.join_room_with_limits(&room, &t_cfg.limits) {
228 Ok(_) => {
229 sess.active_room = Some(room.clone());
230 let _ = out_tx.send(Message::Text(sys_joined_json(&room, &trace_id))).await;
231 },
232 Err(e) => {
233 metrics.service_errors.inc(&[("tenant", &q.tenant), ("svc", "room"), ("type", "join_failed")]);
234 let _ = out_tx.send(Message::Text(sys_error_json(e.client_code().as_str(), &e.to_string(), &trace_id))).await;
235 }
236 }
237 continue;
238 }
239 if env.svc == "room" && env.msg_type == "leave" {
240 if let Some(room) = sess.active_room.take() {
241 let ctx = RealtimeCtx::new(q.tenant.clone(), user_id.clone(), sid.clone(), trace_id.clone(), None, core.clone());
242 ctx.leave_room(&room);
243 }
244 let _ = out_tx.send(Message::Text(sys_left_json(&trace_id))).await;
245 continue;
246 }
247 let ctx = RealtimeCtx::new(q.tenant.clone(), user_id.clone(), sid.clone(), trace_id.clone(), sess.active_room.clone(), core.clone());
248 let start = Instant::now();
249 let res = dispatcher.dispatch_text(ctx, env).await;
250 metrics.dispatch_duration.observe(&[("tenant", &q.tenant), ("lane", "ext")], start.elapsed());
252 if let Err(e) = res {
253 metrics.service_errors.inc(&[("tenant", &q.tenant), ("lane", "ext")]);
254 let _ = out_tx.send(Message::Text(sys_error_json(e.client_code().as_str(), &e.to_string(), &trace_id))).await;
255 }
256 },
257 Inbound::Hot { frame, bytes_len } => {
258 match policy.check_hot(bytes_len, frame.svc_id, frame.opcode) {
259 PolicyDecision::Pass => {},
260 PolicyDecision::Drop => {
261 metrics.policy_decisions.inc(&[("tenant", &q.tenant), ("lane", "hot"), ("decision", "drop"), ("reason", "policy")]);
262 continue;
263 },
264 PolicyDecision::Reject { code, msg } => {
265 metrics.policy_decisions.inc(&[("tenant", &q.tenant), ("lane", "hot"), ("decision", "reject"), ("reason", code.as_str())]);
266 if let HotErrorMode::SysError = policy.hot_error_mode() {
267 let _ = out_tx.send(Message::Text(sys_error_json(code.as_str(), msg, &trace_id))).await;
268 }
269 continue;
270 },
271 PolicyDecision::Close { code, msg } => {
272 metrics.policy_decisions.inc(&[("tenant", &q.tenant), ("lane", "hot"), ("decision", "close"), ("reason", code.as_str())]);
273 if let HotErrorMode::SysError = policy.hot_error_mode() {
274 let _ = out_tx.send(Message::Text(sys_error_json(code.as_str(), msg, &trace_id))).await;
275 }
276 break;
277 }
278 }
279 if policy.hot_requires_active_room() && sess.active_room.is_none() {
280 if let HotErrorMode::SysError = policy.hot_error_mode() {
281 let _ = out_tx.send(Message::Text(sys_error_json("BAD_REQUEST", "no active room", &trace_id))).await;
282 }
283 continue;
284 }
285 let ctx = RealtimeCtx::new(q.tenant.clone(), user_id.clone(), sid.clone(), trace_id.clone(), sess.active_room.clone(), core.clone());
286
287 hot_op_counter = hot_op_counter.wrapping_add(1);
289 let should_sample = (hot_op_counter & 1023) == 0;
290 let start = if should_sample { Some(Instant::now()) } else { None };
291
292 let res = dispatcher.dispatch_hot(ctx, frame).await;
293 if let Some(s) = start {
294 metrics.dispatch_duration.observe(&[("tenant", &q.tenant), ("lane", "hot")], s.elapsed());
295 }
296 if let Err(e) = res {
297 metrics.service_errors.inc(&[("tenant", &q.tenant), ("lane", "hot")]);
298 if let HotErrorMode::SysError = policy.hot_error_mode() {
299 let _ = out_tx.send(Message::Text(sys_error_json(e.client_code().as_str(), &e.to_string(), &trace_id))).await;
300 }
301 }
302 }
303 }
304 }
305 _ = ping_tick.tick() => { let _ = out_tx.send(Message::Ping(Vec::new())).await; }
306 _ = idle_tick.tick() => {
307 if sess.last_activity.elapsed() >= idle_timeout {
308 let _ = out_tx.send(Message::Text(sys_error_json("TIMEOUT", "idle", &trace_id))).await;
309 break;
310 }
311 }
312 }
313 }
314 Ok(())
315}