I started to learn Rust
only recently and while exploring slice methods I was a bit surprised that I didn’t find any method for finding the position of element in the slice:
fn find(haystack: &[u8], needle: u8) -> Option<usize> { ... }
I had some experience with Zig
which has a very useful std.mem
module with many generic functions including indexOf
, which internally implements Boyer-Moore-Horspool pattern matching algorithm against generic element type T
:
fn indexOf(comptime T: type, haystack: []const T, needle: []const T) ?usize { ... }
After discussing with Rust
experts I quickly got the response that I can just use methods of Iterator
traits:
fn find(haystack: &[u8], needle: u8) -> Option<usize> {
haystack.iter().position(|&x| x == needle)
}
Nice! But what about performance of this method? At first, I was afraid that using lambda function with closure will lead to poor performance (coming from Go
with non-LLVM
based compiler which has pretty limited power of inlining optimization). But, unsurprisingly for most of the developers, LLVM
(and Rust
) can optimize this method very nicely and rustc
produce very clean binary with -C opt-level=3 -C target-cpu=native
release profile flags:
# input : rdi=haystack.ptr, rsi=haystack.size, rdx=needle
# output: rax=None/Some, rdx=Some(v)
example::find:
test rsi, rsi # if haystack.len() == 0
je .LBB0_1 # return None
mov ecx, edx # b = needle
xor eax, eax # result = None
xor edx, edx # i = 0
.LBB0_3: # loop {
cmp byte ptr [rdi + rdx], cl # if haystack[i] == b
je .LBB0_4 # return Option::Some(i)
inc rdx # i++
cmp rsi, rdx # if haystack.len() != i
jne .LBB0_3 # continue
mov rdx, rsi # i = haystack.len()
ret # return None()
.LBB0_1: # }
xor edx, edx
xor eax, eax
ret
.LBB0_4:
mov eax, 1
ret
Can we improve the method’s performance?
find
without early returnsWe can notice that for other iterator methods like filter
compiler will use SSE / AVX
instructions (if target CPU supports them). Then, what is preventing compiler from using SIMD
instructions for position
method? Internally within a team we came to the conclusion that position
method implementation returns early which makes it harder for LLVM
to apply SIMD
(although I have no proofs for that).
We can assume that compiler will be able to vectorize function if it will have predictable amount of operations (static or with simple relation of input properties like slice length). How can we achieve that for the position
function? Actually, there is a nice way to implement it without break
in the middle of the loop: we just need to process slice in reverse order! Then, in this case, we can just reassign result variable if we found matching element:
pub fn find_branchless(haystack: &[u8], needle: u8) -> Option<usize> {
let mut position = None::<usize>;
for (i, &b) in haystack.iter().enumerate().rev() {
if b == needle {
position = Some(i);
}
}
position
}
Unfortunately, this doesn’t help – there are still to SIMD
instructions in the output assembler. But wait, we can notice drastic changes in the output binary – now it seems like compiler unrolled our main loop and compare bytes in chunks of size 8:
# there is just a part of the assembler, you can find full output by the godbolt link
.LBB0_11:
cmp byte ptr [r8 + r11 - 1], dl
lea r14, [rsi + r11 - 1]
lea r15, [rsi + r11 - 7]
cmovne r14, rcx
cmove rax, rbx
cmp byte ptr [r8 + r11 - 2], dl
lea rcx, [rsi + r11 - 2]
cmovne rcx, r14
cmove rax, rbx
...
cmp byte ptr [r8 + r11 - 8], dl
lea rcx, [rsi + r11 - 8]
cmovne rcx, r15
cmove rax, rbx
That’s looks promising! Unrolling will help in performance by itself, but we can be on the right path to the successful vectorization guidance for the compiler!
At this point, I had no clue of how I can simplify life of the compiler except only one last thing – we can make slice length constant and hope that this will finally activate vectorization engine in the LLVM
. Turns out that this was enough! If we will use [u8; 16]
or [u8; 32]
types for input arguments – then LLVM
will use 128
-bit or 256
-bit SSE
/ AVX
registers and corresponding instructions!
example::find_branchless:
vmovd xmm0, esi
vpbroadcastb xmm0, xmm0
vpcmpeqb xmm0, xmm0, xmmword ptr [rdi]
vpextrb eax, xmm0, 14
vpextrb ecx, xmm0, 13
vpextrb edx, xmm0, 10
and eax, 1
xor rax, 15
test cl, 1
mov ecx, 13
cmove rcx, rax
vpextrb eax, xmm0, 12
# ... there are a lot of instructions for determining actual position of matched byte ...
vpmovmskb esi, xmm0
cmove rdx, rcx
xor eax, eax
test esi, esi
setne al
ret
You can notice that compiler generates truly branch-less code (literally, no jump instructions!). This can be surprising at the first sight, but actually compiler make use of cmove
(“conditional move”) instruction which move value between operands only if the flags register are in the specific state. This instruction has way better performance then ordinary CMP
/ JEQ
pair and allow to implement simple conditional scenarios like we have in the branch-less implementation of find
function.
find
Ok, that’s great that we finally forced rustc
to use vectorization. But current find
implementation is barely usable because it works only for byte arrays of fixed size! What can we do about that?
Here the actions are straightforward – we can split our input slice in chunks of bounded size and try to apply our branch-less implementation of find
method for them. Rust
has nice chunks
function which do exactly what we want, let’s try to use it:
fn find_branchless(haystack: &[u8], needle: u8) -> Option<usize> { ... }
pub fn find(haystack: &[u8], needle: u8) -> Option<usize> {
haystack
.chunks(32)
.enumerate()
.find_map(|(i, chunk)| find_branchless(chunk, needle).map(|x| 32 * i + x) )
}
Unfortunately, this doesn’t work – the compiler again produces boring assembly with only unrolling optimization on. But, if we stop and think about it, this is actually expected! Chunking logic make every chunk unpredictable in size – because there is no guarantees about exact size of the last chunk (and every chunk can be the last one!).
Luckily, Rust
developer team thought about this and added method chunks_exact
specifically for such cases! This method split slice in equally sized chunks and provides access to the tail of potentially smaller size through additional method: remainder
.
This final step allow us to make our dream come true: vectorized find
function with only safe Rust
!
// bonus: refactoring of find_branchless function to make it more elegant!
fn find_branchless(haystack: &[u8], needle: u8) -> Option<usize> {
return haystack.iter().enumerate()
.filter(|(_, &b)| b == needle)
.rfold(None, |_, (i, _)| Some(i))
}
fn find(haystack: &[u8], needle: u8) -> Option<usize> {
let chunks = haystack.chunks_exact(32);
let remainder = chunks.remainder();
chunks.enumerate()
.find_map(|(i, chunk)| find_branchless(chunk, needle).map(|x| 32 * i + x) )
.or(find_branchless(remainder, needle).map(|x| (haystack.len() & !0x1f) + x))
}
The full benchmark source code is available here: ./rust-find-bench
method | time | speedup |
---|---|---|
find_naive |
366.06 ns |
x1.0 |
find_chunks |
414.06 ns |
x0.9 |
find_chunks_exact |
133.53 ns |
x2.7 |
find_chunks_exact_branchless |
40.48 ns |
x9.0 |