use crossbeam::atomic::AtomicCell; use df::tract::{mut_slice_as_arrayviewmut, slice_as_arrayview}; use df::tract::{DfParams, DfTract, RuntimeParams}; use dioxus_asset_resolver::read_asset_bytes; use manganis::{asset, Asset}; use std::cell::RefCell; use std::sync::Arc; use tracing::{error, info}; use crate::imp::SpawnHandle; static DF_MODEL: Asset = asset!("/assets/DeepFilterNet3_ll_onnx.tar.gz"); // TODO: make this user configurable. static DEFAULT_NOISE_FLOOR: f32 = 0.001; // 200ms hold at 48kHz sample rate static HOLD_SAMPLES_MAX: usize = 48000 / 5; // 9600 samples = 200ms /// Indicates the transmission state after processing audio. #[derive(Debug, Clone, Copy, PartialEq, Eq)] pub enum TransmitState { /// Audio is above threshold, or below but within hold period - transmit normally Transmitting, /// Hold period expired - send this frame as terminator (end_bit = true) Terminator, /// Silent and not transmitting - don't send anything Silent, } enum DenoisingModelState { Nothing, Downloading(Arc>>), Availible(Box), } fn with_denoising_model(spawn: &SpawnHandle, func: impl FnOnce(&mut DfTract) -> O) -> Option { // Using a thread local is super gross, but DfTract is not Send (so it can never leave the current // thread) while AudioProcessing itself might change threads whenever. thread_local! { static STATE: RefCell = const { RefCell::new(DenoisingModelState::Nothing) }; } STATE.with_borrow_mut(|state| match state { DenoisingModelState::Nothing => { let cell = Arc::new(AtomicCell::new(None)); let cell_task = cell.clone(); *state = DenoisingModelState::Downloading(cell); let model = DF_MODEL.to_string(); spawn.spawn(async move { let model_bytes = match read_asset_bytes(&model).await { Ok(b) => b, Err(e) => { error!("could not read denoising model from \"{model}\": {e:?}"); return; } }; let params = match DfParams::from_bytes(&model_bytes) { Ok(p) => p, Err(e) => { error!("could not load denoising model parameters: {e:?}"); return; } }; cell_task.store(Some(params)); }); None } DenoisingModelState::Downloading(cell) => { if let Some(params) = cell.take() { let mut tract = match DfTract::new(params, &RuntimeParams::default_with_ch(1)) { Ok(t) => Box::new(t), Err(e) => { error!("could not create denoising engine: {e:?}"); return None; } }; info!("instantiated denoising engine"); let out = func(&mut tract); *state = DenoisingModelState::Availible(tract); Some(out) } else { None } } DenoisingModelState::Availible(tract) => Some(func(tract)), }) } pub struct AudioProcessor { denoise: bool, spawn: SpawnHandle, buffer: Vec, noise_floor: f32, /// Whether we were transmitting in the previous frame was_transmitting: bool, /// Number of samples we've been below threshold (for hold period) hold_samples: usize, } impl AudioProcessor { pub fn new(denoise: bool) -> Self { AudioProcessor { denoise, spawn: SpawnHandle::current(), buffer: Vec::new(), noise_floor: DEFAULT_NOISE_FLOOR, was_transmitting: false, hold_samples: 0, } } } impl AudioProcessor { pub fn process( &mut self, audio: &[f32], channels: usize, output: &mut Vec, ) -> TransmitState { let mut include_raw = true; if self.denoise { with_denoising_model(&self.spawn, |df| { include_raw = false; self.buffer.extend(audio.iter().step_by(channels).copied()); output.reserve(audio.len()); let hop = df.hop_size; let mut i = 0; while self.buffer[i..].len() >= hop { let audio = &self.buffer[i..][..hop]; i += audio.len(); let j = output.len(); output.extend(std::iter::repeat_n(0f32, audio.len())); let output = &mut output[j..]; df.process( slice_as_arrayview(audio, &[audio.len()]) .into_shape((1, audio.len())) .unwrap(), mut_slice_as_arrayviewmut(output, &[output.len()]) .into_shape((1, output.len())) .unwrap(), ); } self.buffer.splice(..i, []); }); } if include_raw { output.extend(audio.iter().step_by(channels).copied()); } // Calculate average amplitude for VAD let avg: f32 = if output.is_empty() { 0.0 } else { output.iter().map(|x| x.abs()).sum::() / output.len() as f32 }; let above_threshold = avg >= self.noise_floor; let samples_in_frame = output.len(); let state = if above_threshold { // Above threshold - reset hold counter and transmit self.hold_samples = 0; self.was_transmitting = true; TransmitState::Transmitting } else if self.was_transmitting && self.hold_samples < HOLD_SAMPLES_MAX { // Below threshold but in hold period - keep transmitting self.hold_samples += samples_in_frame; TransmitState::Transmitting } else if self.was_transmitting { // Hold period expired - send terminator self.was_transmitting = false; self.hold_samples = 0; TransmitState::Terminator } else { // Not transmitting and below threshold - stay silent output.clear(); // Don't accumulate stale audio during silence TransmitState::Silent }; state } } pub type AudioProcessorSender = Arc>>;