Hackernoon logoAutomating a Machine Learning Workflow using Google BigQuery and Amazon Managed Apache Airflow by@yi

Automating a Machine Learning Workflow using Google BigQuery and Amazon Managed Apache Airflow

Yi Ai Hacker Noon profile picture

@yiYi Ai

Amazon announced the availability of Amazon Managed Workflows for Apache Airflow (MWAA), a fully managed service that makes it easy to run Apache Airflow on AWS and to build data processing workflows in the cloud.

Apache Airflow is an open-source tool used to programmatically author, schedule, and monitor sequences of processes and tasks referred to as “workflows”.
Amazon Personalize enables developers to build applications with the same machine learning (ML) technology used by Amazon.com for real-time personalized recommendations — no ML expertise required.

This article shows how we can build and manage an ML workflow using Google BigQuery, Amazon MWAA, and Amazon Personalize. We’ll build a session-based recommender system to predict the most popular items for an e-commerce website based on the traffic data of the product pages tracked by Google Analytics.

BigQuery is an enterprise data warehouse that solves this problem by enabling super-fast SQL queries using the processing power of Google’s infrastructure.

High-Level Solution

We’ll start by extracting the data, transforming the data, and building, training, deploying a solution version (a trained Amazon Personalize recommendation model) and deploying a campaign.

These tasks will be plugged into a workflow that can be orchestrated and automated through Apache Airflow integration with Amazon Personalize and Google BigQuery.

The diagram below represents the workflow we’ll implement for building the recommender system:

The workflow consists of the following tasks:

Data Preparation

  • Export session and hit data from a Google Analytics 360 account to BigQuery, use SQL to query Analytics data into Pandas data frame with Personalize format, and then write data frame to CSV file directly to S3.

Amazon Personalize Solution

  • Create a Personalize dataset group if it doesn’t exist.
  • Create an Interaction schema for our data if the schema doesn’t exist.
  • Create an ‘Interactions’ dataset type if it doesn’t exist.
  • Attach an Amazon S3 policy to your Amazon Personalize role if it doesn’t exist.
  • Create a Personalize role that has the right permissions if it doesn’t exist.
  • Create your Dataset import jobs.
  • Create / Update Solution.
  • Create/ Update Campaign.

Before implementing the solution, you have to create an Airflow Environment Using Amazon MWAA;

“extra packages” should be included while creating an environment, please don’t include BigQuery Client in the below

requirements.txt
, we will install BigQuery Client in the next step:

boto >= 2.49.0
httplib2
awswrangler
google-api-python-client

When the new Airflow environment is ready to be used, attach the Personalize policies to the IAM role of your environment, run CLI as below:

$aws iam put-role-policy - role-name AmazonMWAA-MyAirflowEnvironment-em53Wv - policy-name AirflowPersonalizePolicy - policy-document file://airflowPersonalizePolicy.json
{
    "Version": "2012-10-17",
    "Id": "AirflowPersonalizePolicy",
    "Statement": [
        {
            "Sid": "PersonalizeAccessPolicy",
            "Effect": "Allow",
            "Action": [
                "personalize:*"
            ],
            "Resource": [
                "*"
            ]
        },
        {
            "Sid": "S3BucketAccessPolicy",
            "Effect": "Allow",
            "Action": [
                "s3:PutObject",
                "s3:PutBucketPolicy"
            ],
            "Resource": [
                "arn:aws:s3:::airflow-demo-personalise",
                "arn:aws:s3:::airflow-demo-personalise/*"
            ]
        },
        {
            "Sid": "IAMPolicy",
            "Effect": "Allow",
            "Action": [
                "iam:CreateRole",
                "iam:AttachRolePolicy"
            ],
            "Resource": [
                "*"
            ]
        },
        {
            "Effect": "Allow",
            "Action": [
                "iam:GetRole",
                "iam:PassRole"
            ],
            "Resource": "arn:aws:iam::*:role/PersonalizeS3Role-*"
        }
    ]
}

Installing Google Bigquery Client

MWAA currently doesn't support Google Cloud Bigquery client (

