In this article, we’ll explore what machine learning is and how to apply it effectively using Java, with hands-on examples and practical libraries.
Introduction
For a long time, Java wasn’t considered the go-to language for machine learning - Python dominated the space with libraries like TensorFlow and PyTorch.
However, Java has powerful tools for ML:DeepLearning4J, Tribuo, and Smile, allowing developers to build models directly within the JVM ecosystem.
In this article, we’ll explore how to use these libraries, show practical examples, and compare their strengths and weaknesses.
1. DeepLearning4J (DL4J)
- What it is: A JVM-based deep learning framework supporting neural networks and integration with Apache Spark.
- When to use: If you need deep learning capabilities, GPU acceleration, and seamless integration with Java applications.
Example: Building and Using a Simple Neural Network
Here’s a full working Java program demonstrating training, prediction, and evaluation using DL4J:
// 1. Define the neural network configuration
MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder()
.seed(123)
.updater(new Nesterovs(0.1, 0.9))
.list()
.layer(new DenseLayer.Builder()
.nIn(4) // 4 input features
.nOut(3) // hidden layer size
.activation(Activation.RELU)
.build())
.layer(new OutputLayer.Builder()
.nOut(3) // 3 output classes
.activation(Activation.SOFTMAX)
.build())
.build();
MultiLayerNetwork model = new MultiLayerNetwork(conf);
model.init();
MultiLayerNetwork model = new MultiLayerNetwork(conf);
model.init();
// 2. Prepare training data (features and one-hot labels)
INDArray input = Nd4j.create(new double[][]{
{0.1, 0.2, 0.3, 0.4},
{0.5, 0.6, 0.7, 0.8},
{0.9, 1.0, 1.1, 1.2},
{1.3, 1.4, 1.5, 1.6},
{1.7, 1.8, 1.9, 2.0}
});
INDArray labels = Nd4j.create(new double[][]{
{1, 0, 0},
{0, 1, 0},
{0, 0, 1},
{1, 0, 0},
{0, 1, 0}
});
DataSet trainingData = new DataSet(input, labels);
// 3. Train the model
int nEpochs = 1000;
for (int i = 0; i < nEpochs; i++) {
model.fit(trainingData);
}
// 4. Make predictions
INDArray testInput = Nd4j.create(new double[][]{
{0.2, 0.3, 0.4, 0.5} // new example
});
INDArray output = model.output(testInput);
System.out.println("Predicted probabilities: " + output);
int predictedClass = Nd4j.argMax(output, 1).getInt(0);
System.out.println("Predicted class: " + predictedClass);
// 5. Evaluate model on training data
Evaluation eval = new Evaluation(3); // 3 classes
INDArray predicted = model.output(input);
eval.eval(labels, predicted);
System.out.println(eval.stats());
}
How it works:
- Defines a simple 2-layer network (4 input features → hidden layer → 3-class output).
- Uses dummy training data (5 examples, one-hot labels).
- Trains the network for 1000 epochs.
- Makes a prediction for a new input example.
- Evaluates the network on the training data and prints accuracy & stats.
Pros:
- GPU acceleration
- Full support for deep neural networks
- Integrates well with existing Java applications
Cons:
- Steeper learning curve compared to Python
2. Tribuo
- What it is: A Java ML library from Oracle for classification, regression, clustering, and more.
- When to use: When you need a fast start with classical ML algorithms in Java.
Example: Text Classification
MutableDataset<Label> dataset = new MutableDataset<>();
dataset.add(new Example<>(new Label("spam"), Map.of("text", "Win a free prize!")));
dataset.add(new Example<>(new Label("ham"), Map.of("text", "Let's meet tomorrow.")));
Trainer<Label> trainer = new SGDTrainer();
Model<Label> model = trainer.train(dataset);
Pros:
- Easy integration
- Flexible architecture for experimentation
- Supports model evaluation
Cons:
- Not specialized for deep learning
3. Smile
- What it is: A lightweight ML and statistics library for the JVM.
- When to use: For quick analytics, regression, clustering, and data visualization.
Example: Linear Regression
double[][] x = {{1}, {2}, {3}, {4}};
double[] y = {1.1, 1.9, 3.0, 4.1};
OLS ols = OLS.fit(x, y);
System.out.println("Prediction for 5: " + ols.predict(new double[]{5}));
Pros:
- Fast and lightweight
- Many built-in algorithms
- Great for statistical analysis
Cons:
- Less focus on deep learning
Practical Use Cases of Java ML Libraries
DeepLearning4J (DL4J) - Deep Learning in Enterprise
Use cases:
- Image recognition / computer vision
- NLP / Chatbots
- Time-series predictions
Mini Case Study: Retail Shelf Monitoring
A retail company wants to automatically detect empty shelves using store cameras. DL4J’s CNN can process images from cameras to identify missing products.
Tribuo - Classical Machine Learning in Java
Use cases:
- Text classification
- Predictive analytics
- Anomaly detection
Mini Case Study: Fraud Detection in Finance
MutableDataset<Label> dataset = new MutableDataset<>();
dataset.add(new Example<>(new Label("fraud"), Map.of("amount", 1000, "country", "NG")));
dataset.add(new Example<>(new Label("legit"), Map.of("amount", 50, "country", "US")));
Trainer<Label> trainer = new SGDTrainer();
Model<Label> model = trainer.train(dataset);
Example<Label> newTransaction = new Example<>(null, Map.of("amount", 500, "country", "NG"));
Label prediction = model.predict(newTransaction);
System.out.println("Predicted: " + prediction.getLabel());
Smile - Fast Analytics and Prototyping
Use cases:
- Statistical analysis / regression
- Clustering / segmentation
- Recommendation systems
Mini Case Study: Customer Segmentation
import smile.clustering.KMeans;
double[][] data = {
{5.2, 10}, {6.5, 12}, {1.0, 2}, {1.2, 3}, {7.0, 11}
};
KMeans kmeans = KMeans.fit(data, 2); // 2 clusters
int[] labels = kmeans.getLabels();
System.out.println("Cluster assignments: " + Arrays.toString(labels));
Library Comparison
Library |
Best for |
Example Tasks |
Notes |
---|---|---|---|
DL4J |
Deep learning, GPU tasks |
Image recognition, NLP, time-series |
High learning curve, enterprise-ready |
Tribuo |
Classical ML |
Classification, regression, anomaly detection |
Easy to integrate in microservices |
Smile |
Analytics & prototyping |
Clustering, regression, statistics |
Lightweight, fast, less focus on deep learning |
Conclusion
Java’s ML ecosystem is robust and evolving rapidly:
- DL4J → heavy deep learning applications.
- Tribuo → classical ML in production-ready Java apps.
- Smile → lightweight prototyping, clustering, regression, and statistics.
With these libraries, Java developers can build modern machine learning applications without leaving the JVM ecosystem, whether it’s enterprise AI, fintech analytics, or user behavior modeling.