Skip to main content

Mocking in Rust

·1103 words·6 mins
Author
Jake Treacher

Overview
#

This post details my journey of mocking in Rust.

The team was new to Rust. We own a low-latency system that performs actions based on the customer’s group. However, the necessary customer details were accessible from a separate service, and any network request would be too costly in terms of latency. Since customer groups rarely change, we decided to periodically download a file for each group containing all the associated users, allowing us to perform the check locally. I worked on this implementation.

The initial code looked something like this:

use aws_sdk_s3::Client as S3Client;

// Result<bool> represents whether a new file was downloaded
pub async fn sync(
    customer_client: Arc<Mutex<CustomerClient>>,
    s3_client: Arc<S3Client>,
    data_dir: &PathBuf,
    bucket: &str,
    prefix: &str,
) -> anyhow::Result<bool> {
    let s3_get_object_response = get_latest_object(
        &s3_client,
        bucket,
        prefix,
    )
    .await;

    let e_tag = s3_get_object_response
        .e_tag()
        .ok_or_else(|| anyhow!("Unable to parse e_tag from response"))?
        .to_owned();

    let update_available = customer_client.lock().e_tag().ne(&e_tag);
    if !update_available {
        return Ok(false);
    }

    let file = NamedTempFile::new_in(data_dir)?;
    stream_data_to_disk(
        s3_get_object_response,
        file.as_file(),
    )
    .await?;

    customer_client.lock().update(e_tag, file);

    Ok(true)
}

Before marking our code as complete, we should ensure it is properly tested. There’s a general sentiment that people don’t like testing, but I believe the issue isn’t with testing itself but rather how difficult it is to setup the test. And this is the challenge we face now - this code is difficult to test. All our test cases depend on the result of get_latest_object(), which relies on an S3Client to make a network request.

In Rust, we don’t have the luxury of sophisticated dependency injection frameworks like those available in Java. However, there are still mechanisms we can use to simplify our testing experience.

Testing Approaches
#

Test Double
#

Rust’s configuration predicates allow us to change imports based on whether we’re in a test environment. By creating a duplicate struct, we can import the test double during testing, enabling us to use predetermined responses.

Although this provides us with the ability to mock, it does come with some drawbacks:

  • This solution works by overriding the implement of a function, get_latest_object. We can only override this once, thus limiting us to a single test scenario.
  • We must import mock::S3Client whenever we import S3Client. This results in excessive use of #[cfg(not(test))], which I find to be tedious and annoying.
// s3_client.rs

pub struct S3Client {
    client: aws_sdk_s3::Client,
}

impl S3Client {
    pub fn get_latest_object(
        &self,
        bucket: &str,
        prefix: &str,
    ) -> anyhow::Result<GetObjectOutput> {
        // make network requests using aws_sdk_s3::Client
    }
}

#[cfg(test)]
mod mock {
    use super::S3Client;
    use aws_sdk_s3::output::GetObjectOutput;

    impl S3Client {
        pub fn get_latest_object(
            &self,
            bucket: &str,
            prefix: &str,
        ) -> anyhow::Result<GetObjectOutput> {
            // create generic test scenario
            let get_object_output = GetObjectOutput::builder()
                .set_body(Some(TEST_DATA))
                .set_e_tag(Some(String::from(TEST_ETAG)))
                .build();
            Ok(get_object_output)
        }
    }
}
// sync.rs

#[cfg(not(test))]
use dependency::s3_client::S3Client;
#[cfg(test)]
use dependency::s3_client::mock::S3Client;

pub async fn sync(
    customer_client: Arc<Mutex<CustomerClient>>,
    s3_client: Arc<S3Client>,
    ...
) -> anyhow::Result<bool> {
    let s3_get_object_response = s3_client
        .get_latest_object(bucket, prefix)
        .await;
    // ...
}

#[cfg(test)]
mod tests {
    use dependency::s3_client::mock::S3Client;

    async fn test_sync_happy_case() {
        let client = aws_sdk_s3::Client::new(...);
        let s3_client = S3Client { client };

        // ...

        let is_updated = sync(
            customer_client,
            s3_client,
            ...
        )
        .await
        .unwrap();

        // ...
    }
}

Trait
#

To provide greater flexibility, we wrap the function-to-be-mocked within a trait. In the previous scenario, the struct needed to have the same members and functions, otherwise the compiler would complain. By implementing a trait, we no longer have this restriction. This enables the test struct to contain the desired response, providing flexibility to the scenarios we want to test.

// s3_client.rs

pub struct DefaultS3Client {
    client: aws_sdk_s3::Client,
}

trait S3Client {
    fn get_latest_object(
        &self,
        bucket: &str,
        prefix: &str,
    ) -> anyhow::Result<GetObjectOutput>;
}

impl S3Client for DefaultS3Client {
    fn get_latest_object(
        &self,
        bucket: &str,
        prefix: &str,
    ) -> anyhow::Result<GetObjectOutput> {
        // make network requests using aws_sdk_s3::Client
    }
}

