import numpy as np
import pandas as pd
import torch
import matplotlib
import matplotlib.pyplot as plt
[docs]
def colorbar_list(AIA_wavelengths, HMI_components):
cmaplist = [
"sdoaia94",
"sdoaia131",
"sdoaia171",
"sdoaia193",
"sdoaia211",
"sdoaia304",
"sdoaia335",
"sdoaia1600",
"sdoaia1700",
"sdoaia4500",
"hmimag",
"hmimag",
"hmimag",
]
wavelengths_components = [
"94A",
"131A",
"171A",
"193A",
"211A",
"304A",
"335A",
"1600A",
"1700A",
"4500A",
"Bx",
"By",
"Bz",
]
cmap_dict = dict(zip(wavelengths_components, cmaplist))
cmap_list = []
for wavelength in AIA_wavelengths:
cmap_list.append(cmap_dict[wavelength])
for component in HMI_components:
cmap_list.append(cmap_dict[component])
return cmap_list
[docs]
def wsa_to_image(lat_sh, lon_sh, r_sun=200, ctr_pxl=256):
"""
Returns position on SDOML image of PSP footpoint from WSA.
lat_sh and lon_sh are Stonyhurst latitude and longitude in degrees.
r_sun is radius of sun on image in units of pixels.
ctr_pixel is the value of the center of each axis in units of pixels.
"""
lat_sh = (2 * np.pi / 360.0) * lat_sh
lon_sh = (2 * np.pi / 360.0) * lon_sh
x_pxl = np.round(r_sun * np.cos(lat_sh) * np.sin(lon_sh) + ctr_pxl).astype(int)
y_pxl = np.round(r_sun * np.sin(lat_sh) + ctr_pxl).astype(int)
return x_pxl, y_pxl
[docs]
def plot_sdoml(
datamodule,
condition=None,
dataset_idx=None,
times=None,
n_samples=None,
wavelengths=[
"131A",
"1600A",
"1700A",
"171A",
"193A",
"211A",
"304A",
"335A",
"94A",
],
components=["Bx", "By", "Bz"],
wsa_footpoint=False,
title=None,
):
indexed_data = datamodule.aligndata.reset_index()
if (
(dataset_idx is None) & (times is None) & (condition is None)
| (dataset_idx is not None) & (times is not None)
| (dataset_idx is not None) & (condition is not None)
| (condition is not None) & (times is not None)
):
raise TypeError("Must specify only one of condition, dataset_idx, times.")
if dataset_idx is not None:
try:
iter(dataset_idx)
ds_idx = dataset_idx.copy()
except TypeError:
ds_idx = [dataset_idx.copy()]
subset = indexed_data[ds_idx]
subset = subset[(subset["lon_footpoint"].abs() < 85)]
if times is not None:
try:
iter(times)
times_idx = times.copy()
except TypeError:
times_idx = [times.copy()]
subset = datamodule.aligndata.loc[times_idx, :]
subset["time_sdo_loc_est"] = subset.index
subset = subset.reset_index(drop=True)
subset = subset[(subset["lon_footpoint"].abs() < 85)]
if len(subset) == 0: # Is there no data?
fig, axes = plt.subplots(
1, 1, figsize=(2 * len(subset), 2 * len(wavelengths))
)
axes.axis("off")
# Add text at relative position (0-1 scale)
axes.text(
0.5,
0.5,
f"No sample",
transform=axes.transAxes,
ha="center",
va="center",
)
plt.suptitle(title)
return (fig, axes)
if condition is not None:
subset = indexed_data[condition & (indexed_data["lon_footpoint"].abs() < 85)]
if n_samples is not None:
subset = subset.sample(n_samples)
fig, axes = plt.subplots(
len(wavelengths) + len(components),
len(subset),
figsize=(2 * len(subset), 2 * len(wavelengths)),
)
if title is not None:
plt.suptitle(title)
colormaps = colorbar_list(datamodule.wavelengths, datamodule.components)
for idx, i in enumerate(subset.index):
imgs = [
datamodule.aia_data[subset.loc[i, "year"].astype(int).astype(str)][
wavelength
][subset.loc[i, "idx_" + wavelength].astype(int)]
for wavelength in datamodule.wavelengths
]
for component in datamodule.components:
imgs.append(
datamodule.hmi_data[subset.loc[i, "year"].astype(int).astype(str)][
component
][subset.loc[i, "idx_" + component].astype(int)]
)
time = subset.loc[i, "time_sdo_loc_est"]
if wsa_footpoint:
wsa_lon = subset.loc[i, "lon_footpoint"]
wsa_lat = subset.loc[i, "lat_footpoint"]
x_foot, y_foot = wsa_to_image(wsa_lat, wsa_lon)
plot_title = time.strftime("%Y-%m-%d\n%H:%M:%S")
axes[0, idx].set_title(f"{plot_title}")
for jdx, img in enumerate(imgs): # Plot the images down the column
if jdx < len(datamodule.wavelengths): # For AIA data
axes[jdx, idx].imshow(
img, cmap=matplotlib.colormaps[colormaps[jdx]], norm="log"
)
axes[jdx, idx].text(
5, 482, wavelengths[jdx], color="white", size="x-small"
)
else: # For HMI data
dynamic_range = np.max(np.abs(img))
axes[jdx, idx].imshow(
img,
cmap=matplotlib.colormaps[colormaps[jdx]],
vmin=-1 * dynamic_range,
vmax=dynamic_range,
)
axes[jdx, idx].text(
5,
482,
components[jdx - len(wavelengths)],
color="white",
size="x-small",
)
if wsa_footpoint:
axes[jdx, idx].plot(
x_foot, y_foot, marker="x", color="k", markersize=5, alpha=0.3
)
axes[jdx, idx].axis("off")
return (fig, axes)
[docs]
def find_images_labels(
imgs, timestamps, targets, preds, position, class_id, ch_id, max_samples=3
):
"""
Args:
imgs, targets, preds, position: Input tensors
class_id: Class to filter for
max_samples: Maximum number of samples to return (default: 3)
"""
# Create masks for correct/incorrect predictions of the target class
class_mask = targets == class_id
correct_mask = torch.logical_and(targets == preds, class_mask)
incorrect_mask = torch.logical_and(targets != preds, class_mask)
correct_indices = torch.nonzero(correct_mask, as_tuple=True)[0]
incorrect_indices = torch.nonzero(incorrect_mask, as_tuple=True)[0]
def sample_indices(indices, max_samples):
"""Helper function to sample indices efficiently."""
if len(indices) == 0:
return torch.tensor([], dtype=torch.long, device=indices.device)
elif len(indices) <= max_samples:
return indices
else:
perm = torch.randperm(len(indices))[:max_samples]
return indices[perm]
# Sample indices
chosen_correct_idx = sample_indices(correct_indices, max_samples)
chosen_incorrect_idx = sample_indices(incorrect_indices, max_samples)
# Create results
def create_result_dict(chosen_idx, data_tensors, keys):
"""Helper to create result dictionary."""
if len(chosen_idx) == 0:
return {key: np.array([]) for key in keys}
return {
key: tensor[chosen_idx].detach().cpu().numpy()
for key, tensor in zip(keys, data_tensors)
}
# result = {}
# for key, tensor in zip(keys, data_tensors):
# print(key, chosen_idx)
# result[key] = tensor[chosen_idx].detach().cpu().numpy()
keys = ["imgs", "timestamps", "targets", "preds", "position"]
data_tensors = [imgs[:, ch_id, 0, :, :], timestamps, targets, preds, position]
result_correct_dict = create_result_dict(chosen_correct_idx, data_tensors, keys)
result_incorrect_dict = create_result_dict(chosen_incorrect_idx, data_tensors, keys)
return result_correct_dict, result_incorrect_dict
[docs]
def find_images_labels_embed(
imgs, timestamps, targets, preds, position, class_id, max_samples=3
):
"""
Args:
imgs, targets, preds, position: Input tensors
class_id: Class to filter for
max_samples: Maximum number of samples to return (default: 3)
"""
# Create masks for correct/incorrect predictions of the target class
class_mask = targets == class_id
correct_mask = torch.logical_and(targets == preds, class_mask)
incorrect_mask = torch.logical_and(targets != preds, class_mask)
correct_indices = torch.nonzero(correct_mask, as_tuple=True)[0]
incorrect_indices = torch.nonzero(incorrect_mask, as_tuple=True)[0]
def sample_indices(indices, max_samples):
"""Helper function to sample indices efficiently."""
if len(indices) == 0:
return torch.tensor([], dtype=torch.long, device=indices.device)
elif len(indices) <= max_samples:
return indices
else:
perm = torch.randperm(len(indices))[:max_samples]
return indices[perm]
# Sample indices
chosen_correct_idx = sample_indices(correct_indices, max_samples)
chosen_incorrect_idx = sample_indices(incorrect_indices, max_samples)
# Create results
def create_result_dict(chosen_idx, data_tensors, keys):
"""Helper to create result dictionary."""
if len(chosen_idx) == 0:
return {key: np.array([]) for key in keys}
return {
key: tensor[chosen_idx].detach().cpu().numpy()
for key, tensor in zip(keys, data_tensors)
}
# result = {}
# for key, tensor in zip(keys, data_tensors):
# print(key, chosen_idx)
# result[key] = tensor[chosen_idx].detach().cpu().numpy()
keys = ["imgs", "timestamps", "targets", "preds", "position"]
data_tensors = [imgs, timestamps, targets, preds, position]
result_correct_dict = create_result_dict(chosen_correct_idx, data_tensors, keys)
result_incorrect_dict = create_result_dict(chosen_incorrect_idx, data_tensors, keys)
return result_correct_dict, result_incorrect_dict
[docs]
def plot_images_grid(correct_data, plt_style, step=None):
plt.style.use(plt_style)
sdoaia193 = matplotlib.colormaps["sdoaia193"]
fig, axes = plt.subplots(4, 3, figsize=(8, 12))
for class_id in range(4):
cor_imgs = correct_data["imgs"][class_id]
if len(cor_imgs) == 0:
# logger.info(f"Class {class_id} has no correct samples")
for i in range(3):
ax = axes[class_id, i]
ax.axis("off")
# Add text at relative position (0-1 scale)
ax.text(
0.5,
0.5,
f"Class: {class_id}\nNo sample",
transform=ax.transAxes,
ha="center",
va="center",
)
# ax.set_visible(False)
continue
gtuths = correct_data["targets"][class_id]
pred_val = correct_data["preds"][class_id]
loc = correct_data["position"][class_id]
raw_times = correct_data["timestamps"][class_id]
# Convert timestamps to UTC strings
# times = [
# datetime.fromtimestamp(ts, tz=timezone.utc).strftime('%Y-%m-%d %H:%M:%S')
# for ts in raw_times
# ]
times = pd.to_datetime(raw_times).strftime("%Y-%m-%d %H:%M:%S").to_list()
num_images = cor_imgs.shape[0] if cor_imgs.ndim > 2 else 1
for i in range(3):
ax = axes[class_id, i]
if i < num_images:
if num_images >= 2:
img = cor_imgs[i]
current_gt = gtuths[i]
current_pred = pred_val[i]
current_loc = loc[i]
timestamp = times[i]
else:
img = cor_imgs[0]
current_gt = gtuths[0]
current_pred = pred_val[0]
current_loc = loc[0]
timestamp = times[0]
lon_psp = current_loc[0]
lat_psp = current_loc[1]
lat_footpoint = current_loc[2]
lon_footpoint = current_loc[3]
# lon_footpoint = current_loc[0]
# lat_footpoint = current_loc[1]
x_foot, y_foot = wsa_to_image(lat_footpoint, lon_footpoint)
x_psp, y_psp = wsa_to_image(lat_psp, lon_psp)
ax.imshow(img, cmap=sdoaia193)
ax.plot(
x_foot,
y_foot,
marker="x",
color="white",
markersize=5,
label="footpoint",
)
ax.plot(
x_psp, y_psp, marker="x", color="grey", markersize=5, label="PSP"
)
ax.legend(
loc="upper center", ncols=2, fontsize="small", labelcolor="white"
)
ax.set_title(f"{timestamp}\nTarget: {current_gt}, Pred: {current_pred}")
ax.axis("off")
else:
ax.axis("off")
# Add text at relative position (0-1 scale)
ax.text(
0.5,
0.5,
f"Class: {class_id}\nNo sample",
transform=ax.transAxes,
ha="center",
va="center",
)
return fig
[docs]
def plot_ecliptic(
img,
model,
time,
x_range=np.linspace(-213, 213, 30),
y_range=np.linspace(-213, 213, 30),
z_range=np.linspace(0, 0, 1),
r_mean=np.float32(8.918359e07),
r_std=np.float32(3.0130634e07),
inner_mask=10,
outer_mask=220,
):
x_range_grid, y_range_grid, z_range_grid = np.meshgrid(x_range, y_range, z_range)
psp_lat_grid = np.arctan2(
z_range_grid, np.sqrt(x_range_grid**2 + y_range_grid**2)
) * (360.0 / (2 * np.pi))
psp_lon_grid = np.arctan2(y_range_grid, x_range_grid) * (360.0 / (2 * np.pi))
psp_r_grid = np.sqrt(
x_range_grid**2 + y_range_grid**2 + z_range_grid**2
) # PSP radial position (must be normed before being passed to model)
wsa_lat_grid = psp_lat_grid # Solar disk latitudes
wsa_lon_grid = psp_lon_grid # Solar disk longitudes
pos_arr = np.array(
[
psp_lon_grid.ravel(), # Create position array (columns in order: PSP Lon, PSP, Lat, WSA Lat, WSA Lon)
psp_lat_grid.ravel(),
wsa_lat_grid.ravel(),
wsa_lon_grid.ravel(),
]
).reshape(len(wsa_lon_grid.ravel()), 4)
imgs_arr = np.zeros(
(len(wsa_lat_grid.ravel()), 9, img.shape[1], img.shape[2], img.shape[3])
) # Make array of same image for each entry
imgs_arr[:] = img[:9, :, :, :].numpy()
pos_disk_grid = torch.from_numpy(pos_arr).to(
torch.float32
) # Turn position array into torch tensor
r_disk_grid = torch.from_numpy((psp_r_grid.ravel() * 695_700 - r_mean) / r_std).to(
torch.float32
) # Make torch tensor from radial distance w/ normalization
imgs_disk_grid = torch.from_numpy(imgs_arr).to(torch.float32)
# Make the predictions in the ecliptic
with torch.no_grad():
dev = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
y_hat_disk_grid = model(
imgs_disk_grid.to(dev), pos_disk_grid.to(dev), r_disk_grid.to(dev)
) # MAE only uses first 9 channels, that's why imgs is indexed the way it is
preds_disk_grid = torch.argmax(y_hat_disk_grid, dim=1)
# Reshape predictions into a grid
preds_disk_grid_arr = preds_disk_grid.detach().cpu().numpy()
# Plot predictions across the ecliptic
fig, axes = plt.subplots(1, 1, figsize=(8, 8))
timestr = pd.to_datetime(time).strftime(
"%Y-%m-%d %H:%M:%S"
) # Could find another way to get first time
preds_disk_grid_arr_masked = np.where(
np.sqrt(x_range_grid**2 + y_range_grid**2) < outer_mask,
preds_disk_grid_arr.reshape(x_range_grid.shape),
np.nan,
)
# Plot classes as image in ecliptic
labels = ["Ejecta", "Coronal Hole", "Sector Reversal", "Streamer Belt"]
plt.imshow(
preds_disk_grid_arr_masked[:, :, 0],
cmap="viridis",
vmin=0,
vmax=3,
origin="lower",
extent=(x_range[0], x_range[-1], y_range[0], y_range[-1]),
interpolation="gaussian",
)
cbar = plt.colorbar(label="SW Classes")
cbar.set_ticks(ticks=[0, 1, 2, 3], labels=labels)
# Add solar system objects/masks
plt.scatter([0], [213], marker="o", color="k", label="Earth")
sun_mask = plt.Circle((0, 0), inner_mask, color="gray")
axes.add_patch(sun_mask)
# plt.legend(loc='center left', bbox_to_anchor=(1, 0.5))
# Adjust the axes:
plt.xlim(-230, 230)
plt.ylim(-230, 230)
plt.gca().set_aspect("equal")
axes.set_title(f"{timestr}\nSolar Wind Types in Ecliptic")
axes.set_xlabel(r"Stonyhurst X ($R_S$)")
axes.set_ylabel(r"Stonyhurst Y ($R_S$)")
return fig
[docs]
def plot_disk_distribution(
img,
model,
time,
wsa_lat=np.linspace(-75, 75, 30),
wsa_lon=np.linspace(-90, 90, 30),
labels=["Ejecta", "Coronal Hole", "Sector Reversal", "Streamer Belt"],
colors=["#191923", "#0E79B2", "#bf1363", "#F39237"],
):
psp_lat = 0 * wsa_lat # PSP latitude (set constant)
psp_lon = (wsa_lon + 180 - 5) % 360 - 180 # PSP longitude (set constant)
psp_r = 0 # This is a scaled value. 0 = mean, 1 = 1std away from mean, etc.
wsa_lat_grid, wsa_lon_grid = np.meshgrid(
wsa_lat, wsa_lon
) # Latitude/longitude grid points (both 2D)
psp_lat_grid, psp_lon_grid = np.meshgrid(psp_lat, psp_lon)
pos_arr = np.array(
[
psp_lon_grid.ravel(), # Create position array (columns in order: PSP Lon, PSP, Lat, WSA Lat, WSA Lon)
psp_lat_grid.ravel(),
wsa_lat_grid.ravel(),
wsa_lon_grid.ravel(),
]
).reshape(len(wsa_lon_grid.ravel()), 4)
imgs_arr = np.zeros(
(len(wsa_lat_grid.ravel()), 9, img.shape[1], img.shape[2], img.shape[3])
) # Make array of same image for each entry
imgs_arr[:] = img[:9, :, :, :].numpy()
pos_disk_grid = torch.from_numpy(pos_arr).to(
torch.float32
) # Turn position array into torch tensor
r_disk_grid = torch.from_numpy(psp_r * np.ones(len(pos_arr))).to(
torch.float32
) # Make torch tensor from radial distance
imgs_disk_grid = torch.from_numpy(imgs_arr).to(torch.float32)
# Make the predictions on the solar disk
with torch.no_grad():
dev = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
y_hat_disk_grid = model(
imgs_disk_grid.to(dev), pos_disk_grid.to(dev), r_disk_grid.to(dev)
) # MAE only uses first 9 channels, that's why imgs is indexed the way it is
preds_disk_grid = torch.argmax(y_hat_disk_grid, dim=1)
# Reshape predictions into a grid
preds_disk_grid_arr = preds_disk_grid.detach().cpu().numpy()
sdoaia193 = matplotlib.colormaps["sdoaia193"]
fig, axes = plt.subplots(1, 1, figsize=(8, 8))
timestr = pd.to_datetime(time).strftime(
"%Y-%m-%d %H:%M:%S"
) # Could find another way to get first time
x_foot, y_foot = wsa_to_image(wsa_lat_grid.ravel(), wsa_lon_grid.ravel())
# Plot the underlying solar disk
axes.imshow(imgs_arr[0, 4, 0, :, :], cmap=sdoaia193)
# Plot each class in turn
for i in range(4): # Loop over class labels
indx = preds_disk_grid_arr == i
plt.scatter(x_foot[indx], y_foot[indx], s=5, color=colors[i], label=labels[i])
plt.legend(loc="center left", bbox_to_anchor=(1, 0.5))
axes.set_title(f"{timestr}\nSolar Disk Class Origin Distribution")
axes.axis("off")
return fig
[docs]
def plot_tsne(
embs,
position,
r_distance,
labels,
perplexities=[5, 30, 50, 100, 300],
classes=["Ejecta", "Coronal Hole", "Sector Reversal", "Streamer Belt"],
colors=["#191923", "#0E79B2", "#bf1363", "#F39237"],
):
from sklearn import manifold
from matplotlib.ticker import NullFormatter
# Get everything into one tensor
batch_size = embs.size(0)
r_tensor = r_distance.to(dtype=embs.dtype, device=embs.device)
# Ensure position encoding matches batch size
if position.size(0) != batch_size:
position = position.expand(batch_size, -1)
# Concatenate and process
r_tensor = r_tensor.view(batch_size, 1) if r_tensor.ndim == 1 else r_tensor
combined = torch.cat([embs, position, r_tensor.reshape(batch_size, -1)], dim=-1)
# Convert the tensor to numpy for sklearn
X = combined.numpy()
y = labels.numpy()
# Generate t-SNE embeddings of the embeddings at various "perplexities"
Y_list = []
for i, perplexity in enumerate(perplexities):
tsne = manifold.TSNE(
n_components=2,
init="random",
random_state=0,
perplexity=perplexity,
max_iter=300,
)
Y = tsne.fit_transform(X)
# Plot the clusters
fig, axes = plt.subplots(1, len(perplexities), figsize=(25, 5))
for i, perplexity in enumerate(perplexities):
axes[i].set_title("Perplexity=%d" % perplexity)
for j in range(4): # Loop over class labels
indx = y == j
axes[i].scatter(
Y_list[i][indx, 0],
Y_list[i][indx, 1],
color=colors[j],
label=classes[j],
alpha=0.25,
)
axes[i].xaxis.set_major_formatter(NullFormatter())
axes[i].yaxis.set_major_formatter(NullFormatter())
axes[i].axis("tight")
return fig