PyTorch-Traps
1. 意外的条件判断
众所周知 在Python的条件判断语句下 只有下面几个值判断为 False
:
[]
/()
""
/''
0
/0.0
None
但是可能不清楚 tensor([0])
也能隐式转换为 False
|
|
2. 复杂的广播规则
PyTorch 的广播机制可以很方便实现矩阵操作 不过会有一些广播上的意外情况发生 具体实现自行查看PyTorch源码
试想象下面一些矩阵加法的结果
[2, 3, 4]+[2, 1, 4]
[2, 3, 4]+[2, 3]
[2, 3, 4]+[3, 4]
[2, 1, 3]+[2, 3]
|
|
答案揭晓为:
[2, 3, 4]+[2, 1, 4]=[2, 3, 4]
相当于广播在dim=1的地方[2, 3, 4]+[2, 3]=RuntimeError
The size of tensor a (4) must match the size of tensor b (3) at non-singleton dimension 2[2, 3, 4]+[3, 4]=[2, 3, 4]
相当于增加一个dim=0 再在dim=0出广播[2, 1, 3]+[2, 3]=[2, 2, 3]
|
|
3. 保持单进程日志读写
在分布式训练时 注意日志读写与模型存储只使用单一进程来完成 一般使用 rank==0
的进程