matching sas proc survey means for quantiles in databricks

Jake3
New Contributor III

Hi, I currently have the following code in databricks that i am using to calculate survey estimates and quantiles. I wish to match (or get as close) to SAS results using proc survey means for quantiles as possible (I am able to match proportions fine). My current databricks code seems to give results of std errors slightly under that of SAS (148 compared to 298 as an example for an estimate of 64300). I was hoping someone might be able to show me how to change my code to better replicate the results of SAS.

def taylor_row_percent5(
    df: Union[pd.DataFrame, SparkDataFrame],
    row_domains: List[str],
    measurestr,
    exclusions: Dict[str, List],
    weight_col: Optional[str= None,
    strata_col: Optional[str= None,
    fpc: Optional[Union[Dict[object], str]] = None,
    p_01_adjust: Optional[float= 1.0,
    zcriticalfloat = 1.96,
    quantiles: Optional[Union[List[float], float]] = None,
    output_sparkbool = False,
    output_delta_table: Optional[str= None,

 


    df = _to_pandas(df)

    if isinstance(quantiles, (floatint)):
        quantiles = [quantiles]

    # ---------------------------- Weight & Strata ----------------------------
    if weight_col is None:
        df["_w"= 1.0
        weight_col = "_w"
    if strata_col is None:
        df["_strata"= "all"
        strata_col = "_strata"
        fpc = None
    if fpc is None:
        df["_N_h"= np.inf
        fpc_col = "_N_h"
    elif isinstance(fpc, dict

 

        df["_N_h"= df[strata_col].map(fpc)
        fpc_col = "_N_h"
    else:
        fpc_col = fpc

    # ---------------------------- Scope flag ----------------------------
    df["scope_fg"= 1
    for var, excl_vals in exclusions.items():
        for val in excl_vals:
            if pd.isna(val):
                df.loc[df[var].isna(), "scope_fg"= 0
            else:
                df.loc[df[var] == val, "scope_fg"= 0

    # ---------------------------- Precompute arrays ----------------------------
    y_all = df[measure].to_numpy(dtype=float)
    w_all = df[weight_col].to_numpy(dtype=float)
    strata_all = df[strata_col].to_numpy()
    fpc_all = df[fpc_col].to_numpy()

    strata_index = {h: np.where(strata_all == h)[0for h in np.unique(strata_all)}

    # ---------------------------- Domain & measure levels ----------------------------
    domain_levels = {c: df.loc[df["scope_fg"== 1, c].dropna().unique() for c in row_domains}
    measure_levels = df.loc[df["scope_fg"== 1, measure].dropna().unique()
    results = []
    domain_rows = list(product(*domain_levels.values()))

    # ---------------------------- Main loop ----------------------------
    for dom_vals in domain_rows:

        mask_row = df["scope_fg"== 1
        for c, v in zip(row_domains, dom_vals):
            mask_row &= (df[c] == v)
        mask_row_np = mask_row.to_numpy()
        n_row = int(mask_row_np.sum())

        # ------------------ QUANTILE OUTPUT ------------------
# ------------------ QUANTILE OUTPUT (SAS-compatible) ------------------
        if quantiles is not None and n_row > 0:

            y = y_all[mask_row_np]
            w = w_all[mask_row_np]
            W_dom = w.sum()

            row = dict(zip(row_domains, dom_vals))
            row["domain_frequency"= n_row

            for q in quantiles:
                # Weighted quantile
                q_hat = weighted_quantile(y, w, q)
                f_hat = weighted_density_finite_diff(y, w, q_hat)

                # Taylor linearization for variance
                I = (y_all <= q_hat).astype(float)
                u = np.zeros(len(df))
                u[mask_row_np] = ((w_all[mask_row_np] / W_dom) * (q - I[mask_row_np]) / f_hat)

                var_q = 0.0
                for h, idx in strata_index.items():
                    u_h = u[idx]
                    n_h = len(u_h)
                    if n_h <= 1:
                        continue
                    N_h = fpc_all[idx][0]
                    f_h = n_h / N_h if np.isfinite(N_h) else 0.0
                    S2 = np.var(u_h, ddof=1)
                    var_q += (1 - f_h) * n_h * S2

                se_q = np.sqrt(var_q) if var_q > 0 else np.nan
                moe = zcritical * se_q

                # SAS-style formatting
                rse = se_q / q_hat if q_hat != 0 else np.nan

                # Suppression & formatting rules
                if n_row == 0:
                    estimate_new = moe_new = lower_new = upper_new = rse_new = "na"
                elif n_row in range(16or np.isnan(se_q):
                    estimate_new = moe_new = lower_new = upper_new = rse_new = "np"
                else:
                    # Round to nearest 100 like SAS
                    estimate_new = round(q_hat, -2)
                    moe_new = round(moe, -2)
                    lower_new = round(max(q_hat - moe, 0), -2)
                    upper_new = round(q_hat + moe, -2)
                    rse_new = round(rse, 2)
                    # Add asterisk if RSE ≥ 0.25
                    if rse >= 0.25:
                        estimate_new = f"{estimate_new}*"

                row.update({
                    "Estimate": q_hat,
                    "StdErr": se_q,
                    "moe": moe,
                    "lower": q_hat - moe,
                    "upper": q_hat + moe,
                    "rse": rse,
                    "estimate_new": estimate_new,
                    "moe_new": moe_new,
                    "lower_new": lower_new,
                    "upper_new": upper_new,
                    "rse_new": rse_new,
                })

            results.append(row)
            continue  # Skip row percent loop

        # ------------------ ROW PERCENT OUTPUT ------------------
        for measure_val in measure_levels:

            mask_cell = mask_row_np & (y_all == measure_val)
            n_cell = int(mask_cell.sum())

            if n_row == 0:
                p_hat = se_p = np.nan
            else:
                W_row = w_all[mask_row_np].sum()
                W_cell = w_all[mask_cell].sum()
                p_hat = W_cell / W_row if W_row > 0 else np.nan

                R = mask_row_np.astype(float)
                C = mask_cell.astype(float)
                u = (w_all / W_row) * (C - p_hat * R)

                var_p = 0.0
                for h, idx in strata_index.items():
                    u_h = u[idx]
                    n_h = len(u_h)
                    if n_h <= 1:
                        continue
                    N_h = fpc_all[idx][0]
                    f_h = n_h / N_h if np.isfinite(N_h) else 0.0
                    S2 = np.var(u_h, ddof=1)
                    var_p += (1 - f_h) * n_h * S2

                se_p = np.sqrt(var_p) if var_p > 0 else np.nan

            # Agresti-Coull adjustment
            adj_p = p_hat
            adj_se = se_p
            if (
                p_01_adjust is not None
                and not np.isnan(p_hat)
                and p_hat in (0.01.0)
                and n_row > 0
            

 

                adj_p = (n_cell + 2/ (n_row + 4)
                adj_se = np.sqrt(adj_p * (1 - adj_p) / (n_row + 4* p_01_adjust)

            rowpercent = p_hat * 100 if not np.isnan(p_hat) else np.nan
            rowstderr = se_p * 100 if not np.isnan(se_p) else np.nan
            adj_rowpercent = adj_p * 100 if not np.isnan(adj_p) else np.nan
            adj_rowstderr = adj_se * 100 if not np.isnan(adj_se) else np.nan
            moe = zcritical * adj_rowstderr if not np.isnan(adj_rowstderr) else np.nan
            lower = rowpercent - moe if not np.isnan(moe) else np.nan
            upper = rowpercent + moe if not np.isnan(moe) else np.nan
            rse = (adj_rowstderr / adj_rowpercent * 100 if adj_rowpercent not in (0, np.nan) and not np.isnan(adj_rowpercent) else np.nan)

            percent_new = f"{rowpercent:4.1f}" if not np.isnan(rowpercent) else ""
            moe_new = f"{moe:4.1f}" if not np.isnan(moe) else ""
            lower_new = f"{max(lower,0):4.1f}" if not np.isnan(lower) else ""
            upper_new = f"{min(upper,100):4.1f}" if not np.isnan(upper) else ""
            rse_new = f"{rse:4.1f}" if not np.isnan(rse) else ""
            if not np.isnan(moe) and moe >= 10:
                percent_new += "*"
            if n_row == 0:
                percent_new = moe_new = lower_new = upper_new = rse_new = "na"
            elif n_row in range(1,6

 

                percent_new = moe_new = lower_new = upper_new = rse_new = "np"

            row = dict(zip(row_domains, dom_vals))
            row[measure] = measure_val
            row.update({
                "domain_frequency": n_row,
                "Frequency": n_cell,
                "RowPercent": rowpercent,
                "RowStdErr": rowstderr,
                "adj_rowpercent": adj_rowpercent,
                "adj_rowstderr": adj_rowstderr,
                "moe": moe,
                "lower": lower,
                "upper": upper,
                "rse": rse,
                "percent_new": percent_new,
                "moe_new": moe_new,
                "lower_new": lower_new,
                "upper_new": upper_new,
                "rse_new": rse_new,
            })
            results.append(row)

    result_pdf = pd.DataFrame(results)

    if output_delta_table:
        (
            spark.createDataFrame(result_pdf)
            .write
            .format("delta")
            .mode("overwrite")
            .option("overwriteSchema""true")
            .saveAsTable(output_delta_table)
        )

    if output_spark or output_delta_table:
        return spark.createDataFrame(result_pdf)

    return result_pdf

This is the SAS macro i am trying ro replicate the results of: 
%macro survey_estimates(survey_data, out_data, analysis_var, domain_varlist, scope_defn, stratum_var, weight_var, popcount_data = , quantile_est = , rogs_suppression = F, p_01_adjust = 1.0, proportion = , critical_value = 1.96, format_list = );
/*** create a macro variable called domain_varlist_ast that includes domain variable names separated by asterixes (for use in table statement) ***/
%if %sysevalf(%superq(domain_varlist) ne , boolean) %then
%do;
 
data _null_;
space_delim = "&domain_varlist";
num_domain_vars = countw("&domain_varlist", ' ');
call symputx("num_domain_vars", num_domain_vars); /* create macro variable called num_domain_vars that contains number of domain variables */
run;
 
data _null_;
length domain_varlist_ast $ 200;
domain_varlist_ast = "";
 
do i = 1 to &num_domain_vars;
domain_i = scan("&domain_varlist", i, ' ');
domain_varlist_ast = cats(domain_varlist_ast, cats(domain_i, " * "));
end;
 
call symput("domain_varlist_ast", domain_varlist_ast); /* create macro variable containing string required to produce estimation table in proc surveyfreq */
run;
 
%end;
%else %let domain_varlist_ast =;
 
data analysis;
set &survey_data;
 
if &scope_defn then
scope = 1;
else scope = 0;
run;
 
%if %sysevalf(%superq(quantile_est) eq , boolean) %then
%do;
/* estimate proportion */
/*** calculate estimates ***/
%if %sysevalf(%superq(popcount_data) ne , boolean) %then
%do;
 
proc surveyfreq data = analysis total = &popcount_data missing; /* finite population correction */
%end;
%else
%do;
 
proc surveyfreq data = analysis missing; /* no finite population correction */
%end;
 
table scope * &domain_varlist_ast &analysis_var / row /*nototal*/;
strata &stratum_var;
weight &weight_var;
ods output crosstabs = out_data (where = (scope eq 1));
run;
 
/*** add column with total number of respondents in each domain ***/
/* reverse order of observations */
data outdata_reverse;
do i = out_nobs to 1 by -1;
set out_data nobs = out_nobs point = i;
output;
end;
 
stop;
run;
 
/* create new variable with number of respondents */
data outdata_reverse;
retain domain_frequency 0;
set outdata_reverse;
 
if _skipline eq "1" then
domain_frequency = frequency;
run;
 
/* reverse back order of observations */
data out_data;
do i = out_nobs to 1 by -1;
set outdata_reverse nobs = out_nobs point = i;
output;
end;
 
stop;
run;
 
/*** correct standard errors for estimates of proportion equal to 0/1 ***/
data out_data;
set out_data;
adj_rowpercent = rowpercent;
adj_rowstderr = rowstderr;
 
if frequency eq 0 and domain_frequency gt 0 then rowpercent = 0; /* "true" zero */
 
%if %sysevalf(%superq(p_01_adjust) ne , boolean) %then
%do;
if rowpercent in (0, 100) then
do;
adj_p = (frequency + 2) / (domain_frequency + 4); /* adjusted estimate of proportion */
adj_rowpercent = adj_p * 100; /* adjusted percent */
adj_rowstderr = sqrt( adj_p * (1 - adj_p) / (domain_frequency + 4) * &p_01_adjust) * 100; /* adjusted standard error */
end;
 
drop adj_p;
%end;
run;
 
/*** apply formats and suppression rules to estimates ***/
data out_data;
set out_data;
 
if not missing(rowpercent) then
do;
moe = &critical_value * adj_rowstderr; /* use adjusted values if requested */
lower = rowpercent - moe;
upper = rowpercent + moe;
rse = adj_rowstderr / adj_rowpercent * 100; /* use adjusted values if requested, and express as a percent */
end;
else
do;
moe = .;
lower = .;
upper = .;
rse = .;
end;
 
%if %sysevalf(%superq(proportion) eq T, boolean) %then
%do;
percent_new = put(rowpercent/100, 7.3); /* uses unadjusted percent for estimates */
moe_new = put(moe/100, 7.3);
lower_new = put(max(lower, 0)/100, 7.3);
upper_new = put(min(upper, 100)/100, 7.3);
rse_new = put(rse/100, 7.3);
%end;
%else
%do;
percent_new = put(rowpercent, 7.1); /* uses unadjusted percent for estimates */
moe_new = put(moe, 7.1);
lower_new = put(max(lower, 0), 7.1);
upper_new = put(min(upper, 100), 7.1);
rse_new = put(rse, 7.1);
%end;
 
%if %sysevalf(&rogs_suppression eq T, boolean) %then
%do;
/* suppression rules used for RoGS */
%if %sysevalf(%superq(p_01_adjust) eq , boolean) %then
%do;
if rowpercent in (0, 100) then
do;
moe_new = "0";
lower_new = percent_new;
upper_new = percent_new;
rse_new = "0";
end;
%end;
 
if domain_frequency eq 0 then
do;
percent_new = "na";
moe_new = "na";
lower_new = "na";
upper_new = "na";
rse_new = "na";
end;
else if domain_frequency in (1:5) | missing(adj_rowstderr) then
do;
 
percent_new = "np";
moe_new = "np";
lower_new = "np";
upper_new = "np";
rse_new = "np";
end;
%end;
%else
%do;
/* suppression rules used for SOS pubs and VA */
%if %sysevalf(%superq(p_01_adjust) eq , boolean) %then
%do;
if rowpercent in (0, 100) then
do;
moe_new = "na";
lower_new = "na";
upper_new = "na";
end;
%end;
 
if domain_frequency eq 0 then
do;
percent_new = "na";
moe_new = "na";
lower_new = "na";
upper_new = "na";
rse_new = "na";
end;
else if domain_frequency in (1:5) then
do;
percent_new = "np";
moe_new = "np";
lower_new = "np";
upper_new = "np";
rse_new = "np";
end;
else if moe ge 10 then
percent_new = cats(percent_new, "*");
%end;
run;
 
/*** clean-up ***/
data &out_data;
set out_data;
 
/* apply formats */
%if %sysevalf(%superq(format_list) ne , boolean) %then
%do;
format &format_list.;
%end;
 
/* trim leading blanks */
array trim_vars {*} percent_new moe_new lower_new upper_new rse_new;
 
do i = 1 to dim(trim_vars);
trim_vars{i} = left(trim_vars{i});
end;
 
keep &analysis_var &domain_varlist adj_p adj_rowpercent adj_rowstderr rowpercent percent_new rse rse_new rowstderr moe moe_new lower lower_new upper upper_new frequency domain_frequency;
run;
 
proc datasets;
delete analysis out_data outdata_reverse;
quit;
 
%end;
%else
%do;
/* estimate median/quantile */
/*** calculate estimates ***/
%if %sysevalf(%superq(popcount_data) ne , boolean) %then
%do;
 
proc surveymeans data = analysis total = &popcount_data quantile = (&quantile_est.) missing nobs;
%end;
%else
%do;
 
proc surveymeans data = analysis quantile = (&quantile_est.) missing nobs;
%end;
 
domain &domain_varlist_ast scope;
var &analysis_var;
strata &stratum_var;
weight &weight_var;
ods output domainquantiles = out_data (where = (scope eq 1));
run;
 
/* add column with total number of respondents in each domain */
proc freq data = analysis noprint;
tables &domain_varlist_ast scope/ out = dom_freq (where = (scope eq 1));
run;
 
proc sort data = out_data;
by &domain_varlist scope;
 
proc sort data = dom_freq;
by &domain_varlist scope;
run;
 
data out_data;
merge out_data dom_freq (drop = percent rename = (count = domain_frequency));
by groupformat &domain_varlist scope;
run;
 
/*** apply formats and suppression rules to estimates ***/
data out_data;
set out_data;
 
if not missing(estimate) then
do;
moe = &critical_value * stderr;
lower = estimate - moe;
upper = estimate + moe;
rse = stderr / estimate;
end;
else
do;
moe = .;
lower = .;
upper = .;
rse = .;
end;
 
estimate_new = put(round(estimate, 100), 12.0);
moe_new = put(round(moe, 100), 12.0);
upper_new = put(round(upper, 100), 12.0);
lower_new = put(round(max(lower, 0), 100), 12.0);
rse_new = put(rse, 12.2);
 
if domain_frequency eq 0 then
do;
estimate_new = "na";
moe_new = "na";
lower_new = "na";
upper_new = "na";
rse_new = "na";
end;
else if domain_frequency in (1:5) | missing(stderr)then
do;
 
estimate_new = "np";
moe_new = "np";
lower_new = "np";
upper_new = "np";
rse_new = "np";
end;
else if rse ge 0.25 then
estimate_new = cats(estimate_new, "*");
run;
 
/*** clean-up ***/
data &out_data;
set out_data;
 
/* apply formats */
%if %sysevalf(%superq(format_list) ne , boolean) %then
%do;
format &format_list.;
%end;
 
/* trim leading blanks */
array trim_vars {*} estimate_new moe_new lower_new upper_new rse_new;
 
do i = 1 to dim(trim_vars);
trim_vars{i} = left(trim_vars{i});
end;
 
keep &domain_varlist adj_p adj_rowpercent adj_rowstderr estimate estimate_new rse rse_new stderr moe moe_new lower lower_new upper upper_new domain_frequency;
run;
 
proc datasets;
delete analysis out_data dom_freq;
quit;
 
%end;
%mend survey_estimates;