paint-brush
How to Build a Question and Answer Chatbot with Amazon Kendra and AWS Fargateby@yi
308 reads
308 reads

How to Build a Question and Answer Chatbot with Amazon Kendra and AWS Fargate

by Yi AiMay 31st, 2020
Read on Terminal Reader
Read this story w/o Javascript
tldt arrow

Too Long; Didn't Read

Amazon announced the general availability of Amazon Kendra a few weeks ago, Kendra is a highly accurate and easy to use enterprise search service powered by machine learning.

People Mentioned

Mention Thumbnail

Companies Mentioned

Mention Thumbnail
Mention Thumbnail

Coin Mentioned

Mention Thumbnail
featured image - How to Build a Question and Answer Chatbot with Amazon Kendra and AWS Fargate
Yi Ai HackerNoon profile picture

Amazon announced the general availability of Amazon Kendra a few weeks ago, Kendra is a highly accurate and easy to use enterprise search service powered by machine learning.

In this post I will build a question and answer chatbot solution using React with Amplify, WebSocket API in AWS API Gateway, AWS Fargate and Amazon Kendra, the solution provides a conversational interface for Questions and Answers. It allows users to ask their questions and get relevant answers quickly.

What we’ll cover in this post:

  1. Create Amazon Kendra index, extract questions and answers from semi-structured document to Kendra FAQ.
  2. Deploy a WebSocket API in API Gateway to process the question and answer messages.
  3. Create a React application and use the AWS Amplify to connect & interact with the chatbot through WebSocket.
  4. Create a service in AWS Fargate that let our bot call Kendra’s API to provide answer and send back to user.

The following diagram shows the architecture of the above steps:

Why AWS Fargate matters?

An easy approach is that using Lambda function to query Amazon Kendra without using Fargate. However, with AWS Fargate or EC2, we are able to extend the chatbot with custom AI models to make our bot more human, eg. We can build a chatbot based on the Hugging Face State-of-the-Art Conversational AI model and query Kendra only for Specific data. 

If your program is a long compute job which requires more GBs of memory and higher performance, Fargate is probably the better option.

Prerequisites

  • Setup an AWS account
  • Install latest aws-cli
  • Install Amplify cli
  • Basic understanding of React
  • Basic understanding of Docker
  • Basic understanding of CloudFormation
  • Install or update the Serverless framework to latest version
  • jq installed (optional)

Now, Let’s get started!

Creating an Amazon Kendra index

Let’s create a Kendra index. Kendra supports unstructured and semi-structured documents like FAQs stored in S3, we will use FAQS in our case.

First, let’s download QnA dataset and upload it to S3. We can use 
Microsoft Research WikiQA Corpus for our chatbot.

Once downloading the dataset, let’s transform to Kendra supported csv format like below:

Use following script to transform dataset and upload the transformed csv file to existing S3 bucket my-kendra-index:

import json
import boto3
import pandas as pd
import os

def create_faq_format(input_path):
    faq_list = []
    
    with open(input_path) as f:
        lines = [line.strip('\n') for line in f]
    for i in range(2, len(lines)):
        l = lines[i].split('\t')
        if l[2]=="1":
            faq_list.append({"Question":l[0],"Answer":l[1]})
                
    return faq_list
    
qa_list = create_faq_format("WikiQACodePackage/data/wiki/WikiQASent-train.txt")
df = pd.DataFrame(qa_list, columns=["Question","Answer"])
df = df.drop_duplicates(subset='Question', keep="first")
df.to_csv('faq.csv', index=False)


s3 = boto3.resource('s3')
s3.meta.client.upload_file("faq.csv", 'my-kendra-index', 'faq/faq.csv')

Now, we are now ready to create a Kendra index. To create a Kendra index, complete the following steps:

  1. On the Amazon Kendra console, choose Launch Amazon Kendra.
  2. Create index and enter an Index name, such as my-aq-index.
  3. For IAM role, choose Create a new role to create a role to allow Amazon Kendra to access CloudWatch Logs.
  4. Create Index.

