diff --git a/gamestream-webtransport-proxy/src/common.rs b/gamestream-webtransport-proxy/src/common.rs new file mode 100644 index 0000000..620b9c4 --- /dev/null +++ b/gamestream-webtransport-proxy/src/common.rs @@ -0,0 +1,45 @@ +use serde::Serialize; +//use salvo::http::{StatusCode, StatusError}; +use salvo::oapi::{self, EndpointOutRegister, ToSchema}; +use salvo::prelude::*; + +#[derive(Debug, Serialize, ToSchema)] +struct ApiError { + pub description: String, +} + +#[derive(Debug)] +pub struct AppError { + pub status_code: StatusCode, + pub description: String, +} + +pub type AppResult = anyhow::Result; + +#[async_trait] +impl Writer for AppError { + async fn write(self, _: &mut Request, _depot: &mut Depot, res: &mut Response) { + res.status_code = Some(self.status_code); + Json(ApiError { + description: self.description, + }) + .render(res) + } +} + +impl EndpointOutRegister for AppError { + fn register(components: &mut oapi::Components, operation: &mut oapi::Operation) { + let errors = vec![ + StatusError::not_found(), + StatusError::request_timeout(), + StatusError::internal_server_error(), + ]; + for StatusError { code, brief, .. } in errors { + operation.responses.insert( + code.as_str(), + oapi::Response::new(brief) + .add_content("application/json", ApiError::to_schema(components)), + ) + } + } +} diff --git a/gamestream-webtransport-proxy/src/main.rs b/gamestream-webtransport-proxy/src/main.rs index 1fac51d..faeb1a9 100644 --- a/gamestream-webtransport-proxy/src/main.rs +++ b/gamestream-webtransport-proxy/src/main.rs @@ -3,13 +3,14 @@ use std::ffi::CString; use moonlight_common_c_sys::{ _SERVER_INFORMATION, _STREAM_CONFIGURATION, COLOR_RANGE_LIMITED, COLORSPACE_REC_601, - CONNECTION_LISTENER_CALLBACKS, ENCFLG_NONE, PCONNECTION_LISTENER_CALLBACKS, SCM_H264, - STREAM_CFG_LOCAL, VIDEO_FORMAT_H264, + CONNECTION_LISTENER_CALLBACKS, ENCFLG_NONE, SCM_H264, STREAM_CFG_LOCAL, VIDEO_FORMAT_H264, }; mod certs; +mod common; mod pair; +#[allow(unused)] fn get_server_info() -> _SERVER_INFORMATION { _SERVER_INFORMATION { // TODO: these all leak @@ -23,6 +24,7 @@ fn get_server_info() -> _SERVER_INFORMATION { } } +#[allow(unused)] fn get_stream_config() -> _STREAM_CONFIGURATION { let mut remote_input_aes_key_u8: [u8; 16] = [0; 16]; let remote_input_aes_iv: [i8; 16] = [0; 16]; @@ -52,9 +54,11 @@ fn get_stream_config() -> _STREAM_CONFIGURATION { } unsafe extern "C" { - fn printf(format: *const i8, ...) -> (); + #[allow(unused)] + fn printf(format: *const i8, ...); } +#[allow(unused)] fn get_listener_callbacks() -> CONNECTION_LISTENER_CALLBACKS { CONNECTION_LISTENER_CALLBACKS { stageStarting: None, @@ -73,36 +77,38 @@ fn get_listener_callbacks() -> CONNECTION_LISTENER_CALLBACKS { } } -fn barmain() { - //let server_info = moonlight_common_c_sys::LiInitializeServerInformation(); - let mut server_info = get_server_info(); - let mut stream_config = get_stream_config(); - let mut listener_callbacks = get_listener_callbacks(); - let mut ret = 0; - unsafe { - ret = moonlight_common_c_sys::LiStartConnection( - &mut server_info, - &mut stream_config, - &mut listener_callbacks, - std::ptr::null_mut(), - std::ptr::null_mut(), - std::ptr::null_mut(), - 0, - std::ptr::null_mut(), - 0, - ); - } - - println!("{ret}"); - - loop {} - - //println!("Hello, world!"); -} +//fn barmain() { +// //let server_info = moonlight_common_c_sys::LiInitializeServerInformation(); +// let mut server_info = get_server_info(); +// let mut stream_config = get_stream_config(); +// let mut listener_callbacks = get_listener_callbacks(); +// let ret; +// unsafe { +// ret = moonlight_common_c_sys::LiStartConnection( +// &mut server_info, +// &mut stream_config, +// &mut listener_callbacks, +// std::ptr::null_mut(), +// std::ptr::null_mut(), +// std::ptr::null_mut(), +// 0, +// std::ptr::null_mut(), +// 0, +// ); +// } +// +// println!("{ret}"); +// +// loop {} +// +// //println!("Hello, world!"); +//} #[tokio::main] async fn main() { - let mut router = Router::new().push(Router::with_path("pair").post(pair::post_pair)); + let router = Router::new() + .push(Router::with_path("pair").post(pair::post_pair)) + .push(Router::with_path("apps").get(apps::get_apps)); let doc = OpenApi::new("test api", "0.0.1").merge_router(&router); let router = router .unshift(doc.into_router("/api-doc/openapi.json")) diff --git a/gamestream-webtransport-proxy/src/pair.rs b/gamestream-webtransport-proxy/src/pair.rs index b90ea80..ec3eceb 100644 --- a/gamestream-webtransport-proxy/src/pair.rs +++ b/gamestream-webtransport-proxy/src/pair.rs @@ -6,10 +6,13 @@ use rand::Rng; use salvo::prelude::*; use serde::{Deserialize, Serialize}; +use crate::common::{AppError, AppResult}; + #[derive(Debug, Deserialize, ToSchema)] struct PostPairParams { host: String, port: u16, + #[allow(unused)] pair_endpoint: Option, } @@ -20,6 +23,7 @@ struct PostPairReturn { #[derive(Debug, Deserialize)] struct ServerCertResponse { + #[allow(unused)] paired: i32, plaincert: String, } @@ -31,6 +35,7 @@ struct ClientChallengeResponse { #[derive(Debug, Deserialize)] struct ServerChallengeResponseResponse { + #[allow(unused)] paired: i32, pairingsecret: String, } @@ -146,7 +151,7 @@ fn generate_challenge_response( )?; cipher_ctx.cipher_final(&mut client_challenge_response_data)?; - let client_challenge_response_data_hex = hex::encode(&client_challenge_response_data); + //let client_challenge_response_data_hex = hex::encode(&client_challenge_response_data); //println!("client_challenge_response_data_hex: {client_challenge_response_data_hex}"); // Extract ASN.1 signature from certificate @@ -161,7 +166,7 @@ fn generate_challenge_response( challenge_response.extend_from_slice(signature_data); challenge_response.extend_from_slice(client_secret_data); - let challenge_response_hex = hex::encode(&challenge_response); + //let challenge_response_hex = hex::encode(&challenge_response); //println!("challenge_response_hex: {challenge_response_hex}"); let mut hasher = Sha256::new(); @@ -213,7 +218,7 @@ async fn do_challenge( cert: &X509, ) -> Result { let aes_key = generate_aes_key(salt, pin); - let aes_hex = hex::encode(&aes_key); + //let aes_hex = hex::encode(&aes_key); //println!("aes_hex: {aes_hex}"); let client_challenge_response = get_server_challenge(base_url.clone(), &aes_key).await?; @@ -323,15 +328,18 @@ fn get_unique_id() -> Result { Ok(hex::encode(bytes)) } -#[salvo::oapi::endpoint] -pub async fn post_pair(body: salvo::oapi::extract::JsonBody) -> StatusCode { +#[salvo::oapi::endpoint(status_codes(StatusCode::OK, StatusCode::INTERNAL_SERVER_ERROR))] +pub async fn post_pair(body: salvo::oapi::extract::JsonBody) -> AppResult<()> { let params = body.into_inner(); let unique_id = match get_unique_id() { Ok(u) => u, Err(e) => { println!("Could not generate unique id: {e}"); - return StatusCode::INTERNAL_SERVER_ERROR; + return Err(AppError { + status_code: StatusCode::INTERNAL_SERVER_ERROR, + description: "Pairing failed".to_string(), + }); } }; @@ -342,7 +350,10 @@ pub async fn post_pair(body: salvo::oapi::extract::JsonBody) -> let mut client_secret_data = [0u8; 16]; if let Err(e) = openssl::rand::rand_bytes(&mut client_secret_data) { println!("Could not generate client secret data: {e}"); - return StatusCode::INTERNAL_SERVER_ERROR; + return Err(AppError { + status_code: StatusCode::INTERNAL_SERVER_ERROR, + description: "Pairing failed".to_string(), + }); } // Get or generate cert / private key @@ -350,7 +361,10 @@ pub async fn post_pair(body: salvo::oapi::extract::JsonBody) -> Ok(v) => v, Err(e) => { println!("Could not generate certs: {e}"); - return StatusCode::INTERNAL_SERVER_ERROR; + return Err(AppError { + status_code: StatusCode::INTERNAL_SERVER_ERROR, + description: "Pairing failed".to_string(), + }); } }; @@ -370,7 +384,10 @@ pub async fn post_pair(body: salvo::oapi::extract::JsonBody) -> Ok(s) => s, Err(e) => { println!("Could not get server cert: {e}"); - return StatusCode::INTERNAL_SERVER_ERROR; + return Err(AppError { + status_code: StatusCode::INTERNAL_SERVER_ERROR, + description: "Pairing failed".to_string(), + }); } }; //println!("{server_cert:?}"); @@ -382,7 +399,10 @@ pub async fn post_pair(body: salvo::oapi::extract::JsonBody) -> Ok(s) => s, Err(e) => { println!("Could not do challenge: {e}"); - return StatusCode::INTERNAL_SERVER_ERROR; + return Err(AppError { + status_code: StatusCode::INTERNAL_SERVER_ERROR, + description: "Pairing failed".to_string(), + }); } }; //println!("{server_pairing_secret:?}"); @@ -394,7 +414,10 @@ pub async fn post_pair(body: salvo::oapi::extract::JsonBody) -> server_cert.cert, ) { println!("Could not verify signature: {e}"); - return StatusCode::INTERNAL_SERVER_ERROR; + return Err(AppError { + status_code: StatusCode::INTERNAL_SERVER_ERROR, + description: "Pairing failed".to_string(), + }); } let client_pairing_secret_response = @@ -403,18 +426,24 @@ pub async fn post_pair(body: salvo::oapi::extract::JsonBody) -> Ok(p) => p, Err(e) => { println!("Could not send client pairing secret: {e}"); - return StatusCode::INTERNAL_SERVER_ERROR; + return Err(AppError { + status_code: StatusCode::INTERNAL_SERVER_ERROR, + description: "Pairing failed".to_string(), + }); } }; if client_pairing_secret_response.paired != 1 { println!("Failed to pair with server"); - return StatusCode::INTERNAL_SERVER_ERROR; + return Err(AppError { + status_code: StatusCode::INTERNAL_SERVER_ERROR, + description: "Pairing failed".to_string(), + }); } else { let host = ¶ms.host; let port = params.port; println!("Paired with server {host}:{port} successfully!"); } - StatusCode::OK + Ok(()) }