forgot length check

This commit is contained in:
2024-05-21 16:53:06 -04:00
parent b2cae01bf8
commit abd2a2f81c
+18 -15
View File
@@ -1,5 +1,6 @@
use anyhow::Result; use anyhow::Result;
use std::time::Duration; use std::time::Duration;
use tokio::io::AsyncReadExt;
use tokio::net::tcp::{OwnedReadHalf, OwnedWriteHalf}; use tokio::net::tcp::{OwnedReadHalf, OwnedWriteHalf};
use tokio::net::TcpStream; use tokio::net::TcpStream;
use tracing::error; use tracing::error;
@@ -9,10 +10,9 @@ use tracing::Instrument;
use tracing_subscriber::filter::LevelFilter; use tracing_subscriber::filter::LevelFilter;
use tracing_subscriber::EnvFilter; use tracing_subscriber::EnvFilter;
use wtransport::endpoint::IncomingSession; use wtransport::endpoint::IncomingSession;
use wtransport::{connection, Endpoint}; use wtransport::Endpoint;
use wtransport::Identity; use wtransport::Identity;
use wtransport::ServerConfig; use wtransport::ServerConfig;
use tokio::io::AsyncReadExt;
use lazy_static::lazy_static; use lazy_static::lazy_static;
use std::collections::HashMap; use std::collections::HashMap;
@@ -40,7 +40,9 @@ async fn main() -> Result<()> {
for id in 0.. { for id in 0.. {
let incoming_session = server.accept().await; let incoming_session = server.accept().await;
tokio::spawn(handle_connection(incoming_session, id).instrument(info_span!("Connection", id))); tokio::spawn(
handle_connection(incoming_session, id).instrument(info_span!("Connection", id)),
);
} }
Ok(()) Ok(())
@@ -54,8 +56,6 @@ async fn handle_connection(incoming_session: IncomingSession, id: usize) {
} }
async fn handle_connection_impl(incoming_session: IncomingSession, id: usize) -> Result<()> { async fn handle_connection_impl(incoming_session: IncomingSession, id: usize) -> Result<()> {
let mut buffer = vec![0; 65536].into_boxed_slice();
info!("Waiting for session request..."); info!("Waiting for session request...");
let session_request = incoming_session.await?; let session_request = incoming_session.await?;
@@ -80,15 +80,20 @@ async fn handle_connection_impl(incoming_session: IncomingSession, id: usize) ->
info!("Spawing jobs..."); info!("Spawing jobs...");
// Spawn tasks to handle transmitting data between the WebTransport client and Mumble TCP Server // Spawn tasks to handle transmitting data between the WebTransport client and Mumble TCP Server
tokio::spawn(handle_client_to_server(stream.1, server_tcp.1).instrument(info_span!("Handler", "Client to server"))); tokio::spawn(
tokio::spawn(handle_server_to_client(stream.0, server_tcp.0).instrument(info_span!("Handler", "Server to client"))); handle_client_to_server(stream.1, server_tcp.1)
.instrument(info_span!("Handler", "Client to server")),
);
tokio::spawn(
handle_server_to_client(stream.0, server_tcp.0)
.instrument(info_span!("Handler", "Server to client")),
);
info!("Spawned jobs."); info!("Spawned jobs.");
Ok(()) Ok(())
} }
async fn handle_client_to_server( async fn handle_client_to_server(
client_stream: wtransport::RecvStream, client_stream: wtransport::RecvStream,
server_stream: OwnedWriteHalf, server_stream: OwnedWriteHalf,
@@ -104,13 +109,12 @@ async fn client_to_server_loop(
let mut buffer = vec![0; 65536].into_boxed_slice(); let mut buffer = vec![0; 65536].into_boxed_slice();
loop { loop {
info!("Reading Data"); info!("Reading Data");
let _bytes_read = match client_stream.read(&mut buffer).await? { let bytes_read = match client_stream.read(&mut buffer).await? {
Some(bytes_read) => bytes_read, Some(bytes_read) => bytes_read,
None => continue, None => break Ok(()),
}; };
info!("Writing data"); info!("Writing data");
server_stream.try_write(&buffer)?; server_stream.try_write(&buffer[..bytes_read])?;
} }
} }
@@ -128,9 +132,8 @@ async fn server_to_client_loop(
) -> Result<()> { ) -> Result<()> {
let mut buffer = vec![0; 65536].into_boxed_slice(); let mut buffer = vec![0; 65536].into_boxed_slice();
loop { loop {
server_stream.read(&mut buffer).await?; let bytes_read = server_stream.read(&mut buffer).await?;
client_stream.write_all(&buffer[..bytes_read]).await?;
client_stream.write_all(&buffer).await?;
} }
} }