When I started to learn Rust
and exploring slice methods, I was surprised that Rust
doesn’t provide a built-in method for finding the position of an element in a slice.
Of course, this is not a big problem, as anyone can implement a simple and concise one-liner for that purpose, like the following:
// for simplicity, I will use u8 type instead of generics throughout the post
fn find(haystack: &[u8], needle: u8) -> Option<usize> {
haystack.iter().position(|&x| x == needle)
}
This is really nice, but what about the performance of this implementation? It turns out that while the performance of this method is good, you can make the implementation 9 times faster by restructuring the code to help the LLVM
compiler backend to vectorize it.
Let’s dig a bit deeper into how we can achieve this performance boost for such a simple task.
First of all, let’s look at the compiler output which rustc
produces (godbolt output with -C opt-level=3 -C target-cpu=native
release profile flags):
# input : rdi=haystack.ptr, rsi=haystack.size, rdx=needle
# output: rax=None/Some, rdx=Some(v)
example::find:
test rsi, rsi # if haystack.len() == 0
je .LBB0_1 # return None
mov ecx, edx # b = needle
xor eax, eax # result = None
xor edx, edx # i = 0
.LBB0_3: # loop {
cmp byte ptr [rdi + rdx], cl # if haystack[i] == b
je .LBB0_4 # return Option::Some(i)
inc rdx # i++
cmp rsi, rdx # if haystack.len() != i
jne .LBB0_3 # continue
mov rdx, rsi # i = haystack.len()
ret # return None()
.LBB0_1: # }
xor edx, edx
xor eax, eax
ret
.LBB0_4:
mov eax, 1
ret
While the assembler code is very clean and do the job, it lacks one important component – vectorization! For such simple operation over primitive u8
type I would expect LLVM
to came up with some clever vectorized version for performance boost. For example, consider the compilation result for haystack.iter().sum()
one-liner:
# core vectorized loop; whole listing omitted for brevity
.LBB0_7:
vpaddb ymm0, ymm0, ymmword ptr [rdi + rax]
vpaddb ymm1, ymm1, ymmword ptr [rdi + rax + 32]
vpaddb ymm2, ymm2, ymmword ptr [rdi + rax + 64]
vpaddb ymm3, ymm3, ymmword ptr [rdi + rax + 96]
sub rax, -128
cmp rcx, rax
jne .LBB0_7
You can notice that here compiler decided to use YMM[N]
registers which are 256-bit wide AVX registers. So, in the case of u8
elements, a single vpaddb
operation will simultaneously process 32 elements of our slice and potentially increase the speed of the sum()
method by a factor of 32!
So, what is preventing LLVM
from using vectorization in our implementation of find
method? We can make a guess or use LLVM
optimization remarks, which will hint at problems with our implementation:
$> rustc main.rs -C opt-level=3 -C target-cpu=native -Cdebuginfo=full \
-Cllvm-args=-pass-remarks-analysis='.*' \
-Cllvm-args=-pass-remarks-missed='.*' \
-Cllvm-args=-pass-remarks='.*'
remark: /.../iter/macros.rs:361:24: loop not vectorized: Cannot vectorize potentially faulting early exit loop
remark: /.../iter/macros.rs:361:24: loop not vectorized
remark: /.../iter/macros.rs:0:0: Cannot SLP vectorize list: vectorization was impossible with available vectorization factors
remark: /.../iter/macros.rs:370:14: Cannot SLP vectorize list: only 2 elements of buildvalue, trying reduction first.
remark: /.../iter/macros.rs:0:0: Cannot SLP vectorize list: vectorization was impossible with available vectorization factors
Remarks listed above highlights two important things:
LLVM
has 2 types of vectorization optimizations: LoopVectorizer
and SLPVectorizer
LoopVectorizer
was unable to optimize the loop because it has early return (Cannot vectorize potentially faulting early exit loop
)
SLPVectorizer
was unable to optimize the loop too for some other reason
Let’s try to fix LoopVectorizer
’s complaint and implement the find
method without early return!
To make things clear, early return appears in our code because under the hood it will be transformed into following form:
fn find(haystack: &[u8], needle: u8) -> Option<usize> {
for (i, &b) in haystack.iter().enumerate() {
if b == needle {
return Some(i) // <- here, the early return!
}
}
None
}
find
without early returnsAs our implementation of find
method needs to find first occurrence of the element – it needs to return when it first encounter the element during natural left-to-right slice traversal. So, in case of natural order enumeration return
is inevitable. But what if we traverse slice elements in reverse order, from right-to-left? In this case we will need to find “last” occurrence of the element (according to the enumeration order) and we will no longer need an early return!
Note, that right-to-left enumeration is pretty inefficient as it needs to traverse whole slice in order to find first occurrence, while original find
implementation works much faster if first needle
occurrence appear close to the slice start. We will fix that later, but for now let’s focus on fixing the early return problem to please the vectorization optimizators.
So, the implementation of the find
method without early returns looks like this:
pub fn find_no_early_returns(haystack: &[u8], needle: u8) -> Option<usize> {
let mut position = None;
for (i, &b) in haystack.iter().enumerate().rev() {
if b == needle {
position = Some(i);
}
}
position
}
Unfortunately, this doesn’t help – there are still no SIMD
instructions in the output assembly. But we can notice drastic changes in the output binary and also in the remarks produced by LLVM
– now compiler unrolled our main loop and compare bytes in chunks of size 8:
# there is just a part of the assembler, you can find full output by the godbolt link
.LBB0_11:
cmp byte ptr [r8 + r11 - 1], dl
lea r14, [rsi + r11 - 1]
cmovne r14, rcx
lea rcx, [rsi + r11 - 2]
cmove rax, rbx
cmp byte ptr [r8 + r11 - 2], dl
cmovne rcx, r14
lea r14, [rsi + r11 - 3]
cmove rax, rbx
...
cmove rax, rbx
cmp byte ptr [r8 + r11 - 8], dl
cmovne rcx, r14
cmove rax, rbx
You can notice that the compiler generates truly branch-less code for the unrolled section (literally, no jump instructions!). This can be surprising at the first sight, but actually compiler make use of cmove
(“conditional move”) instruction which move value between operands only if the flags
register are in the specific state. This instruction has much better performance than an ordinary CMP
/JEQ
pair and can be used in a simple conditional scenarios like we have in the no-early-return implementation of the find
function.
And LLVM
remarks now are different:
$> rustc main.rs -C opt-level=3 -C target-cpu=native -Cdebuginfo=full \
-Cllvm-args=-pass-remarks-analysis='.*' \
-Cllvm-args=-pass-remarks-missed='.*' \
-Cllvm-args=-pass-remarks='.*'
remark: /.../slice/iter/macros.rs:25:86: loop not vectorized: value that could not be identified as reduction is used outside the loop
remark: /.../slice/iter/macros.rs:25:86: loop not vectorized
remark: main.rs:0:0: Cannot SLP vectorize list: vectorization was impossible with available vectorization factors
remark: main.rs:0:0: List vectorization was possible but not beneficial with cost 0 >= 0
remark: main.rs:0:0: List vectorization was possible but not beneficial with cost 0 >= 0
remark: main.rs:17:2: Cannot SLP vectorize list: only 2 elements of buildvalue, trying reduction first.
remark: main.rs:0:0: List vectorization was possible but not beneficial with cost 0 >= 0
remark:
:0:0: Cannot SLP vectorize list: vectorization was impossible with available vectorization factors remark:
:0:0: List vectorization was possible but not beneficial with cost 0 >= 0 remark:
:0:0: List vectorization was possible but not beneficial with cost 0 >= 0 remark: /.../slice/iter/macros.rs:25:86: unrolled loop by a factor of 8 with run-time trip count
Here, we see that LoopVectorizer
was still unable to optimize the loop due to the external variable (pos
) used in it. However, SLPVectorizer
kicked in and decided not to optimize the loop, even though it was possible (so it actually found a way to vectorize the loop, hooray!).
It’s a bit strange that LoopVectorizer
was unable to detect the reduction variable in our case because, for arithmetic operations, it usually works well. But maybe LoopVectorizer
expect nice math properties from the operation (like commutativity) which assignment operator obviously doesn’t have.
SLPVectorizer
!The SLPVectorizer
under the hood “combines similar independent instructions into vector instructions” and it’s not obvious what it can do in our case because we already have all operations under the loop and there is no more repetitions.
Example of SLPVectorizer
from the LLVM
documentation looks like this:
void foo(int a1, int a2, int b1, int b2, int *A) {
A[0] = a1*(a1 + b1);
A[1] = a2*(a2 + b2);
A[2] = a1*(a1 + b1);
A[3] = a2*(a2 + b2);
}
But given that the loop-unroll
optimization was triggered in our case, it seems natural that after unrolling SLPVectorizer
can optimize a bunch of repeated actions into vectorized code. What we can do here is to make more explicit that unrolling is beneficial in our case by using fixed-size slice &[u8; 32]
instead of arbitrary slice. And this works!
// only signature changed - rest of the code is the same
pub fn find_no_early_returns(haystack: &[u8; 32], needle: u8) -> Option<usize> { ... }
For the signature above we will finally see vectorized code in the compiler output:
example::find_no_early_returns:
vmovd xmm0, esi
vpbroadcastb ymm0, xmm0
vpcmpeqb ymm0, ymm0, ymmword ptr [rdi]
vpmovmskb ecx, ymm0
mov eax, ecx
shr eax, 30
and eax, 1
xor rax, 31
test ecx, 536870912
mov edx, 29
cmove rdx, rax
... # there are a lot of instructions for determining actual position of matched byte
test cl, 2
mov esi, 1
cmove rsi, rax
xor edx, edx
test cl, 1
cmove rdx, rsi
xor eax, eax
test ecx, ecx
setne al
vzeroupper
ret
Also, SLPVectorizer
reported successful optimization of the loop in the compiler remarks:
$> rustc main.rs -C opt-level=3 -C target-cpu=native -Cdebuginfo=full \
-Cllvm-args=-pass-remarks-analysis='.*' \
-Cllvm-args=-pass-remarks-missed='.*' \
-Cllvm-args=-pass-remarks='.*'
remark: main.rs:17:2: Cannot SLP vectorize list: only 2 elements of buildvalue, trying reduction first.
remark: main.rs:12:12: Vectorized horizontal reduction with cost -21 and with tree size 3
Ok, the victory seems very close! Now we need to overcome the issue that in order for SLPVectorizer
to work we need to utilize fixed-size slices.
find
methodAs we know that SLPVectorizer
kicks in when slice has fixed size, we should split our original slice into fixed-size chunks and then process them independently with the method without early return. Our first attempt can be to use chunks method which do exactly what we want, so let’s try it out:
While chunks method splits slice into sub-slices of unknown size from the language perspective (they are still &[u8]
– not &[u8; N]
), we can hope that LLVM
aggressive inlining will help us here and optimizer will understand that slices are fixed sized after all inlining will happen under the hood.
fn find_no_early_returns(haystack: &[u8], needle: u8) -> Option<usize> { ... }
pub fn find(haystack: &[u8], needle: u8) -> Option<usize> {
haystack
.chunks(32)
.enumerate()
.find_map(|(i, chunk)| find_no_early_returns(chunk, needle).map(|x| 32 * i + x) )
}
Unfortunately, this doesn’t work – the compiler again produces boring assembly with only the unrolling optimization enabled. But, if we stop and think about it a bits, this is actually expected behaviour! The chunks method treats all chunks uniformly, including the last chunk – which can have a size less than 32 elements!
Luckily, Rust
developer team thought about this and added method chunks_exact specifically for such cases! This method split slice in equally sized chunks and provides access to the tail of potentially smaller size through additional method: remainder
.
This final step allows us to make our dream come true: a vectorized find function using only safe Rust
!
// bonus: refactoring of find_branchless function to make it more elegant!
fn find_no_early_returns(haystack: &[u8], needle: u8) -> Option<usize> {
return haystack.iter().enumerate()
.filter(|(_, &b)| b == needle)
.rfold(None, |_, (i, _)| Some(i))
}
fn find(haystack: &[u8], needle: u8) -> Option<usize> {
let chunks = haystack.chunks_exact(32);
let remainder = chunks.remainder();
chunks.enumerate()
.find_map(|(i, chunk)| find_no_early_returns(chunk, needle).map(|x| 32 * i + x) )
.or(find_no_early_returns(remainder, needle).map(|x| (haystack.len() & !0x1f) + x))
}
While LLVM
is good at auto-vectorization, sometimes you need to push it a bit to make your case more obvious for optimization passes. And here Rust
standard library combined with LLVM
aggressive inlining will allow to make non-trivial structural changes in the code while keeping it completely safe (I imagine that chunks_exact
method can be very helpful for various other cases).
Also, it was interesting to look how LLVM
compiler hints can be exposed to the developer for more deep understanding of the compiler behavior (turned out this is very easy to do as rustc
allow you to pass additional arguments to the LLVM
compiler).
And as a final result, the last implementation of the find
method is 9 times faster than the naive implementation presented in the beginning of the article. You can find the full benchmark source code here: rust-find-bench
find_chunks_exact_no_early_return
– no-early-return version with chunk_exact
method
find_chunks_exact
– naive version with chunks_exact
method
find_naive
– naive version from the beginning of the article
find_chunks
– naive version with chunks
method
method | time | speedup |
---|---|---|
find_chunks_exact_no_early_return |
40.18ns |
x9.0 |
find_chunks_exact |
126.77ns |
x2.7 |
find_naive |
356.07ns |
x1.0 |
find_chunks |
510.16ns |
x0.7 |