wsprism_gateway/realtime/core/
session_registry.rs1use axum::extract::ws::Message;
2use dashmap::{DashMap, DashSet};
3use tokio::sync::mpsc;
4
5use std::sync::atomic::{AtomicU64, Ordering};
6use wsprism_core::error::{Result, WsPrismError};
7
8#[derive(Clone)]
10pub struct Connection {
11 pub tx: mpsc::Sender<Message>,
12}
13
14#[derive(Clone)]
15struct SessionEntry {
16 conn: Connection,
17 created_seq: u64,
18 tenant_id: String,
20}
21
22#[derive(Default)]
27pub struct SessionRegistry {
28 sessions: DashMap<String, SessionEntry>,
29 user_index: DashMap<String, DashSet<String>>,
30 tenant_counts: DashMap<String, AtomicU64>,
32 seq: AtomicU64,
33}
34
35impl SessionRegistry {
36 pub fn new() -> Self {
37 Self {
38 sessions: DashMap::new(),
39 user_index: DashMap::new(),
40 tenant_counts: DashMap::new(),
41 seq: AtomicU64::new(1),
42 }
43 }
44
45 pub fn try_insert(
53 &self,
54 tenant_id: String,
55 user_key: String,
56 session_key: String,
57 conn: Connection,
58 max_total: u64
59 ) -> Result<()> {
60 let counter = self.tenant_counts.entry(tenant_id.clone()).or_insert_with(|| AtomicU64::new(0));
61
62 if max_total > 0 {
64 let current = counter.load(Ordering::Relaxed);
65 if current >= max_total {
66 return Err(WsPrismError::ResourceExhausted("tenant session limit reached".into()));
67 }
68 }
69
70 counter.fetch_add(1, Ordering::Relaxed);
72
73 if max_total > 0 {
75 if counter.load(Ordering::Relaxed) > max_total {
76 counter.fetch_sub(1, Ordering::Relaxed);
77 return Err(WsPrismError::ResourceExhausted("tenant session limit reached (race)".into()));
78 }
79 }
80
81 self.user_index
82 .entry(user_key)
83 .or_insert_with(DashSet::new)
84 .insert(session_key.clone());
85
86 let created_seq = self.seq.fetch_add(1, Ordering::Relaxed);
87 self.sessions.insert(session_key, SessionEntry { conn, created_seq, tenant_id });
88
89 Ok(())
90 }
91
92 pub fn remove_session(&self, user_key: &str, session_key: &str) -> Option<Connection> {
93 if let Some(set) = self.user_index.get(user_key) {
94 set.remove(session_key);
95 if set.is_empty() {
96 drop(set);
97 self.user_index.remove(user_key);
98 }
99 }
100
101 if let Some((_, entry)) = self.sessions.remove(session_key) {
102 if let Some(counter) = self.tenant_counts.get(&entry.tenant_id) {
104 counter.fetch_sub(1, Ordering::Relaxed);
105 }
106 Some(entry.conn)
107 } else {
108 None
109 }
110 }
111
112 pub fn get_session(&self, session_key: &str) -> Option<Connection> {
113 self.sessions.get(session_key).map(|r| r.value().conn.clone())
114 }
115
116 pub fn get_user_sessions(&self, user_key: &str) -> Vec<Connection> {
117 let Some(set) = self.user_index.get(user_key) else { return vec![]; };
118 set.iter()
119 .filter_map(|sid| self.get_session(sid.key()))
120 .collect()
121 }
122
123 pub fn count_user_sessions(&self, user_key: &str) -> usize {
124 self.user_index.get(user_key).map(|s| s.len()).unwrap_or(0)
125 }
126
127 pub fn count_tenant_sessions(&self, tenant_id: &str) -> u64 {
129 self.tenant_counts.get(tenant_id).map(|c| c.load(Ordering::Relaxed)).unwrap_or(0)
130 }
131
132 pub fn all_sessions(&self) -> Vec<(String, Connection)> {
137 self.sessions
138 .iter()
139 .map(|r| (r.key().clone(), r.value().conn.clone()))
140 .collect()
141 }
142
143 pub fn len_sessions(&self) -> usize {
144 self.sessions.len()
145 }
146
147 pub fn evict_oldest(&self, user_key: &str) -> Option<(String, Connection)> {
150 let set = self.user_index.get(user_key)?;
151 let keys: Vec<String> = set.iter().map(|s| s.key().to_string()).collect();
152 drop(set);
153
154 let mut victim_key: Option<String> = None;
155 let mut victim_seq: u64 = u64::MAX;
156 for k in &keys {
157 if let Some(e) = self.sessions.get(k) {
158 if e.value().created_seq < victim_seq {
159 victim_seq = e.value().created_seq;
160 victim_key = Some(k.clone());
161 }
162 }
163 }
164
165 let victim_key = victim_key?;
166 let conn = self.remove_session(user_key, &victim_key)?;
167 Some((victim_key, conn))
168 }
169}