cancel
Showing results for 
Search instead for 
Did you mean: 
Get Started Discussions
Start your journey with Databricks by joining discussions on getting started guides, tutorials, and introductory topics. Connect with beginners and experts alike to kickstart your Databricks experience.
cancel
Showing results for 
Search instead for 
Did you mean: 

Problem with ray train and Databricks Notebook (Strange dbutils error)

JavierS
New Contributor

Hi everyone,

I'm running some code to train a multimodal Hugging Face model with SFTTrainer and TorchTrainer to use all GPU workers. When trying to execute trainer.fit() it gives me a dbutils serialization error,
even I am not using dbutils directly in my code and when I try to restart the Ray cluster it gives me the same dbutils error:

 

 

Exception: You cannot use dbutils within a spark job
File <command-1633567752829769>, line 1
----> 1 trainer.fit()

 

 My code is the given:

 

 

import ray
from ray.util.spark import setup_ray_cluster, shutdown_ray_cluster, MAX_NUM_WORKER_NODES

num_cpu_cores_per_worker = 8 # total cpu's present in each node
num_gpu_per_worker = 1 # total gpu's present in each node
resource_per_worker_int = (num_cpu_cores_per_worker / num_gpu_per_worker) - 2
use_gpu = True
ray_log_dir = f"/local_disk0/ray_logs"

try: 
  shutdown_ray_cluster()
except:
  print("No Ray cluster is initiated")

# Start the ray cluster and follow the output link to open the Ray Dashboard - a vital observability tool for understanding your infrastructure and application.
setup_ray_cluster(
  num_worker_nodes=MAX_NUM_WORKER_NODES,
  num_cpus_per_node=num_cpu_cores_per_worker,
  num_gpus_per_node=num_gpu_per_worker,
  num_cpus_head_node=8,
  num_gpus_head_node=1,
  collect_log_to_path=ray_log_dir
)

ray.init(ignore_reinit_error=True)

import torch
from accelerate import Accelerator
from datasets import load_dataset

from transformers import AutoModelForVision2Seq, AutoProcessor, LlavaForConditionalGeneration, BitsAndBytesConfig
from trl import (
    ModelConfig,
    SFTConfig,
    SFTTrainer
)
from peft import LoraConfig

import ray.train
from ray.train.huggingface.transformers import prepare_trainer, RayTrainReportCallback


import ray
import ray.train.huggingface.transformers
from ray.train import ScalingConfig, RunConfig, CheckpointConfig
from ray.train.torch import TorchTrainer 

def train_fn():
    ##########################
    # Load model and processor
    ##########################

    # BitsAndBytesConfig int-4 config
    bnb_config = BitsAndBytesConfig(
        load_in_4bit=True, bnb_4bit_use_double_quant=True, bnb_4bit_quant_type="nf4", bnb_4bit_compute_dtype=torch.bfloat16
    )
    model_id = "meta-llama/Llama-3.2-11B-Vision-Instruct"
    processor = AutoProcessor.from_pretrained(model_id)
    model = AutoModelForVision2Seq.from_pretrained(model_id, torch_dtype=torch.bfloat16, quantization_config=bnb_config)

    #######################################################
    # Create a data collator to encode text and image pairs
    #######################################################
    def collate_fn(examples):
        # Get the texts and images, and apply the chat template
        texts = [processor.apply_chat_template(example["messages"], tokenize=False) for example in examples]
        images = [example["images"] for example in examples]

        # Tokenize the texts and process the images
        batch = processor(text=texts, images=images, return_tensors="pt", padding=True)

        # The labels are the input_ids, and we mask the padding tokens in the loss computation
        labels = batch["input_ids"].clone()
        labels[labels == processor.tokenizer.pad_token_id] = -100  #
        # Ignore the image token index in the loss computation (model specific)
        image_token_id = processor.tokenizer.convert_tokens_to_ids(processor.image_token)
        labels[labels == image_token_id] = -100
        batch["labels"] = labels

        return batch

    ##############
    # Load dataset
    ##############
    dataset = load_dataset("HuggingFaceH4/llava-instruct-mix-vsft", split="train[:1000]")
    dataset = dataset.train_test_split(test_size=0.2)
    dataset_train = ray.data.from_huggingface(dataset['train'])
    dataset_val = ray.data.from_huggingface(dataset['test']) 

    ###################
    # Configure trainer
    ###################


    # LoRA config based on QLoRA paper & Sebastian Raschka experiment
    peft_config = LoraConfig(
            lora_alpha=16,
            lora_dropout=0.05,
            r=8,
            bias="none",
            target_modules=["q_proj", "v_proj"],
            task_type="CAUSAL_LM", 
    )
    training_args = SFTConfig(
        output_dir="my-awesome-llama", 
        gradient_checkpointing=True,
        gradient_accumulation_steps=8,
        bf16=True,
        remove_unused_columns=False,
        dataset_kwargs = {"skip_prepare_dataset": True} # important for collator
    )

    trainer = SFTTrainer(
        model=model,
        args=training_args,
        data_collator=collate_fn,
        train_dataset=dataset_train,
        eval_dataset=dataset_val,
        peft_config=peft_config,
        tokenizer=processor.tokenizer,
    )

    # Train!

    callback = ray.train.huggingface.transformers.RayTrainReportCallback()
    trainer.add_callback(callback)
    trainer = ray.train.huggingface.transformers.prepare_trainer(trainer)
    trainer.train()