After the Kendra index has been created, we can add our FAQ document:

  1. Add FAQ from Amazon Kendra console.
  2. For S3, browse S3 to find your bucket, and select the FAQ csv file, here we use s3://my-kendra-index/faq/faq.csv.
  3. For IAM role, select Create a new role to allow Amazon Kendra to access FAQ content object in S3 bucket.
  4. Add FAQ.

Now that we have a working Kendra index, let’s move to next step.

Deploying WebSocket API in API Gateway to process the QnA messages

In this section we will build a 1) WebSockets API in AWS API Gateway, 2) create lambda functions to manage WebSockets routes ($connect, $disconnect, sendMessage) and 3) create DynamoDb to store WebSockets connection Ids and user names.

We will use Serverless Framework to build and deploy all required resources, let’s create a new Serverless project and add following config to serverless.yml :

service: serverless-chat

provider:
  name: aws
  runtime: nodejs12.x
  stackName: ${self:service}-${self:provider.stage}
  stage: ${opt:stage, 'dev'}
  region: ${opt:region, 'ap-southeast-2'}
  tags:
    project: chatbot
  iamRoleStatements:
    - Effect: Allow
      Action:
        - "execute-api:ManageConnections"
      Resource:
        - "arn:aws:execute-api:*:*:**/@connections/*"
    - Effect: Allow
      Action:
        - "dynamodb:PutItem"
        - "dynamodb:GetItem"
        - "dynamodb:UpdateItem"
        - "dynamodb:DeleteItem"
        - "dynamodb:Query"
        - "dynamodb:Scan"
      Resource:
        - Fn::GetAtt: [ChatConnectionsTable, Arn]
        - Fn::Join:
            - "/"
            - - Fn::GetAtt: [ChatConnectionsTable, Arn]
              - "*"
  environment:
    CHATCONNECTION_TABLE:
      Ref: ChatConnectionsTable
  websocketApiName: websocket-chat-${self:provider.stage}
  websocketApiRouteSelectionExpression: $request.body.action

functions:
  connectionManager:
    handler: handler.connectionManager
    events:
      - websocket:
          route: $connect
          authorizer:
            name: "authFunc"
            identitySource:
              - "route.request.querystring.token"
      - websocket:
          route: $disconnect
  authFunc:
    handler: handler.authFunc
    environment:
      APP_CLIENT_ID: ${ssm:/chatbot/${self:provider.stage}/app_client_id}
      USER_POOL_ID: ${ssm:/chatbot/${self:provider.stage}/user_pool_id}

  defaultMessages:
    handler: handler.defaultMessage
    events:
      - websocket:
          route: $default
  sendMessage:
    handler: handler.sendMessage
    events:
      - websocket:
          route: sendMessage

resources:
  Resources:
    ChatConnectionsTable: 
      Type: AWS::DynamoDB::Table
      Properties:
        AttributeDefinitions:
          - AttributeName: connectionId
            AttributeType: S
          - AttributeName: userid
            AttributeType: S
        KeySchema:
          - AttributeName: connectionId
            KeyType: HASH
        ProvisionedThroughput:
          ReadCapacityUnits: 5
          WriteCapacityUnits: 5
        GlobalSecondaryIndexes:
          - IndexName: userid_index
            KeySchema:
              - AttributeName: userid
                KeyType: HASH
            Projection:
              ProjectionType: ALL
            ProvisionedThroughput:
              ReadCapacityUnits: "5"
              WriteCapacityUnits: "5"

Note that the Cognito App Client Id (

/chatbot/dev/app_client_id
) and Cognito User Pool Id (
/chatbot/dev/user_pool_id
) in serverless.yml has not been created yet, We only reference Cognito details as SSM Parameters here, in next step, we will create Cognito User Pool using Amplify Cli and then we can modify related SSM parameters from System Storage Manager console.

Once serviceless.yml has been modified, update handler.js to create the lambda functions for WebSockets routes:

$connect with custom authorizer
,
$disconnect
,
sendMessage
:

"use strict";

const AWS = require("aws-sdk");
const DDB = new AWS.DynamoDB({ apiVersion: "2012-10-08" });
const jose = require("node-jose");
const fetch = require("node-fetch");
const KEYS_URL = `https://cognito-idp.${process.env.AWS_REGION}.amazonaws.com/${process.env.USER_POOL_ID}/.well-known/jwks.json`;
const successfullResponse = {
  statusCode: 200,
  body: "Connected",
};

