diff --git a/.gitignore b/.gitignore index cb589bb20..cf0863e31 100644 --- a/.gitignore +++ b/.gitignore @@ -3,4 +3,7 @@ stack.yaml dist dist-newstyle -.ghc.environment* \ No newline at end of file +.ghc.environment* +.cabal-sandbox +cabal.sandbox.config + diff --git a/Math/NumberTheory/MoebiusInversion.hs b/Math/NumberTheory/MoebiusInversion.hs index 86a8721f8..7e56023a3 100644 --- a/Math/NumberTheory/MoebiusInversion.hs +++ b/Math/NumberTheory/MoebiusInversion.hs @@ -7,12 +7,12 @@ -- Generalised Möbius inversion -- {-# LANGUAGE BangPatterns, FlexibleContexts #-} + module Math.NumberTheory.MoebiusInversion - ( generalInversion - , totientSum - ) where + ( generalInversion + , totientSum + ) where -import Data.Array.ST import Control.Monad import Control.Monad.ST @@ -28,7 +28,7 @@ totientSum n | n < 1 = 0 | otherwise = generalInversion (triangle . fromIntegral) n where - triangle k = (k*(k+1)) `quot` 2 + triangle k = (k * (k + 1)) `quot` 2 -- | @generalInversion g n@ evaluates the generalised Möbius inversion of @g@ -- at the argument @n@. @@ -76,65 +76,67 @@ totientSum n -- method is only appropriate to compute isolated values of @f@. generalInversion :: (Int -> Integer) -> Int -> Integer generalInversion fun n - | n < 1 = error "Möbius inversion only defined on positive domain" - | n == 1 = fun 1 - | n == 2 = fun 2 - fun 1 - | n == 3 = fun 3 - 2*fun 1 - | otherwise = fastInvert fun n + | n < 1 = error "Möbius inversion only defined on positive domain" + | n == 1 = fun 1 + | n == 2 = fun 2 - fun 1 + | n == 3 = fun 3 - 2 * fun 1 + | otherwise = fastInvert fun n fastInvert :: (Int -> Integer) -> Int -> Integer -fastInvert fun n = big `unsafeAt` 0 +fastInvert fun n = big `unsafeIndex` 0 where !k0 = integerSquareRoot (n `quot` 2) - !mk0 = n `quot` (2*k0+1) + !mk0 = n `quot` (2 * k0 + 1) kmax a m = (a `quot` m - 1) `quot` 2 - big = runSTArray $ do - small <- newArray_ (0,mk0) :: ST s (STArray s Int Integer) + big = + runST $ do + small <- unsafeNew (mk0 + 1) :: ST s (STVector s Integer) unsafeWrite small 0 0 - unsafeWrite small 1 $! (fun 1) - when (mk0 >= 2) $ - unsafeWrite small 2 $! (fun 2 - fun 1) + unsafeWrite small 1 $! fun 1 + when (mk0 >= 2) $ unsafeWrite small 2 $! (fun 2 - fun 1) let calcit switch change i - | mk0 < i = return (switch,change) - | i == change = calcit (switch+1) (change + 4*switch+6) i - | otherwise = do - let mloop !acc k !m - | k < switch = kloop acc k - | otherwise = do - val <- unsafeRead small m - let nxtk = kmax i (m+1) - mloop (acc - fromIntegral (k-nxtk)*val) nxtk (m+1) - kloop !acc k - | k == 0 = do - unsafeWrite small i $! acc - calcit switch change (i+1) - | otherwise = do - val <- unsafeRead small (i `quot` (2*k+1)) - kloop (acc-val) (k-1) - mloop (fun i - fun (i `quot` 2)) ((i-1) `quot` 2) 1 + | mk0 < i = return (switch, change) + | i == change = calcit (switch + 1) (change + 4 * switch + 6) i + | otherwise = do + let mloop !acc k !m + | k < switch = kloop acc k + | otherwise = do + val <- unsafeRead small m + let nxtk = kmax i (m + 1) + mloop (acc - fromIntegral (k - nxtk) * val) nxtk (m + 1) + kloop !acc k + | k == 0 = do + unsafeWrite small i $! acc + calcit switch change (i + 1) + | otherwise = do + val <- unsafeRead small (i `quot` (2 * k + 1)) + kloop (acc - val) (k - 1) + mloop (fun i - fun (i `quot` 2)) ((i - 1) `quot` 2) 1 (sw, ch) <- calcit 1 8 3 - large <- newArray_ (0,k0-1) + large <- unsafeNew k0 :: ST s (STVector s Integer) let calcbig switch change j - | j == 0 = return large - | (2*j-1)*change <= n = calcbig (switch+1) (change + 4*switch+6) j - | otherwise = do - let i = n `quot` (2*j-1) - mloop !acc k m - | k < switch = kloop acc k - | otherwise = do - val <- unsafeRead small m - let nxtk = kmax i (m+1) - mloop (acc - fromIntegral (k-nxtk)*val) nxtk (m+1) - kloop !acc k - | k == 0 = do - unsafeWrite large (j-1) $! acc - calcbig switch change (j-1) - | otherwise = do - let m = i `quot` (2*k+1) - val <- if m <= mk0 - then unsafeRead small m - else unsafeRead large (k*(2*j-1)+j-1) - kloop (acc-val) (k-1) - mloop (fun i - fun (i `quot` 2)) ((i-1) `quot` 2) 1 - calcbig sw ch k0 - + | j == 0 = return large + | (2 * j - 1) * change <= n = + calcbig (switch + 1) (change + 4 * switch + 6) j + | otherwise = do + let i = n `quot` (2 * j - 1) + mloop !acc k m + | k < switch = kloop acc k + | otherwise = do + val <- unsafeRead small m + let nxtk = kmax i (m + 1) + mloop (acc - fromIntegral (k - nxtk) * val) nxtk (m + 1) + kloop !acc k + | k == 0 = do + unsafeWrite large (j - 1) $! acc + calcbig switch change (j - 1) + | otherwise = do + let m = i `quot` (2 * k + 1) + val <- + if m <= mk0 + then unsafeRead small m + else unsafeRead large (k * (2 * j - 1) + j - 1) + kloop (acc - val) (k - 1) + mloop (fun i - fun (i `quot` 2)) ((i - 1) `quot` 2) 1 + _ <- calcbig sw ch k0 + unsafeFreeze large diff --git a/Math/NumberTheory/MoebiusInversion/Int.hs b/Math/NumberTheory/MoebiusInversion/Int.hs index 659ab8ef5..be18425db 100644 --- a/Math/NumberTheory/MoebiusInversion/Int.hs +++ b/Math/NumberTheory/MoebiusInversion/Int.hs @@ -8,12 +8,14 @@ -- {-# LANGUAGE BangPatterns, FlexibleContexts #-} {-# OPTIONS_GHC -fspec-constr-count=8 #-} + module Math.NumberTheory.MoebiusInversion.Int - ( generalInversion - , totientSum - ) where + ( generalInversion + , totientSum + ) where + +import Prelude hiding (replicate) -import Data.Array.ST import Control.Monad import Control.Monad.ST @@ -29,7 +31,7 @@ totientSum n | n < 1 = 0 | otherwise = generalInversion (triangle . fromIntegral) n where - triangle k = (k*(k+1)) `quot` 2 + triangle k = (k * (k + 1)) `quot` 2 -- | @generalInversion g n@ evaluates the generalised Möbius inversion of @g@ -- at the argument @n@. @@ -77,64 +79,66 @@ totientSum n -- method is only appropriate to compute isolated values of @f@. generalInversion :: (Int -> Int) -> Int -> Int generalInversion fun n - | n < 1 = error "Möbius inversion only defined on positive domain" - | n == 1 = fun 1 - | n == 2 = fun 2 - fun 1 - | n == 3 = fun 3 - 2*fun 1 - | otherwise = fastInvert fun n + | n < 1 = error "Möbius inversion only defined on positive domain" + | n == 1 = fun 1 + | n == 2 = fun 2 - fun 1 + | n == 3 = fun 3 - 2 * fun 1 + | otherwise = fastInvert fun n fastInvert :: (Int -> Int) -> Int -> Int -fastInvert fun n = big `unsafeAt` 0 +fastInvert fun n = big `unsafeIndex` 0 where !k0 = integerSquareRoot (n `quot` 2) - !mk0 = n `quot` (2*k0+1) + !mk0 = n `quot` (2 * k0 + 1) kmax a m = (a `quot` m - 1) `quot` 2 - big = runSTUArray $ do - small <- newArray_ (0,mk0) :: ST s (STUArray s Int Int) + big = + runST $ do + small <- replicate (mk0 + 1) 0 :: ST s (STVector s Int) unsafeWrite small 0 0 unsafeWrite small 1 (fun 1) - when (mk0 >= 2) $ - unsafeWrite small 2 (fun 2 - fun 1) + when (mk0 >= 2) $ unsafeWrite small 2 (fun 2 - fun 1) let calcit switch change i - | mk0 < i = return (switch,change) - | i == change = calcit (switch+1) (change + 4*switch+6) i - | otherwise = do - let mloop !acc k !m - | k < switch = kloop acc k - | otherwise = do - val <- unsafeRead small m - let nxtk = kmax i (m+1) - mloop (acc - (k-nxtk)*val) nxtk (m+1) - kloop !acc k - | k == 0 = do - unsafeWrite small i acc - calcit switch change (i+1) - | otherwise = do - val <- unsafeRead small (i `quot` (2*k+1)) - kloop (acc-val) (k-1) - mloop (fun i - fun (i `quot` 2)) ((i-1) `quot` 2) 1 + | mk0 < i = return (switch, change) + | i == change = calcit (switch + 1) (change + 4 * switch + 6) i + | otherwise = do + let mloop !acc k !m + | k < switch = kloop acc k + | otherwise = do + val <- unsafeRead small m + let nxtk = kmax i (m + 1) + mloop (acc - (k - nxtk) * val) nxtk (m + 1) + kloop !acc k + | k == 0 = do + unsafeWrite small i acc + calcit switch change (i + 1) + | otherwise = do + val <- unsafeRead small (i `quot` (2 * k + 1)) + kloop (acc - val) (k - 1) + mloop (fun i - fun (i `quot` 2)) ((i - 1) `quot` 2) 1 (sw, ch) <- calcit 1 8 3 - large <- newArray_ (0,k0-1) + large <- replicate k0 0 :: ST s (STVector s Int) let calcbig switch change j - | j == 0 = return large - | (2*j-1)*change <= n = calcbig (switch+1) (change + 4*switch+6) j - | otherwise = do - let i = n `quot` (2*j-1) - mloop !acc k m - | k < switch = kloop acc k - | otherwise = do - val <- unsafeRead small m - let nxtk = kmax i (m+1) - mloop (acc - (k-nxtk)*val) nxtk (m+1) - kloop !acc k - | k == 0 = do - unsafeWrite large (j-1) acc - calcbig switch change (j-1) - | otherwise = do - let m = i `quot` (2*k+1) - val <- if m <= mk0 - then unsafeRead small m - else unsafeRead large (k*(2*j-1)+j-1) - kloop (acc-val) (k-1) - mloop (fun i - fun (i `quot` 2)) ((i-1) `quot` 2) 1 - calcbig sw ch k0 + | j == 0 = return large + | (2 * j - 1) * change <= n = + calcbig (switch + 1) (change + 4 * switch + 6) j + | otherwise = do + let i = n `quot` (2 * j - 1) + mloop !acc k m + | k < switch = kloop acc k + | otherwise = do + val <- unsafeRead small m + let nxtk = kmax i (m + 1) + mloop (acc - (k - nxtk) * val) nxtk (m + 1) + kloop !acc k + | k == 0 = do + unsafeWrite large (j - 1) acc + calcbig switch change (j - 1) + | otherwise = do + let m = i `quot` (2 * k + 1) + val <- + if m <= mk0 + then unsafeRead small m + else unsafeRead large (k * (2 * j - 1) + j - 1) + kloop (acc - val) (k - 1) + mloop (fun i - fun (i `quot` 2)) ((i - 1) `quot` 2) 1 + calcbig sw ch k0 >>= unsafeFreeze diff --git a/Math/NumberTheory/Powers/Cubes.hs b/Math/NumberTheory/Powers/Cubes.hs index 87daf90de..233b8af2a 100644 --- a/Math/NumberTheory/Powers/Cubes.hs +++ b/Math/NumberTheory/Powers/Cubes.hs @@ -15,11 +15,10 @@ module Math.NumberTheory.Powers.Cubes , isCube' , isPossibleCube ) where - + #include "MachDeps.h" -import Data.Array.Unboxed -import Data.Array.ST +import Prelude hiding (replicate) import Data.Bits @@ -30,6 +29,9 @@ import GHC.Integer.Logarithms (integerLog2#) import Math.NumberTheory.Unsafe +import Control.Monad.ST (runST) +import Data.Foldable (for_) + -- | Calculate the integer cube root of an integer @n@, -- that is the largest integer @r@ such that @r^3 <= n@. -- Note that this is not symmetric about @0@, for example @@ -120,10 +122,10 @@ isCube' !n = isPossibleCube n #-} isPossibleCube :: Integral a => a -> Bool isPossibleCube !n = - unsafeAt cr512 (fromIntegral n .&. 511) - && unsafeAt cubeRes837 (fromIntegral (n `rem` 837)) - && unsafeAt cubeRes637 (fromIntegral (n `rem` 637)) - && unsafeAt cubeRes703 (fromIntegral (n `rem` 703)) + unsafeIndex cr512 (fromIntegral n .&. 511) + && unsafeIndex cubeRes837 (fromIntegral (n `rem` 837)) + && unsafeIndex cubeRes637 (fromIntegral (n `rem` 637)) + && unsafeIndex cubeRes703 (fromIntegral (n `rem` 703)) ---------------------------------------------------------------------- -- Utility Functions -- @@ -208,40 +210,38 @@ appCuRt n@(Jp# bn#) appCuRt _ = error "integerCubeRoot': negative argument" -- not very discriminating, but cheap, so it's an overall gain -cr512 :: UArray Int Bool -cr512 = runSTUArray $ do - ar <- newArray (0,511) True +cr512 :: Vector Bool +cr512 = runST $ do + v <- replicate 512 True let note s i - | i < 512 = unsafeWrite ar i False >> note s (i+s) + | i < 512 = unsafeWrite v i False >> note s (i+s) | otherwise = return () note 4 2 note 8 4 note 32 16 note 64 32 note 256 128 - unsafeWrite ar 256 False - return ar + unsafeWrite v 256 False + unsafeFreeze v -- Remainders modulo @3^3 * 31@ -cubeRes837 :: UArray Int Bool -cubeRes837 = runSTUArray $ do - ar <- newArray (0,836) False - let note 837 = return ar - note k = unsafeWrite ar ((k*k*k) `rem` 837) True >> note (k+1) - note 0 +cubeRes837 :: Vector Bool +cubeRes837 = runST $ do + v <- replicate 837 False + for_ [0..837] $ \k -> unsafeWrite v ((k*k*k) `rem` 837) True + unsafeFreeze v -- Remainders modulo @7^2 * 13@ -cubeRes637 :: UArray Int Bool -cubeRes637 = runSTUArray $ do - ar <- newArray (0,636) False - let note 637 = return ar - note k = unsafeWrite ar ((k*k*k) `rem` 637) True >> note (k+1) - note 0 - +cubeRes637 :: Vector Bool +cubeRes637 = runST $ do + v <- replicate 637 False + for_ [0..637]$ \k -> unsafeWrite v ((k*k*k) `rem` 637) True + unsafeFreeze v + -- Remainders modulo @19 * 37@ -cubeRes703 :: UArray Int Bool -cubeRes703 = runSTUArray $ do - ar <- newArray (0,702) False - let note 703 = return ar - note k = unsafeWrite ar ((k*k*k) `rem` 703) True >> note (k+1) - note 0 +cubeRes703 :: Vector Bool +cubeRes703 = runST $ do + v <- replicate 703 False + for_ [0..703] $ \k -> unsafeWrite v ((k*k*k) `rem` 703) True + unsafeFreeze v + diff --git a/Math/NumberTheory/Powers/Fourth.hs b/Math/NumberTheory/Powers/Fourth.hs index 506aadd93..6c77bae5f 100644 --- a/Math/NumberTheory/Powers/Fourth.hs +++ b/Math/NumberTheory/Powers/Fourth.hs @@ -18,14 +18,14 @@ module Math.NumberTheory.Powers.Fourth #include "MachDeps.h" +import Prelude hiding (replicate) + import GHC.Base import GHC.Integer import GHC.Integer.GMP.Internals import GHC.Integer.Logarithms (integerLog2#) -import Data.Array.Unboxed -import Data.Array.ST - +import Control.Monad.ST (runST) import Data.Bits import Math.NumberTheory.Unsafe @@ -105,9 +105,9 @@ isFourthPower' n = isPossibleFourthPower n && r2*r2 == n #-} isPossibleFourthPower :: Integral a => a -> Bool isPossibleFourthPower n = - biSqRes256 `unsafeAt` (fromIntegral n .&. 255) - && biSqRes425 `unsafeAt` (fromIntegral (n `rem` 425)) - && biSqRes377 `unsafeAt` (fromIntegral (n `rem` 377)) + biSqRes256 `unsafeIndex` (fromIntegral n .&. 255) + && biSqRes425 `unsafeIndex` fromIntegral (n `rem` 425) + && biSqRes377 `unsafeIndex` fromIntegral (n `rem` 377) {-# SPECIALISE newton4 :: Integer -> Integer -> Integer #-} newton4 :: Integral a => a -> a -> a @@ -149,28 +149,31 @@ appBiSqrt n@(Jp# bn#) appBiSqrt _ = error "integerFourthRoot': negative argument" -biSqRes256 :: UArray Int Bool -biSqRes256 = runSTUArray $ do - ar <- newArray (0,255) False +biSqRes256 :: Vector Bool +biSqRes256 = runST $ do + ar <- replicate 256 False let note 257 = return ar note i = unsafeWrite ar i True >> note (i+16) unsafeWrite ar 0 True unsafeWrite ar 16 True - note 1 + _ <- note 1 + unsafeFreeze ar -biSqRes425 :: UArray Int Bool -biSqRes425 = runSTUArray $ do - ar <- newArray (0,424) False +biSqRes425 :: Vector Bool +biSqRes425 = runST $ do + ar <- replicate 425 False let note 154 = return ar note i = unsafeWrite ar ((i*i*i*i) `rem` 425) True >> note (i+1) - note 0 + _ <- note 0 + unsafeFreeze ar -biSqRes377 :: UArray Int Bool -biSqRes377 = runSTUArray $ do - ar <- newArray (0,376) False +biSqRes377 :: Vector Bool +biSqRes377 = runST $ do + ar <- replicate 377 False let note 144 = return ar note i = unsafeWrite ar ((i*i*i*i) `rem` 377) True >> note (i+1) - note 0 + _ <- note 0 + unsafeFreeze ar biSqrtInt :: Int -> Int biSqrtInt 0 = 0 diff --git a/Math/NumberTheory/Powers/Squares.hs b/Math/NumberTheory/Powers/Squares.hs index d4abe1185..8af61a683 100644 --- a/Math/NumberTheory/Powers/Squares.hs +++ b/Math/NumberTheory/Powers/Squares.hs @@ -23,13 +23,13 @@ module Math.NumberTheory.Powers.Squares #include "MachDeps.h" -import Data.Array.Unboxed -import Data.Array.ST +import Prelude hiding (replicate) +import Data.Foldable (for_) import Data.Bits +import Control.Monad.ST (runST, ST) import Math.NumberTheory.Unsafe - import Math.NumberTheory.Powers.Squares.Internal -- | Calculate the integer square root of a nonnegative number @n@, @@ -143,9 +143,9 @@ isSquare' n #-} isPossibleSquare :: Integral a => a -> Bool isPossibleSquare n = - unsafeAt sr256 ((fromIntegral n) .&. 255) - && unsafeAt sr693 (fromIntegral (n `rem` 693)) - && unsafeAt sr325 (fromIntegral (n `rem` 325)) + unsafeIndex sr256 (fromIntegral n .&. 255) + && unsafeIndex sr693 (fromIntegral (n `rem` 693)) + && unsafeIndex sr325 (fromIntegral (n `rem` 325)) -- | Test whether a non-negative number may be a square. -- Non-negativity is not checked, passing negative arguments may @@ -166,50 +166,49 @@ isPossibleSquare n = #-} isPossibleSquare2 :: Integral a => a -> Bool isPossibleSquare2 n = - unsafeAt sr256 ((fromIntegral n) .&. 255) - && unsafeAt sr819 (fromIntegral (n `rem` 819)) - && unsafeAt sr1025 (fromIntegral (n `rem` 1025)) - && unsafeAt sr2047 (fromIntegral (n `rem` 2047)) - && unsafeAt sr4097 (fromIntegral (n `rem` 4097)) - && unsafeAt sr341 (fromIntegral (n `rem` 341)) + unsafeIndex sr256 (fromIntegral n .&. 255) + && unsafeIndex sr819 (fromIntegral (n `rem` 819)) + && unsafeIndex sr1025 (fromIntegral (n `rem` 1025)) + && unsafeIndex sr2047 (fromIntegral (n `rem` 2047)) + && unsafeIndex sr4097 (fromIntegral (n `rem` 4097)) + && unsafeIndex sr341 (fromIntegral (n `rem` 341)) ----------------------------------------------------------------------------- -- Auxiliary Stuff -- Make an array indicating whether a remainder is a square remainder. -sqRemArray :: Int -> UArray Int Bool -sqRemArray md = runSTUArray $ do - arr <- newArray (0,md-1) False - let !stop = (md `quot` 2) + 1 - fill k - | k < stop = unsafeWrite arr ((k*k) `rem` md) True >> fill (k+1) - | otherwise = return arr +sqRemArray :: Int -> Vector Bool +sqRemArray md = runST $ do + arr <- replicate md False :: ST s (STVector s Bool) + let !stop = md `quot` 2 + for_ [2..stop] $ + \k -> unsafeWrite arr ((k*k) `rem` md) True unsafeWrite arr 0 True unsafeWrite arr 1 True - fill 2 + unsafeFreeze arr -sr256 :: UArray Int Bool +sr256 :: Vector Bool sr256 = sqRemArray 256 -sr819 :: UArray Int Bool +sr819 :: Vector Bool sr819 = sqRemArray 819 -sr4097 :: UArray Int Bool +sr4097 :: Vector Bool sr4097 = sqRemArray 4097 -sr341 :: UArray Int Bool +sr341 :: Vector Bool sr341 = sqRemArray 341 -sr1025 :: UArray Int Bool +sr1025 :: Vector Bool sr1025 = sqRemArray 1025 -sr2047 :: UArray Int Bool +sr2047 :: Vector Bool sr2047 = sqRemArray 2047 -sr693 :: UArray Int Bool +sr693 :: Vector Bool sr693 = sqRemArray 693 -sr325 :: UArray Int Bool +sr325 :: Vector Bool sr325 = sqRemArray 325 -- Specialisations for Int, Word, and Integer diff --git a/Math/NumberTheory/Primes/Counting/Impl.hs b/Math/NumberTheory/Primes/Counting/Impl.hs index 2dd5fc71f..5d1a73995 100644 --- a/Math/NumberTheory/Primes/Counting/Impl.hs +++ b/Math/NumberTheory/Primes/Counting/Impl.hs @@ -6,42 +6,40 @@ -- -- Number of primes not exceeding @n@, @π(n)@, and @n@-th prime. -- -{-# LANGUAGE BangPatterns #-} -{-# LANGUAGE CPP #-} -{-# LANGUAGE FlexibleContexts #-} +{-# LANGUAGE BangPatterns #-} +{-# LANGUAGE CPP #-} +{-# LANGUAGE FlexibleContexts #-} {-# LANGUAGE ScopedTypeVariables #-} - {-# OPTIONS_GHC -fspec-constr-count=24 #-} {-# OPTIONS_HADDOCK hide #-} -module Math.NumberTheory.Primes.Counting.Impl - ( primeCount - , primeCountMaxArg - , nthPrime - , nthPrimeMaxArg - ) where +module Math.NumberTheory.Primes.Counting.Impl + ( primeCount + , primeCountMaxArg + , nthPrime + , nthPrimeMaxArg + ) where #include "MachDeps.h" +import Prelude hiding (replicate) +import Math.NumberTheory.Logarithms +import Math.NumberTheory.Powers.Cubes +import Math.NumberTheory.Powers.Squares +import Math.NumberTheory.Primes.Counting.Approximate import Math.NumberTheory.Primes.Sieve.Eratosthenes import Math.NumberTheory.Primes.Sieve.Indexing -import Math.NumberTheory.Primes.Counting.Approximate import Math.NumberTheory.Primes.Types -import Math.NumberTheory.Powers.Squares -import Math.NumberTheory.Powers.Cubes -import Math.NumberTheory.Logarithms import Math.NumberTheory.Unsafe -import Data.Array.ST import Control.Monad.ST import Data.Bits import Data.Int - +import Data.Vector.Mutable as MV (length) #if SIZEOF_HSWORD < 8 #define COUNT_T Int64 #else #define COUNT_T Int #endif - -- | Maximal allowed argument of 'primeCount'. Currently 8e18. primeCountMaxArg :: Integer primeCountMaxArg = 8000000000000000000 @@ -58,23 +56,27 @@ primeCountMaxArg = 8000000000000000000 -- . primeCount :: Integer -> Integer primeCount n - | n > primeCountMaxArg = error $ "primeCount: can't handle bound " ++ show n - | n < 2 = 0 - | n < 1000 = fromIntegral . length . takeWhile (<= n) . map unPrime . primeList . primeSieve $ max 242 n - | n < 30000 = runST $ do - ba <- sieveTo n - (s,e) <- getBounds ba - ct <- countFromTo s e ba - return (fromIntegral $ ct+3) - | otherwise = - let !ub = cop $ fromInteger n - !sr = integerSquareRoot' ub - !cr = nxtEnd $ integerCubeRoot' ub + 15 - nxtEnd k = k - (k `rem` 30) + 31 - !phn1 = calc ub cr - !cs = cr+6 - !pdf = sieveCount ub cs sr - in phn1 - pdf + | n > primeCountMaxArg = error $ "primeCount: can't handle bound " ++ show n + | n < 2 = 0 + | n < 1000 = + fromIntegral . + Prelude.length . takeWhile (<= n) . map unPrime . primeList . primeSieve $ + max 242 n + | n < 30000 = + runST $ do + let baST = sieveTo n :: ST s (STVector s Bool) + ba <- baST + ct <- countFromTo 0 (MV.length ba - 1) baST + return (fromIntegral $ ct + 3) + | otherwise = + let !ub = cop $ fromInteger n + !sr = integerSquareRoot' ub + !cr = nxtEnd $ integerCubeRoot' ub + 15 + nxtEnd k = k - (k `rem` 30) + 31 + !phn1 = calc ub cr + !cs = cr + 6 + !pdf = sieveCount ub cs sr + in phn1 - pdf -- | Maximal allowed argument of 'nthPrime'. Currently 1.5e17. nthPrimeMaxArg :: Integer @@ -87,321 +89,362 @@ nthPrimeMaxArg = 150000000000000000 -- The argument must be strictly positive, and must not exceed 'nthPrimeMaxArg'. nthPrime :: Integer -> Prime Integer nthPrime n - | n < 1 = error "Prime indexing starts at 1" - | n > nthPrimeMaxArg = error $ "nthPrime: can't handle index " ++ show n - | n < 200000 = Prime $ nthPrimeCt n - | ct0 < n = Prime $ tooLow n p0 (n-ct0) approxGap - | otherwise = Prime $ tooHigh n p0 (ct0-n) approxGap - where - p0 = nthPrimeApprox n - approxGap = (7 * fromIntegral (integerLog2' p0)) `quot` 10 - ct0 = primeCount p0 + | n < 1 = error "Prime indexing starts at 1" + | n > nthPrimeMaxArg = error $ "nthPrime: can't handle index " ++ show n + | n < 200000 = Prime $ nthPrimeCt n + | ct0 < n = Prime $ tooLow n p0 (n - ct0) approxGap + | otherwise = Prime $ tooHigh n p0 (ct0 - n) approxGap + where + p0 = nthPrimeApprox n + approxGap = (7 * fromIntegral (integerLog2' p0)) `quot` 10 + ct0 = primeCount p0 -------------------------------------------------------------------------------- -- The Works -- -------------------------------------------------------------------------------- - -- TODO: do something better in case we guess too high. -- Not too pressing, since I think a) nthPrimeApprox always underestimates -- in the range we can handle, and b) it's always "goodEnough" - tooLow :: Integer -> Integer -> Integer -> Integer -> Integer tooLow n a miss gap - | goodEnough = lowSieve a miss - | c1 < n = lowSieve p1 (n-c1) - | otherwise = lowSieve a miss -- a third count wouldn't make it faster, I think - where - est = miss*gap - p1 = a + (est * 19) `quot` 20 - goodEnough = 3*est*est*est < 2*p1*p1 -- a second counting would be more work than sieving - c1 = primeCount p1 + | goodEnough = lowSieve a miss + | c1 < n = lowSieve p1 (n - c1) + | otherwise = lowSieve a miss -- a third count wouldn't make it faster, I think + where + est = miss * gap + p1 = a + (est * 19) `quot` 20 + goodEnough = 3 * est * est * est < 2 * p1 * p1 -- a second counting would be more work than sieving + c1 = primeCount p1 tooHigh :: Integer -> Integer -> Integer -> Integer -> Integer tooHigh n a surp gap - | c < n = lowSieve b (n-c) - | otherwise = tooHigh n b (c-n) gap - where - b = a - (surp * gap * 11) `quot` 10 - c = primeCount b + | c < n = lowSieve b (n - c) + | otherwise = tooHigh n b (c - n) gap + where + b = a - (surp * gap * 11) `quot` 10 + c = primeCount b lowSieve :: Integer -> Integer -> Integer -lowSieve a miss = countToNth (miss+rep) psieves +lowSieve a miss = countToNth (miss + rep) psieves + where + strt = + if (fromInteger a .&. (1 :: Int)) == 1 + then a + 2 + else a + 1 + psieves@(PS vO ba:_) = psieveFrom strt + rep + | o0 < 0 = 0 + | otherwise = sum [1 | i <- [0 .. r2], ba `unsafeIndex` i] where - strt = if (fromInteger a .&. (1 :: Int)) == 1 - then a+2 - else a+1 - psieves@(PS vO ba:_) = psieveFrom strt - rep | o0 < 0 = 0 - | otherwise = sum [1 | i <- [0 .. r2], ba `unsafeAt` i] - where - o0 = strt - vO - 9 -- (strt - 2) - v0 - 7 - r0 = fromInteger o0 `rem` 30 - r1 = r0 `quot` 3 - r2 = min 7 (if r1 > 5 then r1-1 else r1) + o0 = strt - vO - 9 -- (strt - 2) - v0 - 7 + r0 = fromInteger o0 `rem` 30 + r1 = r0 `quot` 3 + r2 = + min + 7 + (if r1 > 5 + then r1 - 1 + else r1) -- highSieve :: Integer -> Integer -> Integer -> Integer -- highSieve a surp gap = error "Oh shit" - sieveCount :: COUNT_T -> COUNT_T -> COUNT_T -> Integer sieveCount ub cr sr = runST (sieveCountST ub cr sr) sieveCountST :: forall s. COUNT_T -> COUNT_T -> COUNT_T -> ST s Integer sieveCountST ub cr sr = do - let psieves = psieveFrom (fromIntegral cr) - pisr = approxPrimeCount sr - picr = approxPrimeCount cr - diff = pisr - picr - size = fromIntegral (diff + diff `quot` 50) + 30 - store <- unsafeNewArray_ (0,size-1) :: ST s (STUArray s Int COUNT_T) - let feed :: COUNT_T -> Int -> Int -> UArray Int Bool -> [PrimeSieve] -> ST s Integer - feed voff !wi !ri uar sves - | ri == sieveBits = case sves of - (PS vO ba : more) -> feed (fromInteger vO) wi 0 ba more - _ -> error "prime stream ended prematurely" - | pval > sr = do - stu <- unsafeThaw uar - eat 0 0 voff (wi-1) ri stu sves - | uar `unsafeAt` ri = do - unsafeWrite store wi (ub `quot` pval) - feed voff (wi+1) (ri+1) uar sves - | otherwise = feed voff wi (ri+1) uar sves - where - pval = voff + toPrim ri - eat :: Integer -> Integer -> COUNT_T -> Int -> Int -> STUArray s Int Bool -> [PrimeSieve] -> ST s Integer - eat !acc !btw voff !wi !si stu sves - | si == sieveBits = - case sves of - [] -> error "Premature end of prime stream" - (PS vO ba : more) -> do - nstu <- unsafeThaw ba - eat acc btw (fromInteger vO) wi 0 nstu more - | wi < 0 = return acc - | otherwise = do - qb <- unsafeRead store wi - let dist = qb - voff - 7 - if dist < fromIntegral sieveRange - then do - let (b,j) = idxPr (dist+7) - !li = (b `shiftL` 3) .|. j - new <- if li < si then return 0 else countFromTo si li stu - let nbtw = btw + fromIntegral new + 1 - eat (acc+nbtw) nbtw voff (wi-1) (li+1) stu sves - else do - let (cpl,fds) = dist `quotRem` fromIntegral sieveRange - (b,j) = idxPr (fds+7) - !li = (b `shiftL` 3) .|. j - ctLoop !lac 0 (PS vO ba : more) = do - nstu <- unsafeThaw ba - new <- countFromTo 0 li nstu - let nbtw = btw + lac + 1 + fromIntegral new - eat (acc+nbtw) nbtw (fromIntegral vO) (wi-1) (li+1) nstu more - ctLoop lac s (ps : more) = do - !new <- countAll ps - ctLoop (lac + fromIntegral new) (s-1) more - ctLoop _ _ [] = error "Primes ended" - new <- countFromTo si (sieveBits-1) stu - ctLoop (fromIntegral new) (cpl-1) sves - case psieves of - (PS vO ba : more) -> feed (fromInteger vO) 0 0 ba more - _ -> error "No primes sieved" + let psieves = psieveFrom (fromIntegral cr) + pisr = approxPrimeCount sr + picr = approxPrimeCount cr + diff = pisr - picr + size = fromIntegral (diff + diff `quot` 50) + 30 + store <- unsafeNew size :: ST s (STVector s COUNT_T) + let feed :: + COUNT_T -> Int -> Int -> Vector Bool -> [PrimeSieve] -> ST s Integer + feed voff !wi !ri uar sves + | ri == sieveBits = + case sves of + (PS vO ba:more) -> feed (fromInteger vO) wi 0 ba more + _ -> error "prime stream ended prematurely" + | pval > sr = do + let stuST = unsafeThaw uar :: ST s (STVector s Bool) + eat 0 0 voff (wi - 1) ri stuST sves + | uar `unsafeIndex` ri = do + unsafeWrite store wi (ub `quot` pval) + feed voff (wi + 1) (ri + 1) uar sves + | otherwise = feed voff wi (ri + 1) uar sves + where + pval = voff + toPrim ri + eat :: + Integer + -> Integer + -> COUNT_T + -> Int + -> Int + -> ST s (STVector s Bool) + -> [PrimeSieve] + -> ST s Integer + eat !acc !btw voff !wi !si stu sves + | si == sieveBits = + case sves of + [] -> error "Premature end of prime stream" + (PS vO ba:more) -> do + let nstu = unsafeThaw ba :: ST s (STVector s Bool) + eat acc btw (fromInteger vO) wi 0 nstu more + | wi < 0 = return acc + | otherwise = do + qb <- unsafeRead store wi + let dist = qb - voff - 7 + if dist < fromIntegral sieveRange + then do + let (b, j) = idxPr (dist + 7) + !li = (b `shiftL` 3) .|. j + new <- + if li < si + then return 0 + else countFromTo si li stu + let nbtw = btw + fromIntegral new + 1 + eat (acc + nbtw) nbtw voff (wi - 1) (li + 1) stu sves + else do + let (cpl, fds) = dist `quotRem` fromIntegral sieveRange + (b, j) = idxPr (fds + 7) + !li = (b `shiftL` 3) .|. j + ctLoop !lac 0 (PS vO ba:more) = do + let nstu = unsafeThaw ba + new <- countFromTo 0 li nstu + let nbtw = btw + lac + 1 + fromIntegral new + eat + (acc + nbtw) + nbtw + (fromIntegral vO) + (wi - 1) + (li + 1) + nstu + more + ctLoop lac s (ps:more) = do + !new <- countAll ps + ctLoop (lac + fromIntegral new) (s - 1) more + ctLoop _ _ [] = error "Primes ended" + new <- countFromTo si (sieveBits - 1) stu + ctLoop (fromIntegral new) (cpl - 1) sves + case psieves of + (PS vO ba:more) -> feed (fromInteger vO) 0 0 ba more + _ -> error "No primes sieved" calc :: COUNT_T -> COUNT_T -> Integer calc lim plim = runST (calcST lim plim) calcST :: forall s. COUNT_T -> COUNT_T -> ST s Integer calcST lim plim = do - !parr <- sieveTo (fromIntegral plim) - (plo,phi) <- getBounds parr - !pct <- countFromTo plo phi parr - !ar1 <- unsafeNewArray_ (0,end-1) - unsafeWrite ar1 0 lim - unsafeWrite ar1 1 1 - !ar2 <- unsafeNewArray_ (0,end-1) - let go :: Int -> Int -> STUArray s Int COUNT_T -> STUArray s Int COUNT_T -> ST s Integer - go cap pix old new - | pix == 2 = coll cap old - | otherwise = do - isp <- unsafeRead parr pix - if isp - then do - let !n = fromInteger (toPrim pix) - !ncap <- treat cap n old new - go ncap (pix-1) new old - else go cap (pix-1) old new - coll :: Int -> STUArray s Int COUNT_T -> ST s Integer - coll stop ar = - let cgo !acc i - | i < stop = do - !k <- unsafeRead ar i - !v <- unsafeRead ar (i+1) - cgo (acc + fromIntegral v*cp6 k) (i+2) - | otherwise = return (acc+fromIntegral pct+2) - in cgo 0 0 - go 2 start ar1 ar2 + let !parrST = sieveTo (fromIntegral plim) + !parr <- parrST + !pct <- countFromTo 0 (MV.length parr) parrST + !ar1 <- unsafeNew end + unsafeWrite ar1 0 lim + unsafeWrite ar1 1 1 + !ar2 <- unsafeNew end + let go :: + Int + -> Int + -> STVector s COUNT_T + -> STVector s COUNT_T + -> ST s Integer + go cap pix old new + | pix == 2 = coll cap old + | otherwise = do + isp <- unsafeRead parr pix + if isp + then do + let !n = fromInteger (toPrim pix) + !ncap <- treat cap n old new + go ncap (pix - 1) new old + else go cap (pix - 1) old new + coll :: Int -> STVector s COUNT_T -> ST s Integer + coll stop ar = + let cgo !acc i + | i < stop = do + !k <- unsafeRead ar i + !v <- unsafeRead ar (i + 1) + cgo (acc + fromIntegral v * cp6 k) (i + 2) + | otherwise = return (acc + fromIntegral pct + 2) + in cgo 0 0 + go 2 start ar1 ar2 where - (bt,ri) = idxPr plim - !start = 8*bt + ri - !size = fromIntegral $ (integerSquareRoot lim) `quot` 4 - !end = 2*size + (bt, ri) = idxPr plim + !start = 8 * bt + ri + !size = fromIntegral $ integerSquareRoot lim `quot` 4 + !end = 2 * size -treat :: Int -> COUNT_T -> STUArray s Int COUNT_T -> STUArray s Int COUNT_T -> ST s Int +treat :: Int -> COUNT_T -> STVector s COUNT_T -> STVector s COUNT_T -> ST s Int treat end n old new = do - qi0 <- locate n 0 (end `quot` 2 - 1) old - let collect stop !acc ix - | ix < end = do - !k <- unsafeRead old ix - if k < stop - then do - v <- unsafeRead old (ix+1) - collect stop (acc-v) (ix+2) - else return (acc,ix) - | otherwise = return (acc,ix) - goTreat !wi !ci qi - | qi < end = do - !key <- unsafeRead old qi - !val <- unsafeRead old (qi+1) - let !q0 = key `quot` n - !r0 = fromIntegral (q0 `rem` 30030) - !nkey = q0 - fromIntegral (cpDfAr `unsafeAt` r0) - nk0 = q0 + fromIntegral (cpGpAr `unsafeAt` (r0+1) + 1) - !nlim = n*nk0 - (wi1,ci1) <- copyTo end nkey old ci new wi - ckey <- unsafeRead old ci1 - (!acc, !ci2) <- if ckey == nkey - then do - !ov <- unsafeRead old (ci1+1) - return (ov-val,ci1+2) - else return (-val,ci1) - (!tot, !nqi) <- collect nlim acc (qi+2) - unsafeWrite new wi1 nkey - unsafeWrite new (wi1+1) tot - goTreat (wi1+2) ci2 nqi - | otherwise = copyRem end old ci new wi - goTreat 0 0 qi0 + qi0 <- locate n 0 (end `quot` 2 - 1) old + let collect stop !acc ix + | ix < end = do + !k <- unsafeRead old ix + if k < stop + then do + v <- unsafeRead old (ix + 1) + collect stop (acc - v) (ix + 2) + else return (acc, ix) + | otherwise = return (acc, ix) + goTreat !wi !ci qi + | qi < end = do + !key <- unsafeRead old qi + !val <- unsafeRead old (qi + 1) + let !q0 = key `quot` n + !r0 = fromIntegral (q0 `rem` 30030) + !nkey = q0 - fromIntegral (cpDfAr `unsafeIndex` r0) + nk0 = q0 + fromIntegral (cpGpAr `unsafeIndex` (r0 + 1) + 1) + !nlim = n * nk0 + (wi1, ci1) <- copyTo end nkey old ci new wi + ckey <- unsafeRead old ci1 + (!acc, !ci2) <- + if ckey == nkey + then do + !ov <- unsafeRead old (ci1 + 1) + return (ov - val, ci1 + 2) + else return (-val, ci1) + (!tot, !nqi) <- collect nlim acc (qi + 2) + unsafeWrite new wi1 nkey + unsafeWrite new (wi1 + 1) tot + goTreat (wi1 + 2) ci2 nqi + | otherwise = copyRem end old ci new wi + goTreat 0 0 qi0 -------------------------------------------------------------------------------- -- Auxiliaries -- -------------------------------------------------------------------------------- - -locate :: COUNT_T -> Int -> Int -> STUArray s Int COUNT_T -> ST s Int +locate :: COUNT_T -> Int -> Int -> STVector s COUNT_T -> ST s Int locate p low high arr = do - let go lo hi - | lo < hi = do - let !md = (lo+hi) `quot` 2 - v <- unsafeRead arr (2*md) - case compare p v of - LT -> go lo md - EQ -> return (2*md) - GT -> go (md+1) hi - | otherwise = return (2*lo) - go low high + let go lo hi + | lo < hi = do + let !md = (lo + hi) `quot` 2 + v <- unsafeRead arr (2 * md) + case compare p v of + LT -> go lo md + EQ -> return (2 * md) + GT -> go (md + 1) hi + | otherwise = return (2 * lo) + go low high {-# INLINE copyTo #-} -copyTo :: Int -> COUNT_T -> STUArray s Int COUNT_T -> Int - -> STUArray s Int COUNT_T -> Int -> ST s (Int,Int) +copyTo :: + Int + -> COUNT_T + -> STVector s COUNT_T + -> Int + -> STVector s COUNT_T + -> Int + -> ST s (Int, Int) copyTo end lim old oi new ni = do - let go ri wi - | ri < end = do - ok <- unsafeRead old ri - if ok < lim - then do - !ov <- unsafeRead old (ri+1) - unsafeWrite new wi ok - unsafeWrite new (wi+1) ov - go (ri+2) (wi+2) - else return (wi,ri) - | otherwise = return (wi,ri) - go oi ni + let go ri wi + | ri < end = do + ok <- unsafeRead old ri + if ok < lim + then do + !ov <- unsafeRead old (ri + 1) + unsafeWrite new wi ok + unsafeWrite new (wi + 1) ov + go (ri + 2) (wi + 2) + else return (wi, ri) + | otherwise = return (wi, ri) + go oi ni {-# INLINE copyRem #-} -copyRem :: Int -> STUArray s Int COUNT_T -> Int -> STUArray s Int COUNT_T -> Int -> ST s Int +copyRem :: + Int -> STVector s COUNT_T -> Int -> STVector s COUNT_T -> Int -> ST s Int copyRem end old oi new ni = do - let go ri wi - | ri < end = do - unsafeRead old ri >>= unsafeWrite new wi - go (ri+1) (wi+1) - | otherwise = return wi - go oi ni + let go ri wi + | ri < end = do + unsafeRead old ri >>= unsafeWrite new wi + go (ri + 1) (wi + 1) + | otherwise = return wi + go oi ni {-# INLINE cp6 #-} cp6 :: COUNT_T -> Integer cp6 k = case k `quotRem` 30030 of - (q,r) -> 5760*fromIntegral q + - fromIntegral (cpCtAr `unsafeAt` fromIntegral r) + (q, r) -> + 5760 * fromIntegral q + fromIntegral (cpCtAr `unsafeIndex` fromIntegral r) cop :: COUNT_T -> COUNT_T -cop m = m - fromIntegral (cpDfAr `unsafeAt` fromIntegral (m `rem` 30030)) - +cop m = m - fromIntegral (cpDfAr `unsafeIndex` fromIntegral (m `rem` 30030)) -------------------------------------------------------------------------------- -- Ugly helper arrays -- -------------------------------------------------------------------------------- - -cpCtAr :: UArray Int Int16 -cpCtAr = runSTUArray $ do - ar <- newArray (0,30029) 1 +cpCtAr :: Vector Int16 +cpCtAr = + runST $ do + ar <- replicate 30030 1 :: ST s (STVector s Int16) let zilch s i - | i < 30030 = unsafeWrite ar i 0 >> zilch s (i+s) - | otherwise = return () + | i < 30030 = unsafeWrite ar i 0 >> zilch s (i + s) + | otherwise = return () accumulate ct i - | i < 30030 = do - v <- unsafeRead ar i - let !ct' = ct+v - unsafeWrite ar i ct' - accumulate ct' (i+1) - | otherwise = return ar + | i < 30030 = do + v <- unsafeRead ar i + let !ct' = ct + v + unsafeWrite ar i ct' + accumulate ct' (i + 1) + | otherwise = return ar zilch 2 0 zilch 6 3 zilch 10 5 zilch 14 7 zilch 22 11 zilch 26 13 - accumulate 1 2 + _ <- accumulate 1 2 + unsafeFreeze ar -cpDfAr :: UArray Int Int8 -cpDfAr = runSTUArray $ do - ar <- newArray (0,30029) 0 +cpDfAr :: Vector Int8 +cpDfAr = + runST $ do + ar <- replicate 30030 0 :: ST s (STVector s Int8) let note s i - | i < 30029 = unsafeWrite ar i 1 >> note s (i+s) - | otherwise = return () + | i < 30029 = unsafeWrite ar i 1 >> note s (i + s) + | otherwise = return () accumulate d i - | i < 30029 = do - v <- unsafeRead ar i - if v == 0 - then accumulate 2 (i+2) - else do unsafeWrite ar i d - accumulate (d+1) (i+1) - | otherwise = return ar + | i < 30029 = do + v <- unsafeRead ar i + if v == 0 + then accumulate 2 (i + 2) + else do + unsafeWrite ar i d + accumulate (d + 1) (i + 1) + | otherwise = return ar note 2 0 note 6 3 note 10 5 note 14 7 note 22 11 note 26 13 - accumulate 2 3 + _ <- accumulate 2 3 + unsafeFreeze ar -cpGpAr :: UArray Int Int8 -cpGpAr = runSTUArray $ do - ar <- newArray (0,30030) 0 +cpGpAr :: Vector Int8 +cpGpAr = + runST $ do + ar <- replicate 30031 0 unsafeWrite ar 30030 1 let note s i - | i < 30029 = unsafeWrite ar i 1 >> note s (i+s) - | otherwise = return () + | i < 30029 = unsafeWrite ar i 1 >> note s (i + s) + | otherwise = return () accumulate d i - | i < 1 = return ar - | otherwise = do - v <- unsafeRead ar i - if v == 0 - then accumulate 2 (i-2) - else do unsafeWrite ar i d - accumulate (d+1) (i-1) - | otherwise = return ar + | i < 1 = return ar + | otherwise = do + v <- unsafeRead ar i + if v == 0 + then accumulate 2 (i - 2) + else do + unsafeWrite ar i d + accumulate (d + 1) (i - 1) + | otherwise = return ar note 2 0 note 6 3 note 10 5 note 14 7 note 22 11 note 26 13 - accumulate 2 30027 - + _ <- accumulate 2 30027 + unsafeFreeze ar diff --git a/Math/NumberTheory/Primes/Factorisation/Montgomery.hs b/Math/NumberTheory/Primes/Factorisation/Montgomery.hs index cfcca5e91..55a4bd334 100644 --- a/Math/NumberTheory/Primes/Factorisation/Montgomery.hs +++ b/Math/NumberTheory/Primes/Factorisation/Montgomery.hs @@ -18,21 +18,19 @@ -- -- Given enough time, the algorithm should be able to factor numbers of 100-120 digits, but it -- is best suited for numbers of up to 50-60 digits. - -{-# LANGUAGE BangPatterns #-} -{-# LANGUAGE CPP #-} -{-# LANGUAGE DataKinds #-} -{-# LANGUAGE KindSignatures #-} -{-# LANGUAGE LambdaCase #-} +{-# LANGUAGE BangPatterns #-} +{-# LANGUAGE CPP #-} +{-# LANGUAGE DataKinds #-} +{-# LANGUAGE KindSignatures #-} +{-# LANGUAGE LambdaCase #-} {-# LANGUAGE ScopedTypeVariables #-} - {-# OPTIONS_GHC -fno-warn-type-defaults #-} {-# OPTIONS_HADDOCK hide #-} module Math.NumberTheory.Primes.Factorisation.Montgomery - ( -- * Complete factorisation functions + -- * Complete factorisation functions -- ** Functions with input checking - factorise + ( factorise , defaultStdGenFactorisation -- ** Functions without input checking , factorise' @@ -49,12 +47,12 @@ module Math.NumberTheory.Primes.Factorisation.Montgomery import Control.Arrow import Control.Monad.Trans.State.Lazy -import System.Random import Data.Bits import Data.IntMap (IntMap) import qualified Data.IntMap as IM import Data.List (foldl') import Data.Maybe +import System.Random #if __GLASGOW_HASKELL__ < 803 import Data.Semigroup #endif @@ -65,8 +63,8 @@ import GHC.TypeNats.Compat import Math.NumberTheory.Curves.Montgomery import Math.NumberTheory.Euclidean.Coprimes (splitIntoCoprimes, unCoprimes) import Math.NumberTheory.Moduli.Class -import Math.NumberTheory.Powers.General (highestPower, largePFPower) -import Math.NumberTheory.Powers.Squares (integerSquareRoot') +import Math.NumberTheory.Powers.General (highestPower, largePFPower) +import Math.NumberTheory.Powers.Squares (integerSquareRoot') import Math.NumberTheory.Primes.Sieve.Eratosthenes import Math.NumberTheory.Primes.Sieve.Indexing import Math.NumberTheory.Primes.Testing.Probabilistic @@ -79,26 +77,38 @@ import Math.NumberTheory.Utils -- an arbitrary manner from the bit-pattern of @n@. factorise :: Integer -> [(Integer, Word)] factorise n - | abs n == 1 = [] - | n < 0 = factorise (-n) - | n == 0 = error "0 has no prime factorisation" - | otherwise = factorise' n + | abs n == 1 = [] + | n < 0 = factorise (-n) + | n == 0 = error "0 has no prime factorisation" + | otherwise = factorise' n -- | Like 'factorise', but without input checking, hence @n > 1@ is required. factorise' :: Integer -> [(Integer, Word)] -factorise' n = defaultStdGenFactorisation' (mkStdGen $ fromInteger n `xor` 0xdeadbeef) n +factorise' n = + defaultStdGenFactorisation' (mkStdGen $ fromInteger n `xor` 0xdeadbeef) n -- | @'stepFactorisation'@ is like 'factorise'', except that it doesn't use a -- pseudo random generator but steps through the curves in order. -- This strategy turns out to be surprisingly fast, on average it doesn't -- seem to be slower than the 'StdGen' based variant. stepFactorisation :: Integer -> [(Integer, Word)] -stepFactorisation n - = let (sfs,mb) = smallFactors 100000 n - in sfs ++ case mb of - Nothing -> [] - Just r -> curveFactorisation (Just 10000000000) bailliePSW - (\m k -> (if k < (m-1) then k else error "Curves exhausted",k+1)) 6 Nothing r +stepFactorisation n = + let (sfs, mb) = smallFactors 100000 n + in sfs ++ + case mb of + Nothing -> [] + Just r -> + curveFactorisation + (Just 10000000000) + bailliePSW + (\m k -> + ( if k < (m - 1) + then k + else error "Curves exhausted" + , k + 1)) + 6 + Nothing + r -- | @'defaultStdGenFactorisation'@ first strips off all small prime factors and then, -- if the factorisation is not complete, proceeds to curve factorisation. @@ -107,35 +117,42 @@ stepFactorisation n -- an error. defaultStdGenFactorisation :: StdGen -> Integer -> [(Integer, Word)] defaultStdGenFactorisation sg n - | n == 0 = error "0 has no prime factorisation" - | n < 0 = (-1,1) : defaultStdGenFactorisation sg (-n) - | n == 1 = [] - | otherwise = defaultStdGenFactorisation' sg n + | n == 0 = error "0 has no prime factorisation" + | n < 0 = (-1, 1) : defaultStdGenFactorisation sg (-n) + | n == 1 = [] + | otherwise = defaultStdGenFactorisation' sg n -- | Like 'defaultStdGenFactorisation', but without input checking, so -- @n@ must be larger than @1@. defaultStdGenFactorisation' :: StdGen -> Integer -> [(Integer, Word)] -defaultStdGenFactorisation' sg n - = let (sfs,mb) = smallFactors 100000 n - in sfs ++ case mb of - Nothing -> [] - Just m -> stdGenFactorisation (Just 10000000000) sg Nothing m +defaultStdGenFactorisation' sg n = + let (sfs, mb) = smallFactors 100000 n + in sfs ++ + case mb of + Nothing -> [] + Just m -> stdGenFactorisation (Just 10000000000) sg Nothing m ---------------------------------------------------------------------------------------------------- -- Factorisation wrappers -- ---------------------------------------------------------------------------------------------------- - -- | A wrapper around 'curveFactorisation' providing a few default arguments. -- The primality test is 'bailliePSW', the @prng@ function - naturally - -- 'randomR'. This function also requires small prime factors to have been -- stripped before. -stdGenFactorisation :: Maybe Integer -- ^ Lower bound for composite divisors - -> StdGen -- ^ Standard PRNG - -> Maybe Int -- ^ Estimated number of digits of smallest prime factor - -> Integer -- ^ The number to factorise - -> [(Integer, Word)] -- ^ List of prime factors and exponents -stdGenFactorisation primeBound sg digits n - = curveFactorisation primeBound bailliePSW (\m -> randomR (6,m-2)) sg digits n +stdGenFactorisation :: + Maybe Integer -- ^ Lower bound for composite divisors + -> StdGen -- ^ Standard PRNG + -> Maybe Int -- ^ Estimated number of digits of smallest prime factor + -> Integer -- ^ The number to factorise + -> [(Integer, Word)] -- ^ List of prime factors and exponents +stdGenFactorisation primeBound sg digits n = + curveFactorisation + primeBound + bailliePSW + (\m -> randomR (6, m - 2)) + sg + digits + n -- | 'curveFactorisation' is the driver for the factorisation. Its performance (and success) -- can be influenced by passing appropriate arguments. If you know that @n@ has no prime divisors @@ -155,78 +172,80 @@ stdGenFactorisation primeBound sg digits n -- -- 'curveFactorisation' is unlikely to succeed if @n@ has more than one (really) large prime factor. -- -curveFactorisation - :: forall g. - Maybe Integer -- ^ Lower bound for composite divisors - -> (Integer -> Bool) -- ^ A primality test +curveFactorisation :: + forall g. + Maybe Integer -- ^ Lower bound for composite divisors + -> (Integer -> Bool) -- ^ A primality test -> (Integer -> g -> (Integer, g)) -- ^ A PRNG - -> g -- ^ Initial PRNG state - -> Maybe Int -- ^ Estimated number of digits of the smallest prime factor - -> Integer -- ^ The number to factorise - -> [(Integer, Word)] -- ^ List of prime factors and exponents + -> g -- ^ Initial PRNG state + -> Maybe Int -- ^ Estimated number of digits of the smallest prime factor + -> Integer -- ^ The number to factorise + -> [(Integer, Word)] -- ^ List of prime factors and exponents curveFactorisation primeBound primeTest prng seed mbdigs n - | n == 1 = [] - | ptest n = [(n, 1)] - | otherwise = evalState (fact n digits) seed - where - digits :: Int - digits = fromMaybe 8 mbdigs - - ptest :: Integer -> Bool - ptest = maybe primeTest (\bd k -> k <= bd || primeTest k) primeBound - - rndR :: Integer -> State g Integer - rndR k = state (prng k) - - perfPw :: Integer -> (Integer, Word) - perfPw = maybe highestPower (largePFPower . integerSquareRoot') primeBound - - fact :: Integer -> Int -> State g [(Integer, Word)] - fact 1 _ = return mempty - fact m digs = do - let (b1, b2, ct) = findParms digs + | n == 1 = [] + | ptest n = [(n, 1)] + | otherwise = evalState (fact n digits) seed + where + digits :: Int + digits = fromMaybe 8 mbdigs + ptest :: Integer -> Bool + ptest = maybe primeTest (\bd k -> k <= bd || primeTest k) primeBound + rndR :: Integer -> State g Integer + rndR k = state (prng k) + perfPw :: Integer -> (Integer, Word) + perfPw = maybe highestPower (largePFPower . integerSquareRoot') primeBound + fact :: Integer -> Int -> State g [(Integer, Word)] + fact 1 _ = return mempty + fact m digs = do + let (b1, b2, ct) = findParms digs -- All factors (both @pfs@ and @cfs@), are pairwise coprime. This is -- because 'repFact' returns either a single factor, or output of 'workFact'. -- In its turn, 'workFact' returns either a single factor, -- or concats 'repFact's over coprime integers. Induction completes the proof. - Factors pfs cfs <- repFact m b1 b2 ct - case cfs of - [] -> return pfs - _ -> do - nfs <- forM cfs $ \(k, j) -> - map (second (* j)) <$> fact k (if null pfs then digs + 5 else digs) - return $ mconcat (pfs : nfs) - - repFact :: Integer -> Word -> Word -> Word -> State g Factors - repFact 1 _ _ _ = return mempty - repFact m b1 b2 count = - case perfPw m of - (_, 1) -> workFact m b1 b2 count - (b, e) - | ptest b -> return $ singlePrimeFactor b e - | otherwise -> modifyPowers (* e) <$> workFact b b1 b2 count - - workFact :: Integer -> Word -> Word -> Word -> State g Factors - workFact 1 _ _ _ = return mempty - workFact m _ _ 0 = return $ singleCompositeFactor m 1 - workFact m b1 b2 count = do - s <- rndR m - case s `modulo` fromInteger m of - InfMod{} -> error "impossible case" - SomeMod sm -> case montgomeryFactorisation b1 b2 sm of - Nothing -> workFact m b1 b2 (count - 1) - Just d -> do - let cs = unCoprimes $ splitIntoCoprimes [(d, 1), (m `quot` d, 1)] + Factors pfs cfs <- repFact m b1 b2 ct + case cfs of + [] -> return pfs + _ -> do + nfs <- + forM cfs $ \(k, j) -> + map (second (* j)) <$> + fact + k + (if null pfs + then digs + 5 + else digs) + return $ mconcat (pfs : nfs) + repFact :: Integer -> Word -> Word -> Word -> State g Factors + repFact 1 _ _ _ = return mempty + repFact m b1 b2 count = + case perfPw m of + (_, 1) -> workFact m b1 b2 count + (b, e) + | ptest b -> return $ singlePrimeFactor b e + | otherwise -> modifyPowers (* e) <$> workFact b b1 b2 count + workFact :: Integer -> Word -> Word -> Word -> State g Factors + workFact 1 _ _ _ = return mempty + workFact m _ _ 0 = return $ singleCompositeFactor m 1 + workFact m b1 b2 count = do + s <- rndR m + case s `modulo` fromInteger m of + InfMod {} -> error "impossible case" + SomeMod sm -> + case montgomeryFactorisation b1 b2 sm of + Nothing -> workFact m b1 b2 (count - 1) + Just d -> do + let cs = unCoprimes $ splitIntoCoprimes [(d, 1), (m `quot` d, 1)] -- Since all @cs@ are coprime, we can factor each of -- them and just concat results, without summing up -- powers of the same primes in different elements. - fmap mconcat $ flip mapM cs $ - \(x, xm) -> if ptest x - then pure $ singlePrimeFactor x xm - else repFact x b1 b2 (count - 1) + fmap mconcat $ + flip mapM cs $ \(x, xm) -> + if ptest x + then pure $ singlePrimeFactor x xm + else repFact x b1 b2 (count - 1) data Factors = Factors - { _primeFactors :: [(Integer, Word)] + { _primeFactors :: [(Integer, Word)] , _compositeFactors :: [(Integer, Word)] } @@ -237,21 +256,19 @@ singleCompositeFactor :: Integer -> Word -> Factors singleCompositeFactor a b = Factors [] [(a, b)] instance Semigroup Factors where - Factors pfs1 cfs1 <> Factors pfs2 cfs2 - = Factors (pfs1 <> pfs2) (cfs1 <> cfs2) + Factors pfs1 cfs1 <> Factors pfs2 cfs2 = Factors (pfs1 <> pfs2) (cfs1 <> cfs2) instance Monoid Factors where mempty = Factors [] [] mappend = (<>) modifyPowers :: (Word -> Word) -> Factors -> Factors -modifyPowers f (Factors pfs cfs) - = Factors (map (second f) pfs) (map (second f) cfs) +modifyPowers f (Factors pfs cfs) = + Factors (map (second f) pfs) (map (second f) cfs) ---------------------------------------------------------------------------------------------------- -- The workhorse -- ---------------------------------------------------------------------------------------------------- - -- | @'montgomeryFactorisation' n b1 b2 s@ tries to find a factor of @n@ using the -- curve and point determined by the seed @s@ (@6 <= s < n-1@), multiplying the -- point by the least common multiple of all numbers @<= b1@ and all primes @@ -266,20 +283,24 @@ modifyPowers f (Factors pfs cfs) -- -- The result is maybe a nontrivial divisor of @n@. montgomeryFactorisation :: KnownNat n => Word -> Word -> Mod n -> Maybe Integer -montgomeryFactorisation b1 b2 s = case newPoint (getVal s) n of - Nothing -> Nothing - Just (SomePoint p0) -> do +montgomeryFactorisation b1 b2 s = + case newPoint (getVal s) n of + Nothing -> Nothing + Just (SomePoint p0) -- Small step: for each prime p <= b1 -- multiply point 'p0' by the highest power p^k <= b1. - let q = foldl (flip multiply) p0 smallPowers - z = pointZ q - - fromIntegral <$> case gcd n z of + -> do + let q = foldl (flip multiply) p0 smallPowers + z = pointZ q + fromIntegral <$> + case gcd n z -- If small step did not succeed, perform a big step. - 1 -> case gcd n (bigStep q b1 b2) of - 1 -> Nothing - g -> Just g - g -> Just g + of + 1 -> + case gcd n (bigStep q b1 b2) of + 1 -> Nothing + g -> Just g + g -> Just g where n = getMod s smallPrimes = takeWhile (<= b1) (2 : 3 : 5 : list primeStore) @@ -288,7 +309,7 @@ montgomeryFactorisation b1 b2 s = case newPoint (getVal s) n of where go acc | acc <= b1 `quot` p = go (acc * p) - | otherwise = acc + | otherwise = acc -- | The implementation follows the algorithm at p. 6-7 -- of @@ -297,39 +318,44 @@ bigStep :: (KnownNat a24, KnownNat n) => Point a24 n -> Word -> Word -> Integer bigStep q b1 b2 = rs where n = pointN q - b0 = b1 - b1 `rem` wheel - qks = zip [0..] $ map (\k -> multiply k q) wheelCoprimes + qks = zip [0 ..] $ map (\k -> multiply k q) wheelCoprimes qs = enumAndMultiplyFromThenTo q b0 (b0 + wheel) b2 - - rs = foldl' (\ts (_cHi, p) -> foldl' (\us (_cLo, pq) -> - us * (pointZ p * pointX pq - pointX p * pointZ pq) `rem` n - ) ts qks) 1 qs + rs = + foldl' + (\ts (_cHi, p) -> + foldl' + (\us (_cLo, pq) -> + us * (pointZ p * pointX pq - pointX p * pointZ pq) `rem` n) + ts + qks) + 1 + qs wheel :: Word wheel = 210 wheelCoprimes :: [Word] -wheelCoprimes = [ k | k <- [1 .. wheel `div` 2], k `gcd` wheel == 1 ] +wheelCoprimes = [k | k <- [1 .. wheel `div` 2], k `gcd` wheel == 1] -- | Same as map (id *** flip multiply p) [from, thn .. to], -- but calculated in more efficient way. -enumAndMultiplyFromThenTo - :: (KnownNat a24, KnownNat n) +enumAndMultiplyFromThenTo :: + (KnownNat a24, KnownNat n) => Point a24 n -> Word -> Word -> Word -> [(Word, Point a24 n)] -enumAndMultiplyFromThenTo p from thn to = zip [from, thn .. to] progression +enumAndMultiplyFromThenTo p from thn to = zip [from,thn .. to] progression where step = thn - from - pFrom = multiply from p - pThen = multiply thn p + pThen = multiply thn p pStep = multiply step p - - progression = pFrom : pThen : zipWith (\x0 x1 -> add x0 pStep x1) progression (tail progression) + progression = + pFrom : + pThen : zipWith (\x0 x1 -> add x0 pStep x1) progression (tail progression) -- primes, compactly stored as a bit sieve primeStore :: [PrimeSieve] @@ -337,47 +363,60 @@ primeStore = psieveFrom 7 -- generate list of primes from arrays list :: [PrimeSieve] -> [Word] -list sieves = concat [[off + toPrim i | i <- [0 .. li], unsafeAt bs i] - | PS vO bs <- sieves, let { (_,li) = bounds bs; off = fromInteger vO; }] +list sieves = + concat + [ [off + toPrim i | i <- [0 .. li], unsafeIndex bs i] + | PS vO bs <- sieves + , let li = length bs + off = fromInteger vO + ] -- | @'smallFactors' bound n@ finds all prime divisors of @n > 1@ up to @bound@ by trial division and returns the -- list of these together with their multiplicities, and a possible remaining factor which may be composite. smallFactors :: Integer -> Integer -> ([(Integer, Word)], Maybe Integer) -smallFactors bd n = case shiftToOddCount n of - (0,m) -> go m prms - (k,m) -> (2,k) <: if m == 1 then ([],Nothing) else go m prms +smallFactors bd n = + case shiftToOddCount n of + (0, m) -> go m prms + (k, m) -> + (2, k) <: + if m == 1 + then ([], Nothing) + else go m prms where prms = map unPrime $ tail (primeStore >>= primeList) - x <: ~(l,b) = (x:l,b) + x <: ~(l, b) = (x : l, b) go m (p:ps) - | m < p*p = ([(m,1)], Nothing) - | bd < p = ([], Just m) - | otherwise = case splitOff p m of - (0,_) -> go m ps - (k,r) | r == 1 -> ([(p,k)], Nothing) - | otherwise -> (p,k) <: go r ps - go m [] = ([(m,1)], Nothing) + | m < p * p = ([(m, 1)], Nothing) + | bd < p = ([], Just m) + | otherwise = + case splitOff p m of + (0, _) -> go m ps + (k, r) + | r == 1 -> ([(p, k)], Nothing) + | otherwise -> (p, k) <: go r ps + go m [] = ([(m, 1)], Nothing) -- | For a given estimated decimal length of the smallest prime factor -- ("tier") return parameters B1, B2 and the number of curves to try -- before next "tier". -- Roughly based on http://www.mersennewiki.org/index.php/Elliptic_Curve_Method#Choosing_the_best_parameters_for_ECM testParms :: IntMap (Word, Word, Word) -testParms = IM.fromList - [ (12, ( 400, 40000, 10)) - , (15, ( 2000, 200000, 25)) - , (20, ( 11000, 1100000, 90)) - , (25, ( 50000, 5000000, 300)) - , (30, ( 250000, 25000000, 700)) - , (35, ( 1000000, 100000000, 1800)) - , (40, ( 3000000, 300000000, 5100)) - , (45, ( 11000000, 1100000000, 10600)) - , (50, ( 43000000, 4300000000, 19300)) - , (55, ( 110000000, 11000000000, 49000)) - , (60, ( 260000000, 26000000000, 124000)) - , (65, ( 850000000, 85000000000, 210000)) - , (70, (2900000000, 290000000000, 340000)) - ] +testParms = + IM.fromList + [ (12, (400, 40000, 10)) + , (15, (2000, 200000, 25)) + , (20, (11000, 1100000, 90)) + , (25, (50000, 5000000, 300)) + , (30, (250000, 25000000, 700)) + , (35, (1000000, 100000000, 1800)) + , (40, (3000000, 300000000, 5100)) + , (45, (11000000, 1100000000, 10600)) + , (50, (43000000, 4300000000, 19300)) + , (55, (110000000, 11000000000, 49000)) + , (60, (260000000, 26000000000, 124000)) + , (65, (850000000, 85000000000, 210000)) + , (70, (2900000000, 290000000000, 340000)) + ] findParms :: Int -> (Word, Word, Word) findParms digs = maybe (wheel, 1000, 7) snd (IM.lookupLT digs testParms) diff --git a/Math/NumberTheory/Primes/Sieve/Eratosthenes.hs b/Math/NumberTheory/Primes/Sieve/Eratosthenes.hs index 2ed3622ee..b975755f8 100644 --- a/Math/NumberTheory/Primes/Sieve/Eratosthenes.hs +++ b/Math/NumberTheory/Primes/Sieve/Eratosthenes.hs @@ -6,52 +6,51 @@ -- -- Sieve -- -{-# LANGUAGE BangPatterns #-} -{-# LANGUAGE CPP #-} -{-# LANGUAGE FlexibleContexts #-} -{-# LANGUAGE LambdaCase #-} +{-# LANGUAGE BangPatterns #-} +{-# LANGUAGE CPP #-} +{-# LANGUAGE FlexibleContexts #-} +{-# LANGUAGE LambdaCase #-} {-# LANGUAGE ScopedTypeVariables #-} - {-# OPTIONS_GHC -fspec-constr-count=8 #-} {-# OPTIONS_HADDOCK hide #-} -module Math.NumberTheory.Primes.Sieve.Eratosthenes - ( primes - , sieveFrom - , psieveFrom - , PrimeSieve(..) - , psieveList - , primeList - , primeSieve - , nthPrimeCt - , countFromTo - , countAll - , countToNth - , sieveBits - , sieveRange - , sieveTo - ) where +module Math.NumberTheory.Primes.Sieve.Eratosthenes + ( primes + , sieveFrom + , psieveFrom + , PrimeSieve(..) + , psieveList + , primeList + , primeSieve + , nthPrimeCt + , countFromTo + , countAll + , countToNth + , sieveBits + , sieveRange + , sieveTo + ) where #include "MachDeps.h" +import Prelude hiding (replicate) +import Control.Monad (when) import Control.Monad.ST -import Data.Array.ST -import Data.Array.Unboxed +import Data.Bits import Data.Coerce import Data.Proxy -import Control.Monad (when) -import Data.Bits #if WORD_SIZE_IN_BITS == 32 import Data.Word #endif - +import Data.Vector as V (length) +import Data.Vector.Mutable as MV (length) import Math.NumberTheory.Powers.Squares (integerSquareRoot) -import Math.NumberTheory.Unsafe -import Math.NumberTheory.Utils -import Math.NumberTheory.Utils.FromIntegral import Math.NumberTheory.Primes.Counting.Approximate import Math.NumberTheory.Primes.Sieve.Indexing import Math.NumberTheory.Primes.Types - +import Math.NumberTheory.Unsafe +import Math.NumberTheory.Utils +import Math.NumberTheory.Utils.FromIntegral +import Unsafe.Coerce (unsafeCoerce) #define IX_MASK 0xFFFFF #define IX_BITS 20 #define IX_J_MASK 0x7FFFFF @@ -59,16 +58,15 @@ import Math.NumberTheory.Primes.Types #define J_MASK 7 #define J_BITS 3 #define SIEVE_KB 128 - -- Sieve in 128K chunks. -- Large enough to get something done per chunk -- and hopefully small enough to fit in the cache. sieveBytes :: Int -sieveBytes = SIEVE_KB*1024 +sieveBytes = SIEVE_KB * 1024 -- Number of bits per chunk. sieveBits :: Int -sieveBits = 8*sieveBytes +sieveBits = 8 * sieveBytes -- Last index of chunk. lastIndex :: Int @@ -76,11 +74,10 @@ lastIndex = sieveBits - 1 -- Range of a chunk. sieveRange :: Int -sieveRange = 30*sieveBytes +sieveRange = 30 * sieveBytes sieveWords :: Int sieveWords = sieveBytes `quot` SIZEOF_HSWORD - #if SIZEOF_HSWORD == 8 type CacheWord = Word #define RMASK 63 @@ -94,9 +91,10 @@ type CacheWord = Word64 #define TOPB 16 #define TOPM 0xFFFF #endif - -- | Compact store of primality flags. -data PrimeSieve = PS !Integer {-# UNPACK #-} !(UArray Int Bool) +data PrimeSieve = + PS !Integer + {-# UNPACK #-}!(Vector Bool) -- | Sieve primes up to (and including) a bound (or 7, if bound is smaller). -- For small enough bounds, this is more efficient than @@ -108,39 +106,52 @@ data PrimeSieve = PS !Integer {-# UNPACK #-} !(UArray Int Bool) -- is often within memory limits, so don't give bounds larger than -- @8*10^9@ there. primeSieve :: Integer -> PrimeSieve -primeSieve bound = PS 0 (runSTUArray $ sieveTo bound) +primeSieve bound = + PS + 0 + (runST $ do + res <- sieveTo bound :: ST s (STVector s Bool) + unsafeFreeze res) -- | Generate a list of primes for consumption from a -- 'PrimeSieve'. -primeList :: forall a. Integral a => PrimeSieve -> [Prime a] +primeList :: + forall a. Integral a + => PrimeSieve + -> [Prime a] primeList ps@(PS v _) - | doesNotFit (Proxy :: Proxy a) v - = [] -- has an overflow already happened? - | v == 0 = (coerce :: [a] -> [Prime a]) - $ takeWhileIncreasing $ 2 : 3 : 5 : primeListInternal ps - | otherwise = (coerce :: [a] -> [Prime a]) - $ takeWhileIncreasing $ primeListInternal ps + | doesNotFit (Proxy :: Proxy a) v = [] -- has an overflow already happened? + | v == 0 = + (coerce :: [a] -> [Prime a]) $ + takeWhileIncreasing $ 2 : 3 : 5 : primeListInternal ps + | otherwise = + (coerce :: [a] -> [Prime a]) $ takeWhileIncreasing $ primeListInternal ps primeListInternal :: Num a => PrimeSieve -> [a] -primeListInternal (PS v0 bs) - = map ((+ fromInteger v0) . toPrim) - $ filter (unsafeAt bs) [lo..hi] - where - (lo, hi) = bounds bs +primeListInternal (PS v0 bs) = + map ((+ fromInteger v0) . toPrim) $ + filter (unsafeIndex bs) [0 .. (V.length bs - 1)] -- | Returns true if integer is beyond representation range of type a. -doesNotFit :: forall a. Integral a => Proxy a -> Integer -> Bool +doesNotFit :: + forall a. Integral a + => Proxy a + -> Integer + -> Bool doesNotFit _ v = toInteger (fromInteger v :: a) /= v -- | Extracts the longest strictly increasing prefix of the list -- (possibly infinite). takeWhileIncreasing :: Ord a => [a] -> [a] -takeWhileIncreasing = \case - [] -> [] - x : xs -> x : foldr go (const []) xs x - where - go :: Ord a => a -> (a -> [a]) -> a -> [a] - go y f z = if z < y then y : f y else [] +takeWhileIncreasing = + \case + [] -> [] + x:xs -> x : foldr go (const []) xs x + where go :: Ord a => a -> (a -> [a]) -> a -> [a] + go y f z = + if z < y + then y : f y + else [] -- | Ascending list of primes. -- @@ -168,9 +179,9 @@ takeWhileIncreasing = \case -- 15485867 -- (0.02 secs, 336,232 bytes) primes :: Integral a => [Prime a] -primes - = (coerce :: [a] -> [Prime a]) - $ takeWhileIncreasing $ 2 : 3 : 5 : concatMap primeListInternal psieveList +primes = + (coerce :: [a] -> [Prime a]) $ + takeWhileIncreasing $ 2 : 3 : 5 : concatMap primeListInternal psieveList -- | List of primes in the form of a list of 'PrimeSieve's, more compact than -- 'primes', thus it may be better to use @'psieveList' >>= 'primeList'@ @@ -178,201 +189,237 @@ primes psieveList :: [PrimeSieve] psieveList = makeSieves plim sqlim 0 0 cache where - plim = 4801 -- prime #647, 644 of them to use - sqlim = plim*plim - cache = runSTUArray $ do + plim = 4801 -- prime #647, 644 of them to use + sqlim = plim * plim + cache = + runST $ do sieve <- sieveTo (4801 :: Integer) - new <- unsafeNewArray_ (0,1287) :: ST s (STUArray s Int CacheWord) + new <- unsafeNew 1288 :: ST s (STVector s CacheWord) let fill j indx - | 1279 < indx = return new -- index of 4801 = 159*30 + 31 ~> 159*8+7 + | 1279 < indx = return new -- index of 4801 = 159*30 + 31 ~> 159*8+7 | otherwise = do p <- unsafeRead sieve indx if p then do let !i = indx .&. J_MASK k = indx `shiftR` J_BITS - strt1 = (k*(30*k + 2*rho i) + byte i) `shiftL` J_BITS + fromIntegral (idx i) + strt1 = + (k * (30 * k + 2 * rho i) + byte i) `shiftL` J_BITS + + fromIntegral (idx i) !strt = fromIntegral (strt1 .&. IX_MASK) !skip = fromIntegral (strt1 `shiftR` IX_BITS) - !ixes = fromIntegral indx `shiftL` IX_J_BITS + strt `shiftL` J_BITS + fromIntegral i + !ixes = + fromIntegral indx `shiftL` IX_J_BITS + + strt `shiftL` J_BITS + + fromIntegral i unsafeWrite new j skip - unsafeWrite new (j+1) ixes - fill (j+2) (indx+1) - else fill j (indx+1) - fill 0 0 - -makeSieves :: Integer -> Integer -> Integer -> Integer -> UArray Int CacheWord -> [PrimeSieve] + unsafeWrite new (j + 1) ixes + fill (j + 2) (indx + 1) + else fill j (indx + 1) + _ <- fill 0 0 + unsafeFreeze new + +makeSieves :: + Integer + -> Integer + -> Integer + -> Integer + -> Vector CacheWord + -> [PrimeSieve] makeSieves plim sqlim bitOff valOff cache | valOff' < sqlim = - let (nc, bs) = runST $ do - cch <- unsafeThaw cache :: ST s (STUArray s Int CacheWord) + let (nc, bs) = + runST $ do + cch <- unsafeThaw cache :: ST s (STVector s CacheWord) bs0 <- slice cch fcch <- unsafeFreeze cch fbs0 <- unsafeFreeze bs0 return (fcch, fbs0) - in PS valOff bs : makeSieves plim sqlim bitOff' valOff' nc - | otherwise = - let plim' = plim + 4800 - sqlim' = plim' * plim' - (nc,bs) = runST $ do + in PS valOff bs : makeSieves plim sqlim bitOff' valOff' nc + | otherwise = + let plim' = plim + 4800 + sqlim' = plim' * plim' + (nc, bs) = + runST $ do cch <- growCache bitOff plim cache bs0 <- slice cch fcch <- unsafeFreeze cch fbs0 <- unsafeFreeze bs0 return (fcch, fbs0) - in PS valOff bs : makeSieves plim' sqlim' bitOff' valOff' nc - where - valOff' = valOff + fromIntegral sieveRange - bitOff' = bitOff + fromIntegral sieveBits + in PS valOff bs : makeSieves plim' sqlim' bitOff' valOff' nc + where + valOff' = valOff + fromIntegral sieveRange + bitOff' = bitOff + fromIntegral sieveBits -slice :: STUArray s Int CacheWord -> ST s (STUArray s Int Bool) +slice :: STVector s CacheWord -> ST s (STVector s Bool) slice cache = do - hi <- snd `fmap` getBounds cache - sieve <- newArray (0,lastIndex) True - let treat pr - | hi < pr = return sieve - | otherwise = do - w <- unsafeRead cache pr - if w /= 0 - then unsafeWrite cache pr (w-1) - else do - ixes <- unsafeRead cache (pr+1) - let !stj = fromIntegral ixes .&. IX_J_MASK -- position of multiple and index of cofactor - !ixw = fromIntegral (ixes `shiftR` IX_J_BITS) -- prime data, up to 41 bits - !i = ixw .&. J_MASK - !k = ixw - i -- On 32-bits, k > 44717396 means overflow is possible in tick - !o = i `shiftL` J_BITS - !j = stj .&. J_MASK -- index of cofactor - !s = stj `shiftR` J_BITS -- index of first multiple to tick off - (n, u) <- tick k o j s - let !skip = fromIntegral (n `shiftR` IX_BITS) - !strt = fromIntegral (n .&. IX_MASK) - unsafeWrite cache pr skip - unsafeWrite cache (pr+1) ((ixes .&. complement IX_J_MASK) .|. strt `shiftL` J_BITS .|. fromIntegral u) - treat (pr+2) - tick stp off j ix - | lastIndex < ix = return (ix - sieveBits, j) - | otherwise = do - p <- unsafeRead sieve ix - when p (unsafeWrite sieve ix False) - tick stp off ((j+1) .&. J_MASK) (ix + stp*delta j + tau (off+j)) - treat 0 + let hi = MV.length cache + 1 + sieve <- replicate (lastIndex + 1) True + let treat pr + | hi < pr = return sieve + | otherwise = do + w <- unsafeRead cache pr + if w /= 0 + then unsafeWrite cache pr (w - 1) + else do + ixes <- unsafeRead cache (pr + 1) + let !stj = fromIntegral ixes .&. IX_J_MASK -- position of multiple and index of cofactor + !ixw = fromIntegral (ixes `shiftR` IX_J_BITS) -- prime data, up to 41 bits + !i = ixw .&. J_MASK + !k = ixw - i -- On 32-bits, k > 44717396 means overflow is possible in tick + !o = i `shiftL` J_BITS + !j = stj .&. J_MASK -- index of cofactor + !s = stj `shiftR` J_BITS -- index of first multiple to tick off + (n, u) <- tick k o j s + let !skip = fromIntegral (n `shiftR` IX_BITS) + !strt = fromIntegral (n .&. IX_MASK) + unsafeWrite cache pr skip + unsafeWrite + cache + (pr + 1) + ((ixes .&. complement IX_J_MASK) .|. strt `shiftL` J_BITS .|. + fromIntegral u) + treat (pr + 2) + tick stp off j ix + | lastIndex < ix = return (ix - sieveBits, j) + | otherwise = do + p <- unsafeRead sieve ix + when p (unsafeWrite sieve ix False) + tick stp off ((j + 1) .&. J_MASK) (ix + stp * delta j + tau (off + j)) + treat 0 -- | Sieve up to bound in one go. -sieveTo :: Integer -> ST s (STUArray s Int Bool) +sieveTo :: Integer -> ST s (STVector s Bool) sieveTo bound = arr where - (bytes,lidx) = idxPr bound - !mxidx = 8*bytes+lidx + (bytes, lidx) = idxPr bound + !mxidx = 8 * bytes + lidx mxval :: Integer - mxval = 30*fromIntegral bytes + fromIntegral (rho lidx) + mxval = 30 * fromIntegral bytes + fromIntegral (rho lidx) !mxsve = integerSquareRoot mxval - (kr,r) = idxPr mxsve - !svbd = 8*kr+r + (kr, r) = idxPr mxsve + !svbd = 8 * kr + r arr = do - ar <- newArray (0,mxidx) True - let start k i = 8*(k*(30*k+2*rho i) + byte i) + idx i - tick stp off j ix - | mxidx < ix = return () - | otherwise = do - p <- unsafeRead ar ix - when p (unsafeWrite ar ix False) - tick stp off ((j+1) .&. J_MASK) (ix + stp*delta j + tau (off+j)) - sift ix - | svbd < ix = return ar - | otherwise = do - p <- unsafeRead ar ix - when p (do let i = ix .&. J_MASK - k = ix `shiftR` J_BITS - !off = i `shiftL` J_BITS - !stp = ix - i - tick stp off i (start k i)) - sift (ix+1) - sift 0 - -growCache :: Integer -> Integer -> UArray Int CacheWord -> ST s (STUArray s Int CacheWord) + ar <- replicate (mxidx + 1) True + let start k i = 8 * (k * (30 * k + 2 * rho i) + byte i) + idx i + tick stp off j ix + | mxidx < ix = return () + | otherwise = do + p <- unsafeRead ar ix + when p (unsafeWrite ar ix False) + tick + stp + off + ((j + 1) .&. J_MASK) + (ix + stp * delta j + tau (off + j)) + sift ix + | svbd < ix = return ar + | otherwise = do + p <- unsafeRead ar ix + when + p + (do let i = ix .&. J_MASK + k = ix `shiftR` J_BITS + !off = i `shiftL` J_BITS + !stp = ix - i + tick stp off i (start k i)) + sift (ix + 1) + sift 0 + +growCache :: + Integer -> Integer -> Vector CacheWord -> ST s (STVector s CacheWord) growCache offset plim old = do - let (_,num) = bounds old - (bt,ix) = idxPr plim - !start = 8*bt+ix+1 - !nlim = plim+4800 - sieve <- sieveTo nlim -- Implement SieveFromTo for this, it's pretty wasteful when nlim isn't - (_,hi) <- getBounds sieve -- very small anymore - more <- countFromToWd start hi sieve - new <- unsafeNewArray_ (0,num+2*more) :: ST s (STUArray s Int CacheWord) - let copy i - | num < i = return () - | otherwise = do - unsafeWrite new i (old `unsafeAt` i) - copy (i+1) - copy 0 - let fill j indx - | hi < indx = return new - | otherwise = do - p <- unsafeRead sieve indx - if p - then do - let !i = indx .&. J_MASK - k :: Integer - k = fromIntegral (indx `shiftR` J_BITS) - strt0 = ((k*(30*k + fromIntegral (2*rho i)) - + fromIntegral (byte i)) `shiftL` J_BITS) - + fromIntegral (idx i) - strt1 = strt0 - offset - !strt = fromIntegral strt1 .&. IX_MASK - !skip = fromIntegral (strt1 `shiftR` IX_BITS) - !ixes = fromIntegral indx `shiftL` IX_J_BITS .|. strt `shiftL` J_BITS .|. fromIntegral i - unsafeWrite new j skip - unsafeWrite new (j+1) ixes - fill (j+2) (indx+1) - else fill j (indx+1) - fill (num+1) start + let num = V.length old + (bt, ix) = idxPr plim + !start = 8 * bt + ix + 1 + !nlim = plim + 4800 + let sieveST = sieveTo nlim :: ST s (STVector s Bool) -- Implement SieveFromTo for this, it's pretty wasteful when nlim isn't + sieve <- sieveST + let hi = (MV.length sieve) - 1 -- very small anymore + more <- countFromToWd start hi sieveST + new <- unsafeNew (1 + num + 2 * more) :: ST s (STVector s CacheWord) + let copy i + | num < i = return () + | otherwise = do + unsafeWrite new i (old `unsafeIndex` i) + copy (i + 1) + copy 0 + let fill j indx + | hi < indx = return new + | otherwise = do + p <- unsafeRead sieve indx + if p + then do + let !i = indx .&. J_MASK + k :: Integer + k = fromIntegral (indx `shiftR` J_BITS) + strt0 = + ((k * (30 * k + fromIntegral (2 * rho i)) + + fromIntegral (byte i)) `shiftL` + J_BITS) + + fromIntegral (idx i) + strt1 = strt0 - offset + !strt = fromIntegral strt1 .&. IX_MASK + !skip = fromIntegral (strt1 `shiftR` IX_BITS) + !ixes = + fromIntegral indx `shiftL` IX_J_BITS .|. + strt `shiftL` J_BITS .|. + fromIntegral i + unsafeWrite new j skip + unsafeWrite new (j + 1) ixes + fill (j + 2) (indx + 1) + else fill j (indx + 1) + fill (num + 1) start + +castSTVector :: ST s (STVector s Bool) -> ST s (STVector s Word) +castSTVector = unsafeCoerce -- Danger: relies on start and end being the first resp. last -- index in a Word -- Do not use except in growCache and psieveFrom {-# INLINE countFromToWd #-} -countFromToWd :: Int -> Int -> STUArray s Int Bool -> ST s Int +countFromToWd :: Int -> Int -> ST s (STVector s Bool) -> ST s Int countFromToWd start end ba = do - wa <- (castSTUArray :: STUArray s Int Bool -> ST s (STUArray s Int Word)) ba - let !sb = start `shiftR` WSHFT - !eb = end `shiftR` WSHFT - count !acc i - | eb < i = return acc - | otherwise = do - w <- unsafeRead wa i - count (acc + bitCountWord w) (i+1) - count 0 sb + wa <- (castSTVector ba) + let !sb = start `shiftR` WSHFT + !eb = end `shiftR` WSHFT + count !acc i + | eb < i = return acc + | otherwise = do + w <- unsafeRead wa i + count (acc + bitCountWord w) (i + 1) + count 0 sb -- count set bits between two indices (inclusive) -- start and end must both be valid indices and start <= end -countFromTo :: Int -> Int -> STUArray s Int Bool -> ST s Int +countFromTo :: Int -> Int -> ST s (STVector s Bool) -> ST s Int countFromTo start end ba = do - wa <- (castSTUArray :: STUArray s Int Bool -> ST s (STUArray s Int Word)) ba - let !sb = start `shiftR` WSHFT - !si = start .&. RMASK - !eb = end `shiftR` WSHFT - !ei = end .&. RMASK - count !acc i - | i == eb = do - w <- unsafeRead wa i - return (acc + bitCountWord (w `shiftL` (RMASK - ei))) - | otherwise = do - w <- unsafeRead wa i - count (acc + bitCountWord w) (i+1) - if sb < eb - then do - w <- unsafeRead wa sb - count (bitCountWord (w `shiftR` si)) (sb+1) - else do - w <- unsafeRead wa sb - let !w1 = w `shiftR` si - return (bitCountWord (w1 `shiftL` (RMASK - ei + si))) + wa <- (castSTVector ba) + let !sb = start `shiftR` WSHFT + !si = start .&. RMASK + !eb = end `shiftR` WSHFT + !ei = end .&. RMASK + count !acc i + | i == eb = do + w :: Word <- unsafeRead wa i + return (acc + bitCountWord (w `shiftL` (RMASK - ei))) + | otherwise = do + w <- unsafeRead wa i + count (acc + bitCountWord w) (i + 1) + if sb < eb + then do + w <- unsafeRead wa sb + count (bitCountWord (w `shiftR` si)) (sb + 1) + else do + w <- unsafeRead wa sb + let !w1 = w `shiftR` si + return (bitCountWord (w1 `shiftL` (RMASK - ei + si))) -- | @'sieveFrom' n@ creates the list of primes not less than @n@. sieveFrom :: Integer -> [Prime Integer] -sieveFrom n = case psieveFrom n of - ps -> dropWhile ((< n) . unPrime) (ps >>= primeList) +sieveFrom n = + case psieveFrom n of + ps -> dropWhile ((< n) . unPrime) (ps >>= primeList) -- | @'psieveFrom' n@ creates the list of 'PrimeSieve's starting roughly -- at @n@. Due to the organisation of the sieve, the list may contain @@ -381,76 +428,90 @@ sieveFrom n = case psieveFrom n of -- to use this if it is to be reused. psieveFrom :: Integer -> [PrimeSieve] psieveFrom n = makeSieves plim sqlim bitOff valOff cache - where - k0 = ((n `max` 7) - 7) `quot` 30 -- beware arithmetic underflow - valOff = 30*k0 - bitOff = 8*k0 - start = valOff+7 - ssr = integerSquareRoot (start-1) + 1 - end1 = start - 6 + fromIntegral sieveRange - plim0 = integerSquareRoot end1 - plim = plim0 + 4801 - (plim0 `rem` 4800) - sqlim = plim*plim - cache = runSTUArray $ do - sieve <- sieveTo plim - (lo,hi) <- getBounds sieve - pct <- countFromToWd lo hi sieve - new <- unsafeNewArray_ (0,2*pct-1) :: ST s (STUArray s Int CacheWord) - let fill j indx - | hi < indx = return new - | otherwise = do - isPr <- unsafeRead sieve indx - if isPr - then do - let !i = indx .&. J_MASK - !moff = i `shiftL` J_BITS - k :: Integer - k = fromIntegral (indx `shiftR` J_BITS) - p = 30*k+fromIntegral (rho i) - q0 = (start-1) `quot` p - (skp0,q1) = q0 `quotRem` fromIntegral sieveRange - (b0,r0) - | q1 == 0 = (-1,6) - | q1 < 7 = (-1,7) - | otherwise = idxPr (fromIntegral q1 :: Int) - (b1,r1) | r0 == 7 = (b0+1,0) - | otherwise = (b0,r0+1) - b2 = skp0*fromIntegral sieveBytes + fromIntegral b1 - strt0 = ((k*(30*b2 + fromIntegral (rho r1)) - + b2 * fromIntegral (rho i) - + fromIntegral (mu (moff + r1))) `shiftL` J_BITS) - + fromIntegral (nu (moff + r1)) - strt1 = ((k*(30*k + fromIntegral (2*rho i)) - + fromIntegral (byte i)) `shiftL` J_BITS) - + fromIntegral (idx i) - (strt2,r2) - | p < ssr = (strt0 - bitOff,r1) - | otherwise = (strt1 - bitOff, i) - !strt = fromIntegral strt2 .&. IX_MASK - !skip = fromIntegral (strt2 `shiftR` IX_BITS) - !ixes = fromIntegral indx `shiftL` IX_J_BITS .|. strt `shiftL` J_BITS .|. fromIntegral r2 - unsafeWrite new j skip - unsafeWrite new (j+1) ixes - fill (j+2) (indx+1) - else fill j (indx+1) - fill 0 0 + where + k0 = ((n `max` 7) - 7) `quot` 30 -- beware arithmetic underflow + valOff = 30 * k0 + bitOff = 8 * k0 + start = valOff + 7 + ssr = integerSquareRoot (start - 1) + 1 + end1 = start - 6 + fromIntegral sieveRange + plim0 = integerSquareRoot end1 + plim = plim0 + 4801 - (plim0 `rem` 4800) + sqlim = plim * plim + cache = + runST $ do + let sieveST = sieveTo plim + sieve <- sieveST + let (lo, hi) = (0, MV.length sieve) + pct <- countFromToWd lo hi sieveST + new <- unsafeNew (2 * pct) :: ST s (STVector s CacheWord) + let fill j indx + | hi < indx = return new + | otherwise = do + isPr <- unsafeRead sieve indx + if isPr + then do + let !i = indx .&. J_MASK + !moff = i `shiftL` J_BITS + k :: Integer + k = fromIntegral (indx `shiftR` J_BITS) + p = 30 * k + fromIntegral (rho i) + q0 = (start - 1) `quot` p + (skp0, q1) = q0 `quotRem` fromIntegral sieveRange + (b0, r0) + | q1 == 0 = (-1, 6) + | q1 < 7 = (-1, 7) + | otherwise = idxPr (fromIntegral q1 :: Int) + (b1, r1) + | r0 == 7 = (b0 + 1, 0) + | otherwise = (b0, r0 + 1) + b2 = skp0 * fromIntegral sieveBytes + fromIntegral b1 + strt0 = + ((k * (30 * b2 + fromIntegral (rho r1)) + + b2 * fromIntegral (rho i) + + fromIntegral (mu (moff + r1))) `shiftL` + J_BITS) + + fromIntegral (nu (moff + r1)) + strt1 = + ((k * (30 * k + fromIntegral (2 * rho i)) + + fromIntegral (byte i)) `shiftL` + J_BITS) + + fromIntegral (idx i) + (strt2, r2) + | p < ssr = (strt0 - bitOff, r1) + | otherwise = (strt1 - bitOff, i) + !strt = fromIntegral strt2 .&. IX_MASK + !skip = fromIntegral (strt2 `shiftR` IX_BITS) + !ixes = + fromIntegral indx `shiftL` IX_J_BITS .|. + strt `shiftL` J_BITS .|. + fromIntegral r2 + unsafeWrite new j skip + unsafeWrite new (j + 1) ixes + fill (j + 2) (indx + 1) + else fill j (indx + 1) + _ <- fill 0 0 + unsafeFreeze new -- prime counting - nthPrimeCt :: Integer -> Integer -nthPrimeCt 1 = 2 -nthPrimeCt 2 = 3 -nthPrimeCt 3 = 5 -nthPrimeCt 4 = 7 -nthPrimeCt 5 = 11 -nthPrimeCt 6 = 13 +nthPrimeCt 1 = 2 +nthPrimeCt 2 = 3 +nthPrimeCt 3 = 5 +nthPrimeCt 4 = 7 +nthPrimeCt 5 = 11 +nthPrimeCt 6 = 13 nthPrimeCt n - | n < 1 = error "nthPrimeCt: negative argument" - | n < 200000 = let bd0 = nthPrimeApprox n - bnd = bd0 + bd0 `quot` 32 + 37 - !sv = primeSieve bnd - in countToNth (n-3) [sv] - | otherwise = countToNth (n-3) (psieveFrom (intToInteger $ fromInteger n .&. (7 :: Int))) + | n < 1 = error "nthPrimeCt: negative argument" + | n < 200000 = + let bd0 = nthPrimeApprox n + bnd = bd0 + bd0 `quot` 32 + 37 + !sv = primeSieve bnd + in countToNth (n - 3) [sv] + | otherwise = + countToNth + (n - 3) + (psieveFrom (intToInteger $ fromInteger n .&. (7 :: Int))) -- find the n-th set bit in a list of PrimeSieves, -- aka find the (n+3)-rd prime @@ -458,117 +519,292 @@ countToNth :: Integer -> [PrimeSieve] -> Integer countToNth !n ps = runST (countDown n ps) countDown :: Integer -> [PrimeSieve] -> ST s Integer -countDown !n (ps@(PS v0 bs) : more) +countDown !n (ps@(PS v0 bs):more) | n > 278734 || (v0 /= 0 && n > 253000) = do ct <- countAll ps countDown (n - fromIntegral ct) more | otherwise = do - stu <- unsafeThaw bs - wa <- (castSTUArray :: STUArray s Int Bool -> ST s (STUArray s Int Word)) stu + let stu = unsafeThaw bs :: ST s (STVector s Bool) + wa <- castSTVector stu let go !k i - | i == sieveWords = countDown k more - | otherwise = do + | i == sieveWords = countDown k more + | otherwise = do w <- unsafeRead wa i let !bc = fromIntegral $ bitCountWord w if bc < k - then go (k-bc) (i+1) - else let !j = fromIntegral (bc - k) - !px = top w j (fromIntegral bc) - in return (v0 + toPrim (px+(i `shiftL` WSHFT))) + then go (k - bc) (i + 1) + else let !j = fromIntegral (bc - k) + !px = top w j (fromIntegral bc) + in return (v0 + toPrim (px + (i `shiftL` WSHFT))) go n 0 countDown _ [] = error "Prime stream ended prematurely" -- count all set bits in a chunk, do it wordwise for speed. countAll :: PrimeSieve -> ST s Int countAll (PS _ bs) = do - stu <- unsafeThaw bs - wa <- (castSTUArray :: STUArray s Int Bool -> ST s (STUArray s Int Word)) stu - let go !ct i - | i == sieveWords = return ct - | otherwise = do - w <- unsafeRead wa i - go (ct + bitCountWord w) (i+1) - go 0 0 + let stu = unsafeThaw bs + wa <- castSTVector stu + let go !ct i + | i == sieveWords = return ct + | otherwise = do + w <- unsafeRead wa i + go (ct + bitCountWord w) (i + 1) + go 0 0 -- Find the j-th highest of bc set bits in the Word w. top :: Word -> Int -> Int -> Int top w j bc = go 0 TOPB TOPM bn w - where - !bn = bc-j - go !_ _ !_ !_ 0 = error "Too few bits set" - go bs 0 _ _ wd = if wd .&. 1 == 0 then error "Too few bits, shift 0" else bs - go bs a msk ix wd = - case bitCountWord (wd .&. msk) of - lc | lc < ix -> go (bs+a) a msk (ix-lc) (wd `uncheckedShiftR` a) - | otherwise -> - let !na = a `shiftR` 1 - in go bs na (msk `uncheckedShiftR` na) ix wd + where + !bn = bc - j + go !_ _ !_ !_ 0 = error "Too few bits set" + go bs 0 _ _ wd = + if wd .&. 1 == 0 + then error "Too few bits, shift 0" + else bs + go bs a msk ix wd = + case bitCountWord (wd .&. msk) of + lc + | lc < ix -> go (bs + a) a msk (ix - lc) (wd `uncheckedShiftR` a) + | otherwise -> + let !na = a `shiftR` 1 + in go bs na (msk `uncheckedShiftR` na) ix wd {-# INLINE delta #-} delta :: Int -> Int -delta i = unsafeAt deltas i +delta i = unsafeIndex deltas i -deltas :: UArray Int Int -deltas = listArray (0,7) [4,2,4,2,4,6,2,6] +deltas :: Vector Int +deltas = fromList [4, 2, 4, 2, 4, 6, 2, 6] {-# INLINE tau #-} tau :: Int -> Int -tau i = unsafeAt taus i - -taus :: UArray Int Int -taus = listArray (0,63) - [ 7, 4, 7, 4, 7, 12, 3, 12 - , 12, 6, 11, 6, 12, 18, 5, 18 - , 14, 7, 13, 7, 14, 21, 7, 21 - , 18, 9, 19, 9, 18, 27, 9, 27 - , 20, 10, 21, 10, 20, 30, 11, 30 - , 25, 12, 25, 12, 25, 36, 13, 36 - , 31, 15, 31, 15, 31, 47, 15, 47 - , 33, 17, 33, 17, 33, 49, 17, 49 - ] +tau i = unsafeIndex taus i + +taus :: Vector Int +taus = + fromList + [ 7 + , 4 + , 7 + , 4 + , 7 + , 12 + , 3 + , 12 + , 12 + , 6 + , 11 + , 6 + , 12 + , 18 + , 5 + , 18 + , 14 + , 7 + , 13 + , 7 + , 14 + , 21 + , 7 + , 21 + , 18 + , 9 + , 19 + , 9 + , 18 + , 27 + , 9 + , 27 + , 20 + , 10 + , 21 + , 10 + , 20 + , 30 + , 11 + , 30 + , 25 + , 12 + , 25 + , 12 + , 25 + , 36 + , 13 + , 36 + , 31 + , 15 + , 31 + , 15 + , 31 + , 47 + , 15 + , 47 + , 33 + , 17 + , 33 + , 17 + , 33 + , 49 + , 17 + , 49 + ] {-# INLINE byte #-} byte :: Int -> Int -byte i = unsafeAt startByte i +byte i = unsafeIndex startByte i -startByte :: UArray Int Int -startByte = listArray (0,7) [1,3,5,9,11,17,27,31] +startByte :: Vector Int +startByte = fromList [1, 3, 5, 9, 11, 17, 27, 31] {-# INLINE idx #-} idx :: Int -> Int -idx i = unsafeAt startIdx i +idx i = unsafeIndex startIdx i -startIdx :: UArray Int Int -startIdx = listArray (0,7) [4,7,4,4,7,4,7,7] +startIdx :: Vector Int +startIdx = fromList [4, 7, 4, 4, 7, 4, 7, 7] {-# INLINE mu #-} mu :: Int -> Int -mu i = unsafeAt mArr i +mu i = unsafeIndex mArr i {-# INLINE nu #-} nu :: Int -> Int -nu i = unsafeAt nArr i - -mArr :: UArray Int Int -mArr = listArray (0,63) - [ 1, 2, 2, 3, 4, 5, 6, 7 - , 2, 3, 4, 6, 6, 8, 10, 11 - , 2, 4, 5, 7, 8, 9, 12, 13 - , 3, 6, 7, 9, 10, 12, 16, 17 - , 4, 6, 8, 10, 11, 14, 18, 19 - , 5, 8, 9, 12, 14, 17, 22, 23 - , 6, 10, 12, 16, 18, 22, 27, 29 - , 7, 11, 13, 17, 19, 23, 29, 31 - ] - -nArr :: UArray Int Int -nArr = listArray (0,63) - [ 4, 3, 7, 6, 2, 1, 5, 0 - , 3, 7, 5, 0, 6, 2, 4, 1 - , 7, 5, 4, 1, 0, 6, 3, 2 - , 6, 0, 1, 4, 5, 7, 2, 3 - , 2, 6, 0, 5, 7, 3, 1, 4 - , 1, 2, 6, 7, 3, 4, 0, 5 - , 5, 4, 3, 2, 1, 0, 7, 6 - , 0, 1, 2, 3, 4, 5, 6, 7 - ] +nu i = unsafeIndex nArr i + +mArr :: Vector Int +mArr = + fromList + [ 1 + , 2 + , 2 + , 3 + , 4 + , 5 + , 6 + , 7 + , 2 + , 3 + , 4 + , 6 + , 6 + , 8 + , 10 + , 11 + , 2 + , 4 + , 5 + , 7 + , 8 + , 9 + , 12 + , 13 + , 3 + , 6 + , 7 + , 9 + , 10 + , 12 + , 16 + , 17 + , 4 + , 6 + , 8 + , 10 + , 11 + , 14 + , 18 + , 19 + , 5 + , 8 + , 9 + , 12 + , 14 + , 17 + , 22 + , 23 + , 6 + , 10 + , 12 + , 16 + , 18 + , 22 + , 27 + , 29 + , 7 + , 11 + , 13 + , 17 + , 19 + , 23 + , 29 + , 31 + ] + +nArr :: Vector Int +nArr = + fromList + [ 4 + , 3 + , 7 + , 6 + , 2 + , 1 + , 5 + , 0 + , 3 + , 7 + , 5 + , 0 + , 6 + , 2 + , 4 + , 1 + , 7 + , 5 + , 4 + , 1 + , 0 + , 6 + , 3 + , 2 + , 6 + , 0 + , 1 + , 4 + , 5 + , 7 + , 2 + , 3 + , 2 + , 6 + , 0 + , 5 + , 7 + , 3 + , 1 + , 4 + , 1 + , 2 + , 6 + , 7 + , 3 + , 4 + , 0 + , 5 + , 5 + , 4 + , 3 + , 2 + , 1 + , 0 + , 7 + , 6 + , 0 + , 1 + , 2 + , 3 + , 4 + , 5 + , 6 + , 7 + ] diff --git a/Math/NumberTheory/Primes/Sieve/Indexing.hs b/Math/NumberTheory/Primes/Sieve/Indexing.hs index 7ba06a1c6..5eb5d3956 100644 --- a/Math/NumberTheory/Primes/Sieve/Indexing.hs +++ b/Math/NumberTheory/Primes/Sieve/Indexing.hs @@ -7,40 +7,47 @@ -- Auxiliary stuff, conversion between number and index, -- remainders modulo 30 and related things. {-# OPTIONS_HADDOCK hide #-} + module Math.NumberTheory.Primes.Sieve.Indexing - ( idxPr - , toPrim - , rho - ) where + ( idxPr + , toPrim + , rho + ) where -import Data.Array.Unboxed import Data.Bits import Math.NumberTheory.Unsafe {-# INLINE idxPr #-} -idxPr :: Integral a => a -> (Int,Int) +idxPr :: Integral a => a -> (Int, Int) idxPr n0 - | n0 < 7 = (0, 0) - | otherwise = (fromIntegral bytes0, rm3) + | n0 < 7 = (0, 0) + | otherwise = (fromIntegral bytes0, rm3) where - n = if (fromIntegral n0 .&. 1 == (1 :: Int)) - then n0 else (n0-1) - (bytes0,rm0) = (n-7) `quotRem` 30 + n = + if (fromIntegral n0 .&. 1 == (1 :: Int)) + then n0 + else (n0 - 1) + (bytes0, rm0) = (n - 7) `quotRem` 30 rm1 = fromIntegral rm0 rm2 = rm1 `quot` 3 - rm3 = min 7 (if rm2 > 5 then rm2-1 else rm2) + rm3 = + min + 7 + (if rm2 > 5 + then rm2 - 1 + else rm2) {-# INLINE toPrim #-} toPrim :: Num a => Int -> a -toPrim ix = 30*fromIntegral k + fromIntegral (rho i) +toPrim ix = 30 * fromIntegral k + fromIntegral (rho i) where i = ix .&. 7 k = ix `shiftR` 3 {-# INLINE rho #-} rho :: Int -> Int -rho i = unsafeAt residues i +rho = unsafeIndex residues -residues :: UArray Int Int -residues = listArray (0,7) [7,11,13,17,19,23,29,31] +residues :: Vector Int +residues = fromList [7, 11, 13, 17, 19, 23, 29, 31] diff --git a/Math/NumberTheory/Unsafe.hs b/Math/NumberTheory/Unsafe.hs index a63683c37..ec9a61c24 100644 --- a/Math/NumberTheory/Unsafe.hs +++ b/Math/NumberTheory/Unsafe.hs @@ -6,64 +6,47 @@ -- -- Layer to switch between safe and unsafe arrays. -- - {-# LANGUAGE CPP #-} module Math.NumberTheory.Unsafe - ( UArray - , bounds - , castSTUArray - , unsafeAt + ( STVector + , Vector + , unsafeIndex , unsafeFreeze - , unsafeNewArray_ + , unsafeNew , unsafeRead , unsafeThaw , unsafeWrite + , replicate + , fromList ) where +import Data.Vector (Vector, fromList) +import Data.Vector.Mutable (STVector, replicate) +import Prelude hiding (read, replicate) #ifdef CheckBounds +import Control.Monad.ST (ST) +import Data.Vector ((!), freeze, thaw) +import Data.Vector.Mutable (new, read, write) -import Data.Array.Base - ( UArray - , castSTUArray - ) -import Data.Array.IArray - ( IArray - , bounds - , (!) - ) -import Data.Array.MArray - -unsafeAt :: (IArray a e, Ix i) => a i e -> i -> e -unsafeAt = (!) +unsafeIndex :: Vector a -> Int -> a +unsafeIndex = (!) -unsafeFreeze :: (Ix i, MArray a e m, IArray b e) => a i e -> m (b i e) +unsafeFreeze :: STVector s a -> ST s (Vector a) unsafeFreeze = freeze -unsafeNewArray_ :: (Ix i, MArray a e m) => (i, i) -> m (a i e) -unsafeNewArray_ = newArray_ +unsafeNew :: Int -> ST s (STVector s a) +unsafeNew = new -unsafeRead :: (MArray a e m, Ix i) => a i e -> i -> m e -unsafeRead = readArray +unsafeRead :: STVector s a -> Int -> ST s a +unsafeRead = read -unsafeThaw :: (Ix i, IArray a e, MArray b e m) => a i e -> m (b i e) +unsafeThaw :: Vector a -> ST s (STVector s a) unsafeThaw = thaw -unsafeWrite :: (MArray a e m, Ix i) => a i e -> i -> e -> m () -unsafeWrite = writeArray - +unsafeWrite :: STVector s a -> Int -> a -> ST s () +unsafeWrite = write #else - -import Data.Array.Base - ( UArray - , bounds - , castSTUArray - , unsafeAt - , unsafeFreeze - , unsafeNewArray_ - , unsafeRead - , unsafeThaw - , unsafeWrite - ) - +import Data.Vector (unsafeFreeze, unsafeIndex, unsafeThaw) +import Data.Vector.Mutable (unsafeNew, unsafeRead, unsafeWrite) #endif