1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132
| import os,PIL import numpy as np from torch.utils.data import DataLoader, Dataset import torch from torch import nn import torchvision from torchvision import transforms import datetime import wandb from argparse import Namespace
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
config = Namespace( project_name = 'wandb_demo', batch_size = 512, hidden_layer_width = 64, dropout_p = 0.1, lr = 1e-4, optim_type = 'Adam', epochs = 15, ckpt_path = 'checkpoint.pt' )
def create_dataloaders(config): transform = transforms.Compose([transforms.ToTensor()]) ds_train = torchvision.datasets.MNIST(root="./mnist/",train=True,download=True,transform=transform) ds_val = torchvision.datasets.MNIST(root="./mnist/",train=False,download=True,transform=transform)
ds_train_sub = torch.utils.data.Subset(ds_train, indices=range(0, len(ds_train), 5)) dl_train = torch.utils.data.DataLoader(ds_train_sub, batch_size=config.batch_size, shuffle=True, num_workers=2,drop_last=True) dl_val = torch.utils.data.DataLoader(ds_val, batch_size=config.batch_size, shuffle=False, num_workers=2,drop_last=True) return dl_train,dl_val
def create_net(config): net = nn.Sequential() net.add_module("conv1",nn.Conv2d(in_channels=1,out_channels=config.hidden_layer_width,kernel_size = 3)) net.add_module("pool1",nn.MaxPool2d(kernel_size = 2,stride = 2)) net.add_module("conv2",nn.Conv2d(in_channels=config.hidden_layer_width, out_channels=config.hidden_layer_width,kernel_size = 5)) net.add_module("pool2",nn.MaxPool2d(kernel_size = 2,stride = 2)) net.add_module("dropout",nn.Dropout2d(p = config.dropout_p)) net.add_module("adaptive_pool",nn.AdaptiveMaxPool2d((1,1))) net.add_module("flatten",nn.Flatten()) net.add_module("linear1",nn.Linear(config.hidden_layer_width,config.hidden_layer_width)) net.add_module("relu",nn.ReLU()) net.add_module("linear2",nn.Linear(config.hidden_layer_width,10)) net.to(device) return net
def train_epoch(model,dl_train,optimizer): model.train() for step, batch in enumerate(dl_train): features,labels = batch features,labels = features.to(device),labels.to(device)
preds = model(features) loss = nn.CrossEntropyLoss()(preds,labels) loss.backward()
optimizer.step() optimizer.zero_grad() return model
def eval_epoch(model,dl_val): model.eval() accurate = 0 num_elems = 0 for batch in dl_val: features,labels = batch features,labels = features.to(device),labels.to(device) with torch.no_grad(): preds = model(features) predictions = preds.argmax(dim=-1) accurate_preds = (predictions==labels) num_elems += accurate_preds.shape[0] accurate += accurate_preds.long().sum()
val_acc = accurate.item() / num_elems return val_acc
def train(config = config): dl_train, dl_val = create_dataloaders(config) model = create_net(config); optimizer = torch.optim.__dict__[config.optim_type](params=model.parameters(), lr=config.lr) nowtime = datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S') wandb.init(project=config.project_name, config = config.__dict__, name = nowtime, save_code=True) model.run_id = wandb.run.id model.best_metric = -1.0 for epoch in range(1,config.epochs+1): model = train_epoch(model,dl_train,optimizer) val_acc = eval_epoch(model,dl_val) if val_acc > model.best_metric: model.best_metric = val_acc torch.save(model.state_dict(),config.ckpt_path) nowtime = datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S') print(f"epoch【{epoch}】@{nowtime} --> val_acc= {100 * val_acc:.2f}%") wandb.log({'epoch':epoch, 'val_acc': val_acc, 'best_val_acc':model.best_metric}) wandb.finish() return model
""" 三步走 1. wandb init 规定好 项目名称 超参数 run的这一下的名称 是否保存代码 2. wandb.log 定义好要记录的东西 来绘制折线图(一般来说 acc 和 最好的acc 是肯定要记录的 ) 3. wandb.finish 打完收工
"""
model = train(config)
|