How to use mutable references as arguments for generic functions in rust

68 views Asked by At

I am attempting to practice using the stratergy pattern and thought to sort a vector using a function selected at runtime, ideally being able to change the data in place using a mutable reference.

I tried to add the keyword &mut in multiple places but could not figure out how to get my code to work with the mutable references.

Here is my code

use rand::Rng;

#[derive(Clone)]
pub struct Sorter<F: Fn(Vec<T>) -> Vec<T>, T: Ord + Eq + Copy> {
    sort_strategy: F,
    sort_data: Vec<T>,
}

impl<F: Fn(Vec<T>) -> Vec<T>, T: Ord + Eq + Copy> Sorter<F, T> {
    // return self with sorted data
    // how do i remove the return self and make it a mutable reference to self
    pub fn sort(mut self) -> Self {
        self.sort_data = (self.sort_strategy)(self.sort_data);
        self
    }

    pub fn new(sort_strategy: F, sort_data: Vec<T>) -> Sorter<F, T> {
        Sorter {
            sort_strategy,
            sort_data,
        }
    }
}

// bubble sort function
pub fn bubble_sort<T: Eq + PartialOrd + Copy>(input_vector: Vec<T>) -> Vec<T> {
    let mut output_vector = input_vector.clone();
    for _ in 0..output_vector.len() {
        for j in 0..(output_vector.len() - 1) {
            if output_vector[j] > output_vector[j + 1] {
                output_vector.swap(j + 1, j);
            }
        }
    }
    output_vector
}

// quick sort function
pub fn quick_sort<T: Eq + PartialOrd + Copy>(input_vector: Vec<T>) -> Vec<T> {
    if input_vector.len() <= 1 {
        return input_vector;
    }

    let pivot = rand::thread_rng().gen_range(0..input_vector.len());
    let pivot_val = input_vector[pivot];
    let mut little_vector: Vec<T> = Vec::new();
    let mut big_vector: Vec<T> = Vec::new();

    for i in input_vector.iter().enumerate() {
        if i.0 == pivot {
            continue;
        }
        if *(i.1) > pivot_val {
            big_vector.push(*(i.1));
            continue;
        }
        if *(i.1) <= pivot_val {
            little_vector.push(*(i.1))
        }
    }

    little_vector = quick_sort(little_vector);
    little_vector.push(pivot_val);
    quick_sort(big_vector)
        .iter()
        .for_each(|n| little_vector.push(*n));

    little_vector
}

#[cfg(test)]
mod tests {
    use std::vec;

    use super::*;

    #[test]
    // test that bubble_sort is functional
    fn test_bubble_sort() {
        let data = vec![5, 4, 3, 2, 1];
        let result = bubble_sort(data);
        assert_eq!(result, vec![1, 2, 3, 4, 5])
    }

    #[test]
    // test that quick_sort is functional
    fn test_quick_sort() {
        let data = vec![5, 4, 3, 2, 1];
        let result = quick_sort(data);
        assert_eq!(result, vec![1, 2, 3, 4, 5])
    }

    #[test]
    // test that bubble_sort works in struct Sorter
    fn test_stratergy_pattern_bubble() {
        let sorter = Sorter::new(bubble_sort, vec![5, 4, 3, 2, 1]);
        let sorter = sorter.sort();
        assert_eq!(sorter.sort_data, vec![1, 2, 3, 4, 5]);
    }

    #[test]
    // test that quick_sort works in struct Sorter
    fn test_stratergy_pattern_quick() {
        let sorter = Sorter::new(quick_sort, vec![5, 4, 3, 2, 1]);
        let sorter = sorter.sort();
        assert_eq!(sorter.sort_data, vec![1, 2, 3, 4, 5]);
    }
}

Or on the Rust playground

1

There are 1 answers

0
user4815162342 On

As a first step, "remove the return self and make it a mutable reference to self" means we'd like Sorter::sort() to look like this:

pub fn sort(&mut self) {
    self.sort_data = (self.sort_strategy)(self.sort_data);
}

Doing that complains of not being able to move out of reference. The reason is that you cannot just take a value that resides in self.sort_data and pass it to a function that expects an owned value, such as self.sort_strategy. Doing that would leave self in a partially moved state - which the borrow checker supports, but then it won't allow further usage of the moved field, such as the assignment to self.sort_data.

If you don't want to change the signature of sort_data, then the easiest way is to replace self.sort_data with an empty vector, i.e. use std::mem::replace(&mut self.sort_data, vec![]). For types that implement Default, such as Vec, there is a std::mem::take() function that replaces with the default value. Then sort() would look like this:

pub fn sort(&mut self) {
    self.sort_data = (self.sort_strategy)(std::mem::take(&mut self.sort_data));
}

After trivial changes to the test suite not to capture the result of sort, your tests pass. (Playground)

The next step is to be consistent and adjust the sort functions to also work in-place, i.e. change the bound from Fn(Vec<T>) -> Vec<T> to Fn(&mut Vec<T>) or, better yet, Fn(&mut [T]) (because you don't need to grow or shrink the vector, you can sort the slice directly). In that case Sorter will look like this:

#[derive(Clone)]
pub struct Sorter<F, T> {
    sort_strategy: F,
    sort_data: Vec<T>,
}

impl<F: Fn(&mut [T]), T: Ord + Eq> Sorter<F, T> {
    pub fn sort(&mut self) {
        (self.sort_strategy)(&mut self.sort_data);
    }

    // new is unchanged
}

Sorting functions like bubble_sort() no longer return a value, but just take &mut [T] and sort in-place instead of allocating - which is how sorting methods are defined in the standard library. In your case:

pub fn bubble_sort<T: Eq + Ord>(vector: &mut [T]) {
    for _ in 0..vector.len() {
        for j in 0..(vector.len() - 1) {
            if vector[j] > vector[j + 1] {
                vector.swap(j + 1, j);
            }
        }
    }
}

Playground

(quick_sort() takes a bit more work to rewrite like this, so I skipped it.)

Side notes:

  • you don't need T: Copy, swapping values works by moving.
  • require T: Ord rather than T: PartialOrd, because sorts depend on total order, which is the guarantee provided by Ord.
  • even in the design where the sorter takes and returns a Vec, a sorter like bubble_sort() naturally sorts in place. In that case it can work on the input vector and return it, there is no reason to clone it into an "output" vector.