Skip to content

Commit

Permalink
Union experiment
Browse files Browse the repository at this point in the history
  • Loading branch information
edsko committed Mar 6, 2025
1 parent 2730554 commit 20d9164
Show file tree
Hide file tree
Showing 9 changed files with 463 additions and 0 deletions.
29 changes: 29 additions & 0 deletions union-experiment/LICENSE
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
Copyright (c) 2024-2025, Well-Typed LLP and Anduril Industries Inc.


Redistribution and use in source and binary forms, with or without
modification, are permitted provided that the following conditions are met:

* Redistributions of source code must retain the above copyright
notice, this list of conditions and the following disclaimer.

* Redistributions in binary form must reproduce the above
copyright notice, this list of conditions and the following
disclaimer in the documentation and/or other materials provided
with the distribution.

* Neither the name of the copyright holder nor the names of its
contributors may be used to endorse or promote products derived
from this software without specific prior written permission.

THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
"AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
146 changes: 146 additions & 0 deletions union-experiment/app/Example/C.hsc
Original file line number Diff line number Diff line change
@@ -0,0 +1,146 @@
module Example.C (
Dim2(..)
, Dim3(..)
, DimPayload(..)
, Dim(..)
, set_DimPayload_dim2
, set_DimPayload_dim3
, get_DimPayload_dim2
, get_DimPayload_dim3
, max
, grow
) where

import Prelude hiding (max)

import Data.Primitive.ByteArray
import Foreign
import Foreign.C
import System.IO.Unsafe

import Util.PeekPokeByteArray

#include "example.h"

{-------------------------------------------------------------------------------
Dim2 and Dim3
-------------------------------------------------------------------------------}

data Dim2 = Dim2 CInt CInt
data Dim3 = Dim3 CInt CInt CInt

instance Storable Dim2 where
sizeOf _ = #size Dim2_t
alignment _ = #alignment Dim2_t

