Basix
math.h
1// Copyright (C) 2021 Igor Baratta
2//
3// This file is part of DOLFINx (https://www.fenicsproject.org)
4//
5// SPDX-License-Identifier: LGPL-3.0-or-later
6
7#pragma once
8
9#include "mdspan.hpp"
10#include <array>
11#include <span>
12#include <vector>
13
18namespace basix::math
19{
20
21namespace impl
22{
27void dot_blas(const std::span<const double>& A,
28 std::array<std::size_t, 2> Ashape,
29 const std::span<const double>& B,
30 std::array<std::size_t, 2> Bshape, const std::span<double>& C);
31} // namespace impl
32
37template <typename U, typename V>
38std::pair<std::vector<typename U::value_type>, std::array<std::size_t, 2>>
39outer(const U& u, const V& v)
40{
41 std::vector<typename U::value_type> result(u.size() * v.size());
42 for (std::size_t i = 0; i < u.size(); ++i)
43 for (std::size_t j = 0; j < v.size(); ++j)
44 result[i * v.size() + j] = u[i] * v[j];
45
46 return {std::move(result), {u.size(), v.size()}};
47}
48
53template <typename U, typename V>
54std::array<typename U::value_type, 3> cross(const U& u, const V& v)
55{
56 assert(u.size() == 3);
57 assert(v.size() == 3);
58 return {u[1] * v[2] - u[2] * v[1], u[2] * v[0] - u[0] * v[2],
59 u[0] * v[1] - u[1] * v[0]};
60}
61
68std::pair<std::vector<double>, std::vector<double>>
69eigh(const std::span<const double>& A, std::size_t n);
70
75std::vector<double>
76solve(const std::experimental::mdspan<
77 const double, std::experimental::dextents<std::size_t, 2>>& A,
78 const std::experimental::mdspan<
79 const double, std::experimental::dextents<std::size_t, 2>>& B);
80
84bool is_singular(const std::experimental::mdspan<
85 const double, std::experimental::dextents<std::size_t, 2>>& A);
86
91std::vector<std::size_t>
92transpose_lu(std::pair<std::vector<double>, std::array<std::size_t, 2>>& A);
93
99template <typename U, typename V, typename W>
100void dot(const U& A, const V& B, W&& C)
101{
102 assert(A.extent(1) == B.extent(0));
103 assert(C.extent(0) == C.extent(0));
104 assert(C.extent(1) == B.extent(1));
105 if (A.extent(0) * B.extent(1) * A.extent(1) < 4096)
106 {
107 std::fill_n(C.data_handle(), C.extent(0) * C.extent(1), 0);
108 for (std::size_t i = 0; i < A.extent(0); ++i)
109 for (std::size_t j = 0; j < B.extent(1); ++j)
110 for (std::size_t k = 0; k < A.extent(1); ++k)
111 C(i, j) += A(i, k) * B(k, j);
112 }
113 else
114 {
115 impl::dot_blas(
116 std::span(A.data_handle(), A.size()), {A.extent(0), A.extent(1)},
117 std::span(B.data_handle(), B.size()), {B.extent(0), B.extent(1)},
118 std::span(C.data_handle(), C.size()));
119 }
120}
121
125std::vector<double> eye(std::size_t n);
126
127} // namespace basix::math
Definition: math.h:19
std::pair< std::vector< double >, std::vector< double > > eigh(const std::span< const double > &A, std::size_t n)
Definition: math.cpp:55
void dot(const U &A, const V &B, W &&C)
Definition: math.h:100
std::vector< std::size_t > transpose_lu(std::pair< std::vector< double >, std::array< std::size_t, 2 > > &A)
Definition: math.cpp:162
std::array< typename U::value_type, 3 > cross(const U &u, const V &v)
Definition: math.h:54
std::vector< double > solve(const std::experimental::mdspan< const double, std::experimental::dextents< std::size_t, 2 > > &A, const std::experimental::mdspan< const double, std::experimental::dextents< std::size_t, 2 > > &B)
Definition: math.cpp:92
bool is_singular(const std::experimental::mdspan< const double, std::experimental::dextents< std::size_t, 2 > > &A)
Definition: math.cpp:130
std::vector< double > eye(std::size_t n)
Definition: math.cpp:186
std::pair< std::vector< typename U::value_type >, std::array< std::size_t, 2 > > outer(const U &u, const V &v)
Compute the outer product of vectors u and v.
Definition: math.h:39