google-cloud-bigquery
) and
pandas-gbq 
with
grpc > 1.20
 .

We are not able to install Bigquery Client through

requirements.txt
, if you put the above dependencies into requiremnts.txt, pip installation won’t install any dependencies in
requirements.txt
, and you will meet error
No module named “httplib2”
when running DAG.

To resolve this issue, we can:

  1. package the required Google libraries in local computer and upload to S3,
  2. and then download them to Airflow workers when the Bigquery export task started,
  3. after that we can dynamically import required modules given the full file path. 

I created a bash file and requirements.txt for the above steps; run the following command:

$bash setup.sh

setup.sh

#!/bin/bash

virtualenv -p python3.7 venv

source venv/bin/activate

pip install -r requirements.txt 

cd venv/lib/python3.7

zip -r site-packages.zip site-packages/

mv site-packages.zip ../../../site-packages.zip

cd ../../../

deactivate

aws s3 cp site-packages.zip s3://airflow-demo-personalise/

requirements.txt

cachetools==4.2.0
certifi==2020.12.5
cffi==1.14.4
chardet==3.0.4
google-api-core==1.24.0
google-auth==1.24.0
google-auth-oauthlib==0.4.2
google-cloud-bigquery==2.6.1
google-cloud-bigquery-storage==2.1.0
google-cloud-core==1.5.0
google-crc32c==1.1.0
google-resumable-media==1.2.0
googleapis-common-protos==1.52.0
grpcio==1.34.0
idna==2.10
libcst==0.3.15
mypy-extensions==0.4.3
numpy==1.19.4
oauthlib==3.1.0
pandas==1.1.5
pandas-gbq==0.14.1
proto-plus==1.13.0
protobuf==3.14.0
pyarrow==2.0.0
pyasn1==0.4.8
pyasn1-modules==0.2.8
pycparser==2.20
pydata-google-auth==1.1.0
python-dateutil==2.8.1
pytz==2020.4
PyYAML==5.3.1
requests==2.25.0
requests-oauthlib==1.3.0
rsa==4.6
six==1.15.0
typing-extensions==3.7.4.3
typing-inspect==0.6.0
urllib3==1.26.2

Then copy the following code to the DAG task to import google modules dynamically:

s3 = boto3.resource('s3')
logger.info(
    "Download google bigquery and google client dependencies from S3")
s3.Bucket(BUCKET_NAME).download_file(LIB_KEY, '/tmp/site-packages.zip')

with zipfile.ZipFile('/tmp/site-packages.zip', 'r') as zip_ref:
    zip_ref.extractall('/tmp/python3.7/site-packages')

sys.path.insert(1, "/tmp/python3.7/site-packages/site-packages/")

# Import google bigquery and google client dependencies
import pyarrow
from google.cloud import bigquery
from airflow.contrib.hooks.bigquery_hook import BigQueryHook

Next, We will create Google Cloud Connection in Airflow Airflow UI.

Now, we will be able to use Google BigQuery in Amazon Managed Airflow workers; let’s begin to create workflow tasks.

Data Preparation

First, export session and hit data from a Google Analytics 360 account to BigQuery, use SQL to query Analytics data into Pandas data frame with Personalize format.

To prepare an interaction dataset for Personalize, we need to extract the following data from BigQuery Google Analytics:

  • USER_ID
    , In this example, we don’t have user data of the e-commerce website, and there is no user interaction data from the website database. However, we can use the client id provided by Google Analytics. The client id (
    cid
    ) is a unique identifier for a browser–device pair that helps Google Analytics link user actions on a site. By default, Google Analytics determines unique users using this parameter. The client ID format is a randomly generated 31-bit integer followed by a dot (
    “.”
    ) followed by the current time in seconds. Hence we only need the 31-bit integer before the dot. BigQuery provides regular expression support, which we can put
    REGEXP_EXTRACT(USER_ID, r’(\d+)\.’) AS USER_ID 
    in Big Query to extract the session id as session-based User Id. 
  • ITEM_ID
     , Google Analytics provides page location (
    page_location
    ), so we can extract product pages (product slug) by WHERE Clause
    page_location LIKE ‘%/product/%’ 
    and make it as Item Id.
  • TIMESTAMP
     , timestamp data must be in UNIX epoch time format, use
    TIMESTAMP_TRUNC(TIMESTAMP_MICROS(event_timestamp) 
    to convert Analytics event_timestamp to correct format.
  • device.category
    AS
    DEVICE
     .
  • geo.country
    AS
    LOCATION
     .
  • event_name
    AS
    EVENT_NAME
     .

SQL query in BigQuery as below:

BQ_SQL = """
SELECT  REGEXP_EXTRACT(USER_ID, r'(\d+)\.') AS USER_ID, UNIX_SECONDS(EVENT_DATE) AS TIMESTAMP, REGEXP_EXTRACT(page_location,r'product/([^?&#]*)') as ITEM_ID, LOCATION, DEVICE, EVENT_NAME AS EVENT_TYPE
FROM
(
    SELECT user_pseudo_id AS USER_ID, (SELECT value.string_value FROM UNNEST(event_params) 
    WHERE key = "page_location") as page_location, TIMESTAMP_TRUNC(TIMESTAMP_MICROS(event_timestamp), 
    MINUTE) AS EVENT_DATE, device.category AS DEVICE, geo.country AS LOCATION, event_name AS EVENT_NAME 
    FROM `lively-metrics-295911.analytics_254171871.events_intraday_*`
    WHERE
    _TABLE_SUFFIX = FORMAT_DATE('%Y%m%d', DATE_SUB(CURRENT_DATE(), INTERVAL 1 DAY))
)
WHERE  page_location LIKE '%/product/%'
GROUP BY USER_ID, EVENT_DATE, page_location, LOCATION, DEVICE,EVENT_NAME
ORDER BY EVENT_DATE DESC
"""

Next, write the data frame to the CSV file directory to S3 using AWS DataWragler.

The following PythonOperator snippet in the DAG defines

BigQuery to S3 
task.

def bq_to_s3():
    s3 = boto3.resource('s3')
    logger.info(
        "Download google bigquery and google client dependencies from S3")
    s3.Bucket(BUCKET_NAME).download_file(LIB_KEY, '/tmp/site-packages.zip')

    with zipfile.ZipFile('/tmp/site-packages.zip', 'r') as zip_ref:
        zip_ref.extractall('/tmp/python3.7/site-packages')

    sys.path.insert(1, "/tmp/python3.7/site-packages/site-packages/")

    # Import google bigquery and google client dependencies
    import pyarrow
    from google.cloud import bigquery
    from airflow.contrib.hooks.bigquery_hook import BigQueryHook

    bq_hook = BigQueryHook(
        bigquery_conn_id="bigquery_default", use_legacy_sql=False)

    bq_client = bigquery.Client(project=bq_hook._get_field(
        "project"), credentials=bq_hook._get_credentials())

    events_df = bq_client.query(BQ_SQL).result().to_dataframe(
        create_bqstorage_client=False)

    logger.info(
        f'google analytics events dataframe head - {events_df.head()}')

    wr.s3.to_csv(events_df, OUTPUT_PATH, index=False)
    
    
t_export_bq_to_s3 = PythonOperator(task_id='export_bq_to_s3',
                               python_callable=bq_to_s3,
                               dag=dag,
                               retries=1)

Creating a Recommendation Model With Amazon Personalize

In this section, we will build a Personalize solution to identify the most popular items for an e-commerce website integrated with Google Analytics.

We will use Popularity-Count Recipe for training our model. Although Personalize supports importing interactions incrementally, we will retrain the model base on daily interaction data to get more relevant recommendations.

What we’ll cover:

  • check_s3_for_key
    (S3KeySensor): check if the dataset CSV file exists.
  • t_check_dataset_group
    (BranchPythonOperator): check if the Personalize dataset group exists. If Yes, trigger
    t_init_personalize
     , else trigger
    t_skip_init_personalize
     .
  • t_init_personalize
    (DummyOperator): trigger parallels tasks if dataset group doesn't exist(
    t_create_dataset_group
    ,
    t_create_schema
    ,
    t_put_bucket_policies
    ,
    t_create_iam_role
    ).
  • t_create_dataset_group
    (PythonOperator): Create a Personalize dataset group if it doesn’t exist.
  • t_create_schema
    (PythonOperator): Create an Interaction schema for our data if the schema doesn’t exist.
  • t_put_bucket_policies
    (PythonOperator): Attach an Amazon S3 policy to your Amazon Personalize role if it doesn’t exist.
  • t_create_iam_role
    (PythonOperator): Create a Personalize role that has the right permissions if it doesn’t exist.
  • t_create_dataset_type
    (PythonOperator): Create an ‘Interactions’ dataset type if it doesn’t exist.
  • t_skip_init_personalize
    (DummyOperator): Downstream task of BranchOperator task.
  • t_create_import_dataset_job
    (PythonOperator): Create your Dataset import jobs.
  • t_update_solution
    (PythonOperator): Create / Update Solution.
  • t_update_campagin
    (PythonOperator): Create/ Update Campaign.
def create_dataset_group(**kwargs):
    create_dg_response = personalize.create_dataset_group(
        name=DATASET_GROUP_NAME
    )
    dataset_group_arn = create_dg_response["datasetGroupArn"]

    status = None
    max_time = time.time() + 2*60*60  # 2 hours
    while time.time() < max_time:
        describe_dataset_group_response = personalize.describe_dataset_group(
            datasetGroupArn=dataset_group_arn
        )
        status = describe_dataset_group_response["datasetGroup"]["status"]
        logger.info(f"DatasetGroup: {status}")

        if status == "ACTIVE" or status == "CREATE FAILED":
            break

        time.sleep(20)
    if status == "ACTIVE":
        kwargs['ti'].xcom_push(key="dataset_group_arn",
                               value=dataset_group_arn)
    if status == "CREATE FAILED":
        raise AirflowFailException(
            f"DatasetGroup {DATASET_GROUP_NAME} create failed")


def check_dataset_group(**kwargs):
    dg_response = personalize.list_dataset_groups(
        maxResults=100
    )

    demo_dg = next((datasetGroup for datasetGroup in dg_response["datasetGroups"]
                    if datasetGroup["name"] == DATASET_GROUP_NAME), False)

    if not demo_dg:
        return "init_personalize"
    else:
        kwargs['ti'].xcom_push(key="dataset_group_arn",
                               value=demo_dg["datasetGroupArn"])
        return "skip_init_personalize"


def create_schema():
    schema_response = personalize.list_schemas(
        maxResults=100
    )

    interaction_schema = next((schema for schema in schema_response["schemas"]
                               if schema["name"] == INTERCATION_SCHEMA_NAME), False)
    if not interaction_schema:
        create_schema_response = personalize.create_schema(
            name=INTERCATION_SCHEMA_NAME,
            schema=json.dumps({
                "type": "record",
                "name": "Interactions",
                "namespace": "com.amazonaws.personalize.schema",
                "fields": [
                    {
                        "name": "USER_ID",
                        "type": "string"
                    },
                    {
                        "name": "ITEM_ID",
                        "type": "string"
                    },
                    {
                        "name": "TIMESTAMP",
                        "type": "long"
                    },
                    {
                        "name": "LOCATION",
                        "type": "string",
                        "categorical": True
                    },
                    {
                        "name": "DEVICE",
                        "type": "string",
                        "categorical": True
                    },
                    {
                        "name": "EVENT_TYPE",
                        "type": "string"
                    }
                ]
            }))
        logger.info(json.dumps(create_schema_response, indent=2))
        schema_arn = create_schema_response["schemaArn"]
        return schema_arn

    return interaction_schema["schemaArn"]


