Учебный проект по классификации изображений на Java.
Проект начался с реализации MultiPerceptron для задания Princeton COS126 – Image Classification, а затем был расширен более современными моделями:
- MultiPerceptron
- Logistic Regression
- MLP (Multi-Layer Perceptron)
- CNN (Convolutional Neural Network)
В основной части проекта реализованы:
Perceptron.java— бинарный perceptronMultiPerceptron.java— многоклассовая классификация (One-vs-All)LogisticRegression.java— многоклассовая логистическая регрессияMLP.java— простая нейронная сетьImageClassifier.java— единый pipeline для обучения и тестирования
Поддерживаемые режимы запуска:
perceptronlogregmlp
Для сверточной нейронной сети добавлен отдельный Maven-проект:
cnn-java/pom.xmlcnn-java/src/main/java/CnnDigits.java
CNN вынесена отдельно, потому что использует Maven + DL4J, а не ручную компиляцию через javac -cp ....
| Model | Description |
|---|---|
| MultiPerceptron | линейная модель One-vs-All |
| Logistic Regression | вероятностная линейная модель |
| MLP | простая полносвязная нейронная сеть |
| CNN | сверточная нейронная сеть для изображений |
PerceptronClassifier
│
├─ src
│ ├─ Perceptron.java
│ ├─ MultiPerceptron.java
│ ├─ LogisticRegression.java
│ ├─ MLP.java
│ └─ ImageClassifier.java
│
├─ lib
│ └─ stdlib.jar
│
├─ out
│ └─ compiled classes
│
├─ datasets
│ ├─ digits
│ │ ├─ digits.jar
│ │ ├─ training.zip
│ │ ├─ testing.zip
│ │ ├─ digits-training5.txt
│ │ ├─ digits-training10.txt
│ │ ├─ digits-training20.txt
│ │ ├─ digits-training30.txt
│ │ ├─ digits-training40.txt
│ │ ├─ digits-training50.txt
│ │ ├─ digits-training100.txt
│ │ ├─ digits-training6K.txt
│ │ ├─ digits-training60K.txt
│ │ ├─ digits-testing3.txt
│ │ ├─ digits-testing10.txt
│ │ ├─ digits-testing20.txt
│ │ ├─ digits-testing30.txt
│ │ ├─ digits-testing40.txt
│ │ ├─ digits-testing50.txt
│ │ ├─ digits-testing100.txt
│ │ ├─ digits-testing1K.txt
│ │ └─ digits-testing10K.txt
│ │
│ ├─ animals
│ ├─ fashion
│ ├─ Kuzushiji
│ ├─ music
│ └─ fruit
│
└─ cnn-java
├─ pom.xml
└─ src
└─ main
└─ java
└─ CnnDigits.java
В основной части проекта используется Multiclass Perceptron (One-vs-All).
Идея алгоритма:
- Для каждого класса создаётся отдельный perceptron
- Каждый perceptron обучается распознавать свой класс
- При классификации выбирается perceptron с максимальной оценкой
Обучение выполняется по правилу обновления весов:
если prediction ≠ label:
weights[label] += x
weights[predicted] -= x
Где:
x— вектор признаков изображенияlabel— правильный классpredicted— предсказанный класс
Таким образом модель постепенно учится различать классы изображений.
flowchart LR
A[Image 28x28] --> B[Feature Extraction]
B --> C[Vector 784]
C --> D[MultiPerceptron]
D --> P0[Perceptron class 0]
D --> P1[Perceptron class 1]
D --> P2[Perceptron class 2]
D --> P9[Perceptron class 9]
P0 --> S0[Score]
P1 --> S1
P2 --> S2
P9 --> S9
S0 --> M[Max Score]
S1 --> M
S2 --> M
S9 --> M
M --> O[Predicted class]
1️⃣ изображение 28×28 2️⃣ преобразуется в 784-мерный вектор признаков 3️⃣ каждый perceptron обучается распознавать один класс 4️⃣ выбирается perceptron с максимальной оценкой
Это стратегия One-vs-All.
Класс Perceptron реализует бинарный perceptron, который работает с двумя метками:
+1
-1
- хранит массив весов
- вычисляет взвешенную сумму
- предсказывает класс
- обновляет веса при ошибке
Основные методы:
weightedSum()
predict()
train()
Если perceptron ошибается, веса обновляются:
weights[i] += label * x[i]
Этот класс является базовым строительным блоком модели.
Класс MultiPerceptron реализует многоклассовую классификацию.
Используется стратегия:
One-vs-All
То есть:
- для каждого класса создаётся отдельный perceptron
- каждый perceptron пытается распознать свой класс
При ошибке обновляются два perceptron:
perceptrons[predicted] -= x
perceptrons[label] += x
Основные методы:
predictMulti()
trainMulti()
numberOfClasses()
numberOfInputs()
Этот класс реализует многоклассовую логистическую регрессию.
Особенности:
- используется функция softmax
- модель предсказывает вероятности классов
- веса обучаются методом градиентного спуска
Логистическая регрессия лучше perceptron, потому что:
- оптимизирует функцию потерь
- работает со стохастическим градиентом
Класс MLP реализует простую нейронную сеть.
Структура сети:
Input layer (784)
Hidden layer
Output layer (10)
Используются:
- ReLU или sigmoid в скрытом слое
- softmax на выходе
Обучение происходит через:
backpropagation
Это главный класс проекта.
Он объединяет:
- загрузку данных
- извлечение признаков
- обучение модели
- тестирование
1️⃣ читает training-файл
2️⃣ загружает изображения
3️⃣ извлекает признаки
4️⃣ обучает модель
5️⃣ тестирует модель
6️⃣ вычисляет test error rate
Архитектура проекта выглядит так:
ImageClassifier
│
▼
extractFeatures()
│
▼
vector (784 features)
│
▼
Model
│ │ │
▼ ▼ ▼
Perceptron LogisticRegression MLP
│
▼
prediction
1️⃣ ImageClassifier читает изображение
2️⃣ преобразует его в вектор признаков
3️⃣ передаёт его выбранной модели
Модель:
- вычисляет оценки классов
- выбирает максимальную
После тестирования программа выводит:
test error rate
В основном проекте изображение преобразуется в одномерный вектор признаков длины:
width * height
Для изображения 28×28 это:
784 признака
В ImageClassifier.java используется яркость пикселя, нормализованная в диапазон [0, 1].
Для датасета digits в проекте используются два способа хранения изображений.
Файлы:
digits-training6K.txtdigits-training60K.txtdigits-testing1K.txtdigits-testing10K.txt
используют пути вида:
jar:file:digits.jar!/training/7/4545.png 7
jar:file:digits.jar!/training/5/49785.png 5
Это означает, что изображения читаются напрямую из архива:
datasets/digits/digits.jar
Для таких запусков распаковка не нужна.
Файлы вроде:
digits-training5.txtdigits-training10.txtdigits-training20.txtdigits-training30.txtdigits-training40.txtdigits-training50.txtdigits-training100.txtdigits-testing3.txtdigits-testing10.txtdigits-testing20.txtdigits-testing30.txtdigits-testing40.txtdigits-testing50.txtdigits-testing100.txt
используют обычные относительные пути вида:
digits/training/1/99.png 1
digits/training/9/19.png 9
digits/training/0/69.png 0
Для таких файлов нужна распакованная папка с PNG.
Файлы:
datasets/digits/training.zipdatasets/digits/testing.zip
содержат распаковываемые PNG-изображения, которые нужны:
- для маленьких тестовых запусков
- для локального просмотра изображений
- для CNN, если сеть читает изображения из папок
training/0..9иtesting/0..9
То есть:
- для
digits.jarраспаковка не нужна - для CNN распаковка нужна
Нужно распаковать оба архива:
training.ziptesting.zip
в папку:
datasets/digits/
После правильной распаковки структура должна получиться такой:
datasets/digits
│
├─ digits.jar
├─ training.zip
├─ testing.zip
│
├─ training/
│ ├─ 0/
│ ├─ 1/
│ ├─ 2/
│ ├─ 3/
│ ├─ 4/
│ ├─ 5/
│ ├─ 6/
│ ├─ 7/
│ ├─ 8/
│ └─ 9/
│
└─ testing/
├─ 0/
├─ 1/
├─ 2/
├─ 3/
├─ 4/
├─ 5/
├─ 6/
├─ 7/
├─ 8/
└─ 9/
Из папки datasets/digits:
Expand-Archive training.zip
Expand-Archive testing.zipunzip training.zip
unzip testing.zipДля CnnDigits.java ожидается структура:
training/
0/
1/
...
9/
testing/
0/
1/
...
9/
Если после распаковки архив создаёт лишнюю вложенную папку, например:
datasets/digits/digits/training
datasets/digits/digits/testing
то путь в CnnDigits.java нужно либо:
- исправить,
- либо переместить папки на уровень выше.
В локальном запуске CNN использовалось:
training: 60000 imagestesting: 10062 images
Команды подсчёта на Windows:
dir /s /b "C:\ … your files … \*.png" | find /c /v ""
dir /s /b "C:\ … your files … \digits\testing\*.png" | find /c /v ""Принцип работы:
dir /s /bвыводит все.pngфайлы по одному пути на строкуfind /c /v ""считает количество непустых строк- одна строка = один файл
| Model | Test Error Rate | Accuracy |
|---|---|---|
| MultiPerceptron | 0.1293 | 87.07% |
| Logistic Regression | 0.1039 | 89.61% |
| MLP | 0.0313 | 96.87% |
| CNN | 0.0167 | 98.33% |
accuracy = 1 - error_rate
Пример:
1 - 0.0167 = 0.9833 = 98.33%
Эксперименты на digits показывают, что качество модели растёт по мере усложнения архитектуры:
MultiPerceptron → Logistic Regression → MLP → CNN
Причина:
- Perceptron и Logistic Regression — более простые модели
- MLP умеет находить нелинейные зависимости
- CNN лучше всего подходит для изображений, потому что учитывает пространственную структуру, локальные признаки, контуры и формы
| Dataset | Training Size | Test Size | Error Rate |
|---|---|---|---|
| digits | 60 000 | 10 000 | 0.1293 |
| fashion | 60 000 | 10 000 | 0.2204 |
| Kuzushiji | 60 000 | 10 000 | 0.4587 |
| animals | 60 000 | 12 000 | 0.7328 |
| music | 50 000 | 10 000 | 0.5479 |
| fruit | 30 000 | 6 000 | 0.1361 |
Из корня проекта:
javac -cp ".;lib\stdlib.jar" -d out src\*.javaПерейти в папку датасета:
cd /d "C:\ … your files … \datasets\digits"java -cp "..\..\out;..\..\lib\stdlib.jar" ImageClassifier digits-training60K.txt digits-testing10K.txtили явно:
java -cp "..\..\out;..\..\lib\stdlib.jar" ImageClassifier digits-training60K.txt digits-testing10K.txt perceptronjava -cp "..\..\out;..\..\lib\stdlib.jar" ImageClassifier digits-training60K.txt digits-testing10K.txt logregjava -cp "..\..\out;..\..\lib\stdlib.jar" ImageClassifier digits-training60K.txt digits-testing10K.txt mlpCNN находится в папке:
cnn-java/
Используемый стек:
- Maven
- Deeplearning4j
- ND4J
- DataVec
cnn-java
├─ pom.xml
└─ src
└─ main
└─ java
└─ CnnDigits.java
Из корня проекта:
cd /d "C: … your files … \PerceptronClassifier"
mkdir cnn-java
mkdir cnn-java\src
mkdir cnn-java\src\main
mkdir cnn-java\src\main\javaПосле этого нужно создать файлы:
cnn-java\pom.xmlcnn-java\src\main\java\CnnDigits.java
Официальный сайт:
https://maven.apache.org/download.cgi
Нужно скачать Binary zip archive, например:
apache-maven-3.9.13-bin.zip
Например сюда:
D:\Tools\apache-maven-3.9.13
Важно, чтобы существовал файл:
D:\Tools\apache-maven-3.9.13\bin\mvn.cmd
"D:\Tools\apache-maven-3.9.13\bin\mvn.cmd" -vПример:
Apache Maven 3.9.13
Java version: 21
Из папки cnn-java:
cd /d "C:\ … your files … \PerceptronClassifier\cnn-java"
"C:\Tools\apache-maven-3.9.13\bin\mvn.cmd" compile
"C:\Tools\apache-maven-3.9.13\bin\mvn.cmd" exec:javaКоманда:
"D:\Tools\apache-maven-3.9.13\bin\mvn.cmd" compileделает следующее:
- читает
pom.xml - скачивает нужные библиотеки
- компилирует
src\main\java\CnnDigits.java
Команда:
"D:\Tools\apache-maven-3.9.13\bin\mvn.cmd" exec:javaделает следующее:
- запускает главный Java-класс, указанный в
pom.xml - обучает CNN
- выводит метрики качества
После 2 эпох CNN показала:
- Accuracy: 98.33%
- Test Error Rate: 0.0167
| Dataset | Source |
|---|---|
| digits | MNIST |
| fashion | Fashion-MNIST |
| Kuzushiji | Kuzushiji-MNIST |
| animals | Google QuickDraw |
| music | Google QuickDraw |
| fruit | Google QuickDraw |
Amanzhol
Учебный проект по Java и Machine Learning.
Реализованы:
- Perceptron
- Logistic Regression
- MLP
- CNN
# ⭐ Если проект оказался полезным
Если этот проект оказался полезным или интересным,
можно поставить ⭐ репозиторию на GitHub.