How to utilize clustered gpu for large hf models

dk_g
New Contributor

Hi,

I am using clustered GPU(driver -1GPU and Worker-3GPU), and caching model data into unity catalog but while loading model checkpoint shards its always use driver memory and failed due insufficient memory.

How to use complete cluster GPU while loading HF models.

Thanks

lin-yuan
Databricks Employee
Databricks Employee

1. Are you using any of the model parallel library, such as FSDP or DeepSpeed? Otherwise, every GPU will load the entire model weights. 

2. If yes in 1, Unity Catalog Volumes are exposed on every node at /Volumes/<catalog>/<schema>/<volume>/..., so workers can open files themselves without going through the driver.

An example code will look like folllowing:

import os, torch
local_rank = int(os.environ.get("LOCAL_RANK", 0))
ckpt_dir = "/Volumes/<catalog>/<schema>/<volume>/checkpoints/epoch-10"

# Example DeepSpeed ZeRO-3 shard name pattern; adjust to your framework.
fname = f"mp_rank_{local_rank:02}_model_states.pt"

# Always deserialize to CPU first to avoid big transient spikes in driver/GPU
state = torch.load(os.path.join(ckpt_dir, fname), map_location="cpu")
# then load into the module on this worker
model.load_state_dict(state["module"], strict=False)

Please let me know if this solved your problem. Thanks