Example Polynomial Fitting
Jump to navigation
Jump to search
In this example it is shown how EJML can be used to fit a polynomial of arbitrary degree to a set of data. The key concepts shown here are; 1) how to create a linear using LinearSolverFactory, 2) use an adjustable linear solver, 3) and effective matrix reshaping. This is all done using the procedural interface.
First a best fit polynomial is fit to a set of data and then a outliers are removed from the observation set and the coefficients recomputed. Outliers are removed efficiently using an adjustable solver that does not resolve the whole system again.
External Resources:
- PolynomialFit.java source code
- <disqus>Discuss this example</disqus>
PolynomialFit Example Code
/**
* <p>
* This example demonstrates how a polynomial can be fit to a set of data. This is done by
* using a least squares solver that is adjustable. By using an adjustable solver elements
* can be inexpensively removed and the coefficients recomputed. This is much less expensive
* than resolving the whole system from scratch.
* </p>
* <p>
* The following is demonstrated:<br>
* <ol>
* <li>Creating a solver using LinearSolverFactory</li>
* <li>Using an adjustable solver</li>
* <li>reshaping</li>
* </ol>
*
* @author Peter Abeles
*/
public class PolynomialFit {
// Vandermonde matrix
DMatrixRMaj A;
// matrix containing computed polynomial coefficients
DMatrixRMaj coef;
// observation matrix
DMatrixRMaj y;
// solver used to compute
AdjustableLinearSolver_DDRM solver;
/**
* Constructor.
*
* @param degree The polynomial's degree which is to be fit to the observations.
*/
public PolynomialFit( int degree ) {
coef = new DMatrixRMaj(degree + 1, 1);
A = new DMatrixRMaj(1, degree + 1);
y = new DMatrixRMaj(1, 1);
// create a solver that allows elements to be added or removed efficiently
solver = LinearSolverFactory_DDRM.adjustable();
}
/**
* Returns the computed coefficients
*
* @return polynomial coefficients that best fit the data.
*/
public double[] getCoef() {
return coef.data;
}
/**
* Computes the best fit set of polynomial coefficients to the provided observations.
*
* @param samplePoints where the observations were sampled.
* @param observations A set of observations.
*/
public void fit( double[] samplePoints, double[] observations ) {
// Create a copy of the observations and put it into a matrix
y.reshape(observations.length, 1, false);
System.arraycopy(observations, 0, y.data, 0, observations.length);
// reshape the matrix to avoid unnecessarily declaring new memory
// save values is set to false since its old values don't matter
A.reshape(y.numRows, coef.numRows, false);
// set up the A matrix
for (int i = 0; i < observations.length; i++) {
double obs = 1;
for (int j = 0; j < coef.numRows; j++) {
A.set(i, j, obs);
obs *= samplePoints[i];
}
}
// process the A matrix and see if it failed
if (!solver.setA(A))
throw new RuntimeException("Solver failed");
// solver the the coefficients
solver.solve(y, coef);
}
/**
* Removes the observation that fits the model the worst and recomputes the coefficients.
* This is done efficiently by using an adjustable solver. Often times the elements with
* the largest errors are outliers and not part of the system being modeled. By removing them
* a more accurate set of coefficients can be computed.
*/
public void removeWorstFit() {
// find the observation with the most error
int worstIndex = -1;
double worstError = -1;
for (int i = 0; i < y.numRows; i++) {
double predictedObs = 0;
for (int j = 0; j < coef.numRows; j++) {
predictedObs += A.get(i, j)*coef.get(j, 0);
}
double error = Math.abs(predictedObs - y.get(i, 0));
if (error > worstError) {
worstError = error;
worstIndex = i;
}
}
// nothing left to remove, so just return
if (worstIndex == -1)
return;
// remove that observation
removeObservation(worstIndex);
// update A
solver.removeRowFromA(worstIndex);
// solve for the parameters again
solver.solve(y, coef);
}
/**
* Removes an element from the observation matrix.
*
* @param index which element is to be removed
*/
private void removeObservation( int index ) {
final int N = y.numRows - 1;
final double[] d = y.data;
// shift
for (int i = index; i < N; i++) {
d[i] = d[i + 1];
}
y.numRows--;
}
}