hiryucodes
Databricks Employee
Databricks Employee

The only way I was able to make this work was to have the custom data source class code in the same file as the DLT pipeline. Like this:

class MyDataSource(DataSource):
    @classmethod
    def name(cls):
        return "mydatasource"

    def schema(self):
        return my_schema

    def reader(self, schema: StructType):
        return MyDataSourceReader(schema, self.options)


class MyDataSourceReader(DataSourceReader):
    def __init__(self, schema: StructType, options: dict):
        self.schema = schema
        self.options = options
        self.api_key = self.options.get("api_key")
        self.base_url = "https:/my/api/url"
        self.timeout = int(self.options.get("timeout", 600))

        json_params = self.options.get("params")
        if json_params:
            try:
                self.params = json.loads(json_params)
            except json.JSONDecodeError:
                raise ValueError("Invalid JSON format for params")
        else:
            self.params = {}

    def read(self, partition):
        headers = {"Accept": "application/json", "X-Api-Key": self.api_key}

        page = 1
        result = []
        next_page = True
        self.params["pageSize"] = 1000
        while next_page:
            params["page"] = page

            response = requests.get(self.base_url, headers=headers, params=self.params, timeout=self.timeout)
            response.raise_for_status()
            data = response.json()

            for record in data:
                result.append(record)

            if len(data) == self.params["pageSize"]:
                page += 1
            else:
                next_page = False

        for item in data:
            yield tuple(item.values())

    def partitions(self):
        from pyspark.sql.datasource import InputPartition

        return [InputPartition(0)]

@Dlt.table(
    name="my_table",
    table_properties={"quality": "bronze"},
)
def dlt_ingestion():
    api_key = os.getenv("API_KEY")
    params = build_query_params()

    spark.dataSource.register(MyDataSource)
    response_df = (
        spark.read.format("mydatasource").option("api_key", api_key).option("params", json.dumps(params)).load()
    )

    return response_df

This is not practical at all as I also want to use the custom data source in other DLT pipelines, which I won't be able to unless I duplicate the code every time.