import datetime from argparse import ArgumentParser import torch from lightning import Trainer from lightning.pytorch.loggers import TensorBoardLogger from lightning.pytorch.callbacks import ModelSummary from src.trainer import ViTLightningModule def main(): """ Neural network trainer entry point. """ parser = ArgumentParser(description='KAUST-SDAIA Diabetic Retinopathy') parser.add_argument('--tag', action='store', type=str, help='Extra suffix to put on the artefact dir name') parser.add_argument('--debug', action='store_true', help="Dummy training cycle for testing purposes") parser.add_argument('--convert-checkpoint', action='store', type=str, help='Convert a checkpoint from training to pickle-independent ' 'predictor-compatible directory') args = parser.parse_args() torch.set_float32_matmul_precision('high') # for V100/A100 if args.convert_checkpoint is not None: print("Converting checkpoint", args.convert_checkpoint) checkpoint = torch.load(args.convert_checkpoint, map_location="cpu") print(list(checkpoint.keys())) model = ViTLightningModule.load_from_checkpoint( args.convert_checkpoint, map_location="cpu", hparams_file="tmp_ckpt_deleteme.yaml") model.save_checkpoint_dk("tmp_checkp_path_deleteme") print("Saved checkpoint. Done.") else: print("Start training") fast_dev_run = True if args.debug == True else False model = ViTLightningModule(fast_dev_run) datetime_str = datetime.datetime.now().strftime("%Y-%m-%d_%H-%M-%S") art_dir_name = (f"{datetime_str}" + (f"_{args.tag}" if args.tag is not None else "")) logger = TensorBoardLogger(save_dir=".", name="lightning_logs", version=art_dir_name) trainer = Trainer( logger=logger, benchmark=True, devices="auto", accelerator="auto", max_epochs=-1, callbacks=[ ModelSummary(max_depth=-1), ], fast_dev_run=fast_dev_run, log_every_n_steps=10, ) trainer.fit( model, train_dataloaders=model._train_dataloader, val_dataloaders=model._val_dataloader, ) print("Training done") if __name__ == "__main__": main()