peek ptr = do
x <- (#peek Dim2_t, x) ptr
y <- (#peek Dim2_t, y) ptr
return (Dim2 x y)

poke ptr (Dim2 x y) = do
(#poke Dim2_t, x) ptr x
(#poke Dim2_t, y) ptr y

instance Storable Dim3 where
sizeOf _ = #size Dim3_t
alignment _ = #alignment Dim3_t

peek ptr = do
x <- (#peek Dim3_t, x) ptr
y <- (#peek Dim3_t, y) ptr
z <- (#peek Dim3_t, z) ptr
return (Dim3 x y z)

poke ptr (Dim3 x y z) = do
(#poke Dim3_t, x) ptr x
(#poke Dim3_t, y) ptr y
(#poke Dim3_t, z) ptr z

{-------------------------------------------------------------------------------
DimPayload
-------------------------------------------------------------------------------}

newtype DimPayload = UnsafeDimPayload {
getDimPayload :: ByteArray
}

mkDimPayload :: ByteArray -> DimPayload
mkDimPayload bytes
| sizeofByteArray bytes == (#size DimPayload_t)
= UnsafeDimPayload bytes

| otherwise
= error $ concat [
"mkDimPayload: expected "
, show ((#size DimPayload_t) :: CInt)
, " bytes, but got "
, show (sizeofByteArray bytes)
]

instance Storable DimPayload where
sizeOf _ = #size DimPayload_t
alignment _ = #alignment DimPayload_t

peek = \ptr -> mkDimPayload <$> peekByteArray ptr (#size DimPayload_t)
poke = \ptr -> pokeByteArray ptr . getDimPayload

set_DimPayload_dim2 :: Dim2 -> DimPayload
set_DimPayload_dim2 = setUnionPayload

set_DimPayload_dim3 :: Dim3 -> DimPayload
set_DimPayload_dim3 = setUnionPayload

get_DimPayload_dim2 :: DimPayload -> Dim2
get_DimPayload_dim2 = getUnionPayload

get_DimPayload_dim3 :: DimPayload -> Dim3
get_DimPayload_dim3 = getUnionPayload

{-------------------------------------------------------------------------------
Dim
-------------------------------------------------------------------------------}

data Dim = Dim {
tag :: CInt
, payload :: DimPayload
}

instance Storable Dim where
sizeOf _ = #size Dim_t
alignment _ = #alignment Dim_t

peek ptr = do
tag <- (#peek Dim_t, tag) ptr
payload <- (#peek Dim_t, payload) ptr
return Dim{tag, payload}

poke ptr Dim{tag, payload} = do
(#poke Dim_t, tag) ptr tag
(#poke Dim_t, payload) ptr payload

{-------------------------------------------------------------------------------
Foreign imports
-------------------------------------------------------------------------------}

foreign import capi "example.h dim_max"
c_dim_max :: Ptr Dim -> IO CInt

foreign import capi "example.h dim_grow"
c_dim_grow :: Ptr Dim -> Ptr Dim -> IO ()

{-------------------------------------------------------------------------------
Wrapper functions
-------------------------------------------------------------------------------}

max :: Dim -> CInt
max inp = unsafePerformIO $
with inp $ \inpPtr -> c_dim_max inpPtr

grow :: Dim -> Dim
grow inp = unsafePerformIO $
with inp $ \inpPtr ->
alloca $ \outPtr -> do
c_dim_grow inpPtr outPtr
peek outPtr

70 changes: 70 additions & 0 deletions union-experiment/app/Example/Hs.hs
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
module Example.Hs (
-- * Definition
Dim(..)
, max
, grow
-- * Translation to C
, toC
, fromC
) where

import Prelude hiding (max)

import Foreign.C (CInt)
import Test.QuickCheck

import Example.C qualified as C

{-------------------------------------------------------------------------------
Definition
-------------------------------------------------------------------------------}

data Dim =
Dim2 Int Int
| Dim3 Int Int Int
deriving stock (Show, Eq)

max :: Dim -> Int
max (Dim2 x y) = maximum [x, y]
max (Dim3 x y z) = maximum [x, y, z]

grow :: Dim -> Dim
grow dim = Dim3 (max dim) (max dim) (max dim)

{-------------------------------------------------------------------------------
QuickCheck
-------------------------------------------------------------------------------}

instance Arbitrary Dim where
arbitrary = fromEither <$> arbitrary
shrink = map fromEither . shrink . toEither

toEither :: Dim -> Either (Int, Int) (Int, Int, Int)
toEither (Dim2 x y) = Left (x, y)
toEither (Dim3 x y z) = Right (x, y, z)

fromEither :: Either (Int, Int) (Int, Int, Int) -> Dim
fromEither (Left (x, y)) = Dim2 x y
fromEither (Right (x, y, z)) = Dim3 x y z

{-------------------------------------------------------------------------------
Translation
-------------------------------------------------------------------------------}

toC :: Dim -> C.Dim
toC = \case
Dim2 x y -> C.Dim 0 (C.set_DimPayload_dim2 $ C.Dim2 (c x) (c y))
Dim3 x y z -> C.Dim 1 (C.set_DimPayload_dim3 $ C.Dim3 (c x) (c y) (c z))
where
c :: Int -> CInt
c = fromIntegral

fromC :: C.Dim -> Dim
fromC = \case
C.Dim 0 (C.get_DimPayload_dim2 -> C.Dim2 x y) -> Dim2 (c x) (c y)
C.Dim 1 (C.get_DimPayload_dim3 -> C.Dim3 x y z) -> Dim3 (c x) (c y) (c z)
C.Dim tag _ -> error $ "fromC: unexpected tag " ++ show tag
where
c :: CInt -> Int
c = fromIntegral

25 changes: 25 additions & 0 deletions union-experiment/app/Main.hs
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
module Main where

import Test.Tasty
import Test.Tasty.QuickCheck

import Example.Hs qualified as Hs
import Example.C qualified as C

main :: IO ()
main = defaultMain $ testGroup "union-experiment" [
testGroup "hs-vs-c" [
testProperty "max" prop_CHs_max
, testProperty "grow" prop_CHs_grow
]
]

prop_CHs_max :: Hs.Dim -> Property
prop_CHs_max dim =
Hs.max dim
=== (fromIntegral . C.max . Hs.toC $ dim)

prop_CHs_grow :: Hs.Dim -> Property
prop_CHs_grow dim =
Hs.grow dim
=== (Hs.fromC . C.grow . Hs.toC $ dim)
90 changes: 90 additions & 0 deletions union-experiment/app/Util/PeekPokeByteArray.hs
Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@
-- | Utilities for dealing with 'ByteArray' and 'Storable'
--
-- The additional copying we have to do here is a bit annoying, but in the end
-- an FFI implementation based on 'Storable' is never going to be /extremely/
-- fast, as we are effectively (de)serializing. A few additional @memcpy@
-- operations are therefore not going to be a huge difference.
--
-- We /could/ choose to use pinned bytearrays. This would avoid /some/ copying,
-- but by no means all: we'd still need one copy (instead of two) in
-- 'peekByteArray' and 'pokeByteArray', and the calls to 'peek' and 'poke' in
-- 'peekFromByteArray' and 'pokeToByteArray' will (likely) do copying of their
-- own as well.
module Util.PeekPokeByteArray (
-- * Support for defining 'Storable' instances for union types
peekByteArray
, pokeByteArray
-- * Support for defining setters and getters for union types
, setUnionPayload
, getUnionPayload
) where

import Control.Monad.Primitive
import Data.Coerce
import Data.Primitive.ByteArray
import Foreign
import System.IO.Unsafe

{-------------------------------------------------------------------------------
Support for defining 'Storable' instances for union types
-------------------------------------------------------------------------------}

peekByteArray :: Ptr a -> Int -> IO ByteArray
peekByteArray src n = do
pinnedCopy <- newPinnedByteArray n
withMutableByteArrayContents pinnedCopy $ \dest ->
copyBytes dest (castPtr src) n
freezeByteArray pinnedCopy 0 n

pokeByteArray :: Ptr a -> ByteArray -> IO ()
pokeByteArray dest bytes = do
pinnedCopy <- thawPinned bytes
withMutableByteArrayContents pinnedCopy $ \src ->
copyBytes dest (castPtr src) n
where
n = sizeofByteArray bytes

{-------------------------------------------------------------------------------
Support for defining setters and getters for union types
-------------------------------------------------------------------------------}

setUnionPayload :: forall payload union.
( Storable payload
, Storable union
, Coercible union ByteArray
)
=> payload -> union
setUnionPayload = coerce . pokeToByteArray (sizeOf (undefined :: union))

getUnionPayload :: forall payload union.
( Storable payload
, Storable union
, Coercible union ByteArray
)
=> union -> payload
getUnionPayload = peekFromByteArray . coerce

peekFromByteArray :: Storable a => ByteArray -> a
peekFromByteArray bytes = unsafePerformIO $ do
pinnedCopy <- thawPinned bytes
withMutableByteArrayContents pinnedCopy $ \ptr ->
peek (castPtr ptr)

pokeToByteArray :: Storable a => Int -> a -> ByteArray
pokeToByteArray n x = unsafePerformIO $ do
pinnedCopy <- newPinnedByteArray n
withMutableByteArrayContents pinnedCopy $ \ptr ->
poke (castPtr ptr) x
freezeByteArray pinnedCopy 0 n

{-------------------------------------------------------------------------------
Internal auxiliary
-------------------------------------------------------------------------------}

thawPinned :: ByteArray -> IO (MutableByteArray RealWorld)
thawPinned src = do
dest <- newPinnedByteArray n
copyByteArray dest 0 src 0 n
return dest
where
n = sizeofByteArray src
1 change: 1 addition & 0 deletions union-experiment/cabal.project
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
packages: .
24 changes: 24 additions & 0 deletions union-experiment/cbits/example.c
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
#include "example.h"

int maxint(int a, int b) {
return a > b ? a : b;
}

int dim_max(Dim_t* dim) {
switch(dim->tag) {
case 0: {
Dim2_t* payload = &(dim->payload.dim2);
return maxint(payload->x, payload->y);
}
default: {
Dim3_t* payload = &(dim->payload.dim3);
return maxint(payload->x, maxint(payload->y, payload->z));
}
};
}

void dim_grow(Dim_t* in, Dim_t* out) {
int max = dim_max(in);
Dim_t result = { tag: 1, payload: { dim3: { x: max, y: max, z: max } } };
*out = result;
}
Loading

0 comments on commit 20d9164

Please sign in to comment.