Comparison of Deep Learning frameworks

This article shows how to implement a training on CIFAR10 dataset with different frameworks (FastAI, JAX, Keras, MXNet, PaddlePaddle, Pytorch and Pytorch-lightning) to compare them in terms of:

  • ease of implementation (user friendly coding, ease of finding information online, etc.),
  • time per epoch for the same model and the same training parameters,
  • memory and GPU usage (thanks to pytorch-profiler),
  • accuracy obtained after the same training.

The code can be found here.

Table of contents

  1. Introduction
  2. Coding overview
    1. FastAI
    2. JAX
    3. Keras
    4. MXNet
    5. PaddlePaddle
    6. Pytorch
    7. Pytorch-lightning
  3. Training results
  4. Comparison summary
  5. How to analyse Memory/GPU usage with pytorch profiler
  6. Conclusion

1) Introduction

This article gives an overview of the differences between the best-known Python frameworks for Deep Learning.

The code is available on Github and can also be used as a functional basis for launching any kind of training.
First, we’ll look at the key points in terms of programming, then go into more detail on inference times, convergence speeds and memory/gpu usage.

2) Coding overview

2-1) FastAI

FastAI lives up to its name. It’s extremely easy to code a training session: 2 lines to load the dataset with train/test split and 3 lines to launch a training session and evaluate the model. It’s certainly the most user-friendly famework in terms of programming.
It also offers many useful functions such as lr_find() to find the optimal learning rate. The documentation is also very well done and there seems to be quite a large community.

Example:

def get_data(batch_size=128) -> DataLoaders:
    """Get DataLoaders for CIFAR-10 dataset.

    Returns:
       CIFAR10 dataloader
    """
    path = untar_data(URLs.CIFAR)
    return ImageDataLoaders.from_folder(path, train="train", valid="test", bs=batch_size)

def run_training(
    dataloader: DataLoaders,
    epochs: int = 3,
    batch_size: int = 128,
    learning_rate: float = 0.0001,
) -> float:
    """Run CIFAR10 training with FASTAI frameworks.

    Returns:
        validation accuracy
    """
    model = CustomCNN()
    learn = Learner(dataloader, model, metrics=accuracy)

    learn.fit(epochs, lr=learning_rate)

    return learn.validate()[1]

2-2) JAX

As for Jax, asked me a lot more effort into developing this training code. In particular, to get all the data in the right format. I am no Jax expert, but it also seemed to me necessary to implement the functions for evaluating predictions, cross-entropy loss and the training loop. All this can lead to implementation errors quite quickly, making the framework less usable for the uninitiated. The community also didn’t seem very important to me, as I couldn’t find the answers to my debugging questions.

Example:

def CrossEntropyLoss(
    weights: list,
    input_data: jax.Array,
    targets: jax.Array,
    model: CustomCNN,
) -> jax.Array:
    """Implement of cross entropy loss.

    Args:
        weights: list from _, _, opt_get_weights = optimizers.adam(lr), opt_get_weights(opt_state)
        input_data: data to predict
        targets: groundtruth targets in one hot encoding
        model: model with conv_apply var

    Returns:
        loss value
    """
    preds = model.conv_apply(weights, input_data)
    log_preds = jnp.log(preds + tf.keras.backend.epsilon())
    return -jnp.mean(targets * log_preds)


