Databricks Pyspark filter several columns with similar criteria
- Mark as New
- Bookmark
- Subscribe
- Mute
- Subscribe to RSS Feed
- Permalink
- Report Inappropriate Content
10-11-2024 06:03 PM - edited 10-11-2024 06:04 PM
I am querying a table from the Databricks Catalog which I have to filter several columns with the same criteria. below is what I have created so far. I have 10 columns that I have filter with a set of criteria from (dx_list1) and another 10 that I have to filter with another set of criteria (dx_list2).
I have started doing this: col("DX_00").isin(dx_list1) | col("DX_01").isin (dx_list2) and was planning to go all the way to DX_19.
I am wondering if there is a more efficient way to get the same results or is this as good as it gets. Thank you
Code below:
dx_list1 = ['J984', 'J466', 'J754', 'J64', 'J465', 'J445']
dx_list2 = ['J454','J445','J443','J76','J487','J765','J765']
test = spark.table("claim").select("ID","MEMB_KEY","PROV_KEY","CL#","DOS","DX_00","DX_01","DX_02","DX_03",
"DX_04","DX_05","DX_06","DX_07","DX_08","DX_09","DX_10")
.filter(
col("DX_00").isin(dx_list1) | col("DX_01").isin (dx_list2) &
col("DX_02").isin(dx_list1) | col("DX_03").isin (dx_list2) &
col("DX_04").isin(dx_list1) | col("DX_05").isin (dx_list2)
)
display(test)
- Mark as New
- Bookmark
- Subscribe
- Mute
- Subscribe to RSS Feed
- Permalink
- Report Inappropriate Content
10-13-2024 03:18 AM - edited 10-13-2024 03:19 AM
Hi @abueno ,
As I understand the logic you want to implement is:
1. For every pair of columns:
- First Column (DX_i): Must be in dx_list1
- Second Column (DX_{i+1}): Must be in dx_list2
2. The condition for each pair is:
- col('DX_i').isin(dx_list1) OR col('DX_{i+1}').isin(dx_list2)
3. All pairwise conditions are combined using AND
In other words, you want to filter rows where for each pair of columns, at least one of the columns satisfies the corresponding criteria, and all pairs must satisfy this condition.
If the logic is different, please clarify. The below example is based on this logic.
# Sample data with records designed to be filtered out
data = [
# ID, MEMB_KEY, PROV_KEY, CL#, DOS, DX_00 to DX_11
# Records that should be included
(1, 'M001', 'P001', 'CL001', '2022-01-01',
'J984', 'A123', 'J984', 'A123', 'J984', 'A123',
'J984', 'A123', 'J984', 'A123', 'J984', 'A123'),
(2, 'M002', 'P002', 'CL002', '2022-01-02',
'A123', 'J445', 'A123', 'J445', 'A123', 'J445',
'A123', 'J445', 'A123', 'J445', 'A123', 'J445'),
(3, 'M003', 'P003', 'CL003', '2022-01-03',
'J984', 'J454', 'J984', 'J454', 'J984', 'J454',
'J984', 'J454', 'J984', 'J454', 'J984', 'J454'),
(4, 'M004', 'P004', 'CL004', '2022-01-04',
'J465', 'J76', 'J465', 'J76', 'J465', 'J76',
'J465', 'J76', 'J465', 'J76', 'J465', 'J76'),
(5, 'M005', 'P005', 'CL005', '2022-01-05',
'J754', 'J765', 'J754', 'J765', 'J754', 'J765',
'J754', 'J765', 'J754', 'J765', 'J754', 'J765'),
# Records that should be excluded
(6, 'M006', 'P006', 'CL006', '2022-01-06',
'A123', 'A123', 'A123', 'A123', 'A123', 'A123',
'A123', 'A123', 'A123', 'A123', 'A123', 'A123'),
(7, 'M007', 'P007', 'CL007', '2022-01-07',
'J64', 'J487', 'A123', 'A123', 'A123', 'A123',
'A123', 'A123', 'A123', 'A123', 'A123', 'A123'),
(8, 'M008', 'P008', 'CL008', '2022-01-08',
'A123', 'A123', 'J64', 'A123', 'A123', 'A123',
'A123', 'A123', 'A123', 'A123', 'A123', 'A123'),
(9, 'M009', 'P009', 'CL009', '2022-01-09',
'J466', 'J443', 'J466', 'A123', 'A123', 'J443',
'A123', 'A123', 'A123', 'A123', 'A123', 'A123'),
(10, 'M010', 'P010', 'CL010', '2022-01-10',
'J445', 'A123', 'A123', 'J445', 'A123', 'A123',
'J445', 'A123', 'A123', 'A123', 'J445', 'A123'),
]
columns = ["ID", "MEMB_KEY", "PROV_KEY", "CL#", "DOS"] + \
[f'DX_{str(i).zfill(2)}' for i in range(0, 12)]
df = spark.createDataFrame(data, columns)
# Filtering lists
dx_list1 = ['J984', 'J466', 'J754', 'J64', 'J465', 'J445']
dx_list2 = ['J454', 'J445', 'J443', 'J76', 'J487', 'J765']
# Generate DX column names
dx_columns = [f'DX_{str(i).zfill(2)}' for i in range(0, 12)] # DX_00 to DX_11
# Create pairs of columns: (DX_00, DX_01), (DX_02, DX_03), ..., (DX_10, DX_11)
dx_pairs = [(dx_columns[i], dx_columns[i+1]) for i in range(0, len(dx_columns)-1, 2)]
from pyspark.sql.functions import col
from functools import reduce
import operator
# Build list of conditions
pair_conditions = [
(col(dx_i).isin(dx_list1) | col(dx_j).isin(dx_list2))
for dx_i, dx_j in dx_pairs
]
# Combine conditions with AND
final_condition = reduce(operator.and_, pair_conditions)
filtered_df = df.filter(final_condition)
# Show the results
filtered_df.select("ID", *dx_columns).show(truncate=False)