module.exports.connectionManager = async (event, context, callback) => {
  if (event.requestContext.eventType === "CONNECT") {
    try {
      await addConnection(
        event.requestContext.connectionId,
        event.queryStringParameters.username
      );
      callback(null, successfullResponse);
    } catch (error) {
      callback(null, JSON.stringify(error));
    }
  } else if (event.requestContext.eventType === "DISCONNECT") {
    try {
      await deleteConnection(event.requestContext.connectionId);
      callback(null, successfullResponse);
    } catch (error) {
      callback(null, {
        statusCode: 500,
        body: "Failed to connect: " + JSON.stringify(err),
      });
    }
  }
};

module.exports.defaultMessage = (event, context, callback) => {
  callback(null);
};

module.exports.sendMessage = async (event, context, callback) => {
  let connectionData;
  try {
    const { body } = event;
    const messageBodyObj = JSON.parse(body);
    const params = {
      IndexName: "userid_index",
      KeyConditionExpression: "userid = :u",
      ExpressionAttributeValues: {
        ":u": {
          S: JSON.parse(messageBodyObj.data).to || "ROBOT",
        },
      },
      TableName: process.env.CHATCONNECTION_TABLE,
    };
    connectionData = await DDB.query(params).promise();
  } catch (err) {
    console.log(err);
    return { statusCode: 500 };
  }
  const postCalls = connectionData.Items.map(async ({ connectionId }) => {
    try {
      return await send(event, connectionId.S);
    } catch (err) {
      if (err.statusCode === 410) {
        return await deleteConnection(connectionId.S);
      }
      console.log(JSON.stringify(err));
      throw err;
    }
  });

  try {
    await Promise.all(postCalls);
  } catch (err) {
    console.log(err);
    callback(null, JSON.stringify(err));
  }
  callback(null, successfullResponse);
};

const send = (event, connectionId) => {
  const postData = JSON.parse(event.body).data;
  const apigwManagementApi = new AWS.ApiGatewayManagementApi({
    apiVersion: "2018-11-29",
    endpoint:
      event.requestContext.domainName + "/" + event.requestContext.stage,
  });
  return apigwManagementApi
    .postToConnection({ ConnectionId: connectionId, Data: postData })
    .promise();
};

const addConnection = (connectionId, userid) => {
  const putParams = {
    TableName: process.env.CHATCONNECTION_TABLE,
    Item: {
      connectionId: { S: connectionId },
      userid: { S: userid },
    },
  };

  return DDB.putItem(putParams).promise();
};

const deleteConnection = (connectionId) => {
  const deleteParams = {
    TableName: process.env.CHATCONNECTION_TABLE,
    Key: {
      connectionId: { S: connectionId },
    },
  };

  return DDB.deleteItem(deleteParams).promise();
};

module.exports.authFunc = async (event, context, callback) => {
  const {
    queryStringParameters: { token },
    methodArn,
  } = event;

  let policy;

  try {
    policy = await authCognitoToken(token, methodArn);
    callback(null, policy);
  } catch (error) {
    console.log(error);
    callback("Signature verification failed");
  }
};

const authCognitoToken = async (token, methodArn) => {
  if (!token) throw new Error("Unauthorized");
  const app_client_id = process.env.APP_CLIENT_ID;
  const sections = token.split(".");
  let authHeader = jose.util.base64url.decode(sections[0]);
  authHeader = JSON.parse(authHeader);
  const kid = authHeader.kid;
  const rawRes = await fetch(KEYS_URL);
  const response = await rawRes.json();
  if (rawRes.ok) {
    const keys = response["keys"];
    let key_index = -1;
    keys.some((key, index) => {
      if (kid == key.kid) {
        key_index = index;
      }
    });
    const foundKey = keys.find((key) => {
      return kid === key.kid;
    });

    if (!foundKey) {
      callback("Public key not found in jwks.json");
    }

    const jwkRes = await jose.JWK.asKey(foundKey);
    const verifyRes = await jose.JWS.createVerify(jwkRes).verify(token);
    const claims = JSON.parse(verifyRes.payload);

    const current_ts = Math.floor(new Date() / 1000);
    if (current_ts > claims.exp) {
      throw new Error("Token is expired");
    }

    if (claims.client_id != app_client_id) {
      throw new Error("Token was not issued for this audience");
    } else {
      return generatePolicy("me", "Allow", methodArn);
    }
  }
  throw new Error("Keys url is invalid");
};

