optimizing my databricks code

Jake3
New Contributor III

I have the following code in databricks under serverless and i want to know how to improve it to make it more efficient and run faster without having the results change (standard errors change slightly when i try to improve it): 

# Databricks Serverless - Taylor Row Percent
# ============================================================
# Runs on Databricks Serverless (pandas on Spark / pandas UDF pattern).
# The core estimation logic is kept in pandas because Taylor
# linearisation and survey-variance math has no native Spark equivalent.
#
# Usage
# -----
# Input  : pass a pandas DataFrame directly, OR a Spark DataFrame
#          (it will be converted automatically via .toPandas()).
# Output : pandas DataFrame (call display() or wrap in spark.createDataFrame()
#          if you need a Spark / Delta result).
# ============================================================

from __future__ import annotations

import numpy as np
import pandas as pd
from itertools import product
from typing import Dict, List, Optional, Union

# Databricks / PySpark imports (available on every Serverless cluster)
from pyspark.sql import DataFrame as SparkDataFrame


# ------------------------------------------------------------------
# Helper: accept either a Spark or pandas DataFrame
# ------------------------------------------------------------------
def _to_pandas(df: Union[pd.DataFrame, SparkDataFrame]) -> pd.DataFrame:
    """Convert Spark DataFrame → pandas if needed."""
    if isinstance(df, SparkDataFrame):
        return df.toPandas()
    return df.copy()


