Training directly with PyTorch Lightning
Using PyTorch Lightning
The TabNN models are implemented using Pytorch Lightning. It follows the following training implementation principle as described here:
# define Dataloader
train_loader = DataLoader(x_train, y_train)
val_loader = DataLoader(x_val, y_val)
test_loader = DataLoader(x_test, y_test)
# define model using a Pytorch LightningModule
nn_model = MyModel(hyper_param1, hyper_param2, ...)
# train model using the Pytorch Lightning Trainer
trainer = pl.Trainer()
trainer.fit(model=nn_model, train_dataloaders=train_loader, val_dataloaders=val_loader)
# make predictions using the Trainer
pred = trainer.predict(nn_model, dataloaders=test_loader)
In our use case, adapted to the Tabular NN Network, the implementation looks like this:
from sklearn.datasets import make_classification
from sklearn.model_selection import train_test_split
from pytabkit.models.alg_interfaces.base import SplitIdxs, InterfaceResources
from pytabkit.models.data.data import DictDataset, TensorInfo
from pytabkit.models.sklearn.default_params import DefaultParams
from pytabkit.models.training.lightning_modules import TabNNModule
import lightning.pytorch as pl # or: import pytorch_lightning as pl
import numpy as np
import torch
n_epochs = 200
X, y = make_classification()
idxs = np.arange(len(X))
trainval_idxs, test_idxs = train_test_split(idxs, test_size=0.2)
n_trainval_splits = 5
train_idxs_list = []
val_idxs_list = []
for i in range(n_trainval_splits):
train_idxs, val_idxs = train_test_split(trainval_idxs, test_size=0.2)
train_idxs_list.append(train_idxs)
val_idxs_list.append(val_idxs)
# define datasets
ds = DictDataset(tensors={'x_cont': torch.as_tensor(X, dtype=torch.float32),
'x_cat': torch.zeros(len(X), 0),
'y': torch.as_tensor(y, dtype=torch.long)[:, None]},
tensor_infos={'x_cont': TensorInfo(feat_shape=[X.shape[1]]),
'x_cat': TensorInfo(cat_sizes=[]),
'y': TensorInfo(cat_sizes=[np.max(y) + 1])}, ) # (1)
train_val_splitting_idxs_list = [
SplitIdxs(train_idxs=torch.as_tensor(np.stack(train_idxs_list, axis=0), dtype=torch.long),
val_idxs=torch.as_tensor(np.stack(val_idxs_list, axis=0), dtype=torch.long),
test_idxs=torch.as_tensor(test_idxs, dtype=torch.long),
split_seed=0, sub_split_seeds=list(range(len(train_idxs_list))), split_id=0)]
test_ds = ds.get_sub_dataset(torch.as_tensor(test_idxs, dtype=torch.long))
# Create assigned resources
# interface_resources = InterfaceResources(n_threads=4, gpu_devices=['cuda:0']) # (2)
interface_resources = InterfaceResources(n_threads=4, gpu_devices=[]) # (2)
# define the model using our LightningModule TabNNModule
nn_model = TabNNModule(**DefaultParams.RealMLP_TD_CLASS)
# build and 'compile' the model using the data, now it is ready to use
nn_model.compile_model(ds, train_val_splitting_idxs_list, interface_resources)
# train the model using the Pytorch Lightning Trainer
trainer = pl.Trainer(
callbacks=nn_model.create_callbacks(),
max_epochs=n_epochs,
enable_checkpointing=False,
enable_progress_bar=False,
num_sanity_val_steps=0,
logger=pl.loggers.logger.DummyLogger(),
) # (3)
trainer.fit(
model=nn_model,
train_dataloaders=nn_model.train_dl,
val_dataloaders=nn_model.val_dl
)
# make predictions using the Trainer
pred = trainer.predict(
model=nn_model,
dataloaders=nn_model.get_predict_dataloader(test_ds)
)
The NN Models have special requirements for their dataloaders, therefore we need to use the
DictDatasetfirst to create a dataset for both training and validation.We handle our resource management manually, not with Lightning, therefore we need to create an
InterfaceResourcesobjectWe use the original
TrainerClass from Lightning. However, all of the parameters specified here are obligatory for the TabNNModule to work properly.