diff --git a/integrations/dav-server/src/file.rs b/integrations/dav-server/src/file.rs index 1f21cf882596..ecfd91ad5e16 100644 --- a/integrations/dav-server/src/file.rs +++ b/integrations/dav-server/src/file.rs @@ -23,8 +23,8 @@ use dav_server::fs::{DavFile, OpenOptions}; use dav_server::fs::{DavMetaData, FsResult}; use dav_server::fs::{FsError, FsFuture}; use futures::FutureExt; -use futures::{AsyncReadExt, AsyncSeekExt, AsyncWriteExt}; -use opendal::{FuturesAsyncReader, FuturesAsyncWriter, Operator}; +use futures::{AsyncReadExt, AsyncSeekExt}; +use opendal::{Buffer, FuturesAsyncReader, Operator, Writer}; use super::metadata::OpendalMetaData; use super::utils::*; @@ -54,7 +54,7 @@ impl Debug for OpendalFile { enum State { Read(FuturesAsyncReader), - Write(FuturesAsyncWriter), + Write(Option), } impl OpendalFile { @@ -74,9 +74,8 @@ impl OpendalFile { .writer_with(&path) .append(options.append) .await - .map_err(convert_error)? - .into_futures_async_write(); - State::Write(w) + .map_err(convert_error)?; + State::Write(Some(w)) } else { return Err(FsError::NotImplemented); }; @@ -106,13 +105,19 @@ impl DavFile for OpendalFile { fn write_buf(&mut self, mut buf: Box) -> FsFuture<'_, ()> { async move { - let State::Write(w) = &mut self.state else { + let State::Write(Some(w)) = &mut self.state else { return Err(FsError::GeneralFailure); }; - w.write_all(&buf.copy_to_bytes(buf.remaining())) + if w.write(Buffer::from(buf.copy_to_bytes(buf.remaining()))) .await - .map_err(|_| FsError::GeneralFailure)?; + .is_err() + { + let _ = w.abort().await; + self.state = State::Write(None); + return Err(FsError::GeneralFailure); + } + Ok(()) } .boxed() @@ -120,11 +125,17 @@ impl DavFile for OpendalFile { fn write_bytes(&mut self, buf: Bytes) -> FsFuture<'_, ()> { async move { - let State::Write(w) = &mut self.state else { + let State::Write(Some(w)) = &mut self.state else { return Err(FsError::GeneralFailure); }; - w.write_all(&buf).await.map_err(|_| FsError::GeneralFailure) + if w.write(Buffer::from(buf)).await.is_err() { + let _ = w.abort().await; + self.state = State::Write(None); + return Err(FsError::GeneralFailure); + } + + Ok(()) } .boxed() } @@ -158,12 +169,18 @@ impl DavFile for OpendalFile { fn flush(&mut self) -> FsFuture<'_, ()> { async move { - let State::Write(w) = &mut self.state else { + let State::Write(Some(w)) = &mut self.state else { return Err(FsError::GeneralFailure); }; - w.flush().await.map_err(|_| FsError::GeneralFailure)?; - w.close().await.map_err(|_| FsError::GeneralFailure) + if w.close().await.is_err() { + let _ = w.abort().await; + self.state = State::Write(None); + return Err(FsError::GeneralFailure); + } + + self.state = State::Write(None); + Ok(()) } .boxed() } diff --git a/integrations/dav-server/tests/test.rs b/integrations/dav-server/tests/test.rs index 7403ea0174e3..e991afc5a93c 100644 --- a/integrations/dav-server/tests/test.rs +++ b/integrations/dav-server/tests/test.rs @@ -22,10 +22,17 @@ use dav_server::fs::OpenOptions; use dav_server::fs::{DavFileSystem, ReadDirMeta}; use dav_server_opendalfs::OpendalFs; use futures::StreamExt; +use opendal::Buffer; use opendal::Operator; +use opendal::raw::oio; +use opendal::raw::{Access, AccessorInfo, OpWrite, RpWrite}; use opendal::services::Fs; +use opendal::{Capability, Error, ErrorKind, Metadata}; +use std::fmt::{Debug, Formatter}; use std::fs; use std::path::Path; +use std::sync::Arc; +use std::sync::atomic::{AtomicBool, Ordering}; #[tokio::test] async fn test() -> Result<()> { @@ -213,3 +220,161 @@ async fn test_read_dir() { fs::remove_dir_all(TMP_PATH).unwrap(); } + +#[derive(Clone)] +struct AbortTrackingAccess { + info: Arc, + aborted: Arc, + fail_on_write: bool, + fail_on_close: bool, +} + +impl Debug for AbortTrackingAccess { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + f.debug_struct("AbortTrackingAccess").finish() + } +} + +impl AbortTrackingAccess { + fn with_failures(aborted: Arc, fail_on_write: bool, fail_on_close: bool) -> Self { + let info = AccessorInfo::default(); + info.set_scheme("memory") + .set_root("/") + .set_name("abort-tracking") + .set_native_capability(Capability { + write: true, + ..Default::default() + }); + + Self { + info: info.into(), + aborted, + fail_on_write, + fail_on_close, + } + } + + fn write_failure(aborted: Arc) -> Self { + Self::with_failures(aborted, true, false) + } + + fn close_failure(aborted: Arc) -> Self { + Self::with_failures(aborted, false, true) + } +} + +impl Access for AbortTrackingAccess { + type Reader = oio::Reader; + type Writer = oio::Writer; + type Lister = oio::Lister; + type Deleter = oio::Deleter; + type Copier = oio::Copier; + + fn info(&self) -> Arc { + self.info.clone() + } + + async fn write(&self, _: &str, _: OpWrite) -> opendal::Result<(RpWrite, Self::Writer)> { + Ok(( + RpWrite::new(), + Box::new(AbortTrackingWriter { + aborted: self.aborted.clone(), + fail_on_write: self.fail_on_write, + fail_on_close: self.fail_on_close, + }), + )) + } +} + +struct AbortTrackingWriter { + aborted: Arc, + fail_on_write: bool, + fail_on_close: bool, +} + +impl oio::Write for AbortTrackingWriter { + async fn write(&mut self, _: Buffer) -> opendal::Result<()> { + if self.fail_on_write { + return Err(Error::new(ErrorKind::Unexpected, "injected write failure")); + } + + Ok(()) + } + + async fn close(&mut self) -> opendal::Result { + if self.fail_on_close { + return Err(Error::new(ErrorKind::Unexpected, "injected close failure")); + } + + Ok(Metadata::default()) + } + + async fn abort(&mut self) -> opendal::Result<()> { + self.aborted.store(true, Ordering::SeqCst); + Ok(()) + } +} + +#[tokio::test] +async fn test_failed_write_aborts_before_drop() { + let aborted = Arc::new(AtomicBool::new(false)); + let op = Operator::from_inner(Arc::new(AbortTrackingAccess::write_failure( + aborted.clone(), + ))); + let webdavfs = OpendalFs::new(op); + + let mut file = webdavfs + .open( + &DavPath::new("/failed-write").unwrap(), + OpenOptions { + write: true, + ..OpenOptions::default() + }, + ) + .await + .unwrap(); + + let err = file.write_bytes(Bytes::from(vec![1; 300 * 1024])).await; + assert!(err.is_err()); + + drop(file); + + assert!( + aborted.load(Ordering::SeqCst), + "writer.abort() should be called when a write fails before close()" + ); +} + +#[tokio::test] +async fn test_failed_close_aborts_before_drop() { + let aborted = Arc::new(AtomicBool::new(false)); + let op = Operator::from_inner(Arc::new(AbortTrackingAccess::close_failure( + aborted.clone(), + ))); + let webdavfs = OpendalFs::new(op); + + let mut file = webdavfs + .open( + &DavPath::new("/failed-close").unwrap(), + OpenOptions { + write: true, + ..OpenOptions::default() + }, + ) + .await + .unwrap(); + + file.write_bytes(Bytes::from(vec![1; 300 * 1024])) + .await + .unwrap(); + + let err = file.flush().await; + assert!(err.is_err()); + + drop(file); + + assert!( + aborted.load(Ordering::SeqCst), + "writer.abort() should be called when close() fails during flush()" + ); +}