Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
### Changed

- Made the (&&) and (||) operators short-circuit also in the Haskell side.
uplc code is unaffected and is already short-circuiting.
3 changes: 3 additions & 0 deletions plutus-tx/plutus-tx.cabal
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,7 @@ library
PlutusTx.Semigroup
PlutusTx.Show
PlutusTx.Show.TH
PlutusTx.Eq.TH
PlutusTx.Sqrt
PlutusTx.TH
PlutusTx.These
Expand Down Expand Up @@ -208,6 +209,8 @@ test-suite plutus-tx-test
Blueprint.Definition.Spec
Blueprint.Spec
List.Spec
Bool.Spec
Eq.Spec
Rational.Laws
Rational.Laws.Additive
Rational.Laws.Construction
Expand Down
12 changes: 10 additions & 2 deletions plutus-tx/src/PlutusTx/Bool.hs
Original file line number Diff line number Diff line change
Expand Up @@ -10,27 +10,35 @@ import Prelude (Bool (..), otherwise)

-- `(&&)` and `(||)` are handled specially in the plugin to make sure they can short-circuit.
-- See Note [Lazy boolean operators] in the plugin.
-- In the Haskell-side, we are using `default-extensions: Strict` throughout PlutusTx,
-- which means that we have to sure that the second argument is lazy to short-circuit the `(&&)` and `(||)`.

{-| Logical AND. Short-circuits if the first argument evaluates to `False`.

>>> True && False
False

>>> False && error ()
False
-}
infixr 3 &&

