Now that the ICFP deadline is past, I've returned to working on adding SIMD
support to GHC and associated libraries. The short-term goal is to be able to
leverage SIMD instructions from the vector
and—hopefully
transparently—from
Data Parallel Haskell. Back in November I added a fairly complete set of
primops to GHC
that generate SSE instructions when using the LLVM back-end. If you're
interested in the original design for SIMD support in GHC and the state of
implementation, you can read about it
here. It's a bit of a hack
and requires the LLVM back-end, but it does work.
Primops are necessary, but what we'd really like a higher-level interface to
these low-level instructions. This post is a very short introduction to my
recent efforts in that direction. All the code I describe is public—you can
find directions for getting up and running with the GHC simd branch
on github.
Because this is Haskell, we'll start by
introducing a data type for a SIMD vector that is indexed by the type of the
scalar values it contains. The term "vector" is already overloaded, so, at Simon
Peyton Jones' suggestion, we call a SIMD vector containing scalars of type a
a
Multi a
. Because we want to choose a different primitive representation for
each different a
, Multi
is a
type family (actually
an associated type family). Along with Multi
, we define a type class
MultiPrim
that allows us to treat primitive operations on Multi
's in a
uniform way, just as the Prim
type class defined by the
primitive
library allows for
scalars. Here's the first part of the definition of the MultiPrim
type class
and the Multi
associated family. You can see that it defines functions for
replicating scalars across Multi
's, folding a function over the scalar
elements of a Multi
, reading a Multi
out of a ByteArray#
, etc. Right now
there are instance definitions for Multi Float
, Multi Double
, Multi Int32
,
Multi Int64
, and Multi Int
. This type class and the rest of the code I'll be
showing are
actually part of the simd
branch of the vector
library that I've but on
github. You can go look there for
further details, like the Num
instances defined for the Multi
's.
class (Prim a, Prim (Multi a)) => MultiPrim a where
data Multi a
-- | The number of elements of type @a@ in a @Multi a@.
multiplicity :: Multi a -> Int
-- | Replicate a scalar across a @Multi a@.
multireplicate :: a -> Multi a
-- | Map a function over the elements of a @Multi a@.
multimap :: (a -> a) -> Multi a -> Multi a
-- | Fold a function over the elements of a @Multi a@.
multifold :: (b -> a -> b) -> b -> Multi a -> b
-- | Read a multi-value from the array. The offset is in elements of type
-- @a@ rather than in elements of type @Multi a@.
indexByteArrayAsMulti# :: ByteArray# -> Int# -> Multi a
Now that we have the Multi
type, we would like to use it operate over
Vector
's—that is, vector types from the
vector
library. A Vector
has
scalar elements, so for us to be able to use SIMD operations on these scalars we
need to know something about the representation the Vector
uses, namely that
it lays out scalars contiguously in memory. The PackedVector
type class lets
us express this constraint in Haskell's type system, and I won't say anything
more about it here, but instances are defined for the appropriate vector types
in the Data.Vector.Unboxed
and Data.Vector.Storable
modules.
Of course the next step is to define appropriate versions of our old friends,
map, zip, and fold, that will let us exploit SIMD operations. Here they are.
mmap :: (PackedVector v a, PackedVector v b)
=> (a -> b)
-> (Multi a -> Multi b)
-> v a
-> v b
mzipWith :: (PackedVector v a, PackedVector v b, PackedVector v c)
=> (a -> b -> c)
-> (Multi a -> Multi b -> Multi c)
-> v a
-> v b
-> v c
mfoldl' :: PackedVector v b
=> (a -> b -> a)
-> (a -> Multi b -> a)
-> a
-> v b
-> a
If you're familiar with the vector
library, you may know it uses
stream fusion to generate very efficient code—many operations are typically
compiled
to tight loops similar to what one would get from a C compiler. Stream fusion
works by re-expressing high-level operations like map, zip, and fold in terms of
"step" functions. Each step function takes some state and an element and
produces either some new state and a new element, just some new state, or a
value that says it is done processing elements. To support computing over
vectors
using SIMD operations, I have added a new "stream" variant so that step
functions can receive not just scalar elements, but Multi
elements. That is,
at every step, the stream consumer could be handed either a scalar or a Multi
and must be prepared for either case. mmap
, mzipWith
, and mfoldl
are
almost exactly like their scalar-only counterparts, but they each take an extra
function argument for handling Multi
's.
Let's see if this stuff actually works by starting off with something easy---
summing up all the elements in a vector. The following code uses the new
vector
library primitives multifold
and U.mfoldl
to exploit SIMD
instructions.
import qualified Data.Vector.Unboxed as U
import Data.Primitive.Multi
multisum :: U.Vector Float -> Float
multisum v =
multifold (+) s ms
where
s :: Float
ms :: Multi Float
(s, ms) = U.mfoldl' plus1 plusm (0, 0) v
plusm (x, mx) my = (x, mx + my)
plus1 (x, mx) y = (x + y, mx)
We'll compare it with five other versions. "Scalar" and "Scalar (C)" are plain
old scalar versions written in Haskell and C, respectively. "Manual" and "Manual
(C)" are hand-written Haskell and C versions, respectively. The Haskell version
explicitly iterates over the vector instead of using a fold. The vector
version is the code we just saw, and the multivector
version is based on a
library I wrote to test out fusion when I first added SSE support to GHC. It
implements a small subset of the vector
library API. Here we go
Not bad. The following table gives the timings for vectors with 224 elements.
In this case, Haskell is as fast as C. This isn't too surprising, as we've
seen before that Haskell can be as fast as C.
Timings for the sum
function. Summing vectors of size 224.
Variant | Time (ms) |
Scalar | 19.7 ± 0.2 |
Scalar (C) | 19.7 ± 0.4 |
Manual | 4.62 ± 0.03 |
Manual (C) | 4.58 ± 0.02 |
vector | 4.62 ± 0.02 |
multivector | 4.62 ± 0.02 |
Of course, summing up the elements in a vector isn't so hard. The great thing
about the vector
library is that you can write high-level Haskell code and,
through the magic of fusion, you end up with a tight inner loop that looks like
what you might have gotten out of a C compiler had you chosen to write in C.
Let's try s slightly more difficult computation that will require fusion—dot
product.
Computing the dot product efficiently requires fusing two loops to perform a
combined addition and multiplication. Here is the scalar version in Haskell
import qualified Data.Vector.Unboxed as U
dotp :: U.Vector Float -> U.Vector Float -> Float
dotp v w =
U.sum $ U.zipWith (*) v w
And here is our first cut at a SIMD version.
import qualified Data.Vector.Unboxed as U
import Data.Primitive.Multi
multidotp :: U.Vector Float -> U.Vector Float -> Float
multidotp v w =
multifold (+) s ms
where
s :: Float
ms :: Multi Float
(s, ms) = U.mfoldl' plus1 plusm (0, 0) $ U.mzipWith (*) (*) v w
plusm (x, mx) my = (x, mx + my)
plus1 (x, mx) y = (x + y, mx)
Let's look at performance once more. Again, "Manual" is a Haskell version that
manually iterates over the vector once and fuses the addition and
multiplication, the idea being that this is what we would hope to get out of GHC
after fusion, inlining, constructor specialization, etc.
For reference, here are the timings for the case with n = 224 again.
Timings for the dotp
function. Calculating the dot product of vectors of size 224.
Variant | Time (ms) |
Scalar | 16.98 ± 0.08 |
Scalar (C) | 16.63 ± 0.09 |
Manual | 8.87 ± 0.03 |
Manual (C) | 8.64 ± 0.02 |
vector | 13.03 ± 0.07 |
multivector | 9.5 ± 0.1 |
Not so hot. Although our hand-written Haskell implementation ("Manual" in the
plot and table) is competitive with C, the vector
version is not.
Interestingly, the "multivector" version is competitive. What could be going
on?
The first things that jumps to mind is that fusion might not be kicking in: I
could've screwed up the implementation of the SIMD-enabled combinators! To check
this hypothesis, let's look at the
GHC core generated for the main loop in multidotp
(this is the loop that
iterates over elements SIMD-vector-wise):
1: letrec {
2: $s$wmfoldlM_loopm_s4ri [Occ=LoopBreaker]
3: :: GHC.Prim.Int#
4: -> GHC.Prim.Int#
5: -> GHC.Prim.~#
6: *
7: Data.Primitive.Multi.FloatX4.FloatX4
8: (Data.Primitive.Multi.Multi GHC.Types.Float)
9: -> GHC.Prim.FloatX4#
10: -> GHC.Prim.Float#
11: -> (# GHC.Types.Float,
12: Data.Primitive.Multi.Multi GHC.Types.Float #)
13: [LclId, Arity=5, Str=DmdType LLLLL]
14: $s$wmfoldlM_loopm_s4ri =
15: \ (sc_s4nR :: GHC.Prim.Int#)
16: (sc1_s4nS :: GHC.Prim.Int#)
17: (sg_s4nT
18: :: GHC.Prim.~#
19: *
20: Data.Primitive.Multi.FloatX4.FloatX4
21: (Data.Primitive.Multi.Multi GHC.Types.Float))
22: (sc2_s4nU :: GHC.Prim.FloatX4#)
23: (sc3_s4nV :: GHC.Prim.Float#) ->
24: case GHC.Prim.>=# sc1_s4nS ipv7_aHm of _ {
25: GHC.Types.False ->
26: case GHC.Prim.indexFloatArrayAsFloatX4#
27: ipv2_s4kn (GHC.Prim.+# ipv_s4kl sc1_s4nS)
28: of wild_a4j3 { __DEFAULT ->
29: case GHC.Prim.>=# sc_s4nR ipv6_XI3 of _ {
30: GHC.Types.False ->
31: case GHC.Prim.indexFloatArrayAsFloatX4#
32: ipv5_s4l7 (GHC.Prim.+# ipv3_s4l5 sc_s4nR)
33: of wild3_X4jF { __DEFAULT ->
34: $s$wmfoldlM_loopm_s4ri
35: (GHC.Prim.+# sc_s4nR 4)
36: (GHC.Prim.+# sc1_s4nS 4)
37: @~ (Sym (Data.Primitive.Multi.NTCo:R:MultiFloat) ; Sym
38: (Data.Primitive.Multi.TFCo:R:MultiFloat))
39: (GHC.Prim.plusFloatX4#
40: sc2_s4nU (GHC.Prim.timesFloatX4# wild_a4j3 wild3_X4jF))
41: sc3_s4nV
42: };
43: GHC.Types.True -> ...
44: }
45: };
46: GHC.Types.True -> ...
47: };
We can see that the two loops have been fused. I won't show the core for the
other Haskell implementations, but I'll note that it looks pretty much the same
except for one thing: multidotp
is carrying around two pieces of state
during the fold it performs, a scalar Float
and a Multi Float
. That
shouldn't make a difference though—these guys should just live in two separate
registers. There's only one reasonable thing left to do: look at some assembly.
Just so we have an idea of what we want to see, let's examine the inner loop
of the C version first:
.L3:
movaps (%rdi,%rax), %xmm0
mulps (%rdx,%rax), %xmm0
addq $16, %rax
cmpq %r8, %rax
addps %xmm0, %xmm1
jne .L3
Cool. Our array pointers live in rdi
and rdx
, our index in rax
, and the
array bounds in r8
. Now on to the "manual" Haskell version.
.LBB5_3: # %n5oi
# =>This Inner Loop Header: Depth=1
movups (%rcx), %xmm2
movups (%rdx), %xmm1
mulps %xmm2, %xmm1
addps %xmm1, %xmm0
addq $16, %rcx
addq $16, %rdx
addq $4, %r14
cmpq %r14, %rax
jg .LBB5_3
Still pretty good. This time our array pointers live in rcx
and rdx
, our
index in r14
, and our bounds in rax
. Note that the index is now measured in
float
's instead of bytes. How about the "multivector" version?
1: .LBB1_2: # %n3JW.i
2: # in Loop: Header=BB1_1 Depth=1
3: cmpq %rax, %r8
4: jle .LBB1_5
5: # BB#3: # %n3K9.i
6: # in Loop: Header=BB1_1 Depth=1
7: movq 8(%rcx), %rdx
8: addq %rax, %rdx
9: movq 16(%rcx), %rdi
10: movups 16(%rdi,%rdx,4), %xmm2
11: movups (%rbx), %xmm1
12: mulps %xmm2, %xmm1
13: addps %xmm1, %xmm0
14: movups %xmm0, -56(%rcx) #
15: addq $16, %rbx
16: addq $4, %rax
17: .LBB1_1: # %tailrecurse.i
18: # =>This Inner Loop Header: Depth=1
19: cmpq %rax, %r9
20: jg .LBB1_2
There is definitely more junk here. Still, not horrible except for line
14 where we spill the result to the stack. Apparently the
spill doesn't cost us much . Now the "vector" version that had performance issues.
1: .LBB4_2: # %n4H3
2: # in Loop: Header=BB4_1 Depth=1
3: cmpq %r14, 43(%rbx)
4: jle .LBB4_5
5: # BB#3: # %n4Hw
6: # in Loop: Header=BB4_1 Depth=1
7: movq 35(%rbx), %rdx
8: addq %r14, %rdx
9: movq 3(%rbx), %rcx
10: movq 11(%rbx), %rdi
11: movups 16(%rdi,%rdx,4), %xmm0
12: movq 27(%rbx), %rdx
13: addq %rsi, %rdx
14: movups 16(%rcx,%rdx,4), %xmm1
15: mulps %xmm0, %xmm1
16: movups (%rbp), %xmm0 #
17: addps %xmm1, %xmm0
18: movups %xmm0, (%rbp) #
19: addq $4, %r14
20: addq $4, %rsi
21: .LBB4_1: # %tailrecurse
22: # =>This Inner Loop Header: Depth=1
23: cmpq %rsi, %rax
24: jg .LBB4_2
Ah-hah, there's our likely culprit: our accumulator is loaded from the stack in
line 16 and spilled back in line 18.
Yuck! It looks like carrying around that extra bit of state really cost us. I'm
not sure why LLVM didn't spill the Float
portion of the state to the stack
temporarily so that it could use the register for the main loop, but it seems
likely that it is related to the
GHC calling convention used by the LLVM back-end.
I'm disappointed that we weren't able to get C-competitive performance from our
high-level Haskell code, especially since it seems so tantalizingly close. At
least there is hope that with some prodding we can convince LLVM to keep our
accumulating parameter in a register.