In [1]:
# CELL 1: SETUP, CONFIG & DATA LOADING
import os
import time
import json
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import cv2
import rasterio
from rasterio.windows import from_bounds
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import TensorDataset, DataLoader, random_split
import ee
import geemap
from huggingface_hub import hf_hub_download
from sklearn.metrics import accuracy_score, f1_score, jaccard_score, precision_score, recall_score, confusion_matrix
from google.colab import drive
# 1. Environment & Auth
drive.mount('/content/drive', force_remount=True)
os.system('pip install -q rasterio geopandas timm segmentation-models-pytorch huggingface_hub geedim')
try:
ee.Initialize(project='[REDACTED_FOR_SECURITY]')
except:
ee.Authenticate()
ee.Initialize(project='[REDACTED_FOR_SECURITY]')
# 2. Configuration
BATCH_SIZE = 4
EPOCHS = 50
LR = 1e-4
PATCH_SIZE = 224
S2_SCALE = 5000.0
S1_MIN, S1_MAX = -25.0, 0.0
ASSET_ID = 'projects/[REDACTED_FOR_SECURITY]/assets/Punjab_Mask_2024_NEW'
SAVE_DIR = '/content/drive/MyDrive/SatMAE_Results/'
if not os.path.exists(SAVE_DIR): os.makedirs(SAVE_DIR)
TIME_WINDOWS = [
('2024-11-01', '2024-11-30'), # Sowing
('2025-02-15', '2025-03-15'), # Peak
('2025-04-01', '2025-04-15') # Harvest
]
# 3. Robust Data Loader
def get_satmae_data():
print("1. Ingesting Asset...")
mask_img = ee.Image(ASSET_ID)
roi_geom = mask_img.geometry()
mask_file = 'local_mask.tif'
if not os.path.exists(mask_file):
geemap.download_ee_image(mask_img, mask_file, region=roi_geom, scale=10, crs='EPSG:4326', overwrite=True)
print("2. Defining Subset (Anti-Crash)...")
with rasterio.open(mask_file) as src:
b = src.bounds
cx, cy = (b.left + b.right)/2, (b.bottom + b.top)/2
offset = 0.06
window = from_bounds(cx-offset, cy-offset, cx+offset, cy+offset, src.transform)
mask = src.read(1, window=window)
mask = np.where(mask > 0, 1.0, 0.0).astype(np.float32)
small_roi = ee.Geometry.Rectangle([cx-offset, cy-offset, cx+offset, cy+offset], proj=str(src.crs), geodesic=False)
target_h, target_w = mask.shape
stack = []
print("3. Stacking Multi-Temporal Bands...")
for i, (start, end) in enumerate(TIME_WINDOWS):
fname = f'time_{i}.tif'
if not os.path.exists(fname):
s2 = ee.ImageCollection('COPERNICUS/S2_SR_HARMONIZED').filterBounds(small_roi).filterDate(start, end).median().select(['B2','B3','B4','B8','B11','B12'])
s1 = ee.ImageCollection('COPERNICUS/S1_GRD').filterBounds(small_roi).filterDate(start, end).mean().select(['VV','VH'])
fused = ee.Image.cat([s2, s1]).clip(small_roi)
geemap.download_ee_image(fused, fname, region=small_roi, scale=10, crs='EPSG:4326', overwrite=True)
with rasterio.open(fname) as src:
arr = src.read()
arr = np.transpose(arr, (1, 2, 0))
if arr.shape[:2] != (target_h, target_w):
arr = cv2.resize(arr, (target_w, target_h), interpolation=cv2.INTER_LINEAR)
s2_n = np.clip(arr[:,:,:6] / S2_SCALE, 0, 1)
s1_n = np.clip((arr[:,:,6:] - S1_MIN) / (S1_MAX - S1_MIN), 0, 1)
stack.append(np.concatenate([s2_n, s1_n], axis=2))
full_cube = np.stack(stack, axis=2)
print("4. Tiling...")
x_out, y_out = [], []
stride = PATCH_SIZE
for y in range(0, target_h, stride):
for x in range(0, target_w, stride):
img_p = full_cube[y:y+stride, x:x+stride]
mask_p = mask[y:y+stride, x:x+stride]
if img_p.shape[0] != PATCH_SIZE or img_p.shape[1] != PATCH_SIZE: continue
if np.min(img_p) < 0: continue
x_out.append(img_p)
y_out.append(mask_p)
if len(x_out) == 0: raise ValueError("No valid patches found.")
# SatMAE Format: (N, C, T, H, W)
X = np.array(x_out, dtype=np.float32).transpose(0, 4, 3, 1, 2)
y = np.array(y_out, dtype=np.float32)[:, None, :, :]
print(f" Data Ready. Shape: {X.shape}")
return torch.tensor(X), torch.tensor(y)
# 4. Prepare Loaders
X_data, y_data = get_satmae_data()
ds = TensorDataset(X_data, y_data)
tr_sz = int(0.85 * len(ds))
t_ds, v_ds = random_split(ds, [tr_sz, len(ds)-tr_sz])
train_loader = DataLoader(t_ds, BATCH_SIZE, shuffle=True)
val_loader = DataLoader(v_ds, BATCH_SIZE, shuffle=False)
# 5. Loss Function
class DiceLoss(nn.Module):
def __init__(self, smooth=1e-6):
super().__init__()
self.smooth = smooth
def forward(self, inputs, targets):
inputs = torch.sigmoid(inputs).view(-1)
targets = targets.view(-1)
inter = (inputs * targets).sum()
dice = (2. * inter + self.smooth) / (inputs.sum() + targets.sum() + self.smooth)
return 1 - dice
Mounted at /content/drive 1. Ingesting Asset...
/usr/local/lib/python3.12/dist-packages/geemap/common.py:12471: FutureWarning: 'BaseImage' is deprecated and will be removed in a future release. Please use the 'ee.Image.gd' accessor instead. img = gd.download.BaseImage(image)
...tmae-2026/assets/Punjab_Mask_2024_NEW: 0%| |0/585 tiles [00:00<?]
WARNING:urllib3.connectionpool:Connection pool is full, discarding connection: earthengine.googleapis.com. Connection pool size: 10 WARNING:urllib3.connectionpool:Connection pool is full, discarding connection: earthengine.googleapis.com. Connection pool size: 10 WARNING:urllib3.connectionpool:Connection pool is full, discarding connection: earthengine.googleapis.com. Connection pool size: 10 WARNING:urllib3.connectionpool:Connection pool is full, discarding connection: earthengine.googleapis.com. Connection pool size: 10 WARNING:urllib3.connectionpool:Connection pool is full, discarding connection: earthengine.googleapis.com. Connection pool size: 10 WARNING:urllib3.connectionpool:Connection pool is full, discarding connection: earthengine.googleapis.com. Connection pool size: 10 WARNING:urllib3.connectionpool:Connection pool is full, discarding connection: earthengine.googleapis.com. Connection pool size: 10 WARNING:urllib3.connectionpool:Connection pool is full, discarding connection: earthengine.googleapis.com. Connection pool size: 10 WARNING:urllib3.connectionpool:Connection pool is full, discarding connection: earthengine.googleapis.com. Connection pool size: 10 WARNING:urllib3.connectionpool:Connection pool is full, discarding connection: earthengine.googleapis.com. Connection pool size: 10 WARNING:urllib3.connectionpool:Connection pool is full, discarding connection: earthengine.googleapis.com. Connection pool size: 10 WARNING:urllib3.connectionpool:Connection pool is full, discarding connection: earthengine.googleapis.com. Connection pool size: 10 WARNING:urllib3.connectionpool:Connection pool is full, discarding connection: earthengine.googleapis.com. Connection pool size: 10 WARNING:urllib3.connectionpool:Connection pool is full, discarding connection: earthengine.googleapis.com. Connection pool size: 10 WARNING:urllib3.connectionpool:Connection pool is full, discarding connection: earthengine.googleapis.com. Connection pool size: 10 WARNING:urllib3.connectionpool:Connection pool is full, discarding connection: earthengine.googleapis.com. Connection pool size: 10 WARNING:urllib3.connectionpool:Connection pool is full, discarding connection: earthengine.googleapis.com. Connection pool size: 10 WARNING:urllib3.connectionpool:Connection pool is full, discarding connection: earthengine.googleapis.com. Connection pool size: 10 WARNING:urllib3.connectionpool:Connection pool is full, discarding connection: earthengine.googleapis.com. Connection pool size: 10 WARNING:urllib3.connectionpool:Connection pool is full, discarding connection: earthengine.googleapis.com. Connection pool size: 10 WARNING:urllib3.connectionpool:Connection pool is full, discarding connection: earthengine.googleapis.com. Connection pool size: 10 WARNING:urllib3.connectionpool:Connection pool is full, discarding connection: earthengine.googleapis.com. Connection pool size: 10 WARNING:urllib3.connectionpool:Connection pool is full, discarding connection: earthengine.googleapis.com. Connection pool size: 10 WARNING:urllib3.connectionpool:Connection pool is full, discarding connection: earthengine.googleapis.com. Connection pool size: 10 WARNING:urllib3.connectionpool:Connection pool is full, discarding connection: earthengine.googleapis.com. Connection pool size: 10 WARNING:urllib3.connectionpool:Connection pool is full, discarding connection: earthengine.googleapis.com. Connection pool size: 10 WARNING:urllib3.connectionpool:Connection pool is full, discarding connection: earthengine.googleapis.com. Connection pool size: 10 WARNING:urllib3.connectionpool:Connection pool is full, discarding connection: earthengine.googleapis.com. Connection pool size: 10 WARNING:urllib3.connectionpool:Connection pool is full, discarding connection: earthengine.googleapis.com. Connection pool size: 10 WARNING:urllib3.connectionpool:Connection pool is full, discarding connection: earthengine.googleapis.com. Connection pool size: 10 WARNING:urllib3.connectionpool:Connection pool is full, discarding connection: earthengine.googleapis.com. Connection pool size: 10 WARNING:urllib3.connectionpool:Connection pool is full, discarding connection: earthengine.googleapis.com. Connection pool size: 10 WARNING:urllib3.connectionpool:Connection pool is full, discarding connection: earthengine.googleapis.com. Connection pool size: 10 WARNING:urllib3.connectionpool:Connection pool is full, discarding connection: earthengine.googleapis.com. Connection pool size: 10 WARNING:urllib3.connectionpool:Connection pool is full, discarding connection: earthengine.googleapis.com. Connection pool size: 10 /usr/local/lib/python3.12/dist-packages/geedim/image.py:254: RuntimeWarning: Couldn't find STAC entry for: 'projects/satmae-2026/assets/Punjab_Mask_2024_NEW'. return STACClient().get(self.id)
2. Defining Subset (Anti-Crash)... 3. Stacking Multi-Temporal Bands...
0%| |0/48 tiles [00:00<?]
WARNING:urllib3.connectionpool:Connection pool is full, discarding connection: earthengine.googleapis.com. Connection pool size: 10 WARNING:urllib3.connectionpool:Connection pool is full, discarding connection: earthengine.googleapis.com. Connection pool size: 10 WARNING:urllib3.connectionpool:Connection pool is full, discarding connection: earthengine.googleapis.com. Connection pool size: 10 WARNING:urllib3.connectionpool:Connection pool is full, discarding connection: earthengine.googleapis.com. Connection pool size: 10 WARNING:urllib3.connectionpool:Connection pool is full, discarding connection: earthengine.googleapis.com. Connection pool size: 10 WARNING:urllib3.connectionpool:Connection pool is full, discarding connection: earthengine.googleapis.com. Connection pool size: 10 WARNING:urllib3.connectionpool:Connection pool is full, discarding connection: earthengine.googleapis.com. Connection pool size: 10 WARNING:urllib3.connectionpool:Connection pool is full, discarding connection: earthengine.googleapis.com. Connection pool size: 10 WARNING:urllib3.connectionpool:Connection pool is full, discarding connection: earthengine.googleapis.com. Connection pool size: 10 WARNING:urllib3.connectionpool:Connection pool is full, discarding connection: earthengine.googleapis.com. Connection pool size: 10 WARNING:urllib3.connectionpool:Connection pool is full, discarding connection: earthengine.googleapis.com. Connection pool size: 10 WARNING:urllib3.connectionpool:Connection pool is full, discarding connection: earthengine.googleapis.com. Connection pool size: 10 WARNING:urllib3.connectionpool:Connection pool is full, discarding connection: earthengine.googleapis.com. Connection pool size: 10 WARNING:urllib3.connectionpool:Connection pool is full, discarding connection: earthengine.googleapis.com. Connection pool size: 10 WARNING:urllib3.connectionpool:Connection pool is full, discarding connection: earthengine.googleapis.com. Connection pool size: 10 WARNING:urllib3.connectionpool:Connection pool is full, discarding connection: earthengine.googleapis.com. Connection pool size: 10 WARNING:urllib3.connectionpool:Connection pool is full, discarding connection: earthengine.googleapis.com. Connection pool size: 10 WARNING:urllib3.connectionpool:Connection pool is full, discarding connection: earthengine.googleapis.com. Connection pool size: 10 WARNING:urllib3.connectionpool:Connection pool is full, discarding connection: earthengine.googleapis.com. Connection pool size: 10 WARNING:urllib3.connectionpool:Connection pool is full, discarding connection: earthengine.googleapis.com. Connection pool size: 10 WARNING:urllib3.connectionpool:Connection pool is full, discarding connection: earthengine.googleapis.com. Connection pool size: 10 WARNING:urllib3.connectionpool:Connection pool is full, discarding connection: earthengine.googleapis.com. Connection pool size: 10 /usr/local/lib/python3.12/dist-packages/geedim/image.py:254: RuntimeWarning: Couldn't find STAC entry for: 'None'. return STACClient().get(self.id)
0%| |0/48 tiles [00:00<?]
WARNING:urllib3.connectionpool:Connection pool is full, discarding connection: earthengine.googleapis.com. Connection pool size: 10 WARNING:urllib3.connectionpool:Connection pool is full, discarding connection: earthengine.googleapis.com. Connection pool size: 10 WARNING:urllib3.connectionpool:Connection pool is full, discarding connection: earthengine.googleapis.com. Connection pool size: 10 WARNING:urllib3.connectionpool:Connection pool is full, discarding connection: earthengine.googleapis.com. Connection pool size: 10 WARNING:urllib3.connectionpool:Connection pool is full, discarding connection: earthengine.googleapis.com. Connection pool size: 10 WARNING:urllib3.connectionpool:Connection pool is full, discarding connection: earthengine.googleapis.com. Connection pool size: 10 WARNING:urllib3.connectionpool:Connection pool is full, discarding connection: earthengine.googleapis.com. Connection pool size: 10 WARNING:urllib3.connectionpool:Connection pool is full, discarding connection: earthengine.googleapis.com. Connection pool size: 10 WARNING:urllib3.connectionpool:Connection pool is full, discarding connection: earthengine.googleapis.com. Connection pool size: 10 WARNING:urllib3.connectionpool:Connection pool is full, discarding connection: earthengine.googleapis.com. Connection pool size: 10 WARNING:urllib3.connectionpool:Connection pool is full, discarding connection: earthengine.googleapis.com. Connection pool size: 10 WARNING:urllib3.connectionpool:Connection pool is full, discarding connection: earthengine.googleapis.com. Connection pool size: 10 WARNING:urllib3.connectionpool:Connection pool is full, discarding connection: earthengine.googleapis.com. Connection pool size: 10 WARNING:urllib3.connectionpool:Connection pool is full, discarding connection: earthengine.googleapis.com. Connection pool size: 10 WARNING:urllib3.connectionpool:Connection pool is full, discarding connection: earthengine.googleapis.com. Connection pool size: 10 WARNING:urllib3.connectionpool:Connection pool is full, discarding connection: earthengine.googleapis.com. Connection pool size: 10 WARNING:urllib3.connectionpool:Connection pool is full, discarding connection: earthengine.googleapis.com. Connection pool size: 10 WARNING:urllib3.connectionpool:Connection pool is full, discarding connection: earthengine.googleapis.com. Connection pool size: 10 WARNING:urllib3.connectionpool:Connection pool is full, discarding connection: earthengine.googleapis.com. Connection pool size: 10 WARNING:urllib3.connectionpool:Connection pool is full, discarding connection: earthengine.googleapis.com. Connection pool size: 10 WARNING:urllib3.connectionpool:Connection pool is full, discarding connection: earthengine.googleapis.com. Connection pool size: 10 WARNING:urllib3.connectionpool:Connection pool is full, discarding connection: earthengine.googleapis.com. Connection pool size: 10
0%| |0/48 tiles [00:00<?]
WARNING:urllib3.connectionpool:Connection pool is full, discarding connection: earthengine.googleapis.com. Connection pool size: 10 WARNING:urllib3.connectionpool:Connection pool is full, discarding connection: earthengine.googleapis.com. Connection pool size: 10 WARNING:urllib3.connectionpool:Connection pool is full, discarding connection: earthengine.googleapis.com. Connection pool size: 10 WARNING:urllib3.connectionpool:Connection pool is full, discarding connection: earthengine.googleapis.com. Connection pool size: 10 WARNING:urllib3.connectionpool:Connection pool is full, discarding connection: earthengine.googleapis.com. Connection pool size: 10 WARNING:urllib3.connectionpool:Connection pool is full, discarding connection: earthengine.googleapis.com. Connection pool size: 10 WARNING:urllib3.connectionpool:Connection pool is full, discarding connection: earthengine.googleapis.com. Connection pool size: 10 WARNING:urllib3.connectionpool:Connection pool is full, discarding connection: earthengine.googleapis.com. Connection pool size: 10 WARNING:urllib3.connectionpool:Connection pool is full, discarding connection: earthengine.googleapis.com. Connection pool size: 10 WARNING:urllib3.connectionpool:Connection pool is full, discarding connection: earthengine.googleapis.com. Connection pool size: 10 WARNING:urllib3.connectionpool:Connection pool is full, discarding connection: earthengine.googleapis.com. Connection pool size: 10 WARNING:urllib3.connectionpool:Connection pool is full, discarding connection: earthengine.googleapis.com. Connection pool size: 10 WARNING:urllib3.connectionpool:Connection pool is full, discarding connection: earthengine.googleapis.com. Connection pool size: 10 WARNING:urllib3.connectionpool:Connection pool is full, discarding connection: earthengine.googleapis.com. Connection pool size: 10 WARNING:urllib3.connectionpool:Connection pool is full, discarding connection: earthengine.googleapis.com. Connection pool size: 10 WARNING:urllib3.connectionpool:Connection pool is full, discarding connection: earthengine.googleapis.com. Connection pool size: 10 WARNING:urllib3.connectionpool:Connection pool is full, discarding connection: earthengine.googleapis.com. Connection pool size: 10 WARNING:urllib3.connectionpool:Connection pool is full, discarding connection: earthengine.googleapis.com. Connection pool size: 10 WARNING:urllib3.connectionpool:Connection pool is full, discarding connection: earthengine.googleapis.com. Connection pool size: 10 WARNING:urllib3.connectionpool:Connection pool is full, discarding connection: earthengine.googleapis.com. Connection pool size: 10 WARNING:urllib3.connectionpool:Connection pool is full, discarding connection: earthengine.googleapis.com. Connection pool size: 10 WARNING:urllib3.connectionpool:Connection pool is full, discarding connection: earthengine.googleapis.com. Connection pool size: 10
4. Tiling... ✅ Data Ready. Shape: (25, 8, 3, 224, 224)
In [2]:
# CELL 2: MODEL DEFINITION
class SatMAEScratch(nn.Module):
def __init__(self, num_frames=3, in_chans=8, embed_dim=768):
super().__init__()
print(" Initializing SatMAE (Scratch)...")
# 1. 2D Patch Embedding (Treats time as batch initially)
self.patch_embed = nn.Conv2d(in_chans, embed_dim, kernel_size=16, stride=16)
# 2. Embeddings
self.pos_embed = nn.Parameter(torch.zeros(1, 1, 196+1, embed_dim))
self.time_embed = nn.Parameter(torch.zeros(1, num_frames, 1, embed_dim))
self.cls_token = nn.Parameter(torch.zeros(1, 1, 1, embed_dim))
# 3. Transformer
enc_layer = nn.TransformerEncoderLayer(embed_dim, 12, embed_dim*4, batch_first=True, norm_first=True)
self.blocks = nn.TransformerEncoder(enc_layer, 12)
self.norm = nn.LayerNorm(embed_dim)
# 4. Decoder
self.temp_agg = nn.Conv2d(embed_dim * num_frames, embed_dim, kernel_size=1)
self.decoder = nn.Sequential(
nn.Upsample(scale_factor=2), nn.Conv2d(768, 256, 3, 1, 1), nn.BatchNorm2d(256), nn.GELU(),
nn.Upsample(scale_factor=2), nn.Conv2d(256, 128, 3, 1, 1), nn.BatchNorm2d(128), nn.GELU(),
nn.Upsample(scale_factor=2), nn.Conv2d(128, 64, 3, 1, 1), nn.BatchNorm2d(64), nn.GELU(),
nn.Upsample(scale_factor=2), nn.Conv2d(64, 32, 3, 1, 1), nn.BatchNorm2d(32), nn.GELU(),
nn.Conv2d(32, 1, 1)
)
def forward(self, x):
# x: (B, 8, 3, 224, 224)
B, C, T, H, W = x.shape
x = x.permute(0, 2, 1, 3, 4).reshape(B*T, C, H, W)
x = self.patch_embed(x).flatten(2).transpose(1, 2).reshape(B, T, -1, 768)
x = x + self.time_embed
x = x.reshape(B, T*196, 768)
pos = self.pos_embed[:, :, 1:, :].expand(B, T, -1, -1).reshape(B, T*196, 768)
x = x + pos
cls = self.cls_token.expand(B, -1, -1, -1).reshape(B, 1, 768) + self.pos_embed[:, :, 0, :].expand(B, 1, 768)
x = torch.cat((cls, x), dim=1)
x = self.norm(self.blocks(x))[:, 1:, :] # Drop CLS
x = x.reshape(B, T, 196, 768).permute(0, 3, 1, 2).reshape(B, 768*3, 14, 14)
return self.decoder(self.temp_agg(x))
if torch.cuda.is_available(): device = torch.device('cuda')
else: device = torch.device('cpu')
model = SatMAEScratch().to(device)
print(" Model Created.")
Initializing SatMAE (Scratch)...
/usr/local/lib/python3.12/dist-packages/torch/nn/modules/transformer.py:392: UserWarning: enable_nested_tensor is True, but self.use_nested_tensor is False because encoder_layer.norm_first was True warnings.warn(
Model Created.
In [3]:
# CELL 3: TRAINING LOOP
optimizer = optim.AdamW(model.parameters(), lr=LR)
criterion = DiceLoss()
t_hist, v_hist = [], []
print(" Starting Scratch Training...")
for ep in range(EPOCHS):
model.train()
el = 0
for x, y in train_loader:
optimizer.zero_grad()
loss = criterion(model(x.to(device)), y.to(device))
loss.backward()
optimizer.step()
el += loss.item()
t_hist.append(el/len(train_loader))
model.eval()
vl = 0
with torch.no_grad():
for x, y in val_loader:
vl += criterion(model(x.to(device)), y.to(device)).item()
v_hist.append(vl/len(val_loader))
if (ep+1)%5==0: print(f"Ep {ep+1} | Train: {t_hist[-1]:.4f} | Val: {v_hist[-1]:.4f}")
MODEL_NAME = "SatMAE_Scratch"
Starting Scratch Training... Ep 5 | Train: 0.2466 | Val: 0.2127 Ep 10 | Train: 0.2279 | Val: 0.1786 Ep 15 | Train: 0.2081 | Val: 0.1800 Ep 20 | Train: 0.2062 | Val: 0.1668 Ep 25 | Train: 0.2001 | Val: 0.1595 Ep 30 | Train: 0.1884 | Val: 0.1624 Ep 35 | Train: 0.1788 | Val: 0.1650 Ep 40 | Train: 0.1782 | Val: 0.1654 Ep 45 | Train: 0.1712 | Val: 0.1496 Ep 50 | Train: 0.1726 | Val: 0.1472
In [4]:
# CELL 4: VISUALIZATION
print(f"Visualizing {MODEL_NAME}...")
plt.figure(figsize=(10, 4))
plt.plot(t_hist, label='Train')
plt.plot(v_hist, label='Val')
plt.title(f"{MODEL_NAME} Loss Curve")
plt.legend()
plt.show()
model.eval()
x_batch, y_batch = next(iter(val_loader))
with torch.no_grad():
preds = (torch.sigmoid(model(x_batch.to(device))) > 0.5).float().cpu()
fig, ax = plt.subplots(3, 3, figsize=(12, 12))
cols = ["Input (Peak)", "Ground Truth", "Prediction"]
for a, col in zip(ax[0], cols): a.set_title(col)
for i in range(3):
rgb = x_batch[i, [3,2,1], 1, :, :].permute(1, 2, 0).numpy()
rgb = np.clip(rgb * 3.5, 0, 1)
ax[i, 0].imshow(rgb)
ax[i, 1].imshow(y_batch[i, 0], cmap='gray')
ax[i, 2].imshow(preds[i, 0], cmap='gray')
iou = (preds[i]*y_batch[i]).sum() / (preds[i].sum() + y_batch[i].sum() - (preds[i]*y_batch[i]).sum() + 1e-6)
ax[i, 2].set_xlabel(f"IoU: {iou:.2f}")
plt.show()
Visualizing SatMAE_Scratch...
In [5]:
# CELL 5: METRICS & SAVE
def evaluate_and_save():
print(f"Evaluating {MODEL_NAME}...")
model.eval()
all_preds, all_targets = [], []
start = time.time()
with torch.no_grad():
for x, y in val_loader:
probs = torch.sigmoid(model(x.to(device)))
all_preds.extend((probs > 0.5).float().cpu().numpy().flatten())
all_targets.extend(y.cpu().numpy().flatten())
fps = len(ds) / (time.time() - start + 1e-6)
y_p = np.array(all_preds).astype(int)
y_t = np.array(all_targets).astype(int)
metrics = {
"IoU": round(jaccard_score(y_t, y_p, average='binary'), 4),
"F1": round(f1_score(y_t, y_p, average='binary'), 4),
"Pixel_Acc": round(accuracy_score(y_t, y_p), 4),
"FPS": round(fps, 2)
}
print(f" Results: {metrics}")
# Save Files
pd.DataFrame({'Epoch':range(1,len(t_hist)+1), 'Train':t_hist, 'Val':v_hist}).to_csv(f"{SAVE_DIR}{MODEL_NAME}_History.csv", index=False)
torch.save(model.state_dict(), f"{SAVE_DIR}{MODEL_NAME}_Weights.pth")
with open(f"{SAVE_DIR}{MODEL_NAME}_Report.json", 'w') as f: json.dump(metrics, f)
print(" Experiment Saved.")
evaluate_and_save()
Evaluating SatMAE_Scratch...
Results: {'IoU': np.float64(0.8638), 'F1': 0.9269, 'Pixel_Acc': 0.8929, 'FPS': 90.38}
Experiment Saved.
In [7]:
# ==========================================
# CELL: SAVE TO DRIVE UTILITY
# ==========================================
import os
import torch
import pandas as pd
import json
from google.colab import drive
# 1. Mount Drive
drive.mount('/content/drive', force_remount=True)
# 2. Configuration
# Change 'SatMAE_Scratch' to your current model name
MODEL_NAME = "SatMAE_Scratch"
SAVE_DIR = '/content/drive/MyDrive/SatMAE_Scratch/'
# Create directory if it doesn't exist
if not os.path.exists(SAVE_DIR):
os.makedirs(SAVE_DIR)
print(f" Created directory: {SAVE_DIR}")
def save_to_drive(model, t_hist, v_hist, metrics=None):
print(f" Saving {MODEL_NAME} to Google Drive...")
# A. Save Model Weights (.pth)
weights_path = f"{SAVE_DIR}{MODEL_NAME}_Weights.pth"
torch.save(model.state_dict(), weights_path)
print(f" Weights saved: {weights_path}")
# B. Save Training History (.csv)
history_path = f"{SAVE_DIR}{MODEL_NAME}_History.csv"
df = pd.DataFrame({'Epoch': range(1, len(t_hist) + 1), 'Train_Loss': t_hist, 'Val_Loss': v_hist})
df.to_csv(history_path, index=False)
print(f" History saved: {history_path}")
# C. Save Metrics (.json) [Optional]
if metrics:
metrics_path = f"{SAVE_DIR}{MODEL_NAME}_Report.json"
with open(metrics_path, 'w') as f:
json.dump(metrics, f, indent=4)
print(f" Metrics saved: {metrics_path}")
print(" Save Complete.")
# --- EXECUTE SAVE ---
# Ensure these variables exist from your training loop
if 'model' in globals() and 't_hist' in globals() and 'v_hist' in globals():
# If you have metrics from the evaluation cell, pass them here
# current_metrics = {'IoU': 0.75, 'F1': 0.80} # Example
current_metrics = metrics if 'metrics' in globals() else None
save_to_drive(model, t_hist, v_hist, current_metrics)
else:
print(" variables (model, t_hist, v_hist) not found. Run training first.")
Mounted at /content/drive
Created directory: /content/drive/MyDrive/SatMAE_Scratch/
Saving SatMAE_Scratch to Google Drive...
Weights saved: /content/drive/MyDrive/SatMAE_Scratch/SatMAE_Scratch_Weights.pth
History saved: /content/drive/MyDrive/SatMAE_Scratch/SatMAE_Scratch_History.csv
Save Complete.