def run_training(
    dataloader: Union[tf.Tensor, tf.data.Dataset],
    epochs: int = 3,
    batch_size: int = 128,
    learning_rate: float = 0.0001,
) -> float:
    """Run CIFAR10 training with JAX frameworks.

    Returns:
        validation accuracy
    """
    model = CustomCNN()
    train_data, test_data = dataloader["train"], dataloader["test"]
    X_train, Y_train = train_data["image"], train_data["label"]
    X_test, Y_test = test_data["image"], test_data["label"]

    X_train, X_test, Y_train, Y_test = (
        jnp.array(X_train, dtype=jnp.float32),
        jnp.array(X_test, dtype=jnp.float32),
        jnp.array(Y_train, dtype=jnp.float32),
        jnp.array(Y_test, dtype=jnp.float32),
    )
    rng = jax.random.PRNGKey(123)
    weights = model.conv_init(rng, (18, 32, 32, 3))[1]
    opt_init, opt_update, opt_get_weights = optimizers.adam(learning_rate)
    opt_state = opt_init(weights)
    Y_train_one_hot = jax.nn.one_hot(Y_train, num_classes=10)

    for i in range(epochs):
        batches = jnp.arange((X_train.shape[0] // batch_size) + 1)
        progress_bar = tqdm(batches, position=0, leave=True)

        losses = []
        for batch in batches:
            if batch != batches[-1]:
                start, end = int(batch * batch_size), int(batch * batch_size + batch_size)
            else:
                start, end = int(batch * batch_size), None

            X_batch, Y_batch = X_train[start:end], Y_train_one_hot[start:end]

            loss, gradients = value_and_grad(CrossEntropyLoss)(
                opt_get_weights(opt_state),
                X_batch,
                Y_batch,
                model=model,
            )

            ## Update Weights
            opt_state = opt_update(i, gradients, opt_state)

            losses.append(loss)

            progress_bar.set_description(f"Epoch {i+1}/{epochs}")
            progress_bar.set_postfix(train_loss=jnp.round(jnp.array(losses).mean(), decimals=3))
            progress_bar.update()

    test_preds = MakePredictions(
        opt_get_weights(opt_state),
        X_test,
        batch_size=batch_size,
        model=model,
    )

    ## Combine predictions of all batches
    test_preds = jnp.concatenate(test_preds).squeeze()

    test_preds = jnp.argmax(test_preds, axis=1)
    return accuracy_score(Y_test, test_preds)

2-3) Keras

I think Keras needs no introduction: very easy to use, short and intuitive code, great customization possibilities and a very large community. It’s hard to find fault with it. It’s also very deployment-accessible, with lots of tools, which makes it the framework of choice for businesses.

Example:

def get_data(batch_size=128) -> dict:
    """Get DataLoaders for CIFAR-10 dataset.

    Returns:
       CIFAR10 dataloader
    """
    # Load the CIFAR-10 dataset
    (x_train, y_train), (x_test, y_test) = keras.datasets.cifar10.load_data()
    return {"train_data": (x_train, y_train), "test_data": (x_test, y_test)}

def run_training(
    dataloader: dict,
    epochs: int = 3,
    batch_size: int = 128,
    learning_rate: float = 0.0001,
) -> float:
    """Run CIFAR10 training with keras frameworks.

    Returns:
        validation accuracy
    """
    (x_train, y_train), (x_test, y_test) = dataloader["train_data"], dataloader["test_data"]

    # Compile the model
    model = CustomCNN()
    optimizer = keras.optimizers.Adam(learning_rate=learning_rate)
    model.compile(
        optimizer=optimizer,
        loss="sparse_categorical_crossentropy",
        metrics=["accuracy"],
    )

    # Train the model
    model.fit(
        x_train,
        y_train,
        epochs=epochs,
        validation_data=(x_test, y_test),
        batch_size=batch_size,
    )

    # Evaluate the model
    _, test_acc = model.evaluate(x_test, y_test)
    return test_acc

2-4) MXNet

MXNet is easy to use. It has a lot of built-in block functions, such as creating your own dataloader, importing losses, optimizing, etc. Then the training loop seems to have to be implemented. As I haven’t had any particular problems, I didn’t need to debug it, but it seems to me that there’s a community online. Its closest framework seems to me to be Pytorch.

Example:

def get_data(batch_size=128) -> dict:
    """Get DataLoaders for CIFAR-10 dataset.

    Returns:
       CIFAR10 dataloader
    """
    # Load and transform the CIFAR10 data
    transform_train = transforms.Compose(
        [
            transforms.ToTensor(),
            # transforms.Normalize(mean=[0.4914, 0.4822, 0.4465], std=[0.2023, 0.1994, 0.2010]),
        ],
    )
    transform_test = transforms.Compose(
        [
            transforms.ToTensor(),
            # transforms.Normalize(mean=[0.4914, 0.4822, 0.4465], std=[0.2023, 0.1994, 0.2010]),
        ],
    )
    train_data = datasets.CIFAR10(train=True).transform_first(transform_train)
    test_data = datasets.CIFAR10(train=False).transform_first(transform_test)
    return {"train_data": train_data, "test_data": test_data}

