Skip to content

Commit 4af5215

Browse files
iradonovIvan Radonov
and
Ivan Radonov
authored
added Schur complement to linear algebra (TheAlgorithms#4793)
* added schur complement and tests to linear algebra * updated according to checklist * updated variable names and typing * added two testcases for input validation * fixed import order Co-authored-by: Ivan Radonov <ivan.radonov@ad.mentormate.bg>
1 parent fa88559 commit 4af5215

File tree

1 file changed

+94
-0
lines changed

1 file changed

+94
-0
lines changed
+94
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,94 @@
1+
import unittest
2+
3+
import numpy as np
4+
5+
6+
def schur_complement(
7+
mat_a: np.ndarray,
8+
mat_b: np.ndarray,
9+
mat_c: np.ndarray,
10+
pseudo_inv: np.ndarray = None,
11+
) -> np.ndarray:
12+
"""
13+
Schur complement of a symmetric matrix X given as a 2x2 block matrix
14+
consisting of matrices A, B and C.
15+
Matrix A must be quadratic and non-singular.
16+
In case A is singular, a pseudo-inverse may be provided using
17+
the pseudo_inv argument.
18+
19+
Link to Wiki: https://en.wikipedia.org/wiki/Schur_complement
20+
See also Convex Optimization – Boyd and Vandenberghe, A.5.5
21+
>>> import numpy as np
22+
>>> a = np.array([[1, 2], [2, 1]])
23+
>>> b = np.array([[0, 3], [3, 0]])
24+
>>> c = np.array([[2, 1], [6, 3]])
25+
>>> schur_complement(a, b, c)
26+
array([[ 5., -5.],
27+
[ 0., 6.]])
28+
"""
29+
shape_a = np.shape(mat_a)
30+
shape_b = np.shape(mat_b)
31+
shape_c = np.shape(mat_c)
32+
33+
if shape_a[0] != shape_b[0]:
34+
raise ValueError(
35+
f"Expected the same number of rows for A and B. \
36+
Instead found A of size {shape_a} and B of size {shape_b}"
37+
)
38+
39+
if shape_b[1] != shape_c[1]:
40+
raise ValueError(
41+
f"Expected the same number of columns for B and C. \
42+
Instead found B of size {shape_b} and C of size {shape_c}"
43+
)
44+
45+
a_inv = pseudo_inv
46+
if a_inv is None:
47+
try:
48+
a_inv = np.linalg.inv(mat_a)
49+
except np.linalg.LinAlgError:
50+
raise ValueError(
51+
"Input matrix A is not invertible. Cannot compute Schur complement."
52+
)
53+
54+
return mat_c - mat_b.T @ a_inv @ mat_b
55+
56+
57+
class TestSchurComplement(unittest.TestCase):
58+
def test_schur_complement(self) -> None:
59+
a = np.array([[1, 2, 1], [2, 1, 2], [3, 2, 4]])
60+
b = np.array([[0, 3], [3, 0], [2, 3]])
61+
c = np.array([[2, 1], [6, 3]])
62+
63+
s = schur_complement(a, b, c)
64+
65+
input_matrix = np.block([[a, b], [b.T, c]])
66+
67+
det_x = np.linalg.det(input_matrix)
68+
det_a = np.linalg.det(a)
69+
det_s = np.linalg.det(s)
70+
71+
self.assertAlmostEqual(det_x, det_a * det_s)
72+
73+
def test_improper_a_b_dimensions(self) -> None:
74+
a = np.array([[1, 2, 1], [2, 1, 2], [3, 2, 4]])
75+
b = np.array([[0, 3], [3, 0], [2, 3]])
76+
c = np.array([[2, 1], [6, 3]])
77+
78+
with self.assertRaises(ValueError):
79+
schur_complement(a, b, c)
80+
81+
def test_improper_b_c_dimensions(self) -> None:
82+
a = np.array([[1, 2, 1], [2, 1, 2], [3, 2, 4]])
83+
b = np.array([[0, 3], [3, 0], [2, 3]])
84+
c = np.array([[2, 1, 3], [6, 3, 5]])
85+
86+
with self.assertRaises(ValueError):
87+
schur_complement(a, b, c)
88+
89+
90+
if __name__ == "__main__":
91+
import doctest
92+
93+
doctest.testmod()
94+
unittest.main()

0 commit comments

Comments
 (0)