naming is hard

Find slice element position in Rust, fast!

I started to learn Rust only recently and while exploring slice methods I was a bit surprised that I didn’t find any method for finding the position of element in the slice:

fn find(haystack: &[u8], needle: u8) -> Option<usize> { ... }

I had some experience with Zig which has a very useful std.mem module with many generic functions including indexOf, which internally implements Boyer-Moore-Horspool pattern matching algorithm against generic element type T:

fn indexOf(comptime T: type, haystack: []const T, needle: []const T) ?usize { ... }

After discussing with Rust experts I quickly got the response that I can just use methods of Iterator traits:

fn find(haystack: &[u8], needle: u8) -> Option<usize> {
    haystack.iter().position(|&x| x == needle)
}

Nice! But what about performance of this method? At first, I was afraid that using lambda function with closure will lead to poor performance (coming from Go with non-LLVM based compiler which has pretty limited power of inlining optimization). But, unsurprisingly for most of the developers, LLVM (and Rust) can optimize this method very nicely and rustc produce very clean binary with -C opt-level=3 -C target-cpu=native release profile flags:

# input : rdi=haystack.ptr, rsi=haystack.size, rdx=needle
# output: rax=None/Some, rdx=Some(v)
example::find:
        test    rsi, rsi # if haystack.len() == 0
        je      .LBB0_1  #   return None
        mov     ecx, edx # b = needle
        xor     eax, eax # result = None
        xor     edx, edx # i = 0
.LBB0_3:                                 # loop {
        cmp     byte ptr [rdi + rdx], cl #   if haystack[i] == b
        je      .LBB0_4                  #     return Option::Some(i)
        inc     rdx                      #   i++
        cmp     rsi, rdx                 #   if haystack.len() != i
        jne     .LBB0_3                  #     continue
        mov     rdx, rsi                 #   i = haystack.len()
        ret                              #   return None()
.LBB0_1:                                 # }
        xor     edx, edx
        xor     eax, eax
        ret
.LBB0_4:
        mov     eax, 1
        ret

Can we improve the method’s performance?

Implementing find without early returns

We can notice that for other iterator methods like filter compiler will use SSE / AVX instructions (if target CPU supports them). Then, what is preventing compiler from using SIMD instructions for position method? Internally within a team we came to the conclusion that position method implementation returns early which makes it harder for LLVM to apply SIMD (although I have no proofs for that).

We can assume that compiler will be able to vectorize function if it will have predictable amount of operations (static or with simple relation of input properties like slice length). How can we achieve that for the position function? Actually, there is a nice way to implement it without break in the middle of the loop: we just need to process slice in reverse order! Then, in this case, we can just reassign result variable if we found matching element:

pub fn find_branchless(haystack: &[u8], needle: u8) -> Option<usize> {
    let mut position = None::<usize>;
    for (i, &b) in haystack.iter().enumerate().rev() {
        if b == needle {
            position = Some(i);
        }
    }
    position
}

Unfortunately, this doesn’t help – there are still to SIMD instructions in the output assembler. But wait, we can notice drastic changes in the output binary – now it seems like compiler unrolled our main loop and compare bytes in chunks of size 8:

# there is just a part of the assembler, you can find full output by the godbolt link
.LBB0_11:
        cmp     byte ptr [r8 + r11 - 1], dl
        lea     r14, [rsi + r11 - 1]
        lea     r15, [rsi + r11 - 7]
        cmovne  r14, rcx
        cmove   rax, rbx
        cmp     byte ptr [r8 + r11 - 2], dl
        lea     rcx, [rsi + r11 - 2]
        cmovne  rcx, r14
        cmove   rax, rbx
        ...
        cmp     byte ptr [r8 + r11 - 8], dl
        lea     rcx, [rsi + r11 - 8]
        cmovne  rcx, r15
        cmove   rax, rbx

That’s looks promising! Unrolling will help in performance by itself, but we can be on the right path to the successful vectorization guidance for the compiler!

Vectorization by any means!

At this point, I had no clue of how I can simplify life of the compiler except only one last thing – we can make slice length constant and hope that this will finally activate vectorization engine in the LLVM. Turns out that this was enough! If we will use [u8; 16] or [u8; 32] types for input arguments – then LLVM will use 128-bit or 256-bit SSE / AVX registers and corresponding instructions!

example::find_branchless:
        vmovd   xmm0, esi
        vpbroadcastb    xmm0, xmm0
        vpcmpeqb        xmm0, xmm0, xmmword ptr [rdi]
        vpextrb eax, xmm0, 14
        vpextrb ecx, xmm0, 13
        vpextrb edx, xmm0, 10
        and     eax, 1
        xor     rax, 15
        test    cl, 1
        mov     ecx, 13
        cmove   rcx, rax
        vpextrb eax, xmm0, 12
        # ... there are a lot of instructions for determining actual position of matched byte ...
        vpmovmskb       esi, xmm0
        cmove   rdx, rcx
        xor     eax, eax
        test    esi, esi
        setne   al
        ret

You can notice that compiler generates truly branch-less code (literally, no jump instructions!). This can be surprising at the first sight, but actually compiler make use of cmove (“conditional move”) instruction which move value between operands only if the flags register are in the specific state. This instruction has way better performance then ordinary CMP / JEQ pair and allow to implement simple conditional scenarios like we have in the branch-less implementation of find function.

Vectorized version of find

Ok, that’s great that we finally forced rustc to use vectorization. But current find implementation is barely usable because it works only for byte arrays of fixed size! What can we do about that?

Here the actions are straightforward – we can split our input slice in chunks of bounded size and try to apply our branch-less implementation of find method for them. Rust has nice chunks function which do exactly what we want, let’s try to use it:

fn find_branchless(haystack: &[u8], needle: u8) -> Option<usize> { ... }
pub fn find(haystack: &[u8], needle: u8) -> Option<usize> {
    haystack
        .chunks(32)
        .enumerate()
        .find_map(|(i, chunk)| find_branchless(chunk, needle).map(|x| 32 * i + x) )
}

Unfortunately, this doesn’t work – the compiler again produces boring assembly with only unrolling optimization on. But, if we stop and think about it, this is actually expected! Chunking logic make every chunk unpredictable in size – because there is no guarantees about exact size of the last chunk (and every chunk can be the last one!).

Luckily, Rust developer team thought about this and added method chunks_exact specifically for such cases! This method split slice in equally sized chunks and provides access to the tail of potentially smaller size through additional method: remainder.

This final step allow us to make our dream come true: vectorized find function with only safe Rust!

// bonus: refactoring of find_branchless function to make it more elegant!
fn find_branchless(haystack: &[u8], needle: u8) -> Option<usize> {
    return haystack.iter().enumerate()
        .filter(|(_, &b)| b == needle)
        .rfold(None, |_, (i, _)| Some(i))
}

fn find(haystack: &[u8], needle: u8) -> Option<usize> {
    let chunks = haystack.chunks_exact(32);
    let remainder = chunks.remainder();
    chunks.enumerate()
        .find_map(|(i, chunk)| find_branchless(chunk, needle).map(|x| 32 * i + x) )
        .or(find_branchless(remainder, needle).map(|x| (haystack.len() & !0x1f) + x))
}

Benchmarks

The full benchmark source code is available here: ./rust-find-bench

method time speedup
find_naive 366.06 ns x1.0
find_chunks 414.06 ns x0.9
find_chunks_exact 133.53 ns x2.7
find_chunks_exact_branchless 40.48 ns x9.0