from __future__ import annotations
import torch
import torch.nn.functional as F
def learn_lawmm_mask(
inversed_steps_latents: torch.Tensor,
sample_steps_latents: torch.Tensor,
signature_fft: torch.Tensor,
sample_steps=(5, 10, 15),
threshold: float = 0.91,
num_learning_epoch: int = 50,
lr: float = 0.02,
show_progress: bool = False,
) -> torch.Tensor:
"""
Learn the LAWMM binary mask (frequency-domain adaptive mask).
Args:
inversed_steps_latents: [T_inv, B, C, H, W] inversion latents.
sample_steps_latents: [T_sam, B, C, H, W] sampling latents.
signature_fft: complex tensor, broadcastable to [B, C, H, W].
sample_steps: which time steps to use as constraints (1-based in original impl).
threshold: sigmoid(p) > threshold -> True.
num_learning_epoch: optimization steps for mask logits p.
lr: Adam lrforp.
show_progress:if True, show a tqdm progress bar (if installed).
Returns:
final_mask: bool tensor with shape [B, C, H, W].
"""
if inversed_steps_latents.ndim != 5or sample_steps_latents.ndim != 5:
raise ValueError("Expected latents with shape [T, B, C, H, W].")
device = inversed_steps_latents.device
T_inv, B, C, H, W = inversed_steps_latents.shape
# ---- Step indices (match original logic) ----
step_idx_x = torch.tensor([int(i) - 1for i in sample_steps], device=device, dtype=torch.long)
step_idx_y = (T_inv - 1) - torch.tensor([int(i) for i in sample_steps], device=device, dtype=torch.long)
if (step_idx_x.min() < 0) or (step_idx_x.max() >= sample_steps_latents.shape[0]):
raise ValueError("sample_steps out of range for sample_steps_latents.")
if (step_idx_y.min() < 0) or (step_idx_y.max() >= inversed_steps_latents.shape[0]):
raise ValueError("sample_steps out of range for inversed_steps_latents.")
# ---- IMPORTANT: FFT safety => doall FFT math in float32/complex64 ----
x_list = sample_steps_latents.index_select(0, step_idx_x).detach().float() # [K,B,C,H,W] float32
y = inversed_steps_latents.index_select(0, step_idx_y).detach().float() # [K,B,C,H,W] float32
# ---- Broadcast signature_fft to [B,C,H,W] and ensure complex64 ----
sig = signature_fft
if not torch.is_complex(sig):
raise TypeError("signature_fft must be a complex tensor (torch.complex64/128).")
ifsig.ndim == 2:
sig = sig[None, None, :, :] # [1,1,H,W]
elif sig.ndim == 3:
sig = sig[None, :, :, :] # [1,C,H,W]
elif sig.ndim != 4:
raise ValueError("signature_fft must have shape [H,W] or [C,H,W] or [B,C,H,W].")
ifsig.shape[-2:] != (H, W):
raise ValueError(f"signature_fft spatial size {sig.shape[-2:]} != latents {(H,W)}")
ifsig.shape[1] not in (1, C):
raise ValueError(f"signature_fft channel dim must be 1 or {C}, got {sig.shape[1]}")
ifsig.shape[0] not in (1, B):
raise ValueError(f"signature_fft batch dim must be 1 or {B}, got {sig.shape[0]}")
sig = sig.to(device=device)
# Force to complex64 for stable FFT ops
ifsig.dtype != torch.complex64:
sig = sig.to(torch.complex64)
sig = sig.expand(B, C, H, W).detach() # [B,C,H,W] complex64
# ---- Learn mask logits p in float32 (Adam + stability) ----
p = torch.rand((B, C, H, W), device=device, dtype=torch.float32, requires_grad=True)
opt = torch.optim.Adam([p], lr=float(lr))
it = range(int(num_learning_epoch))
if show_progress:
try:
from tqdm import tqdm
it = tqdm(it, total=int(num_learning_epoch))
except Exception:
pass
for _ in it:
loss = 0.0
mask_soft = torch.sigmoid(p) # float32
fork in range(x_list.shape[0]):
x = x_list[k] # float32 [B,C,H,W]
# FFT -> complex64
X = torch.fft.fft2(x)
X_prime = mask_soft * sig + (1.0 - mask_soft) * X
x_prime = torch.fft.ifft2(X_prime).real # float32
# Match original style constraint:
loss = loss + F.mse_loss(x_prime + (sig * mask_soft).real, y[k])
opt.zero_grad(set_to_none=True)
loss.backward()
opt.step()
final_mask = (torch.sigmoid(p) > float(threshold))
return final_mask
def apply_watermark_fft(
last_inverse_latent: torch.Tensor,
watermarking_mask: torch.Tensor,
signature_fft: torch.Tensor
) -> torch.Tensor:
"""
Replace FFT bins selected by watermarking_mask with signature_fft bins.
last_inverse_latent: [B,C,H,W] real
watermarking_mask : bool broadcastable to [B,C,H,W]
signature_fft : complex broadcastable to [B,C,H,W]
"""
if last_inverse_latent.ndim != 4:
raise ValueError("last_inverse_latent must be [B,C,H,W].")
if not torch.is_complex(signature_fft):
raise TypeError("signature_fft must be complex.")
B, C, H, W = last_inverse_latent.shape
device = last_inverse_latent.device
x = last_inverse_latent.float()
X = torch.fft.fft2(x)
sig = signature_fft
ifsig.ndim == 2:
sig = sig[None, None, :, :]
elif sig.ndim == 3:
sig = sig[None, :, :, :]
elif sig.ndim != 4:
raise ValueError("signature_fft must have shape [H,W] or [C,H,W] or [B,C,H,W].")
sig = sig.to(device=device)
ifsig.dtype != torch.complex64:
sig = sig.to(torch.complex64)
sig = sig.expand(B, C, H, W)
mask = watermarking_mask
if mask.dtype != torch.bool:
mask = mask.bool()
if mask.shape != (B, C, H, W):
mask = mask.expand(B, C, H, W)
X = X.clone()
X[mask] = sig[mask].clone()
return torch.fft.ifft2(X).real