我用 GO 语言封装了一个机器学习框架,并实现了一个小型的 GPT 模型来对对联
libgotorch
首先利用 libtorch 库封装了一个 libgotorch 库,已支持最新的 libtorch2.0.1
问题一:cgo 中返回的 tensor 对象在栈上,直接使用可能会有内存安全问题
我做了一层简单的封装来使其创建到堆上,但其引发的问题是需要手动管理内存,因此我编写了 mmgr 包在每一个 tensor 对象创建的时候自动加入 mmgr 的 storage 当中,最后在每一轮训练完毕后通过 GC 方法释放堆上的 tensor 对象
问题二:windows 下的 libtorch 库通过 msvc 编译,提供的是 C++ 接口,无法在 mingw 中无法正常链接
解决方案是通过在封装一个动态链接库并暴露 C 语言接口,在 mingw 中即可正常链接
通过解决以上两个问题,已可以在 go 语言中使用 libtorch 库并实现自己的模型了
对对联
下面进入正题,我在 tnn 库中实现了一个小型的 GPT 模型来实现对对联:couplet,下面让我们来看一下最终效果
注意:该模型仅训练了开源数据集 couplet-dataset 中的前 1 万个样本
模型的参数结构如下:
整个模型共有 751 万个参数,模型包含 2 个 transformer 模块,由于在训练时只使用了 8 个 float32 来对每一个字进行表征,因此 attention 层的参数量较少,其他参数配置如下:
最后让我们来看看模型的泛化能力如何
效果不是很理想,可能还是跟训练的样本数量太少有关
另外还有一些示例可在 example 目录下找到,如使用 RNN 来学习如何画 sin 曲线等
最后是项目地址: