Author: Joseph Catanzarite
Course: EN.705.603 Introduction to Generative AI, Spring 2026
Johns Hopkins University — Whiting School of Engineering
Whether augmentation helps is set by data scarcity, not filtering. At 50/class the baseline is near-saturated and every augmented condition degrades accuracy; at 25/class every augmented condition helps — up to +7.2 points (p<0.01), recovering ~43% of the accuracy lost to halving the real data. Holding count and ratio fixed, Stage 1 critic filtering shows no statistically detectable effect (all four isolation contrasts ≤ 0.55 points; three non-significant, the fourth failing Holm–Bonferroni). An earlier diagonal design — unfiltered at 50/class, filtered at 25/class — confounded the two factors and made filtering look like the cause; this factorial removes the confound. Stage 2 of the filter (SSIM/LPIPS perceptual screening) was not implemented here and remains planned future work.
Low-resource image classification is a central challenge in medical imaging, where labeled data is expensive and scarce. This study uses Chinese-MNIST as a controlled proxy to evaluate whether generative augmentation can improve classifier performance under data scarcity conditions analogous to those encountered in clinical imaging datasets.
Cell 1.1 — Environment setup. Mounts Google Drive at /content/drive, defines all project paths (data, checkpoints, generated images, outputs) under a single base directory so artifacts persist across Colab sessions, creates any missing directories, installs the Kaggle CLI (command-line interface), and verifies that PyTorch sees a CUDA (Compute Unified Device Architecture) GPU (Graphics Processing Unit). Everything downstream assumes these paths and the GPU exist, so this cell must run first in every session.
# ============================================================
# CELL 1.1 — Mount Drive & Install Dependencies
# ============================================================
from google.colab import drive
drive.mount('/content/drive')
import os
BASE_DIR = '/content/drive/MyDrive/Chinese_MNIST_Data_Augmentation'
DATA_DIR = os.path.join(BASE_DIR, 'data')
IMAGES_DIR = os.path.join(DATA_DIR, 'data', 'data')
CSV_PATH = os.path.join(DATA_DIR, 'chinese_mnist.csv')
CKPT_DIR = os.path.join(BASE_DIR, 'checkpoints')
GEN_DIR = os.path.join(BASE_DIR, 'generated')
OUT_DIR = os.path.join(BASE_DIR, 'outputs')
for d in [DATA_DIR, CKPT_DIR, GEN_DIR, OUT_DIR]:
os.makedirs(d, exist_ok=True)
!pip install -q kaggle
import torch
print(f'PyTorch: {torch.__version__}')
print(f'CUDA: {torch.cuda.is_available()}')
if torch.cuda.is_available():
print(f'GPU: {torch.cuda.get_device_name(0)}')
Mounted at /content/drive PyTorch: 2.11.0+cu128 CUDA: True GPU: NVIDIA L4
Cell 1.2 — Dataset acquisition (idempotent). Downloads the Chinese-MNIST dataset (15,000 JPEG --- Joint Photographic Experts Group --- images plus an index CSV, a comma-separated values file) from Kaggle into Drive, but only if the CSV is not already present — re-runs skip the download entirely. Two assertions act as a contract: the CSV must exist and exactly 15,000 images must be found, failing fast if the extraction layout is wrong.
# ============================================================
# CELL 1.2 — Download Dataset (skips if already in Drive)
# ============================================================
import os
# Kaggle API token — required only if the dataset is not already in Drive.
# Paste your own token below, or (preferred) store it in Colab Secrets
# (key icon in the left sidebar, name it KAGGLE_TOKEN) and use the two
# commented lines instead.
os.environ['KAGGLE_TOKEN'] = 'PASTE_YOUR_KAGGLE_TOKEN_HERE'
# from google.colab import userdata
# os.environ['KAGGLE_TOKEN'] = userdata.get('KAGGLE_TOKEN')
if os.path.exists(CSV_PATH):
print('Dataset already in Drive — skipping download')
else:
print('Downloading...')
os.system(f'kaggle datasets download -d gpreda/chinese-mnist -p "{DATA_DIR}" --unzip')
print('Done')
n_images = len([f for f in os.listdir(IMAGES_DIR) if f.endswith('.jpg')])
assert os.path.exists(CSV_PATH), 'CSV not found'
assert n_images == 15000, f'Expected 15000 images, found {n_images}'
print(f'✓ CSV found')
print(f'✓ Images: {n_images}')
Dataset already in Drive — skipping download ✓ CSV found ✓ Images: 15000
Cell 1.3 — Configuration, data pipeline, and splits. Declares the single CONFIG dictionary holding every hyperparameter (image size, scarcity levels, WGAN-GP and CNN training settings, seeds), a set_seed utility for reproducibility, and two dataset classes: ChineseMNIST (reads images from disk via the CSV index) and SyntheticDataset (wraps generated tensors with the same (image, int_label) interface so real and synthetic data can be concatenated). It then builds a stratified, seeded train/val/test split (50/50/100 images per class) and the corresponding DataLoaders. Because the split is seeded once, every experiment in the notebook sees identical val/test data.
# ============================================================
# CELL 1.3 — Global CONFIG, Imports, Dataset, DataLoaders
# ============================================================
import numpy as np
import pandas as pd
import random
from PIL import Image
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader, ConcatDataset
import torchvision.transforms as transforms
# ---- CONFIG — all hyperparameters here ----
CONFIG = {
'img_size': 64,
'n_classes': 15,
'n_channels': 1,
'train_per_class': 50, # simulated data scarcity
'test_per_class': 100,
'val_per_class': 50,
'latent_dim': 128,
'wgan_epochs': 500,
'wgan_batch_size': 64,
'wgan_lr': 1e-4,
'n_critic': 5,
'lambda_gp': 10,
'wgan_b1': 0.0,
'wgan_b2': 0.9,
'cnn_epochs': 30,
'cnn_batch_size': 64,
'cnn_lr': 1e-3,
'dropout': 0.4,
'n_seeds': 5,
'n_generate_1to1': 50, # 1:1 synthetic-to-real ratio
'n_generate_4to1': 200, # 4:1 synthetic-to-real ratio
'base_seed': 42,
'images_dir': IMAGES_DIR,
'csv_path': CSV_PATH,
'ckpt_dir': CKPT_DIR,
'generated_dir': GEN_DIR,
'outputs_dir': OUT_DIR,
}
# ---- Utilities ----
def set_seed(seed):
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
torch.backends.cudnn.deterministic = True
set_seed(CONFIG['base_seed'])
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f'Device: {DEVICE}')
# ---- Dataset ----
class ChineseMNIST(Dataset):
def __init__(self, csv_path, images_dir, indices, transform=None):
self.df = pd.read_csv(csv_path).iloc[indices].reset_index(drop=True)
self.images_dir = images_dir
self.transform = transform
unique_vals = sorted(self.df['value'].unique())
self.label_map = {v: i for i, v in enumerate(unique_vals)}
def __len__(self):
return len(self.df)
def __getitem__(self, idx):
row = self.df.iloc[idx]
fname = f"input_{row['suite_id']}_{row['sample_id']}_{row['code']}.jpg"
img = Image.open(os.path.join(self.images_dir, fname)).convert('L')
if self.transform:
img = self.transform(img)
return img, self.label_map[row['value']]
class SyntheticDataset(Dataset):
def __init__(self, imgs_tensor, labels_tensor):
self.imgs = imgs_tensor
self.labels = labels_tensor
def __len__(self):
return len(self.imgs)
def __getitem__(self, idx):
return self.imgs[idx], int(self.labels[idx])
# ---- Transforms ----
transform = transforms.Compose([
transforms.Resize((CONFIG['img_size'], CONFIG['img_size'])),
transforms.ToTensor(),
transforms.Normalize([0.5], [0.5])
])
# ---- Stratified split ----
def make_stratified_split(csv_path, n_train, n_val, n_test, seed):
df = pd.read_csv(csv_path)
rng = np.random.RandomState(seed)
train_idx, val_idx, test_idx = [], [], []
for val in sorted(df['value'].unique()):
class_idx = df[df['value'] == val].index.tolist()
rng.shuffle(class_idx)
train_idx.extend(class_idx[:n_train])
val_idx.extend(class_idx[n_train:n_train+n_val])
test_idx.extend(class_idx[n_train+n_val:n_train+n_val+n_test])
return train_idx, val_idx, test_idx
train_idx, val_idx, test_idx = make_stratified_split(
CSV_PATH, CONFIG['train_per_class'], CONFIG['val_per_class'],
CONFIG['test_per_class'], CONFIG['base_seed']
)
train_dataset = ChineseMNIST(CSV_PATH, IMAGES_DIR, train_idx, transform)
val_dataset = ChineseMNIST(CSV_PATH, IMAGES_DIR, val_idx, transform)
test_dataset = ChineseMNIST(CSV_PATH, IMAGES_DIR, test_idx, transform)
assert len(train_dataset) == CONFIG['train_per_class'] * CONFIG['n_classes']
assert len(test_dataset) == CONFIG['test_per_class'] * CONFIG['n_classes']
train_loader = DataLoader(train_dataset, batch_size=CONFIG['wgan_batch_size'],
shuffle=True, num_workers=2, pin_memory=True)
val_loader = DataLoader(val_dataset, batch_size=CONFIG['cnn_batch_size'],
shuffle=False, num_workers=2, pin_memory=True)
test_loader = DataLoader(test_dataset, batch_size=CONFIG['cnn_batch_size'],
shuffle=False, num_workers=2, pin_memory=True)
print(f'✓ Train: {len(train_dataset)} images ({CONFIG["train_per_class"]}/class)')
print(f'✓ Val: {len(val_dataset)} images')
print(f'✓ Test: {len(test_dataset)} images')
Device: cuda ✓ Train: 750 images (50/class) ✓ Val: 750 images ✓ Test: 1500 images
Cell 1.4 — Data sanity check (Checkpoint 1). Renders one real training image per class in a 3×5 grid and saves the figure to the outputs directory. This is a human-in-the-loop gate: it confirms the file-name decoding, grayscale conversion, resizing, and label mapping all behave before any compute is spent on training.
# ============================================================
# CELL 1.4 — CHECKPOINT 1: Visualize Real Samples
# ============================================================
df_meta = pd.read_csv(CONFIG['csv_path'])
class_names = [str(row['character']) for _, row in
df_meta.drop_duplicates('value').sort_values('value').iterrows()]
fig, axes = plt.subplots(3, 5, figsize=(12, 8))
fig.suptitle('CHECKPOINT 1: Real samples (one per class)', fontsize=13, fontweight='bold')
for i, ax in enumerate(axes.flat):
img, label = train_dataset[i * CONFIG['train_per_class']]
ax.imshow(img.squeeze().numpy() * 0.5 + 0.5, cmap='gray')
ax.set_title(f'Class {label}', fontsize=9)
ax.axis('off')
plt.tight_layout()
plt.savefig(os.path.join(OUT_DIR, 'checkpoint1_real_samples.png'), dpi=100)
plt.show()
print('\n✅ CHECKPOINT 1 PASSED — Real images look correct')
✅ CHECKPOINT 1 PASSED — Real images look correct
Cell 2.1 — Downstream classifier architecture. Defines CNNClassifier, the model whose test accuracy is the study's outcome metric: three Conv→BatchNorm→ReLU→MaxPool blocks (64×64 → 8×8 spatial) followed by a dropout-regularized fully-connected head producing 15 logits. The identical architecture and hyperparameters are used in every experimental condition, so any accuracy difference is attributable to the training data, not the model. A shape assertion on a dummy batch verifies the forward pass.
# ============================================================
# CELL 2.1 — CNN Classifier Architecture
# ============================================================
class CNNClassifier(nn.Module):
def __init__(self, n_classes=15, dropout=0.4):
super().__init__()
self.features = nn.Sequential(
nn.Conv2d(1, 32, kernel_size=3, padding=1),
nn.BatchNorm2d(32), nn.ReLU(inplace=True), nn.MaxPool2d(2),
nn.Conv2d(32, 64, kernel_size=3, padding=1),
nn.BatchNorm2d(64), nn.ReLU(inplace=True), nn.MaxPool2d(2),
nn.Conv2d(64, 128, kernel_size=3, padding=1),
nn.BatchNorm2d(128), nn.ReLU(inplace=True), nn.MaxPool2d(2),
)
self.classifier = nn.Sequential(
nn.Flatten(),
nn.Linear(128 * 8 * 8, 256), nn.ReLU(inplace=True),
nn.Dropout(dropout),
nn.Linear(256, n_classes)
)
def forward(self, x):
return self.classifier(self.features(x))
_m = CNNClassifier().to(DEVICE)
_x = torch.zeros(4, 1, 64, 64).to(DEVICE)
assert _m(_x).shape == (4, 15)
print('✓ CNNClassifier shape check passed')
del _m, _x
✓ CNNClassifier shape check passed
Cell 2.2 — Generative model architectures. Defines the conditional generator and critic for the WGAN-GP. The Generator concatenates a 128-d noise vector with a learned class embedding and upsamples through three transposed convolutions to a 64×64 image bounded by Tanh. The Critic embeds the class label as a 64×64 spatial map, stacks it as a second input channel, and downsamples through four strided convolutions to a single unbounded realism score — no sigmoid, as required by the Wasserstein formulation; InstanceNorm replaces BatchNorm per WGAN-GP convention. Shape assertions verify both forward passes.
# ============================================================
# CELL 2.2 — Generator & Critic Architectures
# ============================================================
class Generator(nn.Module):
def __init__(self, latent_dim, n_classes, n_channels=1):
super().__init__()
self.label_emb = nn.Embedding(n_classes, n_classes)
self.model = nn.Sequential(
nn.Linear(latent_dim + n_classes, 256 * 8 * 8),
nn.Unflatten(1, (256, 8, 8)),
nn.ConvTranspose2d(256, 128, 4, 2, 1), nn.BatchNorm2d(128), nn.ReLU(inplace=True),
nn.ConvTranspose2d(128, 64, 4, 2, 1), nn.BatchNorm2d(64), nn.ReLU(inplace=True),
nn.ConvTranspose2d(64, n_channels, 4, 2, 1), nn.Tanh()
)
def forward(self, noise, labels):
return self.model(torch.cat([noise, self.label_emb(labels)], dim=1))
class Critic(nn.Module):
def __init__(self, n_classes, n_channels=1):
super().__init__()
self.label_emb = nn.Embedding(n_classes, 64 * 64)
self.model = nn.Sequential(
nn.Conv2d(n_channels + 1, 64, 4, 2, 1), nn.LeakyReLU(0.2, inplace=True),
nn.Conv2d(64, 128, 4, 2, 1), nn.InstanceNorm2d(128), nn.LeakyReLU(0.2, inplace=True),
nn.Conv2d(128, 256, 4, 2, 1), nn.InstanceNorm2d(256), nn.LeakyReLU(0.2, inplace=True),
nn.Conv2d(256, 512, 4, 2, 1), nn.InstanceNorm2d(512), nn.LeakyReLU(0.2, inplace=True),
nn.Flatten(), nn.Linear(512 * 4 * 4, 1)
)
def forward(self, img, labels):
label_map = self.label_emb(labels).view(labels.size(0), 1, 64, 64)
return self.model(torch.cat([img, label_map], dim=1))
_G = Generator(CONFIG['latent_dim'], CONFIG['n_classes']).to(DEVICE)
_C = Critic(CONFIG['n_classes']).to(DEVICE)
_n = torch.randn(4, CONFIG['latent_dim']).to(DEVICE)
_l = torch.randint(0, CONFIG['n_classes'], (4,)).to(DEVICE)
assert _G(_n, _l).shape == (4, 1, 64, 64)
assert _C(_G(_n, _l), _l).shape == (4, 1)
print('✓ Generator and Critic shape checks passed')
del _G, _C, _n, _l
✓ Generator and Critic shape checks passed
Cell 2.3 — All training and generation procedures. Four functions used by every later cell: compute_gradient_penalty implements the WGAN-GP Lipschitz penalty on critic gradients at interpolated points; train_wgan runs the adversarial loop (5 critic steps per generator step, Adam with β1=0, checkpoints to Drive every 25 epochs); train_cnn trains a fresh seeded classifier for 30 epochs with cosine learning-rate annealing and returns held-out test accuracy; generate_synthetic samples class-conditional images from a trained generator into a SyntheticDataset. Centralizing these guarantees all conditions share identical training mechanics.
# ============================================================
# CELL 2.3 — Training Functions
# ============================================================
def compute_gradient_penalty(critic, real, fake, labels, device):
B = real.size(0)
alpha = torch.rand(B, 1, 1, 1, device=device)
interpolated = (alpha * real + (1 - alpha) * fake.detach()).requires_grad_(True)
d_interp = critic(interpolated, labels)
gradients = torch.autograd.grad(
outputs=d_interp, inputs=interpolated,
grad_outputs=torch.ones(d_interp.size(), device=device),
create_graph=True, retain_graph=True, only_inputs=True
)[0]
return ((gradients.view(B, -1).norm(2, dim=1) - 1) ** 2).mean()
def train_wgan(train_loader, config, device):
set_seed(config['base_seed'])
G = Generator(config['latent_dim'], config['n_classes']).to(device)
C = Critic(config['n_classes']).to(device)
opt_G = torch.optim.Adam(G.parameters(), lr=config['wgan_lr'],
betas=(config['wgan_b1'], config['wgan_b2']))
opt_C = torch.optim.Adam(C.parameters(), lr=config['wgan_lr'],
betas=(config['wgan_b1'], config['wgan_b2']))
critic_losses, gen_losses = [], []
for epoch in range(config['wgan_epochs']):
ep_c, ep_g, nc, ng = 0.0, 0.0, 0, 0
for imgs, labels in train_loader:
imgs, labels = imgs.to(device), labels.to(device)
B = imgs.size(0)
for _ in range(config['n_critic']):
noise = torch.randn(B, config['latent_dim'], device=device)
fake = G(noise, labels).detach()
gp = compute_gradient_penalty(C, imgs, fake, labels, device)
c_loss = -C(imgs, labels).mean() + C(fake, labels).mean() + config['lambda_gp'] * gp
opt_C.zero_grad(); c_loss.backward(); opt_C.step()
ep_c += c_loss.item(); nc += 1
noise = torch.randn(B, config['latent_dim'], device=device)
fake = G(noise, labels)
g_loss = -C(fake, labels).mean()
opt_G.zero_grad(); g_loss.backward(); opt_G.step()
ep_g += g_loss.item(); ng += 1
critic_losses.append(ep_c / nc)
gen_losses.append(ep_g / ng)
if (epoch + 1) % 10 == 0:
print(f'Epoch [{epoch+1}/{config["wgan_epochs"]}] '
f'Critic: {ep_c/nc:.4f} Generator: {ep_g/ng:.4f}')
if (epoch + 1) % 25 == 0:
ckpt = os.path.join(config['ckpt_dir'], f'wgan_epoch{epoch+1}.pt')
torch.save({'G': G.state_dict(), 'C': C.state_dict()}, ckpt)
print(f' → Checkpoint saved: {ckpt}')
return G, C, critic_losses, gen_losses
def train_cnn(train_loader, test_loader, seed, tag=''):
set_seed(seed)
model = CNNClassifier(CONFIG['n_classes'], CONFIG['dropout']).to(DEVICE)
optimizer = torch.optim.Adam(model.parameters(), lr=CONFIG['cnn_lr'])
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=CONFIG['cnn_epochs'])
criterion = nn.CrossEntropyLoss()
for epoch in range(CONFIG['cnn_epochs']):
model.train()
running_loss = 0.0
for imgs, labels in train_loader:
imgs, labels = imgs.to(DEVICE), labels.to(DEVICE)
optimizer.zero_grad()
loss = criterion(model(imgs), labels)
loss.backward()
optimizer.step()
running_loss += loss.item()
scheduler.step()
if (epoch + 1) % 10 == 0:
print(f' [{tag}] Epoch {epoch+1}/{CONFIG["cnn_epochs"]} '
f'loss={running_loss/len(train_loader):.4f}')
model.eval()
correct = total = 0
with torch.no_grad():
for imgs, labels in test_loader:
imgs, labels = imgs.to(DEVICE), labels.to(DEVICE)
correct += (model(imgs).argmax(dim=1) == labels).sum().item()
total += labels.size(0)
return correct / total, model
def generate_synthetic(G, n_per_class, n_classes, latent_dim, device):
G.eval()
all_imgs, all_labels = [], []
with torch.no_grad():
for c in range(n_classes):
noise = torch.randn(n_per_class, latent_dim, device=device)
labels = torch.full((n_per_class,), c, dtype=torch.long, device=device)
all_imgs.append(G(noise, labels).cpu())
all_labels.extend([c] * n_per_class)
return SyntheticDataset(
torch.cat(all_imgs, dim=0),
torch.tensor(all_labels, dtype=torch.long)
)
print('✓ All training functions defined')
✓ All training functions defined
Cell 3A — WGAN-GP training (skip-if-trained). Trains the conditional WGAN-GP for 500 epochs on the real training set — but only if the final checkpoint is absent from Drive. When the checkpoint exists (the normal case after the initial overnight run), the cell prints a skip notice, making full Run-All executions cheap and reproducible.
# ============================================================
# CELL 3A — Train C-WGAN-GP from scratch (skip if checkpoint exists)
# ============================================================
FINAL_CKPT = os.path.join(CKPT_DIR, 'wgan_500ep_final.pt')
if os.path.exists(FINAL_CKPT):
print(f'Checkpoint found at {FINAL_CKPT}')
print('Skip this cell and run Cell 3B instead')
else:
print('No checkpoint found — training from scratch...')
CONFIG['wgan_epochs'] = 500
G, C, critic_losses, gen_losses = train_wgan(train_loader, CONFIG, DEVICE)
torch.save({'G': G.state_dict(), 'C': C.state_dict()}, FINAL_CKPT)
print(f'\n✓ Model saved: {FINAL_CKPT}')
Checkpoint found at /content/drive/MyDrive/Chinese_MNIST_Data_Augmentation/checkpoints/wgan_500ep_final.pt Skip this cell and run Cell 3B instead
Cell 3B — Model restoration. Instantiates fresh Generator and Critic objects and loads the 500-epoch weights from the Drive checkpoint. This decouples every downstream experiment from the cost of training: any new session reaches an identical trained generator in seconds.
# ============================================================
# CELL 3B — Load C-WGAN-GP from checkpoint
# ============================================================
FINAL_CKPT = os.path.join(CKPT_DIR, 'wgan_500ep_final.pt')
ckpt = torch.load(FINAL_CKPT, map_location=DEVICE)
G = Generator(CONFIG['latent_dim'], CONFIG['n_classes']).to(DEVICE)
C = Critic(CONFIG['n_classes']).to(DEVICE)
G.load_state_dict(ckpt['G'])
C.load_state_dict(ckpt['C'])
print(f'✓ Loaded generator and critic from {FINAL_CKPT}')
✓ Loaded generator and critic from /content/drive/MyDrive/Chinese_MNIST_Data_Augmentation/checkpoints/wgan_500ep_final.pt
Cell 3C — Generation quality check (Checkpoint 3). Samples one synthetic image per class from the restored generator and renders the 3×5 grid. A second human gate: downstream augmentation experiments are only meaningful if the generator demonstrably produces class-conditional character structure rather than noise.
# ============================================================
# CELL 3C — CHECKPOINT 3: Visual Inspection of Generated Samples
# ============================================================
G.eval()
fig, axes = plt.subplots(3, 5, figsize=(12, 8))
fig.suptitle('CHECKPOINT 3: Generated samples (one per class)', fontsize=13, fontweight='bold')
with torch.no_grad():
for i, ax in enumerate(axes.flat):
noise = torch.randn(1, CONFIG['latent_dim']).to(DEVICE)
label = torch.tensor([i]).to(DEVICE)
img = G(noise, label).squeeze().cpu().numpy() * 0.5 + 0.5
ax.imshow(img, cmap='gray')
ax.set_title(f'Class {i}', fontsize=9)
ax.axis('off')
plt.tight_layout()
plt.savefig(os.path.join(OUT_DIR, 'checkpoint3_generated.png'), dpi=100)
plt.show()
print('\n✅ CHECKPOINT 3 — Do images resemble Chinese characters?')
print('If yes: proceed to Phase 4')
print('If noise: re-train with more epochs')
✅ CHECKPOINT 3 — Do images resemble Chinese characters? If yes: proceed to Phase 4 If noise: re-train with more epochs
This section builds the experiment around a single design principle: the synthetic images are a fixed artifact, and "filtered" vs "unfiltered" are two selection rules over the same raw generator outputs.
Because filtered and unfiltered now share identical raw generations, identical counts, and identical ratios, the four isolation contrasts in Cell 5.1 (filtered − unfiltered at fixed real-count and ratio) attribute any difference to the critic selection alone. This supersedes the earlier hardcoded Experiment-1 constants and the separate-pool generation; everything is one coherent run.
Executed on a Colab L4. All ten conditions were run live (≈50 CNN trainings: 10 conditions × 5 seeds); the tables and plot below are from that run. No numbers are hardcoded — every cell fills in when run.
Cell A — Deep-scarcity training set. Builds the 25-real-images-per-class training set by slicing the first 25 indices of each class block from the existing 50/class training split. Because it is a strict subset, the validation and test sets remain byte-identical to the 50/class regime, keeping the two regimes comparable on the same held-out data. An assertion verifies the expected 375-image count.
# ============================================================
# CELL A — Extreme scarcity: 25 real images/class
# Built as a SUBSET of the existing 50/class train set,
# so val/test sets remain IDENTICAL (clean comparison).
# ============================================================
N_SCARCE = 25
scarce_idx = []
for c in range(CONFIG['n_classes']):
class_block = train_idx[c * 50 : c * 50 + 50]
scarce_idx.extend(class_block[:N_SCARCE])
train_dataset_scarce = ChineseMNIST(CSV_PATH, IMAGES_DIR, scarce_idx, transform)
assert len(train_dataset_scarce) == N_SCARCE * CONFIG['n_classes']
train_loader_scarce = DataLoader(train_dataset_scarce,
batch_size=CONFIG['cnn_batch_size'],
shuffle=True, num_workers=2, pin_memory=True)
print(f'✓ Scarce train set: {len(train_dataset_scarce)} images ({N_SCARCE}/class)')
print(f'✓ Test set unchanged: {len(test_dataset)} images')
✓ Scarce train set: 375 images (25/class) ✓ Test set unchanged: 1500 images
Cell B — One raw pool, scored once; two selection rules. Generates 800 raw samples/class from the generator, scores all with the critic a single time, then derives the filtered block (the top-scoring 200/class, score-sorted) and the unfiltered block (a random 200/class) from that same pool. Defines make_aug_loader_gen, which concatenates real images with a nested prefix of a synthetic block — so the filtered condition trains on the strict top-k by critic score (the top ~3–25% of the pool, depending on count and ratio). Note: the unfiltered random draw may overlap the filtered top scores — that is intentional and unbiased; excluding the high-scoring images would depress the unfiltered condition and inflate the filter's apparent effect.
# ============================================================
# CELL B - One raw pool, scored once; two selection rules
# filtered vs unfiltered = two selections over the SAME raw
# generator outputs, so the contrast isolates the critic alone.
# ============================================================
POOL_PER_CLASS = 800 # raw pool; top 25% -> 200 filtered/class
KEEP = 200 # max synthetic/class needed (50/class at 4:1)
def make_scored_pool(G, C, pool_per_class, n_classes, latent_dim, device, seed):
# Generate one raw pool/class from G; score every image with C once.
set_seed(seed)
G.eval(); C.eval()
imgs_all, score_all = [], []
with torch.no_grad():
for c in range(n_classes):
noise = torch.randn(pool_per_class, latent_dim, device=device)
labels = torch.full((pool_per_class,), c, dtype=torch.long, device=device)
imgs = G(noise, labels)
scores = C(imgs, labels).squeeze(1) # higher = more realistic
imgs_all.append(imgs.cpu()); score_all.append(scores.cpu())
return torch.cat(imgs_all), torch.cat(score_all)
def make_aug_loader_gen(real_ds, syn_ds, n_syn_per_class, block_per_class):
idx = []
for c in range(CONFIG['n_classes']):
start = c * block_per_class
idx.extend(range(start, start + n_syn_per_class))
sub = torch.utils.data.Subset(syn_ds, idx)
ds = ConcatDataset([real_ds, sub])
return DataLoader(ds, batch_size=CONFIG['cnn_batch_size'],
shuffle=True, num_workers=2, pin_memory=True), len(ds)
# 1) Generate + score ONE pool
raw_imgs, raw_scores = make_scored_pool(
G, C, POOL_PER_CLASS, CONFIG['n_classes'], CONFIG['latent_dim'], DEVICE,
seed=CONFIG['base_seed'])
print(f'Raw pool: {len(raw_imgs)} images ({POOL_PER_CLASS}/class), scored once')
# 2) Derive two selection rules over the SAME pool, KEEP/class each.
# Filtered = top-KEEP by critic score (best first -> nests by ratio).
# Unfiltered= random-KEEP from the same pool (no selection).
rng = np.random.RandomState(CONFIG['base_seed'])
filt_i, filt_l, unf_i, unf_l = [], [], [], []
for cls in range(CONFIG['n_classes']):
lo = cls * POOL_PER_CLASS
block_imgs = raw_imgs[lo:lo + POOL_PER_CLASS]
block_scores = raw_scores[lo:lo + POOL_PER_CLASS]
top = torch.argsort(block_scores, descending=True)[:KEEP]
rand = torch.from_numpy(rng.permutation(POOL_PER_CLASS)[:KEEP])
filt_i.append(block_imgs[top]); filt_l += [cls] * KEEP
unf_i.append(block_imgs[rand]); unf_l += [cls] * KEEP
syn_filtered = SyntheticDataset(torch.cat(filt_i), torch.tensor(filt_l, dtype=torch.long))
syn_unfiltered = SyntheticDataset(torch.cat(unf_i), torch.tensor(unf_l, dtype=torch.long))
print(f'Filtered set: {len(syn_filtered)} images ({KEEP}/class, critic top-k by score)')
print(f'Unfiltered set: {len(syn_unfiltered)} images ({KEEP}/class, random from same pool)')
Raw pool: 12000 images (800/class), scored once Filtered set: 3000 images (200/class, critic top-k by score) Unfiltered set: 3000 images (200/class, random from same pool)
Cell B-check — visual confirmation. Top row: the highest-critic-scored (filtered) sample per class. Bottom row: a random (unfiltered) sample from the same pool. The filtered samples should look cleaner; if they don't, the critic is not separating quality and the downstream contrast won't mean much.
# ============================================================
# CELL B-check - filtered (top) vs unfiltered (bottom) samples
# ============================================================
show_classes = [0, 3, 6, 9, 12]
fig, axes = plt.subplots(2, len(show_classes), figsize=(2.2 * len(show_classes), 4.6))
for j, cls in enumerate(show_classes):
f_img, _ = syn_filtered[cls * KEEP + 0] # top-scoring
u_img, _ = syn_unfiltered[cls * KEEP + 0] # a random one
axes[0, j].imshow(f_img.squeeze().numpy() * 0.5 + 0.5, cmap='gray')
axes[0, j].set_title(f'class {cls}', fontsize=9); axes[0, j].axis('off')
axes[1, j].imshow(u_img.squeeze().numpy() * 0.5 + 0.5, cmap='gray')
axes[1, j].axis('off')
fig.text(0.01, 0.72, 'filtered', rotation=90, va='center', fontsize=10)
fig.text(0.01, 0.28, 'unfiltered', rotation=90, va='center', fontsize=10)
fig.suptitle('Same raw pool, two selection rules', fontsize=12)
plt.tight_layout(rect=[0.03, 0, 1, 1])
plt.savefig(os.path.join(OUT_DIR, 'rev11_pool_check.png'), dpi=120)
plt.show()
Cell C — Full factorial execution. Runs all ten conditions live from the shared pool, storing per-seed accuracies in results. Same seeds and test set throughout, so every paired t-test downstream is valid. About 50 CNN trainings.
# ============================================================
# CELL C (rev11) - Full factorial execution (10 conditions, ONE pool)
# ~50 CNN trainings (10 x 5 seeds). All from the same raw pool.
# ============================================================
results = {}
def run_condition(key, loader, label):
print(f'=== {label} ===')
accs = []
for i in range(CONFIG['n_seeds']):
acc, _ = train_cnn(loader, test_loader, seed=CONFIG['base_seed'] + i, tag=f'{key} s{i+1}')
accs.append(acc); print(f' -> Seed {i+1}: {acc:.4f}')
results[key] = accs
print(f' {label}: {np.mean(accs):.4f} +/- {np.std(accs):.4f}\n')
# Baselines (real only)
run_condition('baseline_50', train_loader, 'Baseline 50/class')
run_condition('baseline_25', train_loader_scarce, 'Baseline 25/class')
# 50 real/class + augmentation (block size = KEEP)
lo, _ = make_aug_loader_gen(train_dataset, syn_unfiltered, 50, KEEP); run_condition('unfilt_50_1to1', lo, 'Unfiltered 1:1 @50')
lo, _ = make_aug_loader_gen(train_dataset, syn_unfiltered, 200, KEEP); run_condition('unfilt_50_4to1', lo, 'Unfiltered 4:1 @50')
lo, _ = make_aug_loader_gen(train_dataset, syn_filtered, 50, KEEP); run_condition('filt_50_1to1', lo, 'Filtered 1:1 @50')
lo, _ = make_aug_loader_gen(train_dataset, syn_filtered, 200, KEEP); run_condition('filt_50_4to1', lo, 'Filtered 4:1 @50')
# 25 real/class + augmentation
lo, _ = make_aug_loader_gen(train_dataset_scarce, syn_unfiltered, 25, KEEP); run_condition('unfilt_25_1to1', lo, 'Unfiltered 1:1 @25')
lo, _ = make_aug_loader_gen(train_dataset_scarce, syn_unfiltered, 100, KEEP); run_condition('unfilt_25_4to1', lo, 'Unfiltered 4:1 @25')
lo, _ = make_aug_loader_gen(train_dataset_scarce, syn_filtered, 25, KEEP); run_condition('filt_25_1to1', lo, 'Filtered 1:1 @25')
lo, _ = make_aug_loader_gen(train_dataset_scarce, syn_filtered, 100, KEEP); run_condition('filt_25_4to1', lo, 'Filtered 4:1 @25')
print('\nAll conditions complete.')
=== Baseline 50/class === [baseline_50 s1] Epoch 10/30 loss=0.4733 [baseline_50 s1] Epoch 20/30 loss=0.1218 [baseline_50 s1] Epoch 30/30 loss=0.0867 -> Seed 1: 0.8907 [baseline_50 s2] Epoch 10/30 loss=0.5400 [baseline_50 s2] Epoch 20/30 loss=0.1765 [baseline_50 s2] Epoch 30/30 loss=0.1150 -> Seed 2: 0.8933 [baseline_50 s3] Epoch 10/30 loss=0.9371 [baseline_50 s3] Epoch 20/30 loss=0.3433 [baseline_50 s3] Epoch 30/30 loss=0.2763 -> Seed 3: 0.8820 [baseline_50 s4] Epoch 10/30 loss=0.8691 [baseline_50 s4] Epoch 20/30 loss=0.3554 [baseline_50 s4] Epoch 30/30 loss=0.2586 -> Seed 4: 0.8713 [baseline_50 s5] Epoch 10/30 loss=0.7191 [baseline_50 s5] Epoch 20/30 loss=0.2068 [baseline_50 s5] Epoch 30/30 loss=0.1570 -> Seed 5: 0.8873 Baseline 50/class: 0.8849 +/- 0.0078 === Baseline 25/class === [baseline_25 s1] Epoch 10/30 loss=1.1889 [baseline_25 s1] Epoch 20/30 loss=0.4814 [baseline_25 s1] Epoch 30/30 loss=0.3193 -> Seed 1: 0.7140 [baseline_25 s2] Epoch 10/30 loss=0.9380 [baseline_25 s2] Epoch 20/30 loss=0.4288 [baseline_25 s2] Epoch 30/30 loss=0.3054 -> Seed 2: 0.7647 [baseline_25 s3] Epoch 10/30 loss=1.3329 [baseline_25 s3] Epoch 20/30 loss=0.5279 [baseline_25 s3] Epoch 30/30 loss=0.4185 -> Seed 3: 0.7113 [baseline_25 s4] Epoch 10/30 loss=1.3607 [baseline_25 s4] Epoch 20/30 loss=0.5606 [baseline_25 s4] Epoch 30/30 loss=0.4142 -> Seed 4: 0.7127 [baseline_25 s5] Epoch 10/30 loss=1.3868 [baseline_25 s5] Epoch 20/30 loss=0.6796 [baseline_25 s5] Epoch 30/30 loss=0.5963 -> Seed 5: 0.6860 Baseline 25/class: 0.7177 +/- 0.0257 === Unfiltered 1:1 @50 === [unfilt_50_1to1 s1] Epoch 10/30 loss=0.3409 [unfilt_50_1to1 s1] Epoch 20/30 loss=0.0965 [unfilt_50_1to1 s1] Epoch 30/30 loss=0.0605 -> Seed 1: 0.8680 [unfilt_50_1to1 s2] Epoch 10/30 loss=0.2277 [unfilt_50_1to1 s2] Epoch 20/30 loss=0.0649 [unfilt_50_1to1 s2] Epoch 30/30 loss=0.0461 -> Seed 2: 0.8860 [unfilt_50_1to1 s3] Epoch 10/30 loss=0.4278 [unfilt_50_1to1 s3] Epoch 20/30 loss=0.1838 [unfilt_50_1to1 s3] Epoch 30/30 loss=0.1005 -> Seed 3: 0.8540 [unfilt_50_1to1 s4] Epoch 10/30 loss=0.3158 [unfilt_50_1to1 s4] Epoch 20/30 loss=0.0997 [unfilt_50_1to1 s4] Epoch 30/30 loss=0.0675 -> Seed 4: 0.8720 [unfilt_50_1to1 s5] Epoch 10/30 loss=0.2906 [unfilt_50_1to1 s5] Epoch 20/30 loss=0.0783 [unfilt_50_1to1 s5] Epoch 30/30 loss=0.0465 -> Seed 5: 0.8727 Unfiltered 1:1 @50: 0.8705 +/- 0.0103 === Unfiltered 4:1 @50 === [unfilt_50_4to1 s1] Epoch 10/30 loss=0.0997 [unfilt_50_4to1 s1] Epoch 20/30 loss=0.0238 [unfilt_50_4to1 s1] Epoch 30/30 loss=0.0180 -> Seed 1: 0.8620 [unfilt_50_4to1 s2] Epoch 10/30 loss=0.1271 [unfilt_50_4to1 s2] Epoch 20/30 loss=0.0330 [unfilt_50_4to1 s2] Epoch 30/30 loss=0.0187 -> Seed 2: 0.8667 [unfilt_50_4to1 s3] Epoch 10/30 loss=0.0979 [unfilt_50_4to1 s3] Epoch 20/30 loss=0.0250 [unfilt_50_4to1 s3] Epoch 30/30 loss=0.0121 -> Seed 3: 0.8520 [unfilt_50_4to1 s4] Epoch 10/30 loss=0.0905 [unfilt_50_4to1 s4] Epoch 20/30 loss=0.0167 [unfilt_50_4to1 s4] Epoch 30/30 loss=0.0094 -> Seed 4: 0.8653 [unfilt_50_4to1 s5] Epoch 10/30 loss=0.0915 [unfilt_50_4to1 s5] Epoch 20/30 loss=0.0278 [unfilt_50_4to1 s5] Epoch 30/30 loss=0.0145 -> Seed 5: 0.8713 Unfiltered 4:1 @50: 0.8635 +/- 0.0065 === Filtered 1:1 @50 === [filt_50_1to1 s1] Epoch 10/30 loss=0.2693 [filt_50_1to1 s1] Epoch 20/30 loss=0.0851 [filt_50_1to1 s1] Epoch 30/30 loss=0.0548 -> Seed 1: 0.8687 [filt_50_1to1 s2] Epoch 10/30 loss=0.1742 [filt_50_1to1 s2] Epoch 20/30 loss=0.0448 [filt_50_1to1 s2] Epoch 30/30 loss=0.0272 -> Seed 2: 0.8960 [filt_50_1to1 s3] Epoch 10/30 loss=0.2647 [filt_50_1to1 s3] Epoch 20/30 loss=0.0932 [filt_50_1to1 s3] Epoch 30/30 loss=0.0519 -> Seed 3: 0.8573 [filt_50_1to1 s4] Epoch 10/30 loss=0.2335 [filt_50_1to1 s4] Epoch 20/30 loss=0.0580 [filt_50_1to1 s4] Epoch 30/30 loss=0.0377 -> Seed 4: 0.8767 [filt_50_1to1 s5] Epoch 10/30 loss=0.2067 [filt_50_1to1 s5] Epoch 20/30 loss=0.0592 [filt_50_1to1 s5] Epoch 30/30 loss=0.0319 -> Seed 5: 0.8813 Filtered 1:1 @50: 0.8760 +/- 0.0129 === Filtered 4:1 @50 === [filt_50_4to1 s1] Epoch 10/30 loss=0.0853 [filt_50_4to1 s1] Epoch 20/30 loss=0.0177 [filt_50_4to1 s1] Epoch 30/30 loss=0.0104 -> Seed 1: 0.8673 [filt_50_4to1 s2] Epoch 10/30 loss=0.0851 [filt_50_4to1 s2] Epoch 20/30 loss=0.0209 [filt_50_4to1 s2] Epoch 30/30 loss=0.0148 -> Seed 2: 0.8773 [filt_50_4to1 s3] Epoch 10/30 loss=0.0778 [filt_50_4to1 s3] Epoch 20/30 loss=0.0226 [filt_50_4to1 s3] Epoch 30/30 loss=0.0105 -> Seed 3: 0.8613 [filt_50_4to1 s4] Epoch 10/30 loss=0.0678 [filt_50_4to1 s4] Epoch 20/30 loss=0.0131 [filt_50_4to1 s4] Epoch 30/30 loss=0.0070 -> Seed 4: 0.8573 [filt_50_4to1 s5] Epoch 10/30 loss=0.1098 [filt_50_4to1 s5] Epoch 20/30 loss=0.0208 [filt_50_4to1 s5] Epoch 30/30 loss=0.0133 -> Seed 5: 0.8607 Filtered 4:1 @50: 0.8648 +/- 0.0070 === Unfiltered 1:1 @25 === [unfilt_25_1to1 s1] Epoch 10/30 loss=0.2884 [unfilt_25_1to1 s1] Epoch 20/30 loss=0.0594 [unfilt_25_1to1 s1] Epoch 30/30 loss=0.0473 -> Seed 1: 0.7860 [unfilt_25_1to1 s2] Epoch 10/30 loss=0.6548 [unfilt_25_1to1 s2] Epoch 20/30 loss=0.2062 [unfilt_25_1to1 s2] Epoch 30/30 loss=0.1586 -> Seed 2: 0.7860 [unfilt_25_1to1 s3] Epoch 10/30 loss=0.6164 [unfilt_25_1to1 s3] Epoch 20/30 loss=0.2212 [unfilt_25_1to1 s3] Epoch 30/30 loss=0.1448 -> Seed 3: 0.7753 [unfilt_25_1to1 s4] Epoch 10/30 loss=0.4434 [unfilt_25_1to1 s4] Epoch 20/30 loss=0.1270 [unfilt_25_1to1 s4] Epoch 30/30 loss=0.0916 -> Seed 4: 0.7640 [unfilt_25_1to1 s5] Epoch 10/30 loss=0.6515 [unfilt_25_1to1 s5] Epoch 20/30 loss=0.2248 [unfilt_25_1to1 s5] Epoch 30/30 loss=0.1784 -> Seed 5: 0.7647 Unfiltered 1:1 @25: 0.7752 +/- 0.0097 === Unfiltered 4:1 @25 === [unfilt_25_4to1 s1] Epoch 10/30 loss=0.1207 [unfilt_25_4to1 s1] Epoch 20/30 loss=0.0344 [unfilt_25_4to1 s1] Epoch 30/30 loss=0.0193 -> Seed 1: 0.7907 [unfilt_25_4to1 s2] Epoch 10/30 loss=0.1184 [unfilt_25_4to1 s2] Epoch 20/30 loss=0.0288 [unfilt_25_4to1 s2] Epoch 30/30 loss=0.0194 -> Seed 2: 0.8040 [unfilt_25_4to1 s3] Epoch 10/30 loss=0.1516 [unfilt_25_4to1 s3] Epoch 20/30 loss=0.0441 [unfilt_25_4to1 s3] Epoch 30/30 loss=0.0219 -> Seed 3: 0.7680 [unfilt_25_4to1 s4] Epoch 10/30 loss=0.1713 [unfilt_25_4to1 s4] Epoch 20/30 loss=0.0430 [unfilt_25_4to1 s4] Epoch 30/30 loss=0.0270 -> Seed 4: 0.7880 [unfilt_25_4to1 s5] Epoch 10/30 loss=0.1450 [unfilt_25_4to1 s5] Epoch 20/30 loss=0.0348 [unfilt_25_4to1 s5] Epoch 30/30 loss=0.0215 -> Seed 5: 0.7887 Unfiltered 4:1 @25: 0.7879 +/- 0.0115 === Filtered 1:1 @25 === [filt_25_1to1 s1] Epoch 10/30 loss=0.2645 [filt_25_1to1 s1] Epoch 20/30 loss=0.0764 [filt_25_1to1 s1] Epoch 30/30 loss=0.0560 -> Seed 1: 0.7833 [filt_25_1to1 s2] Epoch 10/30 loss=0.3955 [filt_25_1to1 s2] Epoch 20/30 loss=0.1276 [filt_25_1to1 s2] Epoch 30/30 loss=0.0952 -> Seed 2: 0.7833 [filt_25_1to1 s3] Epoch 10/30 loss=0.4028 [filt_25_1to1 s3] Epoch 20/30 loss=0.1326 [filt_25_1to1 s3] Epoch 30/30 loss=0.0857 -> Seed 3: 0.7660 [filt_25_1to1 s4] Epoch 10/30 loss=0.4151 [filt_25_1to1 s4] Epoch 20/30 loss=0.1129 [filt_25_1to1 s4] Epoch 30/30 loss=0.0644 -> Seed 4: 0.7847 [filt_25_1to1 s5] Epoch 10/30 loss=0.4703 [filt_25_1to1 s5] Epoch 20/30 loss=0.1783 [filt_25_1to1 s5] Epoch 30/30 loss=0.1156 -> Seed 5: 0.7767 Filtered 1:1 @25: 0.7788 +/- 0.0070 === Filtered 4:1 @25 === [filt_25_4to1 s1] Epoch 10/30 loss=0.0881 [filt_25_4to1 s1] Epoch 20/30 loss=0.0222 [filt_25_4to1 s1] Epoch 30/30 loss=0.0128 -> Seed 1: 0.7980 [filt_25_4to1 s2] Epoch 10/30 loss=0.0634 [filt_25_4to1 s2] Epoch 20/30 loss=0.0145 [filt_25_4to1 s2] Epoch 30/30 loss=0.0109 -> Seed 2: 0.7933 [filt_25_4to1 s3] Epoch 10/30 loss=0.1022 [filt_25_4to1 s3] Epoch 20/30 loss=0.0312 [filt_25_4to1 s3] Epoch 30/30 loss=0.0177 -> Seed 3: 0.7793 [filt_25_4to1 s4] Epoch 10/30 loss=0.0955 [filt_25_4to1 s4] Epoch 20/30 loss=0.0230 [filt_25_4to1 s4] Epoch 30/30 loss=0.0133 -> Seed 4: 0.7973 [filt_25_4to1 s5] Epoch 10/30 loss=0.1218 [filt_25_4to1 s5] Epoch 20/30 loss=0.0280 [filt_25_4to1 s5] Epoch 30/30 loss=0.0171 -> Seed 5: 0.7827 Filtered 4:1 @25: 0.7901 +/- 0.0077 All conditions complete.
Cell 5.1 — Full factorial table + critic-filter isolation. All ten conditions, each augmented row paired against its same-regime baseline. The isolation block reports filtered − unfiltered at fixed real-count and ratio: with the pool, count, and ratio all held constant, any significant value would isolate the critic filter as the cause. (In this run all four contrasts are ≤ 0.55 points and none survive Holm–Bonferroni — no detectable Stage 1 filter effect.)
# ============================================================
# CELL 5.1 - Full factorial table & significance tests (rev11)
# ============================================================
from scipy import stats
def fmt(name, accs, base=None):
m, s = np.mean(accs), np.std(accs)
if base is None:
print(f'{name:<30} {m:>8.4f} {s:>8.4f} {"-":>9} {"-":>9}')
else:
t, p = stats.ttest_rel(accs, base)
d = m - np.mean(base); flag = '*' if p < 0.05 else ' '
print(f'{name:<30} {m:>8.4f} {s:>8.4f} {d:>+9.4f} {p:>8.4f}{flag}')
print('=' * 70)
print('FULL FACTORIAL (one shared raw pool) - real-count x filter x ratio')
print('=' * 70)
print(f'{"Condition":<30} {"Mean":>8} {"Std":>8} {"Delta":>9} {"p-value":>9}')
print('-' * 70)
print('- 50 real/class -')
fmt('Baseline (50)', results['baseline_50'])
fmt('Unfiltered 1:1', results['unfilt_50_1to1'], results['baseline_50'])
fmt('Unfiltered 4:1', results['unfilt_50_4to1'], results['baseline_50'])
fmt('Filtered 1:1', results['filt_50_1to1'], results['baseline_50'])
fmt('Filtered 4:1', results['filt_50_4to1'], results['baseline_50'])
print('- 25 real/class -')
fmt('Baseline (25)', results['baseline_25'])
fmt('Unfiltered 1:1', results['unfilt_25_1to1'], results['baseline_25'])
fmt('Unfiltered 4:1', results['unfilt_25_4to1'], results['baseline_25'])
fmt('Filtered 1:1', results['filt_25_1to1'], results['baseline_25'])
fmt('Filtered 4:1', results['filt_25_4to1'], results['baseline_25'])
print('=' * 70)
print('* significant at p < 0.05 (paired t-test vs same-regime baseline)')
print('\n' + '=' * 70)
print('CRITIC-FILTER ISOLATION (filtered - unfiltered; real-count AND ratio fixed)')
print('=' * 70)
def contrast(name, a, b):
d = np.array(a) - np.array(b); t, p = stats.ttest_rel(a, b)
flag = '*' if p < 0.05 else ' '
print(f'{name:<24} {np.mean(d):>+8.4f} (paired p = {p:.4f}){flag}')
contrast('50/class 1:1', results['filt_50_1to1'], results['unfilt_50_1to1'])
contrast('50/class 4:1', results['filt_50_4to1'], results['unfilt_50_4to1'])
contrast('25/class 1:1', results['filt_25_1to1'], results['unfilt_25_1to1'])
contrast('25/class 4:1', results['filt_25_4to1'], results['unfilt_25_4to1'])
print('=' * 70)
print('Same pool, same count, same ratio -> these isolate the critic filter.')
====================================================================== FULL FACTORIAL (one shared raw pool) - real-count x filter x ratio ====================================================================== Condition Mean Std Delta p-value ---------------------------------------------------------------------- - 50 real/class - Baseline (50) 0.8849 0.0078 - - Unfiltered 1:1 0.8705 0.0103 -0.0144 0.0490* Unfiltered 4:1 0.8635 0.0065 -0.0215 0.0094* Filtered 1:1 0.8760 0.0129 -0.0089 0.2221 Filtered 4:1 0.8648 0.0070 -0.0201 0.0010* - 25 real/class - Baseline (25) 0.7177 0.0257 - - Unfiltered 1:1 0.7752 0.0097 +0.0575 0.0047* Unfiltered 4:1 0.7879 0.0115 +0.0701 0.0027* Filtered 1:1 0.7788 0.0070 +0.0611 0.0071* Filtered 4:1 0.7901 0.0077 +0.0724 0.0036* ====================================================================== * significant at p < 0.05 (paired t-test vs same-regime baseline) ====================================================================== CRITIC-FILTER ISOLATION (filtered - unfiltered; real-count AND ratio fixed) ====================================================================== 50/class 1:1 +0.0055 (paired p = 0.0335)* 50/class 4:1 +0.0013 (paired p = 0.7800) 25/class 1:1 +0.0036 (paired p = 0.5493) 25/class 4:1 +0.0023 (paired p = 0.6362) ====================================================================== Same pool, same count, same ratio -> these isolate the critic filter.
Cell 5.2 — Full factorial plot. Blue = baseline, red = unfiltered, green = filtered; the dotted line separates the 50/class block (left) from the 25/class block (right).
# ============================================================
# CELL 5.2 - Full factorial plot (rev11)
# ============================================================
names = ['Base\n50','Unf 1:1\n50','Unf 4:1\n50','Filt 1:1\n50','Filt 4:1\n50',
'Base\n25','Unf 1:1\n25','Unf 4:1\n25','Filt 1:1\n25','Filt 4:1\n25']
keys = ['baseline_50','unfilt_50_1to1','unfilt_50_4to1','filt_50_1to1','filt_50_4to1',
'baseline_25','unfilt_25_1to1','unfilt_25_4to1','filt_25_1to1','filt_25_4to1']
data = [results[k] for k in keys]
means = [np.mean(d) for d in data]; stds = [np.std(d) for d in data]
colors = ['#4C72B0','#C44E52','#C44E52','#55A868','#55A868',
'#4C72B0','#C44E52','#C44E52','#55A868','#55A868']
fig, ax = plt.subplots(figsize=(13, 6))
bars = ax.bar(names, means, yerr=stds, capsize=4, color=colors, alpha=0.85)
for b, m in zip(bars, means):
ax.text(b.get_x() + b.get_width()/2, b.get_height() + 0.003, f'{m:.3f}',
ha='center', va='bottom', fontsize=8)
ax.set_ylabel('Test Accuracy')
ax.set_title('Full factorial (one shared pool): blue=baseline, red=unfiltered, green=filtered')
ax.axvline(4.5, color='gray', ls=':', lw=1, alpha=0.6)
ax.set_ylim(min(means) - 0.05, max(means) + 0.04)
plt.tight_layout()
plt.savefig(os.path.join(OUT_DIR, 'rev11_full_factorial.png'), dpi=150)
plt.show()
print('Saved rev11 full-factorial plot')
Saved rev11 full-factorial plot