libgotorch

首先利用 libtorch 库封装了一个 libgotorch 库,已支持最新的 libtorch2.0.1

问题一:cgo 中返回的 tensor 对象在栈上,直接使用可能会有内存安全问题

我做了一层简单的封装来使其创建到堆上,但其引发的问题是需要手动管理内存,因此我编写了 mmgr 包在每一个 tensor 对象创建的时候自动加入 mmgr 的 storage 当中,最后在每一轮训练完毕后通过 GC 方法释放堆上的 tensor 对象

问题二:windows 下的 libtorch 库通过 msvc 编译,提供的是 C++ 接口,无法在 mingw 中无法正常链接

解决方案是通过在封装一个动态链接库并暴露 C 语言接口,在 mingw 中即可正常链接

通过解决以上两个问题,已可以在 go 语言中使用 libtorch 库并实现自己的模型了

对对联

下面进入正题,我在 tnn 库中实现了一个小型的 GPT 模型来实现对对联:couplet,下面让我们来看一下最终效果

$ go run main.go evaluate --model model7M 晚风摇树树还挺
load embedding...
model loaded
inputs: [472 3 462 148 148 342 1516]
map[4.278747:[ 醉] 5.084207:[ 润] 8.868446:[ 晨 ]]
map[3.8447263:[ 花] 4.750472:[ 润] 8.635651:[ 露 ]]
map[5.46043:[ 花] 6.7003703:[ 露] 10.768249:[ 润 ]]
map[4.3850584:[ 露] 4.875666:[ 润] 9.896332:[ 花 ]]
map[3.6241615:[ 红] 5.611262:[ 润] 10.782802:[ 花 ]]
map[4.3855276:[ 花] 5.48069:[ 红] 9.480111:[ 更 ]]
map[3.7904112:[ 心] 4.269902:[ 花] 10.3220415:[ 红 ]]
晨露润花花更红

$ go run main.go evaluate --model model7M 投石向天跟命斗
load embedding...
model loaded
inputs: [1233 190 383 11 2623 620 490]
map[5.7068815:[ 门] 5.7826476:[ 问] 9.79136:[ 闭 ]]
map[3.0136497:[ 问] 3.1092193:[ 人] 8.903796:[ 门 ]]
map[3.021591:[ 还] 3.448888:[ 歌] 8.96453:[ 问 ]]
map[4.9368696:[ 地] 5.7390223:[ 时] 9.438878:[ 卷 ]]
map[3.5542138:[ 话] 3.858942:[ 时] 8.253393:[ 与 ]]
map[3.025545:[ 与] 3.2461479:[ 卷] 9.06726:[ 时 ]]
map[4.250452:[ 时] 4.712057:[ 舟] 10.401218:[ 争 ]]
闭门问卷与时争
Plain text

注意:该模型仅训练了开源数据集 couplet-dataset 中的前 1 万个样本

模型的参数结构如下:

+------------------------+---------+
|          NAME          |  COUNT  |
+------------------------+---------+
| transformer0_attention |    1872 |
| transformer0_dense     | 1256640 |
| transformer0_output    | 1254960 |
| transformer1_attention |    1872 |
| transformer1_dense     | 1256640 |
| transformer1_output    | 1254960 |
| output                 | 2488596 |
| total                  | 7515540 |
+------------------------+---------+

train 200, cost=2h15m7.877395694s, loss=3.665343e-02
Plain text

整个模型共有 751 万个参数,模型包含 2 个 transformer 模块,由于在训练时只使用了 8 个 float32 来对每一个字进行表征,因此 attention 层的参数量较少,其他参数配置如下:

const embeddingDim = 8 // 8 个 float32 表示一个字向量
const paddingSize = 70 // 最长为 34*2 ,因此 padding 长度必须大于 68
const heads = 4
const batchSize = 128
const epoch = 200
const lr = 0.001
const transformerSize = 2
Plain text

最后让我们来看看模型的泛化能力如何

$ go run main.go evaluate --model model7M 我是谁
load embedding...
model loaded
inputs: [85 62 191]
map[4.3809786:[ 雨] 4.9436274:[ 染] 7.105626:[ 绿 ]]
map[3.8163047:[ 水] 4.013789:[ 东] 4.088595:[ 得 ]]
map[4.872726:[ 唱] 5.4107614:[ 兰] 6.3983927:[ 发 ]]
绿得发

$ go run main.go evaluate --model ./model7M 我在哪
load embedding...
model loaded
inputs: [85 99 1151]
map[1.480957:[ 思] 2.002811:[ 得] 4.0260763:[ 寻 ]]
map[3.4100764:[ 女] 3.868993:[ 对] 4.448501:[ 得 ]]
map[2.2672489:[ 年] 2.3772364:[ 历] 4.946753:[ 谁 ]]
寻得谁
Plain text

效果不是很理想,可能还是跟训练的样本数量太少有关

另外还有一些示例可在 example 目录下找到,如使用 RNN 来学习如何画 sin 曲线等

最后是项目地址: