{-# LANGUAGE RecordWildCards #-}
{-# LANGUAGE MultiParamTypeClasses #-}
{-# LANGUAGE FlexibleInstances #-}

{-# OPTIONS_GHC -fno-warn-missing-signatures #-}

module Internal.Sparse(
    GMatrix(..), CSR(..), mkCSR, fromCSR,
    mkSparse, mkDiagR, mkDense,
    AssocMatrix,
    toDense,
    gmXv, (!#>)
)where

import Internal.Vector
import Internal.Matrix
import Internal.Numeric
import qualified Data.Vector.Storable as V
import Data.Function(on)
import Control.Arrow((***))
import Control.Monad(when)
import Data.List(groupBy, sort)
import Foreign.C.Types(CInt(..))

import Internal.Devel
import System.IO.Unsafe(unsafePerformIO)
import Foreign(Ptr)
import Text.Printf(printf)

infixl 0 ~!~
c :: Bool
c ~!~ :: Bool -> [Char] -> f ()
~!~ msg :: [Char]
msg = Bool -> f () -> f ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when Bool
c ([Char] -> f ()
forall a. HasCallStack => [Char] -> a
error [Char]
msg)

type AssocMatrix = [((Int,Int),Double)]

data CSR = CSR
        { CSR -> Vector Double
csrVals  :: Vector Double
        , CSR -> Vector CInt
csrCols  :: Vector CInt
        , CSR -> Vector CInt
csrRows  :: Vector CInt
        , CSR -> Int
csrNRows :: Int
        , CSR -> Int
csrNCols :: Int
        } deriving Int -> CSR -> ShowS
[CSR] -> ShowS
CSR -> [Char]
(Int -> CSR -> ShowS)
-> (CSR -> [Char]) -> ([CSR] -> ShowS) -> Show CSR
forall a.
(Int -> a -> ShowS) -> (a -> [Char]) -> ([a] -> ShowS) -> Show a
showList :: [CSR] -> ShowS
$cshowList :: [CSR] -> ShowS
show :: CSR -> [Char]
$cshow :: CSR -> [Char]
showsPrec :: Int -> CSR -> ShowS
$cshowsPrec :: Int -> CSR -> ShowS
Show

data CSC = CSC
        { CSC -> Vector Double
cscVals  :: Vector Double
        , CSC -> Vector CInt
cscRows  :: Vector CInt
        , CSC -> Vector CInt
cscCols  :: Vector CInt
        , CSC -> Int
cscNRows :: Int
        , CSC -> Int
cscNCols :: Int
        } deriving Int -> CSC -> ShowS
[CSC] -> ShowS
CSC -> [Char]
(Int -> CSC -> ShowS)
-> (CSC -> [Char]) -> ([CSC] -> ShowS) -> Show CSC
forall a.
(Int -> a -> ShowS) -> (a -> [Char]) -> ([a] -> ShowS) -> Show a
showList :: [CSC] -> ShowS
$cshowList :: [CSC] -> ShowS
show :: CSC -> [Char]
$cshow :: CSC -> [Char]
showsPrec :: Int -> CSC -> ShowS
$cshowsPrec :: Int -> CSC -> ShowS
Show


mkCSR :: AssocMatrix -> CSR
mkCSR :: AssocMatrix -> CSR
mkCSR sm' :: AssocMatrix
sm' = CSR :: Vector Double -> Vector CInt -> Vector CInt -> Int -> Int -> CSR
CSR{..}
  where
    sm :: AssocMatrix
sm = AssocMatrix -> AssocMatrix
forall a. Ord a => [a] -> [a]
sort AssocMatrix
sm'
    rws :: [(Vector CInt, Vector Double)]
rws = (AssocMatrix -> (Vector CInt, Vector Double))
-> [AssocMatrix] -> [(Vector CInt, Vector Double)]
forall a b. (a -> b) -> [a] -> [b]
map (([CInt] -> Vector CInt
forall a. Storable a => [a] -> Vector a
fromList ([CInt] -> Vector CInt)
-> ([Double] -> Vector Double)
-> ([CInt], [Double])
-> (Vector CInt, Vector Double)
forall (a :: * -> * -> *) b c b' c'.
Arrow a =>
a b c -> a b' c' -> a (b, b') (c, c')
*** [Double] -> Vector Double
forall a. Storable a => [a] -> Vector a
fromList)
              (([CInt], [Double]) -> (Vector CInt, Vector Double))
-> (AssocMatrix -> ([CInt], [Double]))
-> AssocMatrix
-> (Vector CInt, Vector Double)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. [(CInt, Double)] -> ([CInt], [Double])
forall a b. [(a, b)] -> ([a], [b])
unzip
              ([(CInt, Double)] -> ([CInt], [Double]))
-> (AssocMatrix -> [(CInt, Double)])
-> AssocMatrix
-> ([CInt], [Double])
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (((Int, Int), Double) -> (CInt, Double))
-> AssocMatrix -> [(CInt, Double)]
forall a b. (a -> b) -> [a] -> [b]
map ((CInt -> CInt
forall a. Enum a => a -> a
succ(CInt -> CInt) -> ((Int, Int) -> CInt) -> (Int, Int) -> CInt
forall b c a. (b -> c) -> (a -> b) -> a -> c
.Int -> CInt
fi(Int -> CInt) -> ((Int, Int) -> Int) -> (Int, Int) -> CInt
forall b c a. (b -> c) -> (a -> b) -> a -> c
.(Int, Int) -> Int
forall a b. (a, b) -> b
snd) ((Int, Int) -> CInt)
-> (Double -> Double) -> ((Int, Int), Double) -> (CInt, Double)
forall (a :: * -> * -> *) b c b' c'.
Arrow a =>
a b c -> a b' c' -> a (b, b') (c, c')
*** Double -> Double
forall a. a -> a
id)
              )
        ([AssocMatrix] -> [(Vector CInt, Vector Double)])
-> (AssocMatrix -> [AssocMatrix])
-> AssocMatrix
-> [(Vector CInt, Vector Double)]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (((Int, Int), Double) -> ((Int, Int), Double) -> Bool)
-> AssocMatrix -> [AssocMatrix]
forall a. (a -> a -> Bool) -> [a] -> [[a]]
groupBy (Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
(==) (Int -> Int -> Bool)
-> (((Int, Int), Double) -> Int)
-> ((Int, Int), Double)
-> ((Int, Int), Double)
-> Bool
forall b c a. (b -> b -> c) -> (a -> b) -> a -> a -> c
`on` ((Int, Int) -> Int
forall a b. (a, b) -> a
fst((Int, Int) -> Int)
-> (((Int, Int), Double) -> (Int, Int))
-> ((Int, Int), Double)
-> Int
forall b c a. (b -> c) -> (a -> b) -> a -> c
.((Int, Int), Double) -> (Int, Int)
forall a b. (a, b) -> a
fst))
        (AssocMatrix -> [(Vector CInt, Vector Double)])
-> AssocMatrix -> [(Vector CInt, Vector Double)]
forall a b. (a -> b) -> a -> b
$ AssocMatrix
sm
    rszs :: [CInt]
rszs = ((Vector CInt, Vector Double) -> CInt)
-> [(Vector CInt, Vector Double)] -> [CInt]
forall a b. (a -> b) -> [a] -> [b]
map (Int -> CInt
fi (Int -> CInt)
-> ((Vector CInt, Vector Double) -> Int)
-> (Vector CInt, Vector Double)
-> CInt
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Vector CInt -> Int
forall t. Storable t => Vector t -> Int
dim (Vector CInt -> Int)
-> ((Vector CInt, Vector Double) -> Vector CInt)
-> (Vector CInt, Vector Double)
-> Int
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Vector CInt, Vector Double) -> Vector CInt
forall a b. (a, b) -> a
fst) [(Vector CInt, Vector Double)]
rws
    csrRows :: Vector CInt
csrRows = [CInt] -> Vector CInt
forall a. Storable a => [a] -> Vector a
fromList ((CInt -> CInt -> CInt) -> CInt -> [CInt] -> [CInt]
forall b a. (b -> a -> b) -> b -> [a] -> [b]
scanl CInt -> CInt -> CInt
forall a. Num a => a -> a -> a
(+) 1 [CInt]
rszs)
    csrVals :: Vector Double
csrVals = [Vector Double] -> Vector Double
forall t. Storable t => [Vector t] -> Vector t
vjoin (((Vector CInt, Vector Double) -> Vector Double)
-> [(Vector CInt, Vector Double)] -> [Vector Double]
forall a b. (a -> b) -> [a] -> [b]
map (Vector CInt, Vector Double) -> Vector Double
forall a b. (a, b) -> b
snd [(Vector CInt, Vector Double)]
rws)
    csrCols :: Vector CInt
csrCols = [Vector CInt] -> Vector CInt
forall t. Storable t => [Vector t] -> Vector t
vjoin (((Vector CInt, Vector Double) -> Vector CInt)
-> [(Vector CInt, Vector Double)] -> [Vector CInt]
forall a b. (a -> b) -> [a] -> [b]
map (Vector CInt, Vector Double) -> Vector CInt
forall a b. (a, b) -> a
fst [(Vector CInt, Vector Double)]
rws)
    csrNRows :: Int
csrNRows = Vector CInt -> Int
forall t. Storable t => Vector t -> Int
dim Vector CInt
csrRows Int -> Int -> Int
forall a. Num a => a -> a -> a
- 1
    csrNCols :: Int
csrNCols = CInt -> Int
forall a b. (Integral a, Num b) => a -> b
fromIntegral (Vector CInt -> CInt
forall a. (Storable a, Ord a) => Vector a -> a
V.maximum Vector CInt
csrCols)

{- | General matrix with specialized internal representations for
     dense, sparse, diagonal, banded, and constant elements.

>>> let m = mkSparse [((0,999),1.0),((1,1999),2.0)]
>>> m
SparseR {gmCSR = CSR {csrVals = fromList [1.0,2.0],
                      csrCols = fromList [1000,2000],
                      csrRows = fromList [1,2,3],
                      csrNRows = 2,
                      csrNCols = 2000},
                      nRows = 2,
                      nCols = 2000}

>>> let m = mkDense (mat 2 [1..4])
>>> m
Dense {gmDense = (2><2)
 [ 1.0, 2.0
 , 3.0, 4.0 ], nRows = 2, nCols = 2}

-}
data GMatrix
    = SparseR
        { GMatrix -> CSR
gmCSR   :: CSR
        , GMatrix -> Int
nRows   :: Int
        , GMatrix -> Int
nCols   :: Int
        }
    | SparseC
        { GMatrix -> CSC
gmCSC   :: CSC
        , nRows   :: Int
        , nCols   :: Int
        }
    | Diag
        { GMatrix -> Vector Double
diagVals :: Vector Double
        , nRows    :: Int
        , nCols    :: Int
        }
    | Dense
        { GMatrix -> Matrix Double
gmDense :: Matrix Double
        , nRows   :: Int
        , nCols   :: Int
        }
--    | Banded
    deriving Int -> GMatrix -> ShowS
[GMatrix] -> ShowS
GMatrix -> [Char]
(Int -> GMatrix -> ShowS)
-> (GMatrix -> [Char]) -> ([GMatrix] -> ShowS) -> Show GMatrix
forall a.
(Int -> a -> ShowS) -> (a -> [Char]) -> ([a] -> ShowS) -> Show a
showList :: [GMatrix] -> ShowS
$cshowList :: [GMatrix] -> ShowS
show :: GMatrix -> [Char]
$cshow :: GMatrix -> [Char]
showsPrec :: Int -> GMatrix -> ShowS
$cshowsPrec :: Int -> GMatrix -> ShowS
Show


mkDense :: Matrix Double -> GMatrix
mkDense :: Matrix Double -> GMatrix
mkDense m :: Matrix Double
m = Dense :: Matrix Double -> Int -> Int -> GMatrix
Dense{..}
  where
    gmDense :: Matrix Double
gmDense = Matrix Double
m
    nRows :: Int
nRows = Matrix Double -> Int
forall t. Matrix t -> Int
rows Matrix Double
m
    nCols :: Int
nCols = Matrix Double -> Int
forall t. Matrix t -> Int
cols Matrix Double
m

mkSparse :: AssocMatrix -> GMatrix
mkSparse :: AssocMatrix -> GMatrix
mkSparse = CSR -> GMatrix
fromCSR (CSR -> GMatrix) -> (AssocMatrix -> CSR) -> AssocMatrix -> GMatrix
forall b c a. (b -> c) -> (a -> b) -> a -> c
. AssocMatrix -> CSR
mkCSR

fromCSR :: CSR -> GMatrix
fromCSR :: CSR -> GMatrix
fromCSR csr :: CSR
csr = SparseR :: CSR -> Int -> Int -> GMatrix
SparseR {..}
  where
    gmCSR :: CSR
gmCSR @ CSR {..} = CSR
csr
    nRows :: Int
nRows = Int
csrNRows
    nCols :: Int
nCols = Int
csrNCols


mkDiagR :: Int -> Int -> Vector Double -> GMatrix
mkDiagR r :: Int
r c :: Int
c v :: Vector Double
v
    | Vector Double -> Int
forall t. Storable t => Vector t -> Int
dim Vector Double
v Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
<= Int -> Int -> Int
forall a. Ord a => a -> a -> a
min Int
r Int
c = Diag :: Vector Double -> Int -> Int -> GMatrix
Diag{..}
    | Bool
otherwise = [Char] -> GMatrix
forall a. HasCallStack => [Char] -> a
error ([Char] -> GMatrix) -> [Char] -> GMatrix
forall a b. (a -> b) -> a -> b
$ [Char] -> Int -> Int -> Int -> [Char]
forall r. PrintfType r => [Char] -> r
printf "mkDiagR: incorrect sizes (%d,%d) [%d]" Int
r Int
c (Vector Double -> Int
forall t. Storable t => Vector t -> Int
dim Vector Double
v)
  where
    nRows :: Int
nRows = Int
r
    nCols :: Int
nCols = Int
c
    diagVals :: Vector Double
diagVals = Vector Double
v


type IV t = CInt -> Ptr CInt   -> t
type  V t = CInt -> Ptr Double -> t
type SMxV = V (IV (IV (V (V (IO CInt)))))

gmXv :: GMatrix -> Vector Double -> Vector Double
gmXv :: GMatrix -> Vector Double -> Vector Double
gmXv SparseR { gmCSR :: GMatrix -> CSR
gmCSR = CSR{..}, .. } v :: Vector Double
v = IO (Vector Double) -> Vector Double
forall a. IO a -> a
unsafePerformIO (IO (Vector Double) -> Vector Double)
-> IO (Vector Double) -> Vector Double
forall a b. (a -> b) -> a -> b
$ do
    Vector Double -> Int
forall t. Storable t => Vector t -> Int
dim Vector Double
v Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
/= Int
nCols Bool -> [Char] -> IO ()
forall (f :: * -> *). Applicative f => Bool -> [Char] -> f ()
~!~ [Char] -> Int -> Int -> Int -> [Char]
forall r. PrintfType r => [Char] -> r
printf "gmXv (CSR): incorrect sizes: (%d,%d) x %d" Int
nRows Int
nCols (Vector Double -> Int
forall t. Storable t => Vector t -> Int
dim Vector Double
v)
    Vector Double
r <- Int -> IO (Vector Double)
forall a. Storable a => Int -> IO (Vector a)
createVector Int
nRows
    (Vector Double
csrVals Vector Double
-> ((CInt
     -> Ptr CInt
     -> CInt
     -> Ptr CInt
     -> CInt
     -> Ptr Double
     -> CInt
     -> Ptr Double
     -> IO CInt)
    -> IO CInt)
-> Trans
     (Vector Double)
     (CInt
      -> Ptr CInt
      -> CInt
      -> Ptr CInt
      -> CInt
      -> Ptr Double
      -> CInt
      -> Ptr Double
      -> IO CInt)
-> IO CInt
forall c b r. TransArray c => c -> (b -> IO r) -> Trans c b -> IO r
# Vector CInt
csrCols Vector CInt
-> ((CInt
     -> Ptr CInt -> CInt -> Ptr Double -> CInt -> Ptr Double -> IO CInt)
    -> IO CInt)
-> Trans
     (Vector CInt)
     (CInt
      -> Ptr CInt -> CInt -> Ptr Double -> CInt -> Ptr Double -> IO CInt)
-> IO CInt
forall c b r. TransArray c => c -> (b -> IO r) -> Trans c b -> IO r
# Vector CInt
csrRows Vector CInt
-> ((CInt -> Ptr Double -> CInt -> Ptr Double -> IO CInt)
    -> IO CInt)
-> Trans
     (Vector CInt) (CInt -> Ptr Double -> CInt -> Ptr Double -> IO CInt)
-> IO CInt
forall c b r. TransArray c => c -> (b -> IO r) -> Trans c b -> IO r
# Vector Double
v Vector Double
-> Vector Double
-> Trans (Vector Double) (Trans (Vector Double) (IO CInt))
-> IO CInt
forall c c1 r.
(TransArray c, TransArray c1) =>
c1 -> c -> Trans c1 (Trans c (IO r)) -> IO r
#! Vector Double
r) Trans
  (Vector Double)
  (CInt
   -> Ptr CInt
   -> CInt
   -> Ptr CInt
   -> CInt
   -> Ptr Double
   -> CInt
   -> Ptr Double
   -> IO CInt)
SMxV
c_smXv IO CInt -> [Char] -> IO ()
#|"CSRXv"
    Vector Double -> IO (Vector Double)
forall (m :: * -> *) a. Monad m => a -> m a
return Vector Double
r

gmXv SparseC { gmCSC :: GMatrix -> CSC
gmCSC = CSC{..}, .. } v :: Vector Double
v = IO (Vector Double) -> Vector Double
forall a. IO a -> a
unsafePerformIO (IO (Vector Double) -> Vector Double)
-> IO (Vector Double) -> Vector Double
forall a b. (a -> b) -> a -> b
$ do
    Vector Double -> Int
forall t. Storable t => Vector t -> Int
dim Vector Double
v Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
/= Int
nCols Bool -> [Char] -> IO ()
forall (f :: * -> *). Applicative f => Bool -> [Char] -> f ()
~!~ [Char] -> Int -> Int -> Int -> [Char]
forall r. PrintfType r => [Char] -> r
printf "gmXv (CSC): incorrect sizes: (%d,%d) x %d" Int
nRows Int
nCols (Vector Double -> Int
forall t. Storable t => Vector t -> Int
dim Vector Double
v)
    Vector Double
r <- Int -> IO (Vector Double)
forall a. Storable a => Int -> IO (Vector a)
createVector Int
nRows
    (Vector Double
cscVals Vector Double
-> ((CInt
     -> Ptr CInt
     -> CInt
     -> Ptr CInt
     -> CInt
     -> Ptr Double
     -> CInt
     -> Ptr Double
     -> IO CInt)
    -> IO CInt)
-> Trans
     (Vector Double)
     (CInt
      -> Ptr CInt
      -> CInt
      -> Ptr CInt
      -> CInt
      -> Ptr Double
      -> CInt
      -> Ptr Double
      -> IO CInt)
-> IO CInt
forall c b r. TransArray c => c -> (b -> IO r) -> Trans c b -> IO r
# Vector CInt
cscRows Vector CInt
-> ((CInt
     -> Ptr CInt -> CInt -> Ptr Double -> CInt -> Ptr Double -> IO CInt)
    -> IO CInt)
-> Trans
     (Vector CInt)
     (CInt
      -> Ptr CInt -> CInt -> Ptr Double -> CInt -> Ptr Double -> IO CInt)
-> IO CInt
forall c b r. TransArray c => c -> (b -> IO r) -> Trans c b -> IO r
# Vector CInt
cscCols Vector CInt
-> ((CInt -> Ptr Double -> CInt -> Ptr Double -> IO CInt)
    -> IO CInt)
-> Trans
     (Vector CInt) (CInt -> Ptr Double -> CInt -> Ptr Double -> IO CInt)
-> IO CInt
forall c b r. TransArray c => c -> (b -> IO r) -> Trans c b -> IO r
# Vector Double
v Vector Double
-> Vector Double
-> Trans (Vector Double) (Trans (Vector Double) (IO CInt))
-> IO CInt
forall c c1 r.
(TransArray c, TransArray c1) =>
c1 -> c -> Trans c1 (Trans c (IO r)) -> IO r
#! Vector Double
r) Trans
  (Vector Double)
  (CInt
   -> Ptr CInt
   -> CInt
   -> Ptr CInt
   -> CInt
   -> Ptr Double
   -> CInt
   -> Ptr Double
   -> IO CInt)
SMxV
c_smTXv IO CInt -> [Char] -> IO ()
#|"CSCXv"
    Vector Double -> IO (Vector Double)
forall (m :: * -> *) a. Monad m => a -> m a
return Vector Double
r

gmXv Diag{..} v :: Vector Double
v
    | Vector Double -> Int
forall t. Storable t => Vector t -> Int
dim Vector Double
v Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
nCols
        = [Vector Double] -> Vector Double
forall t. Storable t => [Vector t] -> Vector t
vjoin [ Int -> Int -> Vector Double -> Vector Double
forall t. Storable t => Int -> Int -> Vector t -> Vector t
subVector 0 (Vector Double -> Int
forall t. Storable t => Vector t -> Int
dim Vector Double
diagVals) Vector Double
v Vector Double -> Vector Double -> Vector Double
forall (c :: * -> *) e. Container c e => c e -> c e -> c e
`mul` Vector Double
diagVals
                , Double -> Int -> Vector Double
forall e d (c :: * -> *). Konst e d c => e -> d -> c e
konst 0 (Int
nRows Int -> Int -> Int
forall a. Num a => a -> a -> a
- Vector Double -> Int
forall t. Storable t => Vector t -> Int
dim Vector Double
diagVals) ]
    | Bool
otherwise = [Char] -> Vector Double
forall a. HasCallStack => [Char] -> a
error ([Char] -> Vector Double) -> [Char] -> Vector Double
forall a b. (a -> b) -> a -> b
$ [Char] -> Int -> Int -> Int -> Int -> [Char]
forall r. PrintfType r => [Char] -> r
printf "gmXv (Diag): incorrect sizes: (%d,%d) [%d] x %d"
                                 Int
nRows Int
nCols (Vector Double -> Int
forall t. Storable t => Vector t -> Int
dim Vector Double
diagVals) (Vector Double -> Int
forall t. Storable t => Vector t -> Int
dim Vector Double
v)

gmXv Dense{..} v :: Vector Double
v
    | Vector Double -> Int
forall t. Storable t => Vector t -> Int
dim Vector Double
v Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
nCols
        = Matrix Double -> Vector Double -> Vector Double
forall t. Product t => Matrix t -> Vector t -> Vector t
mXv Matrix Double
gmDense Vector Double
v
    | Bool
otherwise = [Char] -> Vector Double
forall a. HasCallStack => [Char] -> a
error ([Char] -> Vector Double) -> [Char] -> Vector Double
forall a b. (a -> b) -> a -> b
$ [Char] -> Int -> Int -> Int -> [Char]
forall r. PrintfType r => [Char] -> r
printf "gmXv (Dense): incorrect sizes: (%d,%d) x %d"
                                 Int
nRows Int
nCols (Vector Double -> Int
forall t. Storable t => Vector t -> Int
dim Vector Double
v)


{- | general matrix - vector product

>>> let m = mkSparse [((0,999),1.0),((1,1999),2.0)]
m :: GMatrix
>>> m !#> vector [1..2000]
[1000.0,4000.0]
it :: Vector Double

-}
infixr 8 !#>
(!#>) :: GMatrix -> Vector Double -> Vector Double
!#> :: GMatrix -> Vector Double -> Vector Double
(!#>) = GMatrix -> Vector Double -> Vector Double
gmXv

--------------------------------------------------------------------------------

foreign import ccall unsafe "smXv"
  c_smXv :: SMxV

foreign import ccall unsafe "smTXv"
  c_smTXv :: SMxV

--------------------------------------------------------------------------------

toDense :: AssocMatrix -> Matrix Double
toDense :: AssocMatrix -> Matrix Double
toDense asm :: AssocMatrix
asm = IndexOf Matrix
-> Double -> [(IndexOf Matrix, Double)] -> Matrix Double
forall (c :: * -> *) e.
Container c e =>
IndexOf c -> e -> [(IndexOf c, e)] -> c e
assoc (Int
rInt -> Int -> Int
forall a. Num a => a -> a -> a
+1,Int
cInt -> Int -> Int
forall a. Num a => a -> a -> a
+1) 0 AssocMatrix
[(IndexOf Matrix, Double)]
asm
  where
    (r :: Int
r,c :: Int
c) = ([Int] -> Int
forall (t :: * -> *) a. (Foldable t, Ord a) => t a -> a
maximum ([Int] -> Int) -> ([Int] -> Int) -> ([Int], [Int]) -> (Int, Int)
forall (a :: * -> * -> *) b c b' c'.
Arrow a =>
a b c -> a b' c' -> a (b, b') (c, c')
*** [Int] -> Int
forall (t :: * -> *) a. (Foldable t, Ord a) => t a -> a
maximum) (([Int], [Int]) -> (Int, Int))
-> (AssocMatrix -> ([Int], [Int])) -> AssocMatrix -> (Int, Int)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. [(Int, Int)] -> ([Int], [Int])
forall a b. [(a, b)] -> ([a], [b])
unzip ([(Int, Int)] -> ([Int], [Int]))
-> (AssocMatrix -> [(Int, Int)]) -> AssocMatrix -> ([Int], [Int])
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (((Int, Int), Double) -> (Int, Int)) -> AssocMatrix -> [(Int, Int)]
forall a b. (a -> b) -> [a] -> [b]
map ((Int, Int), Double) -> (Int, Int)
forall a b. (a, b) -> a
fst (AssocMatrix -> (Int, Int)) -> AssocMatrix -> (Int, Int)
forall a b. (a -> b) -> a -> b
$ AssocMatrix
asm


instance Transposable CSR CSC
  where
    tr :: CSR -> CSC
tr (CSR vs :: Vector Double
vs cs :: Vector CInt
cs rs :: Vector CInt
rs n :: Int
n m :: Int
m) = Vector Double -> Vector CInt -> Vector CInt -> Int -> Int -> CSC
CSC Vector Double
vs Vector CInt
cs Vector CInt
rs Int
m Int
n
    tr' :: CSR -> CSC
tr' = CSR -> CSC
forall m mt. Transposable m mt => m -> mt
tr

instance Transposable CSC CSR
  where
    tr :: CSC -> CSR
tr (CSC vs :: Vector Double
vs rs :: Vector CInt
rs cs :: Vector CInt
cs n :: Int
n m :: Int
m) = Vector Double -> Vector CInt -> Vector CInt -> Int -> Int -> CSR
CSR Vector Double
vs Vector CInt
rs Vector CInt
cs Int
m Int
n
    tr' :: CSC -> CSR
tr' = CSC -> CSR
forall m mt. Transposable m mt => m -> mt
tr

instance Transposable GMatrix GMatrix
  where
    tr :: GMatrix -> GMatrix
tr (SparseR s :: CSR
s n :: Int
n m :: Int
m) = CSC -> Int -> Int -> GMatrix
SparseC (CSR -> CSC
forall m mt. Transposable m mt => m -> mt
tr CSR
s) Int
m Int
n
    tr (SparseC s :: CSC
s n :: Int
n m :: Int
m) = CSR -> Int -> Int -> GMatrix
SparseR (CSC -> CSR
forall m mt. Transposable m mt => m -> mt
tr CSC
s) Int
m Int
n
    tr (Diag v :: Vector Double
v n :: Int
n m :: Int
m) = Vector Double -> Int -> Int -> GMatrix
Diag Vector Double
v Int
m Int
n
    tr (Dense a :: Matrix Double
a n :: Int
n m :: Int
m) = Matrix Double -> Int -> Int -> GMatrix
Dense (Matrix Double -> Matrix Double
forall m mt. Transposable m mt => m -> mt
tr Matrix Double
a) Int
m Int
n
    tr' :: GMatrix -> GMatrix
tr' = GMatrix -> GMatrix
forall m mt. Transposable m mt => m -> mt
tr