const generatePolicy = function (principalId, effect, resource) {
  var authResponse = {};
  authResponse.principalId = principalId;
  if (effect && resource) {
    var policyDocument = {};
    policyDocument.Version = "2012-10-17";
    policyDocument.Statement = [];
    var statementOne = {};
    statementOne.Action = "execute-api:Invoke";
    statementOne.Effect = effect;
    statementOne.Resource = resource;
    policyDocument.Statement[0] = statementOne;
    authResponse.policyDocument = policyDocument;
  }
  return authResponse;
};

const generateAllow = function (principalId, resource) {
  return generatePolicy(principalId, "Allow", resource);
};

const generateDeny = function (principalId, resource) {
  return generatePolicy(principalId, "Deny", resource);
};

Run the following commands to deploy WebSocket API:

$sls deploy --stage dev --region YOUR_REGION

Building a React application with AWS Amplify

In this section, We will build a web app using React and AWS Amplify with authentication feature.

The complete project in Github repo, you can find the following folders in the project directory:

  • amplify/.config/, and amplify/backend/.
  • project-config.json in .config/ folder.
  • backend-config.json in backend/ folder.
  • CloudFormation files in the backend/ folder.

Let’s download source code and re-initialise the existing Amplify project by running:

$amplify init

then push changes:

$amplify push

and deploy:

$amplify publish

We will get the web application URL after project has been deployed:

Now, log in to AWS Cognito service console and you can now see AWS Cognito User Pool has been created . Copy the User Pool Id and App Client Id to SSM Parameters we already created in previous step.

Now that we’re ready for the final step!

Creating a chatbot service in AWS Fargate 

In this section, we will create a bot task running in chatbot service in AWS Fargate. Chatbot task connects to WebSocket API, When the user asks a question, the bot can query the Kendra index, and Kendra will surface a relevant answer, send back to the user who asked the question.

To deploy the Fargate service, perform the following steps:

  1. Download chatbot script and
    Dockerfile
    here.
  2. Build Docker, Tag an Amazon ECR Repository and push the image to ECR, for more details, please refer to AWS official tutorial.
  3. Download the CloudFormation templates and bash scripts here.
  4. If using the Fargate launch type, the
    awsvpc network
    mode is required. We need to deploy VPC and Security Groups:
  5. $bash create-infra.sh -d dev
  6. Create task definition.
  7. $bash create-task.sh -d dev
  8. Deploy chatbot service in AWS Fargate.
  9. $bash deploy-service.sh -d dev

The main logic can be found here:

import json
import logging
import boto3
import datetime
import websocket
import ssl
import os
import logging

logger = logging.getLogger()
logger.setLevel(logging.INFO)
ssm = boto3.client('ssm', region_name="ap-southeast-2")
KENDRA_INDEX_SSM = ssm.get_parameter(Name=os.environ["KENDRA_INDEX_KEY"])
kendra_index_id = KENDRA_INDEX_SSM["Parameter"]["Value"]
kendra = boto3.client('kendra', region_name='us-east-1')
ROBOT_USER_SSM = ssm.get_parameter(Name=os.environ["ROBOT_USER_SSM"])
user_name = ROBOT_USER_SSM["Parameter"]["Value"]
ROBOT_PASS_SSM = ssm.get_parameter(
    Name=os.environ["ROBOT_PASS_SSM"], WithDecryption=True)
