diff --git a/Cargo.lock b/Cargo.lock index c97fdaf..a90f797 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -520,6 +520,16 @@ dependencies = [ "windows-sys 0.59.0", ] +[[package]] +name = "etag" +version = "4.0.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4b3d0661a2ccddc26cba0b834e9b717959ed6fdd76c7129ee159c170a875bf44" +dependencies = [ + "str-buf", + "xxhash-rust", +] + [[package]] name = "fastrand" version = "2.3.0" @@ -639,7 +649,10 @@ dependencies = [ "salvo", "serde", "serde-xml-rs", + "serde_json", "tokio", + "tracing", + "tracing-subscriber", "url-constructor", "uuid", ] @@ -1120,6 +1133,12 @@ dependencies = [ "simple_asn1", ] +[[package]] +name = "lazy_static" +version = "1.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "bbd2bcb4c963f2ddae06a2efc7e9f3591312473c50c6685e1f298068316e66fe" + [[package]] name = "libc" version = "0.2.174" @@ -1295,6 +1314,16 @@ dependencies = [ "minimal-lexical", ] +[[package]] +name = "nu-ansi-term" +version = "0.46.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "77a8165726e8236064dbb45459242600304b42a5ea24ee2948e18e023bf7ba84" +dependencies = [ + "overload", + "winapi", +] + [[package]] name = "num-bigint" version = "0.4.6" @@ -1400,6 +1429,12 @@ version = "0.2.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "04744f49eae99ab78e0d5c0b603ab218f515ea8cfe5a456d7629ad883a3b6e7d" +[[package]] +name = "overload" +version = "0.1.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b15813163c1d831bf4a13c3610c05c0d03b39feb07f7e09fa234dac9b15aaf39" + [[package]] name = "parking_lot" version = "0.12.3" @@ -1958,6 +1993,15 @@ dependencies = [ "security-framework 3.2.0", ] +[[package]] +name = "rustls-pemfile" +version = "2.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "dce314e5fee3f39953d46bb63bb8a46d40c2f8fb7cc5a3b6cab2bde9721d6e50" +dependencies = [ + "rustls-pki-types", +] + [[package]] name = "rustls-pki-types" version = "1.12.0" @@ -1993,21 +2037,45 @@ checksum = "28d3b2b1366ec20994f1fd18c3c594f05c5dd4bc44d8bb0c1c632c8d6829481f" [[package]] name = "salvo" -version = "0.79.0" +version = "0.80.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ca8e3fda2c363337d9179477b947546433da95c94d7d2ba325e21fa9c314ff78" +checksum = "7ce0331005d85590a43295118391452dd04cf7826c2f8251c2362874b8d8b67a" dependencies = [ + "salvo-craft", "salvo-jwt-auth", "salvo-oapi", "salvo-proxy", "salvo_core", + "salvo_extra", +] + +[[package]] +name = "salvo-craft" +version = "0.80.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5e0e91ffd1029d42f100923525be9ed0cb500e651100580ba9e67cae368f34b3" +dependencies = [ + "salvo-craft-macros", +] + +[[package]] +name = "salvo-craft-macros" +version = "0.80.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9514000ff4ff9697ec8334f8be2cf519a0a3c6b77685fb1ad7a271f952231dc0" +dependencies = [ + "proc-macro-crate", + "proc-macro2", + "quote", + "regex", + "syn 2.0.104", ] [[package]] name = "salvo-jwt-auth" -version = "0.79.0" +version = "0.80.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "88690329319f476e45ea15208f947dd0f162388dcfb14f34b9fbd665e11ca286" +checksum = "d9d41ac7983e15ce964f2c2b55d07ea83f1ec8c40abac4fc912ae2c34c272205" dependencies = [ "base64", "bytes", @@ -2025,9 +2093,9 @@ dependencies = [ [[package]] name = "salvo-oapi" -version = "0.79.0" +version = "0.80.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b278c065715279c485d2d74be6823388f7ddb32200da7d47b7cdc3f4780e40f5" +checksum = "7b9ef2a41e4bebb38f859d12b257a386e26eb736ad2932b39a7fa6b4a1574cbc" dependencies = [ "anyhow", "base64", @@ -2060,9 +2128,9 @@ dependencies = [ [[package]] name = "salvo-oapi-macros" -version = "0.79.0" +version = "0.80.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a68d0fdfa852c79ea04fe035b5717d5cced58cb7f28fa5f20f74b16b0bbddd61" +checksum = "e11a94dacba1e1faeeb27c0affcecea7913ae086f9dfc8476dfb115a61c9f412" dependencies = [ "proc-macro-crate", "proc-macro2", @@ -2075,9 +2143,9 @@ dependencies = [ [[package]] name = "salvo-proxy" -version = "0.79.0" +version = "0.80.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7ac4a3774c12134537dbf7b4de0a3a33355bc8181b7aedecf2e676f6155c16dd" +checksum = "60554233611342021acaf1340679427e388e358486fa5ef7684d90aa738c854b" dependencies = [ "fastrand", "futures-util", @@ -2093,9 +2161,9 @@ dependencies = [ [[package]] name = "salvo-serde-util" -version = "0.79.0" +version = "0.80.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7bf8b2a2576a893b80542c400ae4d3b1b1a903ad400ce6a7affe930e4ad9346d" +checksum = "2014d193ab04bf917c574c155801279474becd532a02f7dbe64928bdd2b072a4" dependencies = [ "proc-macro2", "quote", @@ -2104,9 +2172,9 @@ dependencies = [ [[package]] name = "salvo_core" -version = "0.79.0" +version = "0.80.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c7a887d957cfc87e4f25e33d37bfc8aac0544ab95478b4620faaa2f0be522c27" +checksum = "fb81e2093c75ff8c3273196c7b5e3b2fb9c5f838ef7e925753ce7e58ddbd6fe4" dependencies = [ "async-trait", "base64", @@ -2133,6 +2201,7 @@ dependencies = [ "pin-project", "rand 0.9.1", "regex", + "rustls-pemfile", "salvo_macros", "serde", "serde-xml-rs", @@ -2147,10 +2216,32 @@ dependencies = [ ] [[package]] -name = "salvo_macros" -version = "0.79.0" +name = "salvo_extra" +version = "0.80.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "fc781932226f47832bff8e11d32d2c799081b08116cdfe212712095aa908a7af" +checksum = "3f90ca584472adbb3dacf4e82bdf46f8022771916342d598b5521f6bfe6b8bfd" +dependencies = [ + "base64", + "etag", + "futures-util", + "http-body-util", + "hyper", + "pin-project", + "salvo_core", + "serde", + "serde_json", + "tokio", + "tokio-tungstenite", + "tower", + "tracing", + "ulid", +] + +[[package]] +name = "salvo_macros" +version = "0.80.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "83c13353864ea07e0673511a843b762c775b7f9364657ff5c49ad5fccc5ce0a0" dependencies = [ "proc-macro-crate", "proc-macro2", @@ -2317,6 +2408,15 @@ dependencies = [ "digest", ] +[[package]] +name = "sharded-slab" +version = "0.1.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f40ca3c46823713e0d4209592e8d6e826aa57e928f09752619fc696c499637f6" +dependencies = [ + "lazy_static", +] + [[package]] name = "shlex" version = "1.3.0" @@ -2390,6 +2490,12 @@ version = "1.1.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "a2eb9349b6444b326872e140eb1cf5e7c522154d69e7a0ffb0fb81c06b37543f" +[[package]] +name = "str-buf" +version = "3.0.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0ceb97b7225c713c2fd4db0153cb6b3cab244eb37900c3f634ed4d43310d8c34" + [[package]] name = "subtle" version = "2.6.1" @@ -2518,6 +2624,15 @@ dependencies = [ "syn 2.0.104", ] +[[package]] +name = "thread_local" +version = "1.1.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f60246a4944f24f6e018aa17cdeffb7818b76356965d03b07d6a9886e8962185" +dependencies = [ + "cfg-if", +] + [[package]] name = "time" version = "0.3.41" @@ -2623,6 +2738,18 @@ dependencies = [ "tokio", ] +[[package]] +name = "tokio-tungstenite" +version = "0.27.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "489a59b6730eda1b0171fcfda8b121f4bee2b35cba8645ca35c5f7ba3eb736c1" +dependencies = [ + "futures-util", + "log", + "tokio", + "tungstenite", +] + [[package]] name = "tokio-util" version = "0.7.15" @@ -2664,8 +2791,10 @@ dependencies = [ "pin-project-lite", "sync_wrapper", "tokio", + "tokio-util", "tower-layer", "tower-service", + "tracing", ] [[package]] @@ -2727,6 +2856,32 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "b9d12581f227e93f094d3af2ae690a574abb8a2b9b7a96e7cfe9647b2b617678" dependencies = [ "once_cell", + "valuable", +] + +[[package]] +name = "tracing-log" +version = "0.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ee855f1f400bd0e5c02d150ae5de3840039a3f54b025156404e34c23c03f47c3" +dependencies = [ + "log", + "once_cell", + "tracing-core", +] + +[[package]] +name = "tracing-subscriber" +version = "0.3.19" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e8189decb5ac0fa7bc8b96b7cb9b2701d60d48805aca84a238004d665fcc4008" +dependencies = [ + "nu-ansi-term", + "sharded-slab", + "smallvec", + "thread_local", + "tracing-core", + "tracing-log", ] [[package]] @@ -2735,6 +2890,19 @@ version = "0.2.5" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "e421abadd41a4225275504ea4d6566923418b7f05506fbc9c0fe86ba7396114b" +[[package]] +name = "tungstenite" +version = "0.27.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "eadc29d668c91fcc564941132e17b28a7ceb2f3ebf0b9dae3e03fd7a6748eb0d" +dependencies = [ + "bytes", + "log", + "rand 0.9.1", + "thiserror 2.0.12", + "utf-8", +] + [[package]] name = "typenum" version = "1.18.0" @@ -2747,6 +2915,7 @@ version = "1.2.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "470dbf6591da1b39d43c14523b2b469c86879a53e8b758c8e090a470fe7b1fbe" dependencies = [ + "rand 0.9.1", "web-time", ] @@ -2801,6 +2970,12 @@ version = "0.1.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "76ebdb602caa7eec1e27aaf8364e23c9807d0deb2aae98c6e0fa37ea05954d6d" +[[package]] +name = "utf-8" +version = "0.7.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "09cc8ee72d2a9becf2f2febe0205bbed8fc6615b7cb429ad062dc7b7ddd036a9" + [[package]] name = "utf8_iter" version = "1.0.4" @@ -2815,9 +2990,16 @@ checksum = "3cf4199d1e5d15ddd86a694e4d0dffa9c323ce759fea589f00fef9d81cc1931d" dependencies = [ "getrandom 0.3.3", "js-sys", + "serde", "wasm-bindgen", ] +[[package]] +name = "valuable" +version = "0.1.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ba73ea9cf16a25df0c8caa16c51acb937d5712a8429db78a3ee29d5dcacd3a65" + [[package]] name = "vcpkg" version = "0.2.15" @@ -2977,15 +3159,37 @@ dependencies = [ "rustls-pki-types", ] +[[package]] +name = "winapi" +version = "0.3.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5c839a674fcd7a98952e593242ea400abe93992746761e38641405d28b00f419" +dependencies = [ + "winapi-i686-pc-windows-gnu", + "winapi-x86_64-pc-windows-gnu", +] + +[[package]] +name = "winapi-i686-pc-windows-gnu" +version = "0.4.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ac3b87c63620426dd9b991e5ce0329eff545bccbbb34f3be09ff6fb6ab51b7b6" + [[package]] name = "winapi-util" version = "0.1.9" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "cf221c93e13a30d793f7645a0e7762c55d169dbb0a49671918a2319d289b10bb" dependencies = [ - "windows-sys 0.52.0", + "windows-sys 0.59.0", ] +[[package]] +name = "winapi-x86_64-pc-windows-gnu" +version = "0.4.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "712e227841d057c1ee1cd2fb22fa7e5a5461ae8e48fa2ca79ec42cfc1931183f" + [[package]] name = "windows-core" version = "0.61.0" @@ -3241,6 +3445,12 @@ version = "0.8.26" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "a62ce76d9b56901b19a74f19431b0d8b3bc7ca4ad685a746dfd78ca8f4fc6bda" +[[package]] +name = "xxhash-rust" +version = "0.8.15" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "fdd20c5420375476fbd4394763288da7eb0cc0b8c11deed431a91562af7335d3" + [[package]] name = "yansi" version = "1.0.1" diff --git a/gamestream-webtransport-proxy/Cargo.toml b/gamestream-webtransport-proxy/Cargo.toml index 798cf12..5075d2d 100644 --- a/gamestream-webtransport-proxy/Cargo.toml +++ b/gamestream-webtransport-proxy/Cargo.toml @@ -16,9 +16,12 @@ reqwest = { version = "0.12.20", features = [ "rustls-tls", "native-tls", ], default-features = false } -salvo = { version = "0.79.0", features = ["oapi"] } +salvo = { version = "0.80.0", features = ["oapi", "craft", "logging"] } serde = { version = "1.0.219", features = ["serde_derive"] } serde-xml-rs = "0.8.1" +serde_json = "1.0.140" tokio = { version = "1.45.1", features = ["full"] } +tracing = "0.1.41" +tracing-subscriber = "0.3.19" url-constructor = "0.1.0" -uuid = { version = "1.17.0", features = ["v4"] } +uuid = { version = "1.17.0", features = ["v4", "serde"] } diff --git a/gamestream-webtransport-proxy/src/apps.rs b/gamestream-webtransport-proxy/src/apps.rs new file mode 100644 index 0000000..ac3ffb8 --- /dev/null +++ b/gamestream-webtransport-proxy/src/apps.rs @@ -0,0 +1,115 @@ +use std::collections::HashMap; + +use salvo::prelude::*; +use serde::{Deserialize, Serialize}; +use tracing::{debug, error}; + +use crate::{common::AppResult, config::ConfigReader}; + +#[derive(Deserialize)] +struct AppListRespApp { + #[serde(rename = "AppTitle")] + app_title: String, + #[serde(rename = "UUID")] + uuid: uuid::Uuid, + #[serde(rename = "IsHdrSupported")] + is_hdr_supported: bool, + #[serde(rename = "ID")] + id: u64, +} + +#[derive(Deserialize)] +struct AppListResp { + #[serde(rename = "App")] + apps: Vec, +} + +#[derive(Debug, Serialize, ToSchema)] +struct App { + title: String, + id: u64, + hdr_supported: bool, +} + +#[derive(Debug, Serialize, ToSchema)] +struct GetAppsResponse { + apps: HashMap>, +} + +#[craft] +impl crate::config::ConfigFile { + #[craft(endpoint(status_codes(StatusCode::OK, StatusCode::INTERNAL_SERVER_ERROR)))] + pub async fn get_apps(self: ::std::sync::Arc) -> AppResult> { + let standard_error = Err(crate::common::AppError { + status_code: StatusCode::INTERNAL_SERVER_ERROR, + description: "failed to get available apps".to_string(), + }); + + let reader = self.read().await; + let unique_id = match reader.unique_id() { + Ok(u) => u, + Err(e) => { + error!("could not get unique id: {e}"); + return standard_error; + } + }; + + let servers = match reader.servers() { + Ok(s) => s, + Err(e) => { + error!("could not get servers: {e}"); + return standard_error; + } + }; + + let mut get_apps_resp = GetAppsResponse { + apps: HashMap::new(), + }; + + for (_, server) in servers.into_iter() { + let mut base_url = crate::common::base_url( + "https", + &server.host, + server.https_port(), + &unique_id, + "applist", + None, + ); + + let resp = match crate::common::get_url(&mut base_url, true).await { + Ok(r) => r, + Err(e) => { + error!("could not get applist from server {}: {}", server.name, e); + continue; + } + }; + debug!(resp); + + let applist_resp: AppListResp = match serde_xml_rs::from_str(&resp) { + Ok(r) => r, + Err(e) => { + error!( + "could not parse applist response from server {}: {}", + server.name, e + ); + continue; + } + }; + + let resp_vec = applist_resp + .apps + .into_iter() + .map(|a| App { + title: a.app_title, + hdr_supported: a.is_hdr_supported, + id: a.id, + }) + .rev() + .collect(); + + get_apps_resp.apps.insert(server.name, resp_vec); + } + + Ok(Json(get_apps_resp)) + } +} diff --git a/gamestream-webtransport-proxy/src/certs.rs b/gamestream-webtransport-proxy/src/certs.rs index eccb702..41b5a22 100644 --- a/gamestream-webtransport-proxy/src/certs.rs +++ b/gamestream-webtransport-proxy/src/certs.rs @@ -110,7 +110,7 @@ pub fn save_cert_and_key_to_disk(cert: &X509, key: &PKey) -> Result<()> Ok(()) } -pub fn http_client_with_identity() -> Result { +pub fn identity() -> Result { let cert_dir = get_and_create_cert_dir()?; let cert_filepath = cert_dir.join("cert"); let key_filepath = cert_dir.join("key"); @@ -118,7 +118,8 @@ pub fn http_client_with_identity() -> Result { let cert_bytes = fs::read(cert_filepath)?; let key_bytes = fs::read(key_filepath)?; - let identity = reqwest::tls::Identity::from_pkcs8_pem(&cert_bytes, &key_bytes)?; - - Ok(reqwest::Client::builder().identity(identity).build()?) + Ok(reqwest::tls::Identity::from_pkcs8_pem( + &cert_bytes, + &key_bytes, + )?) } diff --git a/gamestream-webtransport-proxy/src/common.rs b/gamestream-webtransport-proxy/src/common.rs index 620b9c4..440a733 100644 --- a/gamestream-webtransport-proxy/src/common.rs +++ b/gamestream-webtransport-proxy/src/common.rs @@ -1,7 +1,8 @@ -use serde::Serialize; -//use salvo::http::{StatusCode, StatusError}; +use anyhow::Result; use salvo::oapi::{self, EndpointOutRegister, ToSchema}; use salvo::prelude::*; +use serde::Serialize; +use tracing::debug; #[derive(Debug, Serialize, ToSchema)] struct ApiError { @@ -14,7 +15,7 @@ pub struct AppError { pub description: String, } -pub type AppResult = anyhow::Result; +pub type AppResult = Result; #[async_trait] impl Writer for AppError { @@ -43,3 +44,53 @@ impl EndpointOutRegister for AppError { } } } + +pub fn base_url( + scheme: &str, + host: &String, + base_port: u16, + unique_id: &String, + path: &str, + params: Option>, +) -> url_constructor::UrlConstructor { + let mut base_url = url_constructor::UrlConstructor::new(); + base_url + .scheme(scheme) + .host(host) + .port(base_port) + .subdir(path) + .param("uniqueid", unique_id); + + if let Some(p) = params { + for (k, v) in p.into_iter() { + base_url.param(k, v); + } + } + base_url +} + +pub async fn get_url( + base_url: &mut url_constructor::UrlConstructor, + with_identity: bool, +) -> Result { + let mut uuidv2 = [0u8; 16]; + openssl::rand::rand_bytes(&mut uuidv2)?; + let uuidv2_hex = hex::encode(uuidv2); + + let url = base_url.param("uuid", uuidv2_hex).build(); + debug!("Getting url: {url}"); + + let mut http_builder = reqwest::Client::builder(); + http_builder = http_builder.user_agent("Mozilla/5.0"); + http_builder = http_builder.danger_accept_invalid_certs(true); + if with_identity { + let identity = crate::certs::identity()?; + http_builder = http_builder.identity(identity); + } + + let client = http_builder.build()?; + + let resp = client.get(url).send().await?; + let text = resp.text().await?; + Ok(text) +} diff --git a/gamestream-webtransport-proxy/src/config.rs b/gamestream-webtransport-proxy/src/config.rs new file mode 100644 index 0000000..90b5cb2 --- /dev/null +++ b/gamestream-webtransport-proxy/src/config.rs @@ -0,0 +1,158 @@ +use std::{ + collections::HashMap, + fs::{self, File}, + path::PathBuf, +}; + +use anyhow::{Result, anyhow}; +use serde::{Deserialize, Serialize}; +use tokio::sync::{RwLockReadGuard, RwLockWriteGuard}; + +#[derive(Serialize, Deserialize)] +pub struct Server { + pub name: String, + pub host: String, + pub base_port: u16, +} + +impl Server { + pub fn http_port(&self) -> u16 { + self.base_port + 189 + } + pub fn https_port(&self) -> u16 { + self.base_port + 184 + } +} + +#[derive(Serialize, Deserialize)] +struct Config { + servers: HashMap, + unique_id: String, +} + +pub struct ConfigFile { + lock: tokio::sync::RwLock<()>, + path: PathBuf, +} + +pub trait ConfigReader { + fn servers(&self) -> Result>; + fn unique_id(&self) -> Result; +} + +pub trait ConfigWriter { + fn add_server(&self, server: Server) -> Result<()>; +} + +pub struct ConfigReadAccess<'a> { + _guard: RwLockReadGuard<'a, ()>, + config: &'a ConfigFile, +} + +pub struct ConfigWriteAccess<'a> { + _guard: RwLockWriteGuard<'a, ()>, + config: &'a ConfigFile, +} + +impl ConfigFile { + fn load_config(&self) -> Result { + tracing::debug!("parsing config file"); + + let config_file = File::open(&self.path)?; + Ok(serde_json::from_reader(config_file)?) + } + + fn write_config(&self, config: Config) -> Result<()> { + tracing::debug!("serializing config file"); + let config_file = File::create(&self.path)?; + Ok(serde_json::to_writer_pretty(config_file, &config)?) + } + + pub fn new() -> Result { + let project_dirs = + directories::ProjectDirs::from("xyz", "ohea", "gamestream-webtransport-proxy") + .ok_or(anyhow::anyhow!("Could not get project dirs"))?; + + let data_dir = project_dirs.data_dir(); + fs::create_dir_all(data_dir)?; + + let state_path = data_dir.join("state.json"); + + if let Err(e) = File::open(&state_path) { + if e.kind() == std::io::ErrorKind::NotFound { + write_default_config_to_path(&state_path)?; + } else { + return Err(anyhow!(e)); + } + } + + Ok(ConfigFile { + lock: tokio::sync::RwLock::new(()), + path: state_path, + }) + } +} + +impl ConfigFile { + pub async fn read(&self) -> ConfigReadAccess { + ConfigReadAccess { + _guard: self.lock.read().await, + config: self, + } + } + + pub async fn write(&self) -> ConfigWriteAccess { + ConfigWriteAccess { + _guard: self.lock.write().await, + config: self, + } + } +} + +impl<'a> ConfigReader for ConfigReadAccess<'a> { + fn servers(&self) -> Result> { + let config = self.config.load_config()?; + Ok(config.servers) + } + + fn unique_id(&self) -> Result { + let config = self.config.load_config()?; + Ok(config.unique_id) + } +} + +impl<'a> ConfigWriter for ConfigWriteAccess<'a> { + fn add_server(&self, server: Server) -> Result<()> { + let mut config = self.config.load_config()?; + + if config.servers.contains_key(&server.name) { + return Err(anyhow!( + "cannot add duplicate server with name: {}", + server.name + )); + } + + config.servers.insert(server.name.clone(), server); + + self.config.write_config(config)?; + + Ok(()) + } +} + +pub fn get_unique_id() -> Result { + let mut bytes = [0u8; 8]; + openssl::rand::rand_bytes(&mut bytes)?; + Ok(hex::encode(bytes)) +} + +fn write_default_config_to_path(path: &PathBuf) -> Result<()> { + let default_config = Config { + servers: HashMap::new(), + unique_id: get_unique_id()?, + }; + + let config_file = File::create(path)?; + + Ok(serde_json::to_writer_pretty(config_file, &default_config)?) +} diff --git a/gamestream-webtransport-proxy/src/main.rs b/gamestream-webtransport-proxy/src/main.rs index faeb1a9..eddb306 100644 --- a/gamestream-webtransport-proxy/src/main.rs +++ b/gamestream-webtransport-proxy/src/main.rs @@ -1,3 +1,4 @@ +use salvo::logging::Logger; use salvo::prelude::*; use std::ffi::CString; @@ -6,8 +7,10 @@ use moonlight_common_c_sys::{ CONNECTION_LISTENER_CALLBACKS, ENCFLG_NONE, SCM_H264, STREAM_CFG_LOCAL, VIDEO_FORMAT_H264, }; +mod apps; mod certs; mod common; +mod config; mod pair; #[allow(unused)] @@ -105,15 +108,26 @@ fn get_listener_callbacks() -> CONNECTION_LISTENER_CALLBACKS { //} #[tokio::main] -async fn main() { +async fn main() -> anyhow::Result<()> { + tracing_subscriber::fmt() + .with_max_level(tracing::Level::DEBUG) + .init(); + + let config = config::ConfigFile::new()?; + let config_arc = std::sync::Arc::new(config); + let router = Router::new() - .push(Router::with_path("pair").post(pair::post_pair)) - .push(Router::with_path("apps").get(apps::get_apps)); + .push(Router::with_path("pair").post(config_arc.post_pair())) + .push(Router::with_path("apps").get(config_arc.get_apps())); let doc = OpenApi::new("test api", "0.0.1").merge_router(&router); let router = router .unshift(doc.into_router("/api-doc/openapi.json")) .unshift(SwaggerUi::new("/api-doc/openapi.json").into_router("/swagger-ui")); + let service = Service::new(router).hoop(Logger::new()); + let listener = TcpListener::new("0.0.0.0:3001"); let acceptor = listener.join(TcpListener::new("0.0.0.0:3000")).bind().await; - salvo::Server::new(acceptor).serve(router).await; + salvo::Server::new(acceptor).serve(service).await; + + Ok(()) } diff --git a/gamestream-webtransport-proxy/src/pair.rs b/gamestream-webtransport-proxy/src/pair.rs index ec3eceb..88e822c 100644 --- a/gamestream-webtransport-proxy/src/pair.rs +++ b/gamestream-webtransport-proxy/src/pair.rs @@ -5,15 +5,16 @@ use openssl::x509::X509; use rand::Rng; use salvo::prelude::*; use serde::{Deserialize, Serialize}; +use tracing::{debug, error, info}; -use crate::common::{AppError, AppResult}; +use crate::common::{AppError, AppResult, get_url}; +use crate::config::{ConfigReader, ConfigWriter}; #[derive(Debug, Deserialize, ToSchema)] struct PostPairParams { + name: String, host: String, - port: u16, - #[allow(unused)] - pair_endpoint: Option, + base_port: u16, } #[derive(Debug, Serialize, ToSchema)] @@ -56,23 +57,6 @@ struct ServerPairingSecret { signature: Vec, } -async fn get_url(base_url: &mut url_constructor::UrlConstructor) -> Result { - let mut uuidv2 = [0u8; 16]; - openssl::rand::rand_bytes(&mut uuidv2)?; - let uuidv2_hex = hex::encode(uuidv2); - - let url = base_url.param("uuid", uuidv2_hex).build(); - //println!("Getting url: {url}"); - - let mut http_builder = reqwest::Client::builder(); - http_builder = http_builder.user_agent("Mozilla/5.0"); - let client = http_builder.build().unwrap(); - - let resp = client.get(url).send().await?; - let text = resp.text().await?; - Ok(text) -} - async fn get_server_cert( mut base_url: url_constructor::UrlConstructor, salt_hex: String, @@ -84,7 +68,7 @@ async fn get_server_cert( .param("salt", &salt_hex) .param("clientcert", &cert_hex); - let text = get_url(url).await?; + let text = get_url(url, false).await?; let server_cert: ServerCertResponse = serde_xml_rs::from_str(&text)?; let server_cert_bytes = hex::decode(server_cert.plaincert)?; @@ -116,7 +100,7 @@ async fn get_server_challenge( let challenge_hex = hex::encode(challenge_enc); let url = base_url.param("clientchallenge", challenge_hex); - let text = get_url(url).await?; + let text = get_url(url, false).await?; Ok(serde_xml_rs::from_str(&text)?) } @@ -151,8 +135,8 @@ fn generate_challenge_response( )?; cipher_ctx.cipher_final(&mut client_challenge_response_data)?; - //let client_challenge_response_data_hex = hex::encode(&client_challenge_response_data); - //println!("client_challenge_response_data_hex: {client_challenge_response_data_hex}"); + let client_challenge_response_data_hex = hex::encode(&client_challenge_response_data); + debug!("client_challenge_response_data_hex: {client_challenge_response_data_hex}"); // Extract ASN.1 signature from certificate let asn_signature = cert.signature(); @@ -166,8 +150,8 @@ fn generate_challenge_response( challenge_response.extend_from_slice(signature_data); challenge_response.extend_from_slice(client_secret_data); - //let challenge_response_hex = hex::encode(&challenge_response); - //println!("challenge_response_hex: {challenge_response_hex}"); + let challenge_response_hex = hex::encode(&challenge_response); + debug!("challenge_response_hex: {challenge_response_hex}"); let mut hasher = Sha256::new(); hasher.update(&challenge_response); @@ -196,7 +180,7 @@ async fn send_server_challenge_response( ) -> Result { let url = base_url.param("serverchallengeresp", server_challenge_response); - let text = get_url(url).await?; + let text = get_url(url, false).await?; Ok(serde_xml_rs::from_str(&text)?) } @@ -218,11 +202,11 @@ async fn do_challenge( cert: &X509, ) -> Result { let aes_key = generate_aes_key(salt, pin); - //let aes_hex = hex::encode(&aes_key); - //println!("aes_hex: {aes_hex}"); + let aes_hex = hex::encode(&aes_key); + debug!("aes_hex: {aes_hex}"); let client_challenge_response = get_server_challenge(base_url.clone(), &aes_key).await?; - //println!("{client_challenge_response:?}"); + debug!("{client_challenge_response:?}"); let challenge_response = generate_challenge_response( client_challenge_response.challengeresponse, @@ -241,32 +225,16 @@ async fn do_challenge( }) } -fn get_base_url(host: &String, port: u16, unique_id: String) -> url_constructor::UrlConstructor { - let mut base_url = url_constructor::UrlConstructor::new(); - base_url - .scheme("http") - .host(host) - .port(port) - .subdir("pair") - .param("uniqueid", unique_id) - .param("devicename", "roth") // TODO: what is this roth thing? - .param("updateState", "1"); - base_url -} - fn generate_pin() -> [u8; 4] { let mut pin = [0u8; 4]; { - print!("pairing pin: "); - - // TODO: reenable real RNG let mut rng = rand::rng(); for i in 0..pin.len() { // Generate ascii number 0-9 pin[i] = rng.random_range(48..58); - print!("{}", pin[i] as char); } - println!(""); + let pin_string: String = pin.iter().map(|&b| b as char).collect(); + info!("pairing pin: {}", pin_string); } pin } @@ -317,133 +285,161 @@ async fn send_client_pairing_secret( let url = base_url.param("clientpairingsecret", client_secret_hex); - let text = get_url(url).await?; + let text = get_url(url, false).await?; Ok(serde_xml_rs::from_str(&text)?) } -fn get_unique_id() -> Result { - let mut bytes = [0u8; 8]; - openssl::rand::rand_bytes(&mut bytes)?; - Ok(hex::encode(bytes)) -} +#[craft] +impl crate::config::ConfigFile { + #[craft(endpoint(status_codes(StatusCode::OK, StatusCode::INTERNAL_SERVER_ERROR)))] + pub async fn post_pair( + self: ::std::sync::Arc, + body: salvo::oapi::extract::JsonBody, + ) -> AppResult<()> { + let params = body.into_inner(); -#[salvo::oapi::endpoint(status_codes(StatusCode::OK, StatusCode::INTERNAL_SERVER_ERROR))] -pub async fn post_pair(body: salvo::oapi::extract::JsonBody) -> AppResult<()> { - let params = body.into_inner(); + let server = crate::config::Server { + host: params.host, + base_port: params.base_port, + name: params.name, + }; - let unique_id = match get_unique_id() { - Ok(u) => u, - Err(e) => { - println!("Could not generate unique id: {e}"); + let unique_id = match self.read().await.unique_id() { + Ok(u) => u, + Err(e) => { + error!("Could not get unique id: {e}"); + return Err(AppError { + status_code: StatusCode::INTERNAL_SERVER_ERROR, + description: "Pairing failed".to_string(), + }); + } + }; + + let base_url = crate::common::base_url( + "http", + &server.host, + server.http_port(), + &unique_id, + "pair", + Some(vec![("devicename", "roth"), ("updateState", "1")]), + ); + + let pin = generate_pin(); + + let mut client_secret_data = [0u8; 16]; + if let Err(e) = openssl::rand::rand_bytes(&mut client_secret_data) { + error!("Could not generate client secret data: {e}"); return Err(AppError { status_code: StatusCode::INTERNAL_SERVER_ERROR, description: "Pairing failed".to_string(), }); } - }; - let base_url = get_base_url(¶ms.host, params.port, unique_id); + // Get or generate cert / private key + let (cert, private_key) = match crate::certs::get_cert_and_key() { + Ok(v) => v, + Err(e) => { + error!("Could not generate certs: {e}"); + return Err(AppError { + status_code: StatusCode::INTERNAL_SERVER_ERROR, + description: "Pairing failed".to_string(), + }); + } + }; - let pin = generate_pin(); + // Convert to hex + let cert_pem = cert.to_pem().unwrap(); + let cert_hex = hex::encode(&cert_pem); - let mut client_secret_data = [0u8; 16]; - if let Err(e) = openssl::rand::rand_bytes(&mut client_secret_data) { - println!("Could not generate client secret data: {e}"); - return Err(AppError { - status_code: StatusCode::INTERNAL_SERVER_ERROR, - description: "Pairing failed".to_string(), - }); - } + // Generate salt and convert to hex + let mut salt = [0u8; 16]; + openssl::rand::rand_bytes(&mut salt).unwrap(); + let salt_hex = hex::encode(salt); - // Get or generate cert / private key - let (cert, private_key) = match crate::certs::get_cert_and_key() { - Ok(v) => v, - Err(e) => { - println!("Could not generate certs: {e}"); - return Err(AppError { - status_code: StatusCode::INTERNAL_SERVER_ERROR, - description: "Pairing failed".to_string(), - }); - } - }; - - // Convert to hex - let cert_pem = cert.to_pem().unwrap(); - let cert_hex = hex::encode(&cert_pem); - - // Generate salt and convert to hex - let mut salt = [0u8; 16]; - openssl::rand::rand_bytes(&mut salt).unwrap(); - let salt_hex = hex::encode(salt); - - // Get the server certificate and start the pairing process - // This returns once the user has submitted the pin to the - // server out of band. - let server_cert = match get_server_cert(base_url.clone(), salt_hex, cert_hex).await { - Ok(s) => s, - Err(e) => { - println!("Could not get server cert: {e}"); - return Err(AppError { - status_code: StatusCode::INTERNAL_SERVER_ERROR, - description: "Pairing failed".to_string(), - }); - } - }; - //println!("{server_cert:?}"); - - // Do the challenge response process - // This returns the pairing secret - let server_pairing_secret = - match do_challenge(base_url.clone(), &client_secret_data, pin, salt, &cert).await { + // Get the server certificate and start the pairing process + // This returns once the user has submitted the pin to the + // server out of band. + let server_cert = match get_server_cert(base_url.clone(), salt_hex, cert_hex).await { Ok(s) => s, Err(e) => { - println!("Could not do challenge: {e}"); + error!("Could not get server cert: {e}"); return Err(AppError { status_code: StatusCode::INTERNAL_SERVER_ERROR, description: "Pairing failed".to_string(), }); } }; - //println!("{server_pairing_secret:?}"); + let server_cert_hex = hex::encode(&server_cert.cert); + debug!("server_cert_hex: {server_cert_hex:?}"); - // Verify the pairing_secret signature - if let Err(e) = verify_signature( - server_pairing_secret.pairing_secret, - server_pairing_secret.signature, - server_cert.cert, - ) { - println!("Could not verify signature: {e}"); - return Err(AppError { - status_code: StatusCode::INTERNAL_SERVER_ERROR, - description: "Pairing failed".to_string(), - }); - } + // Do the challenge response process + // This returns the pairing secret + let server_pairing_secret = + match do_challenge(base_url.clone(), &client_secret_data, pin, salt, &cert).await { + Ok(s) => s, + Err(e) => { + error!("Could not do challenge: {e}"); + return Err(AppError { + status_code: StatusCode::INTERNAL_SERVER_ERROR, + description: "Pairing failed".to_string(), + }); + } + }; + let server_pairing_secret_hex = hex::encode(&server_pairing_secret.pairing_secret); + debug!("server_pairing_secret_hex: {server_pairing_secret_hex:?}"); - let client_pairing_secret_response = - match send_client_pairing_secret(base_url.clone(), &client_secret_data, &private_key).await - { - Ok(p) => p, + // Verify the pairing_secret signature + if let Err(e) = verify_signature( + server_pairing_secret.pairing_secret, + server_pairing_secret.signature, + server_cert.cert, + ) { + error!("Could not verify signature: {e}"); + return Err(AppError { + status_code: StatusCode::INTERNAL_SERVER_ERROR, + description: "Pairing failed".to_string(), + }); + } + + let client_pairing_secret_response = + match send_client_pairing_secret(base_url.clone(), &client_secret_data, &private_key) + .await + { + Ok(p) => p, + Err(e) => { + error!("Could not send client pairing secret: {e}"); + return Err(AppError { + status_code: StatusCode::INTERNAL_SERVER_ERROR, + description: "Pairing failed".to_string(), + }); + } + }; + + if client_pairing_secret_response.paired != 1 { + error!("Failed to pair with server"); + return Err(AppError { + status_code: StatusCode::INTERNAL_SERVER_ERROR, + description: "Pairing failed".to_string(), + }); + } else { + info!( + "Paired with server {}:{} successfully!", + server.host, server.base_port + ); + } + + match self.write().await.add_server(server) { + Ok(_) => (), Err(e) => { - println!("Could not send client pairing secret: {e}"); + error!("Could not write to config file: {e}"); return Err(AppError { status_code: StatusCode::INTERNAL_SERVER_ERROR, description: "Pairing failed".to_string(), }); } - }; + } - if client_pairing_secret_response.paired != 1 { - println!("Failed to pair with server"); - return Err(AppError { - status_code: StatusCode::INTERNAL_SERVER_ERROR, - description: "Pairing failed".to_string(), - }); - } else { - let host = ¶ms.host; - let port = params.port; - println!("Paired with server {host}:{port} successfully!"); + Ok(()) } - - Ok(()) }