Creating a Middleware in Golang for JWT based Authentication

Author profile picture

@agzuniverseAswin G

CS undergrad student with a love for tech. Currently hacking with JavaScript, Python and Golang.

Golang has been a popular language over the past few years known for it's simplicity and great out-of-the-box support for building web applications and for concurrency heavy processing. Similarly, JWT (JSON Web Tokens) are turning into an increasingly popular way of authenticating users. In this post I shall go over how to create an authentication middleware for Golang that can restrict certain parts of your web app to require authentication.

You can also apply this pattern for creation of middlewares to make any middleware that does anything you want to.
If you aren't familiar with JWTs, https://jwt.io/ is a great resource to get familiar with them. Basically, a JWT is a token included in the
Authorization
header of an HTTP request can be used to verify the user making the request. While most forms of token authentication requires a database read to verify the token belongs to an active authenticated user, when using JWTs, if the JWT can be decoded successfully, that itself guarantees it is a valid token since it has a signature field that will become invalid if any data in the token is corrupted or manipulated. You can also decide what data to encode in the JWT body, which means on successfully decoding a JWT you can also get useful data, such as a user's username or email.
The scope of this article is limited to creating a middleware in Golang to check the validity of a JWT in an incoming request. Generation of JWTs are a separate process and I will not be describing how to do that here.

Setting up the web app

I'm using a simple Golang web app that provides a single API endpoint -
/ping
, that responds with "pong". No libraries have been used in the code below.
package main

import (
	"log"
	"net/http"
)
func pong(w http.ResponseWriter, r *http.Request) {
	w.WriteHeader(http.StatusOK)
	w.Write([]byte("pong"))
}

func main() {
	http.Handle("/ping", http.HandlerFunc(pong))
	log.Fatal(http.ListenAndServe(":8080", nil))
}
Here's a quick breakdown of what this does:
The main function registers
pong
as the handler function for the
/ping
endpoint. Then it starts an HTTP server running on port 8080. The function
pong
simply responds with a status of 200 (meaning "OK") and the text "pong".
Handle
requires it's second argument to be a value that implements the
Handler
interface, so
HandlerFunc
is an adapter function that does just that. More on this later.
So now if you run this with
go run main.go
(Assuming the file name is
main.go
) and send a GET request to
http://localhost:8080/ping
(Opening this link in your browser is an easy way) you'll get back the text
pong
.
Now I want to protect the
/ping
endpoint so only incoming requests that have a valid JWT can get the required response. If the JWT is not present or is corrupted, the app should return HTTP status code 401 - Not authorized.

Creating a middleware

A middleware for our HTTP server should be a function that takes in a function that implements the
http.Handler
interface and returns a new function that implements the
http.Handler
interface.
This concept should be familiar for anyone who has used JavaScript closures, Python decorators or functional programming of some kind. But essentially, it is a function that takes a function and adds additional functionality to it.
But wait - the handler function we just wrote,
pong
, does not implement this interface. So we have to take care of that. If you look in Golang's built in
http
package you'll find that the
Handler
interface just specifies that the
ServeHTTP
method should be implemented for that value. In fact - we don't even have to implement this ourselves for our handler function.
http
also provides a handy helper function called
HandlerFunc
that takes in any function that accepts
ResponseWriter
and
*Request
as parameters, and converts it to a function that implements the
Handler
interface.
At this point it's easy to get confused between these 4 things in the
http 
package, so let me just clarify them:
http.Handler
- An interface with a single member -
ServeHTTP
.
http.HandlerFunc
- A function that is used to convert
pong
(and any other handler function) to a function that implements the
Handler
interface.
http.HandleFunc
- A function used to associate an endpoint with a handler function. Automatically converts the handler function into a function that implements the Handler interface. We have not used this here.
http.Handle
- Same as
HandleFunc
, except it does NOT convert the handler function into a function that implements the
Handler
interface.
Let's create the structure of a basic middleware:
func middleware(next http.Handler) http.Handler {
	return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
		// Do stuff
		next.ServeHTTP(w, r)
	})
}
As you can see, we get the request and can do anything we want with it before calling the handler function we created earlier, which is passed to our middleware as
next
. We make use of
HandlerFunc
again to convert the function we return to a function that implements
Handler
interface.
Now you can have your middleware do whatever you want with the incoming request. What I want to do is to validate the JWT in the
Authorization
header of the request, so I'll start with the following changes:
authHeader := strings.Split(r.Header.Get("Authorization"), "Bearer ")
		if len(authHeader) != 2 {
			fmt.Println("Malformed token")
			w.WriteHeader(http.StatusUnauthorized)
			w.Write([]byte("Malformed Token"))
		}