password = ROBOT_PASS_SSM["Parameter"]["Value"]
USER_POOL_SSM = ssm.get_parameter(Name=os.environ["USER_POOL_SSM"])
user_pool = USER_POOL_SSM["Parameter"]["Value"]
APP_CLIENT_SSM = ssm.get_parameter(Name=os.environ["APP_CLIENT_SSM"])
app_client = APP_CLIENT_SSM["Parameter"]["Value"]
credentials = boto3.Session().get_credentials()
WS_URL_SSM = ssm.get_parameter(Name=os.environ["WS_URL_KEY"])


def on_message(ws, message):
    message_obj = json.loads(message)
    result = get_answer(message_obj["data"]["text"])
    if len(result["ResultItems"]) > 0:
        logger.debug(result["ResultItems"][0]["DocumentExcerpt"]["Text"])
        answer_text = result["ResultItems"][0]["DocumentExcerpt"]["Text"]
    else:
        answer_text = "Sorry, I could not find an answer."

    ws.send(json.dumps({
            "action": "sendMessage",
            "data": json.dumps({"data": answer_text,
                                "type": "text",
                                "author": "ROBOT",
                                "to": message_obj["author"]})
            }))


def authenticate_and_get_token(username, password,
                               user_pool_id, app_client_id):
    client = boto3.client('cognito-idp')

    resp = client.admin_initiate_auth(
        UserPoolId=user_pool_id,
        ClientId=app_client_id,
        AuthFlow='ADMIN_NO_SRP_AUTH',
        AuthParameters={
            "USERNAME": username,
            "PASSWORD": password
        }
    )

    return resp['AuthenticationResult']['AccessToken']


def on_error(ws, error):
    logger.error(error)


def on_close(ws):
    logger.info("### closed ###")


def on_open(ws):
    logger.info("connected")


def get_answer(text):
    response = kendra.query(
        IndexId=kendra_index_id,
        QueryText=text,
        QueryResultTypeFilter='QUESTION_ANSWER',
    )
    return response


if __name__ == '__main__':
    access_token = authenticate_and_get_token(
        user_name, password, user_pool, app_client)

    ws_url = "{}?token={}&username=ROBOT".format(
        WS_URL_SSM["Parameter"]["Value"], access_token)
    websocket.enableTrace(False)
    ws = websocket.WebSocketApp(ws_url, on_message=on_message,
                                on_error=on_error,
                                on_close=on_close)
    ws.on_open = on_open
    ws.run_forever(sslopt={"cert_reqs": ssl.CERT_NONE})

(Optional) Extending Chatbot with ConvAI model

To extend the chatbot with ConvAI model, you can try below sample script, note that you will need to put more effort to train the model and put it in the docker or EFS.

import json
import logging
import boto3
import flask
import torch
import datetime
import torch.nn.functional as F
# from requests_aws4auth import AWS4Auth
from simpletransformers.conv_ai import ConvAIModel
from flask import request, Response

app = flask.Flask(__name__)
region = 'ap-southeast-2'
# ssm = boto3.client('ssm', region_name=region)
# credentials = boto3.Session().get_credentials()
# awsauth = AWS4Auth(credentials.access_key, credentials.secret_key,
#                    region, service, session_token=credentials.token)
dynamodb = boto3.client('dynamodb')
polly_client = boto3.Session(region_name=region).client('polly')
s3 = boto3.resource('s3')
BUCKET_NAME = "aiyi.demo.textract"
TABLE_NAME = "serverless-chat-dev-ChatHistoryTable-M0BPSVMQJBFX"
SPECIAL_TOKENS = ["<bos>", "<eos>", "<speaker1>", "<speaker2>", "<pad>"]
history = []
convAimodel = ConvAIModel("gpt", "model", use_cuda=False)
character = [
    "i like computers .",
    "i like reading books .",
    "i like talking to chatbots .",
    "i love listening to classical music ."
]


def text_2_speech(userid, response_msg):
    response = polly_client.synthesize_speech(VoiceId='Joanna',
                                              OutputFormat='mp3',
                                              Text=response_msg)
    object_key = "{}/{}/speech.mp3".format(userid,
                                           int(datetime.datetime.utcnow().timestamp()))
    object = s3.Object(
        BUCKET_NAME, object_key)
    object.put(Body=response['AudioStream'].read())
    return object_key


