wsprism_gateway/
app_state.rs1use std::collections::HashMap;
8use std::sync::Arc;
9
10use wsprism_core::error::{Result, WsPrismError};
11
12use crate::{config::GatewayConfig, policy};
13use crate::dispatch::Dispatcher;
14use crate::realtime::RealtimeCore;
15use crate::obs::metrics::GatewayMetrics;
16use crate::services::{ChatService, EchoBinaryService};
17use crate::transport::handshake::HandshakeDefender;
19
20const FAIL_FAST_ON_MISMATCH: bool = false;
22
23#[derive(Clone)]
25pub struct AppState {
26 inner: Arc<AppStateInner>,
27 realtime: Arc<RealtimeCore>,
28 dispatcher: Arc<Dispatcher>,
29 metrics: Arc<GatewayMetrics>,
30 handshake: Arc<HandshakeDefender>,
32}
33
34struct AppStateInner {
35 cfg: GatewayConfig,
36 tenant_policy: HashMap<String, Arc<policy::TenantPolicyRuntime>>,
37}
38
39impl AppState {
40 pub fn new(cfg: GatewayConfig) -> Result<Self> {
44 let metrics = Arc::new(GatewayMetrics::default());
45 let handshake = Arc::new(HandshakeDefender::new(cfg.gateway.handshake_limit.clone()));
47
48 let mut tenant_policy = HashMap::new();
50 for t in &cfg.tenants {
51 let runtime = policy::TenantPolicyRuntime::new(
52 t.id.clone(),
53 t.limits.max_frame_bytes,
54 &t.policy,
55 )
56 .map_err(|e| {
57 WsPrismError::BadRequest(format!(
58 "tenant policy compile failed (tenant={}): {e}",
59 t.id
60 ))
61 })?;
62
63 tenant_policy.insert(t.id.clone(), Arc::new(runtime));
64 }
65
66 let realtime = Arc::new(RealtimeCore::new());
68 let dispatcher = Dispatcher::new();
69
70 dispatcher.register_text(Arc::new(ChatService::new()));
72 dispatcher.register_hot(Arc::new(EchoBinaryService::new(1)));
73
74 {
76 let text_svcs = dispatcher.registered_text_svcs();
77 let hot_svcs = dispatcher.registered_hot_svcs();
78
79 let exempt_text = ["room", "sys"]; for t in &cfg.tenants {
82 for rule in &t.policy.ext_allowlist {
84 if let Some((svc, _ty)) = rule.split_once(':') {
85 if exempt_text.contains(&svc) { continue; }
86 if !text_svcs.contains(&svc) {
87 tracing::warn!(tenant=%t.id, rule=%rule, "ext_allowlist refers to unregistered text service");
88 if FAIL_FAST_ON_MISMATCH {
89 return Err(WsPrismError::BadRequest(format!(
90 "tenant {} ext_allowlist references unregistered text service: {}",
91 t.id, svc
92 )));
93 }
94 }
95 }
96 }
97
98 for rule in &t.policy.hot_allowlist {
100 if let Some((sid_s, _op)) = rule.split_once(':') {
101 if let Ok(sid) = sid_s.parse::<u8>() {
102 if !hot_svcs.contains(&sid) {
103 tracing::warn!(tenant=%t.id, rule=%rule, sid=%sid, "hot_allowlist refers to unregistered binary service");
104 if FAIL_FAST_ON_MISMATCH {
105 return Err(WsPrismError::BadRequest(format!(
106 "tenant {} hot_allowlist references unregistered hot service id: {}",
107 t.id, sid
108 )));
109 }
110 }
111 }
112 }
113 }
114 }
115 }
116
117 Ok(Self {
118 inner: Arc::new(AppStateInner { cfg, tenant_policy }),
119 realtime,
120 dispatcher: Arc::new(dispatcher),
121 metrics,
122 handshake,
123 })
124 }
125
126 pub fn cfg(&self) -> &GatewayConfig {
127 &self.inner.cfg
128 }
129
130 pub fn tenant_policy(&self, tenant_id: &str) -> Option<Arc<policy::TenantPolicyRuntime>> {
131 self.inner.tenant_policy.get(tenant_id).cloned()
132 }
133
134 pub fn resolve_ticket(&self, ticket: &str) -> Result<String> {
135 match ticket {
136 "dev" => Ok("user:dev".to_string()),
137 _ => Err(WsPrismError::AuthFailed),
138 }
139 }
140
141 pub fn realtime(&self) -> Arc<RealtimeCore> {
142 Arc::clone(&self.realtime)
143 }
144
145 pub fn dispatcher(&self) -> Arc<Dispatcher> {
146 Arc::clone(&self.dispatcher)
147 }
148
149 pub fn metrics(&self) -> Arc<GatewayMetrics> {
150 Arc::clone(&self.metrics)
151 }
152
153 pub fn handshake(&self) -> Arc<HandshakeDefender> {
154 Arc::clone(&self.handshake)
155 }
156
157 pub fn is_draining(&self) -> bool {
158 self.metrics.is_draining()
159 }
160
161 pub fn enter_draining(&self) {
163 let _ = self.metrics.set_draining();
164 }
165
166 pub fn metrics_extra(&self) -> Vec<(&'static str, u64)> {
168 vec![
169 (
170 "wsprism_egress_drop_total",
171 crate::realtime::core::egress_drop_count(),
172 ),
173 (
174 "wsprism_egress_send_fail_total",
175 crate::realtime::core::egress_send_fail_count(),
176 ),
177 ]
178 }
179}