PyTorch-Traps
1. 意外的条件判断
众所周知 在Python的条件判断语句下 只有下面几个值判断为 False :
[]/()""/''0/0.0None
但是可能不清楚 tensor([0]) 也能隐式转换为 False
| |
2. 复杂的广播规则
PyTorch 的广播机制可以很方便实现矩阵操作 不过会有一些广播上的意外情况发生 具体实现自行查看PyTorch源码
试想象下面一些矩阵加法的结果
[2, 3, 4]+[2, 1, 4][2, 3, 4]+[2, 3][2, 3, 4]+[3, 4][2, 1, 3]+[2, 3]
| |
答案揭晓为:
[2, 3, 4]+[2, 1, 4]=[2, 3, 4]相当于广播在dim=1的地方[2, 3, 4]+[2, 3]=RuntimeErrorThe size of tensor a (4) must match the size of tensor b (3) at non-singleton dimension 2[2, 3, 4]+[3, 4]=[2, 3, 4]相当于增加一个dim=0 再在dim=0出广播[2, 1, 3]+[2, 3]=[2, 2, 3]
| |
3. 保持单进程日志读写
在分布式训练时 注意日志读写与模型存储只使用单一进程来完成 一般使用 rank==0 的进程