Hi team,
I’m processing ~5,000 EMR notes with a Databricks notebook. The job reads from `crc_lakehouse.bronze.emr_notes`, runs SciSpaCy UMLS entity extraction plus a fine-tuned BERT sentiment model per partition, and builds a DataFrame (`df_entities`) with JSON fields for keywords, sentiment, and note sections.
Cluster details:
- Driver: 64 GB RAM, 8 cores
- Workers: 64 GB RAM, 8-core CPU each, autoscaling 1–4 nodes
Symptoms:
1. `df_entities.show()` works, so the transformation succeeds.
2. As soon as I try to `write`/save the DataFrame to another table (for example, `write.format("delta").mode("append").saveAsTable(...)`), the job fails with an OOM / executor memory error. No rows are persisted.
3. Only ~5k rows are processed, so I expected this to fit easily on this cluster profile.
What I’ve checked:
- Repartitioned input to 16 partitions.
- Verified no skew in the source table.
- Tried caching, disabling broadcast, lowering `show()` counts—none change the failure when writing.
- No custom memory configs; using defaults for this cluster size.
Could you help identify why the write stage is exhausting memory despite the modest dataset? Are there best practices for running SciSpaCy + transformer sentiment inside `mapPartitions` on this configuration so the output can be saved?
Full notebook code (`pipelines/pipelines/script_01_2025-11-14 09_22_15.py`):
```
# Databricks notebook source
import json
from pyspark.sql.functions import col, trim, length, to_date, current_date, date_sub
SOURCE_TABLE = "crc_lakehouse.bronze.emr_notes"
# Fetch only yesterday's non-empty notes
df_yesterday = (
spark.read.table(SOURCE_TABLE)
.select("notes_id", "contentResolved", "ingestion_time")
.filter(to_date(col("ingestion_time")) == date_sub(current_date(), 1))
.filter(col("contentResolved").isNotNull() & (length(trim(col("contentResolved"))) > 0))
.select("notes_id", "contentResolved")
)
# Repartition to match worker cores (8 cores = 8 partitions)
df_yesterday = df_yesterday.repartition(16)
# COMMAND ----------
df_yesterday.count()
# COMMAND ----------
# MAGIC %run ../helpers/umls_utils
# COMMAND ----------
# Global variables to hold models per executor
global_nlp = None
global_linker = None
global_tokenizer = None
global_sentiment_model = None
def get_pipeline():
"""Load NLP and sentiment models once per executor"""
global global_nlp, global_linker, global_tokenizer, global_sentiment_model
if global_nlp is None or global_linker is None:
import spacy
from scispacy.abbreviation import AbbreviationDetector
import time
from scispacy.linking import EntityLinker
import torch
from transformers import AutoTokenizer, AutoModelForSequenceClassification
# Load SciSpacy NLP
t0 = time.time()
nlp = spacy.load("en_core_sci_lg")
if "abbreviation_detector" not in nlp.pipe_names:
try:
nlp.add_pipe("abbreviation_detector")
except Exception:
nlp.add_pipe(AbbreviationDetector(nlp))
if "scispacy_linker" not in nlp.pipe_names:
nlp.add_pipe(
"scispacy_linker",
config={"resolve_abbreviations": True, "linker_name": "umls"},
)
linker = nlp.get_pipe("scispacy_linker")
print(f"✅ NLP loaded in executor in {time.time()-t0:.2f}s")
# Load sentiment model
MODEL_PATH = "../training/fine_tuned_bert_sentiment_v2"
tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH)
sentiment_model = AutoModelForSequenceClassification.from_pretrained(MODEL_PATH)
sentiment_model.eval()
print("✅ Sentiment model loaded in executor")
# Save globally for this executor
global_nlp = nlp
global_linker = linker
global_tokenizer = tokenizer
global_sentiment_model = sentiment_model
return global_nlp, global_linker, global_tokenizer, global_sentiment_model
# COMMAND ----------
def process_partition_rows(iterator):
nlp, linker, tokenizer, sentiment_model = get_pipeline()
import torch
from rapidfuzz import fuzz, process
import time
MIN_SCORE = 0.80
ALLOWED_GROUPS = {"Disorders", "Drugs", "Anatomy", "Procedures", "Physiology"}
SENTIMENT_LABELS = ["negative", "neutral", "positive"]
def analyze_sentiment(text):
"""Sentiment for a single text"""
inputs = tokenizer(
text, return_tensors="pt", truncation=True, max_length=128, padding=True
)
with torch.no_grad():
outputs = sentiment_model(**inputs)
probs = torch.softmax(outputs.logits, dim=1)[0]
pred_class = torch.argmax(probs).item()
return {
"sentiment": SENTIMENT_LABELS[pred_class],
"confidence": float(probs[pred_class]),
"scores": {
"negative": float(probs[0]),
"neutral": float(probs[1]),
"positive": float(probs[2])
}
}
def get_match_type_with_aliases(entity_text, concept):
t = (entity_text or "").strip().lower()
canon = (getattr(concept, "canonical_name", "") or "").strip().lower()
aliases = [a.strip().lower() for a in (getattr(concept, "aliases", []) or [])]
if t == canon:
return "Exact"
if t in aliases:
return "Synonym"
score_canon = fuzz.token_set_ratio(t, canon) if canon else 0
best_alias_score = 0
if aliases:
match = process.extractOne(t, aliases, scorer=fuzz.token_set_ratio)
best_alias_score = match[1] if match else 0
score = max(score_canon, best_alias_score)
return "Fuzzy" if score >= 85 else "Synonym"
def draftjs_to_text(draftjs_json):
"""Convert Draft.js JSON content to plain text"""
if not draftjs_json:
return ""
if isinstance(draftjs_json, str):
try:
data = json.loads(draftjs_json)
except Exception:
return draftjs_json
else:
data = draftjs_json
blocks = data.get("blocks", [])
text = "\n".join(block.get("text", "") for block in blocks)
return text
for row in iterator:
notes_id = row.notes_id
original_draftjs = row.contentResolved
plain_text = draftjs_to_text(original_draftjs)
# Prepare note_section dictionary
note_section = {
"original_note": original_draftjs,
"plain_text": plain_text
}
entities_list = []
# UMLS entity extraction
if plain_text:
doc = nlp(plain_text)
all_entities = []
for ent in doc.ents:
if not ent._.kb_ents:
continue
umls_id, score = ent._.kb_ents[0]
if score < MIN_SCORE:
continue
concept = linker.kb.cui_to_entity.get(umls_id)
if concept is None:
continue
tui_codes = list(getattr(concept, "types", []) or [])
semantic_groups = format_semantic_types(tui_codes, format_type="group")
semantic_type_names = format_semantic_types(tui_codes, format_type="full")
group_set = {g.strip() for g in semantic_groups.split(",") if g.strip()}
if ALLOWED_GROUPS and group_set.isdisjoint(ALLOWED_GROUPS):
continue
match_type = get_match_type_with_aliases(ent.text, concept)
status = "allow" if match_type in ["Exact", "Synonym"] else "not-allow"
definition = getattr(concept, "definition", "N/A") or "N/A"
aliases_count = len(getattr(concept, "aliases", []))
all_entities.append({
"extracted_text": ent.text,
"entity_label": ent.label_,
"start_char": ent.start_char,
"end_char": ent.end_char,
"cui": umls_id,
"canonical_name": concept.canonical_name,
"match_type": match_type,
"status": status,
"confidence_score": round(score, 4),
"category": semantic_groups,
"detailed_types": semantic_type_names,
"tui_codes": ", ".join(tui_codes),
"definition": definition[:200] if len(definition) > 200 else definition,
"aliases_count": aliases_count,
})
# Deduplicate and count mentions
text_counts = {}
for e in all_entities:
t = e["extracted_text"].lower().strip()
text_counts[t] = text_counts.get(t, 0) + 1
seen = set()
for e in all_entities:
t = e["extracted_text"].lower().strip()
if t not in seen:
seen.add(t)
e["mention_count"] = text_counts[t]
entities_list.append(e)
# Sentiment analysis
sentiment_result = analyze_sentiment(plain_text) if plain_text else None
# Yield all four fields
yield (notes_id, entities_list, sentiment_result, note_section)
# COMMAND ----------
import json
from pyspark.sql import Row
from pyspark.sql.functions import current_timestamp
from pyspark.sql.types import StructType, StructField, StringType, TimestampType
# Call the partition processor to create the RDD
processed_rdd = df_yesterday.rdd.mapPartitions(process_partition_rows)
# Convert RDD to DataFrame with JSON fields
df_entities = processed_rdd.map(lambda x: Row(
note_id=x[0] if x[0] is not None else "",
keywords_extracted=json.dumps(x[1]) if x[1] else "[]",
sentiment_analysis=json.dumps(x[2]) if x[2] else "{}",
note_section=json.dumps(x[3]) if x[3] else "{}"
)).toDF()
# Define schema explicitly
schema = StructType([
StructField("note_id", StringType(), True),
StructField("keywords_extracted", StringType(), True),
StructField("sentiment_analysis", StringType(), True),
StructField("note_section", StringType(), True)
])
df_entities = spark.createDataFrame(df_entities.rdd, schema)
# Add processing timestamp
df_entities = df_entities.withColumn("processed_timestamp", current_timestamp())
# Show the DataFrame instead of saving
# df_entities.show(truncate=False) # Set truncate=False to see full content
df_entities.write.mode("overwrite").saveAsTable(
"crc_lakehouse.silver.notes_nlp_processed"
)
```
Full error trace:
```
Py4JJavaError: An error occurred while calling o625.saveAsTable.
: org.apache.spark.SparkException: Job aborted due to stage failure: Task 3 in stage 20.0 failed 4 times, most recent failure: Lost task 3.3 in stage 20.0 (TID 123) (10.139.64.8 executor 5): org.apache.spark.api.python.PythonException: Traceback (most recent call last):
File "/databricks/spark/python/pyspark/worker.py", line 1980, in main
process()
File "/databricks/spark/python/pyspark/worker.py", line 1972, in process
serializer.dump_stream(out_iter, outfile)
File "/databricks/spark/python/pyspark/serializers.py", line 356, in dump_stream
vs = list(itertools.islice(iterator, batch))
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/root/.ipykernel/2187/command-4748090034621861-3307002280", line 2, in process_partition_rows
File "/root/.ipykernel/2187/command-4748090034621856-3519326438", line 30, in get_pipeline
File "/local_disk0/.ephemeral_nfs/cluster_libraries/python/lib/python3.11/site-packages/spacy/language.py", line 821, in add_pipe
pipe_component = self.create_pipe(
^^^^^^^^^^^^^^^^^
File "/local_disk0/.ephemeral_nfs/cluster_libraries/python/lib/python3.11/site-packages/spacy/language.py", line 709, in create_pipe
resolved = registry.resolve(cfg, validate=validate)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/databricks/python/lib/python3.11/site-packages/confection/__init__.py", line 759, in resolve
resolved, _ = cls._make(
^^^^^^^^^^
File "/databricks/python/lib/python3.11/site-packages/confection/__init__.py", line 808, in _make
filled, _, resolved = cls._fill(
^^^^^^^^^^
File "/databricks/python/lib/python3.11/site-packages/confection/__init__.py", line 880, in _fill
getter_result = getter(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^
File "/local_disk0/.ephemeral_nfs/cluster_libraries/python/lib/python3.11/site-packages/scispacy/linking.py", line 85, in __init__
self.candidate_generator = candidate_generator or CandidateGenerator(
^^^^^^^^^^^^^^^^^^^
File "/local_disk0/.ephemeral_nfs/cluster_libraries/python/lib/python3.11/site-packages/scispacy/candidate_generation.py", line 221, in __init__
self.ann_index = ann_index or load_approximate_nearest_neighbours_index(
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/local_disk0/.ephemeral_nfs/cluster_libraries/python/lib/python3.11/site-packages/scispacy/candidate_generation.py", line 141, in load_approximate_nearest_neighbours_index
ann_index.loadIndex(cached_path(linker_paths.ann_index))
RuntimeError: basic_ios::clear: iostream error
at org.apache.spark.api.python.BasePythonRunner$ReaderIterator.handlePythonException(PythonRunner.scala:604)
at org.apache.spark.api.python.PythonRunner$$anon$3.read(PythonRunner.scala:1063)
at org.apache.spark.api.python.PythonRunner$$anon$3.read(PythonRunner.scala:1048)
at org.apache.spark.api.python.BasePythonRunner$ReaderIterator.hasNext(PythonRunner.scala:558)
at org.apache.spark.InterruptibleIterator.hasNext(InterruptibleIterator.scala:37)
at scala.collection.Iterator$$anon$11 hasNext(Iterator.scala:491)
... (full Spark stack trace continues)
Driver stacktrace:
at org.apache.spark.scheduler.DAGScheduler.$anonfun$failJobAndIndependentStages$1(DAGScheduler.scala:4043)
... (truncated for brevity) ...
Caused by: org.apache.spark.api.python.PythonException: Traceback (most recent call last):
File "/databricks/spark/python/pyspark/worker.py", line 1980, in main
process()
File "/databricks/spark/python/pyspark/worker.py", line 1972, in process
serializer.dump_stream(out_iter, outfile)
File "/databricks/spark/python/pyspark/serializers.py", line 356, in dump_stream
vs = list(itertools.islice(iterator, batch))
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/root/.ipykernel/2187/command-4748090034621861-3307002280", line 2, in process_partition_rows
File "/root/.ipykernel/2187/command-4748090034621856-3519326438", line 30, in get_pipeline
File "/local_disk0/.ephemeral_nfs/cluster_libraries/python/lib/python3.11/site-packages/spacy/language.py", line 821, in add_pipe
pipe_component = self.create_pipe(
^^^^^^^^^^^^^^^^^
File "/local_disk0/.ephemeral_nfs/cluster_libraries/python/lib/python3.11/site-packages/spacy/language.py", line 709, in create_pipe
resolved = registry.resolve(cfg, validate=validate)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/databricks/python/lib/python3.11/site-packages/confection/__init__.py", line 759, in resolve
resolved, _ = cls._make(
^^^^^^^^^^
File "/databricks/python/lib/python3.11/site-packages/confection/__init__.py", line 808, in _make
filled, _, resolved = cls._fill(
^^^^^^^^^^
File "/databricks/python/lib/python3.11/site-packages/confection/__init__.py", line 880, in _fill
getter_result = getter(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^
File "/local_disk0/.ephemeral_nfs/cluster_libraries/python/lib/python3.11/site-packages/scispacy/linking.py", line 85, in __init__
self.candidate_generator = candidate_generator or CandidateGenerator(
^^^^^^^^^^^^^^^^^^^
File "/local_disk0/.ephemeral_nfs/cluster_libraries/python/lib/python3.11/site-packages/scispacy/candidate_generation.py", line 221, in __init__
self.ann_index = ann_index or load_approximate_nearest_neighbours_index(
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/local_disk0/.ephemeral_nfs/cluster_libraries/python/lib/python3.11/site-packages/scispacy/candidate_generation.py", line 141, in load_approximate_nearest_neighbours_index
ann_index.loadIndex(cached_path(linker_paths.ann_index))
RuntimeError: basic_ios::clear: iostream error
```
Any guidance would be appreciated.