naming is hard

Find element’s position in Rust – 9 times faster!

When I started to learn Rust and exploring slice methods, I was surprised that Rust doesn’t provide a built-in method for finding the position of an element in a slice.

Of course, this is not a big problem, as anyone can implement a simple and concise one-liner for that purpose, like the following:

// for simplicity, I will use u8 type instead of generics throughout the post
fn find(haystack: &[u8], needle: u8) -> Option<usize> {
    haystack.iter().position(|&x| x == needle)
}

This is really nice, but what about the performance of this implementation? It turns out that while the performance of this method is good, you can make the implementation 9 times faster by restructuring the code to help the LLVM compiler backend to vectorize it.

Let’s dig a bit deeper into how we can achieve this performance boost for such a simple task.

First version

First of all, let’s look at the compiler output which rustc produces (godbolt output with -C opt-level=3 -C target-cpu=native release profile flags):

# input : rdi=haystack.ptr, rsi=haystack.size, rdx=needle
# output: rax=None/Some, rdx=Some(v)
example::find:
        test    rsi, rsi # if haystack.len() == 0
        je      .LBB0_1  #   return None
        mov     ecx, edx # b = needle
        xor     eax, eax # result = None
        xor     edx, edx # i = 0
.LBB0_3:                                 # loop {
        cmp     byte ptr [rdi + rdx], cl #   if haystack[i] == b
        je      .LBB0_4                  #     return Option::Some(i)
        inc     rdx                      #   i++
        cmp     rsi, rdx                 #   if haystack.len() != i
        jne     .LBB0_3                  #     continue
        mov     rdx, rsi                 #   i = haystack.len()
        ret                              #   return None()
.LBB0_1:                                 # }
        xor     edx, edx
        xor     eax, eax
        ret
.LBB0_4:
        mov     eax, 1
        ret

While the assembler code is very clean and do the job, it lacks one important component – vectorization! For such simple operation over primitive u8 type I would expect LLVM to came up with some clever vectorized version for performance boost. For example, consider the compilation result for haystack.iter().sum() one-liner:

# core vectorized loop; whole listing omitted for brevity
.LBB0_7:
        vpaddb  ymm0, ymm0, ymmword ptr [rdi + rax]
        vpaddb  ymm1, ymm1, ymmword ptr [rdi + rax + 32]
        vpaddb  ymm2, ymm2, ymmword ptr [rdi + rax + 64]
        vpaddb  ymm3, ymm3, ymmword ptr [rdi + rax + 96]
        sub     rax, -128
        cmp     rcx, rax
        jne     .LBB0_7

You can notice that here compiler decided to use YMM[N] registers which are 256-bit wide AVX registers. So, in the case of u8 elements, a single vpaddb operation will simultaneously process 32 elements of our slice and potentially increase the speed of the sum() method by a factor of 32!

So, what is preventing LLVM from using vectorization in our implementation of find method? We can make a guess or use LLVM optimization remarks, which will hint at problems with our implementation:

$> rustc main.rs -C opt-level=3 -C target-cpu=native -Cdebuginfo=full \
    -Cllvm-args=-pass-remarks-analysis='.*' \
    -Cllvm-args=-pass-remarks-missed='.*' \
    -Cllvm-args=-pass-remarks='.*'
remark: /.../iter/macros.rs:361:24: loop not vectorized: Cannot vectorize potentially faulting early exit loop
remark: /.../iter/macros.rs:361:24: loop not vectorized
remark: /.../iter/macros.rs:0:0: Cannot SLP vectorize list: vectorization was impossible with available vectorization factors
remark: /.../iter/macros.rs:370:14: Cannot SLP vectorize list: only 2 elements of buildvalue, trying reduction first.
remark: /.../iter/macros.rs:0:0: Cannot SLP vectorize list: vectorization was impossible with available vectorization factors

Remarks listed above highlights two important things:

  1. LLVM has 2 types of vectorization optimizations: LoopVectorizer and SLPVectorizer
  2. LoopVectorizer was unable to optimize the loop because it has early return (Cannot vectorize potentially faulting early exit loop)
  3. SLPVectorizer was unable to optimize the loop too for some other reason

Let’s try to fix LoopVectorizer’s complaint and implement the find method without early return!

To make things clear, early return appears in our code because under the hood it will be transformed into following form:

fn find(haystack: &[u8], needle: u8) -> Option<usize> {
    for (i, &b) in haystack.iter().enumerate() {
        if b == needle {
            return Some(i) // <- here, the early return!
        }
    }
    None
}

Implementing find without early returns

As our implementation of find method needs to find first occurrence of the element – it needs to return when it first encounter the element during natural left-to-right slice traversal. So, in case of natural order enumeration return is inevitable. But what if we traverse slice elements in reverse order, from right-to-left? In this case we will need to find “last” occurrence of the element (according to the enumeration order) and we will no longer need an early return!

