Skip to content

Commit

Permalink
Add callback support to FileDescription
Browse files Browse the repository at this point in the history
    - Implementing atomic reads for contiguous and scattered buffers

Signed-off-by: shamb0 <[email protected]>
  • Loading branch information
shamb0 committed Dec 26, 2024
1 parent 81afa3e commit 88cbd55
Show file tree
Hide file tree
Showing 7 changed files with 1,219 additions and 22 deletions.
18 changes: 16 additions & 2 deletions src/concurrency/thread.rs
Original file line number Diff line number Diff line change
Expand Up @@ -48,9 +48,23 @@ pub enum MachineCallbackState {
TimedOut,
}

/// Generic callback trait for machine operations.
/// A generic callback trait for asynchronous machine operations.
///
/// # Type Parameters
/// * `'tcx`: The typing context lifetime
///
/// Callbacks are executed while holding mutable access to the interpreter
/// context, so they must maintain interpreter invariants.
pub trait MachineCallback<'tcx>: VisitProvenance {
/// Called when the operation completes (successfully or with timeout).
/// Called when an operation completes, either successfully or due to timeout.
///
/// # Arguments
/// * `self`: Owned callback, boxed to allow for dynamic dispatch
/// * `ecx`: Mutable interpreter context
/// * `result`: Operation completion state
///
/// # Returns
/// Result indicating if callback execution succeeded
fn call(
self: Box<Self>,
ecx: &mut InterpCx<'tcx, MiriMachine<'tcx>>,
Expand Down
177 changes: 177 additions & 0 deletions src/shims/files.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ pub trait FileDescription: std::fmt::Debug + Any {
/// Reads as much as possible into the given buffer `ptr`.
/// `len` indicates how many bytes we should try to read.
/// `dest` is where the return value should be stored: number of bytes read, or `-1` in case of error.
#[allow(dead_code)]
fn read<'tcx>(
&self,
_self_ref: &FileDescriptionRef,
Expand All @@ -29,6 +30,29 @@ pub trait FileDescription: std::fmt::Debug + Any {
throw_unsup_format!("cannot read from {}", self.name());
}

/// Performs an atomic read operation on the file.
///
/// # Arguments
/// * `self_ref` - Strong reference to file description for lifetime management
/// * `communicate_allowed` - Whether external communication is permitted
/// * `op` - The I/O operation containing buffer and layout information
/// * `dest` - Destination for storing operation results
/// * `ecx` - Mutable reference to interpreter context
///
/// # Returns
/// * `Ok(())` on successful read
/// * `Err(_)` if read fails or is unsupported
fn read_atomic<'tcx>(
&self,
_self_ref: &FileDescriptionRef,
_communicate_allowed: bool,
_op: &mut IoTransferOperation<'tcx>,
_dest: &MPlaceTy<'tcx>,
_ecx: &mut MiriInterpCx<'tcx>,
) -> InterpResult<'tcx> {
throw_unsup_format!("cannot read from {}", self.name());
}

/// Writes as much as possible from the given buffer `ptr`.
/// `len` indicates how many bytes we should try to write.
/// `dest` is where the return value should be stored: number of bytes written, or `-1` in case of error.
Expand Down Expand Up @@ -409,3 +433,156 @@ pub trait EvalContextExt<'tcx>: crate::MiriInterpCxExt<'tcx> {
interp_ok(())
}
}

/// Represents an atomic I/O operation that handles data transfer between memory regions.
/// Supports both contiguous and scattered memory layouts for efficient I/O operations.
#[derive(Clone)]
pub struct IoTransferOperation<'tcx> {
/// Intermediate buffer for atomic transfer operations.
/// For reads: Temporary storage before distribution to destinations
/// For writes: Aggregation point before writing to file
transfer_buffer: Vec<u8>,

/// Memory layout specification for the transfer operation.
layout: IoBufferLayout,

/// Total number of bytes to be processed in this operation.
total_size: usize,

