Do Databricks support XLA compilation for TensorFlow models?
- Mark as New
- Bookmark
- Subscribe
- Mute
- Subscribe to RSS Feed
- Permalink
- Report Inappropriate Content
11-30-2021 10:36 AM
I am defining a sequential Keras model using tensorflow.keras
Runtime: Databricks ML 8.3
Cluster: Standard NC24 with 4 GPUs per node.
To enable XLA compilation, I set the following flag:
tf.config.optimizer.set_jit(True)
Here is the output when I try to train the model:
<command-4238178162238395> in train_distributed_tf(train_count, val_count, params)
18 metrics=['mean_absolute_error', 'mean_absolute_percentage_error'])
19
---> 20 history = model.fit(
21 distributed_train,
22 epochs=EPOCHS,
/databricks/python/lib/python3.8/site-packages/mlflow/utils/autologging_utils/safety.py in safe_patch_function(*args, **kwargs)
485
486 if patch_is_class:
--> 487 patch_function.call(call_original, *args, **kwargs)
488 else:
489 patch_function(call_original, *args, **kwargs)
/databricks/python/lib/python3.8/site-packages/mlflow/utils/autologging_utils/safety.py in call(cls, original, *args, **kwargs)
151 @classmethod
152 def call(cls, original, *args, **kwargs):
--> 153 return cls().__call__(original, *args, **kwargs)
154
155 def __call__(self, original, *args, **kwargs):
/databricks/python/lib/python3.8/site-packages/mlflow/utils/autologging_utils/safety.py in __call__(self, original, *args, **kwargs)
162 # Regardless of what happens during the `_on_exception` callback, reraise
163 # the original implementation exception once the callback completes
--> 164 raise e
165
166
/databricks/python/lib/python3.8/site-packages/mlflow/utils/autologging_utils/safety.py in __call__(self, original, *args, **kwargs)
155 def __call__(self, original, *args, **kwargs):
156 try:
--> 157 return self._patch_implementation(original, *args, **kwargs)
158 except (Exception, KeyboardInterrupt) as e:
159 try:
/databricks/python/lib/python3.8/site-packages/mlflow/utils/autologging_utils/safety.py in _patch_implementation(self, original, *args, **kwargs)
214 self.managed_run = try_mlflow_log(create_managed_run)
215
--> 216 result = super(PatchWithManagedRun, self)._patch_implementation(
217 original, *args, **kwargs
218 )
/databricks/python/lib/python3.8/site-packages/mlflow/tensorflow.py in _patch_implementation(self, original, inst, *args, **kwargs)
1086 _log_early_stop_callback_params(early_stop_callback)
1087
-> 1088 history = original(inst, *args, **kwargs)
1089
1090 _log_early_stop_callback_metrics(early_stop_callback, history, metrics_logger)
/databricks/python/lib/python3.8/site-packages/mlflow/utils/autologging_utils/safety.py in call_original(*og_args, **og_kwargs)
443 disable_warnings=False, reroute_warnings=False,
444 😞
--> 445 original_result = original(*og_args, **og_kwargs)
446
447 try_log_autologging_event(
/databricks/python/lib/python3.8/site-packages/tensorflow/python/keras/engine/training.py in fit(self, x, y, batch_size, epochs, verbose, callbacks, validation_split, validation_data, shuffle, class_weight, sample_weight, initial_epoch, steps_per_epoch, validation_steps, validation_batch_size, validation_freq, max_queue_size, workers, use_multiprocessing)
1098 _r=1):
1099 callbacks.on_train_batch_begin(step)
-> 1100 tmp_logs = self.train_function(iterator)
1101 if data_handler.should_sync:
1102 context.async_wait()
/databricks/python/lib/python3.8/site-packages/tensorflow/python/eager/def_function.py in __call__(self, *args, **kwds)
826 tracing_count = self.experimental_get_tracing_count()
827 with trace.Trace(self._name) as tm:
--> 828 result = self._call(*args, **kwds)
829 compiler = "xla" if self._experimental_compile else "nonXla"
830 new_tracing_count = self.experimental_get_tracing_count()
/databricks/python/lib/python3.8/site-packages/tensorflow/python/eager/def_function.py in _call(self, *args, **kwds)
886 # Lifting succeeded, so variables are initialized and we can run the
887 # stateless function.
--> 888 return self._stateless_fn(*args, **kwds)
889 else:
890 _, _, _, filtered_flat_args = \
/databricks/python/lib/python3.8/site-packages/tensorflow/python/eager/function.py in __call__(self, *args, **kwargs)
2940 (graph_function,
2941 filtered_flat_args) = self._maybe_define_function(args, kwargs)
-> 2942 return graph_function._call_flat(
2943 filtered_flat_args, captured_inputs=graph_function.captured_inputs) # pylint: disable=protected-access
2944
/databricks/python/lib/python3.8/site-packages/tensorflow/python/eager/function.py in _call_flat(self, args, captured_inputs, cancellation_manager)
1916 and executing_eagerly):
1917 # No tape is watching; skip to running the function.
-> 1918 return self._build_call_outputs(self._inference_function.call(
1919 ctx, args, cancellation_manager=cancellation_manager))
1920 forward_backward = self._select_forward_and_backward_functions(
/databricks/python/lib/python3.8/site-packages/tensorflow/python/eager/function.py in call(self, ctx, args, cancellation_manager)
553 with _InterpolateFunctionError(self):
554 if cancellation_manager is None:
--> 555 outputs = execute.execute(
556 str(self.signature.name),
557 num_outputs=self._num_outputs,
/databricks/python/lib/python3.8/site-packages/tensorflow/python/eager/execute.py in quick_execute(op_name, num_outputs, inputs, attrs, ctx, name)
57 try:
58 ctx.ensure_initialized()
---> 59 tensors = pywrap_tfe.TFE_Py_Execute(ctx._handle, device_name, op_name,
60 inputs, attrs, num_outputs)
61 except core._NotOkStatusException as e:
InternalError: 5 root error(s) found.
(0) Internal: libdevice not found at ./libdevice.10.bc
[[{{node cluster_3_1/xla_compile}}]]
[[div_no_nan_33/ReadVariableOp_3/_318]]
(1) Internal: libdevice not found at ./libdevice.10.bc
[[{{node cluster_3_1/xla_compile}}]]
(2) Internal: libdevice not found at ./libdevice.10.bc
[[{{node cluster_3_1/xla_compile}}]]
[[div_no_nan/_825]]
(3) Internal: libdevice not found at ./libdevice.10.bc
[[{{node cluster_3_1/xla_compile}}]]
[[div_no_nan_26/AddN/_272]]
(4) Internal: libdevice not found at ./libdevice.10.bc
[[{{node cluster_3_1/xla_compile}}]]
[[div_no_nan/_821]]
0 successful operations.
0 derived errors ignored. [Op:__inference_train_function_2599244]
Function call stack:
train_function -> train_function -> train_function -> train_function -> train_function
- Labels:
-
Autoloader
-
Databricks Runtime