wsprism_gateway/realtime/core/
session_registry.rs

1use axum::extract::ws::Message;
2use dashmap::{DashMap, DashSet};
3use tokio::sync::mpsc;
4
5use std::sync::atomic::{AtomicU64, Ordering};
6use wsprism_core::error::{Result, WsPrismError};
7
8/// One session's outbound queue sender.
9#[derive(Clone)]
10pub struct Connection {
11    pub tx: mpsc::Sender<Message>,
12}
13
14#[derive(Clone)]
15struct SessionEntry {
16    conn: Connection,
17    created_seq: u64,
18    // Sprint 5: store tenant here to facilitate cleanup without looking up other maps
19    tenant_id: String,
20}
21
22/// Session registry:
23/// - `session_key -> Connection`
24/// - `user_key -> {session_key...}`
25/// - `tenant_id -> count` (Atomic)
26#[derive(Default)]
27pub struct SessionRegistry {
28    sessions: DashMap<String, SessionEntry>,
29    user_index: DashMap<String, DashSet<String>>,
30    // Sprint 5: O(1) Tenant Counter
31    tenant_counts: DashMap<String, AtomicU64>,
32    seq: AtomicU64,
33}
34
35impl SessionRegistry {
36    pub fn new() -> Self {
37        Self {
38            sessions: DashMap::new(),
39            user_index: DashMap::new(),
40            tenant_counts: DashMap::new(),
41            seq: AtomicU64::new(1),
42        }
43    }
44
45    // Sprint 5: try_insert with limits enforcement
46    /// Insert a session while enforcing a tenant-wide cap (best-effort).
47    ///
48    /// Concurrency note: For throughput, this uses lock-free atomics plus an
49    /// optimistic increment/check pattern. Under extreme contention, a small
50    /// temporary overshoot of `max_total` is possible before the counter is
51    /// corrected. This is an intentional trade-off to avoid global locks.
52    pub fn try_insert(
53        &self,
54        tenant_id: String,
55        user_key: String,
56        session_key: String,
57        conn: Connection,
58        max_total: u64
59    ) -> Result<()> {
60        let counter = self.tenant_counts.entry(tenant_id.clone()).or_insert_with(|| AtomicU64::new(0));
61
62        // Strict enforcement
63        if max_total > 0 {
64            let current = counter.load(Ordering::Relaxed);
65            if current >= max_total {
66                 return Err(WsPrismError::ResourceExhausted("tenant session limit reached".into()));
67            }
68        }
69
70        // Optimistic increment
71        counter.fetch_add(1, Ordering::Relaxed);
72
73        // Check again (race condition mitigation) - Optional but safer
74        if max_total > 0 {
75            if counter.load(Ordering::Relaxed) > max_total {
76                counter.fetch_sub(1, Ordering::Relaxed);
77                return Err(WsPrismError::ResourceExhausted("tenant session limit reached (race)".into()));
78            }
79        }
80
81        self.user_index
82            .entry(user_key)
83            .or_insert_with(DashSet::new)
84            .insert(session_key.clone());
85
86        let created_seq = self.seq.fetch_add(1, Ordering::Relaxed);
87        self.sessions.insert(session_key, SessionEntry { conn, created_seq, tenant_id });
88
89        Ok(())
90    }
91
92    pub fn remove_session(&self, user_key: &str, session_key: &str) -> Option<Connection> {
93        if let Some(set) = self.user_index.get(user_key) {
94            set.remove(session_key);
95            if set.is_empty() {
96                drop(set);
97                self.user_index.remove(user_key);
98            }
99        }
100
101        if let Some((_, entry)) = self.sessions.remove(session_key) {
102            // Sprint 5: Decrement tenant counter
103            if let Some(counter) = self.tenant_counts.get(&entry.tenant_id) {
104                counter.fetch_sub(1, Ordering::Relaxed);
105            }
106            Some(entry.conn)
107        } else {
108            None
109        }
110    }
111
112    pub fn get_session(&self, session_key: &str) -> Option<Connection> {
113        self.sessions.get(session_key).map(|r| r.value().conn.clone())
114    }
115
116    pub fn get_user_sessions(&self, user_key: &str) -> Vec<Connection> {
117        let Some(set) = self.user_index.get(user_key) else { return vec![]; };
118        set.iter()
119            .filter_map(|sid| self.get_session(sid.key()))
120            .collect()
121    }
122
123    pub fn count_user_sessions(&self, user_key: &str) -> usize {
124        self.user_index.get(user_key).map(|s| s.len()).unwrap_or(0)
125    }
126
127    // Sprint 5: O(1) Tenant Counter Lookup
128    pub fn count_tenant_sessions(&self, tenant_id: &str) -> u64 {
129        self.tenant_counts.get(tenant_id).map(|c| c.load(Ordering::Relaxed)).unwrap_or(0)
130    }
131
132    /// Snapshot of all active sessions.
133    ///
134    /// Returns a vector of (session_key, Connection). Intended for best-effort
135    /// shutdown/draining logic.
136    pub fn all_sessions(&self) -> Vec<(String, Connection)> {
137        self.sessions
138            .iter()
139            .map(|r| (r.key().clone(), r.value().conn.clone()))
140            .collect()
141    }
142
143    pub fn len_sessions(&self) -> usize {
144        self.sessions.len()
145    }
146
147    /// Evict the oldest session for this user.
148    /// Returns (victim_session_key, victim_connection).
149    pub fn evict_oldest(&self, user_key: &str) -> Option<(String, Connection)> {
150        let set = self.user_index.get(user_key)?;
151        let keys: Vec<String> = set.iter().map(|s| s.key().to_string()).collect();
152        drop(set);
153
154        let mut victim_key: Option<String> = None;
155        let mut victim_seq: u64 = u64::MAX;
156        for k in &keys {
157            if let Some(e) = self.sessions.get(k) {
158                if e.value().created_seq < victim_seq {
159                    victim_seq = e.value().created_seq;
160                    victim_key = Some(k.clone());
161                }
162            }
163        }
164
165        let victim_key = victim_key?;
166        let conn = self.remove_session(user_key, &victim_key)?;
167        Some((victim_key, conn))
168    }
169}