(&&) :: Bool -> Bool -> Bool
(&&) l r = if l then r else False
(&&) l ~r = if l then r else False
{-# OPAQUE (&&) #-}

{-| Logical OR. Short-circuits if the first argument evaluates to `True`.

>>> True || False
True

>>> True || error ()
True
-}
infixr 2 ||

(||) :: Bool -> Bool -> Bool
(||) l r = if l then True else r
(||) l ~r = if l then True else r
{-# OPAQUE (||) #-}

{-| Logical negation
Expand Down
16 changes: 4 additions & 12 deletions plutus-tx/src/PlutusTx/Eq.hs
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
{-# OPTIONS_GHC -fno-omit-interface-pragmas #-}
{-# OPTIONS_GHC -Wno-orphans #-}

module PlutusTx.Eq (Eq (..), (/=)) where
module PlutusTx.Eq (Eq (..), (/=), deriveEq) where

import PlutusTx.Eq.TH
import PlutusTx.Bool
import PlutusTx.Builtins qualified as Builtins
import PlutusTx.Either (Either (..))
Expand All @@ -10,17 +12,7 @@ import Prelude (Maybe (..))

{- HLINT ignore -}

infix 4 ==, /=

-- Copied from the GHC definition

-- | The 'Eq' class defines equality ('==').
class Eq a where
(==) :: a -> a -> Bool

-- (/=) deliberately omitted, to make this a one-method class which has a
-- simpler representation

infix 4 /=
(/=) :: (Eq a) => a -> a -> Bool
x /= y = not (x == y)
{-# INLINEABLE (/=) #-}
Expand Down
66 changes: 66 additions & 0 deletions plutus-tx/src/PlutusTx/Eq/TH.hs
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
{-# LANGUAGE TemplateHaskellQuotes #-}
module PlutusTx.Eq.TH (Eq (..), deriveEq) where

import PlutusTx.Bool ((&&), Bool (True))
import Prelude hiding (Eq, (==), (&&), Bool (True))
import Data.Foldable
import Data.Traversable
import Language.Haskell.TH as TH
import Language.Haskell.TH.Datatype as TH
import Data.Deriving.Internal (varTToName)

infix 4 ==

-- Copied from the GHC definition

-- | The 'Eq' class defines equality ('==').
class Eq a where
(==) :: a -> a -> Bool

-- (/=) deliberately omitted, to make this a one-method class which has a
-- simpler representation

deriveEq :: TH.Name -> TH.Q [TH.Dec]
deriveEq name = do
TH.DatatypeInfo
{ TH.datatypeName = tyConName
, TH.datatypeInstTypes = tyVars0
, TH.datatypeCons = cons
} <-
TH.reifyDatatype name
let
-- The purpose of the `TH.VarT . varTToName` roundtrip is to remove the kind
-- signatures attached to the type variables in `tyVars0`. Otherwise, the
-- `KindSignatures` extension would be needed whenever `length tyVars0 > 0`.
tyVars = TH.VarT . varTToName <$> tyVars0
instanceCxt :: TH.Cxt
instanceCxt = TH.AppT (TH.ConT ''Eq) <$> tyVars
instanceType :: TH.Type
instanceType = TH.AppT (TH.ConT ''Eq) $ foldl' TH.AppT (TH.ConT tyConName) tyVars

pure <$> instanceD (pure instanceCxt) (pure instanceType)
[funD '(==) (fmap deriveEqCons cons <> [pure eqDefaultClause])
, TH.pragInlD '(==) TH.Inlinable TH.FunLike TH.AllPhases
]


-- Clause: Cons1 l1 l2 l3 .. ln == Cons1 r1 r2 r3 .. rn
deriveEqCons :: ConstructorInfo -> Q Clause
deriveEqCons (ConstructorInfo {constructorName = name, constructorFields = fields })
= do
argsL <- for [1 .. length fields] $ \i -> TH.newName ("l" <> show i)
argsR <- for [1 .. length fields] $ \i -> TH.newName ("r" <> show i)
pure (TH.Clause [ConP name [] (fmap VarP argsL), ConP name [] (fmap VarP argsR)]
(NormalB $
foldr
(\ (argL,argR) acc ->
TH.InfixE(pure $ TH.InfixE (pure $ TH.VarE argL) (TH.VarE '(==)) (pure $ TH.VarE argR)) (TH.VarE '(&&)) (pure acc))
(TH.ConE 'True)
(zip argsL argsR)
)
[]
)

-- Clause: _ == _ = False
eqDefaultClause :: Clause
eqDefaultClause = TH.Clause [WildP, WildP] (TH.NormalB (ConE 'False)) []
18 changes: 18 additions & 0 deletions plutus-tx/test/Bool/Spec.hs
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
module Bool.Spec (boolTests) where

import PlutusTx.Builtins as Tx
import PlutusTx.Bool qualified as Tx

import Prelude (($))
import Test.Tasty
import Test.Tasty.HUnit

boolTests :: TestTree
boolTests =
testGroup
"PlutusTx.Bool tests"
-- in uplc the &&,|| are treated specially to short-circuit
-- this makes sures that the Haskell counterparts also short-circuit
[ testCase "shortcircuit_&&" $ Tx.False Tx.&& Tx.error () @?= Tx.False
, testCase "shortcircuit_||" $ Tx.True Tx.|| Tx.error () @?= Tx.True
]
41 changes: 41 additions & 0 deletions plutus-tx/test/Eq/Spec.hs
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
{-# LANGUAGE TemplateHaskell #-}
{-# LANGUAGE DerivingStrategies #-}
{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE TypeApplications #-}
module Eq.Spec (eqTests) where

import PlutusTx.Builtins as Tx
import PlutusTx.Bool qualified as Tx
import PlutusTx.Eq as Tx
import Control.Exception

import Data.Either

import Prelude hiding (Eq (..), error)
import Prelude qualified as HS (Eq (..),)
import Test.Tasty
import Test.Tasty.HUnit

data SomeLargeADT a b c d e =
SomeLargeADT1 Integer a Tx.Bool b c d
| SomeLargeADT2
| SomeLargeADT3 { f1 :: e, f2 :: e, _f3 :: e, _f4 :: e, _f5 :: e }
deriving stock HS.Eq
deriveEq ''SomeLargeADT

eqTests :: TestTree
eqTests =
let v1 :: SomeLargeADT () BuiltinString () () () = SomeLargeADT1 1 () Tx.True "foobar" () ()
v2 :: SomeLargeADT () () () () () = SomeLargeADT2
v3 :: SomeLargeADT () () () () Integer = SomeLargeADT3 1 2 3 4 5
v3Error1 = v3 { f1 = 0, f2 = error () } -- mismatch comes first, error comes later
v3Error2 = v3 { f1 = error (), f2 = 0 } -- error comes first, mismatch later

in testGroup
"PlutusTx.Eq tests"
[testCase "reflexive1" $ (v1 Tx.== v1) @?= (v1 HS.== v1)
, testCase "reflexive2" $ (v2 Tx.== v2) @?= (v2 HS.== v2)
, testCase "reflexive3" $ (v3 Tx.== v3) @?= (v3 HS.== v3)
, testCase "shortcircuit" $ (v3 Tx.== v3Error1) @?= (v3 Tx.== v3Error1) -- should not throw an error
, testCase "throws" $ try @SomeException (evaluate $ v3 Tx.== v3Error2) >>= assertBool "did not throw error" . isLeft -- should throw erro
]
4 changes: 4 additions & 0 deletions plutus-tx/test/Spec.hs
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@ import Hedgehog (MonadGen, Property, PropertyT, annotateShow, assert, forAll, pr
import Hedgehog.Gen qualified as Gen
import Hedgehog.Range qualified as Range
import List.Spec (listTests)
import Bool.Spec (boolTests)
import Eq.Spec (eqTests)
import PlutusCore.Data (Data (B, Constr, I, List, Map))
import PlutusTx.Enum (Enum (..))
import PlutusTx.Numeric (negate)
Expand All @@ -45,6 +47,8 @@ tests =
, bytestringTests
, enumTests
, listTests
, boolTests
, eqTests
, lawsTests
, Show.Spec.propertyTests
, Show.Spec.goldenTests
Expand Down