Example Polynomial Fitting
Jump to navigation
Jump to search
The printable version is no longer supported and may have rendering errors. Please update your browser bookmarks and please use the default browser print function instead.
In this example it is shown how EJML can be used to fit a polynomial of arbitrary degree to a set of data. The key concepts shown here are; 1) how to create a linear using LinearSolverFactory, 2) use an adjustable linear solver, 3) and effective matrix reshaping. This is all done using the procedural interface.
First a best fit polynomial is fit to a set of data and then a outliers are removed from the observation set and the coefficients recomputed. Outliers are removed efficiently using an adjustable solver that does not resolve the whole system again.
External Resources:
- PolynomialFit.java source code
- <disqus>Discuss this example</disqus>
PolynomialFit Example Code
/**
* <p>
* This example demonstrates how a polynomial can be fit to a set of data. This is done by
* using a least squares solver that is adjustable. By using an adjustable solver elements
* can be inexpensively removed and the coefficients recomputed. This is much less expensive
* than resolving the whole system from scratch.
* </p>
* <p>
* The following is demonstrated:<br>
* <ol>
* <li>Creating a solver using LinearSolverFactory</li>
* <li>Using an adjustable solver</li>
* <li>reshaping</li>
* </ol>
* @author Peter Abeles
*/
public class PolynomialFit {
// Vandermonde matrix
DMatrixRMaj A;
// matrix containing computed polynomial coefficients
DMatrixRMaj coef;
// observation matrix
DMatrixRMaj y;
// solver used to compute
AdjustableLinearSolver_DDRM solver;
/**
* Constructor.
*
* @param degree The polynomial's degree which is to be fit to the observations.
*/
public PolynomialFit( int degree ) {
coef = new DMatrixRMaj(degree+1,1);
A = new DMatrixRMaj(1,degree+1);
y = new DMatrixRMaj(1,1);
// create a solver that allows elements to be added or removed efficiently
solver = LinearSolverFactory_DDRM.adjustable();
}
/**
* Returns the computed coefficients
*
* @return polynomial coefficients that best fit the data.
*/
public double[] getCoef() {
return coef.data;
}
/**
* Computes the best fit set of polynomial coefficients to the provided observations.
*
* @param samplePoints where the observations were sampled.
* @param observations A set of observations.
*/
public void fit( double samplePoints[] , double[] observations ) {
// Create a copy of the observations and put it into a matrix
y.reshape(observations.length,1,false);
System.arraycopy(observations,0, y.data,0,observations.length);
// reshape the matrix to avoid unnecessarily declaring new memory
// save values is set to false since its old values don't matter
A.reshape(y.numRows, coef.numRows,false);
// set up the A matrix
for( int i = 0; i < observations.length; i++ ) {
double obs = 1;
for( int j = 0; j < coef.numRows; j++ ) {
A.set(i,j,obs);
obs *= samplePoints[i];
}
}
// process the A matrix and see if it failed
if( !solver.setA(A) )
throw new RuntimeException("Solver failed");
// solver the the coefficients
solver.solve(y,coef);
}
/**
* Removes the observation that fits the model the worst and recomputes the coefficients.
* This is done efficiently by using an adjustable solver. Often times the elements with
* the largest errors are outliers and not part of the system being modeled. By removing them
* a more accurate set of coefficients can be computed.
*/
public void removeWorstFit() {
// find the observation with the most error
int worstIndex=-1;
double worstError = -1;
for( int i = 0; i < y.numRows; i++ ) {
double predictedObs = 0;
for( int j = 0; j < coef.numRows; j++ ) {
predictedObs += A.get(i,j)*coef.get(j,0);
}
double error = Math.abs(predictedObs- y.get(i,0));
if( error > worstError ) {
worstError = error;
worstIndex = i;
}
}
// nothing left to remove, so just return
if( worstIndex == -1 )
return;
// remove that observation
removeObservation(worstIndex);
// update A
solver.removeRowFromA(worstIndex);
// solve for the parameters again
solver.solve(y,coef);
}
/**
* Removes an element from the observation matrix.
*
* @param index which element is to be removed
*/
private void removeObservation( int index ) {
final int N = y.numRows-1;
final double d[] = y.data;
// shift
for( int i = index; i < N; i++ ) {
d[i] = d[i+1];
}
y.numRows--;
}
}