{-# OPTIONS_GHC -fno-warn-name-shadowing #-}
-----------------------------------------------------------------------------
-- |
-- Module      :  Diagrams.Solve.Tridiagonal
-- Copyright   :  (c) 2011-2015 diagrams-solve team (see LICENSE)
-- License     :  BSD-style (see LICENSE)
-- Maintainer  :  diagrams-discuss@googlegroups.com
--
-- Solving of tridiagonal and cyclic tridiagonal linear systems.
--
-----------------------------------------------------------------------------
module Diagrams.Solve.Tridiagonal
       ( solveTriDiagonal
       , solveCyclicTriDiagonal
       ) where

-- | @solveTriDiagonal as bs cs ds@ solves a system of the form @A*X = ds@
--   where 'A' is an 'n' by 'n' matrix with 'bs' as the main diagonal
--   and 'as' the diagonal below and 'cs' the diagonal above.  See:
--   <http://en.wikipedia.org/wiki/Tridiagonal_matrix_algorithm>

solveTriDiagonal :: Fractional a => [a] -> [a] -> [a] -> [a] -> [a]
solveTriDiagonal :: [a] -> [a] -> [a] -> [a] -> [a]
solveTriDiagonal as :: [a]
as (b0 :: a
b0:bs :: [a]
bs) (c0 :: a
c0:cs :: [a]
cs) (d0 :: a
d0:ds :: [a]
ds) = [a] -> [a] -> [a]
forall a. Num a => [a] -> [a] -> [a]
h [a]
cs' [a]
ds'
  where
    cs' :: [a]
cs' = a
c0 a -> a -> a
forall a. Fractional a => a -> a -> a
/ a
b0 a -> [a] -> [a]
forall a. a -> [a] -> [a]
: [a] -> [a] -> [a] -> [a] -> [a]
forall a. Fractional a => [a] -> [a] -> [a] -> [a] -> [a]
f [a]
cs' [a]
as [a]
bs [a]
cs
    f :: [a] -> [a] -> [a] -> [a] -> [a]
f _ [_] _ _ = []
    f (c' :: a
c':cs' :: [a]
cs') (a :: a
a:as :: [a]
as) (b :: a
b:bs :: [a]
bs) (c :: a
c:cs :: [a]
cs) = a
c a -> a -> a
forall a. Fractional a => a -> a -> a
/ (a
b a -> a -> a
forall a. Num a => a -> a -> a
- a
c' a -> a -> a
forall a. Num a => a -> a -> a
* a
a) a -> [a] -> [a]
forall a. a -> [a] -> [a]
: [a] -> [a] -> [a] -> [a] -> [a]
f [a]
cs' [a]
as [a]
bs [a]
cs
    f _ _ _ _ = [Char] -> [a]
forall a. HasCallStack => [Char] -> a
error "solveTriDiagonal.f: impossible!"

    ds' :: [a]
ds' = a
d0 a -> a -> a
forall a. Fractional a => a -> a -> a
/ a
b0 a -> [a] -> [a]
forall a. a -> [a] -> [a]
: [a] -> [a] -> [a] -> [a] -> [a] -> [a]
forall a. Fractional a => [a] -> [a] -> [a] -> [a] -> [a] -> [a]
g [a]
ds' [a]
as [a]
bs [a]
cs' [a]
ds
    g :: [a] -> [a] -> [a] -> [a] -> [a] -> [a]
g _ [] _ _ _ = []
    g (d' :: a
d':ds' :: [a]
ds') (a :: a
a:as :: [a]
as) (b :: a
b:bs :: [a]
bs) (c' :: a
c':cs' :: [a]
cs') (d :: a
d:ds :: [a]
ds) = (a
d a -> a -> a
forall a. Num a => a -> a -> a
- a
d' a -> a -> a
forall a. Num a => a -> a -> a
* a
a)a -> a -> a
forall a. Fractional a => a -> a -> a
/(a
b a -> a -> a
forall a. Num a => a -> a -> a
- a
c' a -> a -> a
forall a. Num a => a -> a -> a
* a
a) a -> [a] -> [a]
forall a. a -> [a] -> [a]
: [a] -> [a] -> [a] -> [a] -> [a] -> [a]
g [a]
ds' [a]
as [a]
bs [a]
cs' [a]
ds
    g _ _ _ _ _ = [Char] -> [a]
forall a. HasCallStack => [Char] -> a
error "solveTriDiagonal.g: impossible!"

    h :: [a] -> [a] -> [a]
h _ [d :: a
d] = [a
d]
    h (c :: a
c:cs :: [a]
cs) (d :: a
d:ds :: [a]
ds) = let xs :: [a]
xs@(x :: a
x:_) = [a] -> [a] -> [a]
h [a]
cs [a]
ds in a
d a -> a -> a
forall a. Num a => a -> a -> a
- a
c a -> a -> a
forall a. Num a => a -> a -> a
* a
x a -> [a] -> [a]
forall a. a -> [a] -> [a]
: [a]
xs
    h _ _ = [Char] -> [a]
forall a. HasCallStack => [Char] -> a
error "solveTriDiagonal.h: impossible!"

solveTriDiagonal _ _ _ _ = [Char] -> [a]
forall a. HasCallStack => [Char] -> a
error "arguments 2,3,4 to solveTriDiagonal must be nonempty"

-- Helper that applies the passed function only to the last element of a list
modifyLast :: (a -> a) -> [a] -> [a]
modifyLast :: (a -> a) -> [a] -> [a]
modifyLast _ []     = []
modifyLast f :: a -> a
f [a :: a
a]    = [a -> a
f a
a]
modifyLast f :: a -> a
f (a :: a
a:as :: [a]
as) = a
a a -> [a] -> [a]
forall a. a -> [a] -> [a]
: (a -> a) -> [a] -> [a]
forall a. (a -> a) -> [a] -> [a]
modifyLast a -> a
f [a]
as

-- Helper that builds a list of length n of the form: '[s,m,m,...,m,m,e]'
sparseVector :: Int -> a -> a -> a -> [a]
sparseVector :: Int -> a -> a -> a -> [a]
sparseVector n :: Int
n s :: a
s m :: a
m e :: a
e
    | Int
n Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
< 1     = []
    | Bool
otherwise = a
s a -> [a] -> [a]
forall a. a -> [a] -> [a]
: Int -> [a]
forall t. (Eq t, Num t) => t -> [a]
h (Int
n Int -> Int -> Int
forall a. Num a => a -> a -> a
- 1)
  where
    h :: t -> [a]
h 1 = [a
e]
    h n :: t
n = a
m a -> [a] -> [a]
forall a. a -> [a] -> [a]
: t -> [a]
h (t
n t -> t -> t
forall a. Num a => a -> a -> a
- 1)

-- | Solves a system similar to the tri-diagonal system using a special case
--   of the Sherman-Morrison formula (<http://en.wikipedia.org/wiki/Sherman-Morrison_formula>).
--   This code is based on /Numerical Recpies in C/'s @cyclic@ function in section 2.7.
solveCyclicTriDiagonal :: Fractional a => [a] -> [a] -> [a] -> [a] -> a -> a -> [a]
solveCyclicTriDiagonal :: [a] -> [a] -> [a] -> [a] -> a -> a -> [a]
solveCyclicTriDiagonal as :: [a]
as (b0 :: a
b0:bs :: [a]
bs) cs :: [a]
cs ds :: [a]
ds alpha :: a
alpha beta :: a
beta = (a -> a -> a) -> [a] -> [a] -> [a]
forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith (a -> a -> a
forall a. Num a => a -> a -> a
(+) (a -> a -> a) -> (a -> a) -> a -> a -> a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (a
fact a -> a -> a
forall a. Num a => a -> a -> a
*)) [a]
zs [a]
xs
  where
    l :: Int
l = [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [a]
ds
    gamma :: a
gamma = -a
b0
    us :: [a]
us = Int -> a -> a -> a -> [a]
forall a. Int -> a -> a -> a -> [a]
sparseVector Int
l a
gamma 0 a
alpha

    bs' :: [a]
bs' = (a
b0 a -> a -> a
forall a. Num a => a -> a -> a
- a
gamma) a -> [a] -> [a]
forall a. a -> [a] -> [a]
: (a -> a) -> [a] -> [a]
forall a. (a -> a) -> [a] -> [a]
modifyLast (a -> a -> a
forall a. Num a => a -> a -> a
subtract (a
alphaa -> a -> a
forall a. Num a => a -> a -> a
*a
betaa -> a -> a
forall a. Fractional a => a -> a -> a
/a
gamma)) [a]
bs

    xs :: [a]
xs@(x :: a
x:_) = [a] -> [a] -> [a] -> [a] -> [a]
forall a. Fractional a => [a] -> [a] -> [a] -> [a] -> [a]
solveTriDiagonal [a]
as [a]
bs' [a]
cs [a]
ds
    zs :: [a]
zs@(z :: a
z:_) = [a] -> [a] -> [a] -> [a] -> [a]
forall a. Fractional a => [a] -> [a] -> [a] -> [a] -> [a]
solveTriDiagonal [a]
as [a]
bs' [a]
cs [a]
us

    fact :: a
fact = -(a
x a -> a -> a
forall a. Num a => a -> a -> a
+ a
beta a -> a -> a
forall a. Num a => a -> a -> a
* [a] -> a
forall a. [a] -> a
last [a]
xs a -> a -> a
forall a. Fractional a => a -> a -> a
/ a
gamma) a -> a -> a
forall a. Fractional a => a -> a -> a
/ (1.0 a -> a -> a
forall a. Num a => a -> a -> a
+ a
z a -> a -> a
forall a. Num a => a -> a -> a
+ a
beta a -> a -> a
forall a. Num a => a -> a -> a
* [a] -> a
forall a. [a] -> a
last [a]
zs a -> a -> a
forall a. Fractional a => a -> a -> a
/ a
gamma)

solveCyclicTriDiagonal _ _ _ _ _ _ = [Char] -> [a]
forall a. HasCallStack => [Char] -> a
error "second argument to solveCyclicTriDiagonal must be nonempty"