def get_chat_histories(userid):
    response = dynamodb.get_item(TableName=TABLE_NAME, Key={
        'userid': {
            'S': userid
        }})

    if 'Item' in response:
        return json.loads(response["Item"]["history"]["S"])
    return {"history": []}


def save_chat_histories(userid, history):
    return dynamodb.put_item(TableName=TABLE_NAME, Item={'userid': {'S': userid}, 'history': {'S': history}})


def sample_sequence(aiCls, personality, history, tokenizer, model, args, current_output=None):
    special_tokens_ids = tokenizer.convert_tokens_to_ids(SPECIAL_TOKENS)
    if current_output is None:
        current_output = []

    for i in range(args["max_length"]):
        instance = aiCls.build_input_from_segments(
            personality, history, current_output, tokenizer, with_eos=False)

        input_ids = torch.tensor(
            instance["input_ids"], device=aiCls.device).unsqueeze(0)
        token_type_ids = torch.tensor(
            instance["token_type_ids"], device=aiCls.device).unsqueeze(0)

        logits = model(input_ids, token_type_ids=token_type_ids)
        if isinstance(logits, tuple):  # for gpt2 and maybe others
            logits = logits[0]
        logits = logits[0, -1, :] / args["temperature"]
        logits = aiCls.top_filtering(
            logits, top_k=args["top_k"], top_p=args["top_p"])
        probs = F.softmax(logits, dim=-1)

        prev = torch.topk(probs, 1)[
            1] if args["no_sample"] else torch.multinomial(probs, 1)
        if i < args["min_length"] and prev.item() in special_tokens_ids:
            while prev.item() in special_tokens_ids:
                if probs.max().item() == 1:
                    warnings.warn(
                        "Warning: model generating special token with probability 1.")
                    break  # avoid infinitely looping over special token
                prev = torch.multinomial(probs, num_samples=1)

        if prev.item() in special_tokens_ids:
            break
        current_output.append(prev.item())

    return current_output


def interact(raw_text, model, personality, userid, history):
    """
    Interact with a model in the terminal.
    Args:
        personality: A list of sentences that the model will use to build a personality.
    Returns:
        None
    """
    args = model.args
    tokenizer = model.tokenizer
    process_count = model.args["process_count"]

    model._move_model_to_device()

    if not personality:
        dataset = get_dataset(
            tokenizer,
            None,
            args["cache_dir"],
            process_count=process_count,
            proxies=model.__dict__.get("proxies", None),
            interact=True,
        )
        personalities = [dialog["personality"]
                         for dataset in dataset.values() for dialog in dataset]
        personality = random.choice(personalities)
    else:
        personality = [tokenizer.encode(s.lower()) for s in personality]

    history.append(tokenizer.encode(raw_text))
    with torch.no_grad():
        out_ids = sample_sequence(
            model, personality, history, tokenizer, model.model, args)
    history.append(out_ids)
    history = history[-(2 * args["max_history"] + 1):]
    out_text = tokenizer.decode(out_ids, skip_special_tokens=True)
    save_chat_histories(userid, json.dumps({"history": history}))
    return out_text


@app.route('/message-received', methods=['POST'])
def process_chat_message():
    response = None

    if request.form['userid'] is None:
        response = Response("", status=415)
    else:
        try:
            userid = request.form['userid']
            message = request.form['message']
            history = get_chat_histories(userid)
            history = history["history"]
            response_msg = interact(message, convAimodel,
                                    character, userid, history)
            audio_key = text_2_speech(userid, response_msg)
            return Response(
                json.dumps({"message": response_msg, "audio": audio_key}),
                status=200, mimetype='application/json')
        except Exception as ex:
            logging.exception(ex)
            return Response(ex.message, status=500)
    # return response


if __name__ == '__main__':
    app.run(debug=True, host='0.0.0.0', port=8000)

Once the service has been deployed, we should be able to ask the bot questions, let’s visit the React application and try it out live!

AS you can see the results above evidence that even only key words are used, the system can respond with the correct answer.

If you would like to learn more about Amazon Kendra, there is an official tutorial of building a chatbot using Lex and Kendra, for details please refer to this link.

I hope you have found this article useful, The source code for this post can be found in my GitHub repo.