diff --git a/Cargo.lock b/Cargo.lock index 9baac20..27b3f3d 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -201,6 +201,27 @@ dependencies = [ "cc", ] +[[package]] +name = "directories" +version = "6.0.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "16f5094c54661b38d03bd7e50df373292118db60b585c08a411c6d840017fe7d" +dependencies = [ + "dirs-sys", +] + +[[package]] +name = "dirs-sys" +version = "0.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e01a3366d27ee9890022452ee61b2b63a67e6f13f58900b651ff5665f0bb1fab" +dependencies = [ + "libc", + "option-ext", + "redox_users", + "windows-sys 0.59.0", +] + [[package]] name = "displaydoc" version = "0.2.5" @@ -287,6 +308,7 @@ version = "0.1.0" dependencies = [ "anyhow", "axum", + "directories", "getrandom 0.3.3", "hex", "libc", @@ -617,6 +639,16 @@ dependencies = [ "windows-targets 0.53.2", ] +[[package]] +name = "libredox" +version = "0.1.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1580801010e535496706ba011c15f8532df6b42297d2e471fec38ceadd8c0638" +dependencies = [ + "bitflags", + "libc", +] + [[package]] name = "litemap" version = "0.8.0" @@ -754,6 +786,12 @@ dependencies = [ "vcpkg", ] +[[package]] +name = "option-ext" +version = "0.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "04744f49eae99ab78e0d5c0b603ab218f515ea8cfe5a456d7629ad883a3b6e7d" + [[package]] name = "parking_lot" version = "0.12.3" @@ -945,6 +983,17 @@ dependencies = [ "bitflags", ] +[[package]] +name = "redox_users" +version = "0.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "dd6f9d3d47bdd2ad6945c5015a226ec6155d0bcdfd8f7cd29f86b71f8de99d2b" +dependencies = [ + "getrandom 0.2.16", + "libredox", + "thiserror 2.0.12", +] + [[package]] name = "regex" version = "1.11.1" diff --git a/gamestream-webtransport-proxy/Cargo.toml b/gamestream-webtransport-proxy/Cargo.toml index de1140c..d52fe84 100644 --- a/gamestream-webtransport-proxy/Cargo.toml +++ b/gamestream-webtransport-proxy/Cargo.toml @@ -6,6 +6,7 @@ edition = "2024" [dependencies] anyhow = "1.0.98" axum = "0.8.4" +directories = "6.0.0" getrandom = { version = "0.3.3", features = ["std"] } hex = "0.4.3" libc = "0.2.174" diff --git a/gamestream-webtransport-proxy/src/pair.rs b/gamestream-webtransport-proxy/src/pair.rs index 0809bcf..f4b220b 100644 --- a/gamestream-webtransport-proxy/src/pair.rs +++ b/gamestream-webtransport-proxy/src/pair.rs @@ -1,3 +1,7 @@ +use std::fs; +use std::io::Write; + +use axum::Json; use axum::extract::Path; use axum::http::StatusCode; use axum::response::{IntoResponse, Response}; @@ -8,10 +12,9 @@ use serde::{Deserialize, Serialize}; use openssl::pkey::{PKey, Private}; use openssl::rsa::Rsa; -use openssl::x509::{self, X509}; +use openssl::x509::X509; use anyhow::Result; -use url_constructor::UrlConstructor; #[derive(Debug, Deserialize)] struct ServerCertResponse { @@ -30,6 +33,11 @@ pub struct ServerChallengeResponseResponse { pairingsecret: String, } +#[derive(Debug, Deserialize)] +pub struct ClientPairingSecretResponse { + paired: i32, +} + #[derive(Debug)] struct ServerCert { cert: Vec, @@ -41,6 +49,11 @@ pub struct ServerPairingSecret { signature: Vec, } +#[derive(Debug, Serialize)] +struct PairResult { + paired: bool, +} + fn get_cert_and_private_key() -> Result<(X509, PKey)> { let rsa = Rsa::generate(2048)?; let key_pair = PKey::from_rsa(rsa)?; @@ -74,6 +87,49 @@ fn get_cert_and_private_key() -> Result<(X509, PKey)> { Ok((cert, key_pair)) } +fn save_cert_and_key_to_disk( + cert: X509, + key: PKey, + host: &String, + port: u16, +) -> Result<()> { + let project_dirs = + directories::ProjectDirs::from("xyz", "ohea", "gamestream-webtransport-proxy") + .ok_or(anyhow::anyhow!("Could not get project dirs"))?; + let data_dir = project_dirs.data_dir(); + let cert_dir = data_dir.join("certs"); + fs::create_dir_all(&cert_dir)?; + + let cert_filepath = cert_dir.join(format!("{host}_{port}_cert")); + let key_filepath = cert_dir.join(format!("{host}_{port}_key")); + + let mut cert_file_builder = std::fs::OpenOptions::new(); + cert_file_builder.create(true); + cert_file_builder.truncate(true); + cert_file_builder.write(true); + + let mut key_file_builder = std::fs::OpenOptions::new(); + key_file_builder.create(true); + key_file_builder.truncate(true); + key_file_builder.write(true); + + #[cfg(target_family = "unix")] + { + use std::os::unix::fs::OpenOptionsExt; + + key_file_builder.mode(0o600); + cert_file_builder.mode(0o600); + } + + let mut cert_file = cert_file_builder.open(&cert_filepath)?; + let mut key_file = key_file_builder.open(&key_filepath)?; + + cert_file.write_all(&cert.to_pem()?)?; + key_file.write_all(&key.private_key_to_pem_pkcs8()?)?; + + Ok(()) +} + async fn get_url(base_url: &mut url_constructor::UrlConstructor) -> Result { let mut uuidv2 = [0u8; 16]; openssl::rand::rand_bytes(&mut uuidv2)?; @@ -260,14 +316,14 @@ async fn do_challenge( } pub async fn get_base_url( - host: String, + host: &String, port: u16, unique_id: String, ) -> url_constructor::UrlConstructor { let mut base_url = url_constructor::UrlConstructor::new(); base_url .scheme("http") - .host(&host) + .host(host) .port(port) .subdir("pair") .param("uniqueid", unique_id) @@ -284,10 +340,10 @@ pub async fn generate_pin() -> [u8; 4] { // TODO: reenable real RNG let mut rng = rand::rng(); for i in 0..pin.len() { - pin[i] = rng.random_range(48..58); // Generate ascii number 0-9 + // Generate ascii number 0-9 + pin[i] = rng.random_range(48..58); print!("{}", pin[i] as char); } - // Print as a four-digit, zero-padded integer println!(""); } pin @@ -328,7 +384,7 @@ pub async fn send_client_pairing_secret( mut base_url: url_constructor::UrlConstructor, client_secret_data: &[u8; 16], private_key: &PKey, -) -> Result<()> { +) -> Result { let signature = create_signature(client_secret_data, private_key).await?; let mut client_secret = Vec::with_capacity(client_secret_data.len() + signature.len()); @@ -340,9 +396,8 @@ pub async fn send_client_pairing_secret( let url = base_url.param("clientpairingsecret", client_secret_hex); let text = get_url(url).await?; - println!("{text:?}"); - Ok(()) + Ok(serde_xml_rs::from_str(&text)?) } async fn get_unique_id() -> Result { @@ -358,7 +413,7 @@ pub async fn get_pair(Path((host, port)): Path<(String, u16)>) -> Response { let unique_id = get_unique_id().await.unwrap(); - let base_url = get_base_url(host, port, unique_id).await; + let base_url = get_base_url(&host, port, unique_id).await; let pin = generate_pin().await; @@ -422,12 +477,32 @@ pub async fn get_pair(Path((host, port)): Path<(String, u16)>) -> Response { return StatusCode::INTERNAL_SERVER_ERROR.into_response(); } - if let Err(e) = - send_client_pairing_secret(base_url.clone(), &client_secret_data, &private_key).await - { - println!("Could not send client pairing secret: {e}"); + let client_pairing_secret_response = + match send_client_pairing_secret(base_url.clone(), &client_secret_data, &private_key).await + { + Ok(p) => p, + Err(e) => { + println!("Could not send client pairing secret: {e}"); + return StatusCode::INTERNAL_SERVER_ERROR.into_response(); + } + }; + + let pairing_result = PairResult { + paired: client_pairing_secret_response.paired == 1, + }; + + if pairing_result.paired { + println!("Successfully paired to server {host}:{port}"); + } else { + println!("Failed to pair with server {host}:{port}"); + return Json(pairing_result).into_response(); + } + + // Save certificate to disk so it can be used for subsequent connections + if let Err(e) = save_cert_and_key_to_disk(cert, private_key, &host, port) { + println!("Could not save cert and key to disk: {e}"); return StatusCode::INTERNAL_SERVER_ERROR.into_response(); }; - "test".into_response() + Json(pairing_result).into_response() }