HelloGitHub 推出的 《解说开源项目》 系列。这一期是由亚马逊工程师: Keerthan Vasist ,为我们解说 DJL(完全由 Java 构建的深度学习平台)系列的第 4 篇。
一、前言很长时间以来,Java 都是一个很受企业欢迎的编程言语。得益于丰厚的生态以及完善维护的包和框架,Java 拥有着庞大的开发者社区。虽然深度学习运用的不断演进和落地,提供应 Java 开发者的框架和库却十分充足。现今主要盛行的深度学习模型都是用 Python 编译和训练的。关于 Java 开发者而言,假设要进军深度学习界,就需求重新学习并接受一门新的编程言语同时还要学习深度学习的复杂知识。这使得大部分 Java 开发者学习和转型深度学习开发变得困难重重。
为了增加 Java 开发者学习深度学习的成本,AWS 构建了 Deep Java Library (DJL),一个为 Java 开发者定制的开源深度学习框架。它为 Java 开发者对接主流深度学习框架提供了一个桥梁。
在这篇文章中,我们会尝试用 DJL 构建一个深度学习模型并用它训练 MNIST 手写数字辨认义务。
二、什么是深度学习?在我们正式末尾之前,我们先来了解一下机器学习和深度学习的基本概念。
机器学习是一个经过应用统计学知识,将数据输入到计算机中停止训练并完成特定目的义务的进程。这种归结学习的办法可以让计算机学习一些特征并停止一系列复杂的义务,比如辨认照片中的物体。由于需求写复杂的逻辑以及测量标准,这些义务在传统计算迷信范围中很难完成。
深度学习是机器学习的一个分支,主要侧重于关于人工神经网络的开发。人工神经网络是经过研讨人脑如何学习和完成目的的进程中归结而得出一套计算逻辑。它经过模拟部分人脑神经间信息传递的进程,从而完成各类复杂的义务。深度学习中的“深度”来源于我们会在人工神经网络中编织构建出许多层(layer)从而进一步对数据信息停止更深层的传导。深度学习技术运用范围十分普遍,如今被用来做目的检测、举措辨认、机器翻译、语意剖析等各类理想运用中。
三、训练 MNIST 手写数字辨认 3.1 项目配置你可以用如下的 gradle 配置来引入依赖项。在这个案例中,我们用 DJL 的 api 包 (中心 DJL 组件) 和 basicdataset 包 (DJL 数据集) 来构建神经网络和数据集。这个案例中我们运用了 MXNet 作为深度学习引擎,所以我们会引入 mxnet-engine 和 mxnet-native-auto 两个包。这个案例也可以运转在 PyTorch 引擎下,只需求交流成对应的软件包即可。
plugins {
id 'java'
}
repositories {
jcenter()
}
dependencies {
implementation platform("ai.djl:bom:0.8.0")
implementation "ai.djl:api"
implementation "ai.djl:basicdataset"
// MXNet
runtimeOnly "ai.djl.mxnet:mxnet-engine"
runtimeOnly "ai.djl.mxnet:mxnet-native-auto"
}
3.2 NDArray 和 NDManagerNDArray 是 DJL 存储数据结构和数学运算的基本结构。一个 NDArray 表达了一个定长的多维数组。NDArray 的运用办法相似于 Python 中的 numpy.ndarray 。
NDManager 是 NDArray 的老板。它担任管理 NDArray 的产生和回收进程,这样可以协助我们更好的对 Java 内存停止优化。每一个 NDArray 都会是由一个 NDManager 发明出来,同时它们会在 NDManager 封锁时一同封锁。NDManager 和 NDArray 都是由 Java 的 AutoClosable 构建,这样可以确保在运转完毕时及时停止回收。想了解更多关于它们的用法和实际,请参阅我们前一期文章:
DJL 之 Java 玩转多维数组,就像 NumPy 一样
Model在 DJL 中,训练和推理都是从 Model class 末尾构建的。我们在这里主要讲训练进程中的构建办法。下面我们为 Model 创立一个新的目的。由于 Model 也是承袭了 AutoClosable 结构体,我们会用一个 try block 完成:
try (Model model = Model.newInstance()) {
...
// 主体训练代码
...
}
预备数据MNIST(Modified National Institute of Standards and Technology)数据库包含少量手写数字的图,通常被用来训练图像处置系统。DJL 曾经将 MNIST 的数据集收录到了 basicdataset 数据集里,每个 MNIST 的图的大小是 28 x 28 。假设你有本人的数据集,你也可以经过 DJL 数据集导入教程来导入数据集到你的训练义务中。
数据集导入教程: #how-to-create-your-own-dataset
int batchSize = 32; // 批大小
Mnist trainingDataset = Mnist.builder()
.optUsage(Usage.TRAIN) // 训练集
.setSampling(batchSize, true)
.build();
Mnist validationDataset = Mnist.builder()
.optUsage(Usage.TEST) // 验证集
.setSampling(batchSize, true)
.build();
(责任编辑:admin)