With certain data Simple Linear Regression wins and while the rest of the ML/AI world push tools that are far larger scope than needed for most, sometimes our best tools are hidden in plain sight.
Apache Commons Math, old, kinda forgotten but kinda cool, well Simple Linear Regression is hiding in there and is easy to put together.
1. Add the dependency
Put this in your pom.xml file…..
<!-- https://mvnrepository.com/artifact/org.apache.commons/commons-math3 --> <dependency> <groupId>org.apache.commons</groupId> <artifactId>commons-math3</artifactId> <version>3.6.1</version> </dependency>
2. Import the class
In your Java class add this import statement.
import org.apache.commons.math3.stat.regression.SimpleRegression;
3. Add your two data points
I’m reading in a list of comma delimited strings so I’m parsing and converting them. The basic premise of building the model is simple though….
public SimpleRegression getLinearRegressionModel(List<String> lines) { SimpleRegression sr = new SimpleRegression(); for(String s : lines) { String[] ssplit = s.split(","); double x = Double.parseDouble(ssplit[0]); double y = Double.parseDouble(ssplit[1]); sr.addData(x,y); } return sr; }
3. Make some predictions
The SimpleLinearRegression class will give you back the slope and intercept, from there is plain sailing to make a prediction.
private String runPredictions(SimpleRegression sr, int runs) { StringBuilder sb = new StringBuilder(); // Display the intercept of the regression sb.append("Intercept: " + sr.getIntercept()); sb.append("\n"); // Display the slope of the regression. sb.append("Slope: " + sr.getSlope()); sb.append("\n"); // Display the slope standard error sb.append("Standard Error: " + sr.getSlopeStdErr()); sb.append("\n"); // Display adjusted R2 value sb.append("Adjusted R2 value: " + sr.getRSquare()); sb.append("\n"); sb.append("*************************************************"); sb.append("\n"); sb.append("Running random predictions......"); sb.append("\n"); sb.append(""); Random r = new Random(); for (int i = 0 ; i < runs ; i++) { int rn = r.nextInt(10); sb.append("Input score: " + rn + " prediction: " + Math.round(sr.predict(rn))); sb.append("\n"); } return sb.toString(); }
Job done.
Now remember the key metric is the R2 score, sr.getRSquare()
from your model. It’s a number between 0 and 1. 0 is pointless and the model shouldn’t be used, 1 is basically the most accurate model you can get. Anything less than 50% is basically less reliable than a coin flip. Aim for a minimum of 0.8 (80%) and you’re well on your way to bragging about your predictions at the pub, or on Twitter, or Facebook or at the pub on Twitter and Facebook……