if __name__ == "__main__":


    # Preparing train configurations
    # training config
    train_loop_config = {
        "per_device_train_batch_size": 1,
        "per_device_eval_batch_size": 1,
        "gradient_accumulation_steps": 4,
        "learning_rate": 2e-4,
        "max_steps": 100,
        "save_steps": 10,
        "logging_steps": 10,
    }


    scaling_config = ScalingConfig(num_workers=4, use_gpu=True)

    run_config = RunConfig(checkpoint_config=CheckpointConfig(num_to_keep=5,
                            checkpoint_score_attribute="loss",
                            checkpoint_score_order="min"),
                            storage_path=f"/local_disk0/train_logs/",
                            name='RAY_TEST_ON_LLAMA_VISUAL')
    trainer = TorchTrainer(
        train_loop_per_worker=train_fn,
        train_loop_config=train_loop_config,
        run_config=run_config,
        scaling_config=scaling_config
    )


    # train
    result = trainer.fit()
    print(f"Training result: {result}")

 

 

1 ACCEPTED SOLUTION

Accepted Solutions

sarahbhord
Databricks Employee
Databricks Employee

JavierS - 

The dbutils serialization error occurs in your code because dbutils is only available on the Databricks driver node and cannot be pickled or transferred to Spark or Ray worker nodes. This error can appear even if your code doesn't directly call dbutils—if any import or dependency (including libraries or initialization scripts) references dbutils at the module/global level, that reference may be serialized along with your training function or objects, causing the error when trainer.fit() is called in a distributed context.

Some tips: 

  • Relocate any imports—especially those that might transitively pull in dbutils—inside your train_fn (the function passed to TorchTrainer). This keeps driver-only modules (like dbutils) out of the scope that Ray/TorchTrainer serializes to the workers.
  • Make sure neither your code nor any imported script/module at global scope uses dbutils.* anywhere outside the driver before launching distributed jobs.
  • Retrieve dbutils-based values on the driver before launching the trainer, then pass them to the workers as standard variables (not as part of module-level imports/objects).

Let me know if this helps! 

Best, 

Sarah

 

 

View solution in original post

1 REPLY 1

sarahbhord
Databricks Employee
Databricks Employee

JavierS - 

The dbutils serialization error occurs in your code because dbutils is only available on the Databricks driver node and cannot be pickled or transferred to Spark or Ray worker nodes. This error can appear even if your code doesn't directly call dbutils—if any import or dependency (including libraries or initialization scripts) references dbutils at the module/global level, that reference may be serialized along with your training function or objects, causing the error when trainer.fit() is called in a distributed context.

Some tips: 

  • Relocate any imports—especially those that might transitively pull in dbutils—inside your train_fn (the function passed to TorchTrainer). This keeps driver-only modules (like dbutils) out of the scope that Ray/TorchTrainer serializes to the workers.
  • Make sure neither your code nor any imported script/module at global scope uses dbutils.* anywhere outside the driver before launching distributed jobs.
  • Retrieve dbutils-based values on the driver before launching the trainer, then pass them to the workers as standard variables (not as part of module-level imports/objects).

Let me know if this helps! 

Best, 

Sarah