diff --git a/gamestream-webtransport-proxy/src/main.rs b/gamestream-webtransport-proxy/src/main.rs index 7b94830..8d338be 100644 --- a/gamestream-webtransport-proxy/src/main.rs +++ b/gamestream-webtransport-proxy/src/main.rs @@ -124,9 +124,9 @@ async fn run_backend(port: u16) -> Result<()> { Ok(()) } -async fn run_proxy(port: u16, stream_id: uuid::Uuid) -> Result<()> { +async fn run_proxy(port: u16, stream_id: uuid::Uuid, stream_token: String) -> Result<()> { let (config, cert_hash) = certs::get_webtransport_stream_config(stream_id)?; - let proxy = proxy::Proxy::new(cert_hash); + let proxy = proxy::Proxy::new(cert_hash, stream_token); let proxy_arc = std::sync::Arc::new(proxy); let router = Router::new() @@ -166,8 +166,11 @@ async fn main() -> anyhow::Result<()> { .nth(3) .ok_or(anyhow!("Cert ID argument missing"))?, )?; + let stream_token = std::env::args() + .nth(4) + .ok_or(anyhow!("Stream token argument missing"))?; - run_proxy(port, stream_id).await + run_proxy(port, stream_id, stream_token).await } _ => Err(anyhow!("Unknown mode: {mode}")), } diff --git a/gamestream-webtransport-proxy/src/proxy/handler.rs b/gamestream-webtransport-proxy/src/proxy/handler.rs index da9ed27..dbbd14a 100644 --- a/gamestream-webtransport-proxy/src/proxy/handler.rs +++ b/gamestream-webtransport-proxy/src/proxy/handler.rs @@ -85,6 +85,18 @@ impl crate::proxy::Proxy { description: "Could not start stream".to_string(), }); + // Validate single-use stream token via the shared helper so this + // handler and its unit tests exercise the same code path. + let provided_token = req.query::("token").unwrap_or_default(); + if let Err(msg) = super::validate_stream_token(&self, &provided_token).await { + error!("Stream token validation failed: {msg}"); + return Err(AppError { + status_code: StatusCode::UNAUTHORIZED, + description: msg, + }); + } + info!("Stream token validated and consumed"); + info!("WebTransport connection initiated"); let (wt_stream_send, wt_stream_recv, wt_datagram_send) = match setup_webtransport(req).await { diff --git a/gamestream-webtransport-proxy/src/proxy/mod.rs b/gamestream-webtransport-proxy/src/proxy/mod.rs index a74caa5..93a8825 100644 --- a/gamestream-webtransport-proxy/src/proxy/mod.rs +++ b/gamestream-webtransport-proxy/src/proxy/mod.rs @@ -11,16 +11,16 @@ mod video; pub struct Proxy { pub cert_hash: [u8; 32], - //pub cert_hash: String, pub stream: RwLock>, + pub stream_token: RwLock>, } impl Proxy { - pub fn new(cert_hash: [u8; 32]) -> Self { - //pub fn new(cert_hash: String) -> Self { + pub fn new(cert_hash: [u8; 32], stream_token: String) -> Self { Proxy { stream: RwLock::new(None), cert_hash, + stream_token: RwLock::new(Some(stream_token)), } } } @@ -78,6 +78,22 @@ async fn proxy_main( Ok(()) } +/// Validate a provided token against the stored token. Consumes the token on success (single-use). +/// Returns Ok(()) if valid, Err with description if invalid or already consumed. +pub async fn validate_stream_token(proxy: &Proxy, provided: &str) -> std::result::Result<(), String> { + let mut token_guard = proxy.stream_token.write().await; + match token_guard.take() { + Some(expected) if expected == provided => Ok(()), + Some(_) => { + // Wrong token: still consumed by the `take()` above. Any validation + // attempt — correct or not — invalidates the token, so a wrong + // guess cannot be followed by a correct one. + Err("Invalid stream token".to_string()) + } + None => Err("Stream token already used".to_string()), + } +} + async fn spawn_gamestream(stream: backend::Stream) -> Result { let (tx, rx) = tokio::sync::oneshot::channel(); let (stop_tx, stop_rx) = tokio::sync::oneshot::channel::<()>(); @@ -99,3 +115,59 @@ async fn spawn_gamestream(stream: backend::Stream) -> Result { .context("Could not get gamestream communication channels")?, }) } + +#[cfg(test)] +mod tests { + use super::*; + + fn make_proxy(token: &str) -> Proxy { + Proxy { + cert_hash: [0u8; 32], + stream: RwLock::new(None), + stream_token: RwLock::new(Some(token.to_string())), + } + } + + #[tokio::test] + async fn test_valid_token_accepted() { + let proxy = make_proxy("abc123"); + let result = validate_stream_token(&proxy, "abc123").await; + assert!(result.is_ok()); + } + + #[tokio::test] + async fn test_wrong_token_rejected() { + let proxy = make_proxy("abc123"); + let result = validate_stream_token(&proxy, "wrong").await; + assert!(result.is_err()); + assert_eq!(result.unwrap_err(), "Invalid stream token"); + } + + #[tokio::test] + async fn test_missing_token_rejected() { + let proxy = make_proxy("abc123"); + let result = validate_stream_token(&proxy, "").await; + assert!(result.is_err()); + } + + #[tokio::test] + async fn test_token_consumed_after_use() { + let proxy = make_proxy("abc123"); + let first = validate_stream_token(&proxy, "abc123").await; + assert!(first.is_ok()); + + let second = validate_stream_token(&proxy, "abc123").await; + assert!(second.is_err()); + assert_eq!(second.unwrap_err(), "Stream token already used"); + } + + #[tokio::test] + async fn test_wrong_attempt_consumes_token() { + let proxy = make_proxy("abc123"); + // Wrong token attempt should consume it + let _ = validate_stream_token(&proxy, "wrong").await; + // Correct token should also fail now + let result = validate_stream_token(&proxy, "abc123").await; + assert!(result.is_err()); + } +} diff --git a/gamestream-webtransport-proxy/src/stream.rs b/gamestream-webtransport-proxy/src/stream.rs index 1ca9dd3..edd4ba2 100644 --- a/gamestream-webtransport-proxy/src/stream.rs +++ b/gamestream-webtransport-proxy/src/stream.rs @@ -25,7 +25,7 @@ struct PostStreamStartParams { struct PostStreamStartResponse { url: String, cert_hash: [u8; 32], - //cert_hash: String, + stream_token: String, } #[derive(Deserialize)] @@ -301,6 +301,19 @@ impl crate::backend::Backend { let port = self.port + ::try_from((*writer).len()).unwrap(); + // Generate single-use stream token for proxy authentication + let stream_token = { + let mut bytes = [0u8; 32]; + openssl::rand::rand_bytes(&mut bytes).map_err(|e| { + error!("Failed to generate stream token: {e}"); + AppError { + status_code: StatusCode::INTERNAL_SERVER_ERROR, + description: "Could not start stream".to_string(), + } + })?; + hex::encode(bytes) + }; + // Spawn WebTransport proxy let binary_path = match std::env::current_exe() { Ok(b) => b, @@ -314,7 +327,7 @@ impl crate::backend::Backend { stream_id, port ); match tokio::process::Command::new(binary_path) - .args(["proxy", &port.to_string(), &stream_id.to_string()]) + .args(["proxy", &port.to_string(), &stream_id.to_string(), &stream_token]) .spawn() { Ok(_) => (), @@ -355,6 +368,7 @@ impl crate::backend::Backend { let post_stream_response = PostStreamStartResponse { url: webtransport_url, cert_hash: setup_resp.cert_hash, + stream_token, }; Ok(Json(post_stream_response))