sarahbhord
Databricks Employee
Databricks Employee

Here are some suggestions: 

1. Update coda.yaml. Replace the current config with this optimized version: 

channels:
  - conda-forge
dependencies:
  - python=3.10  # 3.12 may cause compatibility issues
  - pip
  - pip:
    - mlflow==2.21.3
    - torch==2.2.1  # Align with CUDA 12.1
    - transformers==4.40.0  # Latest stable for multi-GPU
    - accelerate==0.29.0  # Critical for device_map="auto"
    - bitsandbytes==0.43.0  # For 8/4-bit quantization
    - xformers==0.0.25  # Memory-efficient attention
name: chatts-env
variables:
  MLFLOW_HUGGINGFACE_DEVICE_MAP_STRATEGY: auto  # Not "sequential"

2. Model loading fixes. In your MLFlow model's inference script, enforce multi-GPU distribution.

from transformers import AutoModelForCausalLM, AutoTokenizer
import torch

def load_model(model_path):
    model = AutoModelForCausalLM.from_pretrained(
        model_path,
        device_map="auto",  # Distribute across all GPUs
        torch_dtype=torch.float16,  # 16-bit
        trust_remote_code=True,
        low_cpu_mem_usage=True  # Reduce CPU RAM pressure
    )
    return model

 3. Serving endpoint configuration. Use this JSON payload when creating the endpoint to ensure tensor parallelism. 

{
  "name": "qwen-chat-endpoint",
  "config": {
    "served_entities": [{
      "entity_name": "catalog.schema.model_name",
      "entity_version": "1",
      "workload_type": "GPU_LARGE",  # Use A100 GPUs (80GB each)
      "workload_size": "Large",      # 4xGPUs
      "task": "llm/v1/completions",
      "environment_vars": {
        "HF_HOME": "/dbfs/huggingface",
        "MAX_JOBS": "4"  # Parallelize model loading
      }
    }]
  }
}

4. Other adjustments: 

  • Avoid T4 GPUs: They only have 16GB each. use A100 instances with 40GB/GPU.
  • Quantize further: Add load_in_8bit=True to your model loading code if 16-bit isnt enough. 
  • Check layer splitting: if device_map="auto"" fails, manually specify no_split_module_classes for QWEN's architecture.

 

If the error persists, share the full CUDA OOM log to debug layer-specific memory issues.