def put_bucket_policies():
    s3 = boto3.client("s3")
    policy = {
        "Version": "2012-10-17",
        "Id": "PersonalizeS3BucketAccessPolicy",
        "Statement": [
            {
                "Sid": "PersonalizeS3BucketAccessPolicy",
                "Effect": "Allow",
                "Principal": {
                    "Service": "personalize.amazonaws.com"
                },
                "Action": [
                    "s3:GetObject",
                    "s3:ListBucket"
                ],
                "Resource": [
                    "arn:aws:s3:::{}".format(BUCKET_NAME),
                    "arn:aws:s3:::{}/*".format(BUCKET_NAME)
                ]
            }
        ]
    }

    s3.put_bucket_policy(Bucket=BUCKET_NAME, Policy=json.dumps(policy))


def create_iam_role(**kwargs):
    role_name = f"PersonalizeS3Role-{suffix}"
    assume_role_policy_document = {
        "Version": "2012-10-17",
        "Statement": [
            {
                "Effect": "Allow",
                "Principal": {
                    "Service": "personalize.amazonaws.com"
                },
                "Action": "sts:AssumeRole"
            }
        ]
    }
    try:
        create_role_response = iam.create_role(
            RoleName=role_name,
            AssumeRolePolicyDocument=json.dumps(assume_role_policy_document)
        )

        iam.attach_role_policy(
            RoleName=role_name,
            PolicyArn="arn:aws:iam::aws:policy/AmazonS3ReadOnlyAccess"
        )

        role_arn = create_role_response["Role"]["Arn"]

        # sometimes need to wait a bit for the role to be created
        time.sleep(30)
        return role_arn
    except ClientError as e:
        if e.response['Error']['Code'] == 'EntityAlreadyExists':
            role_arn = iam.get_role(RoleName=role_name)['Role']['Arn']
            time.sleep(30)
            return role_arn
        else:
            raise AirflowFailException(f"PersonalizeS3Role create failed")


def create_dataset_type(**kwargs):
    ti = kwargs['ti']
    schema_arn = ti.xcom_pull(key="return_value", task_ids='create_schema')
    print(schema_arn)
    dataset_group_arn = ti.xcom_pull(key="dataset_group_arn",
                                     task_ids='create_dataset_group')
    dataset_type = "INTERACTIONS"
    create_dataset_response = personalize.create_dataset(
        datasetType=dataset_type,
        datasetGroupArn=dataset_group_arn,
        schemaArn=schema_arn,
        name=f"DEMO-metadata-dataset-interactions-{suffix}"
    )

    interactions_dataset_arn = create_dataset_response['datasetArn']
    logger.info(json.dumps(create_dataset_response, indent=2))
    return interactions_dataset_arn


def import_dataset(**kwargs):
    ti = kwargs['ti']
    interactions_dataset_arn = ti.xcom_pull(key="return_value",
                                            task_ids='create_dataset_type')
    role_arn = ti.xcom_pull(key="return_value",
                            task_ids='create_iam_role')
    create_dataset_import_job_response = personalize.create_dataset_import_job(
        jobName="DEMO-dataset-import-job-"+suffix,
        datasetArn=interactions_dataset_arn,
        dataSource={
            "dataLocation": OUTPUT_PATH
        },
        roleArn=role_arn
    )

    dataset_import_job_arn = create_dataset_import_job_response['datasetImportJobArn']
    logger.info(json.dumps(create_dataset_import_job_response, indent=2))

    status = None
    max_time = time.time() + 2*60*60  # 2 hours

    while time.time() < max_time:
        describe_dataset_import_job_response = personalize.describe_dataset_import_job(
            datasetImportJobArn=dataset_import_job_arn
        )

        dataset_import_job = describe_dataset_import_job_response["datasetImportJob"]
        if "latestDatasetImportJobRun" not in dataset_import_job:
            status = dataset_import_job["status"]
            logger.info("DatasetImportJob: {}".format(status))
        else:
            status = dataset_import_job["latestDatasetImportJobRun"]["status"]
            logger.info("LatestDatasetImportJobRun: {}".format(status))

        if status == "ACTIVE" or status == "CREATE FAILED":
            break

        time.sleep(60)

    if status == "ACTIVE":
        return dataset_import_job_arn
    if status == "CREATE FAILED":
        raise AirflowFailException(
            f"Dataset import job create failed")


