Training directly with PyTorch Lightning

Using PyTorch Lightning

The TabNN models are implemented using Pytorch Lightning. It follows the following training implementation principle as described here:

# define Dataloader
train_loader = DataLoader(x_train, y_train)
val_loader = DataLoader(x_val, y_val)
test_loader = DataLoader(x_test, y_test)

# define model using a Pytorch LightningModule
nn_model = MyModel(hyper_param1, hyper_param2, ...)

# train model using the Pytorch Lightning Trainer
trainer = pl.Trainer()
trainer.fit(model=nn_model, train_dataloaders=train_loader, val_dataloaders=val_loader)

# make predictions using the Trainer
pred = trainer.predict(nn_model, dataloaders=test_loader)

In our use case, adapted to the Tabular NN Network, the implementation looks like this:

from sklearn.datasets import make_classification
from sklearn.model_selection import train_test_split

from pytabkit.models.alg_interfaces.base import SplitIdxs, InterfaceResources
from pytabkit.models.data.data import DictDataset, TensorInfo
from pytabkit.models.sklearn.default_params import DefaultParams
from pytabkit.models.training.lightning_modules import TabNNModule

import lightning.pytorch as pl  # or: import pytorch_lightning as pl
import numpy as np
import torch

n_epochs = 200

X, y = make_classification()

idxs = np.arange(len(X))
trainval_idxs, test_idxs = train_test_split(idxs, test_size=0.2)
n_trainval_splits = 5
train_idxs_list = []
val_idxs_list = []
for i in range(n_trainval_splits):
    train_idxs, val_idxs = train_test_split(trainval_idxs, test_size=0.2)
    train_idxs_list.append(train_idxs)
    val_idxs_list.append(val_idxs)

# define datasets
ds = DictDataset(tensors={'x_cont': torch.as_tensor(X, dtype=torch.float32),
                                    'x_cat': torch.zeros(len(X), 0),
                                    'y': torch.as_tensor(y, dtype=torch.long)[:, None]},
                           tensor_infos={'x_cont': TensorInfo(feat_shape=[X.shape[1]]),
                                         'x_cat': TensorInfo(cat_sizes=[]),
                                         'y': TensorInfo(cat_sizes=[np.max(y) + 1])}, )  # (1)
train_val_splitting_idxs_list = [
    SplitIdxs(train_idxs=torch.as_tensor(np.stack(train_idxs_list, axis=0), dtype=torch.long),
              val_idxs=torch.as_tensor(np.stack(val_idxs_list, axis=0), dtype=torch.long),
              test_idxs=torch.as_tensor(test_idxs, dtype=torch.long),
              split_seed=0, sub_split_seeds=list(range(len(train_idxs_list))), split_id=0)]

test_ds = ds.get_sub_dataset(torch.as_tensor(test_idxs, dtype=torch.long))

# Create assigned resources
# interface_resources = InterfaceResources(n_threads=4, gpu_devices=['cuda:0'])  # (2)
interface_resources = InterfaceResources(n_threads=4, gpu_devices=[])  # (2)

# define the model using our LightningModule TabNNModule
nn_model = TabNNModule(**DefaultParams.RealMLP_TD_CLASS)
# build and 'compile' the model using the data, now it is ready to use
nn_model.compile_model(ds, train_val_splitting_idxs_list, interface_resources)

# train the model using the Pytorch Lightning Trainer
trainer = pl.Trainer(
    callbacks=nn_model.create_callbacks(),
    max_epochs=n_epochs,
    enable_checkpointing=False,
    enable_progress_bar=False,
    num_sanity_val_steps=0,
    logger=pl.loggers.logger.DummyLogger(),
)  # (3)

trainer.fit(
    model=nn_model,
    train_dataloaders=nn_model.train_dl,
    val_dataloaders=nn_model.val_dl
)
# make predictions using the Trainer
pred = trainer.predict(
    model=nn_model,
    dataloaders=nn_model.get_predict_dataloader(test_ds)
)
  1. The NN Models have special requirements for their dataloaders, therefore we need to use the DictDataset first to create a dataset for both training and validation.

  2. We handle our resource management manually, not with Lightning, therefore we need to create an InterfaceResources object

  3. We use the original Trainer Class from Lightning. However, all of the parameters specified here are obligatory for the TabNNModule to work properly.