This takes the JWT token from the
Authorization
header. If the token is not present, it returns an unauthorized status and never calls our handler function.
The expected format of the header is
Bearer <token>
. Note that you can pass the JWT in the request anyway you want, but this is the widely accepted way of doing it.
If the token is indeed present, we'll need to decode it. For this I'm using the
go-jwt
library, which can be installed with
go get github.com/dgrijalva/jwt-go
else {
			jwtToken := authHeader[1]
			token, err := jwt.Parse(jwtToken, func(token *jwt.Token) (interface{}, error) {
				if _, ok := token.Method.(*jwt.SigningMethodHMAC); !ok {
					return nil, fmt.Errorf("Unexpected signing method: %v", token.Header["alg"])
				}
				return []byte(SECRETKEY), nil
			})

			if claims, ok := token.Claims.(jwt.MapClaims); ok && token.Valid {
				ctx := context.WithValue(r.Context(), "props", claims)
				// Access context values in handlers like this
				// props, _ := r.Context().Value("props").(jwt.MapClaims)
				next.ServeHTTP(w, r.WithContext(ctx))
			} else {
				fmt.Println(err)
				w.WriteHeader(http.StatusUnauthorized)
				w.Write([]byte("Unauthorized"))
			}
		}
We use
jwt.Parse
to decode our token. The second argument to this function is a function that is used to return the secret key used to decode the token after checking if the signing method of the token is HMAC. This is because JWT can be encoded in many ways including asymmetric encryption with a public-private key pair. Here I'm using a simple secret key to decode the JWT. This must be the same secret key used to encode the JWT by the entity that generated the JWT. The signing method will not be HMAC if the key was encoded in some other way, so we check this first.
Note that you have to replace the
SECRETKEY
variable with your secret key, which should be a string.
Then we obtain the claims from the token. The claims are whatever values that have been encoded in the JWT. If the decoding fails, this means the token has been corrupt or tampered with, and we return an unauthorized status.
If the decoding was successful, we create a variable
ctx
to hold these claims and attach them to the request instance through it's
Context
. This is done so that we can access them in our handler function (or other middlewares if we chain multiple middlewares) if needed, as mentioned in the commented out lines. "props" here is a key I used, and can be substituted with any value of your choice as long as you use the same key while trying to get data from the context. A typical use case would be to get the user ID from the claims in the handler function and use that to perform a database operation for some information that corresponds to that user.
Finally we change our earlier code to wrap the handler function in the middleware.
http.Handle("/ping", middleware(http.HandlerFunc(pong)))
The complete code looks like this:
package main

import (
	"context"
	"fmt"
	"log"
	"net/http"
	"strings"

	"github.com/dgrijalva/jwt-go"
)

func middleware(next http.Handler) http.Handler {
	return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
		authHeader := strings.Split(r.Header.Get("Authorization"), "Bearer ")
		if len(authHeader) != 2 {
			fmt.Println("Malformed token")
			w.WriteHeader(http.StatusUnauthorized)
			w.Write([]byte("Malformed Token"))
		} else {
			jwtToken := authHeader[1]
			token, err := jwt.Parse(jwtToken, func(token *jwt.Token) (interface{}, error) {
				if _, ok := token.Method.(*jwt.SigningMethodHMAC); !ok {
					return nil, fmt.Errorf("Unexpected signing method: %v", token.Header["alg"])
				}
				return []byte(SECRETKEY), nil
			})

			if claims, ok := token.Claims.(jwt.MapClaims); ok && token.Valid {
				ctx := context.WithValue(r.Context(), "props", claims)
				// Access context values in handlers like this
				// props, _ := r.Context().Value("props").(jwt.MapClaims)
				next.ServeHTTP(w, r.WithContext(ctx))
			} else {
				fmt.Println(err)
				w.WriteHeader(http.StatusUnauthorized)
				w.Write([]byte("Unauthorized"))
			}
		}
	})
}

func pong(w http.ResponseWriter, r *http.Request) {
	w.WriteHeader(http.StatusOK)
	w.Write([]byte("pong"))
}

func main() {
	http.Handle("/ping", middleware(http.HandlerFunc(pong)))
	log.Fatal(http.ListenAndServe(":8080", nil))
}
You can use the same approach to create multiple middlewares and apply them on endpoints that need them.
I hope you found this post helpful.
You can find me on Twitter and LinkedIn.

Tags

The Noonification banner

Subscribe to get your daily round-up of top tech stories!