1. Are you using any of the model parallel library, such as FSDP or DeepSpeed? Otherwise, every GPU will load the entire model weights.
2. If yes in 1, Unity Catalog Volumes are exposed on every node at /Volumes/<catalog>/<schema>/<volume>/..., so workers can open files themselves without going through the driver.
An example code will look like folllowing:
import os, torch
local_rank = int(os.environ.get("LOCAL_RANK", 0))
ckpt_dir = "/Volumes/<catalog>/<schema>/<volume>/checkpoints/epoch-10"
# Example DeepSpeed ZeRO-3 shard name pattern; adjust to your framework.
fname = f"mp_rank_{local_rank:02}_model_states.pt"
# Always deserialize to CPU first to avoid big transient spikes in driver/GPU
state = torch.load(os.path.join(ckpt_dir, fname), map_location="cpu")
# then load into the module on this worker
model.load_state_dict(state["module"], strict=False)
Please let me know if this solved your problem. Thanks