Difference between revisions of "Example Polynomial Fitting"
		
		
		
		
		
		Jump to navigation
		Jump to search
		
				
		
		
	
| (One intermediate revision by the same user not shown) | |||
| Line 4: | Line 4: | ||
| External Resources: | External Resources: | ||
| − | * [https://github.com/lessthanoptimal/ejml/blob/v0. | + | * [https://github.com/lessthanoptimal/ejml/blob/v0.41/examples/src/org/ejml/example/PolynomialFit.java PolynomialFit.java source code] | 
| − | |||
| = PolynomialFit Example Code = | = PolynomialFit Example Code = | ||
| Line 24: | Line 23: | ||
|   *  <li>reshaping</li> |   *  <li>reshaping</li> | ||
|   * </ol> |   * </ol> | ||
| + |  * | ||
|   * @author Peter Abeles |   * @author Peter Abeles | ||
|   */ |   */ | ||
| public class PolynomialFit { | public class PolynomialFit { | ||
| − | |||
|      // Vandermonde matrix |      // Vandermonde matrix | ||
|      DMatrixRMaj A; |      DMatrixRMaj A; | ||
| Line 44: | Line 43: | ||
|       */ |       */ | ||
|      public PolynomialFit( int degree ) { |      public PolynomialFit( int degree ) { | ||
| − |          coef = new DMatrixRMaj(degree+1,1); | + |          coef = new DMatrixRMaj(degree + 1, 1); | 
| − |          A = new DMatrixRMaj(1,degree+1); | + |          A = new DMatrixRMaj(1, degree + 1); | 
| − |          y = new DMatrixRMaj(1,1); | + |          y = new DMatrixRMaj(1, 1); | 
|          // create a solver that allows elements to be added or removed efficiently |          // create a solver that allows elements to be added or removed efficiently | ||
| Line 67: | Line 66: | ||
|       * @param observations A set of observations. |       * @param observations A set of observations. | ||
|       */ |       */ | ||
| − |      public void fit( double  | + |      public void fit( double[] samplePoints, double[] observations ) { | 
|          // Create a copy of the observations and put it into a matrix |          // Create a copy of the observations and put it into a matrix | ||
| − |          y.reshape(observations.length,1,false); | + |          y.reshape(observations.length, 1, false); | 
| − |          System.arraycopy(observations,0, y.data,0,observations.length); | + |          System.arraycopy(observations, 0, y.data, 0, observations.length); | 
|          // reshape the matrix to avoid unnecessarily declaring new memory |          // reshape the matrix to avoid unnecessarily declaring new memory | ||
|          // save values is set to false since its old values don't matter |          // save values is set to false since its old values don't matter | ||
| − |          A.reshape(y.numRows, coef.numRows,false); | + |          A.reshape(y.numRows, coef.numRows, false); | 
|          // set up the A matrix |          // set up the A matrix | ||
| − |          for( int i = 0; i < observations.length; i++ ) { | + |          for (int i = 0; i < observations.length; i++) { | 
|              double obs = 1; |              double obs = 1; | ||
| − |              for( int j = 0; j < coef.numRows; j++ ) { | + |              for (int j = 0; j < coef.numRows; j++) { | 
| − |                  A.set(i,j,obs); | + |                  A.set(i, j, obs); | 
|                  obs *= samplePoints[i]; |                  obs *= samplePoints[i]; | ||
|              } |              } | ||
| Line 88: | Line 87: | ||
|          // process the A matrix and see if it failed |          // process the A matrix and see if it failed | ||
| − |          if( !solver.setA(A) ) | + |          if (!solver.setA(A)) | 
|              throw new RuntimeException("Solver failed"); |              throw new RuntimeException("Solver failed"); | ||
|          // solver the the coefficients |          // solver the the coefficients | ||
| − |          solver.solve(y,coef); | + |          solver.solve(y, coef); | 
|      } |      } | ||
| Line 103: | Line 102: | ||
|      public void removeWorstFit() { |      public void removeWorstFit() { | ||
|          // find the observation with the most error |          // find the observation with the most error | ||
| − |          int worstIndex=-1; | + |          int worstIndex = -1; | 
|          double worstError = -1; |          double worstError = -1; | ||
| − |          for( int i = 0; i < y.numRows; i++ ) { | + |          for (int i = 0; i < y.numRows; i++) { | 
|              double predictedObs = 0; |              double predictedObs = 0; | ||
| − |              for( int j = 0; j < coef.numRows; j++ ) { | + |              for (int j = 0; j < coef.numRows; j++) { | 
| − |                  predictedObs += A.get(i,j)*coef.get(j,0); | + |                  predictedObs += A.get(i, j)*coef.get(j, 0); | 
|              } |              } | ||
| − |              double error = Math.abs(predictedObs- y.get(i,0)); | + |              double error = Math.abs(predictedObs - y.get(i, 0)); | 
| − |              if( error > worstError ) { | + |              if (error > worstError) { | 
|                  worstError = error; |                  worstError = error; | ||
|                  worstIndex = i; |                  worstIndex = i; | ||
| Line 122: | Line 121: | ||
|          // nothing left to remove, so just return |          // nothing left to remove, so just return | ||
| − |          if( worstIndex == -1 ) | + |          if (worstIndex == -1) | 
|              return; |              return; | ||
| Line 132: | Line 131: | ||
|          // solve for the parameters again |          // solve for the parameters again | ||
| − |          solver.solve(y,coef); | + |          solver.solve(y, coef); | 
|      } |      } | ||
| Line 141: | Line 140: | ||
|       */ |       */ | ||
|      private void removeObservation( int index ) { |      private void removeObservation( int index ) { | ||
| − |          final int N = y.numRows-1; | + |          final int N = y.numRows - 1; | 
| − |          final double  | + |          final double[] d = y.data; | 
|          // shift |          // shift | ||
| − |          for( int i = index; i < N; i++ ) { | + |          for (int i = index; i < N; i++) { | 
| − |              d[i] = d[i+1]; | + |              d[i] = d[i + 1]; | 
|          } |          } | ||
|          y.numRows--; |          y.numRows--; | ||
Latest revision as of 08:30, 7 July 2021
In this example it is shown how EJML can be used to fit a polynomial of arbitrary degree to a set of data. The key concepts shown here are; 1) how to create a linear using LinearSolverFactory, 2) use an adjustable linear solver, 3) and effective matrix reshaping. This is all done using the procedural interface.
First a best fit polynomial is fit to a set of data and then a outliers are removed from the observation set and the coefficients recomputed. Outliers are removed efficiently using an adjustable solver that does not resolve the whole system again.
External Resources:
PolynomialFit Example Code
/**
 * <p>
 * This example demonstrates how a polynomial can be fit to a set of data.  This is done by
 * using a least squares solver that is adjustable.  By using an adjustable solver elements
 * can be inexpensively removed and the coefficients recomputed.  This is much less expensive
 * than resolving the whole system from scratch.
 * </p>
 * <p>
 * The following is demonstrated:<br>
 * <ol>
 *  <li>Creating a solver using LinearSolverFactory</li>
 *  <li>Using an adjustable solver</li>
 *  <li>reshaping</li>
 * </ol>
 *
 * @author Peter Abeles
 */
public class PolynomialFit {
    // Vandermonde matrix
    DMatrixRMaj A;
    // matrix containing computed polynomial coefficients
    DMatrixRMaj coef;
    // observation matrix
    DMatrixRMaj y;
    // solver used to compute
    AdjustableLinearSolver_DDRM solver;
    /**
     * Constructor.
     *
     * @param degree The polynomial's degree which is to be fit to the observations.
     */
    public PolynomialFit( int degree ) {
        coef = new DMatrixRMaj(degree + 1, 1);
        A = new DMatrixRMaj(1, degree + 1);
        y = new DMatrixRMaj(1, 1);
        // create a solver that allows elements to be added or removed efficiently
        solver = LinearSolverFactory_DDRM.adjustable();
    }
    /**
     * Returns the computed coefficients
     *
     * @return polynomial coefficients that best fit the data.
     */
    public double[] getCoef() {
        return coef.data;
    }
    /**
     * Computes the best fit set of polynomial coefficients to the provided observations.
     *
     * @param samplePoints where the observations were sampled.
     * @param observations A set of observations.
     */
    public void fit( double[] samplePoints, double[] observations ) {
        // Create a copy of the observations and put it into a matrix
        y.reshape(observations.length, 1, false);
        System.arraycopy(observations, 0, y.data, 0, observations.length);
        // reshape the matrix to avoid unnecessarily declaring new memory
        // save values is set to false since its old values don't matter
        A.reshape(y.numRows, coef.numRows, false);
        // set up the A matrix
        for (int i = 0; i < observations.length; i++) {
            double obs = 1;
            for (int j = 0; j < coef.numRows; j++) {
                A.set(i, j, obs);
                obs *= samplePoints[i];
            }
        }
        // process the A matrix and see if it failed
        if (!solver.setA(A))
            throw new RuntimeException("Solver failed");
        // solver the the coefficients
        solver.solve(y, coef);
    }
    /**
     * Removes the observation that fits the model the worst and recomputes the coefficients.
     * This is done efficiently by using an adjustable solver.  Often times the elements with
     * the largest errors are outliers and not part of the system being modeled.  By removing them
     * a more accurate set of coefficients can be computed.
     */
    public void removeWorstFit() {
        // find the observation with the most error
        int worstIndex = -1;
        double worstError = -1;
        for (int i = 0; i < y.numRows; i++) {
            double predictedObs = 0;
            for (int j = 0; j < coef.numRows; j++) {
                predictedObs += A.get(i, j)*coef.get(j, 0);
            }
            double error = Math.abs(predictedObs - y.get(i, 0));
            if (error > worstError) {
                worstError = error;
                worstIndex = i;
            }
        }
        // nothing left to remove, so just return
        if (worstIndex == -1)
            return;
        // remove that observation
        removeObservation(worstIndex);
        // update A
        solver.removeRowFromA(worstIndex);
        // solve for the parameters again
        solver.solve(y, coef);
    }
    /**
     * Removes an element from the observation matrix.
     *
     * @param index which element is to be removed
     */
    private void removeObservation( int index ) {
        final int N = y.numRows - 1;
        final double[] d = y.data;
        // shift
        for (int i = index; i < N; i++) {
            d[i] = d[i + 1];
        }
        y.numRows--;
    }
}