Simple Linear Regression in 2 minutes. #machinelearning #linearregression #java

With certain data Simple Linear Regression wins and while the rest of the ML/AI world push tools that are far larger scope than needed for most, sometimes our best tools are hidden in plain sight.

Apache Commons Math, old, kinda forgotten but kinda cool, well Simple Linear Regression is hiding in there and is easy to put together.

1. Add the dependency

Put this in your pom.xml file…..

<!-- https://mvnrepository.com/artifact/org.apache.commons/commons-math3 -->
<dependency>
<groupId>org.apache.commons</groupId>
<artifactId>commons-math3</artifactId>
<version>3.6.1</version>
</dependency>

2. Import the class

In your Java class add this import statement.

import org.apache.commons.math3.stat.regression.SimpleRegression;

3. Add your two data points

I’m reading in a list of comma delimited strings so I’m parsing and converting them. The basic premise of building the model is simple though….

public SimpleRegression getLinearRegressionModel(List<String> lines) {
  SimpleRegression sr = new SimpleRegression();
  for(String s : lines) {
    String[] ssplit = s.split(",");
    double x = Double.parseDouble(ssplit[0]);
    double y = Double.parseDouble(ssplit[1]);
    sr.addData(x,y);
  }

return sr;
}

3. Make some predictions

The SimpleLinearRegression class will give you back the slope and intercept, from there is plain sailing to make a prediction.

private String runPredictions(SimpleRegression sr, int runs) {
  StringBuilder sb = new StringBuilder();
  // Display the intercept of the regression
  sb.append("Intercept: " + sr.getIntercept());
  sb.append("\n");
  // Display the slope of the regression.
  sb.append("Slope: " + sr.getSlope());
  sb.append("\n");
  // Display the slope standard error
  sb.append("Standard Error: " + sr.getSlopeStdErr());
  sb.append("\n");
  // Display adjusted R2 value
  sb.append("Adjusted R2 value: " + sr.getRSquare());
  sb.append("\n");
  sb.append("*************************************************");
  sb.append("\n");
  sb.append("Running random predictions......");
  sb.append("\n");
  sb.append("");
  Random r = new Random();
  for (int i = 0 ; i < runs ; i++) {
    int rn = r.nextInt(10);
    sb.append("Input score: " + rn + " prediction: " + Math.round(sr.predict(rn)));
    sb.append("\n");
  }
  return sb.toString();
}

Job done.

Now remember the key metric is the R2 score, sr.getRSquare() from your model. It’s a number between 0 and 1. 0 is pointless and the model shouldn’t be used, 1 is basically the most accurate model you can get. Anything less than 50% is basically less reliable than a coin flip. Aim for a minimum of 0.8 (80%) and you’re well on your way to bragging about your predictions at the pub, or on Twitter, or Facebook or at the pub on Twitter and Facebook……

 

 

Leave a Reply

Fill in your details below or click an icon to log in:

WordPress.com Logo

You are commenting using your WordPress.com account. Log Out /  Change )

Twitter picture

You are commenting using your Twitter account. Log Out /  Change )

Facebook photo

You are commenting using your Facebook account. Log Out /  Change )

Connecting to %s

This site uses Akismet to reduce spam. Learn how your comment data is processed.

%d bloggers like this: