Generative Augmentation in Sparse Data Regimes¶

A Controlled Factorial Study in Low-Resource Chinese Character Classification¶

Author: Joseph Catanzarite
Course: EN.705.603 Introduction to Generative AI, Spring 2026
Johns Hopkins University — Whiting School of Engineering


Questions¶

  1. Augmentation efficacy. Does augmenting a small real training set with synthetic images from a Conditional Wasserstein Generative Adversarial Network with Gradient Penalty (C-WGAN-GP) produce a statistically significant change in downstream Convolutional Neural Network (CNN) classification accuracy, relative to training on real data alone?
  2. Critic-filter value. Does selecting synthetic samples by the trained critic's realism score (a top-k "quality filter", Stage 1 of a planned two-stage filter) improve augmentation over an unfiltered random draw from the same generator outputs?

Experimental Design — controlled 2 × 2 × 3 factorial¶

  • Dataset: Chinese-MNIST --- a 15-class analogue of the MNIST (Modified National Institute of Standards and Technology) handwritten-digit benchmark --- 64×64 grayscale, ~1000 images/class
  • Two scarcity regimes: 25 and 50 real images/class (deep and mild scarcity)
  • Generator: C-WGAN-GP trained once on the 50/class real set
  • Selection rule: critic-filtered (top-k by score) vs unfiltered (random) — both drawn from a single shared 800/class critic-scored pool
  • Augmentation ratios: 1:1 and 4:1 synthetic-to-real
  • Conditions: a real-only baseline at each regime, plus selection × ratio at each regime — ten conditions in all
  • Statistics: paired t-test across K=5 seeds (p<0.05); plus four isolation contrasts (filtered − unfiltered at fixed count and ratio) under Holm–Bonferroni correction
  • Metric: top-1 accuracy on a fixed held-out real test set

Key result¶

Whether augmentation helps is set by data scarcity, not filtering. At 50/class the baseline is near-saturated and every augmented condition degrades accuracy; at 25/class every augmented condition helps — up to +7.2 points (p<0.01), recovering ~43% of the accuracy lost to halving the real data. Holding count and ratio fixed, Stage 1 critic filtering shows no statistically detectable effect (all four isolation contrasts ≤ 0.55 points; three non-significant, the fourth failing Holm–Bonferroni). An earlier diagonal design — unfiltered at 50/class, filtered at 25/class — confounded the two factors and made filtering look like the cause; this factorial removes the confound. Stage 2 of the filter (SSIM/LPIPS perceptual screening) was not implemented here and remains planned future work.

Medical Imaging Proxy Rationale¶

Low-resource image classification is a central challenge in medical imaging, where labeled data is expensive and scarce. This study uses Chinese-MNIST as a controlled proxy to evaluate whether generative augmentation can improve classifier performance under data scarcity conditions analogous to those encountered in clinical imaging datasets.


PHASE 1: Environment Setup & Data Pipeline¶


Cell 1.1 — Environment setup. Mounts Google Drive at /content/drive, defines all project paths (data, checkpoints, generated images, outputs) under a single base directory so artifacts persist across Colab sessions, creates any missing directories, installs the Kaggle CLI (command-line interface), and verifies that PyTorch sees a CUDA (Compute Unified Device Architecture) GPU (Graphics Processing Unit). Everything downstream assumes these paths and the GPU exist, so this cell must run first in every session.

In [1]:
# ============================================================
# CELL 1.1 — Mount Drive & Install Dependencies
# ============================================================
from google.colab import drive
drive.mount('/content/drive')

import os

BASE_DIR   = '/content/drive/MyDrive/Chinese_MNIST_Data_Augmentation'
DATA_DIR   = os.path.join(BASE_DIR, 'data')
IMAGES_DIR = os.path.join(DATA_DIR, 'data', 'data')
CSV_PATH   = os.path.join(DATA_DIR, 'chinese_mnist.csv')
CKPT_DIR   = os.path.join(BASE_DIR, 'checkpoints')
GEN_DIR    = os.path.join(BASE_DIR, 'generated')
OUT_DIR    = os.path.join(BASE_DIR, 'outputs')

for d in [DATA_DIR, CKPT_DIR, GEN_DIR, OUT_DIR]:
    os.makedirs(d, exist_ok=True)

!pip install -q kaggle

import torch
print(f'PyTorch: {torch.__version__}')
print(f'CUDA: {torch.cuda.is_available()}')
if torch.cuda.is_available():
    print(f'GPU: {torch.cuda.get_device_name(0)}')
Mounted at /content/drive
PyTorch: 2.11.0+cu128
CUDA: True
GPU: NVIDIA L4

Cell 1.2 — Dataset acquisition (idempotent). Downloads the Chinese-MNIST dataset (15,000 JPEG --- Joint Photographic Experts Group --- images plus an index CSV, a comma-separated values file) from Kaggle into Drive, but only if the CSV is not already present — re-runs skip the download entirely. Two assertions act as a contract: the CSV must exist and exactly 15,000 images must be found, failing fast if the extraction layout is wrong.

In [2]:
# ============================================================
# CELL 1.2 — Download Dataset (skips if already in Drive)
# ============================================================
import os

# Kaggle API token — required only if the dataset is not already in Drive.
# Paste your own token below, or (preferred) store it in Colab Secrets
# (key icon in the left sidebar, name it KAGGLE_TOKEN) and use the two
# commented lines instead.
os.environ['KAGGLE_TOKEN'] = 'PASTE_YOUR_KAGGLE_TOKEN_HERE'
# from google.colab import userdata
# os.environ['KAGGLE_TOKEN'] = userdata.get('KAGGLE_TOKEN')

if os.path.exists(CSV_PATH):
    print('Dataset already in Drive — skipping download')
else:
    print('Downloading...')
    os.system(f'kaggle datasets download -d gpreda/chinese-mnist -p "{DATA_DIR}" --unzip')
    print('Done')

n_images = len([f for f in os.listdir(IMAGES_DIR) if f.endswith('.jpg')])
assert os.path.exists(CSV_PATH), 'CSV not found'
assert n_images == 15000, f'Expected 15000 images, found {n_images}'
print(f'✓ CSV found')
print(f'✓ Images: {n_images}')
Dataset already in Drive — skipping download
✓ CSV found
✓ Images: 15000

Cell 1.3 — Configuration, data pipeline, and splits. Declares the single CONFIG dictionary holding every hyperparameter (image size, scarcity levels, WGAN-GP and CNN training settings, seeds), a set_seed utility for reproducibility, and two dataset classes: ChineseMNIST (reads images from disk via the CSV index) and SyntheticDataset (wraps generated tensors with the same (image, int_label) interface so real and synthetic data can be concatenated). It then builds a stratified, seeded train/val/test split (50/50/100 images per class) and the corresponding DataLoaders. Because the split is seeded once, every experiment in the notebook sees identical val/test data.

In [3]:
# ============================================================
# CELL 1.3 — Global CONFIG, Imports, Dataset, DataLoaders
# ============================================================
import numpy as np
import pandas as pd
import random
from PIL import Image
import matplotlib.pyplot as plt

import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader, ConcatDataset
import torchvision.transforms as transforms

# ---- CONFIG — all hyperparameters here ----
CONFIG = {
    'img_size':        64,
    'n_classes':       15,
    'n_channels':      1,
    'train_per_class': 50,      # simulated data scarcity
    'test_per_class':  100,
    'val_per_class':   50,
    'latent_dim':      128,
    'wgan_epochs':     500,
    'wgan_batch_size': 64,
    'wgan_lr':         1e-4,
    'n_critic':        5,
    'lambda_gp':       10,
    'wgan_b1':         0.0,
    'wgan_b2':         0.9,
    'cnn_epochs':      30,
    'cnn_batch_size':  64,
    'cnn_lr':          1e-3,
    'dropout':         0.4,
    'n_seeds':         5,
    'n_generate_1to1': 50,      # 1:1 synthetic-to-real ratio
    'n_generate_4to1': 200,     # 4:1 synthetic-to-real ratio
    'base_seed':       42,
    'images_dir':      IMAGES_DIR,
    'csv_path':        CSV_PATH,
    'ckpt_dir':        CKPT_DIR,
    'generated_dir':   GEN_DIR,
    'outputs_dir':     OUT_DIR,
}

# ---- Utilities ----
def set_seed(seed):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True

set_seed(CONFIG['base_seed'])
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f'Device: {DEVICE}')

# ---- Dataset ----
class ChineseMNIST(Dataset):
    def __init__(self, csv_path, images_dir, indices, transform=None):
        self.df = pd.read_csv(csv_path).iloc[indices].reset_index(drop=True)
        self.images_dir = images_dir
        self.transform = transform
        unique_vals = sorted(self.df['value'].unique())
        self.label_map = {v: i for i, v in enumerate(unique_vals)}

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        row = self.df.iloc[idx]
        fname = f"input_{row['suite_id']}_{row['sample_id']}_{row['code']}.jpg"
        img = Image.open(os.path.join(self.images_dir, fname)).convert('L')
        if self.transform:
            img = self.transform(img)
        return img, self.label_map[row['value']]

class SyntheticDataset(Dataset):
    def __init__(self, imgs_tensor, labels_tensor):
        self.imgs   = imgs_tensor
        self.labels = labels_tensor

    def __len__(self):
        return len(self.imgs)

    def __getitem__(self, idx):
        return self.imgs[idx], int(self.labels[idx])

# ---- Transforms ----
transform = transforms.Compose([
    transforms.Resize((CONFIG['img_size'], CONFIG['img_size'])),
    transforms.ToTensor(),
    transforms.Normalize([0.5], [0.5])
])

# ---- Stratified split ----
def make_stratified_split(csv_path, n_train, n_val, n_test, seed):
    df = pd.read_csv(csv_path)
    rng = np.random.RandomState(seed)
    train_idx, val_idx, test_idx = [], [], []
    for val in sorted(df['value'].unique()):
        class_idx = df[df['value'] == val].index.tolist()
        rng.shuffle(class_idx)
        train_idx.extend(class_idx[:n_train])
        val_idx.extend(class_idx[n_train:n_train+n_val])
        test_idx.extend(class_idx[n_train+n_val:n_train+n_val+n_test])
    return train_idx, val_idx, test_idx

train_idx, val_idx, test_idx = make_stratified_split(
    CSV_PATH, CONFIG['train_per_class'], CONFIG['val_per_class'],
    CONFIG['test_per_class'], CONFIG['base_seed']
)

train_dataset = ChineseMNIST(CSV_PATH, IMAGES_DIR, train_idx, transform)
val_dataset   = ChineseMNIST(CSV_PATH, IMAGES_DIR, val_idx,   transform)
test_dataset  = ChineseMNIST(CSV_PATH, IMAGES_DIR, test_idx,  transform)

assert len(train_dataset) == CONFIG['train_per_class'] * CONFIG['n_classes']
assert len(test_dataset)  == CONFIG['test_per_class']  * CONFIG['n_classes']

train_loader = DataLoader(train_dataset, batch_size=CONFIG['wgan_batch_size'],
                          shuffle=True,  num_workers=2, pin_memory=True)
val_loader   = DataLoader(val_dataset,   batch_size=CONFIG['cnn_batch_size'],
                          shuffle=False, num_workers=2, pin_memory=True)
test_loader  = DataLoader(test_dataset,  batch_size=CONFIG['cnn_batch_size'],
                          shuffle=False, num_workers=2, pin_memory=True)

print(f'✓ Train: {len(train_dataset)} images ({CONFIG["train_per_class"]}/class)')
print(f'✓ Val:   {len(val_dataset)} images')
print(f'✓ Test:  {len(test_dataset)} images')
Device: cuda
✓ Train: 750 images (50/class)
✓ Val:   750 images
✓ Test:  1500 images

Cell 1.4 — Data sanity check (Checkpoint 1). Renders one real training image per class in a 3×5 grid and saves the figure to the outputs directory. This is a human-in-the-loop gate: it confirms the file-name decoding, grayscale conversion, resizing, and label mapping all behave before any compute is spent on training.

In [4]:
# ============================================================
# CELL 1.4 — CHECKPOINT 1: Visualize Real Samples
# ============================================================
df_meta = pd.read_csv(CONFIG['csv_path'])
class_names = [str(row['character']) for _, row in
               df_meta.drop_duplicates('value').sort_values('value').iterrows()]

fig, axes = plt.subplots(3, 5, figsize=(12, 8))
fig.suptitle('CHECKPOINT 1: Real samples (one per class)', fontsize=13, fontweight='bold')
for i, ax in enumerate(axes.flat):
    img, label = train_dataset[i * CONFIG['train_per_class']]
    ax.imshow(img.squeeze().numpy() * 0.5 + 0.5, cmap='gray')
    ax.set_title(f'Class {label}', fontsize=9)
    ax.axis('off')
plt.tight_layout()
plt.savefig(os.path.join(OUT_DIR, 'checkpoint1_real_samples.png'), dpi=100)
plt.show()
print('\n✅ CHECKPOINT 1 PASSED — Real images look correct')
✅ CHECKPOINT 1 PASSED — Real images look correct

PHASE 2: Model Definitions¶


Cell 2.1 — Downstream classifier architecture. Defines CNNClassifier, the model whose test accuracy is the study's outcome metric: three Conv→BatchNorm→ReLU→MaxPool blocks (64×64 → 8×8 spatial) followed by a dropout-regularized fully-connected head producing 15 logits. The identical architecture and hyperparameters are used in every experimental condition, so any accuracy difference is attributable to the training data, not the model. A shape assertion on a dummy batch verifies the forward pass.

In [5]:
# ============================================================
# CELL 2.1 — CNN Classifier Architecture
# ============================================================
class CNNClassifier(nn.Module):
    def __init__(self, n_classes=15, dropout=0.4):
        super().__init__()
        self.features = nn.Sequential(
            nn.Conv2d(1, 32, kernel_size=3, padding=1),
            nn.BatchNorm2d(32), nn.ReLU(inplace=True), nn.MaxPool2d(2),
            nn.Conv2d(32, 64, kernel_size=3, padding=1),
            nn.BatchNorm2d(64), nn.ReLU(inplace=True), nn.MaxPool2d(2),
            nn.Conv2d(64, 128, kernel_size=3, padding=1),
            nn.BatchNorm2d(128), nn.ReLU(inplace=True), nn.MaxPool2d(2),
        )
        self.classifier = nn.Sequential(
            nn.Flatten(),
            nn.Linear(128 * 8 * 8, 256), nn.ReLU(inplace=True),
            nn.Dropout(dropout),
            nn.Linear(256, n_classes)
        )

    def forward(self, x):
        return self.classifier(self.features(x))

_m = CNNClassifier().to(DEVICE)
_x = torch.zeros(4, 1, 64, 64).to(DEVICE)
assert _m(_x).shape == (4, 15)
print('✓ CNNClassifier shape check passed')
del _m, _x
✓ CNNClassifier shape check passed

Cell 2.2 — Generative model architectures. Defines the conditional generator and critic for the WGAN-GP. The Generator concatenates a 128-d noise vector with a learned class embedding and upsamples through three transposed convolutions to a 64×64 image bounded by Tanh. The Critic embeds the class label as a 64×64 spatial map, stacks it as a second input channel, and downsamples through four strided convolutions to a single unbounded realism score — no sigmoid, as required by the Wasserstein formulation; InstanceNorm replaces BatchNorm per WGAN-GP convention. Shape assertions verify both forward passes.

In [6]:
# ============================================================
# CELL 2.2 — Generator & Critic Architectures
# ============================================================
class Generator(nn.Module):
    def __init__(self, latent_dim, n_classes, n_channels=1):
        super().__init__()
        self.label_emb = nn.Embedding(n_classes, n_classes)
        self.model = nn.Sequential(
            nn.Linear(latent_dim + n_classes, 256 * 8 * 8),
            nn.Unflatten(1, (256, 8, 8)),
            nn.ConvTranspose2d(256, 128, 4, 2, 1), nn.BatchNorm2d(128), nn.ReLU(inplace=True),
            nn.ConvTranspose2d(128, 64,  4, 2, 1), nn.BatchNorm2d(64),  nn.ReLU(inplace=True),
            nn.ConvTranspose2d(64, n_channels, 4, 2, 1), nn.Tanh()
        )

    def forward(self, noise, labels):
        return self.model(torch.cat([noise, self.label_emb(labels)], dim=1))


class Critic(nn.Module):
    def __init__(self, n_classes, n_channels=1):
        super().__init__()
        self.label_emb = nn.Embedding(n_classes, 64 * 64)
        self.model = nn.Sequential(
            nn.Conv2d(n_channels + 1, 64,  4, 2, 1), nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(64,  128, 4, 2, 1), nn.InstanceNorm2d(128), nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(128, 256, 4, 2, 1), nn.InstanceNorm2d(256), nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(256, 512, 4, 2, 1), nn.InstanceNorm2d(512), nn.LeakyReLU(0.2, inplace=True),
            nn.Flatten(), nn.Linear(512 * 4 * 4, 1)
        )

    def forward(self, img, labels):
        label_map = self.label_emb(labels).view(labels.size(0), 1, 64, 64)
        return self.model(torch.cat([img, label_map], dim=1))


_G = Generator(CONFIG['latent_dim'], CONFIG['n_classes']).to(DEVICE)
_C = Critic(CONFIG['n_classes']).to(DEVICE)
_n = torch.randn(4, CONFIG['latent_dim']).to(DEVICE)
_l = torch.randint(0, CONFIG['n_classes'], (4,)).to(DEVICE)
assert _G(_n, _l).shape == (4, 1, 64, 64)
assert _C(_G(_n, _l), _l).shape == (4, 1)
print('✓ Generator and Critic shape checks passed')
del _G, _C, _n, _l
✓ Generator and Critic shape checks passed

Cell 2.3 — All training and generation procedures. Four functions used by every later cell: compute_gradient_penalty implements the WGAN-GP Lipschitz penalty on critic gradients at interpolated points; train_wgan runs the adversarial loop (5 critic steps per generator step, Adam with β1=0, checkpoints to Drive every 25 epochs); train_cnn trains a fresh seeded classifier for 30 epochs with cosine learning-rate annealing and returns held-out test accuracy; generate_synthetic samples class-conditional images from a trained generator into a SyntheticDataset. Centralizing these guarantees all conditions share identical training mechanics.

In [7]:
# ============================================================
# CELL 2.3 — Training Functions
# ============================================================

def compute_gradient_penalty(critic, real, fake, labels, device):
    B = real.size(0)
    alpha = torch.rand(B, 1, 1, 1, device=device)
    interpolated = (alpha * real + (1 - alpha) * fake.detach()).requires_grad_(True)
    d_interp = critic(interpolated, labels)
    gradients = torch.autograd.grad(
        outputs=d_interp, inputs=interpolated,
        grad_outputs=torch.ones(d_interp.size(), device=device),
        create_graph=True, retain_graph=True, only_inputs=True
    )[0]
    return ((gradients.view(B, -1).norm(2, dim=1) - 1) ** 2).mean()


def train_wgan(train_loader, config, device):
    set_seed(config['base_seed'])
    G = Generator(config['latent_dim'], config['n_classes']).to(device)
    C = Critic(config['n_classes']).to(device)
    opt_G = torch.optim.Adam(G.parameters(), lr=config['wgan_lr'],
                              betas=(config['wgan_b1'], config['wgan_b2']))
    opt_C = torch.optim.Adam(C.parameters(), lr=config['wgan_lr'],
                              betas=(config['wgan_b1'], config['wgan_b2']))
    critic_losses, gen_losses = [], []

    for epoch in range(config['wgan_epochs']):
        ep_c, ep_g, nc, ng = 0.0, 0.0, 0, 0
        for imgs, labels in train_loader:
            imgs, labels = imgs.to(device), labels.to(device)
            B = imgs.size(0)
            for _ in range(config['n_critic']):
                noise = torch.randn(B, config['latent_dim'], device=device)
                fake  = G(noise, labels).detach()
                gp    = compute_gradient_penalty(C, imgs, fake, labels, device)
                c_loss = -C(imgs, labels).mean() + C(fake, labels).mean() + config['lambda_gp'] * gp
                opt_C.zero_grad(); c_loss.backward(); opt_C.step()
                ep_c += c_loss.item(); nc += 1
            noise  = torch.randn(B, config['latent_dim'], device=device)
            fake   = G(noise, labels)
            g_loss = -C(fake, labels).mean()
            opt_G.zero_grad(); g_loss.backward(); opt_G.step()
            ep_g += g_loss.item(); ng += 1

        critic_losses.append(ep_c / nc)
        gen_losses.append(ep_g / ng)
        if (epoch + 1) % 10 == 0:
            print(f'Epoch [{epoch+1}/{config["wgan_epochs"]}] '
                  f'Critic: {ep_c/nc:.4f}  Generator: {ep_g/ng:.4f}')
        if (epoch + 1) % 25 == 0:
            ckpt = os.path.join(config['ckpt_dir'], f'wgan_epoch{epoch+1}.pt')
            torch.save({'G': G.state_dict(), 'C': C.state_dict()}, ckpt)
            print(f'  → Checkpoint saved: {ckpt}')
    return G, C, critic_losses, gen_losses


def train_cnn(train_loader, test_loader, seed, tag=''):
    set_seed(seed)
    model     = CNNClassifier(CONFIG['n_classes'], CONFIG['dropout']).to(DEVICE)
    optimizer = torch.optim.Adam(model.parameters(), lr=CONFIG['cnn_lr'])
    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=CONFIG['cnn_epochs'])
    criterion = nn.CrossEntropyLoss()
    for epoch in range(CONFIG['cnn_epochs']):
        model.train()
        running_loss = 0.0
        for imgs, labels in train_loader:
            imgs, labels = imgs.to(DEVICE), labels.to(DEVICE)
            optimizer.zero_grad()
            loss = criterion(model(imgs), labels)
            loss.backward()
            optimizer.step()
            running_loss += loss.item()
        scheduler.step()
        if (epoch + 1) % 10 == 0:
            print(f'  [{tag}] Epoch {epoch+1}/{CONFIG["cnn_epochs"]} '
                  f'loss={running_loss/len(train_loader):.4f}')
    model.eval()
    correct = total = 0
    with torch.no_grad():
        for imgs, labels in test_loader:
            imgs, labels = imgs.to(DEVICE), labels.to(DEVICE)
            correct += (model(imgs).argmax(dim=1) == labels).sum().item()
            total   += labels.size(0)
    return correct / total, model


def generate_synthetic(G, n_per_class, n_classes, latent_dim, device):
    G.eval()
    all_imgs, all_labels = [], []
    with torch.no_grad():
        for c in range(n_classes):
            noise  = torch.randn(n_per_class, latent_dim, device=device)
            labels = torch.full((n_per_class,), c, dtype=torch.long, device=device)
            all_imgs.append(G(noise, labels).cpu())
            all_labels.extend([c] * n_per_class)
    return SyntheticDataset(
        torch.cat(all_imgs, dim=0),
        torch.tensor(all_labels, dtype=torch.long)
    )

print('✓ All training functions defined')
✓ All training functions defined

PHASE 3: C-WGAN-GP Training¶

Skip to PHASE 3B if checkpoint already exists in Drive¶


Cell 3A — WGAN-GP training (skip-if-trained). Trains the conditional WGAN-GP for 500 epochs on the real training set — but only if the final checkpoint is absent from Drive. When the checkpoint exists (the normal case after the initial overnight run), the cell prints a skip notice, making full Run-All executions cheap and reproducible.

In [8]:
# ============================================================
# CELL 3A — Train C-WGAN-GP from scratch (skip if checkpoint exists)
# ============================================================
FINAL_CKPT = os.path.join(CKPT_DIR, 'wgan_500ep_final.pt')

if os.path.exists(FINAL_CKPT):
    print(f'Checkpoint found at {FINAL_CKPT}')
    print('Skip this cell and run Cell 3B instead')
else:
    print('No checkpoint found — training from scratch...')
    CONFIG['wgan_epochs'] = 500
    G, C, critic_losses, gen_losses = train_wgan(train_loader, CONFIG, DEVICE)
    torch.save({'G': G.state_dict(), 'C': C.state_dict()}, FINAL_CKPT)
    print(f'\n✓ Model saved: {FINAL_CKPT}')
Checkpoint found at /content/drive/MyDrive/Chinese_MNIST_Data_Augmentation/checkpoints/wgan_500ep_final.pt
Skip this cell and run Cell 3B instead

Cell 3B — Model restoration. Instantiates fresh Generator and Critic objects and loads the 500-epoch weights from the Drive checkpoint. This decouples every downstream experiment from the cost of training: any new session reaches an identical trained generator in seconds.

In [9]:
# ============================================================
# CELL 3B — Load C-WGAN-GP from checkpoint
# ============================================================
FINAL_CKPT = os.path.join(CKPT_DIR, 'wgan_500ep_final.pt')
ckpt = torch.load(FINAL_CKPT, map_location=DEVICE)
G = Generator(CONFIG['latent_dim'], CONFIG['n_classes']).to(DEVICE)
C = Critic(CONFIG['n_classes']).to(DEVICE)
G.load_state_dict(ckpt['G'])
C.load_state_dict(ckpt['C'])
print(f'✓ Loaded generator and critic from {FINAL_CKPT}')
✓ Loaded generator and critic from /content/drive/MyDrive/Chinese_MNIST_Data_Augmentation/checkpoints/wgan_500ep_final.pt

Cell 3C — Generation quality check (Checkpoint 3). Samples one synthetic image per class from the restored generator and renders the 3×5 grid. A second human gate: downstream augmentation experiments are only meaningful if the generator demonstrably produces class-conditional character structure rather than noise.

In [10]:
# ============================================================
# CELL 3C — CHECKPOINT 3: Visual Inspection of Generated Samples
# ============================================================
G.eval()
fig, axes = plt.subplots(3, 5, figsize=(12, 8))
fig.suptitle('CHECKPOINT 3: Generated samples (one per class)', fontsize=13, fontweight='bold')
with torch.no_grad():
    for i, ax in enumerate(axes.flat):
        noise = torch.randn(1, CONFIG['latent_dim']).to(DEVICE)
        label = torch.tensor([i]).to(DEVICE)
        img   = G(noise, label).squeeze().cpu().numpy() * 0.5 + 0.5
        ax.imshow(img, cmap='gray')
        ax.set_title(f'Class {i}', fontsize=9)
        ax.axis('off')
plt.tight_layout()
plt.savefig(os.path.join(OUT_DIR, 'checkpoint3_generated.png'), dpi=100)
plt.show()
print('\n✅ CHECKPOINT 3 — Do images resemble Chinese characters?')
print('If yes: proceed to Phase 4')
print('If noise: re-train with more epochs')
✅ CHECKPOINT 3 — Do images resemble Chinese characters?
If yes: proceed to Phase 4
If noise: re-train with more epochs

PHASE 4: Controlled Factorial — Scarcity × Critic Filtering × Ratio¶


Unified factorial from one shared synthetic pool¶

This section builds the experiment around a single design principle: the synthetic images are a fixed artifact, and "filtered" vs "unfiltered" are two selection rules over the same raw generator outputs.

  • Cell B generates one raw pool per class (800/class) and scores every image with the critic once. The filtered block is the critic's top-scoring 200/class (a score-sorted prefix); the unfiltered block is a random 200/class drawn from the same pool.
  • Ratios nest by taking a prefix of each 200/class block, so "more synthetic" means "the same images plus more." Because each condition takes the top-n of the filtered block, the effective filter is stricter than a quartile — from roughly the top 3% of the pool (1:1 at 25 real/class) to the top 25% (4:1 at 50 real/class).
  • Cell C runs all ten cells of the 2 × 2 × 3 factorial live from that one pool — both baselines, and unfiltered/filtered × {1:1, 4:1} at both 25 and 50 real/class.

Because filtered and unfiltered now share identical raw generations, identical counts, and identical ratios, the four isolation contrasts in Cell 5.1 (filtered − unfiltered at fixed real-count and ratio) attribute any difference to the critic selection alone. This supersedes the earlier hardcoded Experiment-1 constants and the separate-pool generation; everything is one coherent run.

Executed on a Colab L4. All ten conditions were run live (≈50 CNN trainings: 10 conditions × 5 seeds); the tables and plot below are from that run. No numbers are hardcoded — every cell fills in when run.

Cell A — Deep-scarcity training set. Builds the 25-real-images-per-class training set by slicing the first 25 indices of each class block from the existing 50/class training split. Because it is a strict subset, the validation and test sets remain byte-identical to the 50/class regime, keeping the two regimes comparable on the same held-out data. An assertion verifies the expected 375-image count.

In [11]:
# ============================================================
# CELL A — Extreme scarcity: 25 real images/class
# Built as a SUBSET of the existing 50/class train set,
# so val/test sets remain IDENTICAL (clean comparison).
# ============================================================
N_SCARCE = 25

scarce_idx = []
for c in range(CONFIG['n_classes']):
    class_block = train_idx[c * 50 : c * 50 + 50]
    scarce_idx.extend(class_block[:N_SCARCE])

train_dataset_scarce = ChineseMNIST(CSV_PATH, IMAGES_DIR, scarce_idx, transform)
assert len(train_dataset_scarce) == N_SCARCE * CONFIG['n_classes']

train_loader_scarce = DataLoader(train_dataset_scarce,
                                 batch_size=CONFIG['cnn_batch_size'],
                                 shuffle=True, num_workers=2, pin_memory=True)

print(f'✓ Scarce train set: {len(train_dataset_scarce)} images ({N_SCARCE}/class)')
print(f'✓ Test set unchanged: {len(test_dataset)} images')
✓ Scarce train set: 375 images (25/class)
✓ Test set unchanged: 1500 images

Cell B — One raw pool, scored once; two selection rules. Generates 800 raw samples/class from the generator, scores all with the critic a single time, then derives the filtered block (the top-scoring 200/class, score-sorted) and the unfiltered block (a random 200/class) from that same pool. Defines make_aug_loader_gen, which concatenates real images with a nested prefix of a synthetic block — so the filtered condition trains on the strict top-k by critic score (the top ~3–25% of the pool, depending on count and ratio). Note: the unfiltered random draw may overlap the filtered top scores — that is intentional and unbiased; excluding the high-scoring images would depress the unfiltered condition and inflate the filter's apparent effect.

In [12]:
# ============================================================
# CELL B - One raw pool, scored once; two selection rules
# filtered vs unfiltered = two selections over the SAME raw
# generator outputs, so the contrast isolates the critic alone.
# ============================================================
POOL_PER_CLASS = 800   # raw pool; top 25% -> 200 filtered/class
KEEP           = 200   # max synthetic/class needed (50/class at 4:1)

def make_scored_pool(G, C, pool_per_class, n_classes, latent_dim, device, seed):
    # Generate one raw pool/class from G; score every image with C once.
    set_seed(seed)
    G.eval(); C.eval()
    imgs_all, score_all = [], []
    with torch.no_grad():
        for c in range(n_classes):
            noise  = torch.randn(pool_per_class, latent_dim, device=device)
            labels = torch.full((pool_per_class,), c, dtype=torch.long, device=device)
            imgs   = G(noise, labels)
            scores = C(imgs, labels).squeeze(1)          # higher = more realistic
            imgs_all.append(imgs.cpu()); score_all.append(scores.cpu())
    return torch.cat(imgs_all), torch.cat(score_all)

def make_aug_loader_gen(real_ds, syn_ds, n_syn_per_class, block_per_class):
    idx = []
    for c in range(CONFIG['n_classes']):
        start = c * block_per_class
        idx.extend(range(start, start + n_syn_per_class))
    sub = torch.utils.data.Subset(syn_ds, idx)
    ds  = ConcatDataset([real_ds, sub])
    return DataLoader(ds, batch_size=CONFIG['cnn_batch_size'],
                      shuffle=True, num_workers=2, pin_memory=True), len(ds)

# 1) Generate + score ONE pool
raw_imgs, raw_scores = make_scored_pool(
    G, C, POOL_PER_CLASS, CONFIG['n_classes'], CONFIG['latent_dim'], DEVICE,
    seed=CONFIG['base_seed'])
print(f'Raw pool: {len(raw_imgs)} images ({POOL_PER_CLASS}/class), scored once')

# 2) Derive two selection rules over the SAME pool, KEEP/class each.
#    Filtered  = top-KEEP by critic score (best first -> nests by ratio).
#    Unfiltered= random-KEEP from the same pool (no selection).
rng = np.random.RandomState(CONFIG['base_seed'])
filt_i, filt_l, unf_i, unf_l = [], [], [], []
for cls in range(CONFIG['n_classes']):
    lo = cls * POOL_PER_CLASS
    block_imgs   = raw_imgs[lo:lo + POOL_PER_CLASS]
    block_scores = raw_scores[lo:lo + POOL_PER_CLASS]
    top  = torch.argsort(block_scores, descending=True)[:KEEP]
    rand = torch.from_numpy(rng.permutation(POOL_PER_CLASS)[:KEEP])
    filt_i.append(block_imgs[top]);  filt_l += [cls] * KEEP
    unf_i.append(block_imgs[rand]);  unf_l += [cls] * KEEP

syn_filtered   = SyntheticDataset(torch.cat(filt_i), torch.tensor(filt_l, dtype=torch.long))
syn_unfiltered = SyntheticDataset(torch.cat(unf_i), torch.tensor(unf_l, dtype=torch.long))
print(f'Filtered set:   {len(syn_filtered)} images ({KEEP}/class, critic top-k by score)')
print(f'Unfiltered set: {len(syn_unfiltered)} images ({KEEP}/class, random from same pool)')
Raw pool: 12000 images (800/class), scored once
Filtered set:   3000 images (200/class, critic top-k by score)
Unfiltered set: 3000 images (200/class, random from same pool)

Cell B-check — visual confirmation. Top row: the highest-critic-scored (filtered) sample per class. Bottom row: a random (unfiltered) sample from the same pool. The filtered samples should look cleaner; if they don't, the critic is not separating quality and the downstream contrast won't mean much.

In [13]:
# ============================================================
# CELL B-check - filtered (top) vs unfiltered (bottom) samples
# ============================================================
show_classes = [0, 3, 6, 9, 12]
fig, axes = plt.subplots(2, len(show_classes), figsize=(2.2 * len(show_classes), 4.6))
for j, cls in enumerate(show_classes):
    f_img, _ = syn_filtered[cls * KEEP + 0]    # top-scoring
    u_img, _ = syn_unfiltered[cls * KEEP + 0]  # a random one
    axes[0, j].imshow(f_img.squeeze().numpy() * 0.5 + 0.5, cmap='gray')
    axes[0, j].set_title(f'class {cls}', fontsize=9); axes[0, j].axis('off')
    axes[1, j].imshow(u_img.squeeze().numpy() * 0.5 + 0.5, cmap='gray')
    axes[1, j].axis('off')
fig.text(0.01, 0.72, 'filtered',   rotation=90, va='center', fontsize=10)
fig.text(0.01, 0.28, 'unfiltered', rotation=90, va='center', fontsize=10)
fig.suptitle('Same raw pool, two selection rules', fontsize=12)
plt.tight_layout(rect=[0.03, 0, 1, 1])
plt.savefig(os.path.join(OUT_DIR, 'rev11_pool_check.png'), dpi=120)
plt.show()

Cell C — Full factorial execution. Runs all ten conditions live from the shared pool, storing per-seed accuracies in results. Same seeds and test set throughout, so every paired t-test downstream is valid. About 50 CNN trainings.

In [14]:
# ============================================================
# CELL C (rev11) - Full factorial execution (10 conditions, ONE pool)
# ~50 CNN trainings (10 x 5 seeds). All from the same raw pool.
# ============================================================
results = {}

def run_condition(key, loader, label):
    print(f'=== {label} ===')
    accs = []
    for i in range(CONFIG['n_seeds']):
        acc, _ = train_cnn(loader, test_loader, seed=CONFIG['base_seed'] + i, tag=f'{key} s{i+1}')
        accs.append(acc); print(f'  -> Seed {i+1}: {acc:.4f}')
    results[key] = accs
    print(f'  {label}: {np.mean(accs):.4f} +/- {np.std(accs):.4f}\n')

# Baselines (real only)
run_condition('baseline_50', train_loader,        'Baseline 50/class')
run_condition('baseline_25', train_loader_scarce, 'Baseline 25/class')

# 50 real/class + augmentation (block size = KEEP)
lo, _ = make_aug_loader_gen(train_dataset, syn_unfiltered,  50, KEEP); run_condition('unfilt_50_1to1', lo, 'Unfiltered 1:1 @50')
lo, _ = make_aug_loader_gen(train_dataset, syn_unfiltered, 200, KEEP); run_condition('unfilt_50_4to1', lo, 'Unfiltered 4:1 @50')
lo, _ = make_aug_loader_gen(train_dataset, syn_filtered,    50, KEEP); run_condition('filt_50_1to1',   lo, 'Filtered 1:1 @50')
lo, _ = make_aug_loader_gen(train_dataset, syn_filtered,   200, KEEP); run_condition('filt_50_4to1',   lo, 'Filtered 4:1 @50')

