custom-resnet50d / modeling_resnet.py
thinh-huynh-re's picture
Upload model
257c098
raw
history blame contribute delete
No virus
2.51 kB
from typing import Dict
import timm
from timm.models.resnet import BasicBlock, Bottleneck, ResNet
from torch import Tensor, nn
from transformers import PreTrainedModel
from .configuration_resnet import ResnetConfig
BLOCK_MAPPING = {"basic": BasicBlock, "bottleneck": Bottleneck}
class ResnetModel(PreTrainedModel):
"""
The line that sets the config_class is not mandatory,
unless you want to register your model with the auto classes
"""
config_class = ResnetConfig
def __init__(self, config: ResnetConfig):
super().__init__(config)
block_layer = BLOCK_MAPPING[config.block_type]
self.model = ResNet(
block_layer,
config.layers,
num_classes=config.num_classes,
in_chans=config.input_channels,
cardinality=config.cardinality,
base_width=config.base_width,
stem_width=config.stem_width,
stem_type=config.stem_type,
avg_down=config.avg_down,
)
def forward(self, tensor: Tensor) -> Tensor:
return self.model.forward_features(tensor)
class ResnetModelForImageClassification(PreTrainedModel):
"""
The line that sets the config_class is not mandatory,
unless you want to register your model with the auto classes
"""
config_class = ResnetConfig
def __init__(self, config: ResnetConfig):
super().__init__(config)
self.model = ResnetModel(config)
"""
You can have your model return anything you want,
but returning a dictionary like we did for ResnetModelForImageClassification,
with the loss included when labels are passed,
will make your model directly usable inside the Trainer class.
Using another output format is fine as long as you are planning on
using your own training loop or another library for training.
"""
def forward(self, tensor: Tensor, labels=None) -> Dict[str, Tensor]:
logits = self.model(tensor)
if labels is not None:
loss = nn.cross_entropy(logits, labels)
return {"loss": loss, "logits": logits}
return {"logits": logits}
if __name__ == "__main__":
resnet50d_config = ResnetConfig.from_pretrained("custom-resnet")
resnet50d = ResnetModelForImageClassification(resnet50d_config)
# Load pretrained weights from timm
pretrained_model: nn.Module = timm.create_model("resnet50d", pretrained=True)
resnet50d.model.load_state_dict(pretrained_model.state_dict())