Hi @abueno ,
As I understand the logic you want to implement is:
1. For every pair of columns:
- First Column (DX_i): Must be in dx_list1
- Second Column (DX_{i+1}): Must be in dx_list2
2. The condition for each pair is:
- col('DX_i').isin(dx_list1) OR col('DX_{i+1}').isin(dx_list2)
3. All pairwise conditions are combined using AND
In other words, you want to filter rows where for each pair of columns, at least one of the columns satisfies the corresponding criteria, and all pairs must satisfy this condition.
If the logic is different, please clarify. The below example is based on this logic.
# Sample data with records designed to be filtered out
data = [
# ID, MEMB_KEY, PROV_KEY, CL#, DOS, DX_00 to DX_11
# Records that should be included
(1, 'M001', 'P001', 'CL001', '2022-01-01',
'J984', 'A123', 'J984', 'A123', 'J984', 'A123',
'J984', 'A123', 'J984', 'A123', 'J984', 'A123'),
(2, 'M002', 'P002', 'CL002', '2022-01-02',
'A123', 'J445', 'A123', 'J445', 'A123', 'J445',
'A123', 'J445', 'A123', 'J445', 'A123', 'J445'),
(3, 'M003', 'P003', 'CL003', '2022-01-03',
'J984', 'J454', 'J984', 'J454', 'J984', 'J454',
'J984', 'J454', 'J984', 'J454', 'J984', 'J454'),
(4, 'M004', 'P004', 'CL004', '2022-01-04',
'J465', 'J76', 'J465', 'J76', 'J465', 'J76',
'J465', 'J76', 'J465', 'J76', 'J465', 'J76'),
(5, 'M005', 'P005', 'CL005', '2022-01-05',
'J754', 'J765', 'J754', 'J765', 'J754', 'J765',
'J754', 'J765', 'J754', 'J765', 'J754', 'J765'),
# Records that should be excluded
(6, 'M006', 'P006', 'CL006', '2022-01-06',
'A123', 'A123', 'A123', 'A123', 'A123', 'A123',
'A123', 'A123', 'A123', 'A123', 'A123', 'A123'),
(7, 'M007', 'P007', 'CL007', '2022-01-07',
'J64', 'J487', 'A123', 'A123', 'A123', 'A123',
'A123', 'A123', 'A123', 'A123', 'A123', 'A123'),
(8, 'M008', 'P008', 'CL008', '2022-01-08',
'A123', 'A123', 'J64', 'A123', 'A123', 'A123',
'A123', 'A123', 'A123', 'A123', 'A123', 'A123'),
(9, 'M009', 'P009', 'CL009', '2022-01-09',
'J466', 'J443', 'J466', 'A123', 'A123', 'J443',
'A123', 'A123', 'A123', 'A123', 'A123', 'A123'),
(10, 'M010', 'P010', 'CL010', '2022-01-10',
'J445', 'A123', 'A123', 'J445', 'A123', 'A123',
'J445', 'A123', 'A123', 'A123', 'J445', 'A123'),
columns = ["ID", "MEMB_KEY", "PROV_KEY", "CL#", "DOS"] + \
[f'DX_{str(i).zfill(2)}' for i in range(0, 12)]
df = spark.createDataFrame(data, columns)
# Filtering lists
dx_list1 = ['J984', 'J466', 'J754', 'J64', 'J465', 'J445']
dx_list2 = ['J454', 'J445', 'J443', 'J76', 'J487', 'J765']
# Generate DX column names
dx_columns = [f'DX_{str(i).zfill(2)}' for i in range(0, 12)] # DX_00 to DX_11
# Create pairs of columns: (DX_00, DX_01), (DX_02, DX_03), ..., (DX_10, DX_11)
dx_pairs = [(dx_columns[i], dx_columns[i+1]) for i in range(0, len(dx_columns)-1, 2)]
from pyspark.sql.functions import col
from functools import reduce
import operator
# Build list of conditions
pair_conditions = [
(col(dx_i).isin(dx_list1) | col(dx_j).isin(dx_list2))
for dx_i, dx_j in dx_pairs
# Combine conditions with AND
final_condition = reduce(operator.and_, pair_conditions)
filtered_df = df.filter(final_condition)
# Show the results
filtered_df.select("ID", *dx_columns).show(truncate=False)