Difference between revisions of "Example Polynomial Fitting"

From Efficient Java Matrix Library
Jump to navigation Jump to search
 
(2 intermediate revisions by the same user not shown)
Line 4: Line 4:
  
 
External Resources:
 
External Resources:
* [https://github.com/lessthanoptimal/ejml/blob/v0.27/examples/src/org/ejml/example/PolynomialFit.java PolynomialFit.java source code]
+
* [https://github.com/lessthanoptimal/ejml/blob/v0.41/examples/src/org/ejml/example/PolynomialFit.java PolynomialFit.java source code]
* <disqus>Discuss this example</disqus>
 
  
 
= PolynomialFit Example Code =
 
= PolynomialFit Example Code =
Line 24: Line 23:
 
  *  <li>reshaping</li>
 
  *  <li>reshaping</li>
 
  * </ol>
 
  * </ol>
 +
*
 
  * @author Peter Abeles
 
  * @author Peter Abeles
 
  */
 
  */
 
public class PolynomialFit {
 
public class PolynomialFit {
 
 
     // Vandermonde matrix
 
     // Vandermonde matrix
     DenseMatrix64F A;
+
     DMatrixRMaj A;
 
     // matrix containing computed polynomial coefficients
 
     // matrix containing computed polynomial coefficients
     DenseMatrix64F coef;
+
     DMatrixRMaj coef;
 
     // observation matrix
 
     // observation matrix
     DenseMatrix64F y;
+
     DMatrixRMaj y;
  
 
     // solver used to compute
 
     // solver used to compute
     AdjustableLinearSolver solver;
+
     AdjustableLinearSolver_DDRM solver;
  
 
     /**
 
     /**
Line 44: Line 43:
 
     */
 
     */
 
     public PolynomialFit( int degree ) {
 
     public PolynomialFit( int degree ) {
         coef = new DenseMatrix64F(degree+1,1);
+
         coef = new DMatrixRMaj(degree + 1, 1);
         A = new DenseMatrix64F(1,degree+1);
+
         A = new DMatrixRMaj(1, degree + 1);
         y = new DenseMatrix64F(1,1);
+
         y = new DMatrixRMaj(1, 1);
  
 
         // create a solver that allows elements to be added or removed efficiently
 
         // create a solver that allows elements to be added or removed efficiently
         solver = LinearSolverFactory.adjustable();
+
         solver = LinearSolverFactory_DDRM.adjustable();
 
     }
 
     }
  
Line 67: Line 66:
 
     * @param observations A set of observations.
 
     * @param observations A set of observations.
 
     */
 
     */
     public void fit( double samplePoints[] , double[] observations ) {
+
     public void fit( double[] samplePoints, double[] observations ) {
 
         // Create a copy of the observations and put it into a matrix
 
         // Create a copy of the observations and put it into a matrix
         y.reshape(observations.length,1,false);
+
         y.reshape(observations.length, 1, false);
         System.arraycopy(observations,0, y.data,0,observations.length);
+
         System.arraycopy(observations, 0, y.data, 0, observations.length);
  
 
         // reshape the matrix to avoid unnecessarily declaring new memory
 
         // reshape the matrix to avoid unnecessarily declaring new memory
 
         // save values is set to false since its old values don't matter
 
         // save values is set to false since its old values don't matter
         A.reshape(y.numRows, coef.numRows,false);
+
         A.reshape(y.numRows, coef.numRows, false);
  
 
         // set up the A matrix
 
         // set up the A matrix
         for( int i = 0; i < observations.length; i++ ) {
+
         for (int i = 0; i < observations.length; i++) {
  
 
             double obs = 1;
 
             double obs = 1;
  
             for( int j = 0; j < coef.numRows; j++ ) {
+
             for (int j = 0; j < coef.numRows; j++) {
                 A.set(i,j,obs);
+
                 A.set(i, j, obs);
 
                 obs *= samplePoints[i];
 
                 obs *= samplePoints[i];
 
             }
 
             }
Line 88: Line 87:
  
 
         // process the A matrix and see if it failed
 
         // process the A matrix and see if it failed
         if( !solver.setA(A) )
+
         if (!solver.setA(A))
 
             throw new RuntimeException("Solver failed");
 
             throw new RuntimeException("Solver failed");
  
 
         // solver the the coefficients
 
         // solver the the coefficients
         solver.solve(y,coef);
+
         solver.solve(y, coef);
 
     }
 
     }
  
Line 103: Line 102:
 
     public void removeWorstFit() {
 
     public void removeWorstFit() {
 
         // find the observation with the most error
 
         // find the observation with the most error
         int worstIndex=-1;
+
         int worstIndex = -1;
 
         double worstError = -1;
 
         double worstError = -1;
  
         for( int i = 0; i < y.numRows; i++ ) {
+
         for (int i = 0; i < y.numRows; i++) {
 
             double predictedObs = 0;
 
             double predictedObs = 0;
  
             for( int j = 0; j < coef.numRows; j++ ) {
+
             for (int j = 0; j < coef.numRows; j++) {
                 predictedObs += A.get(i,j)*coef.get(j,0);
+
                 predictedObs += A.get(i, j)*coef.get(j, 0);
 
             }
 
             }
  
             double error = Math.abs(predictedObs- y.get(i,0));
+
             double error = Math.abs(predictedObs - y.get(i, 0));
  
             if( error > worstError ) {
+
             if (error > worstError) {
 
                 worstError = error;
 
                 worstError = error;
 
                 worstIndex = i;
 
                 worstIndex = i;
Line 122: Line 121:
  
 
         // nothing left to remove, so just return
 
         // nothing left to remove, so just return
         if( worstIndex == -1 )
+
         if (worstIndex == -1)
 
             return;
 
             return;
  
Line 132: Line 131:
  
 
         // solve for the parameters again
 
         // solve for the parameters again
         solver.solve(y,coef);
+
         solver.solve(y, coef);
 
     }
 
     }
  
Line 141: Line 140:
 
     */
 
     */
 
     private void removeObservation( int index ) {
 
     private void removeObservation( int index ) {
         final int N = y.numRows-1;
+
         final int N = y.numRows - 1;
         final double d[] = y.data;
+
         final double[] d = y.data;
  
 
         // shift
 
         // shift
         for( int i = index; i < N; i++ ) {
+
         for (int i = index; i < N; i++) {
             d[i] = d[i+1];
+
             d[i] = d[i + 1];
 
         }
 
         }
 
         y.numRows--;
 
         y.numRows--;

Latest revision as of 07:30, 7 July 2021

In this example it is shown how EJML can be used to fit a polynomial of arbitrary degree to a set of data. The key concepts shown here are; 1) how to create a linear using LinearSolverFactory, 2) use an adjustable linear solver, 3) and effective matrix reshaping. This is all done using the procedural interface.

First a best fit polynomial is fit to a set of data and then a outliers are removed from the observation set and the coefficients recomputed. Outliers are removed efficiently using an adjustable solver that does not resolve the whole system again.

External Resources:

PolynomialFit Example Code

/**
 * <p>
 * This example demonstrates how a polynomial can be fit to a set of data.  This is done by
 * using a least squares solver that is adjustable.  By using an adjustable solver elements
 * can be inexpensively removed and the coefficients recomputed.  This is much less expensive
 * than resolving the whole system from scratch.
 * </p>
 * <p>
 * The following is demonstrated:<br>
 * <ol>
 *  <li>Creating a solver using LinearSolverFactory</li>
 *  <li>Using an adjustable solver</li>
 *  <li>reshaping</li>
 * </ol>
 *
 * @author Peter Abeles
 */
public class PolynomialFit {
    // Vandermonde matrix
    DMatrixRMaj A;
    // matrix containing computed polynomial coefficients
    DMatrixRMaj coef;
    // observation matrix
    DMatrixRMaj y;

    // solver used to compute
    AdjustableLinearSolver_DDRM solver;

    /**
     * Constructor.
     *
     * @param degree The polynomial's degree which is to be fit to the observations.
     */
    public PolynomialFit( int degree ) {
        coef = new DMatrixRMaj(degree + 1, 1);
        A = new DMatrixRMaj(1, degree + 1);
        y = new DMatrixRMaj(1, 1);

        // create a solver that allows elements to be added or removed efficiently
        solver = LinearSolverFactory_DDRM.adjustable();
    }

    /**
     * Returns the computed coefficients
     *
     * @return polynomial coefficients that best fit the data.
     */
    public double[] getCoef() {
        return coef.data;
    }

    /**
     * Computes the best fit set of polynomial coefficients to the provided observations.
     *
     * @param samplePoints where the observations were sampled.
     * @param observations A set of observations.
     */
    public void fit( double[] samplePoints, double[] observations ) {
        // Create a copy of the observations and put it into a matrix
        y.reshape(observations.length, 1, false);
        System.arraycopy(observations, 0, y.data, 0, observations.length);

        // reshape the matrix to avoid unnecessarily declaring new memory
        // save values is set to false since its old values don't matter
        A.reshape(y.numRows, coef.numRows, false);

        // set up the A matrix
        for (int i = 0; i < observations.length; i++) {

            double obs = 1;

            for (int j = 0; j < coef.numRows; j++) {
                A.set(i, j, obs);
                obs *= samplePoints[i];
            }
        }

        // process the A matrix and see if it failed
        if (!solver.setA(A))
            throw new RuntimeException("Solver failed");

        // solver the the coefficients
        solver.solve(y, coef);
    }

    /**
     * Removes the observation that fits the model the worst and recomputes the coefficients.
     * This is done efficiently by using an adjustable solver.  Often times the elements with
     * the largest errors are outliers and not part of the system being modeled.  By removing them
     * a more accurate set of coefficients can be computed.
     */
    public void removeWorstFit() {
        // find the observation with the most error
        int worstIndex = -1;
        double worstError = -1;

        for (int i = 0; i < y.numRows; i++) {
            double predictedObs = 0;

            for (int j = 0; j < coef.numRows; j++) {
                predictedObs += A.get(i, j)*coef.get(j, 0);
            }

            double error = Math.abs(predictedObs - y.get(i, 0));

            if (error > worstError) {
                worstError = error;
                worstIndex = i;
            }
        }

        // nothing left to remove, so just return
        if (worstIndex == -1)
            return;

        // remove that observation
        removeObservation(worstIndex);

        // update A
        solver.removeRowFromA(worstIndex);

        // solve for the parameters again
        solver.solve(y, coef);
    }

    /**
     * Removes an element from the observation matrix.
     *
     * @param index which element is to be removed
     */
    private void removeObservation( int index ) {
        final int N = y.numRows - 1;
        final double[] d = y.data;

        // shift
        for (int i = index; i < N; i++) {
            d[i] = d[i + 1];
        }
        y.numRows--;
    }
}