Note, that right-to-left enumeration is pretty inefficient as it needs to traverse whole slice in order to find first occurrence, while original find implementation works much faster if first needle occurrence appear close to the slice start. We will fix that later, but for now let’s focus on fixing the early return problem to please the vectorization optimizators.

So, the implementation of the find method without early returns looks like this:

pub fn find_no_early_returns(haystack: &[u8], needle: u8) -> Option<usize> {
    let mut position = None;
    for (i, &b) in haystack.iter().enumerate().rev() {
        if b == needle {
            position = Some(i);
        }
    }
    position
}

Unfortunately, this doesn’t help – there are still no SIMD instructions in the output assembly. But we can notice drastic changes in the output binary and also in the remarks produced by LLVM – now compiler unrolled our main loop and compare bytes in chunks of size 8:

# there is just a part of the assembler, you can find full output by the godbolt link
.LBB0_11:
        cmp     byte ptr [r8 + r11 - 1], dl
        lea     r14, [rsi + r11 - 1]
        cmovne  r14, rcx
        lea     rcx, [rsi + r11 - 2]
        cmove   rax, rbx
        cmp     byte ptr [r8 + r11 - 2], dl
        cmovne  rcx, r14
        lea     r14, [rsi + r11 - 3]
        cmove   rax, rbx
        ...
        cmove   rax, rbx
        cmp     byte ptr [r8 + r11 - 8], dl
        cmovne  rcx, r14
        cmove   rax, rbx

You can notice that the compiler generates truly branch-less code for the unrolled section (literally, no jump instructions!). This can be surprising at the first sight, but actually compiler make use of cmove (“conditional move”) instruction which move value between operands only if the flags register are in the specific state. This instruction has much better performance than an ordinary CMP/JEQ pair and can be used in a simple conditional scenarios like we have in the no-early-return implementation of the find function.

And LLVM remarks now are different:

$> rustc main.rs -C opt-level=3 -C target-cpu=native -Cdebuginfo=full \
    -Cllvm-args=-pass-remarks-analysis='.*' \
    -Cllvm-args=-pass-remarks-missed='.*' \
    -Cllvm-args=-pass-remarks='.*'
remark: /.../slice/iter/macros.rs:25:86: loop not vectorized: value that could not be identified as reduction is used outside the loop
remark: /.../slice/iter/macros.rs:25:86: loop not vectorized
remark: main.rs:0:0: Cannot SLP vectorize list: vectorization was impossible with available vectorization factors
remark: main.rs:0:0: List vectorization was possible but not beneficial with cost 0 >= 0
remark: main.rs:0:0: List vectorization was possible but not beneficial with cost 0 >= 0
remark: main.rs:17:2: Cannot SLP vectorize list: only 2 elements of buildvalue, trying reduction first.
remark: main.rs:0:0: List vectorization was possible but not beneficial with cost 0 >= 0
remark: :0:0: Cannot SLP vectorize list: vectorization was impossible with available vectorization factors
remark: :0:0: List vectorization was possible but not beneficial with cost 0 >= 0
remark: :0:0: List vectorization was possible but not beneficial with cost 0 >= 0
remark: /.../slice/iter/macros.rs:25:86: unrolled loop by a factor of 8 with run-time trip count

Here, we see that LoopVectorizer was still unable to optimize the loop due to the external variable (pos) used in it. However, SLPVectorizer kicked in and decided not to optimize the loop, even though it was possible (so it actually found a way to vectorize the loop, hooray!).

It’s a bit strange that LoopVectorizer was unable to detect the reduction variable in our case because, for arithmetic operations, it usually works well. But maybe LoopVectorizer expect nice math properties from the operation (like commutativity) which assignment operator obviously doesn’t have.

Activate SLPVectorizer!

The SLPVectorizer under the hood “combines similar independent instructions into vector instructions” and it’s not obvious what it can do in our case because we already have all operations under the loop and there is no more repetitions.

Example of SLPVectorizer from the LLVM documentation looks like this:

void foo(int a1, int a2, int b1, int b2, int *A) {
  A[0] = a1*(a1 + b1);
  A[1] = a2*(a2 + b2);
  A[2] = a1*(a1 + b1);
  A[3] = a2*(a2 + b2);
}

But given that the loop-unroll optimization was triggered in our case, it seems natural that after unrolling SLPVectorizer can optimize a bunch of repeated actions into vectorized code. What we can do here is to make more explicit that unrolling is beneficial in our case by using fixed-size slice &[u8; 32] instead of arbitrary slice. And this works!

// only signature changed - rest of the code is the same
pub fn find_no_early_returns(haystack: &[u8; 32], needle: u8) -> Option<usize> { ... }

