Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

WIP: Migration to vector #147

Closed
wants to merge 13 commits into from
5 changes: 4 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -3,4 +3,7 @@
stack.yaml
dist
dist-newstyle
.ghc.environment*
.ghc.environment*
.cabal-sandbox
cabal.sandbox.config

118 changes: 60 additions & 58 deletions Math/NumberTheory/MoebiusInversion.hs
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,12 @@
-- Generalised Möbius inversion
--
{-# LANGUAGE BangPatterns, FlexibleContexts #-}

module Math.NumberTheory.MoebiusInversion
( generalInversion
, totientSum
) where
( generalInversion
, totientSum
) where

import Data.Array.ST
import Control.Monad
import Control.Monad.ST

Expand All @@ -28,7 +28,7 @@ totientSum n
| n < 1 = 0
| otherwise = generalInversion (triangle . fromIntegral) n
where
triangle k = (k*(k+1)) `quot` 2
triangle k = (k * (k + 1)) `quot` 2

-- | @generalInversion g n@ evaluates the generalised Möbius inversion of @g@
-- at the argument @n@.
Expand Down Expand Up @@ -76,65 +76,67 @@ totientSum n
-- method is only appropriate to compute isolated values of @f@.
generalInversion :: (Int -> Integer) -> Int -> Integer
generalInversion fun n
| n < 1 = error "Möbius inversion only defined on positive domain"
| n == 1 = fun 1
| n == 2 = fun 2 - fun 1
| n == 3 = fun 3 - 2*fun 1
| otherwise = fastInvert fun n
| n < 1 = error "Möbius inversion only defined on positive domain"
| n == 1 = fun 1
| n == 2 = fun 2 - fun 1
| n == 3 = fun 3 - 2 * fun 1
| otherwise = fastInvert fun n

fastInvert :: (Int -> Integer) -> Int -> Integer
fastInvert fun n = big `unsafeAt` 0
fastInvert fun n = big `unsafeIndex` 0
where
!k0 = integerSquareRoot (n `quot` 2)
!mk0 = n `quot` (2*k0+1)
!mk0 = n `quot` (2 * k0 + 1)
kmax a m = (a `quot` m - 1) `quot` 2
big = runSTArray $ do
small <- newArray_ (0,mk0) :: ST s (STArray s Int Integer)
big =
runST $ do
small <- unsafeNew (mk0 + 1) :: ST s (STVector s Integer)
unsafeWrite small 0 0
unsafeWrite small 1 $! (fun 1)
when (mk0 >= 2) $
unsafeWrite small 2 $! (fun 2 - fun 1)
unsafeWrite small 1 $! fun 1
when (mk0 >= 2) $ unsafeWrite small 2 $! (fun 2 - fun 1)
let calcit switch change i
| mk0 < i = return (switch,change)
| i == change = calcit (switch+1) (change + 4*switch+6) i
| otherwise = do
let mloop !acc k !m
| k < switch = kloop acc k
| otherwise = do
val <- unsafeRead small m
let nxtk = kmax i (m+1)
mloop (acc - fromIntegral (k-nxtk)*val) nxtk (m+1)
kloop !acc k
| k == 0 = do
unsafeWrite small i $! acc
calcit switch change (i+1)
| otherwise = do
val <- unsafeRead small (i `quot` (2*k+1))
kloop (acc-val) (k-1)
mloop (fun i - fun (i `quot` 2)) ((i-1) `quot` 2) 1
| mk0 < i = return (switch, change)
| i == change = calcit (switch + 1) (change + 4 * switch + 6) i
| otherwise = do
let mloop !acc k !m
| k < switch = kloop acc k
| otherwise = do
val <- unsafeRead small m
let nxtk = kmax i (m + 1)
mloop (acc - fromIntegral (k - nxtk) * val) nxtk (m + 1)
kloop !acc k
| k == 0 = do
unsafeWrite small i $! acc
calcit switch change (i + 1)
| otherwise = do
val <- unsafeRead small (i `quot` (2 * k + 1))
kloop (acc - val) (k - 1)
mloop (fun i - fun (i `quot` 2)) ((i - 1) `quot` 2) 1
(sw, ch) <- calcit 1 8 3
large <- newArray_ (0,k0-1)
large <- unsafeNew k0 :: ST s (STVector s Integer)
let calcbig switch change j
| j == 0 = return large
| (2*j-1)*change <= n = calcbig (switch+1) (change + 4*switch+6) j
| otherwise = do
let i = n `quot` (2*j-1)
mloop !acc k m
| k < switch = kloop acc k
| otherwise = do
val <- unsafeRead small m
let nxtk = kmax i (m+1)
mloop (acc - fromIntegral (k-nxtk)*val) nxtk (m+1)
kloop !acc k
| k == 0 = do
unsafeWrite large (j-1) $! acc
calcbig switch change (j-1)
| otherwise = do
let m = i `quot` (2*k+1)
val <- if m <= mk0
then unsafeRead small m
else unsafeRead large (k*(2*j-1)+j-1)
kloop (acc-val) (k-1)
mloop (fun i - fun (i `quot` 2)) ((i-1) `quot` 2) 1
calcbig sw ch k0

| j == 0 = return large
| (2 * j - 1) * change <= n =
calcbig (switch + 1) (change + 4 * switch + 6) j
| otherwise = do
let i = n `quot` (2 * j - 1)
mloop !acc k m
| k < switch = kloop acc k
| otherwise = do
val <- unsafeRead small m
let nxtk = kmax i (m + 1)
mloop (acc - fromIntegral (k - nxtk) * val) nxtk (m + 1)
kloop !acc k
| k == 0 = do
unsafeWrite large (j - 1) $! acc
calcbig switch change (j - 1)
| otherwise = do
let m = i `quot` (2 * k + 1)
val <-
if m <= mk0
then unsafeRead small m
else unsafeRead large (k * (2 * j - 1) + j - 1)
kloop (acc - val) (k - 1)
mloop (fun i - fun (i `quot` 2)) ((i - 1) `quot` 2) 1
_ <- calcbig sw ch k0
unsafeFreeze large
116 changes: 60 additions & 56 deletions Math/NumberTheory/MoebiusInversion/Int.hs
Original file line number Diff line number Diff line change
Expand Up @@ -8,12 +8,14 @@
--
{-# LANGUAGE BangPatterns, FlexibleContexts #-}
{-# OPTIONS_GHC -fspec-constr-count=8 #-}

module Math.NumberTheory.MoebiusInversion.Int
( generalInversion
, totientSum
) where
( generalInversion
, totientSum
) where

import Prelude hiding (replicate)

import Data.Array.ST
import Control.Monad
import Control.Monad.ST

Expand All @@ -29,7 +31,7 @@ totientSum n
| n < 1 = 0
| otherwise = generalInversion (triangle . fromIntegral) n
where
triangle k = (k*(k+1)) `quot` 2
triangle k = (k * (k + 1)) `quot` 2

-- | @generalInversion g n@ evaluates the generalised Möbius inversion of @g@
-- at the argument @n@.
Expand Down Expand Up @@ -77,64 +79,66 @@ totientSum n
-- method is only appropriate to compute isolated values of @f@.
generalInversion :: (Int -> Int) -> Int -> Int
generalInversion fun n
| n < 1 = error "Möbius inversion only defined on positive domain"
| n == 1 = fun 1
| n == 2 = fun 2 - fun 1
| n == 3 = fun 3 - 2*fun 1
| otherwise = fastInvert fun n
| n < 1 = error "Möbius inversion only defined on positive domain"
| n == 1 = fun 1
| n == 2 = fun 2 - fun 1
| n == 3 = fun 3 - 2 * fun 1
| otherwise = fastInvert fun n

fastInvert :: (Int -> Int) -> Int -> Int
fastInvert fun n = big `unsafeAt` 0
fastInvert fun n = big `unsafeIndex` 0
where
!k0 = integerSquareRoot (n `quot` 2)
!mk0 = n `quot` (2*k0+1)
!mk0 = n `quot` (2 * k0 + 1)
kmax a m = (a `quot` m - 1) `quot` 2
big = runSTUArray $ do
small <- newArray_ (0,mk0) :: ST s (STUArray s Int Int)
big =
runST $ do
small <- replicate (mk0 + 1) 0 :: ST s (STVector s Int)
unsafeWrite small 0 0
unsafeWrite small 1 (fun 1)
when (mk0 >= 2) $
unsafeWrite small 2 (fun 2 - fun 1)
when (mk0 >= 2) $ unsafeWrite small 2 (fun 2 - fun 1)
let calcit switch change i
| mk0 < i = return (switch,change)
| i == change = calcit (switch+1) (change + 4*switch+6) i
| otherwise = do
let mloop !acc k !m
| k < switch = kloop acc k
| otherwise = do
val <- unsafeRead small m
let nxtk = kmax i (m+1)
mloop (acc - (k-nxtk)*val) nxtk (m+1)
kloop !acc k
| k == 0 = do
unsafeWrite small i acc
calcit switch change (i+1)
| otherwise = do
val <- unsafeRead small (i `quot` (2*k+1))
kloop (acc-val) (k-1)
mloop (fun i - fun (i `quot` 2)) ((i-1) `quot` 2) 1
| mk0 < i = return (switch, change)
| i == change = calcit (switch + 1) (change + 4 * switch + 6) i
| otherwise = do
let mloop !acc k !m
| k < switch = kloop acc k
| otherwise = do
val <- unsafeRead small m
let nxtk = kmax i (m + 1)
mloop (acc - (k - nxtk) * val) nxtk (m + 1)
kloop !acc k
| k == 0 = do
unsafeWrite small i acc
calcit switch change (i + 1)
| otherwise = do
val <- unsafeRead small (i `quot` (2 * k + 1))
kloop (acc - val) (k - 1)
mloop (fun i - fun (i `quot` 2)) ((i - 1) `quot` 2) 1
(sw, ch) <- calcit 1 8 3
large <- newArray_ (0,k0-1)
large <- replicate k0 0 :: ST s (STVector s Int)
let calcbig switch change j
| j == 0 = return large
| (2*j-1)*change <= n = calcbig (switch+1) (change + 4*switch+6) j
| otherwise = do
let i = n `quot` (2*j-1)
mloop !acc k m
| k < switch = kloop acc k
| otherwise = do
val <- unsafeRead small m
let nxtk = kmax i (m+1)
mloop (acc - (k-nxtk)*val) nxtk (m+1)
kloop !acc k
| k == 0 = do
unsafeWrite large (j-1) acc
calcbig switch change (j-1)
| otherwise = do
let m = i `quot` (2*k+1)
val <- if m <= mk0
then unsafeRead small m
else unsafeRead large (k*(2*j-1)+j-1)
kloop (acc-val) (k-1)
mloop (fun i - fun (i `quot` 2)) ((i-1) `quot` 2) 1
calcbig sw ch k0
| j == 0 = return large
| (2 * j - 1) * change <= n =
calcbig (switch + 1) (change + 4 * switch + 6) j
| otherwise = do
let i = n `quot` (2 * j - 1)
mloop !acc k m
| k < switch = kloop acc k
| otherwise = do
val <- unsafeRead small m
let nxtk = kmax i (m + 1)
mloop (acc - (k - nxtk) * val) nxtk (m + 1)
kloop !acc k
| k == 0 = do
unsafeWrite large (j - 1) acc
calcbig switch change (j - 1)
| otherwise = do
let m = i `quot` (2 * k + 1)
val <-
if m <= mk0
then unsafeRead small m
else unsafeRead large (k * (2 * j - 1) + j - 1)
kloop (acc - val) (k - 1)
mloop (fun i - fun (i `quot` 2)) ((i - 1) `quot` 2) 1
calcbig sw ch k0 >>= unsafeFreeze
Loading