# 25 real/class + augmentation
lo, _ = make_aug_loader_gen(train_dataset_scarce, syn_unfiltered,  25, KEEP); run_condition('unfilt_25_1to1', lo, 'Unfiltered 1:1 @25')
lo, _ = make_aug_loader_gen(train_dataset_scarce, syn_unfiltered, 100, KEEP); run_condition('unfilt_25_4to1', lo, 'Unfiltered 4:1 @25')
lo, _ = make_aug_loader_gen(train_dataset_scarce, syn_filtered,    25, KEEP); run_condition('filt_25_1to1',   lo, 'Filtered 1:1 @25')
lo, _ = make_aug_loader_gen(train_dataset_scarce, syn_filtered,   100, KEEP); run_condition('filt_25_4to1',   lo, 'Filtered 4:1 @25')

print('\nAll conditions complete.')
=== Baseline 50/class ===
  [baseline_50 s1] Epoch 10/30 loss=0.4733
  [baseline_50 s1] Epoch 20/30 loss=0.1218
  [baseline_50 s1] Epoch 30/30 loss=0.0867
  -> Seed 1: 0.8907
  [baseline_50 s2] Epoch 10/30 loss=0.5400
  [baseline_50 s2] Epoch 20/30 loss=0.1765
  [baseline_50 s2] Epoch 30/30 loss=0.1150
  -> Seed 2: 0.8933
  [baseline_50 s3] Epoch 10/30 loss=0.9371
  [baseline_50 s3] Epoch 20/30 loss=0.3433
  [baseline_50 s3] Epoch 30/30 loss=0.2763
  -> Seed 3: 0.8820
  [baseline_50 s4] Epoch 10/30 loss=0.8691
  [baseline_50 s4] Epoch 20/30 loss=0.3554
  [baseline_50 s4] Epoch 30/30 loss=0.2586
  -> Seed 4: 0.8713
  [baseline_50 s5] Epoch 10/30 loss=0.7191
  [baseline_50 s5] Epoch 20/30 loss=0.2068
  [baseline_50 s5] Epoch 30/30 loss=0.1570
  -> Seed 5: 0.8873
  Baseline 50/class: 0.8849 +/- 0.0078

=== Baseline 25/class ===
  [baseline_25 s1] Epoch 10/30 loss=1.1889
  [baseline_25 s1] Epoch 20/30 loss=0.4814
  [baseline_25 s1] Epoch 30/30 loss=0.3193
  -> Seed 1: 0.7140
  [baseline_25 s2] Epoch 10/30 loss=0.9380
  [baseline_25 s2] Epoch 20/30 loss=0.4288
  [baseline_25 s2] Epoch 30/30 loss=0.3054
  -> Seed 2: 0.7647
  [baseline_25 s3] Epoch 10/30 loss=1.3329
  [baseline_25 s3] Epoch 20/30 loss=0.5279
  [baseline_25 s3] Epoch 30/30 loss=0.4185
  -> Seed 3: 0.7113
  [baseline_25 s4] Epoch 10/30 loss=1.3607
  [baseline_25 s4] Epoch 20/30 loss=0.5606
  [baseline_25 s4] Epoch 30/30 loss=0.4142
  -> Seed 4: 0.7127
  [baseline_25 s5] Epoch 10/30 loss=1.3868
  [baseline_25 s5] Epoch 20/30 loss=0.6796
  [baseline_25 s5] Epoch 30/30 loss=0.5963
  -> Seed 5: 0.6860
  Baseline 25/class: 0.7177 +/- 0.0257

=== Unfiltered 1:1 @50 ===
  [unfilt_50_1to1 s1] Epoch 10/30 loss=0.3409
  [unfilt_50_1to1 s1] Epoch 20/30 loss=0.0965
  [unfilt_50_1to1 s1] Epoch 30/30 loss=0.0605
  -> Seed 1: 0.8680
  [unfilt_50_1to1 s2] Epoch 10/30 loss=0.2277
  [unfilt_50_1to1 s2] Epoch 20/30 loss=0.0649
  [unfilt_50_1to1 s2] Epoch 30/30 loss=0.0461
  -> Seed 2: 0.8860
  [unfilt_50_1to1 s3] Epoch 10/30 loss=0.4278
  [unfilt_50_1to1 s3] Epoch 20/30 loss=0.1838
  [unfilt_50_1to1 s3] Epoch 30/30 loss=0.1005
  -> Seed 3: 0.8540
  [unfilt_50_1to1 s4] Epoch 10/30 loss=0.3158
  [unfilt_50_1to1 s4] Epoch 20/30 loss=0.0997
  [unfilt_50_1to1 s4] Epoch 30/30 loss=0.0675
  -> Seed 4: 0.8720
  [unfilt_50_1to1 s5] Epoch 10/30 loss=0.2906
  [unfilt_50_1to1 s5] Epoch 20/30 loss=0.0783
  [unfilt_50_1to1 s5] Epoch 30/30 loss=0.0465
  -> Seed 5: 0.8727
  Unfiltered 1:1 @50: 0.8705 +/- 0.0103

=== Unfiltered 4:1 @50 ===
  [unfilt_50_4to1 s1] Epoch 10/30 loss=0.0997
  [unfilt_50_4to1 s1] Epoch 20/30 loss=0.0238
  [unfilt_50_4to1 s1] Epoch 30/30 loss=0.0180
  -> Seed 1: 0.8620
  [unfilt_50_4to1 s2] Epoch 10/30 loss=0.1271
  [unfilt_50_4to1 s2] Epoch 20/30 loss=0.0330
  [unfilt_50_4to1 s2] Epoch 30/30 loss=0.0187
  -> Seed 2: 0.8667
  [unfilt_50_4to1 s3] Epoch 10/30 loss=0.0979
  [unfilt_50_4to1 s3] Epoch 20/30 loss=0.0250
  [unfilt_50_4to1 s3] Epoch 30/30 loss=0.0121
  -> Seed 3: 0.8520
  [unfilt_50_4to1 s4] Epoch 10/30 loss=0.0905
  [unfilt_50_4to1 s4] Epoch 20/30 loss=0.0167
  [unfilt_50_4to1 s4] Epoch 30/30 loss=0.0094
  -> Seed 4: 0.8653
  [unfilt_50_4to1 s5] Epoch 10/30 loss=0.0915
  [unfilt_50_4to1 s5] Epoch 20/30 loss=0.0278
  [unfilt_50_4to1 s5] Epoch 30/30 loss=0.0145
  -> Seed 5: 0.8713
  Unfiltered 4:1 @50: 0.8635 +/- 0.0065

