defcreate_balanced_dataset(df): # Count the instances of "spam" num_spam = df[df["Label"] == "spam"].shape[0] # Randomly sample "ham" instances to match the number of "spam" instances ham_subset = df[df["Label"] == "ham"].sample(num_spam, random_state=123) # Combine ham "subset" with "spam" balanced_df = pd.concat([ham_subset, df[df["Label"] == "spam"]])
# Pre-tokenize texts self.encoded_texts = [ tokenizer.encode(text) for text in self.data["Text"] ]
if max_length isNone: self.max_length = self._longest_encoded_length() else: self.max_length = max_length # Truncate sequences if they are longer than max_length self.encoded_texts = [ encoded_text[:self.max_length] for encoded_text in self.encoded_texts ]
# Pad sequences to the longest sequence self.encoded_texts = [ encoded_text + [pad_token_id] * (self.max_length - len(encoded_text)) for encoded_text in self.encoded_texts ]
if num_batches isNone: num_batches = len(data_loader) else: num_batches = min(num_batches, len(data_loader)) for i, (input_batch, target_batch) inenumerate(data_loader): if i < num_batches: input_batch, target_batch = input_batch.to(device), target_batch.to(device)
with torch.no_grad(): logits = model(input_batch)[:, -1, :] # Logits of last output token predicted_labels = torch.argmax(logits, dim=-1)
defcalc_loss_batch(input_batch, target_batch, model, device): input_batch, target_batch = input_batch.to(device), target_batch.to(device) logits = model(input_batch)[:, -1, :] # Logits of last output token loss = torch.nn.functional.cross_entropy(logits, target_batch) return loss
defcalc_loss_loader(data_loader, model, device, num_batches=None): total_loss = 0. iflen(data_loader) == 0: returnfloat("nan") elif num_batches isNone: num_batches = len(data_loader) else: # Reduce the number of batches to match the total number of batches in the data loader # if num_batches exceeds the number of batches in the data loader num_batches = min(num_batches, len(data_loader)) for i, (input_batch, target_batch) inenumerate(data_loader): if i < num_batches: loss = calc_loss_batch(input_batch, target_batch, model, device) total_loss += loss.item() else: break return total_loss / num_batches
四、开始微调
我们只微调一部分层,因此需要冻结大部分的层:
1 2 3 4 5 6
for param in model.parameters(): param.requires_grad = False for param in model.trf_blocks[-1].parameters(): param.requires_grad = True for param in model.final_norm.parameters(): param.requires_grad = True
# Overall the same as `train_model_simple` in chapter 5 deftrain_classifier_simple(model, train_loader, val_loader, optimizer, device, num_epochs, eval_freq, eval_iter): # Initialize lists to track losses and examples seen train_losses, val_losses, train_accs, val_accs = [], [], [], [] examples_seen, global_step = 0, -1
# Main training loop for epoch inrange(num_epochs): model.train() # Set model to training mode
for input_batch, target_batch in train_loader: optimizer.zero_grad() # Reset loss gradients from previous batch iteration loss = calc_loss_batch(input_batch, target_batch, model, device) loss.backward() # Calculate loss gradients optimizer.step() # Update model weights using loss gradients examples_seen += input_batch.shape[0] # New: track examples instead of tokens global_step += 1
# Optional evaluation step if global_step % eval_freq == 0: train_loss, val_loss = evaluate_model( model, train_loader, val_loader, device, eval_iter) train_losses.append(train_loss) val_losses.append(val_loss) print(f"Ep {epoch+1} (Step {global_step:06d}): " f"Train loss {train_loss:.3f}, Val loss {val_loss:.3f}")
Ep 1 (Step 000000): Train loss 2.153, Val loss 2.392 Ep 1 (Step 000050): Train loss 0.617, Val loss 0.637 Ep 1 (Step 000100): Train loss 0.523, Val loss 0.557 Training accuracy: 70.00% | Validation accuracy: 72.50% Ep 2 (Step 000150): Train loss 0.561, Val loss 0.489 Ep 2 (Step 000200): Train loss 0.419, Val loss 0.397 Ep 2 (Step 000250): Train loss 0.409, Val loss 0.353 Training accuracy: 82.50% | Validation accuracy: 85.00% Ep 3 (Step 000300): Train loss 0.333, Val loss 0.320 Ep 3 (Step 000350): Train loss 0.340, Val loss 0.306 Training accuracy: 90.00% | Validation accuracy: 90.00% Ep 4 (Step 000400): Train loss 0.136, Val loss 0.200 Ep 4 (Step 000450): Train loss 0.153, Val loss 0.132 Ep 4 (Step 000500): Train loss 0.222, Val loss 0.137 Training accuracy: 100.00% | Validation accuracy: 97.50% Ep 5 (Step 000550): Train loss 0.207, Val loss 0.143 Ep 5 (Step 000600): Train loss 0.083, Val loss 0.074 Training accuracy: 100.00% | Validation accuracy: 97.50% Training completed in 26.73 minutes.
# Plot training and validation loss against epochs ax1.plot(epochs_seen, train_values, label=f"Training {label}") ax1.plot(epochs_seen, val_values, linestyle="-.", label=f"Validation {label}") ax1.set_xlabel("Epochs") ax1.set_ylabel(label.capitalize()) ax1.legend()
# Create a second x-axis for examples seen ax2 = ax1.twiny() # Create a second x-axis that shares the same y-axis ax2.plot(examples_seen, train_values, alpha=0) # Invisible plot for aligning ticks ax2.set_xlabel("Examples seen")
fig.tight_layout() # Adjust layout to make room plt.savefig(f"{label}-plot.pdf") plt.show()