use anyhow::Result; use std::net::ToSocketAddrs; use std::time::Duration; use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt}; use tokio::net::TcpStream; use tokio::pin; use tokio_rustls::rustls::client::danger::{HandshakeSignatureValid, ServerCertVerifier}; use tokio_rustls::rustls::pki_types::{CertificateDer, ServerName, UnixTime}; use tokio_rustls::rustls::{ClientConfig, DigitallySignedStruct}; use tokio_rustls::{rustls, TlsConnector}; use tracing::error; use tracing::info; use tracing::info_span; use tracing::Instrument; use tracing_subscriber::filter::LevelFilter; use tracing_subscriber::EnvFilter; use wtransport::endpoint::IncomingSession; use wtransport::Endpoint; use wtransport::Identity; use wtransport::ServerConfig; use lazy_static::lazy_static; use std::collections::HashMap; use std::sync::{Arc, Mutex}; type GlobalMap = Mutex>; lazy_static! { static ref DATA_MAP: GlobalMap = Mutex::new(HashMap::new()); } #[derive(Debug)] struct NoCertificateVerification; impl ServerCertVerifier for NoCertificateVerification { fn verify_server_cert( &self, _end_entity: &CertificateDer<'_>, _intermediates: &[CertificateDer<'_>], _server_name: &ServerName<'_>, _ocsp: &[u8], _now: UnixTime, ) -> Result { Ok(rustls::client::danger::ServerCertVerified::assertion()) } fn verify_tls12_signature( &self, _message: &[u8], _cert: &CertificateDer<'_>, _dss: &DigitallySignedStruct, ) -> Result { Ok(HandshakeSignatureValid::assertion()) } fn verify_tls13_signature( &self, _message: &[u8], _cert: &CertificateDer<'_>, _dss: &DigitallySignedStruct, ) -> Result { Ok(HandshakeSignatureValid::assertion()) } fn supported_verify_schemes(&self) -> Vec { vec![ rustls::SignatureScheme::RSA_PKCS1_SHA1, rustls::SignatureScheme::ECDSA_SHA1_Legacy, rustls::SignatureScheme::RSA_PKCS1_SHA256, rustls::SignatureScheme::ECDSA_NISTP256_SHA256, rustls::SignatureScheme::RSA_PKCS1_SHA384, rustls::SignatureScheme::ECDSA_NISTP384_SHA384, rustls::SignatureScheme::RSA_PKCS1_SHA512, rustls::SignatureScheme::ECDSA_NISTP521_SHA512, rustls::SignatureScheme::RSA_PSS_SHA256, rustls::SignatureScheme::RSA_PSS_SHA384, rustls::SignatureScheme::RSA_PSS_SHA512, rustls::SignatureScheme::ED25519, rustls::SignatureScheme::ED448, ] } } #[tokio::main] async fn main() -> Result<()> { init_logging(); let config = ServerConfig::builder() .with_bind_default(4433) .with_identity(&Identity::load_pemfiles("cert.pem", "key.pem").await?) .keep_alive_interval(Some(Duration::from_secs(20))) .build(); let server = Endpoint::server(config)?; info!("Server ready!"); for id in 0.. { let incoming_session = server.accept().await; tokio::spawn( handle_connection(incoming_session, id).instrument(info_span!("Connection", id)), ); } Ok(()) } async fn handle_connection(incoming_session: IncomingSession, id: usize) { // Wrapper to handle connection establishment failures if let Err(e) = handle_connection_impl(incoming_session, id).await { error!("{:?}", e); } } async fn handle_connection_impl(incoming_session: IncomingSession, id: usize) -> Result<()> { info!("Waiting for session request..."); let session_request = incoming_session.await?; info!( "New session: Authority: '{}', Path: '{}'", session_request.authority(), session_request.path() ); let connection = session_request.accept().await?; let stream = connection.accept_bi().await?; info!("Connecting to corresponding Mumble server..."); let config = ClientConfig::builder() .dangerous() .with_custom_certificate_verifier(Arc::new(NoCertificateVerification)) .with_no_client_auth(); let connector = TlsConnector::from(Arc::new(config)); let addr = env!("WEBTRANSPORT_PROXY_MUMBLE_SERVER_URL") .to_string() .to_socket_addrs()? .next() .unwrap(); let server_tcp = TcpStream::connect(addr).await?; let server_stream = connector .connect("ohea.xyz".try_into()?, server_tcp) .await?; let (read_server, write_server) = tokio::io::split(server_stream); info!("Connected to Mumble Server!"); // Store connection in global map to prevent it getting dropped DATA_MAP.lock().unwrap().insert(id, connection); info!("Spawing jobs..."); // Spawn tasks to handle transmitting data between the WebTransport client and Mumble TCP Server tokio::spawn( handle_client_to_server(stream.1, write_server) .instrument(info_span!("Handler", "Client to server")), ); tokio::spawn( handle_server_to_client(stream.0, read_server) .instrument(info_span!("Handler", "Server to client")), ); info!("Spawned jobs."); Ok(()) } async fn handle_client_to_server( client_stream: wtransport::RecvStream, server_stream: impl AsyncWrite, ) { let result = client_to_server_loop(client_stream, server_stream).await; error!("{:?}", result); } async fn client_to_server_loop( mut client_stream: wtransport::RecvStream, server_stream: impl AsyncWrite, ) -> Result<()> { let mut buffer = vec![0; 65536].into_boxed_slice(); pin!(server_stream); loop { let bytes_read = match client_stream.read(&mut buffer).await? { Some(bytes_read) => bytes_read, None => break Ok(()), }; server_stream.write_all(&buffer[..bytes_read]).await?; server_stream.flush().await?; } } async fn handle_server_to_client( client_stream: wtransport::SendStream, server_stream: impl AsyncRead, ) { let result = server_to_client_loop(client_stream, server_stream).await; error!("{:?}", result); } async fn server_to_client_loop( mut client_stream: wtransport::SendStream, server_stream: impl AsyncRead, ) -> Result<()> { let mut buffer = vec![0; 65536].into_boxed_slice(); pin!(server_stream); loop { let bytes_read = server_stream.read(&mut buffer).await?; client_stream.write_all(&buffer[..bytes_read]).await?; } } fn init_logging() { let env_filter = EnvFilter::builder() .with_default_directive(LevelFilter::INFO.into()) .from_env_lossy(); tracing_subscriber::fmt() .with_target(true) .with_level(true) .with_env_filter(env_filter) .init(); }