Difference between revisions of "Example Polynomial Fitting"
Jump to navigation
Jump to search
Line 4: | Line 4: | ||
External Resources: | External Resources: | ||
− | * [https://github.com/lessthanoptimal/ejml/blob/v0. | + | * [https://github.com/lessthanoptimal/ejml/blob/v0.41/examples/src/org/ejml/example/PolynomialFit.java PolynomialFit.java source code] |
* <disqus>Discuss this example</disqus> | * <disqus>Discuss this example</disqus> | ||
Line 24: | Line 24: | ||
* <li>reshaping</li> | * <li>reshaping</li> | ||
* </ol> | * </ol> | ||
+ | * | ||
* @author Peter Abeles | * @author Peter Abeles | ||
*/ | */ | ||
public class PolynomialFit { | public class PolynomialFit { | ||
− | |||
// Vandermonde matrix | // Vandermonde matrix | ||
DMatrixRMaj A; | DMatrixRMaj A; | ||
Line 44: | Line 44: | ||
*/ | */ | ||
public PolynomialFit( int degree ) { | public PolynomialFit( int degree ) { | ||
− | coef = new DMatrixRMaj(degree+1,1); | + | coef = new DMatrixRMaj(degree + 1, 1); |
− | A = new DMatrixRMaj(1,degree+1); | + | A = new DMatrixRMaj(1, degree + 1); |
− | y = new DMatrixRMaj(1,1); | + | y = new DMatrixRMaj(1, 1); |
// create a solver that allows elements to be added or removed efficiently | // create a solver that allows elements to be added or removed efficiently | ||
Line 67: | Line 67: | ||
* @param observations A set of observations. | * @param observations A set of observations. | ||
*/ | */ | ||
− | public void fit( double | + | public void fit( double[] samplePoints, double[] observations ) { |
// Create a copy of the observations and put it into a matrix | // Create a copy of the observations and put it into a matrix | ||
− | y.reshape(observations.length,1,false); | + | y.reshape(observations.length, 1, false); |
− | System.arraycopy(observations,0, y.data,0,observations.length); | + | System.arraycopy(observations, 0, y.data, 0, observations.length); |
// reshape the matrix to avoid unnecessarily declaring new memory | // reshape the matrix to avoid unnecessarily declaring new memory | ||
// save values is set to false since its old values don't matter | // save values is set to false since its old values don't matter | ||
− | A.reshape(y.numRows, coef.numRows,false); | + | A.reshape(y.numRows, coef.numRows, false); |
// set up the A matrix | // set up the A matrix | ||
− | for( int i = 0; i < observations.length; i++ ) { | + | for (int i = 0; i < observations.length; i++) { |
double obs = 1; | double obs = 1; | ||
− | for( int j = 0; j < coef.numRows; j++ ) { | + | for (int j = 0; j < coef.numRows; j++) { |
− | A.set(i,j,obs); | + | A.set(i, j, obs); |
obs *= samplePoints[i]; | obs *= samplePoints[i]; | ||
} | } | ||
Line 88: | Line 88: | ||
// process the A matrix and see if it failed | // process the A matrix and see if it failed | ||
− | if( !solver.setA(A) ) | + | if (!solver.setA(A)) |
throw new RuntimeException("Solver failed"); | throw new RuntimeException("Solver failed"); | ||
// solver the the coefficients | // solver the the coefficients | ||
− | solver.solve(y,coef); | + | solver.solve(y, coef); |
} | } | ||
Line 103: | Line 103: | ||
public void removeWorstFit() { | public void removeWorstFit() { | ||
// find the observation with the most error | // find the observation with the most error | ||
− | int worstIndex=-1; | + | int worstIndex = -1; |
double worstError = -1; | double worstError = -1; | ||
− | for( int i = 0; i < y.numRows; i++ ) { | + | for (int i = 0; i < y.numRows; i++) { |
double predictedObs = 0; | double predictedObs = 0; | ||
− | for( int j = 0; j < coef.numRows; j++ ) { | + | for (int j = 0; j < coef.numRows; j++) { |
− | predictedObs += A.get(i,j)*coef.get(j,0); | + | predictedObs += A.get(i, j)*coef.get(j, 0); |
} | } | ||
− | double error = Math.abs(predictedObs- y.get(i,0)); | + | double error = Math.abs(predictedObs - y.get(i, 0)); |
− | if( error > worstError ) { | + | if (error > worstError) { |
worstError = error; | worstError = error; | ||
worstIndex = i; | worstIndex = i; | ||
Line 122: | Line 122: | ||
// nothing left to remove, so just return | // nothing left to remove, so just return | ||
− | if( worstIndex == -1 ) | + | if (worstIndex == -1) |
return; | return; | ||
Line 132: | Line 132: | ||
// solve for the parameters again | // solve for the parameters again | ||
− | solver.solve(y,coef); | + | solver.solve(y, coef); |
} | } | ||
Line 141: | Line 141: | ||
*/ | */ | ||
private void removeObservation( int index ) { | private void removeObservation( int index ) { | ||
− | final int N = y.numRows-1; | + | final int N = y.numRows - 1; |
− | final double | + | final double[] d = y.data; |
// shift | // shift | ||
− | for( int i = index; i < N; i++ ) { | + | for (int i = index; i < N; i++) { |
− | d[i] = d[i+1]; | + | d[i] = d[i + 1]; |
} | } | ||
y.numRows--; | y.numRows--; |
Revision as of 07:29, 7 July 2021
In this example it is shown how EJML can be used to fit a polynomial of arbitrary degree to a set of data. The key concepts shown here are; 1) how to create a linear using LinearSolverFactory, 2) use an adjustable linear solver, 3) and effective matrix reshaping. This is all done using the procedural interface.
First a best fit polynomial is fit to a set of data and then a outliers are removed from the observation set and the coefficients recomputed. Outliers are removed efficiently using an adjustable solver that does not resolve the whole system again.
External Resources:
- PolynomialFit.java source code
- <disqus>Discuss this example</disqus>
PolynomialFit Example Code
/**
* <p>
* This example demonstrates how a polynomial can be fit to a set of data. This is done by
* using a least squares solver that is adjustable. By using an adjustable solver elements
* can be inexpensively removed and the coefficients recomputed. This is much less expensive
* than resolving the whole system from scratch.
* </p>
* <p>
* The following is demonstrated:<br>
* <ol>
* <li>Creating a solver using LinearSolverFactory</li>
* <li>Using an adjustable solver</li>
* <li>reshaping</li>
* </ol>
*
* @author Peter Abeles
*/
public class PolynomialFit {
// Vandermonde matrix
DMatrixRMaj A;
// matrix containing computed polynomial coefficients
DMatrixRMaj coef;
// observation matrix
DMatrixRMaj y;
// solver used to compute
AdjustableLinearSolver_DDRM solver;
/**
* Constructor.
*
* @param degree The polynomial's degree which is to be fit to the observations.
*/
public PolynomialFit( int degree ) {
coef = new DMatrixRMaj(degree + 1, 1);
A = new DMatrixRMaj(1, degree + 1);
y = new DMatrixRMaj(1, 1);
// create a solver that allows elements to be added or removed efficiently
solver = LinearSolverFactory_DDRM.adjustable();
}
/**
* Returns the computed coefficients
*
* @return polynomial coefficients that best fit the data.
*/
public double[] getCoef() {
return coef.data;
}
/**
* Computes the best fit set of polynomial coefficients to the provided observations.
*
* @param samplePoints where the observations were sampled.
* @param observations A set of observations.
*/
public void fit( double[] samplePoints, double[] observations ) {
// Create a copy of the observations and put it into a matrix
y.reshape(observations.length, 1, false);
System.arraycopy(observations, 0, y.data, 0, observations.length);
// reshape the matrix to avoid unnecessarily declaring new memory
// save values is set to false since its old values don't matter
A.reshape(y.numRows, coef.numRows, false);
// set up the A matrix
for (int i = 0; i < observations.length; i++) {
double obs = 1;
for (int j = 0; j < coef.numRows; j++) {
A.set(i, j, obs);
obs *= samplePoints[i];
}
}
// process the A matrix and see if it failed
if (!solver.setA(A))
throw new RuntimeException("Solver failed");
// solver the the coefficients
solver.solve(y, coef);
}
/**
* Removes the observation that fits the model the worst and recomputes the coefficients.
* This is done efficiently by using an adjustable solver. Often times the elements with
* the largest errors are outliers and not part of the system being modeled. By removing them
* a more accurate set of coefficients can be computed.
*/
public void removeWorstFit() {
// find the observation with the most error
int worstIndex = -1;
double worstError = -1;
for (int i = 0; i < y.numRows; i++) {
double predictedObs = 0;
for (int j = 0; j < coef.numRows; j++) {
predictedObs += A.get(i, j)*coef.get(j, 0);
}
double error = Math.abs(predictedObs - y.get(i, 0));
if (error > worstError) {
worstError = error;
worstIndex = i;
}
}
// nothing left to remove, so just return
if (worstIndex == -1)
return;
// remove that observation
removeObservation(worstIndex);
// update A
solver.removeRowFromA(worstIndex);
// solve for the parameters again
solver.solve(y, coef);
}
/**
* Removes an element from the observation matrix.
*
* @param index which element is to be removed
*/
private void removeObservation( int index ) {
final int N = y.numRows - 1;
final double[] d = y.data;
// shift
for (int i = index; i < N; i++) {
d[i] = d[i + 1];
}
y.numRows--;
}
}