cancel
Showing results for 
Search instead for 
Did you mean: 
Data Engineering
Join discussions on data engineering best practices, architectures, and optimization strategies within the Databricks Community. Exchange insights and solutions with fellow data engineers.
cancel
Showing results for 
Search instead for 
Did you mean: 

Do Databricks support XLA compilation for TensorFlow models?

ray21
New Contributor II

I am defining a sequential Keras model using tensorflow.keras

Runtime: Databricks ML 8.3

Cluster: Standard NC24 with 4 GPUs per node.

To enable XLA compilation, I set the following flag:

tf.config.optimizer.set_jit(True)

Here is the output when I try to train the model:

<command-4238178162238395> in train_distributed_tf(train_count, val_count, params)

18 metrics=['mean_absolute_error', 'mean_absolute_percentage_error'])

19

---> 20 history = model.fit(

21 distributed_train,

22 epochs=EPOCHS,

/databricks/python/lib/python3.8/site-packages/mlflow/utils/autologging_utils/safety.py in safe_patch_function(*args, **kwargs)

485

486 if patch_is_class:

--> 487 patch_function.call(call_original, *args, **kwargs)

488 else:

489 patch_function(call_original, *args, **kwargs)

/databricks/python/lib/python3.8/site-packages/mlflow/utils/autologging_utils/safety.py in call(cls, original, *args, **kwargs)

151 @classmethod

152 def call(cls, original, *args, **kwargs):

--> 153 return cls().__call__(original, *args, **kwargs)

154

155 def __call__(self, original, *args, **kwargs):

/databricks/python/lib/python3.8/site-packages/mlflow/utils/autologging_utils/safety.py in __call__(self, original, *args, **kwargs)

162 # Regardless of what happens during the `_on_exception` callback, reraise

163 # the original implementation exception once the callback completes

--> 164 raise e

165

166

/databricks/python/lib/python3.8/site-packages/mlflow/utils/autologging_utils/safety.py in __call__(self, original, *args, **kwargs)

155 def __call__(self, original, *args, **kwargs):

156 try:

--> 157 return self._patch_implementation(original, *args, **kwargs)

158 except (Exception, KeyboardInterrupt) as e:

159 try:

/databricks/python/lib/python3.8/site-packages/mlflow/utils/autologging_utils/safety.py in _patch_implementation(self, original, *args, **kwargs)

214 self.managed_run = try_mlflow_log(create_managed_run)

215

--> 216 result = super(PatchWithManagedRun, self)._patch_implementation(

217 original, *args, **kwargs

218 )

/databricks/python/lib/python3.8/site-packages/mlflow/tensorflow.py in _patch_implementation(self, original, inst, *args, **kwargs)

1086 _log_early_stop_callback_params(early_stop_callback)

1087

-> 1088 history = original(inst, *args, **kwargs)

1089

1090 _log_early_stop_callback_metrics(early_stop_callback, history, metrics_logger)

/databricks/python/lib/python3.8/site-packages/mlflow/utils/autologging_utils/safety.py in call_original(*og_args, **og_kwargs)

443 disable_warnings=False, reroute_warnings=False,

444 😞

--> 445 original_result = original(*og_args, **og_kwargs)

446

447 try_log_autologging_event(

/databricks/python/lib/python3.8/site-packages/tensorflow/python/keras/engine/training.py in fit(self, x, y, batch_size, epochs, verbose, callbacks, validation_split, validation_data, shuffle, class_weight, sample_weight, initial_epoch, steps_per_epoch, validation_steps, validation_batch_size, validation_freq, max_queue_size, workers, use_multiprocessing)

1098 _r=1):

1099 callbacks.on_train_batch_begin(step)

-> 1100 tmp_logs = self.train_function(iterator)

1101 if data_handler.should_sync:

1102 context.async_wait()

/databricks/python/lib/python3.8/site-packages/tensorflow/python/eager/def_function.py in __call__(self, *args, **kwds)

826 tracing_count = self.experimental_get_tracing_count()

827 with trace.Trace(self._name) as tm:

--> 828 result = self._call(*args, **kwds)

829 compiler = "xla" if self._experimental_compile else "nonXla"

830 new_tracing_count = self.experimental_get_tracing_count()

/databricks/python/lib/python3.8/site-packages/tensorflow/python/eager/def_function.py in _call(self, *args, **kwds)

886 # Lifting succeeded, so variables are initialized and we can run the

887 # stateless function.

--> 888 return self._stateless_fn(*args, **kwds)

889 else:

890 _, _, _, filtered_flat_args = \

/databricks/python/lib/python3.8/site-packages/tensorflow/python/eager/function.py in __call__(self, *args, **kwargs)

2940 (graph_function,

2941 filtered_flat_args) = self._maybe_define_function(args, kwargs)

-> 2942 return graph_function._call_flat(

2943 filtered_flat_args, captured_inputs=graph_function.captured_inputs) # pylint: disable=protected-access

2944

/databricks/python/lib/python3.8/site-packages/tensorflow/python/eager/function.py in _call_flat(self, args, captured_inputs, cancellation_manager)

1916 and executing_eagerly):

