cancel
Showing results forย 
Search instead forย 
Did you mean:ย 
Machine Learning
Dive into the world of machine learning on the Databricks platform. Explore discussions on algorithms, model training, deployment, and more. Connect with ML enthusiasts and experts.
cancel
Showing results forย 
Search instead forย 
Did you mean:ย 

FeatureEngineeringClient loses timestamp_keys after write_table

VincentP
New Contributor III

I am trying to use the FeatureEngineeringClient to setup a feature store table with a time series component. However, after initiating the table with a time series column, the key exists, but the key is removed after adding data to the table. Therefore no point-in-time joins can be performed after adding data.

Minimal example:

query_prd_schema = types.StructType([
    types.StructField("timestamp", types.TimestampType(), True),
    types.StructField("search_query", types.StringType(), True),
    types.StructField("prd_id", types.StringType(), True),
    types.StructField("start_agg_window", types.DateType(), True),
    types.StructField("end_agg_window", types.DateType(), True),
    types.StructField("a2c_rate", types.FloatType(), True),
])

fe.create_table(
    name=table_name,
    primary_keys=["search_query", "prd_id", 'timestamp'],
    timeseries_columns=['timestamp'],
    schema=query_prd_schema,
    description="Features aggregated on query, prd level"
)

The code below still returns a list with the timestamp col, but after doing the following:

fe.write_table(
    name=table_name,
    df=a2c_rates,
    mode="merge"
)

The same query returns an empty list. So after calling create_training_set with a FeatureLookup that has a  timestamp_lookup_key, this results in errors.

The a2c_rates table has the following types:

[('timestamp', 'timestamp'),
 ('search_query', 'string'),
 ('prd_id', 'string'),
 ('a2c_rate', 'float'),
 ('start_agg_window', 'date'),
 ('end_agg_window', 'date')]

So that should be fine, right?

Code to produce the a2c_rates table:

# Start with GA data for one day
preprocessed = spark.createDataFrame(
    pd.DataFrame(
        {
            'hitDate': [
                '2023-01-01', '2023-01-01', '2023-01-01', '2023-01-01', '2023-01-01',
                '2023-01-01', '2023-01-01', '2023-01-01', '2023-01-01', '2023-01-01',
                '2023-01-01', '2023-01-01', '2023-01-01', '2023-01-01', '2023-01-01',
                ],
            'emda_id': [
                'emda1', 'emda1', 'emda1', 'emda1', 'emda1',
                'emda2', 'emda2', 'emda2', 'emda2', 'emda2',
                'emda3', 'emda3', 'emda3', 'emda3', 'emda3',
                ],
            'interaction_type': [
                'search', 'impressed', 'impressed', 'impressed', 'added_to_basket',
                'search', 'impressed', 'impressed', 'impressed', 'added_to_basket',
                'search', 'impressed', 'impressed', 'impressed', 'added_to_basket',
                ], 
            'search_query': [
                'pizza', 'pizza', 'pizza', 'pizza', 'pizza',
                'pizza', 'pizza', 'pizza', 'pizza', 'pizza',
                'pizza', 'pizza', 'pizza', 'pizza', 'pizza',
                ],
            'prd_id': [
                None, '001', '002', '003', '001', 
                None, '001', '002', '003', '002', 
                None, '001', '002', '003', '001', 
            ],
        }
    )
)

agg_impressions = (preprocessed
     .filter(F.col('interaction_type') == 'impressed')
     .groupBy('prd_id', 'search_query', window("hitDate", "1 day"))
     .count()
     .withColumnRenamed('count', 'impr_count')
     .withColumn('start_agg_window', F.to_date(F.col('window.start')))
     .withColumn('end_agg_window', F.date_sub(F.to_date(F.col('window.end')), 1))
)

agg_adds = (preprocessed
            .filter(F.col('interaction_type') == 'added_to_basket')
            .groupBy('search_query', 'prd_id', window("hitDate", "1 day"))
            .count()
            .withColumnRenamed('count', 'add_count')
            .withColumn('start_agg_window', F.to_date(F.col('window.start')))
            .withColumn('end_agg_window', F.date_sub(F.to_date(F.col('window.end')), 1))
)

a2c_rates = (agg_impressions
             .join(agg_adds, how='left', on=['search_query', 'prd_id', 'start_agg_window', 'end_agg_window'])
             .withColumn('a2c_rate', F.col('add_count') / F.col('impr_count'))
             .withColumn('a2c_rate', F.col('a2c_rate').cast('float'))
             .fillna(0, subset=['a2c_rate'])
             .withColumn('timestamp', (F.unix_timestamp(F.col('end_agg_window')) + (60*60*24)+60*60*7).cast('timestamp'))
             .select('timestamp', 'search_query', 'prd_id', 'a2c_rate', 'start_agg_window', 'end_agg_window')
)

a2c_rates = a2c_rates.withColumn('timestamp', F.to_timestamp(F.col('timestamp')))
1 ACCEPTED SOLUTION

Accepted Solutions

VincentP
New Contributor III

Needed to update to runtime 13.3 ...

View solution in original post

1 REPLY 1

VincentP
New Contributor III

Needed to update to runtime 13.3 ...

Join Us as a Local Community Builder!

Passionate about hosting events and connecting people? Help us grow a vibrant local communityโ€”sign up today to get started!

Sign Up Now