Skip to content

Adds LU decomposition algorithm #59

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Jul 4, 2017
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
88 changes: 48 additions & 40 deletions src/com/jwetherell/algorithms/data_structures/Matrix.java
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,11 @@
import java.util.Comparator;

/**
* Matrx. This Matrix implementation is designed to be more efficient
* Matrx. This Matrix implementation is designed to be more efficient
* in cache. A matrix is a rectangular array of numbers, symbols, or expressions.
*
* http://en.wikipedia.org/wiki/Matrix_(mathematics)
*
* <p>
* @see <a href="https://en.wikipedia.org/wiki/Matrix_(mathematics)">Matrix (Wikipedia)</a>
* <br>
* @author Justin Wetherell <[email protected]>
*/
@SuppressWarnings("unchecked")
Expand All @@ -29,27 +29,27 @@ public int compare(T o1, T o2) {
int result = 0;
if (o1 instanceof BigDecimal || o2 instanceof BigDecimal) {
BigDecimal c1 = (BigDecimal)o1;
BigDecimal c2 = (BigDecimal)o2;
BigDecimal c2 = (BigDecimal)o2;
result = c1.compareTo(c2);
} else if (o1 instanceof BigInteger || o2 instanceof BigInteger) {
BigInteger c1 = (BigInteger)o1;
BigInteger c2 = (BigInteger)o2;
BigInteger c2 = (BigInteger)o2;
result = c1.compareTo(c2);
} else if (o1 instanceof Long || o2 instanceof Long) {
Long c1 = o1.longValue();
Long c2 = o2.longValue();
Long c2 = o2.longValue();
result = c1.compareTo(c2);
} else if (o1 instanceof Double || o2 instanceof Double) {
Double c1 = o1.doubleValue();
Double c2 = o2.doubleValue();
result = c1.compareTo(c2);
Double c2 = o2.doubleValue();
result = c1.compareTo(c2);
} else if (o1 instanceof Float || o2 instanceof Float) {
Float c1 = o1.floatValue();
Float c2 = o2.floatValue();
result = c1.compareTo(c2);
Float c2 = o2.floatValue();
result = c1.compareTo(c2);
} else {
Integer c1 = o1.intValue();
Integer c2 = o2.intValue();
Integer c2 = o2.intValue();
result = c1.compareTo(c2);
}
return result;
Expand All @@ -58,7 +58,7 @@ public int compare(T o1, T o2) {

/**
* Matrix with 'rows' number of rows and 'cols' number of columns.
*
*
* @param rows Number of rows in Matrix.
* @param cols Number of columns in Matrix.
*/
Expand All @@ -71,7 +71,7 @@ public Matrix(int rows, int cols) {
/**
* Matrix with 'rows' number of rows and 'cols' number of columns, populates
* the double index matrix.
*
*
* @param rows Number of rows in Matrix.
* @param cols Number of columns in Matrix.
* @param matrix 2D matrix used to populate Matrix.
Expand Down Expand Up @@ -116,15 +116,15 @@ public void set(int row, int col, T value) {
}

public Matrix<T> identity() throws Exception{
if(this.rows != this.cols)
throw new Exception("Matrix should be a square");
if(this.rows != this.cols)
throw new Exception("Matrix should be a square");

final T element = this.get(0, 0);
final T zero;
final T one;
if (element instanceof BigDecimal) {
zero = (T)BigDecimal.ZERO;
one = (T)BigDecimal.ONE;
if (element instanceof BigDecimal) {
zero = (T)BigDecimal.ZERO;
one = (T)BigDecimal.ONE;
} else if(element instanceof BigInteger){
zero = (T)BigInteger.ZERO;
one = (T)BigInteger.ONE;
Expand All @@ -142,20 +142,20 @@ public Matrix<T> identity() throws Exception{
one = (T)new Integer(1);
}

final T array[][] = (T[][])new Number[this.rows][this.cols];
for(int i = 0; i < this.rows; ++i) {
for(int j = 0 ; j < this.cols; ++j){
array[i][j] = zero;
}
}

final Matrix<T> identityMatrix = new Matrix<T>(this.rows, this.cols, array);
for(int i = 0; i < this.rows;++i){
identityMatrix.set(i, i, one);
}
return identityMatrix;
final T array[][] = (T[][])new Number[this.rows][this.cols];
for(int i = 0; i < this.rows; ++i) {
for(int j = 0 ; j < this.cols; ++j){
array[i][j] = zero;
}
}

final Matrix<T> identityMatrix = new Matrix<T>(this.rows, this.cols, array);
for(int i = 0; i < this.rows;++i){
identityMatrix.set(i, i, one);
}
return identityMatrix;
}

public Matrix<T> add(Matrix<T> input) {
Matrix<T> output = new Matrix<T>(this.rows, this.cols);
if ((this.cols != input.cols) || (this.rows != input.rows))
Expand Down Expand Up @@ -249,7 +249,7 @@ public Matrix<T> multiply(Matrix<T> input) {
for (int i = 0; i < cols; i++) {
T m1 = row[i];
T m2 = column[i];

BigDecimal result2 = ((BigDecimal)m1).multiply(((BigDecimal)m2));
result = result.add(result2);
}
Expand All @@ -259,7 +259,7 @@ public Matrix<T> multiply(Matrix<T> input) {
for (int i = 0; i < cols; i++) {
T m1 = row[i];
T m2 = column[i];

BigInteger result2 = ((BigInteger)m1).multiply(((BigInteger)m2));
result = result.add(result2);
}
Expand All @@ -269,7 +269,7 @@ public Matrix<T> multiply(Matrix<T> input) {
for (int i = 0; i < cols; i++) {
T m1 = row[i];
T m2 = column[i];

Long result2 = m1.longValue() * m2.longValue();
result = result+result2;
}
Expand All @@ -279,7 +279,7 @@ public Matrix<T> multiply(Matrix<T> input) {
for (int i = 0; i < cols; i++) {
T m1 = row[i];
T m2 = column[i];

Double result2 = m1.doubleValue() * m2.doubleValue();
result = result+result2;
}
Expand All @@ -289,7 +289,7 @@ public Matrix<T> multiply(Matrix<T> input) {
for (int i = 0; i < cols; i++) {
T m1 = row[i];
T m2 = column[i];

Float result2 = m1.floatValue() * m2.floatValue();
result = result+result2;
}
Expand All @@ -300,7 +300,7 @@ public Matrix<T> multiply(Matrix<T> input) {
for (int i = 0; i < cols; i++) {
T m1 = row[i];
T m2 = column[i];

Integer result2 = m1.intValue() * m2.intValue();
result = result+result2;
}
Expand Down Expand Up @@ -348,7 +348,7 @@ public boolean equals(Object obj) {
for (int i=0; i<matrix.length; i++) {
T t1 = matrix[i];
T t2 = m.matrix[i];
int result = comparator.compare(t1, t2);
int result = comparator.compare(t1, t2);
if (result!=0)
return false;
}
Expand All @@ -371,4 +371,12 @@ public String toString() {
}
return builder.toString();
}
}

public int getRows() {
return rows;
}

public int getCols() {
return cols;
}
}
95 changes: 95 additions & 0 deletions src/com/jwetherell/algorithms/mathematics/LUDecomposition.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,95 @@
package com.jwetherell.algorithms.mathematics;

import com.jwetherell.algorithms.data_structures.Matrix;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;

/**
* LU decomposition of matrix M produces 2 matrices L and U such that M = L*U
* where L is lower triangular matrix and U is upper triangular matrix
* <p>
* https://en.wikipedia.org/wiki/LU_decomposition
* <br>
* @author Mateusz Cianciara <[email protected]>
*/
public class LUDecomposition {
private Double[][] L = null;
private Double[][] A = null;
private Integer[] permutation = null;
private int n = 0;

public Matrix<Double> getL() {
return new Matrix<Double>(n, n, L);
}

public Matrix<Double> getU() {
return new Matrix<Double>(n, n, A);
}

public List<Integer> getPermutation() {
return new ArrayList<Integer>(Arrays.asList(permutation));
}

public LUDecomposition(Matrix<Double> input) {
if (input.getCols() != input.getRows()) {
throw new IllegalArgumentException("Matrix is not square");
}
n = input.getCols();
L = new Double[n][n];
A = new Double[n][n];
permutation = new Integer[n];

for (int i = 0; i < n; i++) {
for (int j = 0; j < n; j++) {
L[i][j] = 0.0;
A[i][j] = input.get(i, j);
}
}
for (int i = 0; i < n; i++) {
L[i][i] = 1.0;
permutation[i] = i;
}
for (int row = 0; row < n; row++) {
// find max in column
int max_in_col = row;
double curr_big = Math.abs(A[row][row]);
for (int k = row + 1; k < n; k++) {
if (curr_big < Math.abs(A[k][row])) {
max_in_col = k;
curr_big = Math.abs(A[k][row]);
}
}

//swap rows
if (row != max_in_col) {
for (int i = 0; i < n; i++) {
double temp = A[row][i];
A[row][i] = A[max_in_col][i];
A[max_in_col][i] = temp;
if (i < row) {
temp = L[row][i];
L[row][i] = L[max_in_col][i];
L[max_in_col][i] = temp;
}
}
int temp = permutation[row];
permutation[row] = permutation[max_in_col];
permutation[max_in_col] = temp;
}
//zero column number row
double p = A[row][row];
if (p == 0) return;

for (int i = row + 1; i < n; i++) {
double y = A[i][row];
L[i][row] = y / p;

for (int j = row; j < n; j++) {
A[i][j] -= A[row][j] * (y / p);
}
}
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
package com.jwetherell.algorithms.mathematics.test;

import com.jwetherell.algorithms.data_structures.Matrix;
import com.jwetherell.algorithms.mathematics.LUDecomposition;
import org.junit.Test;

import static org.junit.Assert.*;


public class LUDecompositionTest {
private boolean epsiMatrixCompare(Matrix<Double> a, Matrix<Double> b, double epsi) {
if (a.getRows() != b.getRows() || a.getCols() != b.getCols()) {
throw new IllegalArgumentException("Matrices are not the same shape");
}
for (int i = 0; i < a.getRows(); i++) {
for (int j = 0; j < a.getCols(); j++) {
if (Math.abs(a.get(i, j) - b.get(i, j)) > epsi) {
return false;
}
}
}
return true;
}

@Test
public void decompositionTest1() throws Exception {
Double[][] m = new Double[][]{{4.0, 3.0}, {6.0, 3.0}};
Double[][] resultL = new Double[][]{{1.0, 0.0}, {2.0 / 3.0, 1.0}};
Double[][] resultU = new Double[][]{{6.0, 3.0}, {0.0, 1.0}};

LUDecomposition luDecomposition = new LUDecomposition(new Matrix<Double>(2, 2, m));
assertTrue(epsiMatrixCompare(luDecomposition.getL(), new Matrix<Double>(2, 2, resultL), 10e-4));
assertTrue(epsiMatrixCompare(luDecomposition.getU(), new Matrix<Double>(2, 2, resultU), 10e-4));
}

@Test
public void decompositionTest2() throws Exception {
Double[][] m = new Double[][]{{5.0, 3.0, 2.0}, {1.0, 2.0, 0.0}, {3.0, 0.0, 4.0}};
Double[][] resultL = new Double[][]{{1.0, 0.0, 0.0}, {0.6, 1.0, 0.0}, {0.2, -0.7778, 1.0}};
Double[][] resultU = new Double[][]{{5.0, 3.0, 2.0}, {0.0, -1.8, 2.8}, {0.0, 0.0, 1.778}};

LUDecomposition luDecomposition = new LUDecomposition(new Matrix<Double>(3, 3, m));
assertTrue(epsiMatrixCompare(luDecomposition.getL(), new Matrix<Double>(3, 3, resultL), 10e-4));
assertTrue(epsiMatrixCompare(luDecomposition.getU(), new Matrix<Double>(3, 3, resultU), 10e-4));
}
}