wip salvo server

This commit is contained in:
2024-11-11 17:14:53 -07:00
parent 80aedc7269
commit 3c6a436690
7 changed files with 1030 additions and 722 deletions
+231 -260
View File
@@ -1,14 +1,15 @@
use axum::extract::State;
use axum::http::{Response, StatusCode};
use axum::response::IntoResponse;
use color_eyre::eyre::{anyhow, Context, Error, Result};
use mumble_web2_common::GuiConfig;
use once_cell::sync::OnceCell;
use salvo::conn::rustls::{Keycert, RustlsConfig};
use salvo::logging::Logger;
use salvo::prelude::*;
use salvo::proto::quic::BidiStream;
use serde::Deserialize;
use std::future::IntoFuture;
use std::net::{SocketAddr, ToSocketAddrs};
use std::path::PathBuf;
use std::time::Duration;
use tokio::fs::read_to_string;
use std::sync::Arc;
use tokio::fs;
use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt};
use tokio::net::TcpStream;
use tokio::pin;
@@ -16,25 +17,229 @@ use tokio_rustls::rustls::client::danger::{HandshakeSignatureValid, ServerCertVe
use tokio_rustls::rustls::pki_types::{CertificateDer, ServerName, UnixTime};
use tokio_rustls::rustls::{ClientConfig, DigitallySignedStruct};
use tokio_rustls::{rustls, TlsConnector};
use tracing::error;
use tracing::info;
use tracing::info_span;
use tracing::Instrument;
use tracing::{error, instrument};
use tracing_subscriber::filter::LevelFilter;
use tracing_subscriber::EnvFilter;
use wtransport::endpoint::IncomingSession;
use wtransport::Endpoint;
use wtransport::Identity;
use wtransport::ServerConfig;
use lazy_static::lazy_static;
use std::collections::HashMap;
use std::sync::{Arc, Mutex};
#[derive(Clone, Deserialize)]
struct Config {
https_listen_address: SocketAddr,
http_listen_address: Option<SocketAddr>,
cert_path: PathBuf,
key_path: PathBuf,
mumble_server_url: String,
mumble_server_address: Option<SocketAddr>,
gui_path: PathBuf,
gui: GuiConfig,
}
type GlobalMap = Mutex<HashMap<usize, wtransport::Connection>>;
static CONFIG: OnceCell<Config> = OnceCell::new();
lazy_static! {
static ref DATA_MAP: GlobalMap = Mutex::new(HashMap::new());
#[handler]
#[instrument]
async fn serve_gui_index_html(req: &Request, res: &mut Response) {
let config = CONFIG.get().unwrap();
// Load the HTML file
let path = config.gui_path.join("index.html");
let html = match fs::read_to_string(&path).await {
Ok(content) => content,
Err(err) => {
error!("could not load {}: {:?}", path.display(), err);
res.status_code(StatusCode::INTERNAL_SERVER_ERROR);
return;
}
};
// Insert the script tag with configuration
let modified_html = html.replace(
"</head>",
&format!(
"<script>window.config = {}</script>\n</head>",
serde_json::to_string(&config.gui).unwrap(),
),
);
res.render(Text::Html(modified_html));
}
#[handler]
async fn redirect_to_gui(res: &mut Response) {
res.render(Redirect::permanent("/gui"));
}
async fn init_config() -> Result<()> {
let mut config: Config = toml::from_str(
&fs::read_to_string("./config.toml")
.await
.context("reading config.toml (try making a copy of config.toml.example)")?,
)?;
let mumble_server_addr = config
.mumble_server_url
.to_socket_addrs()
.context(format!(
"parsing mumble_server_url={}",
config.mumble_server_url
))?
.next()
.ok_or(anyhow!(
"no socket addrs in mumble_server_url={}",
config.mumble_server_url
))?;
config.mumble_server_address = Some(mumble_server_addr);
CONFIG
.set(config)
.map_err(|_| anyhow!("config already initialized"))?;
Ok(())
}
#[tokio::main]
async fn main() -> Result<()> {
init_logging();
init_config().await?;
let config = CONFIG.get().unwrap();
// Server routing
let router = Router::new()
.get(redirect_to_gui)
.push(Router::with_path("/proxy").goal(connect_proxy))
.push(Router::with_path("/gui").get(serve_gui_index_html))
.push(Router::with_path("/gui/<*+rest>").get(StaticDir::new(config.gui_path.clone())))
.hoop(Logger::new());
// Read server certs
rustls::crypto::aws_lc_rs::default_provider()
.install_default()
.map_err(|e| anyhow!("could not install crypto provider {e:?}"))?;
let cert = fs::read(&config.cert_path)
.await
.context(format!("reading cert {}", config.cert_path.display()))?;
let key = fs::read(&config.key_path)
.await
.context(format!("reading key {}", config.key_path.display()))?;
let rustls_config = RustlsConfig::new(Keycert::new().cert(cert.as_slice()).key(key.as_slice()));
// Create http listeners
let http_listener = config.http_listen_address.map(TcpListener::new);
let https_listener =
TcpListener::new(config.https_listen_address).rustls(rustls_config.clone());
let http3_listener = QuinnListener::new(rustls_config, config.https_listen_address);
// Start server
match (http_listener, https_listener, http3_listener) {
(Some(a), b, c) => {
let accepter = a.join(b).join(c).bind().await;
Server::new(accepter).serve(router).await;
}
(None, b, c) => {
let accepter = b.join(c).bind().await;
Server::new(accepter).serve(router).await;
}
}
Ok(())
}
#[handler]
#[instrument]
async fn connect_proxy(req: &mut Request, res: &mut Response) {
info!("received proxy request");
let mumble_server_address = CONFIG.get().unwrap().mumble_server_address.unwrap();
let wt = match req.web_transport_mut().await {
Ok(wt) => wt,
Err(err) => {
res.status_code(StatusCode::BAD_REQUEST);
res.render(format!("error with webtransport: {err:?}"));
return;
}
};
info!("got webtransport for connection");
use salvo::webtransport::server::AcceptedBi;
let (id, bi) = match wt.accept_bi().await {
Ok(Some(AcceptedBi::BidiStream(id, bi))) => (id, bi),
Ok(Some(AcceptedBi::Request(req, _))) => {
res.status_code(StatusCode::BAD_REQUEST);
res.render(format!(
"expected webtransport stream but got request {req:?}"
));
return;
}
Ok(None) => {
res.status_code(StatusCode::BAD_REQUEST);
res.render(format!("no bidirectional connection requested"));
return;
}
Err(err) => {
res.status_code(StatusCode::INTERNAL_SERVER_ERROR);
res.render(format!("error with bidirectional connection: {err:?}"));
return;
}
};
/*
let id = wt.session_id();
let bi = match wt.open_bi(id).await {
Ok(bi) => bi,
Err(err) => {
res.status_code(StatusCode::BAD_REQUEST);
res.render(format!("could not open bidirectional stream: {err:?}"));
return;
}
};
*/
let (outgoing, incoming) = bi.split();
tokio::spawn(async move {
if let Err(error) = connect_proxy_impl(mumble_server_address, incoming, outgoing).await {
error!("error connecting proxy {error:?}")
}
});
res.render("connected");
}
#[instrument(skip(incoming, outgoing))]
async fn connect_proxy_impl(
mumble_server_address: SocketAddr,
incoming: impl AsyncRead + Send + Sync + 'static,
outgoing: impl AsyncWrite + Send + Sync + 'static,
) -> Result<()> {
info!("connecting to Mumble server...");
let config = ClientConfig::builder()
.dangerous()
.with_custom_certificate_verifier(Arc::new(NoCertificateVerification))
.with_no_client_auth();
let connector = TlsConnector::from(Arc::new(config));
let server_tcp = TcpStream::connect(mumble_server_address).await?;
let server_stream = connector
.connect("example.com".try_into()?, server_tcp)
.await?;
let (read_server, write_server) = tokio::io::split(server_stream);
info!("connected to Mumble server");
// Spawn tasks to handle transmitting data between the WebTransport client and Mumble TCP Server
let c2s = tokio::spawn(
pass_bytes_loop(incoming, write_server)
.instrument(info_span!("Handler", "Client to server")),
);
let s2c = tokio::spawn(
pass_bytes_loop(read_server, outgoing)
.instrument(info_span!("Handler", "Server to client")),
);
tokio::select! {
res = c2s => res??,
res = s2c => res??,
};
Ok(())
}
#[derive(Debug)]
@@ -89,261 +294,27 @@ impl ServerCertVerifier for NoCertificateVerification {
}
}
#[derive(Clone, Deserialize)]
struct Config {
proxy_listen_address: SocketAddr,
http_listen_address: SocketAddr,
cert_path: PathBuf,
key_path: PathBuf,
#[serde(default)]
serve_https: bool,
mumble_server_url: String,
gui_path: PathBuf,
gui: GuiConfig,
}
//async fn serve_index_html_with_config(State(config): State<Config>) -> impl IntoResponse {
async fn serve_index_html_with_config(State(config): State<Config>) -> impl IntoResponse {
// Load the HTML file
let html = match read_to_string(config.gui_path.join("index.html")).await {
Ok(content) => content,
Err(_) => return (StatusCode::NOT_FOUND, "File not found").into_response(),
};
// Insert the script tag with configuration
let modified_html = html.replace(
"</head>",
&format!(
"<script>window.config = {}</script>\n</head>",
serde_json::to_string(&config.gui).unwrap(),
),
);
// Create a response with the modified HTML
Response::builder()
.status(StatusCode::OK)
.header("Content-Type", "text/html")
.body(modified_html)
.unwrap()
.into_response()
}
fn configure_tls(config: &Config) -> Result<rustls::ServerConfig, Error> {
// Thanks perplexity!
use rustls_pemfile::{certs, pkcs8_private_keys};
use std::fs::File;
use std::io::BufReader;
// Create a new ServerConfig with no client authentication
//(rustls::server::NoClientAuth::new());
// Read the certificate file
let cert_file = File::open(&config.cert_path)
.context(format!("opening cert {}", config.cert_path.display()))?;
let mut cert_reader = BufReader::new(cert_file);
let cert_chain = certs(&mut cert_reader).collect::<Result<_, _>>()?;
// Read the private key file
let key_file = File::open(&config.key_path)
.context(format!("opening key {}", config.key_path.display()))?;
let mut key_reader = BufReader::new(key_file);
let key = pkcs8_private_keys(&mut key_reader)
.next()
.ok_or(anyhow!("no keys in key.pem"))??;
// Set the certificate chain and private key
let config = rustls::ServerConfig::builder()
.with_no_client_auth()
.with_single_cert(cert_chain, key.into())?;
Ok(config)
}
#[tokio::main]
async fn main() -> Result<()> {
init_logging();
let proxy_config: Config = toml::from_str(
&read_to_string("./config.toml")
.await
.context("reading config.toml (try making a copy of config.toml.example)")?,
)?;
let mumble_server_addr = proxy_config
.mumble_server_url
.to_socket_addrs()
.context(format!(
"parsing mumble_server_url={}",
proxy_config.mumble_server_url
))?
.next()
.ok_or(anyhow!(
"no socket addrs in mumble_server_url={}",
proxy_config.mumble_server_url
))?;
// Setup HTTP Server
//let http = axum::Router::new().route("/", axum::routing::get(serve_gui));
let app = axum::Router::new()
.route("/", axum::routing::get(serve_index_html_with_config))
.fallback_service(tower_http::services::ServeDir::new(&proxy_config.gui_path))
.with_state(proxy_config.clone());
if proxy_config.serve_https {
tokio::spawn(
axum_server::bind_rustls(
proxy_config.http_listen_address,
axum_server::tls_rustls::RustlsConfig::from_config(Arc::new(configure_tls(
&proxy_config,
)?)),
)
.serve(app.into_make_service())
.into_future(),
);
} else {
tokio::spawn(
axum_server::bind(proxy_config.http_listen_address)
.serve(app.into_make_service())
.into_future(),
);
}
// Setup WebTransport proxy listener
let identity = Identity::load_pemfiles(proxy_config.cert_path, proxy_config.key_path).await?;
let config = ServerConfig::builder()
.with_bind_address(proxy_config.proxy_listen_address)
.with_identity(&identity)
.keep_alive_interval(Some(Duration::from_secs(20)))
.build();
let server = Endpoint::server(config)?;
info!("Server ready!");
for id in 0.. {
let incoming_session = server.accept().await;
tokio::spawn(
handle_connection(incoming_session, id, mumble_server_addr)
.instrument(info_span!("Connection", id)),
);
}
Ok(())
}
async fn handle_connection(
incoming_session: IncomingSession,
id: usize,
mumble_server_address: SocketAddr,
) {
// Wrapper to handle connection establishment failures
if let Err(e) = handle_connection_impl(incoming_session, id, mumble_server_address).await {
error!("{:?}", e);
}
}
async fn handle_connection_impl(
incoming_session: IncomingSession,
id: usize,
mumble_server_address: SocketAddr,
) -> Result<()> {
info!("Waiting for session request...");
let session_request = incoming_session.await?;
info!(
"New session: Authority: '{}', Path: '{}'",
session_request.authority(),
session_request.path()
);
let connection = session_request.accept().await?;
let stream = connection.accept_bi().await?;
info!("Connecting to corresponding Mumble server...");
let config = ClientConfig::builder()
.dangerous()
.with_custom_certificate_verifier(Arc::new(NoCertificateVerification))
.with_no_client_auth();
let connector = TlsConnector::from(Arc::new(config));
let server_tcp = TcpStream::connect(mumble_server_address).await?;
let server_stream = connector
.connect("example.com".try_into()?, server_tcp)
.await?;
let (read_server, write_server) = tokio::io::split(server_stream);
info!("Connected to Mumble Server!");
// Store connection in global map to prevent it getting dropped
DATA_MAP.lock().unwrap().insert(id, connection);
info!("Spawing jobs...");
// Spawn tasks to handle transmitting data between the WebTransport client and Mumble TCP Server
tokio::spawn(
handle_client_to_server(stream.1, write_server)
.instrument(info_span!("Handler", "Client to server")),
);
tokio::spawn(
handle_server_to_client(stream.0, read_server)
.instrument(info_span!("Handler", "Server to client")),
);
info!("Spawned jobs.");
Ok(())
}
async fn handle_client_to_server(
client_stream: wtransport::RecvStream,
server_stream: impl AsyncWrite,
) {
let result = client_to_server_loop(client_stream, server_stream).await;
error!("{:?}", result);
}
async fn client_to_server_loop(
mut client_stream: wtransport::RecvStream,
server_stream: impl AsyncWrite,
async fn pass_bytes_loop(
client_stream: impl AsyncRead + Sync + Send + 'static,
server_stream: impl AsyncWrite + Send + Sync + 'static,
) -> Result<()> {
let mut buffer = vec![0; 65536].into_boxed_slice();
pin!(client_stream);
pin!(server_stream);
loop {
let bytes_read = match client_stream.read(&mut buffer).await? {
Some(bytes_read) => bytes_read,
None => break Ok(()),
};
let bytes_read = client_stream.read(&mut buffer).await?;
if bytes_read == 0 {
break Ok(());
}
server_stream.write_all(&buffer[..bytes_read]).await?;
server_stream.flush().await?;
}
}
async fn handle_server_to_client(
client_stream: wtransport::SendStream,
server_stream: impl AsyncRead,
) {
let result = server_to_client_loop(client_stream, server_stream).await;
error!("{:?}", result);
}
async fn server_to_client_loop(
mut client_stream: wtransport::SendStream,
server_stream: impl AsyncRead,
) -> Result<()> {
let mut buffer = vec![0; 65536].into_boxed_slice();
pin!(server_stream);
loop {
let bytes_read = server_stream.read(&mut buffer).await?;
client_stream.write_all(&buffer[..bytes_read]).await?;
}
}
fn init_logging() {
let env_filter = EnvFilter::builder()
.with_default_directive(LevelFilter::INFO.into())
.with_default_directive(LevelFilter::DEBUG.into())
.from_env_lossy();
tracing_subscriber::fmt()