/// Interpreter context lifetime marker.
_phantom: std::marker::PhantomData<&'tcx ()>,
}

/// Specifies how memory regions are organized for I/O operations
#[derive(Clone)]
enum IoBufferLayout {
/// Single continuous memory region for transfer.
Contiguous { address: Pointer },
/// Multiple discontinuous memory regions.
Scattered { regions: Vec<(Pointer, usize)> },
}

impl VisitProvenance for IoTransferOperation<'_> {
fn visit_provenance(&self, _visit: &mut VisitWith<'_>) {
// Visits any references that need provenance tracking.
// Currently a no-op as IoTransferOperation contains no such references.
}
}

impl<'tcx> IoTransferOperation<'tcx> {
/// Creates a new I/O operation for a contiguous memory region.
pub fn new_contiguous(ptr: Pointer, len: usize) -> Self {
IoTransferOperation {
transfer_buffer: vec![0; len],
layout: IoBufferLayout::Contiguous { address: ptr },
total_size: len,
_phantom: std::marker::PhantomData,
}
}

/// Creates a new I/O operation for scattered memory regions.
pub fn new_scattered(buffers: Vec<(Pointer, usize)>) -> Self {
let total_size = buffers.iter().map(|(_, len)| len).sum();
IoTransferOperation {
transfer_buffer: vec![0; total_size],
layout: IoBufferLayout::Scattered { regions: buffers },
total_size,
_phantom: std::marker::PhantomData,
}
}

/// Provides mutable access to the transfer buffer.
pub fn buffer_mut(&mut self) -> &mut [u8] {
&mut self.transfer_buffer
}

/// Distributes data from the transfer buffer to final destinations.
pub fn distribute_data(
&mut self,
ecx: &mut MiriInterpCx<'tcx>,
dest: &MPlaceTy<'tcx>,
bytes_processed: usize,
) -> InterpResult<'tcx> {
if bytes_processed > self.total_size {
return ecx.set_last_error_and_return(LibcError("EINVAL"), dest);
}

match &self.layout {
IoBufferLayout::Contiguous { address } => {
// POSIX Compliance: Verify buffer accessibility before writing
if ecx
.check_ptr_access(
*address,
Size::from_bytes(bytes_processed),
CheckInAllocMsg::MemoryAccessTest,
)
.report_err()
.is_err()
{
return ecx.set_last_error_and_return(LibcError("EFAULT"), dest);
}

// Attempt the write operation
if ecx
.write_bytes_ptr(
*address,
self.transfer_buffer[..bytes_processed].iter().copied(),
)
.report_err()
.is_err()
{
return ecx.set_last_error_and_return(LibcError("EIO"), dest);
}
}

IoBufferLayout::Scattered { regions } => {
let mut current_pos = 0;

for (ptr, len) in regions {
if current_pos >= bytes_processed {
break;
}

// Calculate copy size with safe arithmetic
let remaining_bytes = bytes_processed
.checked_sub(current_pos)
.expect("current_pos should never exceed bytes_read");
let copy_size = (*len).min(remaining_bytes);

// POSIX Compliance: Verify each buffer's accessibility
if ecx
.check_ptr_access(
*ptr,
Size::from_bytes(copy_size),
CheckInAllocMsg::MemoryAccessTest,
)
.report_err()
.is_err()
{
return ecx.set_last_error_and_return(LibcError("EFAULT"), dest);
}

let end_pos = current_pos
.checked_add(copy_size)
.expect("end position calculation should not overflow");

// Attempt the write operation with proper error handling
if ecx
.write_bytes_ptr(
*ptr,
self.transfer_buffer[current_pos..end_pos].iter().copied(),
)
.report_err()
.is_err()
{
return ecx.set_last_error_and_return(LibcError("EIO"), dest);
}

current_pos = end_pos;
}
}
}

interp_ok(())
}
}
Loading

0 comments on commit 88cbd55

Please sign in to comment.