In this article, we’ll explore what machine learning is and how to apply it effectively using Java, with hands-on examples and practical libraries. Introduction For a long time, Java wasn’t considered the go-to language for machine learning - Python dominated the space with libraries like TensorFlow and PyTorch.However, Java has powerful tools for ML:DeepLearning4J, Tribuo, and Smile, allowing developers to build models directly within the JVM ecosystem. DeepLearning4J, Tribuo, and Smile In this article, we’ll explore how to use these libraries, show practical examples, and compare their strengths and weaknesses. 1. DeepLearning4J (DL4J) What it is: A JVM-based deep learning framework supporting neural networks and integration with Apache Spark. When to use: If you need deep learning capabilities, GPU acceleration, and seamless integration with Java applications. What it is: A JVM-based deep learning framework supporting neural networks and integration with Apache Spark. What it is: When to use: If you need deep learning capabilities, GPU acceleration, and seamless integration with Java applications. When to use: Example: Building and Using a Simple Neural Network Example: Building and Using a Simple Neural Network Here’s a full working Java program demonstrating training, prediction, and evaluation using DL4J: full working Java program // 1. Define the neural network configuration MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder() .seed(123) .updater(new Nesterovs(0.1, 0.9)) .list() .layer(new DenseLayer.Builder() .nIn(4) // 4 input features .nOut(3) // hidden layer size .activation(Activation.RELU) .build()) .layer(new OutputLayer.Builder() .nOut(3) // 3 output classes .activation(Activation.SOFTMAX) .build()) .build(); MultiLayerNetwork model = new MultiLayerNetwork(conf); model.init(); MultiLayerNetwork model = new MultiLayerNetwork(conf); model.init(); // 2. Prepare training data (features and one-hot labels) INDArray input = Nd4j.create(new double[][]{ {0.1, 0.2, 0.3, 0.4}, {0.5, 0.6, 0.7, 0.8}, {0.9, 1.0, 1.1, 1.2}, {1.3, 1.4, 1.5, 1.6}, {1.7, 1.8, 1.9, 2.0} }); INDArray labels = Nd4j.create(new double[][]{ {1, 0, 0}, {0, 1, 0}, {0, 0, 1}, {1, 0, 0}, {0, 1, 0} }); DataSet trainingData = new DataSet(input, labels); // 3. Train the model int nEpochs = 1000; for (int i = 0; i < nEpochs; i++) { model.fit(trainingData); } // 4. Make predictions INDArray testInput = Nd4j.create(new double[][]{ {0.2, 0.3, 0.4, 0.5} // new example }); INDArray output = model.output(testInput); System.out.println("Predicted probabilities: " + output); int predictedClass = Nd4j.argMax(output, 1).getInt(0); System.out.println("Predicted class: " + predictedClass); // 5. Evaluate model on training data Evaluation eval = new Evaluation(3); // 3 classes INDArray predicted = model.output(input); eval.eval(labels, predicted); System.out.println(eval.stats()); } // 1. Define the neural network configuration MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder() .seed(123) .updater(new Nesterovs(0.1, 0.9)) .list() .layer(new DenseLayer.Builder() .nIn(4) // 4 input features .nOut(3) // hidden layer size .activation(Activation.RELU) .build()) .layer(new OutputLayer.Builder() .nOut(3) // 3 output classes .activation(Activation.SOFTMAX) .build()) .build(); MultiLayerNetwork model = new MultiLayerNetwork(conf); model.init(); MultiLayerNetwork model = new MultiLayerNetwork(conf); model.init(); // 2. Prepare training data (features and one-hot labels) INDArray input = Nd4j.create(new double[][]{ {0.1, 0.2, 0.3, 0.4}, {0.5, 0.6, 0.7, 0.8}, {0.9, 1.0, 1.1, 1.2}, {1.3, 1.4, 1.5, 1.6}, {1.7, 1.8, 1.9, 2.0} }); INDArray labels = Nd4j.create(new double[][]{ {1, 0, 0}, {0, 1, 0}, {0, 0, 1}, {1, 0, 0}, {0, 1, 0} }); DataSet trainingData = new DataSet(input, labels); // 3. Train the model int nEpochs = 1000; for (int i = 0; i < nEpochs; i++) { model.fit(trainingData); } // 4. Make predictions INDArray testInput = Nd4j.create(new double[][]{ {0.2, 0.3, 0.4, 0.5} // new example }); INDArray output = model.output(testInput); System.out.println("Predicted probabilities: " + output); int predictedClass = Nd4j.argMax(output, 1).getInt(0); System.out.println("Predicted class: " + predictedClass); // 5. Evaluate model on training data Evaluation eval = new Evaluation(3); // 3 classes INDArray predicted = model.output(input); eval.eval(labels, predicted); System.out.println(eval.stats()); } How it works: How it works: Defines a simple 2-layer network (4 input features → hidden layer → 3-class output). Uses dummy training data (5 examples, one-hot labels). Trains the network for 1000 epochs. Makes a prediction for a new input example. Evaluates the network on the training data and prints accuracy & stats. Defines a simple 2-layer network (4 input features → hidden layer → 3-class output). Uses dummy training data (5 examples, one-hot labels). Trains the network for 1000 epochs. Makes a prediction for a new input example. Evaluates the network on the training data and prints accuracy & stats. Pros: Pros: GPU acceleration Full support for deep neural networks Integrates well with existing Java applications GPU acceleration Full support for deep neural networks Integrates well with existing Java applications Cons: Cons: Steeper learning curve compared to Python Steeper learning curve compared to Python 2. Tribuo 2. Tribuo What it is: A Java ML library from Oracle for classification, regression, clustering, and more. When to use: When you need a fast start with classical ML algorithms in Java. What it is: A Java ML library from Oracle for classification, regression, clustering, and more. What it is: When to use: When you need a fast start with classical ML algorithms in Java. When to use: Example: Text Classification Example: Text Classification MutableDataset<Label> dataset = new MutableDataset<>(); dataset.add(new Example<>(new Label("spam"), Map.of("text", "Win a free prize!"))); dataset.add(new Example<>(new Label("ham"), Map.of("text", "Let's meet tomorrow."))); Trainer<Label> trainer = new SGDTrainer(); Model<Label> model = trainer.train(dataset); MutableDataset<Label> dataset = new MutableDataset<>(); dataset.add(new Example<>(new Label("spam"), Map.of("text", "Win a free prize!"))); dataset.add(new Example<>(new Label("ham"), Map.of("text", "Let's meet tomorrow."))); Trainer<Label> trainer = new SGDTrainer(); Model<Label> model = trainer.train(dataset); Pros: Pros: Easy integration Flexible architecture for experimentation Supports model evaluation Easy integration Flexible architecture for experimentation Supports model evaluation Cons: Cons: Not specialized for deep learning Not specialized for deep learning 3. Smile 3. Smile What it is: A lightweight ML and statistics library for the JVM. When to use: For quick analytics, regression, clustering, and data visualization. What it is: A lightweight ML and statistics library for the JVM. What it is: When to use: For quick analytics, regression, clustering, and data visualization. When to use: Example: Linear Regression Example: Linear Regression double[][] x = {{1}, {2}, {3}, {4}}; double[] y = {1.1, 1.9, 3.0, 4.1}; OLS ols = OLS.fit(x, y); System.out.println("Prediction for 5: " + ols.predict(new double[]{5})); double[][] x = {{1}, {2}, {3}, {4}}; double[] y = {1.1, 1.9, 3.0, 4.1}; OLS ols = OLS.fit(x, y); System.out.println("Prediction for 5: " + ols.predict(new double[]{5})); Pros: Pros: Fast and lightweight Many built-in algorithms Great for statistical analysis Fast and lightweight Many built-in algorithms Great for statistical analysis Cons: Cons: Less focus on deep learning Less focus on deep learning Practical Use Cases of Java ML Libraries Practical Use Cases of Java ML Libraries DeepLearning4J (DL4J) - Deep Learning in Enterprise DeepLearning4J (DL4J) - Deep Learning in Enterprise Use cases: Use cases: Image recognition / computer vision NLP / Chatbots Time-series predictions Image recognition / computer vision NLP / Chatbots Time-series predictions Mini Case Study: Retail Shelf Monitoring Mini Case Study: Retail Shelf Monitoring A retail company wants to automatically detect empty shelves using store cameras. DL4J’s CNN can process images from cameras to identify missing products. Tribuo - Classical Machine Learning in Java Tribuo - Classical Machine Learning in Java Use cases: Use cases: Text classification Predictive analytics Anomaly detection Text classification Predictive analytics Anomaly detection Mini Case Study: Fraud Detection in Finance Mini Case Study: Fraud Detection in Finance MutableDataset<Label> dataset = new MutableDataset<>(); dataset.add(new Example<>(new Label("fraud"), Map.of("amount", 1000, "country", "NG"))); dataset.add(new Example<>(new Label("legit"), Map.of("amount", 50, "country", "US"))); Trainer<Label> trainer = new SGDTrainer(); Model<Label> model = trainer.train(dataset); Example<Label> newTransaction = new Example<>(null, Map.of("amount", 500, "country", "NG")); Label prediction = model.predict(newTransaction); System.out.println("Predicted: " + prediction.getLabel()); MutableDataset<Label> dataset = new MutableDataset<>(); dataset.add(new Example<>(new Label("fraud"), Map.of("amount", 1000, "country", "NG"))); dataset.add(new Example<>(new Label("legit"), Map.of("amount", 50, "country", "US"))); Trainer<Label> trainer = new SGDTrainer(); Model<Label> model = trainer.train(dataset); Example<Label> newTransaction = new Example<>(null, Map.of("amount", 500, "country", "NG")); Label prediction = model.predict(newTransaction); System.out.println("Predicted: " + prediction.getLabel()); Smile - Fast Analytics and Prototyping Smile - Fast Analytics and Prototyping Use cases: Use cases: Statistical analysis / regression Clustering / segmentation Recommendation systems Statistical analysis / regression Clustering / segmentation Recommendation systems Mini Case Study: Customer Segmentation Mini Case Study: Customer Segmentation import smile.clustering.KMeans; double[][] data = { {5.2, 10}, {6.5, 12}, {1.0, 2}, {1.2, 3}, {7.0, 11} }; KMeans kmeans = KMeans.fit(data, 2); // 2 clusters int[] labels = kmeans.getLabels(); System.out.println("Cluster assignments: " + Arrays.toString(labels)); import smile.clustering.KMeans; double[][] data = { {5.2, 10}, {6.5, 12}, {1.0, 2}, {1.2, 3}, {7.0, 11} }; KMeans kmeans = KMeans.fit(data, 2); // 2 clusters int[] labels = kmeans.getLabels(); System.out.println("Cluster assignments: " + Arrays.toString(labels)); Library Comparison Library Comparison Library Best for Example Tasks Notes DL4J Deep learning, GPU tasks Image recognition, NLP, time-series High learning curve, enterprise-ready Tribuo Classical ML Classification, regression, anomaly detection Easy to integrate in microservices Smile Analytics & prototyping Clustering, regression, statistics Lightweight, fast, less focus on deep learning Library Best for Example Tasks Notes DL4J Deep learning, GPU tasks Image recognition, NLP, time-series High learning curve, enterprise-ready Tribuo Classical ML Classification, regression, anomaly detection Easy to integrate in microservices Smile Analytics & prototyping Clustering, regression, statistics Lightweight, fast, less focus on deep learning Library Best for Example Tasks Notes Library Library Best for Best for Example Tasks Example Tasks Notes Notes DL4J Deep learning, GPU tasks Image recognition, NLP, time-series High learning curve, enterprise-ready DL4J DL4J DL4J Deep learning, GPU tasks Deep learning, GPU tasks Image recognition, NLP, time-series Image recognition, NLP, time-series High learning curve, enterprise-ready High learning curve, enterprise-ready Tribuo Classical ML Classification, regression, anomaly detection Easy to integrate in microservices Tribuo Tribuo Tribuo Classical ML Classical ML Classification, regression, anomaly detection Classification, regression, anomaly detection Easy to integrate in microservices Easy to integrate in microservices Smile Analytics & prototyping Clustering, regression, statistics Lightweight, fast, less focus on deep learning Smile Smile Smile Analytics & prototyping Analytics & prototyping Clustering, regression, statistics Clustering, regression, statistics Lightweight, fast, less focus on deep learning Lightweight, fast, less focus on deep learning Conclusion Conclusion Java’s ML ecosystem is robust and evolving rapidly: DL4J → heavy deep learning applications. Tribuo → classical ML in production-ready Java apps. Smile → lightweight prototyping, clustering, regression, and statistics. DL4J → heavy deep learning applications. DL4J Tribuo → classical ML in production-ready Java apps. Tribuo Smile → lightweight prototyping, clustering, regression, and statistics. Smile With these libraries, Java developers can build modern machine learning applications without leaving the JVM ecosystem, whether it’s enterprise AI, fintech analytics, or user behavior modeling. without leaving the JVM ecosystem