cancel
Showing results forย 
Search instead forย 
Did you mean:ย 
Generative AI
Explore discussions on generative artificial intelligence techniques and applications within the Databricks Community. Share ideas, challenges, and breakthroughs in this cutting-edge field.
cancel
Showing results forย 
Search instead forย 
Did you mean:ย 

How to utilize clustered gpu for large hf models

dk_g
New Contributor

Hi,

I am using clustered GPU(driver -1GPU and Worker-3GPU), and caching model data into unity catalog but while loading model checkpointโ€‡shards its always use driver memory and failed due insufficient memory.

How to use complete cluster GPU while loading HF models.

Thanks

1 REPLY 1

lin-yuan
Databricks Employee
Databricks Employee

1. Are you using any of the model parallel library, such as FSDP or DeepSpeed? Otherwise, every GPU will load the entire model weights. 

2. If yes in 1, Unity Catalog Volumes are exposed on every node at /Volumes/<catalog>/<schema>/<volume>/..., so workers can open files themselves without going through the driver.

An example code will look like folllowing:

import os, torch
local_rank = int(os.environ.get("LOCAL_RANK", 0))
ckpt_dir = "/Volumes/<catalog>/<schema>/<volume>/checkpoints/epoch-10"

# Example DeepSpeed ZeRO-3 shard name pattern; adjust to your framework.
fname = f"mp_rank_{local_rank:02}_model_states.pt"

# Always deserialize to CPU first to avoid big transient spikes in driver/GPU
state = torch.load(os.path.join(ckpt_dir, fname), map_location="cpu")
# then load into the module on this worker
model.load_state_dict(state["module"], strict=False)

Please let me know if this solved your problem. Thanks 

Join Us as a Local Community Builder!

Passionate about hosting events and connecting people? Help us grow a vibrant local communityโ€”sign up today to get started!

Sign Up Now