=== Filtered 1:1 @50 ===
  [filt_50_1to1 s1] Epoch 10/30 loss=0.2693
  [filt_50_1to1 s1] Epoch 20/30 loss=0.0851
  [filt_50_1to1 s1] Epoch 30/30 loss=0.0548
  -> Seed 1: 0.8687
  [filt_50_1to1 s2] Epoch 10/30 loss=0.1742
  [filt_50_1to1 s2] Epoch 20/30 loss=0.0448
  [filt_50_1to1 s2] Epoch 30/30 loss=0.0272
  -> Seed 2: 0.8960
  [filt_50_1to1 s3] Epoch 10/30 loss=0.2647
  [filt_50_1to1 s3] Epoch 20/30 loss=0.0932
  [filt_50_1to1 s3] Epoch 30/30 loss=0.0519
  -> Seed 3: 0.8573
  [filt_50_1to1 s4] Epoch 10/30 loss=0.2335
  [filt_50_1to1 s4] Epoch 20/30 loss=0.0580
  [filt_50_1to1 s4] Epoch 30/30 loss=0.0377
  -> Seed 4: 0.8767
  [filt_50_1to1 s5] Epoch 10/30 loss=0.2067
  [filt_50_1to1 s5] Epoch 20/30 loss=0.0592
  [filt_50_1to1 s5] Epoch 30/30 loss=0.0319
  -> Seed 5: 0.8813
  Filtered 1:1 @50: 0.8760 +/- 0.0129

=== Filtered 4:1 @50 ===
  [filt_50_4to1 s1] Epoch 10/30 loss=0.0853
  [filt_50_4to1 s1] Epoch 20/30 loss=0.0177
  [filt_50_4to1 s1] Epoch 30/30 loss=0.0104
  -> Seed 1: 0.8673
  [filt_50_4to1 s2] Epoch 10/30 loss=0.0851
  [filt_50_4to1 s2] Epoch 20/30 loss=0.0209
  [filt_50_4to1 s2] Epoch 30/30 loss=0.0148
  -> Seed 2: 0.8773
  [filt_50_4to1 s3] Epoch 10/30 loss=0.0778
  [filt_50_4to1 s3] Epoch 20/30 loss=0.0226
  [filt_50_4to1 s3] Epoch 30/30 loss=0.0105
  -> Seed 3: 0.8613
  [filt_50_4to1 s4] Epoch 10/30 loss=0.0678
  [filt_50_4to1 s4] Epoch 20/30 loss=0.0131
  [filt_50_4to1 s4] Epoch 30/30 loss=0.0070
  -> Seed 4: 0.8573
  [filt_50_4to1 s5] Epoch 10/30 loss=0.1098
  [filt_50_4to1 s5] Epoch 20/30 loss=0.0208
  [filt_50_4to1 s5] Epoch 30/30 loss=0.0133
  -> Seed 5: 0.8607
  Filtered 4:1 @50: 0.8648 +/- 0.0070

=== Unfiltered 1:1 @25 ===
  [unfilt_25_1to1 s1] Epoch 10/30 loss=0.2884
  [unfilt_25_1to1 s1] Epoch 20/30 loss=0.0594
  [unfilt_25_1to1 s1] Epoch 30/30 loss=0.0473
  -> Seed 1: 0.7860
  [unfilt_25_1to1 s2] Epoch 10/30 loss=0.6548
  [unfilt_25_1to1 s2] Epoch 20/30 loss=0.2062
  [unfilt_25_1to1 s2] Epoch 30/30 loss=0.1586
  -> Seed 2: 0.7860
  [unfilt_25_1to1 s3] Epoch 10/30 loss=0.6164
  [unfilt_25_1to1 s3] Epoch 20/30 loss=0.2212
  [unfilt_25_1to1 s3] Epoch 30/30 loss=0.1448
  -> Seed 3: 0.7753
  [unfilt_25_1to1 s4] Epoch 10/30 loss=0.4434
  [unfilt_25_1to1 s4] Epoch 20/30 loss=0.1270
  [unfilt_25_1to1 s4] Epoch 30/30 loss=0.0916
  -> Seed 4: 0.7640
  [unfilt_25_1to1 s5] Epoch 10/30 loss=0.6515
  [unfilt_25_1to1 s5] Epoch 20/30 loss=0.2248
  [unfilt_25_1to1 s5] Epoch 30/30 loss=0.1784
  -> Seed 5: 0.7647
  Unfiltered 1:1 @25: 0.7752 +/- 0.0097

=== Unfiltered 4:1 @25 ===
  [unfilt_25_4to1 s1] Epoch 10/30 loss=0.1207
  [unfilt_25_4to1 s1] Epoch 20/30 loss=0.0344
  [unfilt_25_4to1 s1] Epoch 30/30 loss=0.0193
  -> Seed 1: 0.7907
  [unfilt_25_4to1 s2] Epoch 10/30 loss=0.1184
  [unfilt_25_4to1 s2] Epoch 20/30 loss=0.0288
  [unfilt_25_4to1 s2] Epoch 30/30 loss=0.0194
  -> Seed 2: 0.8040
  [unfilt_25_4to1 s3] Epoch 10/30 loss=0.1516
  [unfilt_25_4to1 s3] Epoch 20/30 loss=0.0441
  [unfilt_25_4to1 s3] Epoch 30/30 loss=0.0219
  -> Seed 3: 0.7680
  [unfilt_25_4to1 s4] Epoch 10/30 loss=0.1713
  [unfilt_25_4to1 s4] Epoch 20/30 loss=0.0430
  [unfilt_25_4to1 s4] Epoch 30/30 loss=0.0270
  -> Seed 4: 0.7880
  [unfilt_25_4to1 s5] Epoch 10/30 loss=0.1450
  [unfilt_25_4to1 s5] Epoch 20/30 loss=0.0348
  [unfilt_25_4to1 s5] Epoch 30/30 loss=0.0215
  -> Seed 5: 0.7887
  Unfiltered 4:1 @25: 0.7879 +/- 0.0115

=== Filtered 1:1 @25 ===
  [filt_25_1to1 s1] Epoch 10/30 loss=0.2645
  [filt_25_1to1 s1] Epoch 20/30 loss=0.0764
  [filt_25_1to1 s1] Epoch 30/30 loss=0.0560
  -> Seed 1: 0.7833
  [filt_25_1to1 s2] Epoch 10/30 loss=0.3955
  [filt_25_1to1 s2] Epoch 20/30 loss=0.1276
  [filt_25_1to1 s2] Epoch 30/30 loss=0.0952
  -> Seed 2: 0.7833
  [filt_25_1to1 s3] Epoch 10/30 loss=0.4028
  [filt_25_1to1 s3] Epoch 20/30 loss=0.1326
  [filt_25_1to1 s3] Epoch 30/30 loss=0.0857
  -> Seed 3: 0.7660
  [filt_25_1to1 s4] Epoch 10/30 loss=0.4151
  [filt_25_1to1 s4] Epoch 20/30 loss=0.1129
  [filt_25_1to1 s4] Epoch 30/30 loss=0.0644
  -> Seed 4: 0.7847
  [filt_25_1to1 s5] Epoch 10/30 loss=0.4703
  [filt_25_1to1 s5] Epoch 20/30 loss=0.1783
  [filt_25_1to1 s5] Epoch 30/30 loss=0.1156
  -> Seed 5: 0.7767
  Filtered 1:1 @25: 0.7788 +/- 0.0070

=== Filtered 4:1 @25 ===
  [filt_25_4to1 s1] Epoch 10/30 loss=0.0881
  [filt_25_4to1 s1] Epoch 20/30 loss=0.0222
  [filt_25_4to1 s1] Epoch 30/30 loss=0.0128
  -> Seed 1: 0.7980
  [filt_25_4to1 s2] Epoch 10/30 loss=0.0634
  [filt_25_4to1 s2] Epoch 20/30 loss=0.0145
  [filt_25_4to1 s2] Epoch 30/30 loss=0.0109
  -> Seed 2: 0.7933
  [filt_25_4to1 s3] Epoch 10/30 loss=0.1022
  [filt_25_4to1 s3] Epoch 20/30 loss=0.0312
  [filt_25_4to1 s3] Epoch 30/30 loss=0.0177
  -> Seed 3: 0.7793
  [filt_25_4to1 s4] Epoch 10/30 loss=0.0955
  [filt_25_4to1 s4] Epoch 20/30 loss=0.0230
  [filt_25_4to1 s4] Epoch 30/30 loss=0.0133
  -> Seed 4: 0.7973
  [filt_25_4to1 s5] Epoch 10/30 loss=0.1218
  [filt_25_4to1 s5] Epoch 20/30 loss=0.0280
  [filt_25_4to1 s5] Epoch 30/30 loss=0.0171
  -> Seed 5: 0.7827
  Filtered 4:1 @25: 0.7901 +/- 0.0077


