Skip to content

TensorBoard

TensorBoard 是 TensorFlow 生态中的可视化工具,也可以与 PyTorch 无缝集成。它提供了训练过程中各类指标的实时可视化,是深度学习实验中最基础的监控工具。


基本使用

安装

pip install tensorboard

PyTorch 集成

PyTorch 通过 torch.utils.tensorboard.SummaryWriter 原生支持 TensorBoard:

from torch.utils.tensorboard import SummaryWriter

# 创建 writer,日志保存到 runs/ 目录
writer = SummaryWriter('runs/experiment_1')

for epoch in range(num_epochs):
    train_loss = train_one_epoch(model, train_loader, optimizer)
    val_loss, val_acc = evaluate(model, val_loader)

    # 记录标量
    writer.add_scalar('Loss/train', train_loss, epoch)
    writer.add_scalar('Loss/val', val_loss, epoch)
    writer.add_scalar('Accuracy/val', val_acc, epoch)
    writer.add_scalar('LR', optimizer.param_groups[0]['lr'], epoch)

writer.close()

启动 TensorBoard

# 基本启动
tensorboard --logdir=runs

# 指定端口
tensorboard --logdir=runs --port=6007

# 比较多个实验
tensorboard --logdir=runs  # runs/ 下的每个子目录自动成为一个实验

然后在浏览器中访问 http://localhost:6006


常用功能

记录标量 (Scalars)

最常用的功能,用于追踪 loss、accuracy、learning rate 等随训练步数变化的指标。

writer.add_scalar('tag', scalar_value, global_step)

# 同时记录多个标量到同一图表
writer.add_scalars('Loss', {
    'train': train_loss,
    'val': val_loss
}, epoch)

记录图像 (Images)

# 单张图像
writer.add_image('sample', img_tensor, epoch)

# 图像网格(如一个 batch 的样本)
from torchvision.utils import make_grid
grid = make_grid(images[:16], nrow=4, normalize=True)
writer.add_image('batch_samples', grid, epoch)

记录直方图 (Histograms)

用于观察权重和梯度的分布变化,诊断梯度消失/爆炸问题:

for name, param in model.named_parameters():
    writer.add_histogram(f'weights/{name}', param, epoch)
    if param.grad is not None:
        writer.add_histogram(f'grads/{name}', param.grad, epoch)

记录模型图 (Graph)

可视化模型的计算图结构:

dummy_input = torch.randn(1, 3, 224, 224)
writer.add_graph(model, dummy_input)

记录超参数 (HParams)

记录超参数与最终指标的对应关系,方便对比不同配置:

writer.add_hparams(
    hparam_dict={'lr': 1e-3, 'batch_size': 32, 'optimizer': 'adam'},
    metric_dict={'final_loss': 0.15, 'final_acc': 0.95}
)

实用技巧

  • 实验命名规范:使用有意义的目录名,如 runs/resnet50_lr1e-3_bs64_aug
  • 定期 flush:调用 writer.flush() 确保数据写入磁盘
  • 远程访问:通过 SSH 端口转发在本地查看远程服务器的 TensorBoard:ssh -L 6006:localhost:6006 user@server
  • 与其他工具对比:TensorBoard 适合简单的实验追踪;对于团队协作和大规模实验管理,建议使用 Weights & Biases 或 MLflow

参考


评论 #