use anyhow::{anyhow, Context, Result}; use axum::extract::State; use axum::http::{Response, StatusCode}; use axum::response::IntoResponse; use serde::{Deserialize, Serialize}; use std::future::IntoFuture; use std::net::{SocketAddr, ToSocketAddrs}; use std::path::PathBuf; use std::time::Duration; use tokio::fs::read_to_string; use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt}; use tokio::net::TcpStream; use tokio::pin; use tokio_rustls::rustls::client::danger::{HandshakeSignatureValid, ServerCertVerifier}; use tokio_rustls::rustls::pki_types::{CertificateDer, ServerName, UnixTime}; use tokio_rustls::rustls::{ClientConfig, DigitallySignedStruct}; use tokio_rustls::{rustls, TlsConnector}; use tracing::error; use tracing::info; use tracing::info_span; use tracing::Instrument; use tracing_subscriber::filter::LevelFilter; use tracing_subscriber::EnvFilter; use wtransport::endpoint::IncomingSession; use wtransport::Endpoint; use wtransport::Identity; use wtransport::ServerConfig; use lazy_static::lazy_static; use std::collections::HashMap; use std::sync::{Arc, Mutex}; type GlobalMap = Mutex>; lazy_static! { static ref DATA_MAP: GlobalMap = Mutex::new(HashMap::new()); } #[derive(Debug)] struct NoCertificateVerification; impl ServerCertVerifier for NoCertificateVerification { fn verify_server_cert( &self, _end_entity: &CertificateDer<'_>, _intermediates: &[CertificateDer<'_>], _server_name: &ServerName<'_>, _ocsp: &[u8], _now: UnixTime, ) -> Result { Ok(rustls::client::danger::ServerCertVerified::assertion()) } fn verify_tls12_signature( &self, _message: &[u8], _cert: &CertificateDer<'_>, _dss: &DigitallySignedStruct, ) -> Result { Ok(HandshakeSignatureValid::assertion()) } fn verify_tls13_signature( &self, _message: &[u8], _cert: &CertificateDer<'_>, _dss: &DigitallySignedStruct, ) -> Result { Ok(HandshakeSignatureValid::assertion()) } fn supported_verify_schemes(&self) -> Vec { vec![ rustls::SignatureScheme::RSA_PKCS1_SHA1, rustls::SignatureScheme::ECDSA_SHA1_Legacy, rustls::SignatureScheme::RSA_PKCS1_SHA256, rustls::SignatureScheme::ECDSA_NISTP256_SHA256, rustls::SignatureScheme::RSA_PKCS1_SHA384, rustls::SignatureScheme::ECDSA_NISTP384_SHA384, rustls::SignatureScheme::RSA_PKCS1_SHA512, rustls::SignatureScheme::ECDSA_NISTP521_SHA512, rustls::SignatureScheme::RSA_PSS_SHA256, rustls::SignatureScheme::RSA_PSS_SHA384, rustls::SignatureScheme::RSA_PSS_SHA512, rustls::SignatureScheme::ED25519, rustls::SignatureScheme::ED448, ] } } #[derive(Clone, Deserialize, Serialize)] struct GuiConfig { proxy_url: Option, cert_hash: Option>, } #[derive(Clone, Deserialize)] struct Config { proxy_listen_address: SocketAddr, http_listen_address: SocketAddr, cert_path: String, key_path: String, #[serde(default)] serve_https: bool, mumble_server_url: String, gui_path: PathBuf, gui: GuiConfig, } //async fn serve_index_html_with_config(State(config): State) -> impl IntoResponse { async fn serve_index_html_with_config(State(config): State) -> impl IntoResponse { // Load the HTML file let html = match read_to_string(config.gui_path.join("index.html")).await { Ok(content) => content, Err(_) => return (StatusCode::NOT_FOUND, "File not found").into_response(), }; // Insert the script tag with configuration let modified_html = html.replace( "", &format!( "\n", serde_json::to_string(&config.gui).unwrap(), ), ); // Create a response with the modified HTML Response::builder() .status(StatusCode::OK) .header("Content-Type", "text/html") .body(modified_html) .unwrap() .into_response() } fn configure_tls(config: &Config) -> Result { // Thanks perplexity! use rustls_pemfile::{certs, pkcs8_private_keys}; use std::fs::File; use std::io::BufReader; // Create a new ServerConfig with no client authentication //(rustls::server::NoClientAuth::new()); // Read the certificate file let cert_file = File::open(&config.cert_path)?; let mut cert_reader = BufReader::new(cert_file); let cert_chain = certs(&mut cert_reader).collect::>()?; // Read the private key file let key_file = File::open(&config.key_path)?; let mut key_reader = BufReader::new(key_file); let key = pkcs8_private_keys(&mut key_reader) .next() .ok_or(anyhow!("no keys in key.pem"))??; // Set the certificate chain and private key let config = rustls::ServerConfig::builder() .with_no_client_auth() .with_single_cert(cert_chain, key.into())?; Ok(config) } #[tokio::main] async fn main() -> Result<()> { init_logging(); let proxy_config: Config = toml::from_str( &read_to_string("./config.toml") .await .context("reading config.toml (try making a copy of config.toml.example)")?, )?; let mumble_server_addr = proxy_config .mumble_server_url .to_socket_addrs() .context(format!( "parsing mumble_server_url={}", proxy_config.mumble_server_url ))? .next() .ok_or(anyhow!( "no socket addrs in mumble_server_url={}", proxy_config.mumble_server_url ))?; // Setup HTTP Server //let http = axum::Router::new().route("/", axum::routing::get(serve_gui)); let app = axum::Router::new() .route("/", axum::routing::get(serve_index_html_with_config)) .fallback_service(tower_http::services::ServeDir::new(&proxy_config.gui_path)) .with_state(proxy_config.clone()); if proxy_config.serve_https { tokio::spawn( axum_server::bind_rustls( proxy_config.http_listen_address, axum_server::tls_rustls::RustlsConfig::from_config(Arc::new(configure_tls( &proxy_config, )?)), ) .serve(app.into_make_service()) .into_future(), ); } else { tokio::spawn( axum_server::bind(proxy_config.http_listen_address) .serve(app.into_make_service()) .into_future(), ); } // Setup WebTransport proxy listener let identity = Identity::load_pemfiles(proxy_config.cert_path, proxy_config.key_path).await?; let config = ServerConfig::builder() .with_bind_address(proxy_config.proxy_listen_address) .with_identity(&identity) .keep_alive_interval(Some(Duration::from_secs(20))) .build(); let server = Endpoint::server(config)?; info!("Server ready!"); for id in 0.. { let incoming_session = server.accept().await; tokio::spawn( handle_connection(incoming_session, id, mumble_server_addr) .instrument(info_span!("Connection", id)), ); } Ok(()) } async fn handle_connection( incoming_session: IncomingSession, id: usize, mumble_server_address: SocketAddr, ) { // Wrapper to handle connection establishment failures if let Err(e) = handle_connection_impl(incoming_session, id, mumble_server_address).await { error!("{:?}", e); } } async fn handle_connection_impl( incoming_session: IncomingSession, id: usize, mumble_server_address: SocketAddr, ) -> Result<()> { info!("Waiting for session request..."); let session_request = incoming_session.await?; info!( "New session: Authority: '{}', Path: '{}'", session_request.authority(), session_request.path() ); let connection = session_request.accept().await?; let stream = connection.accept_bi().await?; info!("Connecting to corresponding Mumble server..."); let config = ClientConfig::builder() .dangerous() .with_custom_certificate_verifier(Arc::new(NoCertificateVerification)) .with_no_client_auth(); let connector = TlsConnector::from(Arc::new(config)); let server_tcp = TcpStream::connect(mumble_server_address).await?; let server_stream = connector .connect("example.com".try_into()?, server_tcp) .await?; let (read_server, write_server) = tokio::io::split(server_stream); info!("Connected to Mumble Server!"); // Store connection in global map to prevent it getting dropped DATA_MAP.lock().unwrap().insert(id, connection); info!("Spawing jobs..."); // Spawn tasks to handle transmitting data between the WebTransport client and Mumble TCP Server tokio::spawn( handle_client_to_server(stream.1, write_server) .instrument(info_span!("Handler", "Client to server")), ); tokio::spawn( handle_server_to_client(stream.0, read_server) .instrument(info_span!("Handler", "Server to client")), ); info!("Spawned jobs."); Ok(()) } async fn handle_client_to_server( client_stream: wtransport::RecvStream, server_stream: impl AsyncWrite, ) { let result = client_to_server_loop(client_stream, server_stream).await; error!("{:?}", result); } async fn client_to_server_loop( mut client_stream: wtransport::RecvStream, server_stream: impl AsyncWrite, ) -> Result<()> { let mut buffer = vec![0; 65536].into_boxed_slice(); pin!(server_stream); loop { let bytes_read = match client_stream.read(&mut buffer).await? { Some(bytes_read) => bytes_read, None => break Ok(()), }; server_stream.write_all(&buffer[..bytes_read]).await?; server_stream.flush().await?; } } async fn handle_server_to_client( client_stream: wtransport::SendStream, server_stream: impl AsyncRead, ) { let result = server_to_client_loop(client_stream, server_stream).await; error!("{:?}", result); } async fn server_to_client_loop( mut client_stream: wtransport::SendStream, server_stream: impl AsyncRead, ) -> Result<()> { let mut buffer = vec![0; 65536].into_boxed_slice(); pin!(server_stream); loop { let bytes_read = server_stream.read(&mut buffer).await?; client_stream.write_all(&buffer[..bytes_read]).await?; } } fn init_logging() { let env_filter = EnvFilter::builder() .with_default_directive(LevelFilter::INFO.into()) .from_env_lossy(); tracing_subscriber::fmt() .with_target(true) .with_level(true) .with_env_filter(env_filter) .init(); }