Comparison of Deep Learning frameworks
This article shows how to implement a training on CIFAR10 dataset with different frameworks (FastAI, JAX, Keras, MXNet, PaddlePaddle, Pytorch and Pytorch-lightning) to compare them in terms of:
- ease of implementation (user friendly coding, ease of finding information online, etc.),
- time per epoch for the same model and the same training parameters,
- memory and GPU usage (thanks to pytorch-profiler),
- accuracy obtained after the same training.
The code can be found here.
Table of contents
- Introduction
- Coding overview
- FastAI
- JAX
- Keras
- MXNet
- PaddlePaddle
- Pytorch
- Pytorch-lightning
- Training results
- Comparison summary
- How to analyse Memory/GPU usage with pytorch profiler
- Conclusion
1) Introduction
This article gives an overview of the differences between the best-known Python frameworks for Deep Learning.
The code is available on Github and can also be used as a functional basis for launching any kind of training.
First, we’ll look at the key points in terms of programming, then go into more detail on inference times, convergence speeds and memory/gpu usage.
2) Coding overview
2-1) FastAI
FastAI lives up to its name. It’s extremely easy to code a training session: 2 lines to load the dataset with train/test split and 3 lines to launch a training session and evaluate the model. It’s certainly the most user-friendly famework in terms of programming.
It also offers many useful functions such as lr_find() to find the optimal learning rate. The documentation is also very well done and there seems to be quite a large community.
Example:
def get_data(batch_size=128) -> DataLoaders:
"""Get DataLoaders for CIFAR-10 dataset.
Returns:
CIFAR10 dataloader
"""
path = untar_data(URLs.CIFAR)
return ImageDataLoaders.from_folder(path, train="train", valid="test", bs=batch_size)
def run_training(
dataloader: DataLoaders,
epochs: int = 3,
batch_size: int = 128,
learning_rate: float = 0.0001,
) -> float:
"""Run CIFAR10 training with FASTAI frameworks.
Returns:
validation accuracy
"""
model = CustomCNN()
learn = Learner(dataloader, model, metrics=accuracy)
learn.fit(epochs, lr=learning_rate)
return learn.validate()[1]
2-2) JAX
As for Jax, asked me a lot more effort into developing this training code. In particular, to get all the data in the right format. I am no Jax expert, but it also seemed to me necessary to implement the functions for evaluating predictions, cross-entropy loss and the training loop. All this can lead to implementation errors quite quickly, making the framework less usable for the uninitiated. The community also didn’t seem very important to me, as I couldn’t find the answers to my debugging questions.
Example:
def CrossEntropyLoss(
weights: list,
input_data: jax.Array,
targets: jax.Array,
model: CustomCNN,
) -> jax.Array:
"""Implement of cross entropy loss.
Args:
weights: list from _, _, opt_get_weights = optimizers.adam(lr), opt_get_weights(opt_state)
input_data: data to predict
targets: groundtruth targets in one hot encoding
model: model with conv_apply var
Returns:
loss value
"""
preds = model.conv_apply(weights, input_data)
log_preds = jnp.log(preds + tf.keras.backend.epsilon())
return -jnp.mean(targets * log_preds)
def run_training(
dataloader: Union[tf.Tensor, tf.data.Dataset],
epochs: int = 3,
batch_size: int = 128,
learning_rate: float = 0.0001,
) -> float:
"""Run CIFAR10 training with JAX frameworks.
Returns:
validation accuracy
"""
model = CustomCNN()
train_data, test_data = dataloader["train"], dataloader["test"]
X_train, Y_train = train_data["image"], train_data["label"]
X_test, Y_test = test_data["image"], test_data["label"]
X_train, X_test, Y_train, Y_test = (
jnp.array(X_train, dtype=jnp.float32),
jnp.array(X_test, dtype=jnp.float32),
jnp.array(Y_train, dtype=jnp.float32),
jnp.array(Y_test, dtype=jnp.float32),
)
rng = jax.random.PRNGKey(123)
weights = model.conv_init(rng, (18, 32, 32, 3))[1]
opt_init, opt_update, opt_get_weights = optimizers.adam(learning_rate)
opt_state = opt_init(weights)
Y_train_one_hot = jax.nn.one_hot(Y_train, num_classes=10)
for i in range(epochs):
batches = jnp.arange((X_train.shape[0] // batch_size) + 1)
progress_bar = tqdm(batches, position=0, leave=True)
losses = []
for batch in batches:
if batch != batches[-1]:
start, end = int(batch * batch_size), int(batch * batch_size + batch_size)
else:
start, end = int(batch * batch_size), None
X_batch, Y_batch = X_train[start:end], Y_train_one_hot[start:end]
loss, gradients = value_and_grad(CrossEntropyLoss)(
opt_get_weights(opt_state),
X_batch,
Y_batch,
model=model,
)
## Update Weights
opt_state = opt_update(i, gradients, opt_state)
losses.append(loss)
progress_bar.set_description(f"Epoch {i+1}/{epochs}")
progress_bar.set_postfix(train_loss=jnp.round(jnp.array(losses).mean(), decimals=3))
progress_bar.update()
test_preds = MakePredictions(
opt_get_weights(opt_state),
X_test,
batch_size=batch_size,
model=model,
)
## Combine predictions of all batches
test_preds = jnp.concatenate(test_preds).squeeze()
test_preds = jnp.argmax(test_preds, axis=1)
return accuracy_score(Y_test, test_preds)
2-3) Keras
I think Keras needs no introduction: very easy to use, short and intuitive code, great customization possibilities and a very large community. It’s hard to find fault with it. It’s also very deployment-accessible, with lots of tools, which makes it the framework of choice for businesses.
Example:
def get_data(batch_size=128) -> dict:
"""Get DataLoaders for CIFAR-10 dataset.
Returns:
CIFAR10 dataloader
"""
# Load the CIFAR-10 dataset
(x_train, y_train), (x_test, y_test) = keras.datasets.cifar10.load_data()
return {"train_data": (x_train, y_train), "test_data": (x_test, y_test)}
def run_training(
dataloader: dict,
epochs: int = 3,
batch_size: int = 128,
learning_rate: float = 0.0001,
) -> float:
"""Run CIFAR10 training with keras frameworks.
Returns:
validation accuracy
"""
(x_train, y_train), (x_test, y_test) = dataloader["train_data"], dataloader["test_data"]
# Compile the model
model = CustomCNN()
optimizer = keras.optimizers.Adam(learning_rate=learning_rate)
model.compile(
optimizer=optimizer,
loss="sparse_categorical_crossentropy",
metrics=["accuracy"],
)
# Train the model
model.fit(
x_train,
y_train,
epochs=epochs,
validation_data=(x_test, y_test),
batch_size=batch_size,
)
# Evaluate the model
_, test_acc = model.evaluate(x_test, y_test)
return test_acc
2-4) MXNet
MXNet is easy to use. It has a lot of built-in block functions, such as creating your own dataloader, importing losses, optimizing, etc. Then the training loop seems to have to be implemented. As I haven’t had any particular problems, I didn’t need to debug it, but it seems to me that there’s a community online. Its closest framework seems to me to be Pytorch.
Example:
def get_data(batch_size=128) -> dict:
"""Get DataLoaders for CIFAR-10 dataset.
Returns:
CIFAR10 dataloader
"""
# Load and transform the CIFAR10 data
transform_train = transforms.Compose(
[
transforms.ToTensor(),
# transforms.Normalize(mean=[0.4914, 0.4822, 0.4465], std=[0.2023, 0.1994, 0.2010]),
],
)
transform_test = transforms.Compose(
[
transforms.ToTensor(),
# transforms.Normalize(mean=[0.4914, 0.4822, 0.4465], std=[0.2023, 0.1994, 0.2010]),
],
)
train_data = datasets.CIFAR10(train=True).transform_first(transform_train)
test_data = datasets.CIFAR10(train=False).transform_first(transform_test)
return {"train_data": train_data, "test_data": test_data}
def run_training(
dataloader: dict,
epochs: int = 3,
batch_size: int = 128,
learning_rate: float = 0.0001,
) -> float:
"""Run CIFAR10 training with MXNET frameworks.
Returns:
validation accuracy
"""
net = CustomCNN().net
train_data, test_data = dataloader["train_data"], dataloader["test_data"]
# Define the loss function and optimizer
softmax_cross_entropy = gluon.loss.SoftmaxCrossEntropyLoss()
trainer = gluon.Trainer(net.collect_params(), "adam", {"learning_rate": learning_rate})
# Train the network
batch_size = 128
train_loader = gluon.data.DataLoader(train_data, batch_size, shuffle=True)
test_loader = gluon.data.DataLoader(test_data, batch_size)
for epoch in range(epochs):
progress_bar = tqdm(train_loader, position=0, leave=True)
train_loss, train_acc, n = 0.0, 0.0, 0
for X, y in train_loader:
X, y = X.as_in_context(ctx), y.as_in_context(ctx)
with autograd.record():
y_hat = net(X)
loss = softmax_cross_entropy(y_hat, y)
loss.backward()
trainer.step(batch_size)
current_loss = nd.sum(loss).asscalar()
train_loss += current_loss
current_acc = nd.sum(
ndarray.cast(y_hat.argmax(axis=1), dtype="int32") == y,
).asscalar()
train_acc += current_acc
n += y.size
progress_bar.set_description(f"Epoch {epoch+1}/{epochs}")
progress_bar.set_postfix(train_loss=round(current_loss, 3), train_acc=current_acc)
progress_bar.update()
test_acc = 0.0
for X, y in test_loader:
X, y = X.as_in_context(ctx), y.as_in_context(ctx)
y_hat = net(X)
current_acc = nd.sum(
ndarray.cast(y_hat.argmax(axis=1), dtype="int32") == y,
).asscalar()
test_acc += current_acc
test_acc /= len(test_data)
return test_acc
2-5) PaddlePaddle
PaddlePaddle surprised me a little at first. It’s probably because I’m more used to Keras or Pytorch, but the framework didn’t seem very intuitive. It also didn’t seem to provide much in the way of search results when I needed to debug. In particular, Colab seems to crash when using paddle.fluid.CUDAPlace instead of paddle.fluid.CPUPlace to use a GPU.
Example:
def get_data(batch_size=128) -> dict:
"""Get DataLoaders for CIFAR-10 dataset.
Returns:
CIFAR10 dataloader
"""
# Load and transform the CIFAR10 data
# Each batch will yield 128 images
buf_size = 50000
# Reader for training
train_reader = paddle.batch(
paddle.reader.shuffle(paddle.dataset.cifar.train10(), buf_size=buf_size),
batch_size=batch_size,
)
# Reader for testing. A separated data set for testing.
test_reader = paddle.batch(paddle.dataset.cifar.test10(), batch_size=batch_size)
return {"train_reader": train_reader, "test_reader": test_reader, "buf_size": buf_size}
...
def inference_program():
# The image is 32 * 32 with RGB representation.
data_shape = [3, 32, 32]
images = paddle.fluid.layers.data(name="pixel", shape=data_shape, dtype="float32")
predict = custom_cnn(images)
# predict = custom_cnn(images) # un-comment to use vgg net
return predict
def train_program():
predict = inference_program()
label = paddle.fluid.layers.data(name="label", shape=[1], dtype="int64")
cost = paddle.fluid.layers.cross_entropy(input=predict, label=label)
avg_cost = paddle.fluid.layers.mean(cost)
accuracy = paddle.fluid.layers.accuracy(input=predict, label=label)
return [avg_cost, accuracy, predict]
def run_training(
dataloader: dict,
epochs: int = 3,
batch_size: int = 128,
learning_rate: float = 0.0001,
) -> float:
"""Run CIFAR10 training with paddlepaddle frameworks.
Returns:
validation accuracy
"""
...
main_program = paddle.fluid.default_main_program()
star_program = paddle.fluid.default_startup_program()
avg_cost, acc, predict = train_program()
# Test program
test_program = main_program.clone(for_test=True)
optimizer = optimizer_program(learning_rate=learning_rate)
optimizer.minimize(avg_cost)
exe = paddle.fluid.Executor(place)
params_dirname = "image_classification_resnet.inference.model"
feed_var_list_loop = [main_program.global_block().var(var_name) for var_name in feed_order]
feeder = paddle.fluid.DataFeeder(feed_list=feed_var_list_loop, place=place)
exe.run(star_program)
train_reader, test_reader, buf_size = (
dataloader["train_reader"],
dataloader["test_reader"],
dataloader["buf_size"],
)
for pass_id in range(epochs):
progress_bar = tqdm(
train_reader(),
position=0,
leave=True,
total=round(buf_size / batch_size),
)
for data_train in train_reader():
avg_loss_value = exe.run(
main_program,
feed=feeder.feed(data_train),
fetch_list=[avg_cost, acc],
)
progress_bar.set_description(f"Epoch {pass_id+1}/{epochs}")
progress_bar.set_postfix(train_loss=round(avg_loss_value[0], 3))
progress_bar.update()
_, accuracy_test = train_test(test_program, test_reader, feed_order, place, avg_cost, acc)
# save parameters
if params_dirname is not None:
paddle.fluid.io.save_inference_model(params_dirname, ["pixel"], [predict], exe)
return accuracy_test
2-6) Pytorch
Pytorch should also need no introduction. Like Keras, its code is fairly intuitive, the community is very large and customization is infinite. However, it may require more code to run a training program, and may also be more prone to code errors. On the plus side, many research papers use it, so many state-of-the-art models are directly available in Pytorch.
Example:
def run_training(
dataloader: dict,
epochs: int = 3,
batch_size: int = 128,
learning_rate: float = 0.0001,
) -> float:
"""Run CIFAR10 training with pytorch frameworks.
Returns:
validation accuracy
"""
trainloader, testloader = dataloader["trainloader"], dataloader["testloader"]
classes_count = 10
for data in trainloader:
input_data, _ = data
input_size = input_data.shape
net = Net(input_size[2:4], classes_count).to(device)
# Define a loss function and optimizer
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(net.parameters(), lr=learning_rate)
# Train the network
for epoch in range(epochs):
progress_bar = tqdm(trainloader, position=0, leave=True)
running_loss = 0.0
for data in trainloader:
input_data, labels = data
optimizer.zero_grad()
outputs = net(input_data.to(device))
loss = criterion(outputs, labels.to(device))
loss.backward()
optimizer.step()
running_loss += loss.item()
progress_bar.set_description(f"Epoch {epoch+1}/{epochs}")
progress_bar.set_postfix(train_loss=round(loss.item(), 3))
progress_bar.update()
predictions = []
targets = []
with torch.no_grad():
for data in testloader:
images, labels = data
outputs = net(images.to(device))
predictions.append(outputs)
targets.append(labels)
predictions = torch.argmax(torch.cat(predictions, dim=0), dim=1)
targets = torch.cat(targets, dim=0)
return accuracy_score(targets, predictions)
2-7) Pytorch-lightning
Pytorch-lightning is less well known, but I think I can sum it up simply by saying that it is to Pytorch what Keras is to Tensorflow. It serves as a wrapper for Pytorch, giving access to more concise code thanks to many pre-implemented bricks. It seems to have a growing community that is already quite large. Several well-known tools have decided to support it, including MLFlow (an open source platform for the machine learning lifecycle), which just needs an “autolog()” in the script header to work with pytorch-lightning.
Example:
class CNNModel(L.LightningModule):
def __init__(
self,
batch_size: int = 128,
lr: float = 0.0001,
):
super().__init__()
self.batch_size = batch_size
self.lr = lr
# Set our init args as class attributes
self.data_dir = os.environ.get("PATH_DATASETS", ".")
# Hardcode some dataset specific attributes
self.num_classes = 10
self.dims = (3, 32, 32)
self.transform = transforms.Compose(
[
transforms.ToTensor(),
],
)
num_classes = 10
# Define PyTorch model
self.model = nn.Sequential(
nn.Conv2d(3, 16, kernel_size=3, padding=1),
nn.ReLU(),
nn.MaxPool2d(kernel_size=2, stride=2),
nn.ReLU(),
nn.Conv2d(16, 32, kernel_size=3, padding=1),
nn.ReLU(),
nn.MaxPool2d(kernel_size=2, stride=2),
nn.ReLU(),
nn.Flatten(),
nn.Linear(32 * 8 * 8, 64),
nn.ReLU(),
nn.Linear(64, num_classes),
)
self.train_accuracy = Accuracy(task="multiclass", num_classes=num_classes)
self.val_accuracy = Accuracy(task="multiclass", num_classes=num_classes)
self.test_accuracy = Accuracy(task="multiclass", num_classes=num_classes)
def forward(self, x):
x = self.model(x)
return F.log_softmax(x, dim=1)
def training_step(self, batch, batch_idx):
x, y = batch
logits = self(x)
loss = F.nll_loss(logits, y)
preds = torch.argmax(logits, dim=1)
self.val_accuracy.update(preds, y)
# Calling self.log will surface up scalars for you in TensorBoard
self.log("train_loss", loss, prog_bar=True)
self.log("train_acc", self.train_accuracy, prog_bar=True)
return loss
def validation_step(self, batch, batch_idx):
x, y = batch
logits = self(x)
loss = F.nll_loss(logits, y)
preds = torch.argmax(logits, dim=1)
self.val_accuracy.update(preds, y)
# Calling self.log will surface up scalars for you in TensorBoard
self.log("val_loss", loss, prog_bar=True)
self.log("val_acc", self.val_accuracy, prog_bar=True)
def test_step(self, batch, batch_idx):
x, y = batch
logits = self(x)
loss = F.nll_loss(logits, y)
preds = torch.argmax(logits, dim=1)
self.test_accuracy.update(preds, y)
# Calling self.log will surface up scalars for you in TensorBoard
self.log("test_loss", loss, prog_bar=True)
self.log("test_acc", self.test_accuracy, prog_bar=True)
def configure_optimizers(self):
return torch.optim.Adam(self.parameters(), lr=self.lr)
####################
# DATA RELATED HOOKS
####################
def prepare_data(self):
# download
CIFAR10(self.data_dir, train=True, download=True)
CIFAR10(self.data_dir, train=False, download=True)
def setup(self, stage=None):
# Assign train/val datasets for use in dataloaders
if stage == "fit" or stage is None:
dataset = CIFAR10(self.data_dir, train=True, transform=self.transform)
self.train_db, self.val_db = random_split(dataset, [45000, 5000])
# Assign test dataset for use in dataloader(s)
if stage == "test" or stage is None:
self.test_db = CIFAR10(self.data_dir, train=False, transform=self.transform)
def train_dataloader(self):
return DataLoader(self.train_db, batch_size=self.batch_size)
def val_dataloader(self):
return DataLoader(self.val_db, batch_size=self.batch_size)
def test_dataloader(self):
return DataLoader(self.test_db, batch_size=self.batch_size)
def run_training(
dataloader: dict,
epochs: int = 3,
batch_size: int = 128,
learning_rate: float = 0.0001,
) -> float:
"""Run CIFAR10 training with pytorch_lightning frameworks.
Returns:
validation accuracy
"""
model = CNNModel(batch_size=batch_size, lr=learning_rate)
trainer = L.Trainer(
accelerator="auto",
devices=1,
max_epochs=epochs,
logger=CSVLogger(save_dir="logs/"),
)
trainer.fit(model)
return trainer.test()[0]["test_acc"]
3) Training results
3-1) FastAI
Epoch time: 1mn15 then ~1mn06; acc after 5 epochs 46.4%
3-2) JAX
Epoch time: 3mn then ~1mn08; acc after 5 epochs 30%
3-3) Keras
Epoch time: 13s then 2.5s; acc after 5 epochs 33.5%
3-4) MXNet
Epoch time: 43s then 30s; acc after 5 epochs 52.9%
3-5) PaddlePaddle
Epoch time: 2mn18; acc after 5 epochs 68.7%
2-6) Pytorch
Epoch time: 16s; acc after 5 epochs 45.6%
2-7) Pytorch-lightning
Epoch time: 20s then 11s; acc after 5 epochs 43.6%
4) Comparison summary
Tested with GPU T4 on Google Colab
Here is a summary based on our experience. The best choices terms of :
- ease of implementation: FastAI, Keras, Pytorch lightning / Pytorch and MXNet,
- built-in features: FastAI, Keras and Pytorch lightning / Pytorch
- Customization: Keras and Pytorch lightning / Pytorch
- Community size: Keras and Pytorch lightning / Pytorch
- Ease of deployment: Keras
- Number of state of the art models available: Pytorch
- Speed of training in order: Keras, Pytorch, Pytorch Lightning, MXNet, FastAI, PaddlePaddle, Jax
5) How to analyse Memory/GPU usage with pytorch profiler
Analysis of memory and GPU usage is made easy with the Pytorch-profiler tool, which can be integrated into Tensorboard. Simply add the profiler tool to your script as it is shown in the official tutorial here:
prof = profiler.profile(
schedule=profiler.schedule(wait=0, warmup=0, active=1, repeat=1),
on_trace_ready=profiler.tensorboard_trace_handler(path_to_save_profiling_log),
record_shapes=True,
with_stack=profiler_with_stack,
)
Then Tensorboard will take care of the display and even give tips in the bottom section. Tutorial example:
6) Conclusion
Many frameworks are available these days. This article has only focused on Python frameworks, but there are others in widespread use, such as Deep Learning development in Julia.
Each framework retains its own characteristics, which makes it stand out from the others in certain fields. So it’s always a good idea to know which one to choose, depending on your needs.
We hope this article has helped you in your choice, so please feel free to add to this review by giving your feedback in the comments.