#[cfg(test)]
mod mock {
    pub struct MockS3Client {
        pub response: anyhow::Result<GetObjectOutput>,
    }

    impl S3Client for MockS3Client {
        fn get_latest_object(
            &self,
            bucket: &str,
            prefix: &str,
        ) -> anyhow::Result<GetObjectOutput> {
            &self.response.clone()
        }
    }
}
// sync.rs

use dependency::s3_client::S3Client;

pub async fn sync(
    customer_client: Arc<Mutex<CustomerClient>>,
    s3_client: Arc<dyn S3Client>,
    ...
) -> anyhow::Result<bool> {
    let s3_get_object_response = s3_client
        .get_latest_object(bucket, prefix)
        .await;
    // ...
}

#[cfg(test)]
mod tests {
    use dependency::s3_client::S3Client;
    use dependency::s3_client::mock::MockS3Client;

    async fn test_sync_happy_case() {
        let get_object_output = GetObjectOutput::builder()
            .set_body(Some(TEST_DATA))
            .set_e_tag(Some(String::from(
                TEST_ETAG,
            )))
            .build();
        let response = Ok(get_object_output);
        let s3_client = MockS3Client { response };

        // ...

        let is_updated = sync(
            customer_client,
            s3_client,
            ...
        )
        .await
        .unwrap();

        // ...
    }
}

Closure
#

The previous example requires boilerplate code to extract the trait and configure a struct to contain the response. An alternative approach is to instead extract all relevant network calls into a separate method (i.e.get_latest_object()) and then wrap that method within a closure.

In this instance, it is necessary to split this into a separate generic given the response is a future (i.e. We can’t do F: Fn() -> Future<Output = anyhow::Result<GetObjectOutput>>).

// sync.rs

pub async fn sync(
    customer_client: Arc<Mutex<CustomerClient>>,
    s3_client: Arc<S3Client>,
    data_dir: &PathBuf,
    bucket: &str,
    prefix: &str,
) -> anyhow::Result<bool> {
    let closure = || async {
        get_latest_object(
            Arc::clone(&s3_client),
            bucket,
            prefix,
        )
        .await
    };

    sync_inner(customer_client, data_dir, closure).await
}

async fn sync_inner<F, R>(
    customer_client: Arc<Mutex<CustomerClient>>,
    data_dir: &PathBuf,
    get_latest_object_closure: F,
) -> anyhow::Result<bool>
where
    F: Fn() -> R,
    R: Future<Output = anyhow::Result<GetObjectOutput>>,
{
    let s3_get_object_response =
        get_latest_object_closure().await?;
    // ...
}

#[cfg(test)]
mod tests {
    async fn test_sync_happy_case() {
        let closure = || async {
            let get_object_output = GetObjectOutput::builder()
                .set_body(Some(TEST_DATA))
                .set_e_tag(Some(String::from(
                    TEST_ETAG,
                )))
                .build();
            Ok(get_object_output)
        };

        // ...

        let is_updated = sync_inner(
            customer_client,
            data_dir,
            closure,
        )
        .await
        .unwrap();

        // ...
    }
}

No Mocking
#

I felt like a badass when I got the Closure approach working. But it’s an example of over-engineering. In this specific scenario, it’s not necessary to pass in a closure - we can fetch the response in the containing function and just pass in the result.

// sync.rs

pub async fn sync(
    customer_client: Arc<Mutex<CustomerClient>>,
    s3_client: Arc<S3Client>,
    data_dir: &PathBuf,
    bucket: &str,
    prefix: &str,
) -> anyhow::Result<bool> {
    let s3_get_object_response = get_latest_object(
        Arc::clone(&s3_client),
        bucket,
        prefix,
    )
    .await;

    sync_inner(customer_client, data_dir, s3_get_object_response)
        .await
}

async fn sync_inner(
    customer_client: Arc<Mutex<CustomerClient>>,
    data_dir: &PathBuf,
    s3_get_object_response: Result<GetObjectOutput>,
) -> anyhow::Result<bool> {
    // ...
}

#[cfg(test)]
mod tests {
    async fn test_sync_happy_case() {
        let s3_get_object_response = Ok(
            GetObjectOutput::builder()
                .set_body(Some(TEST_DATA))
                .set_e_tag(Some(String::from(TEST_ETAG)))
                .build(),
        );

        // ...

        let is_updated = sync_inner(
            customer_client,
            data_dir,
            s3_get_object_response,
        )
        .await
        .unwrap();

        // ...
    }
}

Conclusion
#

This journey led me to a solution that, in hindsight, seems obvious. Coming from a Java background, my instinct was to mock everything. However, I’ve learned that mocking isn’t always the best approach. Admittedly, we miss out on some line coverage for our S3 network requests, but integration tests can make up for this shortfall.