def run_training(
    dataloader: dict,
    epochs: int = 3,
    batch_size: int = 128,
    learning_rate: float = 0.0001,
) -> float:
    """Run CIFAR10 training with MXNET frameworks.

    Returns:
        validation accuracy
    """
    net = CustomCNN().net
    train_data, test_data = dataloader["train_data"], dataloader["test_data"]
    # Define the loss function and optimizer
    softmax_cross_entropy = gluon.loss.SoftmaxCrossEntropyLoss()
    trainer = gluon.Trainer(net.collect_params(), "adam", {"learning_rate": learning_rate})

    # Train the network
    batch_size = 128
    train_loader = gluon.data.DataLoader(train_data, batch_size, shuffle=True)
    test_loader = gluon.data.DataLoader(test_data, batch_size)
    for epoch in range(epochs):
        progress_bar = tqdm(train_loader, position=0, leave=True)
        train_loss, train_acc, n = 0.0, 0.0, 0
        for X, y in train_loader:
            X, y = X.as_in_context(ctx), y.as_in_context(ctx)
            with autograd.record():
                y_hat = net(X)
                loss = softmax_cross_entropy(y_hat, y)
            loss.backward()
            trainer.step(batch_size)
            current_loss = nd.sum(loss).asscalar()
            train_loss += current_loss
            current_acc = nd.sum(
                ndarray.cast(y_hat.argmax(axis=1), dtype="int32") == y,
            ).asscalar()
            train_acc += current_acc
            n += y.size

            progress_bar.set_description(f"Epoch {epoch+1}/{epochs}")
            progress_bar.set_postfix(train_loss=round(current_loss, 3), train_acc=current_acc)
            progress_bar.update()

    test_acc = 0.0
    for X, y in test_loader:
        X, y = X.as_in_context(ctx), y.as_in_context(ctx)
        y_hat = net(X)
        current_acc = nd.sum(
            ndarray.cast(y_hat.argmax(axis=1), dtype="int32") == y,
        ).asscalar()
        test_acc += current_acc

    test_acc /= len(test_data)
    return test_acc

2-5) PaddlePaddle

PaddlePaddle surprised me a little at first. It’s probably because I’m more used to Keras or Pytorch, but the framework didn’t seem very intuitive. It also didn’t seem to provide much in the way of search results when I needed to debug. In particular, Colab seems to crash when using paddle.fluid.CUDAPlace instead of paddle.fluid.CPUPlace to use a GPU.

Example:

def get_data(batch_size=128) -> dict:
    """Get DataLoaders for CIFAR-10 dataset.

    Returns:
       CIFAR10 dataloader
    """
    # Load and transform the CIFAR10 data
    # Each batch will yield 128 images
    buf_size = 50000
    # Reader for training
    train_reader = paddle.batch(
        paddle.reader.shuffle(paddle.dataset.cifar.train10(), buf_size=buf_size),
        batch_size=batch_size,
    )

    # Reader for testing. A separated data set for testing.
    test_reader = paddle.batch(paddle.dataset.cifar.test10(), batch_size=batch_size)
    return {"train_reader": train_reader, "test_reader": test_reader, "buf_size": buf_size}

...

def inference_program():
    # The image is 32 * 32 with RGB representation.
    data_shape = [3, 32, 32]
    images = paddle.fluid.layers.data(name="pixel", shape=data_shape, dtype="float32")

    predict = custom_cnn(images)
    # predict = custom_cnn(images) # un-comment to use vgg net
    return predict


def train_program():
    predict = inference_program()

    label = paddle.fluid.layers.data(name="label", shape=[1], dtype="int64")
    cost = paddle.fluid.layers.cross_entropy(input=predict, label=label)
    avg_cost = paddle.fluid.layers.mean(cost)
    accuracy = paddle.fluid.layers.accuracy(input=predict, label=label)
    return [avg_cost, accuracy, predict]

def run_training(
    dataloader: dict,
    epochs: int = 3,
    batch_size: int = 128,
    learning_rate: float = 0.0001,
) -> float:
    """Run CIFAR10 training with paddlepaddle frameworks.

    Returns:
        validation accuracy
    """
    ...

    main_program = paddle.fluid.default_main_program()
    star_program = paddle.fluid.default_startup_program()

    avg_cost, acc, predict = train_program()

    # Test program
    test_program = main_program.clone(for_test=True)

    optimizer = optimizer_program(learning_rate=learning_rate)
    optimizer.minimize(avg_cost)

    exe = paddle.fluid.Executor(place)
    params_dirname = "image_classification_resnet.inference.model"

    feed_var_list_loop = [main_program.global_block().var(var_name) for var_name in feed_order]
    feeder = paddle.fluid.DataFeeder(feed_list=feed_var_list_loop, place=place)
    exe.run(star_program)

    train_reader, test_reader, buf_size = (
        dataloader["train_reader"],
        dataloader["test_reader"],
        dataloader["buf_size"],
    )
    for pass_id in range(epochs):
        progress_bar = tqdm(
            train_reader(),
            position=0,
            leave=True,
            total=round(buf_size / batch_size),
        )
        for data_train in train_reader():
            avg_loss_value = exe.run(
                main_program,
                feed=feeder.feed(data_train),
                fetch_list=[avg_cost, acc],
            )

            progress_bar.set_description(f"Epoch {pass_id+1}/{epochs}")
            progress_bar.set_postfix(train_loss=round(avg_loss_value[0], 3))
            progress_bar.update()

        _, accuracy_test = train_test(test_program, test_reader, feed_order, place, avg_cost, acc)

        # save parameters
        if params_dirname is not None:
            paddle.fluid.io.save_inference_model(params_dirname, ["pixel"], [predict], exe)
    return accuracy_test

