1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122
| def train_one_epoch(model, optimizer, scheduler, dataloader, device, epoch): model.train() dataset_size = 0 running_loss = 0.0 bar = tqdm(enumerate(dataloader), total=len(dataloader)) for step, data in bar: images = data['image'].to(device, dtype=torch.float) labels = data['label'].to(device, dtype=torch.long) batch_size = images.size(0) outputs = model(images, labels) loss = criterion(outputs, labels) loss = loss / CONFIG['n_accumulate'] loss.backward() if (step + 1) % CONFIG['n_accumulate'] == 0: optimizer.step()
optimizer.zero_grad()
if scheduler is not None: scheduler.step() running_loss += (loss.item() * batch_size) dataset_size += batch_size epoch_loss = running_loss / dataset_size bar.set_postfix(Epoch=epoch, Train_Loss=epoch_loss, LR=optimizer.param_groups[0]['lr']) gc.collect() return epoch_loss
@torch.inference_mode() def valid_one_epoch(model, dataloader, device, epoch): model.eval() dataset_size = 0 running_loss = 0.0 bar = tqdm(enumerate(dataloader), total=len(dataloader)) for step, data in bar: images = data['image'].to(device, dtype=torch.float) labels = data['label'].to(device, dtype=torch.long) batch_size = images.size(0)
outputs = model(images, labels) loss = criterion(outputs, labels) running_loss += (loss.item() * batch_size) dataset_size += batch_size epoch_loss = running_loss / dataset_size bar.set_postfix(Epoch=epoch, Valid_Loss=epoch_loss, LR=optimizer.param_groups[0]['lr']) gc.collect() return epoch_loss
def run_training(model, optimizer, scheduler, device, num_epochs): wandb.watch(model, log_freq=100) if torch.cuda.is_available(): print("[INFO] Using GPU: {}\n".format(torch.cuda.get_device_name())) start = time.time() best_model_wts = copy.deepcopy(model.state_dict()) best_epoch_loss = np.inf history = defaultdict(list) for epoch in range(1, num_epochs + 1): gc.collect() train_epoch_loss = train_one_epoch(model, optimizer, scheduler, dataloader=train_loader, device=CONFIG['device'], epoch=epoch) val_epoch_loss = valid_one_epoch(model, valid_loader, device=CONFIG['device'], epoch=epoch) history['Train Loss'].append(train_epoch_loss) history['Valid Loss'].append(val_epoch_loss) wandb.log({"Train Loss": train_epoch_loss}) wandb.log({"Valid Loss": val_epoch_loss}) if val_epoch_loss <= best_epoch_loss: print(f"{b_}Validation Loss Improved ({best_epoch_loss} ---> {val_epoch_loss})") best_epoch_loss = val_epoch_loss run.summary["Best Loss"] = best_epoch_loss best_model_wts = copy.deepcopy(model.state_dict()) PATH = "Loss{:.4f}_epoch{:.0f}.bin".format(best_epoch_loss, epoch) torch.save(model.state_dict(), PATH) print(f"Model Saved{sr_}") print() end = time.time() time_elapsed = end - start print('Training complete in {:.0f}h {:.0f}m {:.0f}s'.format( time_elapsed // 3600, (time_elapsed % 3600) // 60, (time_elapsed % 3600) % 60)) print("Best Loss: {:.4f}".format(best_epoch_loss)) model.load_state_dict(best_model_wts) return model, history
|