Calculating the nth prime in Rust, with memoization and thread-safety
By Pedro R. Borges
Going through the Rust track at exercism.org, I found the Nth Prime exercise. The exercise asks to determine the nth prime for a given integer n, considering 2 as the 0th prime. To make it more interesting, I decided to memoize the calculated primes in a table and to make the calculation thread-safe so that several primes can be computed concurrently. The implementation described here is available as a GitHub repo.
The NthPrime type
Since the original exercise requires the implementation of a stand-alone function nth(n: u32) -> u32
,
in my solution at exercism.org I use a static vector as the table for the memoization, with a fixed size.
For the implementation in this article, I define a struct NthPrime
with the memo table and the maximum position to be memoized as fields, and the nth
function as a method of NthPrime
.
I also changed the type of nth
to receive a usize integer and return u64 primes.
I define the struct as follows:
use std::sync::RwLock;
pub struct NthPrime {
memo_table: RwLock<Vec<u64>>,
max_pos: usize,
}
To allow the memo table to be used concurrently, I placed it under a RwLock, which provides thread-safe read-only as well as exclusive write access for the vector.
I don’t use any lock for the max_pos
field, as this is a read-only field, set when the struct is created.
Creating the NthPrime struct
To implement the new
method that creates a NthPrime
we just have to initialize the two fields of the struct.
For the memo table, we need to create a new RwLock with the vector of memoized primes.
The vector itself is created with the first 2 primes, as required by the implementation of the nth
function described below.
For this reason, the value of max_pos
must be at least 1.
impl NthPrime {
pub fn new(max_pos: usize) -> NthPrime {
NthPrime {
memo_table: RwLock::new(vec![2, 3]),
max_pos: usize::max(1, max_pos),
}
}
}
The nth method
Let’s now define nth(&self, n: usize) -> u64
.
To find the n
th prime, I consider 2 main cases:
-
The prime can be memoized, that is, n ≤
max_pos
. In this case, if the required position on the memo table is not filled, I fill the table up to that position. Then the prime to return is on the table. -
When n >
max_pos
, the prime can not be memoized, and it is calculated bynth_not_memoized
, after filling the memo table if it was not already filled.
The code for nth
is as follows:
pub fn nth(&self, n: usize) -> u64 {
if n <= self.max_pos {
if n >= self.memo_table.read().unwrap().len() {
self.fill_table_to(n);
}
return self.memo_table.read().unwrap()[n];
}
self.fill_table_to(self.max_pos);
self.nth_not_memoized(n)
}
Note that to query the memo table we must acquire a read-lock over it, and so we use self.memo_table.read().unwrap()
to access it.
Filling the memo table
The task of the fill_table_t
method is to find the first prime after the last one on the table, store it, and repeat the process until the required table position is filled.
This is a straightforward process, but we must take into account the possibility of other threads accessing the memo table concurrently.
For example, while we calculate the next prime to store in the table,
another thread may get to do it before us.
If we then push it on the table, it would not be stored in the proper position.
To prevent this, if last_pos
was the last position filled in the table when we started to calculate the next prime,
we must make sure that the position to fill by a push is still last_pos + 1
before doing the push.
Crucially, both the push and the check before it must be done while holding an exclusive lock on the memo table:
{
let mut locked_table = self.memo_table.write().unwrap();
if last_pos + 1 == locked_table.len() {
locked_table.push(next_prime);
}
}
Since we need to modify the table, we acquire a write-lock over it instead of a read lock as before.
Note also that we open an inner block so that locked_table
goes out of scope once we are done, and hence the lock is terminated.
The definition of fill_table_to
is:
fn fill_table_to(&self, end: usize) {
loop {
let (last_memoized, last_pos) = self.get_last_and_pos();
if last_pos >= end {
break;
}
let next_prime = self.prime_from(&[], last_memoized + 2);
{
let mut locked_table = self.memo_table.write().unwrap();
if last_pos + 1 == locked_table.len() {
locked_table.push(next_prime);
}
}
}
}
Here, I use get_last_and_pos
to read the last prime in the table and its position with the table under a lock,
so that both values are consistent:
fn get_last_and_pos(&self) -> (u64, usize) {
let locked_table = self.memo_table.read().unwrap();
let last_pos = locked_table.len() - 1;
(locked_table[last_pos], last_pos)
}
Finding the next prime
To find the next prime to memoize, I define next_prime_from
.
This function receives the first candidate to test for primality as its second argument.
It looks for possible factors either in the memo table or in its first argument.
When used in fill_table_to
, the possible factors for the candidates are all in the memo table, hence it receives an empty slice as the first argument.
A candidate is prime if it does not have factors neither in the memo table nor in the factors received as its first argument.
This is checked by the function has_factor_in
, which also receives an upper bound for the possible factors.
We have, then:
fn prime_from(&self, primes_not_in_table: &[u64], start: u64) -> u64 {
let mut candidate = start;
loop {
let highest_possible_factor = (candidate as f64).sqrt().ceil() as u64;
if !has_factor_in(
&self.memo_table.read().unwrap()[1..],
highest_possible_factor,
candidate,
) && !has_factor_in(primes_not_in_table, highest_possible_factor, candidate)
{
break;
}
candidate += 2;
}
candidate
}
Of course, even candidates are not considered, and this is ensured by the requirement that the memo table starts at least with 2 and 3.
Since I only test odd candidates, I exclude position 0 (which contains 2) from the table when I pass it to has_factor_in
.
In has_factor_in
, we just iterate over the elements on the slice while below the upper limit, and check if any is a factor:
fn has_factor_in(factors: &[u64], limit: u64, x: u64) -> bool {
factors
.iter()
.take_while(|&f| *f <= limit)
.any(|f| x % f == 0)
}
Calculating primes not to be memoized
We now handle the second case considered in the nth
method: when the prime to be calculated is not to be stored in the memo table.
This is the task of nth_not_memoized
.
This method is similar to fill_table_to
: it invokes prime_from
to find the next prime, stores it, and repeats if needed.
However, the computed primes are stored in the vector primes_not_in_table
, which is not shared with other threads.
Since no other thread contributes to the vector, the process has to be done exactly n - max_pos
times, to calculate the n
th prime.
We have, then:
fn nth_not_memoized(&self, n: usize) -> u64 {
let mut last_found = self.memo_table.read().unwrap()[self.max_pos];
let mut primes_not_in_table = Vec::with_capacity(n - self.max_pos);
for _ in self.max_pos..n {
last_found = self.prime_from(&primes_not_in_table, last_found + 2);
primes_not_in_table.push(last_found);
}
last_found
}
This concludes the implementation of NthPrime
.
Let’s now use it in some test.
Testing it
I include three tests in the repo: test_memoized_primes
, test_primes_not_memoized
, and test_multithreaded_same_as_single_threaded
.
The first two are single-threaded.
For them, I just create an NthPrime
, and check the results of computing some known primes, not to be memoized in the second test.
Of these, I include here only the second one:
#[test]
fn test_primes_not_memoized() {
let ntp = NthPrime::new(1_000);
assert_eq!(ntp.nth(3_000), 27457);
assert_eq!(ntp.nth(10_000), 104_743);
assert_eq!(ntp.nth(78_498), 1_000_003);
}
The third test is more interesting. I first define 3 variables with some parameters:
- The number of threads to be spawned:
n_threads
. Each thread will compute a prime. - The
max_pos
parameter for theNthPrime
s created in the test. - The upper limit for the random arguments used for
nth
:max_n
.
The test then proceeds as follows:
-
Generates
n_threads
random integersn
, in the range0..max_n
, and computesnth(n)
for eachn
. The random numbers are stored inrandom_args
and the primes are stored insingle_threaded_primes
. -
Creates a new
NthPrime
object and spawnsn_threads
threads. The threads compute the same primes of the previous step, and they are stored inmulti_threaded_primes
. -
Checks that the primes calculated in the first and second steps are the same.
Since the threads on the second step share the newly created NthPrime
and the random_args
filled in the first step,
they are wrapped in an ARC:
let ntp = Arc::new(NthPrime::new(max_pos));
let random_args = Arc::new(random_args);
The references to both objects are then cloned to be moved to each thread:
let shared_args = random_args.clone();
let shared_ntp = ntp.clone();
The complete test is then:
fn test_multithreaded_same_as_single_threaded() {
let n_threads = 200;
let max_pos = 10_000;
let max_n = 13_000;
let mut rng = rand::thread_rng();
// Generate n_threads random_args n and their corresponding nth(n) primes
let ntp = NthPrime::new(max_pos);
let mut random_args = Vec::with_capacity(n_threads);
let mut single_threaded_primes = Vec::with_capacity(n_threads);
for _ in 0..n_threads {
let n = rng.gen_range(0..max_n);
random_args.push(n);
single_threaded_primes.push(ntp.nth(n));
}
// Generate again the nth(n) primes, each in a separate thread
let ntp = Arc::new(NthPrime::new(max_pos));
let random_args = Arc::new(random_args);
let mut multi_threaded_primes = Vec::with_capacity(n_threads);
let mut handles = Vec::with_capacity(n_threads);
for i in 0..n_threads {
let shared_args = random_args.clone();
let shared_ntp = ntp.clone();
let handle = thread::spawn(move || shared_ntp.nth(shared_args[i]));
handles.push(handle);
}
for handle in handles {
multi_threaded_primes.push(handle.join().unwrap());
}
// Check that multithreaded_primes are the same as the single_threaded
for i in 0..n_threads {
assert_eq!(single_threaded_primes[i], multi_threaded_primes[i]);
}
}
Some comments
In this post, my main interest was to show how to manage the memoization with thread-safety, and how to use NthPrime
, as shown in the multithreaded test.
There are, of course, many possible variations on, for example, how to calculate the prime numbers, or the integer sizes to use.
One variation of the technique used here, known as trial division, is to only consider as prime candidates integers of the form 6k±1, as done in this article. Another alternative is to use a sieve algorithm instead of trial division, as in the Primes crate.
The implementation I describe here for nth_not_memoized
may generate a spike in memory allocation to store the vector primes_not_in_table
, which is just an extension of memo_table
, discarded when the function terminates.
We could avoid storing these extra primes and consider as possible factors all the odd numbers after the last prime in the memo table.
This might not be as costly as it sounds, given the upper bound for the possible factors.
For example, the 1,000,000th prime is 15,485,867,
and to check if it is prime, the highest possible factor to consider is 3,936, which is less than the 546th prime.
Then, if max_pos
is set to 546, computing primes up to the 1,000,000th, would not actually check for factors not in the memo table.
In the original version of this post, I stated that to calculate the nth prime you needed to calculate all the previous primes. This was wrong, as someone pointed out in a comment on Reddit. Using the Prime number theorem, we can obtain a range for the nth prime. This is Consequence two of the theorem in this article, which also discusses techniques for more accurate estimates. You might also be interested in The Nth Prime Page.
Any thoughts you would like to share? Feel free to leave a comment! And thanks for reading!