2-6) Pytorch

Pytorch should also need no introduction. Like Keras, its code is fairly intuitive, the community is very large and customization is infinite. However, it may require more code to run a training program, and may also be more prone to code errors. On the plus side, many research papers use it, so many state-of-the-art models are directly available in Pytorch.

Example:

def run_training(
    dataloader: dict,
    epochs: int = 3,
    batch_size: int = 128,
    learning_rate: float = 0.0001,
) -> float:
    """Run CIFAR10 training with pytorch frameworks.

    Returns:
        validation accuracy
    """
    trainloader, testloader = dataloader["trainloader"], dataloader["testloader"]
    classes_count = 10
    for data in trainloader:
        input_data, _ = data
        input_size = input_data.shape

    net = Net(input_size[2:4], classes_count).to(device)

    # Define a loss function and optimizer
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(net.parameters(), lr=learning_rate)

    # Train the network
    for epoch in range(epochs):
        progress_bar = tqdm(trainloader, position=0, leave=True)
        running_loss = 0.0
        for data in trainloader:
            input_data, labels = data

            optimizer.zero_grad()

            outputs = net(input_data.to(device))
            loss = criterion(outputs, labels.to(device))
            loss.backward()
            optimizer.step()

            running_loss += loss.item()

            progress_bar.set_description(f"Epoch {epoch+1}/{epochs}")
            progress_bar.set_postfix(train_loss=round(loss.item(), 3))
            progress_bar.update()

    predictions = []
    targets = []
    with torch.no_grad():
        for data in testloader:
            images, labels = data
            outputs = net(images.to(device))
            predictions.append(outputs)
            targets.append(labels)

    predictions = torch.argmax(torch.cat(predictions, dim=0), dim=1)
    targets = torch.cat(targets, dim=0)
    return accuracy_score(targets, predictions)

2-7) Pytorch-lightning

Pytorch-lightning is less well known, but I think I can sum it up simply by saying that it is to Pytorch what Keras is to Tensorflow. It serves as a wrapper for Pytorch, giving access to more concise code thanks to many pre-implemented bricks. It seems to have a growing community that is already quite large. Several well-known tools have decided to support it, including MLFlow (an open source platform for the machine learning lifecycle), which just needs an “autolog()” in the script header to work with pytorch-lightning.

Example:

class CNNModel(L.LightningModule):
    def __init__(
        self,
        batch_size: int = 128,
        lr: float = 0.0001,
    ):
        super().__init__()

        self.batch_size = batch_size
        self.lr = lr
        # Set our init args as class attributes
        self.data_dir = os.environ.get("PATH_DATASETS", ".")

        # Hardcode some dataset specific attributes
        self.num_classes = 10
        self.dims = (3, 32, 32)
        self.transform = transforms.Compose(
            [
                transforms.ToTensor(),
            ],
        )

        num_classes = 10
        # Define PyTorch model
        self.model = nn.Sequential(
            nn.Conv2d(3, 16, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2, stride=2),
            nn.ReLU(),
            nn.Conv2d(16, 32, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2, stride=2),
            nn.ReLU(),
            nn.Flatten(),
            nn.Linear(32 * 8 * 8, 64),
            nn.ReLU(),
            nn.Linear(64, num_classes),
        )

        self.train_accuracy = Accuracy(task="multiclass", num_classes=num_classes)
        self.val_accuracy = Accuracy(task="multiclass", num_classes=num_classes)
        self.test_accuracy = Accuracy(task="multiclass", num_classes=num_classes)

    def forward(self, x):
        x = self.model(x)
        return F.log_softmax(x, dim=1)

    def training_step(self, batch, batch_idx):
        x, y = batch
        logits = self(x)
        loss = F.nll_loss(logits, y)
        preds = torch.argmax(logits, dim=1)
        self.val_accuracy.update(preds, y)

        # Calling self.log will surface up scalars for you in TensorBoard
        self.log("train_loss", loss, prog_bar=True)
        self.log("train_acc", self.train_accuracy, prog_bar=True)
        return loss

    def validation_step(self, batch, batch_idx):
        x, y = batch
        logits = self(x)
        loss = F.nll_loss(logits, y)
        preds = torch.argmax(logits, dim=1)
        self.val_accuracy.update(preds, y)

        # Calling self.log will surface up scalars for you in TensorBoard
        self.log("val_loss", loss, prog_bar=True)
        self.log("val_acc", self.val_accuracy, prog_bar=True)

    def test_step(self, batch, batch_idx):
        x, y = batch
        logits = self(x)
        loss = F.nll_loss(logits, y)
        preds = torch.argmax(logits, dim=1)
        self.test_accuracy.update(preds, y)

        # Calling self.log will surface up scalars for you in TensorBoard
        self.log("test_loss", loss, prog_bar=True)
        self.log("test_acc", self.test_accuracy, prog_bar=True)

    def configure_optimizers(self):
        return torch.optim.Adam(self.parameters(), lr=self.lr)

    ####################
    # DATA RELATED HOOKS
    ####################

    def prepare_data(self):
        # download
        CIFAR10(self.data_dir, train=True, download=True)
        CIFAR10(self.data_dir, train=False, download=True)

    def setup(self, stage=None):
        # Assign train/val datasets for use in dataloaders
        if stage == "fit" or stage is None:
            dataset = CIFAR10(self.data_dir, train=True, transform=self.transform)
            self.train_db, self.val_db = random_split(dataset, [45000, 5000])

        # Assign test dataset for use in dataloader(s)
        if stage == "test" or stage is None:
            self.test_db = CIFAR10(self.data_dir, train=False, transform=self.transform)

    def train_dataloader(self):
        return DataLoader(self.train_db, batch_size=self.batch_size)

    def val_dataloader(self):
        return DataLoader(self.val_db, batch_size=self.batch_size)

    def test_dataloader(self):
        return DataLoader(self.test_db, batch_size=self.batch_size)