For the signature above we will finally see vectorized code in the compiler output:

example::find_no_early_returns:
        vmovd   xmm0, esi
        vpbroadcastb    ymm0, xmm0
        vpcmpeqb        ymm0, ymm0, ymmword ptr [rdi]
        vpmovmskb       ecx, ymm0
        mov     eax, ecx
        shr     eax, 30
        and     eax, 1
        xor     rax, 31
        test    ecx, 536870912
        mov     edx, 29
        cmove   rdx, rax
        ... # there are a lot of instructions for determining actual position of matched byte
        test    cl, 2
        mov     esi, 1
        cmove   rsi, rax
        xor     edx, edx
        test    cl, 1
        cmove   rdx, rsi
        xor     eax, eax
        test    ecx, ecx
        setne   al
        vzeroupper
        ret

Also, SLPVectorizer reported successful optimization of the loop in the compiler remarks:

$> rustc main.rs -C opt-level=3 -C target-cpu=native -Cdebuginfo=full \
    -Cllvm-args=-pass-remarks-analysis='.*' \
    -Cllvm-args=-pass-remarks-missed='.*' \
    -Cllvm-args=-pass-remarks='.*'
remark: main.rs:17:2: Cannot SLP vectorize list: only 2 elements of buildvalue, trying reduction first.
remark: main.rs:12:12: Vectorized horizontal reduction with cost -21 and with tree size 3

Ok, the victory seems very close! Now we need to overcome the issue that in order for SLPVectorizer to work we need to utilize fixed-size slices.

Final vectorized version of find method

As we know that SLPVectorizer kicks in when slice has fixed size, we should split our original slice into fixed-size chunks and then process them independently with the method without early return. Our first attempt can be to use chunks method which do exactly what we want, so let’s try it out:

While chunks method splits slice into sub-slices of unknown size from the language perspective (they are still &[u8] – not &[u8; N]), we can hope that LLVM aggressive inlining will help us here and optimizer will understand that slices are fixed sized after all inlining will happen under the hood.

fn find_no_early_returns(haystack: &[u8], needle: u8) -> Option<usize> { ... }
pub fn find(haystack: &[u8], needle: u8) -> Option<usize> {
    haystack
        .chunks(32)
        .enumerate()
        .find_map(|(i, chunk)| find_no_early_returns(chunk, needle).map(|x| 32 * i + x) )
}

Unfortunately, this doesn’t work – the compiler again produces boring assembly with only the unrolling optimization enabled. But, if we stop and think about it a bits, this is actually expected behaviour! The chunks method treats all chunks uniformly, including the last chunk – which can have a size less than 32 elements!

Luckily, Rust developer team thought about this and added method chunks_exact specifically for such cases! This method split slice in equally sized chunks and provides access to the tail of potentially smaller size through additional method: remainder.

This final step allows us to make our dream come true: a vectorized find function using only safe Rust!

// bonus: refactoring of find_branchless function to make it more elegant!
fn find_no_early_returns(haystack: &[u8], needle: u8) -> Option<usize> {
    return haystack.iter().enumerate()
        .filter(|(_, &b)| b == needle)
        .rfold(None, |_, (i, _)| Some(i))
}

fn find(haystack: &[u8], needle: u8) -> Option<usize> {
    let chunks = haystack.chunks_exact(32);
    let remainder = chunks.remainder();
    chunks.enumerate()
        .find_map(|(i, chunk)| find_no_early_returns(chunk, needle).map(|x| 32 * i + x) )
        .or(find_no_early_returns(remainder, needle).map(|x| (haystack.len() & !0x1f) + x))
}

Conclusion and benchmarks

While LLVM is good at auto-vectorization, sometimes you need to push it a bit to make your case more obvious for optimization passes. And here Rust standard library combined with LLVM aggressive inlining will allow to make non-trivial structural changes in the code while keeping it completely safe (I imagine that chunks_exact method can be very helpful for various other cases).

Also, it was interesting to look how LLVM compiler hints can be exposed to the developer for more deep understanding of the compiler behavior (turned out this is very easy to do as rustc allow you to pass additional arguments to the LLVM compiler).

And as a final result, the last implementation of the find method is 9 times faster than the naive implementation presented in the beginning of the article. You can find the full benchmark source code here: rust-find-bench

  • find_chunks_exact_no_early_return – no-early-return version with chunk_exact method
  • find_chunks_exact – naive version with chunks_exact method
  • find_naive – naive version from the beginning of the article
  • find_chunks – naive version with chunks method
method time speedup
find_chunks_exact_no_early_return 40.18ns x9.0
find_chunks_exact 126.77ns x2.7
find_naive 356.07ns x1.0
find_chunks 510.16ns x0.7