All conditions complete.

Cell 5.1 — Full factorial table + critic-filter isolation. All ten conditions, each augmented row paired against its same-regime baseline. The isolation block reports filtered − unfiltered at fixed real-count and ratio: with the pool, count, and ratio all held constant, any significant value would isolate the critic filter as the cause. (In this run all four contrasts are ≤ 0.55 points and none survive Holm–Bonferroni — no detectable Stage 1 filter effect.)

In [15]:
# ============================================================
# CELL 5.1 - Full factorial table & significance tests (rev11)
# ============================================================
from scipy import stats

def fmt(name, accs, base=None):
    m, s = np.mean(accs), np.std(accs)
    if base is None:
        print(f'{name:<30} {m:>8.4f} {s:>8.4f} {"-":>9} {"-":>9}')
    else:
        t, p = stats.ttest_rel(accs, base)
        d = m - np.mean(base); flag = '*' if p < 0.05 else ' '
        print(f'{name:<30} {m:>8.4f} {s:>8.4f} {d:>+9.4f} {p:>8.4f}{flag}')

print('=' * 70)
print('FULL FACTORIAL (one shared raw pool) - real-count x filter x ratio')
print('=' * 70)
print(f'{"Condition":<30} {"Mean":>8} {"Std":>8} {"Delta":>9} {"p-value":>9}')
print('-' * 70)
print('- 50 real/class -')
fmt('Baseline (50)',  results['baseline_50'])
fmt('Unfiltered 1:1', results['unfilt_50_1to1'], results['baseline_50'])
fmt('Unfiltered 4:1', results['unfilt_50_4to1'], results['baseline_50'])
fmt('Filtered 1:1',   results['filt_50_1to1'],   results['baseline_50'])
fmt('Filtered 4:1',   results['filt_50_4to1'],   results['baseline_50'])
print('- 25 real/class -')
fmt('Baseline (25)',  results['baseline_25'])
fmt('Unfiltered 1:1', results['unfilt_25_1to1'], results['baseline_25'])
fmt('Unfiltered 4:1', results['unfilt_25_4to1'], results['baseline_25'])
fmt('Filtered 1:1',   results['filt_25_1to1'],   results['baseline_25'])
fmt('Filtered 4:1',   results['filt_25_4to1'],   results['baseline_25'])
print('=' * 70)
print('* significant at p < 0.05 (paired t-test vs same-regime baseline)')

print('\n' + '=' * 70)
print('CRITIC-FILTER ISOLATION (filtered - unfiltered; real-count AND ratio fixed)')
print('=' * 70)
def contrast(name, a, b):
    d = np.array(a) - np.array(b); t, p = stats.ttest_rel(a, b)
    flag = '*' if p < 0.05 else ' '
    print(f'{name:<24} {np.mean(d):>+8.4f}  (paired p = {p:.4f}){flag}')
contrast('50/class 1:1', results['filt_50_1to1'], results['unfilt_50_1to1'])
contrast('50/class 4:1', results['filt_50_4to1'], results['unfilt_50_4to1'])
contrast('25/class 1:1', results['filt_25_1to1'], results['unfilt_25_1to1'])
contrast('25/class 4:1', results['filt_25_4to1'], results['unfilt_25_4to1'])
print('=' * 70)
print('Same pool, same count, same ratio -> these isolate the critic filter.')
======================================================================
FULL FACTORIAL (one shared raw pool) - real-count x filter x ratio
======================================================================
Condition                          Mean      Std     Delta   p-value
----------------------------------------------------------------------
- 50 real/class -
Baseline (50)                    0.8849   0.0078         -         -
Unfiltered 1:1                   0.8705   0.0103   -0.0144   0.0490*
Unfiltered 4:1                   0.8635   0.0065   -0.0215   0.0094*
Filtered 1:1                     0.8760   0.0129   -0.0089   0.2221 
Filtered 4:1                     0.8648   0.0070   -0.0201   0.0010*
- 25 real/class -
Baseline (25)                    0.7177   0.0257         -         -
Unfiltered 1:1                   0.7752   0.0097   +0.0575   0.0047*
Unfiltered 4:1                   0.7879   0.0115   +0.0701   0.0027*
Filtered 1:1                     0.7788   0.0070   +0.0611   0.0071*
Filtered 4:1                     0.7901   0.0077   +0.0724   0.0036*
======================================================================
* significant at p < 0.05 (paired t-test vs same-regime baseline)

======================================================================
CRITIC-FILTER ISOLATION (filtered - unfiltered; real-count AND ratio fixed)
======================================================================
50/class 1:1              +0.0055  (paired p = 0.0335)*
50/class 4:1              +0.0013  (paired p = 0.7800) 
25/class 1:1              +0.0036  (paired p = 0.5493) 
25/class 4:1              +0.0023  (paired p = 0.6362) 
======================================================================
Same pool, same count, same ratio -> these isolate the critic filter.

Cell 5.2 — Full factorial plot. Blue = baseline, red = unfiltered, green = filtered; the dotted line separates the 50/class block (left) from the 25/class block (right).

In [16]:
# ============================================================
# CELL 5.2 - Full factorial plot (rev11)
# ============================================================
names = ['Base\n50','Unf 1:1\n50','Unf 4:1\n50','Filt 1:1\n50','Filt 4:1\n50',
         'Base\n25','Unf 1:1\n25','Unf 4:1\n25','Filt 1:1\n25','Filt 4:1\n25']
keys  = ['baseline_50','unfilt_50_1to1','unfilt_50_4to1','filt_50_1to1','filt_50_4to1',
         'baseline_25','unfilt_25_1to1','unfilt_25_4to1','filt_25_1to1','filt_25_4to1']
data  = [results[k] for k in keys]
means = [np.mean(d) for d in data]; stds = [np.std(d) for d in data]
colors = ['#4C72B0','#C44E52','#C44E52','#55A868','#55A868',
          '#4C72B0','#C44E52','#C44E52','#55A868','#55A868']
fig, ax = plt.subplots(figsize=(13, 6))
bars = ax.bar(names, means, yerr=stds, capsize=4, color=colors, alpha=0.85)
for b, m in zip(bars, means):
    ax.text(b.get_x() + b.get_width()/2, b.get_height() + 0.003, f'{m:.3f}',
            ha='center', va='bottom', fontsize=8)
ax.set_ylabel('Test Accuracy')
ax.set_title('Full factorial (one shared pool): blue=baseline, red=unfiltered, green=filtered')
ax.axvline(4.5, color='gray', ls=':', lw=1, alpha=0.6)
ax.set_ylim(min(means) - 0.05, max(means) + 0.04)
plt.tight_layout()
plt.savefig(os.path.join(OUT_DIR, 'rev11_full_factorial.png'), dpi=150)
plt.show()
print('Saved rev11 full-factorial plot')
Saved rev11 full-factorial plot