def run_training(
    dataloader: dict,
    epochs: int = 3,
    batch_size: int = 128,
    learning_rate: float = 0.0001,
) -> float:
    """Run CIFAR10 training with pytorch_lightning frameworks.

    Returns:
        validation accuracy
    """
    model = CNNModel(batch_size=batch_size, lr=learning_rate)
    trainer = L.Trainer(
        accelerator="auto",
        devices=1,
        max_epochs=epochs,
        logger=CSVLogger(save_dir="logs/"),
    )
    trainer.fit(model)
    return trainer.test()[0]["test_acc"]

3) Training results

3-1) FastAI

Epoch time: 1mn15 then ~1mn06; acc after 5 epochs 46.4%

3-2) JAX

Epoch time: 3mn then ~1mn08; acc after 5 epochs 30%

3-3) Keras

Epoch time: 13s then 2.5s; acc after 5 epochs 33.5%

3-4) MXNet

Epoch time: 43s then 30s; acc after 5 epochs 52.9%

3-5) PaddlePaddle

Epoch time: 2mn18; acc after 5 epochs 68.7%

2-6) Pytorch

Epoch time: 16s; acc after 5 epochs 45.6%

2-7) Pytorch-lightning

Epoch time: 20s then 11s; acc after 5 epochs 43.6%

4) Comparison summary

Tested with GPU T4 on Google Colab

Here is a summary based on our experience. The best choices terms of :

  • ease of implementation: FastAI, Keras, Pytorch lightning / Pytorch and MXNet,
  • built-in features: FastAI, Keras and Pytorch lightning / Pytorch
  • Customization: Keras and Pytorch lightning / Pytorch
  • Community size: Keras and Pytorch lightning / Pytorch
  • Ease of deployment: Keras
  • Number of state of the art models available: Pytorch
  • Speed of training in order: Keras, Pytorch, Pytorch Lightning, MXNet, FastAI, PaddlePaddle, Jax

5) How to analyse Memory/GPU usage with pytorch profiler

Analysis of memory and GPU usage is made easy with the Pytorch-profiler tool, which can be integrated into Tensorboard. Simply add the profiler tool to your script as it is shown in the official tutorial here:

prof = profiler.profile(
schedule=profiler.schedule(wait=0, warmup=0, active=1, repeat=1),
on_trace_ready=profiler.tensorboard_trace_handler(path_to_save_profiling_log),
record_shapes=True,
with_stack=profiler_with_stack,
)

Then Tensorboard will take care of the display and even give tips in the bottom section. Tutorial example:

6) Conclusion

Many frameworks are available these days. This article has only focused on Python frameworks, but there are others in widespread use, such as Deep Learning development in Julia.


Each framework retains its own characteristics, which makes it stand out from the others in certain fields. So it’s always a good idea to know which one to choose, depending on your needs.


We hope this article has helped you in your choice, so please feel free to add to this review by giving your feedback in the comments.