# ------------------------------------------------------------------
# Main function
# ------------------------------------------------------------------
def taylor_row_percent(
    df: Union[pd.DataFrame, SparkDataFrame],
    row_domains: List[str],
    measure: str,
    exclusions: Dict[str, List],
    weight_col: Optional[str] = None,
    strata_col: Optional[str] = None,
    fpc: Optional[Union[Dict[object, float], str]] = None,
    p_01_adjust: Optional[float] = 1.0,
    zcritical: float = 1.96,
    output_spark: bool = False,          # set True to get a Spark DataFrame back
    output_delta_table: Optional[str] = None,  # e.g. "catalog.schema.table"
) -> Union[pd.DataFrame, SparkDataFrame]:
    """
    Taylor-linearisation row-percent estimator.

    Parameters
    ----------
    df : pandas or Spark DataFrame
    row_domains : columns that define the row grouping
    measure : column whose value distribution is estimated
    exclusions : {col: [values_to_exclude]}  – NaN supported
    weight_col : survey weight column (None → unweighted)
    strata_col : stratification column (None → single stratum)
    fpc : finite-population correction – dict {stratum: N} or column name
    p_01_adjust : Agresti-Coull scaling factor for p=0/1 cells
    zcritical : z-value for confidence interval (default 1.96)
    output_spark : if True, return a Spark DataFrame instead of pandas
    output_delta_table : if set, write result to this Delta table and return it
    """

    # ------------------------------------------------------------------
    # 0. Normalise input
    # ------------------------------------------------------------------
    df = _to_pandas(df)

    # ------------------------------------------------------------------
    # 1. Weight handling
    # ------------------------------------------------------------------
    if weight_col is None:
        df["_w"] = 1.0
        weight_col = "_w"

    # ------------------------------------------------------------------
    # 2. Strata handling
    # ------------------------------------------------------------------
    if strata_col is None:
        df["_strata"] = "all"
        strata_col = "_strata"
        fpc = None

    # ------------------------------------------------------------------
    # 3. FPC handling
    # ------------------------------------------------------------------
    if fpc is None:
        df["_N_h"] = np.inf
        fpc_col = "_N_h"
    elif isinstance(fpc, dict😞
        df["_N_h"] = df[strata_col].map(fpc)
        fpc_col = "_N_h"
    else:
        fpc_col = fpc

    # ------------------------------------------------------------------
    # 4. Scope flag  (mirrors SAS exclusion logic)
    # ------------------------------------------------------------------
    df["scope_fg"] = 1
    for var, excl_vals in exclusions.items():
        for val in excl_vals:
            if pd.isna(val):
                df.loc[df[var].isna(), "scope_fg"] = 0
            else:
                df.loc[df[var] == val, "scope_fg"] = 0

    # ------------------------------------------------------------------
    # 5. Domain & measure levels  (in-scope records only)
    # ------------------------------------------------------------------
    domain_levels = {
        c: df.loc[df["scope_fg"] == 1, c].dropna().unique()
        for c in row_domains
    }
    measure_levels = df.loc[df["scope_fg"] == 1, measure].dropna().unique()
    all_rows = list(product(*domain_levels.values(), measure_levels))

    results = []

    # ------------------------------------------------------------------
    # 6. Main estimation loop
    # ------------------------------------------------------------------
    for combo in all_rows:

        dom_vals    = combo[:-1]
        measure_val = combo[-1]

        # Domain mask (scope == 1 only)
        mask_row = df["scope_fg"] == 1
        for c, v in zip(row_domains, dom_vals):
            mask_row &= (df[c] == v)

        mask_cell = mask_row & (df[measure] == measure_val)

        n_row  = int(mask_row.sum())
        n_cell = int(mask_cell.sum())

        # ---- domain empty ----------------------------------------
        if n_row == 0:
            p_hat = np.nan
            se_p  = np.nan

        else:
            W_row  = df.loc[mask_row,  weight_col].sum()
            W_cell = df.loc[mask_cell, weight_col].sum()

            p_hat = W_cell / W_row if W_row > 0 else np.nan

            # Taylor linearisation
            df2 = df.copy()
            df2["R"] = mask_row.astype(float)
            df2["C"] = mask_cell.astype(float)
            df2["u"] = (df2[weight_col] / W_row) * (
                df2["C"] - p_hat * df2["R"]
            )

            var_p = 0.0
            for h, g in df2.groupby(strata_col):
                n_h = len(g)
                if n_h <= 1:
                    continue
                N_h = g[fpc_col].iloc[0]
                f_h = n_h / N_h if np.isfinite(N_h) else 0.0
                S2  = g["u"].var(ddof=1)
                var_p += (1 - f_h) * n_h * S2

            se_p = np.sqrt(var_p) if var_p > 0 else np.nan

        # ---- Agresti-Coull  (SAS exact logic) --------------------
        adj_p  = p_hat
        adj_se = se_p

        if (
            p_01_adjust is not None
            and not np.isnan(p_hat)
            and p_hat in (0.0, 1.0)
            and n_row > 0
        😞
            adj_p  = (n_cell + 2) / (n_row + 4)
            adj_se = np.sqrt(
                adj_p * (1 - adj_p) / (n_row + 4) * p_01_adjust
            )

        # ---- Final statistics ------------------------------------
        rowpercent    = p_hat * 100 if not np.isnan(p_hat) else np.nan
        rowstderr     = se_p  * 100 if not np.isnan(se_p)  else np.nan
        adj_rowpercent = adj_p  * 100 if not np.isnan(adj_p)  else np.nan
        adj_rowstderr  = adj_se * 100 if not np.isnan(adj_se) else np.nan

        moe   = zcritical * adj_rowstderr if not np.isnan(adj_rowstderr) else np.nan
        lower = rowpercent - moe          if not np.isnan(moe)           else np.nan
        upper = rowpercent + moe          if not np.isnan(moe)           else np.nan
        rse   = (
            (adj_rowstderr / adj_rowpercent) * 100
            if adj_rowpercent not in (0, np.nan) and not np.isnan(adj_rowpercent)
            else np.nan
        )

        # ---- Formatting  (SAS style) -----------------------------
        percent_new = f"{rowpercent:4.1f}" if not np.isnan(rowpercent) else ""
        moe_new     = f"{moe:4.1f}"        if not np.isnan(moe)        else ""
        lower_new   = f"{max(lower, 0):4.1f}"   if not np.isnan(lower) else ""
        upper_new   = f"{min(upper, 100):4.1f}" if not np.isnan(upper) else ""
        rse_new     = f"{rse:4.1f}"        if not np.isnan(rse)        else ""

        # Star rule
        if not np.isnan(moe) and moe >= 10:
            percent_new += "*"

        # ---- Suppression  (SOS rules) ----------------------------
        if n_row == 0:
            percent_new = moe_new = lower_new = upper_new = rse_new = "na"
        elif n_row in range(1, 6😞
            percent_new = moe_new = lower_new = upper_new = rse_new = "np"

        # ---- Store row -------------------------------------------
        row = dict(zip(row_domains, dom_vals))
        row[measure] = measure_val
        row.update({
            "domain_frequency": n_row,
            "Frequency":        n_cell,
            "RowPercent":       rowpercent,
            "RowStdErr":        rowstderr,
            "adj_rowpercent":   adj_rowpercent,
            "adj_rowstderr":    adj_rowstderr,
            "moe":              moe,
            "lower":            lower,
            "upper":            upper,
            "rse":              rse,
            "percent_new":      percent_new,
            "moe_new":          moe_new,
            "lower_new":        lower_new,
            "upper_new":        upper_new,
            "rse_new":          rse_new,
        })
        results.append(row)

    # ------------------------------------------------------------------
    # 7. Build output
    # ------------------------------------------------------------------
    result_pdf = pd.DataFrame(results)

    # Optional: write to Delta table
    if output_delta_table:
        (
            spark.createDataFrame(result_pdf)   # noqa: F821  # spark injected by Databricks
                 .write
                 .format("delta")
                 .mode("overwrite")
                 .option("overwriteSchema", "true")
                 .saveAsTable(output_delta_table)
        )
        print(f"Results written to Delta table: {output_delta_table}")

    # Optional: return Spark DataFrame
    if output_spark or output_delta_table:
        return spark.createDataFrame(result_pdf)  # noqa: F821

    return result_pdf