-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
9 changed files
with
463 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,29 @@ | ||
Copyright (c) 2024-2025, Well-Typed LLP and Anduril Industries Inc. | ||
|
||
|
||
Redistribution and use in source and binary forms, with or without | ||
modification, are permitted provided that the following conditions are met: | ||
|
||
* Redistributions of source code must retain the above copyright | ||
notice, this list of conditions and the following disclaimer. | ||
|
||
* Redistributions in binary form must reproduce the above | ||
copyright notice, this list of conditions and the following | ||
disclaimer in the documentation and/or other materials provided | ||
with the distribution. | ||
|
||
* Neither the name of the copyright holder nor the names of its | ||
contributors may be used to endorse or promote products derived | ||
from this software without specific prior written permission. | ||
|
||
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS | ||
"AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT | ||
LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR | ||
A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT | ||
HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, | ||
SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT | ||
LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, | ||
DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY | ||
THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT | ||
(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE | ||
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,146 @@ | ||
module Example.C ( | ||
Dim2(..) | ||
, Dim3(..) | ||
, DimPayload(..) | ||
, Dim(..) | ||
, set_DimPayload_dim2 | ||
, set_DimPayload_dim3 | ||
, get_DimPayload_dim2 | ||
, get_DimPayload_dim3 | ||
, max | ||
, grow | ||
) where | ||
|
||
import Prelude hiding (max) | ||
|
||
import Data.Primitive.ByteArray | ||
import Foreign | ||
import Foreign.C | ||
import System.IO.Unsafe | ||
|
||
import Util.PeekPokeByteArray | ||
|
||
#include "example.h" | ||
|
||
{------------------------------------------------------------------------------- | ||
Dim2 and Dim3 | ||
-------------------------------------------------------------------------------} | ||
|
||
data Dim2 = Dim2 CInt CInt | ||
data Dim3 = Dim3 CInt CInt CInt | ||
|
||
instance Storable Dim2 where | ||
sizeOf _ = #size Dim2_t | ||
alignment _ = #alignment Dim2_t | ||
|
||
peek ptr = do | ||
x <- (#peek Dim2_t, x) ptr | ||
y <- (#peek Dim2_t, y) ptr | ||
return (Dim2 x y) | ||
|
||
poke ptr (Dim2 x y) = do | ||
(#poke Dim2_t, x) ptr x | ||
(#poke Dim2_t, y) ptr y | ||
|
||
instance Storable Dim3 where | ||
sizeOf _ = #size Dim3_t | ||
alignment _ = #alignment Dim3_t | ||
|
||
peek ptr = do | ||
x <- (#peek Dim3_t, x) ptr | ||
y <- (#peek Dim3_t, y) ptr | ||
z <- (#peek Dim3_t, z) ptr | ||
return (Dim3 x y z) | ||
|
||
poke ptr (Dim3 x y z) = do | ||
(#poke Dim3_t, x) ptr x | ||
(#poke Dim3_t, y) ptr y | ||
(#poke Dim3_t, z) ptr z | ||
|
||
{------------------------------------------------------------------------------- | ||
DimPayload | ||
-------------------------------------------------------------------------------} | ||
|
||
newtype DimPayload = UnsafeDimPayload { | ||
getDimPayload :: ByteArray | ||
} | ||
|
||
mkDimPayload :: ByteArray -> DimPayload | ||
mkDimPayload bytes | ||
| sizeofByteArray bytes == (#size DimPayload_t) | ||
= UnsafeDimPayload bytes | ||
|
||
| otherwise | ||
= error $ concat [ | ||
"mkDimPayload: expected " | ||
, show ((#size DimPayload_t) :: CInt) | ||
, " bytes, but got " | ||
, show (sizeofByteArray bytes) | ||
] | ||
|
||
instance Storable DimPayload where | ||
sizeOf _ = #size DimPayload_t | ||
alignment _ = #alignment DimPayload_t | ||
|
||
peek = \ptr -> mkDimPayload <$> peekByteArray ptr (#size DimPayload_t) | ||
poke = \ptr -> pokeByteArray ptr . getDimPayload | ||
|
||
set_DimPayload_dim2 :: Dim2 -> DimPayload | ||
set_DimPayload_dim2 = setUnionPayload | ||
|
||
set_DimPayload_dim3 :: Dim3 -> DimPayload | ||
set_DimPayload_dim3 = setUnionPayload | ||
|
||
get_DimPayload_dim2 :: DimPayload -> Dim2 | ||
get_DimPayload_dim2 = getUnionPayload | ||
|
||
get_DimPayload_dim3 :: DimPayload -> Dim3 | ||
get_DimPayload_dim3 = getUnionPayload | ||
|
||
{------------------------------------------------------------------------------- | ||
Dim | ||
-------------------------------------------------------------------------------} | ||
|
||
data Dim = Dim { | ||
tag :: CInt | ||
, payload :: DimPayload | ||
} | ||
|
||
instance Storable Dim where | ||
sizeOf _ = #size Dim_t | ||
alignment _ = #alignment Dim_t | ||
|
||
peek ptr = do | ||
tag <- (#peek Dim_t, tag) ptr | ||
payload <- (#peek Dim_t, payload) ptr | ||
return Dim{tag, payload} | ||
|
||
poke ptr Dim{tag, payload} = do | ||
(#poke Dim_t, tag) ptr tag | ||
(#poke Dim_t, payload) ptr payload | ||
|
||
{------------------------------------------------------------------------------- | ||
Foreign imports | ||
-------------------------------------------------------------------------------} | ||
|
||
foreign import capi "example.h dim_max" | ||
c_dim_max :: Ptr Dim -> IO CInt | ||
|
||
foreign import capi "example.h dim_grow" | ||
c_dim_grow :: Ptr Dim -> Ptr Dim -> IO () | ||
|
||
{------------------------------------------------------------------------------- | ||
Wrapper functions | ||
-------------------------------------------------------------------------------} | ||
|
||
max :: Dim -> CInt | ||
max inp = unsafePerformIO $ | ||
with inp $ \inpPtr -> c_dim_max inpPtr | ||
|
||
grow :: Dim -> Dim | ||
grow inp = unsafePerformIO $ | ||
with inp $ \inpPtr -> | ||
alloca $ \outPtr -> do | ||
c_dim_grow inpPtr outPtr | ||
peek outPtr | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,70 @@ | ||
module Example.Hs ( | ||
-- * Definition | ||
Dim(..) | ||
, max | ||
, grow | ||
-- * Translation to C | ||
, toC | ||
, fromC | ||
) where | ||
|
||
import Prelude hiding (max) | ||
|
||
import Foreign.C (CInt) | ||
import Test.QuickCheck | ||
|
||
import Example.C qualified as C | ||
|
||
{------------------------------------------------------------------------------- | ||
Definition | ||
-------------------------------------------------------------------------------} | ||
|
||
data Dim = | ||
Dim2 Int Int | ||
| Dim3 Int Int Int | ||
deriving stock (Show, Eq) | ||
|
||
max :: Dim -> Int | ||
max (Dim2 x y) = maximum [x, y] | ||
max (Dim3 x y z) = maximum [x, y, z] | ||
|
||
grow :: Dim -> Dim | ||
grow dim = Dim3 (max dim) (max dim) (max dim) | ||
|
||
{------------------------------------------------------------------------------- | ||
QuickCheck | ||
-------------------------------------------------------------------------------} | ||
|
||
instance Arbitrary Dim where | ||
arbitrary = fromEither <$> arbitrary | ||
shrink = map fromEither . shrink . toEither | ||
|
||
toEither :: Dim -> Either (Int, Int) (Int, Int, Int) | ||
toEither (Dim2 x y) = Left (x, y) | ||
toEither (Dim3 x y z) = Right (x, y, z) | ||
|
||
fromEither :: Either (Int, Int) (Int, Int, Int) -> Dim | ||
fromEither (Left (x, y)) = Dim2 x y | ||
fromEither (Right (x, y, z)) = Dim3 x y z | ||
|
||
{------------------------------------------------------------------------------- | ||
Translation | ||
-------------------------------------------------------------------------------} | ||
|
||
toC :: Dim -> C.Dim | ||
toC = \case | ||
Dim2 x y -> C.Dim 0 (C.set_DimPayload_dim2 $ C.Dim2 (c x) (c y)) | ||
Dim3 x y z -> C.Dim 1 (C.set_DimPayload_dim3 $ C.Dim3 (c x) (c y) (c z)) | ||
where | ||
c :: Int -> CInt | ||
c = fromIntegral | ||
|
||
fromC :: C.Dim -> Dim | ||
fromC = \case | ||
C.Dim 0 (C.get_DimPayload_dim2 -> C.Dim2 x y) -> Dim2 (c x) (c y) | ||
C.Dim 1 (C.get_DimPayload_dim3 -> C.Dim3 x y z) -> Dim3 (c x) (c y) (c z) | ||
C.Dim tag _ -> error $ "fromC: unexpected tag " ++ show tag | ||
where | ||
c :: CInt -> Int | ||
c = fromIntegral | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,25 @@ | ||
module Main where | ||
|
||
import Test.Tasty | ||
import Test.Tasty.QuickCheck | ||
|
||
import Example.Hs qualified as Hs | ||
import Example.C qualified as C | ||
|
||
main :: IO () | ||
main = defaultMain $ testGroup "union-experiment" [ | ||
testGroup "hs-vs-c" [ | ||
testProperty "max" prop_CHs_max | ||
, testProperty "grow" prop_CHs_grow | ||
] | ||
] | ||
|
||
prop_CHs_max :: Hs.Dim -> Property | ||
prop_CHs_max dim = | ||
Hs.max dim | ||
=== (fromIntegral . C.max . Hs.toC $ dim) | ||
|
||
prop_CHs_grow :: Hs.Dim -> Property | ||
prop_CHs_grow dim = | ||
Hs.grow dim | ||
=== (Hs.fromC . C.grow . Hs.toC $ dim) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,90 @@ | ||
-- | Utilities for dealing with 'ByteArray' and 'Storable' | ||
-- | ||
-- The additional copying we have to do here is a bit annoying, but in the end | ||
-- an FFI implementation based on 'Storable' is never going to be /extremely/ | ||
-- fast, as we are effectively (de)serializing. A few additional @memcpy@ | ||
-- operations are therefore not going to be a huge difference. | ||
-- | ||
-- We /could/ choose to use pinned bytearrays. This would avoid /some/ copying, | ||
-- but by no means all: we'd still need one copy (instead of two) in | ||
-- 'peekByteArray' and 'pokeByteArray', and the calls to 'peek' and 'poke' in | ||
-- 'peekFromByteArray' and 'pokeToByteArray' will (likely) do copying of their | ||
-- own as well. | ||
module Util.PeekPokeByteArray ( | ||
-- * Support for defining 'Storable' instances for union types | ||
peekByteArray | ||
, pokeByteArray | ||
-- * Support for defining setters and getters for union types | ||
, setUnionPayload | ||
, getUnionPayload | ||
) where | ||
|
||
import Control.Monad.Primitive | ||
import Data.Coerce | ||
import Data.Primitive.ByteArray | ||
import Foreign | ||
import System.IO.Unsafe | ||
|
||
{------------------------------------------------------------------------------- | ||
Support for defining 'Storable' instances for union types | ||
-------------------------------------------------------------------------------} | ||
|
||
peekByteArray :: Ptr a -> Int -> IO ByteArray | ||
peekByteArray src n = do | ||
pinnedCopy <- newPinnedByteArray n | ||
withMutableByteArrayContents pinnedCopy $ \dest -> | ||
copyBytes dest (castPtr src) n | ||
freezeByteArray pinnedCopy 0 n | ||
|
||
pokeByteArray :: Ptr a -> ByteArray -> IO () | ||
pokeByteArray dest bytes = do | ||
pinnedCopy <- thawPinned bytes | ||
withMutableByteArrayContents pinnedCopy $ \src -> | ||
copyBytes dest (castPtr src) n | ||
where | ||
n = sizeofByteArray bytes | ||
|
||
{------------------------------------------------------------------------------- | ||
Support for defining setters and getters for union types | ||
-------------------------------------------------------------------------------} | ||
|
||
setUnionPayload :: forall payload union. | ||
( Storable payload | ||
, Storable union | ||
, Coercible union ByteArray | ||
) | ||
=> payload -> union | ||
setUnionPayload = coerce . pokeToByteArray (sizeOf (undefined :: union)) | ||
|
||
getUnionPayload :: forall payload union. | ||
( Storable payload | ||
, Storable union | ||
, Coercible union ByteArray | ||
) | ||
=> union -> payload | ||
getUnionPayload = peekFromByteArray . coerce | ||
|
||
peekFromByteArray :: Storable a => ByteArray -> a | ||
peekFromByteArray bytes = unsafePerformIO $ do | ||
pinnedCopy <- thawPinned bytes | ||
withMutableByteArrayContents pinnedCopy $ \ptr -> | ||
peek (castPtr ptr) | ||
|
||
pokeToByteArray :: Storable a => Int -> a -> ByteArray | ||
pokeToByteArray n x = unsafePerformIO $ do | ||
pinnedCopy <- newPinnedByteArray n | ||
withMutableByteArrayContents pinnedCopy $ \ptr -> | ||
poke (castPtr ptr) x | ||
freezeByteArray pinnedCopy 0 n | ||
|
||
{------------------------------------------------------------------------------- | ||
Internal auxiliary | ||
-------------------------------------------------------------------------------} | ||
|
||
thawPinned :: ByteArray -> IO (MutableByteArray RealWorld) | ||
thawPinned src = do | ||
dest <- newPinnedByteArray n | ||
copyByteArray dest 0 src 0 n | ||
return dest | ||
where | ||
n = sizeofByteArray src |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
packages: . |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,24 @@ | ||
#include "example.h" | ||
|
||
int maxint(int a, int b) { | ||
return a > b ? a : b; | ||
} | ||
|
||
int dim_max(Dim_t* dim) { | ||
switch(dim->tag) { | ||
case 0: { | ||
Dim2_t* payload = &(dim->payload.dim2); | ||
return maxint(payload->x, payload->y); | ||
} | ||
default: { | ||
Dim3_t* payload = &(dim->payload.dim3); | ||
return maxint(payload->x, maxint(payload->y, payload->z)); | ||
} | ||
}; | ||
} | ||
|
||
void dim_grow(Dim_t* in, Dim_t* out) { | ||
int max = dim_max(in); | ||
Dim_t result = { tag: 1, payload: { dim3: { x: max, y: max, z: max } } }; | ||
*out = result; | ||
} |
Oops, something went wrong.