From 12caac1f8028869d4ce95eec8f696c7f18ccb30a Mon Sep 17 00:00:00 2001 From: Alexander Capehart Date: Mon, 17 Feb 2025 17:25:16 -0700 Subject: [PATCH] musikr: fix iostream lifecycle Finally pass over ownership of the RsIOStream to the normal IOStream. --- musikr/src/main/jni/shim/iostream_shim.cpp | 26 ++-- musikr/src/main/jni/shim/iostream_shim.hpp | 2 +- musikr/src/main/jni/src/jstream.rs | 169 ++++++++++----------- musikr/src/main/jni/src/taglib/bridge.rs | 14 +- musikr/src/main/jni/src/taglib/iostream.rs | 48 +++--- 5 files changed, 120 insertions(+), 139 deletions(-) diff --git a/musikr/src/main/jni/shim/iostream_shim.cpp b/musikr/src/main/jni/shim/iostream_shim.cpp index 66d7238ab..007863d47 100644 --- a/musikr/src/main/jni/shim/iostream_shim.cpp +++ b/musikr/src/main/jni/shim/iostream_shim.cpp @@ -11,7 +11,7 @@ namespace taglib_shim class WrappedRsIOStream : public TagLib::IOStream { public: - explicit WrappedRsIOStream(RsIOStream& stream); + explicit WrappedRsIOStream(rust::Box stream); ~WrappedRsIOStream() override; // TagLib::IOStream interface implementation @@ -29,28 +29,28 @@ namespace taglib_shim bool isOpen() const override; private: - RsIOStream& rust_stream; + rust::Box rust_stream; }; - WrappedRsIOStream::WrappedRsIOStream(RsIOStream& stream) : rust_stream(stream) {} + WrappedRsIOStream::WrappedRsIOStream(rust::Box stream) : rust_stream(std::move(stream)) {} WrappedRsIOStream::~WrappedRsIOStream() = default; TagLib::FileName WrappedRsIOStream::name() const { - return rust::string(rust_stream.name()).c_str(); + return rust::string(rust_stream->name()).c_str(); } TagLib::ByteVector WrappedRsIOStream::readBlock(size_t length) { std::vector buffer(length); - size_t bytes_read = rust_stream.read(rust::Slice(buffer.data(), length)); + size_t bytes_read = rust_stream->read(rust::Slice(buffer.data(), length)); return TagLib::ByteVector(reinterpret_cast(buffer.data()), bytes_read); } void WrappedRsIOStream::writeBlock(const TagLib::ByteVector &data) { - rust_stream.write(rust::Slice( + rust_stream->write(rust::Slice( reinterpret_cast(data.data()), data.size())); } @@ -118,7 +118,7 @@ namespace taglib_shim default: throw std::runtime_error("Invalid seek position"); } - rust_stream.seek(offset, whence); + rust_stream->seek(offset, whence); } void WrappedRsIOStream::clear() @@ -129,22 +129,22 @@ namespace taglib_shim void WrappedRsIOStream::truncate(TagLib::offset_t length) { - rust_stream.truncate(length); + rust_stream->truncate(length); } TagLib::offset_t WrappedRsIOStream::tell() const { - return rust_stream.tell(); + return rust_stream->tell(); } TagLib::offset_t WrappedRsIOStream::length() { - return rust_stream.length(); + return rust_stream->length(); } bool WrappedRsIOStream::readOnly() const { - return rust_stream.is_readonly(); + return rust_stream->is_readonly(); } bool WrappedRsIOStream::isOpen() const @@ -153,9 +153,9 @@ namespace taglib_shim } // Factory function to create a new RustIOStream - std::unique_ptr wrap_RsIOStream(RsIOStream& stream) + std::unique_ptr wrap_RsIOStream(rust::Box stream) { - return std::unique_ptr(new WrappedRsIOStream(stream)); + return std::unique_ptr(new WrappedRsIOStream(std::move(stream))); } } // namespace taglib_shim \ No newline at end of file diff --git a/musikr/src/main/jni/shim/iostream_shim.hpp b/musikr/src/main/jni/shim/iostream_shim.hpp index 7eb70db04..839be0161 100644 --- a/musikr/src/main/jni/shim/iostream_shim.hpp +++ b/musikr/src/main/jni/shim/iostream_shim.hpp @@ -12,5 +12,5 @@ struct RsIOStream; namespace taglib_shim { // Factory functions with external linkage - std::unique_ptr wrap_RsIOStream(RsIOStream& stream); + std::unique_ptr wrap_RsIOStream(rust::Box stream); } // namespace taglib_shim \ No newline at end of file diff --git a/musikr/src/main/jni/src/jstream.rs b/musikr/src/main/jni/src/jstream.rs index b41bd8119..c9aec1bb9 100644 --- a/musikr/src/main/jni/src/jstream.rs +++ b/musikr/src/main/jni/src/jstream.rs @@ -15,6 +15,86 @@ impl<'local, 'a> JInputStream<'local> { } impl<'local> IOStream for JInputStream<'local> { + fn read_block(&mut self, buf: &mut [u8]) -> usize { + // Create a direct ByteBuffer from the Rust slice + let byte_buffer = unsafe { + self.env + .borrow_mut() + .new_direct_byte_buffer(buf.as_mut_ptr(), buf.len()) + .expect("Failed to create ByteBuffer") + }; + + // Call readBlock safely + let success = self + .env + .borrow_mut() + .call_method( + &self.input, + "readBlock", + "(Ljava/nio/ByteBuffer;)Z", + &[JValue::Object(&byte_buffer)], + ) + .and_then(|result| result.z()) + .expect("Failed to call readBlock"); + + if !success { + return 0; + } + + buf.len() + } + + fn write_block(&mut self, _data: &[u8]) { + panic!("JInputStream is read-only"); + } + + fn seek(&mut self, pos: SeekFrom) { + let (method, offset) = match pos { + SeekFrom::Start(offset) => ("seekFromBeginning", offset as i64), + SeekFrom::Current(offset) => ("seekFromCurrent", offset), + SeekFrom::End(offset) => ("seekFromEnd", offset), + }; + + // Call the appropriate seek method safely + let success = self + .env + .borrow_mut() + .call_method(&self.input, method, "(J)Z", &[JValue::Long(offset)]) + .and_then(|result| result.z()) + .expect("Failed to seek"); + + if !success { + panic!("Failed to seek"); + } + } + + fn truncate(&mut self, _length: i64) { + panic!("JInputStream is read-only"); + } + + fn tell(&self) -> i64 { + let position = self + .env + .borrow_mut() + .call_method(&self.input, "tell", "()J", &[]) + .and_then(|result| result.j()) + .expect("Failed to get position"); + + if position == i64::MIN { + panic!("Failed to get position"); + } + + position + } + + fn length(&self) -> i64 { + self.env + .borrow_mut() + .call_method(&self.input, "length", "()J", &[]) + .and_then(|result| result.j()) + .expect("Failed to get length") + } + fn name(&self) -> String { // Call the Java name() method safely let name = self @@ -35,92 +115,3 @@ impl<'local> IOStream for JInputStream<'local> { true // JInputStream is always read-only } } - -impl<'local> Read for JInputStream<'local> { - fn read(&mut self, buf: &mut [u8]) -> std::io::Result { - // Create a direct ByteBuffer from the Rust slice - let byte_buffer = unsafe { - self.env - .borrow_mut() - .new_direct_byte_buffer(buf.as_mut_ptr(), buf.len()) - .map_err(|e| std::io::Error::new(std::io::ErrorKind::Other, e.to_string()))? - }; - - // Call readBlock safely - let success = self - .env - .borrow_mut() - .call_method( - &self.input, - "readBlock", - "(Ljava/nio/ByteBuffer;)Z", - &[JValue::Object(&byte_buffer)], - ) - .and_then(|result| result.z()) - .map_err(|e| std::io::Error::new(std::io::ErrorKind::Other, e.to_string()))?; - - if !success { - return Err(std::io::Error::new( - std::io::ErrorKind::Other, - "Failed to read block", - )); - } - - Ok(buf.len()) - } -} - -impl<'local> Write for JInputStream<'local> { - fn write(&mut self, _buf: &[u8]) -> std::io::Result { - Err(std::io::Error::new( - std::io::ErrorKind::PermissionDenied, - "JInputStream is read-only", - )) - } - - fn flush(&mut self) -> std::io::Result<()> { - Ok(()) // Nothing to flush in a read-only stream - } -} - -impl<'local, 'a> Seek for JInputStream<'local> { - fn seek(&mut self, pos: SeekFrom) -> std::io::Result { - let (method, offset) = match pos { - SeekFrom::Start(offset) => ("seekFromBeginning", offset as i64), - SeekFrom::Current(offset) => ("seekFromCurrent", offset), - SeekFrom::End(offset) => ("seekFromEnd", offset), - }; - - // Call the appropriate seek method safely - let success = self - .env - .borrow_mut() - .call_method(&self.input, method, "(J)Z", &[JValue::Long(offset)]) - .and_then(|result| result.z()) - .map_err(|e| std::io::Error::new(std::io::ErrorKind::Other, e.to_string()))?; - - if !success { - return Err(std::io::Error::new( - std::io::ErrorKind::Other, - "Failed to seek", - )); - } - - // Return current position safely - let position = self - .env - .borrow_mut() - .call_method(&self.input, "tell", "()J", &[]) - .and_then(|result| result.j()) - .map_err(|e| std::io::Error::new(std::io::ErrorKind::Other, e.to_string()))?; - - if position == i64::MIN { - return Err(std::io::Error::new( - std::io::ErrorKind::Other, - "Failed to get position", - )); - } - - Ok(position as u64) - } -} diff --git a/musikr/src/main/jni/src/taglib/bridge.rs b/musikr/src/main/jni/src/taglib/bridge.rs index a410d178a..f8f8429cc 100644 --- a/musikr/src/main/jni/src/taglib/bridge.rs +++ b/musikr/src/main/jni/src/taglib/bridge.rs @@ -5,16 +5,16 @@ mod bridge_impl { // Expose Rust IOStream to C++ extern "Rust" { #[cxx_name = "RsIOStream"] - type DynIOStream<'a>; + type DynIOStream<'io_stream>; - fn name(self: &mut DynIOStream<'_>) -> String; + fn name(self: &DynIOStream<'_>) -> String; fn read(self: &mut DynIOStream<'_>, buffer: &mut [u8]) -> usize; fn write(self: &mut DynIOStream<'_>, data: &[u8]); fn seek(self: &mut DynIOStream<'_>, offset: i64, whence: i32); fn truncate(self: &mut DynIOStream<'_>, length: i64); - fn tell(self: &mut DynIOStream<'_>) -> i64; - fn length(self: &mut DynIOStream<'_>) -> i64; - fn is_readonly(self: &mut DynIOStream<'_>) -> bool; + fn tell(self: &DynIOStream<'_>) -> i64; + fn length(self: &DynIOStream<'_>) -> i64; + fn is_readonly(self: &DynIOStream<'_>) -> bool; } #[namespace = "taglib_shim"] @@ -42,8 +42,8 @@ mod bridge_impl { #[namespace = "TagLib"] #[cxx_name = "IOStream"] - type CPPIOStream; - fn wrap_RsIOStream(stream: Pin<&mut DynIOStream>) -> UniquePtr; + type CPPIOStream<'io_stream>; + fn wrap_RsIOStream<'io_stream>(stream: Box>) -> UniquePtr>; #[namespace = "TagLib"] #[cxx_name = "FileRef"] diff --git a/musikr/src/main/jni/src/taglib/iostream.rs b/musikr/src/main/jni/src/taglib/iostream.rs index 5a256973f..37b38f31f 100644 --- a/musikr/src/main/jni/src/taglib/iostream.rs +++ b/musikr/src/main/jni/src/taglib/iostream.rs @@ -3,22 +3,26 @@ use cxx::UniquePtr; use std::io::{Read, Seek, SeekFrom, Write}; use std::pin::Pin; -pub trait IOStream: Read + Write + Seek { +pub trait IOStream { + fn read_block(&mut self, buffer: &mut [u8]) -> usize; + fn write_block(&mut self, data: &[u8]); + fn seek(&mut self, pos: SeekFrom); + fn truncate(&mut self, length: i64); + fn tell(&self) -> i64; + fn length(&self) -> i64; fn name(&self) -> String; fn is_readonly(&self) -> bool; } pub(super) struct BridgedIOStream<'io_stream> { - rs_stream: Pin>>, - cpp_stream: UniquePtr, + cpp_stream: UniquePtr>, } impl<'io_stream> BridgedIOStream<'io_stream> { pub fn new(stream: T) -> Self { - let mut rs_stream = Box::pin(DynIOStream(Box::new(stream))); - let cpp_stream = bridge::wrap_RsIOStream(rs_stream.as_mut()); + let rs_stream: Box> = Box::new(DynIOStream(Box::new(stream))); + let cpp_stream: UniquePtr> = bridge::wrap_RsIOStream(rs_stream); BridgedIOStream { - rs_stream, cpp_stream, } } @@ -28,31 +32,21 @@ impl<'io_stream> BridgedIOStream<'io_stream> { } } -impl<'io_stream> Drop for BridgedIOStream<'io_stream> { - fn drop(&mut self) { - unsafe { - // CPP stream references the rust stream, so it must be dropped first - std::ptr::drop_in_place(&mut self.cpp_stream); - std::ptr::drop_in_place(&mut self.rs_stream); - }; - } -} - #[repr(C)] pub(super) struct DynIOStream<'io_stream>(Box); impl<'io_stream> DynIOStream<'io_stream> { // Implement the exposed functions for cxx bridge - pub fn name(&mut self) -> String { + pub fn name(&self) -> String { self.0.name() } pub fn read(&mut self, buffer: &mut [u8]) -> usize { - self.0.read(buffer).unwrap_or(0) + self.0.read_block(buffer) } pub fn write(&mut self, data: &[u8]) { - self.0.write_all(data).unwrap(); + self.0.write_block(data); } pub fn seek(&mut self, offset: i64, whence: i32) { @@ -62,23 +56,19 @@ impl<'io_stream> DynIOStream<'io_stream> { 2 => SeekFrom::End(offset), _ => panic!("Invalid seek whence"), }; - self.0.seek(pos).unwrap(); + self.0.seek(pos); } pub fn truncate(&mut self, length: i64) { - self.0.seek(SeekFrom::Start(length as u64)).unwrap(); - // TODO: Actually implement truncate once we have a better trait bound + self.0.truncate(length); } - pub fn tell(&mut self) -> i64 { - self.0.seek(SeekFrom::Current(0)).unwrap() as i64 + pub fn tell(&self) -> i64 { + self.0.tell() } - pub fn length(&mut self) -> i64 { - let current = self.0.seek(SeekFrom::Current(0)).unwrap(); - let end = self.0.seek(SeekFrom::End(0)).unwrap(); - self.0.seek(SeekFrom::Start(current)).unwrap(); - end as i64 + pub fn length(&self) -> i64 { + self.0.length() } pub fn is_readonly(&self) -> bool {