def update_solution(**kwargs):
    recipe_arn = "arn:aws:personalize:::recipe/aws-popularity-count"
    ti = kwargs['ti']
    dataset_group_arn = ti.xcom_pull(key="dataset_group_arn",
                                     task_ids='create_dataset_group')
    list_solutions_response = personalize.list_solutions(
        datasetGroupArn=dataset_group_arn,
        maxResults=100
    )

    demo_solution = next((solution for solution in list_solutions_response["solutions"]
                          if solution["name"] == SOLUTION_NAME), False)

    if not demo_solution:
        create_solution_response = personalize.create_solution(
            name=SOLUTION_NAME,
            datasetGroupArn=dataset_group_arn,
            recipeArn=recipe_arn
        )

        solution_arn = create_solution_response['solutionArn']
        logger.info(json.dumps(create_solution_response, indent=2))
    else:
        solution_arn = demo_solution["solutionArn"]

    kwargs['ti'].xcom_push(key="solution_arn",
                               value=solution_arn)
    create_solution_version_response = personalize.create_solution_version(
        solutionArn=solution_arn,
        trainingMode='FULL'
    )

    solution_version_arn = create_solution_version_response['solutionVersionArn']

    status = None
    max_time = time.time() + 2*60*60  # 2 hours
    while time.time() < max_time:
        describe_solution_version_response = personalize.describe_solution_version(
            solutionVersionArn=solution_version_arn
        )
        status = describe_solution_version_response["solutionVersion"]["status"]
        logger.info(f"SolutionVersion: {status}")

        if status == "ACTIVE" or status == "CREATE FAILED":
            break

        time.sleep(60)

    if status == "ACTIVE":
        return solution_version_arn
    if status == "CREATE FAILED":
        raise AirflowFailException(
            f"Solution version create failed")


def update_campagin(**kwargs):
    ti = kwargs['ti']
    solution_version_arn = ti.xcom_pull(key="return_value",
                                        task_ids='update_solution')
    solution_arn = ti.xcom_pull(key="solution_arn",
                                task_ids='update_solution')

    list_campagins_response = personalize.list_campaigns(
        solutionArn=solution_arn,
        maxResults=100
    )

    demo_campaign = next((campaign for campaign in list_campagins_response["campaigns"]
                          if campaign["name"] == CAMPAIGN_NAME), False)
    if not demo_campaign:
        create_campaign_response = personalize.create_campaign(
            name=CAMPAIGN_NAME,
            solutionVersionArn=solution_version_arn,
            minProvisionedTPS=2,
        )

        campaign_arn = create_campaign_response['campaignArn']
        logger.info(json.dumps(create_campaign_response, indent=2))
    else:
        campaign_arn = demo_campaign["campaignArn"]
        personalize.update_campaign(
            campaignArn=campaign_arn,
            solutionVersionArn=solution_version_arn,
            minProvisionedTPS=2
        )

    status = None
    max_time = time.time() + 2*60*60  # 2 hours
    while time.time() < max_time:
        describe_campaign_response = personalize.describe_campaign(
            campaignArn=campaign_arn
        )
        status = describe_campaign_response["campaign"]["status"]
        print("Campaign: {}".format(status))

        if status == "ACTIVE" or status == "CREATE FAILED":
            break

        time.sleep(60)

    if status == "ACTIVE":
        return campaign_arn
    if status == "CREATE FAILED":
        raise AirflowFailException(
            f"Campaign create/update failed")


default_args = {
    'owner': 'airflow',
    'depends_on_past': False,
    'start_date': days_ago(1),
    'email': ['yi.ai@afox.mobi'],
    'email_on_failure': False,
    'email_on_retry': False,
    'retries': 1,
    'retry_delay': timedelta(minutes=5),
}

dag = DAG(
    'ml-pipeline',
    default_args=default_args,
    description='A simple ML data pipeline DAG',
    schedule_interval='@daily',
)

t_check_dataset_group = BranchPythonOperator(
    task_id='check_dataset_group',
    provide_context=True,
    python_callable=check_dataset_group,
    retries=1,
    dag=dag,
)

t_init_personalize = DummyOperator(
    task_id="init_personalize",
    trigger_rule=TriggerRule.ALL_SUCCESS,
    dag=dag,
)

t_create_dataset_group = PythonOperator(
    task_id='create_dataset_group',
    provide_context=True,
    python_callable=create_dataset_group,
    retries=1,
    dag=dag,
)

t_create_schema = PythonOperator(
    task_id='create_schema',
    python_callable=create_schema,
    retries=1,
    dag=dag,
)

t_put_bucket_policies = PythonOperator(
    task_id='put_bucket_policies',
    python_callable=put_bucket_policies,
    retries=1,
    dag=dag,
)

t_create_iam_role = PythonOperator(
    task_id='create_iam_role',
    provide_context=True,
    python_callable=create_iam_role,
    retries=1,
    dag=dag,
)

t_create_dataset_type = PythonOperator(
    task_id='create_dataset_type',
    provide_context=True,
    python_callable=create_dataset_type,
    trigger_rule=TriggerRule.ALL_SUCCESS,
    retries=1,
    dag=dag,
)

t_create_import_dataset_job = PythonOperator(
    task_id='import_dataset',
    provide_context=True,
    python_callable=import_dataset,
    retries=1,
    dag=dag,
)

t_skip_init_personalize = DummyOperator(
    task_id="skip_init_personalize",
    trigger_rule=TriggerRule.NONE_FAILED,
    dag=dag,
)

t_init_personalize_done = DummyOperator(
    task_id="init_personalize_done",
    trigger_rule=TriggerRule.NONE_FAILED,
    dag=dag,
)


t_update_solution = PythonOperator(
    task_id='update_solution',
    provide_context=True,
    python_callable=update_solution,
    trigger_rule=TriggerRule.ALL_SUCCESS,
    retries=1,
    dag=dag,
)

t_update_campagin = PythonOperator(
    task_id='update_campagin',
    provide_context=True,
    python_callable=update_campagin,
    trigger_rule=TriggerRule.ALL_SUCCESS,
    retries=1,
    dag=dag,
)

In the next section, we’ll see how all these tasks are stitched together to form a workflow in an Airflow DAG.

Defining DAG

Different tasks are created in the above sections using operators like

PythonOperator
for generic Python code to run on-demand or at a scheduled interval. 

Now let’s set DAG with parameters; a DAG is simply a Python script that contains a set of tasks and their dependencies.

default_args = {
'owner': 'airflow',
'depends_on_past': False,
'start_date': days_ago(1),
'email': ['yi.ai@afox.mobi'],
'email_on_failure': False,
'email_on_retry': False,
'retries': 1,
'retry_delay': timedelta(minutes=5),
}


dag = DAG(
'ml-pipeline',
default_args=default_args,
description='A simple ML data pipeline DAG',
schedule_interval='@daily',
)

Next, specify task dependencies:

t_export_bq_to_s3 >> check_s3_for_key >> t_check_dataset_group
t_check_dataset_group >> t_init_personalize
t_check_dataset_group >> t_skip_init_personalize >> t_init_personalize_done
t_init_personalize >> [
t_create_dataset_group,
t_create_schema,
t_put_bucket_policies,
t_create_iam_role
] >> t_create_dataset_type
t_create_dataset_type >> t_init_personalize_done
t_init_personalize_done >> t_create_import_dataset_job >> t_update_solution
t_update_solution >> t_update_campagin

After triggering the DAG on-demand or on a schedule, we can monitor DAGs and task executions and directly interact with them through Airflow UI. 

In the Airflow UI, we can see a graph view of the DAG to have a clear representation of how tasks are executed:

Conclusion

In this article, I introduced how we can build an ML workflow using MWAA and BigQuery; You can extend the workflow by customizing the DAGs, such as extending the dataset by merging daily data in a certain period (weekly, monthly, etc.), creating parallel tasks using different recipes and retraining models by schedule or trigger with S3 Key Sensor.

I hope you have found this article useful. You can find the complete project in my GitHub repo.

    Tags

    Join Hacker Noon

    Create your free account to unlock your custom reading experience.