diff --git a/core/src/context/diagnostic.rs b/core/src/context/diagnostic.rs index c069d017f..bf27da071 100644 --- a/core/src/context/diagnostic.rs +++ b/core/src/context/diagnostic.rs @@ -39,7 +39,7 @@ impl DiagnosticContext { shutdown, disk_guid, error: Arc::new(error.into()), - rpc_continuations: RpcContinuations::new(), + rpc_continuations: RpcContinuations::new(None), }))) } } diff --git a/core/src/context/init.rs b/core/src/context/init.rs index b7d5eac6a..5f6c35222 100644 --- a/core/src/context/init.rs +++ b/core/src/context/init.rs @@ -32,7 +32,7 @@ impl InitContext { error: watch::channel(None).0, progress, shutdown, - rpc_continuations: RpcContinuations::new(), + rpc_continuations: RpcContinuations::new(None), }))) } } diff --git a/core/src/context/rpc.rs b/core/src/context/rpc.rs index 61ce35020..204b000b5 100644 --- a/core/src/context/rpc.rs +++ b/core/src/context/rpc.rs @@ -62,8 +62,8 @@ pub struct RpcContextSeed { pub db: TypedPatchDb, pub sync_db: watch::Sender, pub account: SyncRwLock, - pub net_controller: Arc, pub os_net_service: NetService, + pub net_controller: Arc, pub s9pk_arch: Option<&'static str>, pub services: ServiceMap, pub cancellable_installs: SyncMutex>>, @@ -346,10 +346,10 @@ impl RpcContext { services, cancellable_installs: SyncMutex::new(BTreeMap::new()), metrics_cache, + rpc_continuations: RpcContinuations::new(Some(shutdown.clone())), shutdown, lxc_manager: Arc::new(LxcManager::new()), open_authed_continuations: OpenAuthedContinuations::new(), - rpc_continuations: RpcContinuations::new(), wifi_manager: Arc::new(RwLock::new(wifi_interface.clone().map(|i| WpaCli::init(i)))), current_secret: Arc::new( Jwk::generate_ec_key(josekit::jwk::alg::ec::EcCurve::P256).map_err(|e| { diff --git a/core/src/context/setup.rs b/core/src/context/setup.rs index d4d0bb9de..3d16624ef 100644 --- a/core/src/context/setup.rs +++ b/core/src/context/setup.rs @@ -85,7 +85,7 @@ impl SetupContext { result: OnceCell::new(), disk_guid: OnceCell::new(), shutdown, - rpc_continuations: RpcContinuations::new(), + rpc_continuations: RpcContinuations::new(None), install_rootfs: SyncMutex::new(None), language: SyncMutex::new(None), keyboard: SyncMutex::new(None), diff --git a/core/src/registry/context.rs b/core/src/registry/context.rs index c9773d14b..2aa5739f2 100644 --- a/core/src/registry/context.rs +++ b/core/src/registry/context.rs @@ -141,7 +141,7 @@ impl RegistryContext { listen: config.registry_listen.unwrap_or(DEFAULT_REGISTRY_LISTEN), db, datadir, - rpc_continuations: RpcContinuations::new(), + rpc_continuations: RpcContinuations::new(None), client: Client::builder() .proxy(Proxy::custom(move |url| { if url.host_str().map_or(false, |h| h.ends_with(".onion")) { diff --git a/core/src/rpc_continuations.rs b/core/src/rpc_continuations.rs index 42c3ae858..e084264ab 100644 --- a/core/src/rpc_continuations.rs +++ b/core/src/rpc_continuations.rs @@ -17,6 +17,7 @@ use ts_rs::TS; #[allow(unused_imports)] use crate::prelude::*; +use crate::shutdown::Shutdown; use crate::util::future::TimedResource; use crate::util::net::WebSocket; use crate::util::{FromStrParser, new_guid}; @@ -98,12 +99,15 @@ pub type RestHandler = Box RestFuture + Send>; pub struct WebSocketFuture { kill: Option>, + shutdown: Option>>, fut: BoxFuture<'static, ()>, } impl Future for WebSocketFuture { type Output = (); fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { - if self.kill.as_ref().map_or(false, |k| !k.is_empty()) { + if self.kill.as_ref().map_or(false, |k| !k.is_empty()) + || self.shutdown.as_ref().map_or(false, |s| !s.is_empty()) + { Poll::Ready(()) } else { self.fut.poll_unpin(cx) @@ -138,6 +142,7 @@ impl RpcContinuation { RpcContinuation::WebSocket(TimedResource::new( Box::new(|ws| WebSocketFuture { kill: None, + shutdown: None, fut: handler(ws.into()).boxed(), }), timeout, @@ -170,6 +175,7 @@ impl RpcContinuation { RpcContinuation::WebSocket(TimedResource::new( Box::new(|ws| WebSocketFuture { kill, + shutdown: None, fut: handler(ws.into()).boxed(), }), timeout, @@ -183,15 +189,21 @@ impl RpcContinuation { } } -pub struct RpcContinuations(AsyncMutex>); +pub struct RpcContinuations { + continuations: AsyncMutex>, + shutdown: Option>>, +} impl RpcContinuations { - pub fn new() -> Self { - RpcContinuations(AsyncMutex::new(BTreeMap::new())) + pub fn new(shutdown: Option>>) -> Self { + RpcContinuations { + continuations: AsyncMutex::new(BTreeMap::new()), + shutdown, + } } #[instrument(skip_all)] pub async fn clean(&self) { - let mut continuations = self.0.lock().await; + let mut continuations = self.continuations.lock().await; let mut to_remove = Vec::new(); for (guid, cont) in &*continuations { if cont.is_timed_out() { @@ -206,23 +218,28 @@ impl RpcContinuations { #[instrument(skip_all)] pub async fn add(&self, guid: Guid, handler: RpcContinuation) { self.clean().await; - self.0.lock().await.insert(guid, handler); + self.continuations.lock().await.insert(guid, handler); } pub async fn get_ws_handler(&self, guid: &Guid) -> Option { - let mut continuations = self.0.lock().await; + let mut continuations = self.continuations.lock().await; if !matches!(continuations.get(guid), Some(RpcContinuation::WebSocket(_))) { return None; } let Some(RpcContinuation::WebSocket(x)) = continuations.remove(guid) else { return None; }; - x.get().await + let handler = x.get().await?; + let shutdown = self.shutdown.as_ref().map(|s| s.subscribe()); + Some(Box::new(move |ws| { + let mut fut = handler(ws); + fut.shutdown = shutdown; + fut + })) } pub async fn get_rest_handler(&self, guid: &Guid) -> Option { - let mut continuations: tokio::sync::MutexGuard<'_, BTreeMap> = - self.0.lock().await; + let mut continuations = self.continuations.lock().await; if !matches!(continuations.get(guid), Some(RpcContinuation::Rest(_))) { return None; } diff --git a/core/src/tunnel/context.rs b/core/src/tunnel/context.rs index 1cb23c49e..0d6ab5df8 100644 --- a/core/src/tunnel/context.rs +++ b/core/src/tunnel/context.rs @@ -201,7 +201,7 @@ impl TunnelContext { listen, db, datadir, - rpc_continuations: RpcContinuations::new(), + rpc_continuations: RpcContinuations::new(None), open_authed_continuations: OpenAuthedContinuations::new(), ephemeral_sessions: SyncMutex::new(Sessions::new()), net_iface, diff --git a/core/src/upload.rs b/core/src/upload.rs index 5812834da..96f82b812 100644 --- a/core/src/upload.rs +++ b/core/src/upload.rs @@ -13,7 +13,6 @@ use futures::{FutureExt, Stream, StreamExt, ready}; use http::header::CONTENT_LENGTH; use http::{HeaderMap, StatusCode}; use imbl_value::InternedString; -use tokio::fs::File; use tokio::io::{AsyncRead, AsyncSeek, AsyncSeekExt, AsyncWrite, AsyncWriteExt}; use tokio::sync::watch; @@ -23,6 +22,7 @@ use crate::progress::{PhaseProgressTrackerHandle, ProgressUnits}; use crate::rpc_continuations::{Guid, RpcContinuation}; use crate::s9pk::merkle_archive::source::ArchiveSource; use crate::s9pk::merkle_archive::source::multi_cursor_file::{FileCursor, MultiCursorFile}; +use crate::util::direct_io::DirectIoFile; use crate::util::io::{TmpDir, create_file}; pub async fn upload( @@ -69,16 +69,6 @@ impl Progress { false } } - fn handle_write(&mut self, res: &std::io::Result) -> bool { - match res { - Ok(a) => { - self.written += *a as u64; - self.tracker += *a as u64; - true - } - Err(e) => self.handle_error(e), - } - } async fn expected_size(watch: &mut watch::Receiver) -> Option { watch .wait_for(|progress| progress.error.is_some() || progress.expected_size.is_some()) @@ -192,16 +182,19 @@ impl UploadingFile { complete: false, }); let file = create_file(path).await?; + let multi_cursor = MultiCursorFile::open(&file).await?; + let direct_file = DirectIoFile::from_tokio_file(file).await?; let uploading = Self { tmp_dir: None, - file: MultiCursorFile::open(&file).await?, + file: multi_cursor, progress: progress.1, }; Ok(( UploadHandle { tmp_dir: None, - file, + file: direct_file, progress: progress.0, + last_synced: 0, }, uploading, )) @@ -346,8 +339,9 @@ impl AsyncSeek for UploadingFileReader { pub struct UploadHandle { tmp_dir: Option>, #[pin] - file: File, + file: DirectIoFile, progress: watch::Sender, + last_synced: u64, } impl UploadHandle { pub async fn upload(&mut self, request: Request) { @@ -394,6 +388,19 @@ impl UploadHandle { if let Err(e) = self.file.sync_all().await { self.progress.send_if_modified(|p| p.handle_error(&e)); } + // Update progress with final synced bytes + self.update_sync_progress(); + } + fn update_sync_progress(&mut self) { + let synced = self.file.bytes_synced(); + let delta = synced - self.last_synced; + if delta > 0 { + self.last_synced = synced; + self.progress.send_modify(|p| { + p.written += delta; + p.tracker += delta; + }); + } } } #[pin_project::pinned_drop] @@ -410,13 +417,23 @@ impl AsyncWrite for UploadHandle { buf: &[u8], ) -> Poll> { let this = self.project(); + // Update progress based on bytes actually flushed to disk + let synced = this.file.bytes_synced(); + let delta = synced - *this.last_synced; + if delta > 0 { + *this.last_synced = synced; + this.progress.send_modify(|p| { + p.written += delta; + p.tracker += delta; + }); + } match this.file.poll_write(cx, buf) { - Poll::Ready(res) => { + Poll::Ready(Err(e)) => { this.progress - .send_if_modified(|progress| progress.handle_write(&res)); - Poll::Ready(res) + .send_if_modified(|progress| progress.handle_error(&e)); + Poll::Ready(Err(e)) } - Poll::Pending => Poll::Pending, + a => a, } } fn poll_flush( diff --git a/core/src/util/direct_io.rs b/core/src/util/direct_io.rs new file mode 100644 index 000000000..9c3880f8c --- /dev/null +++ b/core/src/util/direct_io.rs @@ -0,0 +1,298 @@ +use std::alloc::Layout; +use std::io::Write; +use std::os::fd::AsRawFd; +use std::pin::Pin; +use std::task::{Context, Poll}; + +use tokio::io::AsyncWrite; +use tokio::task::JoinHandle; + +const BLOCK_SIZE: usize = 4096; +const BUF_CAP: usize = 256 * 1024; // 256KB + +/// Aligned buffer for O_DIRECT I/O. +struct AlignedBuf { + ptr: *mut u8, + len: usize, +} + +// SAFETY: We have exclusive ownership of the allocation. +unsafe impl Send for AlignedBuf {} + +impl AlignedBuf { + fn new() -> Self { + let layout = Layout::from_size_align(BUF_CAP, BLOCK_SIZE).unwrap(); + // SAFETY: layout has non-zero size + let ptr = unsafe { std::alloc::alloc(layout) }; + if ptr.is_null() { + std::alloc::handle_alloc_error(layout); + } + Self { ptr, len: 0 } + } + + fn as_slice(&self) -> &[u8] { + // SAFETY: ptr is valid for len bytes, properly aligned, exclusively owned + unsafe { std::slice::from_raw_parts(self.ptr, self.len) } + } + + fn push(&mut self, data: &[u8]) -> usize { + let n = data.len().min(BUF_CAP - self.len); + // SAFETY: src and dst don't overlap, both valid for n bytes + unsafe { + std::ptr::copy_nonoverlapping(data.as_ptr(), self.ptr.add(self.len), n); + } + self.len += n; + n + } + + fn aligned_len(&self) -> usize { + self.len & !(BLOCK_SIZE - 1) + } + + fn drain_front(&mut self, n: usize) { + debug_assert!(n <= self.len); + let remaining = self.len - n; + if remaining > 0 { + // SAFETY: regions may overlap, so we use copy (memmove) + unsafe { + std::ptr::copy(self.ptr.add(n), self.ptr, remaining); + } + } + self.len = remaining; + } + + /// Extract aligned data into a new buffer for flushing, leaving remainder. + fn take_aligned(&mut self) -> Option<(AlignedBuf, u64)> { + let aligned = self.aligned_len(); + if aligned == 0 { + return None; + } + let mut flush_buf = AlignedBuf::new(); + flush_buf.push(&self.as_slice()[..aligned]); + self.drain_front(aligned); + Some((flush_buf, aligned as u64)) + } +} + +impl Drop for AlignedBuf { + fn drop(&mut self) { + let layout = Layout::from_size_align(BUF_CAP, BLOCK_SIZE).unwrap(); + // SAFETY: ptr was allocated with this layout in new() + unsafe { std::alloc::dealloc(self.ptr, layout) }; + } +} + +enum FileState { + Idle(std::fs::File), + Flushing(JoinHandle>), + Done, +} + +/// A file writer that uses O_DIRECT to bypass the kernel page cache. +/// +/// Buffers writes in an aligned buffer and flushes to disk in the background. +/// New writes can proceed while a flush is in progress (double-buffering). +/// Progress is tracked via [`bytes_synced`](Self::bytes_synced), which reflects +/// bytes actually written to disk. +pub struct DirectIoFile { + file_state: FileState, + buf: AlignedBuf, + synced: u64, +} + +impl DirectIoFile { + fn new(file: std::fs::File) -> Self { + Self { + file_state: FileState::Idle(file), + buf: AlignedBuf::new(), + synced: 0, + } + } + + /// Convert an existing tokio File into a DirectIoFile by adding O_DIRECT. + pub async fn from_tokio_file(file: tokio::fs::File) -> std::io::Result { + let std_file = file.into_std().await; + let fd = std_file.as_raw_fd(); + // SAFETY: fd is valid, F_GETFL/F_SETFL are standard fcntl ops + unsafe { + let flags = libc::fcntl(fd, libc::F_GETFL); + if flags == -1 { + return Err(std::io::Error::last_os_error()); + } + if libc::fcntl(fd, libc::F_SETFL, flags | libc::O_DIRECT) == -1 { + return Err(std::io::Error::last_os_error()); + } + } + Ok(Self::new(std_file)) + } + + /// Number of bytes confirmed written to disk. + pub fn bytes_synced(&self) -> u64 { + self.synced + } + + /// Flush any remaining buffered data and sync to disk. + /// + /// Removes the O_DIRECT flag for the final partial-block write, then + /// calls fsync. Updates `bytes_synced` to the final total. + pub async fn sync_all(&mut self) -> std::io::Result<()> { + // Wait for any in-flight flush + self.await_flush().await?; + + let FileState::Idle(file) = std::mem::replace(&mut self.file_state, FileState::Done) + else { + return Ok(()); + }; + + let mut buf = std::mem::replace(&mut self.buf, AlignedBuf::new()); + let remaining = buf.len as u64; + + tokio::task::spawn_blocking(move || { + let mut file = file; + + // Write any aligned portion + let aligned = buf.aligned_len(); + if aligned > 0 { + let slice = unsafe { std::slice::from_raw_parts(buf.ptr, aligned) }; + file.write_all(slice)?; + buf.drain_front(aligned); + } + + // Write remainder with O_DIRECT disabled + if buf.len > 0 { + let fd = file.as_raw_fd(); + // SAFETY: fd is valid, F_GETFL/F_SETFL are standard fcntl ops + unsafe { + let flags = libc::fcntl(fd, libc::F_GETFL); + libc::fcntl(fd, libc::F_SETFL, flags & !libc::O_DIRECT); + } + file.write_all(buf.as_slice())?; + } + + file.sync_all() + }) + .await + .map_err(|e| std::io::Error::new(std::io::ErrorKind::Other, e))??; + + self.synced += remaining; + Ok(()) + } + + async fn await_flush(&mut self) -> std::io::Result<()> { + if let FileState::Flushing(handle) = &mut self.file_state { + let (file, flushed) = handle + .await + .map_err(|e| std::io::Error::new(std::io::ErrorKind::Other, e))??; + self.synced += flushed; + self.file_state = FileState::Idle(file); + } + Ok(()) + } + + /// Non-blocking poll: try to complete a pending flush. + /// Returns Ready(Ok(())) if idle (or just became idle), Pending if still flushing. + fn poll_complete_flush(&mut self, cx: &mut Context<'_>) -> Poll> { + if let FileState::Flushing(handle) = &mut self.file_state { + match Pin::new(handle).poll(cx) { + Poll::Ready(Ok(Ok((file, flushed)))) => { + self.synced += flushed; + self.file_state = FileState::Idle(file); + } + Poll::Ready(Ok(Err(e))) => { + self.file_state = FileState::Done; + return Poll::Ready(Err(e)); + } + Poll::Ready(Err(e)) => { + self.file_state = FileState::Done; + return Poll::Ready(Err(std::io::Error::new(std::io::ErrorKind::Other, e))); + } + Poll::Pending => return Poll::Pending, + } + } + Poll::Ready(Ok(())) + } + + /// Start a background flush of aligned data if the file is idle. + fn maybe_start_flush(&mut self) { + if !matches!(self.file_state, FileState::Idle(_)) { + return; + } + let Some((flush_buf, count)) = self.buf.take_aligned() else { + return; + }; + let FileState::Idle(file) = std::mem::replace(&mut self.file_state, FileState::Done) + else { + unreachable!() + }; + let handle = tokio::task::spawn_blocking(move || { + let mut file = file; + file.write_all(flush_buf.as_slice())?; + Ok((file, count)) + }); + self.file_state = FileState::Flushing(handle); + } +} + +impl AsyncWrite for DirectIoFile { + fn poll_write( + mut self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &[u8], + ) -> Poll> { + // Try to complete any pending flush (non-blocking, registers waker) + match self.poll_complete_flush(cx) { + Poll::Ready(Err(e)) => return Poll::Ready(Err(e)), + _ => {} // Pending is fine — we can still accept writes into the buffer + } + + // If file just became idle and buffer has aligned data, start a flush + // to free buffer space before accepting new data + self.maybe_start_flush(); + + // Accept data into the buffer + let n = self.buf.push(buf); + if n == 0 { + // Buffer full, must wait for flush to complete and free space. + // Waker was already registered by poll_complete_flush above. + return Poll::Pending; + } + + // If file is idle and we now have aligned data, start flushing + self.maybe_start_flush(); + + Poll::Ready(Ok(n)) + } + + fn poll_flush( + mut self: Pin<&mut Self>, + cx: &mut Context<'_>, + ) -> Poll> { + match self.poll_complete_flush(cx) { + Poll::Pending => return Poll::Pending, + Poll::Ready(Err(e)) => return Poll::Ready(Err(e)), + Poll::Ready(Ok(())) => {} + } + + if self.buf.aligned_len() > 0 { + self.maybe_start_flush(); + // Poll the just-started flush + return self.poll_complete_flush(cx).map(|r| r.map(|_| ())); + } + + Poll::Ready(Ok(())) + } + + fn poll_shutdown( + mut self: Pin<&mut Self>, + cx: &mut Context<'_>, + ) -> Poll> { + match self.poll_complete_flush(cx) { + Poll::Pending => return Poll::Pending, + Poll::Ready(Err(e)) => return Poll::Ready(Err(e)), + Poll::Ready(Ok(())) => {} + } + + self.file_state = FileState::Done; + Poll::Ready(Ok(())) + } +} diff --git a/core/src/util/mod.rs b/core/src/util/mod.rs index 6cdc345a1..e9e98e039 100644 --- a/core/src/util/mod.rs +++ b/core/src/util/mod.rs @@ -38,6 +38,7 @@ pub mod collections; pub mod cpupower; pub mod crypto; pub mod data_url; +pub mod direct_io; pub mod future; pub mod http_reader; pub mod io;