wsprism_gateway/
app_state.rs

1//! Shared application state for wsPrism Gateway.
2//!
3//! Sprint 3 Updated:
4//! - Wire RealtimeCore + Dispatcher, and register built-in services.
5//! - Make startup errors explicit (Result instead of panic).
6
7use std::collections::HashMap;
8use std::sync::Arc;
9
10use wsprism_core::error::{Result, WsPrismError};
11
12use crate::{config::GatewayConfig, policy};
13use crate::dispatch::Dispatcher;
14use crate::realtime::RealtimeCore;
15use crate::obs::metrics::GatewayMetrics;
16use crate::services::{ChatService, EchoBinaryService};
17// Sprint 5
18use crate::transport::handshake::HandshakeDefender;
19
20/// If true, the gateway fails fast on allowlist/dispatcher mismatches at boot.
21const FAIL_FAST_ON_MISMATCH: bool = false;
22
23/// Shared, clonable gateway application state (config + policy + runtimes).
24#[derive(Clone)]
25pub struct AppState {
26    inner: Arc<AppStateInner>,
27    realtime: Arc<RealtimeCore>,
28    dispatcher: Arc<Dispatcher>,
29    metrics: Arc<GatewayMetrics>,
30    // Sprint 5
31    handshake: Arc<HandshakeDefender>,
32}
33
34struct AppStateInner {
35    cfg: GatewayConfig,
36    tenant_policy: HashMap<String, Arc<policy::TenantPolicyRuntime>>,
37}
38
39impl AppState {
40    /// Build application state (config + compiled policies + runtimes).
41    ///
42    /// Returns `Result` so the binary can surface startup errors without panic.
43    pub fn new(cfg: GatewayConfig) -> Result<Self> {
44        let metrics = Arc::new(GatewayMetrics::default());
45        // Sprint 5
46        let handshake = Arc::new(HandshakeDefender::new(cfg.gateway.handshake_limit.clone()));
47
48        // 1) Compile tenant policy runtimes
49        let mut tenant_policy = HashMap::new();
50        for t in &cfg.tenants {
51            let runtime = policy::TenantPolicyRuntime::new(
52                t.id.clone(),
53                t.limits.max_frame_bytes,
54                &t.policy,
55            )
56            .map_err(|e| {
57                WsPrismError::BadRequest(format!(
58                    "tenant policy compile failed (tenant={}): {e}",
59                    t.id
60                ))
61            })?;
62
63            tenant_policy.insert(t.id.clone(), Arc::new(runtime));
64        }
65
66        // 2) Create core components
67        let realtime = Arc::new(RealtimeCore::new());
68        let dispatcher = Dispatcher::new();
69
70        // 3) Register built-in services (Sprint 3)
71        dispatcher.register_text(Arc::new(ChatService::new()));
72        dispatcher.register_hot(Arc::new(EchoBinaryService::new(1)));
73
74        // allowlist <-> dispatcher sanity check
75        {
76            let text_svcs = dispatcher.registered_text_svcs();
77            let hot_svcs = dispatcher.registered_hot_svcs();
78
79            let exempt_text = ["room", "sys"]; // transport/internal
80
81            for t in &cfg.tenants {
82                // ext rules: "svc:type"
83                for rule in &t.policy.ext_allowlist {
84                    if let Some((svc, _ty)) = rule.split_once(':') {
85                        if exempt_text.contains(&svc) { continue; }
86                        if !text_svcs.contains(&svc) {
87                            tracing::warn!(tenant=%t.id, rule=%rule, "ext_allowlist refers to unregistered text service");
88                            if FAIL_FAST_ON_MISMATCH {
89                                return Err(WsPrismError::BadRequest(format!(
90                                    "tenant {} ext_allowlist references unregistered text service: {}",
91                                    t.id, svc
92                                )));
93                            }
94                        }
95                    }
96                }
97
98                // hot rules: "sid:opcode"
99                for rule in &t.policy.hot_allowlist {
100                    if let Some((sid_s, _op)) = rule.split_once(':') {
101                        if let Ok(sid) = sid_s.parse::<u8>() {
102                            if !hot_svcs.contains(&sid) {
103                                tracing::warn!(tenant=%t.id, rule=%rule, sid=%sid, "hot_allowlist refers to unregistered binary service");
104                                if FAIL_FAST_ON_MISMATCH {
105                                    return Err(WsPrismError::BadRequest(format!(
106                                        "tenant {} hot_allowlist references unregistered hot service id: {}",
107                                        t.id, sid
108                                    )));
109                                }
110                            }
111                        }
112                    }
113                }
114            }
115        }
116
117        Ok(Self {
118            inner: Arc::new(AppStateInner { cfg, tenant_policy }),
119            realtime,
120            dispatcher: Arc::new(dispatcher),
121            metrics,
122            handshake,
123        })
124    }
125
126    pub fn cfg(&self) -> &GatewayConfig {
127        &self.inner.cfg
128    }
129
130    pub fn tenant_policy(&self, tenant_id: &str) -> Option<Arc<policy::TenantPolicyRuntime>> {
131        self.inner.tenant_policy.get(tenant_id).cloned()
132    }
133
134    pub fn resolve_ticket(&self, ticket: &str) -> Result<String> {
135        match ticket {
136            "dev" => Ok("user:dev".to_string()),
137            _ => Err(WsPrismError::AuthFailed),
138        }
139    }
140
141    pub fn realtime(&self) -> Arc<RealtimeCore> {
142        Arc::clone(&self.realtime)
143    }
144
145    pub fn dispatcher(&self) -> Arc<Dispatcher> {
146        Arc::clone(&self.dispatcher)
147    }
148
149    pub fn metrics(&self) -> Arc<GatewayMetrics> {
150        Arc::clone(&self.metrics)
151    }
152
153    pub fn handshake(&self) -> Arc<HandshakeDefender> {
154        Arc::clone(&self.handshake)
155    }
156
157    pub fn is_draining(&self) -> bool {
158        self.metrics.is_draining()
159    }
160
161    /// Enter draining mode (idempotent).
162    pub fn enter_draining(&self) {
163        let _ = self.metrics.set_draining();
164    }
165
166    /// Extra counters that are owned by other modules (egress drop/timeouts).
167    pub fn metrics_extra(&self) -> Vec<(&'static str, u64)> {
168        vec![
169            (
170                "wsprism_egress_drop_total",
171                crate::realtime::core::egress_drop_count(),
172            ),
173            (
174                "wsprism_egress_send_fail_total",
175                crate::realtime::core::egress_send_fail_count(),
176            ),
177        ]
178    }
179}