1917 # No tape is watching; skip to running the function.

-> 1918 return self._build_call_outputs(self._inference_function.call(

1919 ctx, args, cancellation_manager=cancellation_manager))

1920 forward_backward = self._select_forward_and_backward_functions(

/databricks/python/lib/python3.8/site-packages/tensorflow/python/eager/function.py in call(self, ctx, args, cancellation_manager)

553 with _InterpolateFunctionError(self):

554 if cancellation_manager is None:

--> 555 outputs = execute.execute(

556 str(self.signature.name),

557 num_outputs=self._num_outputs,

/databricks/python/lib/python3.8/site-packages/tensorflow/python/eager/execute.py in quick_execute(op_name, num_outputs, inputs, attrs, ctx, name)

57 try:

58 ctx.ensure_initialized()

---> 59 tensors = pywrap_tfe.TFE_Py_Execute(ctx._handle, device_name, op_name,

60 inputs, attrs, num_outputs)

61 except core._NotOkStatusException as e:

InternalError: 5 root error(s) found.

(0) Internal: libdevice not found at ./libdevice.10.bc

[[{{node cluster_3_1/xla_compile}}]]

[[div_no_nan_33/ReadVariableOp_3/_318]]

(1) Internal: libdevice not found at ./libdevice.10.bc

[[{{node cluster_3_1/xla_compile}}]]

(2) Internal: libdevice not found at ./libdevice.10.bc

[[{{node cluster_3_1/xla_compile}}]]

[[div_no_nan/_825]]

(3) Internal: libdevice not found at ./libdevice.10.bc

[[{{node cluster_3_1/xla_compile}}]]

[[div_no_nan_26/AddN/_272]]

(4) Internal: libdevice not found at ./libdevice.10.bc

[[{{node cluster_3_1/xla_compile}}]]

[[div_no_nan/_821]]

0 successful operations.

0 derived errors ignored. [Op:__inference_train_function_2599244]

Function call stack:

train_function -> train_function -> train_function -> train_function -> train_function

3 REPLIES 3

mathan_pillai
Databricks Employee
Databricks Employee

Hi @Revanth Pentyala​ 

Can you please try with DBR 7.3 ML cluster ?

It seems like cupti library was deprecated staring from 7.6 DBR.

https://docs.databricks.com/release-notes/runtime/7.6ml.html#deprecations

It seems the cupti version(9) which comes with ubuntu was not compatible with CUDA(11). The workaround would be to install the compatible cupti package(11) through init script.

However for now you can try with DBR 7.3 ML to see if it works there

Thanks

Mathan

Hi @Revanth Pentyala​ ,

Did Mathan's response help you to solve your question/issue? if it did, please mark it as "best" to it can be moved to the top and help others

sean_owen
Databricks Employee
Databricks Employee

I don't think this is specific to Databricks, but rather Tensorflow. See https://stackoverflow.com/questions/68614547/tensorflow-libdevice-not-found-why-is-it-not-found-in-t... for a possibly relevant solution.

I don't see evidence that this is related to libcupti

Connect with Databricks Users in Your Area

Join a Regional User Group to connect with local Databricks users. Events will be happening in your city, and you won’t want to miss the chance to attend and share knowledge.

If there isn’t a group near you